Skip to content

Commit

Permalink
translation: fix validation loss aggregation (#202)
Browse files Browse the repository at this point in the history
  • Loading branch information
ziw-liu authored Nov 8, 2024
1 parent 1cb0fc2 commit db80819
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion viscy/translation/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,14 +378,15 @@ def on_train_epoch_end(self):
def on_validation_epoch_end(self):
super().on_validation_epoch_end()
self._log_samples("val_samples", self.validation_step_outputs)
self.validation_step_outputs = []
# average within each dataloader
loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses]
self.log(
"loss/validate",
torch.tensor(loss_means).mean().to(self.device),
sync_dist=True,
)
self.validation_step_outputs.clear()
self.validation_losses.clear()

def on_test_start(self):
"""Load CellPose model for segmentation."""
Expand Down

0 comments on commit db80819

Please sign in to comment.