Skip to content

Commit

Permalink
[8.0] [ML] Ensure BertTokenizer does not split special tokens (elas…
Browse files Browse the repository at this point in the history
…tic#81254) (elastic#81371)

* [ML] Ensure `BertTokenizer` does not split special tokens (elastic#81254)

This commit changes the way our Bert tokenizer preserves
special tokens without splitting them. The previous approach
split the tokens first on common punctuation, checked if there
were special tokens, and then did further split on all punctuation.
The problem with this was that words containing special tokens and
non-common punctuation were not handled properly.

This commit addresses this by building a trie tree for the special
tokens, splitting the input on all punctuation, and then looking up
the tokens in the special token trie in order to merge matching tokens
back together.

Closes elastic#80484

* Fix compilation
  • Loading branch information
dimitris-athanasiou authored Dec 6, 2021
1 parent c4dffe9 commit 9fc8415
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import java.util.Locale;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Predicate;

/**
* Basic tokenization of text by whitespace with optional extras:
Expand All @@ -32,7 +31,7 @@ public class BasicTokenizer {
private final boolean isLowerCase;
private final boolean isTokenizeCjkChars;
private final boolean isStripAccents;
private final Set<String> neverSplit;
private final TokenTrieNode neverSplitTokenTrieRoot;

/**
* Tokenizer behaviour is controlled by the options passed here.
Expand All @@ -46,14 +45,11 @@ public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars, boolean i
this.isLowerCase = isLowerCase;
this.isTokenizeCjkChars = isTokenizeCjkChars;
this.isStripAccents = isStripAccents;
this.neverSplit = neverSplit;
this.neverSplitTokenTrieRoot = TokenTrieNode.build(neverSplit, this::doTokenize);
}

public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars, boolean isStripAccents) {
this.isLowerCase = isLowerCase;
this.isTokenizeCjkChars = isTokenizeCjkChars;
this.isStripAccents = isStripAccents;
this.neverSplit = Collections.emptySet();
this(isLowerCase, isTokenizeCjkChars, isStripAccents, Collections.emptySet());
}

/**
Expand All @@ -79,6 +75,10 @@ public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars) {
* @return List of tokens
*/
public List<String> tokenize(String text) {
return mergeNeverSplitTokens(doTokenize(text));
}

private List<String> doTokenize(String text) {
text = cleanText(text);
if (isTokenizeCjkChars) {
text = tokenizeCjkChars(text);
Expand All @@ -93,33 +93,47 @@ public List<String> tokenize(String text) {
continue;
}

if (neverSplit.contains(token)) {
processedTokens.add(token);
continue;
if (isLowerCase) {
token = token.toLowerCase(Locale.ROOT);
}

// At this point text has been tokenized by whitespace
// but one of the special never split tokens could be adjacent
// to one or more punctuation characters.
List<String> splitOnCommonTokens = splitOnPredicate(token, BasicTokenizer::isCommonPunctuation);
for (String splitOnCommon : splitOnCommonTokens) {
if (neverSplit.contains(splitOnCommon)) {
processedTokens.add(splitOnCommon);
} else {
if (isLowerCase) {
splitOnCommon = splitOnCommon.toLowerCase(Locale.ROOT);
}
if (isStripAccents) {
splitOnCommon = stripAccents(splitOnCommon);
}
processedTokens.addAll(splitOnPunctuation(splitOnCommon));
}
if (isStripAccents) {
token = stripAccents(token);
}
processedTokens.addAll(splitOnPunctuation(token));
}

return processedTokens;
}

private List<String> mergeNeverSplitTokens(List<String> tokens) {
if (neverSplitTokenTrieRoot.isLeaf()) {
return tokens;
}
List<String> mergedTokens = new ArrayList<>(tokens.size());
List<String> matchingTokens = new ArrayList<>();
TokenTrieNode current = neverSplitTokenTrieRoot;
for (String token : tokens) {
TokenTrieNode childNode = current.getChild(token);
if (childNode == null) {
if (current != neverSplitTokenTrieRoot) {
mergedTokens.addAll(matchingTokens);
matchingTokens = new ArrayList<>();
current = neverSplitTokenTrieRoot;
}
mergedTokens.add(token);
} else if (childNode.isLeaf()) {
matchingTokens.add(token);
mergedTokens.add(String.join("", matchingTokens));
matchingTokens = new ArrayList<>();
current = neverSplitTokenTrieRoot;
} else {
matchingTokens.add(token);
current = childNode;
}
}
return mergedTokens;
}

public boolean isLowerCase() {
return isLowerCase;
}
Expand Down Expand Up @@ -159,16 +173,12 @@ static String stripAccents(String word) {
}

static List<String> splitOnPunctuation(String word) {
return splitOnPredicate(word, BasicTokenizer::isPunctuationMark);
}

static List<String> splitOnPredicate(String word, Predicate<Integer> test) {
List<String> split = new ArrayList<>();
int[] codePoints = word.codePoints().toArray();

int lastSplit = 0;
for (int i = 0; i < codePoints.length; i++) {
if (test.test(codePoints[i])) {
if (isPunctuationMark(codePoints[i])) {
int charCount = i - lastSplit;
if (charCount > 0) {
// add a new string for what has gone before
Expand Down Expand Up @@ -292,14 +302,4 @@ static boolean isPunctuationMark(int codePoint) {
return (category >= Character.DASH_PUNCTUATION && category <= Character.OTHER_PUNCTUATION)
|| (category >= Character.INITIAL_QUOTE_PUNCTUATION && category <= Character.FINAL_QUOTE_PUNCTUATION);
}

/**
* True if the code point is for a common punctuation character
* {@code ! " # $ % & ' ( ) * + , - . / and : ; < = > ?}
* @param codePoint codepoint
* @return true if codepoint is punctuation
*/
static boolean isCommonPunctuation(int codePoint) {
return (codePoint >= 33 && codePoint <= 47) || (codePoint >= 58 && codePoint <= 64);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import org.elasticsearch.core.Nullable;

import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;

class TokenTrieNode {

private static final String EMPTY_STRING = "";

private final Map<String, TokenTrieNode> children;

private TokenTrieNode(Map<String, TokenTrieNode> children) {
this.children = Objects.requireNonNull(children);
}

boolean isLeaf() {
return children.isEmpty();
}

@Nullable
TokenTrieNode getChild(String token) {
return children.get(token);
}

private void insert(List<String> tokens) {
if (tokens.isEmpty()) {
return;
}
TokenTrieNode currentNode = this;
int currentTokenIndex = 0;

// find leaf
while (currentTokenIndex < tokens.size() && currentNode.children.containsKey(tokens.get(currentTokenIndex))) {
currentNode = currentNode.getChild(tokens.get(currentTokenIndex));
currentTokenIndex++;
}
// add rest of tokens as new nodes
while (currentTokenIndex < tokens.size()) {
TokenTrieNode childNode = new TokenTrieNode(new HashMap<>());
currentNode.children.put(tokens.get(currentTokenIndex), childNode);
currentNode = childNode;
currentTokenIndex++;
}
}

static TokenTrieNode build(Collection<String> tokens, Function<String, List<String>> tokenizeFunction) {
TokenTrieNode root = new TokenTrieNode(new HashMap<>());
for (String token : tokens) {
List<String> subTokens = tokenizeFunction.apply(token);
root.insert(subTokens);
}
return root;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ public void testNeverSplit() {

tokens = tokenizer.tokenize("Hello [UNK]!!");
assertThat(tokens, contains("Hello", "[UNK]", "!", "!"));

tokens = tokenizer.tokenize("Hello-[UNK]");
assertThat(tokens, contains("Hello", "-", "[UNK]"));
tokens = tokenizer.tokenize("Hello-[UNK][UNK]");
assertThat(tokens, contains("Hello", "-", "[UNK]", "[UNK]"));
}

public void testSplitOnPunctuation() {
Expand Down Expand Up @@ -154,21 +159,12 @@ public void testIsControl() {
}

public void testIsPunctuation() {
assertTrue(BasicTokenizer.isCommonPunctuation('-'));
assertTrue(BasicTokenizer.isCommonPunctuation('$'));
assertTrue(BasicTokenizer.isCommonPunctuation('.'));
assertFalse(BasicTokenizer.isCommonPunctuation(' '));
assertFalse(BasicTokenizer.isCommonPunctuation('A'));
assertFalse(BasicTokenizer.isCommonPunctuation('`'));

assertTrue(BasicTokenizer.isPunctuationMark('-'));
assertTrue(BasicTokenizer.isPunctuationMark('$'));
assertTrue(BasicTokenizer.isPunctuationMark('`'));
assertTrue(BasicTokenizer.isPunctuationMark('.'));
assertFalse(BasicTokenizer.isPunctuationMark(' '));
assertFalse(BasicTokenizer.isPunctuationMark('A'));

assertFalse(BasicTokenizer.isCommonPunctuation('['));
assertTrue(BasicTokenizer.isPunctuationMark('['));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,21 @@ public void testPunctuation() {

public void testPunctuationWithMask() {
BertTokenizer tokenizer = BertTokenizer.builder(
List.of("[CLS]", "This", "is", "[MASK]", "-", "ta", "##stic", "!", "[SEP]"),
List.of("[CLS]", "This", "is", "[MASK]", "-", "~", "ta", "##stic", "!", "[SEP]", "sub", ",", "."),
Tokenization.createDefault()
).setWithSpecialTokens(true).setNeverSplit(Set.of("[MASK]")).build();

TokenizationResult.Tokenization tokenization = tokenizer.tokenize("This is [MASK]-tastic!");
assertThat(tokenization.getTokens(), arrayContaining("[CLS]", "This", "is", "[MASK]", "-", "ta", "##stic", "!", "[SEP]"));

tokenization = tokenizer.tokenize("This is sub~[MASK]!");
assertThat(tokenization.getTokens(), arrayContaining("[CLS]", "This", "is", "sub", "~", "[MASK]", "!", "[SEP]"));

tokenization = tokenizer.tokenize("This is sub,[MASK].tastic!");
assertThat(
tokenization.getTokens(),
arrayContaining("[CLS]", "This", "is", "sub", ",", "[MASK]", ".", "ta", "##stic", "!", "[SEP]")
);
}

public void testBatchInput() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import org.elasticsearch.test.ESTestCase;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;

public class TokenTrieNodeTests extends ESTestCase {

public void testEmpty() {
TokenTrieNode root = TokenTrieNode.build(Collections.emptyList(), s -> Arrays.asList(s.split(":")));
assertThat(root.isLeaf(), is(true));
}

public void testTokensWithoutDelimiter() {
TokenTrieNode root = TokenTrieNode.build(List.of("a", "b", "c"), s -> Arrays.asList(s.split(":")));
assertThat(root.isLeaf(), is(false));

assertThat(root.getChild("a").isLeaf(), is(true));
assertThat(root.getChild("b").isLeaf(), is(true));
assertThat(root.getChild("c").isLeaf(), is(true));
assertThat(root.getChild("d"), is(nullValue()));
}

public void testTokensWithDelimiter() {
TokenTrieNode root = TokenTrieNode.build(List.of("aa:bb:cc", "aa:bb:dd", "bb:aa:cc", "bb:bb:cc"), s -> Arrays.asList(s.split(":")));
assertThat(root.isLeaf(), is(false));

// Let's look at the aa branch first
{
TokenTrieNode aaNode = root.getChild("aa");
assertThat(aaNode, is(notNullValue()));
assertThat(aaNode.isLeaf(), is(false));
assertThat(aaNode.getChild("zz"), is(nullValue()));
TokenTrieNode bbNode = aaNode.getChild("bb");
assertThat(bbNode, is(notNullValue()));
assertThat(bbNode.isLeaf(), is(false));
assertThat(bbNode.getChild("zz"), is(nullValue()));
TokenTrieNode ccNode = bbNode.getChild("cc");
assertThat(ccNode, is(notNullValue()));
assertThat(ccNode.isLeaf(), is(true));
assertThat(ccNode.getChild("zz"), is(nullValue()));
TokenTrieNode ddNode = bbNode.getChild("dd");
assertThat(ddNode, is(notNullValue()));
assertThat(ddNode.isLeaf(), is(true));
assertThat(ddNode.getChild("zz"), is(nullValue()));
}
// Now the bb branch
{
TokenTrieNode bbNode = root.getChild("bb");
assertThat(bbNode, is(notNullValue()));
assertThat(bbNode.isLeaf(), is(false));
assertThat(bbNode.getChild("zz"), is(nullValue()));
TokenTrieNode aaNode = bbNode.getChild("aa");
assertThat(aaNode, is(notNullValue()));
assertThat(aaNode.isLeaf(), is(false));
assertThat(aaNode.getChild("zz"), is(nullValue()));
TokenTrieNode aaCcNode = aaNode.getChild("cc");
assertThat(aaCcNode, is(notNullValue()));
assertThat(aaCcNode.isLeaf(), is(true));
assertThat(aaCcNode.getChild("zz"), is(nullValue()));
TokenTrieNode bbBbNode = bbNode.getChild("bb");
assertThat(bbBbNode, is(notNullValue()));
assertThat(bbBbNode.isLeaf(), is(false));
assertThat(bbBbNode.getChild("zz"), is(nullValue()));
TokenTrieNode bbCcNode = bbBbNode.getChild("cc");
assertThat(bbCcNode, is(notNullValue()));
assertThat(bbCcNode.isLeaf(), is(true));
assertThat(bbCcNode.getChild("zz"), is(nullValue()));
}
}
}

0 comments on commit 9fc8415

Please sign in to comment.