Skip to content

Commit

Permalink
removing our implementation of NTXentLoss and using pytorch metric
Browse files Browse the repository at this point in the history
  • Loading branch information
edyoshikun committed Nov 14, 2024
1 parent 33de41b commit 74b61af
Showing 1 changed file with 5 additions and 66 deletions.
71 changes: 5 additions & 66 deletions viscy/representation/engine.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,20 @@
import logging
from typing import Literal, Sequence, TypedDict, Tuple
from typing import Literal, Sequence, Tuple, TypedDict

import numpy as np
import torch
import torch.nn.functional as F
from lightning.pytorch import LightningModule
from pytorch_metric_learning.losses import NTXentLoss
from torch import Tensor, nn

from viscy.data.typing import TrackingIndex, TripletSample
from viscy.representation.contrastive import ContrastiveEncoder
from viscy.utils.log_images import detach_sample, render_images
from pytorch_metric_learning.losses import SelfSupervisedLoss
from pytorch_metric_learning.losses import NTXentLoss as NTXentLoss_pml

_logger = logging.getLogger("lightning.pytorch")


class NTXentLoss_viscy(torch.nn.Module):
"""
Normalized Temperature-scaled Cross Entropy Loss
From Chen et.al, https://arxiv.org/abs/2002.05709
"""

def __init__(
self,
temperature=0.5,
criterion=torch.nn.CrossEntropyLoss(reduction="sum"),
):
super(NTXentLoss_viscy, self).__init__()
self.temperature = temperature
self.criterion = criterion

def _get_correlated_mask(self, batch_size):
mask = torch.ones((2 * batch_size, 2 * batch_size), dtype=bool)
mask = mask.fill_diagonal_(0)
for i in range(batch_size):
mask[i, batch_size + i] = 0
mask[batch_size + i, i] = 0
_logger.info(f"mask: {mask}")
return mask

@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
def forward(self, embeddings, labels):
"""
embeddings = [zis, zjs]
zis and zjs are the output projections from the two augmented views.
Here, we assume the two augmented views are the anchor and positive samples
"""
# Get the batch size from tensor
batch_size = embeddings.shape[0] // 2

zis, zjs = torch.split(embeddings, batch_size, dim=0)

# Cosine similarity
similarity_matrix = F.cosine_similarity(
embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2
)
# Temperature scaling
similarity_matrix = similarity_matrix / self.temperature

mask = self._get_correlated_mask(batch_size).to(similarity_matrix.device)

# Mask out unwanted pairs
similarity_matrix = similarity_matrix[mask].view(2 * batch_size, -1)

# Calculate NT-Xent Loss as cross-entropy
loss = self.criterion(similarity_matrix, labels)
loss /= 2 * batch_size

return loss


class ContrastivePrediction(TypedDict):
features: Tensor
projections: Tensor
Expand All @@ -85,11 +28,7 @@ def __init__(
self,
encoder: nn.Module | ContrastiveEncoder,
loss_function: (
nn.Module
| nn.CosineEmbeddingLoss
| nn.TripletMarginLoss
| NTXentLoss_pml
| NTXentLoss_viscy
nn.Module | nn.CosineEmbeddingLoss | nn.TripletMarginLoss | NTXentLoss
) = nn.TripletMarginLoss(margin=0.5),
lr: float = 1e-3,
schedule: Literal["WarmupCosine", "Constant"] = "Constant",
Expand Down Expand Up @@ -175,7 +114,7 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor:
pos_img = batch["positive"]
anchor_projection = self(anchor_img)
positive_projection = self(pos_img)
if isinstance(self.loss_function, (NTXentLoss_pml, NTXentLoss_viscy)):
if isinstance(self.loss_function, NTXentLoss):
indices = torch.arange(
0, anchor_projection.size(0), device=anchor_projection.device
)
Expand Down Expand Up @@ -226,7 +165,7 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor:
pos_img = batch["positive"]
anchor_projection = self(anchor)
positive_projection = self(pos_img)
if isinstance(self.loss_function, (NTXentLoss_pml, NTXentLoss_viscy)):
if isinstance(self.loss_function, NTXentLoss):
indices = torch.arange(
0, anchor_projection.size(0), device=anchor_projection.device
)
Expand Down

0 comments on commit 74b61af

Please sign in to comment.