Commit d4f35fab authored by Dr.李's avatar Dr.李

added optional output for explained

parent 2c095621
......@@ -6,14 +6,19 @@ Created on 2017-4-25
"""
import numpy as np
from numpy import zeros
from numpy.linalg import solve
from typing import Tuple
from typing import Union
from alphamind.aggregate import groupby
def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None) -> Tuple[np.ndarray, np.ndarray]:
def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, keep_explained=False) \
-> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
if groups is not None:
res = np.zeros(y.shape)
res = zeros(y.shape)
if keep_explained:
explained = zeros((x.shape[1],) + y.shape)
groups_ids = groupby(groups)
for curr_idx in groups_ids:
......@@ -21,10 +26,18 @@ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None) -> Tuple[n
curr_y = y[curr_idx]
b = ls_fit(x[curr_idx], y[curr_idx])
res[curr_idx] = ls_res(curr_x, curr_y, b)
return res
if keep_explained:
explained[curr_idx] = ls_explain(curr_x, curr_y, b)
if keep_explained:
return res, explained
else:
return res
else:
b = ls_fit(x, y)
return ls_res(x, y, b)
if keep_explained:
return ls_res(x, y, b), ls_explain(x, y, b)
else:
return ls_res(x, y, b)
def ls_fit(x: np.ndarray, y: np.ndarray) -> np.ndarray:
......@@ -37,8 +50,19 @@ def ls_res(x: np.ndarray, y: np.ndarray, b: np.ndarray) -> np.ndarray:
return y - x @ b
def ls_explained(x: np.ndarray, y: np.ndarray, b: np.ndarray) -> np.ndarray:
pass
def ls_explain(x: np.ndarray, y: np.ndarray, b: np.ndarray) -> np.ndarray:
if y.ndim == 1:
return y.reshape((-1, 1)) - b * x
else:
n_samples = y.shape[0]
dependends = y.shape[1]
factors = x.shape[1]
explained = zeros((n_samples, factors, dependends))
for i in range(dependends):
this_y = y[:, [i]]
explained[:, :, i] = this_y - b[:, i] * x
return explained
if __name__ == '__main__':
......
......@@ -40,7 +40,9 @@ class TestNeutralize(unittest.TestCase):
curr_y = y[groups == i]
model.fit(curr_x, curr_y)
exp_res = curr_y - curr_x @ model.coef_.T
np.testing.assert_array_almost_equal(calc_res[groups ==i ], exp_res)
np.testing.assert_array_almost_equal(calc_res[groups ==i
], exp_res)
if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment