From 8a53054bd4934c02a2d4441bcd16931e9745c3c2 Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Tue, 9 Apr 2024 16:06:01 -0700 Subject: [PATCH] int4 => int8 handler with bitwidth =4 --- quantize.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/quantize.py b/quantize.py index c2b27c1da..8b26638fe 100644 --- a/quantize.py +++ b/quantize.py @@ -481,6 +481,7 @@ def __init__(self, mod, group_size): def create_quantized_state_dict(self): from hqq.core.quantize import Quantizer # TODO maybe torchao + for m in self.mod.modules(): for name, child in m.named_children(): if isinstance(child, torch.nn.Linear): @@ -495,11 +496,16 @@ def create_quantized_state_dict(self): ) ) - return WeightOnlyInt4QuantHandler(self.mod, self.groupsize).create_quantized_state_dict() + # we use Int4 packaged in an int8 for now, packing to follow + # return WeightOnlyInt4QuantHandler(self.mod, self.groupsize).create_quantized_state_dict() + return WeightOnlyInt8QuantHandler(self.mod, bitwidth=4, self.groupsize).create_quantized_state_dict() def _convert_for_runtime(self): - return WeightOnlyInt4GPTQQuantHandler(self.mod, self.groupsize).convert_for_runtime(use_cuda=True) - + # we use Int4 packaged in an int8 for now, packing to follow + # ALSO: all code must work for CPU, CUDA, MPS + # return WeightOnlyInt4GPTQQuantHandler(self.mod, self.groupsize).convert_for_runtime(use_cuda=True) + return WeightOnlyInt4GPTQQuantHandler(self.mod, bitwidth=4, self.groupsize).convert_for_runtime() + def quantized_model(self) -> nn.Module: model_updated_state_dict = self.create_quantized_state_dict() self.convert_for_runtime()