Commit e90c7f82 authored by Dr.李's avatar Dr.李

fixed tests

parent f73bef69
......@@ -265,10 +265,6 @@ class SqlEngine(object):
query = select([t]).select_from(big_table)
df = pd.read_sql(query, self.session.bind).dropna()
if universe.is_filtered:
codes = universe.query(self, start_date, end_date, dates)
df = pd.merge(df, codes, how='inner', on=['trade_date', 'code'])
if dates:
df = df[df.trade_date.isin(dates)]
return df.sort_values(['trade_date', 'code']).drop_duplicates(['trade_date', 'code'])
......@@ -428,8 +424,6 @@ class SqlEngine(object):
).distinct()
df = pd.read_sql(query, self.engine).replace([-np.inf, np.inf], np.nan)
if universe.is_filtered:
df = pd.merge(df, universe_df, how='inner', on=['trade_date', 'code'])
if external_data is not None:
df = pd.merge(df, external_data, on=['trade_date', 'code']).dropna()
......@@ -620,12 +614,7 @@ class SqlEngine(object):
risk_exp = pd.read_sql(query, self.engine).sort_values(['trade_date', 'code']).dropna()
if universe.is_filtered:
codes = universe.query(self, start_date, end_date, dates)
risk_exp = pd.merge(risk_exp, codes, how='inner', on=['trade_date', 'code']).sort_values(
['trade_date', 'code'])
return risk_cov, risk_exp.drop_duplicates(['trade_date', 'code'])
return risk_cov, risk_exp
def fetch_industry(self,
ref_date: str,
......@@ -683,13 +672,10 @@ class SqlEngine(object):
query = select([Industry.trade_date,
Industry.code,
getattr(Industry, code_name).label('industry_code'),
getattr(Industry, category_name).label('industry')]).select_from(big_table).distinct()
getattr(Industry, category_name).label('industry')]).select_from(big_table)\
.order_by(Industry.trade_date, Industry.code)
df = pd.read_sql(query, self.engine).dropna()
if universe.is_filtered:
codes = universe.query(self, start_date, end_date, dates)
df = pd.merge(df, codes, how='inner', on=['trade_date', 'code']).sort_values(['trade_date', 'code'])
return df.drop_duplicates(['trade_date', 'code'])
return pd.read_sql(query, self.engine).dropna()
def fetch_industry_matrix_range(self,
universe: Universe,
......
......@@ -42,7 +42,7 @@ class TestSqlEngine(unittest.TestCase):
def test_sql_engine_fetch_codes(self):
ref_date = self.ref_date
universe = Universe('custom', ['zz500', 'zz1000'])
universe = Universe('zz500') + Universe('zz1000')
codes = self.engine.fetch_codes(ref_date, universe)
query = select([UniverseTable.code]).where(
......@@ -62,7 +62,7 @@ class TestSqlEngine(unittest.TestCase):
ref_dates = makeSchedule(advanceDateByCalendar('china.sse', self.ref_date, '-6m'),
self.ref_date,
'60b', 'china.sse')
universe = Universe('custom', ['zz500', 'zz1000'])
universe = Universe('zz500') + Universe('zz1000')
codes = self.engine.fetch_codes_range(universe, dates=ref_dates)
query = select([UniverseTable.trade_date, UniverseTable.code]).where(
......@@ -84,7 +84,7 @@ class TestSqlEngine(unittest.TestCase):
def test_sdl_engine_fetch_codes_with_exclude_universe(self):
ref_date = self.ref_date
universe = Universe('custom', ['zz500'], exclude_universe=['cyb'])
universe = Universe('zz500') - Universe('cyb')
codes = self.engine.fetch_codes(ref_date, universe)
query = select([UniverseTable.code]).where(
......@@ -102,7 +102,7 @@ class TestSqlEngine(unittest.TestCase):
horizon = 4
offset = 1
ref_date = self.ref_date
universe = Universe('custom', ['zz500', 'zz1000'])
universe = Universe('zz500') + Universe('zz1000')
codes = self.engine.fetch_codes(ref_date, universe)
dx_return = self.engine.fetch_dx_return(ref_date, codes, horizon=horizon, offset=offset)
......@@ -123,7 +123,7 @@ class TestSqlEngine(unittest.TestCase):
horizon = 4
offset = 0
ref_date = self.ref_date
universe = Universe('custom', ['zz500', 'zz1000'])
universe = Universe('zz500') + Universe('zz1000')
codes = self.engine.fetch_codes(ref_date, universe)
dx_return = self.engine.fetch_dx_return(ref_date, codes, horizon=horizon, offset=offset)
......@@ -145,7 +145,7 @@ class TestSqlEngine(unittest.TestCase):
ref_dates = makeSchedule(advanceDateByCalendar('china.sse', self.ref_date, '-6m'),
self.ref_date,
'60b', 'china.sse')
universe = Universe('custom', ['zz500', 'zz1000'])
universe = Universe('zz500') + Universe('zz1000')
dx_return = self.engine.fetch_dx_return_range(universe,
dates=ref_dates,
......@@ -223,7 +223,7 @@ class TestSqlEngine(unittest.TestCase):
def test_sql_engine_fetch_factor(self):
ref_date = self.ref_date
universe = Universe('custom', ['zz500', 'zz1000'])
universe = Universe('zz500') + Universe('zz1000')
codes = self.engine.fetch_codes(ref_date, universe)
factor = 'ROE'
......@@ -243,7 +243,7 @@ class TestSqlEngine(unittest.TestCase):
ref_dates = makeSchedule(advanceDateByCalendar('china.sse', self.ref_date, '-6m'),
self.ref_date,
'60b', 'china.sse')
universe = Universe('custom', ['zz500', 'zz1000'])
universe = Universe('zz500') + Universe('zz1000')
factor = 'ROE'
factor_data = self.engine.fetch_factor_range(universe, factor, dates=ref_dates)
......@@ -268,7 +268,7 @@ class TestSqlEngine(unittest.TestCase):
self.ref_date,
'60b', 'china.sse')
ref_dates = ref_dates + [advanceDateByCalendar('china.sse', ref_dates[-1], '60b').strftime('%Y-%m-%d')]
universe = Universe('custom', ['zz500', 'zz1000'])
universe = Universe('zz500') + Universe('zz1000')
factor = 'ROE'
factor_data = self.engine.fetch_factor_range_forward(universe, factor, dates=ref_dates)
......@@ -329,7 +329,7 @@ class TestSqlEngine(unittest.TestCase):
def test_sql_engine_fetch_risk_model(self):
ref_date = self.ref_date
universe = Universe('custom', ['zz500', 'zz1000'])
universe = Universe('zz500') + Universe('zz1000')
codes = self.engine.fetch_codes(ref_date, universe)
risk_cov, risk_exp = self.engine.fetch_risk_model(ref_date, codes, risk_model='short')
......@@ -359,7 +359,7 @@ class TestSqlEngine(unittest.TestCase):
def test_sql_engine_fetch_industry_matrix(self):
ref_date = self.ref_date
universe = Universe('custom', ['zz500', 'zz1000'])
universe = Universe('zz500') + Universe('zz1000')
codes = self.engine.fetch_codes(ref_date, universe)
ind_matrix = self.engine.fetch_industry_matrix(ref_date, codes, 'sw', 1)
......@@ -382,7 +382,7 @@ class TestSqlEngine(unittest.TestCase):
def test_sql_engine_fetch_factor_by_categories(self):
ref_date = self.ref_date
universe = Universe('custom', ['zz500', 'zz1000'])
universe = Universe('zz500') + Universe('zz1000')
codes = self.engine.fetch_codes(ref_date, universe)
factor1 = {'f': CSRank('ROE', groups='sw1')}
......
......@@ -21,7 +21,7 @@ class TestComposer(unittest.TestCase):
def test_data_meta_persistence(self):
freq = '5b'
universe = Universe('custom', ['zz800'])
universe = Universe('zz800')
batch = 4
neutralized_risk = ['SIZE']
risk_model = 'long'
......@@ -55,7 +55,7 @@ class TestComposer(unittest.TestCase):
def test_composer_persistence(self):
freq = '5b'
universe = Universe('custom', ['zz800'])
universe = Universe('zz800')
batch = 4
neutralized_risk = ['SIZE']
risk_model = 'long'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment