Commit 304766f4 authored by Dr.李's avatar Dr.李

fixed bug when risk factor is in factors

parent d49fb0c7
......@@ -9,6 +9,7 @@ from typing import Iterable
from typing import List
from typing import Dict
from typing import Tuple
from typing import Union
import numpy as np
import pandas as pd
import sqlalchemy as sa
......@@ -81,7 +82,7 @@ macro_styles = ['COUNTRY']
total_risk_factors = risk_styles + industry_styles + macro_styles
factor_tables = [Uqer, Tiny, LegacyFactor, Experimental]
factor_tables = [Uqer, Tiny, LegacyFactor, Experimental, RiskExposure]
def append_industry_info(df):
......@@ -222,7 +223,11 @@ class SqlEngine(object):
factors: Iterable[object],
codes: Iterable[int]) -> pd.DataFrame:
transformer = Transformer(factors)
if isinstance(factors, Transformer):
transformer = factors
else:
transformer = Transformer(factors)
dependency = transformer.dependency
factor_cols = _map_factors(dependency)
......@@ -245,12 +250,16 @@ class SqlEngine(object):
def fetch_factor_range(self,
universe: Universe,
factors: Iterable[object],
factors: Union[Transformer, Iterable[object]],
start_date: str = None,
end_date: str = None,
dates: Iterable[str] = None) -> pd.DataFrame:
transformer = Transformer(factors)
if isinstance(factors, Transformer):
transformer = factors
else:
transformer = Transformer(factors)
dependency = transformer.dependency
factor_cols = _map_factors(dependency)
......@@ -304,7 +313,8 @@ class SqlEngine(object):
def fetch_risk_model(self,
ref_date: str,
codes: Iterable[int],
risk_model: str = 'short') -> Tuple[pd.DataFrame, pd.DataFrame]:
risk_model: str = 'short',
excluded: Iterable[str]=None) -> Tuple[pd.DataFrame, pd.DataFrame]:
risk_cov_table, special_risk_table = _map_risk_model_table(risk_model)
cov_risk_cols = [risk_cov_table.__table__.columns[f] for f in total_risk_factors]
......@@ -315,7 +325,7 @@ class SqlEngine(object):
)
risk_cov = pd.read_sql(query, self.engine).sort_values('FactorID')
risk_exposure_cols = [RiskExposure.__table__.columns[f] for f in total_risk_factors]
risk_exposure_cols = [RiskExposure.__table__.columns[f] for f in total_risk_factors if f not in set(excluded)]
big_table = outerjoin(special_risk_table, RiskExposure,
and_(special_risk_table.Date == RiskExposure.Date,
special_risk_table.Code == RiskExposure.Code))
......@@ -333,7 +343,8 @@ class SqlEngine(object):
start_date: str = None,
end_date: str = None,
dates: Iterable[str] = None,
risk_model: str = 'short') -> Tuple[pd.DataFrame, pd.DataFrame]:
risk_model: str = 'short',
excluded: Iterable[str] = None) -> Tuple[pd.DataFrame, pd.DataFrame]:
risk_cov_table, special_risk_table = _map_risk_model_table(risk_model)
......@@ -349,7 +360,7 @@ class SqlEngine(object):
risk_cov = pd.read_sql(query, self.engine).sort_values(['Date', 'FactorID'])
risk_exposure_cols = [RiskExposure.__table__.columns[f] for f in total_risk_factors]
risk_exposure_cols = [RiskExposure.__table__.columns[f] for f in total_risk_factors if f not in set(excluded)]
big_table = outerjoin(special_risk_table, RiskExposure,
and_(special_risk_table.Date == RiskExposure.Date,
special_risk_table.Code == RiskExposure.Code))
......@@ -374,7 +385,8 @@ class SqlEngine(object):
total_data = {}
factor_data = self.fetch_factor(ref_date, factors, codes)
transformer = Transformer(factors)
factor_data = self.fetch_factor(ref_date, transformer, codes)
if benchmark:
benchmark_data = self.fetch_benchmark(ref_date, benchmark)
......@@ -383,7 +395,8 @@ class SqlEngine(object):
factor_data['weight'] = factor_data['weight'].fillna(0.)
if risk_model:
risk_cov, risk_exp = self.fetch_risk_model(ref_date, codes, risk_model)
excluded = list(set(total_risk_factors).intersection(transformer.dependency))
risk_cov, risk_exp = self.fetch_risk_model(ref_date, codes, risk_model, excluded)
factor_data = pd.merge(factor_data, risk_exp, how='left', on=['Code'])
total_data['risk_cov'] = risk_cov
......@@ -402,8 +415,8 @@ class SqlEngine(object):
risk_model: str = 'short') -> Dict[str, pd.DataFrame]:
total_data = {}
factor_data = self.fetch_factor_range(universe, factors, start_date, end_date, dates)
transformer = Transformer(factors)
factor_data = self.fetch_factor_range(universe, transformer, start_date, end_date, dates)
if benchmark:
benchmark_data = self.fetch_benchmark_range(benchmark, start_date, end_date, dates)
......@@ -412,7 +425,8 @@ class SqlEngine(object):
factor_data['weight'] = factor_data['weight'].fillna(0.)
if risk_model:
risk_cov, risk_exp = self.fetch_risk_model_range(universe, start_date, end_date, dates, risk_model)
excluded = list(set(total_risk_factors).intersection(transformer.dependency))
risk_cov, risk_exp = self.fetch_risk_model_range(universe, start_date, end_date, dates, risk_model, excluded)
factor_data = pd.merge(factor_data, risk_exp, how='left', on=['Date', 'Code'])
total_data['risk_cov'] = risk_cov
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment