Skip to content

Commit

Permalink
fix cross attn and vision attn
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Oct 1, 2024
1 parent 6c9a312 commit 674ad4d
Showing 1 changed file with 42 additions and 61 deletions.
103 changes: 42 additions & 61 deletions src/transformers/models/mllama/modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from flash_attn import flash_attn_func
from torch import nn
from torch.nn import CrossEntropyLoss

Expand Down Expand Up @@ -289,7 +290,7 @@ class MllamaVisionFlashAttention2(MllamaVisionAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Check if flash attention version is greater or equal to 2.1
# Determine if FlashAttention uses the top-left mask based on its version
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

def forward(
Expand All @@ -299,59 +300,55 @@ def forward(
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if output_attentions:
# FlashAttention does not support returning attention weights
logger.warning_once(
"MllamaVisionFlashAttention does not support `output_attentions=True`. "
"Falling back to the manual attention implementation.",
UserWarning
"MllamaModel is using MllamaVisionFlashAttention2, but flash_attention_2 does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_state=hidden_state,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
return super().forward(hidden_state, attention_mask, output_attentions)

batch_size, seq_len, _ = hidden_state.size()
bsz, seq_len, _ = hidden_state.size()

# Compute query, key, and value projections
query_states = self.q_proj(hidden_state)
key_states = self.k_proj(hidden_state)
value_states = self.v_proj(hidden_state)

# Flash attention requires the input to have the shape
# batch_size x seq_length x num_heads x head_dim
query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

# Transpose to get the expected layout for flash attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

# Handle potential silent casting to float32
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
# Reshape and transpose to [batch_size, num_heads, seq_len, head_dim]
query_states = query_states.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# Prepare attention mask - sending attn mask triggers _upad_input which OOMs for vision tokens
attention_mask = None
# if attention_mask is not None:
# # Ensure attention_mask is of shape [batch_size, seq_len] and dtype torch.bool
# attention_mask = attention_mask.squeeze(1).squeeze(1)
# attention_mask = attention_mask.to(torch.bool)

# Call the _flash_attention_forward function
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
seq_len,
is_causal=False, # Vision attention is typically not causal
dropout=0.0, # MllamaVisionAttention doesn't have dropout
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=False, # Vision attention is typically not causal
)

attn_output = attn_output.reshape(batch_size, seq_len, -1).contiguous()
# Reshape attn_output to [batch_size, seq_len, embed_dim]
attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, -1)

# Apply the output projection
output = self.o_proj(attn_output)

return output, None # Return None for attn_weights as Flash Attention doesn't compute them
return output, None


class MllamaVisionSdpaAttention(MllamaVisionAttention):
Expand Down Expand Up @@ -647,8 +644,8 @@ def forward(

class MllamaTextCrossFlashAttention2(MllamaTextCrossAttention):
"""
Mllama flash cross-attention module. This module inherits from `MllamaTextCrossAttention` as the weights of the module
stay untouched. The main changes are in the forward pass to use the flash attention implementation.
Mllama flash cross-attention module. This module inherits from `MllamaTextCrossAttention` and
implements the forward pass using Flash Attention for improved performance.
"""

def __init__(self, *args, **kwargs):
Expand All @@ -667,7 +664,7 @@ def forward(
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
"""Input shape: Batch x seq_len x Channel"""
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
Expand Down Expand Up @@ -697,39 +694,23 @@ def forward(
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
)

# Transpose to match the expected input for flash attention
# Transpose to get the expected layout for flash attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

dropout_rate = self.dropout if self.training else 0.0

# Handle potential silent casting to float32
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype

query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=False, # Cross-attention is not causal
# Apply Flash Attention
dropout_rate = self.dropout if self.training else 0.0
output = flash_attn_func(
query_states, key_states, value_states,
dropout_p=dropout_rate,
softmax_scale=None,
causal=False,
return_attn_probs=output_attentions
)

attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = output.contiguous().view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)

if not output_attentions:
Expand Down

0 comments on commit 674ad4d

Please sign in to comment.