Skip to content

Commit

Permalink
Reserved dot symbols are replaced for each graph visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Oct 27, 2023
1 parent 729f6cc commit 2a10e3e
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 15 deletions.
49 changes: 35 additions & 14 deletions nncf/common/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,29 +574,18 @@ def get_graph_for_structure_analysis(self, extended: bool = False) -> nx.DiGraph
:param extended: whether the graph edges should have attributes: shape of the tensor and tensor primitive type.
:return: An nx.DiGraph to be used for structure analysis
"""
# .dot format reserves ':' character in node names
__RESERVED_DOT_CHARACTER = ":"
__CHARACTER_REPLACE_TO = "^"

out_graph = nx.DiGraph()
for node_name, node in self._nx_graph.nodes.items():
visualization_node_name = node_name.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
attrs_node = {"id": node[NNCFNode.ID_NODE_ATTR], "type": node[NNCFNode.NODE_TYPE_ATTR]}
for attr in ["color", "label", "style"]:
if attr in node:
attrs_node[attr] = node[attr]
# If the node_name has reserved character, use visualization_node_name as node name.
# While use 'label' attribute with original node name for visualization.
if "label" not in attrs_node and __RESERVED_DOT_CHARACTER in node_name:
attrs_node["label"] = node_name

out_graph.add_node(visualization_node_name, **attrs_node)
out_graph.add_node(node_name, **attrs_node)

for u, v in self._nx_graph.edges:
edge = self._nx_graph.edges[u, v]
attrs_edge = {}
u = u.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
v = v.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
label = {}
if edge[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR]:
label["parallel_input_port_ids"] = edge[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR]
Expand All @@ -614,7 +603,7 @@ def get_graph_for_structure_analysis(self, extended: bool = False) -> nx.DiGraph
else:
attrs_edge["label"] = ", ".join((f"{k}:{v}" for k, v in label.items()))
out_graph.add_edge(u, v, **attrs_edge)
return out_graph
return relabel_graph_for_dot_visualization(out_graph)

def _get_graph_for_visualization(self) -> nx.DiGraph:
"""
Expand Down Expand Up @@ -645,7 +634,7 @@ def _get_graph_for_visualization(self) -> nx.DiGraph:
for node in out_graph.nodes.values():
node.pop("label")

return out_graph
return relabel_graph_for_dot_visualization(out_graph)

def get_node_by_name(self, name: NNCFNodeName) -> NNCFNode:
node_ids = self._node_name_to_node_id_map.get(name, None)
Expand Down Expand Up @@ -771,3 +760,35 @@ def find_matching_subgraphs(self, patterns: GraphPattern, strict: bool = True) -
subgraph_list.append(self.get_node_by_key(node_key))
output.append(subgraph_list)
return output


def relabel_graph_for_dot_visualization(nx_graph: nx.Graph) -> nx.Graph:
"""
Relabels NetworkX graph nodes to exclude reserved symbols in keys.
In case replaced names match for two different nodes, integer index is added to its keys.
While node keys are being updated, visualized nodes names corresponds to the original nodes names.
:param nx_graph: NetworkX graph to visualize via dot.
:return: NetworkX graph with reserved symbols in nodes keys replaced.
"""
# .dot format reserves ':' character in node names
__RESERVED_DOT_CHARACTER = ":"
__CHARACTER_REPLACE_TO = "^"

hits = defaultdict(lambda: 0)
mapping = {}
for original_name in nx_graph.nodes():
if __RESERVED_DOT_CHARACTER in original_name:
dot_name = original_name.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
hits[dot_name] += 1
if hits[dot_name] > 1:
dot_name = f"{dot_name}_{hits}"
mapping[original_name] = dot_name

relabeled_graph = nx.relabel_nodes(nx_graph, mapping)
nx.set_node_attributes(
relabeled_graph,
name="label",
values={dot_key: original_key for original_key, dot_key in mapping.items()},
)
return relabeled_graph
3 changes: 2 additions & 1 deletion nncf/common/quantization/quantizer_propagation/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nncf import nncf_logger
from nncf.common.graph import NNCFNode
from nncf.common.graph import NNCFNodeName
from nncf.common.graph.graph import relabel_graph_for_dot_visualization
from nncf.common.graph.operator_metatypes import INPUT_NOOP_METATYPES
from nncf.common.graph.operator_metatypes import OUTPUT_NOOP_METATYPES
from nncf.common.graph.operator_metatypes import NoopMetatype
Expand Down Expand Up @@ -997,7 +998,7 @@ def get_visualized_graph(self):
label="Unified group {}".format(gid),
)

return out_graph
return relabel_graph_for_dot_visualization(out_graph)

def traverse_graph(
self,
Expand Down

0 comments on commit 2a10e3e

Please sign in to comment.