Skip to content

Commit

Permalink
bandaid to cached dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
edyoshikun committed Oct 4, 2024
1 parent 3058a8e commit 2a0eeb6
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion viscy/data/hcs_ram.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def __getitem__(self, index: int) -> Sample:
if norm_meta is not None:
sample_images["norm_meta"] = norm_meta
if self.transform:
sample_images = self.transform(sample_images)
# FIX ME: check why the transforms return a list?
sample_images = self.transform(sample_images)[0]
if "weight" in sample_images:
del sample_images["weight"]
sample = {
Expand Down Expand Up @@ -185,6 +186,11 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]) -> None:
raise NotImplementedError(f"Stage {stage} is not supported")

def _train_transform(self) -> list[Callable]:
""" Set the train augmentations
"""

if self.augmentations:
for aug in self.augmentations:
if isinstance(aug, MultiSampleTrait):
Expand All @@ -197,6 +203,10 @@ def _train_transform(self) -> list[Callable]:
f"transform type {type(aug)}."
)
self.train_patches_per_stack = num_samples
else:
self.augmentations=[]

_logger.info(f'Training augmentations: {self.augmentations}')
return list(self.augmentations)

def _fit_transform(self) -> tuple[Compose, Compose]:
Expand Down

0 comments on commit 2a0eeb6

Please sign in to comment.