Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SFT chatml applied using setup_chat_format leads to surprising updates that (potentially) break pretraining decisions #1412

Closed
jacobwjs opened this issue Mar 10, 2024 · 8 comments

Comments

@jacobwjs
Copy link

jacobwjs commented Mar 10, 2024

``
Issues/questions:

  1. setup_chat_format seems to need more fine-grained control based on model/tokenizer
  2. Should ChatMlSpecialTokens be abstracted into a base class, allowing more customized implementations based on pretrained models used for SFT?
  3. Is it better to apply averaging to initializing new special token embeddings (see: https://nlp.stanford.edu/~johnhew/vocab-expansion.html)?

I'm instruction tuning google/gemma-7b on a custom dataset using TRL. I came across setup_chat_format, which is a great step towards all things problematic with chat templates. After looking into updates applied to the tokenizer I'm wondering if some of the individual token id updates are problematic, as well as the resulting chat_template update.

The Gemma models require a <bos> token prepended to each message. If I apply setup_chat_format to the tokenizer the new bos_token becomes <|im_start|>. Other issues occur such as the eos_token is overwritten with <|im_end|>. This leaves me with some rather surprising updates to the tokenizer (e.g. ['<bos>', '<eos>', '<unk>', '<pad>'] are lost), as well as the updates to the token embeddings (why aren't added tokens being instantiated from the average of token embeddings, see 3. ?).

Quick and dirty example:

model_name = "google/gemma-7b"    # Base model

chatml_tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
custom_tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)       # using my own custom setup_chat_format
phil_tokenizer = AutoTokenizer.from_pretrained("philschmid/gemma-tokenizer-chatml")  # see: https://huggingface.co/philschmid/gemma-tokenizer-chatml
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_8bit=True,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2", 
    device_map="auto", 
    token=HF_TOKEN
)

# Special tokens from Phil
# ------------
print(phil_tokenizer.all_special_tokens)
print("Chat template: ", repr(phil_tokenizer.chat_template))
# Output:
# ['<bos>', '<eos>', '<unk>', '<pad>', '<|im_start|>', '<|im_end|>']
# Chat template:  "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}"



# Use TRL's implementation
# ------------
print("Special tokens (before): ", chatml_tokenizer.all_special_tokens)
model, chatml_tokenizer = setup_chat_format(
    model,
    chatml_tokenizer,
    format = "chatml",
    resize_to_multiple_of = 64
)
print("Special tokens (after): ", chatml_tokenizer.all_special_tokens)
print("Chat template: ", chatml_tokenizer.chat_template)
# Output:
# Special tokens (before):  ['<bos>', '<eos>', '<unk>', '<pad>']
# Special tokens (after):  ['<|im_start|>', '<|im_end|>', '<unk>']
# Chat template:  "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"


# Apply the chat_template to dummy example
# ----------
messages = [
  {"role": "system", "content": "You are Gemma."},
  {"role": "user", "content": "Hello, how are you?"},
  {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
]

philml = phil_tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
print(philml)
print()
chatml = chatml_tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
print(chatml)
print()
# Output:
# <bos><|im_start|>system
# You are Gemma.<|im_end|>
# <|im_start|>user
# Hello, how are you?<|im_end|>
# <|im_start|>assistant
# I'm doing great. How can I help you today?<|im_end|>
# <eos>

# <|im_start|>system
# You are Gemma.<|im_end|>
# <|im_start|>user
# Hello, how are you?<|im_end|>
# <|im_start|>assistant
# I'm doing great. How can I help you today?<|im_end|>

This all leads to some surprising behavior and updates that ignores pretraining decisions with model/tokenizer that directly impacts downstream SFT using chatml format. This all came from seeing some rather strange loss values.

To accommodate this one has to embed <bos>, <eos> into some processing_fn for the dataset and keep this in mind during inference, or completely ignore Gemma's decisions during pretraining. It seems like this could be resolved with finer-grained control along with more robust templating for specific models when applying setup_chat_format, which I feel is the better approach to reduce individual user effort and get more people on board with a seamless way to unify chat structure during SFT, which is already a mess as of today.

I also like @philschmid work around to essentially reassign tokens that comply with chatml, without losing the model/tokenizer's trained embedding.

@concrete13377
Copy link

concrete13377 commented Mar 25, 2024

Why is setup_chat_format is implemented in a way that <|im_start|> token replaces bos_token? As far as I understand <|im_start|> and <|im_end|> are special tokens to separate user's messages in chat history, shouldn't they be added to the tokenizer as new tokens?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@jacobwjs
Copy link
Author

Any inputs from the TRL team on this?

@ivsanro1
Copy link
Contributor

I'm also having problems with this. I trained a LoRA for Llama3 using setup_chat_format and when I try to load the model with the adapter I get:

RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM:
	size mismatch for model.embed_tokens.weight: copying a param with shape torch.Size([128258, 4096]) from checkpoint, the shape in current model is torch.Size([128256, 4096]).
	size mismatch for lm_head.weight: copying a param with shape torch.Size([128258, 4096]) from checkpoint, the shape in current model is torch.Size([128256, 4096]).

@geronimi73
Copy link

I'm also having problems with this. I trained a LoRA for Llama3 using setup_chat_format and when I try to load the model with the adapter I get:

RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM:
	size mismatch for model.embed_tokens.weight: copying a param with shape torch.Size([128258, 4096]) from checkpoint, the shape in current model is torch.Size([128256, 4096]).
	size mismatch for lm_head.weight: copying a param with shape torch.Size([128258, 4096]) from checkpoint, the shape in current model is torch.Size([128256, 4096]).

did you save lm_head and embed_tokens to the adapters when training?

LoraConfig(
  ..,
  modules_to_save = ["lm_head", "embed_tokens"])
)

if you did and it still fails then probably you didn't apply setup_chat_format to the base model before merging the adapter

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@zyzhang1130
Copy link

Should we just use formatting_func and 'data_collator' instead? It seems together they serve the same purpose as setup_chat_format, except when you want to have system prompt in your training data too..

@ArthurZucker
Copy link

huggingface/tokenizers#1570 could also help as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants