-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CUDA] Initial work for boosting and evaluation with CUDA #5279
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
061356a
initial work for boosting and evaluation with CUDA
shiyu1994 14a3c55
fix compatibility with CPU code
shiyu1994 de33ffd
fix creating objective without USE_CUDA_EXP
shiyu1994 fde00bd
fix static analysis errors
shiyu1994 49daa0f
fix static analysis errors
shiyu1994 e80435d
Merge remote-tracking branch 'origin/master' into shiyu-cuda-obj
shiyu1994 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
/*! | ||
* Copyright (c) 2021 Microsoft Corporation. All rights reserved. | ||
* Licensed under the MIT License. See LICENSE file in the project root for license information. | ||
*/ | ||
|
||
#include "cuda_score_updater.hpp" | ||
|
||
#ifdef USE_CUDA_EXP | ||
|
||
namespace LightGBM { | ||
|
||
CUDAScoreUpdater::CUDAScoreUpdater(const Dataset* data, int num_tree_per_iteration, const bool boosting_on_cuda): | ||
ScoreUpdater(data, num_tree_per_iteration), num_threads_per_block_(1024), boosting_on_cuda_(boosting_on_cuda) { | ||
num_data_ = data->num_data(); | ||
int64_t total_size = static_cast<int64_t>(num_data_) * num_tree_per_iteration; | ||
InitCUDA(total_size); | ||
has_init_score_ = false; | ||
const double* init_score = data->metadata().init_score(); | ||
// if exists initial score, will start from it | ||
if (init_score != nullptr) { | ||
if ((data->metadata().num_init_score() % num_data_) != 0 | ||
|| (data->metadata().num_init_score() / num_data_) != num_tree_per_iteration) { | ||
Log::Fatal("Number of class for initial score error"); | ||
} | ||
has_init_score_ = true; | ||
CopyFromHostToCUDADevice<double>(cuda_score_, init_score, total_size, __FILE__, __LINE__); | ||
} else { | ||
SetCUDAMemory<double>(cuda_score_, 0, static_cast<size_t>(num_data_), __FILE__, __LINE__); | ||
} | ||
SynchronizeCUDADevice(__FILE__, __LINE__); | ||
if (boosting_on_cuda_) { | ||
// clear host score buffer | ||
score_.clear(); | ||
score_.shrink_to_fit(); | ||
} | ||
} | ||
|
||
void CUDAScoreUpdater::InitCUDA(const size_t total_size) { | ||
AllocateCUDAMemory<double>(&cuda_score_, total_size, __FILE__, __LINE__); | ||
} | ||
|
||
CUDAScoreUpdater::~CUDAScoreUpdater() { | ||
DeallocateCUDAMemory<double>(&cuda_score_, __FILE__, __LINE__); | ||
} | ||
|
||
inline void CUDAScoreUpdater::AddScore(double val, int cur_tree_id) { | ||
Common::FunctionTimer fun_timer("CUDAScoreUpdater::AddScore", global_timer); | ||
const size_t offset = static_cast<size_t>(num_data_) * cur_tree_id; | ||
LaunchAddScoreConstantKernel(val, offset); | ||
if (!boosting_on_cuda_) { | ||
CopyFromCUDADeviceToHost<double>(score_.data() + offset, cuda_score_ + offset, static_cast<size_t>(num_data_), __FILE__, __LINE__); | ||
} | ||
} | ||
|
||
inline void CUDAScoreUpdater::AddScore(const Tree* tree, int cur_tree_id) { | ||
Common::FunctionTimer fun_timer("ScoreUpdater::AddScore", global_timer); | ||
const size_t offset = static_cast<size_t>(num_data_) * cur_tree_id; | ||
tree->AddPredictionToScore(data_, num_data_, cuda_score_ + offset); | ||
if (!boosting_on_cuda_) { | ||
CopyFromCUDADeviceToHost<double>(score_.data() + offset, cuda_score_ + offset, static_cast<size_t>(num_data_), __FILE__, __LINE__); | ||
} | ||
} | ||
|
||
inline void CUDAScoreUpdater::AddScore(const TreeLearner* tree_learner, const Tree* tree, int cur_tree_id) { | ||
Common::FunctionTimer fun_timer("ScoreUpdater::AddScore", global_timer); | ||
const size_t offset = static_cast<size_t>(num_data_) * cur_tree_id; | ||
tree_learner->AddPredictionToScore(tree, cuda_score_ + offset); | ||
if (!boosting_on_cuda_) { | ||
CopyFromCUDADeviceToHost<double>(score_.data() + offset, cuda_score_ + offset, static_cast<size_t>(num_data_), __FILE__, __LINE__); | ||
} | ||
} | ||
|
||
inline void CUDAScoreUpdater::AddScore(const Tree* tree, const data_size_t* data_indices, | ||
data_size_t data_cnt, int cur_tree_id) { | ||
Common::FunctionTimer fun_timer("ScoreUpdater::AddScore", global_timer); | ||
const size_t offset = static_cast<size_t>(num_data_) * cur_tree_id; | ||
tree->AddPredictionToScore(data_, data_indices, data_cnt, cuda_score_ + offset); | ||
if (!boosting_on_cuda_) { | ||
CopyFromCUDADeviceToHost<double>(score_.data() + offset, cuda_score_ + offset, static_cast<size_t>(num_data_), __FILE__, __LINE__); | ||
} | ||
} | ||
|
||
inline void CUDAScoreUpdater::MultiplyScore(double val, int cur_tree_id) { | ||
Common::FunctionTimer fun_timer("CUDAScoreUpdater::MultiplyScore", global_timer); | ||
const size_t offset = static_cast<size_t>(num_data_) * cur_tree_id; | ||
LaunchMultiplyScoreConstantKernel(val, offset); | ||
if (!boosting_on_cuda_) { | ||
CopyFromCUDADeviceToHost<double>(score_.data() + offset, cuda_score_ + offset, static_cast<size_t>(num_data_), __FILE__, __LINE__); | ||
} | ||
} | ||
|
||
} // namespace LightGBM | ||
|
||
#endif // USE_CUDA_EXP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
/*! | ||
* Copyright (c) 2021 Microsoft Corporation. All rights reserved. | ||
* Licensed under the MIT License. See LICENSE file in the project root for license information. | ||
*/ | ||
|
||
#include "cuda_score_updater.hpp" | ||
|
||
#ifdef USE_CUDA_EXP | ||
|
||
namespace LightGBM { | ||
|
||
__global__ void AddScoreConstantKernel( | ||
const double val, | ||
const size_t offset, | ||
const data_size_t num_data, | ||
double* score) { | ||
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x); | ||
if (data_index < num_data) { | ||
score[data_index + offset] += val; | ||
} | ||
} | ||
|
||
void CUDAScoreUpdater::LaunchAddScoreConstantKernel(const double val, const size_t offset) { | ||
const int num_blocks = (num_data_ + num_threads_per_block_) / num_threads_per_block_; | ||
Log::Warning("adding init score = %f", val); | ||
AddScoreConstantKernel<<<num_blocks, num_threads_per_block_>>>(val, offset, num_data_, cuda_score_); | ||
} | ||
|
||
__global__ void MultiplyScoreConstantKernel( | ||
const double val, | ||
const size_t offset, | ||
const data_size_t num_data, | ||
double* score) { | ||
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x); | ||
if (data_index < num_data) { | ||
score[data_index] *= val; | ||
} | ||
} | ||
|
||
void CUDAScoreUpdater::LaunchMultiplyScoreConstantKernel(const double val, const size_t offset) { | ||
const int num_blocks = (num_data_ + num_threads_per_block_) / num_threads_per_block_; | ||
MultiplyScoreConstantKernel<<<num_blocks, num_threads_per_block_>>>(val, offset, num_data_, cuda_score_); | ||
} | ||
|
||
} // namespace LightGBM | ||
|
||
#endif // USE_CUDA_EXP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
/*! | ||
* Copyright (c) 2021 Microsoft Corporation. All rights reserved. | ||
* Licensed under the MIT License. See LICENSE file in the project root for license information. | ||
*/ | ||
|
||
#ifndef LIGHTGBM_BOOSTING_CUDA_CUDA_SCORE_UPDATER_HPP_ | ||
#define LIGHTGBM_BOOSTING_CUDA_CUDA_SCORE_UPDATER_HPP_ | ||
|
||
#ifdef USE_CUDA_EXP | ||
|
||
#include <LightGBM/cuda/cuda_utils.h> | ||
|
||
#include "../score_updater.hpp" | ||
|
||
namespace LightGBM { | ||
|
||
class CUDAScoreUpdater: public ScoreUpdater { | ||
public: | ||
CUDAScoreUpdater(const Dataset* data, int num_tree_per_iteration, const bool boosting_on_cuda); | ||
|
||
~CUDAScoreUpdater(); | ||
|
||
inline void AddScore(double val, int cur_tree_id) override; | ||
|
||
inline void AddScore(const Tree* tree, int cur_tree_id) override; | ||
|
||
inline void AddScore(const TreeLearner* tree_learner, const Tree* tree, int cur_tree_id) override; | ||
|
||
inline void AddScore(const Tree* tree, const data_size_t* data_indices, | ||
data_size_t data_cnt, int cur_tree_id) override; | ||
|
||
inline void MultiplyScore(double val, int cur_tree_id) override; | ||
|
||
inline const double* score() const override { | ||
if (boosting_on_cuda_) { | ||
return cuda_score_; | ||
} else { | ||
return score_.data(); | ||
} | ||
} | ||
|
||
/*! \brief Disable copy */ | ||
CUDAScoreUpdater& operator=(const CUDAScoreUpdater&) = delete; | ||
|
||
CUDAScoreUpdater(const CUDAScoreUpdater&) = delete; | ||
|
||
private: | ||
void InitCUDA(const size_t total_size); | ||
|
||
void LaunchAddScoreConstantKernel(const double val, const size_t offset); | ||
|
||
void LaunchMultiplyScoreConstantKernel(const double val, const size_t offset); | ||
|
||
double* cuda_score_; | ||
|
||
const int num_threads_per_block_; | ||
|
||
const bool boosting_on_cuda_; | ||
}; | ||
|
||
} // namespace LightGBM | ||
|
||
#endif // USE_CUDA_EXP | ||
|
||
#endif // LIGHTGBM_BOOSTING_CUDA_CUDA_SCORE_UPDATER_HPP_ |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is
boosting_on_cuda
always true? should we keep this variable?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. This is a work around when there are some objective functions not implemented by the CUDA version yet. See line 98 of gbdt.cpp in this PR.