diff --git a/docs/package_reference/util.md b/docs/package_reference/util.md
index 64b86ea80..a3f30fb15 100644
--- a/docs/package_reference/util.md
+++ b/docs/package_reference/util.md
@@ -3,5 +3,5 @@
```eval_rst
.. automodule:: sentence_transformers.util
- :members: cos_sim, dot_score, paraphrase_mining, semantic_search, community_detection, http_get
+ :members: cos_sim, dot_score, paraphrase_mining, semantic_search, community_detection, http_get, truncate_embeddings
```
diff --git a/examples/training/matryoshka/README.md b/examples/training/matryoshka/README.md
index fb92df240..62fb2e623 100644
--- a/examples/training/matryoshka/README.md
+++ b/examples/training/matryoshka/README.md
@@ -54,16 +54,20 @@ loss = Matryoshka2dLoss(model=model, loss=base_loss, matryoshka_dims=[768, 512,
## Inference
-After a model has been trained using a Matryoshka loss, you can then run inference with it using SentenceTransformers.encode
. You must then truncate the resulting embeddings, and it is recommended to renormalize the embeddings.
+After a model has been trained using a Matryoshka loss, you can then run inference with it using SentenceTransformers.encode
.
```python
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import torch.nn.functional as F
-model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
-
matryoshka_dim = 64
+model = SentenceTransformer(
+ "nomic-ai/nomic-embed-text-v1.5",
+ trust_remote_code=True,
+ truncate_dim=matryoshka_dim,
+)
+
embeddings = model.encode(
[
"search_query: What is TSNE?",
@@ -71,7 +75,7 @@ embeddings = model.encode(
"search_document: Amelia Mary Earhart was an American aviation pioneer and writer.",
]
)
-embeddings = embeddings[..., :matryoshka_dim] # Shrink the embedding dimensions
+assert embeddings.shape[-1] == matryoshka_dim
similarities = cos_sim(embeddings[0], embeddings[1:])
# => tensor([[0.7839, 0.4933]])
@@ -86,6 +90,7 @@ See the following scripts as examples of how to apply the Matryoshka2dLoss
:
diff --git a/examples/training/matryoshka/matryoshka_eval_stsb.py b/examples/training/matryoshka/matryoshka_eval_stsb.py
index f7153bb94..4c5cb511d 100644
--- a/examples/training/matryoshka/matryoshka_eval_stsb.py
+++ b/examples/training/matryoshka/matryoshka_eval_stsb.py
@@ -4,10 +4,8 @@
"""
import argparse
-from contextlib import contextmanager
-from functools import wraps
import os
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
+from typing import Dict, List, Optional, Tuple, cast
from datasets import load_dataset
import numpy as np
@@ -17,50 +15,9 @@
EmbeddingSimilarityEvaluator,
SimilarityFunction,
)
-import torch
from tqdm.auto import tqdm
-# Util to truncate
-# Should patch instance, not the class b/c maybe there are other models floating around
-# that shouldn't get truncated
-@contextmanager
-def _monkeypatch_instance_method(obj: Any, method_name: str, new_method: Callable):
- original_method = getattr(obj, method_name)
- # Need to use __get__ when patching instance methods
- # https://stackoverflow.com/a/28127947/18758987
- try:
- setattr(obj, method_name, new_method.__get__(obj, obj.__class__))
- yield
- finally:
- setattr(obj, method_name, original_method.__get__(obj, obj.__class__))
-
-
-@contextmanager
-def truncate_embeddings(model: SentenceTransformer, dim: int):
- """
- In this context, the `model` outputs embeddings truncated at dimension `dim`.
-
- Parameters
- ----------
- model : SentenceTransformer
- model where `model.encode` outputs a (D,) or (N, D) array or tensor of
- embeddings given text(s)
- dim : int
- dimension to truncate at. So a (N, D) array becomes (N, `dim`)
- """
-
- original_encode = model.encode
-
- @wraps(original_encode)
- def encode(self, *args, **kwargs) -> Union[np.ndarray, torch.Tensor]:
- embeddings = original_encode(*args, **kwargs)
- return embeddings[..., :dim]
-
- with _monkeypatch_instance_method(model, "encode", encode):
- yield
-
-
# Dimension plot
def _grouped_barplot_ratios(
group_name_to_x_to_y: Dict[str, Dict[int, float]], ax: Optional[plt.Axes] = None
@@ -202,7 +159,7 @@ def plot_across_dimensions(
for dim in tqdm(DIMENSIONS, desc=f"Evaluating {model_name}"):
output_path = os.path.join(model_name, f"dim-{dim}")
os.makedirs(output_path)
- with truncate_embeddings(model, dim):
+ with model.truncate_sentence_embeddings(dim):
score = test_evaluator(model, output_path=output_path)
print(f"Saved results to {output_path}")
dim_to_score[dim] = score
diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py
index 3163bb10a..6e7ec7622 100644
--- a/sentence_transformers/SentenceTransformer.py
+++ b/sentence_transformers/SentenceTransformer.py
@@ -1,3 +1,4 @@
+from contextlib import contextmanager
import json
import logging
import os
@@ -31,6 +32,7 @@
load_file_path,
save_to_hub_args_decorator,
get_device_name,
+ truncate_embeddings,
)
from .quantization import quantize_embeddings
from .models import Transformer, Pooling, Normalize
@@ -68,6 +70,8 @@ class SentenceTransformer(nn.Sequential):
This option should only be set to True for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
:param token: Hugging Face authentication token to download private models.
+ :param truncate_dim: The dimension to truncate sentence embeddings to. `None` does no truncation. Truncation is
+ only applicable during inference when `.encode` is called.
"""
def __init__(
@@ -82,10 +86,12 @@ def __init__(
revision: Optional[str] = None,
token: Optional[Union[bool, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
+ truncate_dim: Optional[int] = None,
):
# Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name`
self.prompts = prompts or {}
self.default_prompt_name = default_prompt_name
+ self.truncate_dim = truncate_dim
self._model_card_vars = {}
self._model_card_text = None
self._model_config = {}
@@ -253,7 +259,7 @@ def encode(
prompt: Optional[str] = None,
batch_size: int = 32,
show_progress_bar: bool = None,
- output_value: str = "sentence_embedding",
+ output_value: Optional[Literal["sentence_embedding", "token_embeddings"]] = "sentence_embedding",
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
@@ -289,7 +295,8 @@ def encode(
:return: By default, a 2d numpy array with shape [num_inputs, output_dimension] is returned. If only one string
input is provided, then the output is a 1d array with shape [output_dimension]. If `convert_to_tensor`, a
- torch Tensor is returned instead.
+ torch Tensor is returned instead. If `self.truncate_dim <= output_dimension` then output_dimension is
+ `self.truncate_dim`.
"""
self.eval()
if show_progress_bar is None:
@@ -355,6 +362,9 @@ def encode(
with torch.no_grad():
out_features = self.forward(features)
+ out_features["sentence_embedding"] = truncate_embeddings(
+ out_features["sentence_embedding"], self.truncate_dim
+ )
if output_value == "token_embeddings":
embeddings = []
@@ -572,11 +582,48 @@ def get_sentence_features(self, *features):
return self._first_module().get_sentence_features(*features)
def get_sentence_embedding_dimension(self):
+ """
+ :return: The number of dimensions in the output of `encode`. If it's not known, it's `None`.
+ """
+ output_dim = None
for mod in reversed(self._modules.values()):
sent_embedding_dim_method = getattr(mod, "get_sentence_embedding_dimension", None)
if callable(sent_embedding_dim_method):
- return sent_embedding_dim_method()
- return None
+ output_dim = sent_embedding_dim_method()
+ break
+ if self.truncate_dim is not None:
+ # The user requested truncation. If they set it to a dim greater than output_dim,
+ # no truncation will actually happen. So return output_dim insead of self.truncate_dim
+ return min(output_dim or np.inf, self.truncate_dim)
+ return output_dim
+
+ @contextmanager
+ def truncate_sentence_embeddings(self, truncate_dim: Optional[int]):
+ """
+ In this context, `model.encode` outputs sentence embeddings truncated at dimension `truncate_dim`.
+
+ This may be useful when you are using the same model for different applications where different dimensions
+ are needed.
+
+ :param truncate_dim: The dimension to truncate sentence embeddings to. `None` does no truncation.
+
+ Example::
+
+ from sentence_transformers import SentenceTransformer
+
+ model = SentenceTransformer("model-name")
+
+ with model.truncate_sentence_embeddings(truncate_dim=16):
+ embeddings_truncated = model.encode(["hello there", "hiya"])
+ assert embeddings_truncated.shape[-1] == 16
+
+ """
+ original_output_dim = self.truncate_dim
+ try:
+ self.truncate_dim = truncate_dim
+ yield
+ finally:
+ self.truncate_dim = original_output_dim
def _first_module(self):
"""Returns the first module of this sequential embedder"""
diff --git a/sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py b/sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py
index 130d047b6..75c3484d7 100644
--- a/sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py
+++ b/sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py
@@ -1,3 +1,4 @@
+from contextlib import nullcontext
from . import SentenceEvaluator, SimilarityFunction
import logging
import os
@@ -33,6 +34,7 @@ def __init__(
show_progress_bar: bool = False,
write_csv: bool = True,
precision: Optional[Literal["float32", "int8", "uint8", "binary", "ubinary"]] = None,
+ truncate_dim: Optional[int] = None,
):
"""
Constructs an evaluator based for the dataset
@@ -45,12 +47,15 @@ def __init__(
:param write_csv: Write results to a CSV file
:param precision: The precision to use for the embeddings. Can be "float32", "int8", "uint8", "binary", or
"ubinary". Defaults to None.
+ :param truncate_dim: The dimension to truncate sentence embeddings to. `None` uses the model's current
+ truncation dimension. Defaults to None.
"""
self.sentences1 = sentences1
self.sentences2 = sentences2
self.scores = scores
self.write_csv = write_csv
self.precision = precision
+ self.truncate_dim = truncate_dim
assert len(self.sentences1) == len(self.sentences2)
assert len(self.sentences1) == len(self.scores)
@@ -107,22 +112,23 @@ def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int =
logger.info("EmbeddingSimilarityEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt)
- embeddings1 = model.encode(
- self.sentences1,
- batch_size=self.batch_size,
- show_progress_bar=self.show_progress_bar,
- convert_to_numpy=True,
- precision=self.precision,
- normalize_embeddings=bool(self.precision),
- )
- embeddings2 = model.encode(
- self.sentences2,
- batch_size=self.batch_size,
- show_progress_bar=self.show_progress_bar,
- convert_to_numpy=True,
- precision=self.precision,
- normalize_embeddings=bool(self.precision),
- )
+ with nullcontext() if self.truncate_dim is None else model.truncate_sentence_embeddings(self.truncate_dim):
+ embeddings1 = model.encode(
+ self.sentences1,
+ batch_size=self.batch_size,
+ show_progress_bar=self.show_progress_bar,
+ convert_to_numpy=True,
+ precision=self.precision,
+ normalize_embeddings=bool(self.precision),
+ )
+ embeddings2 = model.encode(
+ self.sentences2,
+ batch_size=self.batch_size,
+ show_progress_bar=self.show_progress_bar,
+ convert_to_numpy=True,
+ precision=self.precision,
+ normalize_embeddings=bool(self.precision),
+ )
# Binary and ubinary embeddings are packed, so we need to unpack them for the distance metrics
if self.precision == "binary":
embeddings1 = (embeddings1 + 128).astype(np.uint8)
diff --git a/sentence_transformers/util.py b/sentence_transformers/util.py
index ecebc953a..f66b494fc 100644
--- a/sentence_transformers/util.py
+++ b/sentence_transformers/util.py
@@ -10,7 +10,7 @@
import numpy as np
import queue
import logging
-from typing import Dict, Optional, Union
+from typing import Dict, Optional, Union, overload
from transformers import is_torch_npu_available
from huggingface_hub import snapshot_download, hf_hub_download
@@ -142,6 +142,25 @@ def normalize_embeddings(embeddings: Tensor) -> Tensor:
return torch.nn.functional.normalize(embeddings, p=2, dim=1)
+@overload
+def truncate_embeddings(embeddings: np.ndarray, truncate_dim: Optional[int]) -> np.ndarray: ...
+
+
+@overload
+def truncate_embeddings(embeddings: torch.Tensor, truncate_dim: Optional[int]) -> torch.Tensor: ...
+
+
+def truncate_embeddings(
+ embeddings: Union[np.ndarray, torch.Tensor], truncate_dim: Optional[int]
+) -> Union[np.ndarray, torch.Tensor]:
+ """
+ :param embeddings: Embeddings to truncate.
+ :param truncate_dim: The dimension to truncate sentence embeddings to. `None` does no truncation.
+ :return: Truncated embeddings.
+ """
+ return embeddings[..., :truncate_dim]
+
+
def paraphrase_mining(
model, sentences: List[str], show_progress_bar: bool = False, batch_size: int = 32, *args, **kwargs
) -> List[List[Union[float, int]]]:
diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py
index 7328c6afe..770a737e5 100644
--- a/tests/test_sentence_transformer.py
+++ b/tests/test_sentence_transformer.py
@@ -2,12 +2,15 @@
Tests general behaviour of the SentenceTransformer class
"""
+from functools import partial
import json
import logging
import os
from pathlib import Path
import re
import tempfile
+from typing import Dict, List, Literal, Optional, Union, cast
+
import numpy as np
import pytest
@@ -15,6 +18,7 @@
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Normalize, Transformer, Pooling
+from sentence_transformers import util
def test_load_with_safetensors() -> None:
@@ -380,3 +384,96 @@ def test_encode_quantization(
else:
assert embeddings[0].dtype == expected_torch_dtype
assert isinstance(embeddings, list)
+
+
+@pytest.mark.parametrize("sentences", ("Single sentence", ["One sentence", "Another sentence"]))
+@pytest.mark.parametrize("convert_to_tensor", [True, False])
+@pytest.mark.parametrize("convert_to_numpy", [True, False])
+@pytest.mark.parametrize("normalize_embeddings", [True, False])
+@pytest.mark.parametrize("output_value", ["sentence_embedding", None])
+def test_encode_truncate(
+ sentences: Union[str, List[str]],
+ convert_to_tensor: bool,
+ convert_to_numpy: bool,
+ normalize_embeddings: bool,
+ output_value: Optional[Literal["sentence_embedding"]],
+) -> None:
+ model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")
+ embeddings_full_unnormalized: torch.Tensor = model.encode(
+ sentences, convert_to_numpy=False, convert_to_tensor=True
+ ) # These are raw embeddings which serve as the reference to test against
+
+ def test(model: SentenceTransformer, expected_dim: int):
+ outputs = model.encode(
+ sentences,
+ output_value=output_value,
+ convert_to_tensor=convert_to_tensor,
+ convert_to_numpy=convert_to_numpy,
+ normalize_embeddings=normalize_embeddings,
+ )
+
+ # Extract the sentence embeddings out of outputs
+ if output_value is None:
+ # We get the whole plate
+ if not isinstance(outputs, List):
+ embeddings = outputs["sentence_embedding"]
+ else:
+ outputs = cast(List[Dict[str, torch.Tensor]], outputs)
+ # TODO: can overload model.encode if ppl want type checker compatibility
+ embeddings = [out_features["sentence_embedding"] for out_features in outputs]
+ else:
+ embeddings = outputs
+
+ # Test shape
+ if isinstance(embeddings, list): # list of tensors
+ embeddings_shape = (len(embeddings), embeddings[0].shape[-1])
+ else:
+ embeddings_shape = embeddings.shape
+ expected_shape = (expected_dim,) if isinstance(sentences, str) else (len(sentences), expected_dim)
+ assert embeddings_shape == expected_shape
+ assert model.get_sentence_embedding_dimension() == expected_dim
+
+ # Convert embeddings to a torch Tensor for ease of testing
+ if isinstance(embeddings, list):
+ embeddings = torch.stack(embeddings)
+ elif isinstance(embeddings, np.ndarray):
+ embeddings = torch.from_numpy(embeddings).to(embeddings_full_unnormalized.device)
+ # On a non-cpu device, the device of torch.from_numpy(embeddings) is always CPU
+
+ # Test content
+ if normalize_embeddings:
+ if output_value is None:
+ # Currently, normalization is not performed; it's the raw output of the forward pass
+ pass
+ else:
+ normalize = partial(torch.nn.functional.normalize, p=2, dim=-1)
+ assert torch.allclose(
+ embeddings,
+ normalize(util.truncate_embeddings(embeddings_full_unnormalized, expected_dim)),
+ )
+ else:
+ assert torch.allclose(embeddings, util.truncate_embeddings(embeddings_full_unnormalized, expected_dim))
+
+ # Test init w/o setting truncate_dim (it's None)
+ original_output_dim: int = model.get_sentence_embedding_dimension()
+ test(model, expected_dim=original_output_dim)
+
+ # Test init w/ a set truncate_dim
+ truncate_dim = int(original_output_dim / 4)
+ model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors", truncate_dim=truncate_dim)
+ test(model, expected_dim=truncate_dim)
+
+ # Test setting the attribute after init to a greater dimension
+ new_truncate_dim = 2 * truncate_dim
+ model.truncate_dim = new_truncate_dim
+ test(model, expected_dim=new_truncate_dim)
+
+ # Test context manager
+ final_truncate_dim = int(original_output_dim / 8)
+ with model.truncate_sentence_embeddings(final_truncate_dim):
+ test(model, expected_dim=final_truncate_dim)
+ test(model, expected_dim=new_truncate_dim) # b/c we've exited the context
+
+ # Test w/ an ouptut_dim that's larger than the original_output_dim. No truncation ends up happening
+ model.truncate_dim = 2 * original_output_dim
+ test(model, expected_dim=original_output_dim)