diff --git a/python/tvm/contrib/tedd.py b/python/tvm/contrib/tedd.py index 68e15f2b1ddd..ae5721156860 100644 --- a/python/tvm/contrib/tedd.py +++ b/python/tvm/contrib/tedd.py @@ -37,6 +37,19 @@ 8: ('kTensorized', '#A9DFBF'), } +PALETTE = { + 0: '#000000', + 1: '#922B21', + 2: '#76448A', + 3: '#1F618D', + 4: '#148F77', + 5: '#B7950B', + 6: '#AF601A', + 7: '#F5B7B1', + 8: '#A9DFBF', +} + +PALETTE_SIZE = 9 def dom_path_to_string(dom_path, prefix=""): path_string = prefix @@ -458,8 +471,8 @@ def stage_node_label(stage): var_attr_label = '' if "thread" in leafiv["properties"] and \ leafiv["properties"]["thread"] is not None: - var_attr_label = var_attr_label + "
(" + str( - leafiv["properties"]["thread"]) + ")" + var_attr_label = var_attr_label + "
(" + str( + leafiv["properties"]["thread"]) + ")" if "intrin" in leafiv["properties"] and \ leafiv["properties"]["intrin"] is not None: var_attr_label = var_attr_label + "
" + \ @@ -483,7 +496,11 @@ def compute_at_dot(g, stage): [stage["attaching_to"][0]], "Stage") + ":" + dom_path_to_string( stage["attaching_to"], "IterVar") if stage["attaching_to"] is not None else "ROOT" - g.edge(src, dst) + color = PALETTE[ + stage["attaching_to"][1] + + 1] if stage["attaching_to"] is not None and stage["attaching_to"][ + 1] < PALETTE_SIZE - 1 else PALETTE[0] + g.edge(src, dst, color=color) graph = create_schedule_tree_graph("Schedule Tree") s = extract_dom_for_viz(sch) diff --git a/tests/python/contrib/test_tedd.py b/tests/python/contrib/test_tedd.py index 6e5f3a40fbcb..58ff06418201 100644 --- a/tests/python/contrib/test_tedd.py +++ b/tests/python/contrib/test_tedd.py @@ -125,7 +125,7 @@ def verify(): findany(r"r.outer\(kCommReduce\)", str) findany(r"label=ROOT", str) # Check the compute_at edge - findany(r"Stage_1", str) + findany(r"Stage_1.*\[color\=\"\#000000\"\]", str) if checkdepdency(): verify()