Skip to content

Commit

Permalink
Fix (llm): correct handling of attention mask shape
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jul 4, 2023
1 parent fd4fb20 commit 8d48e24
Showing 1 changed file with 55 additions and 28 deletions.
83 changes: 55 additions & 28 deletions src/brevitas_examples/llm/llm_quant/mha_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,26 @@
from torch import nn


def attention_mask_handler(
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length):
"""Re-arrange attention mask to go from 4D to 3D (explicit batch_size and n_heads) or 2D
(implicit batch_size and n_heads)."""
if len(attention_mask.shape) == 4:
if attention_mask.shape[0] == 1:
attention_mask = attention_mask.repeat(batch_size, 1, 1, 1)
if attention_mask.shape[1] == 1:
attention_mask = attention_mask.repeat(1, num_heads, 1, 1)
if attention_mask.shape[2] == 1:
attention_mask = attention_mask.repeat(1, 1, query_seq_length, 1)
attention_mask = attention_mask.view(
batch_size * num_heads, query_seq_length, key_value_seq_length)
elif len(attention_mask.shape) == 2 and attention_mask.shape[0] == 1:
# This could happen in Encoder-like architecture
assert query_seq_length == key_value_seq_length
attention_mask = attention_mask.repeat(query_seq_length, 1)
return attention_mask


class MultiheadAttentionWrapper(nn.Module):

def __init__(
Expand Down Expand Up @@ -33,6 +53,41 @@ def __init__(
device,
dtype)


class QuantizableOPTAttention(MultiheadAttentionWrapper):

def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if key_value_states is None:
key_value_states = hidden_states
if layer_head_mask is not None:
raise RuntimeError("layer_head_mask is not supported.")
if self.mha.batch_first:
batch_size, query_seq_length = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[1]
else:
query_seq_length, batch_size = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[0]
num_heads = self.mha.num_heads
attention_mask = attention_mask_handler(
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length)
attn_output, attn_output_weights = self.mha(
hidden_states,
key_value_states,
key_value_states,
attn_mask=attention_mask,
need_weights=output_attentions,
average_attn_weights=False)
past_key_value = None
return attn_output, attn_output_weights, past_key_value

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
Expand Down Expand Up @@ -97,31 +152,3 @@ def set_weight(value):
del state_dict[name]
return super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)


class QuantizableOPTAttention(MultiheadAttentionWrapper):

def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if key_value_states is None:
key_value_states = hidden_states
if layer_head_mask is not None:
raise RuntimeError("layer_head_mask is not supported.")
if attention_mask is not None:
attention_mask = attention_mask.squeeze()
attn_output, attn_output_weights = self.mha(
hidden_states,
key_value_states,
key_value_states,
attn_mask=attention_mask,
need_weights=output_attentions,
average_attn_weights=False)
past_key_value = None
return attn_output, attn_output_weights, past_key_value

0 comments on commit 8d48e24

Please sign in to comment.