From a078483550211b0624fb507de51fa605ccb12d94 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 31 Jul 2024 10:33:38 +0500 Subject: [PATCH] Gemma2 and flash-attention (#32188) * enable flash-attn & static cache * this works, not the prev * fix for sliding window layers * not needed anymore --- .../models/gemma2/modeling_gemma2.py | 35 ++++++++++++------- tests/models/gemma2/test_modeling_gemma2.py | 28 +++++++++++++++ 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index db564aa713b43a..53e08df4ce54a4 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -327,6 +327,11 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + if attention_mask is not None: + seq_len = attention_mask.shape[1] + key_states = key_states[:, :, :seq_len] + value_states = value_states[:, :, :seq_len] + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) @@ -510,16 +515,18 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if ( - self.config._attn_implementation != "flash_attention_2" and self.is_sliding and attention_mask is not None - ): # efficient SDPA and no padding - min_dtype = torch.finfo(hidden_states.dtype).min - sliding_window_mask = torch.tril( - torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window - ) - attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) - if attention_mask.shape[-1] <= 1: # when decoding - attention_mask = attention_mask[:, :, :, -self.sliding_window :] + if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding + # Flash-attn is a 2D tensor + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask[:, -self.sliding_window :] + else: + min_dtype = torch.finfo(hidden_states.dtype).min + sliding_window_mask = torch.tril( + torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window + ) + attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) + if attention_mask.shape[-1] <= 1: # when decoding + attention_mask = attention_mask[:, :, :, -self.sliding_window :] residual = hidden_states @@ -824,10 +831,12 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): + # Flash Attention currently doesn't support static cache but Gemma2 work only with static cache. + # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape + # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible + # as it doesn't cause dynamic control issues. if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None + return attention_mask dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 20b8ea3ec5c825..1229ca47eb69c7 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -16,8 +16,11 @@ import unittest +from pytest import mark + from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline from transformers.testing_utils import ( + require_flash_attn, require_read_token, require_torch, require_torch_gpu, @@ -161,3 +164,28 @@ def test_model_9b_pipeline_bf16(self): self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0]) self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1]) + + @require_read_token + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_model_9b_flash_attn(self): + # See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context + model_id = "google/gemma-2-9b" + EXPECTED_TEXTS = [ + 'Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few', + "Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic consisting of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the" + ] # fmt: skip + + model = AutoModelForCausalLM.from_pretrained( + model_id, attn_implementation="flash_attention_2", torch_dtype="float16" + ).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=100, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=False) + print(output_text) + + self.assertEqual(output_text, EXPECTED_TEXTS)