Skip to content

Commit

Permalink
Gemma2 and flash-attention (huggingface#32188)
Browse files Browse the repository at this point in the history
* enable flash-attn & static cache

* this works, not the prev

* fix for sliding window layers

* not needed anymore
  • Loading branch information
zucchini-nlp authored and tGhattas committed Jul 31, 2024
1 parent 4845653 commit a078483
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 13 deletions.
35 changes: 22 additions & 13 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,11 @@ def forward(
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

if attention_mask is not None:
seq_len = attention_mask.shape[1]
key_states = key_states[:, :, :seq_len]
value_states = value_states[:, :, :seq_len]

# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
Expand Down Expand Up @@ -510,16 +515,18 @@ def forward(
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if (
self.config._attn_implementation != "flash_attention_2" and self.is_sliding and attention_mask is not None
): # efficient SDPA and no padding
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask[:, -self.sliding_window :]
else:
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]

residual = hidden_states

Expand Down Expand Up @@ -824,10 +831,12 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# Flash Attention currently doesn't support static cache but Gemma2 work only with static cache.
# So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
# to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
# as it doesn't cause dynamic control issues.
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
return attention_mask

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
Expand Down
28 changes: 28 additions & 0 deletions tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@

import unittest

from pytest import mark

from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline
from transformers.testing_utils import (
require_flash_attn,
require_read_token,
require_torch,
require_torch_gpu,
Expand Down Expand Up @@ -161,3 +164,28 @@ def test_model_9b_pipeline_bf16(self):

self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0])
self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1])

@require_read_token
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_model_9b_flash_attn(self):
# See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context
model_id = "google/gemma-2-9b"
EXPECTED_TEXTS = [
'<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few',
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic consisting of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the"
] # fmt: skip

model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation="flash_attention_2", torch_dtype="float16"
).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)

output = model.generate(**inputs, max_new_tokens=100, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
print(output_text)

self.assertEqual(output_text, EXPECTED_TEXTS)

0 comments on commit a078483

Please sign in to comment.