diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 9ee37185..8831bdd7 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -318,6 +318,8 @@ def __init__( augmentations: list[MapTransform] = [], caching: bool = False, ground_truth_masks: Path | None = None, + persistent_workers=False, + prefetch_factor=None, ): super().__init__() self.data_path = Path(data_path) @@ -334,6 +336,8 @@ def __init__( self.caching = caching self.ground_truth_masks = ground_truth_masks self.prepare_data_per_node = True + self.persistent_workers = persistent_workers + self.prefetch_factor = prefetch_factor @property def cache_path(self): @@ -521,8 +525,8 @@ def train_dataloader(self): batch_size=self.batch_size // self.train_patches_per_stack, num_workers=self.num_workers, shuffle=True, - persistent_workers=bool(self.num_workers), - prefetch_factor=4 if self.num_workers else None, + prefetch_factor=self.prefetch_factor if self.num_workers else None, + persistent_workers=self.persistent_workers, collate_fn=_collate_samples, drop_last=True, ) @@ -533,8 +537,8 @@ def val_dataloader(self): batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, - prefetch_factor=4 if self.num_workers else None, - persistent_workers=bool(self.num_workers), + prefetch_factor=self.prefetch_factor if self.num_workers else None, + persistent_workers=self.persistent_workers, ) def test_dataloader(self):