diff --git a/haystack/preview/components/readers/extractive.py b/haystack/preview/components/readers/extractive.py index 55f3e9cb56..7fbe5d3ad1 100644 --- a/haystack/preview/components/readers/extractive.py +++ b/haystack/preview/components/readers/extractive.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import math import warnings +import logging import os from haystack.preview import component, default_to_dict, ComponentError, Document, ExtractedAnswer @@ -13,6 +14,9 @@ import torch +logger = logging.getLogger(__name__) + + @component class ExtractiveReader: """ @@ -214,14 +218,38 @@ def _postprocess( start_candidates = start_candidates.cpu() end_candidates = end_candidates.cpu() - start_candidates_char_indices = [ - [encoding.token_to_chars(start)[0] for start in candidates] + start_candidates_tokens_to_chars = [ + [encoding.token_to_chars(start) for start in candidates] for candidates, encoding in zip(start_candidates, encodings) ] - end_candidates_char_indices = [ - [encoding.token_to_chars(end)[1] for end in candidates] + if missing_start_tokens := [ + (batch, index) + for batch, token_to_chars in enumerate(start_candidates_tokens_to_chars) + for index, pair in enumerate(token_to_chars) + if pair is None + ]: + logger.warning("Some start tokens could not be found in the context: %s", missing_start_tokens) + start_candidates_char_indices = [ + [token_to_chars[0] if token_to_chars else None for token_to_chars in candidates] + for candidates in start_candidates_tokens_to_chars + ] + + end_candidates_tokens_to_chars = [ + [encoding.token_to_chars(end) for end in candidates] for candidates, encoding in zip(end_candidates, encodings) ] + if missing_end_tokens := [ + (batch, index) + for batch, token_to_chars in enumerate(end_candidates_tokens_to_chars) + for index, pair in enumerate(token_to_chars) + if pair is None + ]: + logger.warning("Some end tokens could not be found in the context: %s", missing_end_tokens) + end_candidates_char_indices = [ + [token_to_chars[1] if token_to_chars else None for token_to_chars in candidates] + for candidates in end_candidates_tokens_to_chars + ] + probabilities = candidates.values.cpu() return start_candidates_char_indices, end_candidates_char_indices, probabilities diff --git a/test/preview/components/readers/test_extractive.py b/test/preview/components/readers/test_extractive.py index 2c6ea04df5..438922ae8d 100644 --- a/test/preview/components/readers/test_extractive.py +++ b/test/preview/components/readers/test_extractive.py @@ -269,6 +269,79 @@ def test_warm_up_use_hf_token(mocked_automodel, mocked_autotokenizer): mocked_autotokenizer.assert_called_once_with("deepset/roberta-base-squad2", token="fake-token") +@pytest.mark.unit +def test_missing_token_to_chars_values(): + # See https://github.com/deepset-ai/haystack/issues/6098 + + def mock_tokenize( + texts: List[str], + text_pairs: List[str], + padding: bool, + truncation: bool, + max_length: int, + return_tensors: str, + return_overflowing_tokens: bool, + stride: int, + ): + assert padding + assert truncation + assert return_tensors == "pt" + assert return_overflowing_tokens + + tokens = Mock() + + num_splits = [ceil(len(text + pair) / max_length) for text, pair in zip(texts, text_pairs)] + tokens.overflow_to_sample_mapping = [i for i, num in enumerate(num_splits) for _ in range(num)] + num_samples = sum(num_splits) + tokens.encodings = [Mock() for _ in range(num_samples)] + sequence_ids = [0] * 16 + [1] * 16 + [None] * (max_length - 32) + for encoding in tokens.encodings: + encoding.sequence_ids = sequence_ids + encoding.token_to_chars = lambda i: None + tokens.input_ids = torch.zeros(num_samples, max_length, dtype=torch.int) + attention_mask = torch.zeros(num_samples, max_length, dtype=torch.int) + attention_mask[:32] = 1 + tokens.attention_mask = attention_mask + return tokens + + class MockModel(torch.nn.Module): + def to(self, device): + assert device == "cpu:0" + self.device_set = True + return self + + def forward(self, input_ids, attention_mask, *args, **kwargs): + assert input_ids.device == torch.device("cpu") + assert attention_mask.device == torch.device("cpu") + assert self.device_set + start = torch.zeros(input_ids.shape[:2]) + end = torch.zeros(input_ids.shape[:2]) + start[:, 27] = 1 + end[:, 31] = 1 + end[:, 32] = 1 + prediction = Mock() + prediction.start_logits = start + prediction.end_logits = end + return prediction + + with patch("haystack.preview.components.readers.extractive.AutoTokenizer.from_pretrained") as tokenizer, patch( + "haystack.preview.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained" + ) as model: + tokenizer.return_value = mock_tokenize + model.return_value = MockModel() + reader = ExtractiveReader(model_name_or_path="mock-model", device="cpu:0") + reader.warm_up() + + answers = reader.run(example_queries[0], example_documents[0], top_k=3)[ + "answers" + ] # [0] Uncomment and remove first two indices when batching support is reintroduced + for doc, answer in zip(example_documents[0], answers[:3]): + assert answer.start is None + assert answer.end is None + assert doc.content is not None + assert answer.data == doc.content + + @pytest.mark.integration def test_t5(): reader = ExtractiveReader("TARUNBHATT/flan-t5-small-finetuned-squad")