Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add truncation support #2573

Merged
merged 7 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions examples/training/matryoshka/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,28 @@ loss = Matryoshka2dLoss(model=model, loss=base_loss, matryoshka_dims=[768, 512,

## Inference

After a model has been trained using a Matryoshka loss, you can then run inference with it using <a href="../../../docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"><code>SentenceTransformers.encode</code></a>. You must then truncate the resulting embeddings, and it is recommended to renormalize the embeddings.
After a model has been trained using a Matryoshka loss, you can then run inference with it using <a href="../../../docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"><code>SentenceTransformers.encode</code></a>.

```python
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import torch.nn.functional as F

model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)

matryoshka_dim = 64
model = SentenceTransformer(
"nomic-ai/nomic-embed-text-v1.5",
trust_remote_code=True,
output_dim=matryoshka_dim,
)
kddubey marked this conversation as resolved.
Show resolved Hide resolved

embeddings = model.encode(
[
"search_query: What is TSNE?",
"search_document: t-distributed stochastic neighbor embedding (t-SNE) is a statistical method for visualizing high-dimensional data by giving each datapoint a location in a two or three-dimensional map.",
"search_document: Amelia Mary Earhart was an American aviation pioneer and writer.",
]
)
embeddings = embeddings[..., :matryoshka_dim] # Shrink the embedding dimensions
assert embeddings.shape[-1] == matryoshka_dim

similarities = cos_sim(embeddings[0], embeddings[1:])
# => tensor([[0.7839, 0.4933]])
Expand All @@ -86,6 +90,7 @@ See the following scripts as examples of how to apply the <a href="../../../docs

* **[matryoshka_nli.py](matryoshka_nli.py)**: This example uses the MultipleNegativesRankingLoss with MatryoshkaLoss to train a strong embedding model using Natural Language Inference (NLI) data. It is an adaptation of the [NLI](../nli/README) documentation.
* **[matryoshka_nli_reduced_dim.py](matryoshka_nli_reduced_dim.py)**: This example uses the MultipleNegativesRankingLoss with MatryoshkaLoss to train a strong embedding model with a small maximum output dimension of 256. It trains using Natural Language Inference (NLI) data, and is an adaptation of the [NLI](../nli/README) documentation.
* **[matryoshka_eval_stsb.py](matryoshka_eval_stsb.py)**: This example evaluates the embedding model trained with MatryoshkaLoss in [matryoshka_nli.py](matryoshka_nli.py) on the test set of the STSBenchmark dataset, and compares it to a non-Matryoshka trained model.
* **[matryoshka_sts.py](matryoshka_sts.py)**: This example uses the CoSENTLoss with MatryoshkaLoss to train an embedding model on the training set of the STSBenchmark dataset. It is an adaptation of the [STS](../sts/README) documentation.

And the following scripts to see how to apply <a href="../../../docs/package_reference/losses.html#matryoshka2dloss"><code>Matryoshka2dLoss</code></a>:
Expand Down
47 changes: 2 additions & 45 deletions examples/training/matryoshka/matryoshka_eval_stsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
"""

import argparse
from contextlib import contextmanager
from functools import wraps
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
from typing import Dict, List, Optional, Tuple, cast

from datasets import load_dataset
import numpy as np
Expand All @@ -17,50 +15,9 @@
EmbeddingSimilarityEvaluator,
SimilarityFunction,
)
import torch
from tqdm.auto import tqdm


# Util to truncate
# Should patch instance, not the class b/c maybe there are other models floating around
# that shouldn't get truncated
@contextmanager
def _monkeypatch_instance_method(obj: Any, method_name: str, new_method: Callable):
original_method = getattr(obj, method_name)
# Need to use __get__ when patching instance methods
# https://stackoverflow.com/a/28127947/18758987
try:
setattr(obj, method_name, new_method.__get__(obj, obj.__class__))
yield
finally:
setattr(obj, method_name, original_method.__get__(obj, obj.__class__))


@contextmanager
def truncate_embeddings(model: SentenceTransformer, dim: int):
"""
In this context, the `model` outputs embeddings truncated at dimension `dim`.

Parameters
----------
model : SentenceTransformer
model where `model.encode` outputs a (D,) or (N, D) array or tensor of
embeddings given text(s)
dim : int
dimension to truncate at. So a (N, D) array becomes (N, `dim`)
"""

original_encode = model.encode

@wraps(original_encode)
def encode(self, *args, **kwargs) -> Union[np.ndarray, torch.Tensor]:
embeddings = original_encode(*args, **kwargs)
return embeddings[..., :dim]

with _monkeypatch_instance_method(model, "encode", encode):
yield


# Dimension plot
def _grouped_barplot_ratios(
group_name_to_x_to_y: Dict[str, Dict[int, float]], ax: Optional[plt.Axes] = None
Expand Down Expand Up @@ -202,7 +159,7 @@ def plot_across_dimensions(
for dim in tqdm(DIMENSIONS, desc=f"Evaluating {model_name}"):
output_path = os.path.join(model_name, f"dim-{dim}")
os.makedirs(output_path)
with truncate_embeddings(model, dim):
with model.truncate_sentence_embeddings(dim):
score = test_evaluator(model, output_path=output_path)
print(f"Saved results to {output_path}")
dim_to_score[dim] = score
Expand Down
53 changes: 48 additions & 5 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import contextmanager
import json
import logging
import os
Expand Down Expand Up @@ -31,6 +32,7 @@
load_file_path,
save_to_hub_args_decorator,
get_device_name,
truncate_embeddings,
)
from .quantization import quantize_embeddings
from .models import Transformer, Pooling, Normalize
Expand Down Expand Up @@ -68,6 +70,7 @@ class SentenceTransformer(nn.Sequential):
This option should only be set to True for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
:param token: Hugging Face authentication token to download private models.
:param output_dim: Number of dimensions to truncate sentence embeddings to. `None` does no truncation.
"""

def __init__(
Expand All @@ -82,10 +85,12 @@ def __init__(
revision: Optional[str] = None,
token: Optional[Union[bool, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
output_dim: Optional[int] = None,
):
# Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name`
self.prompts = prompts or {}
self.default_prompt_name = default_prompt_name
self.output_dim = output_dim
self._model_card_vars = {}
self._model_card_text = None
self._model_config = {}
Expand Down Expand Up @@ -253,7 +258,7 @@ def encode(
prompt: Optional[str] = None,
batch_size: int = 32,
show_progress_bar: bool = None,
output_value: str = "sentence_embedding",
output_value: Optional[Literal["sentence_embedding", "token_embeddings"]] = "sentence_embedding",
kddubey marked this conversation as resolved.
Show resolved Hide resolved
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
Expand Down Expand Up @@ -289,7 +294,7 @@ def encode(

:return: By default, a 2d numpy array with shape [num_inputs, output_dimension] is returned. If only one string
input is provided, then the output is a 1d array with shape [output_dimension]. If `convert_to_tensor`, a
torch Tensor is returned instead.
torch Tensor is returned instead. If `self.output_dim is not None` then output_dimension is `self.output_dim`.
"""
self.eval()
if show_progress_bar is None:
Expand Down Expand Up @@ -370,8 +375,9 @@ def encode(
row = {name: out_features[name][sent_idx] for name in out_features}
embeddings.append(row)
kddubey marked this conversation as resolved.
Show resolved Hide resolved
else: # Sentence embeddings
embeddings = out_features[output_value]
embeddings: torch.Tensor = out_features[output_value]
embeddings = embeddings.detach()
embeddings = truncate_embeddings(embeddings, self.output_dim)
kddubey marked this conversation as resolved.
Show resolved Hide resolved
if normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

Expand Down Expand Up @@ -572,11 +578,48 @@ def get_sentence_features(self, *features):
return self._first_module().get_sentence_features(*features)

def get_sentence_embedding_dimension(self):
"""
:return: The number of dimensions in the output of `encode`. If it's not known, it's `None`.
"""
output_dim = None
for mod in reversed(self._modules.values()):
sent_embedding_dim_method = getattr(mod, "get_sentence_embedding_dimension", None)
if callable(sent_embedding_dim_method):
return sent_embedding_dim_method()
return None
kddubey marked this conversation as resolved.
Show resolved Hide resolved
output_dim = sent_embedding_dim_method()
break
if self.output_dim is not None:
# The user requested truncation. If they set it to a dim greater than output_dim,
# no truncation will actually happen. So return output_dim insead of self.output_dim
return min(output_dim or np.inf, self.output_dim)
return output_dim

@contextmanager
def truncate_sentence_embeddings(self, output_dim: Optional[int]):
kddubey marked this conversation as resolved.
Show resolved Hide resolved
"""
In this context, `model.encode` outputs sentence embeddings truncated at dimension `output_dim`.

This may be useful when you are using the same model for different applications where different dimensions
are needed.

:param output_dim: Number of dimensions to truncate sentence embeddings to. `None` does no truncation.

Example::

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("model-name")

with model.truncate_sentence_embeddings(output_dim=16):
embeddings_truncated = model.encode(["hello there", "hiya"])
assert embeddings_truncated.shape[-1] == 16

"""
original_output_dim = self.output_dim
try:
self.output_dim = output_dim
yield
finally:
self.output_dim = original_output_dim

def _first_module(self):
"""Returns the first module of this sequential embedder"""
Expand Down
21 changes: 20 additions & 1 deletion sentence_transformers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import queue
import logging
from typing import Dict, Optional, Union
from typing import Dict, Optional, Union, overload

from transformers import is_torch_npu_available
from huggingface_hub import snapshot_download, hf_hub_download
Expand Down Expand Up @@ -142,6 +142,25 @@ def normalize_embeddings(embeddings: Tensor) -> Tensor:
return torch.nn.functional.normalize(embeddings, p=2, dim=1)


@overload
def truncate_embeddings(embeddings: np.ndarray, output_dim: Optional[int]) -> np.ndarray: ...


@overload
def truncate_embeddings(embeddings: torch.Tensor, output_dim: Optional[int]) -> torch.Tensor: ...


def truncate_embeddings(
embeddings: Union[np.ndarray, torch.Tensor], output_dim: Optional[int]
) -> Union[np.ndarray, torch.Tensor]:
"""
:param embeddings: Embeddings to truncate.
:param output_dim: Number of dimensions to truncate sentence embeddings to. `None` does no truncation.
:return: Truncated embeddings.
"""
return embeddings[..., :output_dim]


def paraphrase_mining(
model, sentences: List[str], show_progress_bar: bool = False, batch_size: int = 32, *args, **kwargs
) -> List[List[Union[float, int]]]:
Expand Down
74 changes: 74 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,23 @@
Tests general behaviour of the SentenceTransformer class
"""

from functools import partial
import json
import logging
import os
from pathlib import Path
import re
import tempfile
from typing import List, Union

import numpy as np
import pytest

from huggingface_hub import HfApi, RepoUrl, GitRefs, GitRefInfo
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Normalize, Transformer, Pooling
from sentence_transformers import util


def test_load_with_safetensors() -> None:
Expand Down Expand Up @@ -380,3 +384,73 @@ def test_encode_quantization(
else:
assert embeddings[0].dtype == expected_torch_dtype
assert isinstance(embeddings, list)


@pytest.mark.parametrize("convert_to_tensor", [True, False])
@pytest.mark.parametrize("convert_to_numpy", [True, False])
@pytest.mark.parametrize("normalize_embeddings", [True, False])
@pytest.mark.parametrize("sentences", ("Single sentence", ["One sentence", "Another sentence"]))
def test_encode_truncate(
sentences: Union[str, List[str]], convert_to_tensor: bool, convert_to_numpy: bool, normalize_embeddings: bool
) -> None:
model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")
embeddings_full_unnormalized: torch.Tensor = model.encode(
sentences, convert_to_numpy=False, convert_to_tensor=True
) # These are raw embeddings which serve as the reference to test against

def test(model: SentenceTransformer, expected_dim: int):
embeddings = model.encode(
sentences,
convert_to_tensor=convert_to_tensor,
convert_to_numpy=convert_to_numpy,
normalize_embeddings=normalize_embeddings,
)
# Test shape
if isinstance(embeddings, list): # list of tensors
embeddings_shape = (len(embeddings), embeddings[0].shape[-1])
else:
embeddings_shape = embeddings.shape
expected_shape = (expected_dim,) if isinstance(sentences, str) else (len(sentences), expected_dim)
assert embeddings_shape == expected_shape
assert model.get_sentence_embedding_dimension() == expected_dim

# Convert embeddings to a torch Tensor for ease of testing
if isinstance(embeddings, list):
embeddings = torch.stack(embeddings)
elif isinstance(embeddings, np.ndarray):
embeddings = torch.from_numpy(embeddings).to(embeddings_full_unnormalized.device)
# On a non-cpu device, the device of torch.from_numpy(embeddings) is always CPU

# Test content
if not normalize_embeddings:
assert torch.allclose(embeddings, util.truncate_embeddings(embeddings_full_unnormalized, expected_dim))
else:
normalize = partial(torch.nn.functional.normalize, p=2, dim=-1)
assert torch.allclose(
embeddings,
normalize(util.truncate_embeddings(embeddings_full_unnormalized, expected_dim)),
)

# Test init w/o setting output_dim (it's None)
original_output_dim: int = model.get_sentence_embedding_dimension()
test(model, expected_dim=original_output_dim)

# Test init w/ a set output_dim
output_dim = int(original_output_dim / 4)
model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors", output_dim=output_dim)
test(model, expected_dim=output_dim)

# Test setting the attribute after init to a greater dimension
new_output_dim = 2 * output_dim
model.output_dim = new_output_dim
test(model, expected_dim=new_output_dim)

# Test context manager
final_ouptut_dim = int(original_output_dim / 8)
with model.truncate_sentence_embeddings(final_ouptut_dim):
test(model, expected_dim=final_ouptut_dim)
test(model, expected_dim=new_output_dim) # b/c we've exited the context

# Test w/ an ouptut_dim that's larger than the original_output_dim. No truncation ends up happening
model.output_dim = 2 * original_output_dim
test(model, expected_dim=original_output_dim)
Loading