Commit c14a3881 authored by Dr.李's avatar Dr.李

FIX: wrong dx return calculation

parent 64259fc1
...@@ -20,6 +20,10 @@ if "DB_VENDOR" in os.environ and os.environ["DB_VENDOR"].lower() == "rl": ...@@ -20,6 +20,10 @@ if "DB_VENDOR" in os.environ and os.environ["DB_VENDOR"].lower() == "rl":
from alphamind.data.dbmodel.models.models_rl import SpecificRiskLong from alphamind.data.dbmodel.models.models_rl import SpecificRiskLong
from alphamind.data.dbmodel.models.models_rl import IndexComponent from alphamind.data.dbmodel.models.models_rl import IndexComponent
from alphamind.data.dbmodel.models.models_rl import IndexWeight from alphamind.data.dbmodel.models.models_rl import IndexWeight
from alphamind.data.dbmodel.models.models_rl import FactorMomentum
factor_tables = [Market, RiskExposure, FactorMomentum]
else: else:
from alphamind.data.dbmodel.models.models import Market from alphamind.data.dbmodel.models.models import Market
from alphamind.data.dbmodel.models.models import IndexMarket from alphamind.data.dbmodel.models.models import IndexMarket
...@@ -35,3 +39,5 @@ else: ...@@ -35,3 +39,5 @@ else:
from alphamind.data.dbmodel.models.models import FactorMaster from alphamind.data.dbmodel.models.models import FactorMaster
from alphamind.data.dbmodel.models.models import IndexComponent from alphamind.data.dbmodel.models.models import IndexComponent
from alphamind.data.dbmodel.models.models import RiskMaster from alphamind.data.dbmodel.models.models import RiskMaster
factor_tables = [Market, RiskExposure]
...@@ -401,6 +401,85 @@ class _SpecificRiskShort(Base): ...@@ -401,6 +401,85 @@ class _SpecificRiskShort(Base):
SRISK = Column(FLOAT) SRISK = Column(FLOAT)
# Factor tables
class _FactorMomentum(Base):
__tablename__ = 'factor_momentum'
__table_args__ = (
Index('factor_momentum_uindex', 'trade_date', 'security_code', 'flag', unique=True),
)
id = Column(INT, primary_key=True)
code = Column("security_code", Text, nullable=False)
trade_date = Column(Date, nullable=False)
ADX14D = Column(FLOAT)
ADXR14D = Column(FLOAT)
APBMA5D = Column(FLOAT)
ARC50D = Column(FLOAT)
BBI = Column(FLOAT)
BIAS10D = Column(FLOAT)
BIAS20D = Column(FLOAT)
BIAS5D = Column(FLOAT)
BIAS60D = Column(FLOAT)
CCI10D = Column(FLOAT)
CCI20D = Column(FLOAT)
CCI5D = Column(FLOAT)
CCI88D = Column(FLOAT)
ChgTo1MAvg = Column(FLOAT)
ChgTo1YAvg = Column(FLOAT)
ChgTo3MAvg = Column(FLOAT)
ChkOsci3D10D = Column(FLOAT)
ChkVol10D = Column(FLOAT)
DEA = Column(FLOAT)
EMA10D = Column(FLOAT)
EMA120D = Column(FLOAT)
EMA12D = Column(FLOAT)
EMA20D = Column(FLOAT)
EMA26D = Column(FLOAT)
EMA5D = Column(FLOAT)
EMA60D = Column(FLOAT)
EMV14D = Column(FLOAT)
EMV6D = Column(FLOAT)
Fiftytwoweekhigh = Column(FLOAT)
HT_TRENDLINE = Column(FLOAT)
KAMA10D = Column(FLOAT)
MA10Close = Column(FLOAT)
MA10D = Column(FLOAT)
MA10RegressCoeff12 = Column(FLOAT)
MA10RegressCoeff6 = Column(FLOAT)
MA120D = Column(FLOAT)
MA20D = Column(FLOAT)
MA5D = Column(FLOAT)
MA60D = Column(FLOAT)
MACD12D26D = Column(FLOAT)
MIDPOINT10D = Column(FLOAT)
MIDPRICE10D = Column(FLOAT)
MTM10D = Column(FLOAT)
PLRC12D = Column(FLOAT)
PLRC6D = Column(FLOAT)
PM10D = Column(FLOAT)
PM120D = Column(FLOAT)
PM20D = Column(FLOAT)
PM250D = Column(FLOAT)
PM5D = Column(FLOAT)
PM60D = Column(FLOAT)
PMDif5D20D = Column(FLOAT)
PMDif5D60D = Column(FLOAT)
RCI12D = Column(FLOAT)
RCI24D = Column(FLOAT)
SAR = Column(FLOAT)
SAREXT = Column(FLOAT)
SMA15D = Column(FLOAT)
TEMA10D = Column(FLOAT)
TEMA5D = Column(FLOAT)
TRIMA10D = Column(FLOAT)
TRIX10D = Column(FLOAT)
TRIX5D = Column(FLOAT)
UOS7D14D28D = Column(FLOAT)
WMA10D = Column(FLOAT)
flag = Column(INT, server_default=text("'1'"))
Market = _StkDailyPricePro Market = _StkDailyPricePro
IndexMarket = _IndexDailyPrice IndexMarket = _IndexDailyPrice
Universe = _StkUniverse Universe = _StkUniverse
...@@ -414,3 +493,5 @@ SpecificRiskShort = _SpecificRiskShort ...@@ -414,3 +493,5 @@ SpecificRiskShort = _SpecificRiskShort
SpecificRiskLong = _SpecificRiskLong SpecificRiskLong = _SpecificRiskLong
IndexComponent = _IndexComponent IndexComponent = _IndexComponent
IndexWeight = _Index IndexWeight = _Index
FactorMomentum = _FactorMomentum
...@@ -31,7 +31,7 @@ from alphamind.data.dbmodel.models.models_rl import ( ...@@ -31,7 +31,7 @@ from alphamind.data.dbmodel.models.models_rl import (
RiskExposure, RiskExposure,
Universe as UniverseTable, Universe as UniverseTable,
IndexComponent, IndexComponent,
IndexWeight IndexWeight,
) )
from alphamind.data.engines.utilities import factor_tables from alphamind.data.engines.utilities import factor_tables
from alphamind.data.engines.utilities import _map_factors from alphamind.data.engines.utilities import _map_factors
...@@ -229,10 +229,12 @@ class SqlEngine: ...@@ -229,10 +229,12 @@ class SqlEngine:
) )
) )
df1 = pd.read_sql(t1, self.session.bind).dropna() df1 = pd.read_sql(t1, self.session.bind).dropna()
df2 = self.fetch_codes_range(universe, start_date, end_date, dates) df1 = self._create_stats(df1, horizon, offset)
df2 = self.fetch_codes_range(universe, start_date, end_date, dates)
df2["trade_date"] = pd.to_datetime(df2["trade_date"])
df = pd.merge(df1, df2, on=["trade_date", "code"]) df = pd.merge(df1, df2, on=["trade_date", "code"])
df = self._create_stats(df, horizon, offset)
if dates: if dates:
df = df[df.trade_date.isin(dates)] df = df[df.trade_date.isin(dates)]
...@@ -711,6 +713,43 @@ class SqlEngine: ...@@ -711,6 +713,43 @@ class SqlEngine:
).distinct() ).distinct()
return pd.read_sql(query, self.engine) return pd.read_sql(query, self.engine)
def fetch_data(self,
ref_date: str,
factors: Iterable[str],
codes: Iterable[int],
benchmark: int = None,
risk_model: str = 'short',
industry: str = 'sw') -> Dict[str, pd.DataFrame]:
total_data = dict()
transformer = Transformer(factors)
factor_data = self.fetch_factor(ref_date,
transformer,
codes,
used_factor_tables=factor_tables)
if benchmark:
benchmark_data = self.fetch_benchmark(ref_date, benchmark)
total_data['benchmark'] = benchmark_data
factor_data = pd.merge(factor_data, benchmark_data, how='left', on=['code'])
factor_data['weight'] = factor_data['weight'].fillna(0.)
if risk_model:
excluded = list(set(total_risk_factors).intersection(transformer.dependency))
risk_cov, risk_exp = self.fetch_risk_model(ref_date, codes, risk_model, excluded)
factor_data = pd.merge(factor_data, risk_exp, how='left', on=['code'])
total_data['risk_cov'] = risk_cov
industry_info = self.fetch_industry(ref_date=ref_date,
codes=codes,
category=industry)
factor_data = pd.merge(factor_data, industry_info, on=['code'])
total_data['factor'] = factor_data
return total_data
def fetch_data_range(self, def fetch_data_range(self,
universe: Universe, universe: Universe,
factors: Iterable[str], factors: Iterable[str],
...@@ -721,7 +760,8 @@ class SqlEngine: ...@@ -721,7 +760,8 @@ class SqlEngine:
risk_model: str = 'short', risk_model: str = 'short',
industry: str = 'sw', industry: str = 'sw',
external_data: pd.DataFrame = None) -> Dict[str, pd.DataFrame]: external_data: pd.DataFrame = None) -> Dict[str, pd.DataFrame]:
total_data = dict()
total_data = {}
transformer = Transformer(factors) transformer = Transformer(factors)
factor_data = self.fetch_factor_range(universe, factor_data = self.fetch_factor_range(universe,
transformer, transformer,
...@@ -757,22 +797,25 @@ class SqlEngine: ...@@ -757,22 +797,25 @@ class SqlEngine:
if __name__ == "__main__": if __name__ == "__main__":
from PyFin.api import makeSchedule
db_url = "mysql+mysqldb://reader:Reader#2020@121.37.138.1:13317/vision?charset=utf8" db_url = "mysql+mysqldb://reader:Reader#2020@121.37.138.1:13317/vision?charset=utf8"
sql_engine = SqlEngine(db_url=db_url) sql_engine = SqlEngine(db_url=db_url)
universe = Universe("hs300") universe = Universe("hs300")
start_date = '2020-01-01' start_date = '2020-01-01'
end_date = '2020-02-21' end_date = '2020-04-21'
benchmark = 300 benchmark = 300
df = sql_engine.fetch_factor("2020-02-21", factors=["BETA"], codes=["2010031963"]) factors = ["EMA5D", "EMV6D"]
print(df) ref_dates = makeSchedule(start_date, end_date, "10b", 'china.sse')
df = sql_engine.fetch_factor_range(universe=universe, start_date=start_date, end_date=end_date, factors=["BETA"]) # df = sql_engine.fetch_factor("2020-02-21", factors=factors, codes=["2010031963"])
print(df) # print(df)
df = sql_engine.fetch_codes_range(start_date=start_date, end_date=end_date, universe=Universe("hs300")) # df = sql_engine.fetch_factor_range(universe=universe, start_date=start_date, end_date=end_date, factors=factors)
print(df) # print(df)
df = sql_engine.fetch_dx_return("2020-10-09", codes=["2010031963"]) # df = sql_engine.fetch_codes_range(start_date=start_date, end_date=end_date, universe=Universe("hs300"))
print(df) # print(df)
df = sql_engine.fetch_dx_return_range(universe, start_date=start_date, end_date=end_date) # df = sql_engine.fetch_dx_return("2020-10-09", codes=["2010031963"])
# print(df)
df = sql_engine.fetch_dx_return_range(universe, dates=ref_dates, horizon=9)
print(df) print(df)
df = sql_engine.fetch_dx_return_index("2020-10-09", index_code=benchmark) df = sql_engine.fetch_dx_return_index("2020-10-09", index_code=benchmark)
print(df) print(df)
...@@ -805,4 +848,12 @@ if __name__ == "__main__": ...@@ -805,4 +848,12 @@ if __name__ == "__main__":
end_date=end_date, end_date=end_date,
model_type="factor") model_type="factor")
print(df) print(df)
df = sql_engine.fetch_data("2020-02-11", factors=factors, codes=["2010031963"], benchmark=300)
print(df)
df = sql_engine.fetch_data_range(universe,
factors=factors,
start_date=start_date,
end_date=end_date,
benchmark=300)
print(df)
...@@ -17,8 +17,8 @@ from alphamind.data.dbmodel.models import RiskExposure ...@@ -17,8 +17,8 @@ from alphamind.data.dbmodel.models import RiskExposure
from alphamind.data.dbmodel.models import SpecificRiskDay from alphamind.data.dbmodel.models import SpecificRiskDay
from alphamind.data.dbmodel.models import SpecificRiskLong from alphamind.data.dbmodel.models import SpecificRiskLong
from alphamind.data.dbmodel.models import SpecificRiskShort from alphamind.data.dbmodel.models import SpecificRiskShort
from alphamind.data.dbmodel.models import factor_tables
from alphamind.data.engines.industries import INDUSTRY_MAPPING from alphamind.data.engines.industries import INDUSTRY_MAPPING
factor_tables = [Market, RiskExposure]
def _map_risk_model_table(risk_model: str) -> tuple: def _map_risk_model_table(risk_model: str) -> tuple:
......
...@@ -10,9 +10,11 @@ import os ...@@ -10,9 +10,11 @@ import os
SKIP_ENGINE_TESTS = True SKIP_ENGINE_TESTS = True
if not SKIP_ENGINE_TESTS: if not SKIP_ENGINE_TESTS:
try:
DATA_ENGINE_URI = os.environ['DB_URI'] DATA_ENGINE_URI = os.environ['DB_URI']
else: except KeyError:
DATA_ENGINE_URI = None DATA_ENGINE_URI = "mysql+mysqldb://reader:Reader#2020@121.37.138.1:13317/vision?charset=utf8"
if __name__ == '__main__': if __name__ == '__main__':
from simpleutils import add_parent_path from simpleutils import add_parent_path
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment