diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index dcfee03f40cfa..5911b496a32d2 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -586,6 +586,16 @@ def test_logical_not(self): self.assertEqual(out.shape, []) + def test_searchsorted(self): + x = paddle.to_tensor([1, 3, 5, 7, 9]) + y = paddle.rand([]) + + # only has forward kernel + out = paddle.searchsorted(x, y) + + self.assertEqual(out.shape, []) + self.assertEqual(out.numpy(), 0) + class TestSundryAPIStatic(unittest.TestCase): def setUp(self): @@ -656,6 +666,17 @@ def test_logical_not(self): res = self.exe.run(prog, fetch_list=[out]) self.assertEqual(res[0].shape, ()) + @prog_scope() + def test_searchsorted(self): + x = paddle.full([10], 1.0, 'float32') + y = paddle.full([], 1.0, 'float32') + out = paddle.searchsorted(x, y) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out]) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[0], 0) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py index a0925207c957e..3be5c315f3bcc 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py @@ -414,6 +414,16 @@ def test_logical_not(self): self.assertEqual(out.shape, []) + def test_searchsorted(self): + x = paddle.to_tensor([1, 3, 5, 7, 9]) + y = paddle.rand([]) + + # only has forward kernel + out = paddle.searchsorted(x, y) + + self.assertEqual(out.shape, []) + self.assertEqual(out.numpy(), 0) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase):