diff --git a/include/LightGBM/cuda/cuda_algorithms.hpp b/include/LightGBM/cuda/cuda_algorithms.hpp index 1df479210bac..777880a6d017 100644 --- a/include/LightGBM/cuda/cuda_algorithms.hpp +++ b/include/LightGBM/cuda/cuda_algorithms.hpp @@ -18,17 +18,11 @@ #include -#define NUM_BANKS_DATA_PARTITION (16) -#define LOG_NUM_BANKS_DATA_PARTITION (4) #define GLOBAL_PREFIX_SUM_BLOCK_SIZE (1024) - #define BITONIC_SORT_NUM_ELEMENTS (1024) #define BITONIC_SORT_DEPTH (11) #define BITONIC_SORT_QUERY_ITEM_BLOCK_SIZE (10) -#define CONFLICT_FREE_INDEX(n) \ - ((n) + ((n) >> LOG_NUM_BANKS_DATA_PARTITION)) \ - namespace LightGBM { template @@ -223,6 +217,54 @@ __device__ __forceinline__ void BitonicArgSort_1024(const VAL_T* scores, INDEX_T } } +template +__device__ __forceinline__ void BitonicArgSort_2048(const VAL_T* scores, INDEX_T* indices) { + for (INDEX_T base = 0; base < 2048; base += 1024) { + for (INDEX_T outer_depth = 10; outer_depth >= 1; --outer_depth) { + const INDEX_T outer_segment_length = 1 << (11 - outer_depth); + const INDEX_T outer_segment_index = threadIdx.x / outer_segment_length; + const bool ascending = ((base == 0) ^ ASCENDING) ? (outer_segment_index % 2 > 0) : (outer_segment_index % 2 == 0); + for (INDEX_T inner_depth = outer_depth; inner_depth < 11; ++inner_depth) { + const INDEX_T segment_length = 1 << (11 - inner_depth); + const INDEX_T half_segment_length = segment_length >> 1; + const INDEX_T half_segment_index = threadIdx.x / half_segment_length; + if (half_segment_index % 2 == 0) { + const INDEX_T index_to_compare = threadIdx.x + half_segment_length + base; + if ((scores[indices[threadIdx.x + base]] > scores[indices[index_to_compare]]) == ascending) { + const INDEX_T index = indices[threadIdx.x + base]; + indices[threadIdx.x + base] = indices[index_to_compare]; + indices[index_to_compare] = index; + } + } + __syncthreads(); + } + } + } + const unsigned int index_to_compare = threadIdx.x + 1024; + if (scores[indices[index_to_compare]] > scores[indices[threadIdx.x]]) { + const INDEX_T temp_index = indices[index_to_compare]; + indices[index_to_compare] = indices[threadIdx.x]; + indices[threadIdx.x] = temp_index; + } + __syncthreads(); + for (INDEX_T base = 0; base < 2048; base += 1024) { + for (INDEX_T inner_depth = 1; inner_depth < 11; ++inner_depth) { + const INDEX_T segment_length = 1 << (11 - inner_depth); + const INDEX_T half_segment_length = segment_length >> 1; + const INDEX_T half_segment_index = threadIdx.x / half_segment_length; + if (half_segment_index % 2 == 0) { + const INDEX_T index_to_compare = threadIdx.x + half_segment_length + base; + if (scores[indices[threadIdx.x + base]] < scores[indices[index_to_compare]]) { + const INDEX_T index = indices[threadIdx.x + base]; + indices[threadIdx.x + base] = indices[index_to_compare]; + indices[index_to_compare] = index; + } + } + __syncthreads(); + } + } +} + template __device__ void BitonicArgSortDevice(const VAL_T* values, INDEX_T* indices, const int len) { __shared__ VAL_T shared_values[BLOCK_DIM]; @@ -387,6 +429,12 @@ __device__ void BitonicArgSortDevice(const VAL_T* values, INDEX_T* indices, cons } } +void BitonicArgSortItemsGlobal( + const double* scores, + const int num_queries, + const data_size_t* cuda_query_boundaries, + data_size_t* out_indices); + template void BitonicArgSortGlobal(const VAL_T* values, INDEX_T* indices, const size_t len); diff --git a/src/cuda/cuda_algorithms.cu b/src/cuda/cuda_algorithms.cu index 7f84955ce09c..c8eb61c3cc72 100644 --- a/src/cuda/cuda_algorithms.cu +++ b/src/cuda/cuda_algorithms.cu @@ -77,6 +77,34 @@ void ShufflePrefixSumGlobal(uint64_t* values, size_t len, uint64_t* block_prefix ShufflePrefixSumGlobalInner(values, len, block_prefix_sum_buffer); } +__global__ void BitonicArgSortItemsGlobalKernel(const double* scores, + const int num_queries, + const data_size_t* cuda_query_boundaries, + data_size_t* out_indices) { + const int query_index_start = static_cast(blockIdx.x) * BITONIC_SORT_QUERY_ITEM_BLOCK_SIZE; + const int query_index_end = min(query_index_start + BITONIC_SORT_QUERY_ITEM_BLOCK_SIZE, num_queries); + for (int query_index = query_index_start; query_index < query_index_end; ++query_index) { + const data_size_t query_item_start = cuda_query_boundaries[query_index]; + const data_size_t query_item_end = cuda_query_boundaries[query_index + 1]; + const data_size_t num_items_in_query = query_item_end - query_item_start; + BitonicArgSortDevice(scores + query_item_start, + out_indices + query_item_start, + num_items_in_query); + __syncthreads(); + } +} + +void BitonicArgSortItemsGlobal( + const double* scores, + const int num_queries, + const data_size_t* cuda_query_boundaries, + data_size_t* out_indices) { + const int num_blocks = (num_queries + BITONIC_SORT_QUERY_ITEM_BLOCK_SIZE - 1) / BITONIC_SORT_QUERY_ITEM_BLOCK_SIZE; + BitonicArgSortItemsGlobalKernel<<>>( + scores, num_queries, cuda_query_boundaries, out_indices); + SynchronizeCUDADevice(__FILE__, __LINE__); +} + template __global__ void BlockReduceSum(T* block_buffer, const data_size_t num_blocks) { __shared__ T shared_buffer[32]; diff --git a/src/objective/cuda/cuda_rank_objective.cpp b/src/objective/cuda/cuda_rank_objective.cpp new file mode 100644 index 000000000000..d99597727862 --- /dev/null +++ b/src/objective/cuda/cuda_rank_objective.cpp @@ -0,0 +1,65 @@ +/*! + * Copyright (c) 2021 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for + * license information. + */ + +#ifdef USE_CUDA_EXP + +#include +#include + +#include "cuda_rank_objective.hpp" + +namespace LightGBM { + +CUDALambdarankNDCG::CUDALambdarankNDCG(const Config& config): +LambdarankNDCG(config) {} + +CUDALambdarankNDCG::CUDALambdarankNDCG(const std::vector& strs): LambdarankNDCG(strs) {} + +void CUDALambdarankNDCG::Init(const Metadata& metadata, data_size_t num_data) { + const int num_threads = OMP_NUM_THREADS(); + LambdarankNDCG::Init(metadata, num_data); + + std::vector thread_max_num_items_in_query(num_threads); + Threading::For(0, num_queries_, 1, + [this, &thread_max_num_items_in_query] (int thread_index, data_size_t start, data_size_t end) { + for (data_size_t query_index = start; query_index < end; ++query_index) { + const data_size_t query_item_count = query_boundaries_[query_index + 1] - query_boundaries_[query_index]; + if (query_item_count > thread_max_num_items_in_query[thread_index]) { + thread_max_num_items_in_query[thread_index] = query_item_count; + } + } + }); + data_size_t max_items_in_query = 0; + for (int thread_index = 0; thread_index < num_threads; ++thread_index) { + if (thread_max_num_items_in_query[thread_index] > max_items_in_query) { + max_items_in_query = thread_max_num_items_in_query[thread_index]; + } + } + max_items_in_query_aligned_ = 1; + --max_items_in_query; + while (max_items_in_query > 0) { + max_items_in_query >>= 1; + max_items_in_query_aligned_ <<= 1; + } + if (max_items_in_query_aligned_ > 2048) { + cuda_item_indices_buffer_.Resize(static_cast(metadata.query_boundaries()[metadata.num_queries()])); + } + cuda_labels_ = metadata.cuda_metadata()->cuda_label(); + cuda_query_boundaries_ = metadata.cuda_metadata()->cuda_query_boundaries(); + cuda_inverse_max_dcgs_.Resize(inverse_max_dcgs_.size()); + CopyFromHostToCUDADevice(cuda_inverse_max_dcgs_.RawData(), inverse_max_dcgs_.data(), inverse_max_dcgs_.size(), __FILE__, __LINE__); + cuda_label_gain_.Resize(label_gain_.size()); + CopyFromHostToCUDADevice(cuda_label_gain_.RawData(), label_gain_.data(), label_gain_.size(), __FILE__, __LINE__); +} + +void CUDALambdarankNDCG::GetGradients(const double* score, score_t* gradients, score_t* hessians) const { + LaunchGetGradientsKernel(score, gradients, hessians); +} + + +} // namespace LightGBM + +#endif // USE_CUDA_EXP diff --git a/src/objective/cuda/cuda_rank_objective.cu b/src/objective/cuda/cuda_rank_objective.cu new file mode 100644 index 000000000000..5055a07a05f7 --- /dev/null +++ b/src/objective/cuda/cuda_rank_objective.cu @@ -0,0 +1,382 @@ +/*! + * Copyright (c) 2021 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for + * license information. + */ + +#ifdef USE_CUDA_EXP + +#include "cuda_rank_objective.hpp" + +#include +#include +#include + +namespace LightGBM { + +template +__global__ void GetGradientsKernel_LambdarankNDCG(const double* cuda_scores, const label_t* cuda_labels, const data_size_t num_data, + const data_size_t num_queries, const data_size_t* cuda_query_boundaries, const double* cuda_inverse_max_dcgs, + const bool norm, const double sigmoid, const int truncation_level, const double* cuda_label_gain, const data_size_t num_rank_label, + score_t* cuda_out_gradients, score_t* cuda_out_hessians) { + __shared__ score_t shared_scores[MAX_ITEM_GREATER_THAN_1024 ? 2048 : 1024]; + __shared__ uint16_t shared_indices[MAX_ITEM_GREATER_THAN_1024 ? 2048 : 1024]; + __shared__ score_t shared_lambdas[MAX_ITEM_GREATER_THAN_1024 ? 2048 : 1024]; + __shared__ score_t shared_hessians[MAX_ITEM_GREATER_THAN_1024 ? 2048 : 1024]; + __shared__ double shared_label_gain[NUM_RANK_LABEL > 1024 ? 1 : NUM_RANK_LABEL]; + const double* label_gain_ptr = nullptr; + if (NUM_RANK_LABEL <= 1024) { + for (uint32_t i = threadIdx.x; i < num_rank_label; i += blockDim.x) { + shared_label_gain[i] = cuda_label_gain[i]; + } + __syncthreads(); + label_gain_ptr = shared_label_gain; + } else { + label_gain_ptr = cuda_label_gain; + } + const data_size_t query_index_start = static_cast(blockIdx.x) * NUM_QUERY_PER_BLOCK; + const data_size_t query_index_end = min(query_index_start + NUM_QUERY_PER_BLOCK, num_queries); + for (data_size_t query_index = query_index_start; query_index < query_index_end; ++query_index) { + const double inverse_max_dcg = cuda_inverse_max_dcgs[query_index]; + const data_size_t query_start = cuda_query_boundaries[query_index]; + const data_size_t query_end = cuda_query_boundaries[query_index + 1]; + const data_size_t query_item_count = query_end - query_start; + const double* cuda_scores_pointer = cuda_scores + query_start; + score_t* cuda_out_gradients_pointer = cuda_out_gradients + query_start; + score_t* cuda_out_hessians_pointer = cuda_out_hessians + query_start; + const label_t* cuda_label_pointer = cuda_labels + query_start; + if (threadIdx.x < query_item_count) { + shared_scores[threadIdx.x] = cuda_scores_pointer[threadIdx.x]; + shared_indices[threadIdx.x] = static_cast(threadIdx.x); + shared_lambdas[threadIdx.x] = 0.0f; + shared_hessians[threadIdx.x] = 0.0f; + } else { + shared_scores[threadIdx.x] = kMinScore; + shared_indices[threadIdx.x] = static_cast(threadIdx.x); + } + if (MAX_ITEM_GREATER_THAN_1024) { + if (query_item_count > 1024) { + const unsigned int threadIdx_x_plus_1024 = threadIdx.x + 1024; + if (threadIdx_x_plus_1024 < query_item_count) { + shared_scores[threadIdx_x_plus_1024] = cuda_scores_pointer[threadIdx_x_plus_1024]; + shared_indices[threadIdx_x_plus_1024] = static_cast(threadIdx_x_plus_1024); + shared_lambdas[threadIdx_x_plus_1024] = 0.0f; + shared_hessians[threadIdx_x_plus_1024] = 0.0f; + } else { + shared_scores[threadIdx_x_plus_1024] = kMinScore; + shared_indices[threadIdx_x_plus_1024] = static_cast(threadIdx_x_plus_1024); + } + } + } + __syncthreads(); + if (MAX_ITEM_GREATER_THAN_1024) { + if (query_item_count > 1024) { + BitonicArgSort_2048(shared_scores, shared_indices); + } else { + BitonicArgSort_1024(shared_scores, shared_indices, static_cast(query_item_count)); + } + } else { + BitonicArgSort_1024(shared_scores, shared_indices, static_cast(query_item_count)); + } + __syncthreads(); + // get best and worst score + const double best_score = shared_scores[shared_indices[0]]; + data_size_t worst_idx = query_item_count - 1; + if (worst_idx > 0 && shared_scores[shared_indices[worst_idx]] == kMinScore) { + worst_idx -= 1; + } + const double worst_score = shared_scores[shared_indices[worst_idx]]; + __shared__ double sum_lambdas; + if (threadIdx.x == 0) { + sum_lambdas = 0.0f; + } + __syncthreads(); + // start accumulate lambdas by pairs that contain at least one document above truncation level + const data_size_t num_items_i = min(query_item_count - 1, truncation_level); + const data_size_t num_j_per_i = query_item_count - 1; + const data_size_t s = num_j_per_i - num_items_i + 1; + const data_size_t num_pairs = (num_j_per_i + s) * num_items_i / 2; + double thread_sum_lambdas = 0.0f; + for (data_size_t pair_index = static_cast(threadIdx.x); pair_index < num_pairs; pair_index += static_cast(blockDim.x)) { + const double square = 2 * static_cast(pair_index) + s * s - s; + const double sqrt_result = floor(sqrt(square)); + const data_size_t row_index = static_cast(floor(sqrt(square - sqrt_result)) + 1 - s); + const data_size_t i = num_items_i - 1 - row_index; + const data_size_t j = num_j_per_i - (pair_index - (2 * s + row_index - 1) * row_index / 2); + if (cuda_label_pointer[shared_indices[i]] != cuda_label_pointer[shared_indices[j]] && shared_scores[shared_indices[j]] != kMinScore) { + data_size_t high_rank, low_rank; + if (cuda_label_pointer[shared_indices[i]] > cuda_label_pointer[shared_indices[j]]) { + high_rank = i; + low_rank = j; + } else { + high_rank = j; + low_rank = i; + } + const data_size_t high = shared_indices[high_rank]; + const int high_label = static_cast(cuda_label_pointer[high]); + const double high_score = shared_scores[high]; + const double high_label_gain = label_gain_ptr[high_label]; + const double high_discount = log2(2.0f + high_rank); + const data_size_t low = shared_indices[low_rank]; + const int low_label = static_cast(cuda_label_pointer[low]); + const double low_score = shared_scores[low]; + const double low_label_gain = label_gain_ptr[low_label]; + const double low_discount = log2(2.0f + low_rank); + + const double delta_score = high_score - low_score; + + // get dcg gap + const double dcg_gap = high_label_gain - low_label_gain; + // get discount of this pair + const double paired_discount = fabs(high_discount - low_discount); + // get delta NDCG + double delta_pair_NDCG = dcg_gap * paired_discount * inverse_max_dcg; + // regular the delta_pair_NDCG by score distance + if (norm && best_score != worst_score) { + delta_pair_NDCG /= (0.01f + fabs(delta_score)); + } + // calculate lambda for this pair + double p_lambda = 1.0f / (1.0f + exp(sigmoid * delta_score)); + double p_hessian = p_lambda * (1.0f - p_lambda); + // update + p_lambda *= -sigmoid * delta_pair_NDCG; + p_hessian *= sigmoid * sigmoid * delta_pair_NDCG; + atomicAdd_block(shared_lambdas + low, -static_cast(p_lambda)); + atomicAdd_block(shared_hessians + low, static_cast(p_hessian)); + atomicAdd_block(shared_lambdas + high, static_cast(p_lambda)); + atomicAdd_block(shared_hessians + high, static_cast(p_hessian)); + // lambda is negative, so use minus to accumulate + thread_sum_lambdas -= 2 * p_lambda; + } + } + atomicAdd_block(&sum_lambdas, thread_sum_lambdas); + __syncthreads(); + if (norm && sum_lambdas > 0) { + const double norm_factor = log2(1 + sum_lambdas) / sum_lambdas; + if (threadIdx.x < static_cast(query_item_count)) { + cuda_out_gradients_pointer[threadIdx.x] = static_cast(shared_lambdas[threadIdx.x] * norm_factor); + cuda_out_hessians_pointer[threadIdx.x] = static_cast(shared_hessians[threadIdx.x] * norm_factor); + } + if (MAX_ITEM_GREATER_THAN_1024) { + if (query_item_count > 1024) { + const unsigned int threadIdx_x_plus_1024 = threadIdx.x + 1024; + if (threadIdx_x_plus_1024 < static_cast(query_item_count)) { + cuda_out_gradients_pointer[threadIdx_x_plus_1024] = static_cast(shared_lambdas[threadIdx_x_plus_1024] * norm_factor); + cuda_out_hessians_pointer[threadIdx_x_plus_1024] = static_cast(shared_hessians[threadIdx_x_plus_1024] * norm_factor); + } + } + } + } else { + if (threadIdx.x < static_cast(query_item_count)) { + cuda_out_gradients_pointer[threadIdx.x] = static_cast(shared_lambdas[threadIdx.x]); + cuda_out_hessians_pointer[threadIdx.x] = static_cast(shared_hessians[threadIdx.x]); + } + if (MAX_ITEM_GREATER_THAN_1024) { + if (query_item_count > 1024) { + const unsigned int threadIdx_x_plus_1024 = threadIdx.x + 1024; + if (threadIdx_x_plus_1024 < static_cast(query_item_count)) { + cuda_out_gradients_pointer[threadIdx_x_plus_1024] = static_cast(shared_lambdas[threadIdx_x_plus_1024]); + cuda_out_hessians_pointer[threadIdx_x_plus_1024] = static_cast(shared_hessians[threadIdx_x_plus_1024]); + } + } + } + } + __syncthreads(); + } +} + +template +__global__ void GetGradientsKernel_LambdarankNDCG_Sorted( + const double* cuda_scores, const int* cuda_item_indices_buffer, const label_t* cuda_labels, const data_size_t num_data, + const data_size_t num_queries, const data_size_t* cuda_query_boundaries, const double* cuda_inverse_max_dcgs, + const bool norm, const double sigmoid, const int truncation_level, const double* cuda_label_gain, const data_size_t num_rank_label, + score_t* cuda_out_gradients, score_t* cuda_out_hessians) { + __shared__ double shared_label_gain[NUM_RANK_LABEL > 1024 ? 1 : NUM_RANK_LABEL]; + const double* label_gain_ptr = nullptr; + if (NUM_RANK_LABEL <= 1024) { + for (uint32_t i = threadIdx.x; i < static_cast(num_rank_label); i += blockDim.x) { + shared_label_gain[i] = cuda_label_gain[i]; + } + __syncthreads(); + label_gain_ptr = shared_label_gain; + } else { + label_gain_ptr = cuda_label_gain; + } + const data_size_t query_index_start = static_cast(blockIdx.x) * NUM_QUERY_PER_BLOCK; + const data_size_t query_index_end = min(query_index_start + NUM_QUERY_PER_BLOCK, num_queries); + for (data_size_t query_index = query_index_start; query_index < query_index_end; ++query_index) { + const double inverse_max_dcg = cuda_inverse_max_dcgs[query_index]; + const data_size_t query_start = cuda_query_boundaries[query_index]; + const data_size_t query_end = cuda_query_boundaries[query_index + 1]; + const data_size_t query_item_count = query_end - query_start; + const double* cuda_scores_pointer = cuda_scores + query_start; + const int* cuda_item_indices_buffer_pointer = cuda_item_indices_buffer + query_start; + score_t* cuda_out_gradients_pointer = cuda_out_gradients + query_start; + score_t* cuda_out_hessians_pointer = cuda_out_hessians + query_start; + const label_t* cuda_label_pointer = cuda_labels + query_start; + // get best and worst score + const double best_score = cuda_scores_pointer[cuda_item_indices_buffer_pointer[0]]; + data_size_t worst_idx = query_item_count - 1; + if (worst_idx > 0 && cuda_scores_pointer[cuda_item_indices_buffer_pointer[worst_idx]] == kMinScore) { + worst_idx -= 1; + } + const double worst_score = cuda_scores_pointer[cuda_item_indices_buffer_pointer[worst_idx]]; + __shared__ double sum_lambdas; + if (threadIdx.x == 0) { + sum_lambdas = 0.0f; + } + for (int item_index = static_cast(threadIdx.x); item_index < query_item_count; item_index += static_cast(blockDim.x)) { + cuda_out_gradients_pointer[item_index] = 0.0f; + cuda_out_hessians_pointer[item_index] = 0.0f; + } + __syncthreads(); + // start accumulate lambdas by pairs that contain at least one document above truncation level + const data_size_t num_items_i = min(query_item_count - 1, truncation_level); + const data_size_t num_j_per_i = query_item_count - 1; + const data_size_t s = num_j_per_i - num_items_i + 1; + const data_size_t num_pairs = (num_j_per_i + s) * num_items_i / 2; + double thread_sum_lambdas = 0.0f; + for (data_size_t pair_index = static_cast(threadIdx.x); pair_index < num_pairs; pair_index += static_cast(blockDim.x)) { + const double square = 2 * static_cast(pair_index) + s * s - s; + const double sqrt_result = floor(sqrt(square)); + const data_size_t row_index = static_cast(floor(sqrt(square - sqrt_result)) + 1 - s); + const data_size_t i = num_items_i - 1 - row_index; + const data_size_t j = num_j_per_i - (pair_index - (2 * s + row_index - 1) * row_index / 2); + if (j > i) { + // skip pairs with the same labels + if (cuda_label_pointer[cuda_item_indices_buffer_pointer[i]] != cuda_label_pointer[cuda_item_indices_buffer_pointer[j]] && cuda_scores_pointer[cuda_item_indices_buffer_pointer[j]] != kMinScore) { + data_size_t high_rank, low_rank; + if (cuda_label_pointer[cuda_item_indices_buffer_pointer[i]] > cuda_label_pointer[cuda_item_indices_buffer_pointer[j]]) { + high_rank = i; + low_rank = j; + } else { + high_rank = j; + low_rank = i; + } + const data_size_t high = cuda_item_indices_buffer_pointer[high_rank]; + const int high_label = static_cast(cuda_label_pointer[high]); + const double high_score = cuda_scores_pointer[high]; + const double high_label_gain = label_gain_ptr[high_label]; + const double high_discount = log2(2.0f + high_rank); + const data_size_t low = cuda_item_indices_buffer_pointer[low_rank]; + const int low_label = static_cast(cuda_label_pointer[low]); + const double low_score = cuda_scores_pointer[low]; + const double low_label_gain = label_gain_ptr[low_label]; + const double low_discount = log2(2.0f + low_rank); + + const double delta_score = high_score - low_score; + + // get dcg gap + const double dcg_gap = high_label_gain - low_label_gain; + // get discount of this pair + const double paired_discount = fabs(high_discount - low_discount); + // get delta NDCG + double delta_pair_NDCG = dcg_gap * paired_discount * inverse_max_dcg; + // regular the delta_pair_NDCG by score distance + if (norm && best_score != worst_score) { + delta_pair_NDCG /= (0.01f + fabs(delta_score)); + } + // calculate lambda for this pair + double p_lambda = 1.0f / (1.0f + exp(sigmoid * delta_score)); + double p_hessian = p_lambda * (1.0f - p_lambda); + // update + p_lambda *= -sigmoid * delta_pair_NDCG; + p_hessian *= sigmoid * sigmoid * delta_pair_NDCG; + atomicAdd_block(cuda_out_gradients_pointer + low, -static_cast(p_lambda)); + atomicAdd_block(cuda_out_hessians_pointer + low, static_cast(p_hessian)); + atomicAdd_block(cuda_out_gradients_pointer + high, static_cast(p_lambda)); + atomicAdd_block(cuda_out_hessians_pointer + high, static_cast(p_hessian)); + // lambda is negative, so use minus to accumulate + thread_sum_lambdas -= 2 * p_lambda; + } + } + } + atomicAdd_block(&sum_lambdas, thread_sum_lambdas); + __syncthreads(); + if (norm && sum_lambdas > 0) { + const double norm_factor = log2(1 + sum_lambdas) / sum_lambdas; + for (int item_index = static_cast(threadIdx.x); item_index < query_item_count; item_index += static_cast(blockDim.x)) { + cuda_out_gradients_pointer[item_index] *= norm_factor; + cuda_out_hessians_pointer[item_index] *= norm_factor; + } + } + __syncthreads(); + } +} + +void CUDALambdarankNDCG::LaunchGetGradientsKernel(const double* score, score_t* gradients, score_t* hessians) const { + const int num_blocks = (num_queries_ + NUM_QUERY_PER_BLOCK - 1) / NUM_QUERY_PER_BLOCK; + const data_size_t num_rank_label = static_cast(label_gain_.size()); + + #define GetGradientsKernel_LambdarankNDCG_ARGS \ + score, cuda_labels_, num_data_, \ + num_queries_, cuda_query_boundaries_, cuda_inverse_max_dcgs_.RawData(), \ + norm_, sigmoid_, truncation_level_, cuda_label_gain_.RawData(), num_rank_label, \ + gradients, hessians + + #define GetGradientsKernel_LambdarankNDCG_Sorted_ARGS \ + score, cuda_item_indices_buffer_.RawData(), cuda_labels_, num_data_, \ + num_queries_, cuda_query_boundaries_, cuda_inverse_max_dcgs_.RawData(), \ + norm_, sigmoid_, truncation_level_, cuda_label_gain_.RawData(), num_rank_label, \ + gradients, hessians + + if (max_items_in_query_aligned_ <= 1024) { + if (num_rank_label <= 32) { + GetGradientsKernel_LambdarankNDCG<<>>(GetGradientsKernel_LambdarankNDCG_ARGS); + } else if (num_rank_label <= 64) { + GetGradientsKernel_LambdarankNDCG<<>>(GetGradientsKernel_LambdarankNDCG_ARGS); + } else if (num_rank_label <= 128) { + GetGradientsKernel_LambdarankNDCG<<>>(GetGradientsKernel_LambdarankNDCG_ARGS); + } else if (num_rank_label <= 256) { + GetGradientsKernel_LambdarankNDCG<<>>(GetGradientsKernel_LambdarankNDCG_ARGS); + } else if (num_rank_label <= 512) { + GetGradientsKernel_LambdarankNDCG<<>>(GetGradientsKernel_LambdarankNDCG_ARGS); + } else if (num_rank_label <= 1024) { + GetGradientsKernel_LambdarankNDCG<<>>(GetGradientsKernel_LambdarankNDCG_ARGS); + } else { + GetGradientsKernel_LambdarankNDCG<<>>(GetGradientsKernel_LambdarankNDCG_ARGS); + } + } else if (max_items_in_query_aligned_ <= 2048) { + if (num_rank_label <= 32) { + GetGradientsKernel_LambdarankNDCG<<>>(GetGradientsKernel_LambdarankNDCG_ARGS); + } else if (num_rank_label <= 64) { + GetGradientsKernel_LambdarankNDCG<<>>(GetGradientsKernel_LambdarankNDCG_ARGS); + } else if (num_rank_label <= 128) { + GetGradientsKernel_LambdarankNDCG<<>>(GetGradientsKernel_LambdarankNDCG_ARGS); + } else if (num_rank_label <= 256) { + GetGradientsKernel_LambdarankNDCG<<>>(GetGradientsKernel_LambdarankNDCG_ARGS); + } else if (num_rank_label <= 512) { + GetGradientsKernel_LambdarankNDCG<<>>(GetGradientsKernel_LambdarankNDCG_ARGS); + } else if (num_rank_label <= 1024) { + GetGradientsKernel_LambdarankNDCG<<>>(GetGradientsKernel_LambdarankNDCG_ARGS); + } else { + GetGradientsKernel_LambdarankNDCG<<>>(GetGradientsKernel_LambdarankNDCG_ARGS); + } + } else { + BitonicArgSortItemsGlobal(score, num_queries_, cuda_query_boundaries_, cuda_item_indices_buffer_.RawData()); + if (num_rank_label <= 32) { + GetGradientsKernel_LambdarankNDCG_Sorted<32><<>>(GetGradientsKernel_LambdarankNDCG_Sorted_ARGS); + } else if (num_rank_label <= 64) { + GetGradientsKernel_LambdarankNDCG_Sorted<64><<>>(GetGradientsKernel_LambdarankNDCG_Sorted_ARGS); + } else if (num_rank_label <= 128) { + GetGradientsKernel_LambdarankNDCG_Sorted<128><<>>(GetGradientsKernel_LambdarankNDCG_Sorted_ARGS); + } else if (num_rank_label <= 256) { + GetGradientsKernel_LambdarankNDCG_Sorted<256><<>>(GetGradientsKernel_LambdarankNDCG_Sorted_ARGS); + } else if (num_rank_label <= 512) { + GetGradientsKernel_LambdarankNDCG_Sorted<512><<>>(GetGradientsKernel_LambdarankNDCG_Sorted_ARGS); + } else if (num_rank_label <= 1024) { + GetGradientsKernel_LambdarankNDCG_Sorted<1024><<>>(GetGradientsKernel_LambdarankNDCG_Sorted_ARGS); + } else { + GetGradientsKernel_LambdarankNDCG_Sorted<2048><<>>(GetGradientsKernel_LambdarankNDCG_Sorted_ARGS); + } + } + SynchronizeCUDADevice(__FILE__, __LINE__); + + #undef GetGradientsKernel_LambdarankNDCG_ARGS + #undef GetGradientsKernel_LambdarankNDCG_Sorted_ARGS +} + + +} // namespace LightGBM + +#endif // USE_CUDA_EXP diff --git a/src/objective/cuda/cuda_rank_objective.hpp b/src/objective/cuda/cuda_rank_objective.hpp new file mode 100644 index 000000000000..575859cf5c50 --- /dev/null +++ b/src/objective/cuda/cuda_rank_objective.hpp @@ -0,0 +1,56 @@ +/*! + * Copyright (c) 2021 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for + * license information. + */ + +#ifndef LIGHTGBM_OBJECTIVE_CUDA_CUDA_RANK_OBJECTIVE_HPP_ +#define LIGHTGBM_OBJECTIVE_CUDA_CUDA_RANK_OBJECTIVE_HPP_ + +#ifdef USE_CUDA_EXP + +#define NUM_QUERY_PER_BLOCK (10) + +#include +#include + +#include +#include +#include + +#include "../rank_objective.hpp" + +namespace LightGBM { + +class CUDALambdarankNDCG : public CUDAObjectiveInterface, public LambdarankNDCG { + public: + explicit CUDALambdarankNDCG(const Config& config); + + explicit CUDALambdarankNDCG(const std::vector& strs); + + void Init(const Metadata& metadata, data_size_t num_data) override; + + void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override; + + bool IsCUDAObjective() const override { return true; } + + protected: + void LaunchGetGradientsKernel(const double* score, score_t* gradients, score_t* hessians) const; + + // CUDA memory, held by this object + CUDAVector cuda_inverse_max_dcgs_; + CUDAVector cuda_label_gain_; + CUDAVector cuda_item_indices_buffer_; + + // CUDA memory, held by other objects + const label_t* cuda_labels_; + const data_size_t* cuda_query_boundaries_; + + // Host memory + int max_items_in_query_aligned_; +}; + +} // namespace LightGBM + +#endif // USE_CUDA_EXP +#endif // LIGHTGBM_OBJECTIVE_CUDA_CUDA_RANK_OBJECTIVE_HPP_ diff --git a/src/objective/objective_function.cpp b/src/objective/objective_function.cpp index c61199aede76..4a279ec3b37d 100644 --- a/src/objective/objective_function.cpp +++ b/src/objective/objective_function.cpp @@ -11,6 +11,7 @@ #include "xentropy_objective.hpp" #include "cuda/cuda_binary_objective.hpp" +#include "cuda/cuda_rank_objective.hpp" #include "cuda/cuda_regression_objective.hpp" namespace LightGBM { @@ -38,8 +39,7 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& } else if (type == std::string("binary")) { return new CUDABinaryLogloss(config); } else if (type == std::string("lambdarank")) { - Log::Warning("Objective lambdarank is not implemented in cuda_exp version. Fall back to boosting on CPU."); - return new LambdarankNDCG(config); + return new CUDALambdarankNDCG(config); } else if (type == std::string("rank_xendcg")) { Log::Warning("Objective rank_xendcg is not implemented in cuda_exp version. Fall back to boosting on CPU."); return new RankXENDCG(config); diff --git a/src/objective/rank_objective.hpp b/src/objective/rank_objective.hpp index 239bb3651f53..6849fd20f3d8 100644 --- a/src/objective/rank_objective.hpp +++ b/src/objective/rank_objective.hpp @@ -255,7 +255,7 @@ class LambdarankNDCG : public RankingObjective { const char* GetName() const override { return "lambdarank"; } - private: + protected: /*! \brief Sigmoid param */ double sigmoid_; /*! \brief Normalize the lambdas or not */