From 1044c4df18b1a6ee19a3b74d5c937a5692a0eba3 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Fri, 15 Nov 2019 14:37:37 -0800 Subject: [PATCH] [FIX] Fix for a specific case when loop partitioning with indivisble (#4243) factors and resulting nested loop is broken. This is due to the fact that we are creating zero extent loops which are fixed afterwards. However unroll pass breaks due to the zero extent loop. --- src/pass/loop_partition.cc | 52 ++++++++++--------- .../unittest/test_pass_loop_partition.py | 22 ++++++++ 2 files changed, 50 insertions(+), 24 deletions(-) diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index b2a1bea0dd5b..1d669c8b39f1 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -513,17 +513,19 @@ Stmt LoopPartitioner::TryPartition(const Node* node, bool pre_stmt_recurse = true; if (middle_interval_i->HasLowerBound()) { body_begin = ir::Simplify(middle_interval.min()); - Expr cond = (body_begin - min >= 0); - if (!analyzer_.CanProve(cond)) { - LOG(WARNING) << "Cannot prove: " << cond - << ", when generating the pre doubt loop"; - body_begin = Max::make(body_begin, min); - // stop recursing on this interval if we can't prove it has non-negative length - pre_stmt_recurse = false; - } - if (!partition_thread_scope) { - Stmt pre_body = Substitute(body, {{Var{var}, var + min}}); - pre_stmt = MakeFor(node, body_begin - min, pre_body); + if (!analyzer_.CanProve(body_begin == min)) { + Expr cond = (body_begin - min >= 0); + if (!analyzer_.CanProve(cond)) { + LOG(WARNING) << "Cannot prove: " << cond + << ", when generating the pre doubt loop"; + body_begin = Max::make(body_begin, min); + // stop recursing on this interval if we can't prove it has non-negative length + pre_stmt_recurse = false; + } + if (!partition_thread_scope) { + Stmt pre_body = Substitute(body, {{Var{var}, var + min}}); + pre_stmt = MakeFor(node, body_begin - min, pre_body); + } } } else { body_begin = min; @@ -536,19 +538,21 @@ Stmt LoopPartitioner::TryPartition(const Node* node, bool post_stmt_recurse = true; if (middle_interval_i->HasUpperBound()) { post_doubt_begin = ir::Simplify(middle_interval.max() + 1); - // require the extent to be non-negative - Expr cond = (max - post_doubt_begin + 1 >= 0); - if (!analyzer_.CanProve(cond)) { - LOG(WARNING) << "Cannot prove: " << cond - << ", when generating the post doubt loop"; - post_doubt_begin = Min::make(post_doubt_begin, max+1); - // stop recursing on this interval if we can't prove it has non-negative length - post_stmt_recurse = false; - } - if (!partition_thread_scope) { - Stmt post_body = - Substitute(body, {{Var{var}, var + post_doubt_begin}}); - post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body); + if (!analyzer_.CanProve(middle_interval.max() == max)) { + // require the extent to be non-negative + Expr cond = (max - post_doubt_begin + 1 >= 0); + if (!analyzer_.CanProve(cond)) { + LOG(WARNING) << "Cannot prove: " << cond + << ", when generating the post doubt loop"; + post_doubt_begin = Min::make(post_doubt_begin, max+1); + // stop recursing on this interval if we can't prove it has non-negative length + post_stmt_recurse = false; + } + if (!partition_thread_scope) { + Stmt post_body = + Substitute(body, {{Var{var}, var + post_doubt_begin}}); + post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body); + } } } else { post_doubt_begin = max + 1; diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py index b6fcfa3a1512..021709506754 100644 --- a/tests/python/unittest/test_pass_loop_partition.py +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -365,6 +365,27 @@ def test_conv_tiling(): stmt = tvm.ir_pass.Simplify(stmt) assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse)))) + +def test_multilevel_splitting_with_indivisble_factors(): + import topi + A = tvm.placeholder((130,), dtype="float32") + B = topi.nn.relu(A) + s = tvm.create_schedule(B.op) + (y,) = s[B].op.axis + (yo, yi) = s[B].split(y, factor=8) + (yoo, yoi) = s[B].split(yo, factor=16) + s[B].reorder(yoo, yoi, yi) + s[B].unroll(yi) + + ## But this does the right thing. + with tvm.build_config(partition_const_loop=True): + lowered_body = tvm.lower(s, [A, B]).body + def visit_stmt(op): + return(isinstance(op, tvm.expr.Max)) + num_max = collect_visit(lowered_body, visit_stmt) + assert num_max.count(True) == 10 + + def test_double_splitting_with_indivisible_factors(): m = 48 dtype="float32" @@ -443,4 +464,5 @@ def test_simple_rfactor(): test_cce_loop_3() test_conv_tiling() test_double_splitting_with_indivisible_factors() + test_multilevel_splitting_with_indivisble_factors() test_simple_rfactor()