From 9b68571ab35d2f09fa8c237edcd655a41e1db46f Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 23 Jan 2024 09:49:20 +0800 Subject: [PATCH 01/13] fix --- test/legacy_test/test_split_op.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/test/legacy_test/test_split_op.py b/test/legacy_test/test_split_op.py index 0a674009651d0..4f738a225f646 100644 --- a/test/legacy_test/test_split_op.py +++ b/test/legacy_test/test_split_op.py @@ -19,7 +19,7 @@ import paddle from paddle import base -from paddle.base import Program, core, program_guard +from paddle.base import core from paddle.pir_utils import test_with_pir_api @@ -361,10 +361,13 @@ def test_api(self): np.testing.assert_array_equal(res_5, out[2]) -class TestSplitOpError(unittest.TestCase): - def test_errors(self): +class TestSplitOpErrorStatic(unittest.TestCase): + @test_with_pir_api + def test_errors_with_static(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): # The type of axis in split_op should be int or Variable. def test_axis_type(): x6 = paddle.static.data( @@ -412,6 +415,9 @@ def test_axis_type_tensor(): self.assertRaises(TypeError, test_axis_type_tensor) paddle.disable_static() + +class TestSplitOpErrorDynamic(unittest.TestCase): + def test_errors_with_dynamic(self): with paddle.base.dygraph.guard(): def test_0_num_tensor(): From 84500e2474740fb8db4bfe17a27accdc66e6bfc2 Mon Sep 17 00:00:00 2001 From: enkilee Date: Wed, 24 Jan 2024 12:41:45 +0800 Subject: [PATCH 02/13] fix --- test/legacy_test/test_split_op.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/test/legacy_test/test_split_op.py b/test/legacy_test/test_split_op.py index 4f738a225f646..54fd9676382e0 100644 --- a/test/legacy_test/test_split_op.py +++ b/test/legacy_test/test_split_op.py @@ -370,25 +370,13 @@ def test_errors_with_static(self): ): # The type of axis in split_op should be int or Variable. def test_axis_type(): - x6 = paddle.static.data( + x5 = paddle.static.data( shape=[-1, 4], dtype='float16', name='x3' ) - paddle.split(x=x6, num_or_sections=2, axis=3.2) + paddle.split(x=x5, num_or_sections=2, axis=3.2) self.assertRaises(TypeError, test_axis_type) - # The type of axis in split_op should be int or Variable. - def test_axis_variable_type(): - x9 = paddle.static.data( - shape=[-1, 4], dtype='float16', name='x9' - ) - x10 = paddle.static.data( - shape=[-1, 1], dtype='float16', name='x10' - ) - paddle.split(x=x9, num_or_sections=2, axis=x10) - - self.assertRaises(TypeError, test_axis_variable_type) - # The type of num_or_sections in split_op should be int, tuple or list. def test_num_or_sections_type(): x6 = paddle.static.data( From 57825e689ac72e45593b5ebd5c0e1c596ea3dfbd Mon Sep 17 00:00:00 2001 From: enkilee Date: Wed, 24 Jan 2024 14:51:12 +0800 Subject: [PATCH 03/13] fix --- test/legacy_test/test_split_op.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/legacy_test/test_split_op.py b/test/legacy_test/test_split_op.py index 54fd9676382e0..cb441976ed7e8 100644 --- a/test/legacy_test/test_split_op.py +++ b/test/legacy_test/test_split_op.py @@ -371,7 +371,7 @@ def test_errors_with_static(self): # The type of axis in split_op should be int or Variable. def test_axis_type(): x5 = paddle.static.data( - shape=[-1, 4], dtype='float16', name='x3' + shape=[-1, 4], dtype='float16', name='x5' ) paddle.split(x=x5, num_or_sections=2, axis=3.2) @@ -380,7 +380,7 @@ def test_axis_type(): # The type of num_or_sections in split_op should be int, tuple or list. def test_num_or_sections_type(): x6 = paddle.static.data( - shape=[-1, 4], dtype='float16', name='x4' + shape=[-1, 4], dtype='float16', name='x6' ) paddle.split(x=x6, num_or_sections=2.1, axis=3) @@ -388,7 +388,7 @@ def test_num_or_sections_type(): def test_num_or_sections_type_tensor(): x7 = paddle.static.data( - shape=[-1, 4], dtype='float16', name='x5' + shape=[-1, 4], dtype='float16', name='x7' ) paddle.split(input=x7, num_or_sections=2.1, dim=3) @@ -396,7 +396,7 @@ def test_num_or_sections_type_tensor(): def test_axis_type_tensor(): x8 = paddle.static.data( - shape=[-1, 4], dtype='float16', name='x6' + shape=[-1, 4], dtype='float16', name='x8' ) paddle.split(input=x8, num_or_sections=2, dim=3.2) From 4a2cfa30df48fb685b5eb3664982856d86481e2f Mon Sep 17 00:00:00 2001 From: enkilee Date: Thu, 25 Jan 2024 09:25:38 +0800 Subject: [PATCH 04/13] fix --- test/legacy_test/test_split_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/legacy_test/test_split_op.py b/test/legacy_test/test_split_op.py index cb441976ed7e8..b02908fe6ac52 100644 --- a/test/legacy_test/test_split_op.py +++ b/test/legacy_test/test_split_op.py @@ -384,7 +384,7 @@ def test_num_or_sections_type(): ) paddle.split(x=x6, num_or_sections=2.1, axis=3) - self.assertRaises(TypeError, test_num_or_sections_type) + self.assertRaises(IndexError, test_num_or_sections_type) def test_num_or_sections_type_tensor(): x7 = paddle.static.data( From 7daaa730afb5e49a74efacb88e131889c2938f17 Mon Sep 17 00:00:00 2001 From: enkilee Date: Thu, 25 Jan 2024 10:17:24 +0800 Subject: [PATCH 05/13] fix --- test/legacy_test/test_split_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/legacy_test/test_split_op.py b/test/legacy_test/test_split_op.py index b02908fe6ac52..cb441976ed7e8 100644 --- a/test/legacy_test/test_split_op.py +++ b/test/legacy_test/test_split_op.py @@ -384,7 +384,7 @@ def test_num_or_sections_type(): ) paddle.split(x=x6, num_or_sections=2.1, axis=3) - self.assertRaises(IndexError, test_num_or_sections_type) + self.assertRaises(TypeError, test_num_or_sections_type) def test_num_or_sections_type_tensor(): x7 = paddle.static.data( From 18537a52bbbbcb7855610e5f5c20a94f56610710 Mon Sep 17 00:00:00 2001 From: enkilee Date: Mon, 29 Jan 2024 16:23:27 +0800 Subject: [PATCH 06/13] fix --- test/legacy_test/test_split_op.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/legacy_test/test_split_op.py b/test/legacy_test/test_split_op.py index cb441976ed7e8..6935167837e3b 100644 --- a/test/legacy_test/test_split_op.py +++ b/test/legacy_test/test_split_op.py @@ -377,6 +377,18 @@ def test_axis_type(): self.assertRaises(TypeError, test_axis_type) + # The type of axis in split_op should be int or Variable. + def test_axis_variable_type(): + x9 = paddle.static.data( + shape=[-1, 4], dtype='float16', name='x9' + ) + x10 = paddle.static.data( + shape=[-1, 1], dtype='float16', name='x10' + ) + paddle.split(x=x9, num_or_sections=2, axis=x10) + + self.assertRaises(TypeError, test_axis_variable_type) + # The type of num_or_sections in split_op should be int, tuple or list. def test_num_or_sections_type(): x6 = paddle.static.data( From 979359f8b1048a28a5c3e89f0c800093f63fb1d9 Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 30 Jan 2024 16:13:28 +0800 Subject: [PATCH 07/13] fix --- test/legacy_test/test_split_op.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/test/legacy_test/test_split_op.py b/test/legacy_test/test_split_op.py index 6935167837e3b..9311a5f2d9957 100644 --- a/test/legacy_test/test_split_op.py +++ b/test/legacy_test/test_split_op.py @@ -20,6 +20,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.framework import in_pir_mode from paddle.pir_utils import test_with_pir_api @@ -377,17 +378,18 @@ def test_axis_type(): self.assertRaises(TypeError, test_axis_type) - # The type of axis in split_op should be int or Variable. - def test_axis_variable_type(): - x9 = paddle.static.data( - shape=[-1, 4], dtype='float16', name='x9' - ) - x10 = paddle.static.data( - shape=[-1, 1], dtype='float16', name='x10' - ) - paddle.split(x=x9, num_or_sections=2, axis=x10) - - self.assertRaises(TypeError, test_axis_variable_type) + if not in_pir_mode(): + # The type of axis in split_op should be int or Variable. + def test_axis_variable_type(): + x9 = paddle.static.data( + shape=[-1, 4], dtype='float16', name='x9' + ) + x10 = paddle.static.data( + shape=[-1, 1], dtype='float16', name='x10' + ) + paddle.split(x=x9, num_or_sections=2, axis=x10) + + self.assertRaises(TypeError, test_axis_variable_type) # The type of num_or_sections in split_op should be int, tuple or list. def test_num_or_sections_type(): From 6046a7ccb20315f9b47cc7c8697ca29a5fe877a9 Mon Sep 17 00:00:00 2001 From: enkilee Date: Wed, 31 Jan 2024 11:54:53 +0800 Subject: [PATCH 08/13] fix --- test/legacy_test/test_split_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/legacy_test/test_split_op.py b/test/legacy_test/test_split_op.py index 9311a5f2d9957..5854437dbe0ad 100644 --- a/test/legacy_test/test_split_op.py +++ b/test/legacy_test/test_split_op.py @@ -398,7 +398,7 @@ def test_num_or_sections_type(): ) paddle.split(x=x6, num_or_sections=2.1, axis=3) - self.assertRaises(TypeError, test_num_or_sections_type) + self.assertRaises(IndexError, test_num_or_sections_type) def test_num_or_sections_type_tensor(): x7 = paddle.static.data( From 1104f9375cb985134baa551385905ec033854a25 Mon Sep 17 00:00:00 2001 From: enkilee Date: Thu, 1 Feb 2024 14:35:10 +0800 Subject: [PATCH 09/13] fix --- test/legacy_test/test_split_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/legacy_test/test_split_op.py b/test/legacy_test/test_split_op.py index 5854437dbe0ad..9311a5f2d9957 100644 --- a/test/legacy_test/test_split_op.py +++ b/test/legacy_test/test_split_op.py @@ -398,7 +398,7 @@ def test_num_or_sections_type(): ) paddle.split(x=x6, num_or_sections=2.1, axis=3) - self.assertRaises(IndexError, test_num_or_sections_type) + self.assertRaises(TypeError, test_num_or_sections_type) def test_num_or_sections_type_tensor(): x7 = paddle.static.data( From 8cad345f92d6d00d7a716bff964500a763b8accf Mon Sep 17 00:00:00 2001 From: enkilee Date: Thu, 1 Feb 2024 17:05:16 +0800 Subject: [PATCH 10/13] fix --- test/legacy_test/test_split_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/legacy_test/test_split_op.py b/test/legacy_test/test_split_op.py index 9311a5f2d9957..5854437dbe0ad 100644 --- a/test/legacy_test/test_split_op.py +++ b/test/legacy_test/test_split_op.py @@ -398,7 +398,7 @@ def test_num_or_sections_type(): ) paddle.split(x=x6, num_or_sections=2.1, axis=3) - self.assertRaises(TypeError, test_num_or_sections_type) + self.assertRaises(IndexError, test_num_or_sections_type) def test_num_or_sections_type_tensor(): x7 = paddle.static.data( From 09998cc62f6b5c613b79b7af472998efe26e19d6 Mon Sep 17 00:00:00 2001 From: enkilee Date: Fri, 2 Feb 2024 12:48:25 +0800 Subject: [PATCH 11/13] fix --- test/legacy_test/test_split_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/legacy_test/test_split_op.py b/test/legacy_test/test_split_op.py index 5854437dbe0ad..9311a5f2d9957 100644 --- a/test/legacy_test/test_split_op.py +++ b/test/legacy_test/test_split_op.py @@ -398,7 +398,7 @@ def test_num_or_sections_type(): ) paddle.split(x=x6, num_or_sections=2.1, axis=3) - self.assertRaises(IndexError, test_num_or_sections_type) + self.assertRaises(TypeError, test_num_or_sections_type) def test_num_or_sections_type_tensor(): x7 = paddle.static.data( From 150eb681245cb7c5eaf9e8bed13082d836beaf8d Mon Sep 17 00:00:00 2001 From: enkilee Date: Sun, 4 Feb 2024 15:05:31 +0800 Subject: [PATCH 12/13] CI From 3f4ddd4c4b086504933bd9cbaebd693223d9da68 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Mon, 19 Feb 2024 07:41:40 +0000 Subject: [PATCH 13/13] Fix check in pir mode --- python/paddle/tensor/manipulation.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index d62771bb8ed53..05493cd976c94 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2484,6 +2484,12 @@ def split(x, num_or_sections, axis=0, name=None): dim = (len(input.shape) + dim) if dim < 0 else dim input_shape = input.shape + + if not isinstance(num_or_sections, (int, list, tuple)): + raise TypeError( + "The type of 'num_or_sections' in split must be int, list or tuple in imperative mode, but " + "received %s." % (type(num_or_sections)) + ) if isinstance(num_or_sections, int): assert num_or_sections > 0, 'num_or_sections must be than 0.' if isinstance(dim, int) and input_shape[dim] > 0: