Skip to content

Commit

Permalink
Fix errors after latest PR
Browse files Browse the repository at this point in the history
  • Loading branch information
volcacius committed Jul 7, 2023
1 parent 55c1c7c commit 951b5e2
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
27 changes: 25 additions & 2 deletions src/brevitas_examples/llm/llm_quant/mha_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import torch
from torch import nn

from brevitas.nn.equalized_layer import EqualizedModule
from brevitas.utils.torch_utils import KwargsForwardHook


def attention_mask_handler(
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length):
Expand Down Expand Up @@ -53,6 +56,26 @@ def __init__(
device,
dtype)

@property
def wrapped_mha(self):
mha = self.mha
# Workaround for activation equalization for when mha is wrapped
# KwargsForwardHook is inserted during act equalization
# EqualizedModule is inserted after act equalization
if isinstance(mha, KwargsForwardHook):
mha = mha.module
if isinstance(mha, EqualizedModule):
mha = mha.layer
return mha

@property
def num_heads(self):
return self.wrapped_mha.num_heads

@property
def batch_first(self):
return self.wrapped_mha.batch_first

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
Expand Down Expand Up @@ -134,13 +157,13 @@ def forward(
key_value_states = hidden_states
if layer_head_mask is not None:
raise RuntimeError("layer_head_mask is not supported.")
if self.mha.batch_first:
if self.batch_first:
batch_size, query_seq_length = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[1]
else:
query_seq_length, batch_size = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[0]
num_heads = self.mha.num_heads
num_heads = self.num_heads
attention_mask = attention_mask_handler(
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length)
attn_output, attn_output_weights = self.mha(
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas_examples/llm/llm_quant/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def quantize_model(
'q_scaled_quant': q_scaled_quant,
'k_transposed_quant': k_transposed_quant,
'v_quant': v_quant,
'out_proj_input_quant': linear_2d_input_quant,
'out_proj_input_quant': input_quant,
'out_proj_weight_quant': weight_quant,
'out_proj_bias_quant': None,
'out_proj_output_quant': None,
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
default='stats',
choices=['stats', 'mse'],
help=
'How scales/zero-point are determined. Default: stats (percentile for static, absmax minmax for dynamic).'
'How scales/zero-point are determined. Default: stats (percentile for static, absmax or minmax for dynamic).'
)
parser.add_argument(
'--input-scale-precision',
Expand All @@ -89,7 +89,7 @@
parser.add_argument(
'--input-scale-type',
type=str,
default='float',
default='static',
choices=['static', 'dynamic'],
help='Whether input scale is a static value or a dynamic value.')
parser.add_argument(
Expand Down

0 comments on commit 951b5e2

Please sign in to comment.