Skip to content

Commit

Permalink
[enhancement] Make MultipleNegativesRankingLoss easier to understand (
Browse files Browse the repository at this point in the history
#3100)

* Make MultipleNegativesRankingLoss easier to understand

Because this is one of the most common loss functions, I think it's useful to comment-spam it a bit.

* Reformat comment slightly
  • Loading branch information
tomaarsen authored Dec 2, 2024
1 parent a49ffc5 commit ba8cb2e
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions sentence_transformers/losses/MultipleNegativesRankingLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,20 @@ def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_f
self.cross_entropy_loss = nn.CrossEntropyLoss()

def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
embeddings_a = reps[0]
embeddings_b = torch.cat(reps[1:])

scores = self.similarity_fct(embeddings_a, embeddings_b) * self.scale
# Example a[i] should match with b[i]
# Compute the embeddings and distribute them to anchor and candidates (positive and optionally negatives)
embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
anchors = embeddings[0] # (batch_size, embedding_dim)
candidates = torch.cat(embeddings[1:]) # (batch_size * (1 + num_negatives), embedding_dim)

# For every anchor, we compute the similarity to all other candidates (positives and negatives),
# also from other anchors. This gives us a lot of in-batch negatives.
scores = self.similarity_fct(anchors, candidates) * self.scale
# (batch_size, batch_size * (1 + num_negatives))

# anchor[i] should be most similar to candidates[i], as that is the paired positive,
# so the label for anchor[i] is i
range_labels = torch.arange(0, scores.size(0), device=scores.device)

return self.cross_entropy_loss(scores, range_labels)

def get_config_dict(self) -> dict[str, Any]:
Expand Down

0 comments on commit ba8cb2e

Please sign in to comment.