Commit 34497c4c authored by Dr.李's avatar Dr.李

simplified the groupby and group_mapping function

parent 636bc2c2
...@@ -9,13 +9,8 @@ Created on 2017-4-26 ...@@ -9,13 +9,8 @@ Created on 2017-4-26
import numpy as np import numpy as np
from numpy import zeros from numpy import zeros
cimport numpy as np cimport numpy as np
from numpy import array
cimport cython cimport cython
from libcpp.vector cimport vector as cpp_vector
from libcpp.unordered_map cimport unordered_map as cpp_map
from cython.operator cimport dereference as deref
ctypedef long long int64_t
@cython.boundscheck(False) @cython.boundscheck(False)
...@@ -23,23 +18,20 @@ ctypedef long long int64_t ...@@ -23,23 +18,20 @@ ctypedef long long int64_t
@cython.initializedcheck(False) @cython.initializedcheck(False)
cpdef groupby(long[:] groups): cpdef groupby(long[:] groups):
cdef long long length = groups.shape[0] cdef size_t length = groups.shape[0]
cdef cpp_map[long, cpp_vector[int64_t]] group_ids cdef dict group_ids = {}
cdef long long i cdef size_t i
cdef long curr_tag cdef long curr_tag
cdef cpp_map[long, cpp_vector[int64_t]].iterator it
cdef np.ndarray[long long, ndim=1] npy_array
for i in range(length): for i in range(length):
curr_tag = groups[i] curr_tag = groups[i]
it = group_ids.find(curr_tag)
if it == group_ids.end(): try:
group_ids[curr_tag].append(i)
except KeyError:
group_ids[curr_tag] = [i] group_ids[curr_tag] = [i]
else:
deref(it).second.push_back(i)
return [np.array(v) for v in group_ids.values()] return [array(v) for v in group_ids.values()]
@cython.boundscheck(False) @cython.boundscheck(False)
...@@ -48,20 +40,18 @@ cpdef groupby(long[:] groups): ...@@ -48,20 +40,18 @@ cpdef groupby(long[:] groups):
cpdef np.ndarray[long, ndim=1] group_mapping(long[:] groups): cpdef np.ndarray[long, ndim=1] group_mapping(long[:] groups):
cdef size_t length = groups.shape[0] cdef size_t length = groups.shape[0]
cdef np.ndarray[long, ndim=1] res= zeros(length, dtype=long) cdef np.ndarray[long, ndim=1] res= zeros(length, dtype=long)
cdef cpp_map[long, long] current_hold cdef dict current_hold = {}
cdef long curr_tag cdef long curr_tag
cdef long running_tag = -1 cdef long running_tag = -1
cdef size_t i = 0 cdef size_t i
cdef cpp_map[long, long].iterator it
for i in range(length): for i in range(length):
curr_tag = groups[i] curr_tag = groups[i]
it = current_hold.find(curr_tag) try:
if it == current_hold.end(): res[i] = current_hold[curr_tag]
except KeyError:
running_tag += 1 running_tag += 1
res[i] = running_tag res[i] = running_tag
current_hold[curr_tag] = running_tag current_hold[curr_tag] = running_tag
else:
res[i] = deref(it).second
return res return res
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment