Skip to content

Commit

Permalink
Zarr Merger (#6633)
Browse files Browse the repository at this point in the history
Fixes #6006 

### Description

This PR implements `ZarrAvgMerger` which can be used for patch
inference. Also a use case is demonstrated
[here](Project-MONAI/tutorials#1433).

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Behrooz <[email protected]>
  • Loading branch information
drbeh authored Jun 28, 2023
1 parent d4b9552 commit ae95bf9
Show file tree
Hide file tree
Showing 10 changed files with 549 additions and 9 deletions.
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ optuna
opencv-python-headless
onnx>=1.13.0
onnxruntime; python_version <= '3.10'
zarr
5 changes: 5 additions & 0 deletions docs/source/inferers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ Mergers
:members:
:special-members: __call__

`ZarrAvgMerger`
~~~~~~~~~~~~~~~
.. autoclass:: ZarrAvgMerger
:members:
:special-members: __call__


Sliding Window Inference Function
Expand Down
4 changes: 2 additions & 2 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
- The options are

```
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime]
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr]
```

which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`,
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, respectively.
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, and `zarr` respectively.

- `pip install 'monai[all]'` installs all the optional dependencies.
2 changes: 1 addition & 1 deletion monai/inferers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@
SlidingWindowInferer,
SlidingWindowInfererAdapt,
)
from .merger import AvgMerger, Merger
from .merger import AvgMerger, Merger, ZarrAvgMerger
from .splitter import SlidingWindowSplitter, Splitter, WSISlidingWindowSplitter
from .utils import sliding_window_inference
219 changes: 213 additions & 6 deletions monai/inferers/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,24 @@

from __future__ import annotations

import threading
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any

import numpy as np
import torch

from monai.utils import ensure_tuple_size
from monai.utils import ensure_tuple_size, optional_import, require_pkg

__all__ = ["Merger", "AvgMerger"]
if TYPE_CHECKING:
import zarr
else:
zarr, _ = optional_import("zarr")


__all__ = ["Merger", "AvgMerger", "ZarrAvgMerger"]


class Merger(ABC):
Expand Down Expand Up @@ -97,9 +106,9 @@ def __init__(
self,
merged_shape: Sequence[int],
cropped_shape: Sequence[int] | None = None,
device: torch.device | str = "cpu",
value_dtype: torch.dtype = torch.float32,
count_dtype: torch.dtype = torch.uint8,
device: torch.device | str = "cpu",
) -> None:
super().__init__(merged_shape=merged_shape, cropped_shape=cropped_shape, device=device)
if not self.merged_shape:
Expand Down Expand Up @@ -152,12 +161,21 @@ def finalize(self) -> torch.Tensor:

return self.values

def get_output(self) -> torch.Tensor:
"""
Get the final merged output.
Returns:
torch.Tensor: merged output.
"""
return self.finalize()

def get_values(self) -> torch.Tensor:
"""
Get the accumulated values during aggregation or final averaged values after it is finalized.
Returns:
Merged (averaged) output tensor.
torch.tensor: aggregated values.
Notes:
- If called before calling `finalize()`, this method returns the accumulating values.
Expand All @@ -170,6 +188,195 @@ def get_counts(self) -> torch.Tensor:
Get the aggregator tensor for number of samples.
Returns:
torch.Tensor: Number of accumulated samples at each location.
torch.Tensor: number of accumulated samples at each location.
"""
return self.counts


@require_pkg(pkg_name="zarr")
class ZarrAvgMerger(Merger):
"""Merge patches by taking average of the overlapping area and store the results in zarr array.
Zarr is a format for the storage of chunked, compressed, N-dimensional arrays.
Zarr data can be stored in any storage system that can be represented as a key-value store,
like POSIX file systems, cloud object storage, zip files, and relational and document databases.
See https://zarr.readthedocs.io/en/stable/ for more details.
It is particularly useful for storing N-dimensional arrays too large to fit into memory.
One specific use case of this class is to merge patches extracted from whole slide images (WSI),
where the merged results do not fit into memory and need to be stored on a file system.
Args:
merged_shape: the shape of the tensor required to merge the patches.
cropped_shape: the shape of the final merged output tensor.
If not provided, it will be the same as `merged_shape`.
dtype: the dtype for the final merged result. Default is `float32`.
value_dtype: the dtype for value aggregating tensor and the final result. Default is `float32`.
count_dtype: the dtype for sample counting tensor. Default is `uint8`.
store: the zarr store to save the final results. Default is "merged.zarr".
value_store: the zarr store to save the value aggregating tensor. Default is a temporary store.
count_store: the zarr store to save the sample counting tensor. Default is a temporary store.
compressor: the compressor for final merged zarr array. Default is "default".
value_compressor: the compressor for value aggregating zarr array. Default is None.
count_compressor: the compressor for sample counting zarr array. Default is None.
chunks : int or tuple of ints that defines the chunk shape, or boolean. Default is True.
If True, chunk shape will be guessed from `shape` and `dtype`.
If False, it will be set to `shape`, i.e., single chunk for the whole array.
If an int, the chunk size in each dimension will be given by the value of `chunks`.
"""

def __init__(
self,
merged_shape: Sequence[int],
cropped_shape: Sequence[int] | None = None,
dtype: np.dtype | str = "float32",
value_dtype: np.dtype | str = "float32",
count_dtype: np.dtype | str = "uint8",
store: zarr.storage.Store | str = "merged.zarr",
value_store: zarr.storage.Store | str | None = None,
count_store: zarr.storage.Store | str | None = None,
compressor: str = "default",
value_compressor: str | None = None,
count_compressor: str | None = None,
chunks: Sequence[int] | bool = True,
thread_locking: bool = True,
) -> None:
super().__init__(merged_shape=merged_shape, cropped_shape=cropped_shape)
if not self.merged_shape:
raise ValueError(f"`merged_shape` must be provided for `ZarrAvgMerger`. {self.merged_shape} is give.")
self.output_dtype = dtype
self.value_dtype = value_dtype
self.count_dtype = count_dtype
self.store = store
self.value_store = zarr.storage.TempStore() if value_store is None else value_store
self.count_store = zarr.storage.TempStore() if count_store is None else count_store
self.chunks = chunks
self.compressor = compressor
self.value_compressor = value_compressor
self.count_compressor = count_compressor
self.output = zarr.empty(
shape=self.merged_shape,
chunks=self.chunks,
dtype=self.output_dtype,
compressor=self.compressor,
store=self.store,
overwrite=True,
)
self.values = zarr.zeros(
shape=self.merged_shape,
chunks=self.chunks,
dtype=self.value_dtype,
compressor=self.value_compressor,
store=self.value_store,
overwrite=True,
)
self.counts = zarr.zeros(
shape=self.merged_shape,
chunks=self.chunks,
dtype=self.count_dtype,
compressor=self.count_compressor,
store=self.count_store,
overwrite=True,
)
self.lock: threading.Lock | nullcontext
if thread_locking:
# use lock to protect the in-place addition during aggregation
self.lock = threading.Lock()
else:
# use nullcontext to avoid the locking if not needed
self.lock = nullcontext()

def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None:
"""
Aggregate values for merging.
Args:
values: a tensor of shape BCHW[D], representing the values of inference output.
location: a tuple/list giving the top left location of the patch in the original image.
"""
if self.is_finalized:
raise ValueError("`ZarrAvgMerger` is already finalized. Please instantiate a new object to aggregate.")
patch_size = values.shape[2:]
map_slice = tuple(slice(loc, loc + size) for loc, size in zip(location, patch_size))
map_slice = ensure_tuple_size(map_slice, values.ndim, pad_val=slice(None), pad_from_start=True)
with self.lock:
self.values[map_slice] += values.numpy()
self.counts[map_slice] += 1

def finalize(self) -> zarr.Array:
"""
Finalize merging by dividing values by counts and return the merged tensor.
Notes:
To avoid creating a new tensor for the final results (to save memory space),
after this method is called, `get_values()` method will return the "final" averaged values,
and not the accumulating values. Also calling `finalize()` multiple times does not have any effect.
Returns:
zarr.Array: a zarr array of of merged patches
"""
# guard against multiple calls to finalize
if not self.is_finalized:
# use chunks for division to fit into memory
for chunk in iterate_over_chunks(self.values.chunks, self.values.cdata_shape):
self.output[chunk] = self.values[chunk] / self.counts[chunk]
# finalize the shape
self.output.resize(self.cropped_shape)
# set finalize flag to protect performing in-place division again
self.is_finalized = True

return self.output

def get_output(self) -> zarr.Array:
"""
Get the final merged output.
Returns:
zarr.Array: Merged (averaged) output tensor.
"""
return self.output

def get_values(self) -> zarr.Array:
"""
Get the accumulated values during aggregation
Returns:
zarr.Array: aggregated values.
"""
return self.values

def get_counts(self) -> zarr.Array:
"""
Get the aggregator tensor for number of samples.
Returns:
zarr.Array: Number of accumulated samples at each location.
"""
return self.counts


def iterate_over_chunks(chunks, cdata_shape, slice_tuple=()):
"""
Iterate over chunks of a given shape.
Args:
chunks: the chunk shape
cdata_shape: the shape of the data in chunks
slice_tuple: the slice tuple to be used for indexing
Raises:
ValueError: When the length of chunks and cdata_shape are not the same.
Yields:
slices of the data
"""
if len(chunks) != len(cdata_shape):
raise ValueError("chunks and cdata_shape must have the same length")
if len(chunks) == 1:
for i in range(cdata_shape[0]):
yield slice_tuple + (slice(i * chunks[0], (i + 1) * chunks[0]),)
else:
for i in range(cdata_shape[0]):
yield from iterate_over_chunks(
chunks[1:], cdata_shape[1:], slice_tuple + (slice(i * chunks[0], (i + 1) * chunks[0]),)
)
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@ onnx>=1.13.0
onnxruntime; python_version <= '3.10'
typeguard<3 # https://github.com/microsoft/nni/issues/5457
filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523
zarr
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ all =
optuna
onnx>=1.13.0
onnxruntime; python_version <= '3.10'
zarr
nibabel =
nibabel
ninja =
Expand Down Expand Up @@ -142,6 +143,8 @@ optuna =
onnx =
onnx>=1.13.0
onnxruntime; python_version <= '3.10'
zarr =
zarr
# # workaround https://github.com/Project-MONAI/MONAI/issues/5882
# MetricsReloaded =
# MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
Expand Down
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def run_testsuit():
"test_metrics_reloaded",
"test_spatial_combine_transforms",
"test_bundle_workflow",
"test_zarr_avg_merger",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"

Expand Down
1 change: 1 addition & 0 deletions tests/test_download_url_yandex.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@


class TestDownloadUrlYandex(unittest.TestCase):
@unittest.skip("data source unstable")
def test_verify(self):
with tempfile.TemporaryDirectory() as tempdir:
download_url(url=YANDEX_MODEL_URL, filepath=os.path.join(tempdir, "model.pt"))
Expand Down
Loading

0 comments on commit ae95bf9

Please sign in to comment.