Commit e25b625d authored by Dr.李's avatar Dr.李

simplify codes

parent d533b7a2
...@@ -46,15 +46,6 @@ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_exp ...@@ -46,15 +46,6 @@ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_exp
for i in range(explained.shape[2]): for i in range(explained.shape[2]):
explained[curr_idx] = ls_explain(curr_x, b) explained[curr_idx] = ls_explain(curr_x, b)
start = diff_loc + 1 start = diff_loc + 1
curr_idx = order[start:]
curr_x, b = _sub_step(x, y, curr_idx, res)
if output_exposure:
for i in range(exposure.shape[2]):
exposure[curr_idx, :, i] = b[:, i]
if output_explained:
for i in range(explained.shape[2]):
explained[curr_idx] = ls_explain(curr_x, b)
else: else:
b = ls_fit(x, y) b = ls_fit(x, y)
res = ls_res(x, y, b) res = ls_res(x, y, b)
......
...@@ -60,8 +60,6 @@ def _train_loop(index_diff, order, x, y): ...@@ -60,8 +60,6 @@ def _train_loop(index_diff, order, x, y):
for k, diff_loc in enumerate(index_diff): for k, diff_loc in enumerate(index_diff):
res_beta[k] = _train_sub_group(x, y, order[start:diff_loc + 1]) res_beta[k] = _train_sub_group(x, y, order[start:diff_loc + 1])
start = diff_loc + 1 start = diff_loc + 1
res_beta[k + 1] = _train_sub_group(x, y, order[start:])
return res_beta return res_beta
......
...@@ -29,12 +29,6 @@ def percent_build(er: np.ndarray, percent: float, groups: np.ndarray=None) -> np ...@@ -29,12 +29,6 @@ def percent_build(er: np.ndarray, percent: float, groups: np.ndarray=None) -> np
use_rank = int(percent * len(current_index)) use_rank = int(percent * len(current_index))
set_value(weights, current_index[current_ordering[:use_rank]], 1.) set_value(weights, current_index[current_ordering[:use_rank]], 1.)
start = diff_loc + 1 start = diff_loc + 1
current_index = order[start:]
current_ordering = neg_er[current_index].argsort()
current_ordering.shape = -1, 1
use_rank = int(percent * len(current_index))
set_value(weights, current_index[current_ordering[:use_rank]], 1.)
else: else:
ordering = neg_er.argsort() ordering = neg_er.argsort()
use_rank = int(percent * len(neg_er)) use_rank = int(percent * len(neg_er))
...@@ -53,10 +47,6 @@ def percent_build(er: np.ndarray, percent: float, groups: np.ndarray=None) -> np ...@@ -53,10 +47,6 @@ def percent_build(er: np.ndarray, percent: float, groups: np.ndarray=None) -> np
use_rank = int(percent * len(current_index)) use_rank = int(percent * len(current_index))
set_value(weights, current_index[current_ordering[:use_rank]], 1) set_value(weights, current_index[current_ordering[:use_rank]], 1)
start = diff_loc + 1 start = diff_loc + 1
current_index = order[start:]
current_ordering = neg_er[current_index].argsort(axis=0)
use_rank = int(percent * len(current_index))
set_value(weights, current_index[current_ordering[:use_rank]], 1)
else: else:
ordering = neg_er.argsort(axis=0) ordering = neg_er.argsort(axis=0)
use_rank = int(percent * len(neg_er)) use_rank = int(percent * len(neg_er))
......
...@@ -28,11 +28,6 @@ def rank_build(er: np.ndarray, use_rank: int, groups: np.ndarray=None) -> np.nda ...@@ -28,11 +28,6 @@ def rank_build(er: np.ndarray, use_rank: int, groups: np.ndarray=None) -> np.nda
current_ordering.shape = -1, 1 current_ordering.shape = -1, 1
set_value(weights, current_index[current_ordering[:use_rank]], 1.) set_value(weights, current_index[current_ordering[:use_rank]], 1.)
start = diff_loc + 1 start = diff_loc + 1
current_index = order[start:]
current_ordering = neg_er[current_index].argsort()
current_ordering.shape = -1, 1
set_value(weights, current_index[current_ordering[:use_rank]], 1.)
else: else:
ordering = neg_er.argsort() ordering = neg_er.argsort()
weights[ordering[:use_rank]] = 1. weights[ordering[:use_rank]] = 1.
...@@ -49,10 +44,6 @@ def rank_build(er: np.ndarray, use_rank: int, groups: np.ndarray=None) -> np.nda ...@@ -49,10 +44,6 @@ def rank_build(er: np.ndarray, use_rank: int, groups: np.ndarray=None) -> np.nda
current_ordering = neg_er[current_index].argsort(axis=0) current_ordering = neg_er[current_index].argsort(axis=0)
set_value(weights, current_index[current_ordering[:use_rank]], 1) set_value(weights, current_index[current_ordering[:use_rank]], 1)
start = diff_loc + 1 start = diff_loc + 1
current_index = order[start:]
current_ordering = neg_er[current_index].argsort(axis=0)
set_value(weights, current_index[current_ordering[:use_rank]], 1)
else: else:
ordering = neg_er.argsort(axis=0) ordering = neg_er.argsort(axis=0)
set_value(weights, ordering[:use_rank], 1.) set_value(weights, ordering[:use_rank], 1.)
......
...@@ -18,7 +18,7 @@ def groupby(groups): ...@@ -18,7 +18,7 @@ def groupby(groups):
order = groups.argsort() order = groups.argsort()
t = groups[order] t = groups[order]
index_diff = np.where(np.diff(t))[0] index_diff = np.where(np.diff(t))[0]
return index_diff, order return np.concatenate([index_diff, [len(groups)]]), order
@nb.njit(nogil=True, cache=True) @nb.njit(nogil=True, cache=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment