Skip to content

Commit

Permalink
improve sample image logging in fcmae
Browse files Browse the repository at this point in the history
  • Loading branch information
ziw-liu committed Oct 31, 2024
1 parent f7b585c commit 42c49f5
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions viscy/translation/engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import random
from typing import Literal, Sequence, Union

import numpy as np
Expand All @@ -9,7 +10,7 @@
from lightning.pytorch import LightningModule
from monai.data.utils import collate_meta_tensor
from monai.optimizers import WarmupCosineSchedule
from monai.transforms import Compose, DivisiblePad, Rotate90
from monai.transforms import DivisiblePad, Rotate90
from torch import Tensor, nn
from torch.optim.lr_scheduler import ConstantLR
from torchmetrics.functional import (
Expand Down Expand Up @@ -472,14 +473,10 @@ class FcmaeUNet(VSUNet):
def __init__(
self,
fit_mask_ratio: float = 0.0,
train_transforms=[],
validation_transforms=[],
**kwargs,
):
super().__init__(architecture="fcmae", **kwargs)
self.fit_mask_ratio = fit_mask_ratio
self.train_transforms = Compose(train_transforms)
self.validation_transforms = Compose(validation_transforms)
self.save_hyperparameters()

def on_fit_start(self):
Expand Down Expand Up @@ -512,6 +509,8 @@ def train_transform_and_collate(self, batch: list[dict[Sample]]) -> Tensor:
for dataset_batch, dm in zip(batch, self.datamodules):
dataset_batch = dm.train_gpu_transforms(dataset_batch)
transformed.extend(dataset_batch)
# shuffle references in place for better logging
random.shuffle(transformed)
return collate_meta_tensor(transformed)["source"]

@torch.no_grad()
Expand All @@ -525,10 +524,9 @@ def training_step(self, batch: list[list[Sample]], batch_idx: int) -> Tensor:
x = self.train_transform_and_collate(batch)
pred, mask, loss = self.forward_fit(x)
if batch_idx < self.log_batches_per_epoch:
target = x * mask.unsqueeze(2)
self.training_step_outputs.extend(
detach_sample(
(x, x * mask.unsqueeze(2), pred), self.log_samples_per_batch
)
detach_sample((x, target, pred), self.log_samples_per_batch)
)
self.log(
"loss/train",
Expand Down

0 comments on commit 42c49f5

Please sign in to comment.