-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
MultipleNegativesRankingLoss.py
132 lines (107 loc) · 6.81 KB
/
MultipleNegativesRankingLoss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from __future__ import annotations
from collections.abc import Iterable
from typing import Any
import torch
from torch import Tensor, nn
from sentence_transformers import util
from sentence_transformers.SentenceTransformer import SentenceTransformer
class MultipleNegativesRankingLoss(nn.Module):
def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_fct=util.cos_sim) -> None:
"""
This loss expects as input a batch consisting of sentence pairs ``(a_1, p_1), (a_2, p_2)..., (a_n, p_n)``
where we assume that ``(a_i, p_i)`` are a positive pair and ``(a_i, p_j)`` for ``i != j`` a negative pair.
For each ``a_i``, it uses all other ``p_j`` as negative samples, i.e., for ``a_i``, we have 1 positive example
(``p_i``) and ``n-1`` negative examples (``p_j``). It then minimizes the negative log-likehood for softmax
normalized scores.
This loss function works great to train embeddings for retrieval setups where you have positive pairs
(e.g. (query, relevant_doc)) as it will sample in each batch ``n-1`` negative docs randomly.
The performance usually increases with increasing batch sizes.
You can also provide one or multiple hard negatives per anchor-positive pair by structuring the data like this:
``(a_1, p_1, n_1), (a_2, p_2, n_2)``. Then, ``n_1`` is a hard negative for ``(a_1, p_1)``. The loss will use for
the pair ``(a_i, p_i)`` all ``p_j`` for ``j != i`` and all ``n_j`` as negatives.
Args:
model: SentenceTransformer model
scale: Output of similarity function is multiplied by scale
value
similarity_fct: similarity function between sentence
embeddings. By default, cos_sim. Can also be set to dot
product (and then set scale to 1)
References:
- Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4: https://arxiv.org/pdf/1705.00652.pdf
- `Training Examples > Natural Language Inference <../../examples/training/nli/README.html>`_
- `Training Examples > Paraphrase Data <../../examples/training/paraphrases/README.html>`_
- `Training Examples > Quora Duplicate Questions <../../examples/training/quora_duplicate_questions/README.html>`_
- `Training Examples > MS MARCO <../../examples/training/ms_marco/README.html>`_
- `Unsupervised Learning > SimCSE <../../examples/unsupervised_learning/SimCSE/README.html>`_
- `Unsupervised Learning > GenQ <../../examples/unsupervised_learning/query_generation/README.html>`_
Requirements:
1. (anchor, positive) pairs or (anchor, positive, negative) triplets
Inputs:
+-------------------------------------------------+--------+
| Texts | Labels |
+=================================================+========+
| (anchor, positive) pairs | none |
+-------------------------------------------------+--------+
| (anchor, positive, negative) triplets | none |
+-------------------------------------------------+--------+
| (anchor, positive, negative_1, ..., negative_n) | none |
+-------------------------------------------------+--------+
Recommendations:
- Use ``BatchSamplers.NO_DUPLICATES`` (:class:`docs <sentence_transformers.training_args.BatchSamplers>`) to
ensure that no in-batch negatives are duplicates of the anchor or positive samples.
Relations:
- :class:`CachedMultipleNegativesRankingLoss` is equivalent to this loss, but it uses caching that allows for
much higher batch sizes (and thus better performance) without extra memory usage. However, it is slightly
slower.
- :class:`MultipleNegativesSymmetricRankingLoss` is equivalent to this loss, but with an additional loss term.
- :class:`GISTEmbedLoss` is equivalent to this loss, but uses a guide model to guide the in-batch negative
sample selection. `GISTEmbedLoss` yields a stronger training signal at the cost of some training overhead.
Example:
::
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
from datasets import Dataset
model = SentenceTransformer("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
"anchor": ["It's nice weather outside today.", "He drove to work."],
"positive": ["It's so sunny.", "He took the car to the office."],
})
loss = losses.MultipleNegativesRankingLoss(model)
trainer = SentenceTransformerTrainer(
model=model,
train_dataset=train_dataset,
loss=loss,
)
trainer.train()
"""
super().__init__()
self.model = model
self.scale = scale
self.similarity_fct = similarity_fct
self.cross_entropy_loss = nn.CrossEntropyLoss()
def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
# Compute the embeddings and distribute them to anchor and candidates (positive and optionally negatives)
embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
anchors = embeddings[0] # (batch_size, embedding_dim)
candidates = torch.cat(embeddings[1:]) # (batch_size * (1 + num_negatives), embedding_dim)
# For every anchor, we compute the similarity to all other candidates (positives and negatives),
# also from other anchors. This gives us a lot of in-batch negatives.
scores = self.similarity_fct(anchors, candidates) * self.scale
# (batch_size, batch_size * (1 + num_negatives))
# anchor[i] should be most similar to candidates[i], as that is the paired positive,
# so the label for anchor[i] is i
range_labels = torch.arange(0, scores.size(0), device=scores.device)
return self.cross_entropy_loss(scores, range_labels)
def get_config_dict(self) -> dict[str, Any]:
return {"scale": self.scale, "similarity_fct": self.similarity_fct.__name__}
@property
def citation(self) -> str:
return """
@misc{henderson2017efficient,
title={Efficient Natural Language Response Suggestion for Smart Reply},
author={Matthew Henderson and Rami Al-Rfou and Brian Strope and Yun-hsuan Sung and Laszlo Lukacs and Ruiqi Guo and Sanjiv Kumar and Balint Miklos and Ray Kurzweil},
year={2017},
eprint={1705.00652},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
"""