Skip to content

Commit

Permalink
Revert "[8.0] [ML] Ensure BertTokenizer does not split special toke…
Browse files Browse the repository at this point in the history
…ns (elastic#81254) (elastic#81371)" (elastic#81422)

This reverts commit 9fc8415.
  • Loading branch information
dimitris-athanasiou authored Dec 7, 2021
1 parent 69537e5 commit 504b86c
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 208 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.Locale;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Predicate;

/**
* Basic tokenization of text by whitespace with optional extras:
Expand All @@ -31,7 +32,7 @@ public class BasicTokenizer {
private final boolean isLowerCase;
private final boolean isTokenizeCjkChars;
private final boolean isStripAccents;
private final TokenTrieNode neverSplitTokenTrieRoot;
private final Set<String> neverSplit;

/**
* Tokenizer behaviour is controlled by the options passed here.
Expand All @@ -45,11 +46,14 @@ public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars, boolean i
this.isLowerCase = isLowerCase;
this.isTokenizeCjkChars = isTokenizeCjkChars;
this.isStripAccents = isStripAccents;
this.neverSplitTokenTrieRoot = TokenTrieNode.build(neverSplit, this::doTokenize);
this.neverSplit = neverSplit;
}

public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars, boolean isStripAccents) {
this(isLowerCase, isTokenizeCjkChars, isStripAccents, Collections.emptySet());
this.isLowerCase = isLowerCase;
this.isTokenizeCjkChars = isTokenizeCjkChars;
this.isStripAccents = isStripAccents;
this.neverSplit = Collections.emptySet();
}

/**
Expand All @@ -75,10 +79,6 @@ public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars) {
* @return List of tokens
*/
public List<String> tokenize(String text) {
return mergeNeverSplitTokens(doTokenize(text));
}

private List<String> doTokenize(String text) {
text = cleanText(text);
if (isTokenizeCjkChars) {
text = tokenizeCjkChars(text);
Expand All @@ -93,45 +93,31 @@ private List<String> doTokenize(String text) {
continue;
}

if (isLowerCase) {
token = token.toLowerCase(Locale.ROOT);
}
if (isStripAccents) {
token = stripAccents(token);
if (neverSplit.contains(token)) {
processedTokens.add(token);
continue;
}
processedTokens.addAll(splitOnPunctuation(token));
}

return processedTokens;
}

private List<String> mergeNeverSplitTokens(List<String> tokens) {
if (neverSplitTokenTrieRoot.isLeaf()) {
return tokens;
}
List<String> mergedTokens = new ArrayList<>(tokens.size());
List<String> matchingTokens = new ArrayList<>();
TokenTrieNode current = neverSplitTokenTrieRoot;
for (String token : tokens) {
TokenTrieNode childNode = current.getChild(token);
if (childNode == null) {
if (current != neverSplitTokenTrieRoot) {
mergedTokens.addAll(matchingTokens);
matchingTokens = new ArrayList<>();
current = neverSplitTokenTrieRoot;
// At this point text has been tokenized by whitespace
// but one of the special never split tokens could be adjacent
// to one or more punctuation characters.
List<String> splitOnCommonTokens = splitOnPredicate(token, BasicTokenizer::isCommonPunctuation);
for (String splitOnCommon : splitOnCommonTokens) {
if (neverSplit.contains(splitOnCommon)) {
processedTokens.add(splitOnCommon);
} else {
if (isLowerCase) {
splitOnCommon = splitOnCommon.toLowerCase(Locale.ROOT);
}
if (isStripAccents) {
splitOnCommon = stripAccents(splitOnCommon);
}
processedTokens.addAll(splitOnPunctuation(splitOnCommon));
}
mergedTokens.add(token);
} else if (childNode.isLeaf()) {
matchingTokens.add(token);
mergedTokens.add(String.join("", matchingTokens));
matchingTokens = new ArrayList<>();
current = neverSplitTokenTrieRoot;
} else {
matchingTokens.add(token);
current = childNode;
}
}
return mergedTokens;

return processedTokens;
}

public boolean isLowerCase() {
Expand Down Expand Up @@ -173,12 +159,16 @@ static String stripAccents(String word) {
}

static List<String> splitOnPunctuation(String word) {
return splitOnPredicate(word, BasicTokenizer::isPunctuationMark);
}

static List<String> splitOnPredicate(String word, Predicate<Integer> test) {
List<String> split = new ArrayList<>();
int[] codePoints = word.codePoints().toArray();

int lastSplit = 0;
for (int i = 0; i < codePoints.length; i++) {
if (isPunctuationMark(codePoints[i])) {
if (test.test(codePoints[i])) {
int charCount = i - lastSplit;
if (charCount > 0) {
// add a new string for what has gone before
Expand Down Expand Up @@ -302,4 +292,14 @@ static boolean isPunctuationMark(int codePoint) {
return (category >= Character.DASH_PUNCTUATION && category <= Character.OTHER_PUNCTUATION)
|| (category >= Character.INITIAL_QUOTE_PUNCTUATION && category <= Character.FINAL_QUOTE_PUNCTUATION);
}

/**
* True if the code point is for a common punctuation character
* {@code ! " # $ % & ' ( ) * + , - . / and : ; < = > ?}
* @param codePoint codepoint
* @return true if codepoint is punctuation
*/
static boolean isCommonPunctuation(int codePoint) {
return (codePoint >= 33 && codePoint <= 47) || (codePoint >= 58 && codePoint <= 64);
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,6 @@ public void testNeverSplit() {

tokens = tokenizer.tokenize("Hello [UNK]!!");
assertThat(tokens, contains("Hello", "[UNK]", "!", "!"));

tokens = tokenizer.tokenize("Hello-[UNK]");
assertThat(tokens, contains("Hello", "-", "[UNK]"));
tokens = tokenizer.tokenize("Hello-[UNK][UNK]");
assertThat(tokens, contains("Hello", "-", "[UNK]", "[UNK]"));
}

public void testSplitOnPunctuation() {
Expand Down Expand Up @@ -159,12 +154,21 @@ public void testIsControl() {
}

public void testIsPunctuation() {
assertTrue(BasicTokenizer.isCommonPunctuation('-'));
assertTrue(BasicTokenizer.isCommonPunctuation('$'));
assertTrue(BasicTokenizer.isCommonPunctuation('.'));
assertFalse(BasicTokenizer.isCommonPunctuation(' '));
assertFalse(BasicTokenizer.isCommonPunctuation('A'));
assertFalse(BasicTokenizer.isCommonPunctuation('`'));

assertTrue(BasicTokenizer.isPunctuationMark('-'));
assertTrue(BasicTokenizer.isPunctuationMark('$'));
assertTrue(BasicTokenizer.isPunctuationMark('`'));
assertTrue(BasicTokenizer.isPunctuationMark('.'));
assertFalse(BasicTokenizer.isPunctuationMark(' '));
assertFalse(BasicTokenizer.isPunctuationMark('A'));

assertFalse(BasicTokenizer.isCommonPunctuation('['));
assertTrue(BasicTokenizer.isPunctuationMark('['));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,21 +157,12 @@ public void testPunctuation() {

public void testPunctuationWithMask() {
BertTokenizer tokenizer = BertTokenizer.builder(
List.of("[CLS]", "This", "is", "[MASK]", "-", "~", "ta", "##stic", "!", "[SEP]", "sub", ",", "."),
List.of("[CLS]", "This", "is", "[MASK]", "-", "ta", "##stic", "!", "[SEP]"),
Tokenization.createDefault()
).setWithSpecialTokens(true).setNeverSplit(Set.of("[MASK]")).build();

TokenizationResult.Tokenization tokenization = tokenizer.tokenize("This is [MASK]-tastic!");
assertThat(tokenization.getTokens(), arrayContaining("[CLS]", "This", "is", "[MASK]", "-", "ta", "##stic", "!", "[SEP]"));

tokenization = tokenizer.tokenize("This is sub~[MASK]!");
assertThat(tokenization.getTokens(), arrayContaining("[CLS]", "This", "is", "sub", "~", "[MASK]", "!", "[SEP]"));

tokenization = tokenizer.tokenize("This is sub,[MASK].tastic!");
assertThat(
tokenization.getTokens(),
arrayContaining("[CLS]", "This", "is", "sub", ",", "[MASK]", ".", "ta", "##stic", "!", "[SEP]")
);
}

public void testBatchInput() {
Expand Down

This file was deleted.

0 comments on commit 504b86c

Please sign in to comment.