Commit 304766f4 authored by Dr.李's avatar Dr.李

fixed bug when risk factor is in factors

parent d49fb0c7
...@@ -9,6 +9,7 @@ from typing import Iterable ...@@ -9,6 +9,7 @@ from typing import Iterable
from typing import List from typing import List
from typing import Dict from typing import Dict
from typing import Tuple from typing import Tuple
from typing import Union
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import sqlalchemy as sa import sqlalchemy as sa
...@@ -81,7 +82,7 @@ macro_styles = ['COUNTRY'] ...@@ -81,7 +82,7 @@ macro_styles = ['COUNTRY']
total_risk_factors = risk_styles + industry_styles + macro_styles total_risk_factors = risk_styles + industry_styles + macro_styles
factor_tables = [Uqer, Tiny, LegacyFactor, Experimental] factor_tables = [Uqer, Tiny, LegacyFactor, Experimental, RiskExposure]
def append_industry_info(df): def append_industry_info(df):
...@@ -222,7 +223,11 @@ class SqlEngine(object): ...@@ -222,7 +223,11 @@ class SqlEngine(object):
factors: Iterable[object], factors: Iterable[object],
codes: Iterable[int]) -> pd.DataFrame: codes: Iterable[int]) -> pd.DataFrame:
transformer = Transformer(factors) if isinstance(factors, Transformer):
transformer = factors
else:
transformer = Transformer(factors)
dependency = transformer.dependency dependency = transformer.dependency
factor_cols = _map_factors(dependency) factor_cols = _map_factors(dependency)
...@@ -245,12 +250,16 @@ class SqlEngine(object): ...@@ -245,12 +250,16 @@ class SqlEngine(object):
def fetch_factor_range(self, def fetch_factor_range(self,
universe: Universe, universe: Universe,
factors: Iterable[object], factors: Union[Transformer, Iterable[object]],
start_date: str = None, start_date: str = None,
end_date: str = None, end_date: str = None,
dates: Iterable[str] = None) -> pd.DataFrame: dates: Iterable[str] = None) -> pd.DataFrame:
transformer = Transformer(factors) if isinstance(factors, Transformer):
transformer = factors
else:
transformer = Transformer(factors)
dependency = transformer.dependency dependency = transformer.dependency
factor_cols = _map_factors(dependency) factor_cols = _map_factors(dependency)
...@@ -304,7 +313,8 @@ class SqlEngine(object): ...@@ -304,7 +313,8 @@ class SqlEngine(object):
def fetch_risk_model(self, def fetch_risk_model(self,
ref_date: str, ref_date: str,
codes: Iterable[int], codes: Iterable[int],
risk_model: str = 'short') -> Tuple[pd.DataFrame, pd.DataFrame]: risk_model: str = 'short',
excluded: Iterable[str]=None) -> Tuple[pd.DataFrame, pd.DataFrame]:
risk_cov_table, special_risk_table = _map_risk_model_table(risk_model) risk_cov_table, special_risk_table = _map_risk_model_table(risk_model)
cov_risk_cols = [risk_cov_table.__table__.columns[f] for f in total_risk_factors] cov_risk_cols = [risk_cov_table.__table__.columns[f] for f in total_risk_factors]
...@@ -315,7 +325,7 @@ class SqlEngine(object): ...@@ -315,7 +325,7 @@ class SqlEngine(object):
) )
risk_cov = pd.read_sql(query, self.engine).sort_values('FactorID') risk_cov = pd.read_sql(query, self.engine).sort_values('FactorID')
risk_exposure_cols = [RiskExposure.__table__.columns[f] for f in total_risk_factors] risk_exposure_cols = [RiskExposure.__table__.columns[f] for f in total_risk_factors if f not in set(excluded)]
big_table = outerjoin(special_risk_table, RiskExposure, big_table = outerjoin(special_risk_table, RiskExposure,
and_(special_risk_table.Date == RiskExposure.Date, and_(special_risk_table.Date == RiskExposure.Date,
special_risk_table.Code == RiskExposure.Code)) special_risk_table.Code == RiskExposure.Code))
...@@ -333,7 +343,8 @@ class SqlEngine(object): ...@@ -333,7 +343,8 @@ class SqlEngine(object):
start_date: str = None, start_date: str = None,
end_date: str = None, end_date: str = None,
dates: Iterable[str] = None, dates: Iterable[str] = None,
risk_model: str = 'short') -> Tuple[pd.DataFrame, pd.DataFrame]: risk_model: str = 'short',
excluded: Iterable[str] = None) -> Tuple[pd.DataFrame, pd.DataFrame]:
risk_cov_table, special_risk_table = _map_risk_model_table(risk_model) risk_cov_table, special_risk_table = _map_risk_model_table(risk_model)
...@@ -349,7 +360,7 @@ class SqlEngine(object): ...@@ -349,7 +360,7 @@ class SqlEngine(object):
risk_cov = pd.read_sql(query, self.engine).sort_values(['Date', 'FactorID']) risk_cov = pd.read_sql(query, self.engine).sort_values(['Date', 'FactorID'])
risk_exposure_cols = [RiskExposure.__table__.columns[f] for f in total_risk_factors] risk_exposure_cols = [RiskExposure.__table__.columns[f] for f in total_risk_factors if f not in set(excluded)]
big_table = outerjoin(special_risk_table, RiskExposure, big_table = outerjoin(special_risk_table, RiskExposure,
and_(special_risk_table.Date == RiskExposure.Date, and_(special_risk_table.Date == RiskExposure.Date,
special_risk_table.Code == RiskExposure.Code)) special_risk_table.Code == RiskExposure.Code))
...@@ -374,7 +385,8 @@ class SqlEngine(object): ...@@ -374,7 +385,8 @@ class SqlEngine(object):
total_data = {} total_data = {}
factor_data = self.fetch_factor(ref_date, factors, codes) transformer = Transformer(factors)
factor_data = self.fetch_factor(ref_date, transformer, codes)
if benchmark: if benchmark:
benchmark_data = self.fetch_benchmark(ref_date, benchmark) benchmark_data = self.fetch_benchmark(ref_date, benchmark)
...@@ -383,7 +395,8 @@ class SqlEngine(object): ...@@ -383,7 +395,8 @@ class SqlEngine(object):
factor_data['weight'] = factor_data['weight'].fillna(0.) factor_data['weight'] = factor_data['weight'].fillna(0.)
if risk_model: if risk_model:
risk_cov, risk_exp = self.fetch_risk_model(ref_date, codes, risk_model) excluded = list(set(total_risk_factors).intersection(transformer.dependency))
risk_cov, risk_exp = self.fetch_risk_model(ref_date, codes, risk_model, excluded)
factor_data = pd.merge(factor_data, risk_exp, how='left', on=['Code']) factor_data = pd.merge(factor_data, risk_exp, how='left', on=['Code'])
total_data['risk_cov'] = risk_cov total_data['risk_cov'] = risk_cov
...@@ -402,8 +415,8 @@ class SqlEngine(object): ...@@ -402,8 +415,8 @@ class SqlEngine(object):
risk_model: str = 'short') -> Dict[str, pd.DataFrame]: risk_model: str = 'short') -> Dict[str, pd.DataFrame]:
total_data = {} total_data = {}
transformer = Transformer(factors)
factor_data = self.fetch_factor_range(universe, factors, start_date, end_date, dates) factor_data = self.fetch_factor_range(universe, transformer, start_date, end_date, dates)
if benchmark: if benchmark:
benchmark_data = self.fetch_benchmark_range(benchmark, start_date, end_date, dates) benchmark_data = self.fetch_benchmark_range(benchmark, start_date, end_date, dates)
...@@ -412,7 +425,8 @@ class SqlEngine(object): ...@@ -412,7 +425,8 @@ class SqlEngine(object):
factor_data['weight'] = factor_data['weight'].fillna(0.) factor_data['weight'] = factor_data['weight'].fillna(0.)
if risk_model: if risk_model:
risk_cov, risk_exp = self.fetch_risk_model_range(universe, start_date, end_date, dates, risk_model) excluded = list(set(total_risk_factors).intersection(transformer.dependency))
risk_cov, risk_exp = self.fetch_risk_model_range(universe, start_date, end_date, dates, risk_model, excluded)
factor_data = pd.merge(factor_data, risk_exp, how='left', on=['Date', 'Code']) factor_data = pd.merge(factor_data, risk_exp, how='left', on=['Date', 'Code'])
total_data['risk_cov'] = risk_cov total_data['risk_cov'] = risk_cov
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment