From 9cd99f7e6aa2c2d3f1b294b95ca4ac42f73c20f4 Mon Sep 17 00:00:00 2001 From: Infinity_lee Date: Tue, 14 Mar 2023 10:19:21 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90hackathon=204=20No53=E3=80=91label=5Fs?= =?UTF-8?q?mooth=20add=20fp16=20support=20=20(#51493)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../kernels/gpu/label_smooth_grad_kernel.cu | 13 ++++--- paddle/phi/kernels/gpu/label_smooth_kernel.cu | 34 +++++++++++++------ .../tests/unittests/test_label_smooth_op.py | 11 +++++- python/paddle/nn/functional/common.py | 4 +-- 4 files changed, 43 insertions(+), 19 deletions(-) diff --git a/paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu b/paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu index 2ac6442967b38..6f77f4261f72d 100644 --- a/paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu @@ -15,20 +15,22 @@ #include "paddle/phi/kernels/label_smooth_grad_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" - namespace phi { template struct LabelSmoothGradFunctor { - T epsilon; + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType epsilon; __forceinline__ LabelSmoothGradFunctor(float epsilon_data) { - epsilon = static_cast(epsilon_data); + epsilon = static_cast(epsilon_data); } __device__ __forceinline__ T operator()(const T x) const { - return static_cast(1 - epsilon) * x; + return static_cast((static_cast(1) - epsilon) * + static_cast(x)); } }; @@ -52,4 +54,5 @@ PD_REGISTER_KERNEL(label_smooth_grad, ALL_LAYOUT, phi::LabelSmoothGradKernel, float, - double) {} + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/label_smooth_kernel.cu b/paddle/phi/kernels/gpu/label_smooth_kernel.cu index ff2fff4445174..11ec7e5cc31ee 100644 --- a/paddle/phi/kernels/gpu/label_smooth_kernel.cu +++ b/paddle/phi/kernels/gpu/label_smooth_kernel.cu @@ -17,24 +17,27 @@ #include #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" - namespace phi { template struct LabelSmoothFunctor { - T epsilon; - T label_dim; + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType epsilon; + MPType label_dim; __forceinline__ LabelSmoothFunctor(float epsilon_data, int label_dim_data) { - epsilon = static_cast(epsilon_data); - label_dim = static_cast(label_dim_data); + epsilon = static_cast(epsilon_data); + label_dim = static_cast(label_dim_data); } __device__ __forceinline__ T operator()(const T x) const { - return (static_cast(1 - epsilon) * x + - static_cast(epsilon / label_dim)); + return static_cast( + static_cast(static_cast(1) - epsilon) * + static_cast(x) + + static_cast(epsilon / label_dim)); } }; @@ -45,10 +48,14 @@ __global__ void LabelSmoothRunDistKernel(const int N, const T* src, const T* dist_data, T* dst) { + using MPType = typename phi::dtype::MPTypeTrait::Type; CUDA_KERNEL_LOOP(idx, N) { int dist_idx = idx % dist_numel; - dst[idx] = static_cast(1 - epsilon) * src[idx] + - static_cast(epsilon) * dist_data[dist_idx]; + dst[idx] = + static_cast((static_cast(1) - static_cast(epsilon)) * + static_cast(src[idx]) + + static_cast(epsilon) * + static_cast(dist_data[dist_idx])); } } @@ -83,5 +90,10 @@ void LabelSmoothKernel(const Context& ctx, } // namespace phi -PD_REGISTER_KERNEL( - label_smooth, GPU, ALL_LAYOUT, phi::LabelSmoothKernel, float, double) {} +PD_REGISTER_KERNEL(label_smooth, + GPU, + ALL_LAYOUT, + phi::LabelSmoothKernel, + float, + double, + phi::dtype::float16) {} diff --git a/python/paddle/fluid/tests/unittests/test_label_smooth_op.py b/python/paddle/fluid/tests/unittests/test_label_smooth_op.py index b62a75438a7fd..41d56526fff47 100644 --- a/python/paddle/fluid/tests/unittests/test_label_smooth_op.py +++ b/python/paddle/fluid/tests/unittests/test_label_smooth_op.py @@ -24,9 +24,10 @@ class TestLabelSmoothOp(OpTest): def config(self): self.op_type = "label_smooth" self.python_api = paddle.nn.functional.label_smooth + self.init_dtype() self.epsilon = 0.1 batch_size, self.label_dim = 10, 12 - self.label = np.zeros((batch_size, self.label_dim)).astype("float64") + self.label = np.zeros((batch_size, self.label_dim)).astype(self.dtype) nonzero_index = np.random.randint(self.label_dim, size=(batch_size)) self.label[np.arange(batch_size), nonzero_index] = 1 @@ -39,6 +40,9 @@ def setUp(self): self.attrs = {'epsilon': self.epsilon} self.outputs = {'Out': smoothed_label} + def init_dtype(self): + self.dtype = np.float64 + def test_check_output(self): self.check_output(check_eager=True) @@ -46,6 +50,11 @@ def test_check_grad(self): self.check_grad(["X"], "Out", check_eager=True) +class TestLabelSmoothFP16OP(TestLabelSmoothOp): + def init_dtype(self): + self.dtype = np.float16 + + class TestLabelSmoothOpWithPriorDist(TestLabelSmoothOp): def setUp(self): self.config() diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index b957eb369ad8a..6204a4fdbb82b 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1923,7 +1923,7 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None): label(Tensor): The input variable containing the label data. The label data should use one-hot representation. It's a multidimensional tensor with a shape of - :math:`[N_1, ..., Depth]`, where Depth is class number. The dtype can be "float32" and "float64". + :math:`[N_1, ..., Depth]`, where Depth is class number. The dtype can be "float16" "float32" and "float64". prior_dist(Tensor, optional): The prior distribution to be used to smooth labels. If not provided, an uniform distribution is used. It's a multidimensional tensor with a shape of @@ -1965,7 +1965,7 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None): ) check_variable_and_dtype( - label, 'label', ['float32', 'float64'], 'label_smooth' + label, 'label', ['float16', 'float32', 'float64'], 'label_smooth' ) helper = LayerHelper("label_smooth", **locals())