diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index a9ff25d3..51eb8e90 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -6,6 +6,7 @@ import numpy as np import torch +import torch.distributed as dist from iohub.ngff import Position, open_ome_zarr from lightning.pytorch import LightningDataModule from monai.data import set_track_meta @@ -19,11 +20,9 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset +from viscy.data.distributed import ShardedDistributedSampler from viscy.data.hcs import _read_norm_meta from viscy.data.typing import ChannelMap, DictTransform, Sample -from viscy.data.distributed import ShardedDistributedSampler -from torch.distributed import get_rank -import torch.distributed as dist _logger = logging.getLogger("lightning.pytorch") @@ -72,10 +71,12 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: collated[key] = collate_meta_tensor(data) return collated + def is_ddp_enabled() -> bool: """Check if distributed data parallel (DDP) is initialized.""" return dist.is_available() and dist.is_initialized() + class CachedDataset(Dataset): """ A dataset that caches the data in RAM. @@ -341,7 +342,7 @@ def val_dataloader(self) -> DataLoader: else: sampler = None _logger.info("Using standard sampler for non-distributed validation") - + return DataLoader( self.val_dataset, batch_size=self.batch_size, @@ -350,5 +351,5 @@ def val_dataloader(self) -> DataLoader: pin_memory=True, shuffle=False, timeout=self.timeout, - sampler=sampler + sampler=sampler, )