Commit 88a49650 authored by Dr.李's avatar Dr.李

update strategy part

parent 8c1e9cab
......@@ -4,12 +4,12 @@
"data_process":
{
"pre_process": ["winsorize", "standardize"],
"neutralize_risk": ["SIZE"],
"neutralize_risk": ["SIZE", "industry_styles"],
"post_process": ["winsorize", "standardize"]
},
"risk_model": "short",
"model": "LinearRegression",
"alpha_model": "LinearRegression",
"features": ["EPS", "ROEDiluted"],
"parameters":
......@@ -21,5 +21,11 @@
"batch": 4,
"warm_start": 0,
"universe": ["zz500", ["zz500"]],
"benchmark": 905
"benchmark": 905,
"optimizer":
{
"build_type": "risk_neutral",
"neutralize_risk": ["SIZE", "industry_styles"]
}
}
\ No newline at end of file
......@@ -60,7 +60,7 @@ class Strategy(object):
self.neutralize_risk = load_neutralize_risks(strategy_desc['data_process']['neutralize_risk'])
self.risk_model = strategy_desc['risk_model']
self.model_type = load_model_meta(strategy_desc['model'])
self.model_type = load_model_meta(strategy_desc['alpha_model'])
self.parameters = strategy_desc['parameters']
self.features = strategy_desc['features']
self.model = self.model_type(features=self.features, **self.parameters)
......@@ -91,6 +91,11 @@ class Strategy(object):
self.risk_model,
self.pre_process,
self.post_process)
# some cached data to fast processing
settlement_data = self.cached_data['settlement']
self.settle_dfs = settlement_data.set_index('code').groupby('trade_date')
self.scheduled_dates = set(k.strftime('%Y-%m-%d') for k in self.cached_data['train']['x'].keys())
else:
self.cached_data = None
......@@ -149,6 +154,9 @@ class Strategy(object):
return pd.DataFrame({'prediction': prediction,
'code': codes})
def settlement(self, ref_date: str, prediction: pd.DataFrame) -> float:
settlement_data = self.settle_dfs.get_group(ref_date)[['dx', 'weight']]
if __name__ == '__main__':
import json
......@@ -158,7 +166,7 @@ if __name__ == '__main__':
engine = SqlEngine()
start_date = '2012-01-01'
start_date = '2017-06-01'
end_date = '2017-09-14'
with open("sample_strategy.json", 'r') as fp:
......@@ -168,7 +176,7 @@ if __name__ == '__main__':
dates = strategy.cached_dates()
print(dates)
# for date in dates:
# strategy.model_train(date)
# prediction = strategy.model_predict(date)
# print(date)
for date in dates:
strategy.model_train(date)
prediction = strategy.model_predict(date)
strategy.settlement(date, prediction)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment