Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] CUDA Quantized Training (fixes #5606) #5933

Merged
merged 52 commits into from
Oct 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
8187759
add quantized training (first stage)
shiyu1994 Dec 22, 2022
8480873
Merge remote-tracking branch 'origin/master' into quantized-training
shiyu1994 Mar 23, 2023
9e5d46b
add histogram construction functions for integer gradients
shiyu1994 Mar 23, 2023
dd2a3b4
add stochastic rounding
shiyu1994 Mar 23, 2023
41c6c79
update docs
shiyu1994 Mar 23, 2023
dfb5bc4
fix compilation errors by adding template instantiations
shiyu1994 Mar 23, 2023
d830128
update files for compilation
shiyu1994 Mar 23, 2023
e82675b
fix compilation of gpu version
shiyu1994 Mar 23, 2023
1d68e97
initialize gradient discretizer before share states
shiyu1994 Mar 30, 2023
4ccdf34
Merge remote-tracking branch 'origin/master' into quantized-training
shiyu1994 Mar 30, 2023
27dbf8c
add a test case for quantized training
shiyu1994 Apr 5, 2023
5c8aac1
add quantized training for data distributed training
shiyu1994 Apr 5, 2023
1fd115a
Delete origin.pred
shiyu1994 Apr 6, 2023
197b394
Delete ifelse.pred
shiyu1994 Apr 6, 2023
7140bb8
Delete LightGBM_model.txt
shiyu1994 Apr 6, 2023
1f142d5
remove useless changes
shiyu1994 Apr 6, 2023
22a98b7
Merge remote-tracking branch 'origin/master' into quantized-training
shiyu1994 Apr 19, 2023
bc848a0
Merge branch 'quantized-training' of https://github.com/Microsoft/Lig…
shiyu1994 Apr 19, 2023
d5fc93d
fix lint error
shiyu1994 Apr 19, 2023
ed066d0
remove debug loggings
shiyu1994 Apr 19, 2023
06826f0
fix mismatch of vector and allocator types
shiyu1994 Apr 24, 2023
025ad39
remove changes in main.cpp
shiyu1994 Apr 25, 2023
baef468
fix bugs with uninitialized gradient discretizer
shiyu1994 Apr 25, 2023
ce93015
initialize ordered gradients in gradient discretizer
shiyu1994 Apr 25, 2023
2b1118c
disable quantized training with gpu and cuda
shiyu1994 Apr 25, 2023
487f2c4
fix bug in data parallel tree learner
shiyu1994 Apr 26, 2023
8c0e67b
make quantized training test deterministic
shiyu1994 Apr 26, 2023
6a76fde
make quantized training in test case more accurate
shiyu1994 Apr 26, 2023
0812403
refactor test_quantized_training
shiyu1994 Apr 26, 2023
9c8894b
fix leaf splits initialization with quantized training
shiyu1994 May 4, 2023
788e1aa
check distributed quantized training result
shiyu1994 May 5, 2023
bf759a9
add cuda gradient discretizer
shiyu1994 Jun 16, 2023
ba20a6d
Merge branch 'master' into cuda-quantized-training
shiyu1994 Jun 16, 2023
3593b2c
Merge remote-tracking branch 'origin/master' into cuda-quantized-trai…
shiyu1994 Jun 27, 2023
d7298a7
add quantized training for CUDA version in tree learner
shiyu1994 Jun 28, 2023
3ee27f2
remove cuda computability 6.1 and 6.2
shiyu1994 Jul 17, 2023
deace6c
Merge branch 'master' into cuda-quantized-training
shiyu1994 Jul 18, 2023
48df866
fix parts of gpu quantized training errors and warnings
shiyu1994 Jul 20, 2023
1d4b02e
fix build-python.sh to install locally built version
shiyu1994 Aug 8, 2023
a8ebc93
Merge branch 'master' into cuda-quantized-training
shiyu1994 Aug 8, 2023
3eaa652
fix memory access bugs
shiyu1994 Aug 9, 2023
cf12051
fix lint errors
shiyu1994 Aug 9, 2023
043fbcb
mark cuda quantized training on cuda with categorical features as uns…
shiyu1994 Aug 9, 2023
89357e5
rename cuda_utils.h to cuda_utils.hu
shiyu1994 Aug 9, 2023
c7c5d57
enable quantized training with cuda
shiyu1994 Aug 9, 2023
6e3a271
Merge branch 'master' into cuda-quantized-training
shiyu1994 Sep 12, 2023
800a378
fix cuda quantized training with sparse row data
shiyu1994 Sep 12, 2023
6b687b0
allow using global memory buffer in histogram construction with cuda …
shiyu1994 Sep 12, 2023
bd5935d
Merge branch 'master' into cuda-quantized-training
shiyu1994 Sep 30, 2023
811e729
Merge branch 'master' into cuda-quantized-training
shiyu1994 Sep 30, 2023
2cb1abb
recover build-python.sh
shiyu1994 Oct 1, 2023
b71fc86
Merge branch 'master' into cuda-quantized-training
shiyu1994 Oct 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/check_python_dists.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ if [ $PY_MINOR_VER -gt 7 ]; then
pydistcheck \
--inspect \
--ignore 'compiled-objects-have-debug-symbols,distro-too-large-compressed' \
--max-allowed-size-uncompressed '70M' \
--max-allowed-size-uncompressed '100M' \
--max-allowed-files 800 \
${DIST_DIR}/* || exit -1
elif { test $(uname -m) = "aarch64"; }; then
Expand Down
2 changes: 1 addition & 1 deletion include/LightGBM/cuda/cuda_algorithms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <stdio.h>

#include <LightGBM/bin.h>
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/utils/log.h>

#include <algorithm>
Expand Down
2 changes: 1 addition & 1 deletion include/LightGBM/cuda/cuda_column_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#define LIGHTGBM_CUDA_CUDA_COLUMN_DATA_HPP_

#include <LightGBM/config.h>
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/bin.h>
#include <LightGBM/utils/openmp_wrapper.h>

Expand Down
2 changes: 1 addition & 1 deletion include/LightGBM/cuda/cuda_metadata.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#ifndef LIGHTGBM_CUDA_CUDA_METADATA_HPP_
#define LIGHTGBM_CUDA_CUDA_METADATA_HPP_

#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/meta.h>

#include <vector>
Expand Down
2 changes: 1 addition & 1 deletion include/LightGBM/cuda/cuda_metric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

#ifdef USE_CUDA

#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/metric.h>

namespace LightGBM {
Expand Down
2 changes: 1 addition & 1 deletion include/LightGBM/cuda/cuda_objective_function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

#ifdef USE_CUDA

#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/objective_function.h>
#include <LightGBM/meta.h>

Expand Down
2 changes: 1 addition & 1 deletion include/LightGBM/cuda/cuda_row_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

#include <LightGBM/bin.h>
#include <LightGBM/config.h>
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/dataset.h>
#include <LightGBM/train_share_states.h>
#include <LightGBM/utils/openmp_wrapper.h>
Expand Down
2 changes: 2 additions & 0 deletions include/LightGBM/cuda/cuda_split_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ class CUDASplitInfo {

double left_sum_gradients;
double left_sum_hessians;
int64_t left_sum_of_gradients_hessians;
data_size_t left_count;
double left_gain;
double left_value;

double right_sum_gradients;
double right_sum_hessians;
int64_t right_sum_of_gradients_hessians;
data_size_t right_count;
double right_gain;
double right_value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,21 @@
#define LIGHTGBM_CUDA_CUDA_UTILS_H_

#ifdef USE_CUDA

#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>

#include <LightGBM/utils/log.h>

#include <algorithm>
#include <vector>
#include <cmath>

namespace LightGBM {

typedef unsigned long long atomic_add_long_t;

#define CUDASUCCESS_OR_FATAL(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) {
if (code != cudaSuccess) {
Expand Down Expand Up @@ -125,13 +131,19 @@ class CUDAVector {
T* new_data = nullptr;
AllocateCUDAMemory<T>(&new_data, size, __FILE__, __LINE__);
if (size_ > 0 && data_ != nullptr) {
CopyFromCUDADeviceToCUDADevice<T>(new_data, data_, size, __FILE__, __LINE__);
const size_t size_for_old_content = std::min<size_t>(size_, size);
CopyFromCUDADeviceToCUDADevice<T>(new_data, data_, size_for_old_content, __FILE__, __LINE__);
}
DeallocateCUDAMemory<T>(&data_, __FILE__, __LINE__);
data_ = new_data;
size_ = size;
}

void InitFromHostVector(const std::vector<T>& host_vector) {
Resize(host_vector.size());
CopyFromHostToCUDADevice(data_, host_vector.data(), host_vector.size(), __FILE__, __LINE__);
}

void Clear() {
if (size_ > 0 && data_ != nullptr) {
DeallocateCUDAMemory<T>(&data_, __FILE__, __LINE__);
Expand Down Expand Up @@ -171,6 +183,10 @@ class CUDAVector {
return data_;
}

void SetValue(int value) {
SetCUDAMemory<T>(data_, value, size_, __FILE__, __LINE__);
}

const T* RawDataReadOnly() const {
return data_;
}
Expand Down
2 changes: 1 addition & 1 deletion include/LightGBM/sample_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#ifndef LIGHTGBM_SAMPLE_STRATEGY_H_
#define LIGHTGBM_SAMPLE_STRATEGY_H_

#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/utils/random.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/threading.h>
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/cuda/cuda_score_updater.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#ifdef USE_CUDA

#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>

#include "../score_updater.hpp"

Expand Down
2 changes: 1 addition & 1 deletion src/cuda/cuda_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#ifdef USE_CUDA

#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>

namespace LightGBM {

Expand Down
4 changes: 0 additions & 4 deletions src/io/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,6 @@ void Config::CheckParamConflict() {
if (deterministic) {
Log::Warning("Although \"deterministic\" is set, the results ran by GPU may be non-deterministic.");
}
if (use_quantized_grad) {
Log::Warning("Quantized training is not supported by CUDA tree learner. Switch to full precision training.");
use_quantized_grad = false;
}
}
// linear tree learner must be serial type and run on CPU device
if (linear_tree) {
Expand Down
2 changes: 1 addition & 1 deletion src/metric/cuda/cuda_binary_metric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#ifdef USE_CUDA

#include <LightGBM/cuda/cuda_metric.hpp>
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>

#include <vector>

Expand Down
2 changes: 1 addition & 1 deletion src/metric/cuda/cuda_pointwise_metric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#ifdef USE_CUDA

#include <LightGBM/cuda/cuda_metric.hpp>
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>

#include <vector>

Expand Down
2 changes: 1 addition & 1 deletion src/metric/cuda/cuda_regression_metric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#ifdef USE_CUDA

#include <LightGBM/cuda/cuda_metric.hpp>
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>

#include <vector>

Expand Down
19 changes: 16 additions & 3 deletions src/treelearner/cuda/cuda_best_split_finder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ CUDABestSplitFinder::CUDABestSplitFinder(
select_features_by_node_(select_features_by_node),
cuda_hist_(cuda_hist) {
InitFeatureMetaInfo(train_data);
if (has_categorical_feature_ && config->use_quantized_grad) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Link #6119

Log::Fatal("Quantized training on GPU with categorical features is not supported yet.");
}
cuda_leaf_best_split_info_ = nullptr;
cuda_best_split_info_ = nullptr;
cuda_best_split_info_buffer_ = nullptr;
Expand Down Expand Up @@ -326,13 +329,23 @@ void CUDABestSplitFinder::FindBestSplitsForLeaf(
const data_size_t num_data_in_smaller_leaf,
const data_size_t num_data_in_larger_leaf,
const double sum_hessians_in_smaller_leaf,
const double sum_hessians_in_larger_leaf) {
const double sum_hessians_in_larger_leaf,
const score_t* grad_scale,
const score_t* hess_scale,
const uint8_t smaller_num_bits_in_histogram_bins,
const uint8_t larger_num_bits_in_histogram_bins) {
const bool is_smaller_leaf_valid = (num_data_in_smaller_leaf > min_data_in_leaf_ &&
sum_hessians_in_smaller_leaf > min_sum_hessian_in_leaf_);
const bool is_larger_leaf_valid = (num_data_in_larger_leaf > min_data_in_leaf_ &&
sum_hessians_in_larger_leaf > min_sum_hessian_in_leaf_ && larger_leaf_index >= 0);
LaunchFindBestSplitsForLeafKernel(smaller_leaf_splits, larger_leaf_splits,
smaller_leaf_index, larger_leaf_index, is_smaller_leaf_valid, is_larger_leaf_valid);
if (grad_scale != nullptr && hess_scale != nullptr) {
LaunchFindBestSplitsDiscretizedForLeafKernel(smaller_leaf_splits, larger_leaf_splits,
smaller_leaf_index, larger_leaf_index, is_smaller_leaf_valid, is_larger_leaf_valid,
grad_scale, hess_scale, smaller_num_bits_in_histogram_bins, larger_num_bits_in_histogram_bins);
} else {
LaunchFindBestSplitsForLeafKernel(smaller_leaf_splits, larger_leaf_splits,
smaller_leaf_index, larger_leaf_index, is_smaller_leaf_valid, is_larger_leaf_valid);
}
global_timer.Start("CUDABestSplitFinder::LaunchSyncBestSplitForLeafKernel");
LaunchSyncBestSplitForLeafKernel(smaller_leaf_index, larger_leaf_index, is_smaller_leaf_valid, is_larger_leaf_valid);
SynchronizeCUDADevice(__FILE__, __LINE__);
Expand Down
Loading