diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 4cf7772058..4f88139677 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -17,11 +17,19 @@ from typing import Any import numpy as np +import torch -from monai.config import DtypeLike, PathLike +from monai.config import DtypeLike, NdarrayOrTensor, PathLike from monai.data.image_reader import ImageReader, _stack_images from monai.data.utils import is_supported_format -from monai.utils import WSIPatchKeys, ensure_tuple, optional_import, require_pkg +from monai.utils import ( + WSIPatchKeys, + dtype_numpy_to_torch, + dtype_torch_to_numpy, + ensure_tuple, + optional_import, + require_pkg, +) OpenSlide, _ = optional_import("openslide", name="OpenSlide") TiffFile, _ = optional_import("tifffile", name="TiffFile") @@ -33,12 +41,21 @@ class BaseWSIReader(ImageReader): """ An abstract class that defines APIs to load patches from whole slide image files. + Args: + level: the whole slide image level at which the image is extracted. + channel_dim: the desired dimension for color channel. + dtype: the data type of output image. + device: target device to put the extracted patch. Note that if device is "cuda"", + the output will be converted to torch tenor and sent to the gpu even if the dtype is numpy. + mode: the output image color mode, e.g., "RGB" or "RGBA". + kwargs: additional args for the reader + Typical usage of a concrete implementation of this class is: .. code-block:: python image_reader = MyWSIReader() - wsi = image_reader.read(, **kwargs) + wsi = image_reader.read(filepath, **kwargs) img_data, meta_data = image_reader.get_data(wsi) - The `read` call converts an image filename into whole slide image object, @@ -58,13 +75,37 @@ class BaseWSIReader(ImageReader): supported_suffixes: list[str] = [] backend = "" - def __init__(self, level: int = 0, channel_dim: int = 0, **kwargs): + def __init__( + self, + level: int, + channel_dim: int, + dtype: DtypeLike | torch.dtype, + device: torch.device | str | None, + mode: str, + **kwargs, + ): super().__init__() self.level = level self.channel_dim = channel_dim + self.set_dtype(dtype) + self.set_device(device) + self.mode = mode self.kwargs = kwargs self.metadata: dict[Any, Any] = {} + def set_dtype(self, dtype): + self.dtype: torch.dtype | np.dtype + if isinstance(dtype, torch.dtype): + self.dtype = dtype + else: + self.dtype = np.dtype(dtype) + + def set_device(self, device): + if device is None or isinstance(device, (torch.device, str)): + self.device = device + else: + raise ValueError(f"`device` must be `torch.device`, `str` or `None` but {type(device)} is given.") + @abstractmethod def get_size(self, wsi, level: int | None = None) -> tuple[int, int]: """ @@ -138,7 +179,7 @@ def _get_patch( raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") def _get_metadata( - self, wsi, patch: np.ndarray, location: tuple[int, int], size: tuple[int, int], level: int + self, wsi, patch: NdarrayOrTensor, location: tuple[int, int], size: tuple[int, int], level: int ) -> dict: """ Returns metadata of the extracted patch from the whole slide image. @@ -175,8 +216,7 @@ def get_data( location: tuple[int, int] = (0, 0), size: tuple[int, int] | None = None, level: int | None = None, - dtype: DtypeLike = np.uint8, - mode: str = "RGB", + mode: str | None = None, ) -> tuple[np.ndarray, dict]: """ Verifies inputs, extracts patches from WSI image and generates metadata, and return them. @@ -185,15 +225,16 @@ def get_data( wsi: a whole slide image object loaded from a file or a list of such objects location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). size: (height, width) tuple giving the patch size at the given level (`level`). - If None, it is set to the full image size at the given level. + If not provided or None, it is set to the full image size at the given level. level: the level number. Defaults to 0 - dtype: the data type of output image - mode: the output image mode, 'RGB' or 'RGBA' + mode: the output image color mode, "RGB" or "RGBA". If not provided the default of "RGB" is used. Returns: a tuples, where the first element is an image patch [CxHxW] or stack of patches, and second element is a dictionary of metadata """ + if mode is None: + mode = self.mode patch_list: list = [] metadata_list: list = [] # CuImage object is iterable, so ensure_tuple won't work on single object @@ -223,8 +264,25 @@ def get_data( if size[0] <= 0 or size[1] <= 0: raise ValueError(f"Patch size should be greater than zero, provided: patch size = {size}") + # Get numpy dtype if it is not already. + dtype_np = dtype_torch_to_numpy(self.dtype) if isinstance(self.dtype, torch.dtype) else self.dtype # Extract a patch or the entire image - patch = self._get_patch(each_wsi, location=location, size=size, level=level, dtype=dtype, mode=mode) + patch: NdarrayOrTensor + patch = self._get_patch(each_wsi, location=location, size=size, level=level, dtype=dtype_np, mode=mode) + + # Convert the patch to torch.Tensor if dtype is torch + if isinstance(self.dtype, torch.dtype) or ( + self.device is not None and torch.device(self.device).type == "cuda" + ): + # Ensure dtype is torch.dtype if the device is not "cpu" + dtype_torch = ( + dtype_numpy_to_torch(self.dtype) if not isinstance(self.dtype, torch.dtype) else self.dtype + ) + # Copy the numpy array if it is not writable + if patch.flags["WRITEABLE"]: + patch = torch.as_tensor(patch, dtype=dtype_torch, device=self.device) + else: + patch = torch.tensor(patch, dtype=dtype_torch, device=self.device) # check if the image has three dimensions (2D + color) if patch.ndim != 3: @@ -280,26 +338,53 @@ class WSIReader(BaseWSIReader): backend: the name of backend whole slide image reader library, the default is cuCIM. level: the level at which patches are extracted. channel_dim: the desired dimension for color channel. Default to 0 (channel first). + dtype: the data type of output image. Defaults to `np.uint8`. + mode: the output image color mode, "RGB" or "RGBA". Defaults to "RGB". + device: target device to put the extracted patch. Note that if device is "cuda"", + the output will be converted to torch tenor and sent to the gpu even if the dtype is numpy. num_workers: number of workers for multi-thread image loading (cucim backend only). kwargs: additional arguments to be passed to the backend library """ - def __init__(self, backend="cucim", level: int = 0, channel_dim: int = 0, **kwargs): - super().__init__(level, channel_dim, **kwargs) + supported_backends = ["cucim", "openslide", "tifffile"] + + def __init__( + self, + backend="cucim", + level: int = 0, + channel_dim: int = 0, + dtype: DtypeLike | torch.dtype = np.uint8, + device: torch.device | str | None = None, + mode: str = "RGB", + **kwargs, + ): self.backend = backend.lower() self.reader: CuCIMWSIReader | OpenSlideWSIReader | TiffFileWSIReader if self.backend == "cucim": - self.reader = CuCIMWSIReader(level=level, channel_dim=channel_dim, **kwargs) + self.reader = CuCIMWSIReader( + level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs + ) elif self.backend == "openslide": - self.reader = OpenSlideWSIReader(level=level, channel_dim=channel_dim, **kwargs) + self.reader = OpenSlideWSIReader( + level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs + ) elif self.backend == "tifffile": - self.reader = TiffFileWSIReader(level=level, channel_dim=channel_dim, **kwargs) + self.reader = TiffFileWSIReader( + level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs + ) else: raise ValueError( f"The supported backends are cucim, openslide, and tifffile but '{self.backend}' was given." ) self.supported_suffixes = self.reader.supported_suffixes + self.level = self.reader.level + self.channel_dim = self.reader.channel_dim + self.dtype = self.reader.dtype + self.device = self.reader.device + self.mode = self.reader.mode + self.kwargs = self.reader.kwargs + self.metadata = self.reader.metadata def get_level_count(self, wsi) -> int: """ @@ -402,6 +487,10 @@ class CuCIMWSIReader(BaseWSIReader): level: the whole slide image level at which the image is extracted. (default=0) This is overridden if the level argument is provided in `get_data`. channel_dim: the desired dimension for color channel. Default to 0 (channel first). + dtype: the data type of output image. Defaults to `np.uint8`. + device: target device to put the extracted patch. Note that if device is "cuda"", + the output will be converted to torch tenor and sent to the gpu even if the dtype is numpy. + mode: the output image color mode, "RGB" or "RGBA". Defaults to "RGB". num_workers: number of workers for multi-thread image loading kwargs: additional args for `cucim.CuImage` module: https://github.com/rapidsai/cucim/blob/main/cpp/include/cucim/cuimage.h @@ -411,8 +500,17 @@ class CuCIMWSIReader(BaseWSIReader): supported_suffixes = ["tif", "tiff", "svs"] backend = "cucim" - def __init__(self, level: int = 0, channel_dim: int = 0, num_workers: int = 0, **kwargs): - super().__init__(level, channel_dim, **kwargs) + def __init__( + self, + level: int = 0, + channel_dim: int = 0, + dtype: DtypeLike | torch.dtype = np.uint8, + device: torch.device | str | None = None, + mode: str = "RGB", + num_workers: int = 0, + **kwargs, + ): + super().__init__(level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs) self.num_workers = num_workers @staticmethod @@ -551,6 +649,10 @@ class OpenSlideWSIReader(BaseWSIReader): level: the whole slide image level at which the image is extracted. (default=0) This is overridden if the level argument is provided in `get_data`. channel_dim: the desired dimension for color channel. Default to 0 (channel first). + dtype: the data type of output image. Defaults to `np.uint8`. + device: target device to put the extracted patch. Note that if device is "cuda"", + the output will be converted to torch tenor and sent to the gpu even if the dtype is numpy. + mode: the output image color mode, "RGB" or "RGBA". Defaults to "RGB". kwargs: additional args for `openslide.OpenSlide` module. """ @@ -558,6 +660,17 @@ class OpenSlideWSIReader(BaseWSIReader): supported_suffixes = ["tif", "tiff", "svs"] backend = "openslide" + def __init__( + self, + level: int = 0, + channel_dim: int = 0, + dtype: DtypeLike | torch.dtype = np.uint8, + device: torch.device | str | None = None, + mode: str = "RGB", + **kwargs, + ): + super().__init__(level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs) + @staticmethod def get_level_count(wsi) -> int: """ @@ -695,6 +808,10 @@ class TiffFileWSIReader(BaseWSIReader): level: the whole slide image level at which the image is extracted. (default=0) This is overridden if the level argument is provided in `get_data`. channel_dim: the desired dimension for color channel. Default to 0 (channel first). + dtype: the data type of output image. Defaults to `np.uint8`. + device: target device to put the extracted patch. Note that if device is "cuda"", + the output will be converted to torch tenor and sent to the gpu even if the dtype is numpy. + mode: the output image color mode, "RGB" or "RGBA". Defaults to "RGB". kwargs: additional args for `tifffile.TiffFile` module. """ @@ -702,6 +819,17 @@ class TiffFileWSIReader(BaseWSIReader): supported_suffixes = ["tif", "tiff", "svs"] backend = "tifffile" + def __init__( + self, + level: int = 0, + channel_dim: int = 0, + dtype: DtypeLike | torch.dtype = np.uint8, + device: torch.device | str | None = None, + mode: str = "RGB", + **kwargs, + ): + super().__init__(level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs) + @staticmethod def get_level_count(wsi) -> int: """ diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py index d435902c28..06106d3161 100644 --- a/tests/test_wsireader.py +++ b/tests/test_wsireader.py @@ -17,6 +17,7 @@ from unittest import skipUnless import numpy as np +import torch from numpy.testing import assert_array_equal from parameterized import parameterized @@ -27,7 +28,7 @@ from monai.transforms import Compose, LoadImaged, ToTensord from monai.utils import deprecated, first, optional_import from monai.utils.enums import PostFix, WSIPatchKeys -from tests.utils import assert_allclose, download_url_or_skip_test, testing_data_config +from tests.utils import assert_allclose, download_url_or_skip_test, skip_if_no_cuda, testing_data_config cucim, has_cucim = optional_import("cucim") has_cucim = has_cucim and hasattr(cucim, "CuImage") @@ -44,10 +45,13 @@ HEIGHT = 32914 WIDTH = 46000 -TEST_CASE_0 = [FILE_PATH, 2, (3, HEIGHT // 4, WIDTH // 4)] +TEST_CASE_WHOLE_0 = [FILE_PATH, 2, (3, HEIGHT // 4, WIDTH // 4)] TEST_CASE_TRANSFORM_0 = [FILE_PATH, 4, (HEIGHT // 16, WIDTH // 16), (1, 3, HEIGHT // 16, WIDTH // 16)] +# ---------------------------------------------------------------------------- +# Test cases for *deprecated* monai.data.image_reader.WSIReader +# ---------------------------------------------------------------------------- TEST_CASE_DEP_1 = [ FILE_PATH, {"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0}, @@ -83,39 +87,135 @@ np.array([[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[237, 237], [237, 237]]]), ] +# ---------------------------------------------------------------------------- +# Test cases for monai.data.wsi_reader.WSIReader +# ---------------------------------------------------------------------------- + +TEST_CASE_0 = [ + FILE_PATH, + {"level": 8, "dtype": None}, + {"location": (0, 0), "size": (2, 1)}, + np.array([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=np.float64), +] + TEST_CASE_1 = [ FILE_PATH, {}, {"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0}, - np.array([[[246], [246]], [[246], [246]], [[246], [246]]]), + np.array([[[246], [246]], [[246], [246]], [[246], [246]]], dtype=np.uint8), ] TEST_CASE_2 = [ FILE_PATH, {}, {"location": (0, 0), "size": (2, 1), "level": 2}, - np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), + np.array([[[239], [239]], [[239], [239]], [[239], [239]]], dtype=np.uint8), ] TEST_CASE_3 = [ FILE_PATH, {"channel_dim": -1}, {"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0}, - np.moveaxis(np.array([[[246], [246]], [[246], [246]], [[246], [246]]]), 0, -1), + np.moveaxis(np.array([[[246], [246]], [[246], [246]], [[246], [246]]], dtype=np.uint8), 0, -1), ] TEST_CASE_4 = [ FILE_PATH, {"channel_dim": 2}, {"location": (0, 0), "size": (2, 1), "level": 2}, - np.moveaxis(np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), 0, -1), + np.moveaxis(np.array([[[239], [239]], [[239], [239]], [[239], [239]]], dtype=np.uint8), 0, -1), ] TEST_CASE_5 = [ FILE_PATH, {"level": 2}, {"location": (0, 0), "size": (2, 1)}, - np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), + np.array([[[239], [239]], [[239], [239]], [[239], [239]]], dtype=np.uint8), +] + +TEST_CASE_6 = [ + FILE_PATH, + {"level": 2, "dtype": np.int32}, + {"location": (0, 0), "size": (2, 1)}, + np.array([[[239], [239]], [[239], [239]], [[239], [239]]], dtype=np.int32), +] + +TEST_CASE_7 = [ + FILE_PATH, + {"level": 2, "dtype": np.float32}, + {"location": (0, 0), "size": (2, 1)}, + np.array([[[239], [239]], [[239], [239]], [[239], [239]]], dtype=np.float32), +] + +TEST_CASE_8 = [ + FILE_PATH, + {"level": 2, "dtype": torch.uint8}, + {"location": (0, 0), "size": (2, 1)}, + torch.tensor([[[239], [239]], [[239], [239]], [[239], [239]]], dtype=torch.uint8), +] + +TEST_CASE_9 = [ + FILE_PATH, + {"level": 2, "dtype": torch.float32}, + {"location": (0, 0), "size": (2, 1)}, + torch.tensor([[[239], [239]], [[239], [239]], [[239], [239]]], dtype=torch.float32), +] + +# device tests +TEST_CASE_DEVICE_1 = [ + FILE_PATH, + {"level": 2, "dtype": torch.float32, "device": "cpu"}, + {"location": (0, 0), "size": (2, 1)}, + torch.tensor([[[239], [239]], [[239], [239]], [[239], [239]]], dtype=torch.float32), + "cpu", +] + +TEST_CASE_DEVICE_2 = [ + FILE_PATH, + {"level": 2, "dtype": torch.float32, "device": "cuda"}, + {"location": (0, 0), "size": (2, 1)}, + torch.tensor([[[239], [239]], [[239], [239]], [[239], [239]]], dtype=torch.float32), + "cuda", +] + +TEST_CASE_DEVICE_3 = [ + FILE_PATH, + {"level": 2, "dtype": np.float32, "device": "cpu"}, + {"location": (0, 0), "size": (2, 1)}, + np.array([[[239], [239]], [[239], [239]], [[239], [239]]], dtype=np.float32), + "cpu", +] + +TEST_CASE_DEVICE_4 = [ + FILE_PATH, + {"level": 2, "dtype": np.float32, "device": "cuda"}, + {"location": (0, 0), "size": (2, 1)}, + torch.tensor([[[239], [239]], [[239], [239]], [[239], [239]]], dtype=torch.float32), + "cuda", +] + +TEST_CASE_DEVICE_5 = [ + FILE_PATH, + {"level": 2, "device": "cuda"}, + {"location": (0, 0), "size": (2, 1)}, + torch.tensor([[[239], [239]], [[239], [239]], [[239], [239]]], dtype=torch.uint8), + "cuda", +] + +TEST_CASE_DEVICE_6 = [ + FILE_PATH, + {"level": 2}, + {"location": (0, 0), "size": (2, 1)}, + np.array([[[239], [239]], [[239], [239]], [[239], [239]]], dtype=np.uint8), + "cpu", +] + +TEST_CASE_DEVICE_7 = [ + FILE_PATH, + {"level": 2, "device": None}, + {"location": (0, 0), "size": (2, 1)}, + np.array([[[239], [239]], [[239], [239]], [[239], [239]]], dtype=np.uint8), + "cpu", ] TEST_CASE_MULTI_WSI = [ @@ -186,7 +286,7 @@ class WSIReaderDeprecatedTests: class Tests(unittest.TestCase): backend = None - @parameterized.expand([TEST_CASE_0]) + @parameterized.expand([TEST_CASE_WHOLE_0]) def test_read_whole_image(self, file_path, level, expected_shape): reader = WSIReaderDeprecated(self.backend, level=level) with reader.read(file_path) as img_obj: @@ -261,7 +361,9 @@ def test_with_dataloader( ): train_transform = Compose( [ - LoadImaged(keys=["image"], reader=WSIReaderDeprecated, backend=self.backend, level=level), + LoadImaged( + keys=["image"], reader=WSIReaderDeprecated, backend=self.backend, level=level, image_only=False + ), ToTensord(keys=["image"]), ] ) @@ -277,7 +379,7 @@ class WSIReaderTests: class Tests(unittest.TestCase): backend = None - @parameterized.expand([TEST_CASE_0]) + @parameterized.expand([TEST_CASE_WHOLE_0]) def test_read_whole_image(self, file_path, level, expected_shape): reader = WSIReader(self.backend, level=level) with reader.read(file_path) as img_obj: @@ -286,10 +388,23 @@ def test_read_whole_image(self, file_path, level, expected_shape): self.assertEqual(meta["backend"], self.backend) self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower()) self.assertEqual(meta[WSIPatchKeys.LEVEL], level) - assert_array_equal(meta[WSIPatchKeys.SIZE], expected_shape[1:]) - assert_array_equal(meta[WSIPatchKeys.LOCATION], (0, 0)) - - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + assert_allclose(meta[WSIPatchKeys.SIZE], expected_shape[1:], type_test=False) + assert_allclose(meta[WSIPatchKeys.LOCATION], (0, 0), type_test=False) + + @parameterized.expand( + [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + ] + ) def test_read_region(self, file_path, kwargs, patch_info, expected_img): reader = WSIReader(self.backend, **kwargs) level = patch_info.get("level", kwargs.get("level")) @@ -300,14 +415,15 @@ def test_read_region(self, file_path, kwargs, patch_info, expected_img): img, meta = reader.get_data(img_obj, **patch_info) img2 = reader.get_data(img_obj, **patch_info)[0] self.assertTupleEqual(img.shape, img2.shape) - self.assertIsNone(assert_array_equal(img, img2)) + assert_allclose(img, img2) self.assertTupleEqual(img.shape, expected_img.shape) - self.assertIsNone(assert_array_equal(img, expected_img)) + assert_allclose(img, expected_img) + self.assertEqual(img.dtype, expected_img.dtype) self.assertEqual(meta["backend"], self.backend) self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower()) self.assertEqual(meta[WSIPatchKeys.LEVEL], level) - assert_array_equal(meta[WSIPatchKeys.SIZE], patch_info["size"]) - assert_array_equal(meta[WSIPatchKeys.LOCATION], patch_info["location"]) + assert_allclose(meta[WSIPatchKeys.SIZE], patch_info["size"], type_test=False) + assert_allclose(meta[WSIPatchKeys.LOCATION], patch_info["location"], type_test=False) @parameterized.expand([TEST_CASE_MULTI_WSI]) def test_read_region_multi_wsi(self, file_path_list, patch_info, expected_img): @@ -320,14 +436,14 @@ def test_read_region_multi_wsi(self, file_path_list, patch_info, expected_img): for img_obj in img_obj_list: img_obj.close() self.assertTupleEqual(img.shape, img2.shape) - self.assertIsNone(assert_array_equal(img, img2)) + assert_allclose(img, img2) self.assertTupleEqual(img.shape, expected_img.shape) - self.assertIsNone(assert_array_equal(img, expected_img)) + assert_allclose(img, expected_img) self.assertEqual(meta["backend"], self.backend) self.assertEqual(meta[WSIPatchKeys.PATH][0].lower(), str(os.path.abspath(file_path_list[0])).lower()) self.assertEqual(meta[WSIPatchKeys.LEVEL][0], patch_info["level"]) - assert_array_equal(meta[WSIPatchKeys.SIZE][0], expected_img.shape[1:]) - assert_array_equal(meta[WSIPatchKeys.LOCATION][0], patch_info["location"]) + assert_allclose(meta[WSIPatchKeys.SIZE][0], expected_img.shape[1:], type_test=False) + assert_allclose(meta[WSIPatchKeys.LOCATION][0], patch_info["location"], type_test=False) @parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1]) @skipUnless(has_tiff, "Requires tifffile.") @@ -346,8 +462,8 @@ def test_read_rgba(self, img_expected): with reader.read(file_path) as img_obj: image[mode], _ = reader.get_data(img_obj) - self.assertIsNone(assert_array_equal(image["RGB"], img_expected)) - self.assertIsNone(assert_array_equal(image["RGBA"], img_expected)) + assert_allclose(image["RGB"], img_expected) + assert_allclose(image["RGBA"], img_expected) @parameterized.expand([TEST_CASE_ERROR_0C, TEST_CASE_ERROR_1C, TEST_CASE_ERROR_2C, TEST_CASE_ERROR_3D]) @skipUnless(has_tiff, "Requires tifffile.") @@ -368,7 +484,7 @@ def test_with_dataloader( ): train_transform = Compose( [ - LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level), + LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level, image_only=False), ToTensord(keys=["image"]), ] ) @@ -385,7 +501,7 @@ def test_with_dataloader_batch( ): train_transform = Compose( [ - LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level), + LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level, image_only=False), ToTensord(keys=["image"]), ] ) @@ -397,7 +513,7 @@ def test_with_dataloader_batch( assert_allclose(s, expected_spatial_shape, type_test=False) self.assertTupleEqual(data["image"].shape, (batch_size, *expected_shape[1:])) - @parameterized.expand([TEST_CASE_0]) + @parameterized.expand([TEST_CASE_WHOLE_0]) def test_read_whole_image_multi_thread(self, file_path, level, expected_shape): if self.backend == "cucim": reader = WSIReader(self.backend, level=level, num_workers=4) @@ -407,8 +523,8 @@ def test_read_whole_image_multi_thread(self, file_path, level, expected_shape): self.assertEqual(meta["backend"], self.backend) self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower()) self.assertEqual(meta[WSIPatchKeys.LEVEL], level) - assert_array_equal(meta[WSIPatchKeys.SIZE], expected_shape[1:]) - assert_array_equal(meta[WSIPatchKeys.LOCATION], (0, 0)) + assert_allclose(meta[WSIPatchKeys.SIZE], expected_shape[1:], type_test=False) + assert_allclose(meta[WSIPatchKeys.LOCATION], (0, 0), type_test=False) @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_read_region_multi_thread(self, file_path, kwargs, patch_info, expected_img): @@ -419,14 +535,14 @@ def test_read_region_multi_thread(self, file_path, kwargs, patch_info, expected_ img, meta = reader.get_data(img_obj, **patch_info) img2 = reader.get_data(img_obj, **patch_info)[0] self.assertTupleEqual(img.shape, img2.shape) - self.assertIsNone(assert_array_equal(img, img2)) + assert_allclose(img, img2) self.assertTupleEqual(img.shape, expected_img.shape) - self.assertIsNone(assert_array_equal(img, expected_img)) + assert_allclose(img, expected_img) self.assertEqual(meta["backend"], self.backend) self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower()) self.assertEqual(meta[WSIPatchKeys.LEVEL], patch_info["level"]) - assert_array_equal(meta[WSIPatchKeys.SIZE], patch_info["size"]) - assert_array_equal(meta[WSIPatchKeys.LOCATION], patch_info["location"]) + assert_allclose(meta[WSIPatchKeys.SIZE], patch_info["size"], type_test=False) + assert_allclose(meta[WSIPatchKeys.LOCATION], patch_info["location"], type_test=False) @parameterized.expand([TEST_CASE_MPP_0]) def test_resolution_mpp(self, file_path, level, expected_mpp): @@ -435,6 +551,42 @@ def test_resolution_mpp(self, file_path, level, expected_mpp): mpp = reader.get_mpp(img_obj, level) self.assertTupleEqual(mpp, expected_mpp) + @parameterized.expand( + [ + TEST_CASE_DEVICE_1, + TEST_CASE_DEVICE_2, + TEST_CASE_DEVICE_3, + TEST_CASE_DEVICE_4, + TEST_CASE_DEVICE_5, + TEST_CASE_DEVICE_6, + TEST_CASE_DEVICE_7, + ] + ) + @skip_if_no_cuda + def test_read_region_device(self, file_path, kwargs, patch_info, expected_img, device): + reader = WSIReader(self.backend, **kwargs) + level = patch_info.get("level", kwargs.get("level")) + if self.backend == "tifffile" and level < 2: + return + with reader.read(file_path) as img_obj: + # Read twice to check multiple calls + img, meta = reader.get_data(img_obj, **patch_info) + img2 = reader.get_data(img_obj, **patch_info)[0] + self.assertTupleEqual(img.shape, img2.shape) + assert_allclose(img, img2) + self.assertTupleEqual(img.shape, expected_img.shape) + assert_allclose(img, expected_img) + self.assertEqual(img.dtype, expected_img.dtype) + if isinstance(img, torch.Tensor): + self.assertEqual(img.device.type, device) + else: + self.assertEqual("cpu", device) + self.assertEqual(meta["backend"], self.backend) + self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower()) + self.assertEqual(meta[WSIPatchKeys.LEVEL], level) + assert_allclose(meta[WSIPatchKeys.SIZE], patch_info["size"], type_test=False) + assert_allclose(meta[WSIPatchKeys.LOCATION], patch_info["location"], type_test=False) + @skipUnless(has_cucim, "Requires cucim") class TestCuCIMDeprecated(WSIReaderDeprecatedTests.Tests):