Skip to content

Commit

Permalink
【Hackathon 5th No.33】 reshape more dtype and test -part (PaddlePaddle…
Browse files Browse the repository at this point in the history
…#58764)

* [Change] reshape more dtype and test

* [Fix] remove uint16

* [Change] use uint16 instead of bfloat16
  • Loading branch information
megemini authored and SecretXV committed Nov 28, 2023
1 parent 4bfbed4 commit 2826abc
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 1 deletion.
6 changes: 5 additions & 1 deletion python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3747,7 +3747,7 @@ def reshape(x, shape, name=None):
- 3. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape is [-1, 0, 3, 2], the reshape operator will transform x into a 4-D tensor with shape [2, 4, 3, 2] and leaving x's data unchanged. In this case, besides -1, 0 means the actual dimension value is going to be copied from the corresponding dimension of x.
Args:
x (Tensor): An N-D Tensor. The data type is ``float32``, ``float64``, ``int32``, ``int64`` or ``bool``
x (Tensor): An N-D Tensor. The data type is ``float16``, ``float32``, ``float64``, ``int16``, ``int32``, ``int64``, ``int8``, ``uint8``, ``complex64``, ``complex128``, ``bfloat16`` or ``bool``.
shape (list|tuple|Tensor): Define the target shape. At most one dimension of the target shape can be -1.
The data type is ``int32`` . If ``shape`` is a list or tuple, each element of it should be integer or Tensor with shape [].
If ``shape`` is an Tensor, it should be an 1-D Tensor .
Expand Down Expand Up @@ -3892,6 +3892,10 @@ def get_attr_shape(list_shape):
'int64',
'bool',
'uint16',
'int8',
'uint8',
'complex64',
'complex128',
],
'reshape',
)
Expand Down
52 changes: 52 additions & 0 deletions test/legacy_test/test_reshape_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,9 +453,61 @@ def _test_api(self):
np.testing.assert_array_equal(res_3, input.reshape([5, 10]))
np.testing.assert_array_equal(res_4, input.reshape(shape))

def _test_static_dtype(self):
places = [paddle.CPUPlace()] + (
[paddle.CUDAPlace(0)] if base.core.is_compiled_with_cuda() else []
)

dtypes = [
'float16',
'float32',
'float64',
'int16',
'int32',
'int64',
'int8',
'uint8',
'complex64',
'complex128',
'bfloat16',
'bool',
]
for place in places:
for dtype in dtypes:
# core is not compiled with CUDA and not support the bfloat16
if (
dtype == 'bfloat16'
and not base.core.is_compiled_with_cuda()
):
continue

dtype_paddle = dtype
# numpy not support bfloat16, use uint16 instead
dtype_numpy = dtype if dtype != 'bfloat16' else 'uint16'

paddle.enable_static()
input = np.random.random([2, 25]).astype(dtype_numpy)
shape = [2, 5, 5]
main_prog = paddle.static.Program()
with paddle.static.program_guard(
main_prog, paddle.static.Program()
):
x = self.data(name="x", shape=[2, 25], dtype=dtype_paddle)
out_1 = self.reshape(x, shape)

exe = paddle.static.Executor(place=place)
res_1 = exe.run(
main_prog,
feed={"x": input},
fetch_list=[out_1],
)[0]

np.testing.assert_array_equal(res_1, input.reshape(shape))

def test_paddle_api(self):
self._set_paddle_api()
self._test_api()
self._test_static_dtype()

def test_imperative(self):
self._set_paddle_api()
Expand Down

0 comments on commit 2826abc

Please sign in to comment.