Commit fa7f2d68 authored by Dr.李's avatar Dr.李

added n_jobs for xgb trainer

parent 92b23c60
......@@ -98,11 +98,13 @@ class XGBRegressor(ModelBase):
max_depth: int=3,
features: List=None,
formulas: dict = None,
n_jobs: int=1,
**kwargs):
super().__init__(features, formulas=formulas)
self.impl = XGBRegressorImpl(n_estimators=n_estimators,
learning_rate=learning_rate,
max_depth=max_depth,
n_jobs=n_jobs,
**kwargs)
def save(self) -> dict:
......@@ -134,12 +136,14 @@ class XGBClassifier(ModelBase):
max_depth: int=3,
features: List = None,
formulas: dict = None,
n_jobs: int=1,
**kwargs):
super().__init__(features, formulas=formulas)
self.impl = XGBClassifierImpl(n_estimators=n_estimators,
learning_rate=learning_rate,
max_depth=max_depth,
**kwargs)
learning_rate=learning_rate,
max_depth=max_depth,
n_jobs=n_jobs,
**kwargs)
def save(self) -> dict:
model_desc = super().save()
......@@ -177,7 +181,8 @@ class XGBTrainer(ModelBase):
colsample_bytree=1.,
features: List = None,
formulas: dict = None,
random_state=0,
random_state: int=0,
n_jobs: int=1,
**kwargs):
super().__init__(features, formulas=formulas)
self.params = {
......@@ -189,6 +194,7 @@ class XGBTrainer(ModelBase):
'tree_method': tree_method,
'subsample': subsample,
'colsample_bytree': colsample_bytree,
'nthread': n_jobs,
'seed': random_state
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment