diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 983cb54f..b15d305f 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -111,6 +111,16 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): key, grid, self.current_epoch, dataformats="HWC" ) + def _log_step_samples(self, batch_idx, samples, stage: Literal["train", "val"]): + """Common method for logging step samples""" + if batch_idx < self.log_batches_per_epoch: + output_list = ( + self.training_step_outputs + if stage == "train" + else self.validation_step_outputs + ) + output_list.extend(detach_sample(samples, self.log_samples_per_batch)) + def log_embedding_umap(self, embeddings: Tensor, tag: str): _logger.debug(f"Computing UMAP for {tag} embeddings.") umap = UMAP(n_components=2) @@ -157,17 +167,7 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: # Note: we assume the two augmented views are the anchor and positive samples embeddings = torch.cat((anchor_projection, positive_projection)) loss = self.loss_function(embeddings, labels) - self._log_metrics( - loss=loss, - anchor=anchor_projection, - positive=positive_projection, - negative=None, - stage="train", - ) - if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - detach_sample((anchor_img, pos_img), self.log_samples_per_batch) - ) + self._log_step_samples(batch_idx, (anchor_img, pos_img), "train") else: neg_img = batch["negative"] negative_projection = self(neg_img)