Commit b83b8b64 authored by Dr.李's avatar Dr.李

added industry matrix

parent ebd16b85
......@@ -24,6 +24,7 @@ from alphamind.data.standardize import standardize
from alphamind.data.standardize import projection
from alphamind.data.neutralize import neutralize
from alphamind.data.engines.sqlengine import factor_tables
from alphamind.data.engines.utilities import industry_list
from alphamind.model import LinearRegression
from alphamind.model import LassoRegression
......@@ -66,6 +67,7 @@ __all__ = [
'projection',
'neutralize',
'factor_tables',
'industry_list',
'fetch_data_package',
'fetch_train_phase',
'fetch_predict_phase',
......
......@@ -42,6 +42,7 @@ from alphamind.data.engines.utilities import _map_factors
from alphamind.data.engines.utilities import _map_industry_category
from alphamind.data.engines.utilities import _map_risk_model_table
from alphamind.data.engines.utilities import factor_tables
from alphamind.data.engines.utilities import industry_list
from PyFin.api import advanceDateByCalendar
risk_styles = ['BETA',
......@@ -516,13 +517,16 @@ class SqlEngine(object):
def fetch_industry(self,
ref_date: str,
codes: Iterable[int],
category: str = 'sw'):
category: str = 'sw',
level: int = 1):
industry_category_name = _map_industry_category(category)
code_name = 'industryID' + str(level)
category_name = 'industryName' + str(level)
query = select([Industry.code,
Industry.industryID1.label('industry_code'),
Industry.industryName1.label('industry')]).where(
getattr(Industry, code_name).label('industry_code'),
getattr(Industry, category_name).label('industry')]).where(
and_(
Industry.trade_date == ref_date,
Industry.code.in_(codes),
......@@ -532,6 +536,16 @@ class SqlEngine(object):
return pd.read_sql(query, self.engine)
def fetch_industry_matrix(self,
ref_date: str,
codes: Iterable[int],
category: str = 'sw',
level: int = 1):
df = self.fetch_industry(ref_date, codes, category, level)
df = pd.get_dummies(df, columns=['industry'], prefix="", prefix_sep="")
industries = industry_list(category, level)
return df[['code', 'industry_code'] + industries]
def fetch_industry_range(self,
universe: Universe,
start_date: str = None,
......@@ -566,6 +580,19 @@ class SqlEngine(object):
df = pd.merge(df, codes, how='inner', on=['trade_date', 'code']).sort_values(['trade_date', 'code'])
return df
def fetch_industry_matrix_range(self,
universe: Universe,
start_date: str = None,
end_date: str = None,
dates: Iterable[str] = None,
category: str = 'sw',
level: int = 1):
df = self.fetch_industry_range(universe, start_date, end_date, dates, category, level)
df = pd.get_dummies(df, columns=['industry'], prefix="", prefix_sep="")
industries = industry_list(category, level)
return df[['trade_date', 'code', 'industry_code'] + industries]
def fetch_data(self, ref_date: str,
factors: Iterable[str],
codes: Iterable[int],
......@@ -857,7 +884,10 @@ if __name__ == '__main__':
universe = Universe('ss', ['hs300'])
engine = SqlEngine()
df = engine.fetch_industry_range(universe, '2017-12-28', '2017-12-31', category='dx', level=3)
ref_date = '2017-12-28'
codes = universe.query(engine, dates=[ref_date])
df = engine.fetch_industry_matrix(ref_date, codes.code.tolist(), 'dx', 1)
print(df)
df = engine.fetch_industry_matrix_range(universe, '2011-12-28', '2017-12-31', category='sw', level=1)
print(df)
\ No newline at end of file
......@@ -13,6 +13,7 @@ from alphamind.data.dbmodel.models import RiskCovLong
from alphamind.data.dbmodel.models import FullFactor
from alphamind.data.dbmodel.models import Gogoal
from alphamind.data.dbmodel.models import Experimental
from alphamind.data.engines.industries import INDUSTRY_MAPPING
factor_tables = [FullFactor, Gogoal, Experimental]
......@@ -46,12 +47,12 @@ def _map_industry_category(category: str) -> str:
elif category == 'zz':
return '中证行业分类'
elif category == 'dx':
return '中证行业分类'
return '东兴行业分类'
elif category == 'zjh':
return '证监会行业V2012'
else:
raise ValueError("No other industry is supported at the current time")
def industry_list(catrgory, level=1):
pass
\ No newline at end of file
def industry_list(category: str, level: int=1) -> list:
return INDUSTRY_MAPPING[category][level]
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment