Skip to content

Commit

Permalink
format and lint hcs_ram
Browse files Browse the repository at this point in the history
  • Loading branch information
ziw-liu committed Oct 30, 2024
1 parent daa6860 commit 2ca134b
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions viscy/data/hcs_ram.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import torch
import torch.distributed as dist
from iohub.ngff import Position, open_ome_zarr
from lightning.pytorch import LightningDataModule
from monai.data import set_track_meta
Expand All @@ -19,11 +20,9 @@
from torch import Tensor
from torch.utils.data import DataLoader, Dataset

from viscy.data.distributed import ShardedDistributedSampler
from viscy.data.hcs import _read_norm_meta
from viscy.data.typing import ChannelMap, DictTransform, Sample
from viscy.data.distributed import ShardedDistributedSampler
from torch.distributed import get_rank
import torch.distributed as dist

_logger = logging.getLogger("lightning.pytorch")

Expand Down Expand Up @@ -72,10 +71,12 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample:
collated[key] = collate_meta_tensor(data)
return collated


def is_ddp_enabled() -> bool:
"""Check if distributed data parallel (DDP) is initialized."""
return dist.is_available() and dist.is_initialized()


class CachedDataset(Dataset):
"""
A dataset that caches the data in RAM.
Expand Down Expand Up @@ -341,7 +342,7 @@ def val_dataloader(self) -> DataLoader:
else:
sampler = None
_logger.info("Using standard sampler for non-distributed validation")

return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
Expand All @@ -350,5 +351,5 @@ def val_dataloader(self) -> DataLoader:
pin_memory=True,
shuffle=False,
timeout=self.timeout,
sampler=sampler
sampler=sampler,
)

0 comments on commit 2ca134b

Please sign in to comment.