From 22ded913c23057438a49fce8d5c4495916a7c9ae Mon Sep 17 00:00:00 2001 From: Salem Derisavi Date: Tue, 12 Mar 2019 15:33:55 -0400 Subject: [PATCH] Ensure loop count is a constant before trying to unroll. --- src/pass/unroll_loop.cc | 2 +- tests/python/unittest/test_pass_unroll.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc index d4481e86c0fc..f1f13cb87fa3 100644 --- a/src/pass/unroll_loop.cc +++ b/src/pass/unroll_loop.cc @@ -78,7 +78,7 @@ class LoopUnroller : public IRMutator { if ((auto_unroll && explicit_unroll_) || // unroll loops with extent = 1, no matter how many steps in body - (value <= auto_max_extent_ && auto_max_extent_ == 1)) { + (0 <= value && value <= auto_max_extent_ && auto_max_extent_ == 1)) { return Unroll(op); } else { if (auto_unroll) { diff --git a/tests/python/unittest/test_pass_unroll.py b/tests/python/unittest/test_pass_unroll.py index c88a019a8bce..d51c3f016d03 100644 --- a/tests/python/unittest/test_pass_unroll.py +++ b/tests/python/unittest/test_pass_unroll.py @@ -51,7 +51,20 @@ def test_unroll_fake_loop(): ret = tvm.ir_pass.UnrollLoop(stmt, 8, 0, 1, True) assert isinstance(ret.first, tvm.stmt.Store) +def test_unroll_single_count_loops(): + n = tvm.var('n') + A = tvm.placeholder((n,), name='A') + B = tvm.compute((n,), lambda *i: A(*i), name='B') + s = tvm.create_schedule(B.op) + s = s.normalize() + dom_map = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, dom_map) + # all parameters to UnrolLoops are default values except for + # auto_unroll_max_extent which has been set to 1 (default:0) + after_unroll_stmt = tvm.ir_pass.UnrollLoop(stmt, 0, 8, 1, True) + assert after_unroll_stmt == stmt if __name__ == "__main__": test_unroll_loop() test_unroll_fake_loop() + test_unroll_single_count_loops()