Commit b0dbb3ea authored by Dr.李's avatar Dr.李

FIX: use universal names

parent 4e3a67a9
......@@ -12,6 +12,7 @@ from typing import Tuple
from typing import Union
import numpy as np
import pandas as pd
import sqlalchemy as sa
import sqlalchemy.orm as orm
......@@ -53,6 +54,10 @@ class SqlEngine:
def engine(self):
return self._engine
@property
def session(self):
return self._session
def create_session(self):
db_session = orm.sessionmaker(bind=self._engine)
return db_session()
......@@ -76,14 +81,19 @@ class SqlEngine:
else:
end_date = expiry_date
query = select([Market.trade_date, Market.code, Market.chgPct]).where(
query = select([Market.trade_date, Market.code.label("code"), Market.chgPct.label("chgPct")]).where(
and_(
Market.trade_date.between(start_date, end_date),
Market.code.in_(codes)
)
Market.code.in_(codes),
Market.flag == 1
)
).order_by(Market.trade_date, Market.code)
df = pd.read_sql(query, self._session.bind).dropna()
df = pd.read_sql(query, self.session.bind).dropna()
df.set_index("trade_date", inplace=True)
df["dx"] = np.log(1. + df["chgPct"])
df = df.groupby("code").rolling(window=horizon+1)['dx'].sum().shift(-(offset+1)).dropna().reset_index()
df = df[df.trade_date == ref_date]
if neutralized_risks:
......@@ -93,12 +103,10 @@ class SqlEngine:
pre_process=pre_process,
risk_factors=df[neutralized_risks].values,
post_process=post_process)
df.rename(columns={"security_code": "code", "change_pct": "dx"}, inplace=True)
return df[['code', 'dx']]
def fetch_codes(self, ref_date: str, universe: Universe) -> List[int]:
df = universe.query(self, ref_date, ref_date).rename(columns={"security_code": "code"})
df = universe.query(self, ref_date, ref_date)
return sorted(df.code.tolist())
def fetch_codes_range(self,
......@@ -106,7 +114,7 @@ class SqlEngine:
start_date: str = None,
end_date: str = None,
dates: Iterable[str] = None) -> pd.DataFrame:
return universe.query(self, start_date, end_date, dates).rename(columns={"security_code": "code"})
return universe.query(self, start_date, end_date, dates)
if __name__ == "__main__":
......@@ -114,4 +122,5 @@ if __name__ == "__main__":
sql_engine = SqlEngine(db_url=db_url)
df = sql_engine.fetch_codes_range(start_date='2020-09-29', end_date='2020-10-10', universe=Universe("hs300"))
# df = sql_engine.fetch_dx_return("2020-09-25", codes=["2010000001"])
print(df)
......@@ -56,7 +56,7 @@ class BaseUniverse(metaclass=abc.ABCMeta):
more_conditions = [UniverseTable.flag == 1]
else:
more_conditions = []
query = select([UniverseTable.trade_date, UniverseTable.code]).where(
query = select([UniverseTable.trade_date, UniverseTable.code.label("code")]).where(
and_(
self._query_statements(start_date, end_date, dates),
*more_conditions
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment