Commit d7f8463a authored by Dr.李's avatar Dr.李

using dict instead of array to avoid overflow problem

parent beb0dcd3
...@@ -21,9 +21,13 @@ class LinearModel(object): ...@@ -21,9 +21,13 @@ class LinearModel(object):
self.model_parameter = _train(x, y, groups) self.model_parameter = _train(x, y, groups)
def predict(self, x, groups=None): def predict(self, x, groups=None):
if groups is not None and self.model_parameter.ndim == 2: if groups is not None and isinstance(self.model_parameter, dict):
names = np.unique(groups) names = np.unique(groups)
return multiple_prediction(names, self.model_parameter, x, groups) pred_v = np.zeros(x.shape[0])
for name in names:
this_param = self.model_parameter[name]
_prediction_group(name, groups, this_param, x, pred_v)
return pred_v
elif self.model_parameter is None: elif self.model_parameter is None:
raise ValueError("linear model is not calibrated yet") raise ValueError("linear model is not calibrated yet")
elif groups is None: elif groups is None:
...@@ -33,13 +37,9 @@ class LinearModel(object): ...@@ -33,13 +37,9 @@ class LinearModel(object):
@nb.njit(nogil=True, cache=True) @nb.njit(nogil=True, cache=True)
def multiple_prediction(names, model_parames, x, groups): def _prediction_group(name, groups, this_param, x, pred_v):
pred_v = np.zeros(x.shape[0])
for name in names:
this_param = model_parames[name]
idx = groups == name idx = groups == name
pred_v[idx] = x[idx] @ this_param pred_v[idx] = x[idx] @ this_param
return pred_v
def _train(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None) -> np.ndarray: def _train(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None) -> np.ndarray:
...@@ -47,19 +47,19 @@ def _train(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None) -> np.ndarray: ...@@ -47,19 +47,19 @@ def _train(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None) -> np.ndarray:
return ls_fit(x, y) return ls_fit(x, y)
else: else:
groups_ids = groupby(groups) groups_ids = groupby(groups)
res_beta = np.zeros((max(groups_ids.keys())+1, x.shape[1])) res_beta = {}
for k, curr_idx in groups_ids.items(): for k, curr_idx in groups_ids.items():
_train_sub_group(x, y, k, curr_idx, res_beta) res_beta[k] = _train_sub_group(x, y, curr_idx)
return res_beta return res_beta
@nb.njit(nogil=True, cache=True) @nb.njit(nogil=True, cache=True)
def _train_sub_group(x, y, k, curr_idx, res): def _train_sub_group(x, y, curr_idx):
curr_x = x[curr_idx] curr_x = x[curr_idx]
curr_y = y[curr_idx] curr_y = y[curr_idx]
res[k] = ls_fit(curr_x, curr_y) return ls_fit(curr_x, curr_y)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment