Commit 5df30c1b authored by Dr.李's avatar Dr.李

update neutralize

parent 31569ef4
...@@ -10,21 +10,25 @@ from numpy import zeros ...@@ -10,21 +10,25 @@ from numpy import zeros
from numpy.linalg import solve from numpy.linalg import solve
from typing import Tuple from typing import Tuple
from typing import Union from typing import Union
from typing import Dict
from alphamind.aggregate import groupby from alphamind.aggregate import groupby
def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_explained=False) \ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_explained=False, output_exposure=False) \
-> Tuple[np.ndarray, Tuple[Union[np.ndarray, np.ndarray]]]: -> Union[np.ndarray, Tuple[np.ndarray, Dict]]:
if groups is not None: if groups is not None:
res = zeros(y.shape) res = zeros(y.shape)
if y.ndim == 2: if y.ndim == 2:
if output_explained: if output_explained:
explained = zeros(x.shape + (y.shape[1],)) explained = zeros(x.shape + (y.shape[1],))
exposure = zeros(x.shape + (y.shape[1],)) if output_exposure:
exposure = zeros(x.shape + (y.shape[1],))
else: else:
explained = zeros(x.shape) if output_explained:
exposure = zeros(x.shape) explained = zeros(x.shape)
if output_exposure:
exposure = zeros(x.shape)
groups_ids = groupby(groups) groups_ids = groupby(groups)
...@@ -33,24 +37,32 @@ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_exp ...@@ -33,24 +37,32 @@ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_exp
curr_y = y[curr_idx] curr_y = y[curr_idx]
b = ls_fit(x[curr_idx], y[curr_idx]) b = ls_fit(x[curr_idx], y[curr_idx])
res[curr_idx] = ls_res(curr_x, curr_y, b) res[curr_idx] = ls_res(curr_x, curr_y, b)
if exposure.ndim == 3: if output_exposure and exposure.ndim == 3:
for i in range(exposure.shape[2]): for i in range(exposure.shape[2]):
exposure[curr_idx, :, i] = b[:, i] exposure[curr_idx, :, i] = b[:, i]
else: elif output_exposure:
exposure[curr_idx] = b exposure[curr_idx] = b
if output_explained: if output_explained:
explained[curr_idx] = ls_explain(curr_x, b) explained[curr_idx] = ls_explain(curr_x, b)
if output_explained:
return res, (exposure, explained)
else:
return res, (exposure,)
else: else:
b = ls_fit(x, y) b = ls_fit(x, y)
res = ls_res(x, y, b)
if output_explained: if output_explained:
return ls_res(x, y, b), (b, ls_explain(x, b)) explained = ls_explain(x, b)
else: elif output_exposure:
return ls_res(x, y, b), (b,) exposure = b
output_dict = {}
if output_explained:
output_dict['explained'] = explained
elif output_exposure:
output_dict['exposure'] = exposure
if output_dict:
return res, output_dict
else:
return res
def ls_fit(x: np.ndarray, y: np.ndarray) -> np.ndarray: def ls_fit(x: np.ndarray, y: np.ndarray) -> np.ndarray:
......
...@@ -18,7 +18,7 @@ class TestNeutralize(unittest.TestCase): ...@@ -18,7 +18,7 @@ class TestNeutralize(unittest.TestCase):
y = np.random.randn(3000, 4) y = np.random.randn(3000, 4)
x = np.random.randn(3000, 10) x = np.random.randn(3000, 10)
calc_res, _ = neutralize(x, y) calc_res = neutralize(x, y)
model = LinearRegression(fit_intercept=False) model = LinearRegression(fit_intercept=False)
model.fit(x, y) model.fit(x, y)
...@@ -46,7 +46,7 @@ class TestNeutralize(unittest.TestCase): ...@@ -46,7 +46,7 @@ class TestNeutralize(unittest.TestCase):
y = np.random.randn(3000) y = np.random.randn(3000)
x = np.random.randn(3000, 10) x = np.random.randn(3000, 10)
calc_res, (b, calc_explained) = neutralize(x, y, output_explained=True) calc_res, other_stats = neutralize(x, y, output_explained=True)
model = LinearRegression(fit_intercept=False) model = LinearRegression(fit_intercept=False)
model.fit(x, y) model.fit(x, y)
...@@ -55,12 +55,12 @@ class TestNeutralize(unittest.TestCase): ...@@ -55,12 +55,12 @@ class TestNeutralize(unittest.TestCase):
exp_explained = x * model.coef_.T exp_explained = x * model.coef_.T
np.testing.assert_array_almost_equal(calc_res, exp_res) np.testing.assert_array_almost_equal(calc_res, exp_res)
np.testing.assert_array_almost_equal(calc_explained, exp_explained) np.testing.assert_array_almost_equal(other_stats['explained'], exp_explained)
y = np.random.randn(3000, 4) y = np.random.randn(3000, 4)
x = np.random.randn(3000, 10) x = np.random.randn(3000, 10)
calc_res, (b, calc_explained) = neutralize(x, y, output_explained=True) calc_res, other_stats = neutralize(x, y, output_explained=True)
model = LinearRegression(fit_intercept=False) model = LinearRegression(fit_intercept=False)
model.fit(x, y) model.fit(x, y)
...@@ -70,14 +70,14 @@ class TestNeutralize(unittest.TestCase): ...@@ -70,14 +70,14 @@ class TestNeutralize(unittest.TestCase):
for i in range(y.shape[1]): for i in range(y.shape[1]):
exp_explained = x * model.coef_.T[:, i] exp_explained = x * model.coef_.T[:, i]
np.testing.assert_array_almost_equal(calc_explained[:, :, i], exp_explained) np.testing.assert_array_almost_equal(other_stats['explained'][:, :, i], exp_explained)
def test_neutralize_explain_output_with_group(self): def test_neutralize_explain_output_with_group(self):
y = np.random.randn(3000) y = np.random.randn(3000)
x = np.random.randn(3000, 10) x = np.random.randn(3000, 10)
groups = np.random.randint(30, size=3000) groups = np.random.randint(30, size=3000)
calc_res, (b, calc_explained) = neutralize(x, y, groups, output_explained=True) calc_res, other_stats = neutralize(x, y, groups, output_explained=True)
model = LinearRegression(fit_intercept=False) model = LinearRegression(fit_intercept=False)
for i in range(30): for i in range(30):
...@@ -87,12 +87,12 @@ class TestNeutralize(unittest.TestCase): ...@@ -87,12 +87,12 @@ class TestNeutralize(unittest.TestCase):
exp_res = curr_y - curr_x @ model.coef_.T exp_res = curr_y - curr_x @ model.coef_.T
exp_explained = curr_x * model.coef_.T exp_explained = curr_x * model.coef_.T
np.testing.assert_array_almost_equal(calc_res[groups == i], exp_res) np.testing.assert_array_almost_equal(calc_res[groups == i], exp_res)
np.testing.assert_array_almost_equal(calc_explained[groups == i], exp_explained) np.testing.assert_array_almost_equal(other_stats['explained'][groups == i], exp_explained)
y = np.random.randn(3000, 4) y = np.random.randn(3000, 4)
x = np.random.randn(3000, 10) x = np.random.randn(3000, 10)
calc_res, (b, calc_explained) = neutralize(x, y, groups, output_explained=True) calc_res, other_stats = neutralize(x, y, groups, output_explained=True)
model = LinearRegression(fit_intercept=False) model = LinearRegression(fit_intercept=False)
for i in range(30): for i in range(30):
...@@ -104,7 +104,7 @@ class TestNeutralize(unittest.TestCase): ...@@ -104,7 +104,7 @@ class TestNeutralize(unittest.TestCase):
for j in range(y.shape[1]): for j in range(y.shape[1]):
exp_explained = curr_x * model.coef_.T[:, j] exp_explained = curr_x * model.coef_.T[:, j]
np.testing.assert_array_almost_equal(calc_explained[groups == i, :, j], exp_explained) np.testing.assert_array_almost_equal(other_stats['explained'][groups == i, :, j], exp_explained)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment