-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flash Attention with Gemma 2 #31953
Comments
Hey! I think this is pretty much expected when you are quantizing the model, quality is reduced + you have the fact that eager and FLASH don't run exactly the same code. IF you pad then it's expected |
|
I removed the |
The flash attention 2 outputs depend of from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
checkpoint = "google/gemma-2-9b"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
input_text = "Hi, my name is"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
optims = ["eager", "flash_attention_2"]
tokens = int(input("Generate how many tokens? "))
for optim in optims:
torch.manual_seed(0) #IMPORTANT
torch.cuda.empty_cache()
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="cuda", attn_implementation="eager")
outputs = model.generate(**input_ids, max_new_tokens=tokens)
print(f"The {optim} attn_implementation generated {tokenizer.decode(outputs[0])} with seed 0 and max_new_tokens {tokens}.") |
Hey! I don't think that flash attention output are deterministic, #31961 might help, how different are your outputs? |
Okay thx! But it is still strange: |
Ah, what's certain is that the cache is pre-allocated, so there might be some issue with this? cc @zucchini-nlp if you can have a look! |
@ArthurZucker , sorry, I got confused. Do you mean that soft capping is only enabled in flash attention (not in eager or SDPA)? |
The input_embeds is [1, 2910, 3584], and the "new_max_tokens" must be larger than 2910. Then, the generation is hugely slow. Does anybody know the solution? Thanks. I'm using the "eager" attention. But when I trained with "flash_attention_2", it shows "Detected flash_attn version: 2.6.1". |
Flash attnetion is enable in both eager and flash, not in sdpa |
Hey! Sorry it took long to look at this issue. Indeed Gemma generates gibberish for Flash attention and it's because I will see how to enable static shaped cache for flash-attn, should be doable by tweaking with attn masks. @ArthurZucker I also saw you tagged me in another issue. WDYT, do we also need a HybridDynamicCache class, which will be the default in forward calls and can be used in the above idea is not doable? |
System Info
transformers
version: 4.43.0.dev0- compute_environment: LOCAL_MACHINE
- distributed_type: NO
- mixed_precision: bf16
- use_cpu: False
- debug: False
- num_processes: 1
- machine_rank: 0
- num_machines: 1
- gpu_ids: all
- rdzv_backend: static
- same_network: True
- main_training_function: main
- enable_cpu_affinity: False
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
The flash attention implementation isn't returning same output as eager.
Sometimes, the results are same, but sometimes, (especially in long context), it fails totally.
The text was updated successfully, but these errors were encountered: