Commit 5e379c61 authored by Dr.李's avatar Dr.李

FEATURE: updated sqlengine.py

parent de8b6d34
......@@ -5,6 +5,7 @@ Created on 2017-7-7
@author: cheng.li
"""
import os
from typing import Dict
from typing import Iterable
from typing import List
......@@ -15,11 +16,12 @@ import numpy as np
import pandas as pd
import sqlalchemy as sa
import sqlalchemy.orm as orm
from sqlalchemy import select, and_, outerjoin, join, column
from sqlalchemy import select, and_, outerjoin, join, column, Table
from sqlalchemy.sql import func
from PyFin.api import advanceDateByCalendar
from alphamind.data.dbmodel.models.models import metadata
from alphamind.data.dbmodel.models.models import FactorMaster
from alphamind.data.dbmodel.models.models import IndexComponent
from alphamind.data.dbmodel.models.models import IndexMarket
......@@ -29,7 +31,6 @@ from alphamind.data.dbmodel.models.models import RiskExposure
from alphamind.data.dbmodel.models.models import RiskMaster
from alphamind.data.dbmodel.models.models import Universe as UniverseTable
from alphamind.data.engines.universe import Universe
from alphamind.data.engines.utilities import factor_tables
from alphamind.data.engines.utilities import _map_factors
from alphamind.data.engines.utilities import _map_industry_category
from alphamind.data.engines.utilities import _map_risk_model_table
......@@ -88,27 +89,40 @@ DAILY_RETURN_OFFSET = 0
class SqlEngine:
def __init__(self,
db_url: str = None):
self.engine = sa.create_engine(db_url)
self.session = self.create_session()
if self.engine.name == 'mssql':
self.ln_func = func.log
def __init__(self, db_url: str, factor_tables: List[str] = None):
self._engine = sa.create_engine(db_url)
self._session = self.create_session()
if factor_tables:
self._factor_tables = [Table(name, metadata, autoload=True, autoload_with=self._engine)
for name in factor_tables]
else:
self.ln_func = func.ln
try:
factor_tables = os.environ["FACTOR_TABLES"]
self._factor_tables = [Table(name.strip(), metadata, autoload=True, autoload_with=self._engine)
for name in factor_tables.split(",")]
except KeyError:
self._factor_tables = []
self.ln_func = func.ln
def __del__(self):
if self.session:
self.session.close()
if self._session:
self._session.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.session:
self.session.close()
if self._session:
self._session.close()
@property
def engine(self):
return self._engine
@property
def session(self):
return self._session
def create_session(self):
db_session = orm.sessionmaker(bind=self.engine)
......@@ -311,8 +325,7 @@ class SqlEngine:
ref_date: str,
factors: Iterable[object],
codes: Iterable[int],
warm_start: int = 0,
used_factor_tables=None) -> pd.DataFrame:
warm_start: int = 0) -> pd.DataFrame:
if isinstance(factors, Transformer):
transformer = factors
......@@ -321,11 +334,7 @@ class SqlEngine:
dependency = transformer.dependency
if used_factor_tables:
factor_cols = _map_factors(dependency, used_factor_tables)
else:
factor_cols = _map_factors(dependency, factor_tables)
factor_cols = _map_factors(dependency, self._factor_tables)
start_date = advanceDateByCalendar('china.sse', ref_date, str(-warm_start) + 'b').strftime(
'%Y-%m-%d')
end_date = ref_date
......@@ -335,15 +344,15 @@ class SqlEngine:
joined_tables.add(Market.__table__.name)
for t in set(factor_cols.values()):
if t.__table__.name not in joined_tables:
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.trade_date,
Market.code == t.code))
if t.name not in joined_tables:
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.columns["trade_date"],
Market.code == t.columns["code"]))
joined_tables.add(t.__table__.name)
joined_tables.add(t.name)
query = select(
[Market.trade_date, Market.code, Market.chgPct, Market.secShortName] + list(
factor_cols.keys())) \
column(k) for k in factor_cols.keys())) \
.select_from(big_table).where(and_(Market.trade_date.between(start_date, end_date),
Market.code.in_(codes)))
......@@ -365,8 +374,7 @@ class SqlEngine:
start_date: str = None,
end_date: str = None,
dates: Iterable[str] = None,
external_data: pd.DataFrame = None,
used_factor_tables=None) -> pd.DataFrame:
external_data: pd.DataFrame = None) -> pd.DataFrame:
if isinstance(factors, Transformer):
transformer = factors
......@@ -374,34 +382,30 @@ class SqlEngine:
transformer = Transformer(factors)
dependency = transformer.dependency
if used_factor_tables:
factor_cols = _map_factors(dependency, used_factor_tables)
else:
factor_cols = _map_factors(dependency, factor_tables)
factor_cols = _map_factors(dependency, self._factor_tables)
big_table = Market
joined_tables = set()
joined_tables.add(Market.__table__.name)
for t in set(factor_cols.values()):
if t.__table__.name not in joined_tables:
if t.name not in joined_tables:
if dates is not None:
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.trade_date,
Market.code == t.code,
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.columns["trade_date"],
Market.code == t.columns["code"],
Market.trade_date.in_(dates)))
else:
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.trade_date,
Market.code == t.code,
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.columns["trade_date"],
Market.code == t.columns["code"],
Market.trade_date.between(start_date,
end_date)))
joined_tables.add(t.__table__.name)
joined_tables.add(t.name)
universe_df = universe.query(self, start_date, end_date, dates)
query = select(
[Market.trade_date, Market.code, Market.chgPct, Market.secShortName] + list(
factor_cols.keys())) \
column(k) for k in factor_cols.keys())) \
.select_from(big_table).where(
and_(
Market.code.in_(universe_df.code.unique().tolist()),
......@@ -437,7 +441,7 @@ class SqlEngine:
transformer = Transformer(factors)
dependency = transformer.dependency
factor_cols = _map_factors(dependency, factor_tables)
factor_cols = _map_factors(dependency, self._factor_tables)
codes = universe.query(self, start_date, end_date, dates)
total_codes = codes.code.unique().tolist()
......@@ -448,17 +452,17 @@ class SqlEngine:
joined_tables.add(Market.__table__.name)
for t in set(factor_cols.values()):
if t.__table__.name not in joined_tables:
if t.name not in joined_tables:
if dates is not None:
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.trade_date,
Market.code == t.code,
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.columns["trade_date"],
Market.code == t.columns["code"],
Market.trade_date.in_(dates)))
else:
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.trade_date,
Market.code == t.code,
big_table = outerjoin(big_table, t, and_(Market.trade_date == t.columns["trade_date"],
Market.code == t.columns["code"],
Market.trade_date.between(start_date,
end_date)))
joined_tables.add(t.__table__.name)
joined_tables.add(t.name)
stats = func.lag(list(factor_cols.keys())[0], -1).over(
partition_by=Market.code,
......@@ -784,8 +788,7 @@ class SqlEngine:
transformer = Transformer(factors)
factor_data = self.fetch_factor(ref_date,
transformer,
codes,
used_factor_tables=factor_tables)
codes)
if benchmark:
benchmark_data = self.fetch_benchmark(ref_date, benchmark)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment