From c389a57d911fc78a477f7c8077d6ce5138fcf10f Mon Sep 17 00:00:00 2001 From: anakin87 Date: Mon, 9 Dec 2024 17:27:35 +0100 Subject: [PATCH] modify _preprocess --- haystack/components/readers/extractive.py | 8 ++++++-- test/components/readers/test_extractive.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/haystack/components/readers/extractive.py b/haystack/components/readers/extractive.py index 460c87e083..c7bd51d7e5 100644 --- a/haystack/components/readers/extractive.py +++ b/haystack/components/readers/extractive.py @@ -205,7 +205,7 @@ def _flatten_documents( return flattened_queries, flattened_documents, query_ids def _preprocess( # pylint: disable=too-many-positional-arguments - self, queries: List[str], documents: List[Document], max_seq_length: int, query_ids: List[int], stride: int + self, *, queries: List[str], documents: List[Document], max_seq_length: int, query_ids: List[int], stride: int ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", List["Encoding"], List[int], List[int]]: """ Splits and tokenizes Documents and preserves structures by returning mappings to query and Document IDs. @@ -600,7 +600,11 @@ def run( # pylint: disable=too-many-positional-arguments queries, nested_documents ) input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids = self._preprocess( - flattened_queries, flattened_documents, max_seq_length, query_ids, stride + queries=flattened_queries, + documents=flattened_documents, + max_seq_length=max_seq_length, + query_ids=query_ids, + stride=stride, ) num_batches = math.ceil(input_ids.shape[0] / max_batch_size) if max_batch_size else 1 diff --git a/test/components/readers/test_extractive.py b/test/components/readers/test_extractive.py index 9c42c44254..5e9ff24d21 100644 --- a/test/components/readers/test_extractive.py +++ b/test/components/readers/test_extractive.py @@ -321,7 +321,7 @@ def test_flatten_documents(mock_reader: ExtractiveReader): def test_preprocess(mock_reader: ExtractiveReader): _, _, seq_ids, _, query_ids, doc_ids = mock_reader._preprocess( - example_queries * 3, example_documents[0], 384, [1, 1, 1], 0 + queries=example_queries * 3, documents=example_documents[0], max_seq_length=384, query_ids=[1, 1, 1], stride=0 ) expected_seq_ids = torch.full((3, 384), -1, dtype=torch.int) expected_seq_ids[:, :16] = 0 @@ -333,7 +333,11 @@ def test_preprocess(mock_reader: ExtractiveReader): def test_preprocess_splitting(mock_reader: ExtractiveReader): _, _, seq_ids, _, query_ids, doc_ids = mock_reader._preprocess( - example_queries * 4, example_documents[0] + [Document(content="a" * 64)], 96, [1, 1, 1, 1], 0 + queries=example_queries * 4, + documents=example_documents[0] + [Document(content="a" * 64)], + max_seq_length=96, + query_ids=[1, 1, 1, 1], + stride=0, ) assert seq_ids.shape[0] == 5 assert query_ids == [1, 1, 1, 1, 1]