Unverified Commit 77b97192 authored by iLampard's avatar iLampard Committed by GitHub

Merge pull request #16 from alpha-miner/master

merge update
parents 2112699d bce4c53e
......@@ -615,7 +615,7 @@ class SqlEngine(object):
def fetch_industry(self,
ref_date: str,
codes: Iterable[int],
codes: Iterable[int] = None,
category: str = 'sw',
level: int = 1):
......@@ -623,21 +623,26 @@ class SqlEngine(object):
code_name = 'industryID' + str(level)
category_name = 'industryName' + str(level)
query = select([Industry.code,
getattr(Industry, code_name).label('industry_code'),
getattr(Industry, category_name).label('industry')]).where(
and_(
cond = and_(
Industry.trade_date == ref_date,
Industry.code.in_(codes),
Industry.industry == industry_category_name
) if codes else and_(
Industry.trade_date == ref_date,
Industry.industry == industry_category_name
)
query = select([Industry.code,
getattr(Industry, code_name).label('industry_code'),
getattr(Industry, category_name).label('industry')]).where(
cond
).distinct()
return pd.read_sql(query, self.engine).dropna().drop_duplicates(['code'])
def fetch_industry_matrix(self,
ref_date: str,
codes: Iterable[int],
codes: Iterable[int] = None,
category: str = 'sw',
level: int = 1):
df = self.fetch_industry(ref_date, codes, category, level)
......
......@@ -424,7 +424,6 @@ class TestSqlEngine(unittest.TestCase):
df3 = self.engine.fetch_factor(ref_date, raw_factor, codes)
ind_matrix = self.engine.fetch_industry_matrix(ref_date, codes, 'sw', 1)
cols = sorted(ind_matrix.columns[2:].tolist())
series = (ind_matrix[cols] * np.array(range(1, len(cols)+1))).sum(axis=1)
......
......@@ -25,5 +25,20 @@ class TestUniverse(unittest.TestCase):
universe = Universe('zz500')
univ_desc = universe.save()
loaded_universe = load_universe(univ_desc)
self.assertEqual(universe, loaded_universe)
def test_universe_arithmic(self):
universe = Universe('zz500') + Universe('hs300')
univ_desc = universe.save()
loaded_universe = load_universe(univ_desc)
self.assertEqual(universe, loaded_universe)
universe = Universe('zz500') - Universe('hs300')
univ_desc = universe.save()
loaded_universe = load_universe(univ_desc)
self.assertEqual(universe, loaded_universe)
universe = Universe('zz500') & Universe('hs300')
univ_desc = universe.save()
loaded_universe = load_universe(univ_desc)
self.assertEqual(universe, loaded_universe)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment