diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java index dd0907c95ead2..fd1751387942c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.results; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -163,6 +164,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } + @Override + public String toString() { + return Strings.toString(this); + } + public Map toMap() { Map map = new LinkedHashMap<>(); map.put("entity", entity); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java index 7e829da85e9f1..60ab42fd300f7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java @@ -34,7 +34,7 @@ public BertRequestBuilder(BertTokenizer tokenizer) { @Override public NlpTask.Request buildRequest(List inputs, String requestId, Tokenization.Truncate truncate) throws IOException { - if (tokenizer.getPadToken().isEmpty()) { + if (tokenizer.getPadTokenId().isEmpty()) { throw new IllegalStateException("The input tokenizer does not have a " + BertTokenizer.PAD_TOKEN + " token in its vocabulary"); } @@ -46,10 +46,10 @@ public NlpTask.Request buildRequest(List inputs, String requestId, Token @Override public NlpTask.Request buildRequest(TokenizationResult tokenization, String requestId) throws IOException { - if (tokenizer.getPadToken().isEmpty()) { + if (tokenizer.getPadTokenId().isEmpty()) { throw new IllegalStateException("The input tokenizer does not have a " + BertTokenizer.PAD_TOKEN + " token in its vocabulary"); } - return new NlpTask.Request(tokenization, jsonRequest(tokenization, tokenizer.getPadToken().getAsInt(), requestId)); + return new NlpTask.Request(tokenization, jsonRequest(tokenization, tokenizer.getPadTokenId().getAsInt(), requestId)); } static BytesReference jsonRequest(TokenizationResult tokenization, int padToken, String requestId) throws IOException { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java index 3e54d7ab6ca16..72acd2891141b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java @@ -14,7 +14,6 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult; -import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult; @@ -27,10 +26,10 @@ public class FillMaskProcessor implements NlpTask.Processor { - private final NlpTask.RequestBuilder requestBuilder; + private final NlpTokenizer tokenizer; FillMaskProcessor(NlpTokenizer tokenizer, FillMaskConfig config) { - this.requestBuilder = tokenizer.requestBuilder(); + this.tokenizer = tokenizer; } @Override @@ -39,22 +38,23 @@ public void validateInputs(List inputs) { throw new IllegalArgumentException("input request is empty"); } + final String mask = tokenizer.getMaskToken(); for (String input : inputs) { - int maskIndex = input.indexOf(BertTokenizer.MASK_TOKEN); + int maskIndex = input.indexOf(mask); if (maskIndex < 0) { - throw new IllegalArgumentException("no " + BertTokenizer.MASK_TOKEN + " token could be found"); + throw new IllegalArgumentException("no " + mask + " token could be found"); } - maskIndex = input.indexOf(BertTokenizer.MASK_TOKEN, maskIndex + BertTokenizer.MASK_TOKEN.length()); + maskIndex = input.indexOf(mask, maskIndex + mask.length()); if (maskIndex > 0) { - throw new IllegalArgumentException("only one " + BertTokenizer.MASK_TOKEN + " token should exist in the input"); + throw new IllegalArgumentException("only one " + mask + " token should exist in the input"); } } } @Override public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) { - return requestBuilder; + return tokenizer.requestBuilder(); } @Override @@ -64,25 +64,55 @@ public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) { return (tokenization, result) -> processResult( tokenization, result, + tokenizer, fillMaskConfig.getNumTopClasses(), fillMaskConfig.getResultsField() ); } else { - return (tokenization, result) -> processResult(tokenization, result, FillMaskConfig.DEFAULT_NUM_RESULTS, DEFAULT_RESULTS_FIELD); + return (tokenization, result) -> processResult( + tokenization, + result, + tokenizer, + FillMaskConfig.DEFAULT_NUM_RESULTS, + DEFAULT_RESULTS_FIELD + ); } } static InferenceResults processResult( TokenizationResult tokenization, PyTorchResult pyTorchResult, + NlpTokenizer tokenizer, int numResults, String resultsField ) { - if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokens().length == 0) { + if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokenIds().length == 0) { return new WarningInferenceResults("No valid tokens for inference"); } - int maskTokenIndex = Arrays.asList(tokenization.getTokenizations().get(0).getTokens()).indexOf(BertTokenizer.MASK_TOKEN); + if (tokenizer.getMaskTokenId().isEmpty()) { + return new WarningInferenceResults( + "The token id for the mask token {} is not known in the tokenizer. Check the vocabulary contains the mask token", + tokenizer.getMaskToken() + ); + } + + int maskTokenIndex = -1; + int maskTokenId = tokenizer.getMaskTokenId().getAsInt(); + for (int i = 0; i < tokenization.getTokenizations().get(0).getTokenIds().length; i++) { + if (tokenization.getTokenizations().get(0).getTokenIds()[i] == maskTokenId) { + maskTokenIndex = i; + break; + } + } + if (maskTokenIndex == -1) { + return new WarningInferenceResults( + "mask token id [{}] not found in the tokenization {}", + maskTokenId, + Arrays.asList(tokenization.getTokenizations().get(0).getTokenIds()) + ); + } + // TODO - process all results in the batch double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0][maskTokenIndex]); @@ -103,7 +133,7 @@ static InferenceResults processResult( tokenization.getTokenizations() .get(0) .getInput() - .replace(BertTokenizer.MASK_TOKEN, tokenization.getFromVocab(scoreAndIndices[0].index)), + .replace(tokenizer.getMaskToken(), tokenization.getFromVocab(scoreAndIndices[0].index)), results, Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD), scoreAndIndices[0].score, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java index 2674d47fdaa4d..3d33d577066c1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult; @@ -193,7 +194,7 @@ static class NerResultProcessor implements NlpTask.ResultProcessor { @Override public InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) { - if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokens().length == 0) { + if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokenIds().length == 0) { return new WarningInferenceResults("no valid tokens to build result"); } // TODO - process all results in the batch @@ -213,6 +214,7 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchRe ? tokenization.getTokenizations().get(0).getInput().toLowerCase(Locale.ROOT) : tokenization.getTokenizations().get(0).getInput() ); + return new NerResults( resultsField, buildAnnotatedText(tokenization.getTokenizations().get(0).getInput(), entities), @@ -230,7 +232,7 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchRe static List tagTokens(TokenizationResult.Tokenization tokenization, double[][] scores, IobTag[] iobMap) { List taggedTokens = new ArrayList<>(); int startTokenIndex = 0; - while (startTokenIndex < tokenization.getTokens().length) { + while (startTokenIndex < tokenization.getTokenIds().length) { int inputMapping = tokenization.getTokenMap()[startTokenIndex]; if (inputMapping < 0) { // This token does not map to a token in the input (special tokens) @@ -238,15 +240,9 @@ static List tagTokens(TokenizationResult.Tokenization tokenization, continue; } int endTokenIndex = startTokenIndex; - StringBuilder word = new StringBuilder(tokenization.getTokens()[startTokenIndex]); - while (endTokenIndex < tokenization.getTokens().length - 1 + while (endTokenIndex < tokenization.getTokenMap().length - 1 && tokenization.getTokenMap()[endTokenIndex + 1] == inputMapping) { endTokenIndex++; - // TODO Here we try to get rid of the continuation hashes at the beginning of sub-tokens. - // It is probably more correct to implement detokenization on the tokenizer - // that does reverse lookup based on token IDs. - String endTokenWord = tokenization.getTokens()[endTokenIndex].substring(2); - word.append(endTokenWord); } double[] avgScores = Arrays.copyOf(scores[startTokenIndex], iobMap.length); for (int i = startTokenIndex + 1; i <= endTokenIndex; i++) { @@ -262,7 +258,7 @@ static List tagTokens(TokenizationResult.Tokenization tokenization, } int maxScoreIndex = NlpHelpers.argmax(avgScores); double score = avgScores[maxScoreIndex]; - taggedTokens.add(new TaggedToken(word.toString(), iobMap[maxScoreIndex], score)); + taggedTokens.add(new TaggedToken(tokenization.getTokens().get(inputMapping), iobMap[maxScoreIndex], score)); startTokenIndex = endTokenIndex + 1; } return taggedTokens; @@ -283,14 +279,12 @@ static List groupTaggedTokens(List tokens, } List entities = new ArrayList<>(); int startTokenIndex = 0; - int startFindInSeq = 0; while (startTokenIndex < tokens.size()) { TaggedToken token = tokens.get(startTokenIndex); if (token.tag.getEntity() == Entity.NONE) { startTokenIndex++; continue; } - StringBuilder entityWord = new StringBuilder(token.word); int endTokenIndex = startTokenIndex + 1; double scoreSum = token.score; while (endTokenIndex < tokens.size()) { @@ -298,43 +292,50 @@ static List groupTaggedTokens(List tokens, if (endToken.tag.isBeginning() || endToken.tag.getEntity() != token.tag.getEntity()) { break; } - // TODO Here we add a space between tokens. - // It is probably more correct to implement detokenization on the tokenizer - // that does reverse lookup based on token IDs. - entityWord.append(" ").append(endToken.word); scoreSum += endToken.score; endTokenIndex++; } - String entity = entityWord.toString(); - int i = inputSeq.indexOf(entity, startFindInSeq); + + int startPos = token.token.getStartPos(); + int endPos = tokens.get(endTokenIndex - 1).token.getEndPos(); + String entity = inputSeq.substring(startPos, endPos); entities.add( new NerResults.EntityGroup( entity, token.tag.getEntity().toString(), scoreSum / (endTokenIndex - startTokenIndex), - i, - i == -1 ? -1 : i + entity.length() + startPos, + endPos ) ); startTokenIndex = endTokenIndex; - if (i != -1) { - startFindInSeq = i + entity.length(); - } } return entities; } static class TaggedToken { - private final String word; + private final DelimitedToken token; private final IobTag tag; private final double score; - TaggedToken(String word, IobTag tag, double score) { - this.word = word; + TaggedToken(DelimitedToken token, IobTag tag, double score) { + this.token = token; this.tag = tag; this.score = score; } + + @Override + public String toString() { + return new StringBuilder("{").append("token:") + .append(token) + .append(", ") + .append(tag) + .append(", ") + .append(score) + .append("}") + .toString(); + } } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java index d98f1000d42e4..77af60cab0119 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java @@ -15,6 +15,7 @@ import java.util.Locale; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; /** * Basic tokenization of text by whitespace with optional extras: @@ -45,7 +46,7 @@ public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars, boolean i this.isLowerCase = isLowerCase; this.isTokenizeCjkChars = isTokenizeCjkChars; this.isStripAccents = isStripAccents; - this.neverSplitTokenTrieRoot = TokenTrieNode.build(neverSplit, this::doTokenize); + this.neverSplitTokenTrieRoot = TokenTrieNode.build(neverSplit, this::doTokenizeString); } public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars, boolean isStripAccents) { @@ -74,46 +75,51 @@ public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars) { * @param text The input text to tokenize * @return List of tokens */ - public List tokenize(String text) { + public List tokenize(String text) { return mergeNeverSplitTokens(doTokenize(text)); } - private List doTokenize(String text) { + private List doTokenizeString(String text) { + return doTokenize(text).stream().map(DelimitedToken::getToken).collect(Collectors.toList()); + } + + private List doTokenize(String text) { text = cleanText(text); if (isTokenizeCjkChars) { text = tokenizeCjkChars(text); } - String[] tokens = whiteSpaceTokenize(text); + List tokens = whiteSpaceTokenize(text); - List processedTokens = new ArrayList<>(tokens.length); - for (String token : tokens) { + List processedTokens = new ArrayList<>(tokens.size()); + for (DelimitedToken tokenRecord : tokens) { - if (Strings.EMPTY.equals(token)) { + String tokenStr = tokenRecord.getToken(); + if (Strings.EMPTY.equals(tokenStr)) { continue; } if (isLowerCase) { - token = token.toLowerCase(Locale.ROOT); + tokenStr = tokenStr.toLowerCase(Locale.ROOT); } if (isStripAccents) { - token = stripAccents(token); + tokenStr = stripAccents(tokenStr); } - processedTokens.addAll(splitOnPunctuation(token)); + processedTokens.addAll(splitOnPunctuation(new DelimitedToken(tokenRecord.getStartPos(), tokenRecord.getEndPos(), tokenStr))); } return processedTokens; } - private List mergeNeverSplitTokens(List tokens) { + private List mergeNeverSplitTokens(List tokens) { if (neverSplitTokenTrieRoot.isLeaf()) { return tokens; } - List mergedTokens = new ArrayList<>(tokens.size()); - List matchingTokens = new ArrayList<>(); + List mergedTokens = new ArrayList<>(tokens.size()); + List matchingTokens = new ArrayList<>(); TokenTrieNode current = neverSplitTokenTrieRoot; - for (String token : tokens) { - TokenTrieNode childNode = current.getChild(token); + for (DelimitedToken token : tokens) { + TokenTrieNode childNode = current.getChild(token.getToken()); if (childNode == null) { if (current != neverSplitTokenTrieRoot) { mergedTokens.addAll(matchingTokens); @@ -123,7 +129,7 @@ private List mergeNeverSplitTokens(List tokens) { mergedTokens.add(token); } else if (childNode.isLeaf()) { matchingTokens.add(token); - mergedTokens.add(String.join("", matchingTokens)); + mergedTokens.add(DelimitedToken.mergeTokens(matchingTokens)); matchingTokens = new ArrayList<>(); current = neverSplitTokenTrieRoot; } else { @@ -146,9 +152,53 @@ public boolean isTokenizeCjkChars() { return isTokenizeCjkChars; } - static String[] whiteSpaceTokenize(String text) { - text = text.trim(); - return text.split(" "); + /** + * Split the input text by whitespace. + * For the returned objects {@link DelimitedToken#getStartPos()} is the + * start character index inclusive and {@link DelimitedToken#getEndPos()} + * the index exclusive. The number of whitespace characters between 2 consecutive + * {@link DelimitedToken}s is the difference between the first's {@code endPos} + * and the second's {@code startPos}. + * + * The input should be normalized via a call to {@link #cleanText(String)} + * before it is passed to this function. + * + * @param text to tokenize + * @return White space separated strings + */ + static List whiteSpaceTokenize(String text) { + var tokens = new ArrayList(); + + // whitespace at beginning + int index = 0; + while (index < text.length() && text.charAt(index) == ' ') { + index++; + } + + int tokenStart = index; + + while (index < text.length()) { + if (text.charAt(index) == ' ') { + int tokenEnd = index; + index++; + // consume trail whitespace before the next word + // or end of text + while (index < text.length() && text.charAt(index) == ' ') { + index++; + } + + tokens.add(new DelimitedToken(tokenStart, tokenEnd, text.substring(tokenStart, tokenEnd))); + tokenStart = index; + } + index++; + } + + // trailing whitespace + if (tokenStart != text.length()) { + tokens.add(new DelimitedToken(tokenStart, text.length(), text.substring(tokenStart))); + } + + return tokens; } /** @@ -172,9 +222,9 @@ static String stripAccents(String word) { return new String(codePoints, 0, codePoints.length); } - static List splitOnPunctuation(String word) { - List split = new ArrayList<>(); - int[] codePoints = word.codePoints().toArray(); + static List splitOnPunctuation(DelimitedToken word) { + List splits = new ArrayList<>(); + int[] codePoints = word.getToken().codePoints().toArray(); int lastSplit = 0; for (int i = 0; i < codePoints.length; i++) { @@ -182,18 +232,30 @@ static List splitOnPunctuation(String word) { int charCount = i - lastSplit; if (charCount > 0) { // add a new string for what has gone before - split.add(new String(codePoints, lastSplit, i - lastSplit)); + splits.add( + new DelimitedToken( + word.getStartPos() + lastSplit, + word.getStartPos() + i, + new String(codePoints, lastSplit, i - lastSplit) + ) + ); } - split.add(new String(codePoints, i, 1)); + splits.add(new DelimitedToken(word.getStartPos() + i, word.getStartPos() + i + 1, new String(codePoints, i, 1))); lastSplit = i + 1; } } if (lastSplit < codePoints.length) { - split.add(new String(codePoints, lastSplit, codePoints.length - lastSplit)); + splits.add( + new DelimitedToken( + word.getStartPos() + lastSplit, + word.getStartPos() + codePoints.length, + new String(codePoints, lastSplit, codePoints.length - lastSplit) + ) + ); } - return split; + return splits; } /** diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java index 837f10e1e3bfb..9163a8737b48c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; import org.elasticsearch.common.util.set.Sets; -import org.elasticsearch.core.Tuple; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilder; @@ -80,7 +79,7 @@ protected BertTokenizer( } @Override - public OptionalInt getPadToken() { + public OptionalInt getPadTokenId() { Integer pad = vocab.get(PAD_TOKEN); if (pad != null) { return OptionalInt.of(pad); @@ -89,6 +88,21 @@ public OptionalInt getPadToken() { } } + @Override + public OptionalInt getMaskTokenId() { + Integer pad = vocab.get(MASK_TOKEN); + if (pad != null) { + return OptionalInt.of(pad); + } else { + return OptionalInt.empty(); + } + } + + @Override + public String getMaskToken() { + return MASK_TOKEN; + } + @Override public TokenizationResult buildTokenizationResult(List tokenizations) { TokenizationResult tokenizationResult = new TokenizationResult(originalVocab); @@ -112,16 +126,17 @@ public TokenizationResult buildTokenizationResult(List wordPieceTokens = innerResult.v1(); - List tokenPositionMap = innerResult.v2(); - int numTokens = withSpecialTokens ? wordPieceTokens.size() + 2 : wordPieceTokens.size(); + List wordPieceTokenIds = innerResult.wordPieceTokenIds; + List tokenPositionMap = innerResult.tokenPositionMap; + int numTokens = withSpecialTokens ? wordPieceTokenIds.size() + 2 : wordPieceTokenIds.size(); boolean isTruncated = false; + if (numTokens > maxSequenceLength) { switch (truncate) { case FIRST: case SECOND: isTruncated = true; - wordPieceTokens = wordPieceTokens.subList(0, withSpecialTokens ? maxSequenceLength - 2 : maxSequenceLength); + wordPieceTokenIds = wordPieceTokenIds.subList(0, withSpecialTokens ? maxSequenceLength - 2 : maxSequenceLength); break; case NONE: throw ExceptionsHelper.badRequestException( @@ -132,78 +147,75 @@ public TokenizationResult.Tokenization tokenize(String seq, Tokenization.Truncat } numTokens = maxSequenceLength; } - String[] tokens = new String[numTokens]; + int[] tokenIds = new int[numTokens]; int[] tokenMap = new int[numTokens]; if (withSpecialTokens) { - tokens[0] = CLASS_TOKEN; tokenIds[0] = vocab.get(CLASS_TOKEN); tokenMap[0] = SPECIAL_TOKEN_POSITION; } int i = withSpecialTokens ? 1 : 0; final int decrementHandler = withSpecialTokens ? 1 : 0; - for (WordPieceTokenizer.TokenAndId tokenAndId : wordPieceTokens) { - tokens[i] = tokenAndId.getToken(); - tokenIds[i] = tokenAndId.getId(); + for (var tokenId : wordPieceTokenIds) { + tokenIds[i] = tokenId; tokenMap[i] = tokenPositionMap.get(i - decrementHandler); i++; } if (withSpecialTokens) { - tokens[i] = SEPARATOR_TOKEN; tokenIds[i] = vocab.get(SEPARATOR_TOKEN); tokenMap[i] = SPECIAL_TOKEN_POSITION; } - return new TokenizationResult.Tokenization(seq, isTruncated, tokens, tokenIds, tokenMap); + return new TokenizationResult.Tokenization(seq, innerResult.tokens, isTruncated, tokenIds, tokenMap); } @Override public TokenizationResult.Tokenization tokenize(String seq1, String seq2, Tokenization.Truncate truncate) { - var innerResult = innerTokenize(seq1); - List wordPieceTokenSeq1s = innerResult.v1(); - List tokenPositionMapSeq1 = innerResult.v2(); - innerResult = innerTokenize(seq2); - List wordPieceTokenSeq2s = innerResult.v1(); - List tokenPositionMapSeq2 = innerResult.v2(); + var innerResultSeq1 = innerTokenize(seq1); + List wordPieceTokenIdsSeq1 = innerResultSeq1.wordPieceTokenIds; + List tokenPositionMapSeq1 = innerResultSeq1.tokenPositionMap; + var innerResultSeq2 = innerTokenize(seq2); + List wordPieceTokenIdsSeq2 = innerResultSeq2.wordPieceTokenIds; + List tokenPositionMapSeq2 = innerResultSeq2.tokenPositionMap; if (withSpecialTokens == false) { throw new IllegalArgumentException("Unable to do sequence pair tokenization without special tokens"); } // [CLS] seq1 [SEP] seq2 [SEP] - int numTokens = wordPieceTokenSeq1s.size() + wordPieceTokenSeq2s.size() + 3; + int numTokens = wordPieceTokenIdsSeq1.size() + wordPieceTokenIdsSeq2.size() + 3; boolean isTruncated = false; if (numTokens > maxSequenceLength) { switch (truncate) { case FIRST: isTruncated = true; - if (wordPieceTokenSeq2s.size() > maxSequenceLength - 3) { + if (wordPieceTokenIdsSeq2.size() > maxSequenceLength - 3) { throw ExceptionsHelper.badRequestException( "Attempting truncation [{}] but input is too large for the second sequence. " + "The tokenized input length [{}] exceeds the maximum sequence length [{}], " + "when taking special tokens into account", truncate.toString(), - wordPieceTokenSeq2s.size(), + wordPieceTokenIdsSeq2.size(), maxSequenceLength - 3 ); } - wordPieceTokenSeq1s = wordPieceTokenSeq1s.subList(0, maxSequenceLength - 3 - wordPieceTokenSeq2s.size()); + wordPieceTokenIdsSeq1 = wordPieceTokenIdsSeq1.subList(0, maxSequenceLength - 3 - wordPieceTokenIdsSeq2.size()); break; case SECOND: isTruncated = true; - if (wordPieceTokenSeq1s.size() > maxSequenceLength - 3) { + if (wordPieceTokenIdsSeq1.size() > maxSequenceLength - 3) { throw ExceptionsHelper.badRequestException( "Attempting truncation [{}] but input is too large for the first sequence. " + "The tokenized input length [{}] exceeds the maximum sequence length [{}], " + "when taking special tokens into account", truncate.toString(), - wordPieceTokenSeq2s.size(), + wordPieceTokenIdsSeq1.size(), maxSequenceLength - 3 ); } - wordPieceTokenSeq2s = wordPieceTokenSeq2s.subList(0, maxSequenceLength - 3 - wordPieceTokenSeq1s.size()); + wordPieceTokenIdsSeq2 = wordPieceTokenIdsSeq2.subList(0, maxSequenceLength - 3 - wordPieceTokenIdsSeq1.size()); break; case NONE: throw ExceptionsHelper.badRequestException( @@ -214,62 +226,71 @@ public TokenizationResult.Tokenization tokenize(String seq1, String seq2, Tokeni } numTokens = maxSequenceLength; } - String[] tokens = new String[numTokens]; int[] tokenIds = new int[numTokens]; int[] tokenMap = new int[numTokens]; - tokens[0] = CLASS_TOKEN; tokenIds[0] = vocab.get(CLASS_TOKEN); tokenMap[0] = SPECIAL_TOKEN_POSITION; int i = 1; - for (WordPieceTokenizer.TokenAndId tokenAndId : wordPieceTokenSeq1s) { - tokens[i] = tokenAndId.getToken(); - tokenIds[i] = tokenAndId.getId(); + for (var tokenId : wordPieceTokenIdsSeq1) { + tokenIds[i] = tokenId; tokenMap[i] = tokenPositionMapSeq1.get(i - 1); i++; } - tokens[i] = SEPARATOR_TOKEN; tokenIds[i] = vocab.get(SEPARATOR_TOKEN); tokenMap[i] = SPECIAL_TOKEN_POSITION; ++i; int j = 0; - for (WordPieceTokenizer.TokenAndId tokenAndId : wordPieceTokenSeq2s) { - tokens[i] = tokenAndId.getToken(); - tokenIds[i] = tokenAndId.getId(); + for (var tokenId : wordPieceTokenIdsSeq2) { + tokenIds[i] = tokenId; tokenMap[i] = tokenPositionMapSeq2.get(j); i++; j++; } - tokens[i] = SEPARATOR_TOKEN; tokenIds[i] = vocab.get(SEPARATOR_TOKEN); tokenMap[i] = SPECIAL_TOKEN_POSITION; - return new TokenizationResult.Tokenization(seq1 + seq2, isTruncated, tokens, tokenIds, tokenMap); + List tokens = new ArrayList<>(innerResultSeq1.tokens); + tokens.addAll(innerResultSeq2.tokens); + return new TokenizationResult.Tokenization(seq1 + seq2, tokens, isTruncated, tokenIds, tokenMap); } - private Tuple, List> innerTokenize(String seq) { + private InnerTokenization innerTokenize(String seq) { BasicTokenizer basicTokenizer = new BasicTokenizer(doLowerCase, doTokenizeCjKChars, doStripAccents, neverSplit); - List delineatedTokens = basicTokenizer.tokenize(seq); - List wordPieceTokens = new ArrayList<>(); + var tokenSequences = basicTokenizer.tokenize(seq); + List wordPieceTokens = new ArrayList<>(); List tokenPositionMap = new ArrayList<>(); - for (int sourceIndex = 0; sourceIndex < delineatedTokens.size(); sourceIndex++) { - String token = delineatedTokens.get(sourceIndex); + for (int sourceIndex = 0; sourceIndex < tokenSequences.size(); sourceIndex++) { + String token = tokenSequences.get(sourceIndex).getToken(); if (neverSplit.contains(token)) { - wordPieceTokens.add(new WordPieceTokenizer.TokenAndId(token, vocab.getOrDefault(token, vocab.get(UNKNOWN_TOKEN)))); + wordPieceTokens.add(vocab.getOrDefault(token, vocab.get(UNKNOWN_TOKEN))); tokenPositionMap.add(sourceIndex); } else { - List tokens = wordPieceTokenizer.tokenize(token); + List tokens = wordPieceTokenizer.tokenize(tokenSequences.get(sourceIndex)); for (int tokenCount = 0; tokenCount < tokens.size(); tokenCount++) { tokenPositionMap.add(sourceIndex); } wordPieceTokens.addAll(tokens); } } - return Tuple.tuple(wordPieceTokens, tokenPositionMap); + + return new InnerTokenization(tokenSequences, wordPieceTokens, tokenPositionMap); + } + + private static class InnerTokenization { + List tokens; + List wordPieceTokenIds; + List tokenPositionMap; + + InnerTokenization(List tokens, List wordPieceTokenIds, List tokenPositionMap) { + this.tokens = tokens; + this.wordPieceTokenIds = wordPieceTokenIds; + this.tokenPositionMap = tokenPositionMap; + } } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DelimitedToken.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DelimitedToken.java new file mode 100644 index 0000000000000..74f1121cc467f --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DelimitedToken.java @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; + +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +public class DelimitedToken { + + /** + * Merges the list of tokens. + * + * Assumes that the tokens are in order. + * + * @param tokens + * @return The merged token + */ + public static DelimitedToken mergeTokens(List tokens) { + if (tokens.size() == 1) { + return tokens.get(0); + } + + String merged = tokens.stream().map(DelimitedToken::getToken).collect(Collectors.joining()); + return new DelimitedToken(tokens.get(0).getStartPos(), tokens.get(tokens.size() - 1).getEndPos(), merged); + } + + private final int startPos; + private final int endPos; + private final String token; + + DelimitedToken(int startPos, int endPos, String token) { + this.startPos = startPos; + this.endPos = endPos; + this.token = token; + } + + public int getStartPos() { + return startPos; + } + + public int getEndPos() { + return endPos; + } + + public String getToken() { + return token; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DelimitedToken that = (DelimitedToken) o; + return startPos == that.startPos && endPos == that.endPos && Objects.equals(token, that.token); + } + + @Override + public int hashCode() { + return Objects.hash(startPos, endPos, token); + } + + @Override + public String toString() { + return "{" + "startPos=" + startPos + ", endPos=" + endPos + ", token=" + token + '}'; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java index c039ab8de0fda..1146b2d8ae1c1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java @@ -30,7 +30,11 @@ public interface NlpTokenizer { NlpTask.RequestBuilder requestBuilder(); - OptionalInt getPadToken(); + OptionalInt getPadTokenId(); + + OptionalInt getMaskTokenId(); + + String getMaskToken(); static NlpTokenizer build(Vocabulary vocabulary, Tokenization params) { ExceptionsHelper.requireNonNull(params, TOKENIZATION); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java index b50c9548504a9..862be3c43bf67 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java @@ -33,9 +33,9 @@ public List getTokenizations() { return tokenizations; } - public void addTokenization(String input, boolean isTruncated, String[] tokens, int[] tokenIds, int[] tokenMap) { + public void addTokenization(String input, boolean isTruncated, List tokens, int[] tokenIds, int[] tokenMap) { maxLength = Math.max(maxLength, tokenIds.length); - tokenizations.add(new Tokenization(input, isTruncated, tokens, tokenIds, tokenMap)); + tokenizations.add(new Tokenization(input, tokens, isTruncated, tokenIds, tokenMap)); } public void addTokenization(Tokenization tokenization) { @@ -49,16 +49,15 @@ public int getLongestSequenceLength() { public static class Tokenization { - private final String inputSeqs; - private final String[] tokens; + private final String input; + private final List tokens; private final int[] tokenIds; private final int[] tokenMap; private final boolean truncated; - public Tokenization(String input, boolean truncated, String[] tokens, int[] tokenIds, int[] tokenMap) { - assert tokens.length == tokenIds.length; + public Tokenization(String input, List tokens, boolean truncated, int[] tokenIds, int[] tokenMap) { assert tokenIds.length == tokenMap.length; - this.inputSeqs = input; + this.input = input; this.tokens = tokens; this.tokenIds = tokenIds; this.tokenMap = tokenMap; @@ -66,16 +65,7 @@ public Tokenization(String input, boolean truncated, String[] tokens, int[] toke } /** - * The token strings from the tokenization process - * - * @return A list of tokens - */ - public String[] getTokens() { - return tokens; - } - - /** - * The integer values of the tokens in {@link #getTokens()} + * The integer values of the tokens} * * @return A list of token Ids */ @@ -95,7 +85,11 @@ public int[] getTokenMap() { } public String getInput() { - return inputSeqs; + return input; + } + + public List getTokens() { + return tokens; } public boolean isTruncated() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizer.java index a566007a8c5ae..b50e70f85f12a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizer.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -25,26 +26,7 @@ public class WordPieceTokenizer { private final String unknownToken; private final int maxInputCharsPerWord; - public static class TokenAndId { - private final String token; - private final int id; - - TokenAndId(String token, int id) { - this.token = token; - this.id = id; - } - - public int getId() { - return id; - } - - public String getToken() { - return token; - } - } - /** - * * @param vocab The token vocabulary * @param unknownToken If not found in the vocabulary * @param maxInputCharsPerWord Inputs tokens longer than this are 'unknown' @@ -58,63 +40,55 @@ public WordPieceTokenizer(Map vocab, String unknownToken, int m /** * Wordpiece tokenize the input text. * - * @param text A single token or whitespace separated tokens. - * Input should have been normalized by the {@link BasicTokenizer}. - * @return List of tokens + * @param token Word to tokenize + * @return List of token IDs */ - public List tokenize(String text) { - String[] tokens = BasicTokenizer.whiteSpaceTokenize(text); - - List output = new ArrayList<>(); - for (String token : tokens) { - if (token.length() > maxInputCharsPerWord) { - assert vocab.containsKey(unknownToken); - output.add(new TokenAndId(unknownToken, vocab.get(unknownToken))); - continue; - } + public List tokenize(DelimitedToken token) { + + if (token.getToken().length() > maxInputCharsPerWord) { + assert vocab.containsKey(unknownToken); + return Collections.singletonList(vocab.get(unknownToken)); + } - boolean isBad = false; - int start = 0; - List subTokens = new ArrayList<>(); - int length = token.length(); - while (start < length) { - int end = length; - - String currentValidSubStr = null; - - while (start < end) { - String subStr; - if (start > 0) { - subStr = CONTINUATION + token.substring(start, end); - } else { - subStr = token.substring(start, end); - } - - if (vocab.containsKey(subStr)) { - currentValidSubStr = subStr; - break; - } - - end--; + List output = new ArrayList<>(); + boolean isBad = false; + int start = 0; + int length = token.getToken().length(); + while (start < length) { + int end = length; + + String currentValidSubStr = null; + + while (start < end) { + String subStr; + if (start > 0) { + subStr = CONTINUATION + token.getToken().substring(start, end); + } else { + subStr = token.getToken().substring(start, end); } - if (currentValidSubStr == null) { - isBad = true; + if (vocab.containsKey(subStr)) { + currentValidSubStr = subStr; break; } - subTokens.add(new TokenAndId(currentValidSubStr, vocab.get(currentValidSubStr))); - - start = end; + end--; } - if (isBad) { - output.add(new TokenAndId(unknownToken, vocab.get(unknownToken))); - } else { - output.addAll(subTokens); + if (currentValidSubStr == null) { + isBad = true; + break; } + + output.add(vocab.get(currentValidSubStr)); + + start = end; } - return output; + if (isBad) { + return Collections.singletonList(vocab.get(unknownToken)); + } else { + return output; + } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java index e43d1f3e41d60..af41090e4567a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java @@ -14,18 +14,22 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig; import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BasicTokenizer; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.OptionalInt; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class FillMaskProcessorTests extends ESTestCase { @@ -45,17 +49,23 @@ public void testProcessResults() { String input = "The capital of " + BertTokenizer.MASK_TOKEN + " is Paris"; List vocab = Arrays.asList("The", "capital", "of", BertTokenizer.MASK_TOKEN, "is", "Paris", "France"); - String[] tokens = input.split(" "); + List tokens = new BasicTokenizer(randomBoolean(), randomBoolean(), randomBoolean()).tokenize(input); + int[] tokenMap = new int[] { 0, 1, 2, 3, 4, 5 }; int[] tokenIds = new int[] { 0, 1, 2, 3, 4, 5 }; TokenizationResult tokenization = new TokenizationResult(vocab); tokenization.addTokenization(input, false, tokens, tokenIds, tokenMap); + BertTokenizer tokenizer = mock(BertTokenizer.class); + when(tokenizer.getMaskToken()).thenReturn(BertTokenizer.MASK_TOKEN); + when(tokenizer.getMaskTokenId()).thenReturn(OptionalInt.of(3)); + String resultsField = randomAlphaOfLength(10); FillMaskResults result = (FillMaskResults) FillMaskProcessor.processResult( tokenization, new PyTorchResult("1", scores, 0L, null), + tokenizer, 4, resultsField ); @@ -72,12 +82,15 @@ public void testProcessResults() { } public void testProcessResults_GivenMissingTokens() { + BertTokenizer tokenizer = mock(BertTokenizer.class); + when(tokenizer.getMaskToken()).thenReturn("[MASK]"); + TokenizationResult tokenization = new TokenizationResult(Collections.emptyList()); - tokenization.addTokenization("", false, new String[] {}, new int[] {}, new int[] {}); + tokenization.addTokenization("", false, Collections.emptyList(), new int[] {}, new int[] {}); PyTorchResult pyTorchResult = new PyTorchResult("1", new double[][][] { { {} } }, 0L, null); assertThat( - FillMaskProcessor.processResult(tokenization, pyTorchResult, 5, randomAlphaOfLength(10)), + FillMaskProcessor.processResult(tokenization, pyTorchResult, tokenizer, 5, randomAlphaOfLength(10)), instanceOf(WarningInferenceResults.class) ); } @@ -85,8 +98,10 @@ public void testProcessResults_GivenMissingTokens() { public void testValidate_GivenMissingMaskToken() { List input = List.of("The capital of France is Paris"); + BertTokenizer tokenizer = mock(BertTokenizer.class); + when(tokenizer.getMaskToken()).thenReturn("[MASK]"); FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null, null, null); - FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config); + FillMaskProcessor processor = new FillMaskProcessor(tokenizer, config); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> processor.validateInputs(input)); assertThat(e.getMessage(), containsString("no [MASK] token could be found")); @@ -95,8 +110,11 @@ public void testValidate_GivenMissingMaskToken() { public void testProcessResults_GivenMultipleMaskTokens() { List input = List.of("The capital of [MASK] is [MASK]"); + BertTokenizer tokenizer = mock(BertTokenizer.class); + when(tokenizer.getMaskToken()).thenReturn("[MASK]"); + FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null, null, null); - FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config); + FillMaskProcessor processor = new FillMaskProcessor(tokenizer, config); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> processor.validateInputs(input)); assertThat(e.getMessage(), containsString("only one [MASK] token should exist in the input")); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java index 1375a16bc02f8..041d74e7656c4 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java @@ -16,7 +16,9 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig; import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BasicTokenizer; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult; import java.util.ArrayList; @@ -160,23 +162,26 @@ public void testProcessResults_withIobMap() { } public void testGroupTaggedTokens() { - List tokens = new ArrayList<>(); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("Hi", NerProcessor.IobTag.O, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("Sarah", NerProcessor.IobTag.B_PER, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("Jessica", NerProcessor.IobTag.I_PER, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("I", NerProcessor.IobTag.O, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("live", NerProcessor.IobTag.O, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("in", NerProcessor.IobTag.O, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("Manchester", NerProcessor.IobTag.B_LOC, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("and", NerProcessor.IobTag.O, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("work", NerProcessor.IobTag.O, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("for", NerProcessor.IobTag.O, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("Elastic", NerProcessor.IobTag.B_ORG, 1.0)); - - List entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens( - tokens, - "Hi Sarah Jessica, I live in Manchester and work for Elastic" - ); + String input = "Hi Sarah Jessica, I live in Manchester and work for Elastic"; + List tokens = new BasicTokenizer(randomBoolean(), randomBoolean(), randomBoolean()).tokenize(input); + assertThat(tokens, hasSize(12)); + + List taggedTokens = new ArrayList<>(); + int i = 0; + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_LOC, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_ORG, 1.0)); + + List entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input); assertThat(entityGroups, hasSize(3)); assertThat(entityGroups.get(0).getClassName(), equalTo("PER")); assertThat(entityGroups.get(0).getEntity(), equalTo("Sarah Jessica")); @@ -187,23 +192,32 @@ public void testGroupTaggedTokens() { } public void testGroupTaggedTokens_GivenNoEntities() { - List tokens = new ArrayList<>(); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("Hi", NerProcessor.IobTag.O, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("there", NerProcessor.IobTag.O, 1.0)); + String input = "Hi there"; + List tokens = new BasicTokenizer(randomBoolean(), randomBoolean(), randomBoolean()).tokenize(input); + + List taggedTokens = new ArrayList<>(); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(0), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(1), NerProcessor.IobTag.O, 1.0)); - List entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(tokens, "Hi there"); + List entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input); assertThat(entityGroups, is(empty())); } public void testGroupTaggedTokens_GivenConsecutiveEntities() { - List tokens = new ArrayList<>(); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("Rita", NerProcessor.IobTag.B_PER, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("Sue", NerProcessor.IobTag.B_PER, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("and", NerProcessor.IobTag.O, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("Bob", NerProcessor.IobTag.B_PER, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("too", NerProcessor.IobTag.O, 1.0)); - - List entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(tokens, "Rita, Sue, and Bob too"); + String input = "Rita, Sue, and Bob too"; + List tokens = new BasicTokenizer(randomBoolean(), randomBoolean(), randomBoolean()).tokenize(input); + + List taggedTokens = new ArrayList<>(); + int i = 0; + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + + List entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input); assertThat(entityGroups, hasSize(3)); assertThat(entityGroups.get(0).getClassName(), equalTo("PER")); assertThat(entityGroups.get(0).getEntity(), equalTo("Rita")); @@ -214,17 +228,19 @@ public void testGroupTaggedTokens_GivenConsecutiveEntities() { } public void testGroupTaggedTokens_GivenConsecutiveContinuingEntities() { - List tokens = new ArrayList<>(); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("FirstName", NerProcessor.IobTag.B_PER, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("SecondName", NerProcessor.IobTag.I_PER, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("NextPerson", NerProcessor.IobTag.B_PER, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("NextPersonSecondName", NerProcessor.IobTag.I_PER, 1.0)); - tokens.add(new NerProcessor.NerResultProcessor.TaggedToken("something_else", NerProcessor.IobTag.B_ORG, 1.0)); - - List entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens( - tokens, - "FirstName SecondName, NextPerson NextPersonSecondName. something_else" - ); + String input = "FirstName SecondName, NextPerson NextPersonSecondName. something_else"; + List tokens = new BasicTokenizer(randomBoolean(), randomBoolean(), randomBoolean()).tokenize(input); + + List taggedTokens = new ArrayList<>(); + int i = 0; + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_ORG, 1.0)); + + List entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input); assertThat(entityGroups, hasSize(3)); assertThat(entityGroups.get(0).getClassName(), equalTo("PER")); assertThat(entityGroups.get(0).getEntity(), equalTo("FirstName SecondName")); @@ -233,6 +249,39 @@ public void testGroupTaggedTokens_GivenConsecutiveContinuingEntities() { assertThat(entityGroups.get(2).getClassName(), equalTo("ORG")); } + public void testEntityContainsPunctuation() { + String input = "Alexander, my name is Benjamin Trent, I work at Acme Inc.."; + List tokens = new BasicTokenizer(randomBoolean(), randomBoolean(), randomBoolean()).tokenize(input); + + List taggedTokens = new ArrayList<>(); + int i = 0; + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_ORG, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_ORG, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_ORG, 1.0)); + taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0)); + assertEquals(tokens.size(), taggedTokens.size()); + + List entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input); + assertThat(entityGroups, hasSize(3)); + assertThat(entityGroups.get(0).getClassName(), equalTo("PER")); + assertThat(entityGroups.get(0).getEntity(), equalTo("Alexander")); + assertThat(entityGroups.get(1).getClassName(), equalTo("PER")); + assertThat(entityGroups.get(1).getEntity(), equalTo("Benjamin Trent")); + assertThat(entityGroups.get(2).getClassName(), equalTo("ORG")); + assertThat(entityGroups.get(2).getEntity(), equalTo("Acme Inc.")); + } + public void testAnnotatedTextBuilder() { String input = "Alexander, my name is Benjamin Trent, I work at Acme Inc."; List entities = List.of( diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java index 87b02e6066a9c..4bdd24cafe92a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java @@ -11,9 +11,11 @@ import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; -import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.sameInstance; /** @@ -24,91 +26,102 @@ public class BasicTokenizerTests extends ESTestCase { public void testLowerCase() { BasicTokenizer tokenizer = new BasicTokenizer(); - List tokens = tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "); - assertThat(tokens, contains("hello", "!", "how", "are", "you", "?")); + var tokens = tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "); + assertThat(tokenStrings(tokens), contains("hello", "!", "how", "are", "you", "?")); tokens = tokenizer.tokenize("H\u00E9llo"); - assertThat(tokens, contains("hello")); + assertThat(tokenStrings(tokens), contains("hello")); } public void testLowerCaseWithoutStripAccents() { BasicTokenizer tokenizer = new BasicTokenizer(true, true, false); - List tokens = tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "); - assertThat(tokens, contains("hällo", "!", "how", "are", "you", "?")); + var tokens = tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "); + assertThat(tokenStrings(tokens), contains("hällo", "!", "how", "are", "you", "?")); tokens = tokenizer.tokenize("H\u00E9llo"); - assertThat(tokens, contains("h\u00E9llo")); + assertThat(tokenStrings(tokens), contains("h\u00E9llo")); } public void testLowerCaseStripAccentsDefault() { BasicTokenizer tokenizer = new BasicTokenizer(true, true); - List tokens = tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "); - assertThat(tokens, contains("hallo", "!", "how", "are", "you", "?")); + var tokens = tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "); + assertThat(tokenStrings(tokens), contains("hallo", "!", "how", "are", "you", "?")); tokens = tokenizer.tokenize("H\u00E9llo"); - assertThat(tokens, contains("hello")); + assertThat(tokenStrings(tokens), contains("hello")); } public void testNoLower() { - List tokens = new BasicTokenizer(false, true, false).tokenize(" \tHäLLo!how \n Are yoU? "); - assertThat(tokens, contains("HäLLo", "!", "how", "Are", "yoU", "?")); + var tokens = new BasicTokenizer(false, true, false).tokenize(" \tHäLLo!how \n Are yoU? "); + assertThat(tokenStrings(tokens), contains("HäLLo", "!", "how", "Are", "yoU", "?")); } public void testNoLowerStripAccents() { - List tokens = new BasicTokenizer(false, true, true).tokenize(" \tHäLLo!how \n Are yoU? "); - assertThat(tokens, contains("HaLLo", "!", "how", "Are", "yoU", "?")); + var tokens = new BasicTokenizer(false, true, true).tokenize(" \tHäLLo!how \n Are yoU? "); + assertThat(tokenStrings(tokens), contains("HaLLo", "!", "how", "Are", "yoU", "?")); } public void testNeverSplit() { BasicTokenizer tokenizer = new BasicTokenizer(false, false, false, Collections.singleton("[UNK]")); - List tokens = tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"); - assertThat(tokens, contains("HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]")); + var tokens = tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"); + assertThat(tokenStrings(tokens), contains("HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]")); tokens = tokenizer.tokenize("Hello [UNK]."); - assertThat(tokens, contains("Hello", "[UNK]", ".")); + assertThat(tokenStrings(tokens), contains("Hello", "[UNK]", ".")); tokens = tokenizer.tokenize("Hello [UNK]?"); - assertThat(tokens, contains("Hello", "[UNK]", "?")); + assertThat(tokenStrings(tokens), contains("Hello", "[UNK]", "?")); tokens = tokenizer.tokenize("Hello [UNK]!!"); - assertThat(tokens, contains("Hello", "[UNK]", "!", "!")); + assertThat(tokenStrings(tokens), contains("Hello", "[UNK]", "!", "!")); tokens = tokenizer.tokenize("Hello-[UNK]"); - assertThat(tokens, contains("Hello", "-", "[UNK]")); + assertThat(tokenStrings(tokens), contains("Hello", "-", "[UNK]")); tokens = tokenizer.tokenize("Hello-[UNK][UNK]"); - assertThat(tokens, contains("Hello", "-", "[UNK]", "[UNK]")); + assertThat(tokenStrings(tokens), contains("Hello", "-", "[UNK]", "[UNK]")); } public void testSplitOnPunctuation() { - List tokens = BasicTokenizer.splitOnPunctuation("hi!"); - assertThat(tokens, contains("hi", "!")); + var tokens = BasicTokenizer.splitOnPunctuation(new DelimitedToken(0, 3, "hi!")); + assertEquals(new DelimitedToken(0, 2, "hi"), tokens.get(0)); + assertEquals(new DelimitedToken(2, 3, "!"), tokens.get(1)); - tokens = BasicTokenizer.splitOnPunctuation("hi."); - assertThat(tokens, contains("hi", ".")); + tokens = BasicTokenizer.splitOnPunctuation(makeToken("hi.")); + assertEquals(new DelimitedToken(0, 2, "hi"), tokens.get(0)); + assertEquals(new DelimitedToken(2, 3, "."), tokens.get(1)); - tokens = BasicTokenizer.splitOnPunctuation("!hi"); - assertThat(tokens, contains("!", "hi")); + tokens = BasicTokenizer.splitOnPunctuation(makeToken("!hi")); + assertEquals(new DelimitedToken(0, 1, "!"), tokens.get(0)); + assertEquals(new DelimitedToken(1, 3, "hi"), tokens.get(1)); - tokens = BasicTokenizer.splitOnPunctuation("don't"); - assertThat(tokens, contains("don", "'", "t")); + tokens = BasicTokenizer.splitOnPunctuation(makeToken("don't")); + assertEquals(new DelimitedToken(0, 3, "don"), tokens.get(0)); + assertEquals(new DelimitedToken(3, 4, "'"), tokens.get(1)); + assertEquals(new DelimitedToken(4, 5, "t"), tokens.get(2)); - tokens = BasicTokenizer.splitOnPunctuation("!!hi"); - assertThat(tokens, contains("!", "!", "hi")); + tokens = BasicTokenizer.splitOnPunctuation(makeToken("!!hi")); + assertEquals(new DelimitedToken(0, 1, "!"), tokens.get(0)); + assertEquals(new DelimitedToken(1, 2, "!"), tokens.get(1)); + assertEquals(new DelimitedToken(2, 4, "hi"), tokens.get(2)); - tokens = BasicTokenizer.splitOnPunctuation("[hi]"); - assertThat(tokens, contains("[", "hi", "]")); + tokens = BasicTokenizer.splitOnPunctuation(makeToken("[hi]")); + assertEquals(new DelimitedToken(0, 1, "["), tokens.get(0)); + assertEquals(new DelimitedToken(1, 3, "hi"), tokens.get(1)); + assertEquals(new DelimitedToken(3, 4, "]"), tokens.get(2)); - tokens = BasicTokenizer.splitOnPunctuation("hi."); - assertThat(tokens, contains("hi", ".")); + tokens = BasicTokenizer.splitOnPunctuation(makeToken("!!")); + assertEquals(new DelimitedToken(0, 1, "!"), tokens.get(0)); + assertEquals(new DelimitedToken(1, 2, "!"), tokens.get(1)); - tokens = BasicTokenizer.splitOnPunctuation("!!"); - assertThat(tokens, contains("!", "!")); + tokens = BasicTokenizer.splitOnPunctuation(makeToken("elastic’s")); + assertEquals(new DelimitedToken(0, 7, "elastic"), tokens.get(0)); + assertEquals(new DelimitedToken(7, 8, "’"), tokens.get(1)); + assertEquals(new DelimitedToken(8, 9, "s"), tokens.get(2)); - tokens = BasicTokenizer.splitOnPunctuation("elastic’s"); - assertThat(tokens, contains("elastic", "’", "s")); - - tokens = BasicTokenizer.splitOnPunctuation("elastic‘s"); - assertThat(tokens, contains("elastic", "‘", "s")); + tokens = BasicTokenizer.splitOnPunctuation(new DelimitedToken(4, 13, "elastic’s")); + assertEquals(new DelimitedToken(4, 11, "elastic"), tokens.get(0)); + assertEquals(new DelimitedToken(11, 12, "’"), tokens.get(1)); + assertEquals(new DelimitedToken(12, 13, "s"), tokens.get(2)); } public void testStripAccents() { @@ -123,8 +136,8 @@ public void testTokenizeCjkChars() { } public void testTokenizeChinese() { - List tokens = new BasicTokenizer().tokenize("ah\u535A\u63A8zz"); - assertThat(tokens, contains("ah", "\u535A", "\u63A8", "zz")); + var tokens = new BasicTokenizer().tokenize("ah\u535A\u63A8zz"); + assertThat(tokenStrings(tokens), contains("ah", "\u535A", "\u63A8", "zz")); } public void testCleanText() { @@ -132,11 +145,6 @@ public void testCleanText() { assertEquals("filter control chars", BasicTokenizer.cleanText("\u0000filter \uFFFDcontrol chars\u0005")); } - public void testWhiteSpaceTokenize() { - assertThat(BasicTokenizer.whiteSpaceTokenize("nochange"), arrayContaining("nochange")); - assertThat(BasicTokenizer.whiteSpaceTokenize(" some change "), arrayContaining("some", "", "change")); - } - public void testIsWhitespace() { assertTrue(BasicTokenizer.isWhiteSpace(' ')); assertTrue(BasicTokenizer.isWhiteSpace('\t')); @@ -187,4 +195,43 @@ public void testIsCjkChar() { assertTrue(BasicTokenizer.isCjkChar(0x2F800)); assertFalse(BasicTokenizer.isCjkChar(0x2FA20)); } + + public void testWhitespaceTokenize() { + { + List delimitedTokens = BasicTokenizer.whiteSpaceTokenize("hello! how are you?"); + assertThat(delimitedTokens, hasSize(4)); + assertThat(tokenStrings(delimitedTokens), contains("hello!", "how", "are", "you?")); + + assertThat(delimitedTokens.get(0), equalTo(new DelimitedToken(0, 6, "hello!"))); + assertThat(delimitedTokens.get(1), equalTo(new DelimitedToken(7, 10, "how"))); + assertThat(delimitedTokens.get(2), equalTo(new DelimitedToken(11, 14, "are"))); + assertThat(delimitedTokens.get(3), equalTo(new DelimitedToken(15, 19, "you?"))); + } + { + List delimitedTokens = BasicTokenizer.whiteSpaceTokenize(" leading whitespace"); + assertThat(delimitedTokens, hasSize(2)); + assertThat(tokenStrings(delimitedTokens), contains("leading", "whitespace")); + + assertThat(delimitedTokens.get(0), equalTo(new DelimitedToken(3, 10, "leading"))); + assertThat(delimitedTokens.get(1), equalTo(new DelimitedToken(11, 21, "whitespace"))); + } + { + List delimitedTokens = BasicTokenizer.whiteSpaceTokenize("double spaced text "); + assertThat(delimitedTokens, hasSize(3)); + assertThat(tokenStrings(delimitedTokens), contains("double", "spaced", "text")); + + assertThat(delimitedTokens.get(0), equalTo(new DelimitedToken(0, 6, "double"))); + assertThat(delimitedTokens.get(1), equalTo(new DelimitedToken(8, 14, "spaced"))); + assertThat(delimitedTokens.get(2), equalTo(new DelimitedToken(16, 20, "text"))); + } + } + + private List tokenStrings(List tokens) { + return tokens.stream().map(DelimitedToken::getToken).collect(Collectors.toList()); + } + + private DelimitedToken makeToken(String str) { + return new DelimitedToken(0, str.length(), str); + } + } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java index 0d326f5e8931e..fa9a9235cf2f6 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java @@ -16,8 +16,9 @@ import java.util.Collections; import java.util.List; import java.util.Set; +import java.util.stream.Collectors; -import static org.hamcrest.Matchers.arrayContaining; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -45,6 +46,10 @@ public class BertTokenizerTests extends ESTestCase { "with" ); + private List tokenStrings(List tokens) { + return tokens.stream().map(DelimitedToken::getToken).collect(Collectors.toList()); + } + public void testTokenize() { BertTokenizer tokenizer = BertTokenizer.builder( TEST_CASED_VOCAB, @@ -52,7 +57,7 @@ public void testTokenize() { ).build(); TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE); - assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", "fun")); + assertThat(tokenStrings(tokenization.getTokens()), contains("Elasticsearch", "fun")); assertArrayEquals(new int[] { 0, 1, 3 }, tokenization.getTokenIds()); assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.getTokenMap()); } @@ -90,18 +95,21 @@ public void testTokenizeLargeInputTruncation() { "Elasticsearch fun with Pancake and Godzilla", Tokenization.Truncate.FIRST ); - assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", "fun", "with", "Pancake")); + assertArrayEquals(new int[] { 0, 1, 3, 18, 17 }, tokenization.getTokenIds()); - tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, new BertTokenization(null, true, 5, Tokenization.Truncate.FIRST)).build(); - tokenization = tokenizer.tokenize("Elasticsearch fun with Pancake and Godzilla", Tokenization.Truncate.FIRST); - assertThat(tokenization.getTokens(), arrayContaining("[CLS]", "Elastic", "##search", "fun", "[SEP]")); + BertTokenizer tokenizerWithSpecialTokens = BertTokenizer.builder( + TEST_CASED_VOCAB, + new BertTokenization(null, true, 5, Tokenization.Truncate.FIRST) + ).build(); + tokenization = tokenizerWithSpecialTokens.tokenize("Elasticsearch fun with Pancake and Godzilla", Tokenization.Truncate.FIRST); + assertArrayEquals(new int[] { 12, 0, 1, 3, 13 }, tokenization.getTokenIds()); + assertArrayEquals(new int[] { -1, 0, 0, 1, -1 }, tokenization.getTokenMap()); } public void testTokenizeAppendSpecialTokens() { BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, Tokenization.createDefault()).build(); TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE); - assertThat(tokenization.getTokens(), arrayContaining("[CLS]", "Elastic", "##search", "fun", "[SEP]")); assertArrayEquals(new int[] { 12, 0, 1, 3, 13 }, tokenization.getTokenIds()); assertArrayEquals(new int[] { -1, 0, 0, 1, -1 }, tokenization.getTokenMap()); } @@ -118,7 +126,7 @@ public void testNeverSplitTokens() { "Elasticsearch " + specialToken + " fun", Tokenization.Truncate.NONE ); - assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", specialToken, "fun")); + assertThat(tokenStrings(tokenization.getTokens()), contains("Elasticsearch", specialToken, "fun")); assertArrayEquals(new int[] { 0, 1, 15, 3 }, tokenization.getTokenIds()); assertArrayEquals(new int[] { 0, 0, 1, 2 }, tokenization.getTokenMap()); } @@ -131,12 +139,12 @@ public void testDoLowerCase() { ).setDoLowerCase(false).setWithSpecialTokens(false).build(); TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE); - assertThat(tokenization.getTokens(), arrayContaining(BertTokenizer.UNKNOWN_TOKEN, "fun")); assertArrayEquals(new int[] { 3, 2 }, tokenization.getTokenIds()); assertArrayEquals(new int[] { 0, 1 }, tokenization.getTokenMap()); tokenization = tokenizer.tokenize("elasticsearch fun", Tokenization.Truncate.NONE); - assertThat(tokenization.getTokens(), arrayContaining("elastic", "##search", "fun")); + assertArrayEquals(new int[] { 0, 1, 2 }, tokenization.getTokenIds()); + assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.getTokenMap()); } { @@ -146,7 +154,7 @@ public void testDoLowerCase() { .build(); TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE); - assertThat(tokenization.getTokens(), arrayContaining("elastic", "##search", "fun")); + assertArrayEquals(new int[] { 0, 1, 2 }, tokenization.getTokenIds()); } } @@ -154,12 +162,11 @@ public void testPunctuation() { BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, Tokenization.createDefault()).setWithSpecialTokens(false).build(); TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch, fun.", Tokenization.Truncate.NONE); - assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", ",", "fun", ".")); + assertThat(tokenStrings(tokenization.getTokens()), contains("Elasticsearch", ",", "fun", ".")); assertArrayEquals(new int[] { 0, 1, 11, 3, 10 }, tokenization.getTokenIds()); assertArrayEquals(new int[] { 0, 0, 1, 2, 3 }, tokenization.getTokenMap()); tokenization = tokenizer.tokenize("Elasticsearch, fun [MASK].", Tokenization.Truncate.NONE); - assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", ",", "fun", "[MASK]", ".")); assertArrayEquals(new int[] { 0, 1, 11, 3, 14, 10 }, tokenization.getTokenIds()); assertArrayEquals(new int[] { 0, 0, 1, 2, 3, 4 }, tokenization.getTokenMap()); } @@ -171,16 +178,19 @@ public void testPunctuationWithMask() { ).setWithSpecialTokens(true).setNeverSplit(Set.of("[MASK]")).build(); TokenizationResult.Tokenization tokenization = tokenizer.tokenize("This is [MASK]-tastic!", Tokenization.Truncate.NONE); - assertThat(tokenization.getTokens(), arrayContaining("[CLS]", "This", "is", "[MASK]", "-", "ta", "##stic", "!", "[SEP]")); + assertThat(tokenStrings(tokenization.getTokens()), contains("This", "is", "[MASK]", "-", "tastic", "!")); + assertArrayEquals(new int[] { 0, 1, 2, 3, 4, 6, 7, 8, 9 }, tokenization.getTokenIds()); + assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 4, 5, -1 }, tokenization.getTokenMap()); tokenization = tokenizer.tokenize("This is sub~[MASK]!", Tokenization.Truncate.NONE); - assertThat(tokenization.getTokens(), arrayContaining("[CLS]", "This", "is", "sub", "~", "[MASK]", "!", "[SEP]")); + assertThat(tokenStrings(tokenization.getTokens()), contains("This", "is", "sub", "~", "[MASK]", "!")); + assertArrayEquals(new int[] { 0, 1, 2, 10, 5, 3, 8, 9 }, tokenization.getTokenIds()); + assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 5, -1 }, tokenization.getTokenMap()); tokenization = tokenizer.tokenize("This is sub,[MASK].tastic!", Tokenization.Truncate.NONE); - assertThat( - tokenization.getTokens(), - arrayContaining("[CLS]", "This", "is", "sub", ",", "[MASK]", ".", "ta", "##stic", "!", "[SEP]") - ); + assertThat(tokenStrings(tokenization.getTokens()), contains("This", "is", "sub", ",", "[MASK]", ".", "tastic", "!")); + assertArrayEquals(new int[] { 0, 1, 2, 10, 11, 3, 12, 6, 7, 8, 9 }, tokenization.getTokenIds()); + assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 5, 6, 6, 7, -1 }, tokenization.getTokenMap()); } public void testBatchInput() { @@ -200,22 +210,18 @@ public void testBatchInput() { assertThat(tr.getTokenizations(), hasSize(4)); TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0); - assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search")); assertArrayEquals(new int[] { 0, 1 }, tokenization.getTokenIds()); assertArrayEquals(new int[] { 0, 0 }, tokenization.getTokenMap()); tokenization = tr.getTokenizations().get(1); - assertThat(tokenization.getTokens(), arrayContaining("my", "little", "red", "car")); assertArrayEquals(new int[] { 4, 5, 6, 7 }, tokenization.getTokenIds()); assertArrayEquals(new int[] { 0, 1, 2, 3 }, tokenization.getTokenMap()); tokenization = tr.getTokenizations().get(2); - assertThat(tokenization.getTokens(), arrayContaining("God", "##zilla", "day")); assertArrayEquals(new int[] { 8, 9, 16 }, tokenization.getTokenIds()); assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.getTokenMap()); tokenization = tr.getTokenizations().get(3); - assertThat(tokenization.getTokens(), arrayContaining("God", "##zilla", "Pancake", "red", "car", "day")); assertArrayEquals(new int[] { 8, 9, 17, 6, 7, 16 }, tokenization.getTokenIds()); assertArrayEquals(new int[] { 0, 0, 1, 2, 3, 4 }, tokenization.getTokenMap()); } @@ -230,9 +236,11 @@ public void testMultiSeqTokenization() { "Godzilla my little red car", Tokenization.Truncate.NONE ); + + var tokenStream = Arrays.stream(tokenization.getTokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList()); assertThat( - tokenization.getTokens(), - arrayContaining( + tokenStream, + contains( BertTokenizer.CLASS_TOKEN, "Elastic", "##search", @@ -260,9 +268,11 @@ public void testTokenizeLargeInputMultiSequenceTruncation() { "Godzilla my little red car", Tokenization.Truncate.FIRST ); + + var tokenStream = Arrays.stream(tokenization.getTokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList()); assertThat( - tokenization.getTokens(), - arrayContaining( + tokenStream, + contains( BertTokenizer.CLASS_TOKEN, "Elastic", BertTokenizer.SEPARATOR_TOKEN, @@ -286,9 +296,10 @@ public void testTokenizeLargeInputMultiSequenceTruncation() { tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, new BertTokenization(null, true, 10, Tokenization.Truncate.SECOND)).build(); tokenization = tokenizer.tokenize("Elasticsearch is fun", "Godzilla my little red car", Tokenization.Truncate.SECOND); + tokenStream = Arrays.stream(tokenization.getTokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList()); assertThat( - tokenization.getTokens(), - arrayContaining( + tokenStream, + contains( BertTokenizer.CLASS_TOKEN, "Elastic", "##search", diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizerTests.java index dd3bf3863d361..c62df28007eef 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizerTests.java @@ -22,39 +22,39 @@ public class WordPieceTokenizerTests extends ESTestCase { public static final String UNKNOWN_TOKEN = "[UNK]"; public void testTokenize() { - Map vocabMap = createVocabMap( - UNKNOWN_TOKEN, - "[CLS]", - "[SEP]", - "want", - "##want", - "##ed", - "wa", - "un", - "runn", - "##ing" - ); + String[] vocab = { UNKNOWN_TOKEN, "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing" }; + Map vocabMap = createVocabMap(vocab); + WordPieceTokenizer tokenizer = new WordPieceTokenizer(vocabMap, UNKNOWN_TOKEN, 100); - List tokenAndIds = tokenizer.tokenize(""); - assertThat(tokenAndIds, empty()); + var tokenIds = tokenizer.tokenize(new DelimitedToken(0, 0, "")); + assertThat(tokenIds, empty()); + + tokenIds = tokenizer.tokenize(makeToken("unwanted")); + List tokenStrings = tokenIds.stream().map(index -> vocab[index]).collect(Collectors.toList()); + assertThat(tokenStrings, contains("un", "##want", "##ed")); - tokenAndIds = tokenizer.tokenize("unwanted running"); - List tokens = tokenAndIds.stream().map(WordPieceTokenizer.TokenAndId::getToken).collect(Collectors.toList()); - assertThat(tokens, contains("un", "##want", "##ed", "runn", "##ing")); + tokenIds = tokenizer.tokenize(makeToken("running")); + tokenStrings = tokenIds.stream().map(index -> vocab[index]).collect(Collectors.toList()); + assertThat(tokenStrings, contains("runn", "##ing")); + + tokenIds = tokenizer.tokenize(makeToken("unwantedX")); + tokenStrings = tokenIds.stream().map(index -> vocab[index]).collect(Collectors.toList()); + assertThat(tokenStrings, contains(UNKNOWN_TOKEN)); + } - tokenAndIds = tokenizer.tokenize("unwantedX running"); - tokens = tokenAndIds.stream().map(WordPieceTokenizer.TokenAndId::getToken).collect(Collectors.toList()); - assertThat(tokens, contains(UNKNOWN_TOKEN, "runn", "##ing")); + private DelimitedToken makeToken(String str) { + return new DelimitedToken(0, str.length(), str); } public void testMaxCharLength() { - Map vocabMap = createVocabMap("Some", "words", "will", "become", "UNK"); + String[] vocab = { "Some", "words", "will", "become", "UNK" }; + Map vocabMap = createVocabMap(vocab); WordPieceTokenizer tokenizer = new WordPieceTokenizer(vocabMap, "UNK", 4); - List tokenAndIds = tokenizer.tokenize("Some words will become UNK"); - List tokens = tokenAndIds.stream().map(WordPieceTokenizer.TokenAndId::getToken).collect(Collectors.toList()); - assertThat(tokens, contains("Some", "UNK", "will", "UNK", "UNK")); + var tokenIds = tokenizer.tokenize(new DelimitedToken(0, 0, "become")); + List tokenStrings = tokenIds.stream().map(index -> vocab[index]).collect(Collectors.toList()); + assertThat(tokenStrings, contains("UNK")); } static Map createVocabMap(String... words) {