From 414e7a86df4e5fd7aae8ba57e1c0164090f573eb Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 20 Nov 2024 13:47:10 +0100 Subject: [PATCH] Add some docstrings to the decorators in the MatryoshkaLoss --- .../losses/MatryoshkaLoss.py | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/sentence_transformers/losses/MatryoshkaLoss.py b/sentence_transformers/losses/MatryoshkaLoss.py index 25b23a630..fb71cc15d 100644 --- a/sentence_transformers/losses/MatryoshkaLoss.py +++ b/sentence_transformers/losses/MatryoshkaLoss.py @@ -27,6 +27,14 @@ def shrink(tensor: Tensor, dim: int) -> Tensor: class ForwardDecorator: + """ + This decorator is used to cache the output of the Sentence Transformer's forward pass, + so that it can be shrank and reused for multiple loss calculations. This prevents the + model from recalculating the embeddings for each desired Matryoshka dimensionality. + + This decorator is applied to `SentenceTransformer.forward`. + """ + def __init__(self, fn) -> None: self.fn = fn @@ -56,6 +64,15 @@ def __call__(self, features: dict[str, Tensor]) -> dict[str, Tensor]: class CachedLossDecorator: + """ + This decorator is used with the Cached... losses to compute the underlying loss function + for each Matryoshka dimensionality. This is done by shrinking the pre-computed embeddings + to the desired dimensionality and then passing them to the underlying loss function once + for each desired dimensionality. + + This decorator is applied to the `calculate_loss` method of the Cached... losses. + """ + def __init__( self, fn, matryoshka_dims: list[int], matryoshka_weights: list[float | int], n_dims_per_step: int = -1 ) -> None: @@ -158,21 +175,28 @@ def __init__( dims_weights = zip(matryoshka_dims, matryoshka_weights) self.matryoshka_dims, self.matryoshka_weights = zip(*sorted(dims_weights, key=lambda x: x[0], reverse=True)) self.n_dims_per_step = n_dims_per_step + + # The Cached... losses require a special treatment as their backward pass is incompatible with the + # ForwardDecorator approach. Instead, we use a CachedLossDecorator to compute the loss for each + # Matryoshka dimensionality given pre-computed embeddings passed to `calculate_loss`. self.cached_losses = ( CachedMultipleNegativesRankingLoss, CachedGISTEmbedLoss, CachedMultipleNegativesSymmetricRankingLoss, ) - if isinstance(loss, self.cached_losses): loss.calculate_loss = CachedLossDecorator( loss.calculate_loss, self.matryoshka_dims, self.matryoshka_weights ) def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor: + # For the Cached... losses, the CachedLossDecorator has been applied to the `calculate_loss` method, + # so we can directly call the loss function. if isinstance(self.loss, self.cached_losses): return self.loss(sentence_features, labels) + # Otherwise, we apply the ForwardDecorator to the model's forward pass, which will cache the output + # embeddings for each Matryoshka dimensionality, allowing it to be reused for the smaller dimensions. original_forward = self.model.forward try: decorated_forward = ForwardDecorator(original_forward)