Commit 14ea7d7c authored by Dr.李's avatar Dr.李

fixed bug

parent b83b8b64
...@@ -544,7 +544,20 @@ class SqlEngine(object): ...@@ -544,7 +544,20 @@ class SqlEngine(object):
df = self.fetch_industry(ref_date, codes, category, level) df = self.fetch_industry(ref_date, codes, category, level)
df = pd.get_dummies(df, columns=['industry'], prefix="", prefix_sep="") df = pd.get_dummies(df, columns=['industry'], prefix="", prefix_sep="")
industries = industry_list(category, level) industries = industry_list(category, level)
return df[['code', 'industry_code'] + industries]
in_s = []
out_s = []
for i in industries:
if i in df:
in_s.append(i)
else:
out_s.append(i)
res = df[['code', 'industry_code'] + in_s]
for i in out_s:
res[i] = 0
return res
def fetch_industry_range(self, def fetch_industry_range(self,
universe: Universe, universe: Universe,
...@@ -554,7 +567,6 @@ class SqlEngine(object): ...@@ -554,7 +567,6 @@ class SqlEngine(object):
category: str = 'sw', category: str = 'sw',
level: int = 1): level: int = 1):
industry_category_name = _map_industry_category(category) industry_category_name = _map_industry_category(category)
cond = universe._query_statements(start_date, end_date, dates) cond = universe._query_statements(start_date, end_date, dates)
big_table = join(Industry, UniverseTable, big_table = join(Industry, UniverseTable,
...@@ -563,8 +575,7 @@ class SqlEngine(object): ...@@ -563,8 +575,7 @@ class SqlEngine(object):
Industry.code == UniverseTable.code, Industry.code == UniverseTable.code,
Industry.industry == industry_category_name, Industry.industry == industry_category_name,
cond cond
) ))
)
code_name = 'industryID' + str(level) code_name = 'industryID' + str(level)
category_name = 'industryName' + str(level) category_name = 'industryName' + str(level)
...@@ -591,7 +602,20 @@ class SqlEngine(object): ...@@ -591,7 +602,20 @@ class SqlEngine(object):
df = self.fetch_industry_range(universe, start_date, end_date, dates, category, level) df = self.fetch_industry_range(universe, start_date, end_date, dates, category, level)
df = pd.get_dummies(df, columns=['industry'], prefix="", prefix_sep="") df = pd.get_dummies(df, columns=['industry'], prefix="", prefix_sep="")
industries = industry_list(category, level) industries = industry_list(category, level)
return df[['trade_date', 'code', 'industry_code'] + industries]
in_s = []
out_s = []
for i in industries:
if i in df:
in_s.append(i)
else:
out_s.append(i)
res = df[['code', 'industry_code'] + in_s]
for i in out_s:
res[i] = 0
return res
def fetch_data(self, ref_date: str, def fetch_data(self, ref_date: str,
factors: Iterable[str], factors: Iterable[str],
...@@ -889,5 +913,5 @@ if __name__ == '__main__': ...@@ -889,5 +913,5 @@ if __name__ == '__main__':
df = engine.fetch_industry_matrix(ref_date, codes.code.tolist(), 'dx', 1) df = engine.fetch_industry_matrix(ref_date, codes.code.tolist(), 'dx', 1)
print(df) print(df)
df = engine.fetch_industry_matrix_range(universe, '2011-12-28', '2017-12-31', category='sw', level=1) df = engine.fetch_industry_matrix_range(universe, '2017-12-28', '2017-12-31', category='zz', level=2)
print(df) print(df)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment