Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] fix quantized training (fixes #5982) (fixes #5994) #6092

Merged
merged 5 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions src/io/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1278,21 +1278,34 @@ void Dataset::ConstructHistogramsInner(
auto ptr_ordered_grad = gradients;
auto ptr_ordered_hess = hessians;
if (num_used_dense_group > 0) {
if (USE_INDICES) {
if (USE_HESSIAN) {
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
if (USE_QUANT_GRAD) {
int16_t* ordered_gradients_and_hessians = reinterpret_cast<int16_t*>(ordered_gradients);
const int16_t* gradients_and_hessians = reinterpret_cast<const int16_t*>(gradients);
if (USE_INDICES) {
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
for (data_size_t i = 0; i < num_data; ++i) {
ordered_gradients[i] = gradients[data_indices[i]];
ordered_hessians[i] = hessians[data_indices[i]];
ordered_gradients_and_hessians[i] = gradients_and_hessians[data_indices[i]];
}
ptr_ordered_grad = ordered_gradients;
ptr_ordered_hess = ordered_hessians;
} else {
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
for (data_size_t i = 0; i < num_data; ++i) {
ordered_gradients[i] = gradients[data_indices[i]];
ptr_ordered_grad = reinterpret_cast<const score_t*>(ordered_gradients);
ptr_ordered_hess = nullptr;
}
} else {
if (USE_INDICES) {
if (USE_HESSIAN) {
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
for (data_size_t i = 0; i < num_data; ++i) {
ordered_gradients[i] = gradients[data_indices[i]];
ordered_hessians[i] = hessians[data_indices[i]];
}
ptr_ordered_grad = ordered_gradients;
ptr_ordered_hess = ordered_hessians;
} else {
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
for (data_size_t i = 0; i < num_data; ++i) {
ordered_gradients[i] = gradients[data_indices[i]];
}
ptr_ordered_grad = ordered_gradients;
}
ptr_ordered_grad = ordered_gradients;
}
}
OMP_INIT_EX();
Expand Down
19 changes: 19 additions & 0 deletions src/treelearner/leaf_splits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,25 @@ class LeafSplits {
weight_ = weight;
}

/*!
* \brief Init split on current leaf on partial data.
* \param leaf Index of current leaf
* \param data_partition current data partition
* \param sum_gradients
* \param sum_hessians
* \param sum_gradients_and_hessians
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
* \param weight
*/
void Init(int leaf, const DataPartition* data_partition, double sum_gradients,
double sum_hessians, int64_t sum_gradients_and_hessians, double weight) {
leaf_index_ = leaf;
data_indices_ = data_partition->GetIndexOnLeaf(leaf, &num_data_in_leaf_);
sum_gradients_ = sum_gradients;
sum_hessians_ = sum_hessians;
int_sum_gradients_and_hessians_ = sum_gradients_and_hessians;
weight_ = weight;
}

/*!
* \brief Init split on current leaf on partial data.
* \param leaf Index of current leaf
Expand Down
115 changes: 96 additions & 19 deletions src/treelearner/serial_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -841,32 +841,65 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
#endif

// init the leaves that used on next iteration
if (best_split_info.left_count < best_split_info.right_count) {
CHECK_GT(best_split_info.left_count, 0);
smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian,
best_split_info.left_output);
larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian,
best_split_info.right_output);
if (!config_->use_quantized_grad) {
if (best_split_info.left_count < best_split_info.right_count) {
CHECK_GT(best_split_info.left_count, 0);
smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian,
best_split_info.left_output);
larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian,
best_split_info.right_output);
} else {
CHECK_GT(best_split_info.right_count, 0);
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian,
best_split_info.right_output);
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian,
best_split_info.left_output);
}
} else {
CHECK_GT(best_split_info.right_count, 0);
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian,
best_split_info.right_output);
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian,
best_split_info.left_output);
if (best_split_info.left_count < best_split_info.right_count) {
CHECK_GT(best_split_info.left_count, 0);
smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian,
best_split_info.left_sum_gradient_and_hessian,
best_split_info.left_output);
larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian,
best_split_info.right_sum_gradient_and_hessian,
best_split_info.right_output);
} else {
CHECK_GT(best_split_info.right_count, 0);
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian,
best_split_info.right_sum_gradient_and_hessian,
best_split_info.right_output);
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian,
best_split_info.left_sum_gradient_and_hessian,
best_split_info.left_output);
}
}
if (config_->use_quantized_grad && config_->tree_learner != std::string("data")) {
gradient_discretizer_->SetNumBitsInHistogramBin<false>(*left_leaf, *right_leaf,
data_partition_->leaf_count(*left_leaf),
data_partition_->leaf_count(*right_leaf));
}

#ifdef DEBUG
CheckSplit(best_split_info, *left_leaf, *right_leaf);
#endif

auto leaves_need_update = constraints_->Update(
is_numerical_split, *left_leaf, *right_leaf,
best_split_info.monotone_type, best_split_info.right_output,
Expand Down Expand Up @@ -1024,4 +1057,48 @@ std::vector<int8_t> node_used_features = col_sampler_.GetByNode(tree, leaf);
*split = bests[best_idx];
}

#ifdef DEBUG
void SerialTreeLearner::CheckSplit(const SplitInfo& best_split_info, const int left_leaf_index, const int right_leaf_index) {
data_size_t num_data_in_left = 0;
data_size_t num_data_in_right = 0;
const data_size_t* data_indices_in_left = data_partition_->GetIndexOnLeaf(left_leaf_index, &num_data_in_left);
const data_size_t* data_indices_in_right = data_partition_->GetIndexOnLeaf(right_leaf_index, &num_data_in_right);
if (config_->use_quantized_grad) {
int32_t sum_left_gradient = 0;
int32_t sum_left_hessian = 0;
int32_t sum_right_gradient = 0;
int32_t sum_right_hessian = 0;
const int8_t* discretized_grad_and_hess = gradient_discretizer_->discretized_gradients_and_hessians();
for (data_size_t i = 0; i < num_data_in_left; ++i) {
const data_size_t index = data_indices_in_left[i];
sum_left_gradient += discretized_grad_and_hess[2 * index + 1];
sum_left_hessian += discretized_grad_and_hess[2 * index];
}
for (data_size_t i = 0; i < num_data_in_right; ++i) {
const data_size_t index = data_indices_in_right[i];
sum_right_gradient += discretized_grad_and_hess[2 * index + 1];
sum_right_hessian += discretized_grad_and_hess[2 * index];
}
Log::Warning("============================ start leaf split info ============================");
Log::Warning("left_leaf_index = %d, right_leaf_index = %d", left_leaf_index, right_leaf_index);
Log::Warning("num_data_in_left = %d, num_data_in_right = %d", num_data_in_left, num_data_in_right);
Log::Warning("sum_left_gradient = %d, best_split_info->left_sum_gradient_and_hessian.gradient = %d", sum_left_gradient,
static_cast<int32_t>(best_split_info.left_sum_gradient_and_hessian >> 32));
Log::Warning("sum_left_hessian = %d, best_split_info->left_sum_gradient_and_hessian.hessian = %d", sum_left_hessian,
static_cast<int32_t>(best_split_info.left_sum_gradient_and_hessian & 0x00000000ffffffff));
Log::Warning("sum_right_gradient = %d, best_split_info->right_sum_gradient_and_hessian.gradient = %d", sum_right_gradient,
static_cast<int32_t>(best_split_info.right_sum_gradient_and_hessian >> 32));
Log::Warning("sum_right_hessian = %d, best_split_info->right_sum_gradient_and_hessian.hessian = %d", sum_right_hessian,
static_cast<int32_t>(best_split_info.right_sum_gradient_and_hessian & 0x00000000ffffffff));
CHECK_EQ(num_data_in_left, best_split_info.left_count);
CHECK_EQ(num_data_in_right, best_split_info.right_count);
CHECK_EQ(sum_left_gradient, static_cast<int32_t>(best_split_info.left_sum_gradient_and_hessian >> 32))
CHECK_EQ(sum_left_hessian, static_cast<int32_t>(best_split_info.left_sum_gradient_and_hessian & 0x00000000ffffffff));
CHECK_EQ(sum_right_gradient, static_cast<int32_t>(best_split_info.right_sum_gradient_and_hessian >> 32));
CHECK_EQ(sum_right_hessian, static_cast<int32_t>(best_split_info.right_sum_gradient_and_hessian & 0x00000000ffffffff));
Log::Warning("============================ end leaf split info ============================");
}
}
#endif

} // namespace LightGBM
2 changes: 2 additions & 0 deletions src/treelearner/serial_tree_learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ class SerialTreeLearner: public TreeLearner {

std::set<int> FindAllForceFeatures(Json force_split_leaf_setting);

#ifdef DEBUG
void CheckSplit(const SplitInfo& best_split_info, const int left_leaf_index, const int right_leaf_index);
#endif

/*!
* \brief Get the number of data in a leaf
Expand Down
1 change: 0 additions & 1 deletion tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1838,7 +1838,6 @@ def test_distributed_quantized_training(cluster):
'num_grad_quant_bins': 30,
'quant_train_renew_leaf': True,
'verbose': -1,
'force_row_wise': True,
}

quant_dask_classifier = lgb.DaskLGBMRegressor(
Expand Down