diff --git a/docs/changelog/98167.yaml b/docs/changelog/98167.yaml new file mode 100644 index 0000000000000..4622c9fc2b037 --- /dev/null +++ b/docs/changelog/98167.yaml @@ -0,0 +1,6 @@ +pr: 98167 +summary: Fix failure processing Question Answering model output where the input has been spanned over multiple sequences +area: Machine Learning +type: bug +issues: + - 97917 diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessor.java index 0014360fb61ff..e15f0402db3b1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessor.java @@ -87,48 +87,78 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchIn if (pyTorchResult.getInferenceResult().length < 1) { throw new ElasticsearchStatusException("question answering result has no data", RestStatus.INTERNAL_SERVER_ERROR); } + + // The result format is pairs of 'start' and 'end' logits, + // one pair for each span. + // Multiple spans occur where the context text is longer than + // the max sequence length, so the input must be windowed with + // overlap and evaluated in multiple calls. + // Note the response format changed in 8.9 due to the change in + // pytorch_inference to not process requests in batches. + + // The output tensor is a 3d array of doubles. + // 1. The 1st index is the pairs of start and end for each span. + // If there is 1 span there will be 2 elements in this dimension, + // for 2 spans 4 elements + // 2. The 2nd index is the number results per span. + // This dimension is always equal to 1. + // 3. The 3rd index is the actual scores. + // This is an array of doubles equal in size to the number of + // input tokens plus and delimiters (e.g. SEP and CLS tokens) + // added by the tokenizer. + // + // inferenceResult[span_index_start_end][0][scores] + // Should be a collection of "starts" and "ends" - if (pyTorchResult.getInferenceResult().length != 2) { + if (pyTorchResult.getInferenceResult().length % 2 != 0) { throw new ElasticsearchStatusException( - "question answering result has invalid dimension, expected 2 found [{}]", + "question answering result has invalid dimension, number of dimensions must be a multiple of 2 found [{}]", RestStatus.INTERNAL_SERVER_ERROR, pyTorchResult.getInferenceResult().length ); } - double[][] starts = pyTorchResult.getInferenceResult()[0]; - double[][] ends = pyTorchResult.getInferenceResult()[1]; - if (starts.length != ends.length) { - throw new ElasticsearchStatusException( - "question answering result has invalid dimensions; start positions [{}] must equal potential end [{}]", - RestStatus.INTERNAL_SERVER_ERROR, - starts.length, - ends.length - ); - } + + final int numAnswersToGather = Math.max(numTopClasses, 1); + ScoreAndIndicesPriorityQueue finalEntries = new ScoreAndIndicesPriorityQueue(numAnswersToGather); List tokensList = tokenization.getTokensBySequenceId().get(0); - if (starts.length != tokensList.size()) { + + int numberOfSpans = pyTorchResult.getInferenceResult().length / 2; + if (numberOfSpans != tokensList.size()) { throw new ElasticsearchStatusException( - "question answering result has invalid dimensions; start positions number [{}] equal batched token size [{}]", + "question answering result has invalid dimensions; the number of spans [{}] does not match batched token size [{}]", RestStatus.INTERNAL_SERVER_ERROR, - starts.length, + numberOfSpans, tokensList.size() ); } - final int numAnswersToGather = Math.max(numTopClasses, 1); - ScoreAndIndicesPriorityQueue finalEntries = new ScoreAndIndicesPriorityQueue(numAnswersToGather); - for (int i = 0; i < starts.length; i++) { + for (int spanIndex = 0; spanIndex < numberOfSpans; spanIndex++) { + double[][] starts = pyTorchResult.getInferenceResult()[spanIndex * 2]; + double[][] ends = pyTorchResult.getInferenceResult()[(spanIndex * 2) + 1]; + assert starts.length == 1; + assert ends.length == 1; + + if (starts.length != ends.length) { + throw new ElasticsearchStatusException( + "question answering result has invalid dimensions; start positions [{}] must equal potential end [{}]", + RestStatus.INTERNAL_SERVER_ERROR, + starts.length, + ends.length + ); + } + topScores( - starts[i], - ends[i], + starts[0], // always 1 element in this dimension + ends[0], numAnswersToGather, finalEntries::insertWithOverflow, - tokensList.get(i).seqPairOffset(), - tokensList.get(i).tokenIds().length, + tokensList.get(spanIndex).seqPairOffset(), + tokensList.get(spanIndex).tokenIds().length, maxAnswerLength, - i + spanIndex ); } + QuestionAnsweringInferenceResults.TopAnswerEntry[] topAnswerList = new QuestionAnsweringInferenceResults.TopAnswerEntry[numAnswersToGather]; for (int i = numAnswersToGather - 1; i >= 0; i--) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java index ac54abdc73924..9c7f7cb607630 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java @@ -48,7 +48,7 @@ public Map> getTokensBySequenceId() { return tokens.stream().collect(Collectors.groupingBy(Tokens::sequenceId)); } - List getTokens() { + public List getTokens() { return tokens; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java index ab8bdf4870973..48d8c154ae59e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.List; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.DoubleStream; import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizerTests.TEST_CASED_VOCAB; import static org.hamcrest.Matchers.closeTo; @@ -168,4 +169,68 @@ public void testTopScoresMoreThanOne() { assertThat(topScores[1].endToken(), equalTo(5)); } + public void testProcessorMuliptleSpans() throws IOException { + String question = "is Elasticsearch fun?"; + String input = "Pancake day is fun with Elasticsearch and little red car"; + int span = 4; + int maxSequenceLength = 14; + int numberTopClasses = 3; + + BertTokenization tokenization = new BertTokenization(false, true, maxSequenceLength, Tokenization.Truncate.NONE, span); + BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, tokenization).build(); + QuestionAnsweringConfig config = new QuestionAnsweringConfig( + question, + numberTopClasses, + 10, + new VocabularyConfig("index_name"), + tokenization, + "prediction" + ); + QuestionAnsweringProcessor processor = new QuestionAnsweringProcessor(tokenizer); + TokenizationResult tokenizationResult = processor.getRequestBuilder(config) + .buildRequest(List.of(input), "1", Tokenization.Truncate.NONE, span) + .tokenization(); + assertThat(tokenizationResult.anyTruncated(), is(false)); + + // now we know what the tokenization looks like + // (number of spans and size of each) fake the + // question answering response + + int numberSpans = tokenizationResult.getTokens().size(); + double[][][] modelTensorOutput = new double[numberSpans * 2][][]; + for (int i = 0; i < numberSpans; i++) { + var windowTokens = tokenizationResult.getTokens().get(i); + // size of output + int outputSize = windowTokens.tokenIds().length; + // generate low value -ve scores that will not mark + // the expected result with a high degree of probability + double[] starts = DoubleStream.generate(() -> -randomDoubleBetween(0.001, 1.0, true)).limit(outputSize).toArray(); + double[] ends = DoubleStream.generate(() -> -randomDoubleBetween(0.001, 1.0, true)).limit(outputSize).toArray(); + modelTensorOutput[i * 2] = new double[][] { starts }; + modelTensorOutput[(i * 2) + 1] = new double[][] { ends }; + } + + int spanContainingTheAnswer = randomIntBetween(0, numberSpans - 1); + + // insert numbers to mark the answer in the chosen span + int answerStart = tokenizationResult.getTokens().get(spanContainingTheAnswer).seqPairOffset(); // first token of second sequence + // last token of the second sequence ignoring the final SEP added by the BERT tokenizer + int answerEnd = tokenizationResult.getTokens().get(spanContainingTheAnswer).tokenIds().length - 2; + modelTensorOutput[spanContainingTheAnswer * 2][0][answerStart] = 0.5; + modelTensorOutput[(spanContainingTheAnswer * 2) + 1][0][answerEnd] = 1.0; + + NlpTask.ResultProcessor resultProcessor = processor.getResultProcessor(config); + PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(modelTensorOutput); + QuestionAnsweringInferenceResults result = (QuestionAnsweringInferenceResults) resultProcessor.processResult( + tokenizationResult, + pyTorchResult + ); + + // The expected answer is the full text of the span containing the answer + int expectedStart = tokenizationResult.getTokens().get(spanContainingTheAnswer).tokens().get(1).get(0).startOffset(); + int lastTokenPosition = tokenizationResult.getTokens().get(spanContainingTheAnswer).tokens().get(1).size() - 1; + int expectedEnd = tokenizationResult.getTokens().get(spanContainingTheAnswer).tokens().get(1).get(lastTokenPosition).endOffset(); + + assertThat(result.getAnswer(), equalTo(input.substring(expectedStart, expectedEnd))); + } }