diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 33dbaed83b697..cba1209d7d91d 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -509,46 +509,42 @@ Stmt LoopPartitioner::TryPartition(const Node* node, bool pre_stmt_recurse = true; if (middle_interval_i->HasLowerBound()) { body_begin = ir::Simplify(middle_interval.min()); - if (!analyzer_.CanProve(body_begin == min)) { - Expr cond = (body_begin - min >= 0); - if (!analyzer_.CanProve(cond)) { - LOG(WARNING) << "Cannot prove: " << cond - << ", when generating the pre doubt loop"; - body_begin = Max::make(body_begin, min); - // stop recursing on this interval if we can't prove it has non-negative length - pre_stmt_recurse = false; - } - if (!partition_thread_scope) { - Stmt pre_body = Substitute(body, {{Var{var}, var + min}}); - pre_stmt = MakeFor(node, body_begin - min, pre_body); - } + Expr cond = (body_begin - min >= 0); + if (!analyzer_.CanProve(cond)) { + LOG(WARNING) << "Cannot prove: " << cond + << ", when generating the pre doubt loop"; + body_begin = Max::make(body_begin, min); + // stop recursing on this interval if we can't prove it has non-negative length + pre_stmt_recurse = false; + } + if (!partition_thread_scope) { + Stmt pre_body = Substitute(body, {{Var{var}, var + min}}); + pre_stmt = MakeFor(node, body_begin - min, pre_body); } } else { body_begin = min; } // Calculating post-subrange and generating code for it. - // post-subrange = [post_doubt_begin, max] + // post-subrange = [post_doubt_begin, max+1) Expr post_doubt_begin; Stmt post_stmt; bool post_stmt_recurse = true; if (middle_interval_i->HasUpperBound()) { post_doubt_begin = ir::Simplify(middle_interval.max() + 1); - if (!analyzer_.CanProve(middle_interval.max() == max)) { - // require the extent to be non-negative - Expr cond = (max - post_doubt_begin + 1 >= 0); - if (!analyzer_.CanProve(cond)) { - LOG(WARNING) << "Cannot prove: " << cond - << ", when generating the post doubt loop"; - post_doubt_begin = Min::make(post_doubt_begin, max); - // stop recursing on this interval if we can't prove it has non-negative length - post_stmt_recurse = false; - } - if (!partition_thread_scope) { - Stmt post_body = - Substitute(body, {{Var{var}, var + post_doubt_begin}}); - post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body); - } + // require the extent to be non-negative + Expr cond = (max - post_doubt_begin + 1 >= 0); + if (!analyzer_.CanProve(cond)) { + LOG(WARNING) << "Cannot prove: " << cond + << ", when generating the post doubt loop"; + post_doubt_begin = Min::make(post_doubt_begin, max+1); + // stop recursing on this interval if we can't prove it has non-negative length + post_stmt_recurse = false; + } + if (!partition_thread_scope) { + Stmt post_body = + Substitute(body, {{Var{var}, var + post_doubt_begin}}); + post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body); } } else { post_doubt_begin = max + 1;