diff --git a/haystack/components/embedders/sentence_transformers_document_embedder.py b/haystack/components/embedders/sentence_transformers_document_embedder.py index db43581b03..d7f78b54f6 100644 --- a/haystack/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/components/embedders/sentence_transformers_document_embedder.py @@ -187,6 +187,8 @@ def warm_up(self): model_kwargs=self.model_kwargs, tokenizer_kwargs=self.tokenizer_kwargs, ) + if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"): + self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"] @component.output_types(documents=List[Document]) def run(self, documents: List[Document]): diff --git a/haystack/components/embedders/sentence_transformers_text_embedder.py b/haystack/components/embedders/sentence_transformers_text_embedder.py index 96e5356402..e29b2d439c 100644 --- a/haystack/components/embedders/sentence_transformers_text_embedder.py +++ b/haystack/components/embedders/sentence_transformers_text_embedder.py @@ -173,6 +173,8 @@ def warm_up(self): model_kwargs=self.model_kwargs, tokenizer_kwargs=self.tokenizer_kwargs, ) + if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"): + self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"] @component.output_types(embedding=List[float]) def run(self, text: str): diff --git a/releasenotes/notes/update-max-seq-lenght-st-1dc3d7a9c9a3bdcd.yaml b/releasenotes/notes/update-max-seq-lenght-st-1dc3d7a9c9a3bdcd.yaml new file mode 100644 index 0000000000..3a5595ee6c --- /dev/null +++ b/releasenotes/notes/update-max-seq-lenght-st-1dc3d7a9c9a3bdcd.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Updates SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder so model_max_length passed through tokenizer_kwargs also updates the max_seq_length of the underly SentenceTransformer model. diff --git a/test/components/embedders/test_sentence_transformers_document_embedder.py b/test/components/embedders/test_sentence_transformers_document_embedder.py index 0e70085710..a5e6af8278 100644 --- a/test/components/embedders/test_sentence_transformers_document_embedder.py +++ b/test/components/embedders/test_sentence_transformers_document_embedder.py @@ -226,10 +226,14 @@ def test_from_dict_none_device(self): ) def test_warmup(self, mocked_factory): embedder = SentenceTransformersDocumentEmbedder( - model="model", token=None, device=ComponentDevice.from_str("cpu") + model="model", + token=None, + device=ComponentDevice.from_str("cpu"), + tokenizer_kwargs={"model_max_length": 512}, ) mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() + embedder.embedding_backend.model.max_seq_length = 512 mocked_factory.get_embedding_backend.assert_called_once_with( model="model", device="cpu", @@ -237,7 +241,7 @@ def test_warmup(self, mocked_factory): trust_remote_code=False, truncate_dim=None, model_kwargs=None, - tokenizer_kwargs=None, + tokenizer_kwargs={"model_max_length": 512}, ) @patch( diff --git a/test/components/embedders/test_sentence_transformers_text_embedder.py b/test/components/embedders/test_sentence_transformers_text_embedder.py index c739650ef2..2f043de237 100644 --- a/test/components/embedders/test_sentence_transformers_text_embedder.py +++ b/test/components/embedders/test_sentence_transformers_text_embedder.py @@ -201,9 +201,15 @@ def test_from_dict_none_device(self): "haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory" ) def test_warmup(self, mocked_factory): - embedder = SentenceTransformersTextEmbedder(model="model", token=None, device=ComponentDevice.from_str("cpu")) + embedder = SentenceTransformersTextEmbedder( + model="model", + token=None, + device=ComponentDevice.from_str("cpu"), + tokenizer_kwargs={"model_max_length": 512}, + ) mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() + embedder.embedding_backend.model.max_seq_length = 512 mocked_factory.get_embedding_backend.assert_called_once_with( model="model", device="cpu", @@ -211,7 +217,7 @@ def test_warmup(self, mocked_factory): trust_remote_code=False, truncate_dim=None, model_kwargs=None, - tokenizer_kwargs=None, + tokenizer_kwargs={"model_max_length": 512}, ) @patch(