Commit ebd16b85 authored by Dr.李's avatar Dr.李

added more industry styles

parent 4944d755
This diff is collapsed.
......@@ -508,7 +508,8 @@ class SqlEngine(object):
if universe.is_filtered:
codes = universe.query(self, start_date, end_date, dates)
risk_exp = pd.merge(risk_exp, codes, how='inner', on=['trade_date', 'code']).sort_values(['trade_date', 'code'])
risk_exp = pd.merge(risk_exp, codes, how='inner', on=['trade_date', 'code']).sort_values(
['trade_date', 'code'])
return risk_cov, risk_exp
......@@ -536,7 +537,8 @@ class SqlEngine(object):
start_date: str = None,
end_date: str = None,
dates: Iterable[str] = None,
category: str = 'sw'):
category: str = 'sw',
level: int = 1):
industry_category_name = _map_industry_category(category)
cond = universe._query_statements(start_date, end_date, dates)
......@@ -550,10 +552,13 @@ class SqlEngine(object):
)
)
code_name = 'industryID' + str(level)
category_name = 'industryName' + str(level)
query = select([Industry.trade_date,
Industry.code,
Industry.industryID1.label('industry_code'),
Industry.industryName1.label('industry')]).select_from(big_table).distinct()
getattr(Industry, code_name).label('industry_code'),
getattr(Industry, category_name).label('industry')]).select_from(big_table).distinct()
df = pd.read_sql(query, self.engine)
if universe.is_filtered:
......@@ -802,10 +807,10 @@ class SqlEngine(object):
else:
id_filter = 'in_'
t = select([table.trade_id]).\
t = select([table.trade_id]). \
where(and_(table.trade_date <= ref_date,
table.operation == 'withdraw')).alias('t')
query = select([table]).\
query = select([table]). \
where(and_(getattr(table.trade_id, id_filter)(t),
table.trade_date <= ref_date,
table.operation == 'lend'))
......@@ -823,7 +828,7 @@ class SqlEngine(object):
rule = x['price_rule'].split('@')
if rule[0] in ['closePrice', 'openPrice']:
query = select([getattr(Market, rule[0])]).\
query = select([getattr(Market, rule[0])]). \
where(and_(Market.code == code, Market.trade_date == rule[1]))
data = pd.read_sql(query, self.engine)
if not data.empty:
......@@ -835,6 +840,7 @@ class SqlEngine(object):
else:
raise KeyError('do not have rule for %s' % x['price_rule'])
return price
df['price'] = df.apply(lambda x: parse_price_rule(x), axis=1)
df.drop(['remark', 'price_rule', 'operation'], axis=1, inplace=True)
......@@ -848,12 +854,10 @@ class SqlEngine(object):
if __name__ == '__main__':
universe = Universe('ss', ['hs300'])
engine = SqlEngine()
df = engine.fetch_outright_status('2017-12-28')
df = engine.fetch_industry_range(universe, '2017-12-28', '2017-12-31', category='dx', level=3)
print(df)
......@@ -43,5 +43,15 @@ def _map_factors(factors: Iterable[str], used_factor_tables) -> Dict:
def _map_industry_category(category: str) -> str:
if category == 'sw':
return '申万行业分类'
elif category == 'zz':
return '中证行业分类'
elif category == 'dx':
return '中证行业分类'
elif category == 'zjh':
return '证监会行业V2012'
else:
raise ValueError("No other industry is supported at the current time")
def industry_list(catrgory, level=1):
pass
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment