Commit 754c77ac authored by Dr.李's avatar Dr.李

Merge remote-tracking branch 'origin/master'

parents 7186574b b1a7a2bf
......@@ -644,12 +644,48 @@ class Strategy(Base):
class Universe(Base):
__tablename__ = 'universe'
__table_args__ = (
Index('universe_idx', 'trade_date', 'universe', 'code', unique=True),
Index('universe_trade_date_code_uindex', 'trade_date', 'code', unique=True),
)
trade_date = Column(DateTime, primary_key=True, nullable=False)
code = Column(Integer, primary_key=True, nullable=False)
universe = Column(String(20), primary_key=True, nullable=False)
code = Column(BigInteger, primary_key=True, nullable=False)
aerodef = Column(Integer)
agriforest = Column(Integer)
auto = Column(Integer)
bank = Column(Integer)
builddeco = Column(Integer)
chem = Column(Integer)
conmat = Column(Integer)
commetrade = Column(Integer)
computer = Column(Integer)
conglomerates = Column(Integer)
eleceqp = Column(Integer)
electronics = Column(Integer)
foodbever = Column(Integer)
health = Column(Integer)
houseapp = Column(Integer)
ironsteel = Column(Integer)
leiservice = Column(Integer)
lightindus = Column(Integer)
machiequip = Column(Integer)
media = Column(Integer)
mining = Column(Integer)
nonbankfinan = Column(Integer)
nonfermetal = Column(Integer)
realestate = Column(Integer)
telecom = Column(Integer)
textile = Column(Integer)
transportation = Column(Integer)
utilities = Column(Integer)
ashare = Column(Integer)
ashare_ex = Column(Integer)
cyb = Column(Integer)
hs300 = Column(Integer)
sh50 = Column(Integer)
zxb = Column(Integer)
zz1000 = Column(Integer)
zz500 = Column(Integer)
zz800 = Column(Integer)
class Uqer(Base):
......
......@@ -31,8 +31,8 @@ class Universe(object):
special_codes: Iterable = None,
filter_cond=None):
self.name = name
self.base_universe = sorted(base_universe) if base_universe else None
self.exclude_universe = sorted(exclude_universe) if exclude_universe else None
self.base_universe = sorted([u.lower() for u in base_universe]) if base_universe else None
self.exclude_universe = sorted([u.lower() for u in exclude_universe]) if exclude_universe else None
self.special_codes = sorted(special_codes) if special_codes else None
self.filter_cond = filter_cond
......@@ -50,16 +50,27 @@ class Universe(object):
def _query_statements(self, start_date, end_date, dates):
or_conditions = []
for u in self.base_universe:
or_conditions.append(
getattr(UniverseTable, u) == 1
)
if self.special_codes:
or_conditions.append(UniverseTable.code.in_(self.special_codes))
query = or_(
UniverseTable.universe.in_(self.base_universe),
*or_conditions
)
and_conditions = []
if self.exclude_universe:
for u in self.exclude_universe:
and_conditions.append( getattr(UniverseTable, u) != 1)
return and_(
query,
*and_conditions,
UniverseTable.trade_date.in_(dates) if dates else UniverseTable.trade_date.between(start_date, end_date),
)
......@@ -67,11 +78,11 @@ class Universe(object):
universe_cond = self._query_statements(start_date, end_date, dates)
if self.filter_cond is None and self.exclude_universe is None:
if self.filter_cond is None:
# simple case
query = select([UniverseTable.trade_date, UniverseTable.code]).where(
universe_cond
).distinct()
)
return pd.read_sql(query, engine.engine)
else:
if self.filter_cond is not None:
......
......@@ -10,7 +10,7 @@ import unittest
import numpy as np
import pandas as pd
from scipy.stats import rankdata
from sqlalchemy import select, and_
from sqlalchemy import select, and_, or_
from PyFin.api import makeSchedule
from PyFin.api import advanceDateByCalendar
from PyFin.api import bizDatesList
......@@ -48,9 +48,12 @@ class TestSqlEngine(unittest.TestCase):
query = select([UniverseTable.code]).where(
and_(
UniverseTable.trade_date == ref_date,
UniverseTable.universe.in_(['zz500', 'zz1000'])
or_(
UniverseTable.zz500 == 1,
UniverseTable.zz1000 == 1
)
)
).distinct()
)
df = pd.read_sql(query, con=self.engine.engine).sort_values('code')
self.assertListEqual(codes, list(df.code.values))
......@@ -65,9 +68,12 @@ class TestSqlEngine(unittest.TestCase):
query = select([UniverseTable.trade_date, UniverseTable.code]).where(
and_(
UniverseTable.trade_date.in_(ref_dates),
UniverseTable.universe.in_(['zz500', 'zz1000'])
or_(
UniverseTable.zz500 == 1,
UniverseTable.zz1000 == 1
)
)
).distinct()
)
df = pd.read_sql(query, con=self.engine.engine).sort_values('code')
......@@ -76,6 +82,22 @@ class TestSqlEngine(unittest.TestCase):
expected_codes = list(sorted(df[df.trade_date == ref_date].code.values))
self.assertListEqual(calculated_codes, expected_codes)
def test_sdl_engine_fetch_codes_with_exclude_universe(self):
ref_date = self.ref_date
universe = Universe('custom', ['zz500'], exclude_universe=['cyb'])
codes = self.engine.fetch_codes(ref_date, universe)
query = select([UniverseTable.code]).where(
and_(
UniverseTable.trade_date == ref_date,
UniverseTable.zz500 == 1,
UniverseTable.cyb == 0
)
)
df = pd.read_sql(query, con=self.engine.engine).sort_values('code')
self.assertListEqual(codes, list(df.code.values))
def test_sql_engine_fetch_dx_return(self):
horizon = 4
offset = 1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment