From 9772ffb86957c3c2447cffe671aa21ad4b213a06 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Wed, 22 Mar 2023 13:08:14 +0000 Subject: [PATCH 1/6] add scatter composite rule. --- paddle/fluid/prim/api/api.yaml | 1 + .../composite_backward_api.h | 21 ++++ paddle/phi/api/yaml/backward.yaml | 1 + .../fluid/tests/unittests/test_scatter_op.py | 119 ++++++++++++++---- 4 files changed, 116 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/prim/api/api.yaml b/paddle/fluid/prim/api/api.yaml index c5eadec1e079d..1be463e2902f3 100644 --- a/paddle/fluid/prim/api/api.yaml +++ b/paddle/fluid/prim/api/api.yaml @@ -23,6 +23,7 @@ - concat - elementwise_pow - floor +- gather - gather_nd - log - max diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 57203557fb5d1..7e7429a798bfd 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -1200,6 +1200,27 @@ void cos_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { set_output(x_grad_tmp, x_grad); } +template +void scatter_grad(const Tensor& index, + const Tensor& updates, + const Tensor& out_grad, + bool overwrite, + Tensor* x_grad, + Tensor* updates_grad) { + if (x_grad) { + auto zero_tensor = + full(phi::vectorize(updates.dims()), 0.0, updates.dtype()); + auto tmp_grad = scatter(out_grad, index, zero_tensor, false); + set_output(tmp_grad, x_grad); + } + + if (updates_grad) { + Scalar tmp_zero = 0; + auto tmp_updates_grad = gather(out_grad, index, tmp_zero); + set_output(tmp_updates_grad, updates_grad); + } +} + template void batch_norm_grad(const Tensor& x, const Tensor& scale, diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 0c05a4e806a97..69344777b46ab 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1252,6 +1252,7 @@ kernel : func : scatter_grad no_need_buffer : updates + composite: scatter_grad(index, updates, out_grad, overwrite, x_grad, updates_grad) - backward_op : scatter_nd_add_grad forward : scatter_nd_add (Tensor x, Tensor index, Tensor updates) -> Tensor(out) diff --git a/python/paddle/fluid/tests/unittests/test_scatter_op.py b/python/paddle/fluid/tests/unittests/test_scatter_op.py index 99c3bbfbc2582..85f9036092439 100644 --- a/python/paddle/fluid/tests/unittests/test_scatter_op.py +++ b/python/paddle/fluid/tests/unittests/test_scatter_op.py @@ -28,6 +28,7 @@ class TestScatterOp(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" ref_np = np.ones((3, 50)).astype(target_dtype) @@ -46,10 +47,12 @@ def _set_dtype(self): self.dtype = np.float32 def test_check_output(self): - self.check_output(check_eager=False) + self.check_output(check_eager=False, check_prim=True) def test_check_grad(self): - self.check_grad(["X", "Updates"], "Out", check_eager=False) + self.check_grad( + ["X", "Updates"], "Out", check_eager=False, check_prim=True + ) class TestScatterFP16Op(TestScatterOp): @@ -69,13 +72,19 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_eager=False) + self.check_output_with_place( + place, check_eager=False, check_prim=True + ) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X', 'Updates'], 'Out', check_eager=False + place, + ['X', 'Updates'], + 'Out', + check_eager=False, + check_prim=True, ) @@ -83,6 +92,7 @@ class TestScatterOp0(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" ref_np = np.ones((3, 3)).astype(target_dtype) @@ -102,10 +112,12 @@ def _set_dtype(self): self.dtype = np.float32 def test_check_output(self): - self.check_output(check_eager=False) + self.check_output(check_eager=False, check_prim=True) def test_check_grad(self): - self.check_grad(["X", "Updates"], "Out", check_eager=False) + self.check_grad( + ["X", "Updates"], "Out", check_eager=False, check_prim=True + ) class TestScatterFP16Op0(TestScatterOp0): @@ -125,13 +137,19 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_eager=False) + self.check_output_with_place( + place, check_eager=False, check_prim=True + ) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X', 'Updates'], 'Out', check_eager=False + place, + ['X', 'Updates'], + 'Out', + check_eager=False, + check_prim=True, ) @@ -139,6 +157,7 @@ class TestScatterOp1(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" ref_np = np.ones((3, 3)).astype(target_dtype) @@ -161,10 +180,12 @@ def _set_dtype(self): self.dtype = np.float32 def test_check_output(self): - self.check_output(check_eager=False) + self.check_output(check_eager=False, check_prim=True) def test_check_grad(self): - self.check_grad(["X", "Updates"], "Out", check_eager=False) + self.check_grad( + ["X", "Updates"], "Out", check_eager=False, check_prim=True + ) class TestScatterFP16Op1(TestScatterOp1): @@ -184,13 +205,19 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_eager=False) + self.check_output_with_place( + place, check_eager=False, check_prim=True + ) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X', 'Updates'], 'Out', check_eager=False + place, + ['X', 'Updates'], + 'Out', + check_eager=False, + check_prim=True, ) @@ -201,6 +228,7 @@ class TestScatterOp2(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" ref_np = np.ones((3, 3)).astype(target_dtype) @@ -221,13 +249,19 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-3, check_eager=False) + self.check_output_with_place( + place, atol=1e-3, check_eager=False, check_prim=True + ) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X', 'Updates'], 'Out', check_eager=False + place, + ['X', 'Updates'], + 'Out', + check_eager=False, + check_prim=True, ) @@ -256,6 +290,7 @@ class TestScatterOp3(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" ref_np = np.ones((3, 3)).astype(target_dtype) @@ -280,13 +315,19 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-3, check_eager=False) + self.check_output_with_place( + place, atol=1e-3, check_eager=False, check_prim=True + ) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X', 'Updates'], 'Out', check_eager=False + place, + ['X', 'Updates'], + 'Out', + check_eager=False, + check_prim=True, ) @@ -312,6 +353,7 @@ class TestScatterOp4(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" ref_np = np.ones((3, 3)).astype(target_dtype) @@ -330,10 +372,12 @@ def _set_dtype(self): self.dtype = np.float32 def test_check_output(self): - self.check_output(check_eager=False) + self.check_output(check_eager=False, check_prim=True) def test_check_grad(self): - self.check_grad(['X', 'Updates'], 'Out', check_eager=False) + self.check_grad( + ['X', 'Updates'], 'Out', check_eager=False, check_prim=True + ) class TestScatterFP16Op4(TestScatterOp4): @@ -353,13 +397,19 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_eager=False) + self.check_output_with_place( + place, check_eager=False, check_prim=True + ) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X', 'Updates'], 'Out', check_eager=False + place, + ['X', 'Updates'], + 'Out', + check_eager=False, + check_prim=True, ) @@ -370,6 +420,7 @@ class TestScatterOp5(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" ref_np = np.ones((3, 3)).astype(target_dtype) @@ -390,13 +441,19 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-3, check_eager=False) + self.check_output_with_place( + place, atol=1e-3, check_eager=False, check_prim=True + ) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X', 'Updates'], 'Out', check_eager=False + place, + ['X', 'Updates'], + 'Out', + check_eager=False, + check_prim=True, ) @@ -422,6 +479,8 @@ class TestScatterOp6(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.prim_op_type = "prim" + self.enable_cinn = False self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" ref_np = np.ones((3, 50)).astype(target_dtype) @@ -440,10 +499,12 @@ def _set_dtype(self): self.dtype = np.float32 def test_check_output(self): - self.check_output(check_eager=False) + self.check_output(check_eager=False, check_prim=True) def test_check_grad(self): - self.check_grad(["X", "Updates"], "Out", check_eager=False) + self.check_grad( + ["X", "Updates"], "Out", check_eager=False, check_prim=True + ) class TestScatterFP16Op6(TestScatterOp6): @@ -463,13 +524,19 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_eager=False) + self.check_output_with_place( + place, check_eager=False, check_prim=True + ) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X', 'Updates'], 'Out', check_eager=False + place, + ['X', 'Updates'], + 'Out', + check_eager=False, + check_prim=True, ) From 3b4864a658dadb318f95ee6da6480b60bd1cb21e Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Thu, 23 Mar 2023 03:05:20 +0000 Subject: [PATCH 2/6] add public_python_api --- python/paddle/fluid/tests/unittests/test_scatter_op.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_scatter_op.py b/python/paddle/fluid/tests/unittests/test_scatter_op.py index 85f9036092439..c9250d43a93ae 100644 --- a/python/paddle/fluid/tests/unittests/test_scatter_op.py +++ b/python/paddle/fluid/tests/unittests/test_scatter_op.py @@ -28,6 +28,7 @@ class TestScatterOp(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.public_python_api = paddle.scatter self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" @@ -92,6 +93,7 @@ class TestScatterOp0(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.public_python_api = paddle.scatter self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" @@ -157,6 +159,7 @@ class TestScatterOp1(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.public_python_api = paddle.scatter self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" @@ -228,6 +231,7 @@ class TestScatterOp2(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.public_python_api = paddle.scatter self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" @@ -290,6 +294,7 @@ class TestScatterOp3(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.public_python_api = paddle.scatter self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" @@ -353,6 +358,7 @@ class TestScatterOp4(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.public_python_api = paddle.scatter self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" @@ -420,6 +426,7 @@ class TestScatterOp5(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.public_python_api = paddle.scatter self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" @@ -479,6 +486,7 @@ class TestScatterOp6(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.public_python_api = paddle.scatter self.prim_op_type = "prim" self.enable_cinn = False self._set_dtype() From 40c0e8b4f716f4e8e4e17e255e2b054cdf0478d3 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Thu, 23 Mar 2023 07:22:34 +0000 Subject: [PATCH 3/6] add python unit16 support. --- python/paddle/tensor/manipulation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 41a8cfa856f8c..63f380212b99a 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2906,7 +2906,7 @@ def scatter(x, index, updates, overwrite=True, name=None): check_variable_and_dtype( x, 'dtype', - ['float32', 'float64', 'float16', 'int32', 'int64'], + ['float32', 'float64', 'float16', 'int32', 'int64', 'unit16'], 'scatter', ) check_type(overwrite, 'overwrite', bool, 'scatter') From 4eb1eb5261fb34c8be77767ed235114e5628286b Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Thu, 23 Mar 2023 12:44:33 +0000 Subject: [PATCH 4/6] fix code style. --- python/paddle/tensor/manipulation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 63f380212b99a..bb4b78809cd8a 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2906,7 +2906,7 @@ def scatter(x, index, updates, overwrite=True, name=None): check_variable_and_dtype( x, 'dtype', - ['float32', 'float64', 'float16', 'int32', 'int64', 'unit16'], + ['float32', 'float64', 'float16', 'int32', 'int64', 'uint16'], 'scatter', ) check_type(overwrite, 'overwrite', bool, 'scatter') From 767136c8f7debd2c64aa61aed9d98998080e51bb Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Tue, 28 Mar 2023 06:27:54 +0000 Subject: [PATCH 5/6] add cinn to makelist --- .../fluid/tests/unittests/CMakeLists.txt | 1 + .../fluid/tests/unittests/test_scatter_op.py | 55 +++++-------------- 2 files changed, 14 insertions(+), 42 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index a30b02f26a33e..bd7dd57080bd4 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1212,6 +1212,7 @@ set(TEST_CINN_OPS test_mean_op test_unsqueeze2_op test_meshgrid_op + test_scatter_op test_gather_op) foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) diff --git a/python/paddle/fluid/tests/unittests/test_scatter_op.py b/python/paddle/fluid/tests/unittests/test_scatter_op.py index 0c0400605685b..8d3903c86ac25 100644 --- a/python/paddle/fluid/tests/unittests/test_scatter_op.py +++ b/python/paddle/fluid/tests/unittests/test_scatter_op.py @@ -51,9 +51,7 @@ def test_check_output(self): self.check_output(check_prim=True) def test_check_grad(self): - self.check_grad( - ["X", "Updates"], "Out", check_prim=True - ) + self.check_grad(["X", "Updates"], "Out", check_prim=True) class TestScatterFP16Op(TestScatterOp): @@ -73,10 +71,7 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place( - place, check_prim=True - ) - + self.check_output_with_place(place, check_prim=True) def test_check_grad(self): if core.is_compiled_with_cuda(): @@ -89,7 +84,6 @@ def test_check_grad(self): ) - class TestScatterOp0(OpTest): def setUp(self): self.op_type = "scatter" @@ -118,9 +112,7 @@ def test_check_output(self): self.check_output(check_prim=True) def test_check_grad(self): - self.check_grad( - ["X", "Updates"], "Out", check_prim=True - ) + self.check_grad(["X", "Updates"], "Out", check_prim=True) class TestScatterFP16Op0(TestScatterOp0): @@ -140,9 +132,7 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place( - place, check_prim=True - ) + self.check_output_with_place(place, check_prim=True) def test_check_grad(self): if core.is_compiled_with_cuda(): @@ -186,9 +176,7 @@ def test_check_output(self): self.check_output(check_prim=True) def test_check_grad(self): - self.check_grad( - ["X", "Updates"], "Out", check_prim=True - ) + self.check_grad(["X", "Updates"], "Out", check_prim=True) class TestScatterFP16Op1(TestScatterOp1): @@ -208,9 +196,7 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place( - place, check_prim=True - ) + self.check_output_with_place(place, check_prim=True) def test_check_grad(self): if core.is_compiled_with_cuda(): @@ -252,9 +238,7 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place( - place, atol=1e-3, check_prim=True - ) + self.check_output_with_place(place, atol=1e-3, check_prim=True) def test_check_grad(self): if core.is_compiled_with_cuda(): @@ -318,10 +302,7 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place( - place, atol=1e-3, check_prim=True - ) - + self.check_output_with_place(place, atol=1e-3, check_prim=True) def test_check_grad(self): if core.is_compiled_with_cuda(): @@ -379,9 +360,7 @@ def test_check_output(self): self.check_output(check_prim=True) def test_check_grad(self): - self.check_grad( - ['X', 'Updates'], 'Out', check_prim=True - ) + self.check_grad(['X', 'Updates'], 'Out', check_prim=True) class TestScatterFP16Op4(TestScatterOp4): @@ -401,9 +380,7 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place( - place, check_prim=True - ) + self.check_output_with_place(place, check_prim=True) def test_check_grad(self): if core.is_compiled_with_cuda(): @@ -445,9 +422,7 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place( - place, atol=1e-3, check_prim=True - ) + self.check_output_with_place(place, atol=1e-3, check_prim=True) def test_check_grad(self): if core.is_compiled_with_cuda(): @@ -506,9 +481,7 @@ def test_check_output(self): self.check_output(check_prim=True) def test_check_grad(self): - self.check_grad( - ["X", "Updates"], "Out", check_prim=True - ) + self.check_grad(["X", "Updates"], "Out", check_prim=True) class TestScatterFP16Op6(TestScatterOp6): @@ -528,9 +501,7 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place( - place, check_prim=True - ) + self.check_output_with_place(place, check_prim=True) def test_check_grad(self): if core.is_compiled_with_cuda(): From b6c20bf030a06f7368fc7b894c42883be8d8eb03 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Wed, 29 Mar 2023 09:55:51 +0000 Subject: [PATCH 6/6] cinn unsupport uint16, forbidden cinn when dtype==uint16. --- python/paddle/fluid/tests/unittests/test_scatter_op.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_scatter_op.py b/python/paddle/fluid/tests/unittests/test_scatter_op.py index 8d3903c86ac25..34c30e6591df8 100644 --- a/python/paddle/fluid/tests/unittests/test_scatter_op.py +++ b/python/paddle/fluid/tests/unittests/test_scatter_op.py @@ -67,6 +67,7 @@ def _set_dtype(self): class TestScatterBF16Op(TestScatterOp): def _set_dtype(self): self.dtype = np.uint16 + self.enable_cinn = False def test_check_output(self): if core.is_compiled_with_cuda(): @@ -128,6 +129,7 @@ def _set_dtype(self): class TestScatterBF16Op0(TestScatterOp0): def _set_dtype(self): self.dtype = np.uint16 + self.enable_cinn = False def test_check_output(self): if core.is_compiled_with_cuda(): @@ -192,6 +194,7 @@ def _set_dtype(self): class TestScatterBF16Op1(TestScatterOp1): def _set_dtype(self): self.dtype = np.uint16 + self.enable_cinn = False def test_check_output(self): if core.is_compiled_with_cuda(): @@ -267,6 +270,7 @@ def _set_dtype(self): class TestScatterBF16Op2(TestScatterOp2): def _set_dtype(self): self.dtype = np.uint16 + self.enable_cinn = False @unittest.skipIf( @@ -331,6 +335,7 @@ def _set_dtype(self): class TestScatterBF16Op3(TestScatterOp3): def _set_dtype(self): self.dtype = np.uint16 + self.enable_cinn = False class TestScatterOp4(OpTest): @@ -376,6 +381,7 @@ def _set_dtype(self): class TestScatterBF16Op4(TestScatterOp4): def _set_dtype(self): self.dtype = np.uint16 + self.enable_cinn = False def test_check_output(self): if core.is_compiled_with_cuda(): @@ -451,6 +457,7 @@ def _set_dtype(self): class TestScatterBF16Op5(TestScatterOp5): def _set_dtype(self): self.dtype = np.uint16 + self.enable_cinn = False class TestScatterOp6(OpTest):