Skip to content

Commit

Permalink
int4 => int8 handler with bitwidth =4
Browse files Browse the repository at this point in the history
  • Loading branch information
mikekgfb committed Apr 9, 2024
1 parent ffe8d23 commit 8a53054
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ def __init__(self, mod, group_size):
def create_quantized_state_dict(self):
from hqq.core.quantize import Quantizer # TODO maybe torchao


for m in self.mod.modules():
for name, child in m.named_children():
if isinstance(child, torch.nn.Linear):
Expand All @@ -495,11 +496,16 @@ def create_quantized_state_dict(self):
)
)

return WeightOnlyInt4QuantHandler(self.mod, self.groupsize).create_quantized_state_dict()
# we use Int4 packaged in an int8 for now, packing to follow
# return WeightOnlyInt4QuantHandler(self.mod, self.groupsize).create_quantized_state_dict()
return WeightOnlyInt8QuantHandler(self.mod, bitwidth=4, self.groupsize).create_quantized_state_dict()

def _convert_for_runtime(self):
return WeightOnlyInt4GPTQQuantHandler(self.mod, self.groupsize).convert_for_runtime(use_cuda=True)

# we use Int4 packaged in an int8 for now, packing to follow
# ALSO: all code must work for CPU, CUDA, MPS
# return WeightOnlyInt4GPTQQuantHandler(self.mod, self.groupsize).convert_for_runtime(use_cuda=True)
return WeightOnlyInt4GPTQQuantHandler(self.mod, bitwidth=4, self.groupsize).convert_for_runtime()

def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict()
self.convert_for_runtime()
Expand Down

0 comments on commit 8a53054

Please sign in to comment.