diff --git a/paddle/phi/kernels/gpu/flip_kernel.cu b/paddle/phi/kernels/gpu/flip_kernel.cu index d982971029091..6bcc3d6ff4e29 100644 --- a/paddle/phi/kernels/gpu/flip_kernel.cu +++ b/paddle/phi/kernels/gpu/flip_kernel.cu @@ -13,126 +13,123 @@ // limitations under the License. #include "paddle/phi/kernels/flip_kernel.h" - #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/array.h" namespace phi { -template +template __global__ void flip_cuda_kernel(const int N, const T* in_data, T* out_data, - int64_t* x_shape, - int64_t* x_stride, - int* flip_dims, - int flip_dims_size, - int total_dims) { + phi::Array shape, + phi::Array stride, + phi::Array flip_dims, + int flip_dims_size) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= N) { return; } int cur_indices = idx, rem = 0, dst_offset = 0; - for (int i = 0; i < total_dims; ++i) { + for (int i = 0; i < Rank; ++i) { int64_t temp = cur_indices; - cur_indices = cur_indices / x_stride[i]; - rem = temp - cur_indices * x_stride[i]; + cur_indices = cur_indices / stride[i]; + rem = temp - cur_indices * stride[i]; // flip the indices if it is in flip_dims for (int j = 0; j < flip_dims_size; ++j) { if (i == flip_dims[j]) { - cur_indices = x_shape[i] - 1 - cur_indices; + cur_indices = shape[i] - 1 - cur_indices; } } - dst_offset += cur_indices * x_stride[i]; + dst_offset += cur_indices * stride[i]; cur_indices = rem; } out_data[idx] = in_data[dst_offset]; } -template -void FlipKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& axis, - DenseTensor* out) { - const auto gplace = dev_ctx.GetPlace(); - auto cplace = phi::CPUPlace(); - std::vector flip_dims = axis; - +template +void launch_flip_cuda_kernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis, + DenseTensor* out) { + std::vector flip_dims_v = axis; auto* in_data = x.data(); auto* out_data = dev_ctx.template Alloc(out); - const int flip_dims_size = static_cast(flip_dims.size()); auto x_dims = x.dims(); const int total_dims = x_dims.size(); - const int N = x.numel(); + const int numel = x.numel(); int block_size = 512; dim3 dim_block(block_size); - dim3 dim_grid((N + block_size - 1) / block_size); + dim3 dim_grid((numel + block_size - 1) / block_size); - for (size_t i = 0; i < flip_dims.size(); ++i) { - if (flip_dims[i] < 0) { - flip_dims[i] += total_dims; + for (size_t i = 0; i < flip_dims_v.size(); ++i) { + if (flip_dims_v[i] < 0) { + flip_dims_v[i] += total_dims; } } auto x_stride = phi::stride(x_dims); - std::vector x_dims_v = phi::vectorize(x_dims); - std::vector x_stride_v = phi::vectorize(x_stride); - - int bytes = total_dims * sizeof(int64_t); - auto x_strides_array_tmp = paddle::memory::Alloc( - dev_ctx.GetPlace(), - bytes, - phi::Stream(reinterpret_cast(dev_ctx.stream()))); - int64_t* x_strides_array_gpu = - reinterpret_cast(x_strides_array_tmp->ptr()); - paddle::memory::Copy(gplace, - x_strides_array_gpu, - cplace, - x_stride_v.data(), - bytes, - dev_ctx.stream()); - - auto x_shape_array_tmp = paddle::memory::Alloc( - dev_ctx.GetPlace(), - bytes, - phi::Stream(reinterpret_cast(dev_ctx.stream()))); - int64_t* x_shape_array_gpu = - reinterpret_cast(x_shape_array_tmp->ptr()); - paddle::memory::Copy(gplace, - x_shape_array_gpu, - cplace, - x_dims_v.data(), - bytes, - dev_ctx.stream()); - bytes = flip_dims_size * sizeof(int); - auto flip_dims_array_tmp = paddle::memory::Alloc( - dev_ctx.GetPlace(), - bytes, - phi::Stream(reinterpret_cast(dev_ctx.stream()))); - int* flip_dims_array_gpu = reinterpret_cast(flip_dims_array_tmp->ptr()); - paddle::memory::Copy(gplace, - flip_dims_array_gpu, - cplace, - flip_dims.data(), - bytes, - dev_ctx.stream()); + phi::Array stride_a; + phi::Array shape_a; + phi::Array flip_dims_a; + size_t flip_dims_size = flip_dims_v.size(); + for (size_t idx = 0; idx < N; ++idx) { + stride_a[idx] = x_stride[idx]; + shape_a[idx] = x_dims[idx]; + flip_dims_a[idx] = idx < flip_dims_size ? flip_dims_v[idx] : 0; + } + flip_cuda_kernel<<>>( + numel, in_data, out_data, shape_a, stride_a, flip_dims_a, flip_dims_size); +} - flip_cuda_kernel - <<>>(N, - in_data, - out_data, - x_shape_array_gpu, - x_strides_array_gpu, - flip_dims_array_gpu, - flip_dims_size, - total_dims); +template +void FlipKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis, + DenseTensor* out) { + const size_t total_dims = x.dims().size(); + switch (total_dims) { + case 1: + launch_flip_cuda_kernel(dev_ctx, x, axis, out); + break; + case 2: + launch_flip_cuda_kernel(dev_ctx, x, axis, out); + break; + case 3: + launch_flip_cuda_kernel(dev_ctx, x, axis, out); + break; + case 4: + launch_flip_cuda_kernel(dev_ctx, x, axis, out); + break; + case 5: + launch_flip_cuda_kernel(dev_ctx, x, axis, out); + break; + case 6: + launch_flip_cuda_kernel(dev_ctx, x, axis, out); + break; + case 7: + launch_flip_cuda_kernel(dev_ctx, x, axis, out); + break; + case 8: + launch_flip_cuda_kernel(dev_ctx, x, axis, out); + break; + case 9: + launch_flip_cuda_kernel(dev_ctx, x, axis, out); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "dims of input tensor should be less than 10, But received" + "%d", + x.dims().size())); + } } } // namespace phi