diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index e23bb876c3f037..ee5af616ec2b03 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -427,6 +427,7 @@ def forward( dropout=dropout_rate, softmax_scale=self.scaling, is_causal=self.is_causal, + sliding_window=self.sliding_window, use_top_left_mask=self._flash_attn_uses_top_left_mask, softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, ) @@ -567,7 +568,8 @@ def forward( if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding # Flash-attn is a 2D tensor if self.config._attn_implementation == "flash_attention_2": - attention_mask = attention_mask[:, -self.sliding_window :] + if past_key_value is not None: # when decoding + attention_mask = attention_mask[:, -self.sliding_window :] else: min_dtype = torch.finfo(hidden_states.dtype).min sliding_window_mask = torch.tril(