Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WSIReader defaults and tensor conversion #6058

Merged
merged 28 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
bce46a5
Update wsi reader to have object defaults for mode and dtype
drbeh Feb 22, 2023
d0c0318
update comments
drbeh Feb 23, 2023
12a353f
add device
drbeh Feb 23, 2023
fe57e79
add unittests
drbeh Feb 23, 2023
5ae1822
Merge branch 'dev' into fix-6056-wsireader-defaults
drbeh Feb 23, 2023
0f73d95
Merge branch 'dev' of github.com:Project-MONAI/MONAI into fix-6056-ws…
drbeh Mar 1, 2023
d48d3c9
add supported backends
drbeh Mar 1, 2023
d0748cc
Merge branch 'dev' of github.com:Project-MONAI/MONAI into fix-6056-ws…
drbeh Mar 2, 2023
153b14c
Merge branch 'dev' of github.com:Project-MONAI/MONAI into fix-6056-ws…
drbeh Mar 6, 2023
0ae2536
formatting
drbeh Mar 6, 2023
cf1f93b
Merge branch 'dev' into fix-6056-wsireader-defaults
drbeh Mar 7, 2023
e4702d7
device conversion for numpy dtype and added tests
drbeh Mar 7, 2023
a36fb63
fix device comparison
drbeh Mar 8, 2023
ddb9fdc
remove redundant type
drbeh Mar 8, 2023
0f3355b
remove redundant type
drbeh Mar 8, 2023
9e17c01
change from not cpu to cuda
drbeh Mar 8, 2023
e242b65
torch dtype
drbeh Mar 8, 2023
c4acdc5
simplify get_data args, cleanup and update docstring
drbeh Mar 9, 2023
a06f95d
typing
drbeh Mar 9, 2023
3a7baff
Merge branch 'dev' into fix-6056-wsireader-defaults
drbeh Mar 9, 2023
8b85a10
add unittests for none device and dtype
drbeh Mar 9, 2023
762e6a3
update dtype float32
drbeh Mar 9, 2023
6effd94
correct numpy default for None
drbeh Mar 9, 2023
fec192f
Merge branch 'dev' into fix-6056-wsireader-defaults
drbeh Mar 13, 2023
0d0d6c1
remove test case 0 to check on ubuntu quick test
drbeh Mar 13, 2023
e51fe48
read at level=8 for test case 0 to save memory
drbeh Mar 13, 2023
dab7c41
remove default conversion of None device
drbeh Mar 14, 2023
bf37fda
Merge branch 'dev' into fix-6056-wsireader-defaults
wyli Mar 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 109 additions & 18 deletions monai/data/wsi_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
from typing import Any

import numpy as np
import torch

from monai.config import DtypeLike, PathLike
from monai.config import DtypeLike, NdarrayOrTensor, PathLike
from monai.data.image_reader import ImageReader, _stack_images
from monai.data.utils import is_supported_format
from monai.utils import WSIPatchKeys, ensure_tuple, optional_import, require_pkg
from monai.utils import WSIPatchKeys, dtype_torch_to_numpy, ensure_tuple, optional_import, require_pkg

CuImage, _ = optional_import("cucim", name="CuImage")
OpenSlide, _ = optional_import("openslide", name="OpenSlide")
Expand All @@ -34,12 +35,19 @@ class BaseWSIReader(ImageReader):
"""
An abstract class that defines APIs to load patches from whole slide image files.

Args:
level: the whole slide image level at which the image is extracted.
channel_dim: the desired dimension for color channel.
dtype: the data type of output image.
mode: the output image color mode, e.g., "RGB" or "RGBA".
kwargs: additional args for the reader
drbeh marked this conversation as resolved.
Show resolved Hide resolved

Typical usage of a concrete implementation of this class is:

.. code-block:: python

image_reader = MyWSIReader()
wsi = image_reader.read(, **kwargs)
wsi = image_reader.read(filepath, **kwargs)
img_data, meta_data = image_reader.get_data(wsi)

- The `read` call converts an image filename into whole slide image object,
Expand All @@ -59,10 +67,21 @@ class BaseWSIReader(ImageReader):
supported_suffixes: list[str] = []
backend = ""

def __init__(self, level: int = 0, channel_dim: int = 0, **kwargs):
def __init__(
self,
level: int,
channel_dim: int,
dtype: DtypeLike | torch.dtype,
device: torch.device | str,
mode: str,
**kwargs,
):
super().__init__()
self.level = level
self.channel_dim = channel_dim
self.dtype = dtype
self.device = device
self.mode = mode
self.kwargs = kwargs
self.metadata: dict[Any, Any] = {}

Expand Down Expand Up @@ -139,7 +158,7 @@ def _get_patch(
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")

def _get_metadata(
self, wsi, patch: np.ndarray, location: tuple[int, int], size: tuple[int, int], level: int
self, wsi, patch: NdarrayOrTensor, location: tuple[int, int], size: tuple[int, int], level: int
) -> dict:
"""
Returns metadata of the extracted patch from the whole slide image.
Expand Down Expand Up @@ -176,8 +195,9 @@ def get_data(
location: tuple[int, int] = (0, 0),
size: tuple[int, int] | None = None,
level: int | None = None,
dtype: DtypeLike = np.uint8,
mode: str = "RGB",
dtype: DtypeLike | torch.dtype | None = None,
device: torch.device | str | None = None,
mode: str | None = None,
) -> tuple[np.ndarray, dict]:
"""
Verifies inputs, extracts patches from WSI image and generates metadata, and return them.
Expand All @@ -186,15 +206,21 @@ def get_data(
wsi: a whole slide image object loaded from a file or a list of such objects
location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).
size: (height, width) tuple giving the patch size at the given level (`level`).
If None, it is set to the full image size at the given level.
If not provided or None, it is set to the full image size at the given level.
level: the level number. Defaults to 0
dtype: the data type of output image
mode: the output image mode, 'RGB' or 'RGBA'
dtype: the data type of output image. If not provided the default of `np.uint8` is used.
mode: the output image color mode, "RGB" or "RGBA". If not provided the default of "RGB" is used.

Returns:
a tuples, where the first element is an image patch [CxHxW] or stack of patches,
and second element is a dictionary of metadata
"""
if dtype is None:
dtype = self.dtype
if device is None:
device = self.device
if mode is None:
mode = self.mode
patch_list: list = []
metadata_list: list = []
# CuImage object is iterable, so ensure_tuple won't work on single object
Expand Down Expand Up @@ -224,8 +250,17 @@ def get_data(
if size[0] <= 0 or size[1] <= 0:
raise ValueError(f"Patch size should be greater than zero, provided: patch size = {size}")

# Get numpy dtype if it is not already.
np_dtype = dtype_torch_to_numpy(dtype) if isinstance(dtype, torch.dtype) else dtype
# Extract a patch or the entire image
patch = self._get_patch(each_wsi, location=location, size=size, level=level, dtype=dtype, mode=mode)
patch: NdarrayOrTensor
patch = self._get_patch(each_wsi, location=location, size=size, level=level, dtype=np_dtype, mode=mode)
# Convert the patch to torch.Tensor if dtype is torch
if isinstance(dtype, torch.dtype):
if patch.flags["WRITEABLE"]:
patch = torch.as_tensor(patch, dtype=dtype, device=device)
else:
patch = torch.tensor(patch, dtype=dtype, device=device)

# check if the image has three dimensions (2D + color)
if patch.ndim != 3:
Expand Down Expand Up @@ -281,21 +316,40 @@ class WSIReader(BaseWSIReader):
backend: the name of backend whole slide image reader library, the default is cuCIM.
level: the level at which patches are extracted.
channel_dim: the desired dimension for color channel. Default to 0 (channel first).
dtype: the data type of output image. Defaults to `np.uint8`.
mode: the output image color mode, "RGB" or "RGBA". Defaults to "RGB".
drbeh marked this conversation as resolved.
Show resolved Hide resolved
num_workers: number of workers for multi-thread image loading (cucim backend only).
kwargs: additional arguments to be passed to the backend library

"""

def __init__(self, backend="cucim", level: int = 0, channel_dim: int = 0, **kwargs):
super().__init__(level, channel_dim, **kwargs)
supported_backends = ["cucim", "openslide", "tifffile"]

def __init__(
self,
backend="cucim",
level: int = 0,
channel_dim: int = 0,
dtype: DtypeLike | torch.dtype = np.uint8,
device: torch.device | str = "cpu",
mode: str = "RGB",
**kwargs,
):
super().__init__(level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs)
self.backend = backend.lower()
self.reader: CuCIMWSIReader | OpenSlideWSIReader | TiffFileWSIReader
if self.backend == "cucim":
self.reader = CuCIMWSIReader(level=level, channel_dim=channel_dim, **kwargs)
self.reader = CuCIMWSIReader(
level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs
)
elif self.backend == "openslide":
self.reader = OpenSlideWSIReader(level=level, channel_dim=channel_dim, **kwargs)
self.reader = OpenSlideWSIReader(
level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs
)
elif self.backend == "tifffile":
self.reader = TiffFileWSIReader(level=level, channel_dim=channel_dim, **kwargs)
self.reader = TiffFileWSIReader(
level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs
)
else:
raise ValueError(
f"The supported backends are cucim, openslide, and tifffile but '{self.backend}' was given."
Expand Down Expand Up @@ -403,6 +457,8 @@ class CuCIMWSIReader(BaseWSIReader):
level: the whole slide image level at which the image is extracted. (default=0)
This is overridden if the level argument is provided in `get_data`.
channel_dim: the desired dimension for color channel. Default to 0 (channel first).
dtype: the data type of output image. Defaults to `np.uint8`.
mode: the output image color mode, "RGB" or "RGBA". Defaults to "RGB".
drbeh marked this conversation as resolved.
Show resolved Hide resolved
num_workers: number of workers for multi-thread image loading
kwargs: additional args for `cucim.CuImage` module:
https://github.com/rapidsai/cucim/blob/main/cpp/include/cucim/cuimage.h
Expand All @@ -412,8 +468,17 @@ class CuCIMWSIReader(BaseWSIReader):
supported_suffixes = ["tif", "tiff", "svs"]
backend = "cucim"

def __init__(self, level: int = 0, channel_dim: int = 0, num_workers: int = 0, **kwargs):
super().__init__(level, channel_dim, **kwargs)
def __init__(
self,
level: int = 0,
channel_dim: int = 0,
dtype: DtypeLike | torch.dtype = np.uint8,
device: torch.device | str = "cpu",
mode: str = "RGB",
num_workers: int = 0,
**kwargs,
):
super().__init__(level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs)
self.num_workers = num_workers

@staticmethod
Expand Down Expand Up @@ -551,13 +616,26 @@ class OpenSlideWSIReader(BaseWSIReader):
level: the whole slide image level at which the image is extracted. (default=0)
This is overridden if the level argument is provided in `get_data`.
channel_dim: the desired dimension for color channel. Default to 0 (channel first).
dtype: the data type of output image. Defaults to `np.uint8`.
mode: the output image color mode, "RGB" or "RGBA". Defaults to "RGB".
drbeh marked this conversation as resolved.
Show resolved Hide resolved
kwargs: additional args for `openslide.OpenSlide` module.

"""

supported_suffixes = ["tif", "tiff", "svs"]
backend = "openslide"

def __init__(
self,
level: int = 0,
channel_dim: int = 0,
dtype: DtypeLike | torch.dtype = np.uint8,
device: torch.device | str = "cpu",
mode: str = "RGB",
**kwargs,
):
super().__init__(level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs)

@staticmethod
def get_level_count(wsi) -> int:
"""
Expand Down Expand Up @@ -695,13 +773,26 @@ class TiffFileWSIReader(BaseWSIReader):
level: the whole slide image level at which the image is extracted. (default=0)
This is overridden if the level argument is provided in `get_data`.
channel_dim: the desired dimension for color channel. Default to 0 (channel first).
dtype: the data type of output image. Defaults to `np.uint8`.
mode: the output image color mode, "RGB" or "RGBA". Defaults to "RGB".
kwargs: additional args for `tifffile.TiffFile` module.

"""

supported_suffixes = ["tif", "tiff", "svs"]
backend = "tifffile"

def __init__(
self,
level: int = 0,
channel_dim: int = 0,
dtype: DtypeLike | torch.dtype = np.uint8,
device: torch.device | str = "cpu",
mode: str = "RGB",
**kwargs,
):
super().__init__(level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs)

@staticmethod
def get_level_count(wsi) -> int:
"""
Expand Down
Loading