diff --git a/nncf/common/graph/graph.py b/nncf/common/graph/graph.py index f3633772f72..7831aae397b 100644 --- a/nncf/common/graph/graph.py +++ b/nncf/common/graph/graph.py @@ -603,7 +603,7 @@ def get_graph_for_structure_analysis(self, extended: bool = False) -> nx.DiGraph else: attrs_edge["label"] = ", ".join((f"{k}:{v}" for k, v in label.items())) out_graph.add_edge(u, v, **attrs_edge) - return relabel_graph_for_dot_visualization(out_graph) + return out_graph def _get_graph_for_visualization(self) -> nx.DiGraph: """ @@ -634,7 +634,7 @@ def _get_graph_for_visualization(self) -> nx.DiGraph: for node in out_graph.nodes.values(): node.pop("label") - return relabel_graph_for_dot_visualization(out_graph) + return out_graph def get_node_by_name(self, name: NNCFNodeName) -> NNCFNode: node_ids = self._node_name_to_node_id_map.get(name, None) @@ -760,35 +760,3 @@ def find_matching_subgraphs(self, patterns: GraphPattern, strict: bool = True) - subgraph_list.append(self.get_node_by_key(node_key)) output.append(subgraph_list) return output - - -def relabel_graph_for_dot_visualization(nx_graph: nx.Graph) -> nx.Graph: - """ - Relabels NetworkX graph nodes to exclude reserved symbols in keys. - In case replaced names match for two different nodes, integer index is added to its keys. - While nodes keys are being updated, visualized nodes names corresponds to the original nodes names. - - :param nx_graph: NetworkX graph to visualize via dot. - :return: NetworkX graph with reserved symbols in nodes keys replaced. - """ - # .dot format reserves ':' character in node names - __RESERVED_DOT_CHARACTER = ":" - __CHARACTER_REPLACE_TO = "^" - - hits = defaultdict(lambda: 0) - mapping = {} - for original_name in nx_graph.nodes(): - dot_name = original_name.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO) - hits[dot_name] += 1 - if hits[dot_name] > 1: - dot_name = f"{dot_name}_{hits}" - if original_name != dot_name: - mapping[original_name] = dot_name - - relabeled_graph = nx.relabel_nodes(nx_graph, mapping) - nx.set_node_attributes( - relabeled_graph, - name="label", - values={dot_key: original_key for original_key, dot_key in mapping.items()}, - ) - return relabeled_graph diff --git a/nncf/common/quantization/quantizer_propagation/graph.py b/nncf/common/quantization/quantizer_propagation/graph.py index 0adadc627b9..a46968ea974 100644 --- a/nncf/common/quantization/quantizer_propagation/graph.py +++ b/nncf/common/quantization/quantizer_propagation/graph.py @@ -20,7 +20,6 @@ from nncf import nncf_logger from nncf.common.graph import NNCFNode from nncf.common.graph import NNCFNodeName -from nncf.common.graph.graph import relabel_graph_for_dot_visualization from nncf.common.graph.operator_metatypes import INPUT_NOOP_METATYPES from nncf.common.graph.operator_metatypes import OUTPUT_NOOP_METATYPES from nncf.common.graph.operator_metatypes import NoopMetatype @@ -997,7 +996,7 @@ def get_visualized_graph(self): label="Unified group {}".format(gid), ) - return relabel_graph_for_dot_visualization(out_graph) + return out_graph def traverse_graph( self, diff --git a/nncf/common/utils/dot_file_rw.py b/nncf/common/utils/dot_file_rw.py index a67dd9d5df8..f956b22c8ac 100644 --- a/nncf/common/utils/dot_file_rw.py +++ b/nncf/common/utils/dot_file_rw.py @@ -8,17 +8,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import copy import pathlib +from collections import defaultdict +from typing import Dict import networkx as nx def write_dot_graph(G: nx.DiGraph, path: pathlib.Path): # NOTE: writing dot files with colons even in labels or other node/edge/graph attributes leads to an - # error. See https://github.com/networkx/networkx/issues/5962. This limits the networkx version in - # NNCF to 2.8.3 unless this is fixed upstream or an inconvenient workaround is made in NNCF. - nx.nx_pydot.write_dot(G, str(path)) + # error. See https://github.com/networkx/networkx/issues/5962. If `relabel` is True in this function, + # then the colons (:) will be replaced with (^) symbols. + relabeled = relabel_graph_for_dot_visualization(G) + nx.nx_pydot.write_dot(relabeled, str(path)) def get_graph_without_data(G: nx.DiGraph) -> nx.DiGraph: @@ -36,4 +39,90 @@ def get_graph_without_data(G: nx.DiGraph) -> nx.DiGraph: def read_dot_graph(path: pathlib.Path) -> nx.DiGraph: - return nx.nx_pydot.read_dot(str(path)) + loaded = nx.DiGraph(nx.nx_pydot.read_dot(str(path))) + return relabel_graph_for_dot_visualization(loaded, from_reference=True) + + +RESERVED_CHAR = ":" +REPLACEMENT_CHAR = "^" + + +def _maybe_escape_colons_in_attrs(data: Dict): + for attr_name in data: + attr_val = data[attr_name] + if RESERVED_CHAR in attr_val and not (attr_val[0] == '"' or attr_val[-1] == '"'): + data[attr_name] = '"' + data[attr_name] + '"' # escaped colons are allowed + + +def _unescape_colons_in_attrs_with_colons(data: Dict): + for attr_name in data: + attr_val = data[attr_name] + if RESERVED_CHAR in attr_val and (attr_val[0] == '"' and attr_val[-1] == '"'): + data[attr_name] = data[attr_name][1:-1] + + +def _remove_cosmetic_labels(graph: nx.DiGraph): + for node_name, node_data in graph.nodes(data=True): + if "label" in node_data: + label = node_data["label"] + if node_name == label or '"' + node_name + '"' == label: + del node_data["label"] + + +def _add_cosmetic_labels(graph: nx.DiGraph, relabeled_node_mapping: Dict[str, str]): + for original_name, dot_name in relabeled_node_mapping.items(): + node_data = graph.nodes[dot_name] + if "label" not in node_data: + node_data["label"] = '"' + original_name + '"' + + +def relabel_graph_for_dot_visualization(nx_graph: nx.Graph, from_reference: bool = False) -> nx.DiGraph: + """ + Relabels NetworkX graph nodes to exclude reserved symbols in keys. + In case replaced names match for two different nodes, integer index is added to its keys. + While nodes keys are being updated, visualized nodes names corresponds to the original nodes names. + + :param nx_graph: NetworkX graph to visualize via dot. + :return: NetworkX graph with reserved symbols in nodes keys replaced. + """ + + nx_graph = copy.deepcopy(nx_graph) + + # .dot format reserves ':' character in node names + if not from_reference: + # dumping to disk + __CHARACTER_REPLACE_FROM = RESERVED_CHAR + __CHARACTER_REPLACE_TO = REPLACEMENT_CHAR + else: + # loading from disk + __CHARACTER_REPLACE_FROM = REPLACEMENT_CHAR + __CHARACTER_REPLACE_TO = RESERVED_CHAR + + hits = defaultdict(lambda: 0) + mapping = {} + for original_name in nx_graph.nodes(): + dot_name = original_name.replace(__CHARACTER_REPLACE_FROM, __CHARACTER_REPLACE_TO) + hits[dot_name] += 1 + if hits[dot_name] > 1: + dot_name = f"{dot_name}_{hits}" + if original_name != dot_name: + mapping[original_name] = dot_name + + relabeled_graph = nx.relabel_nodes(nx_graph, mapping) + + if not from_reference: + # dumping to disk + _add_cosmetic_labels(relabeled_graph, mapping) + for _, node_data in relabeled_graph.nodes(data=True): + _maybe_escape_colons_in_attrs(node_data) + for _, _, edge_data in relabeled_graph.edges(data=True): + _maybe_escape_colons_in_attrs(edge_data) + else: + # loading from disk + _remove_cosmetic_labels(relabeled_graph) + for _, node_data in relabeled_graph.nodes(data=True): + _unescape_colons_in_attrs_with_colons(node_data) + for _, _, edge_data in relabeled_graph.edges(data=True): + _unescape_colons_in_attrs_with_colons(edge_data) + + return relabeled_graph diff --git a/setup.py b/setup.py index 7663a33a528..05a8239ca71 100644 --- a/setup.py +++ b/setup.py @@ -104,7 +104,7 @@ def find_version(*file_paths): "jsonschema>=3.2.0", "jstyleson>=0.0.2", "natsort>=7.1.0", - "networkx>=2.6, <=2.8.2", # see ticket 94048 or https://github.com/networkx/networkx/issues/5962 + "networkx>=2.6, <=3.1", # see ticket 94048 or https://github.com/networkx/networkx/issues/5962 "ninja>=1.10.0.post2, <1.11", "numpy>=1.19.1, <1.27", "openvino-telemetry>=2023.2.0", @@ -113,12 +113,6 @@ def find_version(*file_paths): "psutil", "pydot>=1.4.1", "pymoo>=0.6.0.1", - # The recent pyparsing major version update seems to break - # integration with networkx - the graphs parsed from current .dot - # reference files no longer match against the graphs produced in tests. - # Using 2.x versions of pyparsing seems to fix the issue. - # Ticket: 69520 - "pyparsing<3.0", "rich>=13.5.2", "scikit-learn>=0.24.0", "scipy>=1.3.2", diff --git a/tests/common/data/reference_graphs/dot_rw_reference.dot b/tests/common/data/reference_graphs/dot_rw_reference.dot new file mode 100644 index 00000000000..90ce0a6b826 --- /dev/null +++ b/tests/common/data/reference_graphs/dot_rw_reference.dot @@ -0,0 +1,9 @@ +strict digraph { +"Node^^A" [label=":baz"]; +"Node^^B" [label="qux:"]; +"Node^^C" [label="Node::C"]; +D; +E [label=no_label]; +F [label="has^label"]; +"Node^^A" -> "Node^^B" [label="foo:bar"]; +} diff --git a/tests/common/graph/test_dot_file_rw.py b/tests/common/graph/test_dot_file_rw.py new file mode 100644 index 00000000000..75b965b88b5 --- /dev/null +++ b/tests/common/graph/test_dot_file_rw.py @@ -0,0 +1,55 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import filecmp +from copy import deepcopy +from pathlib import Path + +import networkx as nx +import pytest + +from nncf.common.utils.dot_file_rw import read_dot_graph +from nncf.common.utils.dot_file_rw import write_dot_graph +from tests.shared.nx_graph import check_nx_graph +from tests.shared.paths import TEST_ROOT + + +@pytest.fixture(scope="module") +def ref_graph() -> nx.DiGraph: + graph = nx.DiGraph() + graph.add_node("Node::A", label=":baz") + graph.add_node("Node::B", label="qux:") + graph.add_node("Node::C") + graph.add_node("D") + graph.add_node("E", label="no_label") + graph.add_node("F", label="has^label") + graph.add_node("F", label="has^label") + graph.add_edge("Node::A", "Node::B", label="foo:bar"), + return graph + + +REF_DOT_REPRESENTATION_GRAPH_PATH = TEST_ROOT / "common" / "data" / "reference_graphs" / "dot_rw_reference.dot" + + +def test_writing_does_not_modify_original_graph(tmp_path: Path, ref_graph: nx.DiGraph): + ref_graph_copy = deepcopy(ref_graph) + write_dot_graph(ref_graph_copy, tmp_path / "graph.dot") + assert nx.utils.graphs_equal(ref_graph_copy, ref_graph) + + +def test_colons_are_replaced_in_written_dot_file(tmp_path: Path, ref_graph: nx.DiGraph): + tmp_path_to_graph = tmp_path / "graph.dot" + write_dot_graph(ref_graph, tmp_path_to_graph) + assert filecmp.cmp(tmp_path_to_graph, REF_DOT_REPRESENTATION_GRAPH_PATH) + + +def test_read_dot_file_gives_graph_with_colons(tmp_path: Path, ref_graph: nx.DiGraph): + test_graph = read_dot_graph(REF_DOT_REPRESENTATION_GRAPH_PATH) + check_nx_graph(test_graph, ref_graph, check_edge_attrs=True)