Skip to content

Commit

Permalink
fix: make ExtractiveReader handle situations where token_to_chars
Browse files Browse the repository at this point in the history
… returns None instead of a (start, end) tuple (#6382)

* fix reader bug

* add test

* log

* fix logging

* improve error message
  • Loading branch information
ZanSara authored Nov 24, 2023
1 parent f3b7303 commit c45d8c3
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 4 deletions.
36 changes: 32 additions & 4 deletions haystack/preview/components/readers/extractive.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import math
import warnings
import logging
import os

from haystack.preview import component, default_to_dict, ComponentError, Document, ExtractedAnswer
Expand All @@ -13,6 +14,9 @@
import torch


logger = logging.getLogger(__name__)


@component
class ExtractiveReader:
"""
Expand Down Expand Up @@ -214,14 +218,38 @@ def _postprocess(
start_candidates = start_candidates.cpu()
end_candidates = end_candidates.cpu()

start_candidates_char_indices = [
[encoding.token_to_chars(start)[0] for start in candidates]
start_candidates_tokens_to_chars = [
[encoding.token_to_chars(start) for start in candidates]
for candidates, encoding in zip(start_candidates, encodings)
]
end_candidates_char_indices = [
[encoding.token_to_chars(end)[1] for end in candidates]
if missing_start_tokens := [
(batch, index)
for batch, token_to_chars in enumerate(start_candidates_tokens_to_chars)
for index, pair in enumerate(token_to_chars)
if pair is None
]:
logger.warning("Some start tokens could not be found in the context: %s", missing_start_tokens)
start_candidates_char_indices = [
[token_to_chars[0] if token_to_chars else None for token_to_chars in candidates]
for candidates in start_candidates_tokens_to_chars
]

end_candidates_tokens_to_chars = [
[encoding.token_to_chars(end) for end in candidates]
for candidates, encoding in zip(end_candidates, encodings)
]
if missing_end_tokens := [
(batch, index)
for batch, token_to_chars in enumerate(end_candidates_tokens_to_chars)
for index, pair in enumerate(token_to_chars)
if pair is None
]:
logger.warning("Some end tokens could not be found in the context: %s", missing_end_tokens)
end_candidates_char_indices = [
[token_to_chars[1] if token_to_chars else None for token_to_chars in candidates]
for candidates in end_candidates_tokens_to_chars
]

probabilities = candidates.values.cpu()

return start_candidates_char_indices, end_candidates_char_indices, probabilities
Expand Down
73 changes: 73 additions & 0 deletions test/preview/components/readers/test_extractive.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,79 @@ def test_warm_up_use_hf_token(mocked_automodel, mocked_autotokenizer):
mocked_autotokenizer.assert_called_once_with("deepset/roberta-base-squad2", token="fake-token")


@pytest.mark.unit
def test_missing_token_to_chars_values():
# See https://github.com/deepset-ai/haystack/issues/6098

def mock_tokenize(
texts: List[str],
text_pairs: List[str],
padding: bool,
truncation: bool,
max_length: int,
return_tensors: str,
return_overflowing_tokens: bool,
stride: int,
):
assert padding
assert truncation
assert return_tensors == "pt"
assert return_overflowing_tokens

tokens = Mock()

num_splits = [ceil(len(text + pair) / max_length) for text, pair in zip(texts, text_pairs)]
tokens.overflow_to_sample_mapping = [i for i, num in enumerate(num_splits) for _ in range(num)]
num_samples = sum(num_splits)
tokens.encodings = [Mock() for _ in range(num_samples)]
sequence_ids = [0] * 16 + [1] * 16 + [None] * (max_length - 32)
for encoding in tokens.encodings:
encoding.sequence_ids = sequence_ids
encoding.token_to_chars = lambda i: None
tokens.input_ids = torch.zeros(num_samples, max_length, dtype=torch.int)
attention_mask = torch.zeros(num_samples, max_length, dtype=torch.int)
attention_mask[:32] = 1
tokens.attention_mask = attention_mask
return tokens

class MockModel(torch.nn.Module):
def to(self, device):
assert device == "cpu:0"
self.device_set = True
return self

def forward(self, input_ids, attention_mask, *args, **kwargs):
assert input_ids.device == torch.device("cpu")
assert attention_mask.device == torch.device("cpu")
assert self.device_set
start = torch.zeros(input_ids.shape[:2])
end = torch.zeros(input_ids.shape[:2])
start[:, 27] = 1
end[:, 31] = 1
end[:, 32] = 1
prediction = Mock()
prediction.start_logits = start
prediction.end_logits = end
return prediction

with patch("haystack.preview.components.readers.extractive.AutoTokenizer.from_pretrained") as tokenizer, patch(
"haystack.preview.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained"
) as model:
tokenizer.return_value = mock_tokenize
model.return_value = MockModel()
reader = ExtractiveReader(model_name_or_path="mock-model", device="cpu:0")
reader.warm_up()

answers = reader.run(example_queries[0], example_documents[0], top_k=3)[
"answers"
] # [0] Uncomment and remove first two indices when batching support is reintroduced
for doc, answer in zip(example_documents[0], answers[:3]):
assert answer.start is None
assert answer.end is None
assert doc.content is not None
assert answer.data == doc.content


@pytest.mark.integration
def test_t5():
reader = ExtractiveReader("TARUNBHATT/flan-t5-small-finetuned-squad")
Expand Down

0 comments on commit c45d8c3

Please sign in to comment.