diff --git a/include/LightGBM/cuda/cuda_tree.hpp b/include/LightGBM/cuda/cuda_tree.hpp index 1ed8c2b95c0d..9d89dc3b7465 100644 --- a/include/LightGBM/cuda/cuda_tree.hpp +++ b/include/LightGBM/cuda/cuda_tree.hpp @@ -134,6 +134,10 @@ class CUDATree : public Tree { void LaunchAddBiasKernel(const double val); + void RecordBranchFeatures(const int left_leaf_index, + const int right_leaf_index, + const int real_feature_index); + int* cuda_left_child_; int* cuda_right_child_; int* cuda_split_feature_inner_; diff --git a/src/io/cuda/cuda_tree.cpp b/src/io/cuda/cuda_tree.cpp index 0a5cc11a9c8d..b7ecee6e6167 100644 --- a/src/io/cuda/cuda_tree.cpp +++ b/src/io/cuda/cuda_tree.cpp @@ -216,6 +216,7 @@ int CUDATree::Split(const int leaf_index, const MissingType missing_type, const CUDASplitInfo* cuda_split_info) { LaunchSplitKernel(leaf_index, real_feature_index, real_threshold, missing_type, cuda_split_info); + RecordBranchFeatures(leaf_index, num_leaves_, real_feature_index); ++num_leaves_; return num_leaves_ - 1; } @@ -235,9 +236,20 @@ int CUDATree::SplitCategorical(const int leaf_index, cuda_bitset_inner_.PushBack(cuda_bitset_inner, cuda_bitset_inner_len); ++num_leaves_; ++num_cat_; + RecordBranchFeatures(leaf_index, num_leaves_, real_feature_index); return num_leaves_ - 1; } +void CUDATree::RecordBranchFeatures(const int left_leaf_index, + const int right_leaf_index, + const int real_feature_index) { + if (track_branch_features_) { + branch_features_[right_leaf_index] = branch_features_[left_leaf_index]; + branch_features_[right_leaf_index].push_back(real_feature_index); + branch_features_[left_leaf_index].push_back(real_feature_index); + } +} + void CUDATree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const { diff --git a/src/treelearner/cuda/cuda_best_split_finder.cpp b/src/treelearner/cuda/cuda_best_split_finder.cpp index 51589a673aa8..fdca46ec1647 100644 --- a/src/treelearner/cuda/cuda_best_split_finder.cpp +++ b/src/treelearner/cuda/cuda_best_split_finder.cpp @@ -17,6 +17,7 @@ CUDABestSplitFinder::CUDABestSplitFinder( const hist_t* cuda_hist, const Dataset* train_data, const std::vector& feature_hist_offsets, + const bool select_features_by_node, const Config* config): num_features_(train_data->num_features()), num_leaves_(config->num_leaves), @@ -36,6 +37,7 @@ CUDABestSplitFinder::CUDABestSplitFinder( use_smoothing_(config->path_smooth > 0), path_smooth_(config->path_smooth), num_total_bin_(feature_hist_offsets.empty() ? 0 : static_cast(feature_hist_offsets.back())), + select_features_by_node_(select_features_by_node), cuda_hist_(cuda_hist) { InitFeatureMetaInfo(train_data); cuda_leaf_best_split_info_ = nullptr; @@ -105,6 +107,11 @@ void CUDABestSplitFinder::Init() { AllocateCUDAMemory(&cuda_feature_hist_index_buffer_, static_cast(num_total_bin_), __FILE__, __LINE__); } } + + if (select_features_by_node_) { + is_feature_used_by_smaller_node_.Resize(num_features_); + is_feature_used_by_larger_node_.Resize(num_features_); + } } void CUDABestSplitFinder::InitCUDAFeatureMetaInfo() { @@ -364,6 +371,16 @@ void CUDABestSplitFinder::AllocateCatVectors(CUDASplitInfo* cuda_split_infos, ui LaunchAllocateCatVectorsKernel(cuda_split_infos, cat_threshold_vec, cat_threshold_real_vec, len); } +void CUDABestSplitFinder::SetUsedFeatureByNode(const std::vector& is_feature_used_by_smaller_node, + const std::vector& is_feature_used_by_larger_node) { + if (select_features_by_node_) { + CopyFromHostToCUDADevice(is_feature_used_by_smaller_node_.RawData(), + is_feature_used_by_smaller_node.data(), is_feature_used_by_smaller_node.size(), __FILE__, __LINE__); + CopyFromHostToCUDADevice(is_feature_used_by_larger_node_.RawData(), + is_feature_used_by_larger_node.data(), is_feature_used_by_larger_node.size(), __FILE__, __LINE__); + } +} + } // namespace LightGBM #endif // USE_CUDA_EXP diff --git a/src/treelearner/cuda/cuda_best_split_finder.cu b/src/treelearner/cuda/cuda_best_split_finder.cu index e11fe436a320..04896c40e7a9 100644 --- a/src/treelearner/cuda/cuda_best_split_finder.cu +++ b/src/treelearner/cuda/cuda_best_split_finder.cu @@ -1375,7 +1375,6 @@ __global__ void FindBestSplitsForLeafKernel_GlobalMemory( is_larger_leaf_valid #define FindBestSplitsForLeafKernel_ARGS \ - cuda_is_feature_used_bytree_, \ num_tasks_, \ cuda_split_find_tasks_.RawData(), \ cuda_randoms_.RawData(), \ @@ -1430,29 +1429,35 @@ void CUDABestSplitFinder::LaunchFindBestSplitsForLeafKernelInner1(LaunchFindBest template void CUDABestSplitFinder::LaunchFindBestSplitsForLeafKernelInner2(LaunchFindBestSplitsForLeafKernel_PARAMS) { + const int8_t* is_feature_used_by_smaller_node = cuda_is_feature_used_bytree_; + const int8_t* is_feature_used_by_larger_node = cuda_is_feature_used_bytree_; + if (select_features_by_node_) { + is_feature_used_by_smaller_node = is_feature_used_by_smaller_node_.RawData(); + is_feature_used_by_larger_node = is_feature_used_by_larger_node_.RawData(); + } if (!use_global_memory_) { if (is_smaller_leaf_valid) { FindBestSplitsForLeafKernel <<>> - (FindBestSplitsForLeafKernel_ARGS); + (is_feature_used_by_smaller_node, FindBestSplitsForLeafKernel_ARGS); } SynchronizeCUDADevice(__FILE__, __LINE__); if (is_larger_leaf_valid) { FindBestSplitsForLeafKernel <<>> - (FindBestSplitsForLeafKernel_ARGS); + (is_feature_used_by_larger_node, FindBestSplitsForLeafKernel_ARGS); } } else { if (is_smaller_leaf_valid) { FindBestSplitsForLeafKernel_GlobalMemory <<>> - (FindBestSplitsForLeafKernel_ARGS, GlobalMemory_Buffer_ARGS); + (is_feature_used_by_smaller_node, FindBestSplitsForLeafKernel_ARGS, GlobalMemory_Buffer_ARGS); } SynchronizeCUDADevice(__FILE__, __LINE__); if (is_larger_leaf_valid) { FindBestSplitsForLeafKernel_GlobalMemory <<>> - (FindBestSplitsForLeafKernel_ARGS, GlobalMemory_Buffer_ARGS); + (is_feature_used_by_larger_node, FindBestSplitsForLeafKernel_ARGS, GlobalMemory_Buffer_ARGS); } } } diff --git a/src/treelearner/cuda/cuda_best_split_finder.hpp b/src/treelearner/cuda/cuda_best_split_finder.hpp index 3efc6011c83b..e9c12922cde6 100644 --- a/src/treelearner/cuda/cuda_best_split_finder.hpp +++ b/src/treelearner/cuda/cuda_best_split_finder.hpp @@ -46,6 +46,7 @@ class CUDABestSplitFinder { const hist_t* cuda_hist, const Dataset* train_data, const std::vector& feature_hist_offsets, + const bool select_features_by_node, const Config* config); ~CUDABestSplitFinder(); @@ -88,6 +89,9 @@ class CUDABestSplitFinder { void ResetConfig(const Config* config, const hist_t* cuda_hist); + void SetUsedFeatureByNode(const std::vector& is_feature_used_by_smaller_node, + const std::vector& is_feature_used_by_larger_node); + private: #define LaunchFindBestSplitsForLeafKernel_PARAMS \ const CUDALeafSplitsStruct* smaller_leaf_splits, \ @@ -172,6 +176,8 @@ class CUDABestSplitFinder { int max_num_categorical_bin_; // marks whether a feature is categorical std::vector is_categorical_; + // whether need to select features by node + bool select_features_by_node_; // CUDA memory, held by this object // for per leaf best split information @@ -195,6 +201,9 @@ class CUDABestSplitFinder { int max_num_categories_in_split_; // used for extremely randomized trees CUDAVector cuda_randoms_; + // features used by node + CUDAVector is_feature_used_by_smaller_node_; + CUDAVector is_feature_used_by_larger_node_; // CUDA memory, held by other object const hist_t* cuda_hist_; diff --git a/src/treelearner/cuda/cuda_single_gpu_tree_learner.cpp b/src/treelearner/cuda/cuda_single_gpu_tree_learner.cpp index 9bf02b1553b7..f8e6fbfec725 100644 --- a/src/treelearner/cuda/cuda_single_gpu_tree_learner.cpp +++ b/src/treelearner/cuda/cuda_single_gpu_tree_learner.cpp @@ -55,8 +55,9 @@ void CUDASingleGPUTreeLearner::Init(const Dataset* train_data, bool is_constant_ cuda_histogram_constructor_->cuda_hist_pointer())); cuda_data_partition_->Init(); + select_features_by_node_ = !config_->interaction_constraints_vector.empty() || config_->feature_fraction_bynode < 1.0; cuda_best_split_finder_.reset(new CUDABestSplitFinder(cuda_histogram_constructor_->cuda_hist(), - train_data_, this->share_state_->feature_hist_offsets(), config_)); + train_data_, this->share_state_->feature_hist_offsets(), select_features_by_node_, config_)); cuda_best_split_finder_->Init(); leaf_best_split_feature_.resize(config_->num_leaves, -1); @@ -149,6 +150,9 @@ Tree* CUDASingleGPUTreeLearner::Train(const score_t* gradients, sum_hessians_in_larger_leaf); global_timer.Stop("CUDASingleGPUTreeLearner::ConstructHistogramForLeaf"); global_timer.Start("CUDASingleGPUTreeLearner::FindBestSplitsForLeaf"); + + SelectFeatureByNode(tree.get()); + cuda_best_split_finder_->FindBestSplitsForLeaf( cuda_smaller_leaf_splits_->GetCUDAStruct(), cuda_larger_leaf_splits_->GetCUDAStruct(), @@ -464,6 +468,18 @@ void CUDASingleGPUTreeLearner::ResetBoostingOnGPU(const bool boosting_on_cuda) { } } +void CUDASingleGPUTreeLearner::SelectFeatureByNode(const Tree* tree) { + if (select_features_by_node_) { + // use feature interaction constraint or sample features by node + const std::vector& is_feature_used_by_smaller_node = col_sampler_.GetByNode(tree, smaller_leaf_index_); + std::vector is_feature_used_by_larger_node; + if (larger_leaf_index_ >= 0) { + is_feature_used_by_larger_node = col_sampler_.GetByNode(tree, larger_leaf_index_); + } + cuda_best_split_finder_->SetUsedFeatureByNode(is_feature_used_by_smaller_node, is_feature_used_by_larger_node); + } +} + #ifdef DEBUG void CUDASingleGPUTreeLearner::CheckSplitValid( const int left_leaf, diff --git a/src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp b/src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp index e65cd428bcd1..b1922f5f28c5 100644 --- a/src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp +++ b/src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp @@ -66,6 +66,8 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner { void AllocateBitset(); + void SelectFeatureByNode(const Tree* tree); + #ifdef DEUBG void CheckSplitValid( const int left_leaf, const int right_leaf, @@ -100,6 +102,8 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner { int best_leaf_index_; int num_cat_threshold_; bool has_categorical_feature_; + // whether need to select features by node + bool select_features_by_node_; std::vector categorical_bin_to_value_; std::vector categorical_bin_offsets_; diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 8594c09ededf..8318216e1411 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -3118,7 +3118,6 @@ def _imptcs_to_numpy(X, impcts_dict): assert tree_df.loc[0, col] is None -@pytest.mark.skipif(getenv('TASK', '') == 'cuda_exp', reason='Interaction constraints are not yet supported by CUDA Experimental version') def test_interaction_constraints(): X, y = load_boston(return_X_y=True) num_features = X.shape[1]