Commit 56a7e393 authored by Dr.李's avatar Dr.李

refactor return get

parent 6b86de35
...@@ -155,6 +155,13 @@ class SqlEngine(object): ...@@ -155,6 +155,13 @@ class SqlEngine(object):
dates: Iterable[str] = None) -> pd.DataFrame: dates: Iterable[str] = None) -> pd.DataFrame:
return universe.query(self, start_date, end_date, dates) return universe.query(self, start_date, end_date, dates)
def _create_stats(self, table, horizon, offset):
stats = func.sum(self.ln_func(1. + table.chgPct)).over(
partition_by=table.code,
order_by=table.trade_date,
rows=(1 + DAILY_RETURN_OFFSET + offset, 1 + horizon + DAILY_RETURN_OFFSET + offset)).label('dx')
return stats
def fetch_dx_return(self, def fetch_dx_return(self,
ref_date: str, ref_date: str,
codes: Iterable[int], codes: Iterable[int],
...@@ -169,10 +176,7 @@ class SqlEngine(object): ...@@ -169,10 +176,7 @@ class SqlEngine(object):
else: else:
end_date = expiry_date end_date = expiry_date
stats = func.sum(self.ln_func(1. + Market.chgPct)).over( stats = self._create_stats(Market, horizon, offset)
partition_by=Market.code,
order_by=Market.trade_date,
rows=(1 + DAILY_RETURN_OFFSET + offset, 1 + horizon + DAILY_RETURN_OFFSET + offset)).label('dx')
query = select([Market.trade_date, Market.code, stats]).where( query = select([Market.trade_date, Market.code, stats]).where(
and_( and_(
...@@ -200,10 +204,7 @@ class SqlEngine(object): ...@@ -200,10 +204,7 @@ class SqlEngine(object):
end_date = advanceDateByCalendar('china.sse', end_date, end_date = advanceDateByCalendar('china.sse', end_date,
str(1 + horizon + offset + DAILY_RETURN_OFFSET) + 'b').strftime('%Y-%m-%d') str(1 + horizon + offset + DAILY_RETURN_OFFSET) + 'b').strftime('%Y-%m-%d')
stats = func.sum(self.ln_func(1. + Market.chgPct)).over( stats = self._create_stats(Market, horizon, offset)
partition_by=Market.code,
order_by=Market.trade_date,
rows=(1 + offset + DAILY_RETURN_OFFSET, 1 + horizon + offset + DAILY_RETURN_OFFSET)).label('dx')
cond = universe._query_statements(start_date, end_date, None) cond = universe._query_statements(start_date, end_date, None)
...@@ -243,10 +244,7 @@ class SqlEngine(object): ...@@ -243,10 +244,7 @@ class SqlEngine(object):
else: else:
end_date = expiry_date end_date = expiry_date
stats = func.sum(self.ln_func(1. + IndexMarket.chgPct)).over( stats = self._create_stats(IndexMarket, horizon, offset)
partition_by=IndexMarket.indexCode,
order_by=IndexMarket.trade_date,
rows=(1 + DAILY_RETURN_OFFSET + offset, 1 + horizon + DAILY_RETURN_OFFSET + offset)).label('dx')
query = select([IndexMarket.trade_date, IndexMarket.indexCode.label('code'), stats]).where( query = select([IndexMarket.trade_date, IndexMarket.indexCode.label('code'), stats]).where(
and_( and_(
...@@ -274,10 +272,7 @@ class SqlEngine(object): ...@@ -274,10 +272,7 @@ class SqlEngine(object):
end_date = advanceDateByCalendar('china.sse', end_date, end_date = advanceDateByCalendar('china.sse', end_date,
str(1 + horizon + offset + DAILY_RETURN_OFFSET) + 'b').strftime('%Y-%m-%d') str(1 + horizon + offset + DAILY_RETURN_OFFSET) + 'b').strftime('%Y-%m-%d')
stats = func.sum(self.ln_func(1. + IndexMarket.chgPct)).over( stats = self._create_stats(IndexMarket, horizon, offset)
partition_by=IndexMarket.indexCode,
order_by=IndexMarket.trade_date,
rows=(1 + offset + DAILY_RETURN_OFFSET, 1 + horizon + offset + DAILY_RETURN_OFFSET)).label('dx')
query = select([IndexMarket.trade_date, IndexMarket.indexCode.label('code'), stats]) \ query = select([IndexMarket.trade_date, IndexMarket.indexCode.label('code'), stats]) \
.where( .where(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment