From 61ab73650f34046b11378ba92dd67b77cb8b1985 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Mon, 18 Mar 2024 21:38:57 +0800 Subject: [PATCH 1/7] test_errors_d_16 --- test/legacy_test/test_full_op.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/legacy_test/test_full_op.py b/test/legacy_test/test_full_op.py index 0281d41252a27..60e7d01c7f237 100644 --- a/test/legacy_test/test_full_op.py +++ b/test/legacy_test/test_full_op.py @@ -18,7 +18,6 @@ import paddle from paddle import base -from paddle.base import Program, program_guard from paddle.pir_utils import test_with_pir_api @@ -26,6 +25,7 @@ class TestFullAPI(unittest.TestCase): @test_with_pir_api def test_api(self): + paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): positive_2_int32 = paddle.tensor.fill_constant([1], "int32", 2) @@ -98,6 +98,7 @@ def test_api(self): np.testing.assert_array_equal( res_7, np.full([1, 2], 1.1, dtype="float32") ) + paddle.disable_static() def test_api_eager(self): with base.dygraph.base.guard(): @@ -184,8 +185,12 @@ def test_api_eager(self): class TestFullOpError(unittest.TestCase): + @test_with_pir_api def test_errors(self): - with program_guard(Program(), Program()): + paddle.enable_static() + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): # for ci coverage self.assertRaises( TypeError, paddle.full, shape=[1], fill_value=5, dtype='uint4' @@ -216,6 +221,7 @@ def test_shape_tensor_list_dtype(): paddle.full(shape=[shape, 2], dtype="float32", fill_value=1) self.assertRaises(TypeError, test_shape_tensor_list_dtype) + paddle.disable_static() if __name__ == "__main__": From c4548310fe8a1ee32a023e11cf73723a824815a8 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Tue, 19 Mar 2024 12:10:13 +0800 Subject: [PATCH 2/7] fix typecheck for fill_constant --- python/paddle/tensor/creation.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 3e74e7a579a35..9fa82ee344d57 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -25,6 +25,7 @@ from ..base.data_feeder import ( check_dtype, + check_shape, check_type, check_variable_and_dtype, convert_dtype, @@ -903,11 +904,17 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): if in_pir_mode() and isinstance(dtype, core.VarDesc.VarType): dtype = paddle.pir.core.vartype_to_datatype[dtype] + check_shape(shape, 'fill_constant') if in_dynamic_mode(): value = float(value) if isinstance(shape, (list, tuple)): shape = paddle.utils.convert_shape_to_list(shape) - + elif isinstance(shape, core.eager.Tensor): + pass + else: + raise TypeError( + "Shape only supports Tensor, or list, or tuple." + ) else: if isinstance(shape, (list, tuple)): if paddle.utils._contain_var(shape): @@ -915,7 +922,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): elif isinstance(shape, paddle.pir.Value): pass else: - TypeError("Shape only supports OpResult, or list, or tuple.") + TypeError("Shape only supports Value, or list, or tuple.") if out is None: out = _C_ops.full(shape, value, dtype, place) From 3d42232a3116f1a858951b94f0b240c59e1fa3e2 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Tue, 19 Mar 2024 14:56:12 +0800 Subject: [PATCH 3/7] add raise and revert change for dynamic mode --- python/paddle/tensor/creation.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 9fa82ee344d57..d4ac79f2e56f2 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -909,12 +909,6 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): value = float(value) if isinstance(shape, (list, tuple)): shape = paddle.utils.convert_shape_to_list(shape) - elif isinstance(shape, core.eager.Tensor): - pass - else: - raise TypeError( - "Shape only supports Tensor, or list, or tuple." - ) else: if isinstance(shape, (list, tuple)): if paddle.utils._contain_var(shape): @@ -922,7 +916,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): elif isinstance(shape, paddle.pir.Value): pass else: - TypeError("Shape only supports Value, or list, or tuple.") + raise TypeError("Shape only supports Value, or list, or tuple.") if out is None: out = _C_ops.full(shape, value, dtype, place) From 701d096b6d0f91e41b77111a8dd5877e56d308da Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Tue, 19 Mar 2024 19:54:45 +0800 Subject: [PATCH 4/7] try to pass to PR-CI-Coverage --- python/paddle/tensor/creation.py | 3 +-- python/paddle/utils/layers_utils.py | 5 +++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index d4ac79f2e56f2..8ea617f05ff6e 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -25,7 +25,6 @@ from ..base.data_feeder import ( check_dtype, - check_shape, check_type, check_variable_and_dtype, convert_dtype, @@ -904,7 +903,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): if in_pir_mode() and isinstance(dtype, core.VarDesc.VarType): dtype = paddle.pir.core.vartype_to_datatype[dtype] - check_shape(shape, 'fill_constant') + paddle.utils.check_shape(shape) if in_dynamic_mode(): value = float(value) if isinstance(shape, (list, tuple)): diff --git a/python/paddle/utils/layers_utils.py b/python/paddle/utils/layers_utils.py index d61ed75aa4e2b..4393e8f7c5a91 100644 --- a/python/paddle/utils/layers_utils.py +++ b/python/paddle/utils/layers_utils.py @@ -30,6 +30,7 @@ _current_expected_place, in_dygraph_mode, ) +from ..pir import Value def convert_to_list(value, n, name, dtype=int): @@ -496,11 +497,11 @@ def check_shape(shape): """ Check shape type and shape elements type before passing it to fill_constant """ - if isinstance(shape, Variable): + if isinstance(shape, (Variable, Value)): check_dtype(shape.dtype, 'shape', ['int32', 'int64'], 'fill_constant') else: for ele in shape: - if not isinstance(ele, Variable): + if not isinstance(ele, (Variable, Value)): if ele < 0: raise ValueError( "All elements in ``shape`` must be positive when it's a list or tuple" From fc45ed47bd3c0cd0c4c7ee4b54c7328a6b87e84d Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Tue, 19 Mar 2024 21:13:59 +0800 Subject: [PATCH 5/7] avoid dynamic_mode --- python/paddle/tensor/creation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 8ea617f05ff6e..b0b7a8c8050f0 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -903,12 +903,12 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): if in_pir_mode() and isinstance(dtype, core.VarDesc.VarType): dtype = paddle.pir.core.vartype_to_datatype[dtype] - paddle.utils.check_shape(shape) if in_dynamic_mode(): value = float(value) if isinstance(shape, (list, tuple)): shape = paddle.utils.convert_shape_to_list(shape) else: + paddle.utils.check_shape(shape) if isinstance(shape, (list, tuple)): if paddle.utils._contain_var(shape): shape = paddle.utils.get_int_tensor_list(shape, place) From 2a8f851b54ecf9f32c3724a86d9afb76876d9d3b Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Wed, 20 Mar 2024 11:57:16 +0800 Subject: [PATCH 6/7] add element of shape check --- python/paddle/utils/layers_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/paddle/utils/layers_utils.py b/python/paddle/utils/layers_utils.py index 4393e8f7c5a91..f83a8b6eabcd4 100644 --- a/python/paddle/utils/layers_utils.py +++ b/python/paddle/utils/layers_utils.py @@ -510,6 +510,13 @@ def check_shape(shape): raise TypeError( "All elements in ``shape`` must be integers when it's a list or tuple" ) + else: + check_dtype( + ele.dtype, + 'element of shape', + ['int32', 'int64'], + 'fill_constant', + ) def try_set_static_shape_tensor(tensor, shape): From aebf6594816fc6a29b4bb3b3b94592c91a66e0d0 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Wed, 20 Mar 2024 17:03:25 +0800 Subject: [PATCH 7/7] fix ci coverage --- python/paddle/utils/layers_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/utils/layers_utils.py b/python/paddle/utils/layers_utils.py index f83a8b6eabcd4..4c0950a3da558 100644 --- a/python/paddle/utils/layers_utils.py +++ b/python/paddle/utils/layers_utils.py @@ -499,7 +499,7 @@ def check_shape(shape): """ if isinstance(shape, (Variable, Value)): check_dtype(shape.dtype, 'shape', ['int32', 'int64'], 'fill_constant') - else: + elif isinstance(shape, (list, tuple)): for ele in shape: if not isinstance(ele, (Variable, Value)): if ele < 0: