Commit 189ee654 authored by Dr.李's avatar Dr.李

update api

parent 238dcfe5
......@@ -45,6 +45,10 @@ def _map_factors(factors: Iterable[str], used_factor_tables) -> Dict:
if f not in excluded and f in t.__table__.columns:
factor_cols[t.__table__.columns[f]] = t
break
if not factor_cols:
raise ValueError(f"some factors in <{factors}> can't be find")
return factor_cols
......
......@@ -261,7 +261,8 @@ def fetch_train_phase(engine,
risk_model: str = 'short',
pre_process: Iterable[object] = None,
post_process: Iterable[object] = None,
warm_start: int = 0) -> dict:
warm_start: int = 0,
fitting_target: Union[Transformer, object] = None) -> dict:
if isinstance(alpha_factors, Transformer):
transformer = alpha_factors
else:
......@@ -281,7 +282,13 @@ def fetch_train_phase(engine,
horizon = map_freq(frequency)
factor_df = engine.fetch_factor_range(universe, factors=transformer, dates=dates)
target_df = engine.fetch_dx_return_range(universe, dates=dates, horizon=horizon)
if fitting_target is None:
target_df = engine.fetch_dx_return_range(universe, dates=dates, horizon=horizon)
else:
one_more_date = advanceDateByCalendar('china.sse', dates[-1], frequency)
target_df = engine.fetch_factor_range_forward(universe, factors=fitting_target, dates=dates + [one_more_date])
target_df = target_df[target_df.trade_date.isin(dates)]
target_df = target_df.groupby('code').apply(lambda x: x.fillna(method='pad'))
df = pd.merge(factor_df, target_df, on=['trade_date', 'code']).dropna()
......@@ -336,7 +343,7 @@ def fetch_predict_phase(engine,
pre_process: Iterable[object] = None,
post_process: Iterable[object] = None,
warm_start: int = 0,
fillna: str=None):
fillna: str = None):
if isinstance(alpha_factors, Transformer):
transformer = alpha_factors
else:
......@@ -356,7 +363,8 @@ def fetch_predict_phase(engine,
factor_df = engine.fetch_factor_range(universe, factors=transformer, dates=dates)
if fillna:
factor_df = factor_df.groupby('trade_date').apply(lambda x: x.fillna(x.median())).reset_index(drop=True).dropna()
factor_df = factor_df.groupby('trade_date').apply(lambda x: x.fillna(x.median())).reset_index(
drop=True).dropna()
else:
factor_df = factor_df.dropna()
......@@ -420,9 +428,10 @@ if __name__ == '__main__':
universe = Universe('zz500', ['hs300', 'zz500'])
neutralized_risk = ['SIZE']
res = fetch_train_phase(engine, ['ep_q'],
'2012-01-05',
'5b',
universe,
16,
neutralized_risk=neutralized_risk)
'2012-01-05',
'5b',
universe,
16,
neutralized_risk=neutralized_risk,
fitting_target='closePrice')
print(res)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment