From 5a6d2a30d9ae9f696b84999bf16250e3789af577 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 14 Mar 2019 09:52:33 -0700 Subject: [PATCH] [ARITH] RewriteSimplifier: min/max, logical, select (#2768) --- src/arithmetic/rewrite_simplify.cc | 657 +++++++++++++++++- .../unittest/test_arith_rewrite_simplify.py | 278 +++++++- 2 files changed, 919 insertions(+), 16 deletions(-) diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index b304a8dc4bf2..17f8e010f393 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -75,8 +75,29 @@ class RewriteSimplifier::Impl : public IRMutator { Expr Mutate_(const Mul* op, const Expr& self) final; Expr Mutate_(const Div* op, const Expr& self) final; Expr Mutate_(const Mod* op, const Expr& self) final; + Expr Mutate_(const Min* op, const Expr& self) final; + Expr Mutate_(const Max* op, const Expr& self) final; + Expr Mutate_(const EQ* op, const Expr& self) final; + Expr Mutate_(const NE* op, const Expr& self) final; + Expr Mutate_(const LT* op, const Expr& self) final; + Expr Mutate_(const LE* op, const Expr& self) final; + Expr Mutate_(const GT* op, const Expr& self) final; + Expr Mutate_(const GE* op, const Expr& self) final; + Expr Mutate_(const And* op, const Expr& self) final; + Expr Mutate_(const Or* op, const Expr& self) final; + Expr Mutate_(const Not* op, const Expr& self) final; + Expr Mutate_(const Select* op, const Expr& self) final; + Expr Mutate_(const Ramp* op, const Expr& self) final; private: + /*! \brief internal structure for comparison. */ + enum CompareResult { + kUnknown, + kEQ, + kGT, + kLT, + kNE + }; // reference to the main analyzer Analyzer* parent_; // counter to record recursive rewrite depth. @@ -92,12 +113,36 @@ class RewriteSimplifier::Impl : public IRMutator { // Whether x == val bool CanProveEqual(const Expr& x, int64_t val) { // TODO(tqchen) refer back to super-analyzer. - Expr res = Mutate(x); - if (const auto* ptr = res.as()) { - return ptr->value == val; + return TryCompare(x, val) == kEQ; + } + // try to prove x equals val + CompareResult TryCompare(const Expr& x, int64_t val) { + Expr diff = Mutate(x); + if (const auto* ptr = diff.as()) { + if (ptr->value == val) { + return kEQ; + } else if (ptr->value > val) { + return kGT; + } else if (ptr->value < val) { + return kLT; + } + } + if (val == 0) { + ModularSet dmod = parent_->modular_set(diff); + if (dmod->base != 0) { + return kNE; + } + } + ConstIntBound dbound = parent_->const_int_bound(diff); + if (dbound->min_value > val) { + return kGT; + } + if (dbound->max_value < val) { + return kLT; } - return false; + return kUnknown; } + // Recursive rewrite x // we limit maximum depth of recursive rewrite allowed to // avoid infinite loop @@ -557,7 +602,7 @@ Mutate_(const Mod* op, const Expr& self) { // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2, c3; + PVar c1, c2; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -626,6 +671,608 @@ Mutate_(const Mod* op, const Expr& self) { return ret; } +Expr RewriteSimplifier::Impl:: +Mutate_(const Min* op, const Expr& self) { + Expr ret = IRMutator::Mutate_(op, self); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); + if (const_res.defined()) return const_res; + + // Pattern var to match any expression + PVar x, y, z, s1, s2; + // Pattern var match IntImm + PVar c1, c2; + PVar lanes; + + // vector rule + if (op->type.lanes() != 1) { + TVM_TRY_REWRITE(min(broadcast(x, lanes), broadcast(y, lanes)), + broadcast(min(x, y), lanes)); + TVM_TRY_REWRITE(min(min(x, broadcast(y, lanes)), broadcast(z, lanes)), + min(x, broadcast(min(y, z), lanes))); + } + if (IsIndexType(op->type)) { + TVM_TRY_REWRITE(min(x, x), x); + + // constant int bound + ConstIntBound a_bound = parent_->const_int_bound(op->a); + ConstIntBound b_bound = parent_->const_int_bound(op->b); + if (a_bound->max_value <= b_bound->min_value) { + return op->a; + } + if (b_bound->max_value <= a_bound->min_value) { + return op->b; + } + + // constant comparison + if (min(x + c1, x + c2).Match(ret)) { + if (c1.Eval()->value < c2.Eval()->value) { + return (x + c1).Eval(); + } else { + return (x + c2).Eval(); + } + } + if (min(x + c1, x).Match(ret) || + min(x, x + c1).Match(ret)) { + if (c1.Eval()->value < 0) { + return (x + c1).Eval(); + } else { + return x.Eval(); + } + } + if (min(c1 - x, c2 - x).Match(ret)) { + if (c1.Eval()->value < c2.Eval()->value) { + return (c1 - x).Eval(); + } else { + return (c2 - x).Eval(); + } + } + + // Divide up rounding + TVM_TRY_REWRITE_IF(min(((x + c1) / c2) * c2, x), x, + c2.Eval()->value > 0 && + c1.Eval()->value + 1 == c2.Eval()->value); + TVM_TRY_REWRITE_IF(min(((x + c1) / c2) * c2, max(x, c2)), max(x, c2), + c2.Eval()->value > 0 && + c1.Eval()->value + 1 == c2.Eval()->value && + CanProveGreaterEqual(x.Eval(), 0)); + + TVM_TRY_REWRITE_IF(min(x, ((x + c1) / c2) * c2), x, + c2.Eval()->value > 0 && + c1.Eval()->value + 1 == c2.Eval()->value); + TVM_TRY_REWRITE_IF(min(max(x, c2), ((x + c1) / c2) * c2), max(x, c2), + c2.Eval()->value > 0 && + c1.Eval()->value + 1 == c2.Eval()->value && + CanProveGreaterEqual(x.Eval(), 0)); + + TVM_TRY_REWRITE(min(max(x, y), min(x, y)), min(x, y)); + TVM_TRY_REWRITE(min(max(x, y), min(y, x)), min(x, y)); + TVM_TRY_REWRITE(min(min(x, y), max(x, y)), min(x, y)); + TVM_TRY_REWRITE(min(min(x, y), max(y, x)), min(x, y)); + + TVM_TRY_REWRITE(min(max(x, y), x), x); + TVM_TRY_REWRITE(min(max(x, y), y), y); + TVM_TRY_REWRITE(min(min(x, y), x), min(x, y)); + TVM_TRY_REWRITE(min(min(x, y), y), min(x, y)); + + TVM_TRY_REWRITE(min(x, max(x, y)), x); + TVM_TRY_REWRITE(min(y, max(x, y)), y); + TVM_TRY_REWRITE(min(x, min(x, y)), min(x, y)); + TVM_TRY_REWRITE(min(y, min(x, y)), min(x, y)); + + TVM_TRY_REWRITE(min(min(min(x, y), z), y), min(min(x, y), z)); + TVM_TRY_REWRITE(min(min(min(min(x, y), z), s1), y), min(min(min(x, y), z), s1)); + TVM_TRY_REWRITE(min(min(min(min(min(x, y), z), s1), s2), y), + min(min(min(min(x, y), z), s1), s2)); + + TVM_TRY_REWRITE(min(max(x, y), max(x, z)), max(min(y, z), x)); + TVM_TRY_REWRITE(min(max(x, y), max(z, x)), max(min(y, z), x)); + TVM_TRY_REWRITE(min(max(y, x), max(x, z)), max(min(y, z), x)); + TVM_TRY_REWRITE(min(max(y, x), max(z, x)), max(min(y, z), x)); + + TVM_TRY_REWRITE(min(min(x, y), min(x, z)), min(min(y, z), x)); + TVM_TRY_REWRITE(min(min(x, y), min(z, x)), min(min(y, z), x)); + TVM_TRY_REWRITE(min(min(y, x), min(x, z)), min(min(y, z), x)); + TVM_TRY_REWRITE(min(min(y, x), min(z, x)), min(min(y, z), x)); + + TVM_TRY_REWRITE(min(y + x, z + x), min(y, z) + x); + TVM_TRY_REWRITE(min(y + x, x + z), min(y, z) + x); + TVM_TRY_REWRITE(min(x + y, x + z), min(y, z) + x); + TVM_TRY_REWRITE(min(x + y, z + x), min(y, z) + x); + + // sub distribution + TVM_TRY_REWRITE(min(y - x, z - x), min(y, z) - x); + TVM_TRY_REWRITE(min(x - y, x - z), x - max(y, z)); + + // constant folding rule. + TVM_TRY_REWRITE(min(min(x, c1), c2), min(x, min(c1, c2))); + + // scaling rule + if (min(x / c1, y / c1).Match(ret)) { + if (c1.Eval()->value > 0) { + return (min(x, y) / c1).Eval(); + } else { + return (max(x, y) / c1).Eval(); + } + } + if (min(x * c1, y * c1).Match(ret)) { + if (c1.Eval()->value > 0) { + return (min(x, y) * c1).Eval(); + } else { + return (max(x, y) * c1).Eval(); + } + } + if (min(x * c1, c2).Match(ret)) { + int64_t c1val = c1.Eval()->value; + int64_t c2val = c2.Eval()->value; + if (c2val % c1val == 0) { + if (c2val / c1val >= 0) { + return (min(x, c2val / c1val) * c1val).Eval(); + } else { + return (max(x, c2val / c1val) * c1val).Eval(); + } + } + } + + // canonicalization + TVM_TRY_RECURSIVE_REWRITE(min(min(x, c1), y), min(min(x, y), c1)); + TVM_TRY_RECURSIVE_REWRITE(min(c1 - x, c2), c1 - max(x, c2 - c1)); + } + + // condition rules. + TVM_TRY_REWRITE(min(select(x, y, z), select(x, s1, s2)), + select(x, min(y, s1), min(z, s2))); + return ret; +} + +Expr RewriteSimplifier::Impl:: +Mutate_(const Max* op, const Expr& self) { + Expr ret = IRMutator::Mutate_(op, self); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); + if (const_res.defined()) return const_res; + + // Pattern var to match any expression + PVar x, y, z, s1, s2; + // Pattern var match IntImm + PVar c1, c2; + PVar lanes; + + // vector rule + if (op->type.lanes() != 1) { + TVM_TRY_REWRITE(max(broadcast(x, lanes), broadcast(y, lanes)), + broadcast(max(x, y), lanes)); + TVM_TRY_REWRITE(max(max(x, broadcast(y, lanes)), broadcast(z, lanes)), + max(x, broadcast(max(y, z), lanes))); + } + if (IsIndexType(op->type)) { + TVM_TRY_REWRITE(max(x, x), x); + + // constant int bound + ConstIntBound a_bound = parent_->const_int_bound(op->a); + ConstIntBound b_bound = parent_->const_int_bound(op->b); + if (a_bound->min_value >= b_bound->max_value) { + return op->a; + } + if (b_bound->min_value >= a_bound->max_value) { + return op->b; + } + + // constant comparison + if (max(x + c1, x + c2).Match(ret)) { + if (c1.Eval()->value > c2.Eval()->value) { + return (x + c1).Eval(); + } else { + return (x + c2).Eval(); + } + } + if (max(x + c1, x).Match(ret) || + max(x, x + c1).Match(ret)) { + if (c1.Eval()->value > 0) { + return (x + c1).Eval(); + } else { + return x.Eval(); + } + } + if (max(c1 - x, c2 - x).Match(ret)) { + if (c1.Eval()->value > c2.Eval()->value) { + return (c1 - x).Eval(); + } else { + return (c2 - x).Eval(); + } + } + + // Divide up rounding + TVM_TRY_REWRITE_IF(max(((x + c1) / c2) * c2, x), ((x + c1) / c2) * c2, + c2.Eval()->value > 0 && + c1.Eval()->value + 1 == c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(x, ((x + c1) / c2) * c2), ((x + c1) / c2) * c2, + c2.Eval()->value > 0 && + c1.Eval()->value + 1 == c2.Eval()->value); + + TVM_TRY_REWRITE(max(min(x, y), max(x, y)), max(x, y)); + TVM_TRY_REWRITE(max(min(x, y), max(y, x)), max(x, y)); + TVM_TRY_REWRITE(max(max(x, y), min(x, y)), max(x, y)); + TVM_TRY_REWRITE(max(max(x, y), min(y, x)), max(x, y)); + + TVM_TRY_REWRITE(max(min(x, y), x), x); + TVM_TRY_REWRITE(max(min(x, y), y), y); + TVM_TRY_REWRITE(max(max(x, y), x), max(x, y)); + TVM_TRY_REWRITE(max(max(x, y), y), max(x, y)); + + TVM_TRY_REWRITE(max(x, min(x, y)), x); + TVM_TRY_REWRITE(max(y, min(x, y)), y); + TVM_TRY_REWRITE(max(x, max(x, y)), max(x, y)); + TVM_TRY_REWRITE(max(y, max(x, y)), max(x, y)); + + TVM_TRY_REWRITE(max(max(max(x, y), z), y), max(max(x, y), z)); + TVM_TRY_REWRITE(max(max(max(max(x, y), z), s1), y), max(max(max(x, y), z), s1)); + TVM_TRY_REWRITE(max(max(max(max(max(x, y), z), s1), s2), y), + max(max(max(max(x, y), z), s1), s2)); + + // max/max cancelation + TVM_TRY_REWRITE(max(max(x, y), max(x, z)), max(max(y, z), x)); + TVM_TRY_REWRITE(max(max(x, y), max(z, x)), max(max(y, z), x)); + TVM_TRY_REWRITE(max(max(y, x), max(x, z)), max(max(y, z), x)); + TVM_TRY_REWRITE(max(max(y, x), max(z, x)), max(max(y, z), x)); + + // max/min distribution + TVM_TRY_REWRITE(max(min(x, y), min(x, z)), min(max(y, z), x)); + TVM_TRY_REWRITE(max(min(x, y), min(z, x)), min(max(y, z), x)); + TVM_TRY_REWRITE(max(min(y, x), min(x, z)), min(max(y, z), x)); + TVM_TRY_REWRITE(max(min(y, x), min(z, x)), min(max(y, z), x)); + + // add distribution + TVM_TRY_REWRITE(max(y + x, z + x), max(y, z) + x); + TVM_TRY_REWRITE(max(y + x, x + z), max(y, z) + x); + TVM_TRY_REWRITE(max(x + y, x + z), max(y, z) + x); + TVM_TRY_REWRITE(max(x + y, z + x), max(y, z) + x); + + // sub distribution + TVM_TRY_REWRITE(max(y - x, z - x), max(y, z) - x); + TVM_TRY_REWRITE(max(x - y, x - z), x - min(y, z)); + + // constant folding rule. + TVM_TRY_REWRITE(max(max(x, c1), c2), max(x, max(c1, c2))); + + // scaling rule + if (max(x / c1, y / c1).Match(ret)) { + if (c1.Eval()->value > 0) { + return (max(x, y) / c1).Eval(); + } else { + return (min(x, y) / c1).Eval(); + } + } + if (max(x * c1, y * c1).Match(ret)) { + if (c1.Eval()->value > 0) { + return (max(x, y) * c1).Eval(); + } else { + return (min(x, y) * c1).Eval(); + } + } + if (max(x * c1, c2).Match(ret)) { + int64_t c1val = c1.Eval()->value; + int64_t c2val = c2.Eval()->value; + if (c2val % c1val == 0) { + if (c2val / c1val >= 0) { + return (max(x, c2val / c1val) * c1val).Eval(); + } else { + return (min(x, c2val / c1val) * c1val).Eval(); + } + } + } + + // canonicalization + TVM_TRY_RECURSIVE_REWRITE(max(max(x, c1), y), max(max(x, y), c1)); + TVM_TRY_RECURSIVE_REWRITE(max(c1 - x, c2), c1 - min(x, c2 - c1)); + } + + // condition rules. + TVM_TRY_REWRITE(max(select(x, y, z), select(x, s1, s2)), + select(x, max(y, s1), max(z, s2))); + return ret; +} + +Expr RewriteSimplifier::Impl:: +Mutate_(const EQ* op, const Expr& self) { + Expr ret = IRMutator::Mutate_(op, self); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); + if (const_res.defined()) return const_res; + + // Pattern var to match any expression + PVar x, y; + // Pattern var match IntImm + PVar c1; + PVar lanes; + + // vector rule + if (op->type.lanes() != 1) { + TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes), + broadcast(x == y, lanes)); + } + + if (IsIndexType(op->a.type())) { + CompareResult result = TryCompare(op->a - op->b, 0); + if (result != kUnknown) { + if (result == kEQ) { + return make_const(op->type, true); + } else { + return make_const(op->type, false); + } + } + TVM_TRY_REWRITE(x - c1 == 0, x == c1); + TVM_TRY_REWRITE(c1 - x == 0, x == c1); + TVM_TRY_REWRITE(x + c1 == 0, x == 0 - c1); + TVM_TRY_REWRITE(x * y == 0, x == 0 || y == 0); + } + return ret; +} + +Expr RewriteSimplifier::Impl:: +Mutate_(const NE* op, const Expr& self) { + return Mutate(Not::make(op->a == op->b)); +} + +Expr RewriteSimplifier::Impl:: +Mutate_(const LE* op, const Expr& self) { + return Mutate(Not::make(op->b < op->a)); +} + +Expr RewriteSimplifier::Impl:: +Mutate_(const GT* op, const Expr& self) { + return Mutate(op->b < op->a); +} + +Expr RewriteSimplifier::Impl:: +Mutate_(const GE* op, const Expr& self) { + return Mutate(Not::make(op->a < op->b)); +} + +Expr RewriteSimplifier::Impl:: +Mutate_(const LT* op, const Expr& self) { + Expr ret = IRMutator::Mutate_(op, self); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); + if (const_res.defined()) return const_res; + + // Pattern var to match any expression + PVar x, y, z, s1, s2; + // Pattern var match IntImm + PVar c1, c2; + PVar lanes; + + // vector rule + if (op->type.lanes() != 1) { + TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes), + broadcast(x < y, lanes)); + TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes), + broadcast(x < y, lanes)); + } + + if (IsIndexType(op->a.type())) { + CompareResult result = TryCompare(op->a - op->b, 0); + if (result == kLT) { + return make_const(op->type, true); + } + if (result == kEQ || result == kGT) { + return make_const(op->type, false); + } + + TVM_TRY_REWRITE(x + y < x + z, y < z); + TVM_TRY_REWRITE(x + y < z + x, y < z); + TVM_TRY_REWRITE(y + x < x + z, y < z); + TVM_TRY_REWRITE(y + x < z + x, y < z); + TVM_TRY_REWRITE(y - x < z - x, y < z); + TVM_TRY_REWRITE(x - y < x - z, z < y); + + TVM_TRY_REWRITE(x < x + z, 0 < z); + TVM_TRY_REWRITE(x < z + x, 0 < z); + TVM_TRY_REWRITE(x < x - z, z < 0); + TVM_TRY_REWRITE(c1 < x + c2, c1 - c2 < x); + TVM_TRY_REWRITE(c1 < c2 - x, x < c2 - c1); + + TVM_TRY_REWRITE_IF(x * c1 < y * c1, x < y, + c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(x * c1 < y * c1, y < x, + c1.Eval()->value < 0); + + // require c1 > 0 to work for any div mode + TVM_TRY_REWRITE_IF(x * c2 < c1, x < (c1 - 1) / c2 + 1, + c1.Eval()->value > 0 && + c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2, + c1.Eval()->value > 0 && + c2.Eval()->value > 0); + + TVM_TRY_REWRITE_IF(c1 < x * c2, c1 / c2 < x, + c1.Eval()->value >= 0 && + c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(c1 < x / c2, (c1 + 1) * c2 - 1 < x, + c1.Eval()->value >= 0 && + c2.Eval()->value > 0); + + // division related simplificationx + // invariance for any div mod: x - (x / c1) * c1 == x % c1 + TVM_TRY_REWRITE_IF((x / c1) * c1 < x, 0 < x % c1, + c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF((x / c1) * c1 < x + y, 0 < x % c1 + y, + c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF((x / c1) * c1 < x - y, y < x % c1, + c1.Eval()->value > 0); + + TVM_TRY_REWRITE_IF(((x + c2)/ c1) * c1 < x, + c2 < (x + c2) % c1, + c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(((x + c2)/ c1) * c1 < x + y, + c2 < (x + c2) % c1 + y, + c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(((x + c2)/ c1) * c1 < x - y, + y < (x + c2) % c1 + (0 - c2), + c1.Eval()->value > 0); + + // canonicalization rule + TVM_TRY_RECURSIVE_REWRITE(min(x, y) < z, x < z || y < z); + TVM_TRY_RECURSIVE_REWRITE(max(x, y) < z, x < z && y < z); + TVM_TRY_RECURSIVE_REWRITE(z < min(x, y), z < x && z < y); + TVM_TRY_RECURSIVE_REWRITE(z < max(x, y), z < x || z < y); + + TVM_TRY_REWRITE(x - c1 < 0, x < c1); + TVM_TRY_REWRITE(x + c1 < c2, x < c2 - c1); + } + return ret; +} + +Expr RewriteSimplifier::Impl:: +Mutate_(const Not* op, const Expr& self) { + Expr ret = IRMutator::Mutate_(op, self); + op = ret.as(); + Expr const_res = TryConstFold(op->a); + if (const_res.defined()) return const_res; + // Pattern var to match any expression + PVar x, y; + PVar lanes; + if (op->type.lanes() != 1) { + TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes)); + } + + TVM_TRY_REWRITE(!(!x), x); + TVM_TRY_REWRITE(!(x <= y), y < x); + TVM_TRY_REWRITE(!(x >= y), x < y); + TVM_TRY_REWRITE(!(x < y), y <= x); + TVM_TRY_REWRITE(!(x > y), x <= y); + TVM_TRY_REWRITE(!(x == y), x != y); + TVM_TRY_REWRITE(!(x != y), x == y); + TVM_TRY_RECURSIVE_REWRITE(!(x || y), (!x) && (!y)); + TVM_TRY_RECURSIVE_REWRITE(!(x && y), (!x) || (!y)); + return ret; +} + +Expr RewriteSimplifier::Impl:: +Mutate_(const And* op, const Expr& self) { + Expr ret = IRMutator::Mutate_(op, self); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); + if (const_res.defined()) return const_res; + + // Pattern var to match any expression + PVar x, y; + // Pattern var match IntImm + PVar c1, c2; + PVar lanes; + + if (op->type.lanes() != 1) { + TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), + broadcast(x && y, lanes)); + } + + auto cfalse = PConst(make_const(op->type, false)); + TVM_TRY_REWRITE(x == y && x != y, cfalse); + TVM_TRY_REWRITE(x != y && x == y, cfalse); + TVM_TRY_REWRITE(x && !x, cfalse); + TVM_TRY_REWRITE(x <= y && y < x, cfalse); + TVM_TRY_REWRITE(y < x && y <= x, cfalse); + + TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse, + c2.Eval()->value + 1 >= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse, + c2.Eval()->value + 1 >= c1.Eval()->value); + + TVM_TRY_REWRITE_IF(x < c1 && c2 <= x, cfalse, + c2.Eval()->value >= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 <= x && x < c1, cfalse, + c2.Eval()->value >= c1.Eval()->value); + TVM_TRY_REWRITE_IF(x <= c1 && c2 < x, cfalse, + c2.Eval()->value >= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 < x && x <= c1, cfalse, + c2.Eval()->value >= c1.Eval()->value); + + TVM_TRY_REWRITE_IF(x <= c1 && c2 <= x, cfalse, + c2.Eval()->value > c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 <= x && x <= c1, cfalse, + c2.Eval()->value > c1.Eval()->value); + + TVM_TRY_REWRITE(x == c1 && x != c2, x == c1 && c1 != c2); + TVM_TRY_REWRITE(x != c2 && x == c1, x == c1 && c1 != c2); + return ret; +} + +Expr RewriteSimplifier::Impl:: +Mutate_(const Or* op, const Expr& self) { + Expr ret = IRMutator::Mutate_(op, self); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); + if (const_res.defined()) return const_res; + + // Pattern var to match any expression + PVar x, y; + // Pattern var match IntImm + PVar c1, c2; + PVar lanes; + + if (op->type.lanes() != 1) { + TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), + broadcast(x || y, lanes)); + } + + auto ctrue = PConst(make_const(op->type, true)); + + TVM_TRY_REWRITE(x == y || x != y, ctrue); + TVM_TRY_REWRITE(x != y || x == y, ctrue); + TVM_TRY_REWRITE(x || !x, ctrue); + TVM_TRY_REWRITE(x <= y || y < x, ctrue); + TVM_TRY_REWRITE(y < x || y <= x, ctrue); + + TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, + c2.Eval()->value < c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, + c2.Eval()->value < c1.Eval()->value); + + TVM_TRY_REWRITE_IF(x <= c1 || c2 < x, ctrue, + c2.Eval()->value <= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 < x || x <= c1, ctrue, + c2.Eval()->value <= c1.Eval()->value); + TVM_TRY_REWRITE_IF(x < c1 || c2 <= x, ctrue, + c2.Eval()->value <= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 <= x || x < c1, ctrue, + c2.Eval()->value <= c1.Eval()->value); + + TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, + c2.Eval()->value <= c1.Eval()->value + 1); + TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, + c2.Eval()->value <= c1.Eval()->value + 1); + + TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2); + TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2); + return ret; +} + +Expr RewriteSimplifier::Impl:: +Mutate_(const Ramp* op, const Expr& self) { + Expr ret = IRMutator::Mutate_(op, self); + op = ret.as(); + if (is_zero(op->stride)) { + return Broadcast::make(op->base, op->lanes); + } + return ret; +} + +Expr RewriteSimplifier::Impl:: +Mutate_(const Select* op, const Expr& self) { + Expr ret = IRMutator::Mutate_(op, self); + op = ret.as