diff --git a/nncf/openvino/graph/model_transformer.py b/nncf/openvino/graph/model_transformer.py index 106f8e0b271..c073ce7605c 100644 --- a/nncf/openvino/graph/model_transformer.py +++ b/nncf/openvino/graph/model_transformer.py @@ -67,11 +67,6 @@ def __init__(self, model: TModel, inplace: bool = False): (OVExtractIfBodyCommand, self._apply_extract_if_body_transformation), ] - @staticmethod - def _convert_to_fp16(data): - clip_data = np.clip(data, np.finfo(np.float16).min, np.finfo(np.float16).max) - return clip_data.astype(np.float16) - @staticmethod def _get_name_to_node_mapping(model: ov.Model) -> Dict[str, ov.Node]: """ @@ -102,16 +97,16 @@ def _get_activation_node_names(model: ov.Model) -> List[str]: return list(activation_nodes) @staticmethod - def _update_tensor_name(tensors: List[DescriptorTensor], name: str) -> None: + def _update_tensor_names(tensors: List[DescriptorTensor], names: List[str]) -> None: """ Updates tensors names in-place. :param model: List of the tensors. - :param name: New name for tensor. + :param names: List of the new names for tensors. """ for tensor in tensors: current_names = tensor.get_names() - current_names.add(name) + current_names.update(names) tensor.set_names(current_names) def transform(self, transformation_layout: TransformationLayout) -> ov.Model: @@ -172,15 +167,16 @@ def _get_extra_model_outputs( node_name = transformation.target_point.target_node_name node = name_to_node_mapping[node_name] port_id = transformation.target_point.port_id + output_dtype = transformation.output_dtype if transformation.target_point.type == TargetType.POST_LAYER_OPERATION: output = node.output(port_id) - extra_model_outputs.append((output, port_id)) + extra_model_outputs.append((output, port_id, output_dtype)) elif transformation.target_point.type in [ TargetType.PRE_LAYER_OPERATION, TargetType.OPERATION_WITH_WEIGHTS, ]: output = node.input_value(port_id) - extra_model_outputs.append((output, output.get_index())) + extra_model_outputs.append((output, output.get_index(), output_dtype)) else: raise NotImplementedError(f"Unsupported target point type {transformation.target_point.type}") @@ -199,12 +195,17 @@ def _insert_outputs(model: ov.Model, outputs: List[Tuple[ov.Output, int, Callabl params = model.get_parameters() extra_model_outputs = [] - for output, port_id in outputs: - output_name = output.get_node().get_friendly_name() - # TODO: (KodiaqQ) check out the models with the Split + for output, port_id, dtype in outputs: + node_output = output + output_name = node_output.get_node().get_friendly_name() result_name = get_result_node_name(output_name, port_id) - result = opset.result(output, name=result_name) - OVModelTransformer._update_tensor_name([result.get_output_tensor(0)], result_name) + + if node_output.get_element_type() != dtype: + node_output = opset.convert(output, destination_type=dtype) + + result = opset.result(node_output, name=result_name) + result_tensor_names = [result_name] + list(output.get_names()) + OVModelTransformer._update_tensor_names([result.get_output_tensor(0)], result_tensor_names) extra_model_outputs.append(result) model_with_outputs = ov.Model( @@ -284,7 +285,7 @@ def _create_fake_quantize( op_output: ov.Output, fake_quantize_params: FakeQuantizeParameters, fake_quantize_name: str, - convert_to_fp16: bool, + data_type: ov.Type, ) -> ov.Node: """ Creates FakeQuantize node. @@ -292,7 +293,7 @@ def _create_fake_quantize( :param op_output: Output of the previous node. :param fake_quantize_params: FakeQuantizeParameters instance. :param fake_quantize_name: New layer name. - :param convert_to_fp16: Whether convert parameters to FP16 or not. + :param data_type: ov.Type instance for data. :return: ov.Node instance. """ @@ -301,24 +302,18 @@ def _create_fake_quantize( output_low = fake_quantize_params.output_low.data output_high = fake_quantize_params.output_high.data levels = fake_quantize_params.levels - dtype = ov.Type.f32 - if convert_to_fp16: - input_low = OVModelTransformer._convert_to_fp16(input_low) - input_high = OVModelTransformer._convert_to_fp16(input_high) - output_low = OVModelTransformer._convert_to_fp16(output_low) - output_high = OVModelTransformer._convert_to_fp16(output_high) - dtype = ov.Type.f16 - - input_low = OVModelTransformer._create_constant(input_low, dtype=dtype, name=f"{fake_quantize_name}/input_low") + input_low = OVModelTransformer._create_constant( + input_low, dtype=data_type, name=f"{fake_quantize_name}/input_low" + ) input_high = OVModelTransformer._create_constant( - input_high, dtype=dtype, name=f"{fake_quantize_name}/input_high" + input_high, dtype=data_type, name=f"{fake_quantize_name}/input_high" ) output_low = OVModelTransformer._create_constant( - output_low, dtype=dtype, name=f"{fake_quantize_name}/output_low" + output_low, dtype=data_type, name=f"{fake_quantize_name}/output_low" ) output_high = OVModelTransformer._create_constant( - output_high, dtype=dtype, name=f"{fake_quantize_name}/output_high" + output_high, dtype=data_type, name=f"{fake_quantize_name}/output_high" ) return opset.fake_quantize( @@ -330,7 +325,7 @@ def _create_fake_convert( op_output: ov.Output, fake_convert_params: FakeConvertParameters, fake_convert_name: str, - convert_to_fp16: bool, + data_type: ov.Type, ) -> ov.Node: """ Creates FakeConvert node. @@ -338,22 +333,16 @@ def _create_fake_convert( :param op_output: Output of the previous node. :param fake_convert_params: FakeConvertParameters instance. :param fake_convert_name: New layer name. - :param convert_to_fp16: Whether convert parameters to FP16 or not. + :param data_type: ov.Type instance for data. :return: ov.Node instance. """ scale = fake_convert_params.scale.data shift = fake_convert_params.shift.data - dtype = ov.Type.f32 - - if convert_to_fp16: - scale = OVModelTransformer._convert_to_fp16(scale) - shift = OVModelTransformer._convert_to_fp16(shift) - dtype = ov.Type.f16 destination_type = fake_convert_params.destination_type.value - scale = OVModelTransformer._create_constant(scale, dtype=dtype, name=f"{fake_convert_name}/scale") - shift = OVModelTransformer._create_constant(shift, dtype=dtype, name=f"{fake_convert_name}/shift") + scale = OVModelTransformer._create_constant(scale, dtype=data_type, name=f"{fake_convert_name}/scale") + shift = OVModelTransformer._create_constant(shift, dtype=data_type, name=f"{fake_convert_name}/shift") return opset.fake_convert( data=op_output, @@ -383,7 +372,6 @@ def _insert_fake_quantize_op( inp_node = target_node.input(port_id) input_node_output = inp_node.get_source_output() data_type = inp_node.get_element_type() - convert_to_fp16 = data_type == ov.Type(np.float16) name = "fq_weights" if transform_type == TargetType.OPERATION_WITH_WEIGHTS else "fq_input" fq_name = f"{node_name}/{name}_{port_id}" @@ -398,20 +386,19 @@ def _insert_fake_quantize_op( op_output=input_node_output, fake_quantize_params=fq_params, fake_quantize_name=fq_name, - convert_to_fp16=convert_to_fp16, + data_type=data_type, ) inp_node.replace_source_output(fq.output(0)) elif transform_type == TargetType.POST_LAYER_OPERATION: output = target_node.output(port_id) data_type = output.get_element_type() - convert_to_fp16 = data_type == ov.Type(np.float16) target_inputs = output.get_target_inputs() fq_name = f"{node_name}/fq_output_{port_id}" fq = OVModelTransformer._create_fake_quantize( op_output=output, fake_quantize_params=fq_params, fake_quantize_name=fq_name, - convert_to_fp16=convert_to_fp16, + data_type=data_type, ) for inp_node in target_inputs: inp_node.replace_source_output(fq.output(0)) @@ -447,25 +434,25 @@ def _insert_fake_convert_op( if out.get_node().get_type_name() == "FakeConvert": fc = out.get_node() if fc is None: - convert_to_fp16 = inp_node.get_element_type() == ov.Type(np.float16) + data_type = inp_node.get_element_type() fc_name = f"{node_name}/fc_{name}_{port_id}" fc = OVModelTransformer._create_fake_convert( op_output=input_node_output, fake_convert_params=fc_params, fake_convert_name=fc_name, - convert_to_fp16=convert_to_fp16, + data_type=data_type, ) inp_node.replace_source_output(fc.output(0)) elif transform_type == TargetType.POST_LAYER_OPERATION: output = target_node.output(port_id) - convert_to_fp16 = output.get_element_type() == ov.Type(np.float16) + data_type = output.get_element_type() target_inputs = output.get_target_inputs() fc_name = f"{node_name}/fc_output_{port_id}" fc = OVModelTransformer._create_fake_convert( op_output=output, fake_convert_params=fc_params, fake_convert_name=fc_name, - convert_to_fp16=convert_to_fp16, + data_type=data_type, ) for inp_node in target_inputs: inp_node.replace_source_output(fc.output(0)) @@ -517,13 +504,19 @@ def _set_const_value(node_with_const: ov.Node, const_port_id: int, const_value: if const_node is None: raise nncf.InternalError("Constant node was expected but could not find it.") - const_shape = const_node.data.shape - const_dtype = const_node.data.dtype - const_value = np.reshape(const_value, const_shape).astype(const_dtype) + const_value = np.reshape(const_value, const_node.data.shape) - # TODO(andrey-churkin): Replace on opset13.constant() in 2023.3 release - new_const_node = ov.op.Constant(const_value, shared_memory=True) - new_const_node.set_friendly_name(const_node.get_friendly_name()) + shared_memory = True + if const_port.get_element_type() == ov.Type.bf16: + # Shared memory does not work for BF16 precision + shared_memory = False + + new_const_node = opset.constant( + const_value, + dtype=const_port.get_element_type(), + name=const_node.get_friendly_name(), + shared_memory=shared_memory, + ) const_port.replace_source_output(new_const_node.output(0)) @staticmethod @@ -553,6 +546,7 @@ def _apply_model_extraction_transformation( :param transformation: Model extraction transformation. :return: Extracted sub-model. """ + outputs_type = ov.Type.f32 transformation = transformations[-1] name_to_node_mapping = OVModelTransformer._get_name_to_node_mapping(model) @@ -564,25 +558,34 @@ def _apply_model_extraction_transformation( continue input_port = input_node.input(input_port_id) + input_type = input_port.get_element_type() input_node_output = input_port.get_source_output() parameter_name = get_parameter_node_name(input_name, input_port_id) + new_param = opset.parameter( shape=input_node_output.partial_shape, - dtype=input_node_output.get_element_type(), + dtype=outputs_type, name=parameter_name, ) - input_port.replace_source_output(new_param.output(0)) + new_input = new_param.output(0) + + if input_type != outputs_type: + new_input = opset.convert(new_param, destination_type=input_type).output(0) + + input_port.replace_source_output(new_input) new_param_tensors = [o.get_tensor() for o in new_param.outputs()] - OVModelTransformer._update_tensor_name(new_param_tensors, parameter_name) + OVModelTransformer._update_tensor_names(new_param_tensors, [parameter_name]) params.append(new_param) for output_name, output_port_id in transformation.output_ids: output_node = name_to_node_mapping[output_name] - output_port = output_node.output(output_port_id) result_name = get_result_node_name(output_name, output_port_id) - new_result = opset.result(output_port, name=result_name) - OVModelTransformer._update_tensor_name([new_result.get_output_tensor(0)], result_name) + if output_node.get_element_type() != outputs_type: + output_node = opset.convert(output_node, destination_type=outputs_type) + new_result = opset.result(output_node, name=result_name) + result_tensor_names = [result_name] + list(output_node.output(0).get_names()) + OVModelTransformer._update_tensor_names([new_result.get_output_tensor(0)], result_tensor_names) results.append(new_result) if not results: @@ -624,7 +627,7 @@ def _apply_stateless_model_extraction_transformation( for input_port in output_port.get_target_inputs(): input_port.replace_source_output(new_param.output(0)) new_param_tensors = [o.get_tensor() for o in new_param.outputs()] - OVModelTransformer._update_tensor_name(new_param_tensors, parameter_name) + OVModelTransformer._update_tensor_names(new_param_tensors, [parameter_name]) params.append(new_param) for output_name, output_port_id in transformation.output_ids: @@ -633,7 +636,7 @@ def _apply_stateless_model_extraction_transformation( output_port = output_node.output(output_port_id) result_name = get_result_node_name(output_name, output_port_id) new_result = opset.result(output_port, name=result_name) - OVModelTransformer._update_tensor_name([new_result.get_output_tensor(0)], result_name) + OVModelTransformer._update_tensor_names([new_result.get_output_tensor(0)], [result_name]) results.append(new_result) if not results: @@ -688,15 +691,16 @@ def _insert_inplace_operation( target_node = name_to_node_mapping[node_name] port_id = transformation.target_point.port_id fn_output_port_id = transformation.fn_output_port_id + output_dtype = transformation.output_dtype if transform_type == TargetType.POST_LAYER_OPERATION: new_node = transformation.inplace_op_fn(target_node, port_id, transformation.last_inplace_node_name) - return (new_node.output(fn_output_port_id), fn_output_port_id) + return (new_node.output(fn_output_port_id), fn_output_port_id, output_dtype) if transform_type in [TargetType.PRE_LAYER_OPERATION, TargetType.OPERATION_WITH_WEIGHTS]: output = target_node.input_value(port_id) new_node = transformation.inplace_op_fn( output.get_node(), output.get_index(), transformation.last_inplace_node_name ) - return (new_node.output(fn_output_port_id), fn_output_port_id) + return (new_node.output(fn_output_port_id), fn_output_port_id, output_dtype) raise nncf.InternalError(f"Transform type {transform_type} is not supported") @staticmethod @@ -759,10 +763,7 @@ def _apply_multiply_insertion_transformations( if target_node.get_friendly_name() in transformation.destination_node_names: destination_ports.append(target_input_port) - scale_dtype = ov.Type(np.float32) - fp16_dtype = ov.Type(np.float16) - if all(p.get_element_type() == fp16_dtype for p in destination_ports): - scale_dtype = fp16_dtype + scale_dtype = node_output_port.get_element_type() scale_constant = OVModelTransformer._create_constant( transformation.scale_value, dtype=scale_dtype, name=f"{transformation.multiply_node_name}/scale" diff --git a/nncf/openvino/graph/node_utils.py b/nncf/openvino/graph/node_utils.py index 69047337830..619c4fa27a6 100644 --- a/nncf/openvino/graph/node_utils.py +++ b/nncf/openvino/graph/node_utils.py @@ -107,17 +107,18 @@ def cnt_if_op(model: ov.Model, cnt: int) -> int: return cnt_if_op(model, 0) -def get_const_value(const_node: ov.Node, dtype: Optional[np.dtype] = None) -> np.ndarray: +def get_const_value(const_node: ov.Node) -> np.ndarray: """ Returns the constant tensor for the node. + This method is applicable only for the floating-point constant data. :param const_node: OpenVINO node. - :param dtype: Destination type. :return: The constant value. """ - if dtype is None: - return const_node.data - return const_node.get_data(dtype=dtype) + if const_node.get_element_type() == ov.Type.bf16: + # Fixed FP32 data type as the result for BF16 constant + return const_node.get_data(dtype=np.float32) + return const_node.data def get_bias_value(node_with_bias: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model) -> np.ndarray: diff --git a/nncf/openvino/graph/transformations/commands.py b/nncf/openvino/graph/transformations/commands.py index f69e68c5504..2f8f891a1ac 100644 --- a/nncf/openvino/graph/transformations/commands.py +++ b/nncf/openvino/graph/transformations/commands.py @@ -48,6 +48,10 @@ def __init__(self, target_point: OVTargetPoint): class OVOutputInsertionCommand(OVInsertionCommand): + def __init__(self, target_point: OVTargetPoint, output_dtype: ov.Type = ov.Type.f32): + super().__init__(target_point) + self.output_dtype = output_dtype + def union(self, other: "TransformationCommand") -> "TransformationCommand": # Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand raise NotImplementedError() @@ -60,11 +64,13 @@ def __init__( inplace_op_fn: InplaceInsertionFnType, fn_output_port_id: int, last_inplace_node_name: str, + output_dtype: ov.Type = ov.Type.f32, ): super().__init__(target_point) self.inplace_op_fn = inplace_op_fn self.fn_output_port_id = fn_output_port_id self.last_inplace_node_name = last_inplace_node_name + self.output_dtype = output_dtype def union(self, other: "TransformationCommand") -> "TransformationCommand": # Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand diff --git a/nncf/openvino/quantization/quantize_ifmodel.py b/nncf/openvino/quantization/quantize_ifmodel.py index 8f106d7387c..07d22171a17 100644 --- a/nncf/openvino/quantization/quantize_ifmodel.py +++ b/nncf/openvino/quantization/quantize_ifmodel.py @@ -300,8 +300,8 @@ def create_output_insertion_commands_if_node(model: ov.Model, if_node: NNCFNode) commands = [] name_to_node_mapping = {op.get_friendly_name(): op for op in model.get_ops()} ov_node = name_to_node_mapping[if_node.node_name] - for port_id in range(len(ov_node.inputs())): - commands.append( - OVOutputInsertionCommand(OVTargetPoint(TargetType.PRE_LAYER_OPERATION, if_node.node_name, port_id)) - ) + for port_id, ov_input in enumerate(ov_node.inputs()): + target_point = OVTargetPoint(TargetType.PRE_LAYER_OPERATION, if_node.node_name, port_id) + ov_input_dtype = ov_input.get_element_type() + commands.append(OVOutputInsertionCommand(target_point, output_dtype=ov_input_dtype)) return commands diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index 81d61cb648c..53011c39d41 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -10,7 +10,6 @@ # limitations under the License. from typing import Dict, Iterable, List, Optional, Tuple -import numpy as np import openvino as ov from openvino.runtime import opset13 as opset @@ -114,7 +113,12 @@ def set_weight( const_port = node_with_const.input(weight_port_id) const_node = node_with_const.input_value(weight_port_id).get_node() - new_const_node = ov.runtime.op.Constant(weight.data, shared_memory=True) + shared_memory = True + if const_node.get_element_type() == ov.Type.bf16: + # Shared memory does not work for BF16 precision + shared_memory = False + + new_const_node = ov.runtime.op.Constant(weight.data, shared_memory=shared_memory) new_const_node.set_friendly_name(const_node.get_friendly_name()) const_port.replace_source_output(new_const_node.output(0)) @@ -167,7 +171,7 @@ def transform_model( should_add_convert_node = True break - weight = Tensor(get_const_value(const_node, np.float32 if const_dtype == ov.Type.bf16 else None)) + weight = Tensor(get_const_value(const_node)) original_shape = weight.shape compressed_weight = compress_weight( weight, diff --git a/tests/openvino/native/common.py b/tests/openvino/native/common.py index def2149745c..47118916d87 100644 --- a/tests/openvino/native/common.py +++ b/tests/openvino/native/common.py @@ -47,7 +47,8 @@ def get_dataset_for_test(model): input_data = {} for param in model.get_parameters(): input_shape = param.partial_shape.get_max_shape() - input_data[param.get_output_tensor(0).get_any_name()] = rng.uniform(0, 1, input_shape) + tensor = param.get_output_tensor(0) + input_data[tensor.get_any_name()] = rng.uniform(0, 1, input_shape).astype(tensor.get_element_type().to_dtype()) dataset = Dataset([input_data]) return dataset diff --git a/tests/openvino/native/data/2024.2/reference_graphs/original_nncf_graph/exctracted_FPModel.dot b/tests/openvino/native/data/2024.2/reference_graphs/original_nncf_graph/exctracted_FPModel.dot new file mode 100644 index 00000000000..24323054b2a --- /dev/null +++ b/tests/openvino/native/data/2024.2/reference_graphs/original_nncf_graph/exctracted_FPModel.dot @@ -0,0 +1,13 @@ +strict digraph { +"0 Parameter_MatMul.0" [id=0, type=Parameter]; +"1 Convert_430" [id=1, type=Convert]; +"2 MatMul" [id=2, type=MatMul]; +"3 Convert_431" [id=3, type=Convert]; +"4 Result_MatMul.0" [id=4, type=Result]; +"5 MatMul_const" [id=5, type=Constant]; +"0 Parameter_MatMul.0" -> "1 Convert_430" [label="[1, 3, 4, 2]", style=solid]; +"1 Convert_430" -> "2 MatMul" [label="[1, 3, 4, 2]", style=solid]; +"2 MatMul" -> "3 Convert_431" [label="[1, 3, 2, 5]", style=solid]; +"3 Convert_431" -> "4 Result_MatMul.0" [label="[1, 3, 2, 5]", style=solid]; +"5 MatMul_const" -> "2 MatMul" [label="[1, 3, 4, 5]", style=solid]; +} diff --git a/tests/openvino/native/models.py b/tests/openvino/native/models.py index c7f883b448f..88bdce3bac6 100644 --- a/tests/openvino/native/models.py +++ b/tests/openvino/native/models.py @@ -282,21 +282,21 @@ def _create_ov_model(self): class FPModel(OVReferenceModel): - def __init__(self, const_dtype="FP32", input_dtype="FP32"): - self.const_dtype = np.float32 if const_dtype == "FP32" else np.float16 - self.input_dtype = np.float32 if input_dtype == "FP32" else np.float16 + def __init__(self, const_dtype: ov.Type = ov.Type.f32, input_dtype: ov.Type = ov.Type.f32): + self.const_dtype = const_dtype + self.input_dtype = input_dtype super().__init__() def _create_ov_model(self): input_shape = [1, 3, 4, 2] input_1 = opset.parameter(input_shape, name="Input", dtype=self.input_dtype) - data = self._rng.random((1, 3, 4, 5)).astype(self.const_dtype) + data = opset.constant(value=self._rng.random((1, 3, 4, 5)), dtype=self.const_dtype, name="MatMul_const") if self.const_dtype != self.input_dtype: - data = opset.convert(data, self.input_dtype) + data = opset.convert(data, self.input_dtype.to_string()) matmul = opset.matmul(input_1, data, transpose_a=True, transpose_b=False, name="MatMul") - bias = self._rng.random((1, 3, 1, 1)).astype(self.const_dtype) + bias = opset.constant(value=self._rng.random((1, 3, 1, 1)), dtype=self.const_dtype, name="MatMul_bias") if self.const_dtype != self.input_dtype: - bias = opset.convert(bias, self.input_dtype) + bias = opset.convert(bias, self.input_dtype.to_string()) add = opset.add(matmul, bias, name="Add") result = opset.result(add, name="Result_Add") result.get_output_tensor(0).set_names(set(["Result_Add"])) diff --git a/tests/openvino/native/quantization/test_fq_params_calculation.py b/tests/openvino/native/quantization/test_fq_params_calculation.py index 459f5a313f4..f455f1ca995 100644 --- a/tests/openvino/native/quantization/test_fq_params_calculation.py +++ b/tests/openvino/native/quantization/test_fq_params_calculation.py @@ -11,7 +11,6 @@ from pathlib import Path -import numpy as np import openvino as ov import pytest import torch @@ -163,8 +162,8 @@ def test_synthetic_models_fq_shapes(model_creator_func, ref_shapes, inplace_stat assert node["output_high"].shape == ref_shapes[node_name] -@pytest.mark.parametrize("const_dtype", ["FP16", "FP32"]) -@pytest.mark.parametrize("input_dtype", ["FP16", "FP32"]) +@pytest.mark.parametrize("const_dtype", [ov.Type.f16, ov.Type.f32, ov.Type.bf16]) +@pytest.mark.parametrize("input_dtype", [ov.Type.f16, ov.Type.f32, ov.Type.bf16]) def test_fq_precision_orig_fp32model(const_dtype, input_dtype, inplace_statistics): model = FPModel(const_dtype, input_dtype) quantized_model = quantize_model( @@ -174,10 +173,10 @@ def test_fq_precision_orig_fp32model(const_dtype, input_dtype, inplace_statistic if op.get_type_name() == "FakeQuantize": inp_node = op.input(0) fq_input_node = inp_node.get_source_output().get_node() - if fq_input_node.get_element_type() == "Constant": - assert op.get_element_type() == ov.Type(np.float32 if input_dtype == "FP32" else np.float16) + if fq_input_node.get_type_name() == "Constant": + assert op.get_element_type() == const_dtype elif op.get_type_name() == "Convert": inp_node = op.input(0) fq_input_node = inp_node.get_source_output().get_node() - if fq_input_node.get_element_type() == "Constant": - assert op.get_element_type() == ov.Type(np.float32 if const_dtype == "FP32" else np.float16) + if fq_input_node.get_type_name() == "Constant": + assert op.get_element_type() == input_dtype diff --git a/tests/openvino/native/quantization/test_quantization_pipeline.py b/tests/openvino/native/quantization/test_quantization_pipeline.py index 8196dbdfc75..657b14ea864 100644 --- a/tests/openvino/native/quantization/test_quantization_pipeline.py +++ b/tests/openvino/native/quantization/test_quantization_pipeline.py @@ -48,7 +48,8 @@ def test_compress_weights(model_creator_func, ref_nodes): fq_nodes = get_nodes_by_type(quantized_model, type_name="FakeQuantize") assert len(fq_nodes) == len(ref_fqs_names) - for fq_name in fq_nodes: + for fq_node in fq_nodes: + fq_name = fq_node.get_friendly_name() assert fq_name in ref_fqs_names for op in quantized_model.get_ops(): @@ -76,7 +77,8 @@ def test_overflow_fix_applied(model_creator_func, ref_nodes): fq_nodes = get_nodes_by_type(quantized_model, type_name="FakeQuantize") assert len(fq_nodes) == len(ref_fqs_names) - for fq_name in fq_nodes: + for fq_node in fq_nodes: + fq_name = fq_node.get_friendly_name() assert fq_name in ref_fqs_names for op in quantized_model.get_ops(): diff --git a/tests/openvino/native/test_model_transformer.py b/tests/openvino/native/test_model_transformer.py index 871aebc354e..e4bb876e76a 100644 --- a/tests/openvino/native/test_model_transformer.py +++ b/tests/openvino/native/test_model_transformer.py @@ -22,6 +22,7 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout from nncf.openvino.graph.model_transformer import OVModelTransformer +from nncf.openvino.graph.node_utils import get_const_value from nncf.openvino.graph.node_utils import get_inplace_batch_mean_op from nncf.openvino.graph.node_utils import get_inplace_max_op from nncf.openvino.graph.node_utils import get_inplace_mean_op @@ -81,22 +82,25 @@ def create_transformed_model(model, target_layers, target_type, command_type, po return transformed_model -def get_extra_outputs(original_model, transformed_model): - extra_outputs = set() +def get_extra_outputs(original_model, transformed_model, as_nodes=False): + extra_outputs = {} for out in transformed_model.get_results(): - extra_outputs.add(out.get_friendly_name()) + extra_outputs[out.get_friendly_name()] = out for out in original_model.get_results(): - extra_outputs.remove(out.get_friendly_name()) + extra_outputs.pop(out.get_friendly_name()) - return extra_outputs + names_set = set(extra_outputs.keys()) + nodes_set = set(extra_outputs.values()) + + return nodes_set if as_nodes else names_set def get_nodes_by_type(model: ov.Model, type_name: str) -> List[ov.Node]: fq_nodes = [] for op in model.get_ops(): if op.get_type_name() == type_name: - fq_nodes.append(op.get_friendly_name()) + fq_nodes.append(op) return fq_nodes @@ -113,6 +117,12 @@ def create_fake_convert_params(destination_type: FP8Type) -> FakeConvertParamete return FakeConvertParameters(scale, shift, destination_type) +def verify_inputs_equality(node: ov.Node) -> None: + act_type = node.input(0).get_source_output().get_element_type() + for node_input in node.inputs(): + assert node_input.get_source_output().get_element_type() == act_type + + @dataclass class InplaceOpTestCase: name: str @@ -386,31 +396,43 @@ def test_inplace_mean_per_ch_fn_dynamic_shapes(test_params: InplaceOpTestCase, i @pytest.mark.parametrize( "target_type", [TargetType.PRE_LAYER_OPERATION, TargetType.POST_LAYER_OPERATION, TargetType.OPERATION_WITH_WEIGHTS] ) +@pytest.mark.parametrize( + "model_object", + [ + LinearModel().ov_model, + FPModel(input_dtype=ov.Type.bf16).ov_model, + FPModel(const_dtype=ov.Type.bf16).ov_model, + FPModel(const_dtype=ov.Type.bf16, input_dtype=ov.Type.bf16).ov_model, + ], +) @pytest.mark.parametrize("target_layers", TARGET_INSERT_LAYERS) -def test_output_insertion(target_type, target_layers): - model = LinearModel().ov_model +def test_output_insertion(target_type, target_layers, model_object): port_id = 1 if target_type == TargetType.OPERATION_WITH_WEIGHTS else 0 - transformed_model = create_transformed_model(model, target_layers, target_type, OVOutputInsertionCommand, port_id) - - if target_type == TargetType.PRE_LAYER_OPERATION: - target_layers = ["Reshape"] + transformed_model = create_transformed_model( + model_object, target_layers, target_type, OVOutputInsertionCommand, port_id + ) - target_nodes = [] + target_nodes = set() for target_layer in target_layers: target_node = get_node_by_name(transformed_model, target_layer) if target_type == TargetType.OPERATION_WITH_WEIGHTS: source_output = target_node.input(1).get_source_output() target_node = source_output.get_node() port_id = source_output.get_index() - target_nodes.append((target_node, port_id)) + elif target_type == TargetType.PRE_LAYER_OPERATION: + source_output = target_node.input(0).get_source_output() + target_node = source_output.get_node() + target_nodes.add((target_node, port_id)) - extra_outputs = get_extra_outputs(model, transformed_model) + extra_outputs = get_extra_outputs(model_object, transformed_model, as_nodes=True) ref_output_names = [ get_result_node_name(target_node.get_friendly_name(), port_id) for target_node, port_id in target_nodes ] assert len(extra_outputs) == len(ref_output_names) - for out_name in extra_outputs: - assert out_name in ref_output_names + for extra_output in extra_outputs: + extra_output_name = extra_output.get_friendly_name() + assert extra_output_name in ref_output_names + assert extra_output.output(0).get_element_type() == ov.Type.f32 @pytest.mark.parametrize("test_params", INPLACE_OPS_TEST_CASES, ids=[str(case) for case in INPLACE_OPS_TEST_CASES]) @@ -453,11 +475,18 @@ def test_node_removing(target_layers): @pytest.mark.parametrize("target_layers, ref_fq_names", zip(TARGET_INSERT_LAYERS, TARGET_PRE_LAYER_FQS)) -def test_fq_insertion_pre_layer(target_layers, ref_fq_names): - model = LinearModel().ov_model - +@pytest.mark.parametrize( + "model_object", + [ + LinearModel().ov_model, + FPModel(input_dtype=ov.Type.bf16).ov_model, + FPModel(const_dtype=ov.Type.bf16).ov_model, + FPModel(const_dtype=ov.Type.bf16, input_dtype=ov.Type.bf16).ov_model, + ], +) +def test_fq_insertion_pre_layer(target_layers, ref_fq_names, model_object): transformed_model = create_transformed_model( - model, + model_object, target_layers, TargetType.PRE_LAYER_OPERATION, OVQuantizerInsertionCommand, @@ -466,16 +495,25 @@ def test_fq_insertion_pre_layer(target_layers, ref_fq_names): fq_nodes = get_nodes_by_type(transformed_model, type_name="FakeQuantize") assert len(fq_nodes) == len(ref_fq_names) - for fq_name in fq_nodes: + for fq_node in fq_nodes: + fq_name = fq_node.get_friendly_name() assert fq_name in ref_fq_names + verify_inputs_equality(fq_node) -@pytest.mark.parametrize("target_layers, ref_fс_names", zip(TARGET_INSERT_LAYERS, TARGET_PRE_LAYER_FCS)) -def test_fc_insertion_pre_layer(target_layers, ref_fс_names): - model = LinearModel().ov_model - +@pytest.mark.parametrize("target_layers, ref_fc_names", zip(TARGET_INSERT_LAYERS, TARGET_PRE_LAYER_FCS)) +@pytest.mark.parametrize( + "model_object", + [ + LinearModel().ov_model, + FPModel(input_dtype=ov.Type.bf16).ov_model, + FPModel(const_dtype=ov.Type.bf16).ov_model, + FPModel(const_dtype=ov.Type.bf16, input_dtype=ov.Type.bf16).ov_model, + ], +) +def test_fc_insertion_pre_layer(target_layers, ref_fc_names, model_object): transformed_model = create_transformed_model( - model, + model_object, target_layers, TargetType.PRE_LAYER_OPERATION, OVConvertInsertionCommand, @@ -483,17 +521,26 @@ def test_fc_insertion_pre_layer(target_layers, ref_fс_names): ) fc_nodes = get_nodes_by_type(transformed_model, type_name="FakeConvert") - assert len(fc_nodes) == len(ref_fс_names) - for fc_name in fc_nodes: - assert fc_name in ref_fс_names + assert len(fc_nodes) == len(ref_fc_names) + for fc_node in fc_nodes: + fc_name = fc_node.get_friendly_name() + assert fc_name in ref_fc_names + verify_inputs_equality(fc_node) @pytest.mark.parametrize("target_layers, ref_fq_names", zip(TARGET_INSERT_LAYERS, TARGET_POST_LAYER_FQS)) -def test_fq_insertion_post_layer(target_layers, ref_fq_names): - model = LinearModel().ov_model - +@pytest.mark.parametrize( + "model_object", + [ + LinearModel().ov_model, + FPModel(input_dtype=ov.Type.bf16).ov_model, + FPModel(const_dtype=ov.Type.bf16).ov_model, + FPModel(const_dtype=ov.Type.bf16, input_dtype=ov.Type.bf16).ov_model, + ], +) +def test_fq_insertion_post_layer(target_layers, ref_fq_names, model_object): transformed_model = create_transformed_model( - model, + model_object, target_layers, TargetType.POST_LAYER_OPERATION, OVQuantizerInsertionCommand, @@ -502,16 +549,25 @@ def test_fq_insertion_post_layer(target_layers, ref_fq_names): fq_nodes = get_nodes_by_type(transformed_model, type_name="FakeQuantize") assert len(fq_nodes) == len(ref_fq_names) - for fq_name in fq_nodes: + for fq_node in fq_nodes: + fq_name = fq_node.get_friendly_name() assert fq_name in ref_fq_names + verify_inputs_equality(fq_node) -@pytest.mark.parametrize("target_layers, ref_fс_names", zip(TARGET_INSERT_LAYERS, TARGET_POST_LAYER_FCS)) -def test_fc_insertion_post_layer(target_layers, ref_fс_names): - model = LinearModel().ov_model - +@pytest.mark.parametrize("target_layers, ref_fc_names", zip(TARGET_INSERT_LAYERS, TARGET_POST_LAYER_FCS)) +@pytest.mark.parametrize( + "model_object", + [ + LinearModel().ov_model, + FPModel(input_dtype=ov.Type.bf16).ov_model, + FPModel(const_dtype=ov.Type.bf16).ov_model, + FPModel(const_dtype=ov.Type.bf16, input_dtype=ov.Type.bf16).ov_model, + ], +) +def test_fc_insertion_post_layer(target_layers, ref_fc_names, model_object): transformed_model = create_transformed_model( - model, + model_object, target_layers, TargetType.POST_LAYER_OPERATION, OVConvertInsertionCommand, @@ -519,17 +575,26 @@ def test_fc_insertion_post_layer(target_layers, ref_fс_names): ) fc_nodes = get_nodes_by_type(transformed_model, type_name="FakeConvert") - assert len(fc_nodes) == len(ref_fс_names) - for fc_name in fc_nodes: - assert fc_name in ref_fс_names + assert len(fc_nodes) == len(ref_fc_names) + for fc_node in fc_nodes: + fc_name = fc_node.get_friendly_name() + assert fc_name in ref_fc_names + verify_inputs_equality(fc_node) @pytest.mark.parametrize("target_layers, ref_fq_names", zip(TARGET_INSERT_LAYERS, TARGET_WEIGHTS_FQS)) -def test_fq_insertion_weights(target_layers, ref_fq_names): - model = LinearModel().ov_model - +@pytest.mark.parametrize( + "model_object", + [ + LinearModel().ov_model, + FPModel(input_dtype=ov.Type.bf16).ov_model, + FPModel(const_dtype=ov.Type.bf16).ov_model, + FPModel(const_dtype=ov.Type.bf16, input_dtype=ov.Type.bf16).ov_model, + ], +) +def test_fq_insertion_weights(target_layers, ref_fq_names, model_object): transformed_model = create_transformed_model( - model, + model_object, target_layers, TargetType.OPERATION_WITH_WEIGHTS, OVQuantizerInsertionCommand, @@ -539,16 +604,25 @@ def test_fq_insertion_weights(target_layers, ref_fq_names): fq_nodes = get_nodes_by_type(transformed_model, type_name="FakeQuantize") assert len(fq_nodes) == len(ref_fq_names) - for fq_name in fq_nodes: + for fq_node in fq_nodes: + fq_name = fq_node.get_friendly_name() assert fq_name in ref_fq_names + verify_inputs_equality(fq_node) -@pytest.mark.parametrize("target_layers, ref_fс_names", zip(TARGET_INSERT_LAYERS, TARGET_WEIGHTS_FCS)) -def test_fc_insertion_weights(target_layers, ref_fс_names): - model = LinearModel().ov_model - +@pytest.mark.parametrize("target_layers, ref_fc_names", zip(TARGET_INSERT_LAYERS, TARGET_WEIGHTS_FCS)) +@pytest.mark.parametrize( + "model_object", + [ + LinearModel().ov_model, + FPModel(input_dtype=ov.Type.bf16).ov_model, + FPModel(const_dtype=ov.Type.bf16).ov_model, + FPModel(const_dtype=ov.Type.bf16, input_dtype=ov.Type.bf16).ov_model, + ], +) +def test_fc_insertion_weights(target_layers, ref_fc_names, model_object): transformed_model = create_transformed_model( - model, + model_object, target_layers, TargetType.OPERATION_WITH_WEIGHTS, OVConvertInsertionCommand, @@ -557,22 +631,36 @@ def test_fc_insertion_weights(target_layers, ref_fс_names): ) fc_nodes = get_nodes_by_type(transformed_model, type_name="FakeConvert") - assert len(fc_nodes) == len(ref_fс_names) - for fc_name in fc_nodes: - assert fc_name in ref_fс_names + assert len(fc_nodes) == len(ref_fc_names) + for fc_node in fc_nodes: + fc_name = fc_node.get_friendly_name() + assert fc_name in ref_fc_names + verify_inputs_equality(fc_node) MODELS_WITH_PARAMETERS = [ { "model": ConvModel().ov_model, "layers": ["Conv"], - "values": [np.full((3,), 2)], + "values": [np.full((3,), 2).astype(np.float32)], + "refs": [2.0], + }, + { + "model": FPModel(const_dtype=ov.Type.f16).ov_model, + "layers": ["MatMul"], + "values": [np.full((3,), 2).astype(np.float16)], "refs": [2.0], }, { - "model": FPModel(const_dtype="FP16").ov_model, + "model": FPModel(const_dtype=ov.Type.f16, input_dtype=ov.Type.f16).ov_model, "layers": ["MatMul"], - "values": [np.full((3,), 2)], + "values": [np.full((3,), 2).astype(np.float16)], + "refs": [2.0], + }, + { + "model": FPModel(const_dtype=ov.Type.bf16, input_dtype=ov.Type.bf16).ov_model, + "layers": ["MatMul"], + "values": [np.full((3,), 2).astype(np.float32)], "refs": [2.0], }, ] @@ -599,7 +687,8 @@ def test_bias_correction(model_with_parameters): if potential_bias.get_type_name() == "Convert": potential_bias = potential_bias.input_value(0).node assert potential_bias.get_type_name() == "Constant" - assert np.all(potential_bias.get_vector() == bias_reference) + potential_bias_value = get_const_value(potential_bias) + assert np.all(potential_bias_value == bias_reference) def test_no_transformations(): @@ -668,6 +757,11 @@ def test_null_biases_insertion(model_with_parameters): "input_ids": [("Relu_1", 0), ("Transpose", 0)], "output_ids": [("Conv_3", 0), ("Add_2", 0)], }, + { + "model": FPModel(const_dtype=ov.Type.bf16, input_dtype=ov.Type.bf16), + "input_ids": [("MatMul", 0)], + "output_ids": [("MatMul", 0)], + }, ] @@ -729,6 +823,12 @@ def test_stateless_model_extraction(model_with_data): "destination_node_names": ["MatMul_0"], "scale": np.ones((1, 1, 1, 1)), }, + { + "model": FPModel(const_dtype=ov.Type.bf16, input_dtype=ov.Type.bf16).ov_model, + "layers": ["Input"], + "destination_node_names": ["MatMul"], + "scale": np.ones((1, 1, 1, 1)), + }, ] @@ -762,7 +862,9 @@ def test_multiply_insertion(model_with_parameters): assert scale_node.get_type_name() == "Multiply" scale_const = scale_node.input(1).get_source_output().get_node() + assert scale_node.input(0).get_element_type() == scale_const.output(0).get_element_type() + assert scale_const.get_type_name() == "Constant" - scale_const_data = scale_const.data + scale_const_data = get_const_value(scale_const) assert np.all(scale_const_data == scale) diff --git a/tests/openvino/native/test_node_utils.py b/tests/openvino/native/test_node_utils.py index 4fdc7eea444..69ddcc8d146 100644 --- a/tests/openvino/native/test_node_utils.py +++ b/tests/openvino/native/test_node_utils.py @@ -10,32 +10,50 @@ # limitations under the License. import numpy as np +import openvino.runtime as ov import pytest from openvino.runtime import opset13 as opset -from nncf.common.factory import NNCFGraphFactory from nncf.common.graph.graph import NNCFNode from nncf.openvino.graph.layer_attributes import OVLayerAttributes from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype from nncf.openvino.graph.nncf_graph_builder import GraphConverter +from nncf.openvino.graph.node_utils import get_const_value from nncf.openvino.graph.node_utils import get_weight_channel_axes -from nncf.openvino.graph.node_utils import get_weight_value from nncf.openvino.graph.node_utils import get_weighted_layer_attributes from nncf.openvino.graph.node_utils import is_node_with_bias from tests.openvino.native.models import ConvModel from tests.openvino.native.models import ConvNotBiasModel -from tests.openvino.native.models import FPModel from tests.openvino.native.models import MatMul2DModel from tests.openvino.native.models import MatMul2DNotBiasModel -def test_get_weight_value_const_with_convert(): - model = FPModel(const_dtype="FP16").ov_model - nncf_graph = NNCFGraphFactory.create(model) - node_with_weight = nncf_graph.get_node_by_name("MatMul") +@pytest.mark.parametrize( + "precisions", + [ + # base FP32 precision + { + "type_for_const": ov.Type.f32, + "ref_type": np.float32, + }, + # base FP16 precision + { + "type_for_const": ov.Type.f16, + "ref_type": np.float16, + }, + # base BF16 precision should be casted to FP32 + { + "type_for_const": ov.Type.bf16, + "ref_type": np.float32, + }, + ], +) +def test_get_const_value(precisions): + const_data = np.ones((1, 2, 3), dtype=np.float32) + weight_const = opset.constant(const_data, dtype=precisions["type_for_const"]) - actual_value = get_weight_value(node_with_weight, model, port_id=1) - assert actual_value.dtype == np.float16 + const_value = get_const_value(weight_const) + assert const_value.dtype == precisions["ref_type"] @pytest.mark.parametrize(