Commit c29dfebb authored by Dr.李's avatar Dr.李

update tests

parent 4df5acb4
......@@ -496,7 +496,6 @@ class SqlEngine(object):
df = pd.read_sql(query, self.engine) \
.replace([-np.inf, np.inf], np.nan) \
.dropna() \
.sort_values(['trade_date', 'code'])
return pd.merge(df, codes[['trade_date', 'code']], how='inner')
......
......@@ -5,13 +5,14 @@ Created on 2018-4-17
@author: cheng.li
"""
import random
import unittest
import numpy as np
import pandas as pd
from sqlalchemy import select, and_
from PyFin.api import adjustDateByCalendar
from PyFin.api import makeSchedule
from PyFin.api import advanceDateByCalendar
from PyFin.api import bizDatesList
from alphamind.tests.test_suite import SKIP_ENGINE_TESTS
from alphamind.data.dbmodel.models import Universe as UniverseTable
from alphamind.data.dbmodel.models import Market
......@@ -20,8 +21,10 @@ from alphamind.data.dbmodel.models import IndexComponent
from alphamind.data.dbmodel.models import Uqer
from alphamind.data.dbmodel.models import RiskCovShort
from alphamind.data.dbmodel.models import RiskExposure
from alphamind.data.dbmodel.models import Industry
from alphamind.data.engines.sqlengine import SqlEngine
from alphamind.data.engines.universe import Universe
from alphamind.utilities import alpha_logger
@unittest.skipIf(SKIP_ENGINE_TESTS, "Omit sql engine tests")
......@@ -29,9 +32,12 @@ class TestSqlEngine(unittest.TestCase):
def setUp(self):
self.engine = SqlEngine()
dates_list = bizDatesList('china.sse', '2010-10-01', '2018-04-27')
self.ref_date = random.choice(dates_list).strftime('%Y-%m-%d')
alpha_logger.info("Test date: {0}".format(self.ref_date))
def test_sql_engine_fetch_codes(self):
ref_date = adjustDateByCalendar('china.sse', '2017-01-31')
ref_date = self.ref_date
universe = Universe('custom', ['zz500', 'zz1000'])
codes = self.engine.fetch_codes(ref_date, universe)
......@@ -67,7 +73,7 @@ class TestSqlEngine(unittest.TestCase):
def test_sql_engine_fetch_dx_return(self):
horizon = 4
offset = 1
ref_date = adjustDateByCalendar('china.sse', '2017-01-31')
ref_date = self.ref_date
universe = Universe('custom', ['zz500', 'zz1000'])
codes = self.engine.fetch_codes(ref_date, universe)
......@@ -88,7 +94,7 @@ class TestSqlEngine(unittest.TestCase):
horizon = 4
offset = 0
ref_date = adjustDateByCalendar('china.sse', '2017-01-31')
ref_date = self.ref_date
universe = Universe('custom', ['zz500', 'zz1000'])
codes = self.engine.fetch_codes(ref_date, universe)
......@@ -108,7 +114,9 @@ class TestSqlEngine(unittest.TestCase):
np.testing.assert_array_almost_equal(dx_return.dx.values, res.chgPct.values)
def test_sql_engine_fetch_dx_return_range(self):
ref_dates = makeSchedule('2017-01-01', '2017-06-30', '60b', 'china.sse')
ref_dates = makeSchedule(advanceDateByCalendar('china.sse', self.ref_date, '-6m'),
self.ref_date,
'60b', 'china.sse')
universe = Universe('custom', ['zz500', 'zz1000'])
dx_return = self.engine.fetch_dx_return_range(universe,
......@@ -138,7 +146,7 @@ class TestSqlEngine(unittest.TestCase):
def test_sql_engine_fetch_dx_return_index(self):
horizon = 4
offset = 1
ref_date = adjustDateByCalendar('china.sse', '2017-01-31')
ref_date = self.ref_date
dx_return = self.engine.fetch_dx_return_index(ref_date,
905,
horizon=horizon,
......@@ -159,7 +167,9 @@ class TestSqlEngine(unittest.TestCase):
np.testing.assert_array_almost_equal(dx_return.dx.values, res.chgPct.values)
def test_sql_engine_fetch_dx_return_index_range(self):
ref_dates = makeSchedule('2017-01-01', '2017-06-30', '60b', 'china.sse')
ref_dates = makeSchedule(advanceDateByCalendar('china.sse', self.ref_date, '-6m'),
self.ref_date,
'60b', 'china.sse')
index_code = 906
dx_return = self.engine.fetch_dx_return_index_range(index_code,
......@@ -184,7 +194,7 @@ class TestSqlEngine(unittest.TestCase):
np.testing.assert_array_almost_equal(calculated_return.dx.values, res.chgPct.values)
def test_sql_engine_fetch_factor(self):
ref_date = adjustDateByCalendar('china.sse', '2017-01-31')
ref_date = self.ref_date
universe = Universe('custom', ['zz500', 'zz1000'])
codes = self.engine.fetch_codes(ref_date, universe)
factor = 'ROE'
......@@ -202,7 +212,9 @@ class TestSqlEngine(unittest.TestCase):
np.testing.assert_array_almost_equal(factor_data.ROE.values, df.ROE.values)
def test_sql_engine_fetch_factor_range(self):
ref_dates = makeSchedule('2017-01-01', '2017-06-30', '60b', 'china.sse')
ref_dates = makeSchedule(advanceDateByCalendar('china.sse', self.ref_date, '-6m'),
self.ref_date,
'60b', 'china.sse')
universe = Universe('custom', ['zz500', 'zz1000'])
factor = 'ROE'
......@@ -224,7 +236,9 @@ class TestSqlEngine(unittest.TestCase):
np.testing.assert_array_almost_equal(calculated_factor.ROE.values, df.ROE.values)
def test_sql_engine_fetch_factor_range_forward(self):
ref_dates = makeSchedule('2017-01-01', '2017-09-30', '60b', 'china.sse')
ref_dates = makeSchedule(advanceDateByCalendar('china.sse', self.ref_date, '-6m'),
self.ref_date,
'60b', 'china.sse')
ref_dates = ref_dates + [advanceDateByCalendar('china.sse', ref_dates[-1], '60b').strftime('%Y-%m-%d')]
universe = Universe('custom', ['zz500', 'zz1000'])
factor = 'ROE'
......@@ -248,7 +262,7 @@ class TestSqlEngine(unittest.TestCase):
np.testing.assert_array_almost_equal(calculated_factor.dx.values, df.ROE.values)
def test_sql_engine_fetch_benchmark(self):
ref_date = adjustDateByCalendar('china.sse', '2017-01-31')
ref_date = self.ref_date
benchmark = 906
index_data = self.engine.fetch_benchmark(ref_date, benchmark)
......@@ -264,7 +278,9 @@ class TestSqlEngine(unittest.TestCase):
np.testing.assert_array_almost_equal(df.weight.values, index_data.weight.values)
def test_sql_engine_fetch_benchmark_range(self):
ref_dates = makeSchedule('2017-01-01', '2017-09-30', '60b', 'china.sse')
ref_dates = makeSchedule(advanceDateByCalendar('china.sse', self.ref_date, '-9m'),
self.ref_date,
'60b', 'china.sse')
benchmark = 906
index_data = self.engine.fetch_benchmark_range(benchmark, dates=ref_dates)
......@@ -282,7 +298,7 @@ class TestSqlEngine(unittest.TestCase):
np.testing.assert_array_almost_equal(calculated_data.weight.values, expected_data.weight.values)
def test_sql_engine_fetch_risk_model(self):
ref_date = adjustDateByCalendar('china.sse', '2017-01-31')
ref_date = self.ref_date
universe = Universe('custom', ['zz500', 'zz1000'])
codes = self.engine.fetch_codes(ref_date, universe)
......@@ -293,4 +309,43 @@ class TestSqlEngine(unittest.TestCase):
RiskCovShort.trade_date == ref_date
)
cov_df = pd.read_sql(query, con=self.engine.engine)
cov_df = pd.read_sql(query, con=self.engine.engine).sort_values('FactorID')
factors = cov_df.Factor.tolist()
np.testing.assert_array_almost_equal(
risk_cov[factors].values, cov_df[factors].values
)
query = select([RiskExposure]).where(
and_(
RiskExposure.trade_date == ref_date,
RiskExposure.code.in_(codes)
)
)
exp_df = pd.read_sql(query, con=self.engine.engine)
np.testing.assert_array_almost_equal(
exp_df[factors].values, risk_exp[factors].values
)
def test_sql_engine_fetch_industry_matrix(self):
ref_date = self.ref_date
universe = Universe('custom', ['zz500', 'zz1000'])
codes = self.engine.fetch_codes(ref_date, universe)
risk_matrix = self.engine.fetch_industry_matrix(ref_date, codes, 'sw', 1)
query = select([Industry.code, Industry.industryName1]).where(
and_(
Industry.trade_date == ref_date,
Industry.industry == '申万行业分类',
Industry.code.in_(codes)
)
)
df = pd.read_sql(query, con=self.engine.engine)
df = pd.get_dummies(df, prefix="", prefix_sep="")
self.assertEqual(len(risk_matrix), len(df))
np.testing.assert_array_almost_equal(
df[risk_matrix.columns[2:]].values, risk_matrix.iloc[:, 2:].values
)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment