From 1e23c2b5aa1e2c95c4f9e3aa5e389834855ea81d Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Fri, 8 Nov 2024 09:47:56 -0800 Subject: [PATCH] translation: fix validation loss aggregation (#202) --- viscy/translation/engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index aa7ac24b..698ff412 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -378,7 +378,6 @@ def on_train_epoch_end(self): def on_validation_epoch_end(self): super().on_validation_epoch_end() self._log_samples("val_samples", self.validation_step_outputs) - self.validation_step_outputs = [] # average within each dataloader loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses] self.log( @@ -386,6 +385,8 @@ def on_validation_epoch_end(self): torch.tensor(loss_means).mean().to(self.device), sync_dist=True, ) + self.validation_step_outputs.clear() + self.validation_losses.clear() def on_test_start(self): """Load CellPose model for segmentation."""