diff --git a/nncf/openvino/graph/metatypes/groups.py b/nncf/openvino/graph/metatypes/groups.py index eed716f9f32..38ffff12753 100644 --- a/nncf/openvino/graph/metatypes/groups.py +++ b/nncf/openvino/graph/metatypes/groups.py @@ -183,12 +183,13 @@ # Contains the operation metatypes for which bias can be applied. -OPERATIONS_WITH_BIAS = [ +# Limited operations scope +OPERATIONS_WITH_BIAS_REDUCED = [ ov_metatypes.OVConvolutionMetatype, - # TODO: add all metatypes with bias ov_metatypes.OVMatMulMetatype, ] +OPERATIONS_WITH_BIAS = [*OPERATIONS_WITH_BIAS_REDUCED, ov_metatypes.OVDepthwiseConvolutionMetatype] CONV_OPERATIONS = [ ov_metatypes.OVConvolutionMetatype, diff --git a/nncf/openvino/graph/node_utils.py b/nncf/openvino/graph/node_utils.py index 59d5e58c45b..ae9bf264ea5 100644 --- a/nncf/openvino/graph/node_utils.py +++ b/nncf/openvino/graph/node_utils.py @@ -44,17 +44,23 @@ InplaceInsertionFnType = Callable[[ov.Node, int], ov.Node] -def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: +def is_node_with_bias( + node: NNCFNode, nncf_graph: NNCFGraph, metatypes_with_bias: Optional[List[OVOpMetatype]] = None +) -> bool: """ Checks if the node has a bias or not. :param node: The node to check. :param nncf_graph: NNCFGraph instance. + :param metatypes_with_bias: List of the metatypes that contains biases. :return: Return `True` if `node` corresponds to the operation with bias (bias is added to the output tensor of that operation), `False` otherwise. """ - if node.metatype not in OPERATIONS_WITH_BIAS: + if metatypes_with_bias is None: + metatypes_with_bias = OPERATIONS_WITH_BIAS + + if node.metatype not in metatypes_with_bias: return False add_node = nncf_graph.get_next_nodes(node)[0] diff --git a/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py b/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py index 58cc3f04ff2..159b2111f1a 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py @@ -20,6 +20,7 @@ from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.tensor import Tensor from nncf.openvino.graph.metatypes.groups import FAKE_QUANTIZE_OPERATIONS +from nncf.openvino.graph.metatypes.groups import OPERATIONS_WITH_BIAS_REDUCED from nncf.openvino.graph.node_utils import get_bias_value from nncf.openvino.graph.node_utils import is_node_with_bias from nncf.openvino.graph.transformations.command_creation import OVCommandCreator @@ -95,7 +96,7 @@ def process_model_output(raw_data: Dict, output_name: str) -> Tensor: @staticmethod def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: - return is_node_with_bias(node, nncf_graph) + return is_node_with_bias(node, nncf_graph, OPERATIONS_WITH_BIAS_REDUCED) @staticmethod def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[str, str]: diff --git a/tests/post_training/data/ptq_reference_data.yaml b/tests/post_training/data/ptq_reference_data.yaml index 3605403a440..6792bbe809e 100644 --- a/tests/post_training/data/ptq_reference_data.yaml +++ b/tests/post_training/data/ptq_reference_data.yaml @@ -146,6 +146,12 @@ timm/mobilenetv3_small_050_backend_OV: metric_value: 0.42184 timm/mobilenetv3_small_050_backend_TORCH: metric_value: 0.4291 +timm/mobilenetv3_small_050_BC_backend_FP32: + metric_value: 0.57906 +timm/mobilenetv3_small_050_BC_backend_ONNX: + metric_value: 0.56496 +timm/mobilenetv3_small_050_BC_backend_OV: + metric_value: 0.56476 timm/regnetx_002_backend_CUDA_TORCH: metric_value: 0.67452 timm/regnetx_002_backend_FP32: diff --git a/tests/post_training/model_scope.py b/tests/post_training/model_scope.py index e643b68087d..69c1b9890e0 100644 --- a/tests/post_training/model_scope.py +++ b/tests/post_training/model_scope.py @@ -188,6 +188,16 @@ }, "backends": ALL_PTQ_BACKENDS, }, + { + "reported_name": "timm/mobilenetv3_small_050_BC", + "model_id": "mobilenetv3_small_050", + "pipeline_cls": ImageClassificationTimm, + "compression_params": { + "preset": QuantizationPreset.MIXED, + "fast_bias_correction": False, + }, + "backends": [BackendType.ONNX, BackendType.OV], + }, { "reported_name": "timm/regnetx_002", "model_id": "regnetx_002",