Skip to content

Commit

Permalink
don't zero out the attention_mask when using sliding window with flas…
Browse files Browse the repository at this point in the history
…h attention
  • Loading branch information
winglian committed Jun 27, 2024
1 parent 1c68f2c commit 3d4ca0c
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ def forward(
class Gemma2DecoderLayer(nn.Module):
def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size

self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
Expand All @@ -625,7 +626,7 @@ def forward(
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
if self.config._attn_implementation != "flash_attention_2" and self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
attention_mask = attention_mask * torch.tril(
torch.ones_like(attention_mask), diagonal=-self.sliding_window
)
Expand Down

0 comments on commit 3d4ca0c

Please sign in to comment.