Commit 3697dcb5 authored by Dr.李's avatar Dr.李

change the implementation of groupby

parent 34497c4c
...@@ -10,6 +10,14 @@ import numpy as np ...@@ -10,6 +10,14 @@ import numpy as np
import numba as nb import numba as nb
def groupby(groups):
a = np.arange(groups.shape[0])
order_group_idx = groups.argsort()
counts = np.bincount(groups)
ret = np.split(a[order_group_idx], np.cumsum(counts)[:-1])
return ret
@nb.njit(nogil=True, cache=True) @nb.njit(nogil=True, cache=True)
def simple_sum(x, axis=0): def simple_sum(x, axis=0):
length, width = x.shape length, width = x.shape
......
...@@ -12,7 +12,7 @@ from numpy.linalg import solve ...@@ -12,7 +12,7 @@ from numpy.linalg import solve
from typing import Tuple from typing import Tuple
from typing import Union from typing import Union
from typing import Dict from typing import Dict
from alphamind.groupby import groupby from alphamind.aggregate import groupby
def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_explained=False, output_exposure=False) \ def neutralize(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None, output_explained=False, output_exposure=False) \
......
...@@ -9,31 +9,9 @@ Created on 2017-4-26 ...@@ -9,31 +9,9 @@ Created on 2017-4-26
import numpy as np import numpy as np
from numpy import zeros from numpy import zeros
cimport numpy as np cimport numpy as np
from numpy import array
cimport cython cimport cython
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.initializedcheck(False)
cpdef groupby(long[:] groups):
cdef size_t length = groups.shape[0]
cdef dict group_ids = {}
cdef size_t i
cdef long curr_tag
for i in range(length):
curr_tag = groups[i]
try:
group_ids[curr_tag].append(i)
except KeyError:
group_ids[curr_tag] = [i]
return [array(v) for v in group_ids.values()]
@cython.boundscheck(False) @cython.boundscheck(False)
@cython.wraparound(False) @cython.wraparound(False)
@cython.initializedcheck(False) @cython.initializedcheck(False)
......
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
import numba as nb import numba as nb
from numpy import zeros from numpy import zeros
from numpy import zeros_like from numpy import zeros_like
from alphamind.groupby import groupby from alphamind.aggregate import groupby
@nb.njit(nogil=True, cache=True) @nb.njit(nogil=True, cache=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment