Skip to content

Commit

Permalink
new nan_to_num impl
Browse files Browse the repository at this point in the history
Signed-off-by: tiancaishaonvjituizi <[email protected]>
  • Loading branch information
tiancaishaonvjituizi committed Aug 13, 2022
1 parent e96dae8 commit a755e2f
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@
from .tensor.math import square # noqa: F401
from .tensor.math import stanh # noqa: F401
from .tensor.math import sum # noqa: F401
from .tensor.math import nan_to_num # noqa: F401
from .tensor.math import nansum # noqa: F401
from .tensor.math import nanmean # noqa: F401
from .tensor.math import count_nonzero # noqa: F401
Expand Down Expand Up @@ -649,6 +650,7 @@
'renorm',
'take_along_axis',
'put_along_axis',
'nan_to_num',
'heaviside',
'tril_indices',
'sgn',
Expand Down
200 changes: 200 additions & 0 deletions python/paddle/fluid/tests/unittests/test_nan_to_num_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from typing import Optional
import numpy as np
import paddle
import paddle.fluid.core as core
from op_test import OpTest


def np_nan_to_num(x: np.ndarray,
nan: float = 0.0,
posinf: Optional[float] = None,
neginf: Optional[float] = None) -> np.ndarray:
return np.nan_to_num(x, True, nan=nan, posinf=posinf, neginf=neginf)


def np_nan_to_num_op(x: np.ndarray, nan: float, replace_posinf_with_max: bool,
posinf: float, replace_neginf_with_min: bool,
neginf: float) -> np.ndarray:
if replace_posinf_with_max:
posinf = None
if replace_neginf_with_min:
neginf = None
return np.nan_to_num(x, True, nan=nan, posinf=posinf, neginf=neginf)


def np_nan_to_num_grad(x: np.ndarray, dout: np.ndarray) -> np.ndarray:
dx = np.copy(dout)
dx[np.isnan(x) | (x == np.inf) | (x == -np.inf)] = 0
return dx


class TestNanToNum(unittest.TestCase):

def setUp(self):
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()

# def test_static(self):
# x_np = np.array([[1, np.nan, -2], [np.inf, 0,
# -np.inf]]).astype(np.float32)
# out1_np = np_nan_to_num(x_np)
# out2_np = np_nan_to_num(x_np, 1.)
# out3_np = np_nan_to_num(x_np, 1., 9.)
# out4_np = np_nan_to_num(x_np, 1., 9., -12.)
# paddle.enable_static()
# with paddle.static.program_guard(paddle.static.Program()):
# x = paddle.fluid.data('X', x_np.shape)
# out1 = paddle.nan_to_num(x)
# out2 = paddle.nan_to_num(x, 1.)
# out3 = paddle.nan_to_num(x, 1., 9.)
# out4 = paddle.nan_to_num(x, 1., 9., -12.)
# exe = paddle.static.Executor(self.place)
# res = exe.run(feed={'X': x_np}, fetch_list=[out1, out2, out3, out4])
#
# self.assertTrue(np.allclose(out1_np, res[0]))
# self.assertTrue(np.allclose(out2_np, res[1]))
# self.assertTrue(np.allclose(out3_np, res[2]))
# self.assertTrue(np.allclose(out4_np, res[3]))
#
# def test_errors(self):
# paddle.enable_static()
# with paddle.static.program_guard(paddle.static.Program()):
#
# def test_dtype():
# x = paddle.fluid.data('X2', [10, 12], 'bool')
# paddle.nan_to_num(x)
#
# self.assertRaises(TypeError, test_dtype)

def test_dygraph(self):

paddle.disable_static(place=self.place)

with paddle.fluid.dygraph.guard():
x_np = np.array([[1, np.nan, -2], [np.inf, 0,
-np.inf]]).astype(np.float64)
# -np.inf]]).astype(np.float32)
x_tensor = paddle.to_tensor(x_np, stop_gradient=False)

out_tensor = paddle.nan_to_num(x_tensor)
out_np = np_nan_to_num(x_np)
self.assertTrue(np.allclose(out_tensor.numpy(), out_np))

out_tensor = paddle.nan_to_num(x_tensor, 1., None, None)
out_np = np_nan_to_num(x_np, 1, None, None)
self.assertTrue(np.allclose(out_tensor.numpy(), out_np))

out_tensor = paddle.nan_to_num(x_tensor, 1., 2., None)
out_np = np_nan_to_num(x_np, 1, 2, None)
self.assertTrue(np.allclose(out_tensor.numpy(), out_np))

out_tensor = paddle.nan_to_num(x_tensor, 1., None, -10.)
out_np = np_nan_to_num(x_np, 1, None, -10)
self.assertTrue(np.allclose(out_tensor.numpy(), out_np))

out_tensor = paddle.nan_to_num(x_tensor, 1., 100., -10.)
out_np = np_nan_to_num(x_np, 1, 100, -10)
self.assertTrue(np.allclose(out_tensor.numpy(), out_np))

paddle.enable_static()

