Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nit-added-tokens #26538

Merged
merged 16 commits into from
Oct 3, 2023
Merged

Nit-added-tokens #26538

merged 16 commits into from
Oct 3, 2023

Conversation

ArthurZucker
Copy link
Collaborator

What does this PR do?

Fixes #26500, fixes #26536

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 2, 2023

The documentation is not available anymore as the PR was closed or merged.

@ArthurZucker ArthurZucker marked this pull request as ready for review October 2, 2023 15:01
@@ -2382,8 +2387,8 @@ def save_pretrained(
tokenizer_config = copy.deepcopy(self.init_kwargs)

# TODO: Ensure the modified attributes (those are also in the __init__ kwargs) will give identical tokenizers
# target_keys = self.init_kwargs.keys()
target_keys = ["model_max_length", "clean_up_tokenization_spaces", "additional_special_tokens"]
target_keys = list(self.init_kwargs.keys())
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when saving; we should overwrite the init_kwargs with the content of self. Don't know why it was not the case before

@@ -2227,7 +2232,7 @@ def _from_pretrained(
if added_tokens_file is not None:
with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
added_tok_encoder = json.load(added_tokens_handle)
# legacy: we have to init with (rstrip=True, lstrip=True)
# legacy: we have to init with (rstrip=True, lstrip=True) (if the token is new? Failing test)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might have to update this. The tests are shitty and the default is biting us

Comment on lines 2202 to 2204
if str(token) in additional_special_tokens:
# at this point if the token is in `additional_special_tokens` as an str, should be updated
additional_special_tokens.remove(str(token))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only use the default legacy values for AddedToken if the token is not already in the added tokens decoder

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ArthurZucker
Copy link
Collaborator Author

ArthurZucker commented Oct 3, 2023

A small benchmark on the get_added_vocab():

from transformers import AutoTokenizer
import time 
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-moe-54b")
start = time.time();tokenizer.get_added_vocab();print(time.time()-start)
>>> 0.17021536827087402

start = time.time();{k.content: v for v, k in sorted(tokenizer.added_tokens_decoder.items(), key=lambda item: item[0])};print(time.time()-start)
>>> 0.0054759979248046875

start = time.time();tokenizer.added_tokens_decoder;print(time.time()-start)
0.0007669925689697266

will update rust to make tokenizer.added_tokens_encoder available.

@ArthurZucker ArthurZucker merged commit 1a2e966 into huggingface:main Oct 3, 2023
3 checks passed
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Comment on lines +2204 to +2208
if str(token) in additional_special_tokens:
# at this point the token is in `additional_special_tokens` as an str, let's add the AddedToken info
additional_special_tokens.remove(str(token))
if token.special and token not in additional_special_tokens:
additional_special_tokens.append(token)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @LysandreJik here

blbadger pushed a commit to blbadger/transformers that referenced this pull request Nov 8, 2023
* fix stripping

* nits

* fix another test

* styling

* fix?

* update

* revert bad merge

* found the bug

* YES SIR

* is that change really required?

* make fast even faster

* re order functions
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 18, 2023
* fix stripping

* nits

* fix another test

* styling

* fix?

* update

* revert bad merge

* found the bug

* YES SIR

* is that change really required?

* make fast even faster

* re order functions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Tokenizer AddedToken load from file bug Tokenizer pad token not saved with save_pretrained
3 participants