Commit dc469751 authored by Dr.李's avatar Dr.李

added sql engine tests

parent 69b2d0d7
......@@ -10,6 +10,7 @@ from typing import List
from typing import Dict
from typing import Tuple
from typing import Union
import numpy as np
import pandas as pd
import sqlalchemy as sa
import sqlalchemy.orm as orm
......@@ -43,8 +44,10 @@ from alphamind.data.engines.utilities import _map_industry_category
from alphamind.data.engines.utilities import _map_risk_model_table
from alphamind.data.engines.utilities import factor_tables
from alphamind.data.engines.utilities import industry_list
from alphamind.data.processing import factor_processing
from PyFin.api import advanceDateByCalendar
risk_styles = ['BETA',
'MOMENTUM',
'SIZE',
......@@ -196,7 +199,10 @@ class SqlEngine(object):
codes: Iterable[int],
expiry_date: str = None,
horizon: int = 0,
offset: int = 0) -> pd.DataFrame:
offset: int = 0,
neutralized_risks: list = None,
pre_process=None,
post_process=None) -> pd.DataFrame:
start_date = ref_date
if not expiry_date:
......@@ -216,6 +222,15 @@ class SqlEngine(object):
df = pd.read_sql(query, self.session.bind).dropna()
df = df[df.trade_date == ref_date]
if neutralized_risks:
_, risk_exp = self.fetch_risk_model(ref_date, codes)
df = pd.merge(df, risk_exp, on='code').dropna()
df[['dx']] = factor_processing(df[['dx']].values,
pre_process=pre_process,
risk_factors=df[neutralized_risks].values,
post_process=post_process)
return df[['code', 'dx']]
def fetch_dx_return_range(self,
......@@ -257,7 +272,7 @@ class SqlEngine(object):
if dates:
df = df[df.trade_date.isin(dates)]
return df
return df.sort_values(['trade_date', 'code'])
def fetch_dx_return_index(self,
ref_date: str,
......@@ -273,8 +288,7 @@ class SqlEngine(object):
else:
end_date = expiry_date
stats = self._create_stats(IndexMarket, horizon, offset, code_attr='indexCode')
stats = self._create_stats(IndexMarket, horizon, offset, code_attr='indexCode')
query = select([IndexMarket.trade_date, IndexMarket.indexCode.label('code'), stats]).where(
and_(
IndexMarket.trade_date.between(start_date, end_date),
......@@ -302,7 +316,6 @@ class SqlEngine(object):
str(1 + horizon + offset + DAILY_RETURN_OFFSET) + 'b').strftime('%Y-%m-%d')
stats = self._create_stats(IndexMarket, horizon, offset, code_attr='indexCode')
query = select([IndexMarket.trade_date, IndexMarket.indexCode.label('code'), stats]) \
.where(
and_(
......@@ -355,13 +368,17 @@ class SqlEngine(object):
.select_from(big_table).where(and_(Market.trade_date.between(start_date, end_date),
Market.code.in_(codes)))
df = pd.read_sql(query, self.engine).sort_values(['trade_date', 'code']).set_index('trade_date')
res = transformer.transform('code', df)
df = pd.read_sql(query, self.engine) \
.replace([-np.inf, np.inf], np.nan) \
.sort_values(['trade_date', 'code']) \
.set_index('trade_date')
res = transformer.transform('code', df).replace([-np.inf, np.inf], np.nan)
for col in res.columns:
if col not in set(['code', 'isOpen']) and col not in df.columns:
df[col] = res[col].values
df.dropna(inplace=True)
df['isOpen'] = df.isOpen.astype(bool)
df = df.loc[ref_date]
df.index = list(range(len(df)))
......@@ -415,7 +432,7 @@ class SqlEngine(object):
)
).distinct()
df = pd.read_sql(query, self.engine)
df = pd.read_sql(query, self.engine).replace([-np.inf, np.inf], np.nan)
if universe.is_filtered:
df = pd.merge(df, universe_df, how='inner', on=['trade_date', 'code'])
......@@ -424,12 +441,13 @@ class SqlEngine(object):
df.sort_values(['trade_date', 'code'], inplace=True)
df.set_index('trade_date', inplace=True)
res = transformer.transform('code', df)
res = transformer.transform('code', df).replace([-np.inf, np.inf], np.nan)
for col in res.columns:
if col not in set(['code', 'isOpen']) and col not in df.columns:
df[col] = res[col].values
df.dropna(inplace=True)
df['isOpen'] = df.isOpen.astype(bool)
df = df.reset_index()
return pd.merge(df, universe_df[['trade_date', 'code']], how='inner')
......@@ -440,7 +458,6 @@ class SqlEngine(object):
start_date: str = None,
end_date: str = None,
dates: Iterable[str] = None):
if isinstance(factors, Transformer):
transformer = factors
else:
......@@ -480,7 +497,10 @@ class SqlEngine(object):
)
)
df = pd.read_sql(query, self.engine).sort_values(['trade_date', 'code'])
df = pd.read_sql(query, self.engine) \
.replace([-np.inf, np.inf], np.nan) \
.dropna() \
.sort_values(['trade_date', 'code'])
return pd.merge(df, codes[['trade_date', 'code']], how='inner')
def fetch_benchmark(self,
......@@ -553,7 +573,7 @@ class SqlEngine(object):
RiskExposure.code.in_(codes)
)).distinct()
risk_exp = pd.read_sql(query, self.engine)
risk_exp = pd.read_sql(query, self.engine).dropna()
return risk_cov, risk_exp
......@@ -608,7 +628,7 @@ class SqlEngine(object):
special_risk_table.SRISK.label('srisk')] + risk_exposure_cols).select_from(big_table) \
.distinct()
risk_exp = pd.read_sql(query, self.engine).sort_values(['trade_date', 'code'])
risk_exp = pd.read_sql(query, self.engine).sort_values(['trade_date', 'code']).dropna()
if universe.is_filtered:
codes = universe.query(self, start_date, end_date, dates)
......@@ -637,7 +657,7 @@ class SqlEngine(object):
)
).distinct()
return pd.read_sql(query, self.engine)
return pd.read_sql(query, self.engine).dropna()
def fetch_industry_matrix(self,
ref_date: str,
......@@ -687,7 +707,7 @@ class SqlEngine(object):
getattr(Industry, code_name).label('industry_code'),
getattr(Industry, category_name).label('industry')]).select_from(big_table).distinct()
df = pd.read_sql(query, self.engine)
df = pd.read_sql(query, self.engine).dropna()
if universe.is_filtered:
codes = universe.query(self, start_date, end_date, dates)
df = pd.merge(df, codes, how='inner', on=['trade_date', 'code']).sort_values(['trade_date', 'code'])
......@@ -1037,5 +1057,5 @@ if __name__ == '__main__':
codes = engine.fetch_codes(ref_date, universe)
dates = makeSchedule('2018-01-01', '2018-02-01', '10b', 'china.sse')
factor_data = engine.fetch_factor_range_forward(universe, ['roe_q'], dates=dates)
factor_data = engine.fetch_dx_return('2018-01-30', codes, neutralized_risks=risk_styles+industry_styles)
print(factor_data)
This diff is collapsed.
......@@ -15,6 +15,7 @@ from alphamind.tests.data.test_neutralize import TestNeutralize
from alphamind.tests.data.test_standardize import TestStandardize
from alphamind.tests.data.test_winsorize import TestWinsorize
from alphamind.tests.data.test_quantile import TestQuantile
from alphamind.tests.data.engines.test_sql_engine import TestSqlEngine
from alphamind.tests.data.engines.test_universe import TestUniverse
from alphamind.tests.portfolio.test_constraints import TestConstraints
from alphamind.tests.portfolio.test_evolver import TestEvolver
......@@ -45,6 +46,7 @@ if __name__ == '__main__':
TestStandardize,
TestWinsorize,
TestQuantile,
TestSqlEngine,
TestUniverse,
TestConstraints,
TestEvolver,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment