diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 842151d83b332..f18c4e1099b3b 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -252,6 +252,8 @@ view, view_as, unfold, + masked_fill, + masked_fill_, ) from .tensor.math import ( # noqa: F401 @@ -907,6 +909,8 @@ 'i1e', 'polygamma', 'polygamma_', + 'masked_fill', + 'masked_fill_', 'hypot', 'hypot_', ] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index c8bfe99f91e6b..90d79844ad9a5 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -166,6 +166,8 @@ from .manipulation import view # noqa: F401 from .manipulation import view_as # noqa: F401 from .manipulation import unfold # noqa: F401 +from .manipulation import masked_fill # noqa: F401 +from .manipulation import masked_fill_ # noqa: F401 from .math import abs # noqa: F401 from .math import abs_ # noqa: F401 from .math import acos # noqa: F401 @@ -695,6 +697,8 @@ 'i1e', 'polygamma', 'polygamma_', + 'masked_fill', + 'masked_fill_', 'diag_embed', 'atan2', 'diagflat', diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index ae61880c997be..84658c3dfc737 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4720,6 +4720,76 @@ def moveaxis(x, source, destination, name=None): return out +def masked_fill(x, mask, value, name=None): + """ + Fills elements of self tensor with value where mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor. + + Args: + x (Tensor) : The Destination Tensor. Supported data types are float, + double, int, int64_t,float16 and bfloat16. + mask (Tensor): The boolean tensor indicate the position to be filled. + The data type of mask must be bool. + value (Scalar or 0-D Tensor): The value used to fill the target tensor. + Supported data types are float, double, int, int64_t,float16 and bfloat16. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + + Returns: + Tensor, same dimention and dtype with x. + Examples: + .. code-block:: python + + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> x = paddle.ones((3, 3), dtype="float32") + >>> mask = paddle.to_tensor([[True, True, False]]) + >>> print(mask) + Tensor(shape=[1, 3], dtype=bool, place=Place(gpu:0), stop_gradient=True, + [[True , True , False]]) + >>> out = paddle.masked_fill(x, mask, 2) + >>> print(out) + Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[2., 2., 1.], + [2., 2., 1.], + [2., 2., 1.]]) + """ + if np.isscalar(value): + value = paddle.full([], value, x.dtype) + + mask = paddle.logical_not(mask) + out = paddle.where(mask, x, value) + return out + + +@inplace_apis_in_dygraph_only +def masked_fill_(x, mask, value, name=None): + """ + Inplace version of ``masked_fill`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_masked_fill`. + + Examples: + .. code-block:: python + + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> x = paddle.ones((3, 3), dtype="float32") + >>> mask = paddle.to_tensor([[True, False, False]]) + >>> out = paddle.masked_fill_(x, mask, 2) + >>> print(out) + Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[2., 1., 1.], + [2., 1., 1.], + [2., 1., 1.]]) + """ + if np.isscalar(value): + value = paddle.full([], value, x.dtype) + + mask = paddle.logical_not(mask) + out = paddle.where_(mask, x, value) + return out + + def non_negative_axis(arr, axis): ndim = len(arr.shape) if axis >= 0: diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index cb45f2fd8969f..c5e94d9cf8930 100644 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -250,6 +250,72 @@ def test_backward_success_2(self): np.testing.assert_array_equal(grad_var_a_inplace, grad_var_a) +class TestDygraphInplaceMaskedFill(TestDygraphInplace): + def non_inplace_api_processing(self, var): + return paddle.masked_fill(var, self.mask, self.value) + + def inplace_api_processing(self, var): + return paddle.masked_fill_(var, self.mask, self.value) + + def init_data(self): + self.dtype = "float32" + self.input_var_numpy = np.random.uniform(-5, 5, [30, 3]) + self.value = np.random.uniform(-10, 10) + self.value = paddle.to_tensor(self.value, dtype=self.dtype) + self.mask = np.random.randint(0, 2, [30, 3]).astype('bool') + self.mask = paddle.to_tensor(self.mask, dtype='bool') + + def test_forward_version(self): + with paddle.base.dygraph.guard(): + var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + self.assertEqual(var.inplace_version, 0) + + inplace_var = self.inplace_api_processing(var) + self.assertEqual(var.inplace_version, 2) + + inplace_var[0] = 2 + self.assertEqual(var.inplace_version, 3) + + inplace_var = self.inplace_api_processing(inplace_var) + self.assertEqual(var.inplace_version, 5) + + def test_backward_error(self): + # It raises an error because the inplace operator will result + # in incorrect gradient computation. + with paddle.base.dygraph.guard(): + var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + var_a.stop_gradient = False + + var_b = var_a**2 + + # Here, the gradient computation will use the value of var_b + var_c = var_b**2 + self.inplace_api_processing(var_b) + + loss = paddle.nn.functional.relu(var_c) + with self.assertRaisesRegex( + RuntimeError, + f"received tensor_version:{2} != wrapper_version_snapshot:{0}", + ): + loss.backward() + + +class TestDygraphInplaceMaskedFill2(TestDygraphInplaceMaskedFill): + def non_inplace_api_processing(self, var): + return paddle.masked_fill(var, self.mask, self.value) + + def inplace_api_processing(self, var): + return paddle.masked_fill_(var, self.mask, self.value) + + def init_data(self): + self.dtype = "float32" + self.input_var_numpy = np.random.uniform(-5, 5, [30, 3]) + self.value = np.random.uniform(-10, 10) + self.value = paddle.to_tensor(self.value, dtype=self.dtype) + self.mask = np.random.randint(0, 2, [30, 1]).astype('bool') + self.mask = paddle.to_tensor(self.mask, dtype='bool') + + class TestDygraphInplaceWithContinuous(TestDygraphInplace): def init_data(self): self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1]) diff --git a/test/legacy_test/test_masked_fill.py b/test/legacy_test/test_masked_fill.py new file mode 100644 index 0000000000000..ec511f9b680e4 --- /dev/null +++ b/test/legacy_test/test_masked_fill.py @@ -0,0 +1,328 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import convert_float_to_uint16 + +import paddle +from paddle import base +from paddle.base import core + + +def np_masked_fill(x, mask, value): + if not np.isscalar(value): + value = value[0] + + x, mask = np.broadcast_arrays(x, mask) + result = np.copy(x) + for idx, m in np.ndenumerate(mask): + if m: + result[idx] = value + return result + + +paddle.enable_static() + + +class TestMaskedFillAPI(unittest.TestCase): + def setUp(self): + self.init() + + self.x_np = np.random.random(self.x_shape).astype(self.dtype) + self.mask_np = np.array( + np.random.randint(2, size=self.mask_shape), dtype="bool" + ) + + self.value_np = np.random.randn(1).astype(self.dtype) + self.out_np = np_masked_fill(self.x_np, self.mask_np, self.value_np) + + def init(self): + self.x_shape = (50, 3) + self.mask_shape = self.x_shape + self.dtype = "float32" + self.scalar_value = False + + def test_static_graph(self): + paddle.enable_static() + startup_program = base.Program() + train_program = base.Program() + with base.program_guard(startup_program, train_program): + x = paddle.static.data( + name='x', dtype=self.dtype, shape=self.x_shape + ) + mask = paddle.static.data( + name='mask', dtype='bool', shape=self.mask_shape + ) + value = paddle.static.data( + name='value', dtype=self.dtype, shape=self.value_np.shape + ) + out = paddle.masked_fill(x, mask, value) + + place = ( + base.CUDAPlace(0) + if core.is_compiled_with_cuda() + else base.CPUPlace() + ) + exe = base.Executor(place) + res = exe.run( + base.default_main_program(), + feed={ + 'x': self.x_np, + 'mask': self.mask_np, + 'value': self.value_np, + }, + fetch_list=[out], + ) + np.testing.assert_allclose( + res[0], self.out_np, atol=1e-5, rtol=1e-5 + ) + paddle.disable_static() + + def test_dygraph(self): + paddle.disable_static() + x = paddle.to_tensor(self.x_np, dtype=self.dtype) + mask = paddle.to_tensor(self.mask_np).astype('bool') + if self.scalar_value: + value = self.value_np[0] + else: + value = paddle.to_tensor(self.value_np, dtype=self.dtype) + result = paddle.masked_fill(x, mask, value) + np.testing.assert_allclose(self.out_np, result.numpy(), rtol=1e-05) + + paddle.enable_static() + + +class TestMaskedFillAPI1(TestMaskedFillAPI): + def init(self): + self.x_shape = (6, 8, 9, 18) + self.mask_shape = self.x_shape + self.dtype = "float32" + self.scalar_value = False + + +class TestMaskedFillAPI2(TestMaskedFillAPI): + def init(self): + self.x_shape = (168,) + self.mask_shape = self.x_shape + self.dtype = "float32" + self.scalar_value = False + + +class TestMaskedFillAPI3(TestMaskedFillAPI): + def init(self): + self.x_shape = (6, 8, 9, 18) + self.mask_shape = self.x_shape + self.dtype = "float32" + self.scalar_value = True + + +class TestMaskedFillGrad(unittest.TestCase): + def setUp(self): + self.typelist = ['float32', 'float64', 'int32', 'int64'] + self.places = [base.CPUPlace()] + if base.core.is_compiled_with_cuda(): + self.places.append(base.CUDAPlace(0)) + self.dtype = "float32" + + def test_backward(self): + expected_np = np.array( + [[2, 1, 1], [2, 1, 1], [2, 1, 1], [2, 1, 1]] + ).astype('float32') + expected_y_grad = np.array( + [[1, 0, 0], [1, 0, 0], [1, 0, 0], [1, 0, 0]] + ).astype('float32') + expected_v_grad = np.array(8).astype('float32') + + for idx, p in enumerate(self.places): + if idx == 0: + paddle.set_device('cpu') + else: + paddle.set_device('gpu') + for dtype in self.typelist: + v = paddle.to_tensor(np.array(1).astype(self.dtype)) + x = paddle.ones((4, 3), dtype=self.dtype) + mask = paddle.to_tensor(np.array([0, 1, 1]).astype("bool")) + x.stop_gradient = False + v.stop_gradient = False + y = x * 2 + y.retain_grads() + ny = y.masked_fill(mask=mask, value=v) + loss = ny.sum() + loss.backward() + + self.assertEqual( + (ny.numpy().astype('float32') == expected_np).all(), True + ) + self.assertEqual( + (y.grad.numpy().astype('float32') == expected_y_grad).all(), + True, + ) + self.assertEqual( + (v.grad.numpy().astype('float32') == expected_v_grad).all(), + True, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedFillFP16API1(TestMaskedFillAPI): + def init(self): + self.x_shape = (6, 8, 9, 18) + self.mask_shape = self.x_shape + self.dtype = "float16" + self.scalar_value = False + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedFillFP16API2(TestMaskedFillAPI): + def init(self): + self.x_shape = (168,) + self.mask_shape = self.x_shape + self.dtype = "float16" + self.scalar_value = False + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedFillFP16API3(TestMaskedFillAPI): + def init(self): + self.x_shape = (168,) + self.mask_shape = self.x_shape + self.dtype = "float16" + self.scalar_value = True + + +class TestMaskedFillAPIBroadcast(TestMaskedFillAPI): + def init(self): + self.x_shape = (3, 40) + self.mask_shape = (3, 1) + self.dtype = "float32" + self.scalar_value = False + + +class TestMaskedFillAPIBroadcast2(TestMaskedFillAPI): + def init(self): + self.x_shape = (3, 3) + self.mask_shape = (1, 3) + self.dtype = "float32" + self.scalar_value = False + + +class TestMaskedFillAPIBroadcast3(TestMaskedFillAPI): + def init(self): + self.x_shape = (120,) + self.mask_shape = (300, 120) + self.dtype = "float32" + self.scalar_value = False + + +class TestMaskedFillAPIBroadcast4(TestMaskedFillAPI): + def init(self): + self.x_shape = (300, 40) + self.mask_shape = (40,) + self.dtype = "float32" + self.scalar_value = False + + +class TestMaskedFillAPIBroadcast5(TestMaskedFillAPI): + def init(self): + self.x_shape = (300, 40) + self.mask_shape = (40,) + self.dtype = "float32" + self.scalar_value = True + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedFillFP16APIBroadcast(TestMaskedFillAPI): + def init(self): + self.x_shape = (3, 40) + self.mask_shape = (3, 1) + self.dtype = "float16" + self.scalar_value = False + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedFillFP16APIBroadcast2(TestMaskedFillAPI): + def init(self): + self.x_shape = (300, 1) + self.mask_shape = (300, 40) + self.dtype = "float16" + self.scalar_value = False + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedFillFP16APIBroadcast3(TestMaskedFillAPI): + def init(self): + self.x_shape = (300, 1) + self.mask_shape = (300, 40) + self.dtype = "float16" + self.scalar_value = True + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support bfloat16", +) +class TestMaskedFillBF16(TestMaskedFillAPI): + def init(self): + self.x_shape = (300, 1) + self.mask_shape = (300, 1) + self.dtype = "uint16" + self.scalar_value = False + + def setUp(self): + self.init() + + self.x_np = convert_float_to_uint16( + np.random.random(self.x_shape).astype("float32") + ) + self.mask_np = np.array( + np.random.randint(2, size=self.mask_shape), dtype="bool" + ) + + self.value_np = convert_float_to_uint16( + np.random.randn(1).astype("float32") + ) + self.out_np = np_masked_fill(self.x_np, self.mask_np, self.value_np) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support bfloat16", +) +class TestMaskedFillBF16APIBroadcast2(TestMaskedFillBF16): + def init(self): + self.x_shape = (300, 1) + self.mask_shape = (300, 3) + self.dtype = "uint16" + self.scalar_value = False + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main()