Commit ebd16b85 authored by Dr.李's avatar Dr.李

added more industry styles

parent 4944d755
This diff is collapsed.
...@@ -508,7 +508,8 @@ class SqlEngine(object): ...@@ -508,7 +508,8 @@ class SqlEngine(object):
if universe.is_filtered: if universe.is_filtered:
codes = universe.query(self, start_date, end_date, dates) codes = universe.query(self, start_date, end_date, dates)
risk_exp = pd.merge(risk_exp, codes, how='inner', on=['trade_date', 'code']).sort_values(['trade_date', 'code']) risk_exp = pd.merge(risk_exp, codes, how='inner', on=['trade_date', 'code']).sort_values(
['trade_date', 'code'])
return risk_cov, risk_exp return risk_cov, risk_exp
...@@ -536,7 +537,8 @@ class SqlEngine(object): ...@@ -536,7 +537,8 @@ class SqlEngine(object):
start_date: str = None, start_date: str = None,
end_date: str = None, end_date: str = None,
dates: Iterable[str] = None, dates: Iterable[str] = None,
category: str = 'sw'): category: str = 'sw',
level: int = 1):
industry_category_name = _map_industry_category(category) industry_category_name = _map_industry_category(category)
cond = universe._query_statements(start_date, end_date, dates) cond = universe._query_statements(start_date, end_date, dates)
...@@ -550,10 +552,13 @@ class SqlEngine(object): ...@@ -550,10 +552,13 @@ class SqlEngine(object):
) )
) )
code_name = 'industryID' + str(level)
category_name = 'industryName' + str(level)
query = select([Industry.trade_date, query = select([Industry.trade_date,
Industry.code, Industry.code,
Industry.industryID1.label('industry_code'), getattr(Industry, code_name).label('industry_code'),
Industry.industryName1.label('industry')]).select_from(big_table).distinct() getattr(Industry, category_name).label('industry')]).select_from(big_table).distinct()
df = pd.read_sql(query, self.engine) df = pd.read_sql(query, self.engine)
if universe.is_filtered: if universe.is_filtered:
...@@ -802,10 +807,10 @@ class SqlEngine(object): ...@@ -802,10 +807,10 @@ class SqlEngine(object):
else: else:
id_filter = 'in_' id_filter = 'in_'
t = select([table.trade_id]).\ t = select([table.trade_id]). \
where(and_(table.trade_date <= ref_date, where(and_(table.trade_date <= ref_date,
table.operation == 'withdraw')).alias('t') table.operation == 'withdraw')).alias('t')
query = select([table]).\ query = select([table]). \
where(and_(getattr(table.trade_id, id_filter)(t), where(and_(getattr(table.trade_id, id_filter)(t),
table.trade_date <= ref_date, table.trade_date <= ref_date,
table.operation == 'lend')) table.operation == 'lend'))
...@@ -823,7 +828,7 @@ class SqlEngine(object): ...@@ -823,7 +828,7 @@ class SqlEngine(object):
rule = x['price_rule'].split('@') rule = x['price_rule'].split('@')
if rule[0] in ['closePrice', 'openPrice']: if rule[0] in ['closePrice', 'openPrice']:
query = select([getattr(Market, rule[0])]).\ query = select([getattr(Market, rule[0])]). \
where(and_(Market.code == code, Market.trade_date == rule[1])) where(and_(Market.code == code, Market.trade_date == rule[1]))
data = pd.read_sql(query, self.engine) data = pd.read_sql(query, self.engine)
if not data.empty: if not data.empty:
...@@ -835,6 +840,7 @@ class SqlEngine(object): ...@@ -835,6 +840,7 @@ class SqlEngine(object):
else: else:
raise KeyError('do not have rule for %s' % x['price_rule']) raise KeyError('do not have rule for %s' % x['price_rule'])
return price return price
df['price'] = df.apply(lambda x: parse_price_rule(x), axis=1) df['price'] = df.apply(lambda x: parse_price_rule(x), axis=1)
df.drop(['remark', 'price_rule', 'operation'], axis=1, inplace=True) df.drop(['remark', 'price_rule', 'operation'], axis=1, inplace=True)
...@@ -848,12 +854,10 @@ class SqlEngine(object): ...@@ -848,12 +854,10 @@ class SqlEngine(object):
if __name__ == '__main__': if __name__ == '__main__':
universe = Universe('ss', ['hs300']) universe = Universe('ss', ['hs300'])
engine = SqlEngine() engine = SqlEngine()
df = engine.fetch_outright_status('2017-12-28') df = engine.fetch_industry_range(universe, '2017-12-28', '2017-12-31', category='dx', level=3)
print(df) print(df)
...@@ -43,5 +43,15 @@ def _map_factors(factors: Iterable[str], used_factor_tables) -> Dict: ...@@ -43,5 +43,15 @@ def _map_factors(factors: Iterable[str], used_factor_tables) -> Dict:
def _map_industry_category(category: str) -> str: def _map_industry_category(category: str) -> str:
if category == 'sw': if category == 'sw':
return '申万行业分类' return '申万行业分类'
elif category == 'zz':
return '中证行业分类'
elif category == 'dx':
return '中证行业分类'
elif category == 'zjh':
return '证监会行业V2012'
else: else:
raise ValueError("No other industry is supported at the current time") raise ValueError("No other industry is supported at the current time")
def industry_list(catrgory, level=1):
pass
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment