From 94e9e1783536326fa99c202cba4e0896b37778eb Mon Sep 17 00:00:00 2001 From: Wenqi Li <831580+wyli@users.noreply.github.com> Date: Thu, 16 Feb 2023 08:45:52 +0000 Subject: [PATCH] 6007 reverse_indexing for PILReader (#6008) Fixes #6007 ### Description - reverse_indexing = False: to support consistency with PIL/torchvision ```py img = LoadImage(image_only=True, ensure_channel_first=True, reverse_indexing=False)("MONAI-logo_color.png") # PILReader torchvision.utils.save_image(img, "MONAI-logo_color_torchvision.png", normalize=True) ``` - reverse_indexing = True: to support consistency with other backends in monai ```py img = LoadImage(image_only=True, ensure_channel_first=True, reader="PILReader", reverse_indexing=True)(filename) # PIL backend img_1 = LoadImage(image_only=True, ensure_channel_first=True, reader="ITKReader")(filename) # itk backend np.testing.assert_allclose(img, img_1) # true ``` ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li --- monai/data/image_reader.py | 12 ++++++++---- monai/transforms/io/array.py | 6 +++--- monai/transforms/io/dictionary.py | 6 +++--- tests/test_pil_reader.py | 9 +++++---- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index a892b5b8fd..736baff538 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1141,13 +1141,17 @@ class PILReader(ImageReader): Args: converter: additional function to convert the image data after `read()`. for example, use `converter=lambda image: image.convert("LA")` to convert image format. + reverse_indexing: whether to swap axis 0 and 1 after loading the array, this is enabled by default, + so that output of the reader is consistent with the other readers. Set this option to ``False`` to use + the PIL backend's original spatial axes convention. kwargs: additional args for `Image.open` API in `read()`, mode details about available args: https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open """ - def __init__(self, converter: Callable | None = None, **kwargs): + def __init__(self, converter: Callable | None = None, reverse_indexing: bool = True, **kwargs): super().__init__() self.converter = converter + self.reverse_indexing = reverse_indexing self.kwargs = kwargs def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: @@ -1194,8 +1198,8 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: It computes `spatial_shape` and stores it in meta dict. When loading a list of files, they are stacked together at a new dimension as the first dimension, and the metadata of the first image is used to represent the output metadata. - Note that it will swap axis 0 and 1 after loading the array because the `HW` definition in PIL - is different from other common medical packages. + Note that by default `self.reverse_indexing` is set to ``True``, which swaps axis 0 and 1 after loading + the array because the spatial axes definition in PIL is different from other common medical packages. Args: img: a PIL Image object loaded from a file or a list of PIL Image objects. @@ -1207,7 +1211,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: for i in ensure_tuple(img): header = self._get_meta_dict(i) header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) - data = np.moveaxis(np.asarray(i), 0, 1) + data = np.moveaxis(np.asarray(i), 0, 1) if self.reverse_indexing else np.asarray(i) img_array.append(data) header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( "no_channel" if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index a91f21e220..a3bd8a2d18 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -116,9 +116,9 @@ class LoadImage(Transform): - Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), (npz, npy -> NumpyReader), (nrrd -> NrrdReader), (DICOM file -> ITKReader). - Please note that for png, jpg, bmp, and other 2D formats, readers often swap axis 0 and 1 after - loading the array because the `HW` definition for non-medical specific file formats is different - from other common medical packages. + Please note that for png, jpg, bmp, and other 2D formats, readers by default swap axis 0 and 1 after + loading the array with ``reverse_indexing`` set to ``True`` because the spatial axes definition + for non-medical specific file formats is different from other common medical packages. See also: diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 9f75444dc6..cb832b59e0 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -53,9 +53,9 @@ class LoadImaged(MapTransform): - Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), (npz, npy -> NumpyReader), (dcm, DICOM series and others -> ITKReader). - Please note that for png, jpg, bmp, and other 2D formats, readers often swap axis 0 and 1 after - loading the array because the `HW` definition for non-medical specific file formats is different - from other common medical packages. + Please note that for png, jpg, bmp, and other 2D formats, readers by default swap axis 0 and 1 after + loading the array with ``reverse_indexing`` set to ``True`` because the spatial axes definition + for non-medical specific file formats is different from other common medical packages. Note: diff --git a/tests/test_pil_reader.py b/tests/test_pil_reader.py index ab4c5c40ca..dfa5eb725d 100644 --- a/tests/test_pil_reader.py +++ b/tests/test_pil_reader.py @@ -25,7 +25,7 @@ TEST_CASE_2 = [(128, 128, 3), ["test_image.png"], (128, 128, 3), (128, 128)] -TEST_CASE_3 = [(128, 128, 4), ["test_image.png"], (128, 128, 4), (128, 128)] +TEST_CASE_3 = [(128, 128, 4), ["test_image.png"], (128, 128, 4), (128, 128), False] TEST_CASE_4 = [(128, 128), ["test_image1.png", "test_image2.png", "test_image3.png"], (3, 128, 128), (128, 128)] @@ -38,20 +38,21 @@ class TestPNGReader(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) - def test_shape_value(self, data_shape, filenames, expected_shape, meta_shape): + def test_shape_value(self, data_shape, filenames, expected_shape, meta_shape, reverse=True): test_image = np.random.randint(0, 256, size=data_shape) with tempfile.TemporaryDirectory() as tempdir: for i, name in enumerate(filenames): filenames[i] = os.path.join(tempdir, name) Image.fromarray(test_image.astype("uint8")).save(filenames[i]) - reader = PILReader(mode="r") + reader = PILReader(mode="r", reverse_indexing=reverse) result = reader.get_data(reader.read(filenames)) # load image by PIL and compare the result test_image = np.asarray(Image.open(filenames[0])) self.assertTupleEqual(tuple(result[1]["spatial_shape"]), meta_shape) self.assertTupleEqual(result[0].shape, expected_shape) - test_image = np.moveaxis(test_image, 0, 1) + if reverse: + test_image = np.moveaxis(test_image, 0, 1) if result[0].shape == test_image.shape: np.testing.assert_allclose(result[0], test_image) else: