Skip to content

Commit

Permalink
don't perform backward pass in evaluation mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcel Brunnbauer committed Dec 10, 2024
1 parent c2d397f commit 0292b9b
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions sentence_transformers/losses/MatryoshkaLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Iterable
from typing import Any

import torch
import torch.nn.functional as F
from torch import Tensor, nn

Expand Down Expand Up @@ -92,15 +93,20 @@ def __call__(self, reps: list[list[Tensor]], *args, **kwargs) -> Tensor:
weight = self.matryoshka_weights[idx]

truncated = [[shrink(r, dim) for r in minibatch] for minibatch in reps]
compute_gradients = torch.is_grad_enabled()
# we need to detach the truncated embeddings,
# otherwise the first backward pass of the underlying function will clear the computation graph of the embedding truncation
detached = [[r.detach().requires_grad_() for r in minibatch] for minibatch in truncated]
loss += weight * self.fn(detached, *args, **kwargs)
if compute_gradients:
matryoshka_reps = [[r.detach().requires_grad_() for r in minibatch] for minibatch in truncated]
else:
matryoshka_reps = truncated
loss += weight * self.fn(matryoshka_reps, *args, **kwargs)
# After computing the gradients in minibatches, we need to continue the backward pass through the truncation calculation
# the gradients must be multipied with the weights because otherwise the matryoshka weights are not considered in the backward pass
for t_minibatch, d_minibatch in zip(truncated, detached):
for t, d in zip(t_minibatch, d_minibatch):
t.backward(weight * d.grad)
if compute_gradients:
for t_minibatch, d_minibatch in zip(truncated, matryoshka_reps):
for t, d in zip(t_minibatch, d_minibatch):
t.backward(weight * d.grad)
return loss


Expand Down

0 comments on commit 0292b9b

Please sign in to comment.