Commit d57dfaa7 authored by Dr.李's avatar Dr.李

begin the model stuff

parent 30758656
......@@ -61,7 +61,8 @@ def prepare_data(engine: SqlEngine,
df = pd.merge(df, industry_df, on=['trade_date', 'code'])
df['weight'] = df['weight'].fillna(0.)
return df[['trade_date', 'code', 'dx']], df[['trade_date', 'code', 'weight', 'isOpen', 'industry_code', 'industry'] + transformer.names]
return df[['trade_date', 'code', 'dx']], df[
['trade_date', 'code', 'weight', 'isOpen', 'industry_code', 'industry'] + transformer.names]
def batch_processing(x_values,
......@@ -75,6 +76,7 @@ def batch_processing(x_values,
train_x_buckets = {}
train_y_buckets = {}
predict_x_buckets = {}
predict_y_buckets = {}
for i, start in enumerate(groups[:-batch]):
end = groups[i + batch]
......@@ -110,7 +112,16 @@ def batch_processing(x_values,
risk_factors=this_risk_exp,
post_process=post_process)
predict_x_buckets[end] = ne_x[sub_dates == end]
return train_x_buckets, train_y_buckets, predict_x_buckets
this_raw_y = y_values[index]
if len(this_raw_y) > 0:
ne_y = factor_processing(this_raw_y,
pre_process=pre_process,
risk_factors=this_risk_exp,
post_process=post_process)
predict_y_buckets[end] = ne_y[sub_dates == end]
return train_x_buckets, train_y_buckets, predict_x_buckets, predict_y_buckets
def fetch_data_package(engine: SqlEngine,
......@@ -170,7 +181,7 @@ def fetch_data_package(engine: SqlEngine,
alpha_logger.info("Loading data is finished")
train_x_buckets, train_y_buckets, predict_x_buckets = batch_processing(x_values,
train_x_buckets, train_y_buckets, predict_x_buckets, predict_y_buckets = batch_processing(x_values,
y_values,
dates,
date_label,
......@@ -185,7 +196,7 @@ def fetch_data_package(engine: SqlEngine,
ret['x_names'] = transformer.names
ret['settlement'] = return_df
ret['train'] = {'x': train_x_buckets, 'y': train_y_buckets}
ret['predict'] = {'x': predict_x_buckets}
ret['predict'] = {'x': predict_x_buckets, 'y': predict_y_buckets}
return ret
......
......@@ -5,66 +5,49 @@ Created on 2017-5-10
@author: cheng.li
"""
from typing import Tuple
from typing import Union
import pickle
import numpy as np
import numba as nb
from alphamind.utilities import groupby
from alphamind.data.neutralize import ls_fit
from distutils.version import LooseVersion
from sklearn import __version__ as sklearn_version
from sklearn.linear_model import LinearRegression as LinearRegressionImpl
from alphamind.model.modelbase import ModelBase
from alphamind.utilities import alpha_logger
class LinearModel(object):
class LinearRegression(ModelBase):
def __init__(self, init_param=None):
self.model_parameter = init_param
def __init__(self, features, fit_intercept: bool=False):
super().__init__(features)
self.impl = LinearRegressionImpl(fit_intercept=fit_intercept)
def calibrate(self, x, y, groups=None):
self.model_parameter = _train(x, y, groups)
def fit(self, x: np.ndarray, y: np.ndarray):
self.impl.fit(x, y)
def predict(self, x, groups=None):
if groups is not None and isinstance(self.model_parameter, tuple):
names = np.unique(groups)
return _prediction_impl(self.model_parameter[0], self.model_parameter[1], groups, names, x)
elif self.model_parameter is None:
raise ValueError("linear model is not calibrated yet")
elif groups is None:
return x @ self.model_parameter
else:
raise ValueError("grouped x value can't be used for vanilla linear model")
def predict(self, x: np.ndarray) -> np.ndarray:
return self.impl.predict(x)
def save(self) -> dict:
model_desc = super().save()
model_desc['desc'] = pickle.dumps(self.impl)
model_desc['sklearn_version'] = sklearn_version
return model_desc
@nb.njit(nogil=True, cache=True)
def _prediction_impl(calibrated_names, model_parameter, groups, names, x):
places = np.searchsorted(calibrated_names, names)
pred_v = np.zeros(x.shape[0])
for k, name in zip(places, names):
this_param = model_parameter[k]
idx = groups == name
pred_v[idx] = x[idx] @ this_param
return pred_v
def load(self, model_desc: dict):
super().load(model_desc)
if LooseVersion(sklearn_version) < LooseVersion(model_desc['sklearn_version']):
alpha_logger.warning('Current sklearn version {0} is lower than the model version {1}. '
'Loaded model may work incorrectly.'.format(
sklearn_version, model_desc['sklearn_version']))
def _train(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
if groups is None:
return ls_fit(x, y)
else:
index_diff, order = groupby(groups)
res_beta = _train_loop(index_diff, order, x, y)
return np.unique(groups), res_beta
self.impl = pickle.loads(model_desc['desc'])
@nb.njit(nogil=True, cache=True)
def _train_loop(index_diff, order, x, y):
res_beta = np.zeros((len(index_diff)+1, x.shape[1]))
start = 0
for k, diff_loc in enumerate(index_diff):
res_beta[k] = _train_sub_group(x, y, order[start:diff_loc + 1])
start = diff_loc + 1
return res_beta
if __name__ == '__main__':
import pprint
ls = LinearRegression(['a', 'b'])
@nb.njit(nogil=True, cache=True)
def _train_sub_group(x, y, curr_idx):
curr_x = x[curr_idx]
curr_y = y[curr_idx]
return ls_fit(curr_x, curr_y)
model_desc = ls.save()
new_model = ls.load(model_desc)
pprint.pprint(model_desc)
# -*- coding: utf-8 -*-
"""
Created on 2017-9-4
@author: cheng.li
"""
import abc
import arrow
import numpy as np
from alphamind.utilities import alpha_logger
class ModelBase(metaclass=abc.ABCMeta):
def __init__(self, features: list):
self.features = features
@abc.abstractmethod
def fit(self, x, y):
pass
@abc.abstractmethod
def predict(self, x) -> np.ndarray:
pass
@abc.abstractmethod
def save(self) -> dict:
if self.__class__.__module__ == '__main__':
alpha_logger.warning("model is defined in a main module. The model_name may not be correct.")
model_desc = dict(internal_model=self.impl.__class__.__module__ + "." + self.impl.__class__.__name__,
model_name=self.__class__.__module__ + "." + self.__class__.__name__,
language='python',
timestamp=arrow.now().format(),
features=self.features)
return model_desc
@abc.abstractmethod
def load(self, model_desc: dict):
self.features = model_desc['features']
arrow >= 0.10.0
cython >= 0.25.2
finance-python >= 0.5.7
mysqlclient >= 1.3.10
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment