diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index aa7ac24b..698ff412 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -378,7 +378,6 @@ def on_train_epoch_end(self): def on_validation_epoch_end(self): super().on_validation_epoch_end() self._log_samples("val_samples", self.validation_step_outputs) - self.validation_step_outputs = [] # average within each dataloader loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses] self.log( @@ -386,6 +385,8 @@ def on_validation_epoch_end(self): torch.tensor(loss_means).mean().to(self.device), sync_dist=True, ) + self.validation_step_outputs.clear() + self.validation_losses.clear() def on_test_start(self): """Load CellPose model for segmentation."""