diff --git a/sentence_transformers/losses/CachedGISTEmbedLoss.py b/sentence_transformers/losses/CachedGISTEmbedLoss.py index 747b4c540..d76d3338e 100644 --- a/sentence_transformers/losses/CachedGISTEmbedLoss.py +++ b/sentence_transformers/losses/CachedGISTEmbedLoss.py @@ -229,15 +229,16 @@ def embed_minibatch_iter( def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor]]) -> Tensor: """Generalized function to calculate the cross-entropy loss and cache the gradients wrt. the embeddings.""" - loss = self.calculate_loss(reps, reps_guided) - loss.backward() + loss = self.calculate_loss(reps, reps_guided, with_backward=True) loss = loss.detach().requires_grad_() self.cache = [[r.grad for r in rs] for rs in reps] return loss - def calculate_loss(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor]]) -> Tensor: + def calculate_loss( + self, reps: list[list[Tensor]], reps_guided: list[list[Tensor]], with_backward: bool = False + ) -> Tensor: """Generalized function to calculate the cross-entropy loss without caching gradients.""" if len(reps) != len(reps_guided): raise ValueError("reps and reps_guided must have the same length") @@ -291,6 +292,9 @@ def calculate_loss(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor # Normalize the scores and calculate the cross-entropy loss scores = scores / self.temperature loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size + if with_backward: + loss_mbatch.backward() + loss_mbatch = loss_mbatch.detach() losses.append(loss_mbatch) loss = sum(losses) diff --git a/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py b/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py index dfdff2c82..74fe8ae64 100644 --- a/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py +++ b/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py @@ -213,15 +213,14 @@ def embed_minibatch_iter( def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]]) -> Tensor: """Calculate the cross-entropy loss and cache the gradients wrt. the embeddings.""" - loss = self.calculate_loss(reps) - loss.backward() + loss = self.calculate_loss(reps, with_backward=True) loss = loss.detach().requires_grad_() self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim) return loss - def calculate_loss(self, reps: list[list[Tensor]]) -> Tensor: + def calculate_loss(self, reps: list[list[Tensor]], with_backward: bool = False) -> Tensor: """Calculate the cross-entropy loss. No need to cache the gradients.""" embeddings_a = torch.cat(reps[0]) # (bsz, hdim) embeddings_b = torch.cat([torch.cat(r) for r in reps[1:]]) # ((1 + nneg) * bsz, hdim) @@ -241,6 +240,9 @@ def calculate_loss(self, reps: list[list[Tensor]]) -> Tensor: e = b + self.mini_batch_size scores: Tensor = self.similarity_fct(embeddings_a[b:e], embeddings_b) * self.scale loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size + if with_backward: + loss_mbatch.backward() + loss_mbatch = loss_mbatch.detach() losses.append(loss_mbatch) loss = sum(losses) diff --git a/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py b/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py index 20a3054d5..4b0995c96 100644 --- a/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py +++ b/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py @@ -182,15 +182,14 @@ def embed_minibatch_iter( def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]]) -> Tensor: """Calculate the symmetric loss and cache gradients.""" - loss = self.calculate_loss(reps) - loss.backward() + loss = self.calculate_loss(reps, with_backward=True) loss = loss.detach().requires_grad_() self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim) return loss - def calculate_loss(self, reps: list[list[Tensor]]) -> Tensor: + def calculate_loss(self, reps: list[list[Tensor]], with_backward: bool = False) -> Tensor: """Calculate the symmetric loss without caching gradients (for evaluation).""" embeddings_a = torch.cat(reps[0]) # (bsz, hdim) embeddings_b = torch.cat([torch.cat(r) for r in reps[1:]]) # ((1 + nneg) * bsz, hdim) @@ -214,6 +213,9 @@ def calculate_loss(self, reps: list[list[Tensor]]) -> Tensor: backward_loss: torch.Tensor = self.cross_entropy_loss(positive_scores.t(), labels[: len(positive_scores)]) loss_mbatch = (forward_loss + backward_loss) / 2 + if with_backward: + loss_mbatch.backward() + loss_mbatch = loss_mbatch.detach() losses.append(loss_mbatch) loss = sum(losses) / len(losses) diff --git a/sentence_transformers/losses/MatryoshkaLoss.py b/sentence_transformers/losses/MatryoshkaLoss.py index fb71cc15d..4dfe0acb4 100644 --- a/sentence_transformers/losses/MatryoshkaLoss.py +++ b/sentence_transformers/losses/MatryoshkaLoss.py @@ -4,6 +4,7 @@ from collections.abc import Iterable from typing import Any +import torch import torch.nn.functional as F from torch import Tensor, nn @@ -81,7 +82,7 @@ def __init__( self.matryoshka_weights = matryoshka_weights self.n_dims_per_step = n_dims_per_step - def __call__(self, reps: list[list[Tensor]], *args) -> Tensor: + def __call__(self, reps: list[list[Tensor]], *args, **kwargs) -> Tensor: dim_indices = range(len(self.matryoshka_dims)) if self.n_dims_per_step > 0 and self.n_dims_per_step < len(dim_indices): dim_indices = random.sample(dim_indices, self.n_dims_per_step) @@ -91,9 +92,21 @@ def __call__(self, reps: list[list[Tensor]], *args) -> Tensor: dim = self.matryoshka_dims[idx] weight = self.matryoshka_weights[idx] - truncated = [[shrink(r, dim) for r in rs] for rs in reps] - loss += weight * self.fn(truncated, *args) - + truncated = [[shrink(r, dim) for r in minibatch] for minibatch in reps] + compute_gradients = torch.is_grad_enabled() + # we need to detach the truncated embeddings, + # otherwise the first backward pass of the underlying function will clear the computation graph of the embedding truncation + if compute_gradients: + matryoshka_reps = [[r.detach().requires_grad_() for r in minibatch] for minibatch in truncated] + else: + matryoshka_reps = truncated + loss += weight * self.fn(matryoshka_reps, *args, **kwargs) + # After computing the gradients in minibatches, we need to continue the backward pass through the truncation calculation + # the gradients must be multipied with the weights because otherwise the matryoshka weights are not considered in the backward pass + if compute_gradients: + for t_minibatch, d_minibatch in zip(truncated, matryoshka_reps): + for t, d in zip(t_minibatch, d_minibatch): + t.backward(weight * d.grad) return loss