Skip to content

Commit

Permalink
Add colors to compute_at edges and thread/block indices. (#5111)
Browse files Browse the repository at this point in the history
  • Loading branch information
yongfeng-nv authored Mar 20, 2020
1 parent 841725c commit b91dbca
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
23 changes: 20 additions & 3 deletions python/tvm/contrib/tedd.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@
8: ('kTensorized', '#A9DFBF'),
}

PALETTE = {
0: '#000000',
1: '#922B21',
2: '#76448A',
3: '#1F618D',
4: '#148F77',
5: '#B7950B',
6: '#AF601A',
7: '#F5B7B1',
8: '#A9DFBF',
}

PALETTE_SIZE = 9

def dom_path_to_string(dom_path, prefix=""):
path_string = prefix
Expand Down Expand Up @@ -458,8 +471,8 @@ def stage_node_label(stage):
var_attr_label = ''
if "thread" in leafiv["properties"] and \
leafiv["properties"]["thread"] is not None:
var_attr_label = var_attr_label + "<br/>(" + str(
leafiv["properties"]["thread"]) + ")"
var_attr_label = var_attr_label + "<br/><font color=\"#2980B9\">(" + str(
leafiv["properties"]["thread"]) + ")</font>"
if "intrin" in leafiv["properties"] and \
leafiv["properties"]["intrin"] is not None:
var_attr_label = var_attr_label + "<br/>" + \
Expand All @@ -483,7 +496,11 @@ def compute_at_dot(g, stage):
[stage["attaching_to"][0]], "Stage") + ":" + dom_path_to_string(
stage["attaching_to"],
"IterVar") if stage["attaching_to"] is not None else "ROOT"
g.edge(src, dst)
color = PALETTE[
stage["attaching_to"][1] +
1] if stage["attaching_to"] is not None and stage["attaching_to"][
1] < PALETTE_SIZE - 1 else PALETTE[0]
g.edge(src, dst, color=color)

graph = create_schedule_tree_graph("Schedule Tree")
s = extract_dom_for_viz(sch)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/contrib/test_tedd.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def verify():
findany(r"r.outer\(kCommReduce\)", str)
findany(r"label=ROOT", str)
# Check the compute_at edge
findany(r"Stage_1", str)
findany(r"Stage_1.*\[color\=\"\#000000\"\]", str)

if checkdepdency():
verify()
Expand Down

0 comments on commit b91dbca

Please sign in to comment.