Skip to content

Commit

Permalink
Adapt Question Answering processing for non-batched evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Aug 3, 2023
1 parent 93ba276 commit fb99f4c
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,48 +87,78 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchIn
if (pyTorchResult.getInferenceResult().length < 1) {
throw new ElasticsearchStatusException("question answering result has no data", RestStatus.INTERNAL_SERVER_ERROR);
}

// The result format is pairs of 'start' and 'end' logits,
// one pair for each span.
// Multiple spans occur where the context text is longer than
// the max sequence length, so the input must be windowed with
// overlap and evaluated in multiple calls.
// Note the response format changed in 8.9 due to the change in
// pytorch_inference to not process requests in batches.

// The output tensor is a 3d array of doubles.
// 1. The 1st index is the pairs of start and end for each span.
// If there is 1 span there will be 2 elements in this dimension,
// for 2 spans 4 elements
// 2. The 2nd index is the number results per span.
// This dimension is always equal to 1.
// 3. The 3rd index is the actual scores.
// This is an array of doubles equal in size to the number of
// input tokens plus and delimiters (e.g. SEP and CLS tokens)
// added by the tokenizer.
//
// inferenceResult[span_index_start_end][0][scores]

// Should be a collection of "starts" and "ends"
if (pyTorchResult.getInferenceResult().length != 2) {
if (pyTorchResult.getInferenceResult().length % 2 != 0) {
throw new ElasticsearchStatusException(
"question answering result has invalid dimension, expected 2 found [{}]",
"question answering result has invalid dimension, number of dimensions must be a multiple of 2 found [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
pyTorchResult.getInferenceResult().length
);
}
double[][] starts = pyTorchResult.getInferenceResult()[0];
double[][] ends = pyTorchResult.getInferenceResult()[1];
if (starts.length != ends.length) {
throw new ElasticsearchStatusException(
"question answering result has invalid dimensions; start positions [{}] must equal potential end [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
starts.length,
ends.length
);
}

final int numAnswersToGather = Math.max(numTopClasses, 1);
ScoreAndIndicesPriorityQueue finalEntries = new ScoreAndIndicesPriorityQueue(numAnswersToGather);
List<TokenizationResult.Tokens> tokensList = tokenization.getTokensBySequenceId().get(0);
if (starts.length != tokensList.size()) {

int numberOfSpans = pyTorchResult.getInferenceResult().length / 2;
if (numberOfSpans != tokensList.size()) {
throw new ElasticsearchStatusException(
"question answering result has invalid dimensions; start positions number [{}] equal batched token size [{}]",
"question answering result has invalid dimensions; the number of spans [{}] does not match batched token size [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
starts.length,
numberOfSpans,
tokensList.size()
);
}
final int numAnswersToGather = Math.max(numTopClasses, 1);

ScoreAndIndicesPriorityQueue finalEntries = new ScoreAndIndicesPriorityQueue(numAnswersToGather);
for (int i = 0; i < starts.length; i++) {
for (int spanIndex = 0; spanIndex < numberOfSpans; spanIndex++) {
double[][] starts = pyTorchResult.getInferenceResult()[spanIndex * 2];
double[][] ends = pyTorchResult.getInferenceResult()[(spanIndex * 2) + 1];
assert starts.length == 1;
assert ends.length == 1;

if (starts.length != ends.length) {
throw new ElasticsearchStatusException(
"question answering result has invalid dimensions; start positions [{}] must equal potential end [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
starts.length,
ends.length
);
}

topScores(
starts[i],
ends[i],
starts[0], // always 1 element in this dimension
ends[0],
numAnswersToGather,
finalEntries::insertWithOverflow,
tokensList.get(i).seqPairOffset(),
tokensList.get(i).tokenIds().length,
tokensList.get(spanIndex).seqPairOffset(),
tokensList.get(spanIndex).tokenIds().length,
maxAnswerLength,
i
spanIndex
);
}

QuestionAnsweringInferenceResults.TopAnswerEntry[] topAnswerList =
new QuestionAnsweringInferenceResults.TopAnswerEntry[numAnswersToGather];
for (int i = numAnswersToGather - 1; i >= 0; i--) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public Map<Integer, List<Tokens>> getTokensBySequenceId() {
return tokens.stream().collect(Collectors.groupingBy(Tokens::sequenceId));
}

List<Tokens> getTokens() {
public List<Tokens> getTokens() {
return tokens;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.IOException;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.DoubleStream;

import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizerTests.TEST_CASED_VOCAB;
import static org.hamcrest.Matchers.closeTo;
Expand Down Expand Up @@ -168,4 +169,67 @@ public void testTopScoresMoreThanOne() {
assertThat(topScores[1].endToken(), equalTo(5));
}

public void testProcessorMuliptleSpans() throws IOException {
String question = "is Elasticsearch fun?";
String input = "Pancake day is fun with Elasticsearch and little red car";
int span = 4;
int maxSequenceLength = 14;
int numberTopClasses = 3;

BertTokenization tokenization = new BertTokenization(false, true, maxSequenceLength, Tokenization.Truncate.NONE, span);
BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, tokenization).build();
QuestionAnsweringConfig config = new QuestionAnsweringConfig(
question,
numberTopClasses,
10,
new VocabularyConfig("index_name"),
tokenization,
"prediction"
);
QuestionAnsweringProcessor processor = new QuestionAnsweringProcessor(tokenizer);
TokenizationResult tokenizationResult = processor.getRequestBuilder(config)
.buildRequest(List.of(input), "1", Tokenization.Truncate.NONE, span)
.tokenization();
assertThat(tokenizationResult.anyTruncated(), is(false));

// now we know what the tokenization looks like
// (number of spans and size of each) fake the
// question answering response

int numberSpans = tokenizationResult.getTokens().size();
double[][][] modelTensorOutput = new double[numberSpans * 2][][];
for (int i = 0; i < numberSpans; i++) {
var windowTokens = tokenizationResult.getTokens().get(i);
// size of output
int outputSize = windowTokens.tokenIds().length;
// generate low value -ve logits that will not mark
// the expected result with a high degree of probability
double[] starts = DoubleStream.generate(() -> -randomDoubleBetween(0.001, 1.0, true)).limit(outputSize).toArray();
double[] ends = DoubleStream.generate(() -> -randomDoubleBetween(0.001, 1.0, true)).limit(outputSize).toArray();
modelTensorOutput[i * 2] = new double[][] { starts };
modelTensorOutput[(i * 2) + 1] = new double[][] { ends };
}

int spanContainingTheAnswer = randomIntBetween(0, numberSpans - 1);

// insert numbers to mark the answer in the chosen span
int answerStart = tokenizationResult.getTokens().get(spanContainingTheAnswer).seqPairOffset(); // first token of second sequence
// last token of the second sequence ignoring the final SEP added by the BERT tokenizer
int answerEnd = tokenizationResult.getTokens().get(spanContainingTheAnswer).tokenIds().length - 2;
modelTensorOutput[spanContainingTheAnswer * 2][0][answerStart] = 0.5;
modelTensorOutput[(spanContainingTheAnswer * 2) + 1][0][answerEnd] = 1.0;

NlpTask.ResultProcessor resultProcessor = processor.getResultProcessor(config);
PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(modelTensorOutput);
QuestionAnsweringInferenceResults result = (QuestionAnsweringInferenceResults) resultProcessor.processResult(
tokenizationResult,
pyTorchResult
);

int expectedStart = tokenizationResult.getTokens().get(spanContainingTheAnswer).tokens().get(1).get(0).startOffset();
int lastTokenPosition = tokenizationResult.getTokens().get(spanContainingTheAnswer).tokens().get(1).size() - 1;
int expectedEnd = tokenizationResult.getTokens().get(spanContainingTheAnswer).tokens().get(1).get(lastTokenPosition).endOffset();

assertThat(result.getAnswer(), equalTo(input.substring(expectedStart, expectedEnd)));
}
}

0 comments on commit fb99f4c

Please sign in to comment.