From 42c49f571c967c8c42f886433b3e8709a96407bb Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 31 Oct 2024 10:39:34 -0700 Subject: [PATCH] improve sample image logging in fcmae --- viscy/translation/engine.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 66e18f8b..5c49aa10 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -1,5 +1,6 @@ import logging import os +import random from typing import Literal, Sequence, Union import numpy as np @@ -9,7 +10,7 @@ from lightning.pytorch import LightningModule from monai.data.utils import collate_meta_tensor from monai.optimizers import WarmupCosineSchedule -from monai.transforms import Compose, DivisiblePad, Rotate90 +from monai.transforms import DivisiblePad, Rotate90 from torch import Tensor, nn from torch.optim.lr_scheduler import ConstantLR from torchmetrics.functional import ( @@ -472,14 +473,10 @@ class FcmaeUNet(VSUNet): def __init__( self, fit_mask_ratio: float = 0.0, - train_transforms=[], - validation_transforms=[], **kwargs, ): super().__init__(architecture="fcmae", **kwargs) self.fit_mask_ratio = fit_mask_ratio - self.train_transforms = Compose(train_transforms) - self.validation_transforms = Compose(validation_transforms) self.save_hyperparameters() def on_fit_start(self): @@ -512,6 +509,8 @@ def train_transform_and_collate(self, batch: list[dict[Sample]]) -> Tensor: for dataset_batch, dm in zip(batch, self.datamodules): dataset_batch = dm.train_gpu_transforms(dataset_batch) transformed.extend(dataset_batch) + # shuffle references in place for better logging + random.shuffle(transformed) return collate_meta_tensor(transformed)["source"] @torch.no_grad() @@ -525,10 +524,9 @@ def training_step(self, batch: list[list[Sample]], batch_idx: int) -> Tensor: x = self.train_transform_and_collate(batch) pred, mask, loss = self.forward_fit(x) if batch_idx < self.log_batches_per_epoch: + target = x * mask.unsqueeze(2) self.training_step_outputs.extend( - detach_sample( - (x, x * mask.unsqueeze(2), pred), self.log_samples_per_batch - ) + detach_sample((x, target, pred), self.log_samples_per_batch) ) self.log( "loss/train",