diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java index d98f1000d42e4..4ca61a03f18c6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java @@ -15,6 +15,7 @@ import java.util.Locale; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Predicate; /** * Basic tokenization of text by whitespace with optional extras: @@ -31,7 +32,7 @@ public class BasicTokenizer { private final boolean isLowerCase; private final boolean isTokenizeCjkChars; private final boolean isStripAccents; - private final TokenTrieNode neverSplitTokenTrieRoot; + private final Set neverSplit; /** * Tokenizer behaviour is controlled by the options passed here. @@ -45,11 +46,14 @@ public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars, boolean i this.isLowerCase = isLowerCase; this.isTokenizeCjkChars = isTokenizeCjkChars; this.isStripAccents = isStripAccents; - this.neverSplitTokenTrieRoot = TokenTrieNode.build(neverSplit, this::doTokenize); + this.neverSplit = neverSplit; } public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars, boolean isStripAccents) { - this(isLowerCase, isTokenizeCjkChars, isStripAccents, Collections.emptySet()); + this.isLowerCase = isLowerCase; + this.isTokenizeCjkChars = isTokenizeCjkChars; + this.isStripAccents = isStripAccents; + this.neverSplit = Collections.emptySet(); } /** @@ -75,10 +79,6 @@ public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars) { * @return List of tokens */ public List tokenize(String text) { - return mergeNeverSplitTokens(doTokenize(text)); - } - - private List doTokenize(String text) { text = cleanText(text); if (isTokenizeCjkChars) { text = tokenizeCjkChars(text); @@ -93,45 +93,31 @@ private List doTokenize(String text) { continue; } - if (isLowerCase) { - token = token.toLowerCase(Locale.ROOT); - } - if (isStripAccents) { - token = stripAccents(token); + if (neverSplit.contains(token)) { + processedTokens.add(token); + continue; } - processedTokens.addAll(splitOnPunctuation(token)); - } - - return processedTokens; - } - private List mergeNeverSplitTokens(List tokens) { - if (neverSplitTokenTrieRoot.isLeaf()) { - return tokens; - } - List mergedTokens = new ArrayList<>(tokens.size()); - List matchingTokens = new ArrayList<>(); - TokenTrieNode current = neverSplitTokenTrieRoot; - for (String token : tokens) { - TokenTrieNode childNode = current.getChild(token); - if (childNode == null) { - if (current != neverSplitTokenTrieRoot) { - mergedTokens.addAll(matchingTokens); - matchingTokens = new ArrayList<>(); - current = neverSplitTokenTrieRoot; + // At this point text has been tokenized by whitespace + // but one of the special never split tokens could be adjacent + // to one or more punctuation characters. + List splitOnCommonTokens = splitOnPredicate(token, BasicTokenizer::isCommonPunctuation); + for (String splitOnCommon : splitOnCommonTokens) { + if (neverSplit.contains(splitOnCommon)) { + processedTokens.add(splitOnCommon); + } else { + if (isLowerCase) { + splitOnCommon = splitOnCommon.toLowerCase(Locale.ROOT); + } + if (isStripAccents) { + splitOnCommon = stripAccents(splitOnCommon); + } + processedTokens.addAll(splitOnPunctuation(splitOnCommon)); } - mergedTokens.add(token); - } else if (childNode.isLeaf()) { - matchingTokens.add(token); - mergedTokens.add(String.join("", matchingTokens)); - matchingTokens = new ArrayList<>(); - current = neverSplitTokenTrieRoot; - } else { - matchingTokens.add(token); - current = childNode; } } - return mergedTokens; + + return processedTokens; } public boolean isLowerCase() { @@ -173,12 +159,16 @@ static String stripAccents(String word) { } static List splitOnPunctuation(String word) { + return splitOnPredicate(word, BasicTokenizer::isPunctuationMark); + } + + static List splitOnPredicate(String word, Predicate test) { List split = new ArrayList<>(); int[] codePoints = word.codePoints().toArray(); int lastSplit = 0; for (int i = 0; i < codePoints.length; i++) { - if (isPunctuationMark(codePoints[i])) { + if (test.test(codePoints[i])) { int charCount = i - lastSplit; if (charCount > 0) { // add a new string for what has gone before @@ -302,4 +292,14 @@ static boolean isPunctuationMark(int codePoint) { return (category >= Character.DASH_PUNCTUATION && category <= Character.OTHER_PUNCTUATION) || (category >= Character.INITIAL_QUOTE_PUNCTUATION && category <= Character.FINAL_QUOTE_PUNCTUATION); } + + /** + * True if the code point is for a common punctuation character + * {@code ! " # $ % & ' ( ) * + , - . / and : ; < = > ?} + * @param codePoint codepoint + * @return true if codepoint is punctuation + */ + static boolean isCommonPunctuation(int codePoint) { + return (codePoint >= 33 && codePoint <= 47) || (codePoint >= 58 && codePoint <= 64); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenTrieNode.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenTrieNode.java deleted file mode 100644 index a6716a9580372..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenTrieNode.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; - -import org.elasticsearch.core.Nullable; - -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.function.Function; - -class TokenTrieNode { - - private static final String EMPTY_STRING = ""; - - private final Map children; - - private TokenTrieNode(Map children) { - this.children = Objects.requireNonNull(children); - } - - boolean isLeaf() { - return children.isEmpty(); - } - - @Nullable - TokenTrieNode getChild(String token) { - return children.get(token); - } - - private void insert(List tokens) { - if (tokens.isEmpty()) { - return; - } - TokenTrieNode currentNode = this; - int currentTokenIndex = 0; - - // find leaf - while (currentTokenIndex < tokens.size() && currentNode.children.containsKey(tokens.get(currentTokenIndex))) { - currentNode = currentNode.getChild(tokens.get(currentTokenIndex)); - currentTokenIndex++; - } - // add rest of tokens as new nodes - while (currentTokenIndex < tokens.size()) { - TokenTrieNode childNode = new TokenTrieNode(new HashMap<>()); - currentNode.children.put(tokens.get(currentTokenIndex), childNode); - currentNode = childNode; - currentTokenIndex++; - } - } - - static TokenTrieNode build(Collection tokens, Function> tokenizeFunction) { - TokenTrieNode root = new TokenTrieNode(new HashMap<>()); - for (String token : tokens) { - List subTokens = tokenizeFunction.apply(token); - root.insert(subTokens); - } - return root; - } -} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java index 87b02e6066a9c..017c3323abf5e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java @@ -72,11 +72,6 @@ public void testNeverSplit() { tokens = tokenizer.tokenize("Hello [UNK]!!"); assertThat(tokens, contains("Hello", "[UNK]", "!", "!")); - - tokens = tokenizer.tokenize("Hello-[UNK]"); - assertThat(tokens, contains("Hello", "-", "[UNK]")); - tokens = tokenizer.tokenize("Hello-[UNK][UNK]"); - assertThat(tokens, contains("Hello", "-", "[UNK]", "[UNK]")); } public void testSplitOnPunctuation() { @@ -159,12 +154,21 @@ public void testIsControl() { } public void testIsPunctuation() { + assertTrue(BasicTokenizer.isCommonPunctuation('-')); + assertTrue(BasicTokenizer.isCommonPunctuation('$')); + assertTrue(BasicTokenizer.isCommonPunctuation('.')); + assertFalse(BasicTokenizer.isCommonPunctuation(' ')); + assertFalse(BasicTokenizer.isCommonPunctuation('A')); + assertFalse(BasicTokenizer.isCommonPunctuation('`')); + assertTrue(BasicTokenizer.isPunctuationMark('-')); assertTrue(BasicTokenizer.isPunctuationMark('$')); assertTrue(BasicTokenizer.isPunctuationMark('`')); assertTrue(BasicTokenizer.isPunctuationMark('.')); assertFalse(BasicTokenizer.isPunctuationMark(' ')); assertFalse(BasicTokenizer.isPunctuationMark('A')); + + assertFalse(BasicTokenizer.isCommonPunctuation('[')); assertTrue(BasicTokenizer.isPunctuationMark('[')); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java index 2daf13816e1d5..b2ecd831d54a9 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java @@ -157,21 +157,12 @@ public void testPunctuation() { public void testPunctuationWithMask() { BertTokenizer tokenizer = BertTokenizer.builder( - List.of("[CLS]", "This", "is", "[MASK]", "-", "~", "ta", "##stic", "!", "[SEP]", "sub", ",", "."), + List.of("[CLS]", "This", "is", "[MASK]", "-", "ta", "##stic", "!", "[SEP]"), Tokenization.createDefault() ).setWithSpecialTokens(true).setNeverSplit(Set.of("[MASK]")).build(); TokenizationResult.Tokenization tokenization = tokenizer.tokenize("This is [MASK]-tastic!"); assertThat(tokenization.getTokens(), arrayContaining("[CLS]", "This", "is", "[MASK]", "-", "ta", "##stic", "!", "[SEP]")); - - tokenization = tokenizer.tokenize("This is sub~[MASK]!"); - assertThat(tokenization.getTokens(), arrayContaining("[CLS]", "This", "is", "sub", "~", "[MASK]", "!", "[SEP]")); - - tokenization = tokenizer.tokenize("This is sub,[MASK].tastic!"); - assertThat( - tokenization.getTokens(), - arrayContaining("[CLS]", "This", "is", "sub", ",", "[MASK]", ".", "ta", "##stic", "!", "[SEP]") - ); } public void testBatchInput() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenTrieNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenTrieNodeTests.java deleted file mode 100644 index a96d557d36b50..0000000000000 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenTrieNodeTests.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; - -import org.elasticsearch.test.ESTestCase; - -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.notNullValue; -import static org.hamcrest.Matchers.nullValue; - -public class TokenTrieNodeTests extends ESTestCase { - - public void testEmpty() { - TokenTrieNode root = TokenTrieNode.build(Collections.emptyList(), s -> Arrays.asList(s.split(":"))); - assertThat(root.isLeaf(), is(true)); - } - - public void testTokensWithoutDelimiter() { - TokenTrieNode root = TokenTrieNode.build(List.of("a", "b", "c"), s -> Arrays.asList(s.split(":"))); - assertThat(root.isLeaf(), is(false)); - - assertThat(root.getChild("a").isLeaf(), is(true)); - assertThat(root.getChild("b").isLeaf(), is(true)); - assertThat(root.getChild("c").isLeaf(), is(true)); - assertThat(root.getChild("d"), is(nullValue())); - } - - public void testTokensWithDelimiter() { - TokenTrieNode root = TokenTrieNode.build(List.of("aa:bb:cc", "aa:bb:dd", "bb:aa:cc", "bb:bb:cc"), s -> Arrays.asList(s.split(":"))); - assertThat(root.isLeaf(), is(false)); - - // Let's look at the aa branch first - { - TokenTrieNode aaNode = root.getChild("aa"); - assertThat(aaNode, is(notNullValue())); - assertThat(aaNode.isLeaf(), is(false)); - assertThat(aaNode.getChild("zz"), is(nullValue())); - TokenTrieNode bbNode = aaNode.getChild("bb"); - assertThat(bbNode, is(notNullValue())); - assertThat(bbNode.isLeaf(), is(false)); - assertThat(bbNode.getChild("zz"), is(nullValue())); - TokenTrieNode ccNode = bbNode.getChild("cc"); - assertThat(ccNode, is(notNullValue())); - assertThat(ccNode.isLeaf(), is(true)); - assertThat(ccNode.getChild("zz"), is(nullValue())); - TokenTrieNode ddNode = bbNode.getChild("dd"); - assertThat(ddNode, is(notNullValue())); - assertThat(ddNode.isLeaf(), is(true)); - assertThat(ddNode.getChild("zz"), is(nullValue())); - } - // Now the bb branch - { - TokenTrieNode bbNode = root.getChild("bb"); - assertThat(bbNode, is(notNullValue())); - assertThat(bbNode.isLeaf(), is(false)); - assertThat(bbNode.getChild("zz"), is(nullValue())); - TokenTrieNode aaNode = bbNode.getChild("aa"); - assertThat(aaNode, is(notNullValue())); - assertThat(aaNode.isLeaf(), is(false)); - assertThat(aaNode.getChild("zz"), is(nullValue())); - TokenTrieNode aaCcNode = aaNode.getChild("cc"); - assertThat(aaCcNode, is(notNullValue())); - assertThat(aaCcNode.isLeaf(), is(true)); - assertThat(aaCcNode.getChild("zz"), is(nullValue())); - TokenTrieNode bbBbNode = bbNode.getChild("bb"); - assertThat(bbBbNode, is(notNullValue())); - assertThat(bbBbNode.isLeaf(), is(false)); - assertThat(bbBbNode.getChild("zz"), is(nullValue())); - TokenTrieNode bbCcNode = bbBbNode.getChild("cc"); - assertThat(bbCcNode, is(notNullValue())); - assertThat(bbCcNode.isLeaf(), is(true)); - assertThat(bbCcNode.getChild("zz"), is(nullValue())); - } - } -}