diff --git a/paddle/phi/kernels/gpu/arange_kernel.cu b/paddle/phi/kernels/gpu/arange_kernel.cu index cb8d30186ff58..dc75e1b8da122 100644 --- a/paddle/phi/kernels/gpu/arange_kernel.cu +++ b/paddle/phi/kernels/gpu/arange_kernel.cu @@ -15,6 +15,9 @@ #include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/errors.h" #include "paddle/phi/core/kernel_registry.h" @@ -23,9 +26,11 @@ namespace phi { -template -__global__ void Range(T start, T step, int64_t size, T* out) { - CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; } +template +__global__ void Range(T start, T step, int64_t size, OUT_TYPE* out) { + CUDA_KERNEL_LOOP(index, size) { + out[index] = static_cast(start + step * index); + } } template @@ -34,9 +39,11 @@ void ArangeKernel(const Context& dev_ctx, const DenseTensor& end, const DenseTensor& step, DenseTensor* out) { - T start_value = GetValue(dev_ctx, start); - T end_value = GetValue(dev_ctx, end); - T step_value = GetValue(dev_ctx, step); + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType start_value = + static_cast(GetValue(dev_ctx, start)); + MPType end_value = static_cast(GetValue(dev_ctx, end)); + MPType step_value = static_cast(GetValue(dev_ctx, step)); int64_t size = 0; phi::funcs::GetSize(start_value, end_value, step_value, &size); @@ -49,7 +56,8 @@ void ArangeKernel(const Context& dev_ctx, return; } int64_t grid = (size + block - 1) / block; - Range<<>>(start_value, step_value, size, out_data); + Range + <<>>(start_value, step_value, size, out_data); } template @@ -78,8 +86,16 @@ template decltype(ArangeNullaryKernel) ArangeNullaryKernel; } // namespace phi -PD_REGISTER_KERNEL( - arange, GPU, ALL_LAYOUT, phi::ArangeKernel, float, double, int64_t, int) { +PD_REGISTER_KERNEL(arange, + GPU, + ALL_LAYOUT, + phi::ArangeKernel, + float, + double, + int64_t, + int, + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); diff --git a/python/paddle/fluid/tests/unittests/test_arange.py b/python/paddle/fluid/tests/unittests/test_arange.py index 1c20d62c0736f..a0d1ddc8b9eec 100644 --- a/python/paddle/fluid/tests/unittests/test_arange.py +++ b/python/paddle/fluid/tests/unittests/test_arange.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle.fluid import core @@ -58,6 +58,50 @@ def init_config(self): self.case = (0, 5, 1) +class TestFloa16ArangeOp(TestArangeOp): + def init_config(self): + self.dtype = np.float16 + self.python_api = paddle.arange + self.case = (0, 5, 1) + + def test_check_output(self): + self.check_output() + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestBFloat16ArangeOp(OpTest): + def setUp(self): + self.op_type = "range" + self.init_config() + self.inputs = { + 'Start': convert_float_to_uint16(self.start), + 'End': convert_float_to_uint16(self.end), + 'Step': convert_float_to_uint16(self.step), + } + + self.outputs = { + 'Out': convert_float_to_uint16( + np.arange(self.start, self.end, self.step) + ) + } + + def init_config(self): + self.dtype = np.uint16 + self.python_api = arange_wrapper + self.case = (0, 5, 1) + self.start = np.array([self.case[0]]).astype(np.float32) + self.end = np.array([self.case[1]]).astype(np.float32) + self.step = np.array([self.case[2]]).astype(np.float32) + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + class TestInt32ArangeOp(TestArangeOp): def init_config(self): self.dtype = np.int32 diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index d0fbe0d393a5e..ff74b4bc94497 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1233,7 +1233,7 @@ def arange(start=0, end=None, step=1, dtype=None, name=None): check_dtype( dtype, 'dtype', - ['float32', 'float64', 'int32', 'int64'], + ['float32', 'float64', 'int32', 'int64', 'float16', 'uint16'], 'range/arange', ) helper = LayerHelper('range', **locals())