Commit f45e145e authored by Dr.李's avatar Dr.李

made factor calculation based on minimum time interval

parent 304766f4
......@@ -221,7 +221,8 @@ class SqlEngine(object):
def fetch_factor(self,
ref_date: str,
factors: Iterable[object],
codes: Iterable[int]) -> pd.DataFrame:
codes: Iterable[int],
default_window: int=0) -> pd.DataFrame:
if isinstance(factors, Transformer):
transformer = factors
......@@ -232,20 +233,26 @@ class SqlEngine(object):
factor_cols = _map_factors(dependency)
start_date = advanceDateByCalendar('china.sse', ref_date, str(-default_window) + 'b').strftime('%Y-%m-%d')
end_date = ref_date
big_table = Market
for t in set(factor_cols.values()):
big_table = outerjoin(big_table, t, and_(Market.Date == t.Date, Market.Code == t.Code))
query = select([Market.Code, Market.isOpen] + list(factor_cols.keys())) \
query = select([Market.Date, Market.Code, Market.isOpen] + list(factor_cols.keys())) \
.select_from(big_table) \
.where(and_(Market.Date == ref_date, Market.Code.in_(codes)))
.where(and_(Market.Date.between(start_date, end_date), Market.Code.in_(codes)))
df = pd.read_sql(query, self.engine).sort_values('Code')
df = pd.read_sql(query, self.engine).sort_values(['Date', 'Code']).set_index('Date')
res = transformer.transform('Code', df)
for col in res.columns:
if col not in set(['Code', 'isOpen']) and col not in df.columns:
df[col] = res[col].values
df = df.loc[ref_date]
df.index = list(range(len(df)))
return df
def fetch_factor_range(self,
......@@ -253,7 +260,8 @@ class SqlEngine(object):
factors: Union[Transformer, Iterable[object]],
start_date: str = None,
end_date: str = None,
dates: Iterable[str] = None) -> pd.DataFrame:
dates: Iterable[str] = None,
default_window: int=0) -> pd.DataFrame:
if isinstance(factors, Transformer):
transformer = factors
......@@ -263,7 +271,14 @@ class SqlEngine(object):
dependency = transformer.dependency
factor_cols = _map_factors(dependency)
q2 = universe.query_range(start_date, end_date, dates).alias('temp_universe')
if dates:
real_start_date = advanceDateByCalendar('china.sse', dates[0], str(-default_window) + 'b').strftime('%Y-%m-%d')
real_end_date = dates[-1]
else:
real_start_date = advanceDateByCalendar('china.sse', start_date, str(-default_window) + 'b').strftime('%Y-%m-%d')
real_end_date = end_date
q2 = universe.query_range(real_start_date, real_end_date).alias('temp_universe')
big_table = join(Market, q2, and_(Market.Date == q2.c.Date, Market.Code == q2.c.Code))
for t in set(factor_cols.values()):
......@@ -278,7 +293,10 @@ class SqlEngine(object):
for col in res.columns:
if col not in set(['Code', 'isOpen']) and col not in df.columns:
df[col] = res[col].values
if dates:
df = df[df.index.isin(dates)]
else:
df = df[start_date:end_date]
return df.reset_index()
def fetch_benchmark(self,
......@@ -446,7 +464,6 @@ if __name__ == '__main__':
ref_date = '2017-08-10'
codes = engine.fetch_codes(universe=universe, ref_date='2017-08-10')
MAXIMUM(('EPS', 'ROEDiluted'))
data2 = engine.fetch_factor_range(universe=universe, dates=['2017-08-01', '2017-08-10'], factors={'factor': MAXIMUM(('EPS', 'ROEDiluted'))})
data2 = engine.fetch_factor_range(universe=universe, start_date='2017-08-01', end_date='2017-08-10', factors={'factor': MAXIMUM(('EPS', 'ROEDiluted'))})
print(codes)
print(data2)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment