diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 051eab9e0e..b35fa5d585 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1713,6 +1713,15 @@ Post-processing (Dict) :members: :special-members: __call__ +Signal (Dict) +^^^^^^^^^^^^^ + +`SignalFillEmptyd` +"""""""""""""""""" +.. autoclass:: SignalFillEmptyd + :members: + :special-members: __call__ + Spatial (Dict) ^^^^^^^^^^^^^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index eb8c5af19e..51fd5c6288 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -345,6 +345,7 @@ SignalRandShift, SignalRemoveFrequency, ) +from .signal.dictionary import SignalFillEmptyd, SignalFillEmptyD, SignalFillEmptyDict from .smooth_field.array import ( RandSmoothDeform, RandSmoothFieldAdjustContrast, diff --git a/monai/transforms/signal/dictionary.py b/monai/transforms/signal/dictionary.py new file mode 100644 index 0000000000..469014d867 --- /dev/null +++ b/monai/transforms/signal/dictionary.py @@ -0,0 +1,52 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A collection of dictionary-based wrappers around the signal operations defined in :py:class:`monai.transforms.signal.array`. + +Class names are ended with 'd' to denote dictionary-based transforms. +""" + +from __future__ import annotations + +from collections.abc import Hashable, Mapping + +from monai.config.type_definitions import KeysCollection, NdarrayOrTensor +from monai.transforms.signal.array import SignalFillEmpty +from monai.transforms.transform import MapTransform + +__all__ = ["SignalFillEmptyd", "SignalFillEmptyD", "SignalFillEmptyDict"] + + +class SignalFillEmptyd(MapTransform): + """ + Applies the SignalFillEmptyd transform on the input. All NaN values will be replaced with the + replacement value. + + Args: + keys: keys of the corresponding items to model output. + allow_missing_keys: don't raise exception if key is missing. + replacement: The value that the NaN entries shall be mapped to. + """ + + backend = SignalFillEmpty.backend + + def __init__(self, keys: KeysCollection = None, allow_missing_keys: bool = False, replacement=0.0): + super().__init__(keys, allow_missing_keys) + self.signal_fill_empty = SignalFillEmpty(replacement=replacement) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + for key in self.key_iterator(data): + data[key] = self.signal_fill_empty(data[key]) # type: ignore + + return data + + +SignalFillEmptyD = SignalFillEmptyDict = SignalFillEmptyd diff --git a/tests/test_signal_fillempty.py b/tests/test_signal_fillempty.py index f44e4ba29a..ee606d960c 100644 --- a/tests/test_signal_fillempty.py +++ b/tests/test_signal_fillempty.py @@ -32,7 +32,7 @@ def test_correct_parameters_multi_channels(self): sig[:, 123] = np.NAN fillempty = SignalFillEmpty(replacement=0.0) fillemptysignal = fillempty(sig) - self.assertTrue(not np.isnan(fillemptysignal.any())) + self.assertTrue(not np.isnan(fillemptysignal).any()) @SkipIfBeforePyTorchVersion((1, 9)) @@ -43,7 +43,7 @@ def test_correct_parameters_multi_channels(self): sig[:, 123] = convert_to_tensor(np.NAN) fillempty = SignalFillEmpty(replacement=0.0) fillemptysignal = fillempty(sig) - self.assertTrue(not torch.isnan(fillemptysignal.any())) + self.assertTrue(not torch.isnan(fillemptysignal).any()) if __name__ == "__main__": diff --git a/tests/test_signal_fillemptyd.py b/tests/test_signal_fillemptyd.py new file mode 100644 index 0000000000..5b12055e7d --- /dev/null +++ b/tests/test_signal_fillemptyd.py @@ -0,0 +1,58 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import unittest + +import numpy as np +import torch + +from monai.transforms import SignalFillEmptyd +from monai.utils.type_conversion import convert_to_tensor +from tests.utils import SkipIfBeforePyTorchVersion + +TEST_SIGNAL = os.path.join(os.path.dirname(__file__), "testing_data", "signal.npy") + + +@SkipIfBeforePyTorchVersion((1, 9)) +class TestSignalFillEmptyNumpy(unittest.TestCase): + def test_correct_parameters_multi_channels(self): + self.assertIsInstance(SignalFillEmptyd(replacement=0.0), SignalFillEmptyd) + sig = np.load(TEST_SIGNAL) + sig[:, 123] = np.NAN + data = {} + data["signal"] = sig + fillempty = SignalFillEmptyd(keys=("signal",), replacement=0.0) + data_ = fillempty(data) + + self.assertTrue(np.isnan(sig).any()) + self.assertTrue(not np.isnan(data_["signal"]).any()) + + +@SkipIfBeforePyTorchVersion((1, 9)) +class TestSignalFillEmptyTorch(unittest.TestCase): + def test_correct_parameters_multi_channels(self): + self.assertIsInstance(SignalFillEmptyd(replacement=0.0), SignalFillEmptyd) + sig = convert_to_tensor(np.load(TEST_SIGNAL)) + sig[:, 123] = convert_to_tensor(np.NAN) + data = {} + data["signal"] = sig + fillempty = SignalFillEmptyd(keys=("signal",), replacement=0.0) + data_ = fillempty(data) + + self.assertTrue(np.isnan(sig).any()) + self.assertTrue(not torch.isnan(data_["signal"]).any()) + + +if __name__ == "__main__": + unittest.main()