From 5384965ae7fc8c6a211777351f2b5520bea3c77e Mon Sep 17 00:00:00 2001 From: Aleksei Kashapov Date: Tue, 18 Jun 2024 11:52:34 +0200 Subject: [PATCH] Ignored scope matching working with multiple NNCFGraphs (#2723) ### Changes Add support of IgnoredScope matching many NNCFGraphs. Change unmatched error logging when `strict=True`:   before: the first unmatched rule is logged   after: all unmatched rules are logged. IgnoredScope validation for OpenVINO models with IF operation was updated with new changes - the validation is put before running the quantization pipeline. ### Reason for changes IgnoredScope validation failure for a OV model with IF operation. ### Related tickets 135110 ### Tests TBD --- nncf/common/graph/graph.py | 6 +- .../openvino/quantization/quantize_ifmodel.py | 24 +-- nncf/openvino/quantization/quantize_model.py | 37 +++- nncf/quantization/quantize_model.py | 29 ++- nncf/scopes.py | 200 ++++++++++++------ tests/common/test_scopes.py | 37 ++++ .../IfModel_2_ignored_scope_else.dot | 13 ++ .../IfModel_2_ignored_scope_main.dot | 9 + .../IfModel_2_ignored_scope_then.dot | 35 +++ tests/openvino/native/models.py | 18 ++ .../native/quantization/test_graphs.py | 27 +++ .../test_templates/test_ptq_params.py | 2 +- 12 files changed, 346 insertions(+), 91 deletions(-) create mode 100644 tests/openvino/native/data/2024.1/reference_graphs/quantized/IfModel_2_ignored_scope_else.dot create mode 100644 tests/openvino/native/data/2024.1/reference_graphs/quantized/IfModel_2_ignored_scope_main.dot create mode 100644 tests/openvino/native/data/2024.1/reference_graphs/quantized/IfModel_2_ignored_scope_then.dot diff --git a/nncf/common/graph/graph.py b/nncf/common/graph/graph.py index 13ea932d921..128c4cc3e3f 100644 --- a/nncf/common/graph/graph.py +++ b/nncf/common/graph/graph.py @@ -237,8 +237,10 @@ def get_nodes_by_types(self, type_list: List[str]) -> List[NNCFNode]: :param type_list: List of types to look for. :return: List of nodes with provided types. """ - all_nodes_of_type = [] - for nncf_node in self.get_all_nodes(): + all_nodes_of_type: List[NNCFNode] = [] + if not type_list: + return all_nodes_of_type + for nncf_node in self.nodes.values(): if nncf_node.node_type in type_list: all_nodes_of_type.append(nncf_node) return all_nodes_of_type diff --git a/nncf/openvino/quantization/quantize_ifmodel.py b/nncf/openvino/quantization/quantize_ifmodel.py index 2dc007201d5..8f106d7387c 100644 --- a/nncf/openvino/quantization/quantize_ifmodel.py +++ b/nncf/openvino/quantization/quantize_ifmodel.py @@ -10,14 +10,13 @@ # limitations under the License. from itertools import islice -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import openvino.runtime as ov from nncf import Dataset from nncf.common import factory from nncf.common.engine import Engine -from nncf.common.factory import NNCFGraphFactory from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.graph import NNCFNode from nncf.common.graph.model_transformer import ModelTransformer @@ -135,26 +134,27 @@ def _add_outputs_before_if_node(model_transformer: ModelTransformer, model: ov.M def apply_algorithm_if_bodies( algorithm: Algorithm, parent_model: ov.Model, - parent_graph: NNCFGraph, + graphs: Dict[str, NNCFGraph], + graph_id: str, parent_dataset: Dataset, subset_size: int, current_model_num: int, - all_models_num: int, parent_statistic_points: Optional[StatisticPointsContainer] = None, ) -> Tuple[ov.Model, int]: """ - Applies an algorithm recursievley to each bodies of If node. + Applies an algorithm recursively to each bodies of If node. :param parent_model: Model to apply algorithm. - :param parent_graph: Graph of a model. + :param graphs: All model graphs. + :param graph_id: Current graph id in the graphs. :param parent_dataset: Dataset for algorithm. :param subset_size: Size of a dataset to use for calibration. :param current_model_num: Current model number. - :param all_models_num: All model numbers. :param parent_statistic_points: Statistics points for algorithm. :return: A model for every bodies of If nodes the algorithm was applied and the latest model number. """ - nncf_logger.info(f"Iteration [{current_model_num}/{all_models_num}] ...") + nncf_logger.info(f"Iteration [{current_model_num}/{len(graphs)}] ...") + parent_graph = graphs[graph_id] quantized_model = algorithm.apply(parent_model, parent_graph, parent_statistic_points, parent_dataset) if get_number_if_op(parent_model) == 0: return quantized_model, current_model_num @@ -183,20 +183,20 @@ def apply_algorithm_if_bodies( then_quantized_model, current_model_num = apply_algorithm_if_bodies( algorithm, then_model, - NNCFGraphFactory.create(then_model), + graphs, + if_node.node_name + "_then", then_dataset, subset_size, current_model_num + 1, - all_models_num, ) else_quantized_model, current_model_num = apply_algorithm_if_bodies( algorithm, else_model, - NNCFGraphFactory.create(else_model), + graphs, + if_node.node_name + "_else", else_dataset, subset_size, current_model_num + 1, - all_models_num, ) model_transformer_int8 = factory.ModelTransformerFactory.create(quantized_model) quantized_model = _update_if_body(model_transformer_int8, if_node, True, then_quantized_model) diff --git a/nncf/openvino/quantization/quantize_model.py b/nncf/openvino/quantization/quantize_model.py index 2fdaf91f1f7..d2db3eaf947 100644 --- a/nncf/openvino/quantization/quantize_model.py +++ b/nncf/openvino/quantization/quantize_model.py @@ -20,6 +20,8 @@ from nncf.common.quantization.structs import QuantizationPreset from nncf.data import Dataset from nncf.openvino.graph.metatypes.groups import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS +from nncf.openvino.graph.metatypes.openvino_metatypes import OVIfMetatype +from nncf.openvino.graph.metatypes.openvino_metatypes import get_node_metatype from nncf.openvino.graph.model_utils import remove_friendly_name_duplicates from nncf.openvino.graph.nncf_graph_builder import GraphConverter from nncf.openvino.graph.node_utils import get_number_if_op @@ -42,10 +44,13 @@ from nncf.quantization.algorithms.accuracy_control.evaluator import Evaluator from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression +from nncf.quantization.quantize_model import BATCHWISE_STATISTICS_WARNING +from nncf.quantization.quantize_model import is_model_no_batchwise_support from nncf.quantization.quantize_model import quantize_with_tune_hyperparams from nncf.quantization.quantize_model import warning_model_no_batchwise_support from nncf.quantization.telemetry_extractors import CompressionStartedWithQuantizeApi from nncf.scopes import IgnoredScope +from nncf.scopes import validate_ignored_scope from nncf.telemetry.decorator import tracked_function from nncf.telemetry.events import NNCF_OV_CATEGORY @@ -72,6 +77,28 @@ def native_quantize_if_op_impl( raise NotImplementedError( "The BiasCorrection algorithm is not supported for OpenVINO models with If operation." ) + graphs = {} + + def _extract_all_subgraphs(model: ov.Model, current_id: str) -> None: + """ + Creates all inner subgraphs from If nodes and adds them to 'graphs'. + + :param model: Model. + :param current_id: Current graph id. + """ + graphs[current_id] = NNCFGraphFactory.create(model) + for op in model.get_ops(): + if get_node_metatype(op) == OVIfMetatype: + _extract_all_subgraphs(op.get_function(0), op.get_friendly_name() + "_then") + _extract_all_subgraphs(op.get_function(1), op.get_friendly_name() + "_else") + + main_model_graph_id = "main_model_graph" + _extract_all_subgraphs(model, main_model_graph_id) + if ignored_scope and ignored_scope.validate: + validate_ignored_scope(ignored_scope, graphs.values()) + ignored_scope = IgnoredScope( + ignored_scope.names, ignored_scope.patterns, ignored_scope.types, ignored_scope.subgraphs, validate=False + ) quantization_algorithm = PostTrainingQuantization( mode=mode, preset=preset, @@ -82,17 +109,17 @@ def native_quantize_if_op_impl( ignored_scope=ignored_scope, advanced_parameters=advanced_parameters, ) - - graph = GraphConverter.create_nncf_graph(model) - warning_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS) + for graph in graphs.values(): + if is_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS): + nncf_logger.warning(BATCHWISE_STATISTICS_WARNING) + break if_ops_number = get_number_if_op(model) - all_models_number = if_ops_number * 2 + 1 nncf_logger.info( f"The model consists of {if_ops_number} If node(-s) with then and else bodies. \ Main model and all If bodies will be quantized recursively." ) quantized_model, _ = apply_algorithm_if_bodies( - quantization_algorithm, model, graph, calibration_dataset, subset_size, 1, all_models_number + quantization_algorithm, model, graphs, main_model_graph_id, calibration_dataset, subset_size, 1 ) if is_weight_compression_needed(advanced_parameters): diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index 46426d77c21..5632bedf5b6 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -42,7 +42,7 @@ BATCHWISE_STATISTICS_WARNING = ( "For the particular model the batchwise statistics collection can lead to inaccurate statistics. " "If the accuracy degradation after compression is unsatisfactory, then " - "the recomendation is to turn off batchwise statistics. If the results are still unsatisfactory, " + "the recommendation is to turn off batchwise statistics. If the results are still unsatisfactory, " "provide a dataloader with batch_size = 1 to the calibration dataset." ) @@ -54,19 +54,38 @@ def warning_model_no_batchwise_support( no_batchwise_support_metatypes: List[OperatorMetatype], ) -> None: """ - Prints the warning message if batchwise statistics could lead to a significant accuracy drop. + Logs when is_model_no_batchwise_support(...) returns True. :param graph: Model's NNCFGraph. :param advanced_quantization_parameters: AdvancedQuantizationParameters. :param model_type: Model type algorithm option. :param no_batchwise_support_metatypes: Meatypes having no batchwise statistics support. """ - if ( + if is_model_no_batchwise_support( + graph, advanced_quantization_parameters, model_type, no_batchwise_support_metatypes + ): + nncf_logger.warning(BATCHWISE_STATISTICS_WARNING) + + +def is_model_no_batchwise_support( + graph: NNCFGraph, + advanced_quantization_parameters: Optional[AdvancedQuantizationParameters], + model_type: ModelType, + no_batchwise_support_metatypes: List[OperatorMetatype], +) -> None: + """ + Returns True if batchwise statistics could lead to a significant accuracy drop. + + :param graph: Model's NNCFGraph. + :param advanced_quantization_parameters: AdvancedQuantizationParameters. + :param model_type: Model type algorithm option. + :param no_batchwise_support_metatypes: Meatypes having no batchwise statistics support. + """ + return ( advanced_quantization_parameters and advanced_quantization_parameters.batchwise_statistics and (graph.get_nodes_by_metatypes(no_batchwise_support_metatypes) or model_type == ModelType.TRANSFORMER) - ): - nncf_logger.warning(BATCHWISE_STATISTICS_WARNING) + ) def _update_advanced_quantization_parameters( diff --git a/nncf/scopes.py b/nncf/scopes.py index c7cf631ad57..a2cee4b319b 100644 --- a/nncf/scopes.py +++ b/nncf/scopes.py @@ -12,7 +12,7 @@ import re from dataclasses import dataclass from dataclasses import field -from typing import List, Optional, Set +from typing import Dict, List, Optional, Set, Tuple import nncf from nncf.common.graph.graph import NNCFGraph @@ -111,6 +111,23 @@ class IgnoredScope: validate: bool = True +def get_difference_ignored_scope(ignored_scope_1: IgnoredScope, ignored_scope_2: IgnoredScope) -> IgnoredScope: + """ + Returns ignored scope with rules from 'ignored_scope_1' not presented at 'ignored_scope_2' + + :param ignored_scope_1: First ignored scope. + :param ignored_scope_2: Second ignored scope. + :return: Ignored scope. + """ + return IgnoredScope( + names=list(set(ignored_scope_1.names) - set(ignored_scope_2.names)), + patterns=list(set(ignored_scope_1.patterns) - set(ignored_scope_2.patterns)), + types=list(set(ignored_scope_1.types) - set(ignored_scope_2.types)), + subgraphs=[subgraph for subgraph in ignored_scope_1.subgraphs if subgraph not in ignored_scope_2.subgraphs], + validate=ignored_scope_1.validate, + ) + + def convert_ignored_scope_to_list(ignored_scope: Optional[IgnoredScope]) -> List[str]: """ Convert the contents of the `IgnoredScope` class to the legacy ignored @@ -130,78 +147,129 @@ def convert_ignored_scope_to_list(ignored_scope: Optional[IgnoredScope]) -> List return results +def get_matched_ignored_scope_info( + ignored_scope: IgnoredScope, nncf_graphs: List[NNCFGraph] +) -> Tuple[IgnoredScope, Dict[str, Set[str]]]: + """ + Returns matched ignored scope for provided graphs along with all found matches. + The resulted ignored scope consist of all matched rules. + The found matches consist of a dictionary with a rule name as a key and matched node names as a value. + + :param ignored_scope: Ignored scope instance. + :param nncf_graphs: Graphs. + :returns: Matched ignored scope along with all matches. + """ + names, patterns, types, subgraphs_numbers = set(), set(), set(), set() + matches = {"names": names, "patterns": set(), "types": set(), "subgraphs": set()} + + for graph in nncf_graphs: + if ignored_scope.names or ignored_scope.patterns: + node_names = set(node.node_name for node in graph.nodes.values()) + + for ignored_node_name in filter(lambda name: name in node_names, ignored_scope.names): + names.add(ignored_node_name) + + for str_pattern in ignored_scope.patterns: + pattern = re.compile(str_pattern) + pattern_matched_names = set(filter(pattern.match, node_names)) + if pattern_matched_names: + matches["patterns"].update(pattern_matched_names) + patterns.add(str_pattern) + + for node in graph.get_nodes_by_types(ignored_scope.types): + matches["types"].add(node.node_name) + types.add(node.node_type) + + for i, subgraph in enumerate(ignored_scope.subgraphs): + names_from_subgraph = get_ignored_node_names_from_subgraph(graph, subgraph) + if names_from_subgraph: + matches["subgraphs"].update(names_from_subgraph) + subgraphs_numbers.add(i) + + matched_ignored_scope = IgnoredScope( + names=list(names), + patterns=list(patterns), + types=list(types), + subgraphs=[subgraph for i, subgraph in enumerate(ignored_scope.subgraphs) if i in subgraphs_numbers], + validate=ignored_scope.validate, + ) + return matched_ignored_scope, matches + + +def _info_matched_ignored_scope(matches: Dict[str, Set[str]]) -> None: + """ + Log matches. + + :param matches: Matches. + """ + for rule_type, rules in matches.items(): + if rules: + nncf_logger.info(f"{len(rules)} ignored nodes were found by {rule_type} in the NNCFGraph") + + +def _error_unmatched_ignored_scope(unmatched_ignored_scope: IgnoredScope) -> str: + """ + Returns an error message for unmatched ignored scope. + + :param unmatched_ignored_scope: Unmatched ignored scope. + :return str: Error message. + """ + err_msg = "\n" + for ignored_type in ("names", "types", "patterns"): + unmatched_rules = getattr(unmatched_ignored_scope, ignored_type) + if unmatched_rules: + err_msg += f"Ignored nodes that matches {ignored_type} {unmatched_rules} were not found in the NNCFGraph.\n" + for subgraph in unmatched_ignored_scope.subgraphs: + err_msg += ( + f"Ignored nodes that matches subgraph with input names {subgraph.inputs} " + f"and output names {subgraph.outputs} were not found in the NNCFGraph.\n" + ) + return err_msg + + +def _check_ignored_scope_strictly_matched(ignored_scope: IgnoredScope, matched_ignored_scope: IgnoredScope) -> None: + """ + Passes when ignored_scope and matched_ignored_scope are equal, otherwise - raises ValidationError. + + :param ignored_scope: Ignored scope. + :param matched_ignored_scope: Matched ignored scope. + """ + unmatched_ignored_scope = get_difference_ignored_scope(ignored_scope, matched_ignored_scope) + if ( + unmatched_ignored_scope.names + or unmatched_ignored_scope.types + or unmatched_ignored_scope.patterns + or unmatched_ignored_scope.subgraphs + ): + raise nncf.ValidationError(_error_unmatched_ignored_scope(unmatched_ignored_scope)) + + def get_ignored_node_names_from_ignored_scope( ignored_scope: IgnoredScope, nncf_graph: NNCFGraph, strict: bool = True ) -> Set[str]: """ Returns ignored names according to ignored scope and NNCFGraph. - If strict is True, raises RuntimeError if any ignored name is not found in the NNCFGraph or - any ignored pattern or any ignored type match 0 nodes in the NNCFGraph. + If strict is True, raises nncf.ValidationError if any ignored rule was not matched. If strict is False, returns all possible matches. - :param ignored_scope: Given ignored scope instance. - :param nncf_graph: Given NNCFGraph. + :param ignored_scope: Ignored scope. + :param nncf_graph: Graph. :param strict: Whether all ignored_scopes must match at least one node or not. - :returns: NNCF node names from given NNCFGraph specified in given ignored scope. + :return: NNCF node names from given NNCFGraph specified in given ignored scope. """ - error_msg = ( - "Refer to the original_graph.dot to discover the operations" - "in the model currently visible to NNCF and specify the ignored/target" - " scopes in terms of the names there." - ) + matched_ignored_scope, matches = get_matched_ignored_scope_info(ignored_scope, [nncf_graph]) + if strict: + _check_ignored_scope_strictly_matched(ignored_scope, matched_ignored_scope) + _info_matched_ignored_scope(matches) + return {name for match in matches.values() for name in match} - node_names = [node.node_name for node in nncf_graph.get_all_nodes()] - matched_by_names = [] - if ignored_scope.names: - for ignored_node_name in ignored_scope.names: - if ignored_node_name in node_names: - matched_by_names.append(ignored_node_name) - if strict and len(ignored_scope.names) != len(matched_by_names): - skipped_names = set(ignored_scope.names) - set(matched_by_names) - raise nncf.ValidationError( - f"Ignored nodes with name {list(skipped_names)} were not found in the NNCFGraph. " + error_msg - ) - nncf_logger.info(f"{len(matched_by_names)} ignored nodes were found by name in the NNCFGraph") - - matched_by_patterns = [] - if ignored_scope.patterns: - not_matched_patterns = [] - for str_pattern in ignored_scope.patterns: - pattern = re.compile(str_pattern) - matches = list(filter(pattern.match, node_names)) - if not matches: - not_matched_patterns.append(str_pattern) - matched_by_patterns.extend(matches) - if strict and not_matched_patterns: - raise nncf.ValidationError( - f"No matches for ignored patterns {not_matched_patterns} in the NNCFGraph. " + error_msg - ) - nncf_logger.info(f"{len(matched_by_patterns)} ignored nodes were found by patterns in the NNCFGraph") - - matched_by_types = [] - if ignored_scope.types: - types_found = set() - for node in nncf_graph.get_all_nodes(): - if node.node_type in ignored_scope.types: - types_found.add(node.node_type) - matched_by_types.append(node.node_name) - not_matched_types = set(ignored_scope.types) - types_found - if strict and not_matched_types: - raise nncf.ValidationError( - f"Nodes with ignored types {list(not_matched_types)} were not found in the NNCFGraph. " + error_msg - ) - nncf_logger.info(f"{len(matched_by_types)} ignored nodes were found by types in the NNCFGraph") - - matched_by_subgraphs = [] - if ignored_scope.subgraphs: - for subgraph in ignored_scope.subgraphs: - names_from_subgraph = get_ignored_node_names_from_subgraph(nncf_graph, subgraph) - if strict and not names_from_subgraph: - raise nncf.ValidationError( - f"Ignored subgraph with input names {subgraph.inputs} and output names {subgraph.outputs} " - "was not found in the NNCFGraph. " + error_msg - ) - - matched_by_subgraphs.extend(names_from_subgraph) - - return set(matched_by_names + matched_by_types + matched_by_patterns + matched_by_subgraphs) + +def validate_ignored_scope(ignored_scope: IgnoredScope, nncf_graphs: List[NNCFGraph]) -> None: + """ + Passes whether all rules at 'ignored_scope' have matches at provided graphs, otherwise - raises ValidationError. + + :param ignored_scope: Ignored scope. + :param nncf_graphs: Graphs. + """ + matched_ignored_scope, _ = get_matched_ignored_scope_info(ignored_scope, nncf_graphs) + _check_ignored_scope_strictly_matched(ignored_scope, matched_ignored_scope) diff --git a/tests/common/test_scopes.py b/tests/common/test_scopes.py index 0ae8ea78ba2..abb12a0e68c 100644 --- a/tests/common/test_scopes.py +++ b/tests/common/test_scopes.py @@ -13,6 +13,8 @@ from nncf.common.graph import NNCFNode from nncf.common.scopes import get_not_matched_scopes from nncf.scopes import IgnoredScope +from nncf.scopes import Subgraph +from nncf.scopes import get_difference_ignored_scope @pytest.mark.parametrize( @@ -38,3 +40,38 @@ def test_get_not_matched_scopes(scope, ref): ] not_matched = get_not_matched_scopes(scope, node_lists) assert not set(not_matched) - set(ref) + + +@pytest.mark.parametrize( + "scope_1, scope_2, ref", + ( + ( + IgnoredScope( + names=["A_name", "B_name"], + patterns=["A_pattern", "B_pattern"], + types=["A_type", "B_type"], + subgraphs=[ + Subgraph(inputs=["A_input"], outputs=["A_output"]), + Subgraph(inputs=["B_input"], outputs=["B_output"]), + ], + ), + IgnoredScope( + names=["B_name", "C_name"], + patterns=["B_pattern", "C_pattern"], + types=["B_type", "C_type"], + subgraphs=[ + Subgraph(inputs=["B_input"], outputs=["B_output"]), + Subgraph(inputs=["C_input"], outputs=["C_output"]), + ], + ), + IgnoredScope( + names=["A_name"], + patterns=["A_pattern"], + types=["A_type"], + subgraphs=[Subgraph(inputs=["A_input"], outputs=["A_output"])], + ), + ), + ), +) +def test_ignored_scope_diff(scope_1, scope_2, ref): + assert get_difference_ignored_scope(scope_1, scope_2) == ref diff --git a/tests/openvino/native/data/2024.1/reference_graphs/quantized/IfModel_2_ignored_scope_else.dot b/tests/openvino/native/data/2024.1/reference_graphs/quantized/IfModel_2_ignored_scope_else.dot new file mode 100644 index 00000000000..a0e344d66ea --- /dev/null +++ b/tests/openvino/native/data/2024.1/reference_graphs/quantized/IfModel_2_ignored_scope_else.dot @@ -0,0 +1,13 @@ +strict digraph { +"0 Input" [id=0, type=Parameter]; +"1 MatMul" [id=1, type=MatMul]; +"2 Add" [id=2, type=Add]; +"3 Result_Add" [id=3, type=Result]; +"4 Add/Constant_16" [id=4, type=Constant]; +"5 MatMul/Constant_14" [id=5, type=Constant]; +"0 Input" -> "1 MatMul" [label="[1, 3, 4, 2]", style=solid]; +"1 MatMul" -> "2 Add" [label="[1, 3, 2, 5]", style=solid]; +"2 Add" -> "3 Result_Add" [label="[1, 3, 2, 5]", style=solid]; +"4 Add/Constant_16" -> "2 Add" [label="[1, 3, 1, 1]", style=solid]; +"5 MatMul/Constant_14" -> "1 MatMul" [label="[1, 3, 4, 5]", style=solid]; +} diff --git a/tests/openvino/native/data/2024.1/reference_graphs/quantized/IfModel_2_ignored_scope_main.dot b/tests/openvino/native/data/2024.1/reference_graphs/quantized/IfModel_2_ignored_scope_main.dot new file mode 100644 index 00000000000..9ac187f8df5 --- /dev/null +++ b/tests/openvino/native/data/2024.1/reference_graphs/quantized/IfModel_2_ignored_scope_main.dot @@ -0,0 +1,9 @@ +strict digraph { +"0 Input_1" [id=0, type=Parameter]; +"1 Cond_input" [id=1, type=Parameter]; +"2 If_19" [id=2, type=If]; +"3 Result" [id=3, type=Result]; +"0 Input_1" -> "2 If_19" [label="[1, 3, 4, 2]", style=solid]; +"1 Cond_input" -> "2 If_19" [label="[]", style=dashed]; +"2 If_19" -> "3 Result" [label="[1, 3, 4, 5]", style=solid]; +} diff --git a/tests/openvino/native/data/2024.1/reference_graphs/quantized/IfModel_2_ignored_scope_then.dot b/tests/openvino/native/data/2024.1/reference_graphs/quantized/IfModel_2_ignored_scope_then.dot new file mode 100644 index 00000000000..63d45ac61ea --- /dev/null +++ b/tests/openvino/native/data/2024.1/reference_graphs/quantized/IfModel_2_ignored_scope_then.dot @@ -0,0 +1,35 @@ +strict digraph { +"0 Input_1" [id=0, type=Parameter]; +"1 Sub" [id=1, type=Subtract]; +"2 Sub/fq_output_0" [id=2, type=FakeQuantize]; +"3 Conv" [id=3, type=Convolution]; +"4 Conv_Add" [id=4, type=Add]; +"5 Relu" [id=5, type=Relu]; +"6 Result" [id=6, type=Result]; +"7 NotBias" [id=7, type=Constant]; +"8 Conv/fq_weights_1" [id=8, type=Multiply]; +"9 Constant_14900" [id=9, type=Constant]; +"10 Convert_14899" [id=10, type=Convert]; +"11 Constant_14898" [id=11, type=Constant]; +"12 Sub/fq_output_0/output_high" [id=12, type=Constant]; +"13 Sub/fq_output_0/output_low" [id=13, type=Constant]; +"14 Sub/fq_output_0/input_high" [id=14, type=Constant]; +"15 Sub/fq_output_0/input_low" [id=15, type=Constant]; +"16 Sub/Constant_4" [id=16, type=Constant]; +"0 Input_1" -> "1 Sub" [label="[1, 3, 4, 2]", style=solid]; +"1 Sub" -> "2 Sub/fq_output_0" [label="[1, 3, 4, 2]", style=solid]; +"2 Sub/fq_output_0" -> "3 Conv" [label="[1, 3, 4, 2]", style=solid]; +"3 Conv" -> "4 Conv_Add" [label="[1, 3, 4, 2]", style=solid]; +"4 Conv_Add" -> "5 Relu" [label="[1, 3, 4, 2]", style=solid]; +"5 Relu" -> "6 Result" [label="[1, 3, 4, 2]", style=solid]; +"7 NotBias" -> "4 Conv_Add" [label="[1, 3, 4, 2]", style=solid]; +"8 Conv/fq_weights_1" -> "3 Conv" [label="[3, 3, 1, 1]", style=solid]; +"9 Constant_14900" -> "8 Conv/fq_weights_1" [label="[3, 1, 1, 1]", style=solid]; +"10 Convert_14899" -> "8 Conv/fq_weights_1" [label="[3, 3, 1, 1]", style=solid]; +"11 Constant_14898" -> "10 Convert_14899" [label="[3, 3, 1, 1]", style=dashed]; +"12 Sub/fq_output_0/output_high" -> "2 Sub/fq_output_0" [label="[]", style=solid]; +"13 Sub/fq_output_0/output_low" -> "2 Sub/fq_output_0" [label="[]", style=solid]; +"14 Sub/fq_output_0/input_high" -> "2 Sub/fq_output_0" [label="[]", style=solid]; +"15 Sub/fq_output_0/input_low" -> "2 Sub/fq_output_0" [label="[]", style=solid]; +"16 Sub/Constant_4" -> "1 Sub" [label="[1, 3, 1, 1]", style=solid]; +} diff --git a/tests/openvino/native/models.py b/tests/openvino/native/models.py index dfb6362b5f0..900fbdf29de 100644 --- a/tests/openvino/native/models.py +++ b/tests/openvino/native/models.py @@ -1103,3 +1103,21 @@ def _create_ov_model(self, stateful=True): model = ov.Model(results=[result], parameters=[input_data], name="TestModel") return model + + +class IfModel_2(OVReferenceModel): + def _create_ov_model(self): + input_1 = opset.parameter([1, 3, 4, 2], name="Input_1") + input_2 = opset.parameter([], dtype=bool, name="Cond_input") + + then_body = ConvNotBiasModel().ov_model + else_body = FPModel().ov_model + + if_node = opset.if_op(input_2) + if_node.set_then_body(then_body) + if_node.set_else_body(else_body) + if_node.set_input(input_1.outputs()[0], then_body.get_parameters()[0], else_body.get_parameters()[0]) + if_node.set_output(then_body.results[0], else_body.results[0]) + result = opset.result(if_node, name="Result") + model = ov.Model([result], [input_1, input_2]) + return model diff --git a/tests/openvino/native/quantization/test_graphs.py b/tests/openvino/native/quantization/test_graphs.py index c97a4873f70..1a36a7c0331 100644 --- a/tests/openvino/native/quantization/test_graphs.py +++ b/tests/openvino/native/quantization/test_graphs.py @@ -26,6 +26,7 @@ from nncf.parameters import QuantizationMode from nncf.parameters import TargetDevice from nncf.quantization.algorithms.smooth_quant.algorithm import SmoothQuant +from nncf.scopes import IgnoredScope from tests.openvino.native.common import compare_nncf_graphs from tests.openvino.native.common import convert_torch_model from tests.openvino.native.common import dump_model @@ -38,6 +39,7 @@ from tests.openvino.native.models import DepthwiseConv5DModel from tests.openvino.native.models import GRUSequenceModel from tests.openvino.native.models import IfModel +from tests.openvino.native.models import IfModel_2 from tests.openvino.native.models import MatmulSoftmaxMatmulBlock from tests.openvino.native.models import ScaledDotProductAttentionModel from tests.openvino.native.models import get_torch_model_info @@ -202,6 +204,31 @@ def test_if_model_fq_placement(): ) +def test_if_model_fq_placement_ignored_scope(): + if_model = IfModel_2() + ov_model = if_model.ov_model + dataset = get_dataset_for_if_model(ov_model) + quantized_model = quantize_impl( + ov_model, dataset, subset_size=2, fast_bias_correction=True, ignored_scope=IgnoredScope(names=["MatMul"]) + ) + if_ops = [op for op in quantized_model.get_ops() if op.get_type_name() == "If"] + assert len(if_ops) == 1 + if_op = if_ops[0] + main_model_path = if_model.ref_model_name + "_ignored_scope_main.dot" + then_body_path = if_model.ref_model_name + "_ignored_scope_then.dot" + else_body_path = if_model.ref_model_name + "_ignored_scope_else.dot" + + compare_nncf_graphs( + quantized_model, get_actual_reference_for_current_openvino(QUANTIZED_REF_GRAPHS_DIR / main_model_path) + ) + compare_nncf_graphs( + if_op.get_function(0), get_actual_reference_for_current_openvino(QUANTIZED_REF_GRAPHS_DIR / then_body_path) + ) + compare_nncf_graphs( + if_op.get_function(1), get_actual_reference_for_current_openvino(QUANTIZED_REF_GRAPHS_DIR / else_body_path) + ) + + @pytest.mark.parametrize("q_params", [{}, {"model_type": ModelType.TRANSFORMER}], ids=["default", "transformer"]) def test_scaled_dot_product_attention_placement(q_params, tmp_path): ov_major_version, ov_minor_version = get_openvino_major_minor_version() diff --git a/tests/post_training/test_templates/test_ptq_params.py b/tests/post_training/test_templates/test_ptq_params.py index 86c7cce64aa..941e8a9deeb 100644 --- a/tests/post_training/test_templates/test_ptq_params.py +++ b/tests/post_training/test_templates/test_ptq_params.py @@ -391,7 +391,7 @@ def test_validate_scope(self, test_params, validate_scopes): ) algo._backend_entity = self.get_algo_backend() if validate_scopes: - with pytest.raises(nncf.ValidationError, match="Ignored nodes with name"): + with pytest.raises(nncf.ValidationError, match="Ignored nodes that matches names"): algo._get_ignored_names(nncf_graph, inference_nncf_graph, ignored_patterns) else: algo._get_ignored_names(nncf_graph, inference_nncf_graph, ignored_patterns)