diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index f4d99638f62..5a5649a7d45 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -142,7 +142,7 @@ def transform_model( const_attributes = wc_params.node_with_weight.layer_attributes.constant_attributes[wc_params.weight_port_id] const_node_name = const_attributes["name"] const_node = self.name_to_node_mapping[const_node_name] - const_dtype = const_node.output(0).get_element_type().to_dtype() + const_dtype = const_node.output(0).get_element_type() weight = Tensor(get_const_value(const_node)) original_shape = weight.shape @@ -151,19 +151,21 @@ def transform_model( compressed_const = opset.constant( compressed_weight.tensor.data, dtype=compression_dtype, name=const_node_name ) - converted_const = opset.convert(compressed_const, const_dtype) + converted_const = opset.convert(compressed_const, ov.Type.f16) if compressed_weight.zero_point is not None: zero_point_const = opset.constant( compressed_weight.zero_point.data, dtype=compression_dtype, name=f"{const_node_name}/zero_point", ) - converted_zero_point = opset.convert(zero_point_const, const_dtype) - converted_const = opset.subtract(converted_const, converted_zero_point) + converted_zero_point = opset.convert(zero_point_const, ov.Type.f16) + converted_const = opset.subtract( + converted_const, converted_zero_point, name=f"{const_node_name}/zero_point/subtract" + ) - scale_const = opset.constant(compressed_weight.scale.data, dtype="float16", name=f"{const_node_name}/scale") - if const_dtype != "float16": - scale_const = opset.convert(scale_const, const_dtype, name=f"{const_node_name}/scale_convert") + scale_const = opset.constant( + compressed_weight.scale.data, dtype=ov.Type.f16, name=f"{const_node_name}/scale" + ) mul = opset.multiply( converted_const, scale_const, @@ -173,6 +175,11 @@ def transform_model( if compression_config.group_size != -1: mul = opset.reshape(mul, output_shape=original_shape, special_zero=False) + if const_dtype != ov.Type.f16: + mul = opset.convert( + mul, const_dtype, name=f"{const_node_name}/fq_weights_{wc_params.weight_port_id}/convert" + ) + mul_output = mul.output(0) for target_input in const_node.output(0).get_target_inputs(): target_input.replace_source_output(mul_output) diff --git a/tests/openvino/native/models.py b/tests/openvino/native/models.py index c25df4d862f..2b83a95457a 100644 --- a/tests/openvino/native/models.py +++ b/tests/openvino/native/models.py @@ -790,16 +790,19 @@ def _create_ov_model(self): class IdentityMatmul(OVReferenceModel): - def _create_ov_model(self, weights_dtype=None): + def _create_ov_model(self, weights_dtype=None, activation_dtype=None): """ :param: weights_dtype: precision of weights, should be either np.float32 or np.float16 + :param: activation_dtype: precision of activations, should be either np.float32 or np.float16 """ weights_dtype = np.float32 if weights_dtype is None else weights_dtype - input_node = opset.parameter([3, 3], name="Input_1") + activation_dtype = np.float32 if activation_dtype is None else activation_dtype + + input_node = opset.parameter([3, 3], dtype=activation_dtype, name="Input_1") weights_data = np.eye(3) * 255 current_weights = opset.constant(weights_data, dtype=weights_dtype, name="weights") - if weights_dtype != np.float32: - current_weights = opset.convert(current_weights, np.float32, name="weights/convert") + if weights_dtype != activation_dtype: + current_weights = opset.convert(current_weights, activation_dtype, name="weights/convert") matmul_node = opset.matmul(input_node, current_weights, transpose_a=False, transpose_b=True, name="MatMul") result = opset.result(matmul_node, name="Result") result.get_output_tensor(0).set_names(set(["Result"])) diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 71556ce3ac6..f339d316d0d 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -96,8 +96,6 @@ def check_int8_node(op: ov.Node, mode: CompressWeightsMode = CompressWeightsMode mul_node = get_next_node(sub_node) assert mul_node.get_type_name() == "Multiply" scale_node = mul_node.input_value(1).get_node() - if scale_node.get_type_name() == "Convert": - scale_node = scale_node.input_value(0).get_node() scale = get_const_value(scale_node) return { @@ -134,13 +132,14 @@ def check_int4_grouped(op: ov.Node, mode: CompressWeightsMode, group_size: int = mul_node = get_next_node(sub_node) assert mul_node.get_type_name() == "Multiply" scale_node = mul_node.input_value(1).get_node() - if scale_node.get_type_name() == "Convert": - scale_node = scale_node.input_value(0).get_node() assert list(scale_node.shape) == reduced_weight_shape reshape_node = get_next_node(mul_node) assert reshape_node.get_type_name() == "Reshape" + convert_node = get_next_node(reshape_node) + assert convert_node.get_type_name() == "Convert" + return { "scale": get_const_value(scale_node), } @@ -160,13 +159,14 @@ def check_nf4_grouped(op: ov.Node, group_size: int = 7): mul_node = get_next_node(convert_node) assert mul_node.get_type_name() == "Multiply" scale_node = mul_node.input_value(1).get_node() - if scale_node.get_type_name() == "Convert": - scale_node = scale_node.input_value(0).get_node() assert list(scale_node.shape) == reduced_weight_shape reshape_node = get_next_node(mul_node) assert reshape_node.get_type_name() == "Reshape" + convert_node = get_next_node(reshape_node) + assert convert_node.get_type_name() == "Convert" + return { "scale": get_const_value(scale_node), } @@ -697,22 +697,35 @@ def test_data_type_for_num_weights(mocker): assert isinstance(params.num_weights, np.uint64) -def test_weight_scale_datatype(): - # When model weight is in fp32, there will be an extra convert node for weight scale f16 > f32 - model_fp32 = IdentityMatmul(weights_dtype=np.float32).ov_model - compressed_model_fp32 = compress_weights(model_fp32) - name_to_node_map = {op.get_friendly_name(): op for op in compressed_model_fp32.get_ops()} - assert "weights/scale_convert" in name_to_node_map - scale_multiply_node = name_to_node_map["weights/fq_weights_1"] - assert scale_multiply_node.input_value(1).get_node().get_element_type() == ov.Type.f32 - - # When model weight is in fp16, there will be no extra convert node for weight scale - model_fp16 = IdentityMatmul(weights_dtype=np.float16).ov_model - compressed_model_fp16 = compress_weights(model_fp16) - name_to_node_map = {op.get_friendly_name(): op for op in compressed_model_fp16.get_ops()} - assert "weights/scale_convert" not in name_to_node_map - scale_multiply_node = name_to_node_map["weights/fq_weights_1"] - assert scale_multiply_node.input_value(1).get_node().get_element_type() == ov.Type.f16 +def test_compression_for_different_dtypes(): + for activation_dtype in [np.float32, np.float16]: + for weight_dtype in [np.float32, np.float16]: + if activation_dtype == np.float16 and weight_dtype == np.float32: + # Activations can be in f16 only if weights are in f16 + continue + + model = IdentityMatmul(weights_dtype=weight_dtype, activation_dtype=activation_dtype).ov_model + compressed_model = compress_weights( + model, mode=CompressWeightsMode.INT4_SYM, ratio=1, group_size=1, all_layers=True + ) + name_to_node_map = {op.get_friendly_name(): op for op in compressed_model.get_ops()} + + # Weight scale should be in fp16 nevertheless the weight data type + scale_multiply_node = name_to_node_map["weights/fq_weights_1"] + assert scale_multiply_node.input_value(1).get_node().get_element_type() == ov.Type.f16 + + reshape_node = get_next_node(scale_multiply_node) + assert reshape_node.get_type_name() == "Reshape" + + next_node = get_next_node(reshape_node) + if activation_dtype == np.float16: + # There should be no convert node after multiply if both weights and activations are in f16 + assert next_node.get_type_name() != "Convert" + else: + assert next_node.get_type_name() == "Convert" + # In case weight is in fp32, the convert node is manually inserted + if weight_dtype == np.float32: + assert next_node.get_friendly_name() == "weights/fq_weights_1/convert" DATASET_SIZE = 129