diff --git a/python/tvm/contrib/tedd.py b/python/tvm/contrib/tedd.py
index 68e15f2b1ddd..ae5721156860 100644
--- a/python/tvm/contrib/tedd.py
+++ b/python/tvm/contrib/tedd.py
@@ -37,6 +37,19 @@
8: ('kTensorized', '#A9DFBF'),
}
+PALETTE = {
+ 0: '#000000',
+ 1: '#922B21',
+ 2: '#76448A',
+ 3: '#1F618D',
+ 4: '#148F77',
+ 5: '#B7950B',
+ 6: '#AF601A',
+ 7: '#F5B7B1',
+ 8: '#A9DFBF',
+}
+
+PALETTE_SIZE = 9
def dom_path_to_string(dom_path, prefix=""):
path_string = prefix
@@ -458,8 +471,8 @@ def stage_node_label(stage):
var_attr_label = ''
if "thread" in leafiv["properties"] and \
leafiv["properties"]["thread"] is not None:
- var_attr_label = var_attr_label + "
(" + str(
- leafiv["properties"]["thread"]) + ")"
+ var_attr_label = var_attr_label + "
(" + str(
+ leafiv["properties"]["thread"]) + ")"
if "intrin" in leafiv["properties"] and \
leafiv["properties"]["intrin"] is not None:
var_attr_label = var_attr_label + "
" + \
@@ -483,7 +496,11 @@ def compute_at_dot(g, stage):
[stage["attaching_to"][0]], "Stage") + ":" + dom_path_to_string(
stage["attaching_to"],
"IterVar") if stage["attaching_to"] is not None else "ROOT"
- g.edge(src, dst)
+ color = PALETTE[
+ stage["attaching_to"][1] +
+ 1] if stage["attaching_to"] is not None and stage["attaching_to"][
+ 1] < PALETTE_SIZE - 1 else PALETTE[0]
+ g.edge(src, dst, color=color)
graph = create_schedule_tree_graph("Schedule Tree")
s = extract_dom_for_viz(sch)
diff --git a/tests/python/contrib/test_tedd.py b/tests/python/contrib/test_tedd.py
index 6e5f3a40fbcb..58ff06418201 100644
--- a/tests/python/contrib/test_tedd.py
+++ b/tests/python/contrib/test_tedd.py
@@ -125,7 +125,7 @@ def verify():
findany(r"r.outer\(kCommReduce\)", str)
findany(r"label=ROOT", str)
# Check the compute_at edge
- findany(r"Stage_1", str)
+ findany(r"Stage_1.*\[color\=\"\#000000\"\]", str)
if checkdepdency():
verify()