From c73dd03c2953ec7aa2abb0d8c97b028b01f00c99 Mon Sep 17 00:00:00 2001 From: Andrey Churkin Date: Fri, 5 Jul 2024 11:24:40 +0100 Subject: [PATCH] Remove unused code --- nncf/common/insertion_point_graph.py | 34 +------------------ .../test_filter_constant_nodes.py | 25 -------------- 2 files changed, 1 insertion(+), 58 deletions(-) diff --git a/nncf/common/insertion_point_graph.py b/nncf/common/insertion_point_graph.py index 5cd1a8e4710..118d079fc8e 100644 --- a/nncf/common/insertion_point_graph.py +++ b/nncf/common/insertion_point_graph.py @@ -12,7 +12,7 @@ from collections import defaultdict from copy import deepcopy from enum import Enum -from typing import Dict, List, Optional, Set +from typing import Dict, List, Set import networkx as nx @@ -23,7 +23,6 @@ from nncf.common.graph.layer_attributes import Dtype from nncf.common.graph.operator_metatypes import INPUT_NOOP_METATYPES from nncf.common.graph.patterns import GraphPattern -from nncf.common.logging import nncf_logger class InsertionPointGraphNodeType(Enum): @@ -393,34 +392,3 @@ def get_pre_hook_node_key(node_key: str, input_port_id: int = 0) -> str: @staticmethod def get_post_hook_node_key(node_key: str) -> str: return InsertionPointGraph.POST_HOOK_ID_PREFIX + node_key - - -class ConstantNodesFilter: - @staticmethod - def filter(ip_graph: InsertionPointGraph, start_traversing_node_keys: Optional[List[str]]) -> InsertionPointGraph: - """ - Removes all Constant nodes from InsertionPointGraph, making it inference graph. - The traversing starts from the input nodes and nodes with weights. - - :param ip_graph: The original InsertionPointGraph. - :param start_traversing_node_keys: Keys of the nodes from which the traversing will be start. - :return: InsertionPointGraph without Constant nodes. - """ - input_nodes = ip_graph.get_input_nodes() - if not input_nodes: - nncf_logger.debug("Skipped filtering - no input nodes found") - return ip_graph - weight_nodes = [] - if start_traversing_node_keys is not None: - weight_nodes = [ - ip_graph.get_merged_node_from_single_node_key(weight_node) for weight_node in start_traversing_node_keys - ] - visited_nodes = set() - start_nodes = input_nodes + weight_nodes - for node in start_nodes: - for node_from, node_to in nx.bfs_edges(ip_graph, source=node): - visited_nodes.add(node_from) - visited_nodes.add(node_to) - constant_nodes = [node for node in ip_graph.nodes if node not in visited_nodes] - ip_graph.remove_nodes_from(constant_nodes) - return ip_graph diff --git a/tests/common/quantization/test_filter_constant_nodes.py b/tests/common/quantization/test_filter_constant_nodes.py index 385e9244f4c..8c1ff2412ce 100644 --- a/tests/common/quantization/test_filter_constant_nodes.py +++ b/tests/common/quantization/test_filter_constant_nodes.py @@ -12,23 +12,16 @@ import re from collections import Counter -import pytest - from nncf.common.graph.operator_metatypes import InputNoopMetatype from nncf.common.graph.operator_metatypes import OutputNoopMetatype -from nncf.common.insertion_point_graph import ConstantNodesFilter from nncf.common.insertion_point_graph import InsertionPointGraph -from nncf.common.quantization.structs import QuantizableWeightedLayerNode -from nncf.common.quantization.structs import QuantizerConfig from nncf.common.utils.registry import Registry -from tests.common.quantization.metatypes import WEIGHT_LAYER_METATYPES from tests.common.quantization.metatypes import Conv2dTestMetatype from tests.common.quantization.metatypes import IdentityTestMetatype from tests.common.quantization.metatypes import LinearTestMetatype from tests.common.quantization.metatypes import ReshapeTestMetatype from tests.common.quantization.mock_graphs import NodeWithType from tests.common.quantization.mock_graphs import create_mock_graph -from tests.common.quantization.mock_graphs import get_ip_graph_for_test from tests.common.quantization.mock_graphs import get_nncf_graph_from_mock_nx_graph SYNTHETIC_NNCF_GRAPH_WITH_CONSTANT_SUBGRAPHS = Registry("SYNTHETIC_MODELS_WITH_CONSTANT_SUBGRAPHS") @@ -218,24 +211,6 @@ def __init__(self): self.ref_nncf_graph = get_nncf_graph_from_mock_nx_graph(reference_mock_graph) -@pytest.mark.parametrize("model_to_test", SYNTHETIC_NNCF_GRAPH_WITH_CONSTANT_SUBGRAPHS.values()) -def test_constant_nodes_filter(model_to_test): - model_to_test = model_to_test() - nncf_graph = model_to_test.nncf_graph - weight_nodes = nncf_graph.get_nodes_by_metatypes(WEIGHT_LAYER_METATYPES) - quantizable_layer_nodes = [ - QuantizableWeightedLayerNode(weight_node, [QuantizerConfig()]) for weight_node in weight_nodes - ] - quantizable_layer_node_keys = [node.node.node_key for node in quantizable_layer_nodes] - - ip_graph = get_ip_graph_for_test(nncf_graph, quantizable_layer_nodes) - filtered_ip_graph = ConstantNodesFilter.filter(ip_graph, quantizable_layer_node_keys) - - ref_ip_graph = get_ip_graph_for_test(model_to_test.ref_nncf_graph, quantizable_layer_nodes) - - check_ip_graphs_are_equal(filtered_ip_graph, ref_ip_graph) - - def check_ip_graphs_are_equal(graph_1: InsertionPointGraph, graph_2: InsertionPointGraph): graph_1_node_keys_without_index = [graph_1_node_key.split(" ")[-1] for graph_1_node_key in graph_1.nodes] graph_2_node_keys_without_index = [graph_2_node_key.split(" ")[-1] for graph_2_node_key in graph_2.nodes]