Commit baab31a8 authored by Dr.李's avatar Dr.李

update data preparing

parent 03e6435a
...@@ -128,12 +128,10 @@ def train_model(ref_date: str, ...@@ -128,12 +128,10 @@ def train_model(ref_date: str,
return base_model return base_model
def predict_by_model(ref_date: str, def fetch_predict_data(ref_date: str,
alpha_model: ModelBase, alpha_model: ModelBase,
data_meta: DataMeta=None, data_meta):
x_values: pd.DataFrame=None,
codes: Iterable[int]=None):
if x_values is None:
predict_data = fetch_predict_phase(data_meta.engine, predict_data = fetch_predict_phase(data_meta.engine,
alpha_model.formulas, alpha_model.formulas,
ref_date, ref_date,
...@@ -144,10 +142,18 @@ def predict_by_model(ref_date: str, ...@@ -144,10 +142,18 @@ def predict_by_model(ref_date: str,
data_meta.risk_model, data_meta.risk_model,
data_meta.pre_process, data_meta.pre_process,
data_meta.post_process, data_meta.post_process,
data_meta.warm_start) data_meta.warm_start,
fillna=True)
return predict_data['predict']['code'], predict_data['predict']['x']
x_values = predict_data['predict']['x']
codes = predict_data['predict']['code'] def predict_by_model(ref_date: str,
alpha_model: ModelBase,
data_meta: DataMeta=None,
x_values: pd.DataFrame=None,
codes: Iterable[int]=None):
if x_values is None:
codes, x_values = fetch_predict_data(ref_date, alpha_model, data_meta)
return pd.DataFrame(alpha_model.predict(x_values).flatten(), index=codes) return pd.DataFrame(alpha_model.predict(x_values).flatten(), index=codes)
......
...@@ -335,7 +335,8 @@ def fetch_predict_phase(engine, ...@@ -335,7 +335,8 @@ def fetch_predict_phase(engine,
risk_model: str = 'short', risk_model: str = 'short',
pre_process: Iterable[object] = None, pre_process: Iterable[object] = None,
post_process: Iterable[object] = None, post_process: Iterable[object] = None,
warm_start: int = 0): warm_start: int = 0,
fillna: str=None):
if isinstance(alpha_factors, Transformer): if isinstance(alpha_factors, Transformer):
transformer = alpha_factors transformer = alpha_factors
else: else:
...@@ -352,7 +353,12 @@ def fetch_predict_phase(engine, ...@@ -352,7 +353,12 @@ def fetch_predict_phase(engine,
dateRule=BizDayConventions.Following, dateRule=BizDayConventions.Following,
dateGenerationRule=DateGeneration.Backward) dateGenerationRule=DateGeneration.Backward)
factor_df = engine.fetch_factor_range(universe, factors=transformer, dates=dates).dropna() factor_df = engine.fetch_factor_range(universe, factors=transformer, dates=dates)
if fillna:
factor_df = factor_df.groupby('trade_date').apply(lambda x: x.fillna(x.median())).reset_index(drop=True).dropna()
else:
factor_df = factor_df.dropna()
names = transformer.names names = transformer.names
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment