From 7f16a15c8c360d276a4c72533fe52fd162ef1fb3 Mon Sep 17 00:00:00 2001 From: Wenqi Li <831580+wyli@users.noreply.github.com> Date: Sun, 25 Sep 2022 20:46:03 +0100 Subject: [PATCH] 5206 channel_dim param has higher priority in `EnsureChannelFirst ` (#5208) Signed-off-by: Wenqi Li Fixes #5206 ### Description changes to `EnsureChannelFirst` so that the input `channel_dim` has a higher priority than `metatensor.meta['original_channel_dim']` ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li --- monai/transforms/utility/array.py | 28 ++++++++++++++------------ monai/transforms/utility/dictionary.py | 8 ++++---- tests/test_ensure_channel_first.py | 4 +++- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index bfea091f20..8d891958de 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -209,15 +209,15 @@ class EnsureChannelFirst(Transform): Args: strict_check: whether to raise an error when the meta information is insufficient. - channel_dim: If the input image `img` is not a MetaTensor or `meta_dict` is not given, - this argument can be used to specify the original channel dimension (integer) of the input array. - If the input array doesn't have a channel dim, this value should be ``'no_channel'`` (default). + channel_dim: This argument can be used to specify the original channel dimension (integer) of the input array. + It overrides the `original_channel_dim` from provided MetaTensor input. + If the input array doesn't have a channel dim, this value should be ``'no_channel'``. If this is set to `None`, this class relies on `img` or `meta_dict` to provide the channel dimension. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, strict_check: bool = True, channel_dim: Union[None, str, int] = "no_channel"): + def __init__(self, strict_check: bool = True, channel_dim: Union[None, str, int] = None): self.strict_check = strict_check self.input_channel_dim = channel_dim @@ -239,17 +239,19 @@ def __call__(self, img: torch.Tensor, meta_dict: Optional[Mapping] = None) -> to meta_dict = img.meta channel_dim = meta_dict.get("original_channel_dim", None) if isinstance(meta_dict, Mapping) else None + if self.input_channel_dim is not None: + channel_dim = self.input_channel_dim if channel_dim is None: - if self.input_channel_dim is None: - msg = "Unknown original_channel_dim in the MetaTensor meta dict or `meta_dict` or `channel_dim`." - if self.strict_check: - raise ValueError(msg) - warnings.warn(msg) - return img - channel_dim = self.input_channel_dim - if isinstance(meta_dict, dict): - meta_dict["original_channel_dim"] = self.input_channel_dim + msg = "Unknown original_channel_dim in the MetaTensor meta dict or `meta_dict` or `channel_dim`." + if self.strict_check: + raise ValueError(msg) + warnings.warn(msg) + return img + + # track the original channel dim + if isinstance(meta_dict, dict): + meta_dict["original_channel_dim"] = channel_dim if channel_dim == "no_channel": result = img[None] diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 53b0646379..513367033a 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -301,7 +301,7 @@ def __init__( meta_key_postfix: str = DEFAULT_POST_FIX, strict_check: bool = True, allow_missing_keys: bool = False, - channel_dim="no_channel", + channel_dim=None, ) -> None: """ Args: @@ -309,9 +309,9 @@ def __init__( See also: :py:class:`monai.transforms.compose.MapTransform` strict_check: whether to raise an error when the meta information is insufficient. allow_missing_keys: don't raise exception if key is missing. - channel_dim: If the input image `img` is not a MetaTensor or `meta_dict` is not given, - this argument can be used to specify the original channel dimension (integer) of the input array. - If the input array doesn't have a channel dim, this value should be ``'no_channel'`` (default). + channel_dim: This argument can be used to specify the original channel dimension (integer) of the input array. + It overrides the `original_channel_dim` from provided MetaTensor input. + If the input array doesn't have a channel dim, this value should be ``'no_channel'``. If this is set to `None`, this class relies on `img` or `meta_dict` to provide the channel dimension. """ super().__init__(keys, allow_missing_keys) diff --git a/tests/test_ensure_channel_first.py b/tests/test_ensure_channel_first.py index fca2f90139..e671edc0a7 100644 --- a/tests/test_ensure_channel_first.py +++ b/tests/test_ensure_channel_first.py @@ -73,6 +73,8 @@ def test_load_png(self): result = LoadImage(image_only=True)(filename) result = EnsureChannelFirst()(result) self.assertEqual(result.shape[0], 3) + result = EnsureChannelFirst(channel_dim=-1)(result) + self.assertEqual(result.shape, (6, 3, 6)) def test_check(self): im = torch.zeros(1, 2, 3) @@ -83,7 +85,7 @@ def test_check(self): with self.assertRaises(ValueError): # no meta EnsureChannelFirst(channel_dim=None)(MetaTensor(im)) with self.assertRaises(ValueError): # no meta channel - EnsureChannelFirst(channel_dim=None)(im_nodim) + EnsureChannelFirst()(im_nodim) with self.assertWarns(Warning): EnsureChannelFirst(strict_check=False, channel_dim=None)(im)