Commit a29ba940 authored by Dr.李's avatar Dr.李

Merge branch 'dxhb' of https://github.com/alpha-miner/alpha-mind into dxhb

parents 6bef6bee 88651c5e
......@@ -84,5 +84,29 @@ class _StkUniverse(Base):
is_verify = Column(INT, index=True, server_default=text("'0'"))
class _SwIndustryDaily(Base):
__tablename__ = 'sw_industry_daily'
__table_args__ = (
Index('sw_industry_daily_uindex', 'trade_date', 'industry_code1', 'symbol', 'flag', unique=True),
)
id = Column(INT, primary_key=True)
trade_date = Column(Date, nullable=False)
symbol = Column(Text, nullable=False)
company_id = Column(Text, nullable=False)
code = Column("security_code", Text, nullable=False)
sname = Column(Text, nullable=False)
industry_code1 = Column(Text, nullable=False)
industry_name1 = Column(Text)
industry_code2 = Column(Text)
industry_name2 = Column(Text)
industry_code3 = Column(Text)
industry_name3 = Column(Text)
Industry_code4 = Column(Text)
Industry_name4 = Column(Text)
flag = Column(INT, server_default=text("'1'"))
Market = _StkDailyPricePro
Universe = _StkUniverse
Industry = _SwIndustryDaily
......@@ -5,6 +5,7 @@ Created on 2020-10-11
@author: cheng.li
"""
import os
from typing import Dict
from typing import Iterable
from typing import List
......@@ -24,10 +25,18 @@ from sqlalchemy import (
from PyFin.api import advanceDateByCalendar
from alphamind.data.dbmodel.models_rl import (
Market
Market,
Industry
)
if "DB_VENDOR" in os.environ and os.environ["DB_VENDOR"].lower() == "rl":
from alphamind.data.dbmodel.models_rl import Universe as UniverseTable
else:
from alphamind.data.dbmodel.models import Universe as UniverseTable
from alphamind.data.engines.universe import Universe
from alphamind.data.processing import factor_processing
from alphamind.data.engines.utilities import _map_industry_category
DAILY_RETURN_OFFSET = 0
......@@ -108,6 +117,41 @@ class SqlEngine:
post_process=post_process)
return df[['code', 'dx']]
def fetch_dx_return_range(self,
universe,
start_date: str = None,
end_date: str = None,
dates: Iterable[str] = None,
horizon: int = 0,
offset: int = 0,
benchmark: int = None) -> pd.DataFrame:
if dates:
start_date = dates[0]
end_date = dates[-1]
end_date = advanceDateByCalendar('china.sse', end_date,
str(
1 + horizon + offset + DAILY_RETURN_OFFSET) + 'b').strftime(
'%Y-%m-%d')
codes = universe.query(self.engine, start_date, end_date, dates)
t1 = select([Market.trade_date, Market.code.label("code"), Market.chgPct.label("chgPct")]).where(
and_(
Market.trade_date.between(start_date, end_date),
Market.code.in_(codes.code.unique().tolist()),
Market.flag == 1
)
)
df1 = pd.read_sql(t1, self.session.bind).dropna()
df2 = self.fetch_codes_range(universe, start_date, end_date, dates)
df = pd.merge(df1, df2, on=["trade_date", "code"])
df = self._create_stats(df, horizon, offset)
if dates:
df = df[df.trade_date.isin(dates)]
return df.reset_index(drop=True).sort_values(['trade_date', 'code'])
def fetch_codes(self, ref_date: str, universe: Universe) -> List[int]:
df = universe.query(self, ref_date, ref_date)
return sorted(df.code.tolist())
......@@ -119,11 +163,44 @@ class SqlEngine:
dates: Iterable[str] = None) -> pd.DataFrame:
return universe.query(self, start_date, end_date, dates)
def fetch_industry(self,
ref_date: str,
codes: Iterable[int] = None,
category: str = 'sw',
level: int = 1):
code_name = 'industry_code' + str(level)
category_name = 'industry_name' + str(level)
cond = and_(
Industry.trade_date == ref_date,
Industry.code.in_(codes),
Industry.flag == 1
) if codes else and_(
Industry.trade_date == ref_date,
Industry.flag == 1
)
query = select([Industry.code.label("code"),
getattr(Industry, code_name).label('industry_code'),
getattr(Industry, category_name).label('industry')]).where(
cond
).distinct()
return pd.read_sql(query, self.engine).dropna().drop_duplicates(['code'])
if __name__ == "__main__":
db_url = "mysql+mysqldb://reader:Reader#2020@121.37.138.1:13317/vision?charset=utf8"
sql_engine = SqlEngine(db_url=db_url)
# df = sql_engine.fetch_codes_range(start_date='2020-09-29', end_date='2020-10-10', universe=Universe("hs300"))
df = sql_engine.fetch_dx_return("2020-09-10", codes=["2010000001"])
universe = Universe("hs300")
start_date = '2020-09-29'
end_date = '2020-10-10'
df = sql_engine.fetch_codes_range(start_date='start_date', end_date=end_date, universe=Universe("hs300"))
print(df)
df = sql_engine.fetch_dx_return("2020-10-09", codes=["2010031963"])
print(df)
df = sql_engine.fetch_dx_return_range(universe, start_date=start_date, end_date=end_date)
print(df)
df = sql_engine.fetch_industry(ref_date="2020-10-09", codes=["2010031963"])
print(df)
......@@ -5,6 +5,7 @@ Created on 2017-12-25
@author: cheng.li
"""
import os
from typing import Dict
from typing import Iterable
......@@ -51,19 +52,32 @@ def _map_factors(factors: Iterable[str], used_factor_tables) -> Dict:
return factor_cols
def _map_industry_category(category: str) -> str:
if category == 'sw':
return '申万行业分类'
elif category == 'sw_adj':
return '申万行业分类修订'
elif category == 'zz':
return '中证行业分类'
elif category == 'dx':
return '东兴行业分类'
elif category == 'zjh':
return '证监会行业V2012'
else:
raise ValueError("No other industry is supported at the current time")
if "DB_VENDOR" in os.environ and os.environ["DB_VENDOR"].lower() == "rl":
def _map_industry_category(category: str) -> str:
if category == 'sw':
return '申万行业分类(2014)'
elif category == 'zz':
return '中证行业分类'
elif category == 'zx':
return '中信标普行业分类'
elif category == 'zjh':
return '证监会行业分类(2012)-证监会'
else:
raise ValueError("No other industry is supported at the current time")
else:
def _map_industry_category(category: str) -> str:
if category == 'sw':
return '申万行业分类'
elif category == 'sw_adj':
return '申万行业分类修订'
elif category == 'zz':
return '中证行业分类'
elif category == 'dx':
return '东兴行业分类'
elif category == 'zjh':
return '证监会行业V2012'
else:
raise ValueError("No other industry is supported at the current time")
def industry_list(category: str, level: int = 1) -> list:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment