Commit b17ff59f authored by Dr.李's avatar Dr.李

added tests on linear models

parent 35584fa2
...@@ -18,7 +18,7 @@ from alphamind.utilities import alpha_logger ...@@ -18,7 +18,7 @@ from alphamind.utilities import alpha_logger
class ConstLinearModel(ModelBase): class ConstLinearModel(ModelBase):
def __init__(self, def __init__(self,
features: np.ndarray=None, features: list=None,
weights: np.ndarray=None): weights: np.ndarray=None):
super().__init__(features) super().__init__(features)
if features is not None and weights is not None: if features is not None and weights is not None:
...@@ -46,7 +46,7 @@ class ConstLinearModel(ModelBase): ...@@ -46,7 +46,7 @@ class ConstLinearModel(ModelBase):
class LinearRegression(ModelBase): class LinearRegression(ModelBase):
def __init__(self, features, fit_intercept: bool=False): def __init__(self, features: list=None, fit_intercept: bool=False):
super().__init__(features) super().__init__(features)
self.impl = LinearRegressionImpl(fit_intercept=fit_intercept) self.impl = LinearRegressionImpl(fit_intercept=fit_intercept)
......
...@@ -13,8 +13,9 @@ from alphamind.utilities import alpha_logger ...@@ -13,8 +13,9 @@ from alphamind.utilities import alpha_logger
class ModelBase(metaclass=abc.ABCMeta): class ModelBase(metaclass=abc.ABCMeta):
def __init__(self, features: np.ndarray=None): def __init__(self, features: list=None):
self.features = features if features is not None:
self.features = list(features)
@abc.abstractmethod @abc.abstractmethod
def fit(self, x, y): def fit(self, x, y):
......
# -*- coding: utf-8 -*-
"""
Created on 2017-9-4
@author: cheng.li
"""
import unittest
import numpy as np
from sklearn.linear_model import LinearRegression as LinearRegression2
from alphamind.model.linearmodel import ConstLinearModel
from alphamind.model.linearmodel import LinearRegression
class TestLinearModel(unittest.TestCase):
def setUp(self):
self.n = 3
self.train_x = np.random.randn(1000, self.n)
self.train_y = np.random.randn(1000, 1)
self.predict_x = np.random.randn(10, self.n)
def test_const_linear_model(self):
weights = np.array([1., 2., 3.])
model = ConstLinearModel(features=['a', 'b', 'c'],
weights=weights)
calculated_y = model.predict(self.predict_x)
expected_y = self.predict_x @ weights
np.testing.assert_array_almost_equal(calculated_y, expected_y)
def test_const_linear_model_persistence(self):
weights = np.array([1., 2., 3.])
model = ConstLinearModel(features=['a', 'b', 'c'],
weights=weights)
desc = model.save()
new_model = ConstLinearModel.load(desc)
self.assertEqual(model.features, new_model.features)
np.testing.assert_array_almost_equal(model.weights, new_model.weights)
def test_linear_regression(self):
model = LinearRegression(['a', 'b', 'c'], fit_intercept=False)
model.fit(self.train_x, self.train_y)
calculated_y = model.predict(self.predict_x)
expected_model = LinearRegression2(fit_intercept=False)
expected_model.fit(self.train_x, self.train_y)
expected_y = expected_model.predict(self.predict_x)
np.testing.assert_array_almost_equal(calculated_y, expected_y)
def test_linear_regression_persistence(self):
model = LinearRegression(['a', 'b', 'c'], fit_intercept=False)
model.fit(self.train_x, self.train_y)
desc = model.save()
new_model = LinearRegression.load(desc)
calculated_y = new_model.predict(self.predict_x)
expected_y = model.predict(self.predict_x)
np.testing.assert_array_almost_equal(calculated_y, expected_y)
...@@ -26,6 +26,7 @@ from alphamind.tests.analysis.test_riskanalysis import TestRiskAnalysis ...@@ -26,6 +26,7 @@ from alphamind.tests.analysis.test_riskanalysis import TestRiskAnalysis
from alphamind.tests.analysis.test_perfanalysis import TestPerformanceAnalysis from alphamind.tests.analysis.test_perfanalysis import TestPerformanceAnalysis
from alphamind.tests.analysis.test_factoranalysis import TestFactorAnalysis from alphamind.tests.analysis.test_factoranalysis import TestFactorAnalysis
from alphamind.tests.analysis.test_quantilieanalysis import TestQuantileAnalysis from alphamind.tests.analysis.test_quantilieanalysis import TestQuantileAnalysis
from alphamind.tests.model.test_linearmodel import TestLinearModel
if __name__ == '__main__': if __name__ == '__main__':
...@@ -43,6 +44,7 @@ if __name__ == '__main__': ...@@ -43,6 +44,7 @@ if __name__ == '__main__':
TestRiskAnalysis, TestRiskAnalysis,
TestPerformanceAnalysis, TestPerformanceAnalysis,
TestFactorAnalysis, TestFactorAnalysis,
TestQuantileAnalysis], TestQuantileAnalysis,
TestLinearModel],
alpha_logger) alpha_logger)
runner.run() runner.run()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment