Commit e25b625d authored by Dr.李's avatar Dr.李

simplify codes

parent d533b7a2
......@@ -46,15 +46,6 @@ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_exp
for i in range(explained.shape[2]):
explained[curr_idx] = ls_explain(curr_x, b)
start = diff_loc + 1
curr_idx = order[start:]
curr_x, b = _sub_step(x, y, curr_idx, res)
if output_exposure:
for i in range(exposure.shape[2]):
exposure[curr_idx, :, i] = b[:, i]
if output_explained:
for i in range(explained.shape[2]):
explained[curr_idx] = ls_explain(curr_x, b)
else:
b = ls_fit(x, y)
res = ls_res(x, y, b)
......
......@@ -60,8 +60,6 @@ def _train_loop(index_diff, order, x, y):
for k, diff_loc in enumerate(index_diff):
res_beta[k] = _train_sub_group(x, y, order[start:diff_loc + 1])
start = diff_loc + 1
res_beta[k + 1] = _train_sub_group(x, y, order[start:])
return res_beta
......
......@@ -29,12 +29,6 @@ def percent_build(er: np.ndarray, percent: float, groups: np.ndarray=None) -> np
use_rank = int(percent * len(current_index))
set_value(weights, current_index[current_ordering[:use_rank]], 1.)
start = diff_loc + 1
current_index = order[start:]
current_ordering = neg_er[current_index].argsort()
current_ordering.shape = -1, 1
use_rank = int(percent * len(current_index))
set_value(weights, current_index[current_ordering[:use_rank]], 1.)
else:
ordering = neg_er.argsort()
use_rank = int(percent * len(neg_er))
......@@ -53,10 +47,6 @@ def percent_build(er: np.ndarray, percent: float, groups: np.ndarray=None) -> np
use_rank = int(percent * len(current_index))
set_value(weights, current_index[current_ordering[:use_rank]], 1)
start = diff_loc + 1
current_index = order[start:]
current_ordering = neg_er[current_index].argsort(axis=0)
use_rank = int(percent * len(current_index))
set_value(weights, current_index[current_ordering[:use_rank]], 1)
else:
ordering = neg_er.argsort(axis=0)
use_rank = int(percent * len(neg_er))
......
......@@ -28,11 +28,6 @@ def rank_build(er: np.ndarray, use_rank: int, groups: np.ndarray=None) -> np.nda
current_ordering.shape = -1, 1
set_value(weights, current_index[current_ordering[:use_rank]], 1.)
start = diff_loc + 1
current_index = order[start:]
current_ordering = neg_er[current_index].argsort()
current_ordering.shape = -1, 1
set_value(weights, current_index[current_ordering[:use_rank]], 1.)
else:
ordering = neg_er.argsort()
weights[ordering[:use_rank]] = 1.
......@@ -49,10 +44,6 @@ def rank_build(er: np.ndarray, use_rank: int, groups: np.ndarray=None) -> np.nda
current_ordering = neg_er[current_index].argsort(axis=0)
set_value(weights, current_index[current_ordering[:use_rank]], 1)
start = diff_loc + 1
current_index = order[start:]
current_ordering = neg_er[current_index].argsort(axis=0)
set_value(weights, current_index[current_ordering[:use_rank]], 1)
else:
ordering = neg_er.argsort(axis=0)
set_value(weights, ordering[:use_rank], 1.)
......
......@@ -18,7 +18,7 @@ def groupby(groups):
order = groups.argsort()
t = groups[order]
index_diff = np.where(np.diff(t))[0]
return index_diff, order
return np.concatenate([index_diff, [len(groups)]]), order
@nb.njit(nogil=True, cache=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment