Skip to content

Commit

Permalink
Fix (llm): correct handling of attention mask shape (#652)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Jul 6, 2023
1 parent 4ee3335 commit 53ed201
Showing 1 changed file with 29 additions and 2 deletions.
31 changes: 29 additions & 2 deletions src/brevitas_examples/llm/llm_quant/mha_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,26 @@
from torch import nn


def attention_mask_handler(
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length):
"""Re-arrange attention mask to go from 4D to 3D (explicit batch_size and n_heads) or 2D
(implicit batch_size and n_heads)."""
if len(attention_mask.shape) == 4:
if attention_mask.shape[0] == 1:
attention_mask = attention_mask.repeat(batch_size, 1, 1, 1)
if attention_mask.shape[1] == 1:
attention_mask = attention_mask.repeat(1, num_heads, 1, 1)
if attention_mask.shape[2] == 1:
attention_mask = attention_mask.repeat(1, 1, query_seq_length, 1)
attention_mask = attention_mask.view(
batch_size * num_heads, query_seq_length, key_value_seq_length)
elif len(attention_mask.shape) == 2 and attention_mask.shape[0] == 1:
# This could happen in Encoder-like architecture
assert query_seq_length == key_value_seq_length
attention_mask = attention_mask.repeat(query_seq_length, 1)
return attention_mask


class MultiheadAttentionWrapper(nn.Module):

def __init__(
Expand Down Expand Up @@ -114,8 +134,15 @@ def forward(
key_value_states = hidden_states
if layer_head_mask is not None:
raise RuntimeError("layer_head_mask is not supported.")
if attention_mask is not None:
attention_mask = attention_mask.squeeze()
if self.mha.batch_first:
batch_size, query_seq_length = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[1]
else:
query_seq_length, batch_size = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[0]
num_heads = self.mha.num_heads
attention_mask = attention_mask_handler(
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length)
attn_output, attn_output_weights = self.mha(
hidden_states,
key_value_states,
Expand Down

0 comments on commit 53ed201

Please sign in to comment.