Commit 4f20b165 authored by Dr.李's avatar Dr.李

added more composer functionality

parent fc024e0d
......@@ -9,6 +9,7 @@ import copy
import bisect
from typing import Iterable
import pandas as pd
from simpleutils.miscellaneous import list_eq
from alphamind.model.modelbase import ModelBase
from alphamind.model.data_preparing import fetch_train_phase
from alphamind.model.data_preparing import fetch_predict_phase
......@@ -16,6 +17,8 @@ from alphamind.data.engines.universe import Universe
from alphamind.data.engines.sqlengine import SqlEngine
from alphamind.data.winsorize import winsorize_normal
from alphamind.data.standardize import standardize
from alphamind.model.loader import load_model
PROCESS_MAPPING = {
'winsorize_normal': winsorize_normal,
......@@ -53,15 +56,26 @@ class DataMeta(object):
self.post_process = _map_process(post_process)
self.warm_start = warm_start
def __eq__(self, rhs):
return self.data_source == rhs.data_source \
and self.freq == rhs.freq \
and self.universe == rhs.universe \
and self.batch == rhs.batch \
and list_eq(self.neutralized_risk, rhs.neutralized_risk) \
and self.risk_model == rhs.risk_model \
and list_eq(self.pre_process, rhs.pre_process) \
and list_eq(self.post_process, rhs.post_process) \
and self.warm_start == rhs.warm_start
def save(self) -> dict:
return dict(
freq=self.freq,
universe=self.universe.save(),
batch=self.batch,
neutralized_risk=neutralized_risk,
neutralized_risk=self.neutralized_risk,
risk_model=self.risk_model,
pre_process=[p.__name__ for p in self.pre_process] if pre_process else None,
post_process=[p.__name__ for p in self.post_process] if pre_process else None,
pre_process=[p.__name__ for p in self.pre_process] if self.pre_process else None,
post_process=[p.__name__ for p in self.post_process] if self.pre_process else None,
warm_start=self.warm_start,
data_source=self.data_source
)
......@@ -131,7 +145,7 @@ def predict_by_model(ref_date: str,
return pd.DataFrame(alpha_model.predict(x_values).flatten(), index=codes)
class ModelComposer(object):
class Composer(object):
def __init__(self,
alpha_model: ModelBase,
data_meta: DataMeta):
......@@ -143,7 +157,7 @@ class ModelComposer(object):
self.sorted_keys = None
def train(self, ref_date: str):
self.models[ref_date] = train_model(ref_date, self.alpha_model, self.data_meta)
self.models[ref_date] = train_model(ref_date, copy.deepcopy(self.alpha_model), self.data_meta)
self.is_updated = False
def predict(self, ref_date: str, x: pd.DataFrame = None) -> pd.DataFrame:
......@@ -166,6 +180,18 @@ class ModelComposer(object):
latest_index = bisect.bisect_left(sorted_keys, ref_date) - 1
return self.models[sorted_keys[latest_index]]
def save(self) -> dict:
return dict(
alpha_model=self.alpha_model.save(),
data_meta=self.data_meta.save()
)
@classmethod
def load(cls, comp_desc):
alpha_model = load_model(comp_desc['alpha_model'])
data_meta = DataMeta.load(comp_desc['data_meta'])
return cls(alpha_model, data_meta)
if __name__ == '__main__':
import numpy as np
......@@ -194,7 +220,7 @@ if __name__ == '__main__':
pos_process,
data_source=data_source)
composer = ModelComposer(alpha_model, data_meta)
composer = Composer(alpha_model, data_meta)
composer.train('2017-09-20')
composer.train('2017-09-22')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment