Commit 6e2b07e7 authored by Dr.李's avatar Dr.李

directly use model's member

parent f18521f6
...@@ -21,7 +21,6 @@ class DataMeta(object): ...@@ -21,7 +21,6 @@ class DataMeta(object):
def __init__(self, def __init__(self,
engine, engine,
alpha_factors: Union[Transformer, Iterable[object]],
freq: str, freq: str,
universe: Universe, universe: Universe,
batch: int, batch: int,
...@@ -32,7 +31,6 @@ class DataMeta(object): ...@@ -32,7 +31,6 @@ class DataMeta(object):
warm_start: int = 0): warm_start: int = 0):
self.engine = engine self.engine = engine
self.alpha_model = alpha_model self.alpha_model = alpha_model
self.alpha_factors = alpha_factors
self.freq = freq self.freq = freq
self.universe = universe self.universe = universe
self.batch = batch self.batch = batch
...@@ -47,7 +45,7 @@ def train_model(ref_date: str, ...@@ -47,7 +45,7 @@ def train_model(ref_date: str,
alpha_model: ModelBase, alpha_model: ModelBase,
data_meta: DataMeta): data_meta: DataMeta):
train_data = fetch_train_phase(data_meta.engine, train_data = fetch_train_phase(data_meta.engine,
data_meta.alpha_factors, alpha_model.formulas,
ref_date, ref_date,
data_meta.freq, data_meta.freq,
data_meta.universe, data_meta.universe,
...@@ -68,7 +66,7 @@ def predict_by_model(ref_date: str, ...@@ -68,7 +66,7 @@ def predict_by_model(ref_date: str,
alpha_model: ModelBase, alpha_model: ModelBase,
data_meta): data_meta):
predict_data = fetch_predict_phase(data_meta.engine, predict_data = fetch_predict_phase(data_meta.engine,
data_meta.alpha_factors, alpha_model.formulas,
ref_date, ref_date,
data_meta.freq, data_meta.freq,
data_meta.universe, data_meta.universe,
...@@ -129,7 +127,7 @@ if __name__ == '__main__': ...@@ -129,7 +127,7 @@ if __name__ == '__main__':
from alphamind.data.engines.sqlengine import SqlEngine from alphamind.data.engines.sqlengine import SqlEngine
from alphamind.model.linearmodel import ConstLinearModel from alphamind.model.linearmodel import ConstLinearModel
engine = SqlEngine() engine = SqlEngine("postgres+psycopg2://postgres:we083826@localhost/alpha")
alpha_model = ConstLinearModel(['EPS'], np.array([1.])) alpha_model = ConstLinearModel(['EPS'], np.array([1.]))
alpha_factors = ['EPS'] alpha_factors = ['EPS']
freq = '1w' freq = '1w'
...@@ -141,7 +139,6 @@ if __name__ == '__main__': ...@@ -141,7 +139,6 @@ if __name__ == '__main__':
pos_process = [winsorize_normal, standardize] pos_process = [winsorize_normal, standardize]
data_meta = DataMeta(engine, data_meta = DataMeta(engine,
alpha_factors,
freq, freq,
universe, universe,
batch, batch,
......
...@@ -247,7 +247,7 @@ def fetch_data_package(engine: SqlEngine, ...@@ -247,7 +247,7 @@ def fetch_data_package(engine: SqlEngine,
def fetch_train_phase(engine, def fetch_train_phase(engine,
alpha_factors: Iterable[object], alpha_factors: Union[Transformer, Iterable[object]],
ref_date, ref_date,
frequency, frequency,
universe, universe,
...@@ -257,7 +257,10 @@ def fetch_train_phase(engine, ...@@ -257,7 +257,10 @@ def fetch_train_phase(engine,
pre_process: Iterable[object] = None, pre_process: Iterable[object] = None,
post_process: Iterable[object] = None, post_process: Iterable[object] = None,
warm_start: int = 0) -> dict: warm_start: int = 0) -> dict:
transformer = Transformer(alpha_factors) if isinstance(alpha_factors, Transformer):
transformer = alpha_factors
else:
transformer = Transformer(alpha_factors)
p = Period(frequency) p = Period(frequency)
p = Period(length=-(warm_start + batch + 1) * p.length(), units=p.units()) p = Period(length=-(warm_start + batch + 1) * p.length(), units=p.units())
...@@ -317,7 +320,7 @@ def fetch_train_phase(engine, ...@@ -317,7 +320,7 @@ def fetch_train_phase(engine,
def fetch_predict_phase(engine, def fetch_predict_phase(engine,
alpha_factors: Iterable[object], alpha_factors: Union[Transformer, Iterable[object]],
ref_date, ref_date,
frequency, frequency,
universe, universe,
...@@ -327,7 +330,10 @@ def fetch_predict_phase(engine, ...@@ -327,7 +330,10 @@ def fetch_predict_phase(engine,
pre_process: Iterable[object] = None, pre_process: Iterable[object] = None,
post_process: Iterable[object] = None, post_process: Iterable[object] = None,
warm_start: int = 0): warm_start: int = 0):
transformer = Transformer(alpha_factors) if isinstance(alpha_factors, Transformer):
transformer = alpha_factors
else:
transformer = Transformer(alpha_factors)
p = Period(frequency) p = Period(frequency)
p = Period(length=-(warm_start + batch) * p.length(), units=p.units()) p = Period(length=-(warm_start + batch) * p.length(), units=p.units())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment