Commit 6e2b07e7 authored by Dr.李's avatar Dr.李

directly use model's member

parent f18521f6
......@@ -21,7 +21,6 @@ class DataMeta(object):
def __init__(self,
engine,
alpha_factors: Union[Transformer, Iterable[object]],
freq: str,
universe: Universe,
batch: int,
......@@ -32,7 +31,6 @@ class DataMeta(object):
warm_start: int = 0):
self.engine = engine
self.alpha_model = alpha_model
self.alpha_factors = alpha_factors
self.freq = freq
self.universe = universe
self.batch = batch
......@@ -47,7 +45,7 @@ def train_model(ref_date: str,
alpha_model: ModelBase,
data_meta: DataMeta):
train_data = fetch_train_phase(data_meta.engine,
data_meta.alpha_factors,
alpha_model.formulas,
ref_date,
data_meta.freq,
data_meta.universe,
......@@ -68,7 +66,7 @@ def predict_by_model(ref_date: str,
alpha_model: ModelBase,
data_meta):
predict_data = fetch_predict_phase(data_meta.engine,
data_meta.alpha_factors,
alpha_model.formulas,
ref_date,
data_meta.freq,
data_meta.universe,
......@@ -129,7 +127,7 @@ if __name__ == '__main__':
from alphamind.data.engines.sqlengine import SqlEngine
from alphamind.model.linearmodel import ConstLinearModel
engine = SqlEngine()
engine = SqlEngine("postgres+psycopg2://postgres:we083826@localhost/alpha")
alpha_model = ConstLinearModel(['EPS'], np.array([1.]))
alpha_factors = ['EPS']
freq = '1w'
......@@ -141,7 +139,6 @@ if __name__ == '__main__':
pos_process = [winsorize_normal, standardize]
data_meta = DataMeta(engine,
alpha_factors,
freq,
universe,
batch,
......
......@@ -247,7 +247,7 @@ def fetch_data_package(engine: SqlEngine,
def fetch_train_phase(engine,
alpha_factors: Iterable[object],
alpha_factors: Union[Transformer, Iterable[object]],
ref_date,
frequency,
universe,
......@@ -257,7 +257,10 @@ def fetch_train_phase(engine,
pre_process: Iterable[object] = None,
post_process: Iterable[object] = None,
warm_start: int = 0) -> dict:
transformer = Transformer(alpha_factors)
if isinstance(alpha_factors, Transformer):
transformer = alpha_factors
else:
transformer = Transformer(alpha_factors)
p = Period(frequency)
p = Period(length=-(warm_start + batch + 1) * p.length(), units=p.units())
......@@ -317,7 +320,7 @@ def fetch_train_phase(engine,
def fetch_predict_phase(engine,
alpha_factors: Iterable[object],
alpha_factors: Union[Transformer, Iterable[object]],
ref_date,
frequency,
universe,
......@@ -327,7 +330,10 @@ def fetch_predict_phase(engine,
pre_process: Iterable[object] = None,
post_process: Iterable[object] = None,
warm_start: int = 0):
transformer = Transformer(alpha_factors)
if isinstance(alpha_factors, Transformer):
transformer = alpha_factors
else:
transformer = Transformer(alpha_factors)
p = Period(frequency)
p = Period(length=-(warm_start + batch) * p.length(), units=p.units())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment