Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash Attention with Gemma 2 #31953

Closed
2 of 4 tasks
Boubou78000 opened this issue Jul 14, 2024 · 11 comments · Fixed by #32188
Closed
2 of 4 tasks

Flash Attention with Gemma 2 #31953

Boubou78000 opened this issue Jul 14, 2024 · 11 comments · Fixed by #32188

Comments

@Boubou78000
Copy link

Boubou78000 commented Jul 14, 2024

System Info

  • transformers version: 4.43.0.dev0
  • Platform: Windows-10-10.0.22631-SP0
  • Python version: 3.11.9
  • Huggingface_hub version: 0.23.4
  • Safetensors version: 0.4.3
  • Accelerate version: 0.32.1
  • Accelerate config:
    - compute_environment: LOCAL_MACHINE
    - distributed_type: NO
    - mixed_precision: bf16
    - use_cpu: False
    - debug: False
    - num_processes: 1
    - machine_rank: 0
    - num_machines: 1
    - gpu_ids: all
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - enable_cpu_affinity: False
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env: []
  • PyTorch version (GPU?): 2.3.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: NVIDIA GeForce RTX 4070
  • Flash attention = 2.6.1 (Added by me)

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import os
import warnings
import timeit

def start_timing(process="process"):
    global start_time
    start_time = timeit.default_timer()
    print(f"Started {process}.")

def finish_timing(description, suffix=""):
    global start_time
    end_time = timeit.default_timer()
    elapsed_time = end_time - start_time
    print(f"{description} took {elapsed_time:.2f}s {suffix}.")

warnings.filterwarnings("ignore")

os.environ["USE_FLASH_ATTENTION"]="1"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

checkpoint = "google/gemma-2-9b"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="cuda", quantization_config=bnb_config, torch_dtype='bfloat16')

input_text = \
"""Hi, my name is"""

base_output = "idk"

context = ""

for line in input_text.split("\n"):
    context += line+" "
    for optim in ["eager", "sdpa", "flash_attention_2"]: #
        torch.manual_seed(0)
        torch.cuda.empty_cache()
        input_ids = tokenizer(context, return_tensors="pt", add_special_tokens=True).to("cuda")
        model.config._attn_implementation = optim
        start_timing()
        outputs = model.generate(**input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id)
        finish_timing(f"The {optim} attention implementation", f"to infer 100 tokens")
        if optim == "eager":
            base_output = tokenizer.decode(outputs[0])
        if base_output == tokenizer.decode(outputs[0]):
            print(f"The {optim} attention returned the same output as eager attention.")
        else:
            print(f"The {optim} attention didn't return the same output as eager attention.")
        print()

Expected behavior

The flash attention implementation isn't returning same output as eager.

Sometimes, the results are same, but sometimes, (especially in long context), it fails totally.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jul 15, 2024

Hey! I think this is pretty much expected when you are quantizing the model, quality is reduced + you have the fact that eager and FLASH don't run exactly the same code. IF you pad then it's expected

@ArthurZucker
Copy link
Collaborator

USE_FLASH_ATTENTION is not in transformers no? so I suppose you are actually using SDPA. You should set attn_implementation="flash_attn2" otherwise you don't have the soft capping

@Boubou78000
Copy link
Author

Boubou78000 commented Jul 15, 2024

I removed the bnb_config and the padding, but I can't understand why SDPA is returning the same output as eager and not flash attention. For all other models (like Gemma 1), it works

@Boubou78000
Copy link
Author

Boubou78000 commented Jul 15, 2024

The flash attention 2 outputs depend of max_new_tokens:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

checkpoint = "google/gemma-2-9b"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

input_text = "Hi, my name is"

input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

optims = ["eager", "flash_attention_2"]

tokens = int(input("Generate how many tokens? "))
for optim in optims:
    torch.manual_seed(0) #IMPORTANT
    torch.cuda.empty_cache()
    model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="cuda", attn_implementation="eager")
    outputs = model.generate(**input_ids, max_new_tokens=tokens)
    print(f"The {optim} attn_implementation generated {tokenizer.decode(outputs[0])} with seed 0 and max_new_tokens {tokens}.")

@ArthurZucker
Copy link
Collaborator

Hey! I don't think that flash attention output are deterministic, #31961 might help, how different are your outputs?
Also your script here uses attn_implementation="eager" not attn_implementation=optim)

@Boubou78000
Copy link
Author

Okay thx!

But it is still strange:
I have a text of 26 lines and when max_new_tokens=10, it's fast and works good. But when i do max_new_tokens=1_000, it just outputs some random tokens...

@ArthurZucker
Copy link
Collaborator

Ah, what's certain is that the cache is pre-allocated, so there might be some issue with this? cc @zucchini-nlp if you can have a look!

@tanliboy
Copy link

tanliboy commented Jul 17, 2024

You should set attn_implementation="flash_attn2" otherwise you don't have the soft capping

@ArthurZucker , sorry, I got confused. Do you mean that soft capping is only enabled in flash attention (not in eager or SDPA)?

@bug-fixed
Copy link

The input_embeds is [1, 2910, 3584], and the "new_max_tokens" must be larger than 2910. Then, the generation is hugely slow. Does anybody know the solution? Thanks. I'm using the "eager" attention.

But when I trained with "flash_attention_2", it shows "Detected flash_attn version: 2.6.1".

@ArthurZucker
Copy link
Collaborator

Flash attnetion is enable in both eager and flash, not in sdpa

@zucchini-nlp
Copy link
Member

Hey! Sorry it took long to look at this issue. Indeed Gemma generates gibberish for Flash attention and it's because static cache implementation is not compatible with attn_implementation==flash_attention_2. In other words, Gemma supports only Hybrid cache which is a static shaped cache.

I will see how to enable static shaped cache for flash-attn, should be doable by tweaking with attn masks.

@ArthurZucker I also saw you tagged me in another issue. WDYT, do we also need a HybridDynamicCache class, which will be the default in forward calls and can be used in the above idea is not doable?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants