diff --git a/nncf/quantization/algorithms/weight_compression/awq.py b/nncf/quantization/algorithms/weight_compression/awq.py index dff55ce5f25..bd012e4d539 100644 --- a/nncf/quantization/algorithms/weight_compression/awq.py +++ b/nncf/quantization/algorithms/weight_compression/awq.py @@ -147,6 +147,9 @@ def apply( for match in matches: nncf_node = graph.get_node_by_key(match[-1]) + if not self._backend_entity.is_node_with_weights(nncf_node, graph): + continue + for weight_op_friendly_name, _ in self._backend_entity.get_weight_names_and_port_ids(nncf_node, graph): target_node_names.append(weight_op_friendly_name) @@ -192,9 +195,13 @@ def apply( top_k = max(int(s.shape[0] * self._percent_to_apply), 1) topk_idxs = fns.argsort(-s)[:top_k] + group_size = config.group_size + if group_size == -1: + group_size = s.shape[0] + groups_to_correct = set() for idx in topk_idxs: - groups_to_correct.add(idx.data // config.group_size) + groups_to_correct.add(idx.data // group_size) groups_to_correct = list(groups_to_correct) @@ -215,15 +222,15 @@ def apply( awq_config.group_size = -1 for gi in groups_to_correct: - offset = gi * config.group_size - gscale = s[offset : offset + config.group_size] + offset = gi * group_size + gscale = s[offset : offset + group_size] a_min = fns.quantile(gscale, 0.1) a_max = 1e2 gscale = fns.clip(gscale, a_min=a_min, a_max=a_max) - gweight = weight[:, offset : offset + config.group_size] - gacts = X[offset : offset + config.group_size, :] + gweight = weight[:, offset : offset + group_size] + gacts = X[offset : offset + group_size, :] fp32_out = fns.matmul(gweight, gacts) min_diff = fns.max(fns.abs(fp32_out)) @@ -247,7 +254,7 @@ def apply( alpha += alpha_step if best_scale is not None: - scale.data[offset : offset + config.group_size] = best_scale.data + scale.data[offset : offset + group_size] = best_scale.data a_scale = scale w_scale = scale diff --git a/tests/openvino/native/models.py b/tests/openvino/native/models.py index 2b83a95457a..339453e9202 100644 --- a/tests/openvino/native/models.py +++ b/tests/openvino/native/models.py @@ -903,41 +903,56 @@ class AWQMatmulModel(OVReferenceModel): Model for testing AWQ algorithm. Contains MatMul->Multiply->MatMul pattern. """ - def _create_ov_model(self): + def _get_weights(self, weights_data, is_int8, name): + if not is_int8: + return opset.constant(weights_data, dtype=np.float32, name=name) + else: + qw = opset.constant(weights_data, dtype=np.uint8, name="qw_" + name) + qw = opset.convert(qw, destination_type=np.float32) + + zp = opset.constant(np.array([2**7]), dtype=np.uint8, name="zp_" + name) + zp = opset.convert(zp, destination_type=np.float32) + + scale = opset.constant( + np.ones((weights_data.shape[0], 1), dtype=np.float32), dtype=np.float32, name="scale_" + name + ) + return (qw - zp) * scale + + def _create_ov_model(self, is_int8=False): input_node = opset.parameter([8, 8], name="Input_1") weights_data1 = np.arange(0, 64).reshape(8, 8) weights_data1[:] = 2.0 - weights1 = opset.constant(weights_data1, dtype=np.float32, name="weights_1") + weights1 = self._get_weights(weights_data1, is_int8, name="weights_1") node1 = opset.matmul(input_node, weights1, transpose_a=False, transpose_b=True, name="MatMul_1") weights_data2 = np.arange(0, 64).reshape(8, 8) weights_data2[:] = 3.0 - weights2 = opset.constant(weights_data2, dtype=np.float32, name="weights_2") + weights2 = self._get_weights(weights_data2, is_int8, name="weights_2") node2 = opset.matmul(input_node, weights2, transpose_a=False, transpose_b=True, name="MatMul_2") node_multiply = opset.multiply(node1, node2, name="Multiply") weights_data3 = np.arange(0, 64).reshape(8, 8) weights_data3[:] = 4.0 - weights3 = opset.constant(weights_data3, dtype=np.float32, name="weights_3") + weights3 = self._get_weights(weights_data3, is_int8, name="weights_3") node3 = opset.matmul(node_multiply, weights3, transpose_a=False, transpose_b=True, name="MatMul_3") weights_data4 = np.arange(0, 64).reshape(8, 8) weights_data4[:] = 2.0 - weights4 = opset.constant(weights_data4, dtype=np.float32, name="weights_4") + weights4 = self._get_weights(weights_data4, is_int8, name="weights_4") node4 = opset.matmul(node3, weights4, transpose_a=False, transpose_b=True, name="MatMul_4") weights_data5 = np.arange(0, 64).reshape(8, 8) weights_data5[:] = 3.0 - weights5 = opset.constant(weights_data5, dtype=np.float32, name="weights_5") + weights5 = self._get_weights(weights_data5, is_int8, name="weights_5") node5 = opset.matmul(node3, weights5, transpose_a=False, transpose_b=True, name="MatMul_5") node_multiply_2 = opset.multiply(node4, node5, name="Multiply_2") weights_data6 = np.arange(0, 64).reshape(8, 8) weights_data6[:] = 4.0 - weights6 = opset.constant(weights_data6, dtype=np.float32, name="weights_6") + weights6 = self._get_weights(weights_data6, is_int8, name="weights_6") node6 = opset.matmul(node_multiply_2, weights6, transpose_a=False, transpose_b=True, name="MatMul_6") result = opset.result(node6, name="Result") diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index cfdbb579b36..57ef30033f0 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -692,6 +692,22 @@ def test_call_max_var_criterion_with_dataset_by_default_awq(mode): compress_weights(model, mode=mode, ratio=1.0, group_size=2, dataset=dataset, awq=True) +@pytest.mark.parametrize("mode", INT4_MODES) +def test_call_max_var_criterion_with_dataset_awq_for_compressed_model(mode): + model = AWQMatmulModel(is_int8=True).ov_model + dataset = Dataset([np.ones([8, 8])]) + + compress_weights(model, mode=mode, ratio=1.0, group_size=2, dataset=dataset, awq=True) + + +@pytest.mark.parametrize("mode", INT4_MODES) +def test_call_max_var_criterion_with_dataset_awq_neg_group_size(mode): + model = AWQMatmulModel().ov_model + dataset = Dataset([np.ones([8, 8])]) + with pytest.raises(AttributeError): + compress_weights(model, mode=mode, ratio=1.0, group_size=-1, dataset=dataset, awq=True) + + def test_data_type_for_num_weights(mocker): stub = mocker.stub() params = WeightCompressionParameters(stub, stub, stub, np.int32(1), stub) @@ -779,3 +795,20 @@ def test_call_max_var_criterion_with_dataset_by_default_scale_estimation(mode): dataset = Dataset([np.ones([8, 8])]) compress_weights(model, mode=mode, ratio=1.0, group_size=2, dataset=dataset, scale_estimation=True) + + +@pytest.mark.parametrize("mode", INT4_MODES) +def test_call_max_var_criterion_with_dataset_scale_estimation_for_compressed_model(mode): + model = AWQMatmulModel(is_int8=True).ov_model + dataset = Dataset([np.ones([8, 8])]) + + compress_weights(model, mode=mode, ratio=1.0, group_size=2, dataset=dataset, scale_estimation=True) + + +@pytest.mark.parametrize("mode", INT4_MODES) +def test_call_max_var_criterion_with_dataset_scale_estimation_neg_group_size(mode): + model = AWQMatmulModel().ov_model + dataset = Dataset([np.ones([8, 8])]) + + with pytest.raises(AttributeError): + compress_weights(model, mode=mode, ratio=1.0, group_size=-1, dataset=dataset, scale_estimation=True)