Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cache option in GridPatchDataset #7180

Merged
merged 32 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
2c16613
fix #6904
KumoLiu Nov 1, 2023
539be13
Merge branch 'dev' into dataset
KumoLiu Nov 1, 2023
5f61895
Merge branch 'dev' of https://github.com/Project-MONAI/MONAI into dat…
KumoLiu Nov 2, 2023
0e7a362
modify test
KumoLiu Nov 2, 2023
bf06497
fix ci
KumoLiu Nov 2, 2023
fa3da38
fix mypy
KumoLiu Nov 2, 2023
84a50c7
fix #6585
KumoLiu Nov 2, 2023
993ca74
minor fix
KumoLiu Nov 2, 2023
66fc0b4
fix flake8
KumoLiu Nov 2, 2023
cf40cbc
minor fix
KumoLiu Nov 2, 2023
6314848
Merge branch 'dev' of https://github.com/Project-MONAI/MONAI into dat…
KumoLiu Nov 3, 2023
7b79b34
Update monai/data/grid_dataset.py
KumoLiu Nov 3, 2023
397907d
Merge branch 'dataset' of https://github.com/KumoLiu/MONAI into dataset
KumoLiu Nov 3, 2023
f3b3b98
address comments
KumoLiu Nov 3, 2023
0282de4
Merge branch 'dev' of https://github.com/Project-MONAI/MONAI into dat…
KumoLiu Nov 3, 2023
caec700
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 3, 2023
e15aeab
add `_generate_patches`
KumoLiu Nov 8, 2023
acac03d
add unittests
KumoLiu Nov 8, 2023
28a82a0
Merge remote-tracking branch 'origin/dev' into dataset
KumoLiu Nov 8, 2023
06effe5
Update monai/data/grid_dataset.py
KumoLiu Nov 10, 2023
1bb1ae7
remove unused import
KumoLiu Nov 10, 2023
520e2b3
update docstring
KumoLiu Nov 10, 2023
32d4412
fix mypy
KumoLiu Nov 10, 2023
87283bf
Merge branch 'dev' into dataset
ericspod Nov 16, 2023
745e452
Merge branch 'dev' into dataset
ericspod Nov 16, 2023
13c25da
Update monai/data/grid_dataset.py
KumoLiu Nov 17, 2023
b4038f2
Update monai/data/grid_dataset.py
KumoLiu Nov 17, 2023
2b5e5df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2023
1f6b147
Merge branch 'dev' into dataset
KumoLiu Nov 17, 2023
f1f6360
address comments
KumoLiu Nov 17, 2023
d8e1f42
fix ci
KumoLiu Nov 17, 2023
872ea16
Merge branch 'dev' into dataset
ericspod Nov 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 184 additions & 32 deletions monai/data/grid_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,35 @@

from __future__ import annotations

from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence
import sys
import warnings
from collections.abc import Callable, Generator, Hashable, Iterable, Iterator, Mapping, Sequence
from copy import deepcopy
from multiprocessing.managers import ListProxy
from multiprocessing.pool import ThreadPool
from typing import TYPE_CHECKING

import numpy as np
import torch

from monai.config import KeysCollection
from monai.config.type_definitions import NdarrayTensor
from monai.data.dataset import Dataset
from monai.data.iterable_dataset import IterableDataset
from monai.data.utils import iter_patch
from monai.transforms import apply_transform
from monai.utils import NumpyPadMode, ensure_tuple, first
from monai.data.utils import iter_patch, pickle_hashing
from monai.transforms import Compose, RandomizableTrait, Transform, apply_transform, convert_to_contiguous
from monai.utils import NumpyPadMode, ensure_tuple, first, min_version, optional_import

if TYPE_CHECKING:
from tqdm import tqdm

has_tqdm = True
else:
tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm")

cp, _ = optional_import("cupy")
lmdb, _ = optional_import("lmdb")
pd, _ = optional_import("pandas")
kvikio_numpy, _ = optional_import("kvikio.numpy")
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved

__all__ = ["PatchDataset", "GridPatchDataset", "PatchIter", "PatchIterd"]

Expand Down Expand Up @@ -193,25 +210,160 @@ def __init__(
patch_iter: Callable,
transform: Callable | None = None,
with_coordinates: bool = True,
cache: bool = False,
cache_num: int = sys.maxsize,
cache_rate: float = 1.0,
num_workers: int | None = 1,
progress: bool = True,
copy_cache: bool = True,
as_contiguous: bool = True,
hash_func: Callable[..., bytes] = pickle_hashing,
) -> None:
super().__init__(data=data, transform=None)
if transform is not None and not isinstance(transform, Compose):
transform = Compose(transform)
self.patch_iter = patch_iter
self.patch_transform = transform
self.with_coordinates = with_coordinates
self.set_num = cache_num
self.set_rate = cache_rate
self.progress = progress
self.copy_cache = copy_cache
self.as_contiguous = as_contiguous
self.hash_func = hash_func
self.num_workers = num_workers
if self.num_workers is not None:
self.num_workers = max(int(self.num_workers), 1)
self._cache: list | ListProxy = []
self._cache_other: list | ListProxy = []
self.cache = cache
if self.cache:
if isinstance(data, Iterator):
raise TypeError("Data can not be iterator when cache is True")
self.set_data(data) # type: ignore

def set_data(self, data: Sequence) -> None:
"""
Set the input data and run deterministic transforms to generate cache content.

def __iter__(self):
for image in super().__iter__():
for patch, *others in self.patch_iter(image):
out_patch = patch
if self.patch_transform is not None:
out_patch = apply_transform(self.patch_transform, patch, map_items=False)
if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords
yield out_patch, others[0]
else:
yield out_patch
Note: should call this func after an entire epoch and must set `persistent_workers=False`
in PyTorch DataLoader, because it needs to create new worker processes based on new
generated cache content.

"""
self.data = data

def _compute_cache_num(data_len: int):
self.cache_num = min(int(self.set_num), int(data_len * self.set_rate), data_len)

# only compute cache for the unique items of dataset, and record the last index for duplicated items
mapping = {self.hash_func(v): i for i, v in enumerate(self.data)}
_compute_cache_num(len(mapping))
self._hash_keys = list(mapping)[: self.cache_num]
indices = list(mapping.values())[: self.cache_num]

class PatchDataset(Dataset):
self._cache = self._fill_cache(indices)
return

def _fill_cache(self, indices=None) -> list:
"""
Compute and fill the cache content from data source.

Args:
indices: target indices in the `self.data` source to compute cache.
if None, use the first `cache_num` items.

"""
if self.cache_num <= 0:
return []
if indices is None:
indices = list(range(self.cache_num))
if self.progress and not has_tqdm:
warnings.warn("tqdm is not installed, will not show the caching progress bar.")
with ThreadPool(self.num_workers) as p:
if self.progress and has_tqdm:
return list(tqdm(p.imap(self._load_cache_item, indices), total=len(indices), desc="Loading dataset"))
return list(p.imap(self._load_cache_item, indices))
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved

def _load_cache_item(self, idx: int):
"""
Args:
idx: the index of the input data sequence.
"""
item = self.data[idx] # type: ignore
patch_cache, other_cache = [], []
for patch, *others in self.patch_iter(item):
if self.patch_transform is not None:
first_random = self.patch_transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
)
patch = self.patch_transform(patch, end=first_random, threading=True)

if self.as_contiguous:
patch = convert_to_contiguous(patch, memory_format=torch.contiguous_format)
if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords
other_cache.append(others[0])
patch_cache.append(patch)
self._cache_other.append(other_cache)
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
return patch_cache

def __iter__(self):
if self.cache:
cache_index = None
for image in super().__iter__():
key = self.hash_func(image)
if key in self._hash_keys:
# if existing in cache, try to get the index in cache
cache_index = self._hash_keys.index(key)
if cache_index is None:
# no cache for this index, execute all the transforms directly
for patch, *others in self.patch_iter(image):
out_patch = patch
if self.patch_transform is not None:
out_patch = apply_transform(self.patch_transform, patch, map_items=False)
if (
self.with_coordinates and len(others) > 0
): # patch_iter to yield at least 2 items: patch, coords
yield out_patch, others[0]
else:
yield out_patch

if self._cache is None:
raise RuntimeError("cache buffer is not initialized, please call `set_data()` first.")
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
data = self._cache[cache_index] # type: ignore
other = self._cache_other[cache_index] # type: ignore

# load data from cache and execute from the first random transform
if not isinstance(self.patch_transform, Compose):
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("transform must be an instance of monai.transforms.Compose.")

first_random = self.patch_transform.get_index_of_first(
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
)
if first_random is not None:
data = deepcopy(data) if self.copy_cache is True else data
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
for out_patch, others in zip(data, other):
if self.patch_transform is not None:
out_patch = self.patch_transform(out_patch, start=first_random)
if (
self.with_coordinates and len(others) > 0
): # patch_iter to yield at least 2 items: patch, coords
yield out_patch, others
else:
yield out_patch
else:
for image in super().__iter__():
for patch, *others in self.patch_iter(image):
out_patch = patch
if self.patch_transform is not None:
out_patch = apply_transform(self.patch_transform, patch, map_items=False)
if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords
yield out_patch, others[0]
else:
yield out_patch
ericspod marked this conversation as resolved.
Show resolved Hide resolved


class PatchDataset(IterableDataset):
"""
returns a patch from an image dataset.
The patches are generated by a user-specified callable `patch_func`,
Expand Down Expand Up @@ -263,26 +415,26 @@ def __init__(
samples_per_image: `patch_func` should return a sequence of `samples_per_image` elements.
transform: transform applied to each patch.
"""
super().__init__(data=data, transform=transform)
super().__init__(data=data, transform=None)

self.patch_func = patch_func
if samples_per_image <= 0:
raise ValueError("sampler_per_image must be a positive integer.")
self.samples_per_image = int(samples_per_image)
self.patch_transform = transform

def __len__(self) -> int:
return len(self.data) * self.samples_per_image

def _transform(self, index: int):
image_id = int(index / self.samples_per_image)
image = self.data[image_id]
patches = self.patch_func(image)
if len(patches) != self.samples_per_image:
raise RuntimeWarning(
f"`patch_func` must return a sequence of length: samples_per_image={self.samples_per_image}."
)
patch_id = (index - image_id * self.samples_per_image) * (-1 if index < 0 else 1)
patch = patches[patch_id]
if self.transform is not None:
patch = apply_transform(self.transform, patch, map_items=False)
return patch
return len(self.data) * self.samples_per_image # type: ignore

def __iter__(self):
for image in super().__iter__():
patches = self.patch_func(image)
if len(patches) != self.samples_per_image:
raise RuntimeWarning(
f"`patch_func` must return a sequence of length: samples_per_image={self.samples_per_image}."
)
for patch in patches:
out_patch = patch
if self.patch_transform is not None:
out_patch = apply_transform(self.patch_transform, patch, map_items=False)
yield out_patch
16 changes: 6 additions & 10 deletions tests/test_grid_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,19 @@ def test_shape(self):
self.assertEqual(sorted(output), sorted(expected))

def test_loading_array(self):
set_determinism(seed=1234)
# set_determinism(seed=1234)
# test sequence input data with images
images = [np.arange(16, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)]
# image level
patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0)
patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0).set_random_state(seed=1234)
patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0))
ds = GridPatchDataset(data=images, patch_iter=patch_iter, transform=patch_intensity)
# use the grid patch dataset
for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=0):
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
np.testing.assert_allclose(
item[0],
np.array([[[[8.240326, 9.240326], [12.240326, 13.240326]]], [[[10.1624, 11.1624], [14.1624, 15.1624]]]]),
np.array([[[[8.708934, 9.708934], [12.708934, 13.708934]]], [[[10.8683, 11.8683], [14.8683, 15.8683]]]]),
rtol=1e-4,
)
np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
Expand All @@ -129,9 +129,7 @@ def test_loading_array(self):
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
np.testing.assert_allclose(
item[0],
np.array(
[[[[7.723618, 8.723618], [11.723618, 12.723618]]], [[[10.7175, 11.7175], [14.7175, 15.7175]]]]
),
np.array([[[[7.27427, 8.27427], [11.27427, 12.27427]]], [[[9.4353, 10.4353], [13.4353, 14.4353]]]]),
rtol=1e-3,
)
np.testing.assert_allclose(
Expand Down Expand Up @@ -164,7 +162,7 @@ def test_loading_dict(self):
self.assertListEqual(item[0]["metadata"], ["test string", "test string"])
np.testing.assert_allclose(
item[0]["image"],
np.array([[[[8.240326, 9.240326], [12.240326, 13.240326]]], [[[10.1624, 11.1624], [14.1624, 15.1624]]]]),
np.array([[[[8.708934, 9.708934], [12.708934, 13.708934]]], [[[10.8683, 11.8683], [14.8683, 15.8683]]]]),
rtol=1e-4,
)
np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
Expand All @@ -173,9 +171,7 @@ def test_loading_dict(self):
np.testing.assert_equal(item[0]["image"].shape, (2, 1, 2, 2))
np.testing.assert_allclose(
item[0]["image"],
np.array(
[[[[7.723618, 8.723618], [11.723618, 12.723618]]], [[[10.7175, 11.7175], [14.7175, 15.7175]]]]
),
np.array([[[[7.27427, 8.27427], [11.27427, 12.27427]]], [[[9.4353, 10.4353], [13.4353, 14.4353]]]]),
rtol=1e-3,
)
np.testing.assert_allclose(
Expand Down
15 changes: 12 additions & 3 deletions tests/test_patch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def test_shape(self):
n_workers = 0 if sys.platform == "win32" else 2
for item in DataLoader(result, batch_size=3, num_workers=n_workers):
output.append("".join(item))
expected = ["vwx", "yzh", "ell", "owo", "rld"]
if n_workers == 0:
expected = ["vwx", "yzh", "ell", "owo", "rld"]
else:
expected = ["vwx", "hel", "yzw", "lo", "orl", "d"]
self.assertEqual(output, expected)

def test_loading_array(self):
Expand All @@ -61,7 +64,7 @@ def test_loading_array(self):
np.testing.assert_allclose(
item[0],
np.array(
[[[-0.593095, 0.406905, 1.406905], [3.406905, 4.406905, 5.406905], [7.406905, 8.406905, 9.406905]]]
[[[4.970372, 5.970372, 6.970372], [8.970372, 9.970372, 10.970372], [12.970372, 13.970372, 14.970372]]]
),
rtol=1e-5,
)
Expand All @@ -71,7 +74,13 @@ def test_loading_array(self):
np.testing.assert_allclose(
item[0],
np.array(
[[[0.234308, 1.234308, 2.234308], [4.234308, 5.234308, 6.234308], [8.234308, 9.234308, 10.234308]]]
[
[
[5.028125, 6.028125, 7.028125],
[9.028125, 10.028125, 11.028125],
[13.028125, 14.028125, 15.028125],
]
]
),
rtol=1e-5,
)
Expand Down