Commit 22137a83 authored by Dr.李's avatar Dr.李

added random forest classifier

parent e4487d31
......@@ -30,6 +30,7 @@ from alphamind.model import LassoRegression
from alphamind.model import ConstLinearModel
from alphamind.model import LogisticRegression
from alphamind.model import RandomForestRegressor
from alphamind.model import RandomForestClassifier
from alphamind.model import XGBRegressor
from alphamind.model import XGBClassifier
from alphamind.model import load_model
......@@ -70,6 +71,7 @@ __all__ = [
'ConstLinearModel',
'LogisticRegression',
'RandomForestRegressor',
'RandomForestClassifier',
'XGBRegressor',
'XGBClassifier',
'load_model',
......
......@@ -11,6 +11,7 @@ from alphamind.model.linearmodel import ConstLinearModel
from alphamind.model.linearmodel import LogisticRegression
from alphamind.model.treemodel import RandomForestRegressor
from alphamind.model.treemodel import RandomForestClassifier
from alphamind.model.treemodel import XGBRegressor
from alphamind.model.treemodel import XGBClassifier
......@@ -22,6 +23,7 @@ __all__ = ['LinearRegression',
'ConstLinearModel',
'LogisticRegression',
'RandomForestRegressor',
'RandomForestClassifier',
'XGBRegressor',
'XGBClassifier',
'load_model']
\ No newline at end of file
......@@ -11,6 +11,7 @@ from alphamind.model.linearmodel import LinearRegression
from alphamind.model.linearmodel import LassoRegression
from alphamind.model.linearmodel import LogisticRegression
from alphamind.model.treemodel import RandomForestRegressor
from alphamind.model.treemodel import RandomForestClassifier
from alphamind.model.treemodel import XGBRegressor
from alphamind.model.treemodel import XGBClassifier
......@@ -30,6 +31,8 @@ def load_model(model_desc: dict) -> ModelBase:
return LogisticRegression.load(model_desc)
elif 'RandomForestRegressor' in model_name_parts:
return RandomForestRegressor.load(model_desc)
elif 'RandomForestClassifier' in model_name_parts:
return RandomForestClassifier.load(model_desc)
elif 'XGBRegressor' in model_name_parts:
return XGBRegressor.load(model_desc)
elif 'XGBClassifier' in model_name_parts:
......
......@@ -9,6 +9,7 @@ from typing import List
from distutils.version import LooseVersion
from sklearn import __version__ as sklearn_version
from sklearn.ensemble import RandomForestRegressor as RandomForestRegressorImpl
from sklearn.ensemble import RandomForestClassifier as RandomForestClassifierImpl
from xgboost import __version__ as xgbboot_version
from xgboost import XGBRegressor as XGBRegressorImpl
from xgboost import XGBClassifier as XGBClassifierImpl
......@@ -18,9 +19,49 @@ from alphamind.utilities import alpha_logger
class RandomForestRegressor(ModelBase):
def __init__(self, n_estimators: int=100, features: List=None, **kwargs):
def __init__(self,
n_estimators: int=100,
max_features: str='auto',
features: List=None,
**kwargs):
super().__init__(features)
self.impl = RandomForestRegressorImpl(n_estimators, **kwargs)
self.impl = RandomForestRegressorImpl(n_estimators=n_estimators,
max_features=max_features,
**kwargs)
self.trained_time = None
def save(self) -> dict:
model_desc = super().save()
model_desc['sklearn_version'] = sklearn_version
model_desc['importances'] = self.importances
return model_desc
@classmethod
def load(cls, model_desc: dict):
obj_layout = super().load(model_desc)
if LooseVersion(sklearn_version) < LooseVersion(model_desc['sklearn_version']):
alpha_logger.warning('Current sklearn version {0} is lower than the model version {1}. '
'Loaded model may work incorrectly.'.format(sklearn_version,
model_desc['sklearn_version']))
return obj_layout
@property
def importances(self):
return self.impl.feature_importances_.tolist()
class RandomForestClassifier(ModelBase):
def __init__(self,
n_estimators: int=100,
max_features: str='auto',
features: List = None,
**kwargs):
super().__init__(features)
self.impl = RandomForestClassifierImpl(n_estimators=n_estimators,
max_features=max_features,
**kwargs)
self.trained_time = None
def save(self) -> dict:
......
......@@ -9,13 +9,14 @@ import unittest
import numpy as np
from alphamind.model.loader import load_model
from alphamind.model.treemodel import RandomForestRegressor
from alphamind.model.treemodel import RandomForestClassifier
from alphamind.model.treemodel import XGBRegressor
from alphamind.model.treemodel import XGBClassifier
class TestTreeModel(unittest.TestCase):
def test_random_forest_regress(self):
def test_random_forest_regress_persistence(self):
model = RandomForestRegressor(features=list(range(10)))
x = np.random.randn(1000, 10)
y = np.random.randn(1000)
......@@ -29,6 +30,21 @@ class TestTreeModel(unittest.TestCase):
sample_x = np.random.randn(100, 10)
np.testing.assert_array_almost_equal(model.predict(sample_x), new_model.predict(sample_x))
def test_random_forest_classify_persistence(self):
model = RandomForestClassifier(features=list(range(10)))
x = np.random.randn(1000, 10)
y = np.random.randn(1000)
y = np.where(y > 0, 1, 0)
model.fit(x, y)
desc = model.save()
new_model = load_model(desc)
self.assertEqual(model.features, new_model.features)
sample_x = np.random.randn(100, 10)
np.testing.assert_array_almost_equal(model.predict(sample_x), new_model.predict(sample_x))
def test_xgb_regress_persistence(self):
model = XGBRegressor(features=list(range(10)))
x = np.random.randn(1000, 10)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment