Skip to content

Commit

Permalink
fixes #6007
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli committed Feb 15, 2023
1 parent 9ddb14c commit e9acacf
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
8 changes: 6 additions & 2 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,13 +1141,17 @@ class PILReader(ImageReader):
Args:
converter: additional function to convert the image data after `read()`.
for example, use `converter=lambda image: image.convert("LA")` to convert image format.
reverse_indexing: whether to swap axis 0 and 1 after loading the array, this is enabled by default,
so that output of the reader is consistent with the other readers. Set this option to ``False`` to use
the PIL backend's original spatial axes convention.
kwargs: additional args for `Image.open` API in `read()`, mode details about available args:
https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open
"""

def __init__(self, converter: Callable | None = None, **kwargs):
def __init__(self, converter: Callable | None = None, reverse_indexing: bool = True, **kwargs):
super().__init__()
self.converter = converter
self.reverse_indexing = reverse_indexing
self.kwargs = kwargs

def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
Expand Down Expand Up @@ -1207,7 +1211,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
for i in ensure_tuple(img):
header = self._get_meta_dict(i)
header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i)
data = np.moveaxis(np.asarray(i), 0, 1)
data = np.moveaxis(np.asarray(i), 0, 1) if self.reverse_indexing else np.asarray(i)
img_array.append(data)
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = (
"no_channel" if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1
Expand Down
9 changes: 5 additions & 4 deletions tests/test_pil_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

TEST_CASE_2 = [(128, 128, 3), ["test_image.png"], (128, 128, 3), (128, 128)]

TEST_CASE_3 = [(128, 128, 4), ["test_image.png"], (128, 128, 4), (128, 128)]
TEST_CASE_3 = [(128, 128, 4), ["test_image.png"], (128, 128, 4), (128, 128), False]

TEST_CASE_4 = [(128, 128), ["test_image1.png", "test_image2.png", "test_image3.png"], (3, 128, 128), (128, 128)]

Expand All @@ -38,20 +38,21 @@

class TestPNGReader(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])
def test_shape_value(self, data_shape, filenames, expected_shape, meta_shape):
def test_shape_value(self, data_shape, filenames, expected_shape, meta_shape, reverse=True):
test_image = np.random.randint(0, 256, size=data_shape)
with tempfile.TemporaryDirectory() as tempdir:
for i, name in enumerate(filenames):
filenames[i] = os.path.join(tempdir, name)
Image.fromarray(test_image.astype("uint8")).save(filenames[i])
reader = PILReader(mode="r")
reader = PILReader(mode="r", reverse_indexing=reverse)
result = reader.get_data(reader.read(filenames))
# load image by PIL and compare the result
test_image = np.asarray(Image.open(filenames[0]))

self.assertTupleEqual(tuple(result[1]["spatial_shape"]), meta_shape)
self.assertTupleEqual(result[0].shape, expected_shape)
test_image = np.moveaxis(test_image, 0, 1)
if reverse:
test_image = np.moveaxis(test_image, 0, 1)
if result[0].shape == test_image.shape:
np.testing.assert_allclose(result[0], test_image)
else:
Expand Down

0 comments on commit e9acacf

Please sign in to comment.