Skip to content

Commit

Permalink
WIP fix visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Oct 27, 2023
1 parent 218cb6b commit 73ca49b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
12 changes: 10 additions & 2 deletions nncf/common/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,11 +621,15 @@ def _get_graph_for_visualization(self) -> nx.DiGraph:
:return: A user-friendly graph .dot file, making it easier to debug the network and setup
ignored/target scopes.
"""
# .dot format reserves ':' character in node names
__RESERVED_DOT_CHARACTER = ":"
__CHARACTER_REPLACE_TO = "^"

out_graph = nx.DiGraph()
for node in self.get_all_nodes():
attrs_node = {}
attrs_node["label"] = f"{node.node_id} {node.node_name}"
node_key = self.get_node_key_by_id(node.node_id)
node_key = self.get_node_key_by_id(node.node_id).replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
out_graph.add_node(node_key, **attrs_node)

for u, v in self._nx_graph.edges:
Expand All @@ -638,9 +642,13 @@ def _get_graph_for_visualization(self) -> nx.DiGraph:
f"{edge[NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR]} \\n"
f"{edge[NNCFGraph.OUTPUT_PORT_ID_EDGE_ATTR]} -> {edge[NNCFGraph.INPUT_PORT_ID_EDGE_ATTR]}"
)
u = u.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
v = v.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
out_graph.add_edge(u, v, label=edge_label, style=style)

mapping = {k: v["label"] for k, v in out_graph.nodes.items()}
mapping = {
k: v["label"].replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO) for k, v in out_graph.nodes.items()
}
out_graph = nx.relabel_nodes(out_graph, mapping)
for node in out_graph.nodes.values():
node.pop("label")
Expand Down
17 changes: 13 additions & 4 deletions nncf/common/quantization/quantizer_propagation/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,16 +917,21 @@ def is_branching_node_dominating_outputs(self, from_node_key: str) -> bool:
return from_node_key in self._branch_nodes_directly_dominating_outputs

def get_visualized_graph(self):
# .dot format reserves ':' character in node names
__RESERVED_DOT_CHARACTER = ":"
__CHARACTER_REPLACE_TO = "^"

out_graph = nx.DiGraph()
unified_scale_group_vs_pq_node_id_dict: Dict[int, List[str]] = {}
for node_key, node in self.nodes.items():
dot_node_key = node_key.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
node_type = node[QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR]
if self.is_insertion_point(node_type):
insertion_point_data: TargetPoint = node[
QuantizerPropagationStateGraph.QUANT_INSERTION_POINT_DATA_NODE_ATTR
]
label = "TP: {}".format(str(insertion_point_data))
out_graph.add_node(node_key, label=label, color="red")
out_graph.add_node(dot_node_key, label=label, color="red")
if node[QuantizerPropagationStateGraph.PROPAGATING_QUANTIZER_NODE_ATTR] is not None:
prop_quantizer: PropagatingQuantizer = node[
QuantizerPropagationStateGraph.PROPAGATING_QUANTIZER_NODE_ATTR
Expand All @@ -947,7 +952,7 @@ def get_visualized_graph(self):
else "yellow"
)
out_graph.add_node(quant_node_key, color=pq_color, label=quant_node_label)
out_graph.add_edge(quant_node_key, node_key, style="dashed")
out_graph.add_edge(quant_node_key, dot_node_key, style="dashed")
if prop_quantizer.unified_scale_type is not None:
gid = self._unified_scale_group_manager.get_group_id_by_propagating_quantizer_id(
prop_quantizer.id
Expand All @@ -958,9 +963,9 @@ def get_visualized_graph(self):
unified_scale_group_vs_pq_node_id_dict[gid] = [quant_node_key]

elif node_type == QuantizerPropagationStateGraphNodeType.OPERATOR:
out_graph.add_node(node_key)
out_graph.add_node(dot_node_key)
elif node_type == QuantizerPropagationStateGraphNodeType.AUXILIARY_BARRIER:
out_graph.add_node(node_key, color="green", label=node["label"])
out_graph.add_node(dot_node_key, color="green", label=node["label"])
else:
raise RuntimeError("Invalid QuantizerPropagationStateGraph node!")
for u, v in self.edges:
Expand All @@ -973,6 +978,8 @@ def get_visualized_graph(self):
is_integer_path = edge[QuantizerPropagationStateGraph.IS_INTEGER_PATH_EDGE_ATTR]
if is_integer_path:
attrs = {"color": "violet", "style": "bold"}
u = u.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
v = v.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
out_graph.add_edge(u, v, **attrs)

for gid, group_pq_node_keys in unified_scale_group_vs_pq_node_id_dict.items():
Expand All @@ -989,6 +996,8 @@ def get_visualized_graph(self):
except StopIteration:
done = True
next_pq_node_key = group_pq_node_keys[0] # back to the first elt
curr_pq_node_key = curr_pq_node_key.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
next_pq_node_key = next_pq_node_key.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
out_graph.add_edge(
curr_pq_node_key,
next_pq_node_key,
Expand Down

0 comments on commit 73ca49b

Please sign in to comment.