Commit 5bf23eee authored by Dr.李's avatar Dr.李

fixed bug for xgb encode

parent 8cf4bb87
...@@ -35,8 +35,15 @@ class ModelBase(metaclass=abc.ABCMeta): ...@@ -35,8 +35,15 @@ class ModelBase(metaclass=abc.ABCMeta):
self.impl = None self.impl = None
self.trained_time = None self.trained_time = None
def model_encode(self):
return encode(self.impl)
@classmethod
def model_decode(cls, model_desc):
return decode(model_desc)
def __eq__(self, rhs): def __eq__(self, rhs):
return encode(self.impl) == encode(rhs.impl) \ return self.model_encode() == rhs.model_encode() \
and self.trained_time == rhs.trained_time \ and self.trained_time == rhs.trained_time \
and list_eq(self.features, rhs.features) \ and list_eq(self.features, rhs.features) \
and encode(self.formulas) == encode(rhs.formulas) \ and encode(self.formulas) == encode(rhs.formulas) \
...@@ -67,7 +74,7 @@ class ModelBase(metaclass=abc.ABCMeta): ...@@ -67,7 +74,7 @@ class ModelBase(metaclass=abc.ABCMeta):
saved_time=arrow.now().format("YYYY-MM-DD HH:mm:ss"), saved_time=arrow.now().format("YYYY-MM-DD HH:mm:ss"),
features=list(self.features), features=list(self.features),
trained_time=self.trained_time, trained_time=self.trained_time,
desc=encode(self.impl), desc=self.model_encode(),
formulas=encode(self.formulas), formulas=encode(self.formulas),
fit_target=encode(self.fit_target), fit_target=encode(self.fit_target),
internal_model=self.impl.__class__.__module__ + "." + self.impl.__class__.__name__) internal_model=self.impl.__class__.__module__ + "." + self.impl.__class__.__name__)
...@@ -80,7 +87,7 @@ class ModelBase(metaclass=abc.ABCMeta): ...@@ -80,7 +87,7 @@ class ModelBase(metaclass=abc.ABCMeta):
obj_layout.features = model_desc['features'] obj_layout.features = model_desc['features']
obj_layout.formulas = decode(model_desc['formulas']) obj_layout.formulas = decode(model_desc['formulas'])
obj_layout.trained_time = model_desc['trained_time'] obj_layout.trained_time = model_desc['trained_time']
obj_layout.impl = decode(model_desc['desc']) obj_layout.impl = cls.model_decode(model_desc['desc'])
if 'fit_target' in model_desc: if 'fit_target' in model_desc:
obj_layout.fit_target = decode(model_desc['fit_target']) obj_layout.fit_target = decode(model_desc['fit_target'])
else: else:
......
...@@ -5,7 +5,6 @@ Created on 2017-12-4 ...@@ -5,7 +5,6 @@ Created on 2017-12-4
@author: cheng.li @author: cheng.li
""" """
from distutils.version import LooseVersion
import arrow import arrow
import numpy as np import numpy as np
import pandas as pd import pandas as pd
...@@ -13,11 +12,9 @@ from sklearn.ensemble import RandomForestRegressor as RandomForestRegressorImpl ...@@ -13,11 +12,9 @@ from sklearn.ensemble import RandomForestRegressor as RandomForestRegressorImpl
from sklearn.ensemble import RandomForestClassifier as RandomForestClassifierImpl from sklearn.ensemble import RandomForestClassifier as RandomForestClassifierImpl
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
import xgboost as xgb import xgboost as xgb
from xgboost import __version__ as xgbboot_version
from xgboost import XGBRegressor as XGBRegressorImpl from xgboost import XGBRegressor as XGBRegressorImpl
from xgboost import XGBClassifier as XGBClassifierImpl from xgboost import XGBClassifier as XGBClassifierImpl
from alphamind.model.modelbase import create_model_base from alphamind.model.modelbase import create_model_base
from alphamind.utilities import alpha_logger
class RandomForestRegressor(create_model_base('sklearn')): class RandomForestRegressor(create_model_base('sklearn')):
...@@ -65,12 +62,14 @@ class XGBRegressor(create_model_base('xgboost')): ...@@ -65,12 +62,14 @@ class XGBRegressor(create_model_base('xgboost')):
features=None, features=None,
fit_target=None, fit_target=None,
n_jobs: int=1, n_jobs: int=1,
missing: float=np.nan,
**kwargs): **kwargs):
super().__init__(features=features, fit_target=fit_target) super().__init__(features=features, fit_target=fit_target)
self.impl = XGBRegressorImpl(n_estimators=n_estimators, self.impl = XGBRegressorImpl(n_estimators=n_estimators,
learning_rate=learning_rate, learning_rate=learning_rate,
max_depth=max_depth, max_depth=max_depth,
n_jobs=n_jobs, n_jobs=n_jobs,
missing=missing,
**kwargs) **kwargs)
@property @property
...@@ -87,13 +86,16 @@ class XGBClassifier(create_model_base('xgboost')): ...@@ -87,13 +86,16 @@ class XGBClassifier(create_model_base('xgboost')):
features=None, features=None,
fit_target=None, fit_target=None,
n_jobs: int=1, n_jobs: int=1,
missing: float=np.nan,
**kwargs): **kwargs):
super().__init__(features=features, fit_target=fit_target) super().__init__(features=features, fit_target=fit_target)
self.impl = XGBClassifierImpl(n_estimators=n_estimators, self.impl = XGBClassifierImpl(n_estimators=n_estimators,
learning_rate=learning_rate, learning_rate=learning_rate,
max_depth=max_depth, max_depth=max_depth,
n_jobs=n_jobs, n_jobs=n_jobs,
missing=missing,
**kwargs) **kwargs)
self.impl = XGBClassifier.model_decode(self.model_encode())
@property @property
def importances(self): def importances(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment