From f1328d5c5fee1da5907371e292dda3e8344b4b63 Mon Sep 17 00:00:00 2001 From: shiyu1994 Date: Wed, 8 Jun 2022 11:03:10 +0800 Subject: [PATCH] Clear split info buffer in cost efficient gradient boosting before every iteration (fix partially #3679) (#5164) * clear split info buffer in cegb_ before every iteration * check nullable of cegb_ in serial_tree_learner.cpp * add a test case for checking the split buffer in CEGB * swith to Threading::For instead of raw OpenMP * apply review suggestions * apply review comments * remove device cpu --- .../cost_effective_gradient_boosting.hpp | 14 ++++++ src/treelearner/serial_tree_learner.cpp | 4 ++ tests/python_package_test/test_engine.py | 45 +++++++++++++++++++ 3 files changed, 63 insertions(+) diff --git a/src/treelearner/cost_effective_gradient_boosting.hpp b/src/treelearner/cost_effective_gradient_boosting.hpp index d66f2f4f92ad..4c29deb82de4 100644 --- a/src/treelearner/cost_effective_gradient_boosting.hpp +++ b/src/treelearner/cost_effective_gradient_boosting.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include @@ -32,6 +33,7 @@ class CostEfficientGradientBoosting { return true; } } + void Init() { auto train_data = tree_learner_->train_data_; if (!init_) { @@ -63,6 +65,17 @@ class CostEfficientGradientBoosting { } init_ = true; } + + void BeforeTrain() { + // clear the splits in splits_per_leaf_ + Threading::For(0, splits_per_leaf_.size(), 1024, + [this] (int /*thread_index*/, size_t start, size_t end) { + for (size_t i = start; i < end; ++i) { + splits_per_leaf_[i].Reset(); + } + }); + } + double DeltaGain(int feature_index, int real_fidx, int leaf_index, int num_data_in_leaf, SplitInfo split_info) { auto config = tree_learner_->config_; @@ -82,6 +95,7 @@ class CostEfficientGradientBoosting { feature_index] = split_info; return delta; } + void UpdateLeafBestSplits(Tree* tree, int best_leaf, const SplitInfo* best_split_info, std::vector* best_split_per_leaf) { diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp index 6fde7e6639bf..402889d3a561 100644 --- a/src/treelearner/serial_tree_learner.cpp +++ b/src/treelearner/serial_tree_learner.cpp @@ -278,6 +278,10 @@ void SerialTreeLearner::BeforeTrain() { } larger_leaf_splits_->Init(); + + if (cegb_ != nullptr) { + cegb_->BeforeTrain(); + } } bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) { diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 852737bcbc21..0aa27349b677 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -3578,3 +3578,48 @@ def test_boost_from_average_with_single_leaf_trees(): preds = model.predict(X) mean_preds = np.mean(preds) assert y.min() <= mean_preds <= y.max() + + +def test_cegb_split_buffer_clean(): + # modified from https://github.com/microsoft/LightGBM/issues/3679#issuecomment-938652811 + # and https://github.com/microsoft/LightGBM/pull/5087 + # test that the ``splits_per_leaf_`` of CEGB is cleaned before training a new tree + # which is done in the fix #5164 + # without the fix: + # Check failed: (best_split_info.left_count) > (0) + + R, C = 1000, 100 + seed = 29 + np.random.seed(seed) + data = np.random.randn(R, C) + for i in range(1, C): + data[i] += data[0] * np.random.randn() + + N = int(0.8 * len(data)) + train_data = data[:N] + test_data = data[N:] + train_y = np.sum(train_data, axis=1) + test_y = np.sum(test_data, axis=1) + + train = lgb.Dataset(train_data, train_y, free_raw_data=True) + + params = { + 'boosting_type': 'gbdt', + 'objective': 'regression', + 'max_bin': 255, + 'num_leaves': 31, + 'seed': 0, + 'learning_rate': 0.1, + 'min_data_in_leaf': 0, + 'verbose': -1, + 'min_split_gain': 1000.0, + 'cegb_penalty_feature_coupled': 5 * np.arange(C), + 'cegb_penalty_split': 0.0002, + 'cegb_tradeoff': 10.0, + 'force_col_wise': True, + } + + model = lgb.train(params, train, num_boost_round=10) + predicts = model.predict(test_data) + rmse = np.sqrt(mean_squared_error(test_y, predicts)) + assert rmse < 10.0