Commit a29ba940 authored by Dr.李's avatar Dr.李

Merge branch 'dxhb' of https://github.com/alpha-miner/alpha-mind into dxhb

parents 6bef6bee 88651c5e
...@@ -84,5 +84,29 @@ class _StkUniverse(Base): ...@@ -84,5 +84,29 @@ class _StkUniverse(Base):
is_verify = Column(INT, index=True, server_default=text("'0'")) is_verify = Column(INT, index=True, server_default=text("'0'"))
class _SwIndustryDaily(Base):
__tablename__ = 'sw_industry_daily'
__table_args__ = (
Index('sw_industry_daily_uindex', 'trade_date', 'industry_code1', 'symbol', 'flag', unique=True),
)
id = Column(INT, primary_key=True)
trade_date = Column(Date, nullable=False)
symbol = Column(Text, nullable=False)
company_id = Column(Text, nullable=False)
code = Column("security_code", Text, nullable=False)
sname = Column(Text, nullable=False)
industry_code1 = Column(Text, nullable=False)
industry_name1 = Column(Text)
industry_code2 = Column(Text)
industry_name2 = Column(Text)
industry_code3 = Column(Text)
industry_name3 = Column(Text)
Industry_code4 = Column(Text)
Industry_name4 = Column(Text)
flag = Column(INT, server_default=text("'1'"))
Market = _StkDailyPricePro Market = _StkDailyPricePro
Universe = _StkUniverse Universe = _StkUniverse
Industry = _SwIndustryDaily
...@@ -5,6 +5,7 @@ Created on 2020-10-11 ...@@ -5,6 +5,7 @@ Created on 2020-10-11
@author: cheng.li @author: cheng.li
""" """
import os
from typing import Dict from typing import Dict
from typing import Iterable from typing import Iterable
from typing import List from typing import List
...@@ -24,10 +25,18 @@ from sqlalchemy import ( ...@@ -24,10 +25,18 @@ from sqlalchemy import (
from PyFin.api import advanceDateByCalendar from PyFin.api import advanceDateByCalendar
from alphamind.data.dbmodel.models_rl import ( from alphamind.data.dbmodel.models_rl import (
Market Market,
Industry
) )
if "DB_VENDOR" in os.environ and os.environ["DB_VENDOR"].lower() == "rl":
from alphamind.data.dbmodel.models_rl import Universe as UniverseTable
else:
from alphamind.data.dbmodel.models import Universe as UniverseTable
from alphamind.data.engines.universe import Universe from alphamind.data.engines.universe import Universe
from alphamind.data.processing import factor_processing from alphamind.data.processing import factor_processing
from alphamind.data.engines.utilities import _map_industry_category
DAILY_RETURN_OFFSET = 0 DAILY_RETURN_OFFSET = 0
...@@ -108,6 +117,41 @@ class SqlEngine: ...@@ -108,6 +117,41 @@ class SqlEngine:
post_process=post_process) post_process=post_process)
return df[['code', 'dx']] return df[['code', 'dx']]
def fetch_dx_return_range(self,
universe,
start_date: str = None,
end_date: str = None,
dates: Iterable[str] = None,
horizon: int = 0,
offset: int = 0,
benchmark: int = None) -> pd.DataFrame:
if dates:
start_date = dates[0]
end_date = dates[-1]
end_date = advanceDateByCalendar('china.sse', end_date,
str(
1 + horizon + offset + DAILY_RETURN_OFFSET) + 'b').strftime(
'%Y-%m-%d')
codes = universe.query(self.engine, start_date, end_date, dates)
t1 = select([Market.trade_date, Market.code.label("code"), Market.chgPct.label("chgPct")]).where(
and_(
Market.trade_date.between(start_date, end_date),
Market.code.in_(codes.code.unique().tolist()),
Market.flag == 1
)
)
df1 = pd.read_sql(t1, self.session.bind).dropna()
df2 = self.fetch_codes_range(universe, start_date, end_date, dates)
df = pd.merge(df1, df2, on=["trade_date", "code"])
df = self._create_stats(df, horizon, offset)
if dates:
df = df[df.trade_date.isin(dates)]
return df.reset_index(drop=True).sort_values(['trade_date', 'code'])
def fetch_codes(self, ref_date: str, universe: Universe) -> List[int]: def fetch_codes(self, ref_date: str, universe: Universe) -> List[int]:
df = universe.query(self, ref_date, ref_date) df = universe.query(self, ref_date, ref_date)
return sorted(df.code.tolist()) return sorted(df.code.tolist())
...@@ -119,11 +163,44 @@ class SqlEngine: ...@@ -119,11 +163,44 @@ class SqlEngine:
dates: Iterable[str] = None) -> pd.DataFrame: dates: Iterable[str] = None) -> pd.DataFrame:
return universe.query(self, start_date, end_date, dates) return universe.query(self, start_date, end_date, dates)
def fetch_industry(self,
ref_date: str,
codes: Iterable[int] = None,
category: str = 'sw',
level: int = 1):
code_name = 'industry_code' + str(level)
category_name = 'industry_name' + str(level)
cond = and_(
Industry.trade_date == ref_date,
Industry.code.in_(codes),
Industry.flag == 1
) if codes else and_(
Industry.trade_date == ref_date,
Industry.flag == 1
)
query = select([Industry.code.label("code"),
getattr(Industry, code_name).label('industry_code'),
getattr(Industry, category_name).label('industry')]).where(
cond
).distinct()
return pd.read_sql(query, self.engine).dropna().drop_duplicates(['code'])
if __name__ == "__main__": if __name__ == "__main__":
db_url = "mysql+mysqldb://reader:Reader#2020@121.37.138.1:13317/vision?charset=utf8" db_url = "mysql+mysqldb://reader:Reader#2020@121.37.138.1:13317/vision?charset=utf8"
sql_engine = SqlEngine(db_url=db_url) sql_engine = SqlEngine(db_url=db_url)
# df = sql_engine.fetch_codes_range(start_date='2020-09-29', end_date='2020-10-10', universe=Universe("hs300")) universe = Universe("hs300")
df = sql_engine.fetch_dx_return("2020-09-10", codes=["2010000001"]) start_date = '2020-09-29'
end_date = '2020-10-10'
df = sql_engine.fetch_codes_range(start_date='start_date', end_date=end_date, universe=Universe("hs300"))
print(df)
df = sql_engine.fetch_dx_return("2020-10-09", codes=["2010031963"])
print(df)
df = sql_engine.fetch_dx_return_range(universe, start_date=start_date, end_date=end_date)
print(df)
df = sql_engine.fetch_industry(ref_date="2020-10-09", codes=["2010031963"])
print(df) print(df)
...@@ -5,6 +5,7 @@ Created on 2017-12-25 ...@@ -5,6 +5,7 @@ Created on 2017-12-25
@author: cheng.li @author: cheng.li
""" """
import os
from typing import Dict from typing import Dict
from typing import Iterable from typing import Iterable
...@@ -51,7 +52,20 @@ def _map_factors(factors: Iterable[str], used_factor_tables) -> Dict: ...@@ -51,7 +52,20 @@ def _map_factors(factors: Iterable[str], used_factor_tables) -> Dict:
return factor_cols return factor_cols
def _map_industry_category(category: str) -> str: if "DB_VENDOR" in os.environ and os.environ["DB_VENDOR"].lower() == "rl":
def _map_industry_category(category: str) -> str:
if category == 'sw':
return '申万行业分类(2014)'
elif category == 'zz':
return '中证行业分类'
elif category == 'zx':
return '中信标普行业分类'
elif category == 'zjh':
return '证监会行业分类(2012)-证监会'
else:
raise ValueError("No other industry is supported at the current time")
else:
def _map_industry_category(category: str) -> str:
if category == 'sw': if category == 'sw':
return '申万行业分类' return '申万行业分类'
elif category == 'sw_adj': elif category == 'sw_adj':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment