Commit 2dc2692a authored by Dr.李's avatar Dr.李

using dict as const linear model weights

parent a6fdafd0
...@@ -408,7 +408,8 @@ class SqlEngine(object): ...@@ -408,7 +408,8 @@ class SqlEngine(object):
def fetch_benchmark(self, def fetch_benchmark(self,
ref_date: str, ref_date: str,
benchmark: int) -> pd.DataFrame: benchmark: int,
codes: Iterable[int]=None) -> pd.DataFrame:
query = select([IndexComponent.code, (IndexComponent.weight / 100.).label('weight')]).where( query = select([IndexComponent.code, (IndexComponent.weight / 100.).label('weight')]).where(
and_( and_(
IndexComponent.trade_date == ref_date, IndexComponent.trade_date == ref_date,
...@@ -416,7 +417,13 @@ class SqlEngine(object): ...@@ -416,7 +417,13 @@ class SqlEngine(object):
) )
) )
return pd.read_sql(query, self.engine) df = pd.read_sql(query, self.engine)
if codes:
df.set_index(['code'], inplace=True)
df = df.reindex(codes).fillna(0.)
df.reset_index(inplace=True)
return df
def fetch_benchmark_range(self, def fetch_benchmark_range(self,
benchmark: int, benchmark: int,
......
...@@ -19,7 +19,7 @@ from alphamind.utilities import alpha_logger ...@@ -19,7 +19,7 @@ from alphamind.utilities import alpha_logger
class ConstLinearModelImpl(object): class ConstLinearModelImpl(object):
def __init__(self, weights: np.ndarray = None): def __init__(self, weights: np.ndarray = None):
self.weights = np.array(weights).flatten() self.weights = weights.flatten()
def fit(self, x: np.ndarray, y: np.ndarray): def fit(self, x: np.ndarray, y: np.ndarray):
pass pass
...@@ -32,13 +32,14 @@ class ConstLinearModel(ModelBase): ...@@ -32,13 +32,14 @@ class ConstLinearModel(ModelBase):
def __init__(self, def __init__(self,
features=None, features=None,
weights: np.ndarray = None): weights: dict = None):
super().__init__(features) super().__init__(features)
if features is not None and weights is not None: if features is not None and weights is not None:
pyFinAssert(len(features) == len(weights), pyFinAssert(len(features) == len(weights),
ValueError, ValueError,
"length of features is not equal to length of weights") "length of features is not equal to length of weights")
self.impl = ConstLinearModelImpl(weights) if weights:
self.impl = ConstLinearModelImpl(np.array([weights[name] for name in self.features]))
def save(self): def save(self):
model_desc = super().save() model_desc = super().save()
......
...@@ -20,6 +20,7 @@ class TestLinearModel(unittest.TestCase): ...@@ -20,6 +20,7 @@ class TestLinearModel(unittest.TestCase):
def setUp(self): def setUp(self):
self.n = 3 self.n = 3
self.features = ['a', 'b', 'c']
self.train_x = pd.DataFrame(np.random.randn(1000, self.n), columns=['a', 'b', 'c']) self.train_x = pd.DataFrame(np.random.randn(1000, self.n), columns=['a', 'b', 'c'])
self.train_y = np.random.randn(1000) self.train_y = np.random.randn(1000)
self.train_y_label = np.where(self.train_y > 0., 1, 0) self.train_y_label = np.where(self.train_y > 0., 1, 0)
...@@ -27,16 +28,17 @@ class TestLinearModel(unittest.TestCase): ...@@ -27,16 +28,17 @@ class TestLinearModel(unittest.TestCase):
def test_const_linear_model(self): def test_const_linear_model(self):
weights = np.array([1., 2., 3.]) features = ['c', 'b', 'a']
model = ConstLinearModel(features=['a', 'b', 'c'], weights = dict(c=3., b=2., a=1.)
model = ConstLinearModel(features=features,
weights=weights) weights=weights)
calculated_y = model.predict(self.predict_x) calculated_y = model.predict(self.predict_x)
expected_y = self.predict_x @ weights expected_y = self.predict_x[features] @ np.array([weights[f] for f in features])
np.testing.assert_array_almost_equal(calculated_y, expected_y) np.testing.assert_array_almost_equal(calculated_y, expected_y)
def test_const_linear_model_persistence(self): def test_const_linear_model_persistence(self):
weights = np.array([1., 2., 3.]) weights = dict(c=3., b=2., a=1.)
model = ConstLinearModel(features=['a', 'b', 'c'], model = ConstLinearModel(features=['a', 'b', 'c'],
weights=weights) weights=weights)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment