From 9e9b705aa842542d9bcf69b9039b555cae982b80 Mon Sep 17 00:00:00 2001 From: Vvsmile <450864116@qq.com> Date: Tue, 29 Nov 2022 17:09:02 +0800 Subject: [PATCH] Optimize the implementation of the argsort operator. (#47738) Optimize the implementation of the argsort operator --- paddle/phi/kernels/gpu/argsort_kernel.cu | 424 +++++++++++++++++------ 1 file changed, 314 insertions(+), 110 deletions(-) diff --git a/paddle/phi/kernels/gpu/argsort_kernel.cu b/paddle/phi/kernels/gpu/argsort_kernel.cu index 6a9c1e275998b..1c3825b90e210 100644 --- a/paddle/phi/kernels/gpu/argsort_kernel.cu +++ b/paddle/phi/kernels/gpu/argsort_kernel.cu @@ -64,8 +64,10 @@ struct SegmentOffsetIter { int num_cols_; }; +#define PADDLE_CUDA_NUM_THREADS 1024 + template -static __global__ void FillIndex(T* indices, T num_rows, T num_cols) { +static __global__ void FillIndex(T *indices, T num_rows, T num_cols) { int col_id = threadIdx.x; int row_id = blockIdx.x; @@ -78,23 +80,246 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) { // Sort by flag descending, True: descending. False: Ascending. // Default is false. -template -void ArgFullSort(const phi::GPUContext& ctx, - const DenseTensor* input, - DenseTensor* output, - DenseTensor* indices, - const IndType num_rows, - const IndType num_cols, +static __global__ void FillIndexAndSegmentKernel(int2 *data, + int numel, + int nsort) { + CUDA_KERNEL_LOOP(idx, numel) { + auto segment = static_cast(idx / nsort); + auto sort = static_cast(idx % nsort); + data[idx] = int2{segment, sort}; + } +} + +#define CUB_WRAPPER(func, ctx, ...) \ + do { \ + size_t temp_storage_bytes = 0; \ + gpuError_t err; \ + err = func(nullptr, temp_storage_bytes, __VA_ARGS__); \ + PADDLE_ENFORCE_GPU_SUCCESS(err); \ + DenseTensor temp_storage; \ + int64_t temp_size = temp_storage_bytes; \ + temp_storage.Resize({temp_size}); \ + ctx.template Alloc(&temp_storage); \ + err = func(temp_storage.data(), temp_storage_bytes, __VA_ARGS__); \ + PADDLE_ENFORCE_GPU_SUCCESS(err); \ + } while (false) + +template +static void RadixSortPairs(const phi::GPUContext &ctx, + const KT *keys_in, + const VT *values_in, + KT *keys_out, + VT *values_out, + int64_t n, + bool descending = false, + int64_t begin_bit = 0, + int64_t end_bit = sizeof(KT) * 8) { + if (keys_out == nullptr) { + DenseTensor key_out_owner; + key_out_owner.Resize({n}); + ctx.template Alloc(&key_out_owner); + keys_out = key_out_owner.data(); + } + + if (descending) { + CUB_WRAPPER(cub::DeviceRadixSort::SortPairsDescending, + ctx, + keys_in, + keys_out, + values_in, + values_out, + n, + begin_bit, + end_bit, + ctx.stream()); + } else { + CUB_WRAPPER(cub::DeviceRadixSort::SortPairs, + ctx, + keys_in, + keys_out, + values_in, + values_out, + n, + begin_bit, + end_bit, + ctx.stream()); + } +} + +template +static void RadixSortKeys(const phi::GPUContext &ctx, + const KT *keys_in, + KT *keys_out, + int64_t n, + bool descending, + int64_t begin_bit, + int64_t end_bit) { + if (descending) { + CUB_WRAPPER(cub::DeviceRadixSort::SortKeysDescending, + ctx, + keys_in, + keys_out, + n, + begin_bit, + end_bit, + ctx.stream()); + } else { + CUB_WRAPPER(cub::DeviceRadixSort::SortKeys, + ctx, + keys_in, + keys_out, + n, + begin_bit, + end_bit, + ctx.stream()); + } +} + +template +static __global__ void SortPostprocessKernel(const T *in, + const int2 *i_s_ptr, + T *out, + int64_t *index, + int nsegments, + int nsort) { + CUDA_KERNEL_LOOP(i, nsegments * nsort) { + int segment = i / nsort; // segment_id + int j = i % nsort; + + int offset = segment * nsort; + const T *in_ = in + offset; + T *out_ = out + offset; + int64_t *index_ = index + offset; + const int2 *i_s_ptr_ = i_s_ptr + offset; + + int idx = i_s_ptr_[j].y; + index_[j] = idx; + out_[j] = in_[idx]; + } +} + +template +inline void SegmentedSortPairsByFullSort(const phi::GPUContext &ctx, + const T *const self_ptr, + T *const values_ptr, + int64_t *const indices_ptr, + const int64_t nsegments, + const int64_t nsort, + const int64_t n, + const bool descending) { + int64_t segment_bits = std::max( + 1L, static_cast(std::ceil(std::log2(nsegments)))); + + const auto numel = nsort * nsegments; + + DenseTensor indices_and_segment; + int64_t indices_and_segment_size = numel; + indices_and_segment.Resize({indices_and_segment_size * 2}); + ctx.template Alloc(&indices_and_segment); + auto i_s_ptr_base = indices_and_segment.data(); + auto i_s_ptr = reinterpret_cast(i_s_ptr_base); + + dim3 block = PADDLE_CUDA_NUM_THREADS; + auto block_num = (numel - 1) / PADDLE_CUDA_NUM_THREADS + 1; + dim3 grid = static_cast(block_num); + + auto cu_stream = ctx.stream(); + + FillIndexAndSegmentKernel<<>>( + i_s_ptr, numel, nsort); + + DenseTensor indices_and_segment2; + int64_t indices_and_segment2_size = numel; + indices_and_segment2.Resize({indices_and_segment2_size * 2}); + ctx.template Alloc(&indices_and_segment2); + auto i_s_ptr2_base = indices_and_segment2.data(); + auto i_s_ptr2 = reinterpret_cast(i_s_ptr2_base); + + RadixSortPairs( + ctx, self_ptr, i_s_ptr, nullptr, i_s_ptr2, n, descending); + + RadixSortKeys(ctx, + reinterpret_cast(i_s_ptr2), + reinterpret_cast(i_s_ptr), + n, + false, + 0, + segment_bits); + + SortPostprocessKernel<<>>( + self_ptr, i_s_ptr, values_ptr, indices_ptr, nsegments, nsort); +} + +// The method is called when # of the rows of the input is less than or equal to +// 4 +template +void ArgFullSortForTinyRows(const phi::GPUContext &ctx, + const DenseTensor *input, + DenseTensor *output, + DenseTensor *indices, + const IndexType num_rows, + const IndexType num_cols, + const bool descending) { + auto gpu_stream = ctx.stream(); + size_t temp_storage_bytes = -1; + + IndexType numel = num_rows * num_cols; + if (numel == 0) { + return; + } + + IndexType numel_or_intmax = + std::min(numel, static_cast(std::numeric_limits::max())); + IndexType nsort = num_cols; + IndexType nbatch = (numel_or_intmax / nsort) * nsort; + + T *sorted_out_ptr; + IndexType *sorted_indices_ptr; + const T *input_data = input->data(); + T *out = ctx.template Alloc(output); + IndexType *ind = ctx.template Alloc(indices); + sorted_out_ptr = out; + sorted_indices_ptr = ind; + + int64_t remaining = numel; + + while (remaining > 0) { + int64_t n = std::min(remaining, nbatch); + IndexType nsegments = n / nsort; + + SegmentedSortPairsByFullSort(ctx, + input_data, + sorted_out_ptr, + sorted_indices_ptr, + nsegments, + nsort, + n, + descending); + + remaining -= n; + input_data += n; + sorted_out_ptr += n; + sorted_indices_ptr += n; + } +} + +template +void ArgFullSort(const phi::GPUContext &ctx, + const DenseTensor *input, + DenseTensor *output, + DenseTensor *indices, + const IndexType num_rows, + const IndexType num_cols, const bool descending) { auto cu_stream = ctx.stream(); DenseTensor input_indices; - const std::vector dims = {num_rows, num_cols}; + const std::vector dims = {num_rows, num_cols}; auto dim = phi::make_ddim(dims); input_indices.Resize(dim); - ctx.template Alloc(&input_indices); + ctx.template Alloc(&input_indices); size_t temp_storage_bytes = -1; - auto ComputeBlockSize = [](IndType col) { + auto ComputeBlockSize = [](IndexType col) { if (col > 512) return 1024; else if (col > 256 && col <= 512) @@ -113,111 +338,70 @@ void ArgFullSort(const phi::GPUContext& ctx, int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX; // Init a index array FillIndex<<>>( - input_indices.data(), num_rows, num_cols); + input_indices.data(), num_rows, num_cols); - T* sorted_out_ptr; - IndType* sorted_indices_ptr; - const T* inp = input->data(); - T* out = ctx.template Alloc(output); - IndType* ind = ctx.template Alloc(indices); + T *sorted_out_ptr; + IndexType *sorted_indices_ptr; + const T *inp = input->data(); + T *out = ctx.template Alloc(output); + IndexType *ind = ctx.template Alloc(indices); sorted_out_ptr = out; sorted_indices_ptr = ind; // create iter for counting input - cub::CountingInputIterator counting_iter(0); + cub::CountingInputIterator counting_iter(0); // segment_offset is used for move to next row - cub::TransformInputIterator> + cub::CountingInputIterator> segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols)); gpuError_t err; if (descending) { - err = cub::DeviceSegmentedRadixSort::SortPairsDescending( - nullptr, - temp_storage_bytes, - inp, - sorted_out_ptr, - input_indices.data(), - sorted_indices_ptr, - num_cols * num_rows, - num_rows, - segment_offsets_t, - segment_offsets_t + 1, - 0, - sizeof(T) * 8, - cu_stream); + CUB_WRAPPER(cub::DeviceSegmentedRadixSort::SortPairsDescending, + ctx, + inp, + sorted_out_ptr, + input_indices.data(), + sorted_indices_ptr, + num_cols * num_rows, + num_rows, + segment_offsets_t, + segment_offsets_t + 1, + 0, + sizeof(T) * 8, + ctx.stream()); } else { - err = - cub::DeviceSegmentedRadixSort::SortPairs(nullptr, - temp_storage_bytes, - inp, - sorted_out_ptr, - input_indices.data(), - sorted_indices_ptr, - num_cols * num_rows, - num_rows, - segment_offsets_t, - segment_offsets_t + 1, - 0, - sizeof(T) * 8, - cu_stream); + CUB_WRAPPER(cub::DeviceSegmentedRadixSort::SortPairs, + ctx, + inp, + sorted_out_ptr, + input_indices.data(), + sorted_indices_ptr, + num_cols * num_rows, + num_rows, + segment_offsets_t, + segment_offsets_t + 1, + 0, + sizeof(T) * 8, + ctx.stream()); } - PADDLE_ENFORCE_GPU_SUCCESS(err); - - DenseTensor temp_storage; - int64_t temp_size = temp_storage_bytes; - temp_storage.Resize({temp_size}); - ctx.template Alloc(&temp_storage); - - if (descending) { - err = cub::DeviceSegmentedRadixSort::SortPairsDescending( - temp_storage.data(), - temp_storage_bytes, - inp, - sorted_out_ptr, - input_indices.data(), - sorted_indices_ptr, - num_cols * num_rows, - num_rows, - segment_offsets_t, - segment_offsets_t + 1, - 0, - sizeof(T) * 8, - cu_stream); - } else { - err = - cub::DeviceSegmentedRadixSort::SortPairs(temp_storage.data(), - temp_storage_bytes, - inp, - sorted_out_ptr, - input_indices.data(), - sorted_indices_ptr, - num_cols * num_rows, - num_rows, - segment_offsets_t, - segment_offsets_t + 1, - 0, - sizeof(T) * 8, - cu_stream); - } - - PADDLE_ENFORCE_GPU_SUCCESS(err); } template -void ArgsortKernel(const Context& dev_ctx, - const DenseTensor& input, +void ArgsortKernel(const Context &dev_ctx, + const DenseTensor &input, int axis, bool descending, - DenseTensor* output, - DenseTensor* indices) { + DenseTensor *output, + DenseTensor *indices) { auto in_dims = input.dims(); axis = (axis < 0) ? (in_dims.size() + axis) : axis; - const T* in_data = input.data(); + + const T *in_data = input.data(); auto size = input.numel(); - T* out_data = dev_ctx.template Alloc(output); - int64_t* ids_data = dev_ctx.template Alloc(indices); + T *out_data = dev_ctx.template Alloc(output); + int64_t *ids_data = dev_ctx.template Alloc(indices); // Use thrust for parallel acceleration when the input size is equal to the // length of the ‘axis’ dimension. @@ -239,13 +423,23 @@ void ArgsortKernel(const Context& dev_ctx, const int64_t input_height = phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); const int64_t input_width = in_dims[in_dims.size() - 1]; - ArgFullSort(dev_ctx, - &input, - output, - indices, - input_height, - input_width, - descending); + if (input_height <= 4) { + ArgFullSortForTinyRows(dev_ctx, + &input, + output, + indices, + input_height, + input_width, + descending); + } else { + ArgFullSort(dev_ctx, + &input, + output, + indices, + input_height, + input_width, + descending); + } } else { // if not full sort, do transpose first std::vector trans; @@ -264,7 +458,7 @@ void ArgsortKernel(const Context& dev_ctx, DenseTensor trans_inp; trans_inp.Resize(trans_dims); - T* trans_inp_data = dev_ctx.template Alloc(&trans_inp); + T *trans_inp_data = dev_ctx.template Alloc(&trans_inp); // Do transpose TransposeKernel(dev_ctx, input, trans, &trans_inp); @@ -282,13 +476,23 @@ void ArgsortKernel(const Context& dev_ctx, dev_ctx.template Alloc(&tmp_indices); dev_ctx.template Alloc(indices); - ArgFullSort(dev_ctx, - &trans_inp, - &tmp_out, - &tmp_indices, - input_height, - input_width, - descending); + if (input_height <= 4) { + ArgFullSortForTinyRows(dev_ctx, + &trans_inp, + &tmp_out, + &tmp_indices, + input_height, + input_width, + descending); + } else { + ArgFullSort(dev_ctx, + &trans_inp, + &tmp_out, + &tmp_indices, + input_height, + input_width, + descending); + } TransposeKernel(dev_ctx, tmp_indices, trans, indices); // transpose back