diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 75b8199314..54fcdc8d59 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1207,7 +1207,7 @@ class MapLabelValue: """ - backend = [TransformBackends.NUMPY] + backend = [TransformBackends.NUMPY, TransformBackends.TORCH] def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeLike = np.float32) -> None: """ @@ -1215,33 +1215,42 @@ def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeL orig_labels: original labels that map to others. target_labels: expected label values, 1: 1 map to the `orig_labels`. dtype: convert the output data to dtype, default to float32. + if dtype is from PyTorch, the transform will use the pytorch backend, else with numpy backend. """ if len(orig_labels) != len(target_labels): raise ValueError("orig_labels and target_labels must have the same length.") - if all(o == z for o, z in zip(orig_labels, target_labels)): - raise ValueError("orig_labels and target_labels are exactly the same, should be different to map.") self.orig_labels = orig_labels self.target_labels = target_labels - self.dtype = get_equivalent_dtype(dtype, data_type=np.ndarray) + self.pair = tuple((o, t) for o, t in zip(self.orig_labels, self.target_labels) if o != t) + type_dtype = type(dtype) + if getattr(type_dtype, "__module__", "") == "torch": + self.use_numpy = False + self.dtype = get_equivalent_dtype(dtype, data_type=torch.Tensor) + else: + self.use_numpy = True + self.dtype = get_equivalent_dtype(dtype, data_type=np.ndarray) def __call__(self, img: NdarrayOrTensor): - img_np, *_ = convert_data_type(img, np.ndarray) - img_flat = img_np.flatten() - try: - out_flat = np.array(img_flat, dtype=self.dtype) - except ValueError: - # can't copy unchanged labels as the expected dtype is not supported, must map all the label values - out_flat = np.zeros(shape=img_flat.shape, dtype=self.dtype) - - for o, t in zip(self.orig_labels, self.target_labels): - if o == t: - continue - np.place(out_flat, img_flat == o, t) - - reshaped = out_flat.reshape(img_np.shape) - out, *_ = convert_to_dst_type(src=reshaped, dst=img, dtype=self.dtype) + if self.use_numpy: + img_np, *_ = convert_data_type(img, np.ndarray) + _out_shape = img_np.shape + img_flat = img_np.flatten() + try: + out_flat = img_flat.astype(self.dtype) + except ValueError: + # can't copy unchanged labels as the expected dtype is not supported, must map all the label values + out_flat = np.zeros(shape=img_flat.shape, dtype=self.dtype) + for o, t in self.pair: + out_flat[img_flat == o] = t + out_t = out_flat.reshape(_out_shape) + else: + img_t, *_ = convert_data_type(img, torch.Tensor) + out_t = img_t.detach().clone().to(self.dtype) # type: ignore + for o, t in self.pair: + out_t[img_t == o] = t + out, *_ = convert_to_dst_type(src=out_t, dst=img, dtype=self.dtype) return out diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 15025ac961..7f4f22a475 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1523,6 +1523,7 @@ def __init__( orig_labels: original labels that map to others. target_labels: expected label values, 1: 1 map to the `orig_labels`. dtype: convert the output data to dtype, default to float32. + if dtype is from PyTorch, the transform will use the pytorch backend, else with numpy backend. allow_missing_keys: don't raise exception if key is missing. """ diff --git a/tests/test_map_label_value.py b/tests/test_map_label_value.py index 32f5fccdb6..6b8121b6df 100644 --- a/tests/test_map_label_value.py +++ b/tests/test_map_label_value.py @@ -42,6 +42,13 @@ p([2, 0, 0, 1]), ] ) + TESTS.append( + [ + {"orig_labels": [1.5, 2.5, 3.5], "target_labels": [0, 1, 2], "dtype": torch.int8}, + p([3.5, 1.5, 1.5, 2.5]), + p([2, 0, 0, 1]), + ] + ) TESTS.extend( [ [ diff --git a/tests/test_map_label_valued.py b/tests/test_map_label_valued.py index 8c91adaa49..fa0d094393 100644 --- a/tests/test_map_label_valued.py +++ b/tests/test_map_label_valued.py @@ -14,9 +14,11 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import MapLabelValued +from tests.utils import assert_allclose TEST_CASE_1 = [ {"keys": "seg", "orig_labels": [3, 2, 1], "target_labels": [0, 1, 2]}, @@ -47,6 +49,11 @@ {"seg": np.array([3.5, 1.5, 1.5, 2.5])}, np.array([2, 0, 0, 1]), ] +TEST_CASE_5_1 = [ + {"keys": "seg", "orig_labels": [1.5, 2.5, 3.5], "target_labels": [0, 1, 2], "dtype": torch.int8}, + {"seg": torch.as_tensor([3.5, 1.5, 1.5, 2.5])}, + torch.as_tensor([2.0, 0.0, 0.0, 1.0]), +] TEST_CASE_6 = [ {"keys": "seg", "orig_labels": ["label3", "label2", "label1"], "target_labels": [0, 1, 2]}, @@ -62,10 +69,15 @@ class TestMapLabelValued(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + @parameterized.expand( + [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_5_1, TEST_CASE_6, TEST_CASE_7] + ) def test_shape(self, input_param, input_data, expected_value): result = MapLabelValued(**input_param)(input_data) - np.testing.assert_equal(result["seg"], expected_value) + if isinstance(expected_value, torch.Tensor): + assert_allclose(result["seg"], expected_value) + else: + np.testing.assert_equal(result["seg"], expected_value) self.assertTupleEqual(result["seg"].shape, expected_value.shape)