diff --git a/nncf/common/graph/graph.py b/nncf/common/graph/graph.py index 4eb61c72399..9e958474c72 100644 --- a/nncf/common/graph/graph.py +++ b/nncf/common/graph/graph.py @@ -573,29 +573,18 @@ def get_graph_for_structure_analysis(self, extended: bool = False) -> nx.DiGraph :param extended: whether the graph edges should have attributes: shape of the tensor and tensor primitive type. :return: An nx.DiGraph to be used for structure analysis """ - # .dot format reserves ':' character in node names - __RESERVED_DOT_CHARACTER = ":" - __CHARACTER_REPLACE_TO = "^" - out_graph = nx.DiGraph() for node_name, node in self._nx_graph.nodes.items(): - visualization_node_name = node_name.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO) attrs_node = {"id": node[NNCFNode.ID_NODE_ATTR], "type": node[NNCFNode.NODE_TYPE_ATTR]} for attr in ["color", "label", "style"]: if attr in node: attrs_node[attr] = node[attr] - # If the node_name has reserved character, use visualization_node_name as node name. - # While use 'label' attribute with original node name for visualization. - if "label" not in attrs_node and __RESERVED_DOT_CHARACTER in node_name: - attrs_node["label"] = node_name - out_graph.add_node(visualization_node_name, **attrs_node) + out_graph.add_node(node_name, **attrs_node) for u, v in self._nx_graph.edges: edge = self._nx_graph.edges[u, v] attrs_edge = {} - u = u.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO) - v = v.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO) label = {} if edge[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR]: label["parallel_input_port_ids"] = edge[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR] @@ -613,7 +602,7 @@ def get_graph_for_structure_analysis(self, extended: bool = False) -> nx.DiGraph else: attrs_edge["label"] = ", ".join((f"{k}:{v}" for k, v in label.items())) out_graph.add_edge(u, v, **attrs_edge) - return out_graph + return relabel_graph_for_dot_visualization(out_graph) def _get_graph_for_visualization(self) -> nx.DiGraph: """ @@ -644,7 +633,7 @@ def _get_graph_for_visualization(self) -> nx.DiGraph: for node in out_graph.nodes.values(): node.pop("label") - return out_graph + return relabel_graph_for_dot_visualization(out_graph) def get_node_by_name(self, name: NNCFNodeName) -> NNCFNode: node_ids = self._node_name_to_node_id_map.get(name, None) @@ -770,3 +759,35 @@ def find_matching_subgraphs(self, patterns: GraphPattern, strict: bool = True) - subgraph_list.append(self.get_node_by_key(node_key)) output.append(subgraph_list) return output + + +def relabel_graph_for_dot_visualization(nx_graph: nx.Graph) -> nx.Graph: + """ + Relabels NetworkX graph nodes to exclude reserved symbols in keys. + In case replaced names match for two different nodes, integer index is added to its keys. + While nodes keys are being updated, visualized nodes names corresponds to the original nodes names. + + :param nx_graph: NetworkX graph to visualize via dot. + :return: NetworkX graph with reserved symbols in nodes keys replaced. + """ + # .dot format reserves ':' character in node names + __RESERVED_DOT_CHARACTER = ":" + __CHARACTER_REPLACE_TO = "^" + + hits = defaultdict(lambda: 0) + mapping = {} + for original_name in nx_graph.nodes(): + dot_name = original_name.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO) + hits[dot_name] += 1 + if hits[dot_name] > 1: + dot_name = f"{dot_name}_{hits}" + if original_name != dot_name: + mapping[original_name] = dot_name + + relabeled_graph = nx.relabel_nodes(nx_graph, mapping) + nx.set_node_attributes( + relabeled_graph, + name="label", + values={dot_key: original_key for original_key, dot_key in mapping.items()}, + ) + return relabeled_graph diff --git a/nncf/common/quantization/quantizer_propagation/graph.py b/nncf/common/quantization/quantizer_propagation/graph.py index a46968ea974..0adadc627b9 100644 --- a/nncf/common/quantization/quantizer_propagation/graph.py +++ b/nncf/common/quantization/quantizer_propagation/graph.py @@ -20,6 +20,7 @@ from nncf import nncf_logger from nncf.common.graph import NNCFNode from nncf.common.graph import NNCFNodeName +from nncf.common.graph.graph import relabel_graph_for_dot_visualization from nncf.common.graph.operator_metatypes import INPUT_NOOP_METATYPES from nncf.common.graph.operator_metatypes import OUTPUT_NOOP_METATYPES from nncf.common.graph.operator_metatypes import NoopMetatype @@ -996,7 +997,7 @@ def get_visualized_graph(self): label="Unified group {}".format(gid), ) - return out_graph + return relabel_graph_for_dot_visualization(out_graph) def traverse_graph( self,