From 86f31c267f81ed68f696ba449dcf09f506b08d40 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 7 Dec 2021 14:50:43 +0000 Subject: [PATCH] [ML] Preserve casing for never split tokens (#81429) This fixes a bug introduced by #81254. We are now using a token trie tree to merge tokens belonging to one of the never-split tokens back together. However, if the tokenizer is lower casing, then the merged token will also be lower case and won't be matched against never split tokens that are expected to be in upper case. This commit fixes this by looking up the original text and only merging tokens together when the original text is matching one of the never split tokens. --- .../nlp/tokenizers/BasicTokenizer.java | 14 +++++++-- .../nlp/tokenizers/BasicTokenizerTests.java | 30 +++++++++++++++++-- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java index 77af60cab0119..789710cefbfb2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java @@ -32,6 +32,7 @@ public class BasicTokenizer { private final boolean isLowerCase; private final boolean isTokenizeCjkChars; private final boolean isStripAccents; + private final Set neverSplitTokens; private final TokenTrieNode neverSplitTokenTrieRoot; /** @@ -46,6 +47,7 @@ public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars, boolean i this.isLowerCase = isLowerCase; this.isTokenizeCjkChars = isTokenizeCjkChars; this.isStripAccents = isStripAccents; + this.neverSplitTokens = neverSplit; this.neverSplitTokenTrieRoot = TokenTrieNode.build(neverSplit, this::doTokenizeString); } @@ -76,7 +78,7 @@ public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars) { * @return List of tokens */ public List tokenize(String text) { - return mergeNeverSplitTokens(doTokenize(text)); + return mergeNeverSplitTokens(text, doTokenize(text)); } private List doTokenizeString(String text) { @@ -111,7 +113,7 @@ private List doTokenize(String text) { return processedTokens; } - private List mergeNeverSplitTokens(List tokens) { + private List mergeNeverSplitTokens(String originalText, List tokens) { if (neverSplitTokenTrieRoot.isLeaf()) { return tokens; } @@ -129,7 +131,13 @@ private List mergeNeverSplitTokens(List tokens) mergedTokens.add(token); } else if (childNode.isLeaf()) { matchingTokens.add(token); - mergedTokens.add(DelimitedToken.mergeTokens(matchingTokens)); + DelimitedToken mergedToken = DelimitedToken.mergeTokens(matchingTokens); + String originalTokenText = originalText.substring(mergedToken.getStartPos(), mergedToken.getEndPos()); + if (neverSplitTokens.contains(originalTokenText)) { + mergedTokens.add(new DelimitedToken(mergedToken.getStartPos(), mergedToken.getEndPos(), originalTokenText)); + } else { + mergedTokens.addAll(matchingTokens); + } matchingTokens = new ArrayList<>(); current = neverSplitTokenTrieRoot; } else { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java index 4bdd24cafe92a..0e08f31989a90 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java @@ -61,7 +61,7 @@ public void testNoLowerStripAccents() { assertThat(tokenStrings(tokens), contains("HaLLo", "!", "how", "Are", "yoU", "?")); } - public void testNeverSplit() { + public void testNeverSplit_GivenNoLowerCase() { BasicTokenizer tokenizer = new BasicTokenizer(false, false, false, Collections.singleton("[UNK]")); var tokens = tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"); assertThat(tokenStrings(tokens), contains("HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]")); @@ -77,8 +77,32 @@ public void testNeverSplit() { tokens = tokenizer.tokenize("Hello-[UNK]"); assertThat(tokenStrings(tokens), contains("Hello", "-", "[UNK]")); - tokens = tokenizer.tokenize("Hello-[UNK][UNK]"); - assertThat(tokenStrings(tokens), contains("Hello", "-", "[UNK]", "[UNK]")); + tokens = tokenizer.tokenize("Hello~[UNK][UNK]"); + assertThat(tokenStrings(tokens), contains("Hello", "~", "[UNK]", "[UNK]")); + tokens = tokenizer.tokenize("Hello-[unk]"); + assertThat(tokenStrings(tokens), contains("Hello", "-", "[", "unk", "]")); + } + + public void testNeverSplit_GivenLowerCase() { + BasicTokenizer tokenizer = new BasicTokenizer(true, false, false, Collections.singleton("[UNK]")); + var tokens = tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"); + assertThat(tokenStrings(tokens), contains("hello", "!", "how", "are", "you", "?", "[UNK]")); + + tokens = tokenizer.tokenize("Hello [UNK]."); + assertThat(tokenStrings(tokens), contains("hello", "[UNK]", ".")); + + tokens = tokenizer.tokenize("Hello [UNK]?"); + assertThat(tokenStrings(tokens), contains("hello", "[UNK]", "?")); + + tokens = tokenizer.tokenize("Hello [UNK]!!"); + assertThat(tokenStrings(tokens), contains("hello", "[UNK]", "!", "!")); + + tokens = tokenizer.tokenize("Hello-[UNK]"); + assertThat(tokenStrings(tokens), contains("hello", "-", "[UNK]")); + tokens = tokenizer.tokenize("Hello~[UNK][UNK]"); + assertThat(tokenStrings(tokens), contains("hello", "~", "[UNK]", "[UNK]")); + tokens = tokenizer.tokenize("Hello-[unk]"); + assertThat(tokenStrings(tokens), contains("hello", "-", "[", "unk", "]")); } public void testSplitOnPunctuation() {