Commit fa7f2d68 authored by Dr.李's avatar Dr.李

added n_jobs for xgb trainer

parent 92b23c60
...@@ -98,11 +98,13 @@ class XGBRegressor(ModelBase): ...@@ -98,11 +98,13 @@ class XGBRegressor(ModelBase):
max_depth: int=3, max_depth: int=3,
features: List=None, features: List=None,
formulas: dict = None, formulas: dict = None,
n_jobs: int=1,
**kwargs): **kwargs):
super().__init__(features, formulas=formulas) super().__init__(features, formulas=formulas)
self.impl = XGBRegressorImpl(n_estimators=n_estimators, self.impl = XGBRegressorImpl(n_estimators=n_estimators,
learning_rate=learning_rate, learning_rate=learning_rate,
max_depth=max_depth, max_depth=max_depth,
n_jobs=n_jobs,
**kwargs) **kwargs)
def save(self) -> dict: def save(self) -> dict:
...@@ -134,11 +136,13 @@ class XGBClassifier(ModelBase): ...@@ -134,11 +136,13 @@ class XGBClassifier(ModelBase):
max_depth: int=3, max_depth: int=3,
features: List = None, features: List = None,
formulas: dict = None, formulas: dict = None,
n_jobs: int=1,
**kwargs): **kwargs):
super().__init__(features, formulas=formulas) super().__init__(features, formulas=formulas)
self.impl = XGBClassifierImpl(n_estimators=n_estimators, self.impl = XGBClassifierImpl(n_estimators=n_estimators,
learning_rate=learning_rate, learning_rate=learning_rate,
max_depth=max_depth, max_depth=max_depth,
n_jobs=n_jobs,
**kwargs) **kwargs)
def save(self) -> dict: def save(self) -> dict:
...@@ -177,7 +181,8 @@ class XGBTrainer(ModelBase): ...@@ -177,7 +181,8 @@ class XGBTrainer(ModelBase):
colsample_bytree=1., colsample_bytree=1.,
features: List = None, features: List = None,
formulas: dict = None, formulas: dict = None,
random_state=0, random_state: int=0,
n_jobs: int=1,
**kwargs): **kwargs):
super().__init__(features, formulas=formulas) super().__init__(features, formulas=formulas)
self.params = { self.params = {
...@@ -189,6 +194,7 @@ class XGBTrainer(ModelBase): ...@@ -189,6 +194,7 @@ class XGBTrainer(ModelBase):
'tree_method': tree_method, 'tree_method': tree_method,
'subsample': subsample, 'subsample': subsample,
'colsample_bytree': colsample_bytree, 'colsample_bytree': colsample_bytree,
'nthread': n_jobs,
'seed': random_state 'seed': random_state
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment