diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 321366276c611..67ca945642385 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -52,7 +52,7 @@ def transpose(x, perm, name=None): perm[i]-th dimension of `input`. Args: - x (Tensor): The input Tensor. It is a N-D Tensor of data types bool, float32, float64, int32. + x (Tensor): The input Tensor. It is a N-D Tensor of data types bool, float16, bfloat16, float32, float64, int8, int16, int32, int64, uint8, uint16, complex64, complex128. perm (list|tuple): Permute the input according to the data of perm. name (str, optional): The name of this layer. For more information, please refer to :ref:`api_guide_Name`. Default is None. @@ -119,8 +119,12 @@ def transpose(x, perm, name=None): [ 'bool', 'float16', + 'bfloat16', 'float32', 'float64', + 'int8', + 'uint8', + 'int16', 'int32', 'int64', 'uint16', diff --git a/test/legacy_test/test_transpose_op.py b/test/legacy_test/test_transpose_op.py index 1ba2c4b980741..d779e4bd7f398 100644 --- a/test/legacy_test/test_transpose_op.py +++ b/test/legacy_test/test_transpose_op.py @@ -21,7 +21,7 @@ import paddle from paddle import base -from paddle.base import Program, core, program_guard +from paddle.base import core from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -499,9 +499,12 @@ def initTestCase(self): class TestTransposeOpError(unittest.TestCase): + @test_with_pir_api def test_errors(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data( name='x', shape=[-1, 10, 5, 3], dtype='float64' ) @@ -512,15 +515,6 @@ def test_x_Variable_check(): self.assertRaises(TypeError, test_x_Variable_check) - def test_x_dtype_check(): - # the Input(x)'s dtype must be one of [bool, float16, float32, float64, int32, int64] - x1 = paddle.static.data( - name='x1', shape=[-1, 10, 5, 3], dtype='int8' - ) - paddle.transpose(x1, perm=[1, 0, 2]) - - self.assertRaises(TypeError, test_x_dtype_check) - def test_perm_list_check(): # Input(perm)'s type must be list paddle.transpose(x, perm="[1, 0, 2]")