Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Initial work for boosting and evaluation with CUDA #5279

Merged
merged 6 commits into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@ if(USE_CUDA OR USE_CUDA_EXP)
src/treelearner/*.cu
endif()
if(USE_CUDA_EXP)
src/boosting/cuda/*.cpp
src/boosting/cuda/*.cu
src/treelearner/cuda/*.cpp
src/treelearner/cuda/*.cu
src/io/cuda/*.cu
Expand Down
25 changes: 25 additions & 0 deletions include/LightGBM/cuda/cuda_tree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,27 @@ class CUDATree : public Tree {
uint32_t* cuda_bitset_inner,
size_t cuda_bitset_inner_len);

/*!
* \brief Adding prediction value of this tree model to scores
* \param data The dataset
* \param num_data Number of total data
* \param score Will add prediction to score
*/
void AddPredictionToScore(const Dataset* data,
data_size_t num_data,
double* score) const override;

/*!
* \brief Adding prediction value of this tree model to scores
* \param data The dataset
* \param used_data_indices Indices of used data
* \param num_data Number of total data
* \param score Will add prediction to score
*/
void AddPredictionToScore(const Dataset* data,
const data_size_t* used_data_indices,
data_size_t num_data, double* score) const override;

const int* cuda_leaf_parent() const { return cuda_leaf_parent_; }

const int* cuda_left_child() const { return cuda_left_child_; }
Expand Down Expand Up @@ -105,6 +126,10 @@ class CUDATree : public Tree {
size_t cuda_bitset_len,
size_t cuda_bitset_inner_len);

void LaunchAddPredictionToScoreKernel(const Dataset* data,
const data_size_t* used_data_indices,
data_size_t num_data, double* score) const;

void LaunchShrinkageKernel(const double rate);

void LaunchAddBiasKernel(const double val);
Expand Down
5 changes: 5 additions & 0 deletions include/LightGBM/metric.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ class Metric {
* \param config Config for metric
*/
LIGHTGBM_EXPORT static Metric* CreateMetric(const std::string& type, const Config& config);

/*!
* \brief Whether boosting is done on CUDA
*/
virtual bool IsCUDAMetric() const { return false; }
};

/*!
Expand Down
5 changes: 5 additions & 0 deletions include/LightGBM/objective_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ class ObjectiveFunction {
* \brief Load objective function from string object
*/
LIGHTGBM_EXPORT static ObjectiveFunction* CreateObjectiveFunction(const std::string& str);

/*!
* \brief Whether boosting is done on CUDA
*/
virtual bool IsCUDAObjective() const { return false; }
};

} // namespace LightGBM
Expand Down
3 changes: 2 additions & 1 deletion include/LightGBM/tree_learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ class TreeLearner {
*/
static TreeLearner* CreateTreeLearner(const std::string& learner_type,
const std::string& device_type,
const Config* config);
const Config* config,
const bool boosting_on_cuda);
};

} // namespace LightGBM
Expand Down
94 changes: 94 additions & 0 deletions src/boosting/cuda/cuda_score_updater.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/

#include "cuda_score_updater.hpp"

#ifdef USE_CUDA_EXP

namespace LightGBM {

CUDAScoreUpdater::CUDAScoreUpdater(const Dataset* data, int num_tree_per_iteration, const bool boosting_on_cuda):
ScoreUpdater(data, num_tree_per_iteration), num_threads_per_block_(1024), boosting_on_cuda_(boosting_on_cuda) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is boosting_on_cuda always true? should we keep this variable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. This is a work around when there are some objective functions not implemented by the CUDA version yet. See line 98 of gbdt.cpp in this PR.

num_data_ = data->num_data();
int64_t total_size = static_cast<int64_t>(num_data_) * num_tree_per_iteration;
InitCUDA(total_size);
has_init_score_ = false;
const double* init_score = data->metadata().init_score();
// if exists initial score, will start from it
if (init_score != nullptr) {
if ((data->metadata().num_init_score() % num_data_) != 0
|| (data->metadata().num_init_score() / num_data_) != num_tree_per_iteration) {
Log::Fatal("Number of class for initial score error");
}
has_init_score_ = true;
CopyFromHostToCUDADevice<double>(cuda_score_, init_score, total_size, __FILE__, __LINE__);
} else {
SetCUDAMemory<double>(cuda_score_, 0, static_cast<size_t>(num_data_), __FILE__, __LINE__);
}
SynchronizeCUDADevice(__FILE__, __LINE__);
if (boosting_on_cuda_) {
// clear host score buffer
score_.clear();
score_.shrink_to_fit();
}
}

void CUDAScoreUpdater::InitCUDA(const size_t total_size) {
AllocateCUDAMemory<double>(&cuda_score_, total_size, __FILE__, __LINE__);
}

CUDAScoreUpdater::~CUDAScoreUpdater() {
DeallocateCUDAMemory<double>(&cuda_score_, __FILE__, __LINE__);
}

inline void CUDAScoreUpdater::AddScore(double val, int cur_tree_id) {
Common::FunctionTimer fun_timer("CUDAScoreUpdater::AddScore", global_timer);
const size_t offset = static_cast<size_t>(num_data_) * cur_tree_id;
LaunchAddScoreConstantKernel(val, offset);
if (!boosting_on_cuda_) {
CopyFromCUDADeviceToHost<double>(score_.data() + offset, cuda_score_ + offset, static_cast<size_t>(num_data_), __FILE__, __LINE__);
}
}

inline void CUDAScoreUpdater::AddScore(const Tree* tree, int cur_tree_id) {
Common::FunctionTimer fun_timer("ScoreUpdater::AddScore", global_timer);
const size_t offset = static_cast<size_t>(num_data_) * cur_tree_id;
tree->AddPredictionToScore(data_, num_data_, cuda_score_ + offset);
if (!boosting_on_cuda_) {
CopyFromCUDADeviceToHost<double>(score_.data() + offset, cuda_score_ + offset, static_cast<size_t>(num_data_), __FILE__, __LINE__);
}
}

inline void CUDAScoreUpdater::AddScore(const TreeLearner* tree_learner, const Tree* tree, int cur_tree_id) {
Common::FunctionTimer fun_timer("ScoreUpdater::AddScore", global_timer);
const size_t offset = static_cast<size_t>(num_data_) * cur_tree_id;
tree_learner->AddPredictionToScore(tree, cuda_score_ + offset);
if (!boosting_on_cuda_) {
CopyFromCUDADeviceToHost<double>(score_.data() + offset, cuda_score_ + offset, static_cast<size_t>(num_data_), __FILE__, __LINE__);
}
}

inline void CUDAScoreUpdater::AddScore(const Tree* tree, const data_size_t* data_indices,
data_size_t data_cnt, int cur_tree_id) {
Common::FunctionTimer fun_timer("ScoreUpdater::AddScore", global_timer);
const size_t offset = static_cast<size_t>(num_data_) * cur_tree_id;
tree->AddPredictionToScore(data_, data_indices, data_cnt, cuda_score_ + offset);
if (!boosting_on_cuda_) {
CopyFromCUDADeviceToHost<double>(score_.data() + offset, cuda_score_ + offset, static_cast<size_t>(num_data_), __FILE__, __LINE__);
}
}

inline void CUDAScoreUpdater::MultiplyScore(double val, int cur_tree_id) {
Common::FunctionTimer fun_timer("CUDAScoreUpdater::MultiplyScore", global_timer);
const size_t offset = static_cast<size_t>(num_data_) * cur_tree_id;
LaunchMultiplyScoreConstantKernel(val, offset);
if (!boosting_on_cuda_) {
CopyFromCUDADeviceToHost<double>(score_.data() + offset, cuda_score_ + offset, static_cast<size_t>(num_data_), __FILE__, __LINE__);
}
}

} // namespace LightGBM

#endif // USE_CUDA_EXP
47 changes: 47 additions & 0 deletions src/boosting/cuda/cuda_score_updater.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/

#include "cuda_score_updater.hpp"

#ifdef USE_CUDA_EXP

namespace LightGBM {

__global__ void AddScoreConstantKernel(
const double val,
const size_t offset,
const data_size_t num_data,
double* score) {
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
if (data_index < num_data) {
score[data_index + offset] += val;
}
}

void CUDAScoreUpdater::LaunchAddScoreConstantKernel(const double val, const size_t offset) {
const int num_blocks = (num_data_ + num_threads_per_block_) / num_threads_per_block_;
Log::Warning("adding init score = %f", val);
AddScoreConstantKernel<<<num_blocks, num_threads_per_block_>>>(val, offset, num_data_, cuda_score_);
}

__global__ void MultiplyScoreConstantKernel(
const double val,
const size_t offset,
const data_size_t num_data,
double* score) {
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
if (data_index < num_data) {
score[data_index] *= val;
}
}

void CUDAScoreUpdater::LaunchMultiplyScoreConstantKernel(const double val, const size_t offset) {
const int num_blocks = (num_data_ + num_threads_per_block_) / num_threads_per_block_;
MultiplyScoreConstantKernel<<<num_blocks, num_threads_per_block_>>>(val, offset, num_data_, cuda_score_);
}

} // namespace LightGBM

#endif // USE_CUDA_EXP
65 changes: 65 additions & 0 deletions src/boosting/cuda/cuda_score_updater.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/

#ifndef LIGHTGBM_BOOSTING_CUDA_CUDA_SCORE_UPDATER_HPP_
#define LIGHTGBM_BOOSTING_CUDA_CUDA_SCORE_UPDATER_HPP_

#ifdef USE_CUDA_EXP

#include <LightGBM/cuda/cuda_utils.h>

#include "../score_updater.hpp"

namespace LightGBM {

class CUDAScoreUpdater: public ScoreUpdater {
public:
CUDAScoreUpdater(const Dataset* data, int num_tree_per_iteration, const bool boosting_on_cuda);

~CUDAScoreUpdater();

inline void AddScore(double val, int cur_tree_id) override;

inline void AddScore(const Tree* tree, int cur_tree_id) override;

inline void AddScore(const TreeLearner* tree_learner, const Tree* tree, int cur_tree_id) override;

inline void AddScore(const Tree* tree, const data_size_t* data_indices,
data_size_t data_cnt, int cur_tree_id) override;

inline void MultiplyScore(double val, int cur_tree_id) override;

inline const double* score() const override {
if (boosting_on_cuda_) {
return cuda_score_;
} else {
return score_.data();
}
}

/*! \brief Disable copy */
CUDAScoreUpdater& operator=(const CUDAScoreUpdater&) = delete;

CUDAScoreUpdater(const CUDAScoreUpdater&) = delete;

private:
void InitCUDA(const size_t total_size);

void LaunchAddScoreConstantKernel(const double val, const size_t offset);

void LaunchMultiplyScoreConstantKernel(const double val, const size_t offset);

double* cuda_score_;

const int num_threads_per_block_;

const bool boosting_on_cuda_;
};

} // namespace LightGBM

#endif // USE_CUDA_EXP

#endif // LIGHTGBM_BOOSTING_CUDA_CUDA_SCORE_UPDATER_HPP_
Loading