From daa686028575868d7f87f34955f7451d9fcfac19 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 24 Oct 2024 10:49:49 -0700 Subject: [PATCH] path for if not ddp --- viscy/data/hcs_ram.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index aa24f28e..a9ff25d3 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -6,7 +6,6 @@ import numpy as np import torch -import torch.distributed as dist from iohub.ngff import Position, open_ome_zarr from lightning.pytorch import LightningDataModule from monai.data import set_track_meta @@ -20,9 +19,11 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset -from viscy.data.distributed import ShardedDistributedSampler from viscy.data.hcs import _read_norm_meta from viscy.data.typing import ChannelMap, DictTransform, Sample +from viscy.data.distributed import ShardedDistributedSampler +from torch.distributed import get_rank +import torch.distributed as dist _logger = logging.getLogger("lightning.pytorch") @@ -71,12 +72,10 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: collated[key] = collate_meta_tensor(data) return collated - def is_ddp_enabled() -> bool: """Check if distributed data parallel (DDP) is initialized.""" return dist.is_available() and dist.is_initialized() - class CachedDataset(Dataset): """ A dataset that caches the data in RAM. @@ -318,7 +317,11 @@ def _setup_fit(self, dataset_settings: dict) -> None: ) def train_dataloader(self) -> DataLoader: - sampler = ShardedDistributedSampler(self.train_dataset, shuffle=True) + if is_ddp_enabled(): + sampler = ShardedDistributedSampler(self.train_dataset, shuffle=True) + else: + sampler = None + _logger.info("Using standard sampler for non-distributed training") return DataLoader( self.train_dataset, batch_size=self.batch_size // self.train_patches_per_stack, @@ -333,7 +336,12 @@ def train_dataloader(self) -> DataLoader: ) def val_dataloader(self) -> DataLoader: - sampler = ShardedDistributedSampler(self.val_dataset, shuffle=False) + if is_ddp_enabled(): + sampler = ShardedDistributedSampler(self.val_dataset, shuffle=False) + else: + sampler = None + _logger.info("Using standard sampler for non-distributed validation") + return DataLoader( self.val_dataset, batch_size=self.batch_size, @@ -342,5 +350,5 @@ def val_dataloader(self) -> DataLoader: pin_memory=True, shuffle=False, timeout=self.timeout, - sampler=sampler, + sampler=sampler )