Skip to content

Commit

Permalink
Ignored scope matching working with multiple NNCFGraphs (#2723)
Browse files Browse the repository at this point in the history
### Changes

Add support of IgnoredScope matching many NNCFGraphs.
Change unmatched error logging when `strict=True`: 
  before: the first unmatched rule is logged
  after: all unmatched rules are logged.
IgnoredScope validation for OpenVINO models with IF operation was
updated with new changes - the validation is put before running the
quantization pipeline.

### Reason for changes

IgnoredScope validation failure for a OV model with IF operation.

### Related tickets

135110

### Tests

TBD
  • Loading branch information
kshpv authored Jun 18, 2024
1 parent 0c22c2c commit 5384965
Show file tree
Hide file tree
Showing 12 changed files with 346 additions and 91 deletions.
6 changes: 4 additions & 2 deletions nncf/common/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,10 @@ def get_nodes_by_types(self, type_list: List[str]) -> List[NNCFNode]:
:param type_list: List of types to look for.
:return: List of nodes with provided types.
"""
all_nodes_of_type = []
for nncf_node in self.get_all_nodes():
all_nodes_of_type: List[NNCFNode] = []
if not type_list:
return all_nodes_of_type
for nncf_node in self.nodes.values():
if nncf_node.node_type in type_list:
all_nodes_of_type.append(nncf_node)
return all_nodes_of_type
Expand Down
24 changes: 12 additions & 12 deletions nncf/openvino/quantization/quantize_ifmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
# limitations under the License.

from itertools import islice
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import openvino.runtime as ov

from nncf import Dataset
from nncf.common import factory
from nncf.common.engine import Engine
from nncf.common.factory import NNCFGraphFactory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.model_transformer import ModelTransformer
Expand Down Expand Up @@ -135,26 +134,27 @@ def _add_outputs_before_if_node(model_transformer: ModelTransformer, model: ov.M
def apply_algorithm_if_bodies(
algorithm: Algorithm,
parent_model: ov.Model,
parent_graph: NNCFGraph,
graphs: Dict[str, NNCFGraph],
graph_id: str,
parent_dataset: Dataset,
subset_size: int,
current_model_num: int,
all_models_num: int,
parent_statistic_points: Optional[StatisticPointsContainer] = None,
) -> Tuple[ov.Model, int]:
"""
Applies an algorithm recursievley to each bodies of If node.
Applies an algorithm recursively to each bodies of If node.
:param parent_model: Model to apply algorithm.
:param parent_graph: Graph of a model.
:param graphs: All model graphs.
:param graph_id: Current graph id in the graphs.
:param parent_dataset: Dataset for algorithm.
:param subset_size: Size of a dataset to use for calibration.
:param current_model_num: Current model number.
:param all_models_num: All model numbers.
:param parent_statistic_points: Statistics points for algorithm.
:return: A model for every bodies of If nodes the algorithm was applied and the latest model number.
"""
nncf_logger.info(f"Iteration [{current_model_num}/{all_models_num}] ...")
nncf_logger.info(f"Iteration [{current_model_num}/{len(graphs)}] ...")
parent_graph = graphs[graph_id]
quantized_model = algorithm.apply(parent_model, parent_graph, parent_statistic_points, parent_dataset)
if get_number_if_op(parent_model) == 0:
return quantized_model, current_model_num
Expand Down Expand Up @@ -183,20 +183,20 @@ def apply_algorithm_if_bodies(
then_quantized_model, current_model_num = apply_algorithm_if_bodies(
algorithm,
then_model,
NNCFGraphFactory.create(then_model),
graphs,
if_node.node_name + "_then",
then_dataset,
subset_size,
current_model_num + 1,
all_models_num,
)
else_quantized_model, current_model_num = apply_algorithm_if_bodies(
algorithm,
else_model,
NNCFGraphFactory.create(else_model),
graphs,
if_node.node_name + "_else",
else_dataset,
subset_size,
current_model_num + 1,
all_models_num,
)
model_transformer_int8 = factory.ModelTransformerFactory.create(quantized_model)
quantized_model = _update_if_body(model_transformer_int8, if_node, True, then_quantized_model)
Expand Down
37 changes: 32 additions & 5 deletions nncf/openvino/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from nncf.common.quantization.structs import QuantizationPreset
from nncf.data import Dataset
from nncf.openvino.graph.metatypes.groups import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS
from nncf.openvino.graph.metatypes.openvino_metatypes import OVIfMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import get_node_metatype
from nncf.openvino.graph.model_utils import remove_friendly_name_duplicates
from nncf.openvino.graph.nncf_graph_builder import GraphConverter
from nncf.openvino.graph.node_utils import get_number_if_op
Expand All @@ -42,10 +44,13 @@
from nncf.quantization.algorithms.accuracy_control.evaluator import Evaluator
from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization
from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression
from nncf.quantization.quantize_model import BATCHWISE_STATISTICS_WARNING
from nncf.quantization.quantize_model import is_model_no_batchwise_support
from nncf.quantization.quantize_model import quantize_with_tune_hyperparams
from nncf.quantization.quantize_model import warning_model_no_batchwise_support
from nncf.quantization.telemetry_extractors import CompressionStartedWithQuantizeApi
from nncf.scopes import IgnoredScope
from nncf.scopes import validate_ignored_scope
from nncf.telemetry.decorator import tracked_function
from nncf.telemetry.events import NNCF_OV_CATEGORY

Expand All @@ -72,6 +77,28 @@ def native_quantize_if_op_impl(
raise NotImplementedError(
"The BiasCorrection algorithm is not supported for OpenVINO models with If operation."
)
graphs = {}

def _extract_all_subgraphs(model: ov.Model, current_id: str) -> None:
"""
Creates all inner subgraphs from If nodes and adds them to 'graphs'.
:param model: Model.
:param current_id: Current graph id.
"""
graphs[current_id] = NNCFGraphFactory.create(model)
for op in model.get_ops():
if get_node_metatype(op) == OVIfMetatype:
_extract_all_subgraphs(op.get_function(0), op.get_friendly_name() + "_then")
_extract_all_subgraphs(op.get_function(1), op.get_friendly_name() + "_else")

main_model_graph_id = "main_model_graph"
_extract_all_subgraphs(model, main_model_graph_id)
if ignored_scope and ignored_scope.validate:
validate_ignored_scope(ignored_scope, graphs.values())
ignored_scope = IgnoredScope(
ignored_scope.names, ignored_scope.patterns, ignored_scope.types, ignored_scope.subgraphs, validate=False
)
quantization_algorithm = PostTrainingQuantization(
mode=mode,
preset=preset,
Expand All @@ -82,17 +109,17 @@ def native_quantize_if_op_impl(
ignored_scope=ignored_scope,
advanced_parameters=advanced_parameters,
)

graph = GraphConverter.create_nncf_graph(model)
warning_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS)
for graph in graphs.values():
if is_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS):
nncf_logger.warning(BATCHWISE_STATISTICS_WARNING)
break
if_ops_number = get_number_if_op(model)
all_models_number = if_ops_number * 2 + 1
nncf_logger.info(
f"The model consists of {if_ops_number} If node(-s) with then and else bodies. \
Main model and all If bodies will be quantized recursively."
)
quantized_model, _ = apply_algorithm_if_bodies(
quantization_algorithm, model, graph, calibration_dataset, subset_size, 1, all_models_number
quantization_algorithm, model, graphs, main_model_graph_id, calibration_dataset, subset_size, 1
)

if is_weight_compression_needed(advanced_parameters):
Expand Down
29 changes: 24 additions & 5 deletions nncf/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
BATCHWISE_STATISTICS_WARNING = (
"For the particular model the batchwise statistics collection can lead to inaccurate statistics. "
"If the accuracy degradation after compression is unsatisfactory, then "
"the recomendation is to turn off batchwise statistics. If the results are still unsatisfactory, "
"the recommendation is to turn off batchwise statistics. If the results are still unsatisfactory, "
"provide a dataloader with batch_size = 1 to the calibration dataset."
)

Expand All @@ -54,19 +54,38 @@ def warning_model_no_batchwise_support(
no_batchwise_support_metatypes: List[OperatorMetatype],
) -> None:
"""
Prints the warning message if batchwise statistics could lead to a significant accuracy drop.
Logs when is_model_no_batchwise_support(...) returns True.
:param graph: Model's NNCFGraph.
:param advanced_quantization_parameters: AdvancedQuantizationParameters.
:param model_type: Model type algorithm option.
:param no_batchwise_support_metatypes: Meatypes having no batchwise statistics support.
"""
if (
if is_model_no_batchwise_support(
graph, advanced_quantization_parameters, model_type, no_batchwise_support_metatypes
):
nncf_logger.warning(BATCHWISE_STATISTICS_WARNING)


def is_model_no_batchwise_support(
graph: NNCFGraph,
advanced_quantization_parameters: Optional[AdvancedQuantizationParameters],
model_type: ModelType,
no_batchwise_support_metatypes: List[OperatorMetatype],
) -> None:
"""
Returns True if batchwise statistics could lead to a significant accuracy drop.
:param graph: Model's NNCFGraph.
:param advanced_quantization_parameters: AdvancedQuantizationParameters.
:param model_type: Model type algorithm option.
:param no_batchwise_support_metatypes: Meatypes having no batchwise statistics support.
"""
return (
advanced_quantization_parameters
and advanced_quantization_parameters.batchwise_statistics
and (graph.get_nodes_by_metatypes(no_batchwise_support_metatypes) or model_type == ModelType.TRANSFORMER)
):
nncf_logger.warning(BATCHWISE_STATISTICS_WARNING)
)


def _update_advanced_quantization_parameters(
Expand Down
Loading

0 comments on commit 5384965

Please sign in to comment.