Skip to content

Commit

Permalink
path for if not ddp
Browse files Browse the repository at this point in the history
  • Loading branch information
edyoshikun committed Oct 24, 2024
1 parent 0b005cf commit daa6860
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions viscy/data/hcs_ram.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import numpy as np
import torch
import torch.distributed as dist
from iohub.ngff import Position, open_ome_zarr
from lightning.pytorch import LightningDataModule
from monai.data import set_track_meta
Expand All @@ -20,9 +19,11 @@
from torch import Tensor
from torch.utils.data import DataLoader, Dataset

from viscy.data.distributed import ShardedDistributedSampler
from viscy.data.hcs import _read_norm_meta
from viscy.data.typing import ChannelMap, DictTransform, Sample
from viscy.data.distributed import ShardedDistributedSampler
from torch.distributed import get_rank
import torch.distributed as dist

_logger = logging.getLogger("lightning.pytorch")

Expand Down Expand Up @@ -71,12 +72,10 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample:
collated[key] = collate_meta_tensor(data)
return collated


def is_ddp_enabled() -> bool:
"""Check if distributed data parallel (DDP) is initialized."""
return dist.is_available() and dist.is_initialized()


class CachedDataset(Dataset):
"""
A dataset that caches the data in RAM.
Expand Down Expand Up @@ -318,7 +317,11 @@ def _setup_fit(self, dataset_settings: dict) -> None:
)

def train_dataloader(self) -> DataLoader:
sampler = ShardedDistributedSampler(self.train_dataset, shuffle=True)
if is_ddp_enabled():
sampler = ShardedDistributedSampler(self.train_dataset, shuffle=True)
else:
sampler = None
_logger.info("Using standard sampler for non-distributed training")
return DataLoader(
self.train_dataset,
batch_size=self.batch_size // self.train_patches_per_stack,
Expand All @@ -333,7 +336,12 @@ def train_dataloader(self) -> DataLoader:
)

def val_dataloader(self) -> DataLoader:
sampler = ShardedDistributedSampler(self.val_dataset, shuffle=False)
if is_ddp_enabled():
sampler = ShardedDistributedSampler(self.val_dataset, shuffle=False)
else:
sampler = None
_logger.info("Using standard sampler for non-distributed validation")

return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
Expand All @@ -342,5 +350,5 @@ def val_dataloader(self) -> DataLoader:
pin_memory=True,
shuffle=False,
timeout=self.timeout,
sampler=sampler,
sampler=sampler
)

0 comments on commit daa6860

Please sign in to comment.