Commit f6600807 authored by Dr.李's avatar Dr.李

update tree models

parent 1e8f9e1b
...@@ -9,6 +9,8 @@ from alphamind.model.modelbase import ModelBase ...@@ -9,6 +9,8 @@ from alphamind.model.modelbase import ModelBase
from alphamind.model.linearmodel import ConstLinearModel from alphamind.model.linearmodel import ConstLinearModel
from alphamind.model.linearmodel import LinearRegression from alphamind.model.linearmodel import LinearRegression
from alphamind.model.linearmodel import LassoRegression from alphamind.model.linearmodel import LassoRegression
from alphamind.model.treemodel import RandomForestRegressor
from alphamind.model.treemodel import XGBRegressor
def load_model(model_desc: dict) -> ModelBase: def load_model(model_desc: dict) -> ModelBase:
...@@ -22,5 +24,9 @@ def load_model(model_desc: dict) -> ModelBase: ...@@ -22,5 +24,9 @@ def load_model(model_desc: dict) -> ModelBase:
return LinearRegression.load(model_desc) return LinearRegression.load(model_desc)
elif 'LassoRegression' in model_name_parts: elif 'LassoRegression' in model_name_parts:
return LassoRegression.load(model_desc) return LassoRegression.load(model_desc)
elif 'RandomForestRegressor' in model_name_parts:
return RandomForestRegressor.load(model_desc)
elif 'XGBRegressor' in model_name_parts:
return XGBRegressor.load(model_desc)
else: else:
raise ValueError('{0} is not currently supported in model loader.'.format(model_name)) raise ValueError('{0} is not currently supported in model loader.'.format(model_name))
...@@ -9,8 +9,8 @@ from typing import List ...@@ -9,8 +9,8 @@ from typing import List
import numpy as np import numpy as np
from distutils.version import LooseVersion from distutils.version import LooseVersion
from sklearn import __version__ as sklearn_version from sklearn import __version__ as sklearn_version
from xgboost import __version__ as xgbboot_version
from sklearn.ensemble import RandomForestRegressor as RandomForestRegressorImpl from sklearn.ensemble import RandomForestRegressor as RandomForestRegressorImpl
from xgboost import __version__ as xgbboot_version
from xgboost import XGBRegressor as XGBRegressorImpl from xgboost import XGBRegressor as XGBRegressorImpl
from alphamind.model.modelbase import ModelBase from alphamind.model.modelbase import ModelBase
from alphamind.utilities import alpha_logger from alphamind.utilities import alpha_logger
......
...@@ -8,6 +8,7 @@ Created on 2017-9-4 ...@@ -8,6 +8,7 @@ Created on 2017-9-4
import unittest import unittest
import numpy as np import numpy as np
from sklearn.linear_model import LinearRegression as LinearRegression2 from sklearn.linear_model import LinearRegression as LinearRegression2
from alphamind.model.loader import load_model
from alphamind.model.linearmodel import ConstLinearModel from alphamind.model.linearmodel import ConstLinearModel
from alphamind.model.linearmodel import LinearRegression from alphamind.model.linearmodel import LinearRegression
...@@ -36,7 +37,7 @@ class TestLinearModel(unittest.TestCase): ...@@ -36,7 +37,7 @@ class TestLinearModel(unittest.TestCase):
weights=weights) weights=weights)
desc = model.save() desc = model.save()
new_model = ConstLinearModel.load(desc) new_model = load_model(desc)
self.assertEqual(model.features, new_model.features) self.assertEqual(model.features, new_model.features)
np.testing.assert_array_almost_equal(model.weights, new_model.weights) np.testing.assert_array_almost_equal(model.weights, new_model.weights)
...@@ -58,7 +59,7 @@ class TestLinearModel(unittest.TestCase): ...@@ -58,7 +59,7 @@ class TestLinearModel(unittest.TestCase):
model.fit(self.train_x, self.train_y) model.fit(self.train_x, self.train_y)
desc = model.save() desc = model.save()
new_model = LinearRegression.load(desc) new_model = load_model(desc)
calculated_y = new_model.predict(self.predict_x) calculated_y = new_model.predict(self.predict_x)
expected_y = model.predict(self.predict_x) expected_y = model.predict(self.predict_x)
......
# -*- coding: utf-8 -*-
"""
Created on 2018-1-5
@author: cheng.li
"""
import unittest
import numpy as np
from alphamind.model.loader import load_model
from alphamind.model.treemodel import RandomForestRegressor
from alphamind.model.treemodel import XGBRegressor
class TestTreeModel(unittest.TestCase):
def test_random_forest_regress(self):
model = RandomForestRegressor(features=list(range(10)))
x = np.random.randn(1000, 10)
y = np.random.randn(1000)
model.fit(x, y)
desc = model.save()
new_model = load_model(desc)
self.assertEqual(model.features, new_model.features)
sample_x = np.random.randn(100, 10)
np.testing.assert_array_almost_equal(model.predict(sample_x), new_model.predict(sample_x))
def tes_xgb_regress(self):
model = XGBRegressor(features=list(range(10)))
x = np.random.randn(1000, 10)
y = np.random.randn(1000)
model.fit(x, y)
desc = model.save()
new_model = load_model(desc)
self.assertEqual(model.features, new_model.features)
sample_x = np.random.randn(100, 10)
np.testing.assert_array_almost_equal(model.predict(sample_x), new_model.predict(sample_x))
\ No newline at end of file
...@@ -28,6 +28,7 @@ from alphamind.tests.analysis.test_perfanalysis import TestPerformanceAnalysis ...@@ -28,6 +28,7 @@ from alphamind.tests.analysis.test_perfanalysis import TestPerformanceAnalysis
from alphamind.tests.analysis.test_factoranalysis import TestFactorAnalysis from alphamind.tests.analysis.test_factoranalysis import TestFactorAnalysis
from alphamind.tests.analysis.test_quantilieanalysis import TestQuantileAnalysis from alphamind.tests.analysis.test_quantilieanalysis import TestQuantileAnalysis
from alphamind.tests.model.test_linearmodel import TestLinearModel from alphamind.tests.model.test_linearmodel import TestLinearModel
from alphamind.tests.model.test_treemodel import TestTreeModel
from alphamind.tests.model.test_loader import TestLoader from alphamind.tests.model.test_loader import TestLoader
from alphamind.tests.execution.test_naiveexecutor import TestNaiveExecutor from alphamind.tests.execution.test_naiveexecutor import TestNaiveExecutor
from alphamind.tests.execution.test_thresholdexecutor import TestThresholdExecutor from alphamind.tests.execution.test_thresholdexecutor import TestThresholdExecutor
...@@ -54,6 +55,7 @@ if __name__ == '__main__': ...@@ -54,6 +55,7 @@ if __name__ == '__main__':
TestFactorAnalysis, TestFactorAnalysis,
TestQuantileAnalysis, TestQuantileAnalysis,
TestLinearModel, TestLinearModel,
TestTreeModel,
TestLoader, TestLoader,
TestNaiveExecutor, TestNaiveExecutor,
TestThresholdExecutor, TestThresholdExecutor,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment