Skip to content

Commit

Permalink
[ML] Preserve casing for never split tokens (#81429)
Browse files Browse the repository at this point in the history
This fixes a bug introduced by #81254. We are now using
a token trie tree to merge tokens belonging to one of the
never-split tokens back together. However, if the tokenizer
is lower casing, then the merged token will also be lower case
and won't be matched against never split tokens that are expected
to be in upper case.

This commit fixes this by looking up the original text and only
merging tokens together when the original text is matching one
of the never split tokens.
  • Loading branch information
dimitris-athanasiou authored Dec 7, 2021
1 parent 061d38c commit 86f31c2
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class BasicTokenizer {
private final boolean isLowerCase;
private final boolean isTokenizeCjkChars;
private final boolean isStripAccents;
private final Set<String> neverSplitTokens;
private final TokenTrieNode neverSplitTokenTrieRoot;

/**
Expand All @@ -46,6 +47,7 @@ public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars, boolean i
this.isLowerCase = isLowerCase;
this.isTokenizeCjkChars = isTokenizeCjkChars;
this.isStripAccents = isStripAccents;
this.neverSplitTokens = neverSplit;
this.neverSplitTokenTrieRoot = TokenTrieNode.build(neverSplit, this::doTokenizeString);
}

Expand Down Expand Up @@ -76,7 +78,7 @@ public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars) {
* @return List of tokens
*/
public List<DelimitedToken> tokenize(String text) {
return mergeNeverSplitTokens(doTokenize(text));
return mergeNeverSplitTokens(text, doTokenize(text));
}

private List<String> doTokenizeString(String text) {
Expand Down Expand Up @@ -111,7 +113,7 @@ private List<DelimitedToken> doTokenize(String text) {
return processedTokens;
}

private List<DelimitedToken> mergeNeverSplitTokens(List<DelimitedToken> tokens) {
private List<DelimitedToken> mergeNeverSplitTokens(String originalText, List<DelimitedToken> tokens) {
if (neverSplitTokenTrieRoot.isLeaf()) {
return tokens;
}
Expand All @@ -129,7 +131,13 @@ private List<DelimitedToken> mergeNeverSplitTokens(List<DelimitedToken> tokens)
mergedTokens.add(token);
} else if (childNode.isLeaf()) {
matchingTokens.add(token);
mergedTokens.add(DelimitedToken.mergeTokens(matchingTokens));
DelimitedToken mergedToken = DelimitedToken.mergeTokens(matchingTokens);
String originalTokenText = originalText.substring(mergedToken.getStartPos(), mergedToken.getEndPos());
if (neverSplitTokens.contains(originalTokenText)) {
mergedTokens.add(new DelimitedToken(mergedToken.getStartPos(), mergedToken.getEndPos(), originalTokenText));
} else {
mergedTokens.addAll(matchingTokens);
}
matchingTokens = new ArrayList<>();
current = neverSplitTokenTrieRoot;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public void testNoLowerStripAccents() {
assertThat(tokenStrings(tokens), contains("HaLLo", "!", "how", "Are", "yoU", "?"));
}

public void testNeverSplit() {
public void testNeverSplit_GivenNoLowerCase() {
BasicTokenizer tokenizer = new BasicTokenizer(false, false, false, Collections.singleton("[UNK]"));
var tokens = tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]");
assertThat(tokenStrings(tokens), contains("HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"));
Expand All @@ -77,8 +77,32 @@ public void testNeverSplit() {

tokens = tokenizer.tokenize("Hello-[UNK]");
assertThat(tokenStrings(tokens), contains("Hello", "-", "[UNK]"));
tokens = tokenizer.tokenize("Hello-[UNK][UNK]");
assertThat(tokenStrings(tokens), contains("Hello", "-", "[UNK]", "[UNK]"));
tokens = tokenizer.tokenize("Hello~[UNK][UNK]");
assertThat(tokenStrings(tokens), contains("Hello", "~", "[UNK]", "[UNK]"));
tokens = tokenizer.tokenize("Hello-[unk]");
assertThat(tokenStrings(tokens), contains("Hello", "-", "[", "unk", "]"));
}

public void testNeverSplit_GivenLowerCase() {
BasicTokenizer tokenizer = new BasicTokenizer(true, false, false, Collections.singleton("[UNK]"));
var tokens = tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]");
assertThat(tokenStrings(tokens), contains("hello", "!", "how", "are", "you", "?", "[UNK]"));

tokens = tokenizer.tokenize("Hello [UNK].");
assertThat(tokenStrings(tokens), contains("hello", "[UNK]", "."));

tokens = tokenizer.tokenize("Hello [UNK]?");
assertThat(tokenStrings(tokens), contains("hello", "[UNK]", "?"));

tokens = tokenizer.tokenize("Hello [UNK]!!");
assertThat(tokenStrings(tokens), contains("hello", "[UNK]", "!", "!"));

tokens = tokenizer.tokenize("Hello-[UNK]");
assertThat(tokenStrings(tokens), contains("hello", "-", "[UNK]"));
tokens = tokenizer.tokenize("Hello~[UNK][UNK]");
assertThat(tokenStrings(tokens), contains("hello", "~", "[UNK]", "[UNK]"));
tokens = tokenizer.tokenize("Hello-[unk]");
assertThat(tokenStrings(tokens), contains("hello", "-", "[", "unk", "]"));
}

public void testSplitOnPunctuation() {
Expand Down

0 comments on commit 86f31c2

Please sign in to comment.