From d5cb567f0a86ecec3ef561ee69fc44c0cb6cb248 Mon Sep 17 00:00:00 2001 From: "Cheng, Penghui" Date: Thu, 16 Nov 2023 10:51:03 +0800 Subject: [PATCH] Support fp16 model to weight-only quantization for PyTorch framework (#1387) Signed-off-by: Cheng, Penghui --- neural_compressor/adaptor/torch_utils/weight_only.py | 5 +++++ test/quantization/test_weight_only_quantization.py | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/neural_compressor/adaptor/torch_utils/weight_only.py b/neural_compressor/adaptor/torch_utils/weight_only.py index 7ba86eaa344..c71fe7df02b 100644 --- a/neural_compressor/adaptor/torch_utils/weight_only.py +++ b/neural_compressor/adaptor/torch_utils/weight_only.py @@ -399,6 +399,9 @@ def rtn_quantize( for name, m in model.named_modules(): if m.__class__.__name__ not in supported_layers: continue + orig_dtype = next(m.parameters()).dtype + if orig_dtype != torch.float: + m = m.float() if name in weight_config: # pragma: no cover num_bits = weight_config[name]["bits"] group_size = weight_config[name]["group_size"] @@ -466,6 +469,8 @@ def rtn_quantize( ) q_weight = q_weight.T if group_dim == 0 else q_weight m.weight.data.copy_(q_weight) + if orig_dtype != torch.float: + m = m.to(orig_dtype) return model diff --git a/test/quantization/test_weight_only_quantization.py b/test/quantization/test_weight_only_quantization.py index 3c700685676..087b985b15b 100644 --- a/test/quantization/test_weight_only_quantization.py +++ b/test/quantization/test_weight_only_quantization.py @@ -54,6 +54,7 @@ def test_trace(self): def test_rtn(self): fp32_model = copy.deepcopy(self.model) + fp16_model = copy.deepcopy(self.model).to(torch.float16) model1 = rtn_quantize(fp32_model, num_bits=3, group_size=-1) self.assertTrue(isinstance(model1.fc1, torch.nn.Linear)) weight_config = { @@ -67,7 +68,7 @@ def test_rtn(self): }, } model2 = rtn_quantize(fp32_model, weight_config=weight_config) - model2 = rtn_quantize(fp32_model, weight_config=weight_config, return_int=True) + model2 = rtn_quantize(fp16_model, weight_config=weight_config, return_int=True) self.assertTrue(isinstance(model2.fc1, WeightOnlyLinear)) def test_awq(self):