Commit 06d3980a authored by Dr.李's avatar Dr.李

FEATURE: added dynamic loaded tables

parent 893c28c2
......@@ -20,12 +20,6 @@ if "DB_VENDOR" in os.environ and os.environ["DB_VENDOR"].lower() == "rl":
from alphamind.data.dbmodel.models.models_rl import SpecificRiskLong
from alphamind.data.dbmodel.models.models_rl import IndexComponent
from alphamind.data.dbmodel.models.models_rl import IndexWeight
from alphamind.data.dbmodel.models.models_rl import FactorMomentum
# from alphamind.data.dbmodel.models.models_rl import FactorValuationEstimation
# from alphamind.data.dbmodel.models.models_rl import FactorVolatilityValue
factor_tables = [Market, RiskExposure, FactorMomentum]
else:
from alphamind.data.dbmodel.models.models import Market
from alphamind.data.dbmodel.models.models import IndexMarket
......@@ -42,5 +36,3 @@ else:
from alphamind.data.dbmodel.models.models import IndexComponent
from alphamind.data.dbmodel.models.models import RiskMaster
from alphamind.data.dbmodel.models.models import Uqer
\ No newline at end of file
factor_tables = [Market, RiskExposure, Uqer]
......@@ -5,6 +5,7 @@ Created on 2020-10-11
@author: cheng.li
"""
import os
from typing import Iterable
from typing import List
from typing import Tuple
......@@ -17,14 +18,17 @@ import sqlalchemy as sa
import sqlalchemy.orm as orm
from sqlalchemy import (
and_,
column,
join,
select,
outerjoin
outerjoin,
Table
)
from PyFin.api import advanceDateByCalendar
from alphamind.data.dbmodel.models.models_rl import (
metadata,
Market,
IndexMarket,
Industry,
......@@ -33,7 +37,6 @@ from alphamind.data.dbmodel.models.models_rl import (
IndexComponent,
IndexWeight,
)
from alphamind.data.engines.utilities import factor_tables
from alphamind.data.engines.utilities import _map_factors
from alphamind.data.engines.universe import Universe
from alphamind.data.processing import factor_processing
......@@ -42,7 +45,6 @@ from alphamind.data.engines.utilities import _map_risk_model_table
from alphamind.portfolio.riskmodel import FactorRiskModel
from alphamind.data.transformer import Transformer
risk_styles = ['BETA',
'MOMENTUM',
'SIZE',
......@@ -100,9 +102,19 @@ DAILY_RETURN_OFFSET = 0
class SqlEngine:
def __init__(self, db_url: str):
def __init__(self, db_url: str, factor_tables: List[str] = None):
self._engine = sa.create_engine(db_url)
self._session = self.create_session()
if factor_tables:
self._factor_tables = [Table(name, metadata, autoload=True, autoload_with=self._engine)
for name in factor_tables]
else:
try:
factor_tables = os.environ["FACTOR_TABLES"]
self._factor_tables = [Table(name.strip(), metadata, autoload=True, autoload_with=self._engine)
for name in factor_tables.split(",")]
except KeyError:
self._factor_tables = []
def __del__(self):
if self._session:
......@@ -327,19 +339,14 @@ class SqlEngine:
ref_date: str,
factors: Iterable[object],
codes: Iterable[int],
warm_start: int = 0,
used_factor_tables=None) -> pd.DataFrame:
warm_start: int = 0) -> pd.DataFrame:
if isinstance(factors, Transformer):
transformer = factors
else:
transformer = Transformer(factors)
dependency = transformer.dependency
if used_factor_tables:
factor_cols = _map_factors(dependency, used_factor_tables)
else:
factor_cols = _map_factors(dependency, factor_tables)
factor_cols = _map_factors(dependency, self._factor_tables)
start_date = advanceDateByCalendar('china.sse', ref_date, str(-warm_start) + 'b').strftime(
'%Y-%m-%d')
......@@ -350,19 +357,19 @@ class SqlEngine:
joined_tables.add(Market.__table__.name)
for t in set(factor_cols.values()):
if t.__table__.name not in joined_tables:
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.trade_date,
Market.code == t.code,
if t.name not in joined_tables:
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.columns["trade_date"],
Market.code == t.columns["security_code"],
Market.flag == 1,
t.flag == 1))
t.columns["flag"] == 1))
joined_tables.add(t.__table__.name)
joined_tables.add(t.name)
query = select(
[Market.trade_date, Market.code.label("code"),
Market.chgPct.label("chgPct"),
Market.secShortName.label("secShortName")] + list(
factor_cols.keys())) \
column(k) for k in factor_cols.keys())) \
.select_from(big_table).where(and_(Market.trade_date.between(start_date, end_date),
Market.code.in_(codes),
Market.flag == 1))
......@@ -386,8 +393,7 @@ class SqlEngine:
start_date: str = None,
end_date: str = None,
dates: Iterable[str] = None,
external_data: pd.DataFrame = None,
used_factor_tables=None) -> pd.DataFrame:
external_data: pd.DataFrame = None) -> pd.DataFrame:
if isinstance(factors, Transformer):
transformer = factors
......@@ -396,31 +402,28 @@ class SqlEngine:
dependency = transformer.dependency
if used_factor_tables:
factor_cols = _map_factors(dependency, used_factor_tables)
else:
factor_cols = _map_factors(dependency, factor_tables)
factor_cols = _map_factors(dependency, self._factor_tables)
big_table = Market
joined_tables = set()
joined_tables.add(Market.__table__.name)
for t in set(factor_cols.values()):
if t.__table__.name not in joined_tables:
if t.name not in joined_tables:
if dates is not None:
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.trade_date,
Market.code == t.code,
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.columns["trade_date"],
Market.code == t.columns["security_code"],
Market.trade_date.in_(dates),
Market.flag == 1,
t.flag == 1))
t.columns["flag"] == 1))
else:
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.trade_date,
Market.code == t.code,
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.columns["trade_date"],
Market.code == t.columns["security_code"],
Market.trade_date.between(start_date,
end_date),
Market.flag == 1,
t.flag == 1))
joined_tables.add(t.__table__.name)
t.columns["flag"] == 1))
joined_tables.add(t.name)
universe_df = universe.query(self, start_date, end_date, dates)
......@@ -429,7 +432,7 @@ class SqlEngine:
Market.code.label("code"),
Market.chgPct.label("chgPct"),
Market.secShortName.label("secShortName")] + list(
factor_cols.keys())) \
column(k) for k in factor_cols.keys())) \
.select_from(big_table).where(
and_(
Market.code.in_(universe_df.code.unique().tolist()),
......@@ -772,8 +775,7 @@ class SqlEngine:
transformer = Transformer(factors)
factor_data = self.fetch_factor(ref_date,
transformer,
codes,
used_factor_tables=factor_tables)
codes)
if benchmark:
benchmark_data = self.fetch_benchmark(ref_date, benchmark)
......@@ -847,7 +849,8 @@ if __name__ == "__main__":
# db_url = "mysql+mysqldb://reader:Reader#2020@121.37.138.1:13317/vision?charset=utf8"
db_url = "mysql+mysqldb://dxrw:dxRW20_2@121.37.138.1:13317/dxtest?charset=utf8"
sql_engine = SqlEngine(db_url=db_url)
sql_engine = SqlEngine(db_url=db_url, factor_tables=["factor_momentum"])
universe = Universe("hs300")
start_date = '2020-01-02'
end_date = '2020-02-21'
......@@ -892,16 +895,16 @@ if __name__ == "__main__":
start_date=start_date,
end_date=end_date)
print(df)
# df = sql_engine.fetch_risk_model_range(universe=universe,
# start_date=start_date,
# end_date=end_date,
# model_type="factor")
# print(df)
# df = sql_engine.fetch_data("2020-02-11", factors=factors, codes=["2010031963"], benchmark=300)
# print(df)
# df = sql_engine.fetch_data_range(universe,
# factors=factors,
# dates=ref_dates,
# benchmark=benchmark)["factor"]
# print(df)
df = sql_engine.fetch_risk_model_range(universe=universe,
start_date=start_date,
end_date=end_date,
model_type="factor")
print(df)
df = sql_engine.fetch_data("2020-02-11", factors=factors, codes=["2010031963"], benchmark=300)
print(df)
df = sql_engine.fetch_data_range(universe,
factors=factors,
dates=ref_dates,
benchmark=benchmark)["factor"]
print(df)
......@@ -17,7 +17,6 @@ from alphamind.data.dbmodel.models import RiskExposure
from alphamind.data.dbmodel.models import SpecificRiskDay
from alphamind.data.dbmodel.models import SpecificRiskLong
from alphamind.data.dbmodel.models import SpecificRiskShort
from alphamind.data.dbmodel.models import factor_tables
from alphamind.data.engines.industries import INDUSTRY_MAPPING
......@@ -38,8 +37,8 @@ def _map_factors(factors: Iterable[str], used_factor_tables) -> Dict:
to_keep = factors.copy()
for f in factors:
for t in used_factor_tables:
if f in t.__table__.columns:
factor_cols[t.__table__.columns[f]] = t
if f in t.columns:
factor_cols[t.columns[f].name] = t
to_keep.remove(f)
break
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment