Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for additional_special_tokens #1221

Merged
merged 10 commits into from
Jan 31, 2024
Merged

Conversation

DreamGenX
Copy link
Contributor

@DreamGenX DreamGenX commented Jan 27, 2024

Description

This lets users specify "additional_special_tokens" of key of the add_special_tokens method.

Motivation and Context

Special tokens are treated different by the tokenizer, ensuring that they are never broken down. This is important for some use cases, like ChatML.

Consider Yi-200K, which has <|im_start|> and <|im_end|> in its vocabulary already. They are, however not marked as special in the base variant. This means that:

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained(
    "01-ai/Yi-34B-200K"
)

tokens = tokenizer("<|im_start|>system\nX<|im_end|>")

Results in:

{'input_ids': [59666, 59705, 622, 59593, 5858, 46826, 10707, 144, 59733,

However, adding these tokens as special:

tokenizer.add_special_tokens(
    {"additional_special_tokens": ["<|im_start|>", "<|im_end|>"]}
)

Fixes the issue:

{'input_ids': [6, 10707, 144, 59733, 7], 'attention_mask': [1, 1, 1, 1, 1]}

This change will let users handle these cases correctly.

How has this been tested?

I ran a training for 1 step after adding this to my config:

special_tokens:
  additional_special_tokens: ["<|im_start|>", "<|im_end|>"]

Without this, it's likely that many ChatML models are in-fact semi broken, because it's common practice to add <|im_start|> and <|im_end|> like this:

special_tokens:
  eos_token: "<|im_end|>"
tokens:
  - "<|im_start|>"

This means that the tokenizer might incorrectly tokenize <|im_start|> strings.

Social Handles (Optional)

dreamgen on discord
https://twitter.com/DreamGenAI

@DreamGenX
Copy link
Contributor Author

DreamGenX commented Jan 27, 2024

Slight clarification:

tokenizer.add_tokens(...)

Also ensures that the tokens are treated special by the tokenizer, even if the token is already in the vocabulary!

Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to it with indices starting from length of the current vocabulary and and will be isolated before the tokenization algorithm is applied. Added tokens and tokens from the vocabulary of the tokenization algorithm are therefore not treated in the same way.

However, special_tokens_map.json file will not contain such added tokens as special, which will potentially confuse downstream libraries / inference engines trying to load these models, and lead to mistokenization.

But at least during training they should be handled fine:

yi_tokenizer = AutoTokenizer.from_pretrained("01-ai/Yi-34B-200K")

print(yi_tokenizer("<|im_start|>system\nX<|im_end|>")["input_ids"])

print(
    yi_tokenizer.add_tokens(
        [
            AddedToken("<|im_start|>", lstrip=False, rstrip=False, normalized=False),
            AddedToken("<|im_end|>", lstrip=False, rstrip=False, normalized=False),
        ]
    )
)

print(yi_tokenizer("<|im_start|>system\nX<|im_end|>")["input_ids"])

Results in the expected:

[59666, 59705, 622, 59593, 5858, 46826, 10707, 144, 59733, 59666, 59705, 622, 59593, 701, 46826]
2
[6, 10707, 144, 59733, 7]

@komninoschatzipapas
Copy link

Thank you for the PR. I ran into this issue today and it also seems to me that <|im_start|> should be marked as special.

Running your code on the mistral base model with this config:

special_tokens:
  eos_token: "<|im_end|>"
  unk_token: "<unk>"
  additional_special_tokens: ["<|im_start|>"]

Did yield an error:

Traceback (most recent call last):
  File "/root/miniconda3/envs/py3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/root/miniconda3/envs/py3.10/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/workspace/axolotl/src/axolotl/cli/train.py", line 58, in <module>
    fire.Fire(do_cli)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/workspace/axolotl/src/axolotl/cli/train.py", line 34, in do_cli
    return do_train(parsed_cfg, parsed_cli_args)
  File "/workspace/axolotl/src/axolotl/cli/train.py", line 54, in do_train
    return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
  File "/workspace/axolotl/src/axolotl/train.py", line 156, in train
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/transformers/trainer.py", line 1539, in train
    return inner_training_loop(

Changing the config to this worked:

special_tokens:
  eos_token: "<|im_end|>"
  unk_token: "<unk>"
  additional_special_tokens: ["<|im_start|>"]
tokens:
  - "<|im_start|>"

I believe it may have to do with the fact that Mistral doesn't have <|im_start|> in its vocabulary, unlike Yi.

@DreamGenX
Copy link
Contributor Author

Thank you for testing! Do you actually see tokenization differences? I did run over the whole training set with these 2 options:

(A) Only use add_tokens on <|im_start|> and <|im_end|>
(B) Use add_tokens + additional_special_tokens

And I did not find a single example where the token ids would differ. If you did, can you share an example so I can test with it?

On the issue you mention, I think additional_special_tokens should work even if the tokens were not in the vocab yet:

# %%
from transformers import AutoTokenizer, AddedToken

tokens = ["<|im_start|>", "<|im_end|>"]

tok1 = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tok2 = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

tok1.add_tokens(
    [
        AddedToken(token, rstrip=False, lstrip=False, normalized=False)
        for token in tokens
    ]
)

tok2.add_special_tokens({"additional_special_tokens": tokens})

# %%

inputs = [
    "<|im_start|>user\nhello<|im_end|>",
    "<|im_start|>user\nhello<|im_end|>\n<|im_start|> assistant\nworld<|im_end|>",
]

for input in inputs:
    print(tok1.tokenize(input))
    print(tok2.tokenize(input))
    print()

# %%

And it does not raise an exception and the outputs are the same (in this case the ids are swapped, but that's fine):

['<|im_start|>', '▁user', '<0x0A>', 'hello', '<|im_end|>']
[1, 32000, 2188, 13, 21558, 32001]
['<|im_start|>', '▁user', '<0x0A>', 'hello', '<|im_end|>']
[1, 32001, 2188, 13, 21558, 32000]

['<|im_start|>', '▁user', '<0x0A>', 'hello', '<|im_end|>', '▁', '<0x0A>', '<|im_start|>', '▁', '▁assistant', '<0x0A>', 'world', '<|im_end|>']
[1, 32000, 2188, 13, 21558, 32001, 28705, 13, 32000, 28705, 13892, 13, 9471, 32001]
['<|im_start|>', '▁user', '<0x0A>', 'hello', '<|im_end|>', '▁', '<0x0A>', '<|im_start|>', '▁', '▁assistant', '<0x0A>', 'world', '<|im_end|>']
[1, 32001, 2188, 13, 21558, 32000, 28705, 13, 32001, 28705, 13892, 13, 9471, 32000]

So I am not sure what's the issue. The stack trace does not have much information unfortunately :-/

@winglian
Copy link
Collaborator

@DreamGenX , I think the difference in the actual implementation and your test code above is that you are adding both the tokens to the vocabulary, and then setting them as special. setting <|im_start|> as a special token without adding it to the vocabulary is what is happening in @komninoschatzipapas first version of the config.

@winglian
Copy link
Collaborator

I also get this error if the additional_special_tokens are not already in the vocabulary or included in tokens:

  File "/home/wing/micromamba/envs/dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wing/micromamba/envs/dev/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1031, in forward
    attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
  File "/home/wing/micromamba/envs/dev/lib/python3.10/site-packages/torch/_tensor.py", line 997, in __contains__
    return (element == self).any().item()  # type: ignore[union-attr]
RuntimeError: CUDA error: device-side assert triggered

@winglian
Copy link
Collaborator

winglian commented Jan 31, 2024

alright, found the bug and added a test for it. dea2d31

@DreamGenX
Copy link
Contributor Author

Thank you for the fix! Just to clarify, you can see in the example that tok2 uses only add_special_tokens, and not add_tokens, and it works. It seems that normally add_special_tokens also adds the special tokens to the vocab if they aren't there yet.

@winglian
Copy link
Collaborator

rebasing from main to pick up fixes for ci failures b/c of torch-2.2.0 release

@winglian winglian merged commit 25e037f into axolotl-ai-cloud:main Jan 31, 2024
7 checks passed
@andysalerno
Copy link

what if the "additional_special_tokens" is already present in the tokenizer.config for the model that is being fine tuned? Is it still necessary to put them in the axolotl yaml manifest? Or they'll just be picked up as expected from tokenizer.config?

@ZQ-Dev8
Copy link

ZQ-Dev8 commented Mar 8, 2024

@winglian @DreamGenX Based on all the discussion and changes above, it is now very confusing knowing how to properly set up your config.yml for training a sharegpt dataset with the chatml chat template. Would the following stripped example work as intended? It mirrors those found in other issues but it's hard to know what's current at this point. Note this is for a model with a tokenizer that does not already contain <im_start> and <im_end>.

base_model: /hf_downloads/mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer

#etc

datasets:
  - path: 64bits/lima_vicuna_format
    type: sharegpt
chat_template: chatml

#etc

tokens: # added for chatml
 - "<|im_start|>" 
 - '<|im_end|>'
special_tokens:
  eos_token: "<|im_end|>" # added for chatml

djsaunde pushed a commit that referenced this pull request Dec 17, 2024
* Support for additional_special_tokens

* Support for additional_special_tokens. Adjust whitespace.

* Support for additional_special_tokens. Use correct quotes.

* Support for additional_special_tokens. Safe pop.

* Support for additional_special_tokens. nt.

* Support for additional_special_tokens. cfg.special_tokens may be None.

* add token if not in vocabulary when adding additional_special_tokens

* fix logic for copy/pasta

* bugfix for popping from config and tokenizer reload

* no need to add tokens manually now with previous bugfix

---------

Co-authored-by: Wing Lian <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants