Commit 5bf23eee authored by Dr.李's avatar Dr.李

fixed bug for xgb encode

parent 8cf4bb87
......@@ -35,8 +35,15 @@ class ModelBase(metaclass=abc.ABCMeta):
self.impl = None
self.trained_time = None
def model_encode(self):
return encode(self.impl)
@classmethod
def model_decode(cls, model_desc):
return decode(model_desc)
def __eq__(self, rhs):
return encode(self.impl) == encode(rhs.impl) \
return self.model_encode() == rhs.model_encode() \
and self.trained_time == rhs.trained_time \
and list_eq(self.features, rhs.features) \
and encode(self.formulas) == encode(rhs.formulas) \
......@@ -67,7 +74,7 @@ class ModelBase(metaclass=abc.ABCMeta):
saved_time=arrow.now().format("YYYY-MM-DD HH:mm:ss"),
features=list(self.features),
trained_time=self.trained_time,
desc=encode(self.impl),
desc=self.model_encode(),
formulas=encode(self.formulas),
fit_target=encode(self.fit_target),
internal_model=self.impl.__class__.__module__ + "." + self.impl.__class__.__name__)
......@@ -80,7 +87,7 @@ class ModelBase(metaclass=abc.ABCMeta):
obj_layout.features = model_desc['features']
obj_layout.formulas = decode(model_desc['formulas'])
obj_layout.trained_time = model_desc['trained_time']
obj_layout.impl = decode(model_desc['desc'])
obj_layout.impl = cls.model_decode(model_desc['desc'])
if 'fit_target' in model_desc:
obj_layout.fit_target = decode(model_desc['fit_target'])
else:
......
......@@ -5,7 +5,6 @@ Created on 2017-12-4
@author: cheng.li
"""
from distutils.version import LooseVersion
import arrow
import numpy as np
import pandas as pd
......@@ -13,11 +12,9 @@ from sklearn.ensemble import RandomForestRegressor as RandomForestRegressorImpl
from sklearn.ensemble import RandomForestClassifier as RandomForestClassifierImpl
from sklearn.model_selection import train_test_split
import xgboost as xgb
from xgboost import __version__ as xgbboot_version
from xgboost import XGBRegressor as XGBRegressorImpl
from xgboost import XGBClassifier as XGBClassifierImpl
from alphamind.model.modelbase import create_model_base
from alphamind.utilities import alpha_logger
class RandomForestRegressor(create_model_base('sklearn')):
......@@ -65,12 +62,14 @@ class XGBRegressor(create_model_base('xgboost')):
features=None,
fit_target=None,
n_jobs: int=1,
missing: float=np.nan,
**kwargs):
super().__init__(features=features, fit_target=fit_target)
self.impl = XGBRegressorImpl(n_estimators=n_estimators,
learning_rate=learning_rate,
max_depth=max_depth,
n_jobs=n_jobs,
missing=missing,
**kwargs)
@property
......@@ -87,13 +86,16 @@ class XGBClassifier(create_model_base('xgboost')):
features=None,
fit_target=None,
n_jobs: int=1,
missing: float=np.nan,
**kwargs):
super().__init__(features=features, fit_target=fit_target)
self.impl = XGBClassifierImpl(n_estimators=n_estimators,
learning_rate=learning_rate,
max_depth=max_depth,
n_jobs=n_jobs,
missing=missing,
**kwargs)
self.impl = XGBClassifier.model_decode(self.model_encode())
@property
def importances(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment