Commit f45e145e authored by Dr.李's avatar Dr.李

made factor calculation based on minimum time interval

parent 304766f4
...@@ -221,7 +221,8 @@ class SqlEngine(object): ...@@ -221,7 +221,8 @@ class SqlEngine(object):
def fetch_factor(self, def fetch_factor(self,
ref_date: str, ref_date: str,
factors: Iterable[object], factors: Iterable[object],
codes: Iterable[int]) -> pd.DataFrame: codes: Iterable[int],
default_window: int=0) -> pd.DataFrame:
if isinstance(factors, Transformer): if isinstance(factors, Transformer):
transformer = factors transformer = factors
...@@ -232,20 +233,26 @@ class SqlEngine(object): ...@@ -232,20 +233,26 @@ class SqlEngine(object):
factor_cols = _map_factors(dependency) factor_cols = _map_factors(dependency)
start_date = advanceDateByCalendar('china.sse', ref_date, str(-default_window) + 'b').strftime('%Y-%m-%d')
end_date = ref_date
big_table = Market big_table = Market
for t in set(factor_cols.values()): for t in set(factor_cols.values()):
big_table = outerjoin(big_table, t, and_(Market.Date == t.Date, Market.Code == t.Code)) big_table = outerjoin(big_table, t, and_(Market.Date == t.Date, Market.Code == t.Code))
query = select([Market.Code, Market.isOpen] + list(factor_cols.keys())) \ query = select([Market.Date, Market.Code, Market.isOpen] + list(factor_cols.keys())) \
.select_from(big_table) \ .select_from(big_table) \
.where(and_(Market.Date == ref_date, Market.Code.in_(codes))) .where(and_(Market.Date.between(start_date, end_date), Market.Code.in_(codes)))
df = pd.read_sql(query, self.engine).sort_values('Code') df = pd.read_sql(query, self.engine).sort_values(['Date', 'Code']).set_index('Date')
res = transformer.transform('Code', df) res = transformer.transform('Code', df)
for col in res.columns: for col in res.columns:
if col not in set(['Code', 'isOpen']) and col not in df.columns: if col not in set(['Code', 'isOpen']) and col not in df.columns:
df[col] = res[col].values df[col] = res[col].values
df = df.loc[ref_date]
df.index = list(range(len(df)))
return df return df
def fetch_factor_range(self, def fetch_factor_range(self,
...@@ -253,7 +260,8 @@ class SqlEngine(object): ...@@ -253,7 +260,8 @@ class SqlEngine(object):
factors: Union[Transformer, Iterable[object]], factors: Union[Transformer, Iterable[object]],
start_date: str = None, start_date: str = None,
end_date: str = None, end_date: str = None,
dates: Iterable[str] = None) -> pd.DataFrame: dates: Iterable[str] = None,
default_window: int=0) -> pd.DataFrame:
if isinstance(factors, Transformer): if isinstance(factors, Transformer):
transformer = factors transformer = factors
...@@ -263,7 +271,14 @@ class SqlEngine(object): ...@@ -263,7 +271,14 @@ class SqlEngine(object):
dependency = transformer.dependency dependency = transformer.dependency
factor_cols = _map_factors(dependency) factor_cols = _map_factors(dependency)
q2 = universe.query_range(start_date, end_date, dates).alias('temp_universe') if dates:
real_start_date = advanceDateByCalendar('china.sse', dates[0], str(-default_window) + 'b').strftime('%Y-%m-%d')
real_end_date = dates[-1]
else:
real_start_date = advanceDateByCalendar('china.sse', start_date, str(-default_window) + 'b').strftime('%Y-%m-%d')
real_end_date = end_date
q2 = universe.query_range(real_start_date, real_end_date).alias('temp_universe')
big_table = join(Market, q2, and_(Market.Date == q2.c.Date, Market.Code == q2.c.Code)) big_table = join(Market, q2, and_(Market.Date == q2.c.Date, Market.Code == q2.c.Code))
for t in set(factor_cols.values()): for t in set(factor_cols.values()):
...@@ -278,7 +293,10 @@ class SqlEngine(object): ...@@ -278,7 +293,10 @@ class SqlEngine(object):
for col in res.columns: for col in res.columns:
if col not in set(['Code', 'isOpen']) and col not in df.columns: if col not in set(['Code', 'isOpen']) and col not in df.columns:
df[col] = res[col].values df[col] = res[col].values
if dates:
df = df[df.index.isin(dates)]
else:
df = df[start_date:end_date]
return df.reset_index() return df.reset_index()
def fetch_benchmark(self, def fetch_benchmark(self,
...@@ -446,7 +464,6 @@ if __name__ == '__main__': ...@@ -446,7 +464,6 @@ if __name__ == '__main__':
ref_date = '2017-08-10' ref_date = '2017-08-10'
codes = engine.fetch_codes(universe=universe, ref_date='2017-08-10') codes = engine.fetch_codes(universe=universe, ref_date='2017-08-10')
MAXIMUM(('EPS', 'ROEDiluted')) data2 = engine.fetch_factor_range(universe=universe, start_date='2017-08-01', end_date='2017-08-10', factors={'factor': MAXIMUM(('EPS', 'ROEDiluted'))})
data2 = engine.fetch_factor_range(universe=universe, dates=['2017-08-01', '2017-08-10'], factors={'factor': MAXIMUM(('EPS', 'ROEDiluted'))})
print(codes) print(codes)
print(data2) print(data2)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment