Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add truncation support #2573

Merged
merged 7 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/package_reference/util.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@

```eval_rst
.. automodule:: sentence_transformers.util
:members: cos_sim, dot_score, paraphrase_mining, semantic_search, community_detection, http_get
:members: cos_sim, dot_score, paraphrase_mining, semantic_search, community_detection, http_get, truncate_embeddings
```
13 changes: 9 additions & 4 deletions examples/training/matryoshka/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,28 @@ loss = Matryoshka2dLoss(model=model, loss=base_loss, matryoshka_dims=[768, 512,

## Inference

After a model has been trained using a Matryoshka loss, you can then run inference with it using <a href="../../../docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"><code>SentenceTransformers.encode</code></a>. You must then truncate the resulting embeddings, and it is recommended to renormalize the embeddings.
After a model has been trained using a Matryoshka loss, you can then run inference with it using <a href="../../../docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"><code>SentenceTransformers.encode</code></a>.

```python
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import torch.nn.functional as F

model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)

matryoshka_dim = 64
model = SentenceTransformer(
"nomic-ai/nomic-embed-text-v1.5",
trust_remote_code=True,
truncate_dim=matryoshka_dim,
)

embeddings = model.encode(
[
"search_query: What is TSNE?",
"search_document: t-distributed stochastic neighbor embedding (t-SNE) is a statistical method for visualizing high-dimensional data by giving each datapoint a location in a two or three-dimensional map.",
"search_document: Amelia Mary Earhart was an American aviation pioneer and writer.",
]
)
embeddings = embeddings[..., :matryoshka_dim] # Shrink the embedding dimensions
assert embeddings.shape[-1] == matryoshka_dim

similarities = cos_sim(embeddings[0], embeddings[1:])
# => tensor([[0.7839, 0.4933]])
Expand All @@ -86,6 +90,7 @@ See the following scripts as examples of how to apply the <a href="../../../docs

* **[matryoshka_nli.py](matryoshka_nli.py)**: This example uses the MultipleNegativesRankingLoss with MatryoshkaLoss to train a strong embedding model using Natural Language Inference (NLI) data. It is an adaptation of the [NLI](../nli/README) documentation.
* **[matryoshka_nli_reduced_dim.py](matryoshka_nli_reduced_dim.py)**: This example uses the MultipleNegativesRankingLoss with MatryoshkaLoss to train a strong embedding model with a small maximum output dimension of 256. It trains using Natural Language Inference (NLI) data, and is an adaptation of the [NLI](../nli/README) documentation.
* **[matryoshka_eval_stsb.py](matryoshka_eval_stsb.py)**: This example evaluates the embedding model trained with MatryoshkaLoss in [matryoshka_nli.py](matryoshka_nli.py) on the test set of the STSBenchmark dataset, and compares it to a non-Matryoshka trained model.
* **[matryoshka_sts.py](matryoshka_sts.py)**: This example uses the CoSENTLoss with MatryoshkaLoss to train an embedding model on the training set of the STSBenchmark dataset. It is an adaptation of the [STS](../sts/README) documentation.

And the following scripts to see how to apply <a href="../../../docs/package_reference/losses.html#matryoshka2dloss"><code>Matryoshka2dLoss</code></a>:
Expand Down
47 changes: 2 additions & 45 deletions examples/training/matryoshka/matryoshka_eval_stsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
"""

import argparse
from contextlib import contextmanager
from functools import wraps
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
from typing import Dict, List, Optional, Tuple, cast

from datasets import load_dataset
import numpy as np
Expand All @@ -17,50 +15,9 @@
EmbeddingSimilarityEvaluator,
SimilarityFunction,
)
import torch
from tqdm.auto import tqdm


# Util to truncate
# Should patch instance, not the class b/c maybe there are other models floating around
# that shouldn't get truncated
@contextmanager
def _monkeypatch_instance_method(obj: Any, method_name: str, new_method: Callable):
original_method = getattr(obj, method_name)
# Need to use __get__ when patching instance methods
# https://stackoverflow.com/a/28127947/18758987
try:
setattr(obj, method_name, new_method.__get__(obj, obj.__class__))
yield
finally:
setattr(obj, method_name, original_method.__get__(obj, obj.__class__))


@contextmanager
def truncate_embeddings(model: SentenceTransformer, dim: int):
"""
In this context, the `model` outputs embeddings truncated at dimension `dim`.

Parameters
----------
model : SentenceTransformer
model where `model.encode` outputs a (D,) or (N, D) array or tensor of
embeddings given text(s)
dim : int
dimension to truncate at. So a (N, D) array becomes (N, `dim`)
"""

original_encode = model.encode

@wraps(original_encode)
def encode(self, *args, **kwargs) -> Union[np.ndarray, torch.Tensor]:
embeddings = original_encode(*args, **kwargs)
return embeddings[..., :dim]

with _monkeypatch_instance_method(model, "encode", encode):
yield


# Dimension plot
def _grouped_barplot_ratios(
group_name_to_x_to_y: Dict[str, Dict[int, float]], ax: Optional[plt.Axes] = None
Expand Down Expand Up @@ -202,7 +159,7 @@ def plot_across_dimensions(
for dim in tqdm(DIMENSIONS, desc=f"Evaluating {model_name}"):
output_path = os.path.join(model_name, f"dim-{dim}")
os.makedirs(output_path)
with truncate_embeddings(model, dim):
with model.truncate_sentence_embeddings(dim):
score = test_evaluator(model, output_path=output_path)
print(f"Saved results to {output_path}")
dim_to_score[dim] = score
Expand Down
55 changes: 51 additions & 4 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import contextmanager
import json
import logging
import os
Expand Down Expand Up @@ -31,6 +32,7 @@
load_file_path,
save_to_hub_args_decorator,
get_device_name,
truncate_embeddings,
)
from .quantization import quantize_embeddings
from .models import Transformer, Pooling, Normalize
Expand Down Expand Up @@ -68,6 +70,8 @@ class SentenceTransformer(nn.Sequential):
This option should only be set to True for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
:param token: Hugging Face authentication token to download private models.
:param truncate_dim: The dimension to truncate sentence embeddings to. `None` does no truncation. Truncation is
only applicable during inference when `.encode` is called.
"""

def __init__(
Expand All @@ -82,10 +86,12 @@ def __init__(
revision: Optional[str] = None,
token: Optional[Union[bool, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
truncate_dim: Optional[int] = None,
):
# Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name`
self.prompts = prompts or {}
self.default_prompt_name = default_prompt_name
self.truncate_dim = truncate_dim
self._model_card_vars = {}
self._model_card_text = None
self._model_config = {}
Expand Down Expand Up @@ -253,7 +259,7 @@ def encode(
prompt: Optional[str] = None,
batch_size: int = 32,
show_progress_bar: bool = None,
output_value: str = "sentence_embedding",
output_value: Optional[Literal["sentence_embedding", "token_embeddings"]] = "sentence_embedding",
kddubey marked this conversation as resolved.
Show resolved Hide resolved
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
Expand Down Expand Up @@ -289,7 +295,8 @@ def encode(

:return: By default, a 2d numpy array with shape [num_inputs, output_dimension] is returned. If only one string
input is provided, then the output is a 1d array with shape [output_dimension]. If `convert_to_tensor`, a
torch Tensor is returned instead.
torch Tensor is returned instead. If `self.truncate_dim <= output_dimension` then output_dimension is
`self.truncate_dim`.
"""
self.eval()
if show_progress_bar is None:
Expand Down Expand Up @@ -355,6 +362,9 @@ def encode(

with torch.no_grad():
out_features = self.forward(features)
out_features["sentence_embedding"] = truncate_embeddings(
out_features["sentence_embedding"], self.truncate_dim
)

if output_value == "token_embeddings":
embeddings = []
Expand Down Expand Up @@ -572,11 +582,48 @@ def get_sentence_features(self, *features):
return self._first_module().get_sentence_features(*features)

def get_sentence_embedding_dimension(self):
"""
:return: The number of dimensions in the output of `encode`. If it's not known, it's `None`.
"""
output_dim = None
for mod in reversed(self._modules.values()):
sent_embedding_dim_method = getattr(mod, "get_sentence_embedding_dimension", None)
if callable(sent_embedding_dim_method):
return sent_embedding_dim_method()
return None
kddubey marked this conversation as resolved.
Show resolved Hide resolved
output_dim = sent_embedding_dim_method()
break
if self.truncate_dim is not None:
# The user requested truncation. If they set it to a dim greater than output_dim,
# no truncation will actually happen. So return output_dim insead of self.truncate_dim
return min(output_dim or np.inf, self.truncate_dim)
return output_dim

@contextmanager
def truncate_sentence_embeddings(self, truncate_dim: Optional[int]):
"""
In this context, `model.encode` outputs sentence embeddings truncated at dimension `truncate_dim`.

This may be useful when you are using the same model for different applications where different dimensions
are needed.

:param truncate_dim: The dimension to truncate sentence embeddings to. `None` does no truncation.

Example::

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("model-name")

with model.truncate_sentence_embeddings(truncate_dim=16):
embeddings_truncated = model.encode(["hello there", "hiya"])
assert embeddings_truncated.shape[-1] == 16

"""
original_output_dim = self.truncate_dim
try:
self.truncate_dim = truncate_dim
yield
finally:
self.truncate_dim = original_output_dim

def _first_module(self):
"""Returns the first module of this sequential embedder"""
Expand Down
38 changes: 22 additions & 16 deletions sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import nullcontext
from . import SentenceEvaluator, SimilarityFunction
import logging
import os
Expand Down Expand Up @@ -33,6 +34,7 @@ def __init__(
show_progress_bar: bool = False,
write_csv: bool = True,
precision: Optional[Literal["float32", "int8", "uint8", "binary", "ubinary"]] = None,
truncate_dim: Optional[int] = None,
):
"""
Constructs an evaluator based for the dataset
Expand All @@ -45,12 +47,15 @@ def __init__(
:param write_csv: Write results to a CSV file
:param precision: The precision to use for the embeddings. Can be "float32", "int8", "uint8", "binary", or
"ubinary". Defaults to None.
:param truncate_dim: The dimension to truncate sentence embeddings to. `None` uses the model's current
truncation dimension. Defaults to None.
"""
self.sentences1 = sentences1
self.sentences2 = sentences2
self.scores = scores
self.write_csv = write_csv
self.precision = precision
self.truncate_dim = truncate_dim

assert len(self.sentences1) == len(self.sentences2)
assert len(self.sentences1) == len(self.scores)
Expand Down Expand Up @@ -107,22 +112,23 @@ def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int =

logger.info("EmbeddingSimilarityEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt)

embeddings1 = model.encode(
self.sentences1,
batch_size=self.batch_size,
show_progress_bar=self.show_progress_bar,
convert_to_numpy=True,
precision=self.precision,
normalize_embeddings=bool(self.precision),
)
embeddings2 = model.encode(
self.sentences2,
batch_size=self.batch_size,
show_progress_bar=self.show_progress_bar,
convert_to_numpy=True,
precision=self.precision,
normalize_embeddings=bool(self.precision),
)
with nullcontext() if self.truncate_dim is None else model.truncate_sentence_embeddings(self.truncate_dim):
embeddings1 = model.encode(
self.sentences1,
batch_size=self.batch_size,
show_progress_bar=self.show_progress_bar,
convert_to_numpy=True,
precision=self.precision,
normalize_embeddings=bool(self.precision),
)
embeddings2 = model.encode(
self.sentences2,
batch_size=self.batch_size,
show_progress_bar=self.show_progress_bar,
convert_to_numpy=True,
precision=self.precision,
normalize_embeddings=bool(self.precision),
)
# Binary and ubinary embeddings are packed, so we need to unpack them for the distance metrics
if self.precision == "binary":
embeddings1 = (embeddings1 + 128).astype(np.uint8)
Expand Down
21 changes: 20 additions & 1 deletion sentence_transformers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import queue
import logging
from typing import Dict, Optional, Union
from typing import Dict, Optional, Union, overload

from transformers import is_torch_npu_available
from huggingface_hub import snapshot_download, hf_hub_download
Expand Down Expand Up @@ -142,6 +142,25 @@ def normalize_embeddings(embeddings: Tensor) -> Tensor:
return torch.nn.functional.normalize(embeddings, p=2, dim=1)


@overload
def truncate_embeddings(embeddings: np.ndarray, truncate_dim: Optional[int]) -> np.ndarray: ...


@overload
def truncate_embeddings(embeddings: torch.Tensor, truncate_dim: Optional[int]) -> torch.Tensor: ...


def truncate_embeddings(
embeddings: Union[np.ndarray, torch.Tensor], truncate_dim: Optional[int]
) -> Union[np.ndarray, torch.Tensor]:
"""
:param embeddings: Embeddings to truncate.
:param truncate_dim: The dimension to truncate sentence embeddings to. `None` does no truncation.
:return: Truncated embeddings.
"""
return embeddings[..., :truncate_dim]


def paraphrase_mining(
model, sentences: List[str], show_progress_bar: bool = False, batch_size: int = 32, *args, **kwargs
) -> List[List[Union[float, int]]]:
Expand Down
Loading
Loading