# def test_check_grad(self):
# paddle.disable_static(place=self.place)
# x_np = np.array([[1, np.nan, -2], [np.inf, 0,
# -np.inf]]).astype(np.float32)
# x_tensor = paddle.to_tensor(x_np, stop_gradient=False)
#
# y = paddle.nan_to_num(x_tensor)
# dx = paddle.grad(y, x_tensor)[0].numpy()
#
# np_grad = np_nan_to_num_grad(x_np, np.ones_like(x_np))
# self.assertTrue(np.allclose(np_grad, dx))
#
# paddle.enable_static()


# class BaseTestCases:
#
# class BaseOpTest(OpTest):
#
# def setUp(self):
# self.op_type = "nan_to_num"
# input = np.arange(100, dtype=np.float64)
# input[5] = np.nan
# input[29] = np.inf
# input[97] = -np.inf
# self.inputs = {'X': input}
# self.attrs = self._attrs()
# self.outputs = {
# 'Out': np_nan_to_num_op(self.inputs['X'], **self.attrs)
# }
# paddle.enable_static()
#
# def test_check_output(self):
# self.check_output()
#
# def test_check_grad(self):
# input = self.inputs['X']
# dout = np.ones_like(input) / input.size
# self.check_grad(
# ['X'],
# 'Out',
# user_defined_grads=[np_nan_to_num_grad(self.inputs['X'], dout)])
#
# def _attrs(self):
# raise NotImplementedError()
#
#
# class TestNanToNumOp1(BaseTestCases.BaseOpTest):
#
# def _attrs(self):
# return {
# 'nan': 0.0,
# 'replace_posinf_with_max': True,
# 'posinf': -1,
# 'replace_neginf_with_min': True,
# 'neginf': -10
# }
#
#
# class TestNanToNumOp2(BaseTestCases.BaseOpTest):
#
# def _attrs(self):
# return {
# 'nan': 2.0,
# 'replace_posinf_with_max': False,
# 'posinf': -1,
# 'replace_neginf_with_min': True,
# 'neginf': -10
# }
#
#
# class TestNanToNumOp3(BaseTestCases.BaseOpTest):
#
# def _attrs(self):
# return {
# 'nan': 0.0,
# 'replace_posinf_with_max': False,
# 'posinf': -1,
# 'replace_neginf_with_min': False,
# 'neginf': -10
# }


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@
from .math import square # noqa: F401
from .math import stanh # noqa: F401
from .math import sum # noqa: F401
from .math import nan_to_num # noqa: F401
from .math import nansum # noqa: F401
from .math import nanmean # noqa: F401
from .math import count_nonzero # noqa: F401
Expand Down Expand Up @@ -344,6 +345,7 @@
'square',
'stanh',
'sum',
'nan_to_num',
'nansum',
'nanmean',
'count_nonzero',
Expand Down
44 changes: 44 additions & 0 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,50 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
return out


def nan_to_num(x, nan=0.0, posinf=None, neginf=None, name=None):
"""
Replaces NaN, positive infinity, and negative infinity values in input tensor.
Args:
x (Tensor): An N-D Tensor, the data type is float32, float64.
nan (float, optional): the value to replace NaNs with. Default is 0.
posinf (float, optional): if a Number, the value to replace positive infinity values with. If None, positive infinity values are replaced with the greatest finite value representable by input’s dtype. Default is None.
neginf (float, optional): if a Number, the value to replace negative infinity values with. If None, negative infinity values are replaced with the lowest finite value representable by input’s dtype. Default is None.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: Results of nan_to_num operation input Tensor ``x``.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([float('nan'), 0.3, float('+inf'), float('-inf')], dtype='float32')
out1 = paddle.nan_to_num(x) # [0, 0.3, 3.4028235e+38, -3.4028235e+38]
out2 = paddle.nan_to_num(x, nan=1) # [1, 0.3, 3.4028235e+38, -3.4028235e+38]
out3 = paddle.nan_to_num(x, posinf=5) # [0, 0.3, 5, -3.4028235e+38]
out4 = paddle.nan_to_num(x, nan=10, neginf=-99) # [10, 0.3, 3.4028235e+38, -99]
"""
# NOTE(tiancaishaonvjituizi): it seems that paddle handles the dtype of python float number
# incorrectly, so we have to explicitly contruct tensors here
full_posinf = paddle.full_like(x, float("+inf"))
full_neginf = paddle.full_like(x, float("-inf"))
full_nan = paddle.full_like(x, nan)
assert x.dtype in [paddle.float32, paddle.float64]
is_float32 = x.dtype == paddle.float32
if posinf is None:
posinf = np.finfo(np.float32).max if is_float32 else np.finfo(np.float64).max
posinf = paddle.full_like(x, posinf)
if neginf is None:
neginf = np.finfo(np.float32).min if is_float32 else np.finfo(np.float64).min
neginf = paddle.full_like(x, neginf)
x = paddle.where(paddle.isnan(x), full_nan, x)
x = paddle.where(x == full_posinf, posinf, x)
x = paddle.where(x == full_neginf, neginf, x)
return x


def nansum(x, axis=None, dtype=None, keepdim=False, name=None):
"""
Computes the sum of tensor elements over the given axis, treating Not a Numbers (NaNs) as zero.
Expand Down

0 comments on commit a755e2f

Please sign in to comment.