-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
dataset.py
1522 lines (1303 loc) · 68.8 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import collections.abc
import math
import pickle
import shutil
import sys
import tempfile
import threading
import time
import warnings
from collections.abc import Callable, Sequence
from copy import copy, deepcopy
from multiprocessing.managers import ListProxy
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, cast
import numpy as np
import torch
from torch.multiprocessing import Manager
from torch.serialization import DEFAULT_PROTOCOL
from torch.utils.data import Dataset as _TorchDataset
from torch.utils.data import Subset
from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing
from monai.transforms import (
Compose,
Randomizable,
RandomizableTrait,
ThreadUnsafe,
Transform,
apply_transform,
convert_to_contiguous,
reset_ops_id,
)
from monai.utils import MAX_SEED, get_seed, look_up_option, min_version, optional_import
from monai.utils.misc import first
if TYPE_CHECKING:
from tqdm import tqdm
has_tqdm = True
else:
tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm")
lmdb, _ = optional_import("lmdb")
pd, _ = optional_import("pandas")
class Dataset(_TorchDataset):
"""
A generic dataset with a length property and an optional callable data transform
when fetching a data sample.
If passing slicing indices, will return a PyTorch Subset, for example: `data: Subset = dataset[1:4]`,
for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset
For example, typical input data can be a list of dictionaries::
[{ { {
'img': 'image1.nii.gz', 'img': 'image2.nii.gz', 'img': 'image3.nii.gz',
'seg': 'label1.nii.gz', 'seg': 'label2.nii.gz', 'seg': 'label3.nii.gz',
'extra': 123 'extra': 456 'extra': 789
}, }, }]
"""
def __init__(self, data: Sequence, transform: Callable | None = None) -> None:
"""
Args:
data: input data to load and transform to generate dataset for model.
transform: a callable data transform on input data.
"""
self.data = data
self.transform: Any = transform
def __len__(self) -> int:
return len(self.data)
def _transform(self, index: int):
"""
Fetch single data item from `self.data`.
"""
data_i = self.data[index]
return apply_transform(self.transform, data_i) if self.transform is not None else data_i
def __getitem__(self, index: int | slice | Sequence[int]):
"""
Returns a `Subset` if `index` is a slice or Sequence, a data item otherwise.
"""
if isinstance(index, slice):
# dataset[:42]
start, stop, step = index.indices(len(self))
indices = range(start, stop, step)
return Subset(dataset=self, indices=indices)
if isinstance(index, collections.abc.Sequence):
# dataset[[1, 3, 4]]
return Subset(dataset=self, indices=index)
return self._transform(index)
class DatasetFunc(Dataset):
"""
Execute function on the input dataset and leverage the output to act as a new Dataset.
It can be used to load / fetch the basic dataset items, like the list of `image, label` paths.
Or chain together to execute more complicated logic, like `partition_dataset`, `resample_datalist`, etc.
The `data` arg of `Dataset` will be applied to the first arg of callable `func`.
Usage example::
data_list = DatasetFunc(
data="path to file",
func=monai.data.load_decathlon_datalist,
data_list_key="validation",
base_dir="path to base dir",
)
# partition dataset for every rank
data_partition = DatasetFunc(
data=data_list,
func=lambda **kwargs: monai.data.partition_dataset(**kwargs)[torch.distributed.get_rank()],
num_partitions=torch.distributed.get_world_size(),
)
dataset = Dataset(data=data_partition, transform=transforms)
Args:
data: input data for the func to process, will apply to `func` as the first arg.
func: callable function to generate dataset items.
kwargs: other arguments for the `func` except for the first arg.
"""
def __init__(self, data: Any, func: Callable, **kwargs) -> None:
super().__init__(data=None, transform=None) # type:ignore
self.src = data
self.func = func
self.kwargs = kwargs
self.reset()
def reset(self, data: Any | None = None, func: Callable | None = None, **kwargs):
"""
Reset the dataset items with specified `func`.
Args:
data: if not None, execute `func` on it, default to `self.src`.
func: if not None, execute the `func` with specified `kwargs`, default to `self.func`.
kwargs: other arguments for the `func` except for the first arg.
"""
src = self.src if data is None else data
self.data = self.func(src, **self.kwargs) if func is None else func(src, **kwargs)
class PersistentDataset(Dataset):
"""
Persistent storage of pre-computed values to efficiently manage larger than memory dictionary format data,
it can operate transforms for specific fields. Results from the non-random transform components are computed
when first used, and stored in the `cache_dir` for rapid retrieval on subsequent uses.
If passing slicing indices, will return a PyTorch Subset, for example: `data: Subset = dataset[1:4]`,
for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset
The transforms which are supposed to be cached must implement the `monai.transforms.Transform`
interface and should not be `Randomizable`. This dataset will cache the outcomes before the first
`Randomizable` `Transform` within a `Compose` instance.
For example, typical input data can be a list of dictionaries::
[{ { {
'image': 'image1.nii.gz', 'image': 'image2.nii.gz', 'image': 'image3.nii.gz',
'label': 'label1.nii.gz', 'label': 'label2.nii.gz', 'label': 'label3.nii.gz',
'extra': 123 'extra': 456 'extra': 789
}, }, }]
For a composite transform like
.. code-block:: python
[ LoadImaged(keys=['image', 'label']),
Orientationd(keys=['image', 'label'], axcodes='RAS'),
ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', spatial_size=(96, 96, 96),
pos=1, neg=1, num_samples=4, image_key='image', image_threshold=0),
ToTensord(keys=['image', 'label'])]
Upon first use a filename based dataset will be processed by the transform for the
[LoadImaged, Orientationd, ScaleIntensityRanged] and the resulting tensor written to
the `cache_dir` before applying the remaining random dependant transforms
[RandCropByPosNegLabeld, ToTensord] elements for use in the analysis.
Subsequent uses of a dataset directly read pre-processed results from `cache_dir`
followed by applying the random dependant parts of transform processing.
During training call `set_data()` to update input data and recompute cache content.
Note:
The input data must be a list of file paths and will hash them as cache keys.
The filenames of the cached files also try to contain the hash of the transforms. In this
fashion, `PersistentDataset` should be robust to changes in transforms. This, however, is
not guaranteed, so caution should be used when modifying transforms to avoid unexpected
errors. If in doubt, it is advisable to clear the cache directory.
"""
def __init__(
self,
data: Sequence,
transform: Sequence[Callable] | Callable,
cache_dir: Path | str | None,
hash_func: Callable[..., bytes] = pickle_hashing,
pickle_module: str = "pickle",
pickle_protocol: int = DEFAULT_PROTOCOL,
hash_transform: Callable[..., bytes] | None = None,
reset_ops_id: bool = True,
) -> None:
"""
Args:
data: input data file paths to load and transform to generate dataset for model.
`PersistentDataset` expects input data to be a list of serializable
and hashes them as cache keys using `hash_func`.
transform: transforms to execute operations on input data.
cache_dir: If specified, this is the location for persistent storage
of pre-computed transformed data tensors. The cache_dir is computed once, and
persists on disk until explicitly removed. Different runs, programs, experiments
may share a common cache dir provided that the transforms pre-processing is consistent.
If `cache_dir` doesn't exist, will automatically create it.
If `cache_dir` is `None`, there is effectively no caching.
hash_func: a callable to compute hash from data items to be cached.
defaults to `monai.data.utils.pickle_hashing`.
pickle_module: string representing the module used for pickling metadata and objects,
default to `"pickle"`. due to the pickle limitation in multi-processing of Dataloader,
we can't use `pickle` as arg directly, so here we use a string name instead.
if want to use other pickle module at runtime, just register like:
>>> from monai.data import utils
>>> utils.SUPPORTED_PICKLE_MOD["test"] = other_pickle
this arg is used by `torch.save`, for more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,
and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.
pickle_protocol: can be specified to override the default protocol, default to `2`.
this arg is used by `torch.save`, for more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
hash_transform: a callable to compute hash from the transform information when caching.
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekys.NONE``, defaults to ``True``.
When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.
This is useful for skipping the transform instance checks when inverting applied operations
using the cached content and with re-created transform instances.
"""
if not isinstance(transform, Compose):
transform = Compose(transform)
super().__init__(data=data, transform=transform)
self.cache_dir = Path(cache_dir) if cache_dir is not None else None
self.hash_func = hash_func
self.pickle_module = pickle_module
self.pickle_protocol = pickle_protocol
if self.cache_dir is not None:
if not self.cache_dir.exists():
self.cache_dir.mkdir(parents=True, exist_ok=True)
if not self.cache_dir.is_dir():
raise ValueError("cache_dir must be a directory.")
self.transform_hash: str = ""
if hash_transform is not None:
self.set_transform_hash(hash_transform)
self.reset_ops_id = reset_ops_id
def set_transform_hash(self, hash_xform_func: Callable[..., bytes]):
"""Get hashable transforms, and then hash them. Hashable transforms
are deterministic transforms that inherit from `Transform`. We stop
at the first non-deterministic transform, or first that does not
inherit from MONAI's `Transform` class."""
hashable_transforms = []
for _tr in self.transform.flatten().transforms:
if isinstance(_tr, RandomizableTrait) or not isinstance(_tr, Transform):
break
hashable_transforms.append(_tr)
# Try to hash. Fall back to a hash of their names
try:
transform_hash = hash_xform_func(hashable_transforms)
except TypeError as te:
if "is not JSON serializable" not in str(te):
raise te
names = "".join(tr.__class__.__name__ for tr in hashable_transforms)
transform_hash = hash_xform_func(names)
self.transform_hash = transform_hash.decode("utf-8")
def set_data(self, data: Sequence):
"""
Set the input data and delete all the out-dated cache content.
"""
self.data = data
if self.cache_dir is not None and self.cache_dir.exists():
shutil.rmtree(self.cache_dir, ignore_errors=True)
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _pre_transform(self, item_transformed):
"""
Process the data from original state up to the first random element.
Args:
item_transformed: The data to be transformed
Returns:
the transformed element up to the first identified
random transform object
"""
for _transform in self.transform.transforms:
# execute all the deterministic transforms
if isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform):
break
# this is to be consistent with CacheDataset even though it's not in a multi-thread situation.
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
item_transformed = self.transform.evaluate_with_overrides(item_transformed, _xform)
item_transformed = apply_transform(_xform, item_transformed)
item_transformed = self.transform.evaluate_with_overrides(item_transformed, None)
if self.reset_ops_id:
reset_ops_id(item_transformed)
return item_transformed
def _post_transform(self, item_transformed):
"""
Process the data from before the first random transform to the final state ready for evaluation.
Args:
item_transformed: The data to be transformed (already processed up to the first random transform)
Returns:
the transformed element through the random transforms
"""
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
start_post_randomize_run = False
for _transform in self.transform.transforms:
if (
start_post_randomize_run
or isinstance(_transform, RandomizableTrait)
or not isinstance(_transform, Transform)
):
start_post_randomize_run = True
item_transformed = self.transform.evaluate_with_overrides(item_transformed, _transform)
item_transformed = apply_transform(_transform, item_transformed)
item_transformed = self.transform.evaluate_with_overrides(item_transformed, None)
return item_transformed
def _cachecheck(self, item_transformed):
"""
A function to cache the expensive input data transform operations
so that huge data sets (larger than computer memory) can be processed
on the fly as needed, and intermediate results written to disk for
future use.
Args:
item_transformed: The current data element to be mutated into transformed representation
Returns:
The transformed data_element, either from cache, or explicitly computing it.
Warning:
The current implementation does not encode transform information as part of the
hashing mechanism used for generating cache names when `hash_transform` is None.
If the transforms applied are changed in any way, the objects in the cache dir will be invalid.
"""
hashfile = None
if self.cache_dir is not None:
data_item_md5 = self.hash_func(item_transformed).decode("utf-8")
data_item_md5 += self.transform_hash
hashfile = self.cache_dir / f"{data_item_md5}.pt"
if hashfile is not None and hashfile.is_file(): # cache hit
try:
return torch.load(hashfile)
except PermissionError as e:
if sys.platform != "win32":
raise e
_item_transformed = self._pre_transform(deepcopy(item_transformed)) # keep the original hashed
if hashfile is None:
return _item_transformed
try:
# NOTE: Writing to a temporary directory and then using a nearly atomic rename operation
# to make the cache more robust to manual killing of parent process
# which may leave partially written cache files in an incomplete state
with tempfile.TemporaryDirectory() as tmpdirname:
temp_hash_file = Path(tmpdirname) / hashfile.name
torch.save(
obj=_item_transformed,
f=temp_hash_file,
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
pickle_protocol=self.pickle_protocol,
)
if temp_hash_file.is_file() and not hashfile.is_file():
# On Unix, if target exists and is a file, it will be replaced silently if the user has permission.
# for more details: https://docs.python.org/3/library/shutil.html#shutil.move.
try:
shutil.move(str(temp_hash_file), hashfile)
except FileExistsError:
pass
except PermissionError: # project-monai/monai issue #3613
pass
return _item_transformed
def _transform(self, index: int):
pre_random_item = self._cachecheck(self.data[index])
return self._post_transform(pre_random_item)
class CacheNTransDataset(PersistentDataset):
"""
Extension of `PersistentDataset`, tt can also cache the result of first N transforms, no matter it's random or not.
"""
def __init__(
self,
data: Sequence,
transform: Sequence[Callable] | Callable,
cache_n_trans: int,
cache_dir: Path | str | None,
hash_func: Callable[..., bytes] = pickle_hashing,
pickle_module: str = "pickle",
pickle_protocol: int = DEFAULT_PROTOCOL,
hash_transform: Callable[..., bytes] | None = None,
reset_ops_id: bool = True,
) -> None:
"""
Args:
data: input data file paths to load and transform to generate dataset for model.
`PersistentDataset` expects input data to be a list of serializable
and hashes them as cache keys using `hash_func`.
transform: transforms to execute operations on input data.
cache_n_trans: cache the result of first N transforms.
cache_dir: If specified, this is the location for persistent storage
of pre-computed transformed data tensors. The cache_dir is computed once, and
persists on disk until explicitly removed. Different runs, programs, experiments
may share a common cache dir provided that the transforms pre-processing is consistent.
If `cache_dir` doesn't exist, will automatically create it.
If `cache_dir` is `None`, there is effectively no caching.
hash_func: a callable to compute hash from data items to be cached.
defaults to `monai.data.utils.pickle_hashing`.
pickle_module: string representing the module used for pickling metadata and objects,
default to `"pickle"`. due to the pickle limitation in multi-processing of Dataloader,
we can't use `pickle` as arg directly, so here we use a string name instead.
if want to use other pickle module at runtime, just register like:
>>> from monai.data import utils
>>> utils.SUPPORTED_PICKLE_MOD["test"] = other_pickle
this arg is used by `torch.save`, for more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,
and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.
pickle_protocol: can be specified to override the default protocol, default to `2`.
this arg is used by `torch.save`, for more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
hash_transform: a callable to compute hash from the transform information when caching.
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekys.NONE``, defaults to ``True``.
When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.
This is useful for skipping the transform instance checks when inverting applied operations
using the cached content and with re-created transform instances.
"""
super().__init__(
data=data,
transform=transform,
cache_dir=cache_dir,
hash_func=hash_func,
pickle_module=pickle_module,
pickle_protocol=pickle_protocol,
hash_transform=hash_transform,
reset_ops_id=reset_ops_id,
)
self.cache_n_trans = cache_n_trans
def _pre_transform(self, item_transformed):
"""
Process the data from original state up to the N element.
Args:
item_transformed: The data to be transformed
Returns:
the transformed element up to the N transform object
"""
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
for i, _transform in enumerate(self.transform.transforms):
if i == self.cache_n_trans:
break
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
item_transformed = self.transform.evaluate_with_overrides(item_transformed, _xform)
item_transformed = apply_transform(_xform, item_transformed)
item_transformed = self.transform.evaluate_with_overrides(item_transformed, None)
reset_ops_id(item_transformed)
return item_transformed
def _post_transform(self, item_transformed):
"""
Process the data from before the N + 1 transform to the final state ready for evaluation.
Args:
item_transformed: The data to be transformed (already processed up to the first N transform)
Returns:
the final transformed result
"""
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
for i, _transform in enumerate(self.transform.transforms):
if i >= self.cache_n_trans:
item_transformed = self.transform.evaluate_with_overrides(item_transformed, item_transformed)
item_transformed = apply_transform(_transform, item_transformed)
item_transformed = self.transform.evaluate_with_overrides(item_transformed, None)
return item_transformed
class LMDBDataset(PersistentDataset):
"""
Extension of `PersistentDataset` using LMDB as the backend.
See Also:
:py:class:`monai.data.PersistentDataset`
Examples:
>>> items = [{"data": i} for i in range(5)]
# [{'data': 0}, {'data': 1}, {'data': 2}, {'data': 3}, {'data': 4}]
>>> lmdb_ds = monai.data.LMDBDataset(items, transform=monai.transforms.SimulateDelayd("data", delay_time=1))
>>> print(list(lmdb_ds)) # using the cached results
"""
def __init__(
self,
data: Sequence,
transform: Sequence[Callable] | Callable,
cache_dir: Path | str = "cache",
hash_func: Callable[..., bytes] = pickle_hashing,
db_name: str = "monai_cache",
progress: bool = True,
pickle_protocol=pickle.HIGHEST_PROTOCOL,
hash_transform: Callable[..., bytes] | None = None,
reset_ops_id: bool = True,
lmdb_kwargs: dict | None = None,
) -> None:
"""
Args:
data: input data file paths to load and transform to generate dataset for model.
`LMDBDataset` expects input data to be a list of serializable
and hashes them as cache keys using `hash_func`.
transform: transforms to execute operations on input data.
cache_dir: if specified, this is the location for persistent storage
of pre-computed transformed data tensors. The cache_dir is computed once, and
persists on disk until explicitly removed. Different runs, programs, experiments
may share a common cache dir provided that the transforms pre-processing is consistent.
If the cache_dir doesn't exist, will automatically create it. Defaults to "./cache".
hash_func: a callable to compute hash from data items to be cached.
defaults to `monai.data.utils.pickle_hashing`.
db_name: lmdb database file name. Defaults to "monai_cache".
progress: whether to display a progress bar.
pickle_protocol: pickle protocol version. Defaults to pickle.HIGHEST_PROTOCOL.
https://docs.python.org/3/library/pickle.html#pickle-protocols
hash_transform: a callable to compute hash from the transform information when caching.
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekys.NONE``, defaults to ``True``.
When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.
This is useful for skipping the transform instance checks when inverting applied operations
using the cached content and with re-created transform instances.
lmdb_kwargs: additional keyword arguments to the lmdb environment.
for more details please visit: https://lmdb.readthedocs.io/en/release/#environment-class
"""
super().__init__(
data=data,
transform=transform,
cache_dir=cache_dir,
hash_func=hash_func,
pickle_protocol=pickle_protocol,
hash_transform=hash_transform,
reset_ops_id=reset_ops_id,
)
self.progress = progress
if not self.cache_dir:
raise ValueError("cache_dir must be specified.")
self.db_file = self.cache_dir / f"{db_name}.lmdb"
self.lmdb_kwargs = lmdb_kwargs or {}
if not self.lmdb_kwargs.get("map_size", 0):
self.lmdb_kwargs["map_size"] = 1024**4 # default map_size
# lmdb is single-writer multi-reader by default
# the cache is created without multi-threading
self._read_env: Any | None = None
# this runs on the primary thread/process
self._fill_cache_start_reader(show_progress=self.progress)
print(f"Accessing lmdb file: {self.db_file.absolute()}.")
def set_data(self, data: Sequence):
"""
Set the input data and delete all the out-dated cache content.
"""
super().set_data(data=data)
self._read_env = self._fill_cache_start_reader(show_progress=self.progress)
def _fill_cache_start_reader(self, show_progress=True):
"""
Check the LMDB cache and write the cache if needed. py-lmdb doesn't have a good support for concurrent write.
This method can be used with multiple processes, but it may have a negative impact on the performance.
Args:
show_progress: whether to show the progress bar if possible.
"""
# create cache
self.lmdb_kwargs["readonly"] = False
env = lmdb.open(path=f"{self.db_file}", subdir=False, **self.lmdb_kwargs)
if show_progress and not has_tqdm:
warnings.warn("LMDBDataset: tqdm is not installed. not displaying the caching progress.")
with env.begin(write=False) as search_txn:
for item in tqdm(self.data) if has_tqdm and show_progress else self.data:
key = self.hash_func(item)
done, retry, val = False, 5, None
while not done and retry > 0:
try:
with search_txn.cursor() as cursor:
done = cursor.set_key(key)
if done:
continue
if val is None:
val = self._pre_transform(deepcopy(item)) # keep the original hashed
val = pickle.dumps(val, protocol=self.pickle_protocol)
with env.begin(write=True) as txn:
txn.put(key, val)
done = True
except lmdb.MapFullError:
done, retry = False, retry - 1
size = env.info()["map_size"]
new_size = size * 2
warnings.warn(
f"Resizing the cache database from {int(size) >> 20}MB" f" to {int(new_size) >> 20}MB."
)
env.set_mapsize(new_size)
except lmdb.MapResizedError:
# the mapsize is increased by another process
# set_mapsize with a size of 0 to adopt the new size
env.set_mapsize(0)
if not done: # still has the map full error
size = env.info()["map_size"]
env.close()
raise ValueError(f"LMDB map size reached, increase size above current size of {size}.")
size = env.info()["map_size"]
env.close()
# read-only database env
self.lmdb_kwargs["readonly"] = True
self.lmdb_kwargs["map_size"] = size
if self.lmdb_kwargs.get("lock", None) is None:
self.lmdb_kwargs["lock"] = False
if self.lmdb_kwargs.get("readahead", None) is None:
self.lmdb_kwargs["readahead"] = False
return lmdb.open(path=f"{self.db_file}", subdir=False, **self.lmdb_kwargs)
def _cachecheck(self, item_transformed):
"""
if the item is not found in the lmdb file, resolves to the persistent cache default behaviour.
"""
if self._read_env is None:
# this runs on multiple processes, each one should have its own env.
self._read_env = self._fill_cache_start_reader(show_progress=False)
with self._read_env.begin(write=False) as txn:
data = txn.get(self.hash_func(item_transformed))
if data is None:
warnings.warn("LMDBDataset: cache key not found, running fallback caching.")
return super()._cachecheck(item_transformed)
try:
return pickle.loads(data)
except Exception as err:
raise RuntimeError("Invalid cache value, corrupted lmdb file?") from err
def info(self):
"""
Returns: dataset info dictionary.
"""
if self._read_env is None:
self._read_env = self._fill_cache_start_reader()
out = dict(self._read_env.info())
out["size"] = len(self.data)
out["filename"] = f"{self.db_file.absolute()}"
return out
class CacheDataset(Dataset):
"""
Dataset with cache mechanism that can load data and cache deterministic transforms' result during training.
By caching the results of non-random preprocessing transforms, it accelerates the training data pipeline.
If the requested data is not in the cache, all transforms will run normally
(see also :py:class:`monai.data.dataset.Dataset`).
Users can set the cache rate or number of items to cache.
It is recommended to experiment with different `cache_num` or `cache_rate` to identify the best training speed.
The transforms which are supposed to be cached must implement the `monai.transforms.Transform`
interface and should not be `Randomizable`. This dataset will cache the outcomes before the first
`Randomizable` `Transform` within a `Compose` instance.
So to improve the caching efficiency, please always put as many as possible non-random transforms
before the randomized ones when composing the chain of transforms.
If passing slicing indices, will return a PyTorch Subset, for example: `data: Subset = dataset[1:4]`,
for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset
For example, if the transform is a `Compose` of::
transforms = Compose([
LoadImaged(),
EnsureChannelFirstd(),
Spacingd(),
Orientationd(),
ScaleIntensityRanged(),
RandCropByPosNegLabeld(),
ToTensord()
])
when `transforms` is used in a multi-epoch training pipeline, before the first training epoch,
this dataset will cache the results up to ``ScaleIntensityRanged``, as
all non-random transforms `LoadImaged`, `EnsureChannelFirstd`, `Spacingd`, `Orientationd`, `ScaleIntensityRanged`
can be cached. During training, the dataset will load the cached results and run
``RandCropByPosNegLabeld`` and ``ToTensord``, as ``RandCropByPosNegLabeld`` is a randomized transform
and the outcome not cached.
During training call `set_data()` to update input data and recompute cache content, note that it requires
`persistent_workers=False` in the PyTorch DataLoader.
Note:
`CacheDataset` executes non-random transforms and prepares cache content in the main process before
the first epoch, then all the subprocesses of DataLoader will read the same cache content in the main process
during training. it may take a long time to prepare cache content according to the size of expected cache data.
So to debug or verify the program before real training, users can set `cache_rate=0.0` or `cache_num=0` to
temporarily skip caching.
"""
def __init__(
self,
data: Sequence,
transform: Sequence[Callable] | Callable | None = None,
cache_num: int = sys.maxsize,
cache_rate: float = 1.0,
num_workers: int | None = 1,
progress: bool = True,
copy_cache: bool = True,
as_contiguous: bool = True,
hash_as_key: bool = False,
hash_func: Callable[..., bytes] = pickle_hashing,
runtime_cache: bool | str | list | ListProxy = False,
) -> None:
"""
Args:
data: input data to load and transform to generate dataset for model.
transform: transforms to execute operations on input data.
cache_num: number of items to be cached. Default is `sys.maxsize`.
will take the minimum of (cache_num, data_length x cache_rate, data_length).
cache_rate: percentage of cached data in total, default is 1.0 (cache all).
will take the minimum of (cache_num, data_length x cache_rate, data_length).
num_workers: the number of worker threads if computing cache in the initialization.
If num_workers is None then the number returned by os.cpu_count() is used.
If a value less than 1 is specified, 1 will be used instead.
progress: whether to display a progress bar.
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
default to `True`. if the random transforms don't modify the cached content
(for example, randomly crop from the cached image and deepcopy the crop region)
or if every cache item is only used once in a `multi-processing` environment,
may set `copy=False` for better performance.
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
it may help improve the performance of following logic.
hash_as_key: whether to compute hash value of input data as the key to save cache,
if key exists, avoid saving duplicated content. it can help save memory when
the dataset has duplicated items or augmented dataset.
hash_func: if `hash_as_key`, a callable to compute hash from data items to be cached.
defaults to `monai.data.utils.pickle_hashing`.
runtime_cache: mode of cache at the runtime. Default to `False` to prepare
the cache content for the entire ``data`` during initialization, this potentially largely increase the
time required between the constructor called and first mini-batch generated.
Three options are provided to compute the cache on the fly after the dataset initialization:
1. ``"threads"`` or ``True``: use a regular ``list`` to store the cache items.
2. ``"processes"``: use a ListProxy to store the cache items, it can be shared among processes.
3. A list-like object: a users-provided container to be used to store the cache items.
For `thread-based` caching (typically for caching cuda tensors), option 1 is recommended.
For single process workflows with multiprocessing data loading, option 2 is recommended.
For multiprocessing workflows (typically for distributed training),
where this class is initialized in subprocesses, option 3 is recommended,
and the list-like object should be prepared in the main process and passed to all subprocesses.
Not following these recommendations may lead to runtime errors or duplicated cache across processes.
"""
if not isinstance(transform, Compose):
transform = Compose(transform)
super().__init__(data=data, transform=transform)
self.set_num = cache_num # tracking the user-provided `cache_num` option
self.set_rate = cache_rate # tracking the user-provided `cache_rate` option
self.progress = progress
self.copy_cache = copy_cache
self.as_contiguous = as_contiguous
self.hash_as_key = hash_as_key
self.hash_func = hash_func
self.num_workers = num_workers
if self.num_workers is not None:
self.num_workers = max(int(self.num_workers), 1)
self.runtime_cache = runtime_cache
self.cache_num = 0
self._cache: list | ListProxy = []
self._hash_keys: list = []
self.set_data(data)
def set_data(self, data: Sequence) -> None:
"""
Set the input data and run deterministic transforms to generate cache content.
Note: should call this func after an entire epoch and must set `persistent_workers=False`
in PyTorch DataLoader, because it needs to create new worker processes based on new
generated cache content.
"""
self.data = data
def _compute_cache_num(data_len: int):
self.cache_num = min(int(self.set_num), int(data_len * self.set_rate), data_len)
if self.hash_as_key:
# only compute cache for the unique items of dataset, and record the last index for duplicated items
mapping = {self.hash_func(v): i for i, v in enumerate(self.data)}
_compute_cache_num(len(mapping))
self._hash_keys = list(mapping)[: self.cache_num]
indices = list(mapping.values())[: self.cache_num]
else:
_compute_cache_num(len(self.data))
indices = list(range(self.cache_num))
if self.runtime_cache in (False, None): # prepare cache content immediately
self._cache = self._fill_cache(indices)
return
if isinstance(self.runtime_cache, str) and "process" in self.runtime_cache:
# this must be in the main process, not in dataloader's workers
self._cache = Manager().list([None] * self.cache_num)
return
if (self.runtime_cache is True) or (isinstance(self.runtime_cache, str) and "thread" in self.runtime_cache):
self._cache = [None] * self.cache_num
return
self._cache = self.runtime_cache # type: ignore
return
def _fill_cache(self, indices=None) -> list:
"""
Compute and fill the cache content from data source.
Args:
indices: target indices in the `self.data` source to compute cache.
if None, use the first `cache_num` items.
"""
if self.cache_num <= 0:
return []
if indices is None:
indices = list(range(self.cache_num))
if self.progress and not has_tqdm:
warnings.warn("tqdm is not installed, will not show the caching progress bar.")
with ThreadPool(self.num_workers) as p:
if self.progress and has_tqdm:
return list(tqdm(p.imap(self._load_cache_item, indices), total=len(indices), desc="Loading dataset"))
return list(p.imap(self._load_cache_item, indices))
def _load_cache_item(self, idx: int):
"""
Args:
idx: the index of the input data sequence.
"""
item = self.data[idx]
for _transform in self.transform.transforms:
# execute all the deterministic transforms
if isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform):
break
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
item = self.transform.evaluate_with_overrides(item, _xform)
item = apply_transform(_xform, item)
item = self.transform.evaluate_with_overrides(item, None)
if self.as_contiguous:
item = convert_to_contiguous(item, memory_format=torch.contiguous_format)
return item
def _transform(self, index: int):
cache_index = None
if self.hash_as_key:
key = self.hash_func(self.data[index])
if key in self._hash_keys:
# if existing in cache, try to get the index in cache
cache_index = self._hash_keys.index(key)
elif index % len(self) < self.cache_num: # support negative index
cache_index = index
if cache_index is None:
# no cache for this index, execute all the transforms directly
return super()._transform(index)
if self._cache is None:
raise RuntimeError("cache buffer is not initialized, please call `set_data()` first.")
data = self._cache[cache_index]
# runtime cache computation
if data is None:
data = self._cache[cache_index] = self._load_cache_item(cache_index)
# load data from cache and execute from the first random transform
start_run = False
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
for _transform in self.transform.transforms:
if start_run or isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform):
# only need to deep copy data on first non-deterministic transform
if not start_run:
start_run = True
if self.copy_cache:
data = deepcopy(data)
data = self.transform.evaluate_with_overrides(data, _transform)
data = apply_transform(_transform, data)
data = self.transform.evaluate_with_overrides(data, None)
return data
class SmartCacheDataset(Randomizable, CacheDataset):
"""
Re-implementation of the SmartCache mechanism in NVIDIA Clara-train SDK.
At any time, the cache pool only keeps a subset of the whole dataset. In each epoch, only the items
in the cache are used for training. This ensures that data needed for training is readily available,
keeping GPU resources busy. Note that cached items may still have to go through a non-deterministic
transform sequence before being fed to GPU. At the same time, another thread is preparing replacement
items by applying the transform sequence to items not in cache. Once one epoch is completed, Smart
Cache replaces the same number of items with replacement items.
Smart Cache uses a simple `running window` algorithm to determine the cache content and replacement items.
Let N be the configured number of objects in cache; and R be the number of replacement objects (R = ceil(N * r),
where r is the configured replace rate).
For more details, please refer to:
https://docs.nvidia.com/clara/tlt-mi/clara-train-sdk-v3.0/nvmidl/additional_features/smart_cache.html#smart-cache
If passing slicing indices, will return a PyTorch Subset, for example: `data: Subset = dataset[1:4]`,
for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset
For example, if we have 5 images: `[image1, image2, image3, image4, image5]`, and `cache_num=4`, `replace_rate=0.25`.
so the actual training images cached and replaced for every epoch are as below::
epoch 1: [image1, image2, image3, image4]
epoch 2: [image2, image3, image4, image5]
epoch 3: [image3, image4, image5, image1]
epoch 3: [image4, image5, image1, image2]
epoch N: [image[N % 5] ...]
The usage of `SmartCacheDataset` contains 4 steps:
1. Initialize `SmartCacheDataset` object and cache for the first epoch.
2. Call `start()` to run replacement thread in background.
3. Call `update_cache()` before every epoch to replace training items.
4. Call `shutdown()` when training ends.
During training call `set_data()` to update input data and recompute cache content, note to call
`shutdown()` to stop first, then update data and call `start()` to restart.
Note:
This replacement will not work for below cases:
1. Set the `multiprocessing_context` of DataLoader to `spawn`.
2. Launch distributed data parallel with `torch.multiprocessing.spawn`.
3. Run on windows(the default multiprocessing method is `spawn`) with `num_workers` greater than 0.
4. Set the `persistent_workers` of DataLoader to `True` with `num_workers` greater than 0.
If using MONAI workflows, please add `SmartCacheHandler` to the handler list of trainer,
otherwise, please make sure to call `start()`, `update_cache()`, `shutdown()` during training.
Args:
data: input data to load and transform to generate dataset for model.
transform: transforms to execute operations on input data.
replace_rate: percentage of the cached items to be replaced in every epoch (default to 0.1).
cache_num: number of items to be cached. Default is `sys.maxsize`.
will take the minimum of (cache_num, data_length x cache_rate, data_length).
cache_rate: percentage of cached data in total, default is 1.0 (cache all).
will take the minimum of (cache_num, data_length x cache_rate, data_length).
num_init_workers: the number of worker threads to initialize the cache for first epoch.
If num_init_workers is None then the number returned by os.cpu_count() is used.
If a value less than 1 is specified, 1 will be used instead.
num_replace_workers: the number of worker threads to prepare the replacement cache for every epoch.
If num_replace_workers is None then the number returned by os.cpu_count() is used.
If a value less than 1 is specified, 1 will be used instead.
progress: whether to display a progress bar when caching for the first epoch.