Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix (llm): correct handling of attention mask shape #652

Merged
merged 1 commit into from
Jul 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions src/brevitas_examples/llm/llm_quant/mha_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,26 @@
from torch import nn


def attention_mask_handler(
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length):
"""Re-arrange attention mask to go from 4D to 3D (explicit batch_size and n_heads) or 2D
(implicit batch_size and n_heads)."""
if len(attention_mask.shape) == 4:
if attention_mask.shape[0] == 1:
attention_mask = attention_mask.repeat(batch_size, 1, 1, 1)
if attention_mask.shape[1] == 1:
attention_mask = attention_mask.repeat(1, num_heads, 1, 1)
if attention_mask.shape[2] == 1:
attention_mask = attention_mask.repeat(1, 1, query_seq_length, 1)
attention_mask = attention_mask.view(
batch_size * num_heads, query_seq_length, key_value_seq_length)
elif len(attention_mask.shape) == 2 and attention_mask.shape[0] == 1:
# This could happen in Encoder-like architecture
assert query_seq_length == key_value_seq_length
attention_mask = attention_mask.repeat(query_seq_length, 1)
return attention_mask


class MultiheadAttentionWrapper(nn.Module):

def __init__(
Expand Down Expand Up @@ -114,8 +134,15 @@ def forward(
key_value_states = hidden_states
if layer_head_mask is not None:
raise RuntimeError("layer_head_mask is not supported.")
if attention_mask is not None:
attention_mask = attention_mask.squeeze()
if self.mha.batch_first:
batch_size, query_seq_length = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[1]
else:
query_seq_length, batch_size = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[0]
num_heads = self.mha.num_heads
attention_mask = attention_mask_handler(
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length)
attn_output, attn_output_weights = self.mha(
hidden_states,
key_value_states,
Expand Down