Commit 257bf603 authored by Dr.李's avatar Dr.李

update new universe

parent ad7c2454
...@@ -24,11 +24,11 @@ class Universe(object): ...@@ -24,11 +24,11 @@ class Universe(object):
def __init__(self, def __init__(self,
name, name,
base_universe: Iterable[str]=None, base_universe: Iterable[str] = None,
include_universe: Iterable[str]=None, include_universe: Iterable[str] = None,
exclude_universe: Iterable[str]=None, exclude_universe: Iterable[str] = None,
include_codes: Iterable[str]=None, include_codes: Iterable[str] = None,
exclude_codes: Iterable[str]=None): exclude_codes: Iterable[str] = None):
self.name = name self.name = name
self.base_universe = base_universe self.base_universe = base_universe
...@@ -84,7 +84,8 @@ class Universe(object): ...@@ -84,7 +84,8 @@ class Universe(object):
def query_range(self, start_date=None, end_date=None, dates=None): def query_range(self, start_date=None, end_date=None, dates=None):
all_and_conditions, all_or_conditions = self._create_condition() all_and_conditions, all_or_conditions = self._create_condition()
dates_cond = UniverseTable.trade_date.in_(dates) if dates else UniverseTable.trade_date.between(start_date, end_date) dates_cond = UniverseTable.trade_date.in_(dates) if dates else UniverseTable.trade_date.between(start_date,
end_date)
if all_or_conditions: if all_or_conditions:
query = and_( query = and_(
...@@ -102,26 +103,29 @@ class Universe(object): ...@@ -102,26 +103,29 @@ class Universe(object):
class UniverseNew(object): class UniverseNew(object):
def __init__(self, name, base_universe): def __init__(self, name, base_universe, filter_cond=None):
self.name = name self.name = name
self.base_universe = base_universe self.base_universe = base_universe
self.filter_cond = filter_cond
def query(self, engine, ref_date: str, filter_cond=None) -> pd.DataFrame: def query(self, engine, start_date: str=None, end_date: str=None, dates=None) -> pd.DataFrame:
if filter_cond is None: universe_cond = and_(
# simple case UniverseTable.trade_date.in_(dates) if dates else UniverseTable.trade_date.between(start_date, end_date),
query = select([UniverseTable.trade_date, UniverseTable.code]).where(
and_(
UniverseTable.trade_date == ref_date,
UniverseTable.universe.in_(self.base_universe) UniverseTable.universe.in_(self.base_universe)
) )
if self.filter_cond is None:
# simple case
query = select([UniverseTable.trade_date, UniverseTable.code]).where(
universe_cond
) )
return pd.read_sql(query, engine.engine) return pd.read_sql(query, engine.engine)
else: else:
if isinstance(filter_cond, Transformer): if isinstance(self.filter_cond, Transformer):
transformer = filter_cond transformer = self.filter_cond
else: else:
transformer = Transformer(filter_cond) transformer = Transformer(self.filter_cond)
dependency = transformer.dependency dependency = transformer.dependency
factor_cols = _map_factors(dependency, factor_tables) factor_cols = _map_factors(dependency, factor_tables)
...@@ -131,13 +135,9 @@ class UniverseNew(object): ...@@ -131,13 +135,9 @@ class UniverseNew(object):
if t.__table__.name != FullFactor.__table__.name: if t.__table__.name != FullFactor.__table__.name:
big_table = outerjoin(big_table, t, and_(FullFactor.trade_date == t.trade_date, big_table = outerjoin(big_table, t, and_(FullFactor.trade_date == t.trade_date,
FullFactor.code == t.code, FullFactor.code == t.code,
FullFactor.trade_date == ref_date)) FullFactor.trade_date.in_(
dates) if dates else FullFactor.trade_date.between(
universe_cond = and_( start_date, end_date)))
UniverseTable.trade_date == ref_date,
UniverseTable.universe.in_(self.base_universe)
)
big_table = join(big_table, UniverseTable, big_table = join(big_table, UniverseTable,
and_(FullFactor.trade_date == UniverseTable.trade_date, and_(FullFactor.trade_date == UniverseTable.trade_date,
FullFactor.code == UniverseTable.code, FullFactor.code == UniverseTable.code,
...@@ -152,14 +152,19 @@ class UniverseNew(object): ...@@ -152,14 +152,19 @@ class UniverseNew(object):
filter_fields = transformer.names filter_fields = transformer.names
pyFinAssert(len(filter_fields) == 1, ValueError, "filter fields can only be 1") pyFinAssert(len(filter_fields) == 1, ValueError, "filter fields can only be 1")
df = transformer.transform('code', df) df = transformer.transform('code', df)
return df[df[filter_fields[0]] == 1].reset_index()[['trade_date', 'code']] df = df[df[filter_fields[0]] == 1].reset_index()[['trade_date', 'code']]
return df
if __name__ == '__main__': if __name__ == '__main__':
from PyFin.api import * from PyFin.api import *
from alphamind.data.engines.sqlengine import SqlEngine from alphamind.data.engines.sqlengine import SqlEngine
engine = SqlEngine() engine = SqlEngine()
universe = UniverseNew('ss', ['hs300']) universe = UniverseNew('ss', ['hs300'])
print(universe.query(engine, '2017-12-21', LAST('closePrice') < 5)) print(universe.query(engine,
start_date='2017-12-21',
end_date='2017-12-25'))
print(universe.query(engine,
dates=['2017-12-21', '2017-12-25']))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment