Commit 93d382eb authored by Dr.李's avatar Dr.李

FEATURE: use global statistics

parent b0dbb3ea
...@@ -62,6 +62,12 @@ class SqlEngine: ...@@ -62,6 +62,12 @@ class SqlEngine:
db_session = orm.sessionmaker(bind=self._engine) db_session = orm.sessionmaker(bind=self._engine)
return db_session() return db_session()
def _create_stats(self, df, horizon, offset):
df.set_index("trade_date", inplace=True)
df["dx"] = np.log(1. + df["chgPct"])
df = df.groupby("code").rolling(window=horizon + 1)['dx'].sum().shift(-(offset + 1)).dropna().reset_index()
return df
def fetch_dx_return(self, def fetch_dx_return(self,
ref_date: str, ref_date: str,
codes: Iterable[int], codes: Iterable[int],
...@@ -81,7 +87,6 @@ class SqlEngine: ...@@ -81,7 +87,6 @@ class SqlEngine:
else: else:
end_date = expiry_date end_date = expiry_date
query = select([Market.trade_date, Market.code.label("code"), Market.chgPct.label("chgPct")]).where( query = select([Market.trade_date, Market.code.label("code"), Market.chgPct.label("chgPct")]).where(
and_( and_(
Market.trade_date.between(start_date, end_date), Market.trade_date.between(start_date, end_date),
...@@ -91,9 +96,7 @@ class SqlEngine: ...@@ -91,9 +96,7 @@ class SqlEngine:
).order_by(Market.trade_date, Market.code) ).order_by(Market.trade_date, Market.code)
df = pd.read_sql(query, self.session.bind).dropna() df = pd.read_sql(query, self.session.bind).dropna()
df.set_index("trade_date", inplace=True) df = self._create_stats(df, horizon, offset)
df["dx"] = np.log(1. + df["chgPct"])
df = df.groupby("code").rolling(window=horizon+1)['dx'].sum().shift(-(offset+1)).dropna().reset_index()
df = df[df.trade_date == ref_date] df = df[df.trade_date == ref_date]
if neutralized_risks: if neutralized_risks:
...@@ -121,6 +124,6 @@ if __name__ == "__main__": ...@@ -121,6 +124,6 @@ if __name__ == "__main__":
db_url = "mysql+mysqldb://reader:Reader#2020@121.37.138.1:13317/vision?charset=utf8" db_url = "mysql+mysqldb://reader:Reader#2020@121.37.138.1:13317/vision?charset=utf8"
sql_engine = SqlEngine(db_url=db_url) sql_engine = SqlEngine(db_url=db_url)
df = sql_engine.fetch_codes_range(start_date='2020-09-29', end_date='2020-10-10', universe=Universe("hs300")) # df = sql_engine.fetch_codes_range(start_date='2020-09-29', end_date='2020-10-10', universe=Universe("hs300"))
# df = sql_engine.fetch_dx_return("2020-09-25", codes=["2010000001"]) df = sql_engine.fetch_dx_return("2020-09-10", codes=["2010000001"])
print(df) print(df)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment