Commit 636bc2c2 authored by Dr.李's avatar Dr.李

some change

parent d94862d5
...@@ -40,7 +40,7 @@ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_exp ...@@ -40,7 +40,7 @@ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_exp
for curr_idx in groups_ids: for curr_idx in groups_ids:
curr_x = x[curr_idx] curr_x = x[curr_idx]
curr_y = y[curr_idx] curr_y = y[curr_idx]
b = ls_fit(x[curr_idx], y[curr_idx]) b = ls_fit(curr_x, curr_y)
res[curr_idx] = ls_res(curr_x, curr_y, b) res[curr_idx] = ls_res(curr_x, curr_y, b)
if output_exposure: if output_exposure:
for i in range(exposure.shape[2]): for i in range(exposure.shape[2]):
...@@ -83,8 +83,9 @@ def ls_res(x: np.ndarray, y: np.ndarray, b: np.ndarray) -> np.ndarray: ...@@ -83,8 +83,9 @@ def ls_res(x: np.ndarray, y: np.ndarray, b: np.ndarray) -> np.ndarray:
@nb.njit(nogil=True, cache=True) @nb.njit(nogil=True, cache=True)
def ls_explain(x: np.ndarray, b: np.ndarray) -> np.ndarray: def ls_explain(x: np.ndarray, b: np.ndarray) -> np.ndarray:
explained = np.zeros(x.shape + (b.shape[1],)) n = b.shape[1]
for i in range(b.shape[1]): explained = np.zeros(x.shape + (n,))
for i in range(n):
explained[:, :, i] = b[:, i] * x explained[:, :, i] = b[:, i] * x
return explained return explained
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment