Commit 4f20b165 authored by Dr.李's avatar Dr.李

added more composer functionality

parent fc024e0d
...@@ -9,6 +9,7 @@ import copy ...@@ -9,6 +9,7 @@ import copy
import bisect import bisect
from typing import Iterable from typing import Iterable
import pandas as pd import pandas as pd
from simpleutils.miscellaneous import list_eq
from alphamind.model.modelbase import ModelBase from alphamind.model.modelbase import ModelBase
from alphamind.model.data_preparing import fetch_train_phase from alphamind.model.data_preparing import fetch_train_phase
from alphamind.model.data_preparing import fetch_predict_phase from alphamind.model.data_preparing import fetch_predict_phase
...@@ -16,6 +17,8 @@ from alphamind.data.engines.universe import Universe ...@@ -16,6 +17,8 @@ from alphamind.data.engines.universe import Universe
from alphamind.data.engines.sqlengine import SqlEngine from alphamind.data.engines.sqlengine import SqlEngine
from alphamind.data.winsorize import winsorize_normal from alphamind.data.winsorize import winsorize_normal
from alphamind.data.standardize import standardize from alphamind.data.standardize import standardize
from alphamind.model.loader import load_model
PROCESS_MAPPING = { PROCESS_MAPPING = {
'winsorize_normal': winsorize_normal, 'winsorize_normal': winsorize_normal,
...@@ -53,15 +56,26 @@ class DataMeta(object): ...@@ -53,15 +56,26 @@ class DataMeta(object):
self.post_process = _map_process(post_process) self.post_process = _map_process(post_process)
self.warm_start = warm_start self.warm_start = warm_start
def __eq__(self, rhs):
return self.data_source == rhs.data_source \
and self.freq == rhs.freq \
and self.universe == rhs.universe \
and self.batch == rhs.batch \
and list_eq(self.neutralized_risk, rhs.neutralized_risk) \
and self.risk_model == rhs.risk_model \
and list_eq(self.pre_process, rhs.pre_process) \
and list_eq(self.post_process, rhs.post_process) \
and self.warm_start == rhs.warm_start
def save(self) -> dict: def save(self) -> dict:
return dict( return dict(
freq=self.freq, freq=self.freq,
universe=self.universe.save(), universe=self.universe.save(),
batch=self.batch, batch=self.batch,
neutralized_risk=neutralized_risk, neutralized_risk=self.neutralized_risk,
risk_model=self.risk_model, risk_model=self.risk_model,
pre_process=[p.__name__ for p in self.pre_process] if pre_process else None, pre_process=[p.__name__ for p in self.pre_process] if self.pre_process else None,
post_process=[p.__name__ for p in self.post_process] if pre_process else None, post_process=[p.__name__ for p in self.post_process] if self.pre_process else None,
warm_start=self.warm_start, warm_start=self.warm_start,
data_source=self.data_source data_source=self.data_source
) )
...@@ -131,7 +145,7 @@ def predict_by_model(ref_date: str, ...@@ -131,7 +145,7 @@ def predict_by_model(ref_date: str,
return pd.DataFrame(alpha_model.predict(x_values).flatten(), index=codes) return pd.DataFrame(alpha_model.predict(x_values).flatten(), index=codes)
class ModelComposer(object): class Composer(object):
def __init__(self, def __init__(self,
alpha_model: ModelBase, alpha_model: ModelBase,
data_meta: DataMeta): data_meta: DataMeta):
...@@ -143,7 +157,7 @@ class ModelComposer(object): ...@@ -143,7 +157,7 @@ class ModelComposer(object):
self.sorted_keys = None self.sorted_keys = None
def train(self, ref_date: str): def train(self, ref_date: str):
self.models[ref_date] = train_model(ref_date, self.alpha_model, self.data_meta) self.models[ref_date] = train_model(ref_date, copy.deepcopy(self.alpha_model), self.data_meta)
self.is_updated = False self.is_updated = False
def predict(self, ref_date: str, x: pd.DataFrame = None) -> pd.DataFrame: def predict(self, ref_date: str, x: pd.DataFrame = None) -> pd.DataFrame:
...@@ -166,6 +180,18 @@ class ModelComposer(object): ...@@ -166,6 +180,18 @@ class ModelComposer(object):
latest_index = bisect.bisect_left(sorted_keys, ref_date) - 1 latest_index = bisect.bisect_left(sorted_keys, ref_date) - 1
return self.models[sorted_keys[latest_index]] return self.models[sorted_keys[latest_index]]
def save(self) -> dict:
return dict(
alpha_model=self.alpha_model.save(),
data_meta=self.data_meta.save()
)
@classmethod
def load(cls, comp_desc):
alpha_model = load_model(comp_desc['alpha_model'])
data_meta = DataMeta.load(comp_desc['data_meta'])
return cls(alpha_model, data_meta)
if __name__ == '__main__': if __name__ == '__main__':
import numpy as np import numpy as np
...@@ -194,7 +220,7 @@ if __name__ == '__main__': ...@@ -194,7 +220,7 @@ if __name__ == '__main__':
pos_process, pos_process,
data_source=data_source) data_source=data_source)
composer = ModelComposer(alpha_model, data_meta) composer = Composer(alpha_model, data_meta)
composer.train('2017-09-20') composer.train('2017-09-20')
composer.train('2017-09-22') composer.train('2017-09-22')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment