Skip to content

Commit

Permalink
[RELAY/PASS] Fix the extent for the post_stmt in the loop partition (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Umang Yadav authored and ZihengJiang committed Oct 2, 2019
1 parent 59cf573 commit a7873b0
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 34 deletions.
56 changes: 26 additions & 30 deletions src/pass/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
std::tie(middle_interval, cond_set) =
GetIntervalAndCondset(finder.partitions, for_interval, false);
if (middle_interval.is_nothing())
// we couldn't find an interval in which the condintions are provably true or false
// we couldn't find an interval in which the conditions are provably true or false
// Therefore, we can't partition the loop based on those conds
return Stmt();
cond_value = false;
Expand All @@ -513,46 +513,42 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
bool pre_stmt_recurse = true;
if (middle_interval_i->HasLowerBound()) {
body_begin = ir::Simplify(middle_interval.min());
if (!analyzer_.CanProve(body_begin == min)) {
Expr cond = (body_begin - min >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the pre doubt loop";
body_begin = Max::make(body_begin, min);
// stop recursing on this interval if we can't prove it has non-negative length
pre_stmt_recurse = false;
}
if (!partition_thread_scope) {
Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
pre_stmt = MakeFor(node, body_begin - min, pre_body);
}
Expr cond = (body_begin - min >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the pre doubt loop";
body_begin = Max::make(body_begin, min);
// stop recursing on this interval if we can't prove it has non-negative length
pre_stmt_recurse = false;
}
if (!partition_thread_scope) {
Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
pre_stmt = MakeFor(node, body_begin - min, pre_body);
}
} else {
body_begin = min;
}

// Calculating post-subrange and generating code for it.
// post-subrange = [post_doubt_begin, max]
// post-subrange = [post_doubt_begin, max+1)
Expr post_doubt_begin;
Stmt post_stmt;
bool post_stmt_recurse = true;
if (middle_interval_i->HasUpperBound()) {
post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
if (!analyzer_.CanProve(middle_interval.max() == max)) {
// require the extent to be non-negative
Expr cond = (max - post_doubt_begin + 1 >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the post doubt loop";
post_doubt_begin = Min::make(post_doubt_begin, max);
// stop recursing on this interval if we can't prove it has non-negative length
post_stmt_recurse = false;
}
if (!partition_thread_scope) {
Stmt post_body =
Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
}
// require the extent to be non-negative
Expr cond = (max - post_doubt_begin + 1 >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the post doubt loop";
post_doubt_begin = Min::make(post_doubt_begin, max+1);
// stop recursing on this interval if we can't prove it has non-negative length
post_stmt_recurse = false;
}
if (!partition_thread_scope) {
Stmt post_body =
Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
}
} else {
post_doubt_begin = max + 1;
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_pass_bound_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def lower(sch, args):
bounds = tvm.schedule.InferBound(sch)
stmt = tvm.schedule.ScheduleOps(sch, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.RemoveNoOp(stmt)
stmt = tvm.ir_pass.StorageFlatten(stmt, binds, 64, True)
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.VectorizeLoop(stmt)
Expand Down
24 changes: 20 additions & 4 deletions topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,16 @@ def check_device(device):
foo(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)

for device in get_all_backend():
check_device(device)
check_device('llvm')
check_device('cuda')
check_device('opencl')
check_device('metal')
check_device('rocm')
check_device('vulkan')
check_device('nvptx')
check_device('llvm -device=arm-cpu')
check_device('opencl -device=mali')
check_device('aocl_sw_emu')

def test_isnan(
low,
Expand Down Expand Up @@ -109,8 +117,16 @@ def check_device(device):
foo(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)

for device in get_all_backend():
check_device(device)
check_device('llvm')
check_device('cuda')
check_device('opencl')
check_device('metal')
check_device('rocm')
check_device('vulkan')
check_device('nvptx')
check_device('llvm -device=arm-cpu')
check_device('opencl -device=mali')
check_device('aocl_sw_emu')

test_apply(topi.floor, "floor", np.floor, -100, 100)
test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
Expand Down

0 comments on commit a7873b0

Please sign in to comment.