Skip to content

Commit

Permalink
Fix in-place processing error in quant_weight function (#1703)
Browse files Browse the repository at this point in the history
Signed-off-by: xin3he <[email protected]>
Signed-off-by: Cheng, Penghui <[email protected]>
  • Loading branch information
xin3he authored Apr 1, 2024
1 parent 3b150d6 commit 92533a6
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 16 deletions.
21 changes: 13 additions & 8 deletions neural_compressor/adaptor/torch_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def qdq_weight_sym(weight, num_bits=4, quantile=1.0, return_int=False, full_rang
weight.round_()
weight.clamp_(minq, maxq)
if return_int:
return weight, scale.type(torch.float), None
return weight, scale, None
return weight.mul_(scale)


Expand Down Expand Up @@ -238,6 +238,7 @@ def quant_weight(

orig_shape = weight.shape
if weight.shape[1] % group_size == 0:
orig_weight = weight
weight = weight.reshape(-1, group_size)
if return_int:
weight, scale, zp = qdq_weight_actor(
Expand All @@ -250,17 +251,21 @@ def quant_weight(
data_type=data_type,
)
weight = weight.reshape(orig_shape)
orig_weight.copy_(weight)
scale = scale.reshape(orig_shape[0], -1)
if zp is not None:
zp = zp.reshape(orig_shape[0], -1)
return weight, scale, zp
return orig_weight, scale, zp
else:
qdq_weight_actor(
weight, num_bits, scheme=scheme, data_type=data_type, quantile=quantile, full_range=full_range
)
return weight.reshape(orig_shape)
weight = weight.reshape(orig_shape)
orig_weight.copy_(weight)
return orig_weight
else:
split_index = weight.shape[1] // group_size * group_size
orig_weight = weight
weight1 = weight[:, :split_index]
weight1 = weight1.reshape(-1, group_size)
if return_int:
Expand All @@ -277,7 +282,7 @@ def quant_weight(
if zp1 is not None:
zp1 = zp1.reshape(orig_shape[0], -1)
else:
weight1 = qdq_weight_actor(
qdq_weight_actor(
weight1, num_bits, scheme=scheme, quantile=quantile, data_type=data_type, full_range=full_range
)
weight1 = weight1.reshape(orig_shape[0], split_index)
Expand All @@ -292,19 +297,19 @@ def quant_weight(
return_int=True,
full_range=full_range,
)
weight.copy_(torch.cat([weight1, weight2], dim=1))
orig_weight.copy_(torch.cat([weight1, weight2], dim=1))
scale = torch.cat([scale1, scale2], dim=1)
if zp2 is not None:
zp = torch.cat([zp1, zp2], dim=1)
else:
zp = None
return weight, scale, zp
return orig_weight, scale, zp
else:
weight2 = qdq_weight_actor(
weight2, num_bits, scheme=scheme, data_type=data_type, quantile=quantile, full_range=full_range
)
weight.copy_(torch.cat([weight1, weight2], dim=1))
return weight
orig_weight.copy_(torch.cat([weight1, weight2], dim=1))
return orig_weight


def search_clip(m, num_bits=4, group_size=32, scheme="asym", data_type="int", enable_full_range=False):
Expand Down
22 changes: 14 additions & 8 deletions neural_compressor/torch/algorithms/weight_only/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,10 @@ def quant_tensor(
group_size = weight.shape[1]
# case 2, reshape based on group size
orig_shape = weight.shape
orig_weight = weight
if weight.shape[1] % group_size == 0:
weight = weight.reshape(-1, group_size)
# return weight for unpacking scale and zp
weight = qdq_weight_actor(
weight,
bits,
Expand All @@ -281,12 +283,15 @@ def quant_tensor(
if return_int or quant_scale:
weight, scale, zp = weight
weight = weight.reshape(orig_shape)
orig_weight.copy_(weight)
scale = scale.reshape(orig_shape[0], -1)
if zp is not None:
zp = zp.reshape(orig_shape[0], -1)
q_state = weight, scale, zp
q_state = orig_weight, scale, zp
else:
return weight.reshape(orig_shape)
weight = weight.reshape(orig_shape)
orig_weight.copy_(weight)
return orig_weight
else:
# case 3, process left part split by group size
split_index = weight.shape[1] // group_size * group_size
Expand Down Expand Up @@ -321,13 +326,13 @@ def quant_tensor(
)
if return_int or quant_scale:
weight2, scale2, zp2 = weight2
weight.copy_(torch.cat([weight1, weight2], dim=1))
orig_weight.copy_(torch.cat([weight1, weight2], dim=1))
scale = torch.cat([scale1, scale2], dim=1)
zp = None if zp2 is None else torch.cat([zp1, zp2], dim=1)
q_state = (weight, scale, zp)
else:
weight.copy_(torch.cat([weight1, weight2], dim=1))
return weight
orig_weight.copy_(torch.cat([weight1, weight2], dim=1))
return orig_weight
if quant_scale:
weight, scale, zp = q_state
scale_dtype = kwargs.get("double_quant_dtype", "int")
Expand All @@ -343,7 +348,7 @@ def quant_tensor(
scale.sub_(scale_mean)
scale_scheme = "sym"
# process: scale
scale = quant_tensor(
quant_tensor(
scale,
dtype=scale_dtype,
bits=scale_bits,
Expand Down Expand Up @@ -375,15 +380,16 @@ def quant_tensor(
weight1 = weight1.mul_(scale[:, :-1].reshape(-1, 1))
weight1 = weight1.reshape(orig_shape[0], -1)
weight2 = weight2.mul_(scale[:, -1].reshape(-1, 1))
weight.copy_(torch.cat([weight1, weight2], dim=1))
orig_weight.copy_(torch.cat([weight1, weight2], dim=1))
else:
if zp is not None:
weight = weight.reshape(-1, group_size) - zp.reshape(-1, 1)
else:
weight = weight.reshape(-1, group_size)
weight = weight.mul_(scale.reshape(-1, 1))
weight = weight.reshape(orig_shape[0], -1)
return weight
orig_weight.copy_(weight)
return orig_weight
else:
return q_state

Expand Down
13 changes: 13 additions & 0 deletions test/3x/torch/algorithms/weight_only/test_woq_utility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest
import torch


@pytest.mark.parametrize("shape", [1024, 512, 300])
def test_quant_tensor_id(shape):
from neural_compressor.torch.algorithms.weight_only.utility import quant_tensor

input = torch.randn(shape, shape)
id1 = id(input)
output = quant_tensor(input)
id2 = id(output)
assert id1 == id2, "quant_tensor function is an in-place operator"

0 comments on commit 92533a6

Please sign in to comment.