Skip to content

Commit

Permalink
Add some docstrings to the decorators in the MatryoshkaLoss
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen committed Nov 20, 2024
1 parent cad56d0 commit 414e7a8
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion sentence_transformers/losses/MatryoshkaLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ def shrink(tensor: Tensor, dim: int) -> Tensor:


class ForwardDecorator:
"""
This decorator is used to cache the output of the Sentence Transformer's forward pass,
so that it can be shrank and reused for multiple loss calculations. This prevents the
model from recalculating the embeddings for each desired Matryoshka dimensionality.
This decorator is applied to `SentenceTransformer.forward`.
"""

def __init__(self, fn) -> None:
self.fn = fn

Expand Down Expand Up @@ -56,6 +64,15 @@ def __call__(self, features: dict[str, Tensor]) -> dict[str, Tensor]:


class CachedLossDecorator:
"""
This decorator is used with the Cached... losses to compute the underlying loss function
for each Matryoshka dimensionality. This is done by shrinking the pre-computed embeddings
to the desired dimensionality and then passing them to the underlying loss function once
for each desired dimensionality.
This decorator is applied to the `calculate_loss` method of the Cached... losses.
"""

def __init__(
self, fn, matryoshka_dims: list[int], matryoshka_weights: list[float | int], n_dims_per_step: int = -1
) -> None:
Expand Down Expand Up @@ -158,21 +175,28 @@ def __init__(
dims_weights = zip(matryoshka_dims, matryoshka_weights)
self.matryoshka_dims, self.matryoshka_weights = zip(*sorted(dims_weights, key=lambda x: x[0], reverse=True))
self.n_dims_per_step = n_dims_per_step

# The Cached... losses require a special treatment as their backward pass is incompatible with the
# ForwardDecorator approach. Instead, we use a CachedLossDecorator to compute the loss for each
# Matryoshka dimensionality given pre-computed embeddings passed to `calculate_loss`.
self.cached_losses = (
CachedMultipleNegativesRankingLoss,
CachedGISTEmbedLoss,
CachedMultipleNegativesSymmetricRankingLoss,
)

if isinstance(loss, self.cached_losses):
loss.calculate_loss = CachedLossDecorator(
loss.calculate_loss, self.matryoshka_dims, self.matryoshka_weights
)

def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
# For the Cached... losses, the CachedLossDecorator has been applied to the `calculate_loss` method,
# so we can directly call the loss function.
if isinstance(self.loss, self.cached_losses):
return self.loss(sentence_features, labels)

# Otherwise, we apply the ForwardDecorator to the model's forward pass, which will cache the output
# embeddings for each Matryoshka dimensionality, allowing it to be reused for the smaller dimensions.
original_forward = self.model.forward
try:
decorated_forward = ForwardDecorator(original_forward)
Expand Down

0 comments on commit 414e7a8

Please sign in to comment.