Commit 706f4e19 authored by Dr.李's avatar Dr.李

made neutralize workable with WLS

parent 0b987f3e
......@@ -13,12 +13,20 @@ from typing import Dict
import alphamind.utilities as utils
def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_explained=False, output_exposure=False) \
def neutralize(x: np.ndarray,
y: np.ndarray,
groups: np.ndarray=None,
output_explained: bool=False,
output_exposure: bool=False,
weights: np.ndarray=None) \
-> Union[np.ndarray, Tuple[np.ndarray, Dict]]:
if y.ndim == 1:
y = y.reshape((-1, 1))
if weights is None:
weights = np.ones(len(y), dtype=float)
if groups is not None:
res = np.zeros(y.shape)
......@@ -38,7 +46,7 @@ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_exp
start = 0
for diff_loc in index_diff:
curr_idx = order[start:diff_loc + 1]
curr_x, b = _sub_step(x, y, curr_idx, res)
curr_x, b = _sub_step(x, y, weights, curr_idx, res)
if output_exposure:
for i in range(exposure.shape[2]):
exposure[curr_idx, :, i] = b[:, i]
......@@ -47,7 +55,7 @@ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_exp
explained[curr_idx] = ls_explain(curr_x, b)
start = diff_loc + 1
else:
b = ls_fit(x, y)
b = ls_fit(x, y, weights)
res = ls_res(x, y, b)
if output_explained:
......@@ -68,18 +76,19 @@ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_exp
@nb.njit(nogil=True, cache=True)
def _sub_step(x, y, curr_idx, res):
def _sub_step(x, y, w, curr_idx, res):
curr_x = x[curr_idx]
curr_y = y[curr_idx]
b = ls_fit(curr_x, curr_y)
curr_w = w[curr_idx]
b = ls_fit(curr_x, curr_y, curr_w)
res[curr_idx] = ls_res(curr_x, curr_y, b)
return curr_x, b
@nb.njit(nogil=True, cache=True)
def ls_fit(x: np.ndarray, y: np.ndarray) -> np.ndarray:
def ls_fit(x: np.ndarray, y: np.ndarray, w: np.ndarray) -> np.ndarray:
x_bar = x.T
b = np.linalg.solve(x_bar @ x, x_bar @ y)
b = np.linalg.solve(x_bar * w @ x, x_bar * w @ y)
return b
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment