Commit 93d382eb authored by Dr.李's avatar Dr.李

FEATURE: use global statistics

parent b0dbb3ea
......@@ -62,6 +62,12 @@ class SqlEngine:
db_session = orm.sessionmaker(bind=self._engine)
return db_session()
def _create_stats(self, df, horizon, offset):
df.set_index("trade_date", inplace=True)
df["dx"] = np.log(1. + df["chgPct"])
df = df.groupby("code").rolling(window=horizon + 1)['dx'].sum().shift(-(offset + 1)).dropna().reset_index()
return df
def fetch_dx_return(self,
ref_date: str,
codes: Iterable[int],
......@@ -81,7 +87,6 @@ class SqlEngine:
else:
end_date = expiry_date
query = select([Market.trade_date, Market.code.label("code"), Market.chgPct.label("chgPct")]).where(
and_(
Market.trade_date.between(start_date, end_date),
......@@ -91,9 +96,7 @@ class SqlEngine:
).order_by(Market.trade_date, Market.code)
df = pd.read_sql(query, self.session.bind).dropna()
df.set_index("trade_date", inplace=True)
df["dx"] = np.log(1. + df["chgPct"])
df = df.groupby("code").rolling(window=horizon+1)['dx'].sum().shift(-(offset+1)).dropna().reset_index()
df = self._create_stats(df, horizon, offset)
df = df[df.trade_date == ref_date]
if neutralized_risks:
......@@ -121,6 +124,6 @@ if __name__ == "__main__":
db_url = "mysql+mysqldb://reader:Reader#2020@121.37.138.1:13317/vision?charset=utf8"
sql_engine = SqlEngine(db_url=db_url)
df = sql_engine.fetch_codes_range(start_date='2020-09-29', end_date='2020-10-10', universe=Universe("hs300"))
# df = sql_engine.fetch_dx_return("2020-09-25", codes=["2010000001"])
# df = sql_engine.fetch_codes_range(start_date='2020-09-29', end_date='2020-10-10', universe=Universe("hs300"))
df = sql_engine.fetch_dx_return("2020-09-10", codes=["2010000001"])
print(df)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment