Commit edee4aed authored by Dr.李's avatar Dr.李

more conditions in universe

parent cbc97303
...@@ -22,9 +22,16 @@ from alphamind.data.transformer import Transformer ...@@ -22,9 +22,16 @@ from alphamind.data.transformer import Transformer
class Universe(object): class Universe(object):
def __init__(self, name, base_universe, filter_cond=None): def __init__(self,
name: str,
base_universe: Iterable,
exclude_universe: Iterable=None,
special_codes: Iterable=None,
filter_cond=None):
self.name = name self.name = name
self.base_universe = base_universe self.base_universe = base_universe
self.exclude_universe = exclude_universe
self.special_codes = special_codes
self.filter_cond = filter_cond self.filter_cond = filter_cond
@property @property
...@@ -32,9 +39,24 @@ class Universe(object): ...@@ -32,9 +39,24 @@ class Universe(object):
return True if self.filter_cond is not None else False return True if self.filter_cond is not None else False
def _query_statements(self, start_date, end_date, dates): def _query_statements(self, start_date, end_date, dates):
or_conditions = []
if self.special_codes:
or_conditions.append(UniverseTable.code.in_(self.special_codes))
query = or_(
UniverseTable.universe.in_(self.base_universe),
*or_conditions
)
and_conditions = []
if self.exclude_universe:
and_conditions.append(~UniverseTable.universe.in_(self.exclude_universe))
return and_( return and_(
query,
UniverseTable.trade_date.in_(dates) if dates else UniverseTable.trade_date.between(start_date, end_date), UniverseTable.trade_date.in_(dates) if dates else UniverseTable.trade_date.between(start_date, end_date),
UniverseTable.universe.in_(self.base_universe) *and_conditions
) )
def query(self, engine, start_date: str=None, end_date: str=None, dates=None) -> pd.DataFrame: def query(self, engine, start_date: str=None, end_date: str=None, dates=None) -> pd.DataFrame:
...@@ -87,7 +109,7 @@ if __name__ == '__main__': ...@@ -87,7 +109,7 @@ if __name__ == '__main__':
from alphamind.data.engines.sqlengine import SqlEngine from alphamind.data.engines.sqlengine import SqlEngine
engine = SqlEngine() engine = SqlEngine()
universe = UniverseNew('ss', ['hs300']) universe = Universe('ss', ['hs300'], special_codes=[603138])
print(universe.query(engine, print(universe.query(engine,
start_date='2017-12-21', start_date='2017-12-21',
end_date='2017-12-25')) end_date='2017-12-25'))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment