Commit c29dfebb authored by Dr.李's avatar Dr.李

update tests

parent 4df5acb4
...@@ -496,7 +496,6 @@ class SqlEngine(object): ...@@ -496,7 +496,6 @@ class SqlEngine(object):
df = pd.read_sql(query, self.engine) \ df = pd.read_sql(query, self.engine) \
.replace([-np.inf, np.inf], np.nan) \ .replace([-np.inf, np.inf], np.nan) \
.dropna() \
.sort_values(['trade_date', 'code']) .sort_values(['trade_date', 'code'])
return pd.merge(df, codes[['trade_date', 'code']], how='inner') return pd.merge(df, codes[['trade_date', 'code']], how='inner')
......
...@@ -5,13 +5,14 @@ Created on 2018-4-17 ...@@ -5,13 +5,14 @@ Created on 2018-4-17
@author: cheng.li @author: cheng.li
""" """
import random
import unittest import unittest
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from sqlalchemy import select, and_ from sqlalchemy import select, and_
from PyFin.api import adjustDateByCalendar
from PyFin.api import makeSchedule from PyFin.api import makeSchedule
from PyFin.api import advanceDateByCalendar from PyFin.api import advanceDateByCalendar
from PyFin.api import bizDatesList
from alphamind.tests.test_suite import SKIP_ENGINE_TESTS from alphamind.tests.test_suite import SKIP_ENGINE_TESTS
from alphamind.data.dbmodel.models import Universe as UniverseTable from alphamind.data.dbmodel.models import Universe as UniverseTable
from alphamind.data.dbmodel.models import Market from alphamind.data.dbmodel.models import Market
...@@ -20,8 +21,10 @@ from alphamind.data.dbmodel.models import IndexComponent ...@@ -20,8 +21,10 @@ from alphamind.data.dbmodel.models import IndexComponent
from alphamind.data.dbmodel.models import Uqer from alphamind.data.dbmodel.models import Uqer
from alphamind.data.dbmodel.models import RiskCovShort from alphamind.data.dbmodel.models import RiskCovShort
from alphamind.data.dbmodel.models import RiskExposure from alphamind.data.dbmodel.models import RiskExposure
from alphamind.data.dbmodel.models import Industry
from alphamind.data.engines.sqlengine import SqlEngine from alphamind.data.engines.sqlengine import SqlEngine
from alphamind.data.engines.universe import Universe from alphamind.data.engines.universe import Universe
from alphamind.utilities import alpha_logger
@unittest.skipIf(SKIP_ENGINE_TESTS, "Omit sql engine tests") @unittest.skipIf(SKIP_ENGINE_TESTS, "Omit sql engine tests")
...@@ -29,9 +32,12 @@ class TestSqlEngine(unittest.TestCase): ...@@ -29,9 +32,12 @@ class TestSqlEngine(unittest.TestCase):
def setUp(self): def setUp(self):
self.engine = SqlEngine() self.engine = SqlEngine()
dates_list = bizDatesList('china.sse', '2010-10-01', '2018-04-27')
self.ref_date = random.choice(dates_list).strftime('%Y-%m-%d')
alpha_logger.info("Test date: {0}".format(self.ref_date))
def test_sql_engine_fetch_codes(self): def test_sql_engine_fetch_codes(self):
ref_date = adjustDateByCalendar('china.sse', '2017-01-31') ref_date = self.ref_date
universe = Universe('custom', ['zz500', 'zz1000']) universe = Universe('custom', ['zz500', 'zz1000'])
codes = self.engine.fetch_codes(ref_date, universe) codes = self.engine.fetch_codes(ref_date, universe)
...@@ -67,7 +73,7 @@ class TestSqlEngine(unittest.TestCase): ...@@ -67,7 +73,7 @@ class TestSqlEngine(unittest.TestCase):
def test_sql_engine_fetch_dx_return(self): def test_sql_engine_fetch_dx_return(self):
horizon = 4 horizon = 4
offset = 1 offset = 1
ref_date = adjustDateByCalendar('china.sse', '2017-01-31') ref_date = self.ref_date
universe = Universe('custom', ['zz500', 'zz1000']) universe = Universe('custom', ['zz500', 'zz1000'])
codes = self.engine.fetch_codes(ref_date, universe) codes = self.engine.fetch_codes(ref_date, universe)
...@@ -88,7 +94,7 @@ class TestSqlEngine(unittest.TestCase): ...@@ -88,7 +94,7 @@ class TestSqlEngine(unittest.TestCase):
horizon = 4 horizon = 4
offset = 0 offset = 0
ref_date = adjustDateByCalendar('china.sse', '2017-01-31') ref_date = self.ref_date
universe = Universe('custom', ['zz500', 'zz1000']) universe = Universe('custom', ['zz500', 'zz1000'])
codes = self.engine.fetch_codes(ref_date, universe) codes = self.engine.fetch_codes(ref_date, universe)
...@@ -108,7 +114,9 @@ class TestSqlEngine(unittest.TestCase): ...@@ -108,7 +114,9 @@ class TestSqlEngine(unittest.TestCase):
np.testing.assert_array_almost_equal(dx_return.dx.values, res.chgPct.values) np.testing.assert_array_almost_equal(dx_return.dx.values, res.chgPct.values)
def test_sql_engine_fetch_dx_return_range(self): def test_sql_engine_fetch_dx_return_range(self):
ref_dates = makeSchedule('2017-01-01', '2017-06-30', '60b', 'china.sse') ref_dates = makeSchedule(advanceDateByCalendar('china.sse', self.ref_date, '-6m'),
self.ref_date,
'60b', 'china.sse')
universe = Universe('custom', ['zz500', 'zz1000']) universe = Universe('custom', ['zz500', 'zz1000'])
dx_return = self.engine.fetch_dx_return_range(universe, dx_return = self.engine.fetch_dx_return_range(universe,
...@@ -138,7 +146,7 @@ class TestSqlEngine(unittest.TestCase): ...@@ -138,7 +146,7 @@ class TestSqlEngine(unittest.TestCase):
def test_sql_engine_fetch_dx_return_index(self): def test_sql_engine_fetch_dx_return_index(self):
horizon = 4 horizon = 4
offset = 1 offset = 1
ref_date = adjustDateByCalendar('china.sse', '2017-01-31') ref_date = self.ref_date
dx_return = self.engine.fetch_dx_return_index(ref_date, dx_return = self.engine.fetch_dx_return_index(ref_date,
905, 905,
horizon=horizon, horizon=horizon,
...@@ -159,7 +167,9 @@ class TestSqlEngine(unittest.TestCase): ...@@ -159,7 +167,9 @@ class TestSqlEngine(unittest.TestCase):
np.testing.assert_array_almost_equal(dx_return.dx.values, res.chgPct.values) np.testing.assert_array_almost_equal(dx_return.dx.values, res.chgPct.values)
def test_sql_engine_fetch_dx_return_index_range(self): def test_sql_engine_fetch_dx_return_index_range(self):
ref_dates = makeSchedule('2017-01-01', '2017-06-30', '60b', 'china.sse') ref_dates = makeSchedule(advanceDateByCalendar('china.sse', self.ref_date, '-6m'),
self.ref_date,
'60b', 'china.sse')
index_code = 906 index_code = 906
dx_return = self.engine.fetch_dx_return_index_range(index_code, dx_return = self.engine.fetch_dx_return_index_range(index_code,
...@@ -184,7 +194,7 @@ class TestSqlEngine(unittest.TestCase): ...@@ -184,7 +194,7 @@ class TestSqlEngine(unittest.TestCase):
np.testing.assert_array_almost_equal(calculated_return.dx.values, res.chgPct.values) np.testing.assert_array_almost_equal(calculated_return.dx.values, res.chgPct.values)
def test_sql_engine_fetch_factor(self): def test_sql_engine_fetch_factor(self):
ref_date = adjustDateByCalendar('china.sse', '2017-01-31') ref_date = self.ref_date
universe = Universe('custom', ['zz500', 'zz1000']) universe = Universe('custom', ['zz500', 'zz1000'])
codes = self.engine.fetch_codes(ref_date, universe) codes = self.engine.fetch_codes(ref_date, universe)
factor = 'ROE' factor = 'ROE'
...@@ -202,7 +212,9 @@ class TestSqlEngine(unittest.TestCase): ...@@ -202,7 +212,9 @@ class TestSqlEngine(unittest.TestCase):
np.testing.assert_array_almost_equal(factor_data.ROE.values, df.ROE.values) np.testing.assert_array_almost_equal(factor_data.ROE.values, df.ROE.values)
def test_sql_engine_fetch_factor_range(self): def test_sql_engine_fetch_factor_range(self):
ref_dates = makeSchedule('2017-01-01', '2017-06-30', '60b', 'china.sse') ref_dates = makeSchedule(advanceDateByCalendar('china.sse', self.ref_date, '-6m'),
self.ref_date,
'60b', 'china.sse')
universe = Universe('custom', ['zz500', 'zz1000']) universe = Universe('custom', ['zz500', 'zz1000'])
factor = 'ROE' factor = 'ROE'
...@@ -224,7 +236,9 @@ class TestSqlEngine(unittest.TestCase): ...@@ -224,7 +236,9 @@ class TestSqlEngine(unittest.TestCase):
np.testing.assert_array_almost_equal(calculated_factor.ROE.values, df.ROE.values) np.testing.assert_array_almost_equal(calculated_factor.ROE.values, df.ROE.values)
def test_sql_engine_fetch_factor_range_forward(self): def test_sql_engine_fetch_factor_range_forward(self):
ref_dates = makeSchedule('2017-01-01', '2017-09-30', '60b', 'china.sse') ref_dates = makeSchedule(advanceDateByCalendar('china.sse', self.ref_date, '-6m'),
self.ref_date,
'60b', 'china.sse')
ref_dates = ref_dates + [advanceDateByCalendar('china.sse', ref_dates[-1], '60b').strftime('%Y-%m-%d')] ref_dates = ref_dates + [advanceDateByCalendar('china.sse', ref_dates[-1], '60b').strftime('%Y-%m-%d')]
universe = Universe('custom', ['zz500', 'zz1000']) universe = Universe('custom', ['zz500', 'zz1000'])
factor = 'ROE' factor = 'ROE'
...@@ -248,7 +262,7 @@ class TestSqlEngine(unittest.TestCase): ...@@ -248,7 +262,7 @@ class TestSqlEngine(unittest.TestCase):
np.testing.assert_array_almost_equal(calculated_factor.dx.values, df.ROE.values) np.testing.assert_array_almost_equal(calculated_factor.dx.values, df.ROE.values)
def test_sql_engine_fetch_benchmark(self): def test_sql_engine_fetch_benchmark(self):
ref_date = adjustDateByCalendar('china.sse', '2017-01-31') ref_date = self.ref_date
benchmark = 906 benchmark = 906
index_data = self.engine.fetch_benchmark(ref_date, benchmark) index_data = self.engine.fetch_benchmark(ref_date, benchmark)
...@@ -264,7 +278,9 @@ class TestSqlEngine(unittest.TestCase): ...@@ -264,7 +278,9 @@ class TestSqlEngine(unittest.TestCase):
np.testing.assert_array_almost_equal(df.weight.values, index_data.weight.values) np.testing.assert_array_almost_equal(df.weight.values, index_data.weight.values)
def test_sql_engine_fetch_benchmark_range(self): def test_sql_engine_fetch_benchmark_range(self):
ref_dates = makeSchedule('2017-01-01', '2017-09-30', '60b', 'china.sse') ref_dates = makeSchedule(advanceDateByCalendar('china.sse', self.ref_date, '-9m'),
self.ref_date,
'60b', 'china.sse')
benchmark = 906 benchmark = 906
index_data = self.engine.fetch_benchmark_range(benchmark, dates=ref_dates) index_data = self.engine.fetch_benchmark_range(benchmark, dates=ref_dates)
...@@ -282,7 +298,7 @@ class TestSqlEngine(unittest.TestCase): ...@@ -282,7 +298,7 @@ class TestSqlEngine(unittest.TestCase):
np.testing.assert_array_almost_equal(calculated_data.weight.values, expected_data.weight.values) np.testing.assert_array_almost_equal(calculated_data.weight.values, expected_data.weight.values)
def test_sql_engine_fetch_risk_model(self): def test_sql_engine_fetch_risk_model(self):
ref_date = adjustDateByCalendar('china.sse', '2017-01-31') ref_date = self.ref_date
universe = Universe('custom', ['zz500', 'zz1000']) universe = Universe('custom', ['zz500', 'zz1000'])
codes = self.engine.fetch_codes(ref_date, universe) codes = self.engine.fetch_codes(ref_date, universe)
...@@ -293,4 +309,43 @@ class TestSqlEngine(unittest.TestCase): ...@@ -293,4 +309,43 @@ class TestSqlEngine(unittest.TestCase):
RiskCovShort.trade_date == ref_date RiskCovShort.trade_date == ref_date
) )
cov_df = pd.read_sql(query, con=self.engine.engine) cov_df = pd.read_sql(query, con=self.engine.engine).sort_values('FactorID')
factors = cov_df.Factor.tolist()
np.testing.assert_array_almost_equal(
risk_cov[factors].values, cov_df[factors].values
)
query = select([RiskExposure]).where(
and_(
RiskExposure.trade_date == ref_date,
RiskExposure.code.in_(codes)
)
)
exp_df = pd.read_sql(query, con=self.engine.engine)
np.testing.assert_array_almost_equal(
exp_df[factors].values, risk_exp[factors].values
)
def test_sql_engine_fetch_industry_matrix(self):
ref_date = self.ref_date
universe = Universe('custom', ['zz500', 'zz1000'])
codes = self.engine.fetch_codes(ref_date, universe)
risk_matrix = self.engine.fetch_industry_matrix(ref_date, codes, 'sw', 1)
query = select([Industry.code, Industry.industryName1]).where(
and_(
Industry.trade_date == ref_date,
Industry.industry == '申万行业分类',
Industry.code.in_(codes)
)
)
df = pd.read_sql(query, con=self.engine.engine)
df = pd.get_dummies(df, prefix="", prefix_sep="")
self.assertEqual(len(risk_matrix), len(df))
np.testing.assert_array_almost_equal(
df[risk_matrix.columns[2:]].values, risk_matrix.iloc[:, 2:].values
)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment