Commit d449b787 authored by Dr.李's avatar Dr.李

FEATURE: added code return range

parent 93d382eb
...@@ -5,6 +5,7 @@ Created on 2020-10-11 ...@@ -5,6 +5,7 @@ Created on 2020-10-11
@author: cheng.li @author: cheng.li
""" """
import os
from typing import Dict from typing import Dict
from typing import Iterable from typing import Iterable
from typing import List from typing import List
...@@ -26,6 +27,12 @@ from PyFin.api import advanceDateByCalendar ...@@ -26,6 +27,12 @@ from PyFin.api import advanceDateByCalendar
from alphamind.data.dbmodel.models_rl import ( from alphamind.data.dbmodel.models_rl import (
Market Market
) )
if "DB_VENDOR" in os.environ and os.environ["DB_VENDOR"].lower() == "rl":
from alphamind.data.dbmodel.models_rl import Universe as UniverseTable
else:
from alphamind.data.dbmodel.models import Universe as UniverseTable
from alphamind.data.engines.universe import Universe from alphamind.data.engines.universe import Universe
from alphamind.data.processing import factor_processing from alphamind.data.processing import factor_processing
...@@ -108,6 +115,41 @@ class SqlEngine: ...@@ -108,6 +115,41 @@ class SqlEngine:
post_process=post_process) post_process=post_process)
return df[['code', 'dx']] return df[['code', 'dx']]
def fetch_dx_return_range(self,
universe,
start_date: str = None,
end_date: str = None,
dates: Iterable[str] = None,
horizon: int = 0,
offset: int = 0,
benchmark: int = None) -> pd.DataFrame:
if dates:
start_date = dates[0]
end_date = dates[-1]
end_date = advanceDateByCalendar('china.sse', end_date,
str(
1 + horizon + offset + DAILY_RETURN_OFFSET) + 'b').strftime(
'%Y-%m-%d')
codes = universe.query(self.engine, start_date, end_date, dates)
t1 = select([Market.trade_date, Market.code.label("code"), Market.chgPct.label("chgPct")]).where(
and_(
Market.trade_date.between(start_date, end_date),
Market.code.in_(codes.code.unique().tolist()),
Market.flag == 1
)
)
df1 = pd.read_sql(t1, self.session.bind).dropna()
df2 = self.fetch_codes_range(universe, start_date, end_date, dates)
df = pd.merge(df1, df2, on=["trade_date", "code"])
df = self._create_stats(df, horizon, offset)
if dates:
df = df[df.trade_date.isin(dates)]
return df.reset_index(drop=True).sort_values(['trade_date', 'code'])
def fetch_codes(self, ref_date: str, universe: Universe) -> List[int]: def fetch_codes(self, ref_date: str, universe: Universe) -> List[int]:
df = universe.query(self, ref_date, ref_date) df = universe.query(self, ref_date, ref_date)
return sorted(df.code.tolist()) return sorted(df.code.tolist())
...@@ -124,6 +166,12 @@ if __name__ == "__main__": ...@@ -124,6 +166,12 @@ if __name__ == "__main__":
db_url = "mysql+mysqldb://reader:Reader#2020@121.37.138.1:13317/vision?charset=utf8" db_url = "mysql+mysqldb://reader:Reader#2020@121.37.138.1:13317/vision?charset=utf8"
sql_engine = SqlEngine(db_url=db_url) sql_engine = SqlEngine(db_url=db_url)
# df = sql_engine.fetch_codes_range(start_date='2020-09-29', end_date='2020-10-10', universe=Universe("hs300")) universe = Universe("hs300")
df = sql_engine.fetch_dx_return("2020-09-10", codes=["2010000001"]) start_date = '2020-09-29'
end_date = '2020-10-10'
df = sql_engine.fetch_codes_range(start_date='start_date', end_date=end_date, universe=Universe("hs300"))
print(df)
df = sql_engine.fetch_dx_return("2020-10-09", codes=["2010031963"])
print(df)
df = sql_engine.fetch_dx_return_range(universe, start_date=start_date, end_date=end_date)
print(df) print(df)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment