Skip to content

Commit

Permalink
update paddle.full
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoyinglia committed Dec 29, 2022
1 parent df63aee commit fa3e474
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 17 deletions.
12 changes: 6 additions & 6 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,14 +785,14 @@ def test_reshape_tensor(self):
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])

new_shape = paddle.full([], 1, "int32")
new_shape = paddle.full([1], 1, "int32")
out = paddle.reshape(x, new_shape)
out.backward()
self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, [1])
self.assertEqual(out.grad.shape, [1])

new_shape = paddle.full([], -1, "int32")
new_shape = paddle.full([1], -1, "int32")
out = paddle.reshape(x, new_shape)
out.backward()
self.assertEqual(x.grad.shape, [1, 1])
Expand Down Expand Up @@ -825,11 +825,11 @@ def test_reshape__tensor(self):
out = paddle.reshape_(x, [])
self.assertEqual(out.shape, [])

new_shape = paddle.full([], 1, "int32")
new_shape = paddle.full([1], 1, "int32")
out = paddle.reshape_(x, new_shape)
self.assertEqual(out.shape, [1])

new_shape = paddle.full([], -1, "int32")
new_shape = paddle.full([1], -1, "int32")
out = paddle.reshape_(x, new_shape)
self.assertEqual(out.shape, [1])

Expand Down Expand Up @@ -1130,11 +1130,11 @@ def test_reshape_tensor(self):
x2.stop_gradient = False
x3.stop_gradient = False

new_shape = paddle.full([], 1, "int32")
new_shape = paddle.full([1], 1, "int32")
out1 = paddle.reshape(x1, new_shape)
paddle.static.append_backward(out1)

new_shape = paddle.full([], -1, "int32")
new_shape = paddle.full([1], -1, "int32")
out2 = paddle.reshape(x2, new_shape)
paddle.static.append_backward(out2)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -634,27 +634,18 @@ def test_reshape__tensor(self):
out = paddle.reshape_(x, [])
self.assertEqual(out.shape, [])

new_shape = paddle.full([], 1, "int32")
new_shape = paddle.full([1], 1, "int32")
out = paddle.reshape_(x, new_shape)
self.assertEqual(out.shape, [1])

new_shape = paddle.full([], -1, "int32")
new_shape = paddle.full([1], -1, "int32")
out = paddle.reshape_(x, new_shape)
self.assertEqual(out.shape, [1])

new_shape = [paddle.full([], -1, "int32"), paddle.full([], 1, "int32")]
out = paddle.reshape_(x, new_shape)
self.assertEqual(out.shape, [1, 1])

def test_reverse(self):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.reverse(x, axis=[])
out.backward()
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])


# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
Expand Down

0 comments on commit fa3e474

Please sign in to comment.