From 34e84a346a2cb50bcf618f8cb496ae12c0c3ad7d Mon Sep 17 00:00:00 2001 From: Lucas Robinet Date: Wed, 24 May 2023 13:13:15 +0200 Subject: [PATCH] Removing L2-norm in contrastive loss (L2-norm already present in cosine-similarity computation) Signed-off-by: Lucas Robinet --- monai/losses/contrastive.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/monai/losses/contrastive.py b/monai/losses/contrastive.py index 6213091bf6..a74f303ec6 100644 --- a/monai/losses/contrastive.py +++ b/monai/losses/contrastive.py @@ -68,13 +68,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: temperature_tensor = torch.as_tensor(self.temperature).to(input.device) batch_size = input.shape[0] - norm_i = F.normalize(input, dim=1) - norm_j = F.normalize(target, dim=1) - negatives_mask = ~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool) negatives_mask = torch.clone(negatives_mask.type(torch.float)).to(input.device) - repr = torch.cat([norm_i, norm_j], dim=0) + repr = torch.cat([input, target], dim=0) sim_matrix = F.cosine_similarity(repr.unsqueeze(1), repr.unsqueeze(0), dim=2) sim_ij = torch.diag(sim_matrix, batch_size) sim_ji = torch.diag(sim_matrix, -batch_size)