Commit b83b8b64 authored by Dr.李's avatar Dr.李

added industry matrix

parent ebd16b85
...@@ -24,6 +24,7 @@ from alphamind.data.standardize import standardize ...@@ -24,6 +24,7 @@ from alphamind.data.standardize import standardize
from alphamind.data.standardize import projection from alphamind.data.standardize import projection
from alphamind.data.neutralize import neutralize from alphamind.data.neutralize import neutralize
from alphamind.data.engines.sqlengine import factor_tables from alphamind.data.engines.sqlengine import factor_tables
from alphamind.data.engines.utilities import industry_list
from alphamind.model import LinearRegression from alphamind.model import LinearRegression
from alphamind.model import LassoRegression from alphamind.model import LassoRegression
...@@ -66,6 +67,7 @@ __all__ = [ ...@@ -66,6 +67,7 @@ __all__ = [
'projection', 'projection',
'neutralize', 'neutralize',
'factor_tables', 'factor_tables',
'industry_list',
'fetch_data_package', 'fetch_data_package',
'fetch_train_phase', 'fetch_train_phase',
'fetch_predict_phase', 'fetch_predict_phase',
......
...@@ -42,6 +42,7 @@ from alphamind.data.engines.utilities import _map_factors ...@@ -42,6 +42,7 @@ from alphamind.data.engines.utilities import _map_factors
from alphamind.data.engines.utilities import _map_industry_category from alphamind.data.engines.utilities import _map_industry_category
from alphamind.data.engines.utilities import _map_risk_model_table from alphamind.data.engines.utilities import _map_risk_model_table
from alphamind.data.engines.utilities import factor_tables from alphamind.data.engines.utilities import factor_tables
from alphamind.data.engines.utilities import industry_list
from PyFin.api import advanceDateByCalendar from PyFin.api import advanceDateByCalendar
risk_styles = ['BETA', risk_styles = ['BETA',
...@@ -516,13 +517,16 @@ class SqlEngine(object): ...@@ -516,13 +517,16 @@ class SqlEngine(object):
def fetch_industry(self, def fetch_industry(self,
ref_date: str, ref_date: str,
codes: Iterable[int], codes: Iterable[int],
category: str = 'sw'): category: str = 'sw',
level: int = 1):
industry_category_name = _map_industry_category(category) industry_category_name = _map_industry_category(category)
code_name = 'industryID' + str(level)
category_name = 'industryName' + str(level)
query = select([Industry.code, query = select([Industry.code,
Industry.industryID1.label('industry_code'), getattr(Industry, code_name).label('industry_code'),
Industry.industryName1.label('industry')]).where( getattr(Industry, category_name).label('industry')]).where(
and_( and_(
Industry.trade_date == ref_date, Industry.trade_date == ref_date,
Industry.code.in_(codes), Industry.code.in_(codes),
...@@ -532,6 +536,16 @@ class SqlEngine(object): ...@@ -532,6 +536,16 @@ class SqlEngine(object):
return pd.read_sql(query, self.engine) return pd.read_sql(query, self.engine)
def fetch_industry_matrix(self,
ref_date: str,
codes: Iterable[int],
category: str = 'sw',
level: int = 1):
df = self.fetch_industry(ref_date, codes, category, level)
df = pd.get_dummies(df, columns=['industry'], prefix="", prefix_sep="")
industries = industry_list(category, level)
return df[['code', 'industry_code'] + industries]
def fetch_industry_range(self, def fetch_industry_range(self,
universe: Universe, universe: Universe,
start_date: str = None, start_date: str = None,
...@@ -566,6 +580,19 @@ class SqlEngine(object): ...@@ -566,6 +580,19 @@ class SqlEngine(object):
df = pd.merge(df, codes, how='inner', on=['trade_date', 'code']).sort_values(['trade_date', 'code']) df = pd.merge(df, codes, how='inner', on=['trade_date', 'code']).sort_values(['trade_date', 'code'])
return df return df
def fetch_industry_matrix_range(self,
universe: Universe,
start_date: str = None,
end_date: str = None,
dates: Iterable[str] = None,
category: str = 'sw',
level: int = 1):
df = self.fetch_industry_range(universe, start_date, end_date, dates, category, level)
df = pd.get_dummies(df, columns=['industry'], prefix="", prefix_sep="")
industries = industry_list(category, level)
return df[['trade_date', 'code', 'industry_code'] + industries]
def fetch_data(self, ref_date: str, def fetch_data(self, ref_date: str,
factors: Iterable[str], factors: Iterable[str],
codes: Iterable[int], codes: Iterable[int],
...@@ -857,7 +884,10 @@ if __name__ == '__main__': ...@@ -857,7 +884,10 @@ if __name__ == '__main__':
universe = Universe('ss', ['hs300']) universe = Universe('ss', ['hs300'])
engine = SqlEngine() engine = SqlEngine()
ref_date = '2017-12-28'
codes = universe.query(engine, dates=[ref_date])
df = engine.fetch_industry_matrix(ref_date, codes.code.tolist(), 'dx', 1)
print(df)
df = engine.fetch_industry_range(universe, '2017-12-28', '2017-12-31', category='dx', level=3) df = engine.fetch_industry_matrix_range(universe, '2011-12-28', '2017-12-31', category='sw', level=1)
print(df) print(df)
\ No newline at end of file
...@@ -13,6 +13,7 @@ from alphamind.data.dbmodel.models import RiskCovLong ...@@ -13,6 +13,7 @@ from alphamind.data.dbmodel.models import RiskCovLong
from alphamind.data.dbmodel.models import FullFactor from alphamind.data.dbmodel.models import FullFactor
from alphamind.data.dbmodel.models import Gogoal from alphamind.data.dbmodel.models import Gogoal
from alphamind.data.dbmodel.models import Experimental from alphamind.data.dbmodel.models import Experimental
from alphamind.data.engines.industries import INDUSTRY_MAPPING
factor_tables = [FullFactor, Gogoal, Experimental] factor_tables = [FullFactor, Gogoal, Experimental]
...@@ -46,12 +47,12 @@ def _map_industry_category(category: str) -> str: ...@@ -46,12 +47,12 @@ def _map_industry_category(category: str) -> str:
elif category == 'zz': elif category == 'zz':
return '中证行业分类' return '中证行业分类'
elif category == 'dx': elif category == 'dx':
return '中证行业分类' return '东兴行业分类'
elif category == 'zjh': elif category == 'zjh':
return '证监会行业V2012' return '证监会行业V2012'
else: else:
raise ValueError("No other industry is supported at the current time") raise ValueError("No other industry is supported at the current time")
def industry_list(catrgory, level=1): def industry_list(category: str, level: int=1) -> list:
pass return INDUSTRY_MAPPING[category][level]
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment