Skip to content

Commit

Permalink
exposing prefetch and persistent worker (#203)
Browse files Browse the repository at this point in the history
  • Loading branch information
edyoshikun authored Nov 13, 2024
1 parent db80819 commit 820c805
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions viscy/data/hcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ def __init__(
augmentations: list[MapTransform] = [],
caching: bool = False,
ground_truth_masks: Path | None = None,
persistent_workers=False,
prefetch_factor=None,
):
super().__init__()
self.data_path = Path(data_path)
Expand All @@ -334,6 +336,8 @@ def __init__(
self.caching = caching
self.ground_truth_masks = ground_truth_masks
self.prepare_data_per_node = True
self.persistent_workers = persistent_workers
self.prefetch_factor = prefetch_factor

@property
def cache_path(self):
Expand Down Expand Up @@ -521,8 +525,8 @@ def train_dataloader(self):
batch_size=self.batch_size // self.train_patches_per_stack,
num_workers=self.num_workers,
shuffle=True,
persistent_workers=bool(self.num_workers),
prefetch_factor=4 if self.num_workers else None,
prefetch_factor=self.prefetch_factor if self.num_workers else None,
persistent_workers=self.persistent_workers,
collate_fn=_collate_samples,
drop_last=True,
)
Expand All @@ -533,8 +537,8 @@ def val_dataloader(self):
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
prefetch_factor=4 if self.num_workers else None,
persistent_workers=bool(self.num_workers),
prefetch_factor=self.prefetch_factor if self.num_workers else None,
persistent_workers=self.persistent_workers,
)

def test_dataloader(self):
Expand Down

0 comments on commit 820c805

Please sign in to comment.