Commit 1c3741fb authored by Dr.李's avatar Dr.李

FEATURE: update universe for mysql vendor

parent 04b494ca
......@@ -89,5 +89,57 @@ class _StkDailyPricePro(Base):
update_time = Column(TIMESTAMP, index=True, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"))
class _StkUniverse(Base):
__tablename__ = 'stk_universe'
__table_args__ = (
Index('unique_stk_universe_index', 'trade_date', 'security_code', 'flag', unique=True),
)
id = Column(INTEGER(10), primary_key=True)
trade_date = Column(Date, nullable=False)
code = Column("security_code", String(20), nullable=False)
aerodef = Column(INTEGER(11), server_default=text("'0'"))
agriforest = Column(INTEGER(11), server_default=text("'0'"))
auto = Column(INTEGER(11), server_default=text("'0'"))
bank = Column(INTEGER(11), server_default=text("'0'"))
builddeco = Column(INTEGER(11), server_default=text("'0'"))
chem = Column(INTEGER(11), server_default=text("'0'"))
conmat = Column(INTEGER(11), server_default=text("'0'"))
commetrade = Column(INTEGER(11), server_default=text("'0'"))
computer = Column(INTEGER(11), server_default=text("'0'"))
conglomerates = Column(INTEGER(11), server_default=text("'0'"))
eleceqp = Column(INTEGER(11), server_default=text("'0'"))
electronics = Column(INTEGER(11), server_default=text("'0'"))
foodbever = Column(INTEGER(11), server_default=text("'0'"))
health = Column(INTEGER(11), server_default=text("'0'"))
houseapp = Column(INTEGER(11), server_default=text("'0'"))
ironsteel = Column(INTEGER(11), server_default=text("'0'"))
leiservice = Column(INTEGER(11), server_default=text("'0'"))
lightindus = Column(INTEGER(11), server_default=text("'0'"))
machiequip = Column(INTEGER(11), server_default=text("'0'"))
media = Column(INTEGER(11), server_default=text("'0'"))
mining = Column(INTEGER(11), server_default=text("'0'"))
nonbankfinan = Column(INTEGER(11), server_default=text("'0'"))
nonfermetal = Column(INTEGER(11), server_default=text("'0'"))
realestate = Column(INTEGER(11), server_default=text("'0'"))
telecom = Column(INTEGER(11), server_default=text("'0'"))
textile = Column(INTEGER(11), server_default=text("'0'"))
transportation = Column(INTEGER(11), server_default=text("'0'"))
utilities = Column(INTEGER(11), server_default=text("'0'"))
ashare = Column(INTEGER(11), server_default=text("'0'"))
ashare_ex = Column(INTEGER(11), server_default=text("'0'"))
cyb = Column(INTEGER(11), server_default=text("'0'"))
hs300 = Column(INTEGER(11), server_default=text("'0'"))
sh50 = Column(INTEGER(11), server_default=text("'0'"))
zxb = Column(INTEGER(11), server_default=text("'0'"))
zz1000 = Column(INTEGER(11), server_default=text("'0'"))
zz500 = Column(INTEGER(11), server_default=text("'0'"))
zz800 = Column(INTEGER(11), server_default=text("'0'"))
flag = Column(INTEGER(11), index=True, server_default=text("'1'"))
is_verify = Column(INTEGER(11), index=True, server_default=text("'0'"))
create_time = Column(TIMESTAMP, server_default=text("CURRENT_TIMESTAMP"))
update_time = Column(TIMESTAMP, index=True, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"))
Market = _StkDailyPricePro
Universe = _StkUniverse
......@@ -5,7 +5,11 @@ Created on 2020-10-11
@author: cheng.li
"""
from typing import Dict
from typing import Iterable
from typing import List
from typing import Tuple
from typing import Union
import pandas as pd
......@@ -22,6 +26,8 @@ from PyFin.api import advanceDateByCalendar
from alphamind.data.dbmodel.models_mysql import (
Market
)
from alphamind.data.dbmodel.models_mysql import Universe as UniverseTable
from alphamind.data.engines.universe import Universe
from alphamind.data.processing import factor_processing
......@@ -45,19 +51,14 @@ class SqlEngine:
if self._session:
self._session.close()
@property
def engine(self):
return self._engine
def create_session(self):
db_session = orm.sessionmaker(bind=self._engine)
return db_session()
# def _create_stats(self, table, horizon, offset, code_attr='security_code'):
# stats = func.sum(self._ln_func(1. + table.change_pct)).over(
# partition_by=getattr(table, code_attr),
# order_by=table.trade_date,
# rows=(
# 1 + DAILY_RETURN_OFFSET + offset, 1 + horizon + DAILY_RETURN_OFFSET + offset)).label(
# 'dx')
# return stats
def fetch_dx_return(self,
ref_date: str,
codes: Iterable[int],
......@@ -95,12 +96,26 @@ class SqlEngine:
risk_factors=df[neutralized_risks].values,
post_process=post_process)
return df # df[['code', 'dx']]
df.rename(columns={"security_code": "code", "change_pct": "dx"}, inplace=True)
return df[['code', 'dx']]
def fetch_codes(self, ref_date: str, universe: Universe) -> List[int]:
df = universe.query(self, ref_date, ref_date).rename(columns={"security_code": "code"})
return sorted(df.code.tolist())
def fetch_codes_range(self,
universe: Universe,
start_date: str = None,
end_date: str = None,
dates: Iterable[str] = None) -> pd.DataFrame:
return universe.query(self, start_date, end_date, dates).rename(columns={"security_code": "code"})
if __name__ == "__main__":
import os
os.environ["DB_VENDOR"] = "mysql"
db_url = "mysql+mysqldb://reader:Reader#2020@121.37.138.1:13317/vision?charset=utf8"
sql_engine = SqlEngine(db_url=db_url)
df = sql_engine.fetch_dx_return(ref_date='2020-09-29', codes=["2010003704"])
df = sql_engine.fetch_codes_range(start_date='2020-09-29', end_date='2020-10-10', universe=Universe("hs300"))
print(df)
......@@ -7,6 +7,7 @@ Created on 2017-7-7
import abc
import sys
import os
import pandas as pd
from sqlalchemy import and_
......@@ -14,7 +15,11 @@ from sqlalchemy import not_
from sqlalchemy import or_
from sqlalchemy import select
from alphamind.data.dbmodel.models import Universe as UniverseTable
if "DB_VENDOR" in os.environ and os.environ["DB_VENDOR"].lower() == "mysql":
from alphamind.data.dbmodel.models_mysql import Universe as UniverseTable
else:
from alphamind.data.dbmodel.models import Universe as UniverseTable
class BaseUniverse(metaclass=abc.ABCMeta):
......@@ -47,8 +52,15 @@ class BaseUniverse(metaclass=abc.ABCMeta):
pass
def query(self, engine, start_date: str = None, end_date: str = None, dates=None):
if hasattr(UniverseTable, "flag"):
more_conditions = [UniverseTable.flag == 1]
else:
more_conditions = []
query = select([UniverseTable.trade_date, UniverseTable.code]).where(
self._query_statements(start_date, end_date, dates)
and_(
self._query_statements(start_date, end_date, dates),
*more_conditions
)
).order_by(UniverseTable.trade_date, UniverseTable.code)
return pd.read_sql(query, engine.engine)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment