Commit c95f7a2b authored by Dr.李's avatar Dr.李

Merge remote-tracking branch 'origin/master'

parents 9ace5e9b 9a15510b
......@@ -14,7 +14,7 @@ import numpy as np
import pandas as pd
import sqlalchemy as sa
import sqlalchemy.orm as orm
from sqlalchemy import select, and_, outerjoin, join, delete, insert
from sqlalchemy import select, and_, outerjoin, join, delete, insert, column
from sqlalchemy.sql import func
from sqlalchemy.sql.expression import bindparam
from alphamind.data.engines.universe import Universe
......@@ -706,16 +706,58 @@ class SqlEngine(object):
def fetch_trade_status(self,
ref_date: str,
codes: Iterable[int]):
codes: Iterable[int],
offset=0):
query = select([Market.code, Market.isOpen]).where(
target_date = advanceDateByCalendar('china.sse', ref_date,
str(offset) + 'b').strftime('%Y%m%d')
stats = func.lead(Market.isOpen, 1).over(
partition_by=Market.code,
order_by=Market.trade_date).label('is_open')
cte = select([Market.trade_date, Market.code, stats]).where(
and_(
Market.trade_date == ref_date,
Market.trade_date.in_([ref_date, target_date]),
Market.code.in_(codes)
)
)
).cte('cte')
query = select([column('code'), column('is_open')]).select_from(cte).where(
column('trade_date') == ref_date
).order_by(column('code'))
return pd.read_sql(query, self.engine).sort_values(['code'])
def fetch_trade_status_range(self,
universe: Universe,
start_date: str = None,
end_date: str = None,
dates: Iterable[str] = None,
offset=0):
codes = universe.query(self, start_date, end_date, dates)
if dates:
start_date = dates[0]
end_date = dates[-1]
end_date = advanceDateByCalendar('china.sse', end_date,
str(offset) + 'b').strftime('%Y-%m-%d')
stats = func.lead(Market.isOpen, 1).over(
partition_by=Market.code,
order_by=Market.trade_date).label('is_open')
cte = select([Market.trade_date, Market.code, stats]).where(
and_(
Market.trade_date.between(start_date, end_date),
Market.code.in_(codes.code.unique().tolist())
)
).cte('cte')
query = select([cte]).select_from(cte).order_by(cte.columns['trade_date'], cte.columns['code'])
df = pd.read_sql(query, self.engine)
return pd.merge(df, codes[['trade_date', 'code']], on=['trade_date', 'code'])
def fetch_data(self,
ref_date: str,
factors: Iterable[str],
......@@ -992,10 +1034,12 @@ if __name__ == '__main__':
from PyFin.api import *
engine = SqlEngine()
ref_date = '2017-06-29'
universe = Universe('', ['zz800'])
ref_date = '2017-05-03'
universe = Universe('custon', ['zz800'])
codes = engine.fetch_codes(ref_date, universe)
dates = makeSchedule('2018-01-01', '2018-02-01', '10b', 'china.sse')
factor_data = engine.fetch_dx_return('2018-01-30', codes, neutralized_risks=risk_styles+industry_styles)
print(factor_data)
# df = engine.fetch_trade_status(ref_date, codes, offset=1)
dates = ['2017-05-02', '2017-05-03', '2017-05-04']
df = engine.fetch_trade_status_range(universe, dates=dates, offset=1)
print(df)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment