Skip to content

Commit

Permalink
[ci][fix] Fix cuda_exp ci (#5438)
Browse files Browse the repository at this point in the history
* fix cuda_exp ci

* fix ci failures introduced by #5279

* cleanup cuda.yml

* fix test.sh

* clean up test.sh

* clean up test.sh

* skip lines by cuda_exp in test_register_logger

* Update tests/python_package_test/test_utilities.py

Co-authored-by: Nikita Titov <[email protected]>

Co-authored-by: Nikita Titov <[email protected]>
  • Loading branch information
shiyu1994 and StrikerRUS authored Aug 29, 2022
1 parent ef006b7 commit be7f321
Show file tree
Hide file tree
Showing 13 changed files with 159 additions and 64 deletions.
15 changes: 7 additions & 8 deletions .github/workflows/cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@ on:
env:
github_actions: 'true'
os_name: linux
task: cuda
conda_env: test-env

jobs:
test:
name: ${{ matrix.tree_learner }} ${{ matrix.cuda_version }} ${{ matrix.method }} (linux, ${{ matrix.compiler }}, Python ${{ matrix.python_version }})
name: ${{ matrix.task }} ${{ matrix.cuda_version }} ${{ matrix.method }} (linux, ${{ matrix.compiler }}, Python ${{ matrix.python_version }})
runs-on: [self-hosted, linux]
timeout-minutes: 60
strategy:
Expand All @@ -27,27 +26,27 @@ jobs:
compiler: gcc
python_version: "3.8"
cuda_version: "11.7.1"
tree_learner: cuda
task: cuda
- method: pip
compiler: clang
python_version: "3.9"
cuda_version: "10.0"
tree_learner: cuda
task: cuda
- method: wheel
compiler: gcc
python_version: "3.10"
cuda_version: "9.0"
tree_learner: cuda
task: cuda
- method: source
compiler: gcc
python_version: "3.8"
cuda_version: "11.7.1"
tree_learner: cuda_exp
task: cuda_exp
- method: pip
compiler: clang
python_version: "3.9"
cuda_version: "10.0"
tree_learner: cuda_exp
task: cuda_exp
steps:
- name: Setup or update software on host machine
run: |
Expand Down Expand Up @@ -86,7 +85,7 @@ jobs:
GITHUB_ACTIONS=${{ env.github_actions }}
OS_NAME=${{ env.os_name }}
COMPILER=${{ matrix.compiler }}
TASK=${{ env.task }}
TASK=${{ matrix.task }}
METHOD=${{ matrix.method }}
CONDA_ENV=${{ env.conda_env }}
PYTHON_VERSION=${{ matrix.python_version }}
Expand Down
4 changes: 4 additions & 0 deletions include/LightGBM/cuda/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ class CUDAVector {
return data_;
}

const T* RawDataReadOnly() const {
return data_;
}

private:
T* data_;
size_t size_;
Expand Down
6 changes: 6 additions & 0 deletions include/LightGBM/tree_learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ class TreeLearner {
*/
virtual void ResetConfig(const Config* config) = 0;

/*!
* \brief Reset boosting_on_gpu_
* \param boosting_on_gpu flag for boosting on GPU
*/
virtual void ResetBoostingOnGPU(const bool /*boosting_on_gpu*/) {}

virtual void SetForcedSplit(const Json* forced_split_json) = 0;

/*!
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/cuda/cuda_score_updater.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ CUDAScoreUpdater::CUDAScoreUpdater(const Dataset* data, int num_tree_per_iterati
has_init_score_ = true;
CopyFromHostToCUDADevice<double>(cuda_score_, init_score, total_size, __FILE__, __LINE__);
} else {
SetCUDAMemory<double>(cuda_score_, 0, static_cast<size_t>(num_data_), __FILE__, __LINE__);
SetCUDAMemory<double>(cuda_score_, 0, static_cast<size_t>(total_size), __FILE__, __LINE__);
}
SynchronizeCUDADevice(__FILE__, __LINE__);
if (boosting_on_cuda_) {
Expand Down
8 changes: 3 additions & 5 deletions src/boosting/cuda/cuda_score_updater.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,22 @@ namespace LightGBM {

__global__ void AddScoreConstantKernel(
const double val,
const size_t offset,
const data_size_t num_data,
double* score) {
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
if (data_index < num_data) {
score[data_index + offset] += val;
score[data_index] += val;
}
}

void CUDAScoreUpdater::LaunchAddScoreConstantKernel(const double val, const size_t offset) {
const int num_blocks = (num_data_ + num_threads_per_block_) / num_threads_per_block_;
Log::Debug("Adding init score = %lf", val);
AddScoreConstantKernel<<<num_blocks, num_threads_per_block_>>>(val, offset, num_data_, cuda_score_);
AddScoreConstantKernel<<<num_blocks, num_threads_per_block_>>>(val, num_data_, cuda_score_ + offset);
}

__global__ void MultiplyScoreConstantKernel(
const double val,
const size_t offset,
const data_size_t num_data,
double* score) {
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
Expand All @@ -39,7 +37,7 @@ __global__ void MultiplyScoreConstantKernel(

void CUDAScoreUpdater::LaunchMultiplyScoreConstantKernel(const double val, const size_t offset) {
const int num_blocks = (num_data_ + num_threads_per_block_) / num_threads_per_block_;
MultiplyScoreConstantKernel<<<num_blocks, num_threads_per_block_>>>(val, offset, num_data_, cuda_score_);
MultiplyScoreConstantKernel<<<num_blocks, num_threads_per_block_>>>(val, num_data_, cuda_score_ + offset);
}

} // namespace LightGBM
Expand Down
115 changes: 79 additions & 36 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ GBDT::GBDT()
average_output_ = false;
tree_learner_ = nullptr;
linear_tree_ = false;
gradients_pointer_ = nullptr;
hessians_pointer_ = nullptr;
boosting_on_gpu_ = false;
}

GBDT::~GBDT() {
Expand Down Expand Up @@ -95,9 +98,9 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective

is_constant_hessian_ = GetIsConstHessian(objective_function);

const bool boosting_on_gpu = objective_function_ != nullptr && objective_function_->IsCUDAObjective();
boosting_on_gpu_ = objective_function_ != nullptr && objective_function_->IsCUDAObjective();
tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(config_->tree_learner, config_->device_type,
config_.get(), boosting_on_gpu));
config_.get(), boosting_on_gpu_));

// init tree learner
tree_learner_->Init(train_data_, is_constant_hessian_);
Expand All @@ -112,7 +115,7 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective

#ifdef USE_CUDA_EXP
if (config_->device_type == std::string("cuda_exp")) {
train_score_updater_.reset(new CUDAScoreUpdater(train_data_, num_tree_per_iteration_, boosting_on_gpu));
train_score_updater_.reset(new CUDAScoreUpdater(train_data_, num_tree_per_iteration_, boosting_on_gpu_));
} else {
#endif // USE_CUDA_EXP
train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));
Expand All @@ -123,9 +126,14 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
num_data_ = train_data_->num_data();
// create buffer for gradients and Hessians
if (objective_function_ != nullptr) {
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
const size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
#ifdef USE_CUDA_EXP
if (config_->device_type == std::string("cuda_exp") && boosting_on_gpu) {
if (config_->device_type == std::string("cuda_exp") && boosting_on_gpu_) {
if (gradients_pointer_ != nullptr) {
CHECK_NOTNULL(hessians_pointer_);
DeallocateCUDAMemory<score_t>(&gradients_pointer_, __FILE__, __LINE__);
DeallocateCUDAMemory<score_t>(&hessians_pointer_, __FILE__, __LINE__);
}
AllocateCUDAMemory<score_t>(&gradients_pointer_, total_size, __FILE__, __LINE__);
AllocateCUDAMemory<score_t>(&hessians_pointer_, total_size, __FILE__, __LINE__);
} else {
Expand All @@ -137,17 +145,14 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
#ifdef USE_CUDA_EXP
}
#endif // USE_CUDA_EXP
#ifndef USE_CUDA_EXP
}
#else // USE_CUDA_EXP
} else {
if (config_->device_type == std::string("cuda_exp")) {
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
AllocateCUDAMemory<score_t>(&gradients_pointer_, total_size, __FILE__, __LINE__);
AllocateCUDAMemory<score_t>(&hessians_pointer_, total_size, __FILE__, __LINE__);
}
} else if (config_->boosting == std::string("goss")) {
const size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
gradients_.resize(total_size);
hessians_.resize(total_size);
gradients_pointer_ = gradients_.data();
hessians_pointer_ = hessians_.data();
}
#endif // USE_CUDA_EXP

// get max feature index
max_feature_idx_ = train_data_->num_total_features() - 1;
// get label index
Expand Down Expand Up @@ -440,23 +445,36 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
Boosting();
gradients = gradients_pointer_;
hessians = hessians_pointer_;
#ifndef USE_CUDA_EXP
}
#else // USE_CUDA_EXP
} else {
if (config_->device_type == std::string("cuda_exp")) {
const size_t total_size = static_cast<size_t>(num_data_ * num_class_);
CopyFromHostToCUDADevice<score_t>(gradients_pointer_, gradients, total_size, __FILE__, __LINE__);
CopyFromHostToCUDADevice<score_t>(hessians_pointer_, hessians, total_size, __FILE__, __LINE__);
// use customized objective function
CHECK(objective_function_ == nullptr);
if (config_->boosting == std::string("goss")) {
// need to copy customized gradients when using GOSS
int64_t total_size = static_cast<int64_t>(num_data_) * num_tree_per_iteration_;
#pragma omp parallel for schedule(static)
for (int64_t i = 0; i < total_size; ++i) {
gradients_[i] = gradients[i];
hessians_[i] = hessians[i];
}
CHECK_EQ(gradients_pointer_, gradients_.data());
CHECK_EQ(hessians_pointer_, hessians_.data());
gradients = gradients_pointer_;
hessians = hessians_pointer_;
}
}
#endif // USE_CUDA_EXP

// bagging logic
Bagging(iter_);

if (gradients != nullptr && is_use_subset_ && bag_data_cnt_ < num_data_ && !boosting_on_gpu_ && config_->boosting != std::string("goss")) {
// allocate gradients_ and hessians_ for copy gradients for using data subset
int64_t total_size = static_cast<int64_t>(num_data_) * num_tree_per_iteration_;
gradients_.resize(total_size);
hessians_.resize(total_size);
gradients_pointer_ = gradients_.data();
hessians_pointer_ = hessians_.data();
}

bool should_continue = false;
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
const size_t offset = static_cast<size_t>(cur_tree_id) * num_data_;
Expand All @@ -465,7 +483,7 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
auto grad = gradients + offset;
auto hess = hessians + offset;
// need to copy gradients for bagging subset.
if (is_use_subset_ && bag_data_cnt_ < num_data_ && config_->device_type != std::string("cuda_exp")) {
if (is_use_subset_ && bag_data_cnt_ < num_data_ && !boosting_on_gpu_) {
for (int i = 0; i < bag_data_cnt_; ++i) {
gradients_pointer_[offset + i] = grad[bag_data_indices_[i]];
hessians_pointer_[offset + i] = hess[bag_data_indices_[i]];
Expand Down Expand Up @@ -591,13 +609,12 @@ void GBDT::UpdateScore(const Tree* tree, const int cur_tree_id) {

std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* score) const {
#ifdef USE_CUDA_EXP
const bool boosting_on_cuda = objective_function_ != nullptr && objective_function_->IsCUDAObjective();
const bool evaluation_on_cuda = metric->IsCUDAMetric();
if ((boosting_on_cuda && evaluation_on_cuda) || (!boosting_on_cuda && !evaluation_on_cuda)) {
if ((boosting_on_gpu_ && evaluation_on_cuda) || (!boosting_on_gpu_ && !evaluation_on_cuda)) {
#endif // USE_CUDA_EXP
return metric->Eval(score, objective_function_);
#ifdef USE_CUDA_EXP
} else if (boosting_on_cuda && !evaluation_on_cuda) {
} else if (boosting_on_gpu_ && !evaluation_on_cuda) {
const size_t total_size = static_cast<size_t>(num_data_) * static_cast<size_t>(num_tree_per_iteration_);
if (total_size > host_score_.size()) {
host_score_.resize(total_size, 0.0f);
Expand Down Expand Up @@ -804,17 +821,16 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
}
training_metrics_.shrink_to_fit();

#ifdef USE_CUDA_EXP
const bool boosting_on_gpu = objective_function_ != nullptr && objective_function_->IsCUDAObjective();
#endif // USE_CUDA_EXP
boosting_on_gpu_ = objective_function_ != nullptr && objective_function_->IsCUDAObjective();
tree_learner_->ResetBoostingOnGPU(boosting_on_gpu_);

if (train_data != train_data_) {
train_data_ = train_data;
// not same training data, need reset score and others
// create score tracker
#ifdef USE_CUDA_EXP
if (config_->device_type == std::string("cuda_exp")) {
train_score_updater_.reset(new CUDAScoreUpdater(train_data_, num_tree_per_iteration_, boosting_on_gpu));
train_score_updater_.reset(new CUDAScoreUpdater(train_data_, num_tree_per_iteration_, boosting_on_gpu_));
} else {
#endif // USE_CUDA_EXP
train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));
Expand All @@ -834,9 +850,14 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*

// create buffer for gradients and hessians
if (objective_function_ != nullptr) {
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
const size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
#ifdef USE_CUDA_EXP
if (config_->device_type == std::string("cuda_exp") && boosting_on_gpu) {
if (config_->device_type == std::string("cuda_exp") && boosting_on_gpu_) {
if (gradients_pointer_ != nullptr) {
CHECK_NOTNULL(hessians_pointer_);
DeallocateCUDAMemory<score_t>(&gradients_pointer_, __FILE__, __LINE__);
DeallocateCUDAMemory<score_t>(&hessians_pointer_, __FILE__, __LINE__);
}
AllocateCUDAMemory<score_t>(&gradients_pointer_, total_size, __FILE__, __LINE__);
AllocateCUDAMemory<score_t>(&hessians_pointer_, total_size, __FILE__, __LINE__);
} else {
Expand All @@ -848,6 +869,12 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
#ifdef USE_CUDA_EXP
}
#endif // USE_CUDA_EXP
} else if (config_->boosting == std::string("goss")) {
const size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
gradients_.resize(total_size);
hessians_.resize(total_size);
gradients_pointer_ = gradients_.data();
hessians_pointer_ = hessians_.data();
}

max_feature_idx_ = train_data_->num_total_features() - 1;
Expand Down Expand Up @@ -879,6 +906,10 @@ void GBDT::ResetConfig(const Config* config) {
if (tree_learner_ != nullptr) {
tree_learner_->ResetConfig(new_config.get());
}

boosting_on_gpu_ = objective_function_ != nullptr && objective_function_->IsCUDAObjective();
tree_learner_->ResetBoostingOnGPU(boosting_on_gpu_);

if (train_data_ != nullptr) {
ResetBaggingConfig(new_config.get(), false);
}
Expand Down Expand Up @@ -953,10 +984,16 @@ void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) {
need_re_bagging_ = true;

if (is_use_subset_ && bag_data_cnt_ < num_data_) {
if (objective_function_ == nullptr) {
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
// resize gradient vectors to copy the customized gradients for goss or bagging with subset
if (objective_function_ != nullptr) {
const size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
#ifdef USE_CUDA_EXP
if (config_->device_type == std::string("cuda_exp") && objective_function_ != nullptr && objective_function_->IsCUDAObjective()) {
if (config_->device_type == std::string("cuda_exp") && boosting_on_gpu_) {
if (gradients_pointer_ != nullptr) {
CHECK_NOTNULL(hessians_pointer_);
DeallocateCUDAMemory<score_t>(&gradients_pointer_, __FILE__, __LINE__);
DeallocateCUDAMemory<score_t>(&hessians_pointer_, __FILE__, __LINE__);
}
AllocateCUDAMemory<score_t>(&gradients_pointer_, total_size, __FILE__, __LINE__);
AllocateCUDAMemory<score_t>(&hessians_pointer_, total_size, __FILE__, __LINE__);
} else {
Expand All @@ -968,6 +1005,12 @@ void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) {
#ifdef USE_CUDA_EXP
}
#endif // USE_CUDA_EXP
} else if (config_->boosting == std::string("goss")) {
const size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
gradients_.resize(total_size);
hessians_.resize(total_size);
gradients_pointer_ = gradients_.data();
hessians_pointer_ = hessians_.data();
}
}
} else {
Expand Down
2 changes: 2 additions & 0 deletions src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,8 @@ class GBDT : public GBDTBase {
score_t* gradients_pointer_;
/*! \brief Pointer to hessian vector, can be on CPU or GPU */
score_t* hessians_pointer_;
/*! \brief Whether boosting is done on GPU, used for cuda_exp */
bool boosting_on_gpu_;
#ifdef USE_CUDA_EXP
/*! \brief Buffer for scores when boosting is on GPU but evaluation is not, used only with cuda_exp */
mutable std::vector<double> host_score_;
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/rf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class RF : public GBDT {
auto hess = hessians + offset;

// need to copy gradients for bagging subset.
if (is_use_subset_ && bag_data_cnt_ < num_data_) {
if (is_use_subset_ && bag_data_cnt_ < num_data_ && !boosting_on_gpu_) {
for (int i = 0; i < bag_data_cnt_; ++i) {
tmp_grad_[i] = grad[bag_data_indices_[i]];
tmp_hess_[i] = hess[bag_data_indices_[i]];
Expand Down
Loading

0 comments on commit be7f321

Please sign in to comment.