Skip to content

Commit

Permalink
Remove unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
andrey-churkin committed Jul 5, 2024
1 parent 28d99a0 commit c73dd03
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 58 deletions.
34 changes: 1 addition & 33 deletions nncf/common/insertion_point_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from collections import defaultdict
from copy import deepcopy
from enum import Enum
from typing import Dict, List, Optional, Set
from typing import Dict, List, Set

import networkx as nx

Expand All @@ -23,7 +23,6 @@
from nncf.common.graph.layer_attributes import Dtype
from nncf.common.graph.operator_metatypes import INPUT_NOOP_METATYPES
from nncf.common.graph.patterns import GraphPattern
from nncf.common.logging import nncf_logger


class InsertionPointGraphNodeType(Enum):
Expand Down Expand Up @@ -393,34 +392,3 @@ def get_pre_hook_node_key(node_key: str, input_port_id: int = 0) -> str:
@staticmethod
def get_post_hook_node_key(node_key: str) -> str:
return InsertionPointGraph.POST_HOOK_ID_PREFIX + node_key


class ConstantNodesFilter:
@staticmethod
def filter(ip_graph: InsertionPointGraph, start_traversing_node_keys: Optional[List[str]]) -> InsertionPointGraph:
"""
Removes all Constant nodes from InsertionPointGraph, making it inference graph.
The traversing starts from the input nodes and nodes with weights.
:param ip_graph: The original InsertionPointGraph.
:param start_traversing_node_keys: Keys of the nodes from which the traversing will be start.
:return: InsertionPointGraph without Constant nodes.
"""
input_nodes = ip_graph.get_input_nodes()
if not input_nodes:
nncf_logger.debug("Skipped filtering - no input nodes found")
return ip_graph
weight_nodes = []
if start_traversing_node_keys is not None:
weight_nodes = [
ip_graph.get_merged_node_from_single_node_key(weight_node) for weight_node in start_traversing_node_keys
]
visited_nodes = set()
start_nodes = input_nodes + weight_nodes
for node in start_nodes:
for node_from, node_to in nx.bfs_edges(ip_graph, source=node):
visited_nodes.add(node_from)
visited_nodes.add(node_to)
constant_nodes = [node for node in ip_graph.nodes if node not in visited_nodes]
ip_graph.remove_nodes_from(constant_nodes)
return ip_graph
25 changes: 0 additions & 25 deletions tests/common/quantization/test_filter_constant_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,16 @@
import re
from collections import Counter

import pytest

from nncf.common.graph.operator_metatypes import InputNoopMetatype
from nncf.common.graph.operator_metatypes import OutputNoopMetatype
from nncf.common.insertion_point_graph import ConstantNodesFilter
from nncf.common.insertion_point_graph import InsertionPointGraph
from nncf.common.quantization.structs import QuantizableWeightedLayerNode
from nncf.common.quantization.structs import QuantizerConfig
from nncf.common.utils.registry import Registry
from tests.common.quantization.metatypes import WEIGHT_LAYER_METATYPES
from tests.common.quantization.metatypes import Conv2dTestMetatype
from tests.common.quantization.metatypes import IdentityTestMetatype
from tests.common.quantization.metatypes import LinearTestMetatype
from tests.common.quantization.metatypes import ReshapeTestMetatype
from tests.common.quantization.mock_graphs import NodeWithType
from tests.common.quantization.mock_graphs import create_mock_graph
from tests.common.quantization.mock_graphs import get_ip_graph_for_test
from tests.common.quantization.mock_graphs import get_nncf_graph_from_mock_nx_graph

SYNTHETIC_NNCF_GRAPH_WITH_CONSTANT_SUBGRAPHS = Registry("SYNTHETIC_MODELS_WITH_CONSTANT_SUBGRAPHS")
Expand Down Expand Up @@ -218,24 +211,6 @@ def __init__(self):
self.ref_nncf_graph = get_nncf_graph_from_mock_nx_graph(reference_mock_graph)


@pytest.mark.parametrize("model_to_test", SYNTHETIC_NNCF_GRAPH_WITH_CONSTANT_SUBGRAPHS.values())
def test_constant_nodes_filter(model_to_test):
model_to_test = model_to_test()
nncf_graph = model_to_test.nncf_graph
weight_nodes = nncf_graph.get_nodes_by_metatypes(WEIGHT_LAYER_METATYPES)
quantizable_layer_nodes = [
QuantizableWeightedLayerNode(weight_node, [QuantizerConfig()]) for weight_node in weight_nodes
]
quantizable_layer_node_keys = [node.node.node_key for node in quantizable_layer_nodes]

ip_graph = get_ip_graph_for_test(nncf_graph, quantizable_layer_nodes)
filtered_ip_graph = ConstantNodesFilter.filter(ip_graph, quantizable_layer_node_keys)

ref_ip_graph = get_ip_graph_for_test(model_to_test.ref_nncf_graph, quantizable_layer_nodes)

check_ip_graphs_are_equal(filtered_ip_graph, ref_ip_graph)


def check_ip_graphs_are_equal(graph_1: InsertionPointGraph, graph_2: InsertionPointGraph):
graph_1_node_keys_without_index = [graph_1_node_key.split(" ")[-1] for graph_1_node_key in graph_1.nodes]
graph_2_node_keys_without_index = [graph_2_node_key.split(" ")[-1] for graph_2_node_key in graph_2.nodes]
Expand Down

0 comments on commit c73dd03

Please sign in to comment.