From fc4d5b56d9efec6aa6548ed68c18a5e967e41e50 Mon Sep 17 00:00:00 2001 From: Kush Dubey Date: Wed, 3 Apr 2024 09:26:10 -0700 Subject: [PATCH 1/7] Add truncation support --- examples/training/matryoshka/README.md | 13 +++- .../matryoshka/matryoshka_eval_stsb.py | 47 +----------- sentence_transformers/SentenceTransformer.py | 53 +++++++++++-- sentence_transformers/util.py | 21 +++++- tests/test_sentence_transformer.py | 74 +++++++++++++++++++ 5 files changed, 153 insertions(+), 55 deletions(-) diff --git a/examples/training/matryoshka/README.md b/examples/training/matryoshka/README.md index fb92df240..9a075f0b9 100644 --- a/examples/training/matryoshka/README.md +++ b/examples/training/matryoshka/README.md @@ -54,16 +54,20 @@ loss = Matryoshka2dLoss(model=model, loss=base_loss, matryoshka_dims=[768, 512, ## Inference -After a model has been trained using a Matryoshka loss, you can then run inference with it using SentenceTransformers.encode. You must then truncate the resulting embeddings, and it is recommended to renormalize the embeddings. +After a model has been trained using a Matryoshka loss, you can then run inference with it using SentenceTransformers.encode. ```python from sentence_transformers import SentenceTransformer from sentence_transformers.util import cos_sim import torch.nn.functional as F -model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True) - matryoshka_dim = 64 +model = SentenceTransformer( + "nomic-ai/nomic-embed-text-v1.5", + trust_remote_code=True, + output_dim=matryoshka_dim, +) + embeddings = model.encode( [ "search_query: What is TSNE?", @@ -71,7 +75,7 @@ embeddings = model.encode( "search_document: Amelia Mary Earhart was an American aviation pioneer and writer.", ] ) -embeddings = embeddings[..., :matryoshka_dim] # Shrink the embedding dimensions +assert embeddings.shape[-1] == matryoshka_dim similarities = cos_sim(embeddings[0], embeddings[1:]) # => tensor([[0.7839, 0.4933]]) @@ -86,6 +90,7 @@ See the following scripts as examples of how to apply the Matryoshka2dLoss: diff --git a/examples/training/matryoshka/matryoshka_eval_stsb.py b/examples/training/matryoshka/matryoshka_eval_stsb.py index f7153bb94..4c5cb511d 100644 --- a/examples/training/matryoshka/matryoshka_eval_stsb.py +++ b/examples/training/matryoshka/matryoshka_eval_stsb.py @@ -4,10 +4,8 @@ """ import argparse -from contextlib import contextmanager -from functools import wraps import os -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast +from typing import Dict, List, Optional, Tuple, cast from datasets import load_dataset import numpy as np @@ -17,50 +15,9 @@ EmbeddingSimilarityEvaluator, SimilarityFunction, ) -import torch from tqdm.auto import tqdm -# Util to truncate -# Should patch instance, not the class b/c maybe there are other models floating around -# that shouldn't get truncated -@contextmanager -def _monkeypatch_instance_method(obj: Any, method_name: str, new_method: Callable): - original_method = getattr(obj, method_name) - # Need to use __get__ when patching instance methods - # https://stackoverflow.com/a/28127947/18758987 - try: - setattr(obj, method_name, new_method.__get__(obj, obj.__class__)) - yield - finally: - setattr(obj, method_name, original_method.__get__(obj, obj.__class__)) - - -@contextmanager -def truncate_embeddings(model: SentenceTransformer, dim: int): - """ - In this context, the `model` outputs embeddings truncated at dimension `dim`. - - Parameters - ---------- - model : SentenceTransformer - model where `model.encode` outputs a (D,) or (N, D) array or tensor of - embeddings given text(s) - dim : int - dimension to truncate at. So a (N, D) array becomes (N, `dim`) - """ - - original_encode = model.encode - - @wraps(original_encode) - def encode(self, *args, **kwargs) -> Union[np.ndarray, torch.Tensor]: - embeddings = original_encode(*args, **kwargs) - return embeddings[..., :dim] - - with _monkeypatch_instance_method(model, "encode", encode): - yield - - # Dimension plot def _grouped_barplot_ratios( group_name_to_x_to_y: Dict[str, Dict[int, float]], ax: Optional[plt.Axes] = None @@ -202,7 +159,7 @@ def plot_across_dimensions( for dim in tqdm(DIMENSIONS, desc=f"Evaluating {model_name}"): output_path = os.path.join(model_name, f"dim-{dim}") os.makedirs(output_path) - with truncate_embeddings(model, dim): + with model.truncate_sentence_embeddings(dim): score = test_evaluator(model, output_path=output_path) print(f"Saved results to {output_path}") dim_to_score[dim] = score diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 3163bb10a..c8e1c486c 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager import json import logging import os @@ -31,6 +32,7 @@ load_file_path, save_to_hub_args_decorator, get_device_name, + truncate_embeddings, ) from .quantization import quantize_embeddings from .models import Transformer, Pooling, Normalize @@ -68,6 +70,7 @@ class SentenceTransformer(nn.Sequential): This option should only be set to True for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. :param token: Hugging Face authentication token to download private models. + :param output_dim: Number of dimensions to truncate sentence embeddings to. `None` does no truncation. """ def __init__( @@ -82,10 +85,12 @@ def __init__( revision: Optional[str] = None, token: Optional[Union[bool, str]] = None, use_auth_token: Optional[Union[bool, str]] = None, + output_dim: Optional[int] = None, ): # Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name` self.prompts = prompts or {} self.default_prompt_name = default_prompt_name + self.output_dim = output_dim self._model_card_vars = {} self._model_card_text = None self._model_config = {} @@ -253,7 +258,7 @@ def encode( prompt: Optional[str] = None, batch_size: int = 32, show_progress_bar: bool = None, - output_value: str = "sentence_embedding", + output_value: Optional[Literal["sentence_embedding", "token_embeddings"]] = "sentence_embedding", precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32", convert_to_numpy: bool = True, convert_to_tensor: bool = False, @@ -289,7 +294,7 @@ def encode( :return: By default, a 2d numpy array with shape [num_inputs, output_dimension] is returned. If only one string input is provided, then the output is a 1d array with shape [output_dimension]. If `convert_to_tensor`, a - torch Tensor is returned instead. + torch Tensor is returned instead. If `self.output_dim is not None` then output_dimension is `self.output_dim`. """ self.eval() if show_progress_bar is None: @@ -370,8 +375,9 @@ def encode( row = {name: out_features[name][sent_idx] for name in out_features} embeddings.append(row) else: # Sentence embeddings - embeddings = out_features[output_value] + embeddings: torch.Tensor = out_features[output_value] embeddings = embeddings.detach() + embeddings = truncate_embeddings(embeddings, self.output_dim) if normalize_embeddings: embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) @@ -572,11 +578,48 @@ def get_sentence_features(self, *features): return self._first_module().get_sentence_features(*features) def get_sentence_embedding_dimension(self): + """ + :return: The number of dimensions in the output of `encode`. If it's not known, it's `None`. + """ + output_dim = None for mod in reversed(self._modules.values()): sent_embedding_dim_method = getattr(mod, "get_sentence_embedding_dimension", None) if callable(sent_embedding_dim_method): - return sent_embedding_dim_method() - return None + output_dim = sent_embedding_dim_method() + break + if self.output_dim is not None: + # The user requested truncation. If they set it to a dim greater than output_dim, + # no truncation will actually happen. So return output_dim insead of self.output_dim + return min(output_dim or np.inf, self.output_dim) + return output_dim + + @contextmanager + def truncate_sentence_embeddings(self, output_dim: int | None): + """ + In this context, `model.encode` outputs sentence embeddings truncated at dimension `output_dim`. + + This may be useful when you are using the same model for different applications where different dimensions + are needed. + + :param output_dim: Number of dimensions to truncate sentence embeddings to. `None` does no truncation. + + Example:: + + from sentence_transformers import SentenceTransformer + + model = SentenceTransformer("model-name") + + with model.truncate_sentence_embeddings(output_dim=16): + embeddings_truncated = model.encode(["hello there", "hiya"]) + assert embeddings_truncated.shape[-1] == 16 + + """ + original_output_dim = self.output_dim + try: + self.output_dim = output_dim + yield + finally: + self.output_dim = original_output_dim def _first_module(self): """Returns the first module of this sequential embedder""" diff --git a/sentence_transformers/util.py b/sentence_transformers/util.py index ecebc953a..a8bf6dcba 100644 --- a/sentence_transformers/util.py +++ b/sentence_transformers/util.py @@ -10,7 +10,7 @@ import numpy as np import queue import logging -from typing import Dict, Optional, Union +from typing import Dict, Optional, Union, overload from transformers import is_torch_npu_available from huggingface_hub import snapshot_download, hf_hub_download @@ -142,6 +142,25 @@ def normalize_embeddings(embeddings: Tensor) -> Tensor: return torch.nn.functional.normalize(embeddings, p=2, dim=1) +@overload +def truncate_embeddings(embeddings: np.ndarray, output_dim: Optional[int]) -> np.ndarray: ... + + +@overload +def truncate_embeddings(embeddings: torch.Tensor, output_dim: Optional[int]) -> torch.Tensor: ... + + +def truncate_embeddings( + embeddings: Union[np.ndarray, torch.Tensor], output_dim: Optional[int] +) -> Union[np.ndarray, torch.Tensor]: + """ + :param embeddings: Embeddings to truncate. + :param output_dim: Number of dimensions to truncate sentence embeddings to. `None` does no truncation. + :return: Truncated embeddings. + """ + return embeddings[..., :output_dim] + + def paraphrase_mining( model, sentences: List[str], show_progress_bar: bool = False, batch_size: int = 32, *args, **kwargs ) -> List[List[Union[float, int]]]: diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index 7328c6afe..788ae4e2e 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -2,12 +2,15 @@ Tests general behaviour of the SentenceTransformer class """ +from functools import partial import json import logging import os from pathlib import Path import re import tempfile +from typing import List, Union + import numpy as np import pytest @@ -15,6 +18,7 @@ import torch from sentence_transformers import SentenceTransformer from sentence_transformers.models import Normalize, Transformer, Pooling +from sentence_transformers import util def test_load_with_safetensors() -> None: @@ -380,3 +384,73 @@ def test_encode_quantization( else: assert embeddings[0].dtype == expected_torch_dtype assert isinstance(embeddings, list) + + +@pytest.mark.parametrize("convert_to_tensor", [True, False]) +@pytest.mark.parametrize("convert_to_numpy", [True, False]) +@pytest.mark.parametrize("normalize_embeddings", [True, False]) +@pytest.mark.parametrize("sentences", ("Single sentence", ["One sentence", "Another sentence"])) +def test_encode_truncate( + sentences: Union[str, List[str]], convert_to_tensor: bool, convert_to_numpy: bool, normalize_embeddings: bool +) -> None: + model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors") + embeddings_full_unnormalized: torch.Tensor = model.encode( + sentences, convert_to_numpy=False, convert_to_tensor=True + ) # These are raw embeddings which serve as the reference to test against + + def test(model: SentenceTransformer, expected_dim: int): + embeddings = model.encode( + sentences, + convert_to_tensor=convert_to_tensor, + convert_to_numpy=convert_to_numpy, + normalize_embeddings=normalize_embeddings, + ) + # Test shape + if isinstance(embeddings, list): # list of tensors + embeddings_shape = (len(embeddings), embeddings[0].shape[-1]) + else: + embeddings_shape = embeddings.shape + expected_shape = (expected_dim,) if isinstance(sentences, str) else (len(sentences), expected_dim) + assert embeddings_shape == expected_shape + assert model.get_sentence_embedding_dimension() == expected_dim + + # Convert embeddings to a torch Tensor for ease of testing + if isinstance(embeddings, list): + embeddings = torch.stack(embeddings) + elif isinstance(embeddings, np.ndarray): + embeddings = torch.from_numpy(embeddings).to(embeddings_full_unnormalized.device) + # On a non-cpu device, the device of torch.from_numpy(embeddings) is always CPU + + # Test content + if not normalize_embeddings: + assert torch.allclose(embeddings, util.truncate_embeddings(embeddings_full_unnormalized, expected_dim)) + else: + normalize = partial(torch.nn.functional.normalize, p=2, dim=-1) + assert torch.allclose( + embeddings, + normalize(util.truncate_embeddings(embeddings_full_unnormalized, expected_dim)), + ) + + # Test init w/o setting output_dim (it's None) + original_output_dim: int = model.get_sentence_embedding_dimension() + test(model, expected_dim=original_output_dim) + + # Test init w/ a set output_dim + output_dim = int(original_output_dim / 4) + model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors", output_dim=output_dim) + test(model, expected_dim=output_dim) + + # Test setting the attribute after init to a greater dimension + new_output_dim = 2 * output_dim + model.output_dim = new_output_dim + test(model, expected_dim=new_output_dim) + + # Test context manager + final_ouptut_dim = int(original_output_dim / 8) + with model.truncate_sentence_embeddings(final_ouptut_dim): + test(model, expected_dim=final_ouptut_dim) + test(model, expected_dim=new_output_dim) # b/c we've exited the context + + # Test w/ an ouptut_dim that's larger than the original_output_dim. No truncation ends up happening + model.output_dim = 2 * original_output_dim + test(model, expected_dim=original_output_dim) From d8e09dd33359e1c370bd6fa1916130f89f61b86a Mon Sep 17 00:00:00 2001 From: Kush Dubey Date: Wed, 3 Apr 2024 09:46:58 -0700 Subject: [PATCH 2/7] no pipe --- sentence_transformers/SentenceTransformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index c8e1c486c..78fb50f0b 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -594,7 +594,7 @@ def get_sentence_embedding_dimension(self): return output_dim @contextmanager - def truncate_sentence_embeddings(self, output_dim: int | None): + def truncate_sentence_embeddings(self, output_dim: Optional[int]): """ In this context, `model.encode` outputs sentence embeddings truncated at dimension `output_dim`. From 698b5df354a5ace0775ca6e42f9a39e835d59188 Mon Sep 17 00:00:00 2001 From: Kush Dubey Date: Thu, 4 Apr 2024 07:00:25 -0700 Subject: [PATCH 3/7] truncate_sentence_embeddings -> truncate_dim --- examples/training/matryoshka/matryoshka_eval_stsb.py | 2 +- sentence_transformers/SentenceTransformer.py | 4 ++-- tests/test_sentence_transformer.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/training/matryoshka/matryoshka_eval_stsb.py b/examples/training/matryoshka/matryoshka_eval_stsb.py index 4c5cb511d..fcbaf100a 100644 --- a/examples/training/matryoshka/matryoshka_eval_stsb.py +++ b/examples/training/matryoshka/matryoshka_eval_stsb.py @@ -159,7 +159,7 @@ def plot_across_dimensions( for dim in tqdm(DIMENSIONS, desc=f"Evaluating {model_name}"): output_path = os.path.join(model_name, f"dim-{dim}") os.makedirs(output_path) - with model.truncate_sentence_embeddings(dim): + with model.truncate_dim(dim): score = test_evaluator(model, output_path=output_path) print(f"Saved results to {output_path}") dim_to_score[dim] = score diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 78fb50f0b..aeaf671de 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -594,7 +594,7 @@ def get_sentence_embedding_dimension(self): return output_dim @contextmanager - def truncate_sentence_embeddings(self, output_dim: Optional[int]): + def truncate_dim(self, output_dim: Optional[int]): """ In this context, `model.encode` outputs sentence embeddings truncated at dimension `output_dim`. @@ -609,7 +609,7 @@ def truncate_sentence_embeddings(self, output_dim: Optional[int]): model = SentenceTransformer("model-name") - with model.truncate_sentence_embeddings(output_dim=16): + with model.truncate_dim(output_dim=16): embeddings_truncated = model.encode(["hello there", "hiya"]) assert embeddings_truncated.shape[-1] == 16 diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index 788ae4e2e..b0332666e 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -447,7 +447,7 @@ def test(model: SentenceTransformer, expected_dim: int): # Test context manager final_ouptut_dim = int(original_output_dim / 8) - with model.truncate_sentence_embeddings(final_ouptut_dim): + with model.truncate_dim(final_ouptut_dim): test(model, expected_dim=final_ouptut_dim) test(model, expected_dim=new_output_dim) # b/c we've exited the context From 5d4e172cd9c5e2952209b53cb4b3a3300198bb70 Mon Sep 17 00:00:00 2001 From: Kush Dubey Date: Thu, 4 Apr 2024 07:57:32 -0700 Subject: [PATCH 4/7] output_dim -> truncate_dim --- examples/training/matryoshka/README.md | 2 +- .../matryoshka/matryoshka_eval_stsb.py | 2 +- sentence_transformers/SentenceTransformer.py | 31 ++++++++++--------- sentence_transformers/util.py | 10 +++--- tests/test_sentence_transformer.py | 26 ++++++++-------- 5 files changed, 36 insertions(+), 35 deletions(-) diff --git a/examples/training/matryoshka/README.md b/examples/training/matryoshka/README.md index 9a075f0b9..62fb2e623 100644 --- a/examples/training/matryoshka/README.md +++ b/examples/training/matryoshka/README.md @@ -65,7 +65,7 @@ matryoshka_dim = 64 model = SentenceTransformer( "nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True, - output_dim=matryoshka_dim, + truncate_dim=matryoshka_dim, ) embeddings = model.encode( diff --git a/examples/training/matryoshka/matryoshka_eval_stsb.py b/examples/training/matryoshka/matryoshka_eval_stsb.py index fcbaf100a..4c5cb511d 100644 --- a/examples/training/matryoshka/matryoshka_eval_stsb.py +++ b/examples/training/matryoshka/matryoshka_eval_stsb.py @@ -159,7 +159,7 @@ def plot_across_dimensions( for dim in tqdm(DIMENSIONS, desc=f"Evaluating {model_name}"): output_path = os.path.join(model_name, f"dim-{dim}") os.makedirs(output_path) - with model.truncate_dim(dim): + with model.truncate_sentence_embeddings(dim): score = test_evaluator(model, output_path=output_path) print(f"Saved results to {output_path}") dim_to_score[dim] = score diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index aeaf671de..fb6f42c69 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -70,7 +70,7 @@ class SentenceTransformer(nn.Sequential): This option should only be set to True for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. :param token: Hugging Face authentication token to download private models. - :param output_dim: Number of dimensions to truncate sentence embeddings to. `None` does no truncation. + :param truncate_dim: The dimension to truncate sentence embeddings to. `None` does no truncation. """ def __init__( @@ -85,12 +85,12 @@ def __init__( revision: Optional[str] = None, token: Optional[Union[bool, str]] = None, use_auth_token: Optional[Union[bool, str]] = None, - output_dim: Optional[int] = None, + truncate_dim: Optional[int] = None, ): # Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name` self.prompts = prompts or {} self.default_prompt_name = default_prompt_name - self.output_dim = output_dim + self.truncate_dim = truncate_dim self._model_card_vars = {} self._model_card_text = None self._model_config = {} @@ -294,7 +294,8 @@ def encode( :return: By default, a 2d numpy array with shape [num_inputs, output_dimension] is returned. If only one string input is provided, then the output is a 1d array with shape [output_dimension]. If `convert_to_tensor`, a - torch Tensor is returned instead. If `self.output_dim is not None` then output_dimension is `self.output_dim`. + torch Tensor is returned instead. If `self.truncate_dim <= output_dimension` then output_dimension is + `self.truncate_dim`. """ self.eval() if show_progress_bar is None: @@ -377,7 +378,7 @@ def encode( else: # Sentence embeddings embeddings: torch.Tensor = out_features[output_value] embeddings = embeddings.detach() - embeddings = truncate_embeddings(embeddings, self.output_dim) + embeddings = truncate_embeddings(embeddings, self.truncate_dim) if normalize_embeddings: embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) @@ -587,21 +588,21 @@ def get_sentence_embedding_dimension(self): if callable(sent_embedding_dim_method): output_dim = sent_embedding_dim_method() break - if self.output_dim is not None: + if self.truncate_dim is not None: # The user requested truncation. If they set it to a dim greater than output_dim, - # no truncation will actually happen. So return output_dim insead of self.output_dim - return min(output_dim or np.inf, self.output_dim) + # no truncation will actually happen. So return output_dim insead of self.truncate_dim + return min(output_dim or np.inf, self.truncate_dim) return output_dim @contextmanager - def truncate_dim(self, output_dim: Optional[int]): + def truncate_sentence_embeddings(self, truncate_dim: Optional[int]): """ - In this context, `model.encode` outputs sentence embeddings truncated at dimension `output_dim`. + In this context, `model.encode` outputs sentence embeddings truncated at dimension `truncate_dim`. This may be useful when you are using the same model for different applications where different dimensions are needed. - :param output_dim: Number of dimensions to truncate sentence embeddings to. `None` does no truncation. + :param truncate_dim: The dimension to truncate sentence embeddings to. `None` does no truncation. Example:: @@ -609,17 +610,17 @@ def truncate_dim(self, output_dim: Optional[int]): model = SentenceTransformer("model-name") - with model.truncate_dim(output_dim=16): + with model.truncate_sentence_embeddings(truncate_dim=16): embeddings_truncated = model.encode(["hello there", "hiya"]) assert embeddings_truncated.shape[-1] == 16 """ - original_output_dim = self.output_dim + original_output_dim = self.truncate_dim try: - self.output_dim = output_dim + self.truncate_dim = truncate_dim yield finally: - self.output_dim = original_output_dim + self.truncate_dim = original_output_dim def _first_module(self): """Returns the first module of this sequential embedder""" diff --git a/sentence_transformers/util.py b/sentence_transformers/util.py index a8bf6dcba..f66b494fc 100644 --- a/sentence_transformers/util.py +++ b/sentence_transformers/util.py @@ -143,22 +143,22 @@ def normalize_embeddings(embeddings: Tensor) -> Tensor: @overload -def truncate_embeddings(embeddings: np.ndarray, output_dim: Optional[int]) -> np.ndarray: ... +def truncate_embeddings(embeddings: np.ndarray, truncate_dim: Optional[int]) -> np.ndarray: ... @overload -def truncate_embeddings(embeddings: torch.Tensor, output_dim: Optional[int]) -> torch.Tensor: ... +def truncate_embeddings(embeddings: torch.Tensor, truncate_dim: Optional[int]) -> torch.Tensor: ... def truncate_embeddings( - embeddings: Union[np.ndarray, torch.Tensor], output_dim: Optional[int] + embeddings: Union[np.ndarray, torch.Tensor], truncate_dim: Optional[int] ) -> Union[np.ndarray, torch.Tensor]: """ :param embeddings: Embeddings to truncate. - :param output_dim: Number of dimensions to truncate sentence embeddings to. `None` does no truncation. + :param truncate_dim: The dimension to truncate sentence embeddings to. `None` does no truncation. :return: Truncated embeddings. """ - return embeddings[..., :output_dim] + return embeddings[..., :truncate_dim] def paraphrase_mining( diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index b0332666e..d17af6141 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -431,26 +431,26 @@ def test(model: SentenceTransformer, expected_dim: int): normalize(util.truncate_embeddings(embeddings_full_unnormalized, expected_dim)), ) - # Test init w/o setting output_dim (it's None) + # Test init w/o setting truncate_dim (it's None) original_output_dim: int = model.get_sentence_embedding_dimension() test(model, expected_dim=original_output_dim) - # Test init w/ a set output_dim - output_dim = int(original_output_dim / 4) - model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors", output_dim=output_dim) - test(model, expected_dim=output_dim) + # Test init w/ a set truncate_dim + truncate_dim = int(original_output_dim / 4) + model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors", truncate_dim=truncate_dim) + test(model, expected_dim=truncate_dim) # Test setting the attribute after init to a greater dimension - new_output_dim = 2 * output_dim - model.output_dim = new_output_dim - test(model, expected_dim=new_output_dim) + new_truncate_dim = 2 * truncate_dim + model.truncate_dim = new_truncate_dim + test(model, expected_dim=new_truncate_dim) # Test context manager - final_ouptut_dim = int(original_output_dim / 8) - with model.truncate_dim(final_ouptut_dim): - test(model, expected_dim=final_ouptut_dim) - test(model, expected_dim=new_output_dim) # b/c we've exited the context + final_truncate_dim = int(original_output_dim / 8) + with model.truncate_sentence_embeddings(final_truncate_dim): + test(model, expected_dim=final_truncate_dim) + test(model, expected_dim=new_truncate_dim) # b/c we've exited the context # Test w/ an ouptut_dim that's larger than the original_output_dim. No truncation ends up happening - model.output_dim = 2 * original_output_dim + model.truncate_dim = 2 * original_output_dim test(model, expected_dim=original_output_dim) From 8125182206bcde523da519d09e92f70ec9813312 Mon Sep 17 00:00:00 2001 From: Kush Dubey Date: Thu, 4 Apr 2024 14:26:41 -0700 Subject: [PATCH 5/7] Also truncate when output_value is None --- sentence_transformers/SentenceTransformer.py | 6 ++- tests/test_sentence_transformer.py | 45 +++++++++++++++----- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index fb6f42c69..df0041509 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -361,6 +361,9 @@ def encode( with torch.no_grad(): out_features = self.forward(features) + out_features["sentence_embedding"] = truncate_embeddings( + out_features["sentence_embedding"], self.truncate_dim + ) if output_value == "token_embeddings": embeddings = [] @@ -376,9 +379,8 @@ def encode( row = {name: out_features[name][sent_idx] for name in out_features} embeddings.append(row) else: # Sentence embeddings - embeddings: torch.Tensor = out_features[output_value] + embeddings = out_features[output_value] embeddings = embeddings.detach() - embeddings = truncate_embeddings(embeddings, self.truncate_dim) if normalize_embeddings: embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index d17af6141..770a737e5 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -9,7 +9,7 @@ from pathlib import Path import re import tempfile -from typing import List, Union +from typing import Dict, List, Literal, Optional, Union, cast import numpy as np import pytest @@ -386,12 +386,17 @@ def test_encode_quantization( assert isinstance(embeddings, list) +@pytest.mark.parametrize("sentences", ("Single sentence", ["One sentence", "Another sentence"])) @pytest.mark.parametrize("convert_to_tensor", [True, False]) @pytest.mark.parametrize("convert_to_numpy", [True, False]) @pytest.mark.parametrize("normalize_embeddings", [True, False]) -@pytest.mark.parametrize("sentences", ("Single sentence", ["One sentence", "Another sentence"])) +@pytest.mark.parametrize("output_value", ["sentence_embedding", None]) def test_encode_truncate( - sentences: Union[str, List[str]], convert_to_tensor: bool, convert_to_numpy: bool, normalize_embeddings: bool + sentences: Union[str, List[str]], + convert_to_tensor: bool, + convert_to_numpy: bool, + normalize_embeddings: bool, + output_value: Optional[Literal["sentence_embedding"]], ) -> None: model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors") embeddings_full_unnormalized: torch.Tensor = model.encode( @@ -399,12 +404,26 @@ def test_encode_truncate( ) # These are raw embeddings which serve as the reference to test against def test(model: SentenceTransformer, expected_dim: int): - embeddings = model.encode( + outputs = model.encode( sentences, + output_value=output_value, convert_to_tensor=convert_to_tensor, convert_to_numpy=convert_to_numpy, normalize_embeddings=normalize_embeddings, ) + + # Extract the sentence embeddings out of outputs + if output_value is None: + # We get the whole plate + if not isinstance(outputs, List): + embeddings = outputs["sentence_embedding"] + else: + outputs = cast(List[Dict[str, torch.Tensor]], outputs) + # TODO: can overload model.encode if ppl want type checker compatibility + embeddings = [out_features["sentence_embedding"] for out_features in outputs] + else: + embeddings = outputs + # Test shape if isinstance(embeddings, list): # list of tensors embeddings_shape = (len(embeddings), embeddings[0].shape[-1]) @@ -422,14 +441,18 @@ def test(model: SentenceTransformer, expected_dim: int): # On a non-cpu device, the device of torch.from_numpy(embeddings) is always CPU # Test content - if not normalize_embeddings: - assert torch.allclose(embeddings, util.truncate_embeddings(embeddings_full_unnormalized, expected_dim)) + if normalize_embeddings: + if output_value is None: + # Currently, normalization is not performed; it's the raw output of the forward pass + pass + else: + normalize = partial(torch.nn.functional.normalize, p=2, dim=-1) + assert torch.allclose( + embeddings, + normalize(util.truncate_embeddings(embeddings_full_unnormalized, expected_dim)), + ) else: - normalize = partial(torch.nn.functional.normalize, p=2, dim=-1) - assert torch.allclose( - embeddings, - normalize(util.truncate_embeddings(embeddings_full_unnormalized, expected_dim)), - ) + assert torch.allclose(embeddings, util.truncate_embeddings(embeddings_full_unnormalized, expected_dim)) # Test init w/o setting truncate_dim (it's None) original_output_dim: int = model.get_sentence_embedding_dimension() From 9606521bfa0c3eda9288ccc98dfc0f380452e738 Mon Sep 17 00:00:00 2001 From: Kush Dubey Date: Thu, 4 Apr 2024 15:03:30 -0700 Subject: [PATCH 6/7] Truncate in EmbeddingSimilarityEvaluator --- sentence_transformers/SentenceTransformer.py | 3 +- .../EmbeddingSimilarityEvaluator.py | 38 +++++++++++-------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index df0041509..6e7ec7622 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -70,7 +70,8 @@ class SentenceTransformer(nn.Sequential): This option should only be set to True for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. :param token: Hugging Face authentication token to download private models. - :param truncate_dim: The dimension to truncate sentence embeddings to. `None` does no truncation. + :param truncate_dim: The dimension to truncate sentence embeddings to. `None` does no truncation. Truncation is + only applicable during inference when `.encode` is called. """ def __init__( diff --git a/sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py b/sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py index 130d047b6..75c3484d7 100644 --- a/sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py +++ b/sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from . import SentenceEvaluator, SimilarityFunction import logging import os @@ -33,6 +34,7 @@ def __init__( show_progress_bar: bool = False, write_csv: bool = True, precision: Optional[Literal["float32", "int8", "uint8", "binary", "ubinary"]] = None, + truncate_dim: Optional[int] = None, ): """ Constructs an evaluator based for the dataset @@ -45,12 +47,15 @@ def __init__( :param write_csv: Write results to a CSV file :param precision: The precision to use for the embeddings. Can be "float32", "int8", "uint8", "binary", or "ubinary". Defaults to None. + :param truncate_dim: The dimension to truncate sentence embeddings to. `None` uses the model's current + truncation dimension. Defaults to None. """ self.sentences1 = sentences1 self.sentences2 = sentences2 self.scores = scores self.write_csv = write_csv self.precision = precision + self.truncate_dim = truncate_dim assert len(self.sentences1) == len(self.sentences2) assert len(self.sentences1) == len(self.scores) @@ -107,22 +112,23 @@ def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = logger.info("EmbeddingSimilarityEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt) - embeddings1 = model.encode( - self.sentences1, - batch_size=self.batch_size, - show_progress_bar=self.show_progress_bar, - convert_to_numpy=True, - precision=self.precision, - normalize_embeddings=bool(self.precision), - ) - embeddings2 = model.encode( - self.sentences2, - batch_size=self.batch_size, - show_progress_bar=self.show_progress_bar, - convert_to_numpy=True, - precision=self.precision, - normalize_embeddings=bool(self.precision), - ) + with nullcontext() if self.truncate_dim is None else model.truncate_sentence_embeddings(self.truncate_dim): + embeddings1 = model.encode( + self.sentences1, + batch_size=self.batch_size, + show_progress_bar=self.show_progress_bar, + convert_to_numpy=True, + precision=self.precision, + normalize_embeddings=bool(self.precision), + ) + embeddings2 = model.encode( + self.sentences2, + batch_size=self.batch_size, + show_progress_bar=self.show_progress_bar, + convert_to_numpy=True, + precision=self.precision, + normalize_embeddings=bool(self.precision), + ) # Binary and ubinary embeddings are packed, so we need to unpack them for the distance metrics if self.precision == "binary": embeddings1 = (embeddings1 + 128).astype(np.uint8) From 4466dc344570baf1eec6c0d125dd0fd42a58f869 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 8 Apr 2024 10:59:07 +0200 Subject: [PATCH 7/7] Add truncate_embeddings to docs --- docs/package_reference/util.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/package_reference/util.md b/docs/package_reference/util.md index 64b86ea80..a3f30fb15 100644 --- a/docs/package_reference/util.md +++ b/docs/package_reference/util.md @@ -3,5 +3,5 @@ ```eval_rst .. automodule:: sentence_transformers.util - :members: cos_sim, dot_score, paraphrase_mining, semantic_search, community_detection, http_get + :members: cos_sim, dot_score, paraphrase_mining, semantic_search, community_detection, http_get, truncate_embeddings ```