diff --git a/paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu b/paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu index 6f77f4261f72d..2e60bfbe73c4d 100644 --- a/paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu @@ -55,4 +55,5 @@ PD_REGISTER_KERNEL(label_smooth_grad, phi::LabelSmoothGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/label_smooth_kernel.cu b/paddle/phi/kernels/gpu/label_smooth_kernel.cu index 11ec7e5cc31ee..a89224c8a2cfd 100644 --- a/paddle/phi/kernels/gpu/label_smooth_kernel.cu +++ b/paddle/phi/kernels/gpu/label_smooth_kernel.cu @@ -96,4 +96,5 @@ PD_REGISTER_KERNEL(label_smooth, phi::LabelSmoothKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/python/paddle/fluid/tests/unittests/test_label_smooth_op.py b/python/paddle/fluid/tests/unittests/test_label_smooth_op.py index e1dd242d2b9b9..4b89ad612d528 100644 --- a/python/paddle/fluid/tests/unittests/test_label_smooth_op.py +++ b/python/paddle/fluid/tests/unittests/test_label_smooth_op.py @@ -15,9 +15,10 @@ import unittest import numpy as np -from eager_op_test import OpTest +from op_test import OpTest, convert_float_to_uint16 import paddle +from paddle.fluid import core class TestLabelSmoothOp(OpTest): @@ -50,6 +51,39 @@ def test_check_grad(self): self.check_grad(["X"], "Out") +@unittest.skipIf( + not core.is_compiled_with_cuda() or not core.supports_bfloat16(), + "core is not compiled with CUDA or place do not support bfloat16", +) +class TestLabelSmoothOpBF16(OpTest): + def config(self): + self.op_type = "label_smooth" + self.python_api = paddle.nn.functional.label_smooth + self.epsilon = 0.1 + self.dtype = np.uint16 + batch_size, self.label_dim = 10, 12 + self.label = np.zeros((batch_size, self.label_dim)).astype(np.float32) + nonzero_index = np.random.randint(self.label_dim, size=(batch_size)) + self.label[np.arange(batch_size), nonzero_index] = 1 + + def setUp(self): + self.config() + smoothed_label = ( + 1 - self.epsilon + ) * self.label + self.epsilon / self.label_dim + self.inputs = {'X': convert_float_to_uint16(self.label)} + self.attrs = {'epsilon': self.epsilon} + self.outputs = {'Out': convert_float_to_uint16(smoothed_label)} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, check_eager=True) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place(place, ["X"], "Out", check_eager=True) + + class TestLabelSmoothFP16OP(TestLabelSmoothOp): def init_dtype(self): self.dtype = np.float16 @@ -58,13 +92,31 @@ def init_dtype(self): class TestLabelSmoothOpWithPriorDist(TestLabelSmoothOp): def setUp(self): self.config() - dist = np.random.random((1, self.label_dim)) + dist = np.random.random((1, self.label_dim)).astype(self.dtype) smoothed_label = (1 - self.epsilon) * self.label + self.epsilon * dist self.inputs = {'X': self.label, 'PriorDist': dist} self.attrs = {'epsilon': self.epsilon} self.outputs = {'Out': smoothed_label} +class TestLabelSmoothFP16OPWithPriorDist(TestLabelSmoothOpWithPriorDist): + def init_dtype(self): + self.dtype = np.float16 + + +class TestLabelSmoothBF16OPWithPriorDist(TestLabelSmoothOpBF16): + def setUp(self): + self.config() + dist = np.random.random((1, self.label_dim)).astype(np.float32) + smoothed_label = (1 - self.epsilon) * self.label + self.epsilon * dist + self.inputs = { + 'X': convert_float_to_uint16(self.label), + 'PriorDist': convert_float_to_uint16(dist), + } + self.attrs = {'epsilon': self.epsilon} + self.outputs = {'Out': convert_float_to_uint16(smoothed_label)} + + class TestLabelSmoothOp3D(TestLabelSmoothOp): def setUp(self): super().setUp() @@ -76,6 +128,22 @@ def setUp(self): ) +class TestLabelSmoothOp3DBF16(TestLabelSmoothOpBF16): + def setUp(self): + super().setUp() + self.inputs['X'] = self.inputs['X'].reshape( + [2, -1, self.inputs['X'].shape[-1]] + ) + self.outputs['Out'] = self.outputs['Out'].reshape( + self.inputs['X'].shape + ) + + +class TestLabelSmoothFP16OP3D(TestLabelSmoothOp3D): + def init_dtype(self): + self.dtype = np.float16 + + class TestLabelSmoothOpWithPriorDist3D(TestLabelSmoothOpWithPriorDist): def setUp(self): super().setUp() @@ -87,6 +155,22 @@ def setUp(self): ) +class TestLabelSmoothFP16OPWithPriorDist3D(TestLabelSmoothOpWithPriorDist3D): + def init_dtype(self): + self.dtype = np.float16 + + +class TestLabelSmoothBF16OpWithPriorDist3D(TestLabelSmoothBF16OPWithPriorDist): + def setUp(self): + super().setUp() + self.inputs['X'] = self.inputs['X'].reshape( + [2, -1, self.inputs['X'].shape[-1]] + ) + self.outputs['Out'] = self.outputs['Out'].reshape( + self.inputs['X'].shape + ) + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index e3ea8faf810b2..ea51e21482619 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1954,7 +1954,10 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None): ) check_variable_and_dtype( - label, 'label', ['float16', 'float32', 'float64'], 'label_smooth' + label, + 'label', + ['uint16', 'float16', 'float32', 'float64'], + 'label_smooth', ) helper = LayerHelper("label_smooth", **locals())