Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tokenizer] Fix slow and fast serialization #26570

Merged
merged 114 commits into from
Oct 18, 2023

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Oct 3, 2023

What does this PR do?

  • sets the defaults for AddedToken instances where needed to match what is pushed to the hub
  • sets the default for AddedToken to not strip left and right to match the fast tokenizers
  • fixes the added_tokens.json file: a recent push made it save all the added tokens encoder, but it should only save the indexes greater than the vocab size for forward compatibility.
  • fixes the list of additionnal_special_tokens that were added twice / overwritten
  • fixes add_tokens : if the added tokens is a string we check if it's not already in the added vocab instead of always defaulting to strip left or right.
  • fixes saving: the added_tokens_decoder should not add a "__type ":"AddedToken" field to the added tokens otherwise the previous versions of transformers will try to load them.

fixes #26732, fixes #26775, fixes #26773, fixes #26768, fixes #26859

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@ArthurZucker ArthurZucker marked this pull request as ready for review October 16, 2023 13:26
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to merge when ready as seen offline with you

@tongyx361
Copy link

I ran into the error below

Traceback (most recent call last):
  File ".../src/train_flash_attn_2.py", line 11, in <module>
    train()
  File ".../src/train.py", line 157, in train
    tokenizer = transformers.AutoTokenizer.from_pretrained(
  File ".../lib/python3.10/site-packages/transformers/models/auto/tokenization_auto.py", line 751, in from_pretrained
    return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
  File ".../lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2017, in from_pretrained
    return cls._from_pretrained(
  File ".../lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2243, in _from_pretrained
    init_kwargs[key] = added_tokens_map.get(init_kwargs[key], init_kwargs[key])
TypeError: unhashable type: 'dict'

So I added some prints and get this intermediate values:

cls.SPECIAL_TOKENS_ATTRIBUTES: (list)['bos_token', 'eos_token', 'unk_token', 'sep_token', 'pad_token', 'cls_token', 'mask_token', 'additional_special_tokens']
added_tokens_map: (dict){'<unk>': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), '<s>': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), '</s>': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True)}
init_kwargs: (dict){'add_bos_token': True, 'add_eos_token': False, 'bos_token': {'__type': 'AddedToken', 'content': '<s>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}, 'clean_up_tokenization_spaces': False, 'eos_token': {'__type': 'AddedToken', 'content': '</s>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}, 'legacy': None, 'model_max_length': 1024, 'pad_token': None, 'sp_model_kwargs': {}, 'unk_token': {'__type': 'AddedToken', 'content': '<unk>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}, 'vocab_file': '.../tokenizer.model', 'tokenizer_file': '.../tokenizer.json', 'name_or_path': '...'}
key: (streos_token
init_kwargs[key]: (dict){'__type': 'AddedToken', 'content': '</s>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}

According to the output, I made a fix, which seemed to work out:

# Passing AddedTokens and not strings to the class to prevent it from casting the string to a different AddedToken
for key in cls.SPECIAL_TOKENS_ATTRIBUTES & init_kwargs.keys():
    if added_tokens_map != {} and init_kwargs[key] is not None:
        if key != "additional_special_tokens":
            # >>> debug
            def print_info(name, obj):
                print(f"{name}: ({type(obj).__name__}){obj}")
            print_info("cls.SPECIAL_TOKENS_ATTRIBUTES", cls.SPECIAL_TOKENS_ATTRIBUTES)
            print_info("added_tokens_map", added_tokens_map)
            print_info("init_kwargs", init_kwargs)
            print_info("key", key)
            print_info("init_kwargs[key]", init_kwargs[key])
            # <<< debug
-            init_kwargs[key] = added_tokens_map.get(init_kwargs[key], init_kwargs[key])
+            init_kwargs[key] = added_tokens_map.get(key, init_kwargs[key]) # fix

@ArthurZucker
Copy link
Collaborator Author

Could you share a reproducer? Would help me a lot as well!

@ArthurZucker ArthurZucker deleted the fix-main branch October 24, 2023 08:44
@tongyx361
Copy link

tongyx361 commented Oct 25, 2023

Could you share a reproducer? Would help me a lot as well!

Sorry that I'm too busy to do so right now 😭

But this only happened when I loaded the tokenizer of Llemma-7B.

I hope this description could help you reproduce the error.

EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
* fix

* last attempt

* current work

* fix forward compatibility

* save all special tokens

* current state

* revert additional changes

* updates

* remove tokenizer.model

* add a test and the fix

* nit

* revert one more break

* fix typefield issue

* quality

* more tests

* fix fields for FC

* more nits?

* new additional changes

* how

* some updates

* simplify all

* more nits

* revert some things to original

* nice

* nits

* a small hack

* more nits

* ahhaha

* fixup

* update

* make test run on ci

* use subtesting

* update

* Update .circleci/create_circleci_config.py

* updates

* fixup

* nits

* replace typo

* fix the test

* nits

* update

* None max dif pls

* a partial fix

* had to revert one thing

* test the fast

* updates

* fixup

* and more nits

* more fixes

* update

* Oupsy 👁️

* nits

* fix marian

* on our way to heaven

* Update src/transformers/models/t5/tokenization_t5.py

Co-authored-by: Lysandre Debut <[email protected]>

* fixup

* Update src/transformers/tokenization_utils_fast.py

Co-authored-by: Leo Tronchon <[email protected]>

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Leo Tronchon <[email protected]>

* fix phobert

* skip some things, test more

* nits

* fixup

* fix deberta

* update

* update

* more updates

* skip one test

* more updates

* fix camembert

* can't test this one

* more good fixes

* kind of a major update

- seperate what is only done in fast in fast init and refactor
- add_token(AddedToken(..., speicla = True)) ignores it in fast
- better loading

* fixup

* more fixups

* fix pegasus and mpnet

* remove skipped tests

* fix phoneme tokenizer if self.verbose

* fix individual models

* update common tests

* update testing files

* all over again

* nits

* skip test for markup lm

* fixups

* fix order of addition in fast by sorting the added tokens decoder

* proper defaults for deberta

* correct default for fnet

* nits on add tokens, string initialized to special if special

* skip irrelevant herbert tests

* main fixes

* update test added_tokens_serialization

* the fix for bart like models and class instanciating

* update bart

* nit!

* update idefix test

* fix whisper!

* some fixup

* fixups

* revert some of the wrong chanegs

* fixup

* fixup

* skip marian

* skip the correct tests

* skip for tf and flax as well

---------

Co-authored-by: Lysandre Debut <[email protected]>
Co-authored-by: Leo Tronchon <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment