Skip to content

Commit

Permalink
[PTQ] Extend layers list for BiasCorrection (#2555)
Browse files Browse the repository at this point in the history
### Changes

- Extended OPERATIONS_WITH_BIAS for BiasCorrection algorithm;
- Extended conformance test scope with the `mobilenetv3` +
BiasCorrection;

### Reason for changes

- List extension would allow us to restore accuracy for the specific
models;
- To observe differences between BiasCorrection algorithms;

### Related tickets

- 133198

### Tests

- manual/post_training_quantization/326/ - passed
  • Loading branch information
KodiaqQ authored Mar 20, 2024
1 parent 7898e8c commit f987125
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 5 deletions.
5 changes: 3 additions & 2 deletions nncf/openvino/graph/metatypes/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,13 @@


# Contains the operation metatypes for which bias can be applied.
OPERATIONS_WITH_BIAS = [
# Limited operations scope
OPERATIONS_WITH_BIAS_REDUCED = [
ov_metatypes.OVConvolutionMetatype,
# TODO: add all metatypes with bias
ov_metatypes.OVMatMulMetatype,
]

OPERATIONS_WITH_BIAS = [*OPERATIONS_WITH_BIAS_REDUCED, ov_metatypes.OVDepthwiseConvolutionMetatype]

CONV_OPERATIONS = [
ov_metatypes.OVConvolutionMetatype,
Expand Down
10 changes: 8 additions & 2 deletions nncf/openvino/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,23 @@
InplaceInsertionFnType = Callable[[ov.Node, int], ov.Node]


def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
def is_node_with_bias(
node: NNCFNode, nncf_graph: NNCFGraph, metatypes_with_bias: Optional[List[OVOpMetatype]] = None
) -> bool:
"""
Checks if the node has a bias or not.
:param node: The node to check.
:param nncf_graph: NNCFGraph instance.
:param metatypes_with_bias: List of the metatypes that contains biases.
:return: Return `True` if `node` corresponds to the operation
with bias (bias is added to the output tensor of that operation),
`False` otherwise.
"""
if node.metatype not in OPERATIONS_WITH_BIAS:
if metatypes_with_bias is None:
metatypes_with_bias = OPERATIONS_WITH_BIAS

if node.metatype not in metatypes_with_bias:
return False

add_node = nncf_graph.get_next_nodes(node)[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.tensor import Tensor
from nncf.openvino.graph.metatypes.groups import FAKE_QUANTIZE_OPERATIONS
from nncf.openvino.graph.metatypes.groups import OPERATIONS_WITH_BIAS_REDUCED
from nncf.openvino.graph.node_utils import get_bias_value
from nncf.openvino.graph.node_utils import is_node_with_bias
from nncf.openvino.graph.transformations.command_creation import OVCommandCreator
Expand Down Expand Up @@ -95,7 +96,7 @@ def process_model_output(raw_data: Dict, output_name: str) -> Tensor:

@staticmethod
def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
return is_node_with_bias(node, nncf_graph)
return is_node_with_bias(node, nncf_graph, OPERATIONS_WITH_BIAS_REDUCED)

@staticmethod
def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[str, str]:
Expand Down
6 changes: 6 additions & 0 deletions tests/post_training/data/ptq_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ timm/mobilenetv3_small_050_backend_OV:
metric_value: 0.42184
timm/mobilenetv3_small_050_backend_TORCH:
metric_value: 0.4291
timm/mobilenetv3_small_050_BC_backend_FP32:
metric_value: 0.57906
timm/mobilenetv3_small_050_BC_backend_ONNX:
metric_value: 0.56496
timm/mobilenetv3_small_050_BC_backend_OV:
metric_value: 0.56476
timm/regnetx_002_backend_CUDA_TORCH:
metric_value: 0.67452
timm/regnetx_002_backend_FP32:
Expand Down
10 changes: 10 additions & 0 deletions tests/post_training/model_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,16 @@
},
"backends": ALL_PTQ_BACKENDS,
},
{
"reported_name": "timm/mobilenetv3_small_050_BC",
"model_id": "mobilenetv3_small_050",
"pipeline_cls": ImageClassificationTimm,
"compression_params": {
"preset": QuantizationPreset.MIXED,
"fast_bias_correction": False,
},
"backends": [BackendType.ONNX, BackendType.OV],
},
{
"reported_name": "timm/regnetx_002",
"model_id": "regnetx_002",
Expand Down

0 comments on commit f987125

Please sign in to comment.