-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.33】 reshape more dtype and test -part #58764
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3734,7 +3734,7 @@ def reshape(x, shape, name=None): | |
- 3. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape is [-1, 0, 3, 2], the reshape operator will transform x into a 4-D tensor with shape [2, 4, 3, 2] and leaving x's data unchanged. In this case, besides -1, 0 means the actual dimension value is going to be copied from the corresponding dimension of x. | ||
|
||
Args: | ||
x (Tensor): An N-D Tensor. The data type is ``float32``, ``float64``, ``int32``, ``int64`` or ``bool`` | ||
x (Tensor): An N-D Tensor. The data type is ``float16``, ``float32``, ``float64``, ``int16``, ``int32``, ``int64``, ``uint16``, ``int8``, ``uint8``, ``complex64``, ``complex128``, ``bfloat16`` or ``bool``. | ||
shape (list|tuple|Tensor): Define the target shape. At most one dimension of the target shape can be -1. | ||
The data type is ``int32`` . If ``shape`` is a list or tuple, each element of it should be integer or Tensor with shape []. | ||
If ``shape`` is an Tensor, it should be an 1-D Tensor . | ||
|
@@ -3879,6 +3879,11 @@ def get_attr_shape(list_shape): | |
'int64', | ||
'bool', | ||
'uint16', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same issue of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
check_variable_and_dtype(
x,
'x',
[
'float16',
'float32',
'float64',
'int16',
'int32',
'int64',
'bool',
'uint16',
'int8',
'uint8',
'complex64',
'complex128',
],
'reshape',
) But remove
def _test_static_dtype(self):
places = [paddle.CPUPlace()] + (
[paddle.CUDAPlace(0)] if base.core.is_compiled_with_cuda() else []
)
dtypes = [
'float16',
'float32',
'float64',
'int16',
'int32',
'int64',
'int8',
'uint8',
'complex64',
'complex128',
'bfloat16',
'bool',
]
... right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, there is a historical issue with function |
||
'int8', | ||
'uint8', | ||
'complex64', | ||
'complex128', | ||
'bfloat16', | ||
], | ||
'reshape', | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -453,9 +453,62 @@ def _test_api(self): | |
np.testing.assert_array_equal(res_3, input.reshape([5, 10])) | ||
np.testing.assert_array_equal(res_4, input.reshape(shape)) | ||
|
||
def _test_static_dtype(self): | ||
places = [paddle.CPUPlace()] + ( | ||
[paddle.CUDAPlace(0)] if base.core.is_compiled_with_cuda() else [] | ||
) | ||
|
||
dtypes = [ | ||
'float16', | ||
'float32', | ||
'float64', | ||
'int16', | ||
'int32', | ||
'int64', | ||
'uint16', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same issue of |
||
'int8', | ||
'uint8', | ||
'complex64', | ||
'complex128', | ||
'bfloat16', | ||
'bool', | ||
] | ||
for place in places: | ||
for dtype in dtypes: | ||
# core is not compiled with CUDA and not support the bfloat16 | ||
if ( | ||
dtype == 'bfloat16' | ||
and not base.core.is_compiled_with_cuda() | ||
): | ||
continue | ||
|
||
dtype_paddle = dtype | ||
# numpy not support bfloat16, use uint16 instead | ||
dtype_numpy = dtype if dtype != 'bfloat16' else 'uint16' | ||
|
||
paddle.enable_static() | ||
input = np.random.random([2, 25]).astype(dtype_numpy) | ||
shape = [2, 5, 5] | ||
main_prog = paddle.static.Program() | ||
with paddle.static.program_guard( | ||
main_prog, paddle.static.Program() | ||
): | ||
x = self.data(name="x", shape=[2, 25], dtype=dtype_paddle) | ||
out_1 = self.reshape(x, shape) | ||
|
||
exe = paddle.static.Executor(place=place) | ||
res_1 = exe.run( | ||
main_prog, | ||
feed={"x": input}, | ||
fetch_list=[out_1], | ||
)[0] | ||
|
||
np.testing.assert_array_equal(res_1, input.reshape(shape)) | ||
|
||
def test_paddle_api(self): | ||
self._set_paddle_api() | ||
self._test_api() | ||
self._test_static_dtype() | ||
|
||
def test_imperative(self): | ||
self._set_paddle_api() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no official support for
uint16
, should deleteuint16
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle use
bfloat16
to holduint16
, so no official support foruint16
?