Skip to content

Commit

Permalink
feat (v2): Update so model_max_length updates max_seq_length for …
Browse files Browse the repository at this point in the history
…Sentence Transformers (#8334)

* Update so model_max_length does what is expected

* Add release notes

* Some fixes

* Another test
  • Loading branch information
sjrl authored Sep 6, 2024
1 parent e98a6fe commit 06dd5c2
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def warm_up(self):
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
)
if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]

@component.output_types(documents=List[Document])
def run(self, documents: List[Document]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def warm_up(self):
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
)
if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]

@component.output_types(embedding=List[float])
def run(self, text: str):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Updates SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder so model_max_length passed through tokenizer_kwargs also updates the max_seq_length of the underly SentenceTransformer model.
Original file line number Diff line number Diff line change
Expand Up @@ -226,18 +226,22 @@ def test_from_dict_none_device(self):
)
def test_warmup(self, mocked_factory):
embedder = SentenceTransformersDocumentEmbedder(
model="model", token=None, device=ComponentDevice.from_str("cpu")
model="model",
token=None,
device=ComponentDevice.from_str("cpu"),
tokenizer_kwargs={"model_max_length": 512},
)
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
embedder.embedding_backend.model.max_seq_length = 512
mocked_factory.get_embedding_backend.assert_called_once_with(
model="model",
device="cpu",
auth_token=None,
trust_remote_code=False,
truncate_dim=None,
model_kwargs=None,
tokenizer_kwargs=None,
tokenizer_kwargs={"model_max_length": 512},
)

@patch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,17 +201,23 @@ def test_from_dict_none_device(self):
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
)
def test_warmup(self, mocked_factory):
embedder = SentenceTransformersTextEmbedder(model="model", token=None, device=ComponentDevice.from_str("cpu"))
embedder = SentenceTransformersTextEmbedder(
model="model",
token=None,
device=ComponentDevice.from_str("cpu"),
tokenizer_kwargs={"model_max_length": 512},
)
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
embedder.embedding_backend.model.max_seq_length = 512
mocked_factory.get_embedding_backend.assert_called_once_with(
model="model",
device="cpu",
auth_token=None,
trust_remote_code=False,
truncate_dim=None,
model_kwargs=None,
tokenizer_kwargs=None,
tokenizer_kwargs={"model_max_length": 512},
)

@patch(
Expand Down

0 comments on commit 06dd5c2

Please sign in to comment.