Commit 189ee654 authored by Dr.李's avatar Dr.李

update api

parent 238dcfe5
...@@ -45,6 +45,10 @@ def _map_factors(factors: Iterable[str], used_factor_tables) -> Dict: ...@@ -45,6 +45,10 @@ def _map_factors(factors: Iterable[str], used_factor_tables) -> Dict:
if f not in excluded and f in t.__table__.columns: if f not in excluded and f in t.__table__.columns:
factor_cols[t.__table__.columns[f]] = t factor_cols[t.__table__.columns[f]] = t
break break
if not factor_cols:
raise ValueError(f"some factors in <{factors}> can't be find")
return factor_cols return factor_cols
......
...@@ -261,7 +261,8 @@ def fetch_train_phase(engine, ...@@ -261,7 +261,8 @@ def fetch_train_phase(engine,
risk_model: str = 'short', risk_model: str = 'short',
pre_process: Iterable[object] = None, pre_process: Iterable[object] = None,
post_process: Iterable[object] = None, post_process: Iterable[object] = None,
warm_start: int = 0) -> dict: warm_start: int = 0,
fitting_target: Union[Transformer, object] = None) -> dict:
if isinstance(alpha_factors, Transformer): if isinstance(alpha_factors, Transformer):
transformer = alpha_factors transformer = alpha_factors
else: else:
...@@ -281,7 +282,13 @@ def fetch_train_phase(engine, ...@@ -281,7 +282,13 @@ def fetch_train_phase(engine,
horizon = map_freq(frequency) horizon = map_freq(frequency)
factor_df = engine.fetch_factor_range(universe, factors=transformer, dates=dates) factor_df = engine.fetch_factor_range(universe, factors=transformer, dates=dates)
if fitting_target is None:
target_df = engine.fetch_dx_return_range(universe, dates=dates, horizon=horizon) target_df = engine.fetch_dx_return_range(universe, dates=dates, horizon=horizon)
else:
one_more_date = advanceDateByCalendar('china.sse', dates[-1], frequency)
target_df = engine.fetch_factor_range_forward(universe, factors=fitting_target, dates=dates + [one_more_date])
target_df = target_df[target_df.trade_date.isin(dates)]
target_df = target_df.groupby('code').apply(lambda x: x.fillna(method='pad'))
df = pd.merge(factor_df, target_df, on=['trade_date', 'code']).dropna() df = pd.merge(factor_df, target_df, on=['trade_date', 'code']).dropna()
...@@ -336,7 +343,7 @@ def fetch_predict_phase(engine, ...@@ -336,7 +343,7 @@ def fetch_predict_phase(engine,
pre_process: Iterable[object] = None, pre_process: Iterable[object] = None,
post_process: Iterable[object] = None, post_process: Iterable[object] = None,
warm_start: int = 0, warm_start: int = 0,
fillna: str=None): fillna: str = None):
if isinstance(alpha_factors, Transformer): if isinstance(alpha_factors, Transformer):
transformer = alpha_factors transformer = alpha_factors
else: else:
...@@ -356,7 +363,8 @@ def fetch_predict_phase(engine, ...@@ -356,7 +363,8 @@ def fetch_predict_phase(engine,
factor_df = engine.fetch_factor_range(universe, factors=transformer, dates=dates) factor_df = engine.fetch_factor_range(universe, factors=transformer, dates=dates)
if fillna: if fillna:
factor_df = factor_df.groupby('trade_date').apply(lambda x: x.fillna(x.median())).reset_index(drop=True).dropna() factor_df = factor_df.groupby('trade_date').apply(lambda x: x.fillna(x.median())).reset_index(
drop=True).dropna()
else: else:
factor_df = factor_df.dropna() factor_df = factor_df.dropna()
...@@ -424,5 +432,6 @@ if __name__ == '__main__': ...@@ -424,5 +432,6 @@ if __name__ == '__main__':
'5b', '5b',
universe, universe,
16, 16,
neutralized_risk=neutralized_risk) neutralized_risk=neutralized_risk,
fitting_target='closePrice')
print(res) print(res)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment