Commit 06d3980a authored by Dr.李's avatar Dr.李

FEATURE: added dynamic loaded tables

parent 893c28c2
...@@ -20,12 +20,6 @@ if "DB_VENDOR" in os.environ and os.environ["DB_VENDOR"].lower() == "rl": ...@@ -20,12 +20,6 @@ if "DB_VENDOR" in os.environ and os.environ["DB_VENDOR"].lower() == "rl":
from alphamind.data.dbmodel.models.models_rl import SpecificRiskLong from alphamind.data.dbmodel.models.models_rl import SpecificRiskLong
from alphamind.data.dbmodel.models.models_rl import IndexComponent from alphamind.data.dbmodel.models.models_rl import IndexComponent
from alphamind.data.dbmodel.models.models_rl import IndexWeight from alphamind.data.dbmodel.models.models_rl import IndexWeight
from alphamind.data.dbmodel.models.models_rl import FactorMomentum
# from alphamind.data.dbmodel.models.models_rl import FactorValuationEstimation
# from alphamind.data.dbmodel.models.models_rl import FactorVolatilityValue
factor_tables = [Market, RiskExposure, FactorMomentum]
else: else:
from alphamind.data.dbmodel.models.models import Market from alphamind.data.dbmodel.models.models import Market
from alphamind.data.dbmodel.models.models import IndexMarket from alphamind.data.dbmodel.models.models import IndexMarket
...@@ -42,5 +36,3 @@ else: ...@@ -42,5 +36,3 @@ else:
from alphamind.data.dbmodel.models.models import IndexComponent from alphamind.data.dbmodel.models.models import IndexComponent
from alphamind.data.dbmodel.models.models import RiskMaster from alphamind.data.dbmodel.models.models import RiskMaster
from alphamind.data.dbmodel.models.models import Uqer from alphamind.data.dbmodel.models.models import Uqer
\ No newline at end of file
factor_tables = [Market, RiskExposure, Uqer]
...@@ -5,6 +5,7 @@ Created on 2020-10-11 ...@@ -5,6 +5,7 @@ Created on 2020-10-11
@author: cheng.li @author: cheng.li
""" """
import os
from typing import Iterable from typing import Iterable
from typing import List from typing import List
from typing import Tuple from typing import Tuple
...@@ -17,14 +18,17 @@ import sqlalchemy as sa ...@@ -17,14 +18,17 @@ import sqlalchemy as sa
import sqlalchemy.orm as orm import sqlalchemy.orm as orm
from sqlalchemy import ( from sqlalchemy import (
and_, and_,
column,
join, join,
select, select,
outerjoin outerjoin,
Table
) )
from PyFin.api import advanceDateByCalendar from PyFin.api import advanceDateByCalendar
from alphamind.data.dbmodel.models.models_rl import ( from alphamind.data.dbmodel.models.models_rl import (
metadata,
Market, Market,
IndexMarket, IndexMarket,
Industry, Industry,
...@@ -33,7 +37,6 @@ from alphamind.data.dbmodel.models.models_rl import ( ...@@ -33,7 +37,6 @@ from alphamind.data.dbmodel.models.models_rl import (
IndexComponent, IndexComponent,
IndexWeight, IndexWeight,
) )
from alphamind.data.engines.utilities import factor_tables
from alphamind.data.engines.utilities import _map_factors from alphamind.data.engines.utilities import _map_factors
from alphamind.data.engines.universe import Universe from alphamind.data.engines.universe import Universe
from alphamind.data.processing import factor_processing from alphamind.data.processing import factor_processing
...@@ -42,7 +45,6 @@ from alphamind.data.engines.utilities import _map_risk_model_table ...@@ -42,7 +45,6 @@ from alphamind.data.engines.utilities import _map_risk_model_table
from alphamind.portfolio.riskmodel import FactorRiskModel from alphamind.portfolio.riskmodel import FactorRiskModel
from alphamind.data.transformer import Transformer from alphamind.data.transformer import Transformer
risk_styles = ['BETA', risk_styles = ['BETA',
'MOMENTUM', 'MOMENTUM',
'SIZE', 'SIZE',
...@@ -100,9 +102,19 @@ DAILY_RETURN_OFFSET = 0 ...@@ -100,9 +102,19 @@ DAILY_RETURN_OFFSET = 0
class SqlEngine: class SqlEngine:
def __init__(self, db_url: str): def __init__(self, db_url: str, factor_tables: List[str] = None):
self._engine = sa.create_engine(db_url) self._engine = sa.create_engine(db_url)
self._session = self.create_session() self._session = self.create_session()
if factor_tables:
self._factor_tables = [Table(name, metadata, autoload=True, autoload_with=self._engine)
for name in factor_tables]
else:
try:
factor_tables = os.environ["FACTOR_TABLES"]
self._factor_tables = [Table(name.strip(), metadata, autoload=True, autoload_with=self._engine)
for name in factor_tables.split(",")]
except KeyError:
self._factor_tables = []
def __del__(self): def __del__(self):
if self._session: if self._session:
...@@ -327,19 +339,14 @@ class SqlEngine: ...@@ -327,19 +339,14 @@ class SqlEngine:
ref_date: str, ref_date: str,
factors: Iterable[object], factors: Iterable[object],
codes: Iterable[int], codes: Iterable[int],
warm_start: int = 0, warm_start: int = 0) -> pd.DataFrame:
used_factor_tables=None) -> pd.DataFrame:
if isinstance(factors, Transformer): if isinstance(factors, Transformer):
transformer = factors transformer = factors
else: else:
transformer = Transformer(factors) transformer = Transformer(factors)
dependency = transformer.dependency dependency = transformer.dependency
factor_cols = _map_factors(dependency, self._factor_tables)
if used_factor_tables:
factor_cols = _map_factors(dependency, used_factor_tables)
else:
factor_cols = _map_factors(dependency, factor_tables)
start_date = advanceDateByCalendar('china.sse', ref_date, str(-warm_start) + 'b').strftime( start_date = advanceDateByCalendar('china.sse', ref_date, str(-warm_start) + 'b').strftime(
'%Y-%m-%d') '%Y-%m-%d')
...@@ -350,19 +357,19 @@ class SqlEngine: ...@@ -350,19 +357,19 @@ class SqlEngine:
joined_tables.add(Market.__table__.name) joined_tables.add(Market.__table__.name)
for t in set(factor_cols.values()): for t in set(factor_cols.values()):
if t.__table__.name not in joined_tables: if t.name not in joined_tables:
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.trade_date, big_table = outerjoin(big_table, t, and_(Market.trade_date == t.columns["trade_date"],
Market.code == t.code, Market.code == t.columns["security_code"],
Market.flag == 1, Market.flag == 1,
t.flag == 1)) t.columns["flag"] == 1))
joined_tables.add(t.__table__.name) joined_tables.add(t.name)
query = select( query = select(
[Market.trade_date, Market.code.label("code"), [Market.trade_date, Market.code.label("code"),
Market.chgPct.label("chgPct"), Market.chgPct.label("chgPct"),
Market.secShortName.label("secShortName")] + list( Market.secShortName.label("secShortName")] + list(
factor_cols.keys())) \ column(k) for k in factor_cols.keys())) \
.select_from(big_table).where(and_(Market.trade_date.between(start_date, end_date), .select_from(big_table).where(and_(Market.trade_date.between(start_date, end_date),
Market.code.in_(codes), Market.code.in_(codes),
Market.flag == 1)) Market.flag == 1))
...@@ -386,8 +393,7 @@ class SqlEngine: ...@@ -386,8 +393,7 @@ class SqlEngine:
start_date: str = None, start_date: str = None,
end_date: str = None, end_date: str = None,
dates: Iterable[str] = None, dates: Iterable[str] = None,
external_data: pd.DataFrame = None, external_data: pd.DataFrame = None) -> pd.DataFrame:
used_factor_tables=None) -> pd.DataFrame:
if isinstance(factors, Transformer): if isinstance(factors, Transformer):
transformer = factors transformer = factors
...@@ -396,31 +402,28 @@ class SqlEngine: ...@@ -396,31 +402,28 @@ class SqlEngine:
dependency = transformer.dependency dependency = transformer.dependency
if used_factor_tables: factor_cols = _map_factors(dependency, self._factor_tables)
factor_cols = _map_factors(dependency, used_factor_tables)
else:
factor_cols = _map_factors(dependency, factor_tables)
big_table = Market big_table = Market
joined_tables = set() joined_tables = set()
joined_tables.add(Market.__table__.name) joined_tables.add(Market.__table__.name)
for t in set(factor_cols.values()): for t in set(factor_cols.values()):
if t.__table__.name not in joined_tables: if t.name not in joined_tables:
if dates is not None: if dates is not None:
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.trade_date, big_table = outerjoin(big_table, t, and_(Market.trade_date == t.columns["trade_date"],
Market.code == t.code, Market.code == t.columns["security_code"],
Market.trade_date.in_(dates), Market.trade_date.in_(dates),
Market.flag == 1, Market.flag == 1,
t.flag == 1)) t.columns["flag"] == 1))
else: else:
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.trade_date, big_table = outerjoin(big_table, t, and_(Market.trade_date == t.columns["trade_date"],
Market.code == t.code, Market.code == t.columns["security_code"],
Market.trade_date.between(start_date, Market.trade_date.between(start_date,
end_date), end_date),
Market.flag == 1, Market.flag == 1,
t.flag == 1)) t.columns["flag"] == 1))
joined_tables.add(t.__table__.name) joined_tables.add(t.name)
universe_df = universe.query(self, start_date, end_date, dates) universe_df = universe.query(self, start_date, end_date, dates)
...@@ -429,7 +432,7 @@ class SqlEngine: ...@@ -429,7 +432,7 @@ class SqlEngine:
Market.code.label("code"), Market.code.label("code"),
Market.chgPct.label("chgPct"), Market.chgPct.label("chgPct"),
Market.secShortName.label("secShortName")] + list( Market.secShortName.label("secShortName")] + list(
factor_cols.keys())) \ column(k) for k in factor_cols.keys())) \
.select_from(big_table).where( .select_from(big_table).where(
and_( and_(
Market.code.in_(universe_df.code.unique().tolist()), Market.code.in_(universe_df.code.unique().tolist()),
...@@ -772,8 +775,7 @@ class SqlEngine: ...@@ -772,8 +775,7 @@ class SqlEngine:
transformer = Transformer(factors) transformer = Transformer(factors)
factor_data = self.fetch_factor(ref_date, factor_data = self.fetch_factor(ref_date,
transformer, transformer,
codes, codes)
used_factor_tables=factor_tables)
if benchmark: if benchmark:
benchmark_data = self.fetch_benchmark(ref_date, benchmark) benchmark_data = self.fetch_benchmark(ref_date, benchmark)
...@@ -847,7 +849,8 @@ if __name__ == "__main__": ...@@ -847,7 +849,8 @@ if __name__ == "__main__":
# db_url = "mysql+mysqldb://reader:Reader#2020@121.37.138.1:13317/vision?charset=utf8" # db_url = "mysql+mysqldb://reader:Reader#2020@121.37.138.1:13317/vision?charset=utf8"
db_url = "mysql+mysqldb://dxrw:dxRW20_2@121.37.138.1:13317/dxtest?charset=utf8" db_url = "mysql+mysqldb://dxrw:dxRW20_2@121.37.138.1:13317/dxtest?charset=utf8"
sql_engine = SqlEngine(db_url=db_url) sql_engine = SqlEngine(db_url=db_url, factor_tables=["factor_momentum"])
universe = Universe("hs300") universe = Universe("hs300")
start_date = '2020-01-02' start_date = '2020-01-02'
end_date = '2020-02-21' end_date = '2020-02-21'
...@@ -892,16 +895,16 @@ if __name__ == "__main__": ...@@ -892,16 +895,16 @@ if __name__ == "__main__":
start_date=start_date, start_date=start_date,
end_date=end_date) end_date=end_date)
print(df) print(df)
# df = sql_engine.fetch_risk_model_range(universe=universe, df = sql_engine.fetch_risk_model_range(universe=universe,
# start_date=start_date, start_date=start_date,
# end_date=end_date, end_date=end_date,
# model_type="factor") model_type="factor")
# print(df) print(df)
# df = sql_engine.fetch_data("2020-02-11", factors=factors, codes=["2010031963"], benchmark=300) df = sql_engine.fetch_data("2020-02-11", factors=factors, codes=["2010031963"], benchmark=300)
# print(df) print(df)
# df = sql_engine.fetch_data_range(universe, df = sql_engine.fetch_data_range(universe,
# factors=factors, factors=factors,
# dates=ref_dates, dates=ref_dates,
# benchmark=benchmark)["factor"] benchmark=benchmark)["factor"]
# print(df) print(df)
...@@ -17,7 +17,6 @@ from alphamind.data.dbmodel.models import RiskExposure ...@@ -17,7 +17,6 @@ from alphamind.data.dbmodel.models import RiskExposure
from alphamind.data.dbmodel.models import SpecificRiskDay from alphamind.data.dbmodel.models import SpecificRiskDay
from alphamind.data.dbmodel.models import SpecificRiskLong from alphamind.data.dbmodel.models import SpecificRiskLong
from alphamind.data.dbmodel.models import SpecificRiskShort from alphamind.data.dbmodel.models import SpecificRiskShort
from alphamind.data.dbmodel.models import factor_tables
from alphamind.data.engines.industries import INDUSTRY_MAPPING from alphamind.data.engines.industries import INDUSTRY_MAPPING
...@@ -38,8 +37,8 @@ def _map_factors(factors: Iterable[str], used_factor_tables) -> Dict: ...@@ -38,8 +37,8 @@ def _map_factors(factors: Iterable[str], used_factor_tables) -> Dict:
to_keep = factors.copy() to_keep = factors.copy()
for f in factors: for f in factors:
for t in used_factor_tables: for t in used_factor_tables:
if f in t.__table__.columns: if f in t.columns:
factor_cols[t.__table__.columns[f]] = t factor_cols[t.columns[f].name] = t
to_keep.remove(f) to_keep.remove(f)
break break
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment