Skip to content

Commit

Permalink
Feat (examples/llm): add first/last layer support (#699)
Browse files Browse the repository at this point in the history
  • Loading branch information
volcacius authored Aug 30, 2023
1 parent 0764832 commit 4ff62c7
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 4 deletions.
11 changes: 8 additions & 3 deletions src/brevitas/nn/quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@
import torch
from torch import Tensor
from torch.nn import Embedding
from torch.nn import EmbeddingBag
from torch.nn.functional import embedding

from brevitas.function.ops import max_int
from brevitas.function.ops_ste import ceil_ste
from brevitas.inject.defaults import Int8WeightPerTensorFloat
from brevitas.quant_tensor import QuantTensor

Expand Down Expand Up @@ -53,6 +50,14 @@ def __init__(
self.accept_quant_tensor = False
self.return_quant_tensor = return_quant_tensor

@property
def output_channel_dim(self) -> int:
return 0

@property
def out_channels(self) -> int:
return self.num_embeddings

def forward(self, inp):
quant_weight = self.quant_weight()
out = embedding(
Expand Down
6 changes: 6 additions & 0 deletions src/brevitas_examples/llm/llm_quant/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def quantize_model(
input_quant_granularity=None,
input_group_size=None,
quantize_input_zero_point=False,
quantize_embedding=False,
seqlen=None):
"""
Replace float layers with quant layers in the target model
Expand Down Expand Up @@ -281,4 +282,9 @@ def quantize_model(
layer_map = {
nn.Linear: (qnn.QuantLinear, quant_linear_kwargs),
nn.MultiheadAttention: (qnn.QuantMultiheadAttention, quant_mha_kwargs)}

if quantize_embedding:
quant_embedding_kwargs = {'weight_quant': weight_quant, 'dtype': dtype}
layer_map[nn.Embedding] = (qnn.QuantEmbedding, quant_embedding_kwargs)

layerwise_quantize(model=model, compute_layer_map=layer_map)
4 changes: 4 additions & 0 deletions src/brevitas_examples/llm/llm_quant/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def expanded_scaling_shape(module, block_size):
return module.weight.size(0), module.weight.size(1) // block_size, block_size, module.weight.size(2), module.weight.size(3)
elif isinstance(module, nn.Linear):
return module.weight.size(0), module.weight.size(1) // block_size, block_size
elif isinstance(module, nn.Embedding):
return module.weight.size(0), module.weight.size(1) // block_size, block_size
else:
raise RuntimeError("Module not supported.")

Expand All @@ -46,6 +48,8 @@ def scaling_shape(module, block_size):
return module.weight.size(0), module.weight.size(1) // block_size, 1, module.weight.size(2), module.weight.size(3)
elif isinstance(module, nn.Linear):
return module.weight.size(0), module.weight.size(1) // block_size, 1
elif isinstance(module, nn.Embedding):
return module.weight.size(0), module.weight.size(1) // block_size, 1
else:
raise RuntimeError("Module not supported.")

Expand Down
15 changes: 14 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@
help='Group size for per_group input quantization. Default: 64.')
parser.add_argument(
'--quantize-input-zero-point', action='store_true', help='Quantize input zero-point.')
parser.add_argument(
'--quantize-embedding', action='store_true', help='Quantize first nn.Embedding layer.')
parser.add_argument(
'--quantize-last-layer', action='store_true', help='Quantize last nn.Linear layer.')
parser.add_argument('--gptq', action='store_true', help='Apply GPTQ.')
parser.add_argument('--act-calibration', action='store_true', help='Apply activation calibration.')
parser.add_argument('--bias-corr', action='store_true', help='Apply bias correction.')
Expand Down Expand Up @@ -254,10 +258,15 @@ def main():
ref_kwargs={'input_ids': calibration_loader[0]})
print("Act equalization applied.")

if args.quantize_embedding or args.quantize_last_layer:
layers_to_quantize = model
else:
layers_to_quantize = get_model_impl(model).layers

if not args.no_quantize:
print("Applying model quantization...")
quantize_model(
get_model_impl(model).layers,
layers_to_quantize,
dtype=dtype,
weight_quant_type=args.weight_quant_type,
weight_bit_width=args.weight_bit_width,
Expand All @@ -274,7 +283,11 @@ def main():
input_quant_granularity=args.input_quant_granularity,
input_group_size=args.input_group_size,
quantize_input_zero_point=args.quantize_input_zero_point,
quantize_embedding=args.quantize_embedding,
seqlen=args.seqlen)
# Tie back first/last layer weights in case they got untied
model.tie_weights()
print(model)
print("Model quantization applied.")

if args.act_calibration:
Expand Down

0 comments on commit 4ff62c7

Please sign in to comment.