Commit 140b4f7c authored by Dr.李's avatar Dr.李

made neutralize workable with groups

parent 5e4c6ca9
...@@ -7,11 +7,23 @@ Created on 2017-4-25 ...@@ -7,11 +7,23 @@ Created on 2017-4-25
import numpy as np import numpy as np
from numpy.linalg import solve from numpy.linalg import solve
from alphamind.aggregate import groupby
def neutralize(x: np.ndarray, y: np.ndarray) -> np.ndarray: def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None) -> np.ndarray:
b = ls_fit(x, y) if groups is not None:
return ls_res(x, y, b) res = np.zeros(y.shape)
groups_ids = groupby(groups)
for curr_idx in groups_ids:
curr_x = x[curr_idx]
curr_y = y[curr_idx]
b = ls_fit(x[curr_idx], y[curr_idx])
res[curr_idx] = ls_res(curr_x, curr_y, b)
return res
else:
b = ls_fit(x, y)
return ls_res(x, y, b)
def ls_fit(x: np.ndarray, y: np.ndarray) -> np.ndarray: def ls_fit(x: np.ndarray, y: np.ndarray) -> np.ndarray:
...@@ -22,3 +34,12 @@ def ls_fit(x: np.ndarray, y: np.ndarray) -> np.ndarray: ...@@ -22,3 +34,12 @@ def ls_fit(x: np.ndarray, y: np.ndarray) -> np.ndarray:
def ls_res(x: np.ndarray, y: np.ndarray, b: np.ndarray) -> np.ndarray: def ls_res(x: np.ndarray, y: np.ndarray, b: np.ndarray) -> np.ndarray:
return y - x @ b return y - x @ b
if __name__ == '__main__':
x = np.random.randn(3000, 3)
y = np.random.randn(3000, 2)
groups = np.random.randint(30, size=3000)
print(neutralize(x, y, groups))
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment