From 92533a69543d0a20b3bed4adb7d6787cdcd70f41 Mon Sep 17 00:00:00 2001 From: xinhe Date: Mon, 1 Apr 2024 14:03:14 +0800 Subject: [PATCH] Fix in-place processing error in quant_weight function (#1703) Signed-off-by: xin3he Signed-off-by: Cheng, Penghui --- .../adaptor/torch_utils/weight_only.py | 21 +++++++++++------- .../torch/algorithms/weight_only/utility.py | 22 ++++++++++++------- .../weight_only/test_woq_utility.py | 13 +++++++++++ 3 files changed, 40 insertions(+), 16 deletions(-) create mode 100644 test/3x/torch/algorithms/weight_only/test_woq_utility.py diff --git a/neural_compressor/adaptor/torch_utils/weight_only.py b/neural_compressor/adaptor/torch_utils/weight_only.py index 53ea98251a8..8a1a683483f 100644 --- a/neural_compressor/adaptor/torch_utils/weight_only.py +++ b/neural_compressor/adaptor/torch_utils/weight_only.py @@ -176,7 +176,7 @@ def qdq_weight_sym(weight, num_bits=4, quantile=1.0, return_int=False, full_rang weight.round_() weight.clamp_(minq, maxq) if return_int: - return weight, scale.type(torch.float), None + return weight, scale, None return weight.mul_(scale) @@ -238,6 +238,7 @@ def quant_weight( orig_shape = weight.shape if weight.shape[1] % group_size == 0: + orig_weight = weight weight = weight.reshape(-1, group_size) if return_int: weight, scale, zp = qdq_weight_actor( @@ -250,17 +251,21 @@ def quant_weight( data_type=data_type, ) weight = weight.reshape(orig_shape) + orig_weight.copy_(weight) scale = scale.reshape(orig_shape[0], -1) if zp is not None: zp = zp.reshape(orig_shape[0], -1) - return weight, scale, zp + return orig_weight, scale, zp else: qdq_weight_actor( weight, num_bits, scheme=scheme, data_type=data_type, quantile=quantile, full_range=full_range ) - return weight.reshape(orig_shape) + weight = weight.reshape(orig_shape) + orig_weight.copy_(weight) + return orig_weight else: split_index = weight.shape[1] // group_size * group_size + orig_weight = weight weight1 = weight[:, :split_index] weight1 = weight1.reshape(-1, group_size) if return_int: @@ -277,7 +282,7 @@ def quant_weight( if zp1 is not None: zp1 = zp1.reshape(orig_shape[0], -1) else: - weight1 = qdq_weight_actor( + qdq_weight_actor( weight1, num_bits, scheme=scheme, quantile=quantile, data_type=data_type, full_range=full_range ) weight1 = weight1.reshape(orig_shape[0], split_index) @@ -292,19 +297,19 @@ def quant_weight( return_int=True, full_range=full_range, ) - weight.copy_(torch.cat([weight1, weight2], dim=1)) + orig_weight.copy_(torch.cat([weight1, weight2], dim=1)) scale = torch.cat([scale1, scale2], dim=1) if zp2 is not None: zp = torch.cat([zp1, zp2], dim=1) else: zp = None - return weight, scale, zp + return orig_weight, scale, zp else: weight2 = qdq_weight_actor( weight2, num_bits, scheme=scheme, data_type=data_type, quantile=quantile, full_range=full_range ) - weight.copy_(torch.cat([weight1, weight2], dim=1)) - return weight + orig_weight.copy_(torch.cat([weight1, weight2], dim=1)) + return orig_weight def search_clip(m, num_bits=4, group_size=32, scheme="asym", data_type="int", enable_full_range=False): diff --git a/neural_compressor/torch/algorithms/weight_only/utility.py b/neural_compressor/torch/algorithms/weight_only/utility.py index 1fd98d82715..0029016459f 100644 --- a/neural_compressor/torch/algorithms/weight_only/utility.py +++ b/neural_compressor/torch/algorithms/weight_only/utility.py @@ -266,8 +266,10 @@ def quant_tensor( group_size = weight.shape[1] # case 2, reshape based on group size orig_shape = weight.shape + orig_weight = weight if weight.shape[1] % group_size == 0: weight = weight.reshape(-1, group_size) + # return weight for unpacking scale and zp weight = qdq_weight_actor( weight, bits, @@ -281,12 +283,15 @@ def quant_tensor( if return_int or quant_scale: weight, scale, zp = weight weight = weight.reshape(orig_shape) + orig_weight.copy_(weight) scale = scale.reshape(orig_shape[0], -1) if zp is not None: zp = zp.reshape(orig_shape[0], -1) - q_state = weight, scale, zp + q_state = orig_weight, scale, zp else: - return weight.reshape(orig_shape) + weight = weight.reshape(orig_shape) + orig_weight.copy_(weight) + return orig_weight else: # case 3, process left part split by group size split_index = weight.shape[1] // group_size * group_size @@ -321,13 +326,13 @@ def quant_tensor( ) if return_int or quant_scale: weight2, scale2, zp2 = weight2 - weight.copy_(torch.cat([weight1, weight2], dim=1)) + orig_weight.copy_(torch.cat([weight1, weight2], dim=1)) scale = torch.cat([scale1, scale2], dim=1) zp = None if zp2 is None else torch.cat([zp1, zp2], dim=1) q_state = (weight, scale, zp) else: - weight.copy_(torch.cat([weight1, weight2], dim=1)) - return weight + orig_weight.copy_(torch.cat([weight1, weight2], dim=1)) + return orig_weight if quant_scale: weight, scale, zp = q_state scale_dtype = kwargs.get("double_quant_dtype", "int") @@ -343,7 +348,7 @@ def quant_tensor( scale.sub_(scale_mean) scale_scheme = "sym" # process: scale - scale = quant_tensor( + quant_tensor( scale, dtype=scale_dtype, bits=scale_bits, @@ -375,7 +380,7 @@ def quant_tensor( weight1 = weight1.mul_(scale[:, :-1].reshape(-1, 1)) weight1 = weight1.reshape(orig_shape[0], -1) weight2 = weight2.mul_(scale[:, -1].reshape(-1, 1)) - weight.copy_(torch.cat([weight1, weight2], dim=1)) + orig_weight.copy_(torch.cat([weight1, weight2], dim=1)) else: if zp is not None: weight = weight.reshape(-1, group_size) - zp.reshape(-1, 1) @@ -383,7 +388,8 @@ def quant_tensor( weight = weight.reshape(-1, group_size) weight = weight.mul_(scale.reshape(-1, 1)) weight = weight.reshape(orig_shape[0], -1) - return weight + orig_weight.copy_(weight) + return orig_weight else: return q_state diff --git a/test/3x/torch/algorithms/weight_only/test_woq_utility.py b/test/3x/torch/algorithms/weight_only/test_woq_utility.py new file mode 100644 index 00000000000..f672ec0ac1c --- /dev/null +++ b/test/3x/torch/algorithms/weight_only/test_woq_utility.py @@ -0,0 +1,13 @@ +import pytest +import torch + + +@pytest.mark.parametrize("shape", [1024, 512, 300]) +def test_quant_tensor_id(shape): + from neural_compressor.torch.algorithms.weight_only.utility import quant_tensor + + input = torch.randn(shape, shape) + id1 = id(input) + output = quant_tensor(input) + id2 = id(output) + assert id1 == id2, "quant_tensor function is an in-place operator"