From a7873b0aeca8f20e6e7a1e43b1c9a912f610d1b1 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 2 Oct 2019 16:13:10 -0400 Subject: [PATCH] [RELAY/PASS] Fix the extent for the post_stmt in the loop partition (#3734) --- src/pass/loop_partition.cc | 56 +++++++++---------- .../unittest/test_pass_bound_checkers.py | 1 + topi/tests/python/test_topi_math.py | 24 ++++++-- 3 files changed, 47 insertions(+), 34 deletions(-) diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index dc791477c616..b2a1bea0dd5b 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -492,7 +492,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, std::tie(middle_interval, cond_set) = GetIntervalAndCondset(finder.partitions, for_interval, false); if (middle_interval.is_nothing()) - // we couldn't find an interval in which the condintions are provably true or false + // we couldn't find an interval in which the conditions are provably true or false // Therefore, we can't partition the loop based on those conds return Stmt(); cond_value = false; @@ -513,46 +513,42 @@ Stmt LoopPartitioner::TryPartition(const Node* node, bool pre_stmt_recurse = true; if (middle_interval_i->HasLowerBound()) { body_begin = ir::Simplify(middle_interval.min()); - if (!analyzer_.CanProve(body_begin == min)) { - Expr cond = (body_begin - min >= 0); - if (!analyzer_.CanProve(cond)) { - LOG(WARNING) << "Cannot prove: " << cond - << ", when generating the pre doubt loop"; - body_begin = Max::make(body_begin, min); - // stop recursing on this interval if we can't prove it has non-negative length - pre_stmt_recurse = false; - } - if (!partition_thread_scope) { - Stmt pre_body = Substitute(body, {{Var{var}, var + min}}); - pre_stmt = MakeFor(node, body_begin - min, pre_body); - } + Expr cond = (body_begin - min >= 0); + if (!analyzer_.CanProve(cond)) { + LOG(WARNING) << "Cannot prove: " << cond + << ", when generating the pre doubt loop"; + body_begin = Max::make(body_begin, min); + // stop recursing on this interval if we can't prove it has non-negative length + pre_stmt_recurse = false; + } + if (!partition_thread_scope) { + Stmt pre_body = Substitute(body, {{Var{var}, var + min}}); + pre_stmt = MakeFor(node, body_begin - min, pre_body); } } else { body_begin = min; } // Calculating post-subrange and generating code for it. - // post-subrange = [post_doubt_begin, max] + // post-subrange = [post_doubt_begin, max+1) Expr post_doubt_begin; Stmt post_stmt; bool post_stmt_recurse = true; if (middle_interval_i->HasUpperBound()) { post_doubt_begin = ir::Simplify(middle_interval.max() + 1); - if (!analyzer_.CanProve(middle_interval.max() == max)) { - // require the extent to be non-negative - Expr cond = (max - post_doubt_begin + 1 >= 0); - if (!analyzer_.CanProve(cond)) { - LOG(WARNING) << "Cannot prove: " << cond - << ", when generating the post doubt loop"; - post_doubt_begin = Min::make(post_doubt_begin, max); - // stop recursing on this interval if we can't prove it has non-negative length - post_stmt_recurse = false; - } - if (!partition_thread_scope) { - Stmt post_body = - Substitute(body, {{Var{var}, var + post_doubt_begin}}); - post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body); - } + // require the extent to be non-negative + Expr cond = (max - post_doubt_begin + 1 >= 0); + if (!analyzer_.CanProve(cond)) { + LOG(WARNING) << "Cannot prove: " << cond + << ", when generating the post doubt loop"; + post_doubt_begin = Min::make(post_doubt_begin, max+1); + // stop recursing on this interval if we can't prove it has non-negative length + post_stmt_recurse = false; + } + if (!partition_thread_scope) { + Stmt post_body = + Substitute(body, {{Var{var}, var + post_doubt_begin}}); + post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body); } } else { post_doubt_begin = max + 1; diff --git a/tests/python/unittest/test_pass_bound_checkers.py b/tests/python/unittest/test_pass_bound_checkers.py index c319d4d9202a..ada81690d05d 100644 --- a/tests/python/unittest/test_pass_bound_checkers.py +++ b/tests/python/unittest/test_pass_bound_checkers.py @@ -37,6 +37,7 @@ def lower(sch, args): bounds = tvm.schedule.InferBound(sch) stmt = tvm.schedule.ScheduleOps(sch, bounds) stmt = tvm.ir_pass.LoopPartition(stmt, True) + stmt = tvm.ir_pass.RemoveNoOp(stmt) stmt = tvm.ir_pass.StorageFlatten(stmt, binds, 64, True) stmt = tvm.ir_pass.CanonicalSimplify(stmt) stmt = tvm.ir_pass.VectorizeLoop(stmt) diff --git a/topi/tests/python/test_topi_math.py b/topi/tests/python/test_topi_math.py index 660d22ccf2bc..ebbf6f7bd3d6 100644 --- a/topi/tests/python/test_topi_math.py +++ b/topi/tests/python/test_topi_math.py @@ -69,8 +69,16 @@ def check_device(device): foo(a, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5) - for device in get_all_backend(): - check_device(device) + check_device('llvm') + check_device('cuda') + check_device('opencl') + check_device('metal') + check_device('rocm') + check_device('vulkan') + check_device('nvptx') + check_device('llvm -device=arm-cpu') + check_device('opencl -device=mali') + check_device('aocl_sw_emu') def test_isnan( low, @@ -109,8 +117,16 @@ def check_device(device): foo(a, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5) - for device in get_all_backend(): - check_device(device) + check_device('llvm') + check_device('cuda') + check_device('opencl') + check_device('metal') + check_device('rocm') + check_device('vulkan') + check_device('nvptx') + check_device('llvm -device=arm-cpu') + check_device('opencl -device=mali') + check_device('aocl_sw_emu') test_apply(topi.floor, "floor", np.floor, -100, 100) test_apply(topi.ceil, "ceil", np.ceil, -100, 100)