Commit d4f35fab authored by Dr.李's avatar Dr.李

added optional output for explained

parent 2c095621
...@@ -6,14 +6,19 @@ Created on 2017-4-25 ...@@ -6,14 +6,19 @@ Created on 2017-4-25
""" """
import numpy as np import numpy as np
from numpy import zeros
from numpy.linalg import solve from numpy.linalg import solve
from typing import Tuple from typing import Tuple
from typing import Union
from alphamind.aggregate import groupby from alphamind.aggregate import groupby
def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None) -> Tuple[np.ndarray, np.ndarray]: def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, keep_explained=False) \
-> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
if groups is not None: if groups is not None:
res = np.zeros(y.shape) res = zeros(y.shape)
if keep_explained:
explained = zeros((x.shape[1],) + y.shape)
groups_ids = groupby(groups) groups_ids = groupby(groups)
for curr_idx in groups_ids: for curr_idx in groups_ids:
...@@ -21,9 +26,17 @@ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None) -> Tuple[n ...@@ -21,9 +26,17 @@ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None) -> Tuple[n
curr_y = y[curr_idx] curr_y = y[curr_idx]
b = ls_fit(x[curr_idx], y[curr_idx]) b = ls_fit(x[curr_idx], y[curr_idx])
res[curr_idx] = ls_res(curr_x, curr_y, b) res[curr_idx] = ls_res(curr_x, curr_y, b)
if keep_explained:
explained[curr_idx] = ls_explain(curr_x, curr_y, b)
if keep_explained:
return res, explained
else:
return res return res
else: else:
b = ls_fit(x, y) b = ls_fit(x, y)
if keep_explained:
return ls_res(x, y, b), ls_explain(x, y, b)
else:
return ls_res(x, y, b) return ls_res(x, y, b)
...@@ -37,8 +50,19 @@ def ls_res(x: np.ndarray, y: np.ndarray, b: np.ndarray) -> np.ndarray: ...@@ -37,8 +50,19 @@ def ls_res(x: np.ndarray, y: np.ndarray, b: np.ndarray) -> np.ndarray:
return y - x @ b return y - x @ b
def ls_explained(x: np.ndarray, y: np.ndarray, b: np.ndarray) -> np.ndarray: def ls_explain(x: np.ndarray, y: np.ndarray, b: np.ndarray) -> np.ndarray:
pass if y.ndim == 1:
return y.reshape((-1, 1)) - b * x
else:
n_samples = y.shape[0]
dependends = y.shape[1]
factors = x.shape[1]
explained = zeros((n_samples, factors, dependends))
for i in range(dependends):
this_y = y[:, [i]]
explained[:, :, i] = this_y - b[:, i] * x
return explained
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -40,7 +40,9 @@ class TestNeutralize(unittest.TestCase): ...@@ -40,7 +40,9 @@ class TestNeutralize(unittest.TestCase):
curr_y = y[groups == i] curr_y = y[groups == i]
model.fit(curr_x, curr_y) model.fit(curr_x, curr_y)
exp_res = curr_y - curr_x @ model.coef_.T exp_res = curr_y - curr_x @ model.coef_.T
np.testing.assert_array_almost_equal(calc_res[groups ==i ], exp_res) np.testing.assert_array_almost_equal(calc_res[groups ==i
], exp_res)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment