diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 3937464fbce493..dc57d0551baf95 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -93,8 +93,8 @@ func : angle_grad - backward_op : argsort_grad - forward : argsort (Tensor x, int axis, bool descending) -> Tensor(out), Tensor(indices) - args : (Tensor indices, Tensor x, Tensor out_grad, int axis, bool descending) + forward : argsort (Tensor x, int axis, bool descending, bool stable) -> Tensor(out), Tensor(indices) + args : (Tensor indices, Tensor x, Tensor out_grad, int axis, bool descending, bool stable) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 0d6fbfc83691a1..97090fb070c539 100755 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -169,7 +169,7 @@ interfaces : paddle::dialect::InferSymbolicShapeInterface - op : argsort - args : (Tensor x, int axis=-1, bool descending=false) + args : (Tensor x, int axis=-1, bool descending=false, bool stable=false) output : Tensor(out), Tensor(indices) infer_meta : func : ArgsortInferMeta diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index a152bc152ae6bb..8df9c5838dda41 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -331,6 +331,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x, void ArgsortInferMeta(const MetaTensor& input, int axis, bool descending, + bool stable, MetaTensor* output, MetaTensor* indices) { auto in_dims = input.dims(); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 29fc97955e87ae..d75638ba668f61 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -55,6 +55,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x, void ArgsortInferMeta(const MetaTensor& input, int axis, bool descending, + bool stable, MetaTensor* output, MetaTensor* indices); diff --git a/paddle/phi/kernels/argsort_grad_kernel.h b/paddle/phi/kernels/argsort_grad_kernel.h index b91bd69911351d..c9495a50b90a82 100644 --- a/paddle/phi/kernels/argsort_grad_kernel.h +++ b/paddle/phi/kernels/argsort_grad_kernel.h @@ -25,6 +25,7 @@ void ArgsortGradKernel(const Context& dev_ctx, const DenseTensor& out_grad, int axis, bool descending, + bool stable, DenseTensor* in_grad); } // namespace phi diff --git a/paddle/phi/kernels/argsort_kernel.h b/paddle/phi/kernels/argsort_kernel.h index 519f2b88547f68..23e37588b4851d 100644 --- a/paddle/phi/kernels/argsort_kernel.h +++ b/paddle/phi/kernels/argsort_kernel.h @@ -33,6 +33,9 @@ namespace phi { * algorithm how to sort the input data. * If descending is true, will sort by descending order, * else if false, sort by ascending order + * @param stable Indicate whether to use stable sorting algorithm, which + * guarantees that the order of equivalent elements is + * preserved. * @param out The sorted tensor of Argsort op, with the same shape as * x * @param indices The indices of a tensor giving the sorted order, with @@ -43,6 +46,7 @@ void ArgsortKernel(const Context& dev_ctx, const DenseTensor& input, int axis, bool descending, + bool stable, DenseTensor* output, DenseTensor* indices); diff --git a/paddle/phi/kernels/cpu/argsort_grad_kernel.cc b/paddle/phi/kernels/cpu/argsort_grad_kernel.cc index 92135f1eb02346..64fc09974e49e7 100644 --- a/paddle/phi/kernels/cpu/argsort_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/argsort_grad_kernel.cc @@ -56,6 +56,7 @@ void ArgsortGradKernel(const Context& dev_ctx, const DenseTensor& out_grad, int axis, bool descending UNUSED, + bool stable UNUSED, DenseTensor* in_grad) { auto in_dims = indices.dims(); auto rank = input.dims().size(); diff --git a/paddle/phi/kernels/cpu/argsort_kernel.cc b/paddle/phi/kernels/cpu/argsort_kernel.cc index 7e3ab23a44dfb1..59c654a3df4065 100644 --- a/paddle/phi/kernels/cpu/argsort_kernel.cc +++ b/paddle/phi/kernels/cpu/argsort_kernel.cc @@ -30,7 +30,8 @@ static void FullSort(Type input_height, const DenseTensor* input, T* t_out, Type* t_indices, - bool descending) { + bool descending, + bool stable) { #ifdef PADDLE_WITH_MKLML #pragma omp parallel for #endif @@ -48,18 +49,34 @@ static void FullSort(Type input_height, col_vec.push_back(std::pair(e_input(i, j), j)); } } - std::sort(col_vec.begin(), - col_vec.end(), - [&](const std::pair& l, const std::pair& r) { - if (descending) - return (std::isnan(static_cast(l.first)) && - !std::isnan(static_cast(r.first))) || - (l.first > r.first); - else - return (!std::isnan(static_cast(l.first)) && - std::isnan(static_cast(r.first))) || - (l.first < r.first); - }); + if (stable) { + std::stable_sort( + col_vec.begin(), + col_vec.end(), + [&](const std::pair& l, const std::pair& r) { + if (descending) + return (std::isnan(static_cast(l.first)) && + !std::isnan(static_cast(r.first))) || + (l.first > r.first); + else + return (!std::isnan(static_cast(l.first)) && + std::isnan(static_cast(r.first))) || + (l.first < r.first); + }); + } else { + std::sort(col_vec.begin(), + col_vec.end(), + [&](const std::pair& l, const std::pair& r) { + if (descending) + return (std::isnan(static_cast(l.first)) && + !std::isnan(static_cast(r.first))) || + (l.first > r.first); + else + return (!std::isnan(static_cast(l.first)) && + std::isnan(static_cast(r.first))) || + (l.first < r.first); + }); + } for (Type j = 0; j < input_width; ++j) { t_out[i * input_width + j] = col_vec[j].first; @@ -73,6 +90,7 @@ void ArgsortKernel(const Context& dev_ctx, const DenseTensor& input, int axis, bool descending, + bool stable, DenseTensor* output, DenseTensor* indices) { auto in_dims = input.dims(); @@ -100,7 +118,8 @@ void ArgsortKernel(const Context& dev_ctx, &input, out_data, ids_data, - descending); + descending, + stable); } else { // If not full sort do transpose std::vector trans; @@ -141,7 +160,8 @@ void ArgsortKernel(const Context& dev_ctx, &trans_inp, t_out, t_ind, - descending); + descending, + stable); dev_ctx.template Alloc(indices); TransposeKernel(dev_ctx, tmp_indices, trans, indices); diff --git a/paddle/phi/kernels/gpu/argsort_grad_kernel.cu b/paddle/phi/kernels/gpu/argsort_grad_kernel.cu index 673e2937c93a5f..bdb36b84a0254d 100644 --- a/paddle/phi/kernels/gpu/argsort_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/argsort_grad_kernel.cu @@ -149,6 +149,7 @@ void ArgsortGradKernel(const Context& dev_ctx, const DenseTensor& out_grad, int axis, bool descending, + bool stable, DenseTensor* in_grad) { dev_ctx.template Alloc(in_grad); phi::funcs::set_constant(dev_ctx, in_grad, static_cast(0.0)); diff --git a/paddle/phi/kernels/gpu/argsort_kernel.cu b/paddle/phi/kernels/gpu/argsort_kernel.cu index 1fc367a5a88c64..3b51a3fceddbb3 100644 --- a/paddle/phi/kernels/gpu/argsort_kernel.cu +++ b/paddle/phi/kernels/gpu/argsort_kernel.cu @@ -230,6 +230,7 @@ void ArgsortKernel(const Context& dev_ctx, const DenseTensor& input, int axis, bool descending, + bool stable, DenseTensor* output, DenseTensor* indices) { auto in_dims = input.dims(); @@ -251,14 +252,30 @@ void ArgsortKernel(const Context& dev_ctx, // Compared to the following 'Special case for full sort', ascending sort is // 34 times faster and descending sort is 31 times faster. if (size == in_dims[axis]) { - thrust::sequence(thrust::device, ids_data, ids_data + size); - thrust::copy(thrust::device, in_data, in_data + size, out_data); - thrust::sort_by_key(thrust::device, out_data, out_data + size, ids_data); - if (descending) { - thrust::reverse(thrust::device, out_data, out_data + size); - thrust::reverse(thrust::device, ids_data, ids_data + size); + if (stable) { + thrust::sequence(thrust::device, ids_data, ids_data + size); + thrust::copy(thrust::device, in_data, in_data + size, out_data); + if (descending) { + thrust::stable_sort_by_key(thrust::device, + out_data, + out_data + size, + ids_data, + thrust::greater()); + } else { + thrust::stable_sort_by_key( + thrust::device, out_data, out_data + size, ids_data); + } + return; + } else { + thrust::sequence(thrust::device, ids_data, ids_data + size); + thrust::copy(thrust::device, in_data, in_data + size, out_data); + thrust::sort_by_key(thrust::device, out_data, out_data + size, ids_data); + if (descending) { + thrust::reverse(thrust::device, out_data, out_data + size); + thrust::reverse(thrust::device, ids_data, ids_data + size); + } + return; } - return; } // Special case for full sort, speedup ~190x. diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 7d619ca5e2e8ad..27788ae9b93969 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -36,7 +36,7 @@ __all__ = [] -def argsort(x, axis=-1, descending=False, name=None): +def argsort(x, axis=-1, descending=False, stable=False, name=None): """ Sorts the input along the given axis, and returns the corresponding index tensor for the sorted output values. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True. @@ -49,6 +49,9 @@ def argsort(x, axis=-1, descending=False, name=None): descending (bool, optional) : Descending is a flag, if set to true, algorithm will sort by descending order, else sort by ascending order. Default is false. + stable (bool, optional): Whether to use stable sorting algorithm or not. + When using stable sorting algorithm, the order of equivalent elements + will be preserved. Default is False. name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: @@ -100,7 +103,7 @@ def argsort(x, axis=-1, descending=False, name=None): [0, 2, 1, 1]]]) """ if in_dynamic_or_pir_mode(): - _, ids = _C_ops.argsort(x, axis, descending) + _, ids = _C_ops.argsort(x, axis, descending, stable) return ids else: check_variable_and_dtype( @@ -129,7 +132,7 @@ def argsort(x, axis=-1, descending=False, name=None): type='argsort', inputs={'X': x}, outputs={'Out': out, 'Indices': ids}, - attrs={'axis': axis, 'descending': descending}, + attrs={'axis': axis, 'descending': descending, 'stable': stable}, ) return ids @@ -500,7 +503,7 @@ def nonzero(x, as_tuple=False): return tuple(list_out) -def sort(x, axis=-1, descending=False, name=None): +def sort(x, axis=-1, descending=False, stable=False, name=None): """ Sorts the input along the given axis, and returns the sorted output tensor. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True. @@ -514,6 +517,9 @@ def sort(x, axis=-1, descending=False, name=None): descending (bool, optional) : Descending is a flag, if set to true, algorithm will sort by descending order, else sort by ascending order. Default is false. + stable (bool, optional): Whether to use stable sorting algorithm or not. + When using stable sorting algorithm, the order of equivalent elements + will be preserved. Default is False. name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: @@ -557,7 +563,7 @@ def sort(x, axis=-1, descending=False, name=None): [5. 7. 7. 9.]]] """ if in_dynamic_or_pir_mode(): - outs, _ = _C_ops.argsort(x, axis, descending) + outs, _ = _C_ops.argsort(x, axis, descending, stable) return outs else: helper = LayerHelper("sort", **locals()) @@ -571,7 +577,7 @@ def sort(x, axis=-1, descending=False, name=None): type='argsort', inputs={'X': x}, outputs={'Out': out, 'Indices': ids}, - attrs={'axis': axis, 'descending': descending}, + attrs={'axis': axis, 'descending': descending, 'stable': stable}, ) return out diff --git a/test/legacy_test/test_argsort_op.py b/test/legacy_test/test_argsort_op.py index 52102c20036c7f..2769d53e1fb71a 100644 --- a/test/legacy_test/test_argsort_op.py +++ b/test/legacy_test/test_argsort_op.py @@ -449,6 +449,133 @@ def init(self): self.axis = 1 +class TestStableArgsort(unittest.TestCase): + def setUp(self): + self.input_shape = [ + 30, + ] + self.axis = 0 + self.data = np.array([100.0, 50.0, 10.0] * 10) + + def cpu_place(self): + self.place = core.CPUPlace() + + def gpu_place(self): + if core.is_compiled_with_cuda(): + self.place = core.CUDAPlace(0) + else: + self.place = core.CPUPlace() + + @test_with_pir_api + def test_api_static1_cpu(self): + self.cpu_place() + with paddle.static.program_guard(paddle.static.Program()): + input = paddle.static.data( + name="input", shape=self.input_shape, dtype="float64" + ) + output = paddle.argsort(input, axis=self.axis, stable=True) + np_result = np.argsort(self.data, axis=self.axis, kind='stable') + exe = paddle.static.Executor(self.place) + result = exe.run( + paddle.static.default_main_program(), + feed={'input': self.data}, + fetch_list=[output], + ) + + self.assertEqual((result == np_result).all(), True) + + @test_with_pir_api + def test_api_static1_gpu(self): + self.gpu_place() + with paddle.static.program_guard(paddle.static.Program()): + input = paddle.static.data( + name="input", shape=self.input_shape, dtype="float64" + ) + output = paddle.argsort(input, axis=self.axis, stable=True) + np_result = np.argsort(self.data, axis=self.axis, kind='stable') + exe = paddle.static.Executor(self.place) + result = exe.run( + paddle.static.default_main_program(), + feed={'input': self.data}, + fetch_list=[output], + ) + + self.assertEqual((result == np_result).all(), True) + + @test_with_pir_api + def test_api_static2_cpu(self): + self.cpu_place() + with paddle.static.program_guard(paddle.static.Program()): + input = paddle.static.data( + name="input", shape=self.input_shape, dtype="float64" + ) + output2 = paddle.argsort( + input, axis=self.axis, descending=True, stable=True + ) + np_result2 = np.argsort(-self.data, axis=self.axis, kind='stable') + exe = paddle.static.Executor(self.place) + result2 = exe.run( + paddle.static.default_main_program(), + feed={'input': self.data}, + fetch_list=[output2], + ) + + self.assertEqual((result2 == np_result2).all(), True) + + @test_with_pir_api + def test_api_static2_gpu(self): + self.gpu_place() + with paddle.static.program_guard(paddle.static.Program()): + input = paddle.static.data( + name="input", shape=self.input_shape, dtype="float64" + ) + output2 = paddle.argsort( + input, axis=self.axis, descending=True, stable=True + ) + np_result2 = np.argsort(-self.data, axis=self.axis, kind='stable') + exe = paddle.static.Executor(self.place) + result2 = exe.run( + paddle.static.default_main_program(), + feed={'input': self.data}, + fetch_list=[output2], + ) + + self.assertEqual((result2 == np_result2).all(), True) + + +class TestStableArgsort2(TestStableArgsort): + def init(self): + self.input_shape = [30, 1] + self.axis = 0 + + +class TestStableArgsort3(TestStableArgsort): + def init(self): + self.input_shape = [1, 30] + self.axis = 1 + + +class TestStableArgsort4(TestStableArgsort): + def init(self): + self.input_shape = [40, 3, 4] + self.axis = 0 + self.data = np.array( + [ + [ + [100.0, 50.0, -10.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [100.0, 50.0, -10.0, 1.0], + ], + [ + [70.0, -30.0, 60.0, 100.0], + [0.0, 0.0, 1.0, 1.0], + [100.0, 50.0, -10.0, 1.0], + ], + ] + * 20 + ) + + class TestArgsortImperative(unittest.TestCase): def init(self): self.input_shape = [ @@ -496,6 +623,89 @@ def init(self): self.axis = 1 +class TestStableArgsortImperative(unittest.TestCase): + def setUp(self): + self.input_shape = [ + 30, + ] + self.axis = 0 + self.input_data = np.array([100.0, 50.0, 10.0] * 10) + + def cpu_place(self): + self.place = core.CPUPlace() + + def gpu_place(self): + if core.is_compiled_with_cuda(): + self.place = core.CUDAPlace(0) + else: + self.place = core.CPUPlace() + + def test_api_cpu(self): + self.cpu_place() + paddle.disable_static(self.place) + var_x = paddle.to_tensor(self.input_data) + out = paddle.argsort(var_x, axis=self.axis, stable=True) + expect = np.argsort(self.input_data, axis=self.axis, kind='stable') + self.assertEqual((expect == out.numpy()).all(), True) + + out2 = paddle.argsort( + var_x, axis=self.axis, descending=True, stable=True + ) + expect2 = np.argsort(-self.input_data, axis=self.axis, kind='stable') + self.assertEqual((expect2 == out2.numpy()).all(), True) + + paddle.enable_static() + + def test_api_gpu(self): + self.gpu_place() + paddle.disable_static(self.place) + var_x = paddle.to_tensor(self.input_data) + out = paddle.argsort(var_x, axis=self.axis, stable=True) + expect = np.argsort(self.input_data, axis=self.axis, kind='stable') + self.assertEqual((expect == out.numpy()).all(), True) + + out2 = paddle.argsort( + var_x, axis=self.axis, descending=True, stable=True + ) + expect2 = np.argsort(-self.input_data, axis=self.axis, kind='stable') + self.assertEqual((expect2 == out2.numpy()).all(), True) + + paddle.enable_static() + + +class TestStableArgsortImperative2(TestStableArgsortImperative): + def init(self): + self.input_shape = [30, 1] + self.axis = 0 + + +class TestStableArgsortImperative3(TestStableArgsortImperative): + def init(self): + self.input_shape = [1, 30] + self.axis = 1 + + +class TestStableArgsortImperative4(TestStableArgsortImperative): + def init(self): + self.input_shape = [40, 3, 4] + self.axis = 0 + self.data = np.array( + [ + [ + [100.0, 50.0, -10.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [100.0, 50.0, -10.0, 1.0], + ], + [ + [70.0, -30.0, 60.0, 100.0], + [0.0, 0.0, 1.0, 1.0], + [100.0, 50.0, -10.0, 1.0], + ], + ] + * 20 + ) + + class TestArgsortWithInputNaN(unittest.TestCase): def init(self): self.axis = 0 diff --git a/test/legacy_test/test_sort_op.py b/test/legacy_test/test_sort_op.py index 6559f966b46859..ac77f2db4e44f7 100644 --- a/test/legacy_test/test_sort_op.py +++ b/test/legacy_test/test_sort_op.py @@ -64,6 +64,22 @@ def test_api_1(self): np_result = np.sort(result, axis=1) self.assertEqual((result == np_result).all(), True) + @test_with_pir_api + def test_api_2(self): + with base.program_guard(base.Program()): + input = paddle.static.data( + name="input", shape=[30], dtype="float32" + ) + output = paddle.sort(x=input, axis=0, stable=True) + exe = base.Executor(self.place) + data = np.array( + [100.0, 50.0, 10.0] * 10, + dtype='float32', + ) + (result,) = exe.run(feed={'input': data}, fetch_list=[output]) + np_result = np.sort(result, axis=0, kind='stable') + self.assertEqual((result == np_result).all(), True) + class TestSortOnGPU(TestSortOnCPU): def init_place(self): @@ -97,6 +113,21 @@ def test_api_1(self): ) paddle.enable_static() + def test_api_2(self): + paddle.disable_static(self.place) + var_x = paddle.to_tensor(np.array([100.0, 50.0, 10.0] * 10)) + out = paddle.sort(var_x, axis=0) + self.assertEqual( + ( + np.sort( + np.array([100.0, 50.0, 10.0] * 10), axis=0, kind='stable' + ) + == out.numpy() + ).all(), + True, + ) + paddle.enable_static() + if __name__ == '__main__': unittest.main()