Commit 36960c27 authored by Dr.李's avatar Dr.李

added model composer

parent d18d36f1
# -*- coding: utf-8 -*-
"""
Created on 2017-9-27
@author: cheng.li
"""
import copy
import bisect
from typing import Union
from typing import Iterable
import pandas as pd
from alphamind.model.modelbase import ModelBase
from alphamind.model.data_preparing import fetch_train_phase
from alphamind.model.data_preparing import fetch_predict_phase
from alphamind.data.transformer import Transformer
from alphamind.data.engines.universe import Universe
class DataMeta(object):
def __init__(self,
engine,
alpha_factors: Union[Transformer, Iterable[object]],
freq: str,
universe: Universe,
batch: int,
neutralized_risk: Iterable[str] = None,
risk_model: str = 'short',
pre_process: Iterable[object] = None,
post_process: Iterable[object] = None,
warm_start: int = 0):
self.engine = engine
self.alpha_model = alpha_model
self.alpha_factors = alpha_factors
self.freq = freq
self.universe = universe
self.batch = batch
self.neutralized_risk = neutralized_risk
self.risk_model = risk_model
self.pre_process = pre_process
self.post_process = post_process
self.warm_start = warm_start
class ModelComposer(object):
def __init__(self,
alpha_model: ModelBase,
data_meta: DataMeta):
self.alpha_model = alpha_model
self.data_meta = data_meta
self.models = {}
self.is_updated = False
self.sorted_keys = None
def train(self, ref_date: str):
train_data = fetch_train_phase(self.data_meta.engine,
self.data_meta.alpha_factors,
ref_date,
self.data_meta.freq,
self.data_meta.universe,
self.data_meta.batch,
self.data_meta.neutralized_risk,
self.data_meta.risk_model,
self.data_meta.pre_process,
self.data_meta.post_process,
self.data_meta.warm_start)
x_values = train_data['train']['x']
y_values = train_data['train']['y']
self.alpha_model.fit(x_values, y_values)
self.models[ref_date] = copy.deepcopy(self.alpha_model)
self.is_updated = False
def predict(self, ref_date: str, x: pd.DataFrame = None) -> pd.DataFrame:
if x is None:
predict_data = fetch_predict_phase(self.data_meta.engine,
self.data_meta.alpha_factors,
ref_date,
self.data_meta.freq,
self.data_meta.universe,
self.data_meta.batch,
self.data_meta.neutralized_risk,
self.data_meta.risk_model,
self.data_meta.pre_process,
self.data_meta.post_process,
self.data_meta.warm_start)
x_values = predict_data['predict']['x']
codes = predict_data['predict']['code']
else:
x_values = x.values
codes = x.index
model = self._fetch_latest_model(ref_date)
return pd.DataFrame(model.predict(x_values).flatten(), index=codes)
def _fetch_latest_model(self, ref_date) -> ModelBase:
if self.is_updated:
sorted_keys = self.sorted_keys
else:
sorted_keys = sorted(self.models.keys())
self.sorted_keys = sorted_keys
self.is_updated = True
latest_index = bisect.bisect_left(sorted_keys, ref_date) - 1
return self.models[sorted_keys[latest_index]]
if __name__ == '__main__':
import numpy as np
from alphamind.data.standardize import standardize
from alphamind.data.winsorize import winsorize_normal
from alphamind.data.engines.sqlengine import industry_styles
from alphamind.data.engines.sqlengine import SqlEngine
from alphamind.model.linearmodel import ConstLinearModel
engine = SqlEngine()
alpha_model = ConstLinearModel(['EPS'], np.array([1.]))
alpha_factors = ['EPS']
freq = '1w'
universe = Universe('zz500', ['zz500'])
batch = 4
neutralized_risk = ['SIZE'] + industry_styles
risk_model = 'short'
pre_process = [winsorize_normal, standardize]
pos_process = [winsorize_normal, standardize]
data_meta = DataMeta(engine,
alpha_factors,
freq,
universe,
batch,
neutralized_risk,
risk_model,
pre_process,
pos_process)
composer = ModelComposer(alpha_model, data_meta)
composer.train('2017-09-20')
composer.train('2017-09-22')
composer.train('2017-09-25')
composer.predict('2017-09-21')
......@@ -189,7 +189,7 @@ def fetch_data_package(engine: SqlEngine,
neutralized_risk: Iterable[str] = None,
risk_model: str = 'short',
pre_process: Iterable[object] = None,
post_process: Iterable[object] = None):
post_process: Iterable[object] = None) -> dict:
alpha_logger.info("Starting data package fetching ...")
transformer = Transformer(alpha_factors)
......@@ -243,7 +243,7 @@ def fetch_train_phase(engine,
risk_model: str = 'short',
pre_process: Iterable[object] = None,
post_process: Iterable[object] = None,
warm_start: int = 0):
warm_start: int = 0) -> dict:
transformer = Transformer(alpha_factors)
p = Period(frequency)
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment