Commit 65b0a5d9 authored by Dr.李's avatar Dr.李

simplified the data fetching logic

parent b17ff59f
...@@ -271,8 +271,7 @@ class SqlEngine(object): ...@@ -271,8 +271,7 @@ class SqlEngine(object):
factors: Union[Transformer, Iterable[object]], factors: Union[Transformer, Iterable[object]],
start_date: str = None, start_date: str = None,
end_date: str = None, end_date: str = None,
dates: Iterable[str] = None, dates: Iterable[str] = None) -> pd.DataFrame:
warm_start: int = 0) -> pd.DataFrame:
if isinstance(factors, Transformer): if isinstance(factors, Transformer):
transformer = factors transformer = factors
...@@ -282,37 +281,20 @@ class SqlEngine(object): ...@@ -282,37 +281,20 @@ class SqlEngine(object):
dependency = transformer.dependency dependency = transformer.dependency
factor_cols = _map_factors(dependency) factor_cols = _map_factors(dependency)
fast_path_optimization = False cond = universe.query_range(start_date, start_date, dates)
for name in transformer.expressions:
if not isinstance(name, SecurityLatestValueHolder) and not isinstance(name, str):
break
else:
fast_path_optimization = True
if fast_path_optimization:
real_start_date = start_date
real_end_date = end_date
real_dates = dates
else:
if dates:
real_start_date = advanceDateByCalendar('china.sse', dates[0], str(-warm_start) + 'b').strftime(
'%Y-%m-%d')
real_end_date = dates[-1]
else:
real_start_date = advanceDateByCalendar('china.sse', start_date, str(-warm_start) + 'b').strftime(
'%Y-%m-%d')
real_end_date = end_date
real_dates = None
cond = universe.query_range(real_start_date, real_end_date, real_dates)
big_table = FullFactorView big_table = FullFactorView
for t in set(factor_cols.values()): for t in set(factor_cols.values()):
if t.__table__.name != FullFactorView.__table__.name: if t.__table__.name != FullFactorView.__table__.name:
big_table = outerjoin(big_table, t, and_(FullFactorView.trade_date == t.trade_date, if dates is not None:
FullFactorView.code == t.code)) big_table = outerjoin(big_table, t, and_(FullFactorView.trade_date == t.trade_date,
FullFactorView.code == t.code,
FullFactorView.trade_date.in_(dates)))
else:
big_table = outerjoin(big_table, t, and_(FullFactorView.trade_date == t.trade_date,
FullFactorView.code == t.code,
FullFactorView.trade_date.between(start_date, end_date)))
big_table = join(big_table, UniverseTable, big_table = join(big_table, UniverseTable,
and_(FullFactorView.trade_date == UniverseTable.trade_date, and_(FullFactorView.trade_date == UniverseTable.trade_date,
...@@ -328,10 +310,7 @@ class SqlEngine(object): ...@@ -328,10 +310,7 @@ class SqlEngine(object):
for col in res.columns: for col in res.columns:
if col not in set(['code', 'isOpen']) and col not in df.columns: if col not in set(['code', 'isOpen']) and col not in df.columns:
df[col] = res[col].values df[col] = res[col].values
if dates:
df = df[df.index.isin(dates)]
else:
df = df[start_date:end_date]
return df.reset_index() return df.reset_index()
def fetch_benchmark(self, def fetch_benchmark(self,
......
...@@ -29,17 +29,17 @@ re-balance - 1 week ...@@ -29,17 +29,17 @@ re-balance - 1 week
training - every 4 week training - every 4 week
''' '''
engine = SqlEngine('postgresql+psycopg2://postgres:we083826@localhost/alpha') engine = SqlEngine('postgresql+psycopg2://postgres:we083826@192.168.0.101/alpha')
universe = Universe('hs300', ['hs300']) universe = Universe('zz500', ['zz500'])
neutralize_risk = ['SIZE'] + industry_styles neutralize_risk = ['SIZE'] + industry_styles
portfolio_risk_neutralize = ['SIZE'] portfolio_risk_neutralize = ['SIZE']
portfolio_industry_neutralize = True portfolio_industry_neutralize = True
alpha_factors = ['RVOL', 'EPS', 'CFinc1', 'BDTO', 'VAL', 'CHV', 'GREV', 'ROEDiluted'] # ['BDTO', 'RVOL', 'CHV', 'VAL', 'CFinc1'] # risk_styles alpha_factors = ['RVOL', 'EPS', 'CFinc1', 'BDTO', 'VAL', 'CHV', 'GREV', 'ROEDiluted'] # ['BDTO', 'RVOL', 'CHV', 'VAL', 'CFinc1'] # risk_styles
benchmark = 300 benchmark = 905
n_bins = 5 n_bins = 5
frequency = '1w' frequency = '2w'
batch = 8 batch = 8
start_date = '2012-01-01' start_date = '2017-01-01'
end_date = '2017-08-31' end_date = '2017-08-31'
method = 'risk_neutral' method = 'risk_neutral'
use_rank = 100 use_rank = 100
......
...@@ -11,6 +11,9 @@ from typing import Iterable ...@@ -11,6 +11,9 @@ from typing import Iterable
from typing import Union from typing import Union
from PyFin.api import makeSchedule from PyFin.api import makeSchedule
from PyFin.api import BizDayConventions from PyFin.api import BizDayConventions
from PyFin.api import advanceDateByCalendar
from PyFin.DateUtilities import Period
from PyFin.Enums import TimeUnits
from alphamind.data.transformer import Transformer from alphamind.data.transformer import Transformer
from alphamind.data.engines.sqlengine import SqlEngine from alphamind.data.engines.sqlengine import SqlEngine
from alphamind.data.engines.universe import Universe from alphamind.data.engines.universe import Universe
...@@ -19,14 +22,15 @@ from alphamind.utilities import alpha_logger ...@@ -19,14 +22,15 @@ from alphamind.utilities import alpha_logger
def _map_horizon(frequency: str) -> int: def _map_horizon(frequency: str) -> int:
if frequency == '1d': parsed_period = Period(frequency)
return 0 unit = parsed_period.units()
elif frequency == '1w': length = parsed_period.length()
return 4 if unit == TimeUnits.BDays or unit == TimeUnits.Days:
elif frequency == '1m': return length - 1
return 21 elif unit == TimeUnits.Weeks:
elif frequency == '3m': return 5 * length - 1
return 62 elif unit == TimeUnits.Months:
return 22 * length - 1
else: else:
raise ValueError('{0} is an unrecognized frequency rule'.format(frequency)) raise ValueError('{0} is an unrecognized frequency rule'.format(frequency))
...@@ -39,6 +43,9 @@ def prepare_data(engine: SqlEngine, ...@@ -39,6 +43,9 @@ def prepare_data(engine: SqlEngine,
universe: Universe, universe: Universe,
benchmark: int, benchmark: int,
warm_start: int = 0): warm_start: int = 0):
if warm_start > 0:
start_date = advanceDateByCalendar('china.sse', start_date, str(-warm_start) + 'b').strftime('%Y-%m-%d')
dates = makeSchedule(start_date, end_date, frequency, calendar='china.sse', dateRule=BizDayConventions.Following) dates = makeSchedule(start_date, end_date, frequency, calendar='china.sse', dateRule=BizDayConventions.Following)
horizon = _map_horizon(frequency) horizon = _map_horizon(frequency)
...@@ -50,8 +57,7 @@ def prepare_data(engine: SqlEngine, ...@@ -50,8 +57,7 @@ def prepare_data(engine: SqlEngine,
factor_df = engine.fetch_factor_range(universe, factor_df = engine.fetch_factor_range(universe,
factors=transformer, factors=transformer,
dates=dates, dates=dates).sort_values(['trade_date', 'code'])
warm_start=warm_start).sort_values(['trade_date', 'code'])
return_df = engine.fetch_dx_return_range(universe, dates=dates, horizon=horizon) return_df = engine.fetch_dx_return_range(universe, dates=dates, horizon=horizon)
industry_df = engine.fetch_industry_range(universe, dates=dates) industry_df = engine.fetch_industry_range(universe, dates=dates)
benchmark_df = engine.fetch_benchmark_range(benchmark, dates=dates) benchmark_df = engine.fetch_benchmark_range(benchmark, dates=dates)
...@@ -61,7 +67,7 @@ def prepare_data(engine: SqlEngine, ...@@ -61,7 +67,7 @@ def prepare_data(engine: SqlEngine,
df = pd.merge(df, industry_df, on=['trade_date', 'code']) df = pd.merge(df, industry_df, on=['trade_date', 'code'])
df['weight'] = df['weight'].fillna(0.) df['weight'] = df['weight'].fillna(0.)
return df[['trade_date', 'code', 'dx']], df[ return dates, df[['trade_date', 'code', 'dx']], df[
['trade_date', 'code', 'weight', 'isOpen', 'industry_code', 'industry'] + transformer.names] ['trade_date', 'code', 'weight', 'isOpen', 'industry_code', 'industry'] + transformer.names]
...@@ -140,15 +146,14 @@ def fetch_data_package(engine: SqlEngine, ...@@ -140,15 +146,14 @@ def fetch_data_package(engine: SqlEngine,
alpha_logger.info("Starting data package fetching ...") alpha_logger.info("Starting data package fetching ...")
transformer = Transformer(alpha_factors) transformer = Transformer(alpha_factors)
dates = makeSchedule(start_date, end_date, frequency, calendar='china.sse', dateRule=BizDayConventions.Following) dates, return_df, factor_df = prepare_data(engine,
return_df, factor_df = prepare_data(engine, transformer,
transformer, start_date,
start_date, end_date,
end_date, frequency,
frequency, universe,
universe, benchmark,
benchmark, warm_start)
warm_start)
if neutralized_risk: if neutralized_risk:
risk_df = engine.fetch_risk_model_range(universe, dates=dates, risk_model=risk_model)[1] risk_df = engine.fetch_risk_model_range(universe, dates=dates, risk_model=risk_model)[1]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment