Skip to content

Commit

Permalink
Fix ImageFilter to allow Gaussian filter without filter_size (#8189)
Browse files Browse the repository at this point in the history
Fixes #8127

Update `ImageFilter` to handle Gaussian filter without requiring
`filter_size`.

* Modify `monai/transforms/utility/array.py` to allow Gaussian filter
without `filter_size`.
- Adjust `_check_filter_format` method to skip `filter_size` check for
Gaussian filter. Indeed Gauss filter is the only one in the list that
doesn't require a filter_size.
* Add unit test in `tests/test_image_filter.py` for Gaussian filter
without `filter_size`.
  - Verify output shape matches input shape.

Note that this method is compliant with the dictionnary version since
this one load the fixed version.

Signed-off-by: Eloi <[email protected]>

---------

Signed-off-by: Eloi Navet <[email protected]>
Signed-off-by: Eloi <[email protected]>
Signed-off-by: Eloi [email protected]
  • Loading branch information
EloiNavet authored Nov 7, 2024
1 parent c1ceea3 commit 530cc1f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
4 changes: 2 additions & 2 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,9 +1609,9 @@ def _check_all_values_uneven(self, x: tuple) -> None:

def _check_filter_format(self, filter: str | NdarrayOrTensor | nn.Module, filter_size: int | None = None) -> None:
if isinstance(filter, str):
if not filter_size:
if filter != "gauss" and not filter_size: # Gauss is the only filter that does not require `filter_size`
raise ValueError("`filter_size` must be specified when specifying filters by string.")
if filter_size % 2 == 0:
if filter_size and filter_size % 2 == 0:
raise ValueError("`filter_size` should be a single uneven integer.")
if filter not in self.supported_filters:
raise NotImplementedError(f"{filter}. Supported filters are {self.supported_filters}.")
Expand Down
6 changes: 6 additions & 0 deletions tests/test_image_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ def test_pass_empty_metadata_dict(self):
out_tensor = filter(image)
self.assertTrue(isinstance(out_tensor, MetaTensor))

def test_gaussian_filter_without_filter_size(self):
"Test Gaussian filter without specifying filter_size"
filter = ImageFilter("gauss", sigma=2)
out_tensor = filter(SAMPLE_IMAGE_2D)
self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_2D.shape[1:])


class TestImageFilterDict(unittest.TestCase):

Expand Down

0 comments on commit 530cc1f

Please sign in to comment.