From 582c760ecd7d76ec60dba6164c1be338125a492c Mon Sep 17 00:00:00 2001 From: Yu Shi Date: Fri, 25 Oct 2024 03:12:25 +0000 Subject: [PATCH 1/3] update gbdt --- src/boosting/gbdt.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index 737c69072b64..4d6601b84918 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -104,12 +104,12 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective boosting_on_gpu_ = objective_function_ != nullptr && objective_function_->IsCUDAObjective() && !data_sample_strategy_->IsHessianChange(); // for sample strategy with Hessian change, fall back to boosting on CPU - tree_learner_ = std::unique_ptr(TreeLearner::CreateTreeLearner(config_->tree_learner, config_->device_type, - config_.get(), boosting_on_gpu_)); + tree_learner_ = nullptr; // std::unique_ptr(TreeLearner::CreateTreeLearner(config_->tree_learner, config_->device_type, + // config_.get(), boosting_on_gpu_)); // init tree learner - tree_learner_->Init(train_data_, is_constant_hessian_); - tree_learner_->SetForcedSplit(&forced_splits_json_); + // tree_learner_->Init(train_data_, is_constant_hessian_); + // tree_learner_->SetForcedSplit(&forced_splits_json_); // push training metrics training_metrics_.clear(); @@ -227,7 +227,7 @@ void GBDT::Boosting() { if (config_->bagging_by_query) { data_sample_strategy_->Bagging(iter_, tree_learner_.get(), gradients_.data(), hessians_.data()); objective_function_-> - GetGradients(GetTrainingScore(&num_score), data_sample_strategy_->num_sampled_queries(), data_sample_strategy_->sampled_query_indices(), gradients_pointer_, hessians_pointer_); + GetGradientsWithSampledQueries(GetTrainingScore(&num_score), data_sample_strategy_->num_sampled_queries(), data_sample_strategy_->sampled_query_indices(), gradients_pointer_, hessians_pointer_); } else { objective_function_-> GetGradients(GetTrainingScore(&num_score), gradients_pointer_, hessians_pointer_); From b9e143bad1c79a9375cdfd9b58546306fc7ceb28 Mon Sep 17 00:00:00 2001 From: Yu Shi Date: Fri, 25 Oct 2024 03:13:16 +0000 Subject: [PATCH 2/3] changes for cuda tree --- include/LightGBM/cuda/cuda_tree.hpp | 52 +++--- src/io/cuda/cuda_tree.cpp | 238 +++++++--------------------- src/io/cuda/cuda_tree.cu | 92 +++++------ 3 files changed, 127 insertions(+), 255 deletions(-) diff --git a/include/LightGBM/cuda/cuda_tree.hpp b/include/LightGBM/cuda/cuda_tree.hpp index 7ab06190481b..dcd7c41fff8b 100644 --- a/include/LightGBM/cuda/cuda_tree.hpp +++ b/include/LightGBM/cuda/cuda_tree.hpp @@ -79,25 +79,25 @@ class CUDATree : public Tree { inline void AsConstantTree(double val, int count) override; - const int* cuda_leaf_parent() const { return cuda_leaf_parent_; } + const int* cuda_leaf_parent() const { return cuda_leaf_parent_.RawData(); } - const int* cuda_left_child() const { return cuda_left_child_; } + const int* cuda_left_child() const { return cuda_left_child_.RawData(); } - const int* cuda_right_child() const { return cuda_right_child_; } + const int* cuda_right_child() const { return cuda_right_child_.RawData(); } - const int* cuda_split_feature_inner() const { return cuda_split_feature_inner_; } + const int* cuda_split_feature_inner() const { return cuda_split_feature_inner_.RawData(); } - const int* cuda_split_feature() const { return cuda_split_feature_; } + const int* cuda_split_feature() const { return cuda_split_feature_.RawData(); } - const uint32_t* cuda_threshold_in_bin() const { return cuda_threshold_in_bin_; } + const uint32_t* cuda_threshold_in_bin() const { return cuda_threshold_in_bin_.RawData(); } - const double* cuda_threshold() const { return cuda_threshold_; } + const double* cuda_threshold() const { return cuda_threshold_.RawData(); } - const int8_t* cuda_decision_type() const { return cuda_decision_type_; } + const int8_t* cuda_decision_type() const { return cuda_decision_type_.RawData(); } - const double* cuda_leaf_value() const { return cuda_leaf_value_; } + const double* cuda_leaf_value() const { return cuda_leaf_value_.RawData(); } - double* cuda_leaf_value_ref() { return cuda_leaf_value_; } + double* cuda_leaf_value_ref() { return cuda_leaf_value_.RawData(); } inline void Shrinkage(double rate) override; @@ -140,22 +140,22 @@ class CUDATree : public Tree { const int right_leaf_index, const int real_feature_index); - int* cuda_left_child_; - int* cuda_right_child_; - int* cuda_split_feature_inner_; - int* cuda_split_feature_; - int* cuda_leaf_depth_; - int* cuda_leaf_parent_; - uint32_t* cuda_threshold_in_bin_; - double* cuda_threshold_; - double* cuda_internal_weight_; - double* cuda_internal_value_; - int8_t* cuda_decision_type_; - double* cuda_leaf_value_; - data_size_t* cuda_leaf_count_; - double* cuda_leaf_weight_; - data_size_t* cuda_internal_count_; - float* cuda_split_gain_; + CUDAVector cuda_left_child_; + CUDAVector cuda_right_child_; + CUDAVector cuda_split_feature_inner_; + CUDAVector cuda_split_feature_; + CUDAVector cuda_leaf_depth_; + CUDAVector cuda_leaf_parent_; + CUDAVector cuda_threshold_in_bin_; + CUDAVector cuda_threshold_; + CUDAVector cuda_internal_weight_; + CUDAVector cuda_internal_value_; + CUDAVector cuda_decision_type_; + CUDAVector cuda_leaf_value_; + CUDAVector cuda_leaf_count_; + CUDAVector cuda_leaf_weight_; + CUDAVector cuda_internal_count_; + CUDAVector cuda_split_gain_; CUDAVector cuda_bitset_; CUDAVector cuda_bitset_inner_; CUDAVector cuda_cat_boundaries_; diff --git a/src/io/cuda/cuda_tree.cpp b/src/io/cuda/cuda_tree.cpp index c5dee89ca3af..886394c3c0ea 100644 --- a/src/io/cuda/cuda_tree.cpp +++ b/src/io/cuda/cuda_tree.cpp @@ -34,178 +34,50 @@ CUDATree::CUDATree(const Tree* host_tree): } CUDATree::~CUDATree() { - DeallocateCUDAMemory(&cuda_left_child_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_right_child_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_split_feature_inner_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_split_feature_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_leaf_depth_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_leaf_parent_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_threshold_in_bin_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_threshold_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_internal_weight_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_internal_value_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_decision_type_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_leaf_value_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_leaf_count_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_leaf_weight_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_internal_count_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_split_gain_, __FILE__, __LINE__); gpuAssert(cudaStreamDestroy(cuda_stream_), __FILE__, __LINE__); } void CUDATree::InitCUDAMemory() { - AllocateCUDAMemory(&cuda_left_child_, - static_cast(max_leaves_), - __FILE__, - __LINE__); - AllocateCUDAMemory(&cuda_right_child_, - static_cast(max_leaves_), - __FILE__, - __LINE__); - AllocateCUDAMemory(&cuda_split_feature_inner_, - static_cast(max_leaves_), - __FILE__, - __LINE__); - AllocateCUDAMemory(&cuda_split_feature_, - static_cast(max_leaves_), - __FILE__, - __LINE__); - AllocateCUDAMemory(&cuda_leaf_depth_, - static_cast(max_leaves_), - __FILE__, - __LINE__); - AllocateCUDAMemory(&cuda_leaf_parent_, - static_cast(max_leaves_), - __FILE__, - __LINE__); - AllocateCUDAMemory(&cuda_threshold_in_bin_, - static_cast(max_leaves_), - __FILE__, - __LINE__); - AllocateCUDAMemory(&cuda_threshold_, - static_cast(max_leaves_), - __FILE__, - __LINE__); - AllocateCUDAMemory(&cuda_decision_type_, - static_cast(max_leaves_), - __FILE__, - __LINE__); - AllocateCUDAMemory(&cuda_leaf_value_, - static_cast(max_leaves_), - __FILE__, - __LINE__); - AllocateCUDAMemory(&cuda_internal_weight_, - static_cast(max_leaves_), - __FILE__, - __LINE__); - AllocateCUDAMemory(&cuda_internal_value_, - static_cast(max_leaves_), - __FILE__, - __LINE__); - AllocateCUDAMemory(&cuda_leaf_weight_, - static_cast(max_leaves_), - __FILE__, - __LINE__); - AllocateCUDAMemory(&cuda_leaf_count_, - static_cast(max_leaves_), - __FILE__, - __LINE__); - AllocateCUDAMemory(&cuda_internal_count_, - static_cast(max_leaves_), - __FILE__, - __LINE__); - AllocateCUDAMemory(&cuda_split_gain_, - static_cast(max_leaves_), - __FILE__, - __LINE__); - SetCUDAMemory(cuda_leaf_value_, 0.0f, 1, __FILE__, __LINE__); - SetCUDAMemory(cuda_leaf_weight_, 0.0f, 1, __FILE__, __LINE__); - SetCUDAMemory(cuda_leaf_parent_, -1, 1, __FILE__, __LINE__); + cuda_left_child_.Resize(static_cast(max_leaves_)); + cuda_right_child_.Resize(static_cast(max_leaves_)); + cuda_split_feature_inner_.Resize(static_cast(max_leaves_)); + cuda_split_feature_.Resize(static_cast(max_leaves_)); + cuda_leaf_depth_.Resize(static_cast(max_leaves_)); + cuda_leaf_parent_.Resize(static_cast(max_leaves_)); + cuda_threshold_in_bin_.Resize(static_cast(max_leaves_)); + cuda_threshold_.Resize(static_cast(max_leaves_)); + cuda_decision_type_.Resize(static_cast(max_leaves_)); + cuda_leaf_value_.Resize(static_cast(max_leaves_)); + cuda_internal_weight_.Resize(static_cast(max_leaves_)); + cuda_internal_value_.Resize(static_cast(max_leaves_)); + cuda_leaf_weight_.Resize(static_cast(max_leaves_)); + cuda_leaf_count_.Resize(static_cast(max_leaves_)); + cuda_internal_count_.Resize(static_cast(max_leaves_)); + cuda_split_gain_.Resize(static_cast(max_leaves_)); + SetCUDAMemory(cuda_leaf_value_.RawData(), 0.0f, 1, __FILE__, __LINE__); + SetCUDAMemory(cuda_leaf_weight_.RawData(), 0.0f, 1, __FILE__, __LINE__); + SetCUDAMemory(cuda_leaf_parent_.RawData(), -1, 1, __FILE__, __LINE__); CUDASUCCESS_OR_FATAL(cudaStreamCreate(&cuda_stream_)); SynchronizeCUDADevice(__FILE__, __LINE__); } void CUDATree::InitCUDA() { - InitCUDAMemoryFromHostMemory(&cuda_left_child_, - left_child_.data(), - left_child_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_right_child_, - right_child_.data(), - right_child_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_split_feature_inner_, - split_feature_inner_.data(), - split_feature_inner_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_split_feature_, - split_feature_.data(), - split_feature_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_threshold_in_bin_, - threshold_in_bin_.data(), - threshold_in_bin_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_threshold_, - threshold_.data(), - threshold_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_leaf_depth_, - leaf_depth_.data(), - leaf_depth_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_decision_type_, - decision_type_.data(), - decision_type_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_internal_weight_, - internal_weight_.data(), - internal_weight_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_internal_value_, - internal_value_.data(), - internal_value_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_internal_count_, - internal_count_.data(), - internal_count_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_leaf_count_, - leaf_count_.data(), - leaf_count_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_split_gain_, - split_gain_.data(), - split_gain_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_leaf_value_, - leaf_value_.data(), - leaf_value_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_leaf_weight_, - leaf_weight_.data(), - leaf_weight_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_leaf_parent_, - leaf_parent_.data(), - leaf_parent_.size(), - __FILE__, - __LINE__); + cuda_left_child_.InitFromHostVector(left_child_); + cuda_right_child_.InitFromHostVector(right_child_); + cuda_split_feature_inner_.InitFromHostVector(split_feature_inner_); + cuda_split_feature_.InitFromHostVector(split_feature_); + cuda_threshold_in_bin_.InitFromHostVector(threshold_in_bin_); + cuda_threshold_.InitFromHostVector(threshold_); + cuda_leaf_depth_.InitFromHostVector(leaf_depth_); + cuda_decision_type_.InitFromHostVector(decision_type_); + cuda_internal_weight_.InitFromHostVector(internal_weight_); + cuda_internal_value_.InitFromHostVector(internal_value_); + cuda_internal_count_.InitFromHostVector(internal_count_); + cuda_leaf_count_.InitFromHostVector(leaf_count_); + cuda_split_gain_.InitFromHostVector(split_gain_); + cuda_leaf_value_.InitFromHostVector(leaf_value_); + cuda_leaf_weight_.InitFromHostVector(leaf_weight_); + cuda_leaf_parent_.InitFromHostVector(leaf_parent_); CUDASUCCESS_OR_FATAL(cudaStreamCreate(&cuda_stream_)); SynchronizeCUDADevice(__FILE__, __LINE__); } @@ -293,22 +165,22 @@ void CUDATree::ToHost() { leaf_depth_.resize(max_leaves_); const size_t num_leaves_size = static_cast(num_leaves_); - CopyFromCUDADeviceToHost(left_child_.data(), cuda_left_child_, num_leaves_size - 1, __FILE__, __LINE__); - CopyFromCUDADeviceToHost(right_child_.data(), cuda_right_child_, num_leaves_size - 1, __FILE__, __LINE__); - CopyFromCUDADeviceToHost(split_feature_inner_.data(), cuda_split_feature_inner_, num_leaves_size - 1, __FILE__, __LINE__); - CopyFromCUDADeviceToHost(split_feature_.data(), cuda_split_feature_, num_leaves_size - 1, __FILE__, __LINE__); - CopyFromCUDADeviceToHost(threshold_in_bin_.data(), cuda_threshold_in_bin_, num_leaves_size - 1, __FILE__, __LINE__); - CopyFromCUDADeviceToHost(threshold_.data(), cuda_threshold_, num_leaves_size - 1, __FILE__, __LINE__); - CopyFromCUDADeviceToHost(decision_type_.data(), cuda_decision_type_, num_leaves_size - 1, __FILE__, __LINE__); - CopyFromCUDADeviceToHost(split_gain_.data(), cuda_split_gain_, num_leaves_size - 1, __FILE__, __LINE__); - CopyFromCUDADeviceToHost(leaf_parent_.data(), cuda_leaf_parent_, num_leaves_size - 1, __FILE__, __LINE__); - CopyFromCUDADeviceToHost(leaf_value_.data(), cuda_leaf_value_, num_leaves_size, __FILE__, __LINE__); - CopyFromCUDADeviceToHost(leaf_weight_.data(), cuda_leaf_weight_, num_leaves_size, __FILE__, __LINE__); - CopyFromCUDADeviceToHost(leaf_count_.data(), cuda_leaf_count_, num_leaves_size, __FILE__, __LINE__); - CopyFromCUDADeviceToHost(internal_value_.data(), cuda_internal_value_, num_leaves_size - 1, __FILE__, __LINE__); - CopyFromCUDADeviceToHost(internal_weight_.data(), cuda_internal_weight_, num_leaves_size - 1, __FILE__, __LINE__); - CopyFromCUDADeviceToHost(internal_count_.data(), cuda_internal_count_, num_leaves_size - 1, __FILE__, __LINE__); - CopyFromCUDADeviceToHost(leaf_depth_.data(), cuda_leaf_depth_, num_leaves_size, __FILE__, __LINE__); + CopyFromCUDADeviceToHost(left_child_.data(), cuda_left_child_.RawData(), num_leaves_size - 1, __FILE__, __LINE__); + CopyFromCUDADeviceToHost(right_child_.data(), cuda_right_child_.RawData(), num_leaves_size - 1, __FILE__, __LINE__); + CopyFromCUDADeviceToHost(split_feature_inner_.data(), cuda_split_feature_inner_.RawData(), num_leaves_size - 1, __FILE__, __LINE__); + CopyFromCUDADeviceToHost(split_feature_.data(), cuda_split_feature_.RawData(), num_leaves_size - 1, __FILE__, __LINE__); + CopyFromCUDADeviceToHost(threshold_in_bin_.data(), cuda_threshold_in_bin_.RawData(), num_leaves_size - 1, __FILE__, __LINE__); + CopyFromCUDADeviceToHost(threshold_.data(), cuda_threshold_.RawData(), num_leaves_size - 1, __FILE__, __LINE__); + CopyFromCUDADeviceToHost(decision_type_.data(), cuda_decision_type_.RawData(), num_leaves_size - 1, __FILE__, __LINE__); + CopyFromCUDADeviceToHost(split_gain_.data(), cuda_split_gain_.RawData(), num_leaves_size - 1, __FILE__, __LINE__); + CopyFromCUDADeviceToHost(leaf_parent_.data(), cuda_leaf_parent_.RawData(), num_leaves_size - 1, __FILE__, __LINE__); + CopyFromCUDADeviceToHost(leaf_value_.data(), cuda_leaf_value_.RawData(), num_leaves_size, __FILE__, __LINE__); + CopyFromCUDADeviceToHost(leaf_weight_.data(), cuda_leaf_weight_.RawData(), num_leaves_size, __FILE__, __LINE__); + CopyFromCUDADeviceToHost(leaf_count_.data(), cuda_leaf_count_.RawData(), num_leaves_size, __FILE__, __LINE__); + CopyFromCUDADeviceToHost(internal_value_.data(), cuda_internal_value_.RawData(), num_leaves_size - 1, __FILE__, __LINE__); + CopyFromCUDADeviceToHost(internal_weight_.data(), cuda_internal_weight_.RawData(), num_leaves_size - 1, __FILE__, __LINE__); + CopyFromCUDADeviceToHost(internal_count_.data(), cuda_internal_count_.RawData(), num_leaves_size - 1, __FILE__, __LINE__); + CopyFromCUDADeviceToHost(leaf_depth_.data(), cuda_leaf_depth_.RawData(), num_leaves_size, __FILE__, __LINE__); if (num_cat_ > 0) { cuda_cat_boundaries_inner_.Resize(num_cat_ + 1); @@ -323,17 +195,17 @@ void CUDATree::ToHost() { } void CUDATree::SyncLeafOutputFromHostToCUDA() { - CopyFromHostToCUDADevice(cuda_leaf_value_, leaf_value_.data(), leaf_value_.size(), __FILE__, __LINE__); + CopyFromHostToCUDADevice(cuda_leaf_value_.RawData(), leaf_value_.data(), leaf_value_.size(), __FILE__, __LINE__); } void CUDATree::SyncLeafOutputFromCUDAToHost() { - CopyFromCUDADeviceToHost(leaf_value_.data(), cuda_leaf_value_, leaf_value_.size(), __FILE__, __LINE__); + CopyFromCUDADeviceToHost(leaf_value_.data(), cuda_leaf_value_.RawData(), leaf_value_.size(), __FILE__, __LINE__); } void CUDATree::AsConstantTree(double val, int count) { Tree::AsConstantTree(val, count); - CopyFromHostToCUDADevice(cuda_leaf_value_, &val, 1, __FILE__, __LINE__); - CopyFromHostToCUDADevice(cuda_leaf_count_, &count, 1, __FILE__, __LINE__); + CopyFromHostToCUDADevice(cuda_leaf_value_.RawData(), &val, 1, __FILE__, __LINE__); + CopyFromHostToCUDADevice(cuda_leaf_count_.RawData(), &count, 1, __FILE__, __LINE__); } } // namespace LightGBM diff --git a/src/io/cuda/cuda_tree.cu b/src/io/cuda/cuda_tree.cu index 87abfc1353b4..821b572464a2 100644 --- a/src/io/cuda/cuda_tree.cu +++ b/src/io/cuda/cuda_tree.cu @@ -139,22 +139,22 @@ void CUDATree::LaunchSplitKernel(const int leaf_index, cuda_split_info, // tree structure num_leaves_, - cuda_leaf_parent_, - cuda_leaf_depth_, - cuda_left_child_, - cuda_right_child_, - cuda_split_feature_inner_, - cuda_split_feature_, - cuda_split_gain_, - cuda_internal_weight_, - cuda_internal_value_, - cuda_internal_count_, - cuda_leaf_weight_, - cuda_leaf_value_, - cuda_leaf_count_, - cuda_decision_type_, - cuda_threshold_in_bin_, - cuda_threshold_); + cuda_leaf_parent_.RawData(), + cuda_leaf_depth_.RawData(), + cuda_left_child_.RawData(), + cuda_right_child_.RawData(), + cuda_split_feature_inner_.RawData(), + cuda_split_feature_.RawData(), + cuda_split_gain_.RawData(), + cuda_internal_weight_.RawData(), + cuda_internal_value_.RawData(), + cuda_internal_count_.RawData(), + cuda_leaf_weight_.RawData(), + cuda_leaf_value_.RawData(), + cuda_leaf_count_.RawData(), + cuda_decision_type_.RawData(), + cuda_threshold_in_bin_.RawData(), + cuda_threshold_.RawData()); } __global__ void SplitCategoricalKernel( // split information @@ -264,22 +264,22 @@ void CUDATree::LaunchSplitCategoricalKernel(const int leaf_index, cuda_split_info, // tree structure num_leaves_, - cuda_leaf_parent_, - cuda_leaf_depth_, - cuda_left_child_, - cuda_right_child_, - cuda_split_feature_inner_, - cuda_split_feature_, - cuda_split_gain_, - cuda_internal_weight_, - cuda_internal_value_, - cuda_internal_count_, - cuda_leaf_weight_, - cuda_leaf_value_, - cuda_leaf_count_, - cuda_decision_type_, - cuda_threshold_in_bin_, - cuda_threshold_, + cuda_leaf_parent_.RawData(), + cuda_leaf_depth_.RawData(), + cuda_left_child_.RawData(), + cuda_right_child_.RawData(), + cuda_split_feature_inner_.RawData(), + cuda_split_feature_.RawData(), + cuda_split_gain_.RawData(), + cuda_internal_weight_.RawData(), + cuda_internal_value_.RawData(), + cuda_internal_count_.RawData(), + cuda_leaf_weight_.RawData(), + cuda_leaf_value_.RawData(), + cuda_leaf_count_.RawData(), + cuda_decision_type_.RawData(), + cuda_threshold_in_bin_.RawData(), + cuda_threshold_.RawData(), cuda_bitset_len, cuda_bitset_inner_len, num_cat_, @@ -297,7 +297,7 @@ __global__ void ShrinkageKernel(const double rate, double* cuda_leaf_value, cons void CUDATree::LaunchShrinkageKernel(const double rate) { const int num_threads_per_block = 1024; const int num_blocks = (num_leaves_ + num_threads_per_block - 1) / num_threads_per_block; - ShrinkageKernel<<>>(rate, cuda_leaf_value_, num_leaves_); + ShrinkageKernel<<>>(rate, cuda_leaf_value_.RawData(), num_leaves_); } __global__ void AddBiasKernel(const double val, double* cuda_leaf_value, const int num_leaves) { @@ -310,7 +310,7 @@ __global__ void AddBiasKernel(const double val, double* cuda_leaf_value, const i void CUDATree::LaunchAddBiasKernel(const double val) { const int num_threads_per_block = 1024; const int num_blocks = (num_leaves_ + num_threads_per_block - 1) / num_threads_per_block; - AddBiasKernel<<>>(val, cuda_leaf_value_, num_leaves_); + AddBiasKernel<<>>(val, cuda_leaf_value_.RawData(), num_leaves_); } template @@ -416,12 +416,12 @@ void CUDATree::LaunchAddPredictionToScoreKernel( cuda_column_data->cuda_feature_to_column(), nullptr, // tree information - cuda_threshold_in_bin_, - cuda_decision_type_, - cuda_split_feature_inner_, - cuda_left_child_, - cuda_right_child_, - cuda_leaf_value_, + cuda_threshold_in_bin_.RawData(), + cuda_decision_type_.RawData(), + cuda_split_feature_inner_.RawData(), + cuda_left_child_.RawData(), + cuda_right_child_.RawData(), + cuda_leaf_value_.RawData(), cuda_bitset_inner_.RawDataReadOnly(), cuda_cat_boundaries_inner_.RawDataReadOnly(), // output @@ -440,12 +440,12 @@ void CUDATree::LaunchAddPredictionToScoreKernel( cuda_column_data->cuda_feature_to_column(), used_data_indices, // tree information - cuda_threshold_in_bin_, - cuda_decision_type_, - cuda_split_feature_inner_, - cuda_left_child_, - cuda_right_child_, - cuda_leaf_value_, + cuda_threshold_in_bin_.RawData(), + cuda_decision_type_.RawData(), + cuda_split_feature_inner_.RawData(), + cuda_left_child_.RawData(), + cuda_right_child_.RawData(), + cuda_leaf_value_.RawData(), cuda_bitset_inner_.RawDataReadOnly(), cuda_cat_boundaries_inner_.RawDataReadOnly(), // output From 483e521debe08f1d61beaa22aeed75eee284eb3e Mon Sep 17 00:00:00 2001 From: Yu Shi Date: Fri, 25 Oct 2024 03:43:37 +0000 Subject: [PATCH 3/3] use CUDAVector for cuda column data --- include/LightGBM/cuda/cuda_column_data.hpp | 68 ++++--- src/io/cuda/cuda_column_data.cpp | 210 ++++++--------------- src/io/cuda/cuda_column_data.cu | 6 +- 3 files changed, 96 insertions(+), 188 deletions(-) diff --git a/include/LightGBM/cuda/cuda_column_data.hpp b/include/LightGBM/cuda/cuda_column_data.hpp index 56ffee3e07ed..964a3218a2b9 100644 --- a/include/LightGBM/cuda/cuda_column_data.hpp +++ b/include/LightGBM/cuda/cuda_column_data.hpp @@ -38,13 +38,11 @@ class CUDAColumnData { const std::vector& feature_mfb_is_na, const std::vector& feature_to_column); - const void* GetColumnData(const int column_index) const { return data_by_column_[column_index]; } + const void* GetColumnData(const int column_index) const { return data_by_column_[column_index]->RawData(); } void CopySubrow(const CUDAColumnData* full_set, const data_size_t* used_indices, const data_size_t num_used_indices); - void* const* cuda_data_by_column() const { return cuda_data_by_column_; } - - void* const* data_by_column() const { return data_by_column_.data(); } + void* const* cuda_data_by_column() const { return cuda_data_by_column_.RawData(); } uint32_t feature_min_bin(const int feature_index) const { return feature_min_bin_[feature_index]; } @@ -64,27 +62,27 @@ class CUDAColumnData { uint8_t feature_mfb_is_na(const int feature_index) const { return feature_mfb_is_na_[feature_index]; } - const uint32_t* cuda_feature_min_bin() const { return cuda_feature_min_bin_; } + const uint32_t* cuda_feature_min_bin() const { return cuda_feature_min_bin_.RawData(); } - const uint32_t* cuda_feature_max_bin() const { return cuda_feature_max_bin_; } + const uint32_t* cuda_feature_max_bin() const { return cuda_feature_max_bin_.RawData(); } - const uint32_t* cuda_feature_offset() const { return cuda_feature_offset_; } + const uint32_t* cuda_feature_offset() const { return cuda_feature_offset_.RawData(); } - const uint32_t* cuda_feature_most_freq_bin() const { return cuda_feature_most_freq_bin_; } + const uint32_t* cuda_feature_most_freq_bin() const { return cuda_feature_most_freq_bin_.RawData(); } - const uint32_t* cuda_feature_default_bin() const { return cuda_feature_default_bin_; } + const uint32_t* cuda_feature_default_bin() const { return cuda_feature_default_bin_.RawData(); } - const uint8_t* cuda_feature_missing_is_zero() const { return cuda_feature_missing_is_zero_; } + const uint8_t* cuda_feature_missing_is_zero() const { return cuda_feature_missing_is_zero_.RawData(); } - const uint8_t* cuda_feature_missing_is_na() const { return cuda_feature_missing_is_na_; } + const uint8_t* cuda_feature_missing_is_na() const { return cuda_feature_missing_is_na_.RawData(); } - const uint8_t* cuda_feature_mfb_is_zero() const { return cuda_feature_mfb_is_zero_; } + const uint8_t* cuda_feature_mfb_is_zero() const { return cuda_feature_mfb_is_zero_.RawData(); } - const uint8_t* cuda_feature_mfb_is_na() const { return cuda_feature_mfb_is_na_; } + const uint8_t* cuda_feature_mfb_is_na() const { return cuda_feature_mfb_is_na_.RawData(); } - const int* cuda_feature_to_column() const { return cuda_feature_to_column_; } + const int* cuda_feature_to_column() const { return cuda_feature_to_column_.RawData(); } - const uint8_t* cuda_column_bit_type() const { return cuda_column_bit_type_; } + const uint8_t* cuda_column_bit_type() const { return cuda_column_bit_type_.RawData(); } int feature_to_column(const int feature_index) const { return feature_to_column_[feature_index]; } @@ -92,7 +90,7 @@ class CUDAColumnData { private: template - void InitOneColumnData(const void* in_column_data, BinIterator* bin_iterator, void** out_column_data_pointer); + void InitOneColumnData(const void* in_column_data, BinIterator* bin_iterator, CUDAVector* out_column_data_pointer); void LaunchCopySubrowKernel(void* const* in_cuda_data_by_column); @@ -100,6 +98,14 @@ class CUDAColumnData { void ResizeWhenCopySubrow(const data_size_t num_used_indices); + std::vector GetDataByColumnPointers(const std::vector>>& data_by_column) const { + std::vector data_by_column_pointers(data_by_column.size(), nullptr); + for (size_t i = 0; i < data_by_column.size(); ++i) { + data_by_column_pointers[i] = reinterpret_cast(data_by_column[i]->RawData()); + } + return data_by_column_pointers; + } + int gpu_device_id_; int num_threads_; data_size_t num_data_; @@ -114,24 +120,24 @@ class CUDAColumnData { std::vector feature_missing_is_na_; std::vector feature_mfb_is_zero_; std::vector feature_mfb_is_na_; - void** cuda_data_by_column_; + CUDAVector cuda_data_by_column_; std::vector feature_to_column_; - std::vector data_by_column_; - - uint8_t* cuda_column_bit_type_; - uint32_t* cuda_feature_min_bin_; - uint32_t* cuda_feature_max_bin_; - uint32_t* cuda_feature_offset_; - uint32_t* cuda_feature_most_freq_bin_; - uint32_t* cuda_feature_default_bin_; - uint8_t* cuda_feature_missing_is_zero_; - uint8_t* cuda_feature_missing_is_na_; - uint8_t* cuda_feature_mfb_is_zero_; - uint8_t* cuda_feature_mfb_is_na_; - int* cuda_feature_to_column_; + std::vector>> data_by_column_; + + CUDAVector cuda_column_bit_type_; + CUDAVector cuda_feature_min_bin_; + CUDAVector cuda_feature_max_bin_; + CUDAVector cuda_feature_offset_; + CUDAVector cuda_feature_most_freq_bin_; + CUDAVector cuda_feature_default_bin_; + CUDAVector cuda_feature_missing_is_zero_; + CUDAVector cuda_feature_missing_is_na_; + CUDAVector cuda_feature_mfb_is_zero_; + CUDAVector cuda_feature_mfb_is_na_; + CUDAVector cuda_feature_to_column_; // used when bagging with subset - data_size_t* cuda_used_indices_; + CUDAVector cuda_used_indices_; data_size_t num_used_indices_; data_size_t cur_subset_buffer_size_; }; diff --git a/src/io/cuda/cuda_column_data.cpp b/src/io/cuda/cuda_column_data.cpp index 6c82508e2909..7fe2238defa0 100644 --- a/src/io/cuda/cuda_column_data.cpp +++ b/src/io/cuda/cuda_column_data.cpp @@ -14,45 +14,14 @@ CUDAColumnData::CUDAColumnData(const data_size_t num_data, const int gpu_device_ num_data_ = num_data; gpu_device_id_ = gpu_device_id >= 0 ? gpu_device_id : 0; SetCUDADevice(gpu_device_id_, __FILE__, __LINE__); - cuda_used_indices_ = nullptr; - cuda_data_by_column_ = nullptr; - cuda_column_bit_type_ = nullptr; - cuda_feature_min_bin_ = nullptr; - cuda_feature_max_bin_ = nullptr; - cuda_feature_offset_ = nullptr; - cuda_feature_most_freq_bin_ = nullptr; - cuda_feature_default_bin_ = nullptr; - cuda_feature_missing_is_zero_ = nullptr; - cuda_feature_missing_is_na_ = nullptr; - cuda_feature_mfb_is_zero_ = nullptr; - cuda_feature_mfb_is_na_ = nullptr; - cuda_feature_to_column_ = nullptr; data_by_column_.clear(); } -CUDAColumnData::~CUDAColumnData() { - DeallocateCUDAMemory(&cuda_used_indices_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_data_by_column_, __FILE__, __LINE__); - for (size_t i = 0; i < data_by_column_.size(); ++i) { - DeallocateCUDAMemory(&data_by_column_[i], __FILE__, __LINE__); - } - DeallocateCUDAMemory(&cuda_column_bit_type_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_feature_min_bin_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_feature_max_bin_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_feature_offset_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_feature_most_freq_bin_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_feature_default_bin_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_feature_missing_is_zero_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_feature_missing_is_na_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_feature_mfb_is_zero_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_feature_mfb_is_na_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_feature_to_column_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_used_indices_, __FILE__, __LINE__); -} +CUDAColumnData::~CUDAColumnData() {} template -void CUDAColumnData::InitOneColumnData(const void* in_column_data, BinIterator* bin_iterator, void** out_column_data_pointer) { - BIN_TYPE* cuda_column_data = nullptr; +void CUDAColumnData::InitOneColumnData(const void* in_column_data, BinIterator* bin_iterator, CUDAVector* out_column_data_pointer) { + CUDAVector cuda_column_data; if (!IS_SPARSE) { if (IS_4BIT) { std::vector expanded_column_data(num_data_, 0); @@ -60,17 +29,9 @@ void CUDAColumnData::InitOneColumnData(const void* in_column_data, BinIterator* for (data_size_t i = 0; i < num_data_; ++i) { expanded_column_data[i] = static_cast((in_column_data_reintrepreted[i >> 1] >> ((i & 1) << 2)) & 0xf); } - InitCUDAMemoryFromHostMemory(&cuda_column_data, - expanded_column_data.data(), - static_cast(num_data_), - __FILE__, - __LINE__); + cuda_column_data.InitFromHostVector(expanded_column_data); } else { - InitCUDAMemoryFromHostMemory(&cuda_column_data, - reinterpret_cast(in_column_data), - static_cast(num_data_), - __FILE__, - __LINE__); + cuda_column_data.InitFromHostMemory(reinterpret_cast(in_column_data), static_cast(num_data_)); } } else { // need to iterate bin iterator @@ -78,13 +39,9 @@ void CUDAColumnData::InitOneColumnData(const void* in_column_data, BinIterator* for (data_size_t i = 0; i < num_data_; ++i) { expanded_column_data[i] = static_cast(bin_iterator->RawGet(i)); } - InitCUDAMemoryFromHostMemory(&cuda_column_data, - expanded_column_data.data(), - static_cast(num_data_), - __FILE__, - __LINE__); + cuda_column_data.InitFromHostVector(expanded_column_data); } - *out_column_data_pointer = reinterpret_cast(cuda_column_data); + out_column_data_pointer->MoveFrom(cuda_column_data, sizeof(BIN_TYPE) * cuda_column_data.Size()); } void CUDAColumnData::Init(const int num_columns, @@ -112,7 +69,9 @@ void CUDAColumnData::Init(const int num_columns, feature_missing_is_na_ = feature_missing_is_na; feature_mfb_is_zero_ = feature_mfb_is_zero; feature_mfb_is_na_ = feature_mfb_is_na; - data_by_column_.resize(num_columns_, nullptr); + for (int column_index = 0; column_index < num_columns_; ++column_index) { + data_by_column_.emplace_back(new CUDAVector()); + } OMP_INIT_EX(); #pragma omp parallel num_threads(num_threads_) { @@ -125,24 +84,24 @@ void CUDAColumnData::Init(const int num_columns, // is dense column if (bit_type == 4) { column_bit_type_[column_index] = 8; - InitOneColumnData(column_data[column_index], nullptr, &data_by_column_[column_index]); + InitOneColumnData(column_data[column_index], nullptr, data_by_column_[column_index].get()); } else if (bit_type == 8) { - InitOneColumnData(column_data[column_index], nullptr, &data_by_column_[column_index]); + InitOneColumnData(column_data[column_index], nullptr, data_by_column_[column_index].get()); } else if (bit_type == 16) { - InitOneColumnData(column_data[column_index], nullptr, &data_by_column_[column_index]); + InitOneColumnData(column_data[column_index], nullptr, data_by_column_[column_index].get()); } else if (bit_type == 32) { - InitOneColumnData(column_data[column_index], nullptr, &data_by_column_[column_index]); + InitOneColumnData(column_data[column_index], nullptr, data_by_column_[column_index].get()); } else { Log::Fatal("Unknow column bit type %d", bit_type); } } else { // is sparse column if (bit_type == 8) { - InitOneColumnData(nullptr, column_bin_iterator[column_index], &data_by_column_[column_index]); + InitOneColumnData(nullptr, column_bin_iterator[column_index], data_by_column_[column_index].get()); } else if (bit_type == 16) { - InitOneColumnData(nullptr, column_bin_iterator[column_index], &data_by_column_[column_index]); + InitOneColumnData(nullptr, column_bin_iterator[column_index], data_by_column_[column_index].get()); } else if (bit_type == 32) { - InitOneColumnData(nullptr, column_bin_iterator[column_index], &data_by_column_[column_index]); + InitOneColumnData(nullptr, column_bin_iterator[column_index], data_by_column_[column_index].get()); } else { Log::Fatal("Unknow column bit type %d", bit_type); } @@ -152,11 +111,7 @@ void CUDAColumnData::Init(const int num_columns, } OMP_THROW_EX(); feature_to_column_ = feature_to_column; - InitCUDAMemoryFromHostMemory(&cuda_data_by_column_, - data_by_column_.data(), - data_by_column_.size(), - __FILE__, - __LINE__); + cuda_data_by_column_.InitFromHostVector(GetDataByColumnPointers(data_by_column_)); InitColumnMetaInfo(); } @@ -177,11 +132,13 @@ void CUDAColumnData::CopySubrow( feature_mfb_is_zero_ = full_set->feature_mfb_is_zero_; feature_mfb_is_na_ = full_set->feature_mfb_is_na_; feature_to_column_ = full_set->feature_to_column_; - if (cuda_used_indices_ == nullptr) { + if (cuda_used_indices_.Size() == 0) { // initialize the subset cuda column data const size_t num_used_indices_size = static_cast(num_used_indices); - AllocateCUDAMemory(&cuda_used_indices_, num_used_indices_size, __FILE__, __LINE__); - data_by_column_.resize(num_columns_, nullptr); + cuda_used_indices_.Resize(num_used_indices_size); + for (int column_index = 0; column_index < num_columns_; ++column_index) { + data_by_column_.emplace_back(new CUDAVector()); + } OMP_INIT_EX(); #pragma omp parallel num_threads(num_threads_) { @@ -191,23 +148,23 @@ void CUDAColumnData::CopySubrow( OMP_LOOP_EX_BEGIN(); const uint8_t bit_type = column_bit_type_[column_index]; if (bit_type == 8) { - uint8_t* column_data = nullptr; - AllocateCUDAMemory(&column_data, num_used_indices_size, __FILE__, __LINE__); - data_by_column_[column_index] = reinterpret_cast(column_data); + CUDAVector column_data; + column_data.Resize(num_used_indices_size); + data_by_column_[column_index]->MoveFrom(column_data, sizeof(uint8_t) * column_data.Size()); } else if (bit_type == 16) { - uint16_t* column_data = nullptr; - AllocateCUDAMemory(&column_data, num_used_indices_size, __FILE__, __LINE__); - data_by_column_[column_index] = reinterpret_cast(column_data); + CUDAVector column_data; + column_data.Resize(num_used_indices_size); + data_by_column_[column_index]->MoveFrom(column_data, sizeof(uint16_t) * column_data.Size()); } else if (bit_type == 32) { - uint32_t* column_data = nullptr; - AllocateCUDAMemory(&column_data, num_used_indices_size, __FILE__, __LINE__); - data_by_column_[column_index] = reinterpret_cast(column_data); + CUDAVector column_data; + column_data.Resize(num_used_indices_size); + data_by_column_[column_index]->MoveFrom(column_data, sizeof(uint32_t) * column_data.Size()); } OMP_LOOP_EX_END(); } } OMP_THROW_EX(); - InitCUDAMemoryFromHostMemory(&cuda_data_by_column_, data_by_column_.data(), data_by_column_.size(), __FILE__, __LINE__); + cuda_data_by_column_.InitFromHostVector(GetDataByColumnPointers(data_by_column_)); InitColumnMetaInfo(); cur_subset_buffer_size_ = num_used_indices; } else { @@ -216,23 +173,23 @@ void CUDAColumnData::CopySubrow( cur_subset_buffer_size_ = num_used_indices; } } - CopyFromHostToCUDADevice(cuda_used_indices_, used_indices, static_cast(num_used_indices), __FILE__, __LINE__); + cuda_used_indices_.InitFromHostMemory(used_indices, static_cast(num_used_indices)); num_used_indices_ = num_used_indices; for (int column_index = 0; column_index < num_columns_; ++column_index) { if (column_bit_type_[column_index] == 8) { CopyFromCUDADeviceToCUDADevice( - reinterpret_cast(data_by_column_[column_index]), - reinterpret_cast(full_set->data_by_column()[column_index]) + used_indices[0], + reinterpret_cast(data_by_column_[column_index]->RawData()), + reinterpret_cast(full_set->GetColumnData(column_index)) + used_indices[0], static_cast(num_used_indices_), __FILE__, __LINE__); } else if (column_bit_type_[column_index] == 16) { CopyFromCUDADeviceToCUDADevice( - reinterpret_cast(data_by_column_[column_index]), - reinterpret_cast(full_set->data_by_column()[column_index]) + used_indices[0], + reinterpret_cast(data_by_column_[column_index]->RawData()), + reinterpret_cast(full_set->GetColumnData(column_index)) + used_indices[0], static_cast(num_used_indices_), __FILE__, __LINE__); } else if (column_bit_type_[column_index] == 32) { CopyFromCUDADeviceToCUDADevice( - reinterpret_cast(data_by_column_[column_index]), - reinterpret_cast(full_set->data_by_column()[column_index]) + used_indices[0], + reinterpret_cast(data_by_column_[column_index]->RawData()), + reinterpret_cast(full_set->GetColumnData(column_index)) + used_indices[0], static_cast(num_used_indices_), __FILE__, __LINE__); } } @@ -241,8 +198,7 @@ void CUDAColumnData::CopySubrow( void CUDAColumnData::ResizeWhenCopySubrow(const data_size_t num_used_indices) { const size_t num_used_indices_size = static_cast(num_used_indices); - DeallocateCUDAMemory(&cuda_used_indices_, __FILE__, __LINE__); - AllocateCUDAMemory(&cuda_used_indices_, num_used_indices_size, __FILE__, __LINE__); + cuda_used_indices_.Resize(num_used_indices_size); OMP_INIT_EX(); #pragma omp parallel num_threads(num_threads_) { @@ -252,85 +208,31 @@ void CUDAColumnData::ResizeWhenCopySubrow(const data_size_t num_used_indices) { OMP_LOOP_EX_BEGIN(); const uint8_t bit_type = column_bit_type_[column_index]; if (bit_type == 8) { - uint8_t* column_data = reinterpret_cast(data_by_column_[column_index]); - DeallocateCUDAMemory(&column_data, __FILE__, __LINE__); - AllocateCUDAMemory(&column_data, num_used_indices_size, __FILE__, __LINE__); - data_by_column_[column_index] = reinterpret_cast(column_data); + data_by_column_[column_index]->Resize(sizeof(uint8_t) * num_used_indices_size); } else if (bit_type == 16) { - uint16_t* column_data = reinterpret_cast(data_by_column_[column_index]); - DeallocateCUDAMemory(&column_data, __FILE__, __LINE__); - AllocateCUDAMemory(&column_data, num_used_indices_size, __FILE__, __LINE__); - data_by_column_[column_index] = reinterpret_cast(column_data); + data_by_column_[column_index]->Resize(sizeof(uint16_t) * num_used_indices_size); } else if (bit_type == 32) { - uint32_t* column_data = reinterpret_cast(data_by_column_[column_index]); - DeallocateCUDAMemory(&column_data, __FILE__, __LINE__); - AllocateCUDAMemory(&column_data, num_used_indices_size, __FILE__, __LINE__); - data_by_column_[column_index] = reinterpret_cast(column_data); + data_by_column_[column_index]->Resize(sizeof(uint32_t) * num_used_indices_size); } OMP_LOOP_EX_END(); } } OMP_THROW_EX(); - DeallocateCUDAMemory(&cuda_data_by_column_, __FILE__, __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_data_by_column_, data_by_column_.data(), data_by_column_.size(), __FILE__, __LINE__); + cuda_data_by_column_.InitFromHostVector(GetDataByColumnPointers(data_by_column_)); } void CUDAColumnData::InitColumnMetaInfo() { - InitCUDAMemoryFromHostMemory(&cuda_column_bit_type_, - column_bit_type_.data(), - column_bit_type_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_feature_max_bin_, - feature_max_bin_.data(), - feature_max_bin_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_feature_min_bin_, - feature_min_bin_.data(), - feature_min_bin_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_feature_offset_, - feature_offset_.data(), - feature_offset_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_feature_most_freq_bin_, - feature_most_freq_bin_.data(), - feature_most_freq_bin_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_feature_default_bin_, - feature_default_bin_.data(), - feature_default_bin_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_feature_missing_is_zero_, - feature_missing_is_zero_.data(), - feature_missing_is_zero_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_feature_missing_is_na_, - feature_missing_is_na_.data(), - feature_missing_is_na_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_feature_mfb_is_zero_, - feature_mfb_is_zero_.data(), - feature_mfb_is_zero_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_feature_mfb_is_na_, - feature_mfb_is_na_.data(), - feature_mfb_is_na_.size(), - __FILE__, - __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_feature_to_column_, - feature_to_column_.data(), - feature_to_column_.size(), - __FILE__, - __LINE__); + cuda_column_bit_type_.InitFromHostVector(column_bit_type_); + cuda_feature_max_bin_.InitFromHostVector(feature_max_bin_); + cuda_feature_min_bin_.InitFromHostVector(feature_min_bin_); + cuda_feature_offset_.InitFromHostVector(feature_offset_); + cuda_feature_most_freq_bin_.InitFromHostVector(feature_most_freq_bin_); + cuda_feature_default_bin_.InitFromHostVector(feature_default_bin_); + cuda_feature_missing_is_zero_.InitFromHostVector(feature_missing_is_zero_); + cuda_feature_missing_is_na_.InitFromHostVector(feature_missing_is_na_); + cuda_feature_mfb_is_zero_.InitFromHostVector(feature_mfb_is_zero_); + cuda_feature_mfb_is_na_.InitFromHostVector(feature_mfb_is_na_); + cuda_feature_to_column_.InitFromHostVector(feature_to_column_); } } // namespace LightGBM diff --git a/src/io/cuda/cuda_column_data.cu b/src/io/cuda/cuda_column_data.cu index 75ff6234e09e..6341ca76129d 100644 --- a/src/io/cuda/cuda_column_data.cu +++ b/src/io/cuda/cuda_column_data.cu @@ -49,11 +49,11 @@ void CUDAColumnData::LaunchCopySubrowKernel(void* const* in_cuda_data_by_column) const int num_blocks = (num_used_indices_ + COPY_SUBROW_BLOCK_SIZE_COLUMN_DATA - 1) / COPY_SUBROW_BLOCK_SIZE_COLUMN_DATA; CopySubrowKernel_ColumnData<<>>( in_cuda_data_by_column, - cuda_column_bit_type_, - cuda_used_indices_, + cuda_column_bit_type_.RawData(), + cuda_used_indices_.RawData(), num_used_indices_, num_columns_, - cuda_data_by_column_); + cuda_data_by_column_.RawData()); } } // namespace LightGBM