Commit d7f8463a authored by Dr.李's avatar Dr.李

using dict instead of array to avoid overflow problem

parent beb0dcd3
......@@ -21,9 +21,13 @@ class LinearModel(object):
self.model_parameter = _train(x, y, groups)
def predict(self, x, groups=None):
if groups is not None and self.model_parameter.ndim == 2:
if groups is not None and isinstance(self.model_parameter, dict):
names = np.unique(groups)
return multiple_prediction(names, self.model_parameter, x, groups)
pred_v = np.zeros(x.shape[0])
for name in names:
this_param = self.model_parameter[name]
_prediction_group(name, groups, this_param, x, pred_v)
return pred_v
elif self.model_parameter is None:
raise ValueError("linear model is not calibrated yet")
elif groups is None:
......@@ -33,13 +37,9 @@ class LinearModel(object):
@nb.njit(nogil=True, cache=True)
def multiple_prediction(names, model_parames, x, groups):
pred_v = np.zeros(x.shape[0])
for name in names:
this_param = model_parames[name]
def _prediction_group(name, groups, this_param, x, pred_v):
idx = groups == name
pred_v[idx] = x[idx] @ this_param
return pred_v
def _train(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None) -> np.ndarray:
......@@ -47,19 +47,19 @@ def _train(x: np.ndarray, y: np.ndarray, groups: np.ndarray=None) -> np.ndarray:
return ls_fit(x, y)
else:
groups_ids = groupby(groups)
res_beta = np.zeros((max(groups_ids.keys())+1, x.shape[1]))
res_beta = {}
for k, curr_idx in groups_ids.items():
_train_sub_group(x, y, k, curr_idx, res_beta)
res_beta[k] = _train_sub_group(x, y, curr_idx)
return res_beta
@nb.njit(nogil=True, cache=True)
def _train_sub_group(x, y, k, curr_idx, res):
def _train_sub_group(x, y, curr_idx):
curr_x = x[curr_idx]
curr_y = y[curr_idx]
res[k] = ls_fit(curr_x, curr_y)
return ls_fit(curr_x, curr_y)
if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment