Commit 22137a83 authored by Dr.李's avatar Dr.李

added random forest classifier

parent e4487d31
...@@ -30,6 +30,7 @@ from alphamind.model import LassoRegression ...@@ -30,6 +30,7 @@ from alphamind.model import LassoRegression
from alphamind.model import ConstLinearModel from alphamind.model import ConstLinearModel
from alphamind.model import LogisticRegression from alphamind.model import LogisticRegression
from alphamind.model import RandomForestRegressor from alphamind.model import RandomForestRegressor
from alphamind.model import RandomForestClassifier
from alphamind.model import XGBRegressor from alphamind.model import XGBRegressor
from alphamind.model import XGBClassifier from alphamind.model import XGBClassifier
from alphamind.model import load_model from alphamind.model import load_model
...@@ -70,6 +71,7 @@ __all__ = [ ...@@ -70,6 +71,7 @@ __all__ = [
'ConstLinearModel', 'ConstLinearModel',
'LogisticRegression', 'LogisticRegression',
'RandomForestRegressor', 'RandomForestRegressor',
'RandomForestClassifier',
'XGBRegressor', 'XGBRegressor',
'XGBClassifier', 'XGBClassifier',
'load_model', 'load_model',
......
...@@ -11,6 +11,7 @@ from alphamind.model.linearmodel import ConstLinearModel ...@@ -11,6 +11,7 @@ from alphamind.model.linearmodel import ConstLinearModel
from alphamind.model.linearmodel import LogisticRegression from alphamind.model.linearmodel import LogisticRegression
from alphamind.model.treemodel import RandomForestRegressor from alphamind.model.treemodel import RandomForestRegressor
from alphamind.model.treemodel import RandomForestClassifier
from alphamind.model.treemodel import XGBRegressor from alphamind.model.treemodel import XGBRegressor
from alphamind.model.treemodel import XGBClassifier from alphamind.model.treemodel import XGBClassifier
...@@ -22,6 +23,7 @@ __all__ = ['LinearRegression', ...@@ -22,6 +23,7 @@ __all__ = ['LinearRegression',
'ConstLinearModel', 'ConstLinearModel',
'LogisticRegression', 'LogisticRegression',
'RandomForestRegressor', 'RandomForestRegressor',
'RandomForestClassifier',
'XGBRegressor', 'XGBRegressor',
'XGBClassifier', 'XGBClassifier',
'load_model'] 'load_model']
\ No newline at end of file
...@@ -11,6 +11,7 @@ from alphamind.model.linearmodel import LinearRegression ...@@ -11,6 +11,7 @@ from alphamind.model.linearmodel import LinearRegression
from alphamind.model.linearmodel import LassoRegression from alphamind.model.linearmodel import LassoRegression
from alphamind.model.linearmodel import LogisticRegression from alphamind.model.linearmodel import LogisticRegression
from alphamind.model.treemodel import RandomForestRegressor from alphamind.model.treemodel import RandomForestRegressor
from alphamind.model.treemodel import RandomForestClassifier
from alphamind.model.treemodel import XGBRegressor from alphamind.model.treemodel import XGBRegressor
from alphamind.model.treemodel import XGBClassifier from alphamind.model.treemodel import XGBClassifier
...@@ -30,6 +31,8 @@ def load_model(model_desc: dict) -> ModelBase: ...@@ -30,6 +31,8 @@ def load_model(model_desc: dict) -> ModelBase:
return LogisticRegression.load(model_desc) return LogisticRegression.load(model_desc)
elif 'RandomForestRegressor' in model_name_parts: elif 'RandomForestRegressor' in model_name_parts:
return RandomForestRegressor.load(model_desc) return RandomForestRegressor.load(model_desc)
elif 'RandomForestClassifier' in model_name_parts:
return RandomForestClassifier.load(model_desc)
elif 'XGBRegressor' in model_name_parts: elif 'XGBRegressor' in model_name_parts:
return XGBRegressor.load(model_desc) return XGBRegressor.load(model_desc)
elif 'XGBClassifier' in model_name_parts: elif 'XGBClassifier' in model_name_parts:
......
...@@ -9,6 +9,7 @@ from typing import List ...@@ -9,6 +9,7 @@ from typing import List
from distutils.version import LooseVersion from distutils.version import LooseVersion
from sklearn import __version__ as sklearn_version from sklearn import __version__ as sklearn_version
from sklearn.ensemble import RandomForestRegressor as RandomForestRegressorImpl from sklearn.ensemble import RandomForestRegressor as RandomForestRegressorImpl
from sklearn.ensemble import RandomForestClassifier as RandomForestClassifierImpl
from xgboost import __version__ as xgbboot_version from xgboost import __version__ as xgbboot_version
from xgboost import XGBRegressor as XGBRegressorImpl from xgboost import XGBRegressor as XGBRegressorImpl
from xgboost import XGBClassifier as XGBClassifierImpl from xgboost import XGBClassifier as XGBClassifierImpl
...@@ -18,9 +19,49 @@ from alphamind.utilities import alpha_logger ...@@ -18,9 +19,49 @@ from alphamind.utilities import alpha_logger
class RandomForestRegressor(ModelBase): class RandomForestRegressor(ModelBase):
def __init__(self, n_estimators: int=100, features: List=None, **kwargs): def __init__(self,
n_estimators: int=100,
max_features: str='auto',
features: List=None,
**kwargs):
super().__init__(features)
self.impl = RandomForestRegressorImpl(n_estimators=n_estimators,
max_features=max_features,
**kwargs)
self.trained_time = None
def save(self) -> dict:
model_desc = super().save()
model_desc['sklearn_version'] = sklearn_version
model_desc['importances'] = self.importances
return model_desc
@classmethod
def load(cls, model_desc: dict):
obj_layout = super().load(model_desc)
if LooseVersion(sklearn_version) < LooseVersion(model_desc['sklearn_version']):
alpha_logger.warning('Current sklearn version {0} is lower than the model version {1}. '
'Loaded model may work incorrectly.'.format(sklearn_version,
model_desc['sklearn_version']))
return obj_layout
@property
def importances(self):
return self.impl.feature_importances_.tolist()
class RandomForestClassifier(ModelBase):
def __init__(self,
n_estimators: int=100,
max_features: str='auto',
features: List = None,
**kwargs):
super().__init__(features) super().__init__(features)
self.impl = RandomForestRegressorImpl(n_estimators, **kwargs) self.impl = RandomForestClassifierImpl(n_estimators=n_estimators,
max_features=max_features,
**kwargs)
self.trained_time = None self.trained_time = None
def save(self) -> dict: def save(self) -> dict:
......
...@@ -9,13 +9,14 @@ import unittest ...@@ -9,13 +9,14 @@ import unittest
import numpy as np import numpy as np
from alphamind.model.loader import load_model from alphamind.model.loader import load_model
from alphamind.model.treemodel import RandomForestRegressor from alphamind.model.treemodel import RandomForestRegressor
from alphamind.model.treemodel import RandomForestClassifier
from alphamind.model.treemodel import XGBRegressor from alphamind.model.treemodel import XGBRegressor
from alphamind.model.treemodel import XGBClassifier from alphamind.model.treemodel import XGBClassifier
class TestTreeModel(unittest.TestCase): class TestTreeModel(unittest.TestCase):
def test_random_forest_regress(self): def test_random_forest_regress_persistence(self):
model = RandomForestRegressor(features=list(range(10))) model = RandomForestRegressor(features=list(range(10)))
x = np.random.randn(1000, 10) x = np.random.randn(1000, 10)
y = np.random.randn(1000) y = np.random.randn(1000)
...@@ -29,6 +30,21 @@ class TestTreeModel(unittest.TestCase): ...@@ -29,6 +30,21 @@ class TestTreeModel(unittest.TestCase):
sample_x = np.random.randn(100, 10) sample_x = np.random.randn(100, 10)
np.testing.assert_array_almost_equal(model.predict(sample_x), new_model.predict(sample_x)) np.testing.assert_array_almost_equal(model.predict(sample_x), new_model.predict(sample_x))
def test_random_forest_classify_persistence(self):
model = RandomForestClassifier(features=list(range(10)))
x = np.random.randn(1000, 10)
y = np.random.randn(1000)
y = np.where(y > 0, 1, 0)
model.fit(x, y)
desc = model.save()
new_model = load_model(desc)
self.assertEqual(model.features, new_model.features)
sample_x = np.random.randn(100, 10)
np.testing.assert_array_almost_equal(model.predict(sample_x), new_model.predict(sample_x))
def test_xgb_regress_persistence(self): def test_xgb_regress_persistence(self):
model = XGBRegressor(features=list(range(10))) model = XGBRegressor(features=list(range(10)))
x = np.random.randn(1000, 10) x = np.random.randn(1000, 10)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment