Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zarr Merger #6633

Merged
merged 24 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b3c4ad4
add zarr dependency
drbeh Jun 19, 2023
5b7ea1d
Implement ZarrAvgMerger
drbeh Jun 19, 2023
571498c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 20, 2023
fe1cbc8
update docs
drbeh Jun 20, 2023
0a1aa3d
update docs
drbeh Jun 20, 2023
8e8594e
update docstring
drbeh Jun 20, 2023
1a9d0f8
add info about Zarr in docstring
drbeh Jun 21, 2023
5c51200
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 21, 2023
fdb3969
add unit tests
drbeh Jun 21, 2023
70bf188
exclude zarr tests from min tests
drbeh Jun 21, 2023
e74658b
add thread locks
drbeh Jun 26, 2023
bf33fb5
remove compression for temp zarr arrays
drbeh Jun 26, 2023
5bc1f94
add flexibility to define compressors for all zarr arrays
drbeh Jun 26, 2023
9bd6ce1
change to skip unless
drbeh Jun 26, 2023
3a348be
Merge branch 'dev' of https://github.com/Project-MONAI/MONAI into zar…
drbeh Jun 26, 2023
2a19fb5
Merge branch 'dev' into zarr-mergers
drbeh Jun 27, 2023
58176e7
Merge branch 'zarr-mergers' of https://github.com/drbeh/MONAI into za…
drbeh Jun 28, 2023
6c231ea
make thread locking optional
drbeh Jun 28, 2023
06c4d25
Merge branch 'dev' into zarr-mergers
drbeh Jun 28, 2023
8145c08
Update monai/inferers/merger.py
drbeh Jun 28, 2023
3b84aee
Update monai/inferers/merger.py
drbeh Jun 28, 2023
a7f4f93
Update monai/inferers/merger.py
drbeh Jun 28, 2023
52d8898
Merge branch 'dev' into zarr-mergers
drbeh Jun 28, 2023
7cb71a0
unblock premerge download test
wyli Jun 28, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ optuna
opencv-python-headless
onnx>=1.13.0
onnxruntime; python_version <= '3.10'
zarr
5 changes: 5 additions & 0 deletions docs/source/inferers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ Mergers
:members:
:special-members: __call__

`ZarrAvgMerger`
~~~~~~~~~~~~~~~
.. autoclass:: ZarrAvgMerger
:members:
:special-members: __call__


Sliding Window Inference Function
Expand Down
4 changes: 2 additions & 2 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
- The options are

```
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime]
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr]
```

which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`,
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, respectively.
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, and `zarr` respectively.

- `pip install 'monai[all]'` installs all the optional dependencies.
2 changes: 1 addition & 1 deletion monai/inferers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@
SlidingWindowInferer,
SlidingWindowInfererAdapt,
)
from .merger import AvgMerger, Merger
from .merger import AvgMerger, Merger, ZarrAvgMerger
from .splitter import SlidingWindowSplitter, Splitter, WSISlidingWindowSplitter
from .utils import sliding_window_inference
219 changes: 213 additions & 6 deletions monai/inferers/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,24 @@

from __future__ import annotations

import threading
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any

import numpy as np
import torch

from monai.utils import ensure_tuple_size
from monai.utils import ensure_tuple_size, optional_import, require_pkg

__all__ = ["Merger", "AvgMerger"]
if TYPE_CHECKING:
import zarr
else:
zarr, _ = optional_import("zarr")


__all__ = ["Merger", "AvgMerger", "ZarrAvgMerger"]


class Merger(ABC):
Expand Down Expand Up @@ -97,9 +106,9 @@ def __init__(
self,
merged_shape: Sequence[int],
cropped_shape: Sequence[int] | None = None,
device: torch.device | str = "cpu",
value_dtype: torch.dtype = torch.float32,
count_dtype: torch.dtype = torch.uint8,
device: torch.device | str = "cpu",
) -> None:
super().__init__(merged_shape=merged_shape, cropped_shape=cropped_shape, device=device)
if not self.merged_shape:
Expand Down Expand Up @@ -152,12 +161,21 @@ def finalize(self) -> torch.Tensor:

return self.values

def get_output(self) -> torch.Tensor:
"""
Get the final merged output.

Returns:
torch.Tensor: merged output.
"""
return self.finalize()

def get_values(self) -> torch.Tensor:
"""
Get the accumulated values during aggregation or final averaged values after it is finalized.
drbeh marked this conversation as resolved.
Show resolved Hide resolved

Returns:
Merged (averaged) output tensor.
torch.tensor: aggregated values.

Notes:
- If called before calling `finalize()`, this method returns the accumulating values.
Expand All @@ -170,6 +188,195 @@ def get_counts(self) -> torch.Tensor:
Get the aggregator tensor for number of samples.

Returns:
torch.Tensor: Number of accumulated samples at each location.
torch.Tensor: number of accumulated samples at each location.
"""
return self.counts


@require_pkg(pkg_name="zarr")
class ZarrAvgMerger(Merger):
"""Merge patches by taking average of the overlapping area and store the results in zarr array.
drbeh marked this conversation as resolved.
Show resolved Hide resolved

Zarr is a format for the storage of chunked, compressed, N-dimensional arrays.
Zarr data can be stored in any storage system that can be represented as a key-value store,
like POSIX file systems, cloud object storage, zip files, and relational and document databases.
See https://zarr.readthedocs.io/en/stable/ for more details.
It is particularly useful for storing N-dimensional arrays too large to fit into memory.
One specific use case of this class is to merge patches extracted from whole slide images (WSI),
where the merged results do not fit into memory and need to be stored on a file system.

Args:
merged_shape: the shape of the tensor required to merge the patches.
cropped_shape: the shape of the final merged output tensor.
If not provided, it will be the same as `merged_shape`.
dtype: the dtype for the final merged result. Default is `float32`.
value_dtype: the dtype for value aggregating tensor and the final result. Default is `float32`.
count_dtype: the dtype for sample counting tensor. Default is `uint8`.
store: the zarr store to save the final results. Default is "merged.zarr".
value_store: the zarr store to save the value aggregating tensor. Default is a temporary store.
count_store: the zarr store to save the sample counting tensor. Default is a temporary store.
compressor: the compressor for final merged zarr array. Default is "default".
value_compressor: the compressor for value aggregating zarr array. Default is None.
count_compressor: the compressor for sample counting zarr array. Default is None.
chunks : int or tuple of ints that defines the chunk shape, or boolean. Default is True.
If True, chunk shape will be guessed from `shape` and `dtype`.
If False, it will be set to `shape`, i.e., single chunk for the whole array.
If an int, the chunk size in each dimension will be given by the value of `chunks`.
"""

def __init__(
self,
merged_shape: Sequence[int],
cropped_shape: Sequence[int] | None = None,
dtype: np.dtype | str = "float32",
value_dtype: np.dtype | str = "float32",
count_dtype: np.dtype | str = "uint8",
store: zarr.storage.Store | str = "merged.zarr",
value_store: zarr.storage.Store | str | None = None,
count_store: zarr.storage.Store | str | None = None,
compressor: str = "default",
drbeh marked this conversation as resolved.
Show resolved Hide resolved
value_compressor: str | None = None,
count_compressor: str | None = None,
chunks: Sequence[int] | bool = True,
thread_locking: bool = True,
) -> None:
super().__init__(merged_shape=merged_shape, cropped_shape=cropped_shape)
if not self.merged_shape:
raise ValueError(f"`merged_shape` must be provided for `ZarrAvgMerger`. {self.merged_shape} is give.")
self.output_dtype = dtype
self.value_dtype = value_dtype
self.count_dtype = count_dtype
self.store = store
self.value_store = zarr.storage.TempStore() if value_store is None else value_store
self.count_store = zarr.storage.TempStore() if count_store is None else count_store
self.chunks = chunks
self.compressor = compressor
self.value_compressor = value_compressor
self.count_compressor = count_compressor
self.output = zarr.empty(
shape=self.merged_shape,
chunks=self.chunks,
dtype=self.output_dtype,
compressor=self.compressor,
store=self.store,
overwrite=True,
)
self.values = zarr.zeros(
shape=self.merged_shape,
chunks=self.chunks,
dtype=self.value_dtype,
compressor=self.value_compressor,
store=self.value_store,
overwrite=True,
)
self.counts = zarr.zeros(
shape=self.merged_shape,
chunks=self.chunks,
dtype=self.count_dtype,
compressor=self.count_compressor,
store=self.count_store,
overwrite=True,
)
self.lock: threading.Lock | nullcontext
if thread_locking:
# use lock to protect the in-place addition during aggregation
self.lock = threading.Lock()
else:
# use nullcontext to avoid the locking if not needed
self.lock = nullcontext()

def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None:
"""
Aggregate values for merging.

Args:
values: a tensor of shape BCHW[D], representing the values of inference output.
location: a tuple/list giving the top left location of the patch in the original image.
"""
if self.is_finalized:
raise ValueError("`ZarrAvgMerger` is already finalized. Please instantiate a new object to aggregate.")
patch_size = values.shape[2:]
map_slice = tuple(slice(loc, loc + size) for loc, size in zip(location, patch_size))
map_slice = ensure_tuple_size(map_slice, values.ndim, pad_val=slice(None), pad_from_start=True)
with self.lock:
self.values[map_slice] += values.numpy()
self.counts[map_slice] += 1

def finalize(self) -> zarr.Array:
drbeh marked this conversation as resolved.
Show resolved Hide resolved
"""
Finalize merging by dividing values by counts and return the merged tensor.

Notes:
To avoid creating a new tensor for the final results (to save memory space),
after this method is called, `get_values()` method will return the "final" averaged values,
and not the accumulating values. Also calling `finalize()` multiple times does not have any effect.

Returns:
zarr.Array: a zarr array of of merged patches
"""
# guard against multiple calls to finalize
if not self.is_finalized:
# use chunks for division to fit into memory
for chunk in iterate_over_chunks(self.values.chunks, self.values.cdata_shape):
self.output[chunk] = self.values[chunk] / self.counts[chunk]
# finalize the shape
self.output.resize(self.cropped_shape)
# set finalize flag to protect performing in-place division again
self.is_finalized = True

return self.output

def get_output(self) -> zarr.Array:
"""
Get the final merged output.

Returns:
zarr.Array: Merged (averaged) output tensor.
"""
return self.output

def get_values(self) -> zarr.Array:
"""
Get the accumulated values during aggregation

Returns:
zarr.Array: aggregated values.

"""
return self.values

def get_counts(self) -> zarr.Array:
"""
Get the aggregator tensor for number of samples.

Returns:
zarr.Array: Number of accumulated samples at each location.
drbeh marked this conversation as resolved.
Show resolved Hide resolved
"""
return self.counts


def iterate_over_chunks(chunks, cdata_shape, slice_tuple=()):
"""
Iterate over chunks of a given shape.

Args:
chunks: the chunk shape
cdata_shape: the shape of the data in chunks
slice_tuple: the slice tuple to be used for indexing

Raises:
ValueError: When the length of chunks and cdata_shape are not the same.

Yields:
slices of the data
"""
if len(chunks) != len(cdata_shape):
raise ValueError("chunks and cdata_shape must have the same length")
if len(chunks) == 1:
for i in range(cdata_shape[0]):
yield slice_tuple + (slice(i * chunks[0], (i + 1) * chunks[0]),)
else:
for i in range(cdata_shape[0]):
yield from iterate_over_chunks(
chunks[1:], cdata_shape[1:], slice_tuple + (slice(i * chunks[0], (i + 1) * chunks[0]),)
)
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@ onnx>=1.13.0
onnxruntime; python_version <= '3.10'
typeguard<3 # https://github.com/microsoft/nni/issues/5457
filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523
zarr
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ all =
optuna
onnx>=1.13.0
onnxruntime; python_version <= '3.10'
zarr
nibabel =
nibabel
ninja =
Expand Down Expand Up @@ -142,6 +143,8 @@ optuna =
onnx =
onnx>=1.13.0
onnxruntime; python_version <= '3.10'
zarr =
zarr
# # workaround https://github.com/Project-MONAI/MONAI/issues/5882
# MetricsReloaded =
# MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
Expand Down
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def run_testsuit():
"test_metrics_reloaded",
"test_spatial_combine_transforms",
"test_bundle_workflow",
"test_zarr_avg_merger",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"

Expand Down
1 change: 1 addition & 0 deletions tests/test_download_url_yandex.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@


class TestDownloadUrlYandex(unittest.TestCase):
@unittest.skip("data source unstable")
def test_verify(self):
with tempfile.TemporaryDirectory() as tempdir:
download_url(url=YANDEX_MODEL_URL, filepath=os.path.join(tempdir, "model.pt"))
Expand Down
Loading