Commit 706f4e19 authored by Dr.李's avatar Dr.李

made neutralize workable with WLS

parent 0b987f3e
...@@ -13,12 +13,20 @@ from typing import Dict ...@@ -13,12 +13,20 @@ from typing import Dict
import alphamind.utilities as utils import alphamind.utilities as utils
def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_explained=False, output_exposure=False) \ def neutralize(x: np.ndarray,
y: np.ndarray,
groups: np.ndarray=None,
output_explained: bool=False,
output_exposure: bool=False,
weights: np.ndarray=None) \
-> Union[np.ndarray, Tuple[np.ndarray, Dict]]: -> Union[np.ndarray, Tuple[np.ndarray, Dict]]:
if y.ndim == 1: if y.ndim == 1:
y = y.reshape((-1, 1)) y = y.reshape((-1, 1))
if weights is None:
weights = np.ones(len(y), dtype=float)
if groups is not None: if groups is not None:
res = np.zeros(y.shape) res = np.zeros(y.shape)
...@@ -38,7 +46,7 @@ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_exp ...@@ -38,7 +46,7 @@ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_exp
start = 0 start = 0
for diff_loc in index_diff: for diff_loc in index_diff:
curr_idx = order[start:diff_loc + 1] curr_idx = order[start:diff_loc + 1]
curr_x, b = _sub_step(x, y, curr_idx, res) curr_x, b = _sub_step(x, y, weights, curr_idx, res)
if output_exposure: if output_exposure:
for i in range(exposure.shape[2]): for i in range(exposure.shape[2]):
exposure[curr_idx, :, i] = b[:, i] exposure[curr_idx, :, i] = b[:, i]
...@@ -47,7 +55,7 @@ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_exp ...@@ -47,7 +55,7 @@ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_exp
explained[curr_idx] = ls_explain(curr_x, b) explained[curr_idx] = ls_explain(curr_x, b)
start = diff_loc + 1 start = diff_loc + 1
else: else:
b = ls_fit(x, y) b = ls_fit(x, y, weights)
res = ls_res(x, y, b) res = ls_res(x, y, b)
if output_explained: if output_explained:
...@@ -68,18 +76,19 @@ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_exp ...@@ -68,18 +76,19 @@ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_exp
@nb.njit(nogil=True, cache=True) @nb.njit(nogil=True, cache=True)
def _sub_step(x, y, curr_idx, res): def _sub_step(x, y, w, curr_idx, res):
curr_x = x[curr_idx] curr_x = x[curr_idx]
curr_y = y[curr_idx] curr_y = y[curr_idx]
b = ls_fit(curr_x, curr_y) curr_w = w[curr_idx]
b = ls_fit(curr_x, curr_y, curr_w)
res[curr_idx] = ls_res(curr_x, curr_y, b) res[curr_idx] = ls_res(curr_x, curr_y, b)
return curr_x, b return curr_x, b
@nb.njit(nogil=True, cache=True) @nb.njit(nogil=True, cache=True)
def ls_fit(x: np.ndarray, y: np.ndarray) -> np.ndarray: def ls_fit(x: np.ndarray, y: np.ndarray, w: np.ndarray) -> np.ndarray:
x_bar = x.T x_bar = x.T
b = np.linalg.solve(x_bar @ x, x_bar @ y) b = np.linalg.solve(x_bar * w @ x, x_bar * w @ y)
return b return b
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment