Skip to content

Commit

Permalink
WSIReader defaults and tensor conversion (#6058)
Browse files Browse the repository at this point in the history
Fixes #6056 

### Description

This PR enable to instantiate `WSIReader` classes with default values
for `mode` and `dtype`. It also accepts `torch.dtype` as dtype and if
`torch.dtype` is provided, the output will be a `torch.Tensor` to avoid
additional step to convert the output for inference.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.

---------

Signed-off-by: Behrooz <[email protected]>
  • Loading branch information
drbeh authored Mar 14, 2023
1 parent 589c711 commit 01b6d70
Show file tree
Hide file tree
Showing 2 changed files with 331 additions and 51 deletions.
164 changes: 146 additions & 18 deletions monai/data/wsi_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,19 @@
from typing import Any

import numpy as np
import torch

from monai.config import DtypeLike, PathLike
from monai.config import DtypeLike, NdarrayOrTensor, PathLike
from monai.data.image_reader import ImageReader, _stack_images
from monai.data.utils import is_supported_format
from monai.utils import WSIPatchKeys, ensure_tuple, optional_import, require_pkg
from monai.utils import (
WSIPatchKeys,
dtype_numpy_to_torch,
dtype_torch_to_numpy,
ensure_tuple,
optional_import,
require_pkg,
)

OpenSlide, _ = optional_import("openslide", name="OpenSlide")
TiffFile, _ = optional_import("tifffile", name="TiffFile")
Expand All @@ -33,12 +41,21 @@ class BaseWSIReader(ImageReader):
"""
An abstract class that defines APIs to load patches from whole slide image files.
Args:
level: the whole slide image level at which the image is extracted.
channel_dim: the desired dimension for color channel.
dtype: the data type of output image.
device: target device to put the extracted patch. Note that if device is "cuda"",
the output will be converted to torch tenor and sent to the gpu even if the dtype is numpy.
mode: the output image color mode, e.g., "RGB" or "RGBA".
kwargs: additional args for the reader
Typical usage of a concrete implementation of this class is:
.. code-block:: python
image_reader = MyWSIReader()
wsi = image_reader.read(, **kwargs)
wsi = image_reader.read(filepath, **kwargs)
img_data, meta_data = image_reader.get_data(wsi)
- The `read` call converts an image filename into whole slide image object,
Expand All @@ -58,13 +75,37 @@ class BaseWSIReader(ImageReader):
supported_suffixes: list[str] = []
backend = ""

def __init__(self, level: int = 0, channel_dim: int = 0, **kwargs):
def __init__(
self,
level: int,
channel_dim: int,
dtype: DtypeLike | torch.dtype,
device: torch.device | str | None,
mode: str,
**kwargs,
):
super().__init__()
self.level = level
self.channel_dim = channel_dim
self.set_dtype(dtype)
self.set_device(device)
self.mode = mode
self.kwargs = kwargs
self.metadata: dict[Any, Any] = {}

def set_dtype(self, dtype):
self.dtype: torch.dtype | np.dtype
if isinstance(dtype, torch.dtype):
self.dtype = dtype
else:
self.dtype = np.dtype(dtype)

def set_device(self, device):
if device is None or isinstance(device, (torch.device, str)):
self.device = device
else:
raise ValueError(f"`device` must be `torch.device`, `str` or `None` but {type(device)} is given.")

@abstractmethod
def get_size(self, wsi, level: int | None = None) -> tuple[int, int]:
"""
Expand Down Expand Up @@ -138,7 +179,7 @@ def _get_patch(
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")

def _get_metadata(
self, wsi, patch: np.ndarray, location: tuple[int, int], size: tuple[int, int], level: int
self, wsi, patch: NdarrayOrTensor, location: tuple[int, int], size: tuple[int, int], level: int
) -> dict:
"""
Returns metadata of the extracted patch from the whole slide image.
Expand Down Expand Up @@ -175,8 +216,7 @@ def get_data(
location: tuple[int, int] = (0, 0),
size: tuple[int, int] | None = None,
level: int | None = None,
dtype: DtypeLike = np.uint8,
mode: str = "RGB",
mode: str | None = None,
) -> tuple[np.ndarray, dict]:
"""
Verifies inputs, extracts patches from WSI image and generates metadata, and return them.
Expand All @@ -185,15 +225,16 @@ def get_data(
wsi: a whole slide image object loaded from a file or a list of such objects
location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).
size: (height, width) tuple giving the patch size at the given level (`level`).
If None, it is set to the full image size at the given level.
If not provided or None, it is set to the full image size at the given level.
level: the level number. Defaults to 0
dtype: the data type of output image
mode: the output image mode, 'RGB' or 'RGBA'
mode: the output image color mode, "RGB" or "RGBA". If not provided the default of "RGB" is used.
Returns:
a tuples, where the first element is an image patch [CxHxW] or stack of patches,
and second element is a dictionary of metadata
"""
if mode is None:
mode = self.mode
patch_list: list = []
metadata_list: list = []
# CuImage object is iterable, so ensure_tuple won't work on single object
Expand Down Expand Up @@ -223,8 +264,25 @@ def get_data(
if size[0] <= 0 or size[1] <= 0:
raise ValueError(f"Patch size should be greater than zero, provided: patch size = {size}")

# Get numpy dtype if it is not already.
dtype_np = dtype_torch_to_numpy(self.dtype) if isinstance(self.dtype, torch.dtype) else self.dtype
# Extract a patch or the entire image
patch = self._get_patch(each_wsi, location=location, size=size, level=level, dtype=dtype, mode=mode)
patch: NdarrayOrTensor
patch = self._get_patch(each_wsi, location=location, size=size, level=level, dtype=dtype_np, mode=mode)

# Convert the patch to torch.Tensor if dtype is torch
if isinstance(self.dtype, torch.dtype) or (
self.device is not None and torch.device(self.device).type == "cuda"
):
# Ensure dtype is torch.dtype if the device is not "cpu"
dtype_torch = (
dtype_numpy_to_torch(self.dtype) if not isinstance(self.dtype, torch.dtype) else self.dtype
)
# Copy the numpy array if it is not writable
if patch.flags["WRITEABLE"]:
patch = torch.as_tensor(patch, dtype=dtype_torch, device=self.device)
else:
patch = torch.tensor(patch, dtype=dtype_torch, device=self.device)

# check if the image has three dimensions (2D + color)
if patch.ndim != 3:
Expand Down Expand Up @@ -280,26 +338,53 @@ class WSIReader(BaseWSIReader):
backend: the name of backend whole slide image reader library, the default is cuCIM.
level: the level at which patches are extracted.
channel_dim: the desired dimension for color channel. Default to 0 (channel first).
dtype: the data type of output image. Defaults to `np.uint8`.
mode: the output image color mode, "RGB" or "RGBA". Defaults to "RGB".
device: target device to put the extracted patch. Note that if device is "cuda"",
the output will be converted to torch tenor and sent to the gpu even if the dtype is numpy.
num_workers: number of workers for multi-thread image loading (cucim backend only).
kwargs: additional arguments to be passed to the backend library
"""

def __init__(self, backend="cucim", level: int = 0, channel_dim: int = 0, **kwargs):
super().__init__(level, channel_dim, **kwargs)
supported_backends = ["cucim", "openslide", "tifffile"]

def __init__(
self,
backend="cucim",
level: int = 0,
channel_dim: int = 0,
dtype: DtypeLike | torch.dtype = np.uint8,
device: torch.device | str | None = None,
mode: str = "RGB",
**kwargs,
):
self.backend = backend.lower()
self.reader: CuCIMWSIReader | OpenSlideWSIReader | TiffFileWSIReader
if self.backend == "cucim":
self.reader = CuCIMWSIReader(level=level, channel_dim=channel_dim, **kwargs)
self.reader = CuCIMWSIReader(
level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs
)
elif self.backend == "openslide":
self.reader = OpenSlideWSIReader(level=level, channel_dim=channel_dim, **kwargs)
self.reader = OpenSlideWSIReader(
level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs
)
elif self.backend == "tifffile":
self.reader = TiffFileWSIReader(level=level, channel_dim=channel_dim, **kwargs)
self.reader = TiffFileWSIReader(
level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs
)
else:
raise ValueError(
f"The supported backends are cucim, openslide, and tifffile but '{self.backend}' was given."
)
self.supported_suffixes = self.reader.supported_suffixes
self.level = self.reader.level
self.channel_dim = self.reader.channel_dim
self.dtype = self.reader.dtype
self.device = self.reader.device
self.mode = self.reader.mode
self.kwargs = self.reader.kwargs
self.metadata = self.reader.metadata

def get_level_count(self, wsi) -> int:
"""
Expand Down Expand Up @@ -402,6 +487,10 @@ class CuCIMWSIReader(BaseWSIReader):
level: the whole slide image level at which the image is extracted. (default=0)
This is overridden if the level argument is provided in `get_data`.
channel_dim: the desired dimension for color channel. Default to 0 (channel first).
dtype: the data type of output image. Defaults to `np.uint8`.
device: target device to put the extracted patch. Note that if device is "cuda"",
the output will be converted to torch tenor and sent to the gpu even if the dtype is numpy.
mode: the output image color mode, "RGB" or "RGBA". Defaults to "RGB".
num_workers: number of workers for multi-thread image loading
kwargs: additional args for `cucim.CuImage` module:
https://github.com/rapidsai/cucim/blob/main/cpp/include/cucim/cuimage.h
Expand All @@ -411,8 +500,17 @@ class CuCIMWSIReader(BaseWSIReader):
supported_suffixes = ["tif", "tiff", "svs"]
backend = "cucim"

def __init__(self, level: int = 0, channel_dim: int = 0, num_workers: int = 0, **kwargs):
super().__init__(level, channel_dim, **kwargs)
def __init__(
self,
level: int = 0,
channel_dim: int = 0,
dtype: DtypeLike | torch.dtype = np.uint8,
device: torch.device | str | None = None,
mode: str = "RGB",
num_workers: int = 0,
**kwargs,
):
super().__init__(level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs)
self.num_workers = num_workers

@staticmethod
Expand Down Expand Up @@ -551,13 +649,28 @@ class OpenSlideWSIReader(BaseWSIReader):
level: the whole slide image level at which the image is extracted. (default=0)
This is overridden if the level argument is provided in `get_data`.
channel_dim: the desired dimension for color channel. Default to 0 (channel first).
dtype: the data type of output image. Defaults to `np.uint8`.
device: target device to put the extracted patch. Note that if device is "cuda"",
the output will be converted to torch tenor and sent to the gpu even if the dtype is numpy.
mode: the output image color mode, "RGB" or "RGBA". Defaults to "RGB".
kwargs: additional args for `openslide.OpenSlide` module.
"""

supported_suffixes = ["tif", "tiff", "svs"]
backend = "openslide"

def __init__(
self,
level: int = 0,
channel_dim: int = 0,
dtype: DtypeLike | torch.dtype = np.uint8,
device: torch.device | str | None = None,
mode: str = "RGB",
**kwargs,
):
super().__init__(level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs)

@staticmethod
def get_level_count(wsi) -> int:
"""
Expand Down Expand Up @@ -695,13 +808,28 @@ class TiffFileWSIReader(BaseWSIReader):
level: the whole slide image level at which the image is extracted. (default=0)
This is overridden if the level argument is provided in `get_data`.
channel_dim: the desired dimension for color channel. Default to 0 (channel first).
dtype: the data type of output image. Defaults to `np.uint8`.
device: target device to put the extracted patch. Note that if device is "cuda"",
the output will be converted to torch tenor and sent to the gpu even if the dtype is numpy.
mode: the output image color mode, "RGB" or "RGBA". Defaults to "RGB".
kwargs: additional args for `tifffile.TiffFile` module.
"""

supported_suffixes = ["tif", "tiff", "svs"]
backend = "tifffile"

def __init__(
self,
level: int = 0,
channel_dim: int = 0,
dtype: DtypeLike | torch.dtype = np.uint8,
device: torch.device | str | None = None,
mode: str = "RGB",
**kwargs,
):
super().__init__(level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs)

@staticmethod
def get_level_count(wsi) -> int:
"""
Expand Down
Loading

0 comments on commit 01b6d70

Please sign in to comment.