Commit 1a454a09 authored by Dr.李's avatar Dr.李

remove the use of array initialize in groupby

parent 83b11bc0
......@@ -51,7 +51,7 @@ cpdef list groupby(long[:] groups):
deref(it).second.push_back(i)
for v in group_ids.values():
res.append(array(v, dtype=np.int64))
res.append(v)
return res
......
......@@ -13,20 +13,22 @@ cimport cython
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.initializedcheck(False)
cpdef void set_value_bool(unsigned char[:, :] mat, long long[:, :] index):
cpdef void set_value_bool(unsigned char[:, :] mat, list index, long long[:, :] used_level):
cdef size_t length = index.shape[0]
cdef size_t width = index.shape[1]
cdef size_t length = used_level.shape[0]
cdef size_t width = used_level.shape[1]
cdef size_t i
cdef size_t j
cdef unsigned char* mat_ptr = &mat[0, 0]
cdef long long* index_ptr = &index[0, 0]
cdef long long* used_level_ptr = &used_level[0, 0]
cdef size_t k
cdef size_t l
for i in range(length):
k = i * width
for j in range(width):
mat_ptr[index_ptr[k + j] * width + j] = True
l = index[used_level_ptr[k + j]]
mat_ptr[l * width + j] = True
@cython.boundscheck(False)
......
......@@ -21,15 +21,16 @@ def rank_build(er: np.ndarray, use_rank: int, groups: np.ndarray=None) -> np.nda
weights = zeros((length, 1))
if groups is not None:
group_ids = groupby(groups)
masks = zeros(length, dtype=bool)
masks = zeros((length, 1), dtype=bool)
for current_index in group_ids:
current_ordering = neg_er[current_index].argsort()
masks[current_index[current_ordering[:use_rank]]] = True
current_ordering.shape = -1, 1
set_value_bool(masks.view(dtype=np.uint8), current_index, current_ordering[:use_rank])
weights[masks] = 1.
else:
ordering = neg_er.argsort()
weights[ordering[:use_rank]] = 1.
return weights
return weights.reshape(er.shape)
else:
length = er.shape[0]
width = er.shape[1]
......@@ -41,9 +42,7 @@ def rank_build(er: np.ndarray, use_rank: int, groups: np.ndarray=None) -> np.nda
masks = zeros((length, width), dtype=bool)
for current_index in group_ids:
current_ordering = neg_er[current_index].argsort(axis=0)
total_index = current_index[current_ordering[:use_rank]]
set_value_bool(masks.view(dtype=np.uint8), total_index)
set_value_bool(masks.view(dtype=np.uint8), current_index, current_ordering[:use_rank])
for j in range(width):
weights[masks[:, j], j] = 1.
else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment