Commit 1fbc1bf1 authored by wegamekinglc's avatar wegamekinglc

added sinle model training api

parent c374c0be
......@@ -74,6 +74,7 @@ class Strategy(object):
self.total_data = None
self.index_return = None
self.risk_models = None
self.alpha_models = None
def prepare_backtest_data(self):
total_factors = self.engine.fetch_factor_range(self.universe,
......@@ -121,6 +122,25 @@ class Strategy(object):
offset=1).set_index('trade_date')
self.total_data = total_data
def prepare_backtest_models(self):
if self.total_data is None:
self.prepare_backtest_data()
total_data_groups = self.total_data.groupby('trade_date')
if self.dask_client is None:
models = {}
for ref_date, _ in total_data_groups:
models[ref_date] = train_model(ref_date.strftime('%Y-%m-%d'), self.alpha_model, self.data_meta)
else:
def worker(parameters):
new_model = train_model(parameters[0].strftime('%Y-%m-%d'), parameters[1], parameters[2])
return parameters[0], new_model
l = self.dask_client.map(worker, [(d[0], self.alpha_model, self.data_meta) for d in total_data_groups])
results = self.dask_client.gather(l)
models = dict(results)
self.alpha_models = models
alpha_logger.info("alpha models training finished ...")
@staticmethod
def _create_lu_bounds(running_setting, codes, benchmark_w):
......@@ -169,22 +189,12 @@ class Strategy(object):
executor = copy.deepcopy(running_setting.executor)
positions = pd.DataFrame()
if self.dask_client is None:
models = {}
for ref_date, _ in total_data_groups:
models[ref_date] = train_model(ref_date.strftime('%Y-%m-%d'), self.alpha_model, self.data_meta)
else:
def worker(parameters):
new_model = train_model(parameters[0].strftime('%Y-%m-%d'), parameters[1], parameters[2])
return parameters[0], new_model
l = self.dask_client.map(worker, [(d[0], self.alpha_model, self.data_meta) for d in total_data_groups])
results = self.dask_client.gather(l)
models = dict(results)
if self.alpha_models is None:
self.prepare_backtest_models()
for ref_date, this_data in total_data_groups:
risk_model = self.risk_models[ref_date]
new_model = models[ref_date]
new_model = self.alpha_models[ref_date]
codes = this_data.code.values.tolist()
if previous_pos.empty:
......@@ -249,7 +259,6 @@ class Strategy(object):
ret_df = ret_df.shift(1)
ret_df.iloc[0] = 0.
ret_df['excess_return'] = ret_df['returns'] - ret_df['benchmark_returns'] * ret_df['leverage']
return ret_df, positions
def _calculate_pos(self, running_setting, er, data, constraints, benchmark_w, lbound, ubound, risk_model,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment