Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Graph visualization] Reserved dot symbols are replaced for each graph visualization #2231

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 35 additions & 14 deletions nncf/common/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,29 +574,18 @@ def get_graph_for_structure_analysis(self, extended: bool = False) -> nx.DiGraph
:param extended: whether the graph edges should have attributes: shape of the tensor and tensor primitive type.
:return: An nx.DiGraph to be used for structure analysis
"""
# .dot format reserves ':' character in node names
__RESERVED_DOT_CHARACTER = ":"
__CHARACTER_REPLACE_TO = "^"

out_graph = nx.DiGraph()
for node_name, node in self._nx_graph.nodes.items():
visualization_node_name = node_name.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
attrs_node = {"id": node[NNCFNode.ID_NODE_ATTR], "type": node[NNCFNode.NODE_TYPE_ATTR]}
for attr in ["color", "label", "style"]:
if attr in node:
attrs_node[attr] = node[attr]
# If the node_name has reserved character, use visualization_node_name as node name.
# While use 'label' attribute with original node name for visualization.
if "label" not in attrs_node and __RESERVED_DOT_CHARACTER in node_name:
attrs_node["label"] = node_name

out_graph.add_node(visualization_node_name, **attrs_node)
out_graph.add_node(node_name, **attrs_node)

for u, v in self._nx_graph.edges:
edge = self._nx_graph.edges[u, v]
attrs_edge = {}
u = u.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
v = v.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
label = {}
if edge[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR]:
label["parallel_input_port_ids"] = edge[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR]
Expand All @@ -614,7 +603,7 @@ def get_graph_for_structure_analysis(self, extended: bool = False) -> nx.DiGraph
else:
attrs_edge["label"] = ", ".join((f"{k}:{v}" for k, v in label.items()))
out_graph.add_edge(u, v, **attrs_edge)
return out_graph
return relabel_graph_for_dot_visualization(out_graph)

def _get_graph_for_visualization(self) -> nx.DiGraph:
"""
Expand Down Expand Up @@ -645,7 +634,7 @@ def _get_graph_for_visualization(self) -> nx.DiGraph:
for node in out_graph.nodes.values():
node.pop("label")

return out_graph
return relabel_graph_for_dot_visualization(out_graph)

def get_node_by_name(self, name: NNCFNodeName) -> NNCFNode:
node_ids = self._node_name_to_node_id_map.get(name, None)
Expand Down Expand Up @@ -771,3 +760,35 @@ def find_matching_subgraphs(self, patterns: GraphPattern, strict: bool = True) -
subgraph_list.append(self.get_node_by_key(node_key))
output.append(subgraph_list)
return output


def relabel_graph_for_dot_visualization(nx_graph: nx.Graph) -> nx.Graph:
"""
Relabels NetworkX graph nodes to exclude reserved symbols in keys.
In case replaced names match for two different nodes, integer index is added to its keys.
While nodes keys are being updated, visualized nodes names corresponds to the original nodes names.

:param nx_graph: NetworkX graph to visualize via dot.
:return: NetworkX graph with reserved symbols in nodes keys replaced.
"""
# .dot format reserves ':' character in node names
__RESERVED_DOT_CHARACTER = ":"
__CHARACTER_REPLACE_TO = "^"

hits = defaultdict(lambda: 0)
mapping = {}
for original_name in nx_graph.nodes():
dot_name = original_name.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
hits[dot_name] += 1
if hits[dot_name] > 1:
dot_name = f"{dot_name}_{hits}"
if original_name != dot_name:
mapping[original_name] = dot_name

relabeled_graph = nx.relabel_nodes(nx_graph, mapping)
nx.set_node_attributes(
relabeled_graph,
name="label",
values={dot_key: original_key for original_key, dot_key in mapping.items()},
)
return relabeled_graph
3 changes: 2 additions & 1 deletion nncf/common/quantization/quantizer_propagation/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nncf import nncf_logger
from nncf.common.graph import NNCFNode
from nncf.common.graph import NNCFNodeName
from nncf.common.graph.graph import relabel_graph_for_dot_visualization
from nncf.common.graph.operator_metatypes import INPUT_NOOP_METATYPES
from nncf.common.graph.operator_metatypes import OUTPUT_NOOP_METATYPES
from nncf.common.graph.operator_metatypes import NoopMetatype
Expand Down Expand Up @@ -997,7 +998,7 @@ def get_visualized_graph(self):
label="Unified group {}".format(gid),
)

return out_graph
return relabel_graph_for_dot_visualization(out_graph)

def traverse_graph(
self,
Expand Down