Skip to content

Commit

Permalink
[PHI decoupling] remove "gpu_device_function.h" in fluid. (#48117)
Browse files Browse the repository at this point in the history
* move "paddle/phi/backends/gpu/gpu_device_function.h" to phi

* update copyright years

* rm "fluid/platform/device/gpu/gpu_device_function.h" in phi

* rm dependence to "gpu_device_function.h" in fluid

* rm gpu_device_function.h etc in fluid

* fix rocm-complie bugs

* fix cuda_helper_test.cu bugs
  • Loading branch information
huangjiyi authored Nov 22, 2022
1 parent 2995f74 commit 4da1a0f
Show file tree
Hide file tree
Showing 17 changed files with 34 additions and 413 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/activation_op.kps
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"

namespace paddle {
Expand Down
13 changes: 7 additions & 6 deletions paddle/fluid/operators/elementwise/elementwise_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ limitations under the License. */

#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/gpu/elementwise_grad.h"

Expand Down Expand Up @@ -982,7 +982,7 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
#pragma unroll
for (int i = BLOCK_X >> 1; i > 0; i >>= 1) {
// reduce sum with wrap
val += platform::CudaShuffleXorSync(0xFFFFFFFF, val, i);
val += phi::backends::gpu::CudaShuffleXorSync(0xFFFFFFFF, val, i);
}

size_t idx_j = j + threadIdx.y;
Expand All @@ -1004,7 +1004,8 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
#pragma unroll
for (int i = BLOCK_X >> 1; i > 0; i >>= 1) {
// reduce sum with wrap
inter_val += platform::CudaShuffleXorSync(0xFFFFFFFF, inter_val, i);
inter_val +=
phi::backends::gpu::CudaShuffleXorSync(0xFFFFFFFF, inter_val, i);
}
if (threadIdx.x == 0 && (idx_j < w)) d_intermediate[idx_j] = inter_val;
}
Expand Down Expand Up @@ -1160,22 +1161,22 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
if (BcastY) {
if (dy) {
val = paddle::platform::reduceSum(val, tid, h);
val = phi::backends::gpu::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
}
} else {
if (dx) {
val = paddle::platform::reduceSum(val, tid, h);
val = phi::backends::gpu::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dx[j] = val;
}
}
}
if (!SameShapeOfIntermediateOutAndOut) {
if (d_intermediate) {
inter_val = paddle::platform::reduceSum(inter_val, tid, h);
inter_val = phi::backends::gpu::reduceSum(inter_val, tid, h);
if (threadIdx.x == 0) {
d_intermediate[j] = inter_val;
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/fused/fused_attention_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"

namespace paddle {
namespace operators {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/fused/fused_dropout_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/functors.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ namespace cub = hipcub;
#endif

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"

namespace paddle {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/fused/fused_gate_attention_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fused_gate_attention.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace paddle {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/group_norm_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace cub = hipcub;
#endif

#include "paddle/fluid/operators/group_norm_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"

namespace paddle {
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/layer_norm_kernel.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ namespace cub = hipcub;
#include <iostream>

#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"

Expand Down Expand Up @@ -55,7 +55,7 @@ static __forceinline__ __device__ U WarpReduceSum(U val) {
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
val += paddle::platform::CudaShuffleDownSync(mask, val, offset);
val += phi::backends::gpu::CudaShuffleDownSync(mask, val, offset);
}
return val;
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/math/beam_search.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/math/beam_search.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"

namespace paddle {
namespace operators {
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/operators/row_conv_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/row_conv_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace paddle {
Expand Down Expand Up @@ -242,7 +242,7 @@ __global__ void RowConvGradFilterImproved(const T *in,

for (int offset = 16; offset > 0;
offset = offset / 2) { // blockDim.x is 32.
val += platform::CudaShuffleDownSync(mask, val, offset);
val += phi::backends::gpu::CudaShuffleDownSync(mask, val, offset);
}
__syncthreads();

Expand Down Expand Up @@ -307,7 +307,7 @@ __global__ void RowConvGradFilter(const T *in,

for (int offset = 16; offset > 0;
offset = offset / 2) { // blockDim.x is 32.
val += platform::CudaShuffleDownSync(mask, val, offset);
val += phi::backends::gpu::CudaShuffleDownSync(mask, val, offset);
}
__syncthreads();

Expand Down
17 changes: 11 additions & 6 deletions paddle/fluid/operators/top_k_function_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"

#define FINAL_MASK 0xffffffff
Expand Down Expand Up @@ -283,8 +283,10 @@ __forceinline__ __device__ Pair<T> WarpReduce(Pair<T> input,
if (largest) {
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
T tmp_val = platform::CudaShuffleDownSync(FINAL_MASK, input.v, offset);
int tmp_id = platform::CudaShuffleDownSync(FINAL_MASK, input.id, offset);
T tmp_val =
phi::backends::gpu::CudaShuffleDownSync(FINAL_MASK, input.v, offset);
int tmp_id =
phi::backends::gpu::CudaShuffleDownSync(FINAL_MASK, input.id, offset);
if (input.v < tmp_val || (input.v == tmp_val && input.id > tmp_id)) {
input.v = tmp_val;
input.id = tmp_id;
Expand All @@ -293,8 +295,10 @@ __forceinline__ __device__ Pair<T> WarpReduce(Pair<T> input,
} else {
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
T tmp_val = platform::CudaShuffleDownSync(FINAL_MASK, input.v, offset);
int tmp_id = platform::CudaShuffleDownSync(FINAL_MASK, input.id, offset);
T tmp_val =
phi::backends::gpu::CudaShuffleDownSync(FINAL_MASK, input.v, offset);
int tmp_id =
phi::backends::gpu::CudaShuffleDownSync(FINAL_MASK, input.id, offset);
if (input.v > tmp_val || (input.v == tmp_val && input.id > tmp_id)) {
input.v = tmp_val;
input.id = tmp_id;
Expand Down Expand Up @@ -357,7 +361,8 @@ __device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
if (tid_max / 32 == wid) {
if (platform::CudaShuffleSync(mask, *beam, tid_max % 32, 32) == MaxLength)
if (phi::backends::gpu::CudaShuffleSync(mask, *beam, tid_max % 32, 32) ==
MaxLength)
break;
}
}
Expand Down
Loading

0 comments on commit 4da1a0f

Please sign in to comment.