Skip to content

Commit

Permalink
modify _preprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Dec 9, 2024
1 parent 16cd484 commit c389a57
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
8 changes: 6 additions & 2 deletions haystack/components/readers/extractive.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def _flatten_documents(
return flattened_queries, flattened_documents, query_ids

def _preprocess( # pylint: disable=too-many-positional-arguments
self, queries: List[str], documents: List[Document], max_seq_length: int, query_ids: List[int], stride: int
self, *, queries: List[str], documents: List[Document], max_seq_length: int, query_ids: List[int], stride: int
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", List["Encoding"], List[int], List[int]]:
"""
Splits and tokenizes Documents and preserves structures by returning mappings to query and Document IDs.
Expand Down Expand Up @@ -600,7 +600,11 @@ def run( # pylint: disable=too-many-positional-arguments
queries, nested_documents
)
input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids = self._preprocess(
flattened_queries, flattened_documents, max_seq_length, query_ids, stride
queries=flattened_queries,
documents=flattened_documents,
max_seq_length=max_seq_length,
query_ids=query_ids,
stride=stride,
)

num_batches = math.ceil(input_ids.shape[0] / max_batch_size) if max_batch_size else 1
Expand Down
8 changes: 6 additions & 2 deletions test/components/readers/test_extractive.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def test_flatten_documents(mock_reader: ExtractiveReader):

def test_preprocess(mock_reader: ExtractiveReader):
_, _, seq_ids, _, query_ids, doc_ids = mock_reader._preprocess(
example_queries * 3, example_documents[0], 384, [1, 1, 1], 0
queries=example_queries * 3, documents=example_documents[0], max_seq_length=384, query_ids=[1, 1, 1], stride=0
)
expected_seq_ids = torch.full((3, 384), -1, dtype=torch.int)
expected_seq_ids[:, :16] = 0
Expand All @@ -333,7 +333,11 @@ def test_preprocess(mock_reader: ExtractiveReader):

def test_preprocess_splitting(mock_reader: ExtractiveReader):
_, _, seq_ids, _, query_ids, doc_ids = mock_reader._preprocess(
example_queries * 4, example_documents[0] + [Document(content="a" * 64)], 96, [1, 1, 1, 1], 0
queries=example_queries * 4,
documents=example_documents[0] + [Document(content="a" * 64)],
max_seq_length=96,
query_ids=[1, 1, 1, 1],
stride=0,
)
assert seq_ids.shape[0] == 5
assert query_ids == [1, 1, 1, 1, 1]
Expand Down

0 comments on commit c389a57

Please sign in to comment.