Commit 5fa76dd3 authored by Dr.李's avatar Dr.李

update store

parent 3f27c3e7
...@@ -6,6 +6,7 @@ Created on 2017-6-26 ...@@ -6,6 +6,7 @@ Created on 2017-6-26
""" """
from typing import Iterable from typing import Iterable
from typing import Union
import sqlalchemy as sa import sqlalchemy as sa
import numpy as np import numpy as np
import pandas as pd import pandas as pd
...@@ -64,44 +65,67 @@ industry_styles = [ ...@@ -64,44 +65,67 @@ industry_styles = [
] ]
def fetch_codes(codes: Union[str, Iterable[int]], start_date, end_date, engine):
code_table = None
code_str = None
if isinstance(codes, str):
# universe
sql = "select Date, Code from universe where Date >= '{0}' and Date <= '{1}' and universe = '{2}'" \
.format(start_date, end_date, codes)
code_table = pd.read_sql(sql, engine)
elif hasattr(codes, '__iter__'):
code_str = ','.join(str(c) for c in codes)
return code_table, code_str
def industry_mapping(industry_arr, industry_dummies):
return [industry_arr[row][0] for row in industry_dummies]
def fetch_data(factors: Iterable[str], def fetch_data(factors: Iterable[str],
start_date: str, start_date: str,
end_date: str, end_date: str,
codes: Iterable[int] = None, codes: Union[str, Iterable[int]] = None,
benchmark: int = None, benchmark: int = None,
risk_model: str = 'day') -> dict: risk_model: str = 'day') -> dict:
engine = sa.create_engine('mysql+mysqldb://{user}:{password}@{host}/{db}?charset={charset}' engine = sa.create_engine('mysql+mysqldb://{user}:{password}@{host}/{db}?charset={charset}'
.format(**db_settings['uqer'])) .format(**db_settings['uqer']))
factor_str = ','.join('factors.' + f for f in factors) factor_str = ','.join('factors.' + f for f in factors)
if codes: code_table, code_str = fetch_codes(codes, start_date, end_date, engine)
code_str = ','.join(str(c) for c in codes)
else:
code_str = None
total_risk_factors = risk_styles + industry_styles total_risk_factors = risk_styles + industry_styles
risk_str = ','.join('risk_exposure.' + f for f in total_risk_factors) risk_str = ','.join('risk_exposure.' + f for f in total_risk_factors)
if code_str: if code_str:
sql = "select factors.Date, factors.Code, {0}, {3}" \ sql = "select factors.Date, factors.Code, {0}, {3}, market.chgPct, market.isOpen" \
" from factors INNER JOIN" \ " from (factors INNER JOIN" \
" risk_exposure on factors.Date = risk_exposure.Date and factors.Code = risk_exposure.Code" \ " risk_exposure on factors.Date = risk_exposure.Date and factors.Code = risk_exposure.Code)" \
" INNER JOIN market on factors.Date = market.Date and factors.Code = market.Code" \
" where factors.Date >= '{1}' and factors.Date <= '{2}' and factors.Code in ({4})".format(factor_str, " where factors.Date >= '{1}' and factors.Date <= '{2}' and factors.Code in ({4})".format(factor_str,
start_date, start_date,
end_date, end_date,
risk_str, risk_str,
code_str) code_str)
else: else:
sql = "select factors.Date, factors.Code, {0}, {3}" \ sql = "select factors.Date, factors.Code, {0}, {3}, market.chgPct, market.isOpen" \
" from factors INNER JOIN" \ " from (factors INNER JOIN" \
" risk_exposure on factors.Date = risk_exposure.Date and factors.Code = risk_exposure.Code" \ " risk_exposure on factors.Date = risk_exposure.Date and factors.Code = risk_exposure.Code)" \
" INNER JOIN market on factors.Date = market.Date and factors.Code = market.Code" \
" where factors.Date >= '{1}' and factors.Date <= '{2}'".format(factor_str, " where factors.Date >= '{1}' and factors.Date <= '{2}'".format(factor_str,
start_date, start_date,
end_date, end_date,
risk_str) risk_str)
factor_data = pd.read_sql(sql, engine) factor_data = pd.read_sql(sql, engine)
if code_table is not None:
factor_data = pd.merge(factor_data, code_table, on=['Date', 'Code'])
risk_cov_table = 'risk_cov_' + risk_model risk_cov_table = 'risk_cov_' + risk_model
risk_str = ','.join(risk_cov_table + '.' + f for f in total_risk_factors) risk_str = ','.join(risk_cov_table + '.' + f for f in total_risk_factors)
...@@ -126,7 +150,7 @@ def fetch_data(factors: Iterable[str], ...@@ -126,7 +150,7 @@ def fetch_data(factors: Iterable[str],
industry_arr = np.array(industry_styles) industry_arr = np.array(industry_styles)
industry_dummies = factor_data[industry_styles].values.astype(bool) industry_dummies = factor_data[industry_styles].values.astype(bool)
factor_data['industry'] = [industry_arr[row][0] for row in industry_dummies] factor_data['industry'] = industry_mapping(industry_arr, industry_dummies)
return total_data return total_data
...@@ -135,6 +159,6 @@ if __name__ == '__main__': ...@@ -135,6 +159,6 @@ if __name__ == '__main__':
import datetime as dt import datetime as dt
start = dt.datetime.now() start = dt.datetime.now()
res = fetch_data(['EPS'], '2017-01-03', '2017-06-05', benchmark=905) res = fetch_data(['EPS'], '2017-01-03', '2017-06-05', benchmark=905, codes='zz500')
print(res) print(res)
print(dt.datetime.now() - start) print(dt.datetime.now() - start)
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment