diff --git a/src/Simplify_Stmts.cpp b/src/Simplify_Stmts.cpp index 57b7e7c46bfa..3645ebbf4369 100644 --- a/src/Simplify_Stmts.cpp +++ b/src/Simplify_Stmts.cpp @@ -217,7 +217,12 @@ Stmt Simplify::visit(const For *op) { (in_vector_loop || op->for_type == ForType::Vectorized)); - bool bounds_tracked = false; + Expr extent_positive = mutate(0 < new_extent, nullptr); + if (is_const_zero(extent_positive)) { + // This loop never runs + return Evaluate::make(0); + } + ExprInfo loop_var_info; // Deduce bounds for the loop var that are true for any code than runs // inside the loop body. Code in the inner loop only runs if the extent is @@ -226,24 +231,19 @@ Stmt Simplify::visit(const For *op) { loop_var_info.bounds = ConstantInterval::make_union(min_info.bounds, min_info.bounds + max(extent_info.bounds, 1) - 1); - - if (loop_var_info.bounds.max_defined || - loop_var_info.bounds.min_defined) { - bounds_tracked = true; - bounds_and_alignment_info.push(op->name, loop_var_info); - } - - Expr extent_positive = mutate(0 < new_extent, nullptr); - if (is_const_zero(extent_positive)) { - return Evaluate::make(0); - } - Stmt new_body; { + ScopedBinding bind_if((loop_var_info.bounds.max_defined || + loop_var_info.bounds.min_defined), + bounds_and_alignment_info, + op->name, + loop_var_info); + // If we're in the loop, the extent must be greater than 0. ScopedFact fact = scoped_truth(extent_positive); new_body = mutate(op->body); } + if (in_unreachable) { // We found that the body of this loop is unreachable when recursively // mutating it, so we can remove the loop. Additionally, if we know the @@ -254,10 +254,6 @@ Stmt Simplify::visit(const For *op) { return Evaluate::make(0); } - if (bounds_tracked) { - bounds_and_alignment_info.pop(op->name); - } - if (const Acquire *acquire = new_body.as()) { if (is_no_op(acquire->body)) { // Rewrite iterated no-op acquires as a single acquire.