From dc21596459d96d3712141cd59c2ef3cdf3fb6855 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Tue, 10 Mar 2020 09:53:18 +0530 Subject: [PATCH] Early checking added and new test cases added for schedule fuse (#5010) * [1] New test case added for fuse * [2] New test case added for fuse * [3] New test case added for fuse * [4] New test case added for fuse * [5] Early check added --- src/te/schedule/schedule_lang.cc | 4 +- tests/python/unittest/test_lang_schedule.py | 46 +++++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index 7a2e05a56fd3..bfee0d5a0a6b 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -263,10 +263,10 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT std::swap(outer, inner); std::swap(pos_inner, pos_outer); } - self->relations.push_back(FuseNode::make(outer, inner, fused)); - all_vars->data.push_back(fused); CHECK_EQ(pos_inner, pos_outer + 1) << "Can only fuse iterations that are consecutive between each other"; + self->relations.push_back(FuseNode::make(outer, inner, fused)); + all_vars->data.push_back(fused); leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer, leaf_vars->data.begin() + pos_inner + 1); leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer, diff --git a/tests/python/unittest/test_lang_schedule.py b/tests/python/unittest/test_lang_schedule.py index dae43bb2bbf2..5a4c02173f39 100644 --- a/tests/python/unittest/test_lang_schedule.py +++ b/tests/python/unittest/test_lang_schedule.py @@ -102,6 +102,49 @@ def test_fuse(): assert any(isinstance(x, tvm.te.schedule.Fuse) for x in s[T].relations) assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi) +def test_fuse_with_split(): + m = te.size_var('m') + n = te.size_var('n') + A = te.placeholder((m, n), name='A') + T = te.compute((m, n), lambda i, j: A[i, j]) + + s = te.create_schedule(T.op) + y = T.op.axis[1] + xo, xi = s[T].split(T.op.axis[0], factor=10) + fused = s[T].fuse(xi, y) + assert any(isinstance(x, tvm.te.schedule.Fuse) for x in s[T].relations) + assert tuple(s[T].leaf_iter_vars) == (xo, fused) + +@pytest.mark.xfail +def test_fuse_with_out_of_order_axis(): + m = te.size_var('m') + n = te.size_var('n') + A = te.placeholder((m, n), name='A') + T = te.compute((m, n), lambda i, j: A[i, j]) + + s = te.create_schedule(T.op) + y = T.op.axis[1] + xo, xi = s[T].split(T.op.axis[0], factor=10) + fused = s[T].fuse(xo, y) # should throw here + +@pytest.mark.xfail +def test_fuse_with_out_of_order_axis_with_reorder(): + m = te.size_var('m') + n = te.size_var('n') + A = te.placeholder((m, n), name='A') + T = te.compute((m, n), lambda i, j: A[i, j]) + + s = te.create_schedule(T.op) + y = T.op.axis[1] + xo, xi = s[T].split(T.op.axis[0], factor=10) + s[T].reorder(y, xo, xi) + fused = s[T].fuse(y, xo) # should be ok + + s = te.create_schedule(T.op) + y = T.op.axis[1] + xo, xi = s[T].split(T.op.axis[0], factor=10) + s[T].reorder(y, xo, xi) + fused = s[T].fuse(y, xi) # should throw here def test_singleton(): print("test singleton") @@ -257,5 +300,8 @@ def intrin_func(ins, outs, sp): test_tile() test_split() test_fuse() + test_fuse_with_split() + test_fuse_with_out_of_order_axis() + test_fuse_with_out_of_order_axis_with_reorder() test_vectorize() test_vectorize_commreduce()