From e0dd4ee9093cf8d14687690213100cb0786e5188 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Tue, 22 Nov 2022 12:25:24 +0800 Subject: [PATCH] bf16 for interpolate, nhwc for bf16 (#48192) --- paddle/phi/kernels/gpu/interpolate_grad_kernel.cu | 7 ++++--- paddle/phi/kernels/gpu/interpolate_kernel.cu | 5 +++-- paddle/phi/kernels/gpudnn/conv_grad_kernel.cu | 6 ++++++ paddle/phi/kernels/gpudnn/conv_kernel.cu | 10 +++++++++- 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu b/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu index 51a5f50560eac..b38cae829680b 100644 --- a/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu @@ -487,13 +487,13 @@ __global__ void KeBicubicInterpBw(T* in, T in_img_idy = align_corners ? static_cast(ratio_h * out_img_idy) : static_cast(ratio_h * (out_img_idy + 0.5) - 0.5); - int input_y = floorf(in_img_idy); + int input_y = floorf(static_cast(in_img_idy)); using MT = typename phi::dtype::MPTypeTrait::Type; const T y_t = static_cast(static_cast(in_img_idy) - input_y); T in_img_idx = align_corners ? static_cast(ratio_w * out_img_idx) : static_cast(ratio_w * (out_img_idx + 0.5) - 0.5); - int input_x = floorf(in_img_idx); + int input_x = floorf(static_cast(in_img_idx)); const T x_t = static_cast(static_cast(in_img_idx) - input_x); T x_coeffs[4]; @@ -1577,7 +1577,8 @@ PD_REGISTER_KERNEL(nearest_interp_grad, phi::NearestInterpGradKernel, float, double, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } diff --git a/paddle/phi/kernels/gpu/interpolate_kernel.cu b/paddle/phi/kernels/gpu/interpolate_kernel.cu index 8135e73142fec..07e113ef7aa80 100644 --- a/paddle/phi/kernels/gpu/interpolate_kernel.cu +++ b/paddle/phi/kernels/gpu/interpolate_kernel.cu @@ -355,14 +355,14 @@ __global__ void KeBicubicInterpFw(const T* in, T in_img_idy = align_corners ? static_cast(ratio_h * out_img_idy) : static_cast(ratio_h * (out_img_idy + 0.5) - 0.5); - int input_y = floorf(in_img_idy); + int input_y = floorf(static_cast(in_img_idy)); using MT = typename phi::dtype::MPTypeTrait::Type; const T y_t = static_cast(static_cast(in_img_idy) - input_y); T in_img_idx = align_corners ? static_cast(ratio_w * out_img_idx) : static_cast(ratio_w * (out_img_idx + 0.5) - 0.5); - int input_x = floorf(in_img_idx); + int input_x = floorf(static_cast(in_img_idx)); const T x_t = static_cast(static_cast(in_img_idx) - input_x); T coefficients[4]; @@ -1468,6 +1468,7 @@ PD_REGISTER_KERNEL(nearest_interp, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int, int64_t) { kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); diff --git a/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu b/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu index 0d5f266d3d172..5d1a92a3119bc 100644 --- a/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu @@ -454,8 +454,14 @@ void ConvCudnnGradKernel(const Context& ctx, #ifdef PADDLE_WITH_HIP // HIP MIOPEN ONLY SUPPORT NCHW format auto compute_format = paddle::platform::DataLayout::kNCHW; +#else +#if CUDNN_VERSION_MIN(8, 1, 0) + const bool compute_in_nhwc = + (dtype == CUDNN_DATA_HALF || dtype == CUDNN_DATA_BFLOAT16) && + IsVoltaOrLater(ctx); #else const bool compute_in_nhwc = dtype == CUDNN_DATA_HALF && IsVoltaOrLater(ctx); +#endif auto compute_format = compute_in_nhwc && channel_last ? paddle::platform::DataLayout::kNHWC : paddle::platform::DataLayout::kNCHW; diff --git a/paddle/phi/kernels/gpudnn/conv_kernel.cu b/paddle/phi/kernels/gpudnn/conv_kernel.cu index 3e3b1fb198da9..4044056653162 100644 --- a/paddle/phi/kernels/gpudnn/conv_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_kernel.cu @@ -373,10 +373,18 @@ void ConvCudnnKernel(const Context& ctx, #ifdef PADDLE_WITH_HIP // HIP MIOPEN ONLY SUPPORT NCHW format auto compute_format = paddle::platform::DataLayout::kNCHW; +#else +#if CUDNN_VERSION_MIN(8, 1, 0) + // Tensor Core introduced from Volta GPUs supports more faster conv op + // with FP16 or BF16 in NHWC data format. + const bool compute_in_nhwc = + (dtype == CUDNN_DATA_HALF || dtype == CUDNN_DATA_BFLOAT16) && + IsVoltaOrLater(ctx); #else // Tensor Core introduced from Volta GPUs supports more faster conv op - // with FP16 in NHWC data format. + // with FP16 in NHWC data format. (BF16 require cudnn >= 8.1.0) const bool compute_in_nhwc = dtype == CUDNN_DATA_HALF && IsVoltaOrLater(ctx); +#endif // We will only do data format conversion from NHWC to NCHW. // cudnn will convert NCHW to NHWC automatically on Tensor Core. auto compute_format = compute_in_nhwc && channel_last