Commit 1e8f9e1b authored by Dr.李's avatar Dr.李

update tree models

parent 88f50881
...@@ -9,8 +9,9 @@ from typing import List ...@@ -9,8 +9,9 @@ from typing import List
import numpy as np import numpy as np
from distutils.version import LooseVersion from distutils.version import LooseVersion
from sklearn import __version__ as sklearn_version from sklearn import __version__ as sklearn_version
from xgboost import __version__ as xgbboot_version
from sklearn.ensemble import RandomForestRegressor as RandomForestRegressorImpl from sklearn.ensemble import RandomForestRegressor as RandomForestRegressorImpl
# from xgboost import XGBRegressor as XGBRegressorImpl from xgboost import XGBRegressor as XGBRegressorImpl
from alphamind.model.modelbase import ModelBase from alphamind.model.modelbase import ModelBase
from alphamind.utilities import alpha_logger from alphamind.utilities import alpha_logger
...@@ -28,6 +29,7 @@ class RandomForestRegressor(ModelBase): ...@@ -28,6 +29,7 @@ class RandomForestRegressor(ModelBase):
def save(self) -> dict: def save(self) -> dict:
model_desc = super().save() model_desc = super().save()
model_desc['sklearn_version'] = sklearn_version model_desc['sklearn_version'] = sklearn_version
return model_desc
@classmethod @classmethod
def load(cls, model_desc: dict): def load(cls, model_desc: dict):
...@@ -40,17 +42,34 @@ class RandomForestRegressor(ModelBase): ...@@ -40,17 +42,34 @@ class RandomForestRegressor(ModelBase):
return obj_layout return obj_layout
# class XGBRegressor(ModelBase): class XGBRegressor(ModelBase):
#
# def __init__(self, def __init__(self,
# n_estimators: int=100, n_estimators: int=100,
# learning_rate: float=0.1, learning_rate: float=0.1,
# max_depth: int=3, max_depth: int=3,
# features: List=None, **kwargs): features: List=None, **kwargs):
# super().__init__(features) super().__init__(features)
# self.impl = XGBRegressorImpl(n_estimators=n_estimators, self.impl = XGBRegressorImpl(n_estimators=n_estimators,
# learning_rate=learning_rate, learning_rate=learning_rate,
# max_depth=max_depth, max_depth=max_depth,
# **kwargs) **kwargs)
def save(self) -> dict:
model_desc = super().save()
model_desc['xgbboot_version'] = xgbboot_version
return model_desc
@classmethod
def load(cls, model_desc: dict):
obj_layout = super().load(model_desc)
if LooseVersion(sklearn_version) < LooseVersion(model_desc['xgbboot_version']):
alpha_logger.warning('Current xgboost version {0} is lower than the model version {1}. '
'Loaded model may work incorrectly.'.format(
xgbboot_version, model_desc['xgbboot_version']))
return obj_layout
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment