diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 937e7c462868b..9f195d7d71cee 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -62,8 +62,8 @@ class RewriteSimplifier::Impl : public IRMutator { // Run simplification in post order Expr PostOrderSimplify(Expr expr, int max_iter = 2) { + for (int i = 0; i < max_iter; ++i) { - recur_counter_ = 0; Expr new_expr = this->Mutate(expr); if (new_expr.same_as(expr)) return expr; expr = new_expr; @@ -80,12 +80,12 @@ class RewriteSimplifier::Impl : public IRMutator { private: // reference to the main analyzer Analyzer* parent_; - // counter to record recursive rewrite times. - int recur_counter_{0}; + // counter to record recursive rewrite depth. + int recur_depth_{0}; // internal variable map std::unordered_map var_map_; // maximum number of recursion allowed during a single pass. - static const constexpr int kMaxRecurCount = 10; + static const constexpr int kMaxRecurDepth = 5; // Whether x >= val bool CanProveGreaterEqual(const Expr& x, int64_t val) { return parent_->CanProveGreaterEqual(x, val); @@ -100,13 +100,16 @@ class RewriteSimplifier::Impl : public IRMutator { return false; } // Recursive rewrite x - // we limit maximum number of recursive rewrite allowed to + // we limit maximum depth of recursive rewrite allowed to // avoid infinite loop - Expr RecursiveRewrite(Expr x) { - if (recur_counter_ >= kMaxRecurCount) return x; - ++recur_counter_; - return Mutate(x); + Expr RecursiveRewrite(const Expr& x) { + if (recur_depth_ >= kMaxRecurDepth) return x; + ++recur_depth_; + Expr res = Mutate(x); + --recur_depth_; + return res; } + template PConstWithTypeLike ZeroWithTypeLike(const Pattern& pattern) { return PConstWithTypeLike(pattern.derived(), 0); @@ -152,6 +155,8 @@ Mutate_(const Add* op, const Expr& self) { TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z)); TVM_TRY_REWRITE(max(x, y) + min(x, y), x + y); TVM_TRY_REWRITE(min(x, y) + max(x, y), x + y); + TVM_TRY_REWRITE(max(x, y) + min(y, x), x + y); + TVM_TRY_REWRITE(min(x, y) + max(y, x), x + y); TVM_TRY_REWRITE_IF(min(x, y + c1) + c2, min(x + c2, y), c1.Eval()->value == -c2.Eval()->value); diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index ab3da21b7f710..bbfddddd41daf 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -72,7 +72,10 @@ def test_select_simplify(): tvm.expr.Select(x > 0, y + 1, z)) ck.verify(tvm.expr.Select(x > 0, y, 1) - tvm.expr.Select(x > 0, 1, z), tvm.expr.Select(x > 0, y + (-1), 1 - z)) - + ck.verify(tvm.expr.Select(x > 0, y, z) - y, + tvm.expr.Select(x > 0, 0, z - y)) + ck.verify(tvm.expr.Select(x > 0, y, z) - z, + tvm.expr.Select(x > 0, y - z, 0)) def test_add_index_simplify():