Commit 8633e03b authored by Dr.李's avatar Dr.李

made it work for several table source

parent 1a6908af
......@@ -108,7 +108,13 @@ class SqlEngine(object):
benchmark: int = None,
risk_model: str = 'short') -> Dict[str, pd.DataFrame]:
factor_str = ','.join('uqer.' + f for f in factors)
def mapping_factors(factors):
factor_list = ','.join("'" + f + "'" for f in factors)
sql = "select factor, source from factor_master where factor in ({factor_list})".format(factor_list=factor_list)
results = self.engine.execute(sql).fetchall()
return ','.join(r[1].strip() + '.' + r[0].strip() for r in results)
factor_str = mapping_factors(factors)
total_risk_factors = risk_styles + industry_styles
risk_str = ','.join('risk_exposure.' + f for f in total_risk_factors)
......@@ -120,6 +126,8 @@ class SqlEngine(object):
" from (uqer INNER JOIN" \
" risk_exposure on uqer.Date = risk_exposure.Date and uqer.Code = risk_exposure.Code)" \
" INNER JOIN market on uqer.Date = market.Date and uqer.Code = market.Code" \
" INNER JOIN tiny on uqer.Date = tiny.Date and uqer.Code = tiny.Code" \
" INNER JOIN legacy_factor on uqer.Date = legacy_factor.Date and uqer.Code = legacy_factor.Code" \
" INNER JOIN daily_return on uqer.Date = daily_return.Date and uqer.Code = daily_return.Code" \
" INNER JOIN {risk_table} on uqer.Date = {risk_table}.Date and uqer.Code = {risk_table}.Code" \
" where uqer.Date = '{ref_date}' and uqer.Code in ({codes})".format(factors=factor_str,
......@@ -139,7 +147,7 @@ class SqlEngine(object):
risk_cov_data = pd.read_sql(sql, self.engine).sort_values('FactorID')
total_data = {'factor': factor_data, 'risk_cov': risk_cov_data}
total_data = {'risk_cov': risk_cov_data}
if benchmark:
sql = "select Code, weight / 100. as weight from index_components " \
......@@ -148,13 +156,17 @@ class SqlEngine(object):
benchmark_data = pd.read_sql(sql, self.engine)
total_data['benchmark'] = benchmark_data
factor_data = pd.merge(factor_data, benchmark_data, how='left', on=['Code'])
factor_data['weight'] = factor_data['weight'].fillna(0.)
total_data['factor'] = factor_data
append_industry_info(factor_data)
return total_data
if __name__ == '__main__':
db_url = 'mysql+mysqldb://root:we083826@localhost/alpha?charset=utf8'
db_url = 'mssql+pymssql://licheng:A12345678!@10.63.6.220/alpha?charset=utf8'
universe = Universe('zz500', ['zz500'])
engine = SqlEngine(db_url)
......@@ -166,7 +178,7 @@ if __name__ == '__main__':
for i in range(10):
factors = engine.fetch_factors_meta()
codes = engine.fetch_codes('2017-07-04', universe)
total_data = engine.fetch_data(ref_date, ['EPS'], [1, 5], 905)
total_data = engine.fetch_data(ref_date, ['EPS', 'DROEAfterNonRecurring'], [1, 5], 905)
print(dt.datetime.now() - start)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment