Skip to content

Commit

Permalink
add 0d tensor support for searchsorted for gpu and cpu (#48314)
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu authored Nov 25, 2022
1 parent db749ee commit 8c797ba
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
21 changes: 21 additions & 0 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,16 @@ def test_logical_not(self):

self.assertEqual(out.shape, [])

def test_searchsorted(self):
x = paddle.to_tensor([1, 3, 5, 7, 9])
y = paddle.rand([])

# only has forward kernel
out = paddle.searchsorted(x, y)

self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 0)


class TestSundryAPIStatic(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -656,6 +666,17 @@ def test_logical_not(self):
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, ())

@prog_scope()
def test_searchsorted(self):
x = paddle.full([10], 1.0, 'float32')
y = paddle.full([], 1.0, 'float32')
out = paddle.searchsorted(x, y)

prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[0], 0)


# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,16 @@ def test_logical_not(self):

self.assertEqual(out.shape, [])

def test_searchsorted(self):
x = paddle.to_tensor([1, 3, 5, 7, 9])
y = paddle.rand([])

# only has forward kernel
out = paddle.searchsorted(x, y)

self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 0)


# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
Expand Down

0 comments on commit 8c797ba

Please sign in to comment.