Commit f18521f6 authored by Dr.李's avatar Dr.李

remove formula parameters from model base

parent 561e1046
......@@ -32,9 +32,8 @@ class ConstLinearModel(ModelBase):
def __init__(self,
features: list = None,
formulas: dict = None,
weights: np.ndarray = None):
super().__init__(features, formulas=formulas)
super().__init__(features)
if features is not None and weights is not None:
pyFinAssert(len(features) == len(weights),
ValueError,
......@@ -57,8 +56,8 @@ class ConstLinearModel(ModelBase):
class LinearRegression(ModelBase):
def __init__(self, features: list = None, formulas: dict = None, fit_intercept: bool = False, **kwargs):
super().__init__(features, formulas=formulas)
def __init__(self, features: list = None, fit_intercept: bool = False, **kwargs):
super().__init__(features)
self.impl = LinearRegressionImpl(fit_intercept=fit_intercept, **kwargs)
self.trained_time = None
......@@ -85,8 +84,8 @@ class LinearRegression(ModelBase):
class LassoRegression(ModelBase):
def __init__(self, alpha=0.01, features: list = None, formulas: dict = None, fit_intercept: bool = False, **kwargs):
super().__init__(features, formulas=formulas)
def __init__(self, alpha=0.01, features: list = None, fit_intercept: bool = False, **kwargs):
super().__init__(features)
self.impl = Lasso(alpha=alpha, fit_intercept=fit_intercept, **kwargs)
self.trained_time = None
......@@ -113,8 +112,8 @@ class LassoRegression(ModelBase):
class LogisticRegression(ModelBase):
def __init__(self, features: list = None, formulas: dict = None, fit_intercept: bool = False, **kwargs):
super().__init__(features, formulas=formulas)
def __init__(self, features: list = None, fit_intercept: bool = False, **kwargs):
super().__init__(features)
self.impl = LogisticRegressionImpl(fit_intercept=fit_intercept, **kwargs)
def save(self) -> dict:
......
......@@ -6,7 +6,6 @@ Created on 2017-9-4
"""
import abc
import copy
import arrow
import numpy as np
from alphamind.utilities import alpha_logger
......@@ -17,13 +16,13 @@ from alphamind.data.transformer import Transformer
class ModelBase(metaclass=abc.ABCMeta):
def __init__(self, features: list=None, formulas: dict=None):
def __init__(self, features: list=None):
if features is not None:
self.features = Transformer(features).names
self.formulas = Transformer(features)
self.features = self.formulas.names
else:
self.features = None
self.impl = None
self.formulas = copy.deepcopy(formulas)
self.trained_time = None
def fit(self, x: np.ndarray, y: np.ndarray):
......
......@@ -28,7 +28,7 @@ class RandomForestRegressor(ModelBase):
max_features: str='auto',
features: List=None,
**kwargs):
super().__init__(features, **kwargs)
super().__init__(features)
self.impl = RandomForestRegressorImpl(n_estimators=n_estimators,
max_features=max_features,
**kwargs)
......@@ -61,9 +61,8 @@ class RandomForestClassifier(ModelBase):
n_estimators: int=100,
max_features: str='auto',
features: List = None,
formulas: dict = None,
**kwargs):
super().__init__(features, formulas=formulas)
super().__init__(features)
self.impl = RandomForestClassifierImpl(n_estimators=n_estimators,
max_features=max_features,
**kwargs)
......@@ -97,10 +96,9 @@ class XGBRegressor(ModelBase):
learning_rate: float=0.1,
max_depth: int=3,
features: List=None,
formulas: dict = None,
n_jobs: int=1,
**kwargs):
super().__init__(features, formulas=formulas)
super().__init__(features)
self.impl = XGBRegressorImpl(n_estimators=n_estimators,
learning_rate=learning_rate,
max_depth=max_depth,
......@@ -135,10 +133,9 @@ class XGBClassifier(ModelBase):
learning_rate: float=0.1,
max_depth: int=3,
features: List = None,
formulas: dict = None,
n_jobs: int=1,
**kwargs):
super().__init__(features, formulas=formulas)
super().__init__(features)
self.impl = XGBClassifierImpl(n_estimators=n_estimators,
learning_rate=learning_rate,
max_depth=max_depth,
......@@ -180,11 +177,10 @@ class XGBTrainer(ModelBase):
subsample=1.,
colsample_bytree=1.,
features: List = None,
formulas: dict = None,
random_state: int=0,
n_jobs: int=1,
**kwargs):
super().__init__(features, formulas=formulas)
super().__init__(features)
self.params = {
'silent': 1,
'objective': objective,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment