From aa7e7d308ba758074f3093b95d70cde0054fca87 Mon Sep 17 00:00:00 2001 From: YibLiu <68105073+YibinLiu666@users.noreply.github.com> Date: Tue, 9 Jan 2024 16:23:36 +0800 Subject: [PATCH] Improve the performence of put_along_axis (#60618) * fix bug of put_along_axis * improve performence of put_along_axis --- .../kernels/funcs/gather_scatter_functor.cu | 88 +++++++------------ 1 file changed, 34 insertions(+), 54 deletions(-) diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.cu b/paddle/phi/kernels/funcs/gather_scatter_functor.cu index 7939589d7c662..9ca3c1c460f24 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.cu +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.cu @@ -92,6 +92,12 @@ class ReduceMin { }; static ReduceMin reduce_min; +__global__ void CudaMemsetAsync(int* dest, int value, size_t size) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid * sizeof(int) >= size) return; + dest[tid] = value; +} + template = numel) return; - - if (tid == 0) { - for (int i = 0; i < numel_data; i++) { - thread_ids[i] = 0; - } - } - __syncthreads(); int64_t i, j, k; // The i, j, k here is the index of the 3 layers loop // squeezed from the N layers loop. /* tid = i * select_dim_size * outer_dim_size + j * outer_dim_size + k */ @@ -267,16 +266,6 @@ __global__ void ScatterMeanGPUKernel(tensor_t* self_data, int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; - if (tid == 0) { - for (int i = 0; i < numel_data; i++) { - shared_mem[i] = 0; // thread_id - if (include_self) - shared_mem[numel_data + i] = 1; // reduce size - else - shared_mem[numel_data + i] = 0; - } - } - __syncthreads(); int64_t i, j, k; // The i, j, k here is the index of the 3 layers loop // squeezed from the N layers loop. /* tid = i * select_dim_size * outer_dim_size + j * outer_dim_size + k */ @@ -384,6 +373,7 @@ struct gpu_gather_scatter_functor { int* shared_mem; cudaMallocAsync( reinterpret_cast(&shared_mem), shared_mem_size, stream); + cudaMemsetAsync(shared_mem, 0, shared_mem_size, stream); ScatterAssignGPUKernel <<>>(self_data, dim, @@ -405,6 +395,14 @@ struct gpu_gather_scatter_functor { int* shared_mem; cudaMallocAsync( reinterpret_cast(&shared_mem), shared_mem_size, stream); + cudaMemsetAsync(shared_mem, 0, sizeof(int) * self_size, stream); + if (include_self) { + int64_t grid_memset = (self_size * 2 + block - 1) / block; + CudaMemsetAsync<<>>( + shared_mem, 1, shared_mem_size); + } else { + cudaMemsetAsync(shared_mem, 0, shared_mem_size, stream); + } ScatterMeanGPUKernel <<>>(self_data, dim, @@ -429,6 +427,9 @@ struct gpu_gather_scatter_functor { shared_mem_size = sizeof(int) * self_size; cudaMallocAsync( reinterpret_cast(&shared_mem), shared_mem_size, stream); + int64_t grid_memset = (self_size + block - 1) / block; + CudaMemsetAsync<<>>( + shared_mem, index_size + 1, shared_mem_size); } GatherScatterGPUKernel <<>>(self_data, @@ -640,12 +641,6 @@ __global__ void ScatterMulInputGradGPUKernel(tensor_t* grad_data, int* thread_ids) { int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; - if (tid == 0) { - for (int i = 0; i < numel_grad; i++) { - thread_ids[i] = 0; - } - } - __syncthreads(); int64_t i, j, k; i = tid / (select_dim_size * outer_dim_size); int64_t remind = tid % (select_dim_size * outer_dim_size); @@ -682,13 +677,6 @@ __global__ void ScatterMinMaxInputGradGPUKernel(tensor_t* grad_data, int* shared_mem) { int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; - - if (tid == 0) { - for (int i = 0; i < numel_grad; i++) { - shared_mem[i] = 1; // number of elements - } - } - __syncthreads(); int64_t i, j, k; i = tid / (select_dim_size * outer_dim_size); int64_t remind = tid % (select_dim_size * outer_dim_size); @@ -762,6 +750,7 @@ void gpu_scatter_mul_min_max_input_grad_kernel(phi::DenseTensor self, int* shared_mem; cudaMallocAsync( reinterpret_cast(&shared_mem), shared_mem_size, stream); + cudaMemsetAsync(shared_mem, 0, shared_mem_size, stream); ScatterMulInputGradGPUKernel <<>>(grad_data, dim, @@ -781,6 +770,9 @@ void gpu_scatter_mul_min_max_input_grad_kernel(phi::DenseTensor self, int* shared_mem; cudaMallocAsync( reinterpret_cast(&shared_mem), shared_mem_size, stream); + int64_t grid_memset = (grad_size + block - 1) / block; + CudaMemsetAsync<<>>( + shared_mem, 1, shared_mem_size); ScatterMinMaxInputGradGPUKernel <<>>(grad_data, dim, @@ -816,13 +808,6 @@ __global__ void ScatterMeanInputGradGPUKernel(tensor_t* grad_data, int* shared_mem) { int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; - if (tid == 0) { - for (int i = 0; i < numel_grad; i++) { - shared_mem[i] = 0; // thread_ids - shared_mem[numel_grad + i] = 1; // number of elements - } - } - __syncthreads(); int64_t i, j, k; i = tid / (select_dim_size * outer_dim_size); int64_t remind = tid % (select_dim_size * outer_dim_size); @@ -879,6 +864,10 @@ void gpu_scatter_mean_input_grad_kernel(phi::DenseTensor self, int* shared_mem; cudaMallocAsync( reinterpret_cast(&shared_mem), shared_mem_size, stream); + cudaMemsetAsync(shared_mem, 0, sizeof(int) * grad_size, stream); + int64_t grid_memset = (grad_size + block - 1) / block; + CudaMemsetAsync<<>>( + shared_mem + grad_size, 1, sizeof(int) * grad_size); ScatterMeanInputGradGPUKernel <<>>(grad_data, dim, @@ -910,12 +899,6 @@ __global__ void ScatterValueGradGPUKernel(tensor_t* grad_data, int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; - if (tid == 0) { - for (int i = 0; i < numel_data; i++) { - thread_ids[i] = 0; - } - } - __syncthreads(); int64_t i, j, k; i = tid / (select_dim_size * outer_dim_size); int64_t remind = tid % (select_dim_size * outer_dim_size); @@ -975,6 +958,7 @@ void gpu_scatter_value_grad_kernel(phi::DenseTensor self, int* shared_mem; cudaMallocAsync( reinterpret_cast(&shared_mem), shared_mem_size, stream); + cudaMemsetAsync(shared_mem, 0, shared_mem_size, stream); ScatterValueGradGPUKernel <<>>(grad_data, dim, @@ -1005,20 +989,10 @@ __global__ void ScatterMeanValueGradGPUKernel(tensor_t* grad_data, int64_t outer_dim_size_grad, int64_t numel, int64_t numel_self, - bool include_self, int* shared_mem) { int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; - if (tid == 0) { - for (int i = 0; i < numel_self; i++) { - if (include_self) - shared_mem[i] = 1; // number of elements - else - shared_mem[i] = 0; - } - } - __syncthreads(); int64_t i, j, k; i = tid / (select_dim_size * outer_dim_size); int64_t remind = tid % (select_dim_size * outer_dim_size); @@ -1114,6 +1088,13 @@ void gpu_scatter_add_mean_value_grad_kernel( int* shared_mem; cudaMallocAsync( reinterpret_cast(&shared_mem), shared_mem_size, stream); + if (include_self) { + int64_t grid_memset = (self_size + block - 1) / block; + CudaMemsetAsync<<>>( + shared_mem, 1, shared_mem_size); + } else { + cudaMemsetAsync(shared_mem, 0, shared_mem_size, stream); + } ScatterMeanValueGradGPUKernel <<>>(grad_data, dim, @@ -1127,7 +1108,6 @@ void gpu_scatter_add_mean_value_grad_kernel( outer_dim_size_grad, index_size, self_size, - include_self, shared_mem); cudaFreeAsync(shared_mem, stream); } else if (reduce == "add") {