diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 3e74e7a579a35..b0b7a8c8050f0 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -907,15 +907,15 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): value = float(value) if isinstance(shape, (list, tuple)): shape = paddle.utils.convert_shape_to_list(shape) - else: + paddle.utils.check_shape(shape) if isinstance(shape, (list, tuple)): if paddle.utils._contain_var(shape): shape = paddle.utils.get_int_tensor_list(shape, place) elif isinstance(shape, paddle.pir.Value): pass else: - TypeError("Shape only supports OpResult, or list, or tuple.") + raise TypeError("Shape only supports Value, or list, or tuple.") if out is None: out = _C_ops.full(shape, value, dtype, place) diff --git a/python/paddle/utils/layers_utils.py b/python/paddle/utils/layers_utils.py index d61ed75aa4e2b..4c0950a3da558 100644 --- a/python/paddle/utils/layers_utils.py +++ b/python/paddle/utils/layers_utils.py @@ -30,6 +30,7 @@ _current_expected_place, in_dygraph_mode, ) +from ..pir import Value def convert_to_list(value, n, name, dtype=int): @@ -496,11 +497,11 @@ def check_shape(shape): """ Check shape type and shape elements type before passing it to fill_constant """ - if isinstance(shape, Variable): + if isinstance(shape, (Variable, Value)): check_dtype(shape.dtype, 'shape', ['int32', 'int64'], 'fill_constant') - else: + elif isinstance(shape, (list, tuple)): for ele in shape: - if not isinstance(ele, Variable): + if not isinstance(ele, (Variable, Value)): if ele < 0: raise ValueError( "All elements in ``shape`` must be positive when it's a list or tuple" @@ -509,6 +510,13 @@ def check_shape(shape): raise TypeError( "All elements in ``shape`` must be integers when it's a list or tuple" ) + else: + check_dtype( + ele.dtype, + 'element of shape', + ['int32', 'int64'], + 'fill_constant', + ) def try_set_static_shape_tensor(tensor, shape): diff --git a/test/legacy_test/test_full_op.py b/test/legacy_test/test_full_op.py index 0281d41252a27..60e7d01c7f237 100644 --- a/test/legacy_test/test_full_op.py +++ b/test/legacy_test/test_full_op.py @@ -18,7 +18,6 @@ import paddle from paddle import base -from paddle.base import Program, program_guard from paddle.pir_utils import test_with_pir_api @@ -26,6 +25,7 @@ class TestFullAPI(unittest.TestCase): @test_with_pir_api def test_api(self): + paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): positive_2_int32 = paddle.tensor.fill_constant([1], "int32", 2) @@ -98,6 +98,7 @@ def test_api(self): np.testing.assert_array_equal( res_7, np.full([1, 2], 1.1, dtype="float32") ) + paddle.disable_static() def test_api_eager(self): with base.dygraph.base.guard(): @@ -184,8 +185,12 @@ def test_api_eager(self): class TestFullOpError(unittest.TestCase): + @test_with_pir_api def test_errors(self): - with program_guard(Program(), Program()): + paddle.enable_static() + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): # for ci coverage self.assertRaises( TypeError, paddle.full, shape=[1], fill_value=5, dtype='uint4' @@ -216,6 +221,7 @@ def test_shape_tensor_list_dtype(): paddle.full(shape=[shape, 2], dtype="float32", fill_value=1) self.assertRaises(TypeError, test_shape_tensor_list_dtype) + paddle.disable_static() if __name__ == "__main__":