Skip to content

Commit

Permalink
fixing the dataloader using torch collate_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
edyoshikun committed Oct 4, 2024
1 parent 2a0eeb6 commit f86a9e7
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions viscy/data/hcs_ram.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from monai.data.utils import collate_meta_tensor

from viscy.data.hcs import _read_norm_meta
from viscy.data.typing import ChannelMap, DictTransform, Sample
Expand All @@ -35,6 +36,25 @@ def _stack_channels(
# sample_images is a list['Phase3D'].shape = (1,3,256,256)
return [torch.stack([im[ch][0] for ch in channels[key]]) for im in sample_images]

def _collate_samples(batch: Sequence[Sample]) -> Sample:
"""Collate samples into a batch sample.
:param Sequence[Sample] batch: a sequence of dictionaries,
where each key may point to a value of a single tensor or a list of tensors,
as is the case with ``train_patches_per_stack > 1``.
:return Sample: Batch sample (dictionary of tensors)
"""
collated: Sample = {}
for key in batch[0].keys():
data = []
for sample in batch:
if isinstance(sample[key], Sequence):
data.extend(sample[key])
else:
data.append(sample[key])
collated[key] = collate_meta_tensor(data)
return collated



class CachedDataset(Dataset):
Expand Down Expand Up @@ -118,7 +138,7 @@ def __getitem__(self, index: int) -> Sample:
sample_images["norm_meta"] = norm_meta
if self.transform:
# FIX ME: check why the transforms return a list?
sample_images = self.transform(sample_images)[0]
sample_images = self.transform(sample_images)
if "weight" in sample_images:
del sample_images["weight"]
sample = {
Expand Down Expand Up @@ -206,7 +226,7 @@ def _train_transform(self) -> list[Callable]:
else:
self.augmentations=[]

_logger.info(f'Training augmentations: {self.augmentations}')
_logger.debug(f'Training augmentations: {self.augmentations}')
return list(self.augmentations)

def _fit_transform(self) -> tuple[Compose, Compose]:
Expand Down Expand Up @@ -267,7 +287,9 @@ def train_dataloader(self) -> DataLoader:
num_workers=self.num_workers,
persistent_workers=bool(self.num_workers),
shuffle=True,
timeout=self.timeout
timeout=self.timeout,
collate_fn=_collate_samples,
drop_last=True
)

def val_dataloader(self) -> DataLoader:
Expand All @@ -277,5 +299,6 @@ def val_dataloader(self) -> DataLoader:
num_workers=self.num_workers,
persistent_workers=bool(self.num_workers),
shuffle=False,
timeout=self.timeout
timeout=self.timeout,

)

0 comments on commit f86a9e7

Please sign in to comment.