Skip to content

Commit

Permalink
5206 channel_dim param has higher priority in EnsureChannelFirst (#…
Browse files Browse the repository at this point in the history
…5208)

Signed-off-by: Wenqi Li <[email protected]>

Fixes #5206

### Description
changes to `EnsureChannelFirst` so that the input `channel_dim` has a
higher priority than `metatensor.meta['original_channel_dim']`

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli authored Sep 25, 2022
1 parent 84e271e commit 7f16a15
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 18 deletions.
28 changes: 15 additions & 13 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,15 @@ class EnsureChannelFirst(Transform):
Args:
strict_check: whether to raise an error when the meta information is insufficient.
channel_dim: If the input image `img` is not a MetaTensor or `meta_dict` is not given,
this argument can be used to specify the original channel dimension (integer) of the input array.
If the input array doesn't have a channel dim, this value should be ``'no_channel'`` (default).
channel_dim: This argument can be used to specify the original channel dimension (integer) of the input array.
It overrides the `original_channel_dim` from provided MetaTensor input.
If the input array doesn't have a channel dim, this value should be ``'no_channel'``.
If this is set to `None`, this class relies on `img` or `meta_dict` to provide the channel dimension.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, strict_check: bool = True, channel_dim: Union[None, str, int] = "no_channel"):
def __init__(self, strict_check: bool = True, channel_dim: Union[None, str, int] = None):
self.strict_check = strict_check
self.input_channel_dim = channel_dim

Expand All @@ -239,17 +239,19 @@ def __call__(self, img: torch.Tensor, meta_dict: Optional[Mapping] = None) -> to
meta_dict = img.meta

channel_dim = meta_dict.get("original_channel_dim", None) if isinstance(meta_dict, Mapping) else None
if self.input_channel_dim is not None:
channel_dim = self.input_channel_dim

if channel_dim is None:
if self.input_channel_dim is None:
msg = "Unknown original_channel_dim in the MetaTensor meta dict or `meta_dict` or `channel_dim`."
if self.strict_check:
raise ValueError(msg)
warnings.warn(msg)
return img
channel_dim = self.input_channel_dim
if isinstance(meta_dict, dict):
meta_dict["original_channel_dim"] = self.input_channel_dim
msg = "Unknown original_channel_dim in the MetaTensor meta dict or `meta_dict` or `channel_dim`."
if self.strict_check:
raise ValueError(msg)
warnings.warn(msg)
return img

# track the original channel dim
if isinstance(meta_dict, dict):
meta_dict["original_channel_dim"] = channel_dim

if channel_dim == "no_channel":
result = img[None]
Expand Down
8 changes: 4 additions & 4 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,17 +301,17 @@ def __init__(
meta_key_postfix: str = DEFAULT_POST_FIX,
strict_check: bool = True,
allow_missing_keys: bool = False,
channel_dim="no_channel",
channel_dim=None,
) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
strict_check: whether to raise an error when the meta information is insufficient.
allow_missing_keys: don't raise exception if key is missing.
channel_dim: If the input image `img` is not a MetaTensor or `meta_dict` is not given,
this argument can be used to specify the original channel dimension (integer) of the input array.
If the input array doesn't have a channel dim, this value should be ``'no_channel'`` (default).
channel_dim: This argument can be used to specify the original channel dimension (integer) of the input array.
It overrides the `original_channel_dim` from provided MetaTensor input.
If the input array doesn't have a channel dim, this value should be ``'no_channel'``.
If this is set to `None`, this class relies on `img` or `meta_dict` to provide the channel dimension.
"""
super().__init__(keys, allow_missing_keys)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_ensure_channel_first.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def test_load_png(self):
result = LoadImage(image_only=True)(filename)
result = EnsureChannelFirst()(result)
self.assertEqual(result.shape[0], 3)
result = EnsureChannelFirst(channel_dim=-1)(result)
self.assertEqual(result.shape, (6, 3, 6))

def test_check(self):
im = torch.zeros(1, 2, 3)
Expand All @@ -83,7 +85,7 @@ def test_check(self):
with self.assertRaises(ValueError): # no meta
EnsureChannelFirst(channel_dim=None)(MetaTensor(im))
with self.assertRaises(ValueError): # no meta channel
EnsureChannelFirst(channel_dim=None)(im_nodim)
EnsureChannelFirst()(im_nodim)

with self.assertWarns(Warning):
EnsureChannelFirst(strict_check=False, channel_dim=None)(im)
Expand Down

0 comments on commit 7f16a15

Please sign in to comment.