From d7117f22ec9ef6e0ab2115486f7f610c0d7eba6f Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 7 Dec 2021 11:54:52 +0000 Subject: [PATCH] [ML] Track token positions and use source string to tag NER entities (#81275) By recording the position of the tokens in the original source string the entity labels is correctly constructed on the original text preserving case and accent characters that were otherwise stripped during normalisation. --- .../core/ml/inference/results/NerResults.java | 6 + .../ml/inference/nlp/BertRequestBuilder.java | 6 +- .../ml/inference/nlp/FillMaskProcessor.java | 54 +++++-- .../xpack/ml/inference/nlp/NerProcessor.java | 53 +++---- .../nlp/tokenizers/BasicTokenizer.java | 114 ++++++++++---- .../nlp/tokenizers/BertTokenizer.java | 111 ++++++++------ .../nlp/tokenizers/DelimitedToken.java | 72 +++++++++ .../nlp/tokenizers/NlpTokenizer.java | 6 +- .../nlp/tokenizers/TokenizationResult.java | 30 ++-- .../nlp/tokenizers/WordPieceTokenizer.java | 104 +++++-------- .../inference/nlp/FillMaskProcessorTests.java | 28 +++- .../ml/inference/nlp/NerProcessorTests.java | 129 +++++++++++----- .../nlp/tokenizers/BasicTokenizerTests.java | 145 ++++++++++++------ .../nlp/tokenizers/BertTokenizerTests.java | 69 +++++---- .../tokenizers/WordPieceTokenizerTests.java | 48 +++--- 15 files changed, 632 insertions(+), 343 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DelimitedToken.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java index dd0907c95ead2..fd1751387942c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.results; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -163,6 +164,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } + @Override + public String toString() { + return Strings.toString(this); + } + public Map toMap() { Map map = new LinkedHashMap<>(); map.put("entity", entity); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java index 7e829da85e9f1..60ab42fd300f7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java @@ -34,7 +34,7 @@ public BertRequestBuilder(BertTokenizer tokenizer) { @Override public NlpTask.Request buildRequest(List inputs, String requestId, Tokenization.Truncate truncate) throws IOException { - if (tokenizer.getPadToken().isEmpty()) { + if (tokenizer.getPadTokenId().isEmpty()) { throw new IllegalStateException("The input tokenizer does not have a " + BertTokenizer.PAD_TOKEN + " token in its vocabulary"); } @@ -46,10 +46,10 @@ public NlpTask.Request buildRequest(List inputs, String requestId, Token @Override public NlpTask.Request buildRequest(TokenizationResult tokenization, String requestId) throws IOException { - if (tokenizer.getPadToken().isEmpty()) { + if (tokenizer.getPadTokenId().isEmpty()) { throw new IllegalStateException("The input tokenizer does not have a " + BertTokenizer.PAD_TOKEN + " token in its vocabulary"); } - return new NlpTask.Request(tokenization, jsonRequest(tokenization, tokenizer.getPadToken().getAsInt(), requestId)); + return new NlpTask.Request(tokenization, jsonRequest(tokenization, tokenizer.getPadTokenId().getAsInt(), requestId)); } static BytesReference jsonRequest(TokenizationResult tokenization, int padToken, String requestId) throws IOException { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java index 3e54d7ab6ca16..72acd2891141b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java @@ -14,7 +14,6 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult; -import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult; @@ -27,10 +26,10 @@ public class FillMaskProcessor implements NlpTask.Processor { - private final NlpTask.RequestBuilder requestBuilder; + private final NlpTokenizer tokenizer; FillMaskProcessor(NlpTokenizer tokenizer, FillMaskConfig config) { - this.requestBuilder = tokenizer.requestBuilder(); + this.tokenizer = tokenizer; } @Override @@ -39,22 +38,23 @@ public void validateInputs(List inputs) { throw new IllegalArgumentException("input request is empty"); } + final String mask = tokenizer.getMaskToken(); for (String input : inputs) { - int maskIndex = input.indexOf(BertTokenizer.MASK_TOKEN); + int maskIndex = input.indexOf(mask); if (maskIndex < 0) { - throw new IllegalArgumentException("no " + BertTokenizer.MASK_TOKEN + " token could be found"); + throw new IllegalArgumentException("no " + mask + " token could be found"); } - maskIndex = input.indexOf(BertTokenizer.MASK_TOKEN, maskIndex + BertTokenizer.MASK_TOKEN.length()); + maskIndex = input.indexOf(mask, maskIndex + mask.length()); if (maskIndex > 0) { - throw new IllegalArgumentException("only one " + BertTokenizer.MASK_TOKEN + " token should exist in the input"); + throw new IllegalArgumentException("only one " + mask + " token should exist in the input"); } } } @Override public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) { - return requestBuilder; + return tokenizer.requestBuilder(); } @Override @@ -64,25 +64,55 @@ public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) { return (tokenization, result) -> processResult( tokenization, result, + tokenizer, fillMaskConfig.getNumTopClasses(), fillMaskConfig.getResultsField() ); } else { - return (tokenization, result) -> processResult(tokenization, result, FillMaskConfig.DEFAULT_NUM_RESULTS, DEFAULT_RESULTS_FIELD); + return (tokenization, result) -> processResult( + tokenization, + result, + tokenizer, + FillMaskConfig.DEFAULT_NUM_RESULTS, + DEFAULT_RESULTS_FIELD + ); } } static InferenceResults processResult( TokenizationResult tokenization, PyTorchResult pyTorchResult, + NlpTokenizer tokenizer, int numResults, String resultsField ) { - if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokens().length == 0) { + if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokenIds().length == 0) { return new WarningInferenceResults("No valid tokens for inference"); } - int maskTokenIndex = Arrays.asList(tokenization.getTokenizations().get(0).getTokens()).indexOf(BertTokenizer.MASK_TOKEN); + if (tokenizer.getMaskTokenId().isEmpty()) { + return new WarningInferenceResults( + "The token id for the mask token {} is not known in the tokenizer. Check the vocabulary contains the mask token", + tokenizer.getMaskToken() + ); + } + + int maskTokenIndex = -1; + int maskTokenId = tokenizer.getMaskTokenId().getAsInt(); + for (int i = 0; i < tokenization.getTokenizations().get(0).getTokenIds().length; i++) { + if (tokenization.getTokenizations().get(0).getTokenIds()[i] == maskTokenId) { + maskTokenIndex = i; + break; + } + } + if (maskTokenIndex == -1) { + return new WarningInferenceResults( + "mask token id [{}] not found in the tokenization {}", + maskTokenId, + Arrays.asList(tokenization.getTokenizations().get(0).getTokenIds()) + ); + } + // TODO - process all results in the batch double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0][maskTokenIndex]); @@ -103,7 +133,7 @@ static InferenceResults processResult( tokenization.getTokenizations() .get(0) .getInput() - .replace(BertTokenizer.MASK_TOKEN, tokenization.getFromVocab(scoreAndIndices[0].index)), + .replace(tokenizer.getMaskToken(), tokenization.getFromVocab(scoreAndIndices[0].index)), results, Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD), scoreAndIndices[0].score, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java index 2674d47fdaa4d..3d33d577066c1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult; @@ -193,7 +194,7 @@ static class NerResultProcessor implements NlpTask.ResultProcessor { @Override public InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) { - if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokens().length == 0) { + if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokenIds().length == 0) { return new WarningInferenceResults("no valid tokens to build result"); } // TODO - process all results in the batch @@ -213,6 +214,7 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchRe ? tokenization.getTokenizations().get(0).getInput().toLowerCase(Locale.ROOT) : tokenization.getTokenizations().get(0).getInput() ); + return new NerResults( resultsField, buildAnnotatedText(tokenization.getTokenizations().get(0).getInput(), entities), @@ -230,7 +232,7 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchRe static List tagTokens(TokenizationResult.Tokenization tokenization, double[][] scores, IobTag[] iobMap) { List taggedTokens = new ArrayList<>(); int startTokenIndex = 0; - while (startTokenIndex < tokenization.getTokens().length) { + while (startTokenIndex < tokenization.getTokenIds().length) { int inputMapping = tokenization.getTokenMap()[startTokenIndex]; if (inputMapping < 0) { // This token does not map to a token in the input (special tokens) @@ -238,15 +240,9 @@ static List tagTokens(TokenizationResult.Tokenization tokenization, continue; } int endTokenIndex = startTokenIndex; - StringBuilder word = new StringBuilder(tokenization.getTokens()[startTokenIndex]); - while (endTokenIndex < tokenization.getTokens().length - 1 + while (endTokenIndex < tokenization.getTokenMap().length - 1 && tokenization.getTokenMap()[endTokenIndex + 1] == inputMapping) { endTokenIndex++; - // TODO Here we try to get rid of the continuation hashes at the beginning of sub-tokens. - // It is probably more correct to implement detokenization on the tokenizer - // that does reverse lookup based on token IDs. - String endTokenWord = tokenization.getTokens()[endTokenIndex].substring(2); - word.append(endTokenWord); } double[] avgScores = Arrays.copyOf(scores[startTokenIndex], iobMap.length); for (int i = startTokenIndex + 1; i <= endTokenIndex; i++) { @@ -262,7 +258,7 @@ static List tagTokens(TokenizationResult.Tokenization tokenization, } int maxScoreIndex = NlpHelpers.argmax(avgScores); double score = avgScores[maxScoreIndex]; - taggedTokens.add(new TaggedToken(word.toString(), iobMap[maxScoreIndex], score)); + taggedTokens.add(new TaggedToken(tokenization.getTokens().get(inputMapping), iobMap[maxScoreIndex], score)); startTokenIndex = endTokenIndex + 1; } return taggedTokens; @@ -283,14 +279,12 @@ static List groupTaggedTokens(List tokens, } List entities = new ArrayList<>(); int startTokenIndex = 0; - int startFindInSeq = 0; while (startTokenIndex < tokens.size()) { TaggedToken token = tokens.get(startTokenIndex); if (token.tag.getEntity() == Entity.NONE) { startTokenIndex++; continue; } - StringBuilder entityWord = new StringBuilder(token.word); int endTokenIndex = startTokenIndex + 1; double scoreSum = token.score; while (endTokenIndex < tokens.size()) { @@ -298,43 +292,50 @@ static List groupTaggedTokens(List tokens, if (endToken.tag.isBeginning() || endToken.tag.getEntity() != token.tag.getEntity()) { break; } - // TODO Here we add a space between tokens. - // It is probably more correct to implement detokenization on the tokenizer - // that does reverse lookup based on token IDs. - entityWord.append(" ").append(endToken.word); scoreSum += endToken.score; endTokenIndex++; } - String entity = entityWord.toString(); - int i = inputSeq.indexOf(entity, startFindInSeq); + + int startPos = token.token.getStartPos(); + int endPos = tokens.get(endTokenIndex - 1).token.getEndPos(); + String entity = inputSeq.substring(startPos, endPos); entities.add( new NerResults.EntityGroup( entity, token.tag.getEntity().toString(), scoreSum / (endTokenIndex - startTokenIndex), - i, - i == -1 ? -1 : i + entity.length() + startPos, + endPos ) ); startTokenIndex = endTokenIndex; - if (i != -1) { - startFindInSeq = i + entity.length(); - } } return entities; } static class TaggedToken { - private final String word; + private final DelimitedToken token; private final IobTag tag; private final double score; - TaggedToken(String word, IobTag tag, double score) { - this.word = word; + TaggedToken(DelimitedToken token, IobTag tag, double score) { + this.token = token; this.tag = tag; this.score = score; } + + @Override + public String toString() { + return new StringBuilder("{").append("token:") + .append(token) + .append(", ") + .append(tag) + .append(", ") + .append(score) + .append("}") + .toString(); + } } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java index d98f1000d42e4..77af60cab0119 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java @@ -15,6 +15,7 @@ import java.util.Locale; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; /** * Basic tokenization of text by whitespace with optional extras: @@ -45,7 +46,7 @@ public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars, boolean i this.isLowerCase = isLowerCase; this.isTokenizeCjkChars = isTokenizeCjkChars; this.isStripAccents = isStripAccents; - this.neverSplitTokenTrieRoot = TokenTrieNode.build(neverSplit, this::doTokenize); + this.neverSplitTokenTrieRoot = TokenTrieNode.build(neverSplit, this::doTokenizeString); } public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars, boolean isStripAccents) { @@ -74,46 +75,51 @@ public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars) { * @param text The input text to tokenize * @return List of tokens */ - public List tokenize(String text) { + public List tokenize(String text) { return mergeNeverSplitTokens(doTokenize(text)); } - private List doTokenize(String text) { + private List doTokenizeString(String text) { + return doTokenize(text).stream().map(DelimitedToken::getToken).collect(Collectors.toList()); + } + + private List doTokenize(String text) { text = cleanText(text); if (isTokenizeCjkChars) { text = tokenizeCjkChars(text); } - String[] tokens = whiteSpaceTokenize(text); + List tokens = whiteSpaceTokenize(text); - List processedTokens = new ArrayList<>(tokens.length); - for (String token : tokens) { + List processedTokens = new ArrayList<>(tokens.size()); + for (DelimitedToken tokenRecord : tokens) { - if (Strings.EMPTY.equals(token)) { + String tokenStr = tokenRecord.getToken(); + if (Strings.EMPTY.equals(tokenStr)) { continue; } if (isLowerCase) { - token = token.toLowerCase(Locale.ROOT); + tokenStr = tokenStr.toLowerCase(Locale.ROOT); } if (isStripAccents) { - token = stripAccents(token); + tokenStr = stripAccents(tokenStr); } - processedTokens.addAll(splitOnPunctuation(token)); + processedTokens.addAll(splitOnPunctuation(new DelimitedToken(tokenRecord.getStartPos(), tokenRecord.getEndPos(), tokenStr))); } return processedTokens; } - private List mergeNeverSplitTokens(List tokens) { + private List mergeNeverSplitTokens(List tokens) { if (neverSplitTokenTrieRoot.isLeaf()) { return tokens; } - List mergedTokens = new ArrayList<>(tokens.size()); - List matchingTokens = new ArrayList<>(); + List mergedTokens = new ArrayList<>(tokens.size()); + List matchingTokens = new ArrayList<>(); TokenTrieNode current = neverSplitTokenTrieRoot; - for (String token : tokens) { - TokenTrieNode childNode = current.getChild(token); + for (DelimitedToken token : tokens) { + TokenTrieNode childNode = current.getChild(token.getToken()); if (childNode == null) { if (current != neverSplitTokenTrieRoot) { mergedTokens.addAll(matchingTokens); @@ -123,7 +129,7 @@ private List mergeNeverSplitTokens(List tokens) { mergedTokens.add(token); } else if (childNode.isLeaf()) { matchingTokens.add(token); - mergedTokens.add(String.join("", matchingTokens)); + mergedTokens.add(DelimitedToken.mergeTokens(matchingTokens)); matchingTokens = new ArrayList<>(); current = neverSplitTokenTrieRoot; } else { @@ -146,9 +152,53 @@ public boolean isTokenizeCjkChars() { return isTokenizeCjkChars; } - static String[] whiteSpaceTokenize(String text) { - text = text.trim(); - return text.split(" "); + /** + * Split the input text by whitespace. + * For the returned objects {@link DelimitedToken#getStartPos()} is the + * start character index inclusive and {@link DelimitedToken#getEndPos()} + * the index exclusive. The number of whitespace characters between 2 consecutive + * {@link DelimitedToken}s is the difference between the first's {@code endPos} + * and the second's {@code startPos}. + * + * The input should be normalized via a call to {@link #cleanText(String)} + * before it is passed to this function. + * + * @param text to tokenize + * @return White space separated strings + */ + static List whiteSpaceTokenize(String text) { + var tokens = new ArrayList(); + + // whitespace at beginning + int index = 0; + while (index < text.length() && text.charAt(index) == ' ') { + index++; + } + + int tokenStart = index; + + while (index < text.length()) { + if (text.charAt(index) == ' ') { + int tokenEnd = index; + index++; + // consume trail whitespace before the next word + // or end of text + while (index < text.length() && text.charAt(index) == ' ') { + index++; + } + + tokens.add(new DelimitedToken(tokenStart, tokenEnd, text.substring(tokenStart, tokenEnd))); + tokenStart = index; + } + index++; + } + + // trailing whitespace + if (tokenStart != text.length()) { + tokens.add(new DelimitedToken(tokenStart, text.length(), text.substring(tokenStart))); + } + + return tokens; } /** @@ -172,9 +222,9 @@ static String stripAccents(String word) { return new String(codePoints, 0, codePoints.length); } - static List splitOnPunctuation(String word) { - List split = new ArrayList<>(); - int[] codePoints = word.codePoints().toArray(); + static List splitOnPunctuation(DelimitedToken word) { + List splits = new ArrayList<>(); + int[] codePoints = word.getToken().codePoints().toArray(); int lastSplit = 0; for (int i = 0; i < codePoints.length; i++) { @@ -182,18 +232,30 @@ static List splitOnPunctuation(String word) { int charCount = i - lastSplit; if (charCount > 0) { // add a new string for what has gone before - split.add(new String(codePoints, lastSplit, i - lastSplit)); + splits.add( + new DelimitedToken( + word.getStartPos() + lastSplit, + word.getStartPos() + i, + new String(codePoints, lastSplit, i - lastSplit) + ) + ); } - split.add(new String(codePoints, i, 1)); + splits.add(new DelimitedToken(word.getStartPos() + i, word.getStartPos() + i + 1, new String(codePoints, i, 1))); lastSplit = i + 1; } } if (lastSplit < codePoints.length) { - split.add(new String(codePoints, lastSplit, codePoints.length - lastSplit)); + splits.add( + new DelimitedToken( + word.getStartPos() + lastSplit, + word.getStartPos() + codePoints.length, + new String(codePoints, lastSplit, codePoints.length - lastSplit) + ) + ); } - return split; + return splits; } /** diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java index 837f10e1e3bfb..9163a8737b48c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; import org.elasticsearch.common.util.set.Sets; -import org.elasticsearch.core.Tuple; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilder; @@ -80,7 +79,7 @@ protected BertTokenizer( } @Override - public OptionalInt getPadToken() { + public OptionalInt getPadTokenId() { Integer pad = vocab.get(PAD_TOKEN); if (pad != null) { return OptionalInt.of(pad); @@ -89,6 +88,21 @@ public OptionalInt getPadToken() { } } + @Override + public OptionalInt getMaskTokenId() { + Integer pad = vocab.get(MASK_TOKEN); + if (pad != null) { + return OptionalInt.of(pad); + } else { + return OptionalInt.empty(); + } + } + + @Override + public String getMaskToken() { + return MASK_TOKEN; + } + @Override public TokenizationResult buildTokenizationResult(List tokenizations) { TokenizationResult tokenizationResult = new TokenizationResult(originalVocab); @@ -112,16 +126,17 @@ public TokenizationResult buildTokenizationResult(List wordPieceTokens = innerResult.v1(); - List tokenPositionMap = innerResult.v2(); - int numTokens = withSpecialTokens ? wordPieceTokens.size() + 2 : wordPieceTokens.size(); + List wordPieceTokenIds = innerResult.wordPieceTokenIds; + List tokenPositionMap = innerResult.tokenPositionMap; + int numTokens = withSpecialTokens ? wordPieceTokenIds.size() + 2 : wordPieceTokenIds.size(); boolean isTruncated = false; + if (numTokens > maxSequenceLength) { switch (truncate) { case FIRST: case SECOND: isTruncated = true; - wordPieceTokens = wordPieceTokens.subList(0, withSpecialTokens ? maxSequenceLength - 2 : maxSequenceLength); + wordPieceTokenIds = wordPieceTokenIds.subList(0, withSpecialTokens ? maxSequenceLength - 2 : maxSequenceLength); break; case NONE: throw ExceptionsHelper.badRequestException( @@ -132,78 +147,75 @@ public TokenizationResult.Tokenization tokenize(String seq, Tokenization.Truncat } numTokens = maxSequenceLength; } - String[] tokens = new String[numTokens]; + int[] tokenIds = new int[numTokens]; int[] tokenMap = new int[numTokens]; if (withSpecialTokens) { - tokens[0] = CLASS_TOKEN; tokenIds[0] = vocab.get(CLASS_TOKEN); tokenMap[0] = SPECIAL_TOKEN_POSITION; } int i = withSpecialTokens ? 1 : 0; final int decrementHandler = withSpecialTokens ? 1 : 0; - for (WordPieceTokenizer.TokenAndId tokenAndId : wordPieceTokens) { - tokens[i] = tokenAndId.getToken(); - tokenIds[i] = tokenAndId.getId(); + for (var tokenId : wordPieceTokenIds) { + tokenIds[i] = tokenId; tokenMap[i] = tokenPositionMap.get(i - decrementHandler); i++; } if (withSpecialTokens) { - tokens[i] = SEPARATOR_TOKEN; tokenIds[i] = vocab.get(SEPARATOR_TOKEN); tokenMap[i] = SPECIAL_TOKEN_POSITION; } - return new TokenizationResult.Tokenization(seq, isTruncated, tokens, tokenIds, tokenMap); + return new TokenizationResult.Tokenization(seq, innerResult.tokens, isTruncated, tokenIds, tokenMap); } @Override public TokenizationResult.Tokenization tokenize(String seq1, String seq2, Tokenization.Truncate truncate) { - var innerResult = innerTokenize(seq1); - List wordPieceTokenSeq1s = innerResult.v1(); - List tokenPositionMapSeq1 = innerResult.v2(); - innerResult = innerTokenize(seq2); - List wordPieceTokenSeq2s = innerResult.v1(); - List tokenPositionMapSeq2 = innerResult.v2(); + var innerResultSeq1 = innerTokenize(seq1); + List wordPieceTokenIdsSeq1 = innerResultSeq1.wordPieceTokenIds; + List tokenPositionMapSeq1 = innerResultSeq1.tokenPositionMap; + var innerResultSeq2 = innerTokenize(seq2); + List wordPieceTokenIdsSeq2 = innerResultSeq2.wordPieceTokenIds; + List tokenPositionMapSeq2 = innerResultSeq2.tokenPositionMap; if (withSpecialTokens == false) { throw new IllegalArgumentException("Unable to do sequence pair tokenization without special tokens"); } // [CLS] seq1 [SEP] seq2 [SEP] - int numTokens = wordPieceTokenSeq1s.size() + wordPieceTokenSeq2s.size() + 3; + int numTokens = wordPieceTokenIdsSeq1.size() + wordPieceTokenIdsSeq2.size() + 3; boolean isTruncated = false; if (numTokens > maxSequenceLength) { switch (truncate) { case FIRST: isTruncated = true; - if (wordPieceTokenSeq2s.size() > maxSequenceLength - 3) { + if (wordPieceTokenIdsSeq2.size() > maxSequenceLength - 3) { throw ExceptionsHelper.badRequestException( "Attempting truncation [{}] but input is too large for the second sequence. " + "The tokenized input length [{}] exceeds the maximum sequence length [{}], " + "when taking special tokens into account", truncate.toString(), - wordPieceTokenSeq2s.size(), + wordPieceTokenIdsSeq2.size(), maxSequenceLength - 3 ); } - wordPieceTokenSeq1s = wordPieceTokenSeq1s.subList(0, maxSequenceLength - 3 - wordPieceTokenSeq2s.size()); + wordPieceTokenIdsSeq1 = wordPieceTokenIdsSeq1.subList(0, maxSequenceLength - 3 - wordPieceTokenIdsSeq2.size()); break; case SECOND: isTruncated = true; - if (wordPieceTokenSeq1s.size() > maxSequenceLength - 3) { + if (wordPieceTokenIdsSeq1.size() > maxSequenceLength - 3) { throw ExceptionsHelper.badRequestException( "Attempting truncation [{}] but input is too large for the first sequence. " + "The tokenized input length [{}] exceeds the maximum sequence length [{}], " + "when taking special tokens into account", truncate.toString(), - wordPieceTokenSeq2s.size(), + wordPieceTokenIdsSeq1.size(), maxSequenceLength - 3 ); } - wordPieceTokenSeq2s = wordPieceTokenSeq2s.subList(0, maxSequenceLength - 3 - wordPieceTokenSeq1s.size()); + wordPieceTokenIdsSeq2 = wordPieceTokenIdsSeq2.subList(0, maxSequenceLength - 3 - wordPieceTokenIdsSeq1.size()); break; case NONE: throw ExceptionsHelper.badRequestException( @@ -214,62 +226,71 @@ public TokenizationResult.Tokenization tokenize(String seq1, String seq2, Tokeni } numTokens = maxSequenceLength; } - String[] tokens = new String[numTokens]; int[] tokenIds = new int[numTokens]; int[] tokenMap = new int[numTokens]; - tokens[0] = CLASS_TOKEN; tokenIds[0] = vocab.get(CLASS_TOKEN); tokenMap[0] = SPECIAL_TOKEN_POSITION; int i = 1; - for (WordPieceTokenizer.TokenAndId tokenAndId : wordPieceTokenSeq1s) { - tokens[i] = tokenAndId.getToken(); - tokenIds[i] = tokenAndId.getId(); + for (var tokenId : wordPieceTokenIdsSeq1) { + tokenIds[i] = tokenId; tokenMap[i] = tokenPositionMapSeq1.get(i - 1); i++; } - tokens[i] = SEPARATOR_TOKEN; tokenIds[i] = vocab.get(SEPARATOR_TOKEN); tokenMap[i] = SPECIAL_TOKEN_POSITION; ++i; int j = 0; - for (WordPieceTokenizer.TokenAndId tokenAndId : wordPieceTokenSeq2s) { - tokens[i] = tokenAndId.getToken(); - tokenIds[i] = tokenAndId.getId(); + for (var tokenId : wordPieceTokenIdsSeq2) { + tokenIds[i] = tokenId; tokenMap[i] = tokenPositionMapSeq2.get(j); i++; j++; } - tokens[i] = SEPARATOR_TOKEN; tokenIds[i] = vocab.get(SEPARATOR_TOKEN); tokenMap[i] = SPECIAL_TOKEN_POSITION; - return new TokenizationResult.Tokenization(seq1 + seq2, isTruncated, tokens, tokenIds, tokenMap); + List tokens = new ArrayList<>(innerResultSeq1.tokens); + tokens.addAll(innerResultSeq2.tokens); + return new TokenizationResult.Tokenization(seq1 + seq2, tokens, isTruncated, tokenIds, tokenMap); } - private Tuple, List> innerTokenize(String seq) { + private InnerTokenization innerTokenize(String seq) { BasicTokenizer basicTokenizer = new BasicTokenizer(doLowerCase, doTokenizeCjKChars, doStripAccents, neverSplit); - List delineatedTokens = basicTokenizer.tokenize(seq); - List wordPieceTokens = new ArrayList<>(); + var tokenSequences = basicTokenizer.tokenize(seq); + List wordPieceTokens = new ArrayList<>(); List tokenPositionMap = new ArrayList<>(); - for (int sourceIndex = 0; sourceIndex < delineatedTokens.size(); sourceIndex++) { - String token = delineatedTokens.get(sourceIndex); + for (int sourceIndex = 0; sourceIndex < tokenSequences.size(); sourceIndex++) { + String token = tokenSequences.get(sourceIndex).getToken(); if (neverSplit.contains(token)) { - wordPieceTokens.add(new WordPieceTokenizer.TokenAndId(token, vocab.getOrDefault(token, vocab.get(UNKNOWN_TOKEN)))); + wordPieceTokens.add(vocab.getOrDefault(token, vocab.get(UNKNOWN_TOKEN))); tokenPositionMap.add(sourceIndex); } else { - List tokens = wordPieceTokenizer.tokenize(token); + List tokens = wordPieceTokenizer.tokenize(tokenSequences.get(sourceIndex)); for (int tokenCount = 0; tokenCount < tokens.size(); tokenCount++) { tokenPositionMap.add(sourceIndex); } wordPieceTokens.addAll(tokens); } } - return Tuple.tuple(wordPieceTokens, tokenPositionMap); + + return new InnerTokenization(tokenSequences, wordPieceTokens, tokenPositionMap); + } + + private static class InnerTokenization { + List tokens; + List wordPieceTokenIds; + List tokenPositionMap; + + InnerTokenization(List tokens, List wordPieceTokenIds, List tokenPositionMap) { + this.tokens = tokens; + this.wordPieceTokenIds = wordPieceTokenIds; + this.tokenPositionMap = tokenPositionMap; + } } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DelimitedToken.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DelimitedToken.java new file mode 100644 index 0000000000000..74f1121cc467f --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DelimitedToken.java @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; + +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +public class DelimitedToken { + + /** + * Merges the list of tokens. + * + * Assumes that the tokens are in order. + * + * @param tokens + * @return The merged token + */ + public static DelimitedToken mergeTokens(List tokens) { + if (tokens.size() == 1) { + return tokens.get(0); + } + + String merged = tokens.stream().map(DelimitedToken::getToken).collect(Collectors.joining()); + return new DelimitedToken(tokens.get(0).getStartPos(), tokens.get(tokens.size() - 1).getEndPos(), merged); + } + + private final int startPos; + private final int endPos; + private final String token; + + DelimitedToken(int startPos, int endPos, String token) { + this.startPos = startPos; + this.endPos = endPos; + this.token = token; + } + + public int getStartPos() { + return startPos; + } + + public int getEndPos() { + return endPos; + } + + public String getToken() { + return token; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DelimitedToken that = (DelimitedToken) o; + return startPos == that.startPos && endPos == that.endPos && Objects.equals(token, that.token); + } + + @Override + public int hashCode() { + return Objects.hash(startPos, endPos, token); + } + + @Override + public String toString() { + return "{" + "startPos=" + startPos + ", endPos=" + endPos + ", token=" + token + '}'; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java index c039ab8de0fda..1146b2d8ae1c1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java @@ -30,7 +30,11 @@ public interface NlpTokenizer { NlpTask.RequestBuilder requestBuilder(); - OptionalInt getPadToken(); + OptionalInt getPadTokenId(); + + OptionalInt getMaskTokenId(); + + String getMaskToken(); static NlpTokenizer build(Vocabulary vocabulary, Tokenization params) { ExceptionsHelper.requireNonNull(params, TOKENIZATION); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java index b50c9548504a9..862be3c43bf67 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java @@ -33,9 +33,9 @@ public List getTokenizations() { return tokenizations; } - public void addTokenization(String input, boolean isTruncated, String[] tokens, int[] tokenIds, int[] tokenMap) { + public void addTokenization(String input, boolean isTruncated, List tokens, int[] tokenIds, int[] tokenMap) { maxLength = Math.max(maxLength, tokenIds.length); - tokenizations.add(new Tokenization(input, isTruncated, tokens, tokenIds, tokenMap)); + tokenizations.add(new Tokenization(input, tokens, isTruncated, tokenIds, tokenMap)); } public void addTokenization(Tokenization tokenization) { @@ -49,16 +49,15 @@ public int getLongestSequenceLength() { public static class Tokenization { - private final String inputSeqs; - private final String[] tokens; + private final String input; + private final List tokens; private final int[] tokenIds; private final int[] tokenMap; private final boolean truncated; - public Tokenization(String input, boolean truncated, String[] tokens, int[] tokenIds, int[] tokenMap) { - assert tokens.length == tokenIds.length; + public Tokenization(String input, List tokens, boolean truncated, int[] tokenIds, int[] tokenMap) { assert tokenIds.length == tokenMap.length; - this.inputSeqs = input; + this.input = input; this.tokens = tokens; this.tokenIds = tokenIds; this.tokenMap = tokenMap; @@ -66,16 +65,7 @@ public Tokenization(String input, boolean truncated, String[] tokens, int[] toke } /** - * The token strings from the tokenization process - * - * @return A list of tokens - */ - public String[] getTokens() { - return tokens; - } - - /** - * The integer values of the tokens in {@link #getTokens()} + * The integer values of the tokens} * * @return A list of token Ids */ @@ -95,7 +85,11 @@ public int[] getTokenMap() { } public String getInput() { - return inputSeqs; + return input; + } + + public List getTokens() { + return tokens; } public boolean isTruncated() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizer.java index a566007a8c5ae..b50e70f85f12a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizer.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -25,26 +26,7 @@ public class WordPieceTokenizer { private final String unknownToken; private final int maxInputCharsPerWord; - public static class TokenAndId { - private final String token; - private final int id; - - TokenAndId(String token, int id) { - this.token = token; - this.id = id; - } - - public int getId() { - return id; - } - - public String getToken() { - return token; - } - } - /** - * * @param vocab The token vocabulary * @param unknownToken If not found in the vocabulary * @param maxInputCharsPerWord Inputs tokens longer than this are 'unknown' @@ -58,63 +40,55 @@ public WordPieceTokenizer(Map vocab, String unknownToken, int m /** * Wordpiece tokenize the input text. * - * @param text A single token or whitespace separated tokens. - * Input should have been normalized by the {@link BasicTokenizer}. - * @return List of tokens + * @param token Word to tokenize + * @return List of token IDs */ - public List tokenize(String text) { - String[] tokens = BasicTokenizer.whiteSpaceTokenize(text); - - List output = new ArrayList<>(); - for (String token : tokens) { - if (token.length() > maxInputCharsPerWord) { - assert vocab.containsKey(unknownToken); - output.add(new TokenAndId(unknownToken, vocab.get(unknownToken))); - continue; - } + public List tokenize(DelimitedToken token) { + + if (token.getToken().length() > maxInputCharsPerWord) { + assert vocab.containsKey(unknownToken); + return Collections.singletonList(vocab.get(unknownToken)); + } - boolean isBad = false; - int start = 0; - List subTokens = new ArrayList<>(); - int length = token.length(); - while (start < length) { - int end = length; - - String currentValidSubStr = null; - - while (start < end) { - String subStr; - if (start > 0) { - subStr = CONTINUATION + token.substring(start, end); - } else { - subStr = token.substring(start, end); - } - - if (vocab.containsKey(subStr)) { - currentValidSubStr = subStr; - break; - } - - end--; + List output = new ArrayList<>(); + boolean isBad = false; + int start = 0; + int length = token.getToken().length(); + while (start < length) { + int end = length; + + String currentValidSubStr = null; + + while (start < end) { + String subStr; + if (start > 0) { + subStr = CONTINUATION + token.getToken().substring(start, end); + } else { + subStr = token.getToken().substring(start, end); } - if (currentValidSubStr == null) { - isBad = true; + if (vocab.containsKey(subStr)) { + currentValidSubStr = subStr; break; } - subTokens.add(new TokenAndId(currentValidSubStr, vocab.get(currentValidSubStr))); - - start = end; + end--; } - if (isBad) { - output.add(new TokenAndId(unknownToken, vocab.get(unknownToken))); - } else { - output.addAll(subTokens); + if (currentValidSubStr == null) { + isBad = true; + break; } + + output.add(vocab.get(currentValidSubStr)); + + start = end; } - return output; + if (isBad) { + return Collections.singletonList(vocab.get(unknownToken)); + } else { + return output; + } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java index e43d1f3e41d60..af41090e4567a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java @@ -14,18 +14,22 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig; import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BasicTokenizer; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.OptionalInt; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class FillMaskProcessorTests extends ESTestCase { @@ -45,17 +49,23 @@ public void testProcessResults() { String input = "The capital of " + BertTokenizer.MASK_TOKEN + " is Paris"; List vocab = Arrays.asList("The", "capital", "of", BertTokenizer.MASK_TOKEN, "is", "Paris", "France"); - String[] tokens = input.split(" "); + List tokens = new BasicTokenizer(randomBoolean(), randomBoolean(), randomBoolean()).tokenize(input); + int[] tokenMap = new int[] { 0, 1, 2, 3, 4, 5 }; int[] tokenIds = new int[] { 0, 1, 2, 3, 4, 5 }; TokenizationResult tokenization = new TokenizationResult(vocab); tokenization.addTokenization(input, false, tokens, tokenIds, tokenMap); + BertTokenizer tokenizer = mock(BertTokenizer.class); + when(tokenizer.getMaskToken()).thenReturn(BertTokenizer.MASK_TOKEN); + when(tokenizer.getMaskTokenId()).thenReturn(OptionalInt.of(3)); + String resultsField = randomAlphaOfLength(10); FillMaskResults result = (FillMaskResults) FillMaskProcessor.processResult( tokenization, new PyTorchResult("1", scores, 0L, null), + tokenizer, 4, resultsField ); @@ -72,12 +82,15 @@ public void testProcessResults() { } public void testProcessResults_GivenMissingTokens() { + BertTokenizer tokenizer = mock(BertTokenizer.class); + when(tokenizer.getMaskToken()).thenReturn("[MASK]"); + TokenizationResult tokenization = new TokenizationResult(Collections.emptyList()); - tokenization.addTokenization("", false, new String[] {}, new int[] {}, new int[] {}); + tokenization.addTokenization("", false, Collections.emptyList(), new int[] {}, new int[] {}); PyTorchResult pyTorchResult = new PyTorchResult("1", new double[][][] { { {} } }, 0L, null); assertThat( - FillMaskProcessor.processResult(tokenization, pyTorchResult, 5, randomAlphaOfLength(10)), + FillMaskProcessor.processResult(tokenization, pyTorchResult, tokenizer, 5, randomAlphaOfLength(10)), instanceOf(WarningInferenceResults.class) ); } @@ -85,8 +98,10 @@ public void testProcessResults_GivenMissingTokens() { public void testValidate_GivenMissingMaskToken() { List input = List.of("The capital of France is Paris"); + BertTokenizer tokenizer = mock(BertTokenizer.class); + when(tokenizer.getMaskToken()).thenReturn("[MASK]"); FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null, null, null); - FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config); + FillMaskProcessor processor = new FillMaskProcessor(tokenizer, config); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> processor.validateInputs(input)); assertThat(e.getMessage(), containsString("no [MASK] token could be found")); @@ -95,8 +110,11 @@ public void testValidate_GivenMissingMaskToken() { public void testProcessResults_GivenMultipleMaskTokens() { List input = List.of("The capital of [MASK] is [MASK]"); + BertTokenizer tokenizer = mock(BertTokenizer.class); + when(tokenizer.getMaskToken()).thenReturn("[MASK]"); + FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null, null, null); - FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config); + FillMaskProcessor processor = new FillMaskProcessor(tokenizer, config); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> processor.validateInputs(input)); assertThat(e.getMessage(), containsString("only one [MASK] token should exist in the input")); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java index 1375a16bc02f8..041d74e7656c4 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java @@ -16,7 +16,9 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig; import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BasicTokenizer; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult; import java.util.ArrayList; @@ -160,23 +162,26 @@ public void testProcessResults_withIobMap() { } public void testGroupTaggedTokens() { - List tokens = new ArrayList<>(); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("Hi", NerProcessor.IobTag.O, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("Sarah", NerProcessor.IobTag.B_PER, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("Jessica", NerProcessor.IobTag.I_PER, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("I", NerProcessor.IobTag.O, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("live", NerProcessor.IobTag.O, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("in", NerProcessor.IobTag.O, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("Manchester", NerProcessor.IobTag.B_LOC, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("and", NerProcessor.IobTag.O, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("work", NerProcessor.IobTag.O, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("for", NerProcessor.IobTag.O, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("Elastic", NerProcessor.IobTag.B_ORG, 1.0)); - - List entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens( - tokens, - "Hi Sarah Jessica, I live in Manchester and work for Elastic" - ); + String input = "Hi Sarah Jessica, I live in Manchester and work for Elastic"; + List tokens = new BasicTokenizer(randomBoolean(), randomBoolean(), randomBoolean()).tokenize(input); + assertThat(tokens, hasSize(12)); + + List taggedTokens = new ArrayList<>(); + int i = 0; + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_LOC, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_ORG, 1.0)); + + List entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input); assertThat(entityGroups, hasSize(3)); assertThat(entityGroups.get(0).getClassName(), equalTo("PER")); assertThat(entityGroups.get(0).getEntity(), equalTo("Sarah Jessica")); @@ -187,23 +192,32 @@ public void testGroupTaggedTokens() { } public void testGroupTaggedTokens_GivenNoEntities() { - List tokens = new ArrayList<>(); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("Hi", NerProcessor.IobTag.O, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("there", NerProcessor.IobTag.O, 1.0)); + String input = "Hi there"; + List tokens = new BasicTokenizer(randomBoolean(), randomBoolean(), randomBoolean()).tokenize(input); + + List taggedTokens = new ArrayList<>(); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(0), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(1), NerProcessor.IobTag.O, 1.0)); - List entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(tokens, "Hi there"); + List entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input); assertThat(entityGroups, is(empty())); } public void testGroupTaggedTokens_GivenConsecutiveEntities() { - List tokens = new ArrayList<>(); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("Rita", NerProcessor.IobTag.B_PER, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("Sue", NerProcessor.IobTag.B_PER, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("and", NerProcessor.IobTag.O, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("Bob", NerProcessor.IobTag.B_PER, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("too", NerProcessor.IobTag.O, 1.0)); - - List entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(tokens, "Rita, Sue, and Bob too"); + String input = "Rita, Sue, and Bob too"; + List tokens = new BasicTokenizer(randomBoolean(), randomBoolean(), randomBoolean()).tokenize(input); + + List taggedTokens = new ArrayList<>(); + int i = 0; + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + + List entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input); assertThat(entityGroups, hasSize(3)); assertThat(entityGroups.get(0).getClassName(), equalTo("PER")); assertThat(entityGroups.get(0).getEntity(), equalTo("Rita")); @@ -214,17 +228,19 @@ public void testGroupTaggedTokens_GivenConsecutiveEntities() { } public void testGroupTaggedTokens_GivenConsecutiveContinuingEntities() { - List tokens = new ArrayList<>(); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("FirstName", NerProcessor.IobTag.B_PER, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("SecondName", NerProcessor.IobTag.I_PER, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("NextPerson", NerProcessor.IobTag.B_PER, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("NextPersonSecondName", NerProcessor.IobTag.I_PER, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("something_else", NerProcessor.IobTag.B_ORG, 1.0)); - - List entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens( - tokens, - "FirstName SecondName, NextPerson NextPersonSecondName. something_else" - ); + String input = "FirstName SecondName, NextPerson NextPersonSecondName. something_else"; + List tokens = new BasicTokenizer(randomBoolean(), randomBoolean(), randomBoolean()).tokenize(input); + + List taggedTokens = new ArrayList<>(); + int i = 0; + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_ORG, 1.0)); + + List entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input); assertThat(entityGroups, hasSize(3)); assertThat(entityGroups.get(0).getClassName(), equalTo("PER")); assertThat(entityGroups.get(0).getEntity(), equalTo("FirstName SecondName")); @@ -233,6 +249,39 @@ public void testGroupTaggedTokens_GivenConsecutiveContinuingEntities() { assertThat(entityGroups.get(2).getClassName(), equalTo("ORG")); } + public void testEntityContainsPunctuation() { + String input = "Alexander, my name is Benjamin Trent, I work at Acme Inc.."; + List tokens = new BasicTokenizer(randomBoolean(), randomBoolean(), randomBoolean()).tokenize(input); + + List taggedTokens = new ArrayList<>(); + int i = 0; + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_ORG, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_ORG, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_ORG, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + assertEquals(tokens.size(), taggedTokens.size()); + + List entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input); + assertThat(entityGroups, hasSize(3)); + assertThat(entityGroups.get(0).getClassName(), equalTo("PER")); + assertThat(entityGroups.get(0).getEntity(), equalTo("Alexander")); + assertThat(entityGroups.get(1).getClassName(), equalTo("PER")); + assertThat(entityGroups.get(1).getEntity(), equalTo("Benjamin Trent")); + assertThat(entityGroups.get(2).getClassName(), equalTo("ORG")); + assertThat(entityGroups.get(2).getEntity(), equalTo("Acme Inc.")); + } + public void testAnnotatedTextBuilder() { String input = "Alexander, my name is Benjamin Trent, I work at Acme Inc."; List entities = List.of( diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java index 87b02e6066a9c..4bdd24cafe92a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java @@ -11,9 +11,11 @@ import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; -import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.sameInstance; /** @@ -24,91 +26,102 @@ public class BasicTokenizerTests extends ESTestCase { public void testLowerCase() { BasicTokenizer tokenizer = new BasicTokenizer(); - List tokens = tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "); - assertThat(tokens, contains("hello", "!", "how", "are", "you", "?")); + var tokens = tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "); + assertThat(tokenStrings(tokens), contains("hello", "!", "how", "are", "you", "?")); tokens = tokenizer.tokenize("H\u00E9llo"); - assertThat(tokens, contains("hello")); + assertThat(tokenStrings(tokens), contains("hello")); } public void testLowerCaseWithoutStripAccents() { BasicTokenizer tokenizer = new BasicTokenizer(true, true, false); - List tokens = tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "); - assertThat(tokens, contains("hällo", "!", "how", "are", "you", "?")); + var tokens = tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "); + assertThat(tokenStrings(tokens), contains("hällo", "!", "how", "are", "you", "?")); tokens = tokenizer.tokenize("H\u00E9llo"); - assertThat(tokens, contains("h\u00E9llo")); + assertThat(tokenStrings(tokens), contains("h\u00E9llo")); } public void testLowerCaseStripAccentsDefault() { BasicTokenizer tokenizer = new BasicTokenizer(true, true); - List tokens = tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "); - assertThat(tokens, contains("hallo", "!", "how", "are", "you", "?")); + var tokens = tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "); + assertThat(tokenStrings(tokens), contains("hallo", "!", "how", "are", "you", "?")); tokens = tokenizer.tokenize("H\u00E9llo"); - assertThat(tokens, contains("hello")); + assertThat(tokenStrings(tokens), contains("hello")); } public void testNoLower() { - List tokens = new BasicTokenizer(false, true, false).tokenize(" \tHäLLo!how \n Are yoU? "); - assertThat(tokens, contains("HäLLo", "!", "how", "Are", "yoU", "?")); + var tokens = new BasicTokenizer(false, true, false).tokenize(" \tHäLLo!how \n Are yoU? "); + assertThat(tokenStrings(tokens), contains("HäLLo", "!", "how", "Are", "yoU", "?")); } public void testNoLowerStripAccents() { - List tokens = new BasicTokenizer(false, true, true).tokenize(" \tHäLLo!how \n Are yoU? "); - assertThat(tokens, contains("HaLLo", "!", "how", "Are", "yoU", "?")); + var tokens = new BasicTokenizer(false, true, true).tokenize(" \tHäLLo!how \n Are yoU? "); + assertThat(tokenStrings(tokens), contains("HaLLo", "!", "how", "Are", "yoU", "?")); } public void testNeverSplit() { BasicTokenizer tokenizer = new BasicTokenizer(false, false, false, Collections.singleton("[UNK]")); - List tokens = tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"); - assertThat(tokens, contains("HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]")); + var tokens = tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"); + assertThat(tokenStrings(tokens), contains("HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]")); tokens = tokenizer.tokenize("Hello [UNK]."); - assertThat(tokens, contains("Hello", "[UNK]", ".")); + assertThat(tokenStrings(tokens), contains("Hello", "[UNK]", ".")); tokens = tokenizer.tokenize("Hello [UNK]?"); - assertThat(tokens, contains("Hello", "[UNK]", "?")); + assertThat(tokenStrings(tokens), contains("Hello", "[UNK]", "?")); tokens = tokenizer.tokenize("Hello [UNK]!!"); - assertThat(tokens, contains("Hello", "[UNK]", "!", "!")); + assertThat(tokenStrings(tokens), contains("Hello", "[UNK]", "!", "!")); tokens = tokenizer.tokenize("Hello-[UNK]"); - assertThat(tokens, contains("Hello", "-", "[UNK]")); + assertThat(tokenStrings(tokens), contains("Hello", "-", "[UNK]")); tokens = tokenizer.tokenize("Hello-[UNK][UNK]"); - assertThat(tokens, contains("Hello", "-", "[UNK]", "[UNK]")); + assertThat(tokenStrings(tokens), contains("Hello", "-", "[UNK]", "[UNK]")); } public void testSplitOnPunctuation() { - List tokens = BasicTokenizer.splitOnPunctuation("hi!"); - assertThat(tokens, contains("hi", "!")); + var tokens = BasicTokenizer.splitOnPunctuation(new DelimitedToken(0, 3, "hi!")); + assertEquals(new DelimitedToken(0, 2, "hi"), tokens.get(0)); + assertEquals(new DelimitedToken(2, 3, "!"), tokens.get(1)); - tokens = BasicTokenizer.splitOnPunctuation("hi."); - assertThat(tokens, contains("hi", ".")); + tokens = BasicTokenizer.splitOnPunctuation(makeToken("hi.")); + assertEquals(new DelimitedToken(0, 2, "hi"), tokens.get(0)); + assertEquals(new DelimitedToken(2, 3, "."), tokens.get(1)); - tokens = BasicTokenizer.splitOnPunctuation("!hi"); - assertThat(tokens, contains("!", "hi")); + tokens = BasicTokenizer.splitOnPunctuation(makeToken("!hi")); + assertEquals(new DelimitedToken(0, 1, "!"), tokens.get(0)); + assertEquals(new DelimitedToken(1, 3, "hi"), tokens.get(1)); - tokens = BasicTokenizer.splitOnPunctuation("don't"); - assertThat(tokens, contains("don", "'", "t")); + tokens = BasicTokenizer.splitOnPunctuation(makeToken("don't")); + assertEquals(new DelimitedToken(0, 3, "don"), tokens.get(0)); + assertEquals(new DelimitedToken(3, 4, "'"), tokens.get(1)); + assertEquals(new DelimitedToken(4, 5, "t"), tokens.get(2)); - tokens = BasicTokenizer.splitOnPunctuation("!!hi"); - assertThat(tokens, contains("!", "!", "hi")); + tokens = BasicTokenizer.splitOnPunctuation(makeToken("!!hi")); + assertEquals(new DelimitedToken(0, 1, "!"), tokens.get(0)); + assertEquals(new DelimitedToken(1, 2, "!"), tokens.get(1)); + assertEquals(new DelimitedToken(2, 4, "hi"), tokens.get(2)); - tokens = BasicTokenizer.splitOnPunctuation("[hi]"); - assertThat(tokens, contains("[", "hi", "]")); + tokens = BasicTokenizer.splitOnPunctuation(makeToken("[hi]")); + assertEquals(new DelimitedToken(0, 1, "["), tokens.get(0)); + assertEquals(new DelimitedToken(1, 3, "hi"), tokens.get(1)); + assertEquals(new DelimitedToken(3, 4, "]"), tokens.get(2)); - tokens = BasicTokenizer.splitOnPunctuation("hi."); - assertThat(tokens, contains("hi", ".")); + tokens = BasicTokenizer.splitOnPunctuation(makeToken("!!")); + assertEquals(new DelimitedToken(0, 1, "!"), tokens.get(0)); + assertEquals(new DelimitedToken(1, 2, "!"), tokens.get(1)); - tokens = BasicTokenizer.splitOnPunctuation("!!"); - assertThat(tokens, contains("!", "!")); + tokens = BasicTokenizer.splitOnPunctuation(makeToken("elastic’s")); + assertEquals(new DelimitedToken(0, 7, "elastic"), tokens.get(0)); + assertEquals(new DelimitedToken(7, 8, "’"), tokens.get(1)); + assertEquals(new DelimitedToken(8, 9, "s"), tokens.get(2)); - tokens = BasicTokenizer.splitOnPunctuation("elastic’s"); - assertThat(tokens, contains("elastic", "’", "s")); - - tokens = BasicTokenizer.splitOnPunctuation("elastic‘s"); - assertThat(tokens, contains("elastic", "‘", "s")); + tokens = BasicTokenizer.splitOnPunctuation(new DelimitedToken(4, 13, "elastic’s")); + assertEquals(new DelimitedToken(4, 11, "elastic"), tokens.get(0)); + assertEquals(new DelimitedToken(11, 12, "’"), tokens.get(1)); + assertEquals(new DelimitedToken(12, 13, "s"), tokens.get(2)); } public void testStripAccents() { @@ -123,8 +136,8 @@ public void testTokenizeCjkChars() { } public void testTokenizeChinese() { - List tokens = new BasicTokenizer().tokenize("ah\u535A\u63A8zz"); - assertThat(tokens, contains("ah", "\u535A", "\u63A8", "zz")); + var tokens = new BasicTokenizer().tokenize("ah\u535A\u63A8zz"); + assertThat(tokenStrings(tokens), contains("ah", "\u535A", "\u63A8", "zz")); } public void testCleanText() { @@ -132,11 +145,6 @@ public void testCleanText() { assertEquals("filter control chars", BasicTokenizer.cleanText("\u0000filter \uFFFDcontrol chars\u0005")); } - public void testWhiteSpaceTokenize() { - assertThat(BasicTokenizer.whiteSpaceTokenize("nochange"), arrayContaining("nochange")); - assertThat(BasicTokenizer.whiteSpaceTokenize(" some change "), arrayContaining("some", "", "change")); - } - public void testIsWhitespace() { assertTrue(BasicTokenizer.isWhiteSpace(' ')); assertTrue(BasicTokenizer.isWhiteSpace('\t')); @@ -187,4 +195,43 @@ public void testIsCjkChar() { assertTrue(BasicTokenizer.isCjkChar(0x2F800)); assertFalse(BasicTokenizer.isCjkChar(0x2FA20)); } + + public void testWhitespaceTokenize() { + { + List delimitedTokens = BasicTokenizer.whiteSpaceTokenize("hello! how are you?"); + assertThat(delimitedTokens, hasSize(4)); + assertThat(tokenStrings(delimitedTokens), contains("hello!", "how", "are", "you?")); + + assertThat(delimitedTokens.get(0), equalTo(new DelimitedToken(0, 6, "hello!"))); + assertThat(delimitedTokens.get(1), equalTo(new DelimitedToken(7, 10, "how"))); + assertThat(delimitedTokens.get(2), equalTo(new DelimitedToken(11, 14, "are"))); + assertThat(delimitedTokens.get(3), equalTo(new DelimitedToken(15, 19, "you?"))); + } + { + List delimitedTokens = BasicTokenizer.whiteSpaceTokenize(" leading whitespace"); + assertThat(delimitedTokens, hasSize(2)); + assertThat(tokenStrings(delimitedTokens), contains("leading", "whitespace")); + + assertThat(delimitedTokens.get(0), equalTo(new DelimitedToken(3, 10, "leading"))); + assertThat(delimitedTokens.get(1), equalTo(new DelimitedToken(11, 21, "whitespace"))); + } + { + List delimitedTokens = BasicTokenizer.whiteSpaceTokenize("double spaced text "); + assertThat(delimitedTokens, hasSize(3)); + assertThat(tokenStrings(delimitedTokens), contains("double", "spaced", "text")); + + assertThat(delimitedTokens.get(0), equalTo(new DelimitedToken(0, 6, "double"))); + assertThat(delimitedTokens.get(1), equalTo(new DelimitedToken(8, 14, "spaced"))); + assertThat(delimitedTokens.get(2), equalTo(new DelimitedToken(16, 20, "text"))); + } + } + + private List tokenStrings(List tokens) { + return tokens.stream().map(DelimitedToken::getToken).collect(Collectors.toList()); + } + + private DelimitedToken makeToken(String str) { + return new DelimitedToken(0, str.length(), str); + } + } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java index 0d326f5e8931e..fa9a9235cf2f6 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java @@ -16,8 +16,9 @@ import java.util.Collections; import java.util.List; import java.util.Set; +import java.util.stream.Collectors; -import static org.hamcrest.Matchers.arrayContaining; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -45,6 +46,10 @@ public class BertTokenizerTests extends ESTestCase { "with" ); + private List tokenStrings(List tokens) { + return tokens.stream().map(DelimitedToken::getToken).collect(Collectors.toList()); + } + public void testTokenize() { BertTokenizer tokenizer = BertTokenizer.builder( TEST_CASED_VOCAB, @@ -52,7 +57,7 @@ public void testTokenize() { ).build(); TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE); - assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", "fun")); + assertThat(tokenStrings(tokenization.getTokens()), contains("Elasticsearch", "fun")); assertArrayEquals(new int[] { 0, 1, 3 }, tokenization.getTokenIds()); assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.getTokenMap()); } @@ -90,18 +95,21 @@ public void testTokenizeLargeInputTruncation() { "Elasticsearch fun with Pancake and Godzilla", Tokenization.Truncate.FIRST ); - assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", "fun", "with", "Pancake")); + assertArrayEquals(new int[] { 0, 1, 3, 18, 17 }, tokenization.getTokenIds()); - tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, new BertTokenization(null, true, 5, Tokenization.Truncate.FIRST)).build(); - tokenization = tokenizer.tokenize("Elasticsearch fun with Pancake and Godzilla", Tokenization.Truncate.FIRST); - assertThat(tokenization.getTokens(), arrayContaining("[CLS]", "Elastic", "##search", "fun", "[SEP]")); + BertTokenizer tokenizerWithSpecialTokens = BertTokenizer.builder( + TEST_CASED_VOCAB, + new BertTokenization(null, true, 5, Tokenization.Truncate.FIRST) + ).build(); + tokenization = tokenizerWithSpecialTokens.tokenize("Elasticsearch fun with Pancake and Godzilla", Tokenization.Truncate.FIRST); + assertArrayEquals(new int[] { 12, 0, 1, 3, 13 }, tokenization.getTokenIds()); + assertArrayEquals(new int[] { -1, 0, 0, 1, -1 }, tokenization.getTokenMap()); } public void testTokenizeAppendSpecialTokens() { BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, Tokenization.createDefault()).build(); TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE); - assertThat(tokenization.getTokens(), arrayContaining("[CLS]", "Elastic", "##search", "fun", "[SEP]")); assertArrayEquals(new int[] { 12, 0, 1, 3, 13 }, tokenization.getTokenIds()); assertArrayEquals(new int[] { -1, 0, 0, 1, -1 }, tokenization.getTokenMap()); } @@ -118,7 +126,7 @@ public void testNeverSplitTokens() { "Elasticsearch " + specialToken + " fun", Tokenization.Truncate.NONE ); - assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", specialToken, "fun")); + assertThat(tokenStrings(tokenization.getTokens()), contains("Elasticsearch", specialToken, "fun")); assertArrayEquals(new int[] { 0, 1, 15, 3 }, tokenization.getTokenIds()); assertArrayEquals(new int[] { 0, 0, 1, 2 }, tokenization.getTokenMap()); } @@ -131,12 +139,12 @@ public void testDoLowerCase() { ).setDoLowerCase(false).setWithSpecialTokens(false).build(); TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE); - assertThat(tokenization.getTokens(), arrayContaining(BertTokenizer.UNKNOWN_TOKEN, "fun")); assertArrayEquals(new int[] { 3, 2 }, tokenization.getTokenIds()); assertArrayEquals(new int[] { 0, 1 }, tokenization.getTokenMap()); tokenization = tokenizer.tokenize("elasticsearch fun", Tokenization.Truncate.NONE); - assertThat(tokenization.getTokens(), arrayContaining("elastic", "##search", "fun")); + assertArrayEquals(new int[] { 0, 1, 2 }, tokenization.getTokenIds()); + assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.getTokenMap()); } { @@ -146,7 +154,7 @@ public void testDoLowerCase() { .build(); TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE); - assertThat(tokenization.getTokens(), arrayContaining("elastic", "##search", "fun")); + assertArrayEquals(new int[] { 0, 1, 2 }, tokenization.getTokenIds()); } } @@ -154,12 +162,11 @@ public void testPunctuation() { BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, Tokenization.createDefault()).setWithSpecialTokens(false).build(); TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch, fun.", Tokenization.Truncate.NONE); - assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", ",", "fun", ".")); + assertThat(tokenStrings(tokenization.getTokens()), contains("Elasticsearch", ",", "fun", ".")); assertArrayEquals(new int[] { 0, 1, 11, 3, 10 }, tokenization.getTokenIds()); assertArrayEquals(new int[] { 0, 0, 1, 2, 3 }, tokenization.getTokenMap()); tokenization = tokenizer.tokenize("Elasticsearch, fun [MASK].", Tokenization.Truncate.NONE); - assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", ",", "fun", "[MASK]", ".")); assertArrayEquals(new int[] { 0, 1, 11, 3, 14, 10 }, tokenization.getTokenIds()); assertArrayEquals(new int[] { 0, 0, 1, 2, 3, 4 }, tokenization.getTokenMap()); } @@ -171,16 +178,19 @@ public void testPunctuationWithMask() { ).setWithSpecialTokens(true).setNeverSplit(Set.of("[MASK]")).build(); TokenizationResult.Tokenization tokenization = tokenizer.tokenize("This is [MASK]-tastic!", Tokenization.Truncate.NONE); - assertThat(tokenization.getTokens(), arrayContaining("[CLS]", "This", "is", "[MASK]", "-", "ta", "##stic", "!", "[SEP]")); + assertThat(tokenStrings(tokenization.getTokens()), contains("This", "is", "[MASK]", "-", "tastic", "!")); + assertArrayEquals(new int[] { 0, 1, 2, 3, 4, 6, 7, 8, 9 }, tokenization.getTokenIds()); + assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 4, 5, -1 }, tokenization.getTokenMap()); tokenization = tokenizer.tokenize("This is sub~[MASK]!", Tokenization.Truncate.NONE); - assertThat(tokenization.getTokens(), arrayContaining("[CLS]", "This", "is", "sub", "~", "[MASK]", "!", "[SEP]")); + assertThat(tokenStrings(tokenization.getTokens()), contains("This", "is", "sub", "~", "[MASK]", "!")); + assertArrayEquals(new int[] { 0, 1, 2, 10, 5, 3, 8, 9 }, tokenization.getTokenIds()); + assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 5, -1 }, tokenization.getTokenMap()); tokenization = tokenizer.tokenize("This is sub,[MASK].tastic!", Tokenization.Truncate.NONE); - assertThat( - tokenization.getTokens(), - arrayContaining("[CLS]", "This", "is", "sub", ",", "[MASK]", ".", "ta", "##stic", "!", "[SEP]") - ); + assertThat(tokenStrings(tokenization.getTokens()), contains("This", "is", "sub", ",", "[MASK]", ".", "tastic", "!")); + assertArrayEquals(new int[] { 0, 1, 2, 10, 11, 3, 12, 6, 7, 8, 9 }, tokenization.getTokenIds()); + assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 5, 6, 6, 7, -1 }, tokenization.getTokenMap()); } public void testBatchInput() { @@ -200,22 +210,18 @@ public void testBatchInput() { assertThat(tr.getTokenizations(), hasSize(4)); TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0); - assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search")); assertArrayEquals(new int[] { 0, 1 }, tokenization.getTokenIds()); assertArrayEquals(new int[] { 0, 0 }, tokenization.getTokenMap()); tokenization = tr.getTokenizations().get(1); - assertThat(tokenization.getTokens(), arrayContaining("my", "little", "red", "car")); assertArrayEquals(new int[] { 4, 5, 6, 7 }, tokenization.getTokenIds()); assertArrayEquals(new int[] { 0, 1, 2, 3 }, tokenization.getTokenMap()); tokenization = tr.getTokenizations().get(2); - assertThat(tokenization.getTokens(), arrayContaining("God", "##zilla", "day")); assertArrayEquals(new int[] { 8, 9, 16 }, tokenization.getTokenIds()); assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.getTokenMap()); tokenization = tr.getTokenizations().get(3); - assertThat(tokenization.getTokens(), arrayContaining("God", "##zilla", "Pancake", "red", "car", "day")); assertArrayEquals(new int[] { 8, 9, 17, 6, 7, 16 }, tokenization.getTokenIds()); assertArrayEquals(new int[] { 0, 0, 1, 2, 3, 4 }, tokenization.getTokenMap()); } @@ -230,9 +236,11 @@ public void testMultiSeqTokenization() { "Godzilla my little red car", Tokenization.Truncate.NONE ); + + var tokenStream = Arrays.stream(tokenization.getTokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList()); assertThat( - tokenization.getTokens(), - arrayContaining( + tokenStream, + contains( BertTokenizer.CLASS_TOKEN, "Elastic", "##search", @@ -260,9 +268,11 @@ public void testTokenizeLargeInputMultiSequenceTruncation() { "Godzilla my little red car", Tokenization.Truncate.FIRST ); + + var tokenStream = Arrays.stream(tokenization.getTokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList()); assertThat( - tokenization.getTokens(), - arrayContaining( + tokenStream, + contains( BertTokenizer.CLASS_TOKEN, "Elastic", BertTokenizer.SEPARATOR_TOKEN, @@ -286,9 +296,10 @@ public void testTokenizeLargeInputMultiSequenceTruncation() { tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, new BertTokenization(null, true, 10, Tokenization.Truncate.SECOND)).build(); tokenization = tokenizer.tokenize("Elasticsearch is fun", "Godzilla my little red car", Tokenization.Truncate.SECOND); + tokenStream = Arrays.stream(tokenization.getTokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList()); assertThat( - tokenization.getTokens(), - arrayContaining( + tokenStream, + contains( BertTokenizer.CLASS_TOKEN, "Elastic", "##search", diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizerTests.java index dd3bf3863d361..c62df28007eef 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizerTests.java @@ -22,39 +22,39 @@ public class WordPieceTokenizerTests extends ESTestCase { public static final String UNKNOWN_TOKEN = "[UNK]"; public void testTokenize() { - Map vocabMap = createVocabMap( - UNKNOWN_TOKEN, - "[CLS]", - "[SEP]", - "want", - "##want", - "##ed", - "wa", - "un", - "runn", - "##ing" - ); + String[] vocab = { UNKNOWN_TOKEN, "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing" }; + Map vocabMap = createVocabMap(vocab); + WordPieceTokenizer tokenizer = new WordPieceTokenizer(vocabMap, UNKNOWN_TOKEN, 100); - List tokenAndIds = tokenizer.tokenize(""); - assertThat(tokenAndIds, empty()); + var tokenIds = tokenizer.tokenize(new DelimitedToken(0, 0, "")); + assertThat(tokenIds, empty()); + + tokenIds = tokenizer.tokenize(makeToken("unwanted")); + List tokenStrings = tokenIds.stream().map(index -> vocab[index]).collect(Collectors.toList()); + assertThat(tokenStrings, contains("un", "##want", "##ed")); - tokenAndIds = tokenizer.tokenize("unwanted running"); - List tokens = tokenAndIds.stream().map(WordPieceTokenizer.TokenAndId::getToken).collect(Collectors.toList()); - assertThat(tokens, contains("un", "##want", "##ed", "runn", "##ing")); + tokenIds = tokenizer.tokenize(makeToken("running")); + tokenStrings = tokenIds.stream().map(index -> vocab[index]).collect(Collectors.toList()); + assertThat(tokenStrings, contains("runn", "##ing")); + + tokenIds = tokenizer.tokenize(makeToken("unwantedX")); + tokenStrings = tokenIds.stream().map(index -> vocab[index]).collect(Collectors.toList()); + assertThat(tokenStrings, contains(UNKNOWN_TOKEN)); + } - tokenAndIds = tokenizer.tokenize("unwantedX running"); - tokens = tokenAndIds.stream().map(WordPieceTokenizer.TokenAndId::getToken).collect(Collectors.toList()); - assertThat(tokens, contains(UNKNOWN_TOKEN, "runn", "##ing")); + private DelimitedToken makeToken(String str) { + return new DelimitedToken(0, str.length(), str); } public void testMaxCharLength() { - Map vocabMap = createVocabMap("Some", "words", "will", "become", "UNK"); + String[] vocab = { "Some", "words", "will", "become", "UNK" }; + Map vocabMap = createVocabMap(vocab); WordPieceTokenizer tokenizer = new WordPieceTokenizer(vocabMap, "UNK", 4); - List tokenAndIds = tokenizer.tokenize("Some words will become UNK"); - List tokens = tokenAndIds.stream().map(WordPieceTokenizer.TokenAndId::getToken).collect(Collectors.toList()); - assertThat(tokens, contains("Some", "UNK", "will", "UNK", "UNK")); + var tokenIds = tokenizer.tokenize(new DelimitedToken(0, 0, "become")); + List tokenStrings = tokenIds.stream().map(index -> vocab[index]).collect(Collectors.toList()); + assertThat(tokenStrings, contains("UNK")); } static Map createVocabMap(String... words) {