Commit 03ea0ad0 authored by Dr.李's avatar Dr.李

update engines

parent 6c3e1b28
......@@ -166,6 +166,9 @@ def factor_analysis(factors: pd.DataFrame,
method='risk_neutral',
**kwargs) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]:
if risk_exp is not None:
risk_exp = risk_exp[:, risk_exp.sum(axis=0) != 0]
data_pack = FDataPack(raw_factors=factors.values,
d1returns=d1returns,
groups=industry,
......
......@@ -22,8 +22,7 @@ risk_styles = ['BETA',
'BTOP',
'LEVERAGE',
'LIQUIDTY',
'SIZENL',
'COUNTRY']
'SIZENL']
industry_styles = [
'Bank',
......@@ -68,12 +67,14 @@ def append_industry_info(df):
class SqlEngine(object):
def __init__(self,
db_url: str,
universe: Universe):
db_url: str):
self.engine = sa.create_engine(db_url)
self.unv = universe
def fetch_codes(self, ref_date: str) -> List[int]:
def fetch_factors_meta(self) -> pd.DataFrame:
sql = "select * from factor_master"
return pd.read_sql(sql, self.engine)
def fetch_codes(self, ref_date: str, univ: Universe) -> List[int]:
def get_universe(univ, ref_date):
univ_str = ','.join("'" + u + "'" for u in univ)
......@@ -85,19 +86,19 @@ class SqlEngine(object):
codes_set = None
if self.unv.include_universe:
include_codes_set = get_universe(self.unv.include_universe, ref_date)
if univ.include_universe:
include_codes_set = get_universe(univ.include_universe, ref_date)
codes_set = include_codes_set
if self.unv.exclude_universe:
exclude_codes_set = get_universe(self.unv.exclude_universe, ref_date)
if univ.exclude_universe:
exclude_codes_set = get_universe(univ.exclude_universe, ref_date)
codes_set -= exclude_codes_set
if self.unv.include_codes:
codes_set = codes_set.union(self.unv.include_codes)
if univ.include_codes:
codes_set = codes_set.union(univ.include_codes)
if self.unv.exclude_codes:
codes_set -= set(self.unv.exclude_codes)
if univ.exclude_codes:
codes_set -= set(univ.exclude_codes)
return sorted(codes_set)
......@@ -154,16 +155,17 @@ class SqlEngine(object):
if __name__ == '__main__':
db_url = 'mysql+mysqldb://root:we083826@localhost/alpha?charset=utf8'
universe = Universe(['zz500'])
universe = Universe('zz500', ['zz500'])
engine = SqlEngine(db_url, universe)
engine = SqlEngine(db_url)
ref_date = '2017-07-04'
import datetime as dt
start = dt.datetime.now()
for i in range(500):
codes = engine.fetch_codes('2017-07-04')
for i in range(10):
factors = engine.fetch_factors_meta()
codes = engine.fetch_codes('2017-07-04', universe)
total_data = engine.fetch_data(ref_date, ['EPS'], [1, 5], 905)
print(dt.datetime.now() - start)
......
......@@ -11,11 +11,13 @@ from typing import Iterable
class Universe(object):
def __init__(self,
name,
include_universe: Iterable[str]=None,
exclude_universe: Iterable[str]=None,
include_codes: Iterable[str]=None,
exclude_codes: Iterable[str]=None):
self.name = name
self.include_universe = include_universe
self.exclude_universe = exclude_universe
self.include_codes = include_codes
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment