Commit fe08443f authored by Dr.李's avatar Dr.李

update api and tests

parent 09337e1b
...@@ -28,6 +28,7 @@ from alphamind.data.standardize import standardize ...@@ -28,6 +28,7 @@ from alphamind.data.standardize import standardize
from alphamind.data.standardize import projection from alphamind.data.standardize import projection
from alphamind.data.neutralize import neutralize from alphamind.data.neutralize import neutralize
from alphamind.data.rank import rank from alphamind.data.rank import rank
from alphamind.data.rank import percentile
from alphamind.data.engines.sqlengine import factor_tables from alphamind.data.engines.sqlengine import factor_tables
from alphamind.data.engines.utilities import industry_list from alphamind.data.engines.utilities import industry_list
...@@ -80,6 +81,7 @@ __all__ = [ ...@@ -80,6 +81,7 @@ __all__ = [
'projection', 'projection',
'neutralize', 'neutralize',
'rank', 'rank',
'percentile',
'factor_tables', 'factor_tables',
'industry_list', 'industry_list',
'fetch_data_package', 'fetch_data_package',
......
...@@ -58,27 +58,23 @@ class Universe(object): ...@@ -58,27 +58,23 @@ class Universe(object):
*or_conditions *or_conditions
) )
and_conditions = []
if self.exclude_universe:
and_conditions.append(UniverseTable.universe.notin_(self.exclude_universe))
return and_( return and_(
query, query,
UniverseTable.trade_date.in_(dates) if dates else UniverseTable.trade_date.between(start_date, end_date), UniverseTable.trade_date.in_(dates) if dates else UniverseTable.trade_date.between(start_date, end_date),
*and_conditions
) )
def query(self, engine, start_date: str = None, end_date: str = None, dates=None) -> pd.DataFrame: def query(self, engine, start_date: str = None, end_date: str = None, dates=None) -> pd.DataFrame:
universe_cond = self._query_statements(start_date, end_date, dates) universe_cond = self._query_statements(start_date, end_date, dates)
if self.filter_cond is None: if self.filter_cond is None and self.exclude_universe is None:
# simple case # simple case
query = select([UniverseTable.trade_date, UniverseTable.code]).where( query = select([UniverseTable.trade_date, UniverseTable.code]).where(
universe_cond universe_cond
).distinct() ).distinct()
return pd.read_sql(query, engine.engine) return pd.read_sql(query, engine.engine)
else: else:
if self.filter_cond is not None:
if isinstance(self.filter_cond, Transformer): if isinstance(self.filter_cond, Transformer):
transformer = self.filter_cond transformer = self.filter_cond
else: else:
...@@ -141,10 +137,10 @@ if __name__ == '__main__': ...@@ -141,10 +137,10 @@ if __name__ == '__main__':
from alphamind.data.engines.sqlengine import SqlEngine from alphamind.data.engines.sqlengine import SqlEngine
engine = SqlEngine() engine = SqlEngine()
universe = Universe('ss', ['ashare_ex'], exclude_universe=['hs300', 'zz500'], special_codes=[603138]) universe = Universe('custom', ['zz800'], exclude_universe=['Bank'])
print(universe.query(engine, print(universe.query(engine,
start_date='2017-12-21', start_date='2018-04-26',
end_date='2017-12-25')) end_date='2018-04-26'))
print(universe.query(engine, print(universe.query(engine,
dates=['2017-12-21', '2017-12-25'])) dates=['2017-12-21', '2017-12-25']))
...@@ -7,6 +7,7 @@ Created on 2017-8-8 ...@@ -7,6 +7,7 @@ Created on 2017-8-8
from typing import Optional from typing import Optional
import numpy as np import numpy as np
from scipy.stats import rankdata
import alphamind.utilities as utils import alphamind.utilities as utils
...@@ -22,8 +23,30 @@ def rank(x: np.ndarray, groups: Optional[np.ndarray]=None) -> np.ndarray: ...@@ -22,8 +23,30 @@ def rank(x: np.ndarray, groups: Optional[np.ndarray]=None) -> np.ndarray:
start = 0 start = 0
for diff_loc in index_diff: for diff_loc in index_diff:
curr_idx = order[start:diff_loc + 1] curr_idx = order[start:diff_loc + 1]
res[curr_idx] = x[curr_idx].argsort(axis=0) res[curr_idx] = rankdata(x[curr_idx]).astype(float) - 1.
start = diff_loc + 1 start = diff_loc + 1
return res return res
else: else:
return x.argsort(axis=0).argsort(axis=0) return (rankdata(x).astype(float) - 1.).reshape((-1, 1))
def percentile(x: np.ndarray, groups: Optional[np.ndarray]=None) -> np.ndarray:
if x.ndim == 1:
x = x.reshape((-1, 1))
if groups is not None:
res = np.zeros(x.shape, dtype=int)
index_diff, order = utils.groupby(groups)
start = 0
for diff_loc in index_diff:
curr_idx = order[start:diff_loc + 1]
curr_values = x[curr_idx]
length = len(curr_values) - 1. if len(curr_values) > 1 else 1.
res[curr_idx] = (rankdata(curr_values).astype(float) - 1.) / length
start = diff_loc + 1
return res
else:
length = len(x) - 1. if len(x) > 1 else 1.
return ((rankdata(x).astype(float) - 1.) / length).reshape((-1, 1))
...@@ -356,7 +356,7 @@ class TestSqlEngine(unittest.TestCase): ...@@ -356,7 +356,7 @@ class TestSqlEngine(unittest.TestCase):
) )
def test_sql_engine_fetch_factor_by_categories(self): def test_sql_engine_fetch_factor_by_categories(self):
ref_date = '2016-08-01' ref_date = self.ref_date
universe = Universe('custom', ['zz500', 'zz1000']) universe = Universe('custom', ['zz500', 'zz1000'])
codes = self.engine.fetch_codes(ref_date, universe) codes = self.engine.fetch_codes(ref_date, universe)
...@@ -377,6 +377,7 @@ class TestSqlEngine(unittest.TestCase): ...@@ -377,6 +377,7 @@ class TestSqlEngine(unittest.TestCase):
expected_rank = df3[['ROE', 'cat']].groupby('cat').transform(lambda x: rankdata(x.values) - 1.) expected_rank = df3[['ROE', 'cat']].groupby('cat').transform(lambda x: rankdata(x.values) - 1.)
expected_rank[np.isnan(df3.ROE)] = np.nan expected_rank[np.isnan(df3.ROE)] = np.nan
expected_rank[np.isnan(df3.ROE)] = np.nan
df3['rank'] = expected_rank['ROE'].values df3['rank'] = expected_rank['ROE'].values
np.testing.assert_array_almost_equal(df3['rank'].values, np.testing.assert_array_almost_equal(df3['rank'].values,
df1['f'].values) df1['f'].values)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment