Commit f6600807 authored by Dr.李's avatar Dr.李

update tree models

parent 1e8f9e1b
......@@ -9,6 +9,8 @@ from alphamind.model.modelbase import ModelBase
from alphamind.model.linearmodel import ConstLinearModel
from alphamind.model.linearmodel import LinearRegression
from alphamind.model.linearmodel import LassoRegression
from alphamind.model.treemodel import RandomForestRegressor
from alphamind.model.treemodel import XGBRegressor
def load_model(model_desc: dict) -> ModelBase:
......@@ -22,5 +24,9 @@ def load_model(model_desc: dict) -> ModelBase:
return LinearRegression.load(model_desc)
elif 'LassoRegression' in model_name_parts:
return LassoRegression.load(model_desc)
elif 'RandomForestRegressor' in model_name_parts:
return RandomForestRegressor.load(model_desc)
elif 'XGBRegressor' in model_name_parts:
return XGBRegressor.load(model_desc)
else:
raise ValueError('{0} is not currently supported in model loader.'.format(model_name))
......@@ -9,8 +9,8 @@ from typing import List
import numpy as np
from distutils.version import LooseVersion
from sklearn import __version__ as sklearn_version
from xgboost import __version__ as xgbboot_version
from sklearn.ensemble import RandomForestRegressor as RandomForestRegressorImpl
from xgboost import __version__ as xgbboot_version
from xgboost import XGBRegressor as XGBRegressorImpl
from alphamind.model.modelbase import ModelBase
from alphamind.utilities import alpha_logger
......
......@@ -8,6 +8,7 @@ Created on 2017-9-4
import unittest
import numpy as np
from sklearn.linear_model import LinearRegression as LinearRegression2
from alphamind.model.loader import load_model
from alphamind.model.linearmodel import ConstLinearModel
from alphamind.model.linearmodel import LinearRegression
......@@ -36,7 +37,7 @@ class TestLinearModel(unittest.TestCase):
weights=weights)
desc = model.save()
new_model = ConstLinearModel.load(desc)
new_model = load_model(desc)
self.assertEqual(model.features, new_model.features)
np.testing.assert_array_almost_equal(model.weights, new_model.weights)
......@@ -58,7 +59,7 @@ class TestLinearModel(unittest.TestCase):
model.fit(self.train_x, self.train_y)
desc = model.save()
new_model = LinearRegression.load(desc)
new_model = load_model(desc)
calculated_y = new_model.predict(self.predict_x)
expected_y = model.predict(self.predict_x)
......
# -*- coding: utf-8 -*-
"""
Created on 2018-1-5
@author: cheng.li
"""
import unittest
import numpy as np
from alphamind.model.loader import load_model
from alphamind.model.treemodel import RandomForestRegressor
from alphamind.model.treemodel import XGBRegressor
class TestTreeModel(unittest.TestCase):
def test_random_forest_regress(self):
model = RandomForestRegressor(features=list(range(10)))
x = np.random.randn(1000, 10)
y = np.random.randn(1000)
model.fit(x, y)
desc = model.save()
new_model = load_model(desc)
self.assertEqual(model.features, new_model.features)
sample_x = np.random.randn(100, 10)
np.testing.assert_array_almost_equal(model.predict(sample_x), new_model.predict(sample_x))
def tes_xgb_regress(self):
model = XGBRegressor(features=list(range(10)))
x = np.random.randn(1000, 10)
y = np.random.randn(1000)
model.fit(x, y)
desc = model.save()
new_model = load_model(desc)
self.assertEqual(model.features, new_model.features)
sample_x = np.random.randn(100, 10)
np.testing.assert_array_almost_equal(model.predict(sample_x), new_model.predict(sample_x))
\ No newline at end of file
......@@ -28,6 +28,7 @@ from alphamind.tests.analysis.test_perfanalysis import TestPerformanceAnalysis
from alphamind.tests.analysis.test_factoranalysis import TestFactorAnalysis
from alphamind.tests.analysis.test_quantilieanalysis import TestQuantileAnalysis
from alphamind.tests.model.test_linearmodel import TestLinearModel
from alphamind.tests.model.test_treemodel import TestTreeModel
from alphamind.tests.model.test_loader import TestLoader
from alphamind.tests.execution.test_naiveexecutor import TestNaiveExecutor
from alphamind.tests.execution.test_thresholdexecutor import TestThresholdExecutor
......@@ -54,6 +55,7 @@ if __name__ == '__main__':
TestFactorAnalysis,
TestQuantileAnalysis,
TestLinearModel,
TestTreeModel,
TestLoader,
TestNaiveExecutor,
TestThresholdExecutor,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment