Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix][WIP] Fix the bug of Chinese tokenization in 1.10.12 #6755

Merged
1 change: 1 addition & 0 deletions changelog/6755.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Treat the length of OOV token as 1 to fix token align issue when OOV occurred.
2 changes: 1 addition & 1 deletion rasa/nlu/utils/hugging_face/hf_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _tokenize_example(
token_ids_out += split_token_ids

tokens_out += train_utils.align_tokens(
split_token_strings, token.end, token.start
split_token_strings, token.end, token.start, self.tokenizer.unk_token
)

return tokens_out, token_ids_out
Expand Down
20 changes: 16 additions & 4 deletions rasa/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ def update_similarity_type(config: Dict[Text, Any]) -> Dict[Text, Any]:


def align_tokens(
tokens_in: List[Text], token_end: int, token_start: int
tokens_in: List[Text],
token_end: int,
token_start: int,
unk_token: Optional[Text] = None,
) -> List[Token]:
"""Align sub-tokens of Language model with tokens return by the WhitespaceTokenizer.

Expand All @@ -95,22 +98,31 @@ def align_tokens(
current_token_offset = token_start

for index, string in enumerate(tokens_in):
# There is absolute no guarantee that the length of OOV token is always 1.
# But some documents (e.g.
# https://mccormickml.com/2019/05/14/BERT-word-embeddings-tutorial/#22-tokenization)
# show that it very likely to be 1 in most case.
# It seems that OOV tokens in most languages (except for Chinese) are emoji characters.
# Chinese language has lots of characters, some rare characters may become OOV.
# This is not a perfect solution, but in practice it can solve most issues related to OOV
string_len = len(string) if unk_token is None or string != unk_token else 1

if index == 0:
if index == len(tokens_in) - 1:
s_token_end = token_end
else:
s_token_end = current_token_offset + len(string)
s_token_end = current_token_offset + string_len
tokens_out.append(Token(string, token_start, end=s_token_end))
elif index == len(tokens_in) - 1:
tokens_out.append(Token(string, current_token_offset, end=token_end))
else:
tokens_out.append(
Token(
string, current_token_offset, end=current_token_offset + len(string)
string, current_token_offset, end=current_token_offset + string_len
)
)

current_token_offset += len(string)
current_token_offset += string_len

return tokens_out

Expand Down
44 changes: 41 additions & 3 deletions tests/nlu/tokenizers/test_lm_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
# TODO: need to fix this failing test
@pytest.mark.xfail(strict=False)
@pytest.mark.parametrize(
"model_name, texts, expected_tokens, expected_indices, expected_num_token_ids",
"model_name, model_weights, texts, expected_tokens, expected_indices, expected_num_token_ids",
[
(
"bert",
None,
[
"Good evening.",
"you're",
Expand Down Expand Up @@ -66,8 +67,32 @@
],
[4, 4, 5, 5, 13],
),
(
"bert",
"bert-base-chinese",
[
"晚上好", # normal & easy case
"没问题!", # `!` is a Chinese punctuation
"去东畈村", # `畈` is a OOV token for bert-base-chinese
"好的😃", # include a emoji which is common in Chinese text-based chat
],
[
["晚", "上", "好"],
["没", "问", "题", "!"],
["去", "东", "畈", "村"],
["好", "的", "😃"],
],
[
[(0, 1), (1, 2), (2, 3)],
[(0, 1), (1, 2), (2, 3), (3, 4)],
[(0, 1), (1, 2), (2, 3), (3, 4)],
[(0, 1), (1, 2), (2, 3)],
],
[3, 4, 4, 3],
),
(
"gpt",
None,
[
"Good evening.",
"hello",
Expand Down Expand Up @@ -106,6 +131,7 @@
),
(
"gpt2",
None,
[
"Good evening.",
"hello",
Expand Down Expand Up @@ -158,6 +184,7 @@
),
(
"xlnet",
None,
[
"Good evening.",
"hello",
Expand Down Expand Up @@ -208,6 +235,7 @@
),
(
"distilbert",
None,
[
"Good evening.",
"you're",
Expand Down Expand Up @@ -257,6 +285,7 @@
),
(
"roberta",
None,
[
"Good evening.",
"hello",
Expand Down Expand Up @@ -310,10 +339,19 @@
],
)
def test_lm_tokenizer_edge_cases(
model_name, texts, expected_tokens, expected_indices, expected_num_token_ids
model_name,
model_weights,
texts,
expected_tokens,
expected_indices,
expected_num_token_ids,
):

transformers_config = {"model_name": model_name}
if model_weights is None:
model_weights_config = {}
else:
model_weights_config = {"model_weights": model_weights}
transformers_config = {**{"model_name": model_name}, **model_weights_config}

transformers_nlp = HFTransformersNLP(transformers_config)
lm_tokenizer = LanguageModelTokenizer()
Expand Down