From 6794050897386c2543ef259b339c7018d4fa5b83 Mon Sep 17 00:00:00 2001 From: Andrey Churkin Date: Thu, 22 Aug 2024 13:21:21 +0100 Subject: [PATCH 1/5] Add find_shapeof_subgraphs method --- .../algorithms/accuracy_control/ranker.py | 8 +++-- nncf/quantization/passes.py | 33 ++++++++++--------- .../quantization/test_quantizer_removal.py | 8 +++-- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/nncf/quantization/algorithms/accuracy_control/ranker.py b/nncf/quantization/algorithms/accuracy_control/ranker.py index 3332f527cd8..8d90223c1b2 100644 --- a/nncf/quantization/algorithms/accuracy_control/ranker.py +++ b/nncf/quantization/algorithms/accuracy_control/ranker.py @@ -30,7 +30,7 @@ from nncf.quantization.algorithms.accuracy_control.evaluator import Evaluator from nncf.quantization.algorithms.accuracy_control.rank_functions import create_normalized_mse_func from nncf.quantization.algorithms.accuracy_control.subset_selection import select_subset -from nncf.quantization.passes import remove_shapeof_subgraphs +from nncf.quantization.passes import find_shapeof_subgraphs TModel = TypeVar("TModel") TPModel = TypeVar("TPModel") @@ -109,11 +109,13 @@ def find_groups_of_quantizers_to_rank(self, quantized_model_graph: NNCFGraph) -> *self._algo_backend.get_start_nodes_for_activation_path_tracing(quantized_model_graph), ] - quantized_model_graph_without_shapeof = remove_shapeof_subgraphs( - deepcopy(quantized_model_graph), + shapeof_subgraphs = find_shapeof_subgraphs( + quantized_model_graph, self._algo_backend.get_shapeof_metatypes(), input_nodes, ) + quantized_model_graph_without_shapeof = deepcopy(quantized_model_graph) + quantized_model_graph_without_shapeof.remove_nodes_from(shapeof_subgraphs) for quantizer_node in reversed(quantizers): if processed.get(quantizer_node.node_name, False): diff --git a/nncf/quantization/passes.py b/nncf/quantization/passes.py index f274215dc38..e80527879c4 100644 --- a/nncf/quantization/passes.py +++ b/nncf/quantization/passes.py @@ -34,27 +34,29 @@ def transform_to_inference_graph( :param dropout_metatypes: List of backend-specific Dropout metatypes. :return: NNCFGraph in the inference style. """ - remove_shapeof_subgraphs(nncf_graph, shapeof_metatypes, input_nodes) + shapeof_subgraphs = find_shapeof_subgraphs(nncf_graph, shapeof_metatypes, input_nodes) + nncf_graph.remove_nodes_from(shapeof_subgraphs) filter_constant_nodes(nncf_graph, input_nodes) remove_nodes_and_reconnect_graph(nncf_graph, dropout_metatypes) return nncf_graph -def remove_shapeof_subgraphs( +def find_shapeof_subgraphs( nncf_graph: NNCFGraph, shapeof_metatypes: List[OperatorMetatype], input_nodes: List[NNCFNode], -) -> NNCFGraph: +) -> List[NNCFNode]: """ - Removes the ShapeOf subgraphs from the provided NNCFGraph instance inplace. - Constant subgraph should be already removed from the given NNCFGraph. - - :param nncf_graph: NNCFGraph instance for the transformation. - :param shapeof_metatypes: List of backend-specific ShapeOf metatypes. - :param input_nodes: List of input nodes for the given NNCFGraph. - :return: NNCFGraph without ShapeOf subgraphs. + Returns a list of nodes belonging to ShapeOf subgraphs. + + :param nncf_graph: The input graph to be analyzed. + :param shapeof_metatypes: A list of metatypes representing backend-specific + ShapeOf operations. + :param input_nodes: A list of nodes designated as graph inputs. These nodes are + used to identify which nodes depend on input data. + :return: A list of nodes belonging to ShapeOf subgraphs. """ - nodes_to_drop = set() + shapeof_subgraphs = set() shape_of_nodes = [] infer_nodes = [] @@ -70,21 +72,20 @@ def remove_shapeof_subgraphs( nodes_queue.extend(nncf_graph.get_next_nodes(node)) for shape_of_node in shape_of_nodes: - nodes_to_drop.add(shape_of_node.node_name) + shapeof_subgraphs.add(shape_of_node) shape_of_queue = collections.deque() shape_of_queue.extend(nncf_graph.get_next_nodes(shape_of_node)) while shape_of_queue: node = shape_of_queue.pop() - if node.node_name in nodes_to_drop or node.node_name in infer_nodes: + if node in shapeof_subgraphs or node.node_name in infer_nodes: continue - nodes_to_drop.add(node.node_name) + shapeof_subgraphs.add(node) # traverse forward and backward to exclude full shape of subgraph # recursion excluded due to infer_nodes list around subgraph shape shape_of_queue.extend(nncf_graph.get_next_nodes(node) + nncf_graph.get_previous_nodes(node)) - nncf_graph.remove_nodes_from([nncf_graph.get_node_by_name(name) for name in nodes_to_drop]) - return nncf_graph + return list(shapeof_subgraphs) def remove_nodes_and_reconnect_graph( diff --git a/tests/common/quantization/test_quantizer_removal.py b/tests/common/quantization/test_quantizer_removal.py index 46ee43e189c..9280666382f 100644 --- a/tests/common/quantization/test_quantizer_removal.py +++ b/tests/common/quantization/test_quantizer_removal.py @@ -18,7 +18,7 @@ from nncf.common.graph import NNCFGraph from nncf.common.graph.layer_attributes import Dtype from nncf.common.quantization.quantizer_removal import find_quantizer_nodes_to_cut -from nncf.quantization.passes import remove_shapeof_subgraphs +from nncf.quantization.passes import find_shapeof_subgraphs from tests.common.quantization.metatypes import CONSTANT_METATYPES from tests.common.quantization.metatypes import METATYPES_FOR_TEST from tests.common.quantization.metatypes import QUANTIZABLE_METATYPES @@ -304,7 +304,11 @@ def test_find_quantizer_nodes_to_cut(nncf_graph: NNCFGraph, test_case: Parameter # As test graphs are fully connected and does not have readvariable metatype, # this should work input_nodes = nncf_graph.get_input_nodes() - nncf_graph_without_shapeof = remove_shapeof_subgraphs(deepcopy(nncf_graph), SHAPEOF_METATYPES, input_nodes) + + shapeof_subgraphs = find_shapeof_subgraphs(nncf_graph, SHAPEOF_METATYPES, input_nodes) + nncf_graph_without_shapeof = deepcopy(nncf_graph) + nncf_graph_without_shapeof.remove_nodes_from(shapeof_subgraphs) + nodes, ops = find_quantizer_nodes_to_cut( nncf_graph_without_shapeof, quantizer_node, From 85c595ba271618e4f299122365212a72274981fd Mon Sep 17 00:00:00 2001 From: Andrey Churkin Date: Thu, 22 Aug 2024 16:43:33 +0100 Subject: [PATCH 2/5] Add find_constant_subgraphs method --- nncf/quantization/passes.py | 21 +++++++++++---------- tests/common/quantization/test_passes.py | 7 ++++--- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/nncf/quantization/passes.py b/nncf/quantization/passes.py index e80527879c4..e871ab474f6 100644 --- a/nncf/quantization/passes.py +++ b/nncf/quantization/passes.py @@ -36,7 +36,8 @@ def transform_to_inference_graph( """ shapeof_subgraphs = find_shapeof_subgraphs(nncf_graph, shapeof_metatypes, input_nodes) nncf_graph.remove_nodes_from(shapeof_subgraphs) - filter_constant_nodes(nncf_graph, input_nodes) + constant_subgraphs = find_constant_subgraphs(nncf_graph, input_nodes) + nncf_graph.remove_nodes_from(constant_subgraphs) remove_nodes_and_reconnect_graph(nncf_graph, dropout_metatypes) return nncf_graph @@ -138,17 +139,17 @@ def remove_nodes_and_reconnect_graph( return nncf_graph -def filter_constant_nodes( +def find_constant_subgraphs( nncf_graph: NNCFGraph, input_nodes: List[NNCFNode], -) -> NNCFGraph: +) -> List[NNCFNode]: """ - Removes all Constant nodes from NNCFGraph inplace, making it inference graph. - The traversing starts from the input nodes and nodes with weights. + Returns a list of nodes belonging to constant subgraphs. - :param nncf_graph: NNCFGraph instance for the transformation. - :param input_nodes: List of input nodes for the given NNCFGraph. - :return: NNCFGraph without Constant nodes. + :param nncf_graph: The input graph to be analyzed. + :param input_nodes: A list of nodes designated as graph inputs. These nodes are + used to identify which nodes depend on input data. + :return: A list of nodes belonging to constant subgraphs. """ if not input_nodes: return nncf_graph @@ -162,5 +163,5 @@ def filter_constant_nodes( visited_nodes.add(node) nodes_queue.extend(nncf_graph.get_next_nodes(node)) constant_nodes = [node for node in nncf_graph.get_all_nodes() if node not in visited_nodes] - nncf_graph.remove_nodes_from(constant_nodes) - return nncf_graph + + return constant_nodes diff --git a/tests/common/quantization/test_passes.py b/tests/common/quantization/test_passes.py index bbf75b5ef67..70d681efa01 100644 --- a/tests/common/quantization/test_passes.py +++ b/tests/common/quantization/test_passes.py @@ -16,7 +16,7 @@ from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes from nncf.common.graph.operator_metatypes import OperatorMetatype -from nncf.quantization.passes import filter_constant_nodes +from nncf.quantization.passes import find_constant_subgraphs from nncf.quantization.passes import remove_nodes_and_reconnect_graph from tests.cross_fw.test_templates.models import NNCFGraphDropoutRemovingCase from tests.cross_fw.test_templates.models import NNCFGraphToTestConstantFiltering @@ -63,7 +63,7 @@ def test_remove_nodes_and_reconnect_graph(mode: ParameterTestModes): @pytest.mark.parametrize("node_between_const_and_op", [False, True]) -def test_filter_constant_nodes(node_between_const_and_op): +def test_find_constant_subgraphs(node_between_const_and_op): dot_reference_path_before = ( Path("passes") / f"test_constant_filtering_model_before{int(node_between_const_and_op)}.dot" ) @@ -88,5 +88,6 @@ class NodeWithWeightMetatype(OperatorMetatype): additional_input_names = ["/Conv2_0", "/Concat_with_missed_input_0"] input_nodes = nncf_graph.get_input_nodes() + [nncf_graph.get_node_by_name(name) for name in additional_input_names] _check_graphs(dot_reference_path_before, nncf_graph) - filter_constant_nodes(nncf_graph, input_nodes) + constant_subgraphs = find_constant_subgraphs(nncf_graph, input_nodes) + nncf_graph.remove_nodes_from(constant_subgraphs) _check_graphs(dot_reference_path_after, nncf_graph) From a385f0d39fe9336e3b4d9376368e60a1cdf09176 Mon Sep 17 00:00:00 2001 From: Andrey Churkin Date: Thu, 22 Aug 2024 17:05:56 +0100 Subject: [PATCH 3/5] Minor fix --- nncf/quantization/passes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nncf/quantization/passes.py b/nncf/quantization/passes.py index e871ab474f6..0894abcbc91 100644 --- a/nncf/quantization/passes.py +++ b/nncf/quantization/passes.py @@ -152,7 +152,7 @@ def find_constant_subgraphs( :return: A list of nodes belonging to constant subgraphs. """ if not input_nodes: - return nncf_graph + return [] visited_nodes = set() nodes_queue = collections.deque(input_nodes) From 676983f9c6f4249a07b66190bc9dc738ac286112 Mon Sep 17 00:00:00 2001 From: Andrey Churkin Date: Thu, 22 Aug 2024 17:25:10 +0100 Subject: [PATCH 4/5] Add preserved nodes --- .../algorithms/layerwise/scheduler.py | 2 +- .../algorithms/min_max/algorithm.py | 1 + .../algorithms/min_max/backend.py | 7 ++++ .../algorithms/min_max/onnx_backend.py | 4 ++ .../algorithms/min_max/openvino_backend.py | 4 ++ .../algorithms/min_max/torch_backend.py | 4 ++ .../algorithms/min_max/torch_fx_backend.py | 4 ++ .../algorithms/weight_compression/awq.py | 2 +- nncf/quantization/passes.py | 41 ++++++++++++++++++- .../test_templates/test_ptq_params.py | 4 ++ .../test_templates/test_quantizer_config.py | 3 ++ 11 files changed, 72 insertions(+), 4 deletions(-) diff --git a/nncf/quantization/algorithms/layerwise/scheduler.py b/nncf/quantization/algorithms/layerwise/scheduler.py index d2d1461dd38..8eee99fad28 100644 --- a/nncf/quantization/algorithms/layerwise/scheduler.py +++ b/nncf/quantization/algorithms/layerwise/scheduler.py @@ -101,7 +101,7 @@ def schedule( """ # Initialize input nodes and create a copy of the graph for inference input_nodes = graph.get_input_nodes() - inference_graph = transform_to_inference_graph(deepcopy(graph), input_nodes, [], []) + inference_graph = transform_to_inference_graph(deepcopy(graph), input_nodes, [], [], []) steps = [] visited_map = {node: False for node in inference_graph.get_all_nodes()} diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index eea683b64ec..6306f6736c2 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -742,6 +742,7 @@ def _find_quantization_target_points( self._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph), self._backend_entity.shapeof_metatypes, self._backend_entity.dropout_metatypes, + self._backend_entity.preserved_metatypes, ) quantizer_setup = self._get_quantizer_setup(nncf_graph, inference_nncf_graph, hw_patterns, ignored_patterns) diff --git a/nncf/quantization/algorithms/min_max/backend.py b/nncf/quantization/algorithms/min_max/backend.py index 70e88d74b2a..26d1d7ca65d 100644 --- a/nncf/quantization/algorithms/min_max/backend.py +++ b/nncf/quantization/algorithms/min_max/backend.py @@ -32,6 +32,13 @@ class MinMaxAlgoBackend(ABC): + @property + @abstractmethod + def preserved_metatypes(self) -> List[OperatorMetatype]: + """ + TODO + """ + @property @abstractmethod def mat_mul_metatypes(self) -> List[OperatorMetatype]: diff --git a/nncf/quantization/algorithms/min_max/onnx_backend.py b/nncf/quantization/algorithms/min_max/onnx_backend.py index e9dc3dded13..ce750026a43 100644 --- a/nncf/quantization/algorithms/min_max/onnx_backend.py +++ b/nncf/quantization/algorithms/min_max/onnx_backend.py @@ -46,6 +46,10 @@ class ONNXMinMaxAlgoBackend(MinMaxAlgoBackend): + @property + def preserved_metatypes(self) -> List[OperatorMetatype]: + return [] + @property def mat_mul_metatypes(self) -> List[OperatorMetatype]: return MATMUL_METATYPES diff --git a/nncf/quantization/algorithms/min_max/openvino_backend.py b/nncf/quantization/algorithms/min_max/openvino_backend.py index 8107e68143a..fbb420f308e 100644 --- a/nncf/quantization/algorithms/min_max/openvino_backend.py +++ b/nncf/quantization/algorithms/min_max/openvino_backend.py @@ -43,6 +43,10 @@ class OVMinMaxAlgoBackend(MinMaxAlgoBackend): + @property + def preserved_metatypes(self) -> List[OperatorMetatype]: + return [om.OVConvolutionMetatype, om.OVLSTMSequenceMetatype] + @property def mat_mul_metatypes(self) -> List[OperatorMetatype]: return [om.OVMatMulMetatype] diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index 21a0a10cadb..98e41e03745 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -56,6 +56,10 @@ class PTMinMaxAlgoBackend(MinMaxAlgoBackend): + @property + def preserved_metatypes(self) -> List[OperatorMetatype]: + return [] + TARGET_TYPE_TO_PT_INS_TYPE_MAP = { TargetType.PRE_LAYER_OPERATION: TargetType.OPERATOR_PRE_HOOK, TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK, diff --git a/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/nncf/quantization/algorithms/min_max/torch_fx_backend.py index 42bddf3fbba..4f82a1e0c8c 100644 --- a/nncf/quantization/algorithms/min_max/torch_fx_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -59,6 +59,10 @@ class FXMinMaxAlgoBackend(MinMaxAlgoBackend): TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK, } + @property + def preserved_metatypes(self) -> List[OperatorMetatype]: + return [] + @property def mat_mul_metatypes(self) -> List[OperatorMetatype]: return [om.PTLinearMetatype, om.PTMatMulMetatype] diff --git a/nncf/quantization/algorithms/weight_compression/awq.py b/nncf/quantization/algorithms/weight_compression/awq.py index c85a6cabd50..c0c05d76e5b 100644 --- a/nncf/quantization/algorithms/weight_compression/awq.py +++ b/nncf/quantization/algorithms/weight_compression/awq.py @@ -133,7 +133,7 @@ def apply( """ matches = [] - inference_nncf_graph = transform_to_inference_graph(deepcopy(graph), [], [], []) + inference_nncf_graph = transform_to_inference_graph(deepcopy(graph), [], [], [], []) nx_graph = inference_nncf_graph.get_nx_graph_copy() for _, pattern_graph in self._patterns.items(): matches.extend(find_subgraphs_matching_pattern(nx_graph, pattern_graph(), strict=False)) diff --git a/nncf/quantization/passes.py b/nncf/quantization/passes.py index 0894abcbc91..4c32dfd85e3 100644 --- a/nncf/quantization/passes.py +++ b/nncf/quantization/passes.py @@ -14,6 +14,7 @@ from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.layer_attributes import Dtype from nncf.common.graph.operator_metatypes import OperatorMetatype TModel = TypeVar("TModel") @@ -24,6 +25,7 @@ def transform_to_inference_graph( input_nodes: List[NNCFNode], shapeof_metatypes: List[OperatorMetatype], dropout_metatypes: List[OperatorMetatype], + preserved_metatypes: List[OperatorMetatype], ) -> NNCFGraph: """ This method contains inplace pipeline of the passes that uses to provide inference graph without constant flows. @@ -35,9 +37,12 @@ def transform_to_inference_graph( :return: NNCFGraph in the inference style. """ shapeof_subgraphs = find_shapeof_subgraphs(nncf_graph, shapeof_metatypes, input_nodes) - nncf_graph.remove_nodes_from(shapeof_subgraphs) + preserved_nodes = find_preserved_nodes(nncf_graph, shapeof_subgraphs, preserved_metatypes) constant_subgraphs = find_constant_subgraphs(nncf_graph, input_nodes) - nncf_graph.remove_nodes_from(constant_subgraphs) + + nodes_to_drop = set([*shapeof_subgraphs, *constant_subgraphs]).difference(preserved_nodes) + nncf_graph.remove_nodes_from(nodes_to_drop) + remove_nodes_and_reconnect_graph(nncf_graph, dropout_metatypes) return nncf_graph @@ -89,6 +94,38 @@ def find_shapeof_subgraphs( return list(shapeof_subgraphs) +def find_preserved_nodes( + graph: NNCFGraph, + shapeof_subgraphs: List[NNCFNode], + preserved_metatypes: List[OperatorMetatype], +) -> List[NNCFNode]: + """ + :param graph: + :param shapeof_subgraphs: + :param preserved_metatypes: + :return: + """ + preserved_nodes = set() + for node in graph.get_nodes_by_metatypes(preserved_metatypes): + for e in graph.get_input_edges(node): + if e.from_node in shapeof_subgraphs and e.dtype == Dtype.FLOAT: + preserved_nodes.add(e.from_node) + + queue = collections.deque(preserved_nodes) + while queue: + node = queue.pop() + + for e in graph.get_input_edges(node): + if e.from_node in preserved_nodes: + continue + + if e.dtype == Dtype.FLOAT and e.from_node in shapeof_subgraphs: + queue.append(e.from_node) + preserved_nodes.add(e.from_node) + + return list(preserved_nodes) + + def remove_nodes_and_reconnect_graph( nncf_graph: NNCFGraph, metatypes: List[OperatorMetatype], diff --git a/tests/cross_fw/test_templates/test_ptq_params.py b/tests/cross_fw/test_templates/test_ptq_params.py index c6b1ce135f0..eacf57652e7 100644 --- a/tests/cross_fw/test_templates/test_ptq_params.py +++ b/tests/cross_fw/test_templates/test_ptq_params.py @@ -236,6 +236,7 @@ def test_quantize_outputs(self, test_params, quantize_outputs): min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph), min_max_algo._backend_entity.shapeof_metatypes, min_max_algo._backend_entity.dropout_metatypes, + min_max_algo._backend_entity.preserved_metatypes, ) q_setup = min_max_algo._get_quantizer_setup(nncf_graph, inference_nncf_graph, hw_patterns, ignored_patterns) act_num_q, weight_num_q = 0, 0 @@ -261,6 +262,7 @@ def test_ignored_scopes(self, test_params, ignored_scopes_data): min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph), min_max_algo._backend_entity.shapeof_metatypes, min_max_algo._backend_entity.dropout_metatypes, + min_max_algo._backend_entity.preserved_metatypes, ) q_setup = min_max_algo._get_quantizer_setup(nncf_graph, inference_nncf_graph, hw_patterns, ignored_patterns) act_num_q, weight_num_q = 0, 0 @@ -286,6 +288,7 @@ def test_model_type_pass(self, test_params, model_type): min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph), min_max_algo._backend_entity.shapeof_metatypes, min_max_algo._backend_entity.dropout_metatypes, + min_max_algo._backend_entity.preserved_metatypes, ) q_setup = min_max_algo._get_quantizer_setup(nncf_graph, inference_nncf_graph, hw_patterns, ignored_patterns) for quantization_point in q_setup.quantization_points.values(): @@ -384,6 +387,7 @@ def test_validate_scope(self, test_params, validate_scopes): self.get_algo_backend().get_start_nodes_for_activation_path_tracing(nncf_graph), [], [], + [], ) ignored_patterns = test_params["test_model_type_pass"]["ignored_patterns"] algo = MinMaxQuantization( diff --git a/tests/cross_fw/test_templates/test_quantizer_config.py b/tests/cross_fw/test_templates/test_quantizer_config.py index 79e3a8e5171..a01e9ffdb8e 100644 --- a/tests/cross_fw/test_templates/test_quantizer_config.py +++ b/tests/cross_fw/test_templates/test_quantizer_config.py @@ -134,6 +134,7 @@ def test_default_quantizer_config(self, single_conv_nncf_graph): min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph), min_max_algo._backend_entity.shapeof_metatypes, min_max_algo._backend_entity.dropout_metatypes, + min_max_algo._backend_entity.preserved_metatypes, ) q_setup = min_max_algo._get_quantizer_setup( nncf_graph, inference_nncf_graph, hw_patterns=GraphPattern(), ignored_patterns=GraphPattern() @@ -188,6 +189,7 @@ def test_quantizer_config_from_ptq_params_for_CPU( min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph), min_max_algo._backend_entity.shapeof_metatypes, min_max_algo._backend_entity.dropout_metatypes, + min_max_algo._backend_entity.preserved_metatypes, ) if signed_weights is False or signed_activations in [True, False]: # Incompatible with HW CPU config with pytest.raises( @@ -230,6 +232,7 @@ def test_depthwise_conv_default_quantizer_config(self, depthwise_conv_nncf_graph min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph), min_max_algo._backend_entity.shapeof_metatypes, min_max_algo._backend_entity.dropout_metatypes, + min_max_algo._backend_entity.preserved_metatypes, ) q_setup = min_max_algo._get_quantizer_setup( nncf_graph, inference_nncf_graph, hw_patterns=GraphPattern(), ignored_patterns=GraphPattern() From ba76dbcd12e290a90bd632dbaeeed9050d85f0ad Mon Sep 17 00:00:00 2001 From: Andrey Churkin Date: Fri, 23 Aug 2024 13:45:13 +0100 Subject: [PATCH 5/5] Add docs --- nncf/quantization/algorithms/min_max/backend.py | 3 ++- nncf/quantization/passes.py | 9 +++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/nncf/quantization/algorithms/min_max/backend.py b/nncf/quantization/algorithms/min_max/backend.py index 26d1d7ca65d..73befbe90c7 100644 --- a/nncf/quantization/algorithms/min_max/backend.py +++ b/nncf/quantization/algorithms/min_max/backend.py @@ -36,7 +36,8 @@ class MinMaxAlgoBackend(ABC): @abstractmethod def preserved_metatypes(self) -> List[OperatorMetatype]: """ - TODO + Property for backend-specific metatypes that require preserving float subgraphs + when removing the ShapeOf subgraph. """ @property diff --git a/nncf/quantization/passes.py b/nncf/quantization/passes.py index 4c32dfd85e3..3d5cfa58e5f 100644 --- a/nncf/quantization/passes.py +++ b/nncf/quantization/passes.py @@ -100,10 +100,11 @@ def find_preserved_nodes( preserved_metatypes: List[OperatorMetatype], ) -> List[NNCFNode]: """ - :param graph: - :param shapeof_subgraphs: - :param preserved_metatypes: - :return: + :param graph: The input graph to be analyzed. + :param shapeof_subgraphs: A list of nodes belonging to ShapeOf subgraphs. + :param preserved_metatypes: Backend-specific metatypes that require preserving + float subgraphs when removing the ShapeOf subgraph. + :return: A list of nodes in float subgraphs of ShapeOf subgraphs. """ preserved_nodes = set() for node in graph.get_nodes_by_metatypes(preserved_metatypes):