Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QLoRA bf16 + model.generate() in TrainerCallback: "RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16" #1515

Closed
2 of 4 tasks
geronimi73 opened this issue Feb 28, 2024 · 2 comments

Comments

@geronimi73
Copy link

geronimi73 commented Feb 28, 2024

System Info

bitsandbytes                      0.42.0
peft                              0.8.2
torch                             2.2.0
transformers                      4.38.1
trl                               0.7.11
Python 3.10.12

Who can help?

@pacman100 @younesbelkada @saya

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

I've come across a an error with model.generate when used inside a TrainerCallback of SFTTrainer. Happens only when training with TrainingArguments( .., bf16=True, ..) but not with fp16=True. Models tested: mistral and llama2-7b.

RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16

Minimal reproducible example:

from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig, TrainerCallback
from trl import SFTTrainer
from peft import LoraConfig
from datasets import Dataset
import torch

modelpath = "mistralai/Mistral-7B-v0.1"

model = AutoModelForCausalLM.from_pretrained(
    modelpath,    
    device_map="auto",
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
    ),
    torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(modelpath)    
tokenizer.pad_token=tokenizer.eos_token
dummy_dataset=Dataset.from_list([dict(messages=[dict(role="user" if i%2==0 else "assistant", content=str(x)) for i,x in enumerate(range(40))])]*5)

model.generate(**tokenizer("test", return_tensors="pt").to("cuda"))

trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=dummy_dataset,
        eval_dataset=dummy_dataset,
        peft_config = LoraConfig(),
        args=TrainingArguments(
            output_dir="out",
            evaluation_strategy="steps",
            per_device_train_batch_size=1,
            eval_steps=1,
            bf16=True,
        ),
)

class EvaluateFirstStepCallback(TrainerCallback):
    def on_step_end(self, args, state, control, model, tokenizer, **kwargs):
        # TrainingArguments bf16=True: fails with RuntimeError: expected mat1 and mat2 
        # TrainingArguments fp16=True: works 
        display(model.generate(**tokenizer("test", return_tensors="pt").to("cuda")))
        
trainer.add_callback(EvaluateFirstStepCallback())
trainer.train()

Full stacktrace

RuntimeError                              Traceback (most recent call last)
Cell In[4], line 2
      1 input_tok=tokenizer("test", return_tensors="pt")
----> 2 trainer.model.generate(**input_tok)

File ~/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/transformers/generation/utils.py:1544, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1526     return self.assisted_decoding(
   1527         input_ids,
   1528         candidate_generator=candidate_generator,
   (...)
   1540         **model_kwargs,
   1541     )
   1542 if generation_mode == GenerationMode.GREEDY_SEARCH:
   1543     # 11. run greedy search
-> 1544     return self.greedy_search(
   1545         input_ids,
   1546         logits_processor=prepared_logits_processor,
   1547         stopping_criteria=prepared_stopping_criteria,
   1548         pad_token_id=generation_config.pad_token_id,
   1549         eos_token_id=generation_config.eos_token_id,
   1550         output_scores=generation_config.output_scores,
   1551         output_logits=generation_config.output_logits,
   1552         return_dict_in_generate=generation_config.return_dict_in_generate,
   1553         synced_gpus=synced_gpus,
   1554         streamer=streamer,
   1555         **model_kwargs,
   1556     )
   1558 elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
   1559     if not model_kwargs["use_cache"]:

File ~/.local/lib/python3.10/site-packages/transformers/generation/utils.py:2404, in GenerationMixin.greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, output_logits, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
   2401 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   2403 # forward pass to get next token
-> 2404 outputs = self(
   2405     **model_inputs,
   2406     return_dict=True,
   2407     output_attentions=output_attentions,
   2408     output_hidden_states=output_hidden_states,
   2409 )
   2411 if synced_gpus and this_peer_finished:
   2412     continue  # don't waste resources running the code we don't need

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/accelerate_fork/src/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File ~/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1170, in MistralForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1157 outputs = self.model(
   1158     input_ids=input_ids,
   1159     attention_mask=attention_mask,
   (...)
   1166     return_dict=return_dict,
   1167 )
   1169 hidden_states = outputs[0]
-> 1170 logits = self.lm_head(hidden_states)
   1171 logits = logits.float()
   1173 loss = None

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/accelerate_fork/src/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/linear.py:116, in Linear.forward(self, input)
    115 def forward(self, input: Tensor) -> Tensor:
--> 116     return F.linear(input, self.weight, self.bias)

RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16

Any idea what's going on? Thank you!

Expected behavior

bf16 and fp16 both working

@geronimi73
Copy link
Author

geronimi73 commented Feb 28, 2024

found this thread, seems related and indeed the following fixes it:

with torch.cuda.amp.autocast():
    model.generate(**tokenizer("test", return_tensors="pt").to("cuda"))

the mentioned PR which fixed this was merged a year ago, why is autocast() still necessary?

@akshaydigheQ
Copy link

found this thread, seems related and indeed the following fixes it:

with torch.cuda.amp.autocast():
    model.generate(**tokenizer("test", return_tensors="pt").to("cuda"))

the mentioned PR which fixed this was merged a year ago, why is autocast() still necessary?

I am facing this issue while training mistral7b with bf16=True. I went through the code of peft/tuners/lora/bnb.py seems like this code is already added there. any suggestions of this error can be fixed?

It works fine with fp16 as true

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants