From ba8cb2ef1d7a5ad6b9169eddee469e900456532c Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Mon, 2 Dec 2024 14:43:14 +0100 Subject: [PATCH] [`enhancement`] Make MultipleNegativesRankingLoss easier to understand (#3100) * Make MultipleNegativesRankingLoss easier to understand Because this is one of the most common loss functions, I think it's useful to comment-spam it a bit. * Reformat comment slightly --- .../losses/MultipleNegativesRankingLoss.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/sentence_transformers/losses/MultipleNegativesRankingLoss.py b/sentence_transformers/losses/MultipleNegativesRankingLoss.py index 1aea7acfa..d7e2d1c41 100644 --- a/sentence_transformers/losses/MultipleNegativesRankingLoss.py +++ b/sentence_transformers/losses/MultipleNegativesRankingLoss.py @@ -99,13 +99,20 @@ def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_f self.cross_entropy_loss = nn.CrossEntropyLoss() def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor: - reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features] - embeddings_a = reps[0] - embeddings_b = torch.cat(reps[1:]) - - scores = self.similarity_fct(embeddings_a, embeddings_b) * self.scale - # Example a[i] should match with b[i] + # Compute the embeddings and distribute them to anchor and candidates (positive and optionally negatives) + embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features] + anchors = embeddings[0] # (batch_size, embedding_dim) + candidates = torch.cat(embeddings[1:]) # (batch_size * (1 + num_negatives), embedding_dim) + + # For every anchor, we compute the similarity to all other candidates (positives and negatives), + # also from other anchors. This gives us a lot of in-batch negatives. + scores = self.similarity_fct(anchors, candidates) * self.scale + # (batch_size, batch_size * (1 + num_negatives)) + + # anchor[i] should be most similar to candidates[i], as that is the paired positive, + # so the label for anchor[i] is i range_labels = torch.arange(0, scores.size(0), device=scores.device) + return self.cross_entropy_loss(scores, range_labels) def get_config_dict(self) -> dict[str, Any]: