Commit 754c77ac authored by Dr.李's avatar Dr.李

Merge remote-tracking branch 'origin/master'

parents 7186574b b1a7a2bf
...@@ -644,12 +644,48 @@ class Strategy(Base): ...@@ -644,12 +644,48 @@ class Strategy(Base):
class Universe(Base): class Universe(Base):
__tablename__ = 'universe' __tablename__ = 'universe'
__table_args__ = ( __table_args__ = (
Index('universe_idx', 'trade_date', 'universe', 'code', unique=True), Index('universe_trade_date_code_uindex', 'trade_date', 'code', unique=True),
) )
trade_date = Column(DateTime, primary_key=True, nullable=False) trade_date = Column(DateTime, primary_key=True, nullable=False)
code = Column(Integer, primary_key=True, nullable=False) code = Column(BigInteger, primary_key=True, nullable=False)
universe = Column(String(20), primary_key=True, nullable=False) aerodef = Column(Integer)
agriforest = Column(Integer)
auto = Column(Integer)
bank = Column(Integer)
builddeco = Column(Integer)
chem = Column(Integer)
conmat = Column(Integer)
commetrade = Column(Integer)
computer = Column(Integer)
conglomerates = Column(Integer)
eleceqp = Column(Integer)
electronics = Column(Integer)
foodbever = Column(Integer)
health = Column(Integer)
houseapp = Column(Integer)
ironsteel = Column(Integer)
leiservice = Column(Integer)
lightindus = Column(Integer)
machiequip = Column(Integer)
media = Column(Integer)
mining = Column(Integer)
nonbankfinan = Column(Integer)
nonfermetal = Column(Integer)
realestate = Column(Integer)
telecom = Column(Integer)
textile = Column(Integer)
transportation = Column(Integer)
utilities = Column(Integer)
ashare = Column(Integer)
ashare_ex = Column(Integer)
cyb = Column(Integer)
hs300 = Column(Integer)
sh50 = Column(Integer)
zxb = Column(Integer)
zz1000 = Column(Integer)
zz500 = Column(Integer)
zz800 = Column(Integer)
class Uqer(Base): class Uqer(Base):
......
...@@ -31,8 +31,8 @@ class Universe(object): ...@@ -31,8 +31,8 @@ class Universe(object):
special_codes: Iterable = None, special_codes: Iterable = None,
filter_cond=None): filter_cond=None):
self.name = name self.name = name
self.base_universe = sorted(base_universe) if base_universe else None self.base_universe = sorted([u.lower() for u in base_universe]) if base_universe else None
self.exclude_universe = sorted(exclude_universe) if exclude_universe else None self.exclude_universe = sorted([u.lower() for u in exclude_universe]) if exclude_universe else None
self.special_codes = sorted(special_codes) if special_codes else None self.special_codes = sorted(special_codes) if special_codes else None
self.filter_cond = filter_cond self.filter_cond = filter_cond
...@@ -50,16 +50,27 @@ class Universe(object): ...@@ -50,16 +50,27 @@ class Universe(object):
def _query_statements(self, start_date, end_date, dates): def _query_statements(self, start_date, end_date, dates):
or_conditions = [] or_conditions = []
for u in self.base_universe:
or_conditions.append(
getattr(UniverseTable, u) == 1
)
if self.special_codes: if self.special_codes:
or_conditions.append(UniverseTable.code.in_(self.special_codes)) or_conditions.append(UniverseTable.code.in_(self.special_codes))
query = or_( query = or_(
UniverseTable.universe.in_(self.base_universe),
*or_conditions *or_conditions
) )
and_conditions = []
if self.exclude_universe:
for u in self.exclude_universe:
and_conditions.append( getattr(UniverseTable, u) != 1)
return and_( return and_(
query, query,
*and_conditions,
UniverseTable.trade_date.in_(dates) if dates else UniverseTable.trade_date.between(start_date, end_date), UniverseTable.trade_date.in_(dates) if dates else UniverseTable.trade_date.between(start_date, end_date),
) )
...@@ -67,11 +78,11 @@ class Universe(object): ...@@ -67,11 +78,11 @@ class Universe(object):
universe_cond = self._query_statements(start_date, end_date, dates) universe_cond = self._query_statements(start_date, end_date, dates)
if self.filter_cond is None and self.exclude_universe is None: if self.filter_cond is None:
# simple case # simple case
query = select([UniverseTable.trade_date, UniverseTable.code]).where( query = select([UniverseTable.trade_date, UniverseTable.code]).where(
universe_cond universe_cond
).distinct() )
return pd.read_sql(query, engine.engine) return pd.read_sql(query, engine.engine)
else: else:
if self.filter_cond is not None: if self.filter_cond is not None:
......
...@@ -10,7 +10,7 @@ import unittest ...@@ -10,7 +10,7 @@ import unittest
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from scipy.stats import rankdata from scipy.stats import rankdata
from sqlalchemy import select, and_ from sqlalchemy import select, and_, or_
from PyFin.api import makeSchedule from PyFin.api import makeSchedule
from PyFin.api import advanceDateByCalendar from PyFin.api import advanceDateByCalendar
from PyFin.api import bizDatesList from PyFin.api import bizDatesList
...@@ -48,9 +48,12 @@ class TestSqlEngine(unittest.TestCase): ...@@ -48,9 +48,12 @@ class TestSqlEngine(unittest.TestCase):
query = select([UniverseTable.code]).where( query = select([UniverseTable.code]).where(
and_( and_(
UniverseTable.trade_date == ref_date, UniverseTable.trade_date == ref_date,
UniverseTable.universe.in_(['zz500', 'zz1000']) or_(
UniverseTable.zz500 == 1,
UniverseTable.zz1000 == 1
)
)
) )
).distinct()
df = pd.read_sql(query, con=self.engine.engine).sort_values('code') df = pd.read_sql(query, con=self.engine.engine).sort_values('code')
self.assertListEqual(codes, list(df.code.values)) self.assertListEqual(codes, list(df.code.values))
...@@ -65,9 +68,12 @@ class TestSqlEngine(unittest.TestCase): ...@@ -65,9 +68,12 @@ class TestSqlEngine(unittest.TestCase):
query = select([UniverseTable.trade_date, UniverseTable.code]).where( query = select([UniverseTable.trade_date, UniverseTable.code]).where(
and_( and_(
UniverseTable.trade_date.in_(ref_dates), UniverseTable.trade_date.in_(ref_dates),
UniverseTable.universe.in_(['zz500', 'zz1000']) or_(
UniverseTable.zz500 == 1,
UniverseTable.zz1000 == 1
)
)
) )
).distinct()
df = pd.read_sql(query, con=self.engine.engine).sort_values('code') df = pd.read_sql(query, con=self.engine.engine).sort_values('code')
...@@ -76,6 +82,22 @@ class TestSqlEngine(unittest.TestCase): ...@@ -76,6 +82,22 @@ class TestSqlEngine(unittest.TestCase):
expected_codes = list(sorted(df[df.trade_date == ref_date].code.values)) expected_codes = list(sorted(df[df.trade_date == ref_date].code.values))
self.assertListEqual(calculated_codes, expected_codes) self.assertListEqual(calculated_codes, expected_codes)
def test_sdl_engine_fetch_codes_with_exclude_universe(self):
ref_date = self.ref_date
universe = Universe('custom', ['zz500'], exclude_universe=['cyb'])
codes = self.engine.fetch_codes(ref_date, universe)
query = select([UniverseTable.code]).where(
and_(
UniverseTable.trade_date == ref_date,
UniverseTable.zz500 == 1,
UniverseTable.cyb == 0
)
)
df = pd.read_sql(query, con=self.engine.engine).sort_values('code')
self.assertListEqual(codes, list(df.code.values))
def test_sql_engine_fetch_dx_return(self): def test_sql_engine_fetch_dx_return(self):
horizon = 4 horizon = 4
offset = 1 offset = 1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment