forked from elastic/elasticsearch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[8.0] [ML] Ensure
BertTokenizer
does not split special tokens (elas…
…tic#81254) (elastic#81371) * [ML] Ensure `BertTokenizer` does not split special tokens (elastic#81254) This commit changes the way our Bert tokenizer preserves special tokens without splitting them. The previous approach split the tokens first on common punctuation, checked if there were special tokens, and then did further split on all punctuation. The problem with this was that words containing special tokens and non-common punctuation were not handled properly. This commit addresses this by building a trie tree for the special tokens, splitting the input on all punctuation, and then looking up the tokens in the special token trie in order to merge matching tokens back together. Closes elastic#80484 * Fix compilation
- Loading branch information
1 parent
c4dffe9
commit 9fc8415
Showing
5 changed files
with
208 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
67 changes: 67 additions & 0 deletions
67
...n/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenTrieNode.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; | ||
|
||
import org.elasticsearch.core.Nullable; | ||
|
||
import java.util.Collection; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
import java.util.function.Function; | ||
|
||
class TokenTrieNode { | ||
|
||
private static final String EMPTY_STRING = ""; | ||
|
||
private final Map<String, TokenTrieNode> children; | ||
|
||
private TokenTrieNode(Map<String, TokenTrieNode> children) { | ||
this.children = Objects.requireNonNull(children); | ||
} | ||
|
||
boolean isLeaf() { | ||
return children.isEmpty(); | ||
} | ||
|
||
@Nullable | ||
TokenTrieNode getChild(String token) { | ||
return children.get(token); | ||
} | ||
|
||
private void insert(List<String> tokens) { | ||
if (tokens.isEmpty()) { | ||
return; | ||
} | ||
TokenTrieNode currentNode = this; | ||
int currentTokenIndex = 0; | ||
|
||
// find leaf | ||
while (currentTokenIndex < tokens.size() && currentNode.children.containsKey(tokens.get(currentTokenIndex))) { | ||
currentNode = currentNode.getChild(tokens.get(currentTokenIndex)); | ||
currentTokenIndex++; | ||
} | ||
// add rest of tokens as new nodes | ||
while (currentTokenIndex < tokens.size()) { | ||
TokenTrieNode childNode = new TokenTrieNode(new HashMap<>()); | ||
currentNode.children.put(tokens.get(currentTokenIndex), childNode); | ||
currentNode = childNode; | ||
currentTokenIndex++; | ||
} | ||
} | ||
|
||
static TokenTrieNode build(Collection<String> tokens, Function<String, List<String>> tokenizeFunction) { | ||
TokenTrieNode root = new TokenTrieNode(new HashMap<>()); | ||
for (String token : tokens) { | ||
List<String> subTokens = tokenizeFunction.apply(token); | ||
root.insert(subTokens); | ||
} | ||
return root; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
84 changes: 84 additions & 0 deletions
84
...src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenTrieNodeTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; | ||
|
||
import org.elasticsearch.test.ESTestCase; | ||
|
||
import java.util.Arrays; | ||
import java.util.Collections; | ||
import java.util.List; | ||
|
||
import static org.hamcrest.Matchers.is; | ||
import static org.hamcrest.Matchers.notNullValue; | ||
import static org.hamcrest.Matchers.nullValue; | ||
|
||
public class TokenTrieNodeTests extends ESTestCase { | ||
|
||
public void testEmpty() { | ||
TokenTrieNode root = TokenTrieNode.build(Collections.emptyList(), s -> Arrays.asList(s.split(":"))); | ||
assertThat(root.isLeaf(), is(true)); | ||
} | ||
|
||
public void testTokensWithoutDelimiter() { | ||
TokenTrieNode root = TokenTrieNode.build(List.of("a", "b", "c"), s -> Arrays.asList(s.split(":"))); | ||
assertThat(root.isLeaf(), is(false)); | ||
|
||
assertThat(root.getChild("a").isLeaf(), is(true)); | ||
assertThat(root.getChild("b").isLeaf(), is(true)); | ||
assertThat(root.getChild("c").isLeaf(), is(true)); | ||
assertThat(root.getChild("d"), is(nullValue())); | ||
} | ||
|
||
public void testTokensWithDelimiter() { | ||
TokenTrieNode root = TokenTrieNode.build(List.of("aa:bb:cc", "aa:bb:dd", "bb:aa:cc", "bb:bb:cc"), s -> Arrays.asList(s.split(":"))); | ||
assertThat(root.isLeaf(), is(false)); | ||
|
||
// Let's look at the aa branch first | ||
{ | ||
TokenTrieNode aaNode = root.getChild("aa"); | ||
assertThat(aaNode, is(notNullValue())); | ||
assertThat(aaNode.isLeaf(), is(false)); | ||
assertThat(aaNode.getChild("zz"), is(nullValue())); | ||
TokenTrieNode bbNode = aaNode.getChild("bb"); | ||
assertThat(bbNode, is(notNullValue())); | ||
assertThat(bbNode.isLeaf(), is(false)); | ||
assertThat(bbNode.getChild("zz"), is(nullValue())); | ||
TokenTrieNode ccNode = bbNode.getChild("cc"); | ||
assertThat(ccNode, is(notNullValue())); | ||
assertThat(ccNode.isLeaf(), is(true)); | ||
assertThat(ccNode.getChild("zz"), is(nullValue())); | ||
TokenTrieNode ddNode = bbNode.getChild("dd"); | ||
assertThat(ddNode, is(notNullValue())); | ||
assertThat(ddNode.isLeaf(), is(true)); | ||
assertThat(ddNode.getChild("zz"), is(nullValue())); | ||
} | ||
// Now the bb branch | ||
{ | ||
TokenTrieNode bbNode = root.getChild("bb"); | ||
assertThat(bbNode, is(notNullValue())); | ||
assertThat(bbNode.isLeaf(), is(false)); | ||
assertThat(bbNode.getChild("zz"), is(nullValue())); | ||
TokenTrieNode aaNode = bbNode.getChild("aa"); | ||
assertThat(aaNode, is(notNullValue())); | ||
assertThat(aaNode.isLeaf(), is(false)); | ||
assertThat(aaNode.getChild("zz"), is(nullValue())); | ||
TokenTrieNode aaCcNode = aaNode.getChild("cc"); | ||
assertThat(aaCcNode, is(notNullValue())); | ||
assertThat(aaCcNode.isLeaf(), is(true)); | ||
assertThat(aaCcNode.getChild("zz"), is(nullValue())); | ||
TokenTrieNode bbBbNode = bbNode.getChild("bb"); | ||
assertThat(bbBbNode, is(notNullValue())); | ||
assertThat(bbBbNode.isLeaf(), is(false)); | ||
assertThat(bbBbNode.getChild("zz"), is(nullValue())); | ||
TokenTrieNode bbCcNode = bbBbNode.getChild("cc"); | ||
assertThat(bbCcNode, is(notNullValue())); | ||
assertThat(bbCcNode.isLeaf(), is(true)); | ||
assertThat(bbCcNode.getChild("zz"), is(nullValue())); | ||
} | ||
} | ||
} |