Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sliding window attention used in Gemma2FlashAttention2 #32522

Merged
merged 8 commits into from
Aug 12, 2024
4 changes: 3 additions & 1 deletion src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ def forward(
dropout=dropout_rate,
softmax_scale=self.scaling,
is_causal=self.is_causal,
sliding_window=self.sliding_window,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None,
)
Expand Down Expand Up @@ -567,7 +568,8 @@ def forward(
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask[:, -self.sliding_window :]
if past_key_value is not None: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
Copy link
Collaborator

@ArthurZucker ArthurZucker Aug 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah that looks better yeah, when you have the first forward pass you need the attention mask, full but have to use self.sliding_window when calling attention, but then when you decode the mask is too big, and the kv cache is sliding_window -1

else:
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
Expand Down
Loading