Commit 30e7eeff authored by Dr.李's avatar Dr.李

update models

parent a6e9e0c8
...@@ -173,6 +173,7 @@ class XGBTrainer(ModelBase): ...@@ -173,6 +173,7 @@ class XGBTrainer(ModelBase):
subsample=1., subsample=1.,
colsample_bytree=1., colsample_bytree=1.,
features: List = None, features: List = None,
random_state=0,
**kwargs): **kwargs):
super().__init__(features) super().__init__(features)
self.params = { self.params = {
...@@ -183,13 +184,15 @@ class XGBTrainer(ModelBase): ...@@ -183,13 +184,15 @@ class XGBTrainer(ModelBase):
'booster': booster, 'booster': booster,
'tree_method': tree_method, 'tree_method': tree_method,
'subsample': subsample, 'subsample': subsample,
'colsample_bytree': colsample_bytree 'colsample_bytree': colsample_bytree,
'seed': random_state
} }
self.eval_sample = eval_sample self.eval_sample = eval_sample
self.num_boost_round = n_estimators self.num_boost_round = n_estimators
self.early_stopping_rounds = early_stopping_rounds self.early_stopping_rounds = early_stopping_rounds
self.impl = None self.impl = None
self.kwargs = kwargs
def fit(self, x, y): def fit(self, x, y):
if self.eval_sample: if self.eval_sample:
...@@ -203,12 +206,14 @@ class XGBTrainer(ModelBase): ...@@ -203,12 +206,14 @@ class XGBTrainer(ModelBase):
dtrain=d_train, dtrain=d_train,
num_boost_round=self.num_boost_round, num_boost_round=self.num_boost_round,
evals=[(d_eval, 'eval')], evals=[(d_eval, 'eval')],
verbose_eval=False) verbose_eval=False,
**self.kwargs)
else: else:
d_train = xgb.DMatrix(x, y) d_train = xgb.DMatrix(x, y)
self.impl = xgb.train(params=self.params, self.impl = xgb.train(params=self.params,
dtrain=d_train, dtrain=d_train,
num_boost_round=self.num_boost_round) num_boost_round=self.num_boost_round,
**self.kwargs)
self.trained_time = arrow.now().format("YYYY-MM-DD HH:mm:ss") self.trained_time = arrow.now().format("YYYY-MM-DD HH:mm:ss")
......
...@@ -31,6 +31,7 @@ class TestTreeModel(unittest.TestCase): ...@@ -31,6 +31,7 @@ class TestTreeModel(unittest.TestCase):
sample_x = np.random.randn(100, 10) sample_x = np.random.randn(100, 10)
np.testing.assert_array_almost_equal(model.predict(sample_x), new_model.predict(sample_x)) np.testing.assert_array_almost_equal(model.predict(sample_x), new_model.predict(sample_x))
np.testing.assert_array_almost_equal(model.importances, new_model.importances)
def test_random_forest_classify_persistence(self): def test_random_forest_classify_persistence(self):
model = RandomForestClassifier(features=list(range(10))) model = RandomForestClassifier(features=list(range(10)))
...@@ -43,6 +44,7 @@ class TestTreeModel(unittest.TestCase): ...@@ -43,6 +44,7 @@ class TestTreeModel(unittest.TestCase):
sample_x = np.random.randn(100, 10) sample_x = np.random.randn(100, 10)
np.testing.assert_array_almost_equal(model.predict(sample_x), new_model.predict(sample_x)) np.testing.assert_array_almost_equal(model.predict(sample_x), new_model.predict(sample_x))
np.testing.assert_array_almost_equal(model.importances, new_model.importances)
def test_xgb_regress_persistence(self): def test_xgb_regress_persistence(self):
model = XGBRegressor(features=list(range(10))) model = XGBRegressor(features=list(range(10)))
...@@ -54,6 +56,7 @@ class TestTreeModel(unittest.TestCase): ...@@ -54,6 +56,7 @@ class TestTreeModel(unittest.TestCase):
sample_x = np.random.randn(100, 10) sample_x = np.random.randn(100, 10)
np.testing.assert_array_almost_equal(model.predict(sample_x), new_model.predict(sample_x)) np.testing.assert_array_almost_equal(model.predict(sample_x), new_model.predict(sample_x))
np.testing.assert_array_almost_equal(model.importances, new_model.importances)
def test_xgb_classify_persistence(self): def test_xgb_classify_persistence(self):
model = XGBClassifier(features=list(range(10))) model = XGBClassifier(features=list(range(10)))
...@@ -66,8 +69,36 @@ class TestTreeModel(unittest.TestCase): ...@@ -66,8 +69,36 @@ class TestTreeModel(unittest.TestCase):
sample_x = np.random.randn(100, 10) sample_x = np.random.randn(100, 10)
np.testing.assert_array_almost_equal(model.predict(sample_x), new_model.predict(sample_x)) np.testing.assert_array_almost_equal(model.predict(sample_x), new_model.predict(sample_x))
np.testing.assert_array_almost_equal(model.importances, new_model.importances)
def test_xgb_trainer_persisence(self): def test_xgb_trainer_equal_classifier(self):
sample_x = np.random.randn(100, 10)
model1 = XGBClassifier(n_estimators=100,
learning_rate=0.1,
max_depth=3,
features=list(range(10)),
random_state=42)
model2 = XGBTrainer(features=list(range(10)),
objective='reg:logistic',
booster='gbtree',
tree_method='exact',
n_estimators=100,
learning_rate=0.1,
max_depth=3,
random_state=42)
y = np.where(self.y > 0, 1, 0)
model1.fit(self.x, y)
model2.fit(self.x, y)
predict1 = model1.predict(sample_x)
predict2 = model2.predict(sample_x)
predict2 = np.where(predict2 > 0.5, 1., 0.)
np.testing.assert_array_almost_equal(predict1, predict2)
def test_xgb_trainer_persistence(self):
model = XGBTrainer(features=list(range(10)), model = XGBTrainer(features=list(range(10)),
objective='binary:logistic', objective='binary:logistic',
booster='gbtree', booster='gbtree',
...@@ -82,3 +113,4 @@ class TestTreeModel(unittest.TestCase): ...@@ -82,3 +113,4 @@ class TestTreeModel(unittest.TestCase):
sample_x = np.random.randn(100, 10) sample_x = np.random.randn(100, 10)
np.testing.assert_array_almost_equal(model.predict(sample_x), new_model.predict(sample_x)) np.testing.assert_array_almost_equal(model.predict(sample_x), new_model.predict(sample_x))
np.testing.assert_array_almost_equal(model.importances, new_model.importances)
...@@ -4,7 +4,12 @@ cd xgboost ...@@ -4,7 +4,12 @@ cd xgboost
git submodule init git submodule init
git submodule update git submodule update
mkdir build
cd build
cmake ..
make -j4 make -j4
cd ..
cd python-package cd python-package
python setup.py install python setup.py install
...@@ -22,4 +27,4 @@ if [ $? -ne 0 ] ; then ...@@ -22,4 +27,4 @@ if [ $? -ne 0 ] ; then
exit 1 exit 1
fi fi
cd ../.. cd ../..
\ No newline at end of file
Subproject commit bf4367184164e593cd2856ef38f8dd4f8cc76999 Subproject commit a187ed6c8f3aa40b47d5be80667cbbe6a6fd563d
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment