Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma2 and flash-attention #32188

Merged
merged 5 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,11 @@ def forward(
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

if attention_mask is not None:
seq_len = attention_mask.shape[1]
key_states = key_states[:, :, :seq_len]
value_states = value_states[:, :, :seq_len]
Comment on lines +329 to +330
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, yeah static cache does not support that because it's input dependent, but great catch


# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
Expand Down Expand Up @@ -821,10 +826,12 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# Flash Attention currently doesn't support static cache but Gemma2 work only with static cache.
# So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
# to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
# as it doesn't cause dynamic control issues.
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
return attention_mask

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
Expand Down
29 changes: 29 additions & 0 deletions tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@

import unittest

from pytest import mark

from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline
from transformers.testing_utils import (
require_flash_attn,
require_read_token,
require_torch,
require_torch_gpu,
Expand Down Expand Up @@ -161,3 +164,29 @@ def test_model_9b_pipeline_bf16(self):

self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0])
self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1])

@require_read_token
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_model_9b_flash_attn(self):
# See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context
# NOTE: the quality is a lot better whan fp16 is used, and worse for bf16

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate more on this? Specifically, does using flash_attention_2 negatively impact the quality of the model's output with bfloat16 compared to using eager mode?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this was true until I discovered a bug I made. It doesn't matter much which precision is used, removed the comment

model_id = "google/gemma-2-9b"
EXPECTED_TEXTS = [
'<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few',
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic consisting of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the"
] # fmt: skip

model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation="flash_attention_2", torch_dtype="float16"
).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)

output = model.generate(**inputs, max_new_tokens=100, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
print(output_text)

self.assertEqual(output_text, EXPECTED_TEXTS)
Loading