From 3dee389326b0f4690c00eece6d4b75d5d2d4f905 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 25 Apr 2024 19:50:19 +0200 Subject: [PATCH 1/2] Recompute the features if return_output --- sentence_transformers/trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sentence_transformers/trainer.py b/sentence_transformers/trainer.py index 8eb4877ba..400bfbd37 100644 --- a/sentence_transformers/trainer.py +++ b/sentence_transformers/trainer.py @@ -153,6 +153,8 @@ def compute_loss( loss_fn.model = model loss = loss_fn(features, labels) if return_outputs: + # Get fresh features, as the loss function has likely modified them + features, _ = self.collect_features(inputs) output = torch.cat([model(row)["sentence_embedding"][:, None] for row in features], dim=1) return loss, output return loss From b75d8e5ef946b55ad44ff61ae36e93a1067519fe Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 25 Apr 2024 19:50:51 +0200 Subject: [PATCH 2/2] Add SimilarityFunction to __init__, increment dev version --- sentence_transformers/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sentence_transformers/__init__.py b/sentence_transformers/__init__.py index 7cd7a2230..3d772711b 100644 --- a/sentence_transformers/__init__.py +++ b/sentence_transformers/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.8.0.dev0" +__version__ = "3.0.0.dev0" __MODEL_HUB_ORGANIZATION__ = "sentence-transformers" import importlib @@ -7,6 +7,7 @@ from .datasets import SentencesDataset, ParallelSentencesDataset from .LoggingHandler import LoggingHandler from .SentenceTransformer import SentenceTransformer +from .similarity_functions import SimilarityFunction from .readers import InputExample from .cross_encoder.CrossEncoder import CrossEncoder from .trainer import SentenceTransformerTrainer @@ -25,6 +26,7 @@ "SentencesDataset", "ParallelSentencesDataset", "SentenceTransformer", + "SimilarityFunction", "InputExample", "CrossEncoder", "SentenceTransformerTrainer",