Skip to content

Commit

Permalink
Support fp16 model to weight-only quantization for PyTorch framework (#…
Browse files Browse the repository at this point in the history
…1387)

Signed-off-by: Cheng, Penghui <[email protected]>
  • Loading branch information
PenghuiCheng authored Nov 16, 2023
1 parent d81269d commit d5cb567
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
5 changes: 5 additions & 0 deletions neural_compressor/adaptor/torch_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,9 @@ def rtn_quantize(
for name, m in model.named_modules():
if m.__class__.__name__ not in supported_layers:
continue
orig_dtype = next(m.parameters()).dtype
if orig_dtype != torch.float:
m = m.float()
if name in weight_config: # pragma: no cover
num_bits = weight_config[name]["bits"]
group_size = weight_config[name]["group_size"]
Expand Down Expand Up @@ -466,6 +469,8 @@ def rtn_quantize(
)
q_weight = q_weight.T if group_dim == 0 else q_weight
m.weight.data.copy_(q_weight)
if orig_dtype != torch.float:
m = m.to(orig_dtype)
return model


Expand Down
3 changes: 2 additions & 1 deletion test/quantization/test_weight_only_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_trace(self):

def test_rtn(self):
fp32_model = copy.deepcopy(self.model)
fp16_model = copy.deepcopy(self.model).to(torch.float16)
model1 = rtn_quantize(fp32_model, num_bits=3, group_size=-1)
self.assertTrue(isinstance(model1.fc1, torch.nn.Linear))
weight_config = {
Expand All @@ -67,7 +68,7 @@ def test_rtn(self):
},
}
model2 = rtn_quantize(fp32_model, weight_config=weight_config)
model2 = rtn_quantize(fp32_model, weight_config=weight_config, return_int=True)
model2 = rtn_quantize(fp16_model, weight_config=weight_config, return_int=True)
self.assertTrue(isinstance(model2.fc1, WeightOnlyLinear))

def test_awq(self):
Expand Down

0 comments on commit d5cb567

Please sign in to comment.