Skip to content

Commit

Permalink
simplying the engine
Browse files Browse the repository at this point in the history
  • Loading branch information
edyoshikun committed Dec 18, 2024
1 parent 9ac6ebf commit 332bb73
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions viscy/representation/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,16 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]):
key, grid, self.current_epoch, dataformats="HWC"
)

def _log_step_samples(self, batch_idx, samples, stage: Literal["train", "val"]):
"""Common method for logging step samples"""
if batch_idx < self.log_batches_per_epoch:
output_list = (
self.training_step_outputs
if stage == "train"
else self.validation_step_outputs
)
output_list.extend(detach_sample(samples, self.log_samples_per_batch))

def log_embedding_umap(self, embeddings: Tensor, tag: str):
_logger.debug(f"Computing UMAP for {tag} embeddings.")
umap = UMAP(n_components=2)
Expand Down Expand Up @@ -157,17 +167,7 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor:
# Note: we assume the two augmented views are the anchor and positive samples
embeddings = torch.cat((anchor_projection, positive_projection))
loss = self.loss_function(embeddings, labels)
self._log_metrics(
loss=loss,
anchor=anchor_projection,
positive=positive_projection,
negative=None,
stage="train",
)
if batch_idx < self.log_batches_per_epoch:
self.training_step_outputs.extend(
detach_sample((anchor_img, pos_img), self.log_samples_per_batch)
)
self._log_step_samples(batch_idx, (anchor_img, pos_img), "train")
else:
neg_img = batch["negative"]
negative_projection = self(neg_img)
Expand Down

0 comments on commit 332bb73

Please sign in to comment.