Skip to content

Commit

Permalink
Fix for pre-compressed MatMuls in AWQ. (#2648)
Browse files Browse the repository at this point in the history
### Changes

1) Avoiding int8/int4 MatMuls processing in AWQ.
2) Tests for case of compression pre-compressed models.
3) Test for group_size == -1 and AWQ/scale estimation.

### Reason for changes

Compression incorrectly falls with AWQ and pre-compressed models with
int8/int4 MatMuls.

### Related tickets

CVS-137988

### Tests

Unit tests.
  • Loading branch information
andreyanufr authored May 3, 2024
1 parent ac15a65 commit 7985126
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 13 deletions.
19 changes: 13 additions & 6 deletions nncf/quantization/algorithms/weight_compression/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ def apply(

for match in matches:
nncf_node = graph.get_node_by_key(match[-1])
if not self._backend_entity.is_node_with_weights(nncf_node, graph):
continue

for weight_op_friendly_name, _ in self._backend_entity.get_weight_names_and_port_ids(nncf_node, graph):
target_node_names.append(weight_op_friendly_name)

Expand Down Expand Up @@ -192,9 +195,13 @@ def apply(
top_k = max(int(s.shape[0] * self._percent_to_apply), 1)
topk_idxs = fns.argsort(-s)[:top_k]

group_size = config.group_size
if group_size == -1:
group_size = s.shape[0]

groups_to_correct = set()
for idx in topk_idxs:
groups_to_correct.add(idx.data // config.group_size)
groups_to_correct.add(idx.data // group_size)

groups_to_correct = list(groups_to_correct)

Expand All @@ -215,15 +222,15 @@ def apply(
awq_config.group_size = -1

for gi in groups_to_correct:
offset = gi * config.group_size
gscale = s[offset : offset + config.group_size]
offset = gi * group_size
gscale = s[offset : offset + group_size]

a_min = fns.quantile(gscale, 0.1)
a_max = 1e2
gscale = fns.clip(gscale, a_min=a_min, a_max=a_max)

gweight = weight[:, offset : offset + config.group_size]
gacts = X[offset : offset + config.group_size, :]
gweight = weight[:, offset : offset + group_size]
gacts = X[offset : offset + group_size, :]

fp32_out = fns.matmul(gweight, gacts)
min_diff = fns.max(fns.abs(fp32_out))
Expand All @@ -247,7 +254,7 @@ def apply(
alpha += alpha_step

if best_scale is not None:
scale.data[offset : offset + config.group_size] = best_scale.data
scale.data[offset : offset + group_size] = best_scale.data

a_scale = scale
w_scale = scale
Expand Down
29 changes: 22 additions & 7 deletions tests/openvino/native/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,41 +903,56 @@ class AWQMatmulModel(OVReferenceModel):
Model for testing AWQ algorithm. Contains MatMul->Multiply->MatMul pattern.
"""

def _create_ov_model(self):
def _get_weights(self, weights_data, is_int8, name):
if not is_int8:
return opset.constant(weights_data, dtype=np.float32, name=name)
else:
qw = opset.constant(weights_data, dtype=np.uint8, name="qw_" + name)
qw = opset.convert(qw, destination_type=np.float32)

zp = opset.constant(np.array([2**7]), dtype=np.uint8, name="zp_" + name)
zp = opset.convert(zp, destination_type=np.float32)

scale = opset.constant(
np.ones((weights_data.shape[0], 1), dtype=np.float32), dtype=np.float32, name="scale_" + name
)
return (qw - zp) * scale

def _create_ov_model(self, is_int8=False):
input_node = opset.parameter([8, 8], name="Input_1")

weights_data1 = np.arange(0, 64).reshape(8, 8)
weights_data1[:] = 2.0
weights1 = opset.constant(weights_data1, dtype=np.float32, name="weights_1")
weights1 = self._get_weights(weights_data1, is_int8, name="weights_1")
node1 = opset.matmul(input_node, weights1, transpose_a=False, transpose_b=True, name="MatMul_1")

weights_data2 = np.arange(0, 64).reshape(8, 8)
weights_data2[:] = 3.0
weights2 = opset.constant(weights_data2, dtype=np.float32, name="weights_2")
weights2 = self._get_weights(weights_data2, is_int8, name="weights_2")
node2 = opset.matmul(input_node, weights2, transpose_a=False, transpose_b=True, name="MatMul_2")

node_multiply = opset.multiply(node1, node2, name="Multiply")

weights_data3 = np.arange(0, 64).reshape(8, 8)
weights_data3[:] = 4.0
weights3 = opset.constant(weights_data3, dtype=np.float32, name="weights_3")
weights3 = self._get_weights(weights_data3, is_int8, name="weights_3")
node3 = opset.matmul(node_multiply, weights3, transpose_a=False, transpose_b=True, name="MatMul_3")

weights_data4 = np.arange(0, 64).reshape(8, 8)
weights_data4[:] = 2.0
weights4 = opset.constant(weights_data4, dtype=np.float32, name="weights_4")
weights4 = self._get_weights(weights_data4, is_int8, name="weights_4")
node4 = opset.matmul(node3, weights4, transpose_a=False, transpose_b=True, name="MatMul_4")

weights_data5 = np.arange(0, 64).reshape(8, 8)
weights_data5[:] = 3.0
weights5 = opset.constant(weights_data5, dtype=np.float32, name="weights_5")
weights5 = self._get_weights(weights_data5, is_int8, name="weights_5")
node5 = opset.matmul(node3, weights5, transpose_a=False, transpose_b=True, name="MatMul_5")

node_multiply_2 = opset.multiply(node4, node5, name="Multiply_2")

weights_data6 = np.arange(0, 64).reshape(8, 8)
weights_data6[:] = 4.0
weights6 = opset.constant(weights_data6, dtype=np.float32, name="weights_6")
weights6 = self._get_weights(weights_data6, is_int8, name="weights_6")
node6 = opset.matmul(node_multiply_2, weights6, transpose_a=False, transpose_b=True, name="MatMul_6")

result = opset.result(node6, name="Result")
Expand Down
33 changes: 33 additions & 0 deletions tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,22 @@ def test_call_max_var_criterion_with_dataset_by_default_awq(mode):
compress_weights(model, mode=mode, ratio=1.0, group_size=2, dataset=dataset, awq=True)


@pytest.mark.parametrize("mode", INT4_MODES)
def test_call_max_var_criterion_with_dataset_awq_for_compressed_model(mode):
model = AWQMatmulModel(is_int8=True).ov_model
dataset = Dataset([np.ones([8, 8])])

compress_weights(model, mode=mode, ratio=1.0, group_size=2, dataset=dataset, awq=True)


@pytest.mark.parametrize("mode", INT4_MODES)
def test_call_max_var_criterion_with_dataset_awq_neg_group_size(mode):
model = AWQMatmulModel().ov_model
dataset = Dataset([np.ones([8, 8])])
with pytest.raises(AttributeError):
compress_weights(model, mode=mode, ratio=1.0, group_size=-1, dataset=dataset, awq=True)


def test_data_type_for_num_weights(mocker):
stub = mocker.stub()
params = WeightCompressionParameters(stub, stub, stub, np.int32(1), stub)
Expand Down Expand Up @@ -779,3 +795,20 @@ def test_call_max_var_criterion_with_dataset_by_default_scale_estimation(mode):
dataset = Dataset([np.ones([8, 8])])

compress_weights(model, mode=mode, ratio=1.0, group_size=2, dataset=dataset, scale_estimation=True)


@pytest.mark.parametrize("mode", INT4_MODES)
def test_call_max_var_criterion_with_dataset_scale_estimation_for_compressed_model(mode):
model = AWQMatmulModel(is_int8=True).ov_model
dataset = Dataset([np.ones([8, 8])])

compress_weights(model, mode=mode, ratio=1.0, group_size=2, dataset=dataset, scale_estimation=True)


@pytest.mark.parametrize("mode", INT4_MODES)
def test_call_max_var_criterion_with_dataset_scale_estimation_neg_group_size(mode):
model = AWQMatmulModel().ov_model
dataset = Dataset([np.ones([8, 8])])

with pytest.raises(AttributeError):
compress_weights(model, mode=mode, ratio=1.0, group_size=-1, dataset=dataset, scale_estimation=True)

0 comments on commit 7985126

Please sign in to comment.