Commit aec38028 authored by Dr.李's avatar Dr.李

fixed when has dependency on hist

parent 0d692780
...@@ -356,9 +356,11 @@ class SqlEngine(object): ...@@ -356,9 +356,11 @@ class SqlEngine(object):
factor_cols = _map_factors(dependency, factor_tables) factor_cols = _map_factors(dependency, factor_tables)
big_table = FullFactor big_table = FullFactor
joined_tables = set()
joined_tables.add(FullFactor.__table__.name)
for t in set(factor_cols.values()): for t in set(factor_cols.values()):
if t.__table__.name != FullFactor.__table__.name: if t.__table__.name not in joined_tables:
if dates is not None: if dates is not None:
big_table = outerjoin(big_table, t, and_(FullFactor.trade_date == t.trade_date, big_table = outerjoin(big_table, t, and_(FullFactor.trade_date == t.trade_date,
FullFactor.code == t.code, FullFactor.code == t.code,
...@@ -367,20 +369,18 @@ class SqlEngine(object): ...@@ -367,20 +369,18 @@ class SqlEngine(object):
big_table = outerjoin(big_table, t, and_(FullFactor.trade_date == t.trade_date, big_table = outerjoin(big_table, t, and_(FullFactor.trade_date == t.trade_date,
FullFactor.code == t.code, FullFactor.code == t.code,
FullFactor.trade_date.between(start_date, end_date))) FullFactor.trade_date.between(start_date, end_date)))
joined_tables.add(t.__table__.name)
cond = universe._query_statements(start_date, end_date, dates) universe_df = universe.query(self, start_date, end_date, dates)
big_table = join(big_table, UniverseTable,
and_(
FullFactor.trade_date == UniverseTable.trade_date,
FullFactor.code == UniverseTable.code,
cond
)
)
query = select( query = select(
[FullFactor.trade_date, FullFactor.code, FullFactor.isOpen] + list(factor_cols.keys())) \ [FullFactor.trade_date, FullFactor.code, FullFactor.isOpen] + list(factor_cols.keys())) \
.select_from(big_table).distinct() .select_from(big_table).where(
and_(
FullFactor.code.in_(universe_df.code.unique().tolist()),
FullFactor.trade_date.in_(dates) if dates is not None else FullFactor.trade_date.between(start_date, end_date)
)
).distinct()
df = pd.read_sql(query, self.engine) df = pd.read_sql(query, self.engine)
if universe.is_filtered: if universe.is_filtered:
...@@ -391,7 +391,6 @@ class SqlEngine(object): ...@@ -391,7 +391,6 @@ class SqlEngine(object):
df = pd.merge(df, external_data, on=['trade_date', 'code']).dropna() df = pd.merge(df, external_data, on=['trade_date', 'code']).dropna()
df.sort_values(['trade_date', 'code'], inplace=True) df.sort_values(['trade_date', 'code'], inplace=True)
df.set_index('trade_date', inplace=True) df.set_index('trade_date', inplace=True)
res = transformer.transform('code', df) res = transformer.transform('code', df)
...@@ -400,7 +399,8 @@ class SqlEngine(object): ...@@ -400,7 +399,8 @@ class SqlEngine(object):
df[col] = res[col].values df[col] = res[col].values
df['isOpen'] = df.isOpen.astype(bool) df['isOpen'] = df.isOpen.astype(bool)
return df.reset_index() df = df.reset_index()
return pd.merge(df, universe_df[['trade_date', 'code']], how='inner')
def fetch_benchmark(self, def fetch_benchmark(self,
ref_date: str, ref_date: str,
...@@ -924,7 +924,10 @@ class SqlEngine(object): ...@@ -924,7 +924,10 @@ class SqlEngine(object):
if __name__ == '__main__': if __name__ == '__main__':
from PyFin.api import *
engine = SqlEngine() engine = SqlEngine()
ref_date = '2017-06-29' ref_date = '2017-06-29'
universe = Universe('', ['zz800']) universe = Universe('', ['zz800'])
p_returns = engine.fetch_dx_return_range(universe, ref_date, ref_date, horizon=0)
dates = makeSchedule('2010-01-01', '2018-02-01', '10b', 'china.sse')
df = engine.fetch_factor_range(universe, DIFF('roe_q'), dates=dates)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment