Commit 5e379c61 authored by Dr.李's avatar Dr.李

FEATURE: updated sqlengine.py

parent de8b6d34
...@@ -5,6 +5,7 @@ Created on 2017-7-7 ...@@ -5,6 +5,7 @@ Created on 2017-7-7
@author: cheng.li @author: cheng.li
""" """
import os
from typing import Dict from typing import Dict
from typing import Iterable from typing import Iterable
from typing import List from typing import List
...@@ -15,11 +16,12 @@ import numpy as np ...@@ -15,11 +16,12 @@ import numpy as np
import pandas as pd import pandas as pd
import sqlalchemy as sa import sqlalchemy as sa
import sqlalchemy.orm as orm import sqlalchemy.orm as orm
from sqlalchemy import select, and_, outerjoin, join, column from sqlalchemy import select, and_, outerjoin, join, column, Table
from sqlalchemy.sql import func from sqlalchemy.sql import func
from PyFin.api import advanceDateByCalendar from PyFin.api import advanceDateByCalendar
from alphamind.data.dbmodel.models.models import metadata
from alphamind.data.dbmodel.models.models import FactorMaster from alphamind.data.dbmodel.models.models import FactorMaster
from alphamind.data.dbmodel.models.models import IndexComponent from alphamind.data.dbmodel.models.models import IndexComponent
from alphamind.data.dbmodel.models.models import IndexMarket from alphamind.data.dbmodel.models.models import IndexMarket
...@@ -29,7 +31,6 @@ from alphamind.data.dbmodel.models.models import RiskExposure ...@@ -29,7 +31,6 @@ from alphamind.data.dbmodel.models.models import RiskExposure
from alphamind.data.dbmodel.models.models import RiskMaster from alphamind.data.dbmodel.models.models import RiskMaster
from alphamind.data.dbmodel.models.models import Universe as UniverseTable from alphamind.data.dbmodel.models.models import Universe as UniverseTable
from alphamind.data.engines.universe import Universe from alphamind.data.engines.universe import Universe
from alphamind.data.engines.utilities import factor_tables
from alphamind.data.engines.utilities import _map_factors from alphamind.data.engines.utilities import _map_factors
from alphamind.data.engines.utilities import _map_industry_category from alphamind.data.engines.utilities import _map_industry_category
from alphamind.data.engines.utilities import _map_risk_model_table from alphamind.data.engines.utilities import _map_risk_model_table
...@@ -88,27 +89,40 @@ DAILY_RETURN_OFFSET = 0 ...@@ -88,27 +89,40 @@ DAILY_RETURN_OFFSET = 0
class SqlEngine: class SqlEngine:
def __init__(self, def __init__(self, db_url: str, factor_tables: List[str] = None):
db_url: str = None): self._engine = sa.create_engine(db_url)
self.engine = sa.create_engine(db_url) self._session = self.create_session()
if factor_tables:
self.session = self.create_session() self._factor_tables = [Table(name, metadata, autoload=True, autoload_with=self._engine)
for name in factor_tables]
if self.engine.name == 'mssql':
self.ln_func = func.log
else: else:
try:
factor_tables = os.environ["FACTOR_TABLES"]
self._factor_tables = [Table(name.strip(), metadata, autoload=True, autoload_with=self._engine)
for name in factor_tables.split(",")]
except KeyError:
self._factor_tables = []
self.ln_func = func.ln self.ln_func = func.ln
def __del__(self): def __del__(self):
if self.session: if self._session:
self.session.close() self._session.close()
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
if self.session: if self._session:
self.session.close() self._session.close()
@property
def engine(self):
return self._engine
@property
def session(self):
return self._session
def create_session(self): def create_session(self):
db_session = orm.sessionmaker(bind=self.engine) db_session = orm.sessionmaker(bind=self.engine)
...@@ -311,8 +325,7 @@ class SqlEngine: ...@@ -311,8 +325,7 @@ class SqlEngine:
ref_date: str, ref_date: str,
factors: Iterable[object], factors: Iterable[object],
codes: Iterable[int], codes: Iterable[int],
warm_start: int = 0, warm_start: int = 0) -> pd.DataFrame:
used_factor_tables=None) -> pd.DataFrame:
if isinstance(factors, Transformer): if isinstance(factors, Transformer):
transformer = factors transformer = factors
...@@ -321,11 +334,7 @@ class SqlEngine: ...@@ -321,11 +334,7 @@ class SqlEngine:
dependency = transformer.dependency dependency = transformer.dependency
if used_factor_tables: factor_cols = _map_factors(dependency, self._factor_tables)
factor_cols = _map_factors(dependency, used_factor_tables)
else:
factor_cols = _map_factors(dependency, factor_tables)
start_date = advanceDateByCalendar('china.sse', ref_date, str(-warm_start) + 'b').strftime( start_date = advanceDateByCalendar('china.sse', ref_date, str(-warm_start) + 'b').strftime(
'%Y-%m-%d') '%Y-%m-%d')
end_date = ref_date end_date = ref_date
...@@ -335,15 +344,15 @@ class SqlEngine: ...@@ -335,15 +344,15 @@ class SqlEngine:
joined_tables.add(Market.__table__.name) joined_tables.add(Market.__table__.name)
for t in set(factor_cols.values()): for t in set(factor_cols.values()):
if t.__table__.name not in joined_tables: if t.name not in joined_tables:
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.trade_date, big_table = outerjoin(big_table, t, and_(Market.trade_date == t.columns["trade_date"],
Market.code == t.code)) Market.code == t.columns["code"]))
joined_tables.add(t.__table__.name) joined_tables.add(t.name)
query = select( query = select(
[Market.trade_date, Market.code, Market.chgPct, Market.secShortName] + list( [Market.trade_date, Market.code, Market.chgPct, Market.secShortName] + list(
factor_cols.keys())) \ column(k) for k in factor_cols.keys())) \
.select_from(big_table).where(and_(Market.trade_date.between(start_date, end_date), .select_from(big_table).where(and_(Market.trade_date.between(start_date, end_date),
Market.code.in_(codes))) Market.code.in_(codes)))
...@@ -365,8 +374,7 @@ class SqlEngine: ...@@ -365,8 +374,7 @@ class SqlEngine:
start_date: str = None, start_date: str = None,
end_date: str = None, end_date: str = None,
dates: Iterable[str] = None, dates: Iterable[str] = None,
external_data: pd.DataFrame = None, external_data: pd.DataFrame = None) -> pd.DataFrame:
used_factor_tables=None) -> pd.DataFrame:
if isinstance(factors, Transformer): if isinstance(factors, Transformer):
transformer = factors transformer = factors
...@@ -374,34 +382,30 @@ class SqlEngine: ...@@ -374,34 +382,30 @@ class SqlEngine:
transformer = Transformer(factors) transformer = Transformer(factors)
dependency = transformer.dependency dependency = transformer.dependency
factor_cols = _map_factors(dependency, self._factor_tables)
if used_factor_tables:
factor_cols = _map_factors(dependency, used_factor_tables)
else:
factor_cols = _map_factors(dependency, factor_tables)
big_table = Market big_table = Market
joined_tables = set() joined_tables = set()
joined_tables.add(Market.__table__.name) joined_tables.add(Market.__table__.name)
for t in set(factor_cols.values()): for t in set(factor_cols.values()):
if t.__table__.name not in joined_tables: if t.name not in joined_tables:
if dates is not None: if dates is not None:
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.trade_date, big_table = outerjoin(big_table, t, and_(Market.trade_date == t.columns["trade_date"],
Market.code == t.code, Market.code == t.columns["code"],
Market.trade_date.in_(dates))) Market.trade_date.in_(dates)))
else: else:
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.trade_date, big_table = outerjoin(big_table, t, and_(Market.trade_date == t.columns["trade_date"],
Market.code == t.code, Market.code == t.columns["code"],
Market.trade_date.between(start_date, Market.trade_date.between(start_date,
end_date))) end_date)))
joined_tables.add(t.__table__.name) joined_tables.add(t.name)
universe_df = universe.query(self, start_date, end_date, dates) universe_df = universe.query(self, start_date, end_date, dates)
query = select( query = select(
[Market.trade_date, Market.code, Market.chgPct, Market.secShortName] + list( [Market.trade_date, Market.code, Market.chgPct, Market.secShortName] + list(
factor_cols.keys())) \ column(k) for k in factor_cols.keys())) \
.select_from(big_table).where( .select_from(big_table).where(
and_( and_(
Market.code.in_(universe_df.code.unique().tolist()), Market.code.in_(universe_df.code.unique().tolist()),
...@@ -437,7 +441,7 @@ class SqlEngine: ...@@ -437,7 +441,7 @@ class SqlEngine:
transformer = Transformer(factors) transformer = Transformer(factors)
dependency = transformer.dependency dependency = transformer.dependency
factor_cols = _map_factors(dependency, factor_tables) factor_cols = _map_factors(dependency, self._factor_tables)
codes = universe.query(self, start_date, end_date, dates) codes = universe.query(self, start_date, end_date, dates)
total_codes = codes.code.unique().tolist() total_codes = codes.code.unique().tolist()
...@@ -448,17 +452,17 @@ class SqlEngine: ...@@ -448,17 +452,17 @@ class SqlEngine:
joined_tables.add(Market.__table__.name) joined_tables.add(Market.__table__.name)
for t in set(factor_cols.values()): for t in set(factor_cols.values()):
if t.__table__.name not in joined_tables: if t.name not in joined_tables:
if dates is not None: if dates is not None:
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.trade_date, big_table = outerjoin(big_table, t, and_(Market.trade_date == t.columns["trade_date"],
Market.code == t.code, Market.code == t.columns["code"],
Market.trade_date.in_(dates))) Market.trade_date.in_(dates)))
else: else:
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.trade_date, big_table = outerjoin(big_table, t, and_(Market.trade_date == t.columns["trade_date"],
Market.code == t.code, Market.code == t.columns["code"],
Market.trade_date.between(start_date, Market.trade_date.between(start_date,
end_date))) end_date)))
joined_tables.add(t.__table__.name) joined_tables.add(t.name)
stats = func.lag(list(factor_cols.keys())[0], -1).over( stats = func.lag(list(factor_cols.keys())[0], -1).over(
partition_by=Market.code, partition_by=Market.code,
...@@ -784,8 +788,7 @@ class SqlEngine: ...@@ -784,8 +788,7 @@ class SqlEngine:
transformer = Transformer(factors) transformer = Transformer(factors)
factor_data = self.fetch_factor(ref_date, factor_data = self.fetch_factor(ref_date,
transformer, transformer,
codes, codes)
used_factor_tables=factor_tables)
if benchmark: if benchmark:
benchmark_data = self.fetch_benchmark(ref_date, benchmark) benchmark_data = self.fetch_benchmark(ref_date, benchmark)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment