From e12a8dc7e5342d9537301c292255c1b2d3c2bf73 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 10 Mar 2019 12:31:34 -0400 Subject: [PATCH] [ARITH] Analyzer RewriteSimplifier: add/sub/mul/div/mod (#2722) --- include/tvm/arithmetic.h | 35 + python/tvm/arith.py | 16 + src/api/api_arith.cc | 4 + src/arithmetic/analyzer.cc | 11 +- src/arithmetic/const_fold.h | 4 +- src/arithmetic/pattern_match.h | 51 +- src/arithmetic/rewrite_simplify.cc | 650 ++++++++++++++++++ tests/cpp/pattern_match_test.cc | 1 + .../unittest/test_arith_rewrite_simplify.py | 252 +++++++ 9 files changed, 1016 insertions(+), 8 deletions(-) create mode 100644 src/arithmetic/rewrite_simplify.cc create mode 100644 tests/python/unittest/test_arith_rewrite_simplify.py diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 44b00b5d89fa..d023f8f1cf7e 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -192,6 +192,39 @@ class ModularSetAnalyzer { Impl* impl_; }; +/*! + * \brief Rewrite-rule based simplifier. + */ +class RewriteSimplifier { + public: + /*! + * \brief analyze the expr + * \param expr The expression of interest. + * \return the result of the analysis. + */ + Expr operator()(const Expr& expr); + + /*! + * \brief Update binding of var to a new expression. + * + * \param var The variable of interest. + * \param new_expr + * \param override Whether do we allow override of existing information. + */ + void Update(const Var& var, + const Expr& new_expr, + bool override = false); + + private: + friend class Analyzer; + friend class ConstraintContext; + explicit RewriteSimplifier(Analyzer* parent); + ~RewriteSimplifier(); + class Impl; + /*! \brief Internal impl */ + Impl* impl_; +}; + /*! * \brief A RAII constraint context. * @@ -242,6 +275,8 @@ class Analyzer { ConstIntBoundAnalyzer const_int_bound; /*! \brief sub-analyzer: modular set */ ModularSetAnalyzer modular_set; + /*! \brief sub-analyzer rewrite simplfy */ + RewriteSimplifier rewrite_simplify; /*! \brief constructor */ Analyzer(); /*! diff --git a/python/tvm/arith.py b/python/tvm/arith.py index 92aaa36aa10f..3981a4815aeb 100644 --- a/python/tvm/arith.py +++ b/python/tvm/arith.py @@ -96,6 +96,7 @@ def __init__(self): self._const_int_bound_update = _mod("const_int_bound_update") self._bind = _mod("bind") self._modular_set = _mod("modular_set") + self._rewrite_simplify = _mod("rewrite_simplify") self._enter_constraint_context = _mod("enter_constraint_context") def const_int_bound(self, expr): @@ -128,6 +129,21 @@ def modular_set(self, expr): """ return self._modular_set(expr) + def rewrite_simplify(self, expr): + """Simplify expression via rewriting rules. + + Parameters + ---------- + expr : tvm.Expr + The expression. + + Returns + ------- + result : Expr + The result. + """ + return self._rewrite_simplify(expr) + def bind(self, var, expr): """Bind a variable to the expression. diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index a714fe37005b..cc7d814617a9 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -98,6 +98,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer") return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { self->const_int_bound.Update(args[0], args[1], args[2]); }); + } else if (name == "rewrite_simplify") { + return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { + *ret = self->rewrite_simplify(args[0]); + }); } else if (name == "bind") { return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { auto& sptr = args[1].node_sptr(); diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index 236a21ba71f5..81195eba2747 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -2,6 +2,7 @@ * Copyright (c) 2019 by Contributors * \file tvm/arithmetic/analyzer.cc */ +#include #include namespace tvm { @@ -9,19 +10,22 @@ namespace arith { Analyzer::Analyzer() : const_int_bound(this), - modular_set(this) { + modular_set(this), + rewrite_simplify(this) { } void Analyzer::Bind(const VarExpr& v, const Expr& expr) { Var var(v.node_); this->const_int_bound.Update(var, this->const_int_bound(expr)); this->modular_set.Update(var, this->modular_set(expr)); + this->rewrite_simplify.Update(var, this->rewrite_simplify(expr)); } void Analyzer::Bind(const VarExpr& v, const Range& range) { Var var(v.node_); this->const_int_bound.Bind(var, range); // skip modular_set + // skip rewrite simplify } ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) { @@ -36,7 +40,10 @@ ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) } bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) { - auto bd = this->const_int_bound(expr); + if (const auto* ptr = expr.as()) { + return ptr->value > lower_bound; + } + auto bd = this->const_int_bound(this->rewrite_simplify(expr)); if (bd->min_value >= lower_bound) return true; return false; } diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index 91613867115b..4c247c8a7b59 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -23,7 +23,9 @@ namespace arith { * \return nullptr if constant fold fails, otherwise return folded result. */ template -inline Expr TryConstFold(Expr a, Expr b); +inline Expr TryConstFold(Expr a, Expr b) { + return Expr(); +} /*! * \brief Try to run unary compute with constant folding. diff --git a/src/arithmetic/pattern_match.h b/src/arithmetic/pattern_match.h index 50f2300dd4b7..20c24b330cbd 100644 --- a/src/arithmetic/pattern_match.h +++ b/src/arithmetic/pattern_match.h @@ -49,6 +49,7 @@ #include #include +#include "const_fold.h" namespace tvm { namespace arith { @@ -242,7 +243,11 @@ class PBinaryExpr : } Expr Eval() const { - return NodeType::make(a_.Eval(), b_.Eval()); + Expr lhs = a_.Eval(); + Expr rhs = b_.Eval(); + Expr ret = TryConstFold(lhs, rhs); + if (ret.defined()) return ret; + return NodeType::make(lhs, rhs); } private: @@ -250,12 +255,48 @@ class PBinaryExpr : typename TB::Nested b_; }; +template +class PConstWithTypeLike : + public Pattern > { + public: + PConstWithTypeLike(const TA& ref, int64_t value) + : ref_(ref), value_(value) {} + + void InitMatch_() const {} + + bool Match_(const NodeRef& node) const { + if (const ir::IntImm* ptr = node.as()) { + return ptr->value == value_; + } else { + return false; + } + } + + Expr Eval() const { + return make_const(ref_.Eval().type(), value_); + } + + private: + typename TA::Nested ref_; + int64_t value_; +}; + -#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \ - template \ - inline PBinaryExpr \ - FuncName(const Pattern& a, const Pattern& b) { \ +#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \ + template \ + inline PBinaryExpr \ + FuncName(const Pattern& a, const Pattern& b) { \ return PBinaryExpr(a.derived(), b.derived()); \ + } \ + template \ + inline PBinaryExpr > \ + FuncName(const Pattern& a, int64_t b) { \ + return FuncName(a, PConstWithTypeLike(a.derived(), b)); \ + } \ + template \ + inline PBinaryExpr, TA> \ + FuncName(int64_t b, const Pattern& a) { \ + return FuncName(PConstWithTypeLike(a.derived(), b), a); \ } // arithmetic expressions diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc new file mode 100644 index 000000000000..b304a8dc4bf2 --- /dev/null +++ b/src/arithmetic/rewrite_simplify.cc @@ -0,0 +1,650 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file rewrite_simplify.cc + * \brief Rewrite-rule based simplification. + */ +// Acknowledgement: Most rewrite-rules are from Halide. +#include +#include +#include +#include "const_fold.h" +#include "pattern_match.h" + +namespace tvm { +namespace arith { + +using namespace ir; + +// macro for doing simple rewrite +#define TVM_TRY_REWRITE(SrcExpr, ResExpr) \ + if ((SrcExpr).Match(ret)) { \ + return (ResExpr).Eval(); \ + } + +// macro for rewrite + recursively rewrite ResExpr +#define TVM_TRY_RECURSIVE_REWRITE(SrcExpr, ResExpr) \ + if ((SrcExpr).Match(ret)) { \ + return RecursiveRewrite((ResExpr).Eval()); \ + } + +// macro rewrite only if CondExor is true after match. +#define TVM_TRY_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ + if ((SrcExpr).Match(ret) && (CondExpr)) { \ + return (ResExpr).Eval(); \ + } + +// macro rewrite + recursive_rewrite only if CondExor is true after match. +#define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ + if ((SrcExpr).Match(ret) && (CondExpr)) { \ + return RecursiveRewrite((ResExpr).Eval()); \ + } + + +// NOTE for developers: +// +// We mainly focus on index expression simplification. +// Besides the RewriteSimplifier, some cases can be better +// handled by CanonicalSimplifier. +// +class RewriteSimplifier::Impl : public IRMutator { + public: + explicit Impl(Analyzer* parent) + : parent_(parent) {} + + void Update(const Var& var, + const Expr& info, + bool override) { + if (!override) { + CHECK(!var_map_.count(var)); + } + var_map_[var] = info; + } + + // Run simplification in post order + Expr PostOrderSimplify(Expr expr, int max_iter = 2) { + for (int i = 0; i < max_iter; ++i) { + Expr new_expr = this->Mutate(expr); + if (new_expr.same_as(expr)) return expr; + expr = new_expr; + } + return expr; + } + + Expr Mutate_(const Add* op, const Expr& self) final; + Expr Mutate_(const Sub* op, const Expr& self) final; + Expr Mutate_(const Mul* op, const Expr& self) final; + Expr Mutate_(const Div* op, const Expr& self) final; + Expr Mutate_(const Mod* op, const Expr& self) final; + + private: + // reference to the main analyzer + Analyzer* parent_; + // counter to record recursive rewrite depth. + int recur_depth_{0}; + // internal variable map + std::unordered_map var_map_; + // maximum number of recursion allowed during a single pass. + static const constexpr int kMaxRecurDepth = 5; + // Whether x >= val + bool CanProveGreaterEqual(const Expr& x, int64_t val) { + return parent_->CanProveGreaterEqual(x, val); + } + // Whether x == val + bool CanProveEqual(const Expr& x, int64_t val) { + // TODO(tqchen) refer back to super-analyzer. + Expr res = Mutate(x); + if (const auto* ptr = res.as()) { + return ptr->value == val; + } + return false; + } + // Recursive rewrite x + // we limit maximum depth of recursive rewrite allowed to + // avoid infinite loop + Expr RecursiveRewrite(const Expr& x) { + if (recur_depth_ >= kMaxRecurDepth) return x; + ++recur_depth_; + Expr res = Mutate(x); + --recur_depth_; + return res; + } + + template + PConstWithTypeLike ZeroWithTypeLike(const Pattern& pattern) { + return PConstWithTypeLike(pattern.derived(), 0); + } +}; + +Expr RewriteSimplifier::Impl:: +Mutate_(const Add* op, const Expr& self) { + Expr ret = IRMutator::Mutate_(op, self); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); + if (const_res.defined()) return const_res; + // Pattern var to match any expression + PVar x, y, z, b1, b2, s1, s2; + // Pattern var match IntImm + PVar c1, c2, c3; + // Pattern var for lanes in broadcast and ramp + PVar lanes; + // Vector rules + if (op->type.lanes() != 1) { + TVM_TRY_REWRITE(ramp(b1, s1, lanes) + ramp(b2, s2, lanes), + ramp(b1 + b2, s1 + s2, lanes)); + TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes), + ramp(b1 + x, s1, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes), + ramp(x + b1, s1, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) + broadcast(y, lanes), + broadcast(x + y, lanes)); + } + + if (IsIndexType(op->type)) { + // Index rules + // cancelation rules + TVM_TRY_REWRITE((x - y) + y, x); + TVM_TRY_REWRITE(x + (y - x), y); + + TVM_TRY_REWRITE((x - y) + (y - z), x - z); + TVM_TRY_REWRITE((x - y) + (z - x), z - y); + + TVM_TRY_REWRITE(min(x, y - z) + z, min(x + z, y)); + TVM_TRY_REWRITE(min(x - z, y) + z, min(x, y + z)); + TVM_TRY_REWRITE(max(x, y - z) + z, max(x + z, y)); + TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z)); + TVM_TRY_REWRITE(max(x, y) + min(x, y), x + y); + TVM_TRY_REWRITE(min(x, y) + max(x, y), x + y); + TVM_TRY_REWRITE(max(x, y) + min(y, x), x + y); + TVM_TRY_REWRITE(min(x, y) + max(y, x), x + y); + + TVM_TRY_REWRITE_IF(min(x, y + c1) + c2, min(x + c2, y), + c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(min(x + c1, y) + c2, min(x, y + c2), + c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(x, y + c1) + c2, max(x + c2, y), + c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(x + c1, y) + c2, max(x, y + c2), + c1.Eval()->value == -c2.Eval()->value); + + // constant folding + // NOTE: canonicalization might better at this. + TVM_TRY_REWRITE((x + c1) + c2, x + (c1 + c2)); + + // mul co-efficient folding + TVM_TRY_REWRITE(x + x, x * 2); + TVM_TRY_REWRITE(x * y + x, x * (y + 1)); + TVM_TRY_REWRITE(y * x + x, x * (y + 1)); + TVM_TRY_REWRITE(x + y * x, x * (1 + y)); + TVM_TRY_REWRITE(x + x * y, x * (1 + y)); + TVM_TRY_REWRITE(x * y + x * z, x * (y + z)); + TVM_TRY_REWRITE(y * x + x * z, x * (y + z)); + TVM_TRY_REWRITE(x * y + z * x, x * (y + z)); + TVM_TRY_REWRITE(y * x + z * x, x * (y + z)); + + // modular-div simplification + // Always pre-condition on positive integer domain + TVM_TRY_REWRITE_IF( + (x / c1) * c1 + x % c1, x, + CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value > 0); + + // canonicalization rule + // will try rewrite again after canonicalization. + TVM_TRY_RECURSIVE_REWRITE(x + (c1 - y), (x - y) + c1); + TVM_TRY_RECURSIVE_REWRITE(x + c1 + y, (x + y) + c1); + TVM_TRY_RECURSIVE_REWRITE(x + (c1 + y), (x + y) + c1); + TVM_TRY_RECURSIVE_REWRITE((y % c1) + x * c1, x * c1 + (y % c1)); + } + + // condition rules. + TVM_TRY_REWRITE(select(x, b1, b2) + select(x, s1, s2), + select(x, b1 + s1, b2 + s2)); + // default value + return ret; +} + +Expr RewriteSimplifier::Impl:: +Mutate_(const Sub* op, const Expr& self) { + Expr ret = IRMutator::Mutate_(op, self); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); + if (const_res.defined()) return const_res; + // Pattern var to match any expression + PVar x, y, z, b1, b2, s1, s2; + // Pattern var match IntImm + PVar c1, c2, c3; + // Pattern var for lanes in broadcast and ramp + PVar lanes; + // Vector rules + if (op->type.lanes() != 1) { + TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), + ramp(b1 - b2, s1 - s2, lanes)); + TVM_TRY_REWRITE(ramp(b1, s1, lanes) - broadcast(x, lanes), + ramp(b1 - x, s1, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) - ramp(b1, s1, lanes), + ramp(x - b1, 0 - s1, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) - broadcast(y, lanes), + broadcast(x - y, lanes)); + } + + if (IsIndexType(op->type)) { + // Index rules + // cancelation rules + TVM_TRY_REWRITE((x + y) - y, x); + TVM_TRY_REWRITE((x + y) - x, y); + TVM_TRY_REWRITE(x - (y + x), 0 - y); + TVM_TRY_REWRITE(x - (x + y), 0 - y); + + TVM_TRY_REWRITE(min(x, y) - x, min(0, y - x)); + TVM_TRY_REWRITE(min(x, y) - y, min(x - y, 0)); + TVM_TRY_REWRITE(max(x, y) - x, max(0, y - x)); + TVM_TRY_REWRITE(max(x, y) - y, max(x - y, 0)); + + TVM_TRY_REWRITE(x - max(x, y), min(0, x - y)); + TVM_TRY_REWRITE(y - max(x, y), min(y - x, 0)); + TVM_TRY_REWRITE(x - min(x, y), max(0, x - y)); + TVM_TRY_REWRITE(y - min(x, y), max(y - x, 0)); + + // mul co-efficient folding + TVM_TRY_REWRITE(x - x, ZeroWithTypeLike(x)); + TVM_TRY_REWRITE(x * y - x, x * (y - 1)); + TVM_TRY_REWRITE(y * x - x, x * (y - 1)); + TVM_TRY_REWRITE(x - y * x, x * (1 - y)); + TVM_TRY_REWRITE(x - x * y, x * (1 - y)); + TVM_TRY_REWRITE(x * y - x * z, x * (y - z)); + TVM_TRY_REWRITE(y * x - x * z, x * (y - z)); + TVM_TRY_REWRITE(x * y - z * x, x * (y - z)); + TVM_TRY_REWRITE(y * x - z * x, x * (y - z)); + + // constant cancelation + TVM_TRY_REWRITE((x + c1) - c2, x + (c1 - c2)); + TVM_TRY_REWRITE((c1 - x) - (c2 - y), (y - x) + (c1 - c2)); + + // cancelization rule involving 4 operands + TVM_TRY_REWRITE((x + y) - (x + z), y - z); + TVM_TRY_REWRITE((x + y) - (z + x), y - z); + TVM_TRY_REWRITE((y + x) - (z + x), y - z); + TVM_TRY_REWRITE((y + x) - (x + z), y - z); + + TVM_TRY_REWRITE(min(x + y, z) - x, min(y, z - x)); + TVM_TRY_REWRITE(min(y + x, z) - x, min(y, z - x)); + TVM_TRY_REWRITE(min(z, x + y) - x, min(z - x, y)); + TVM_TRY_REWRITE(min(z, y + x) - x, min(z - x, y)); + + TVM_TRY_REWRITE(x - min(x + y, z), max(0 - y, x - z)); + TVM_TRY_REWRITE(x - min(y + x, z), max(0 - y, x - z)); + TVM_TRY_REWRITE(x - min(z, x + y), max(x - z, 0 - y)); + TVM_TRY_REWRITE(x - min(z, y + x), max(x - z, 0 - y)); + + TVM_TRY_REWRITE(min(x, y) - min(y, x), ZeroWithTypeLike(x)); + TVM_TRY_REWRITE(max(x, y) - max(y, x), ZeroWithTypeLike(x)); + + TVM_TRY_REWRITE_IF(min(b1, b2) - min(s1, s2), b1 - s1, + CanProveEqual(((b1 - s1) - (b2 - s2)).Eval(), 0)); + + TVM_TRY_REWRITE_IF(min(b1, b2) - min(s1, s2), b1 - s2, + CanProveEqual(((b1 - s2) - (b2 - s1)).Eval(), 0)); + TVM_TRY_REWRITE_IF(max(b1, b2) - max(s1, s2), b1 - s1, + CanProveEqual(((b1 - s1) - (b2 - s2)).Eval(), 0)); + TVM_TRY_REWRITE_IF(max(b1, b2) - max(s1, s2), b1 - s2, + CanProveEqual(((b1 - s2) - (b2 - s1)).Eval(), 0)); + + // modular-div simplification + // Always pre-condition on positive integer domain + TVM_TRY_REWRITE_IF(x - (x / c1) * c1, x % c1, + CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF((x / c1) * c1 - x, 0 - (x % c1), + CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF((x + c1) / c3 - (x + c2) / c3, + ((x + (c1 % c3)) % c3 + (c1 - c2)) / c3, + CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) && + c1.Eval()->value >= c2.Eval()->value && + c3.Eval()->value > 0); + TVM_TRY_REWRITE_IF((x + c1) / c3 - x / c3, + ((x + (c1 % c3)) % c3 + c1) / c3, + CanProveGreaterEqual(x.Eval(), 0) && + c1.Eval()->value >= 0 && + c3.Eval()->value > 0); + // canonicalization rule + // will try rewrite again after canonicalization. + TVM_TRY_REWRITE(x - c1, x + (0 - c1)); + TVM_TRY_RECURSIVE_REWRITE((x + c1) - y, (x - y) + c1); + TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y); + TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1)); + } + + // condition rules. + TVM_TRY_REWRITE(select(x, b1, b2) - select(x, s1, s2), + select(x, b1 - s1, b2 - s2)); + TVM_TRY_REWRITE(select(x, y, z) - z, + select(x, y - z, ZeroWithTypeLike(z))); + TVM_TRY_REWRITE(select(x, y, z) - y, + select(x, ZeroWithTypeLike(y), z - y)); + return ret; +} + +Expr RewriteSimplifier::Impl:: +Mutate_(const Mul* op, const Expr& self) { + Expr ret = IRMutator::Mutate_(op, self); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); + if (const_res.defined()) return const_res; + // Pattern var to match any expression + PVar x, y, z, b1, b2, s1, s2; + // Pattern var match IntImm + PVar c1, c2; + // Pattern var for lanes in broadcast and ramp + PVar lanes; + // Vector rules + if (op->type.lanes() != 1) { + TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), + broadcast(x * y, lanes)); + TVM_TRY_REWRITE(ramp(b1, s1, lanes) * broadcast(x, lanes), + ramp(b1 * x, s1 * x, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) * ramp(b1, s1, lanes), + ramp(b1 * x, s1 * x, lanes)); + } + + if (IsIndexType(op->type)) { + // constant simplification rule + TVM_TRY_REWRITE((x + c1) * c2, x * c2 + c1 * c2); + TVM_TRY_REWRITE((x * c1) * c2, x * (c1 * c2)); + TVM_TRY_REWRITE(min(x, y) * max(x, y), x * y); + TVM_TRY_REWRITE(max(x, y) * min(x, y), x * y); + + // canonicalization + TVM_TRY_RECURSIVE_REWRITE(x * (c1 * y), (x * y) * c1); + TVM_TRY_RECURSIVE_REWRITE_IF( + (x - y) * c1, (y - x) * (0 - c1), + c1.Eval()->value < 0); + } + return ret; +} + +Expr RewriteSimplifier::Impl:: +Mutate_(const Div* op, const Expr& self) { + Expr ret = IRMutator::Mutate_(op, self); + op = ret.as
(); + Expr const_res = TryConstFold
(op->a, op->b); + if (const_res.defined()) return const_res; + // Pattern var to match any expression + PVar x, y, z, b1; + // Pattern var match IntImm + PVar c1, c2, c3; + // Pattern var for lanes in broadcast and ramp + PVar lanes; + + // Vector rules + if (op->type.lanes() != 1) { + TVM_TRY_REWRITE(broadcast(x, lanes) / broadcast(y, lanes), + broadcast(x / y, lanes)); + // ramp / bcast + if ((ramp(b1, c1, lanes) / broadcast(c2, lanes)).Match(ret)) { + int64_t c1val = c1.Eval()->value; + int64_t c2val = c2.Eval()->value; + if (c1val % c2val == 0) { + return ramp(b1 / c2, c1 / c2, lanes).Eval(); + } + // If all possible indices in ramp are the same. + if (CanProveGreaterEqual(b1.Eval(), 0)) { + ModularSet bmod = parent_->modular_set(b1.Eval()); + int64_t ramp_min = bmod->base / c2val; + int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val; + if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) { + return broadcast(b1 / c2, lanes).Eval(); + } + } + } + } + + if (IsIndexType(op->type)) { + // Be-aware of the division rules: + // We adopt the default C division uses truncation instead of floordiv. + // This means most rules need to check non-negativeness of the operands. + + // while it is always true for trunc div + // restrict to common case(positive div) + TVM_TRY_REWRITE_IF((x / c1) / c2, x / (c1 * c2), + c1.Eval()->value > 0 && c2.Eval()->value > 0); + + TVM_TRY_REWRITE_IF((x / c1 + c2) / c3, (x + c1 * c2) / (c1 * c3), + c1.Eval()->value > 0 && + c2.Eval()->value >= 0 && + c3.Eval()->value > 0 && + CanProveGreaterEqual(x.Eval(), 0)); + + if (((x * c1) / c2).Match(ret)) { + int64_t c1val = c1.Eval()->value; + int64_t c2val = c2.Eval()->value; + if (c1val > 0 && c2val > 0) { + if (c1val % c2val == 0) return (x * (c1 / c2)).Eval(); + if (c2val % c1val == 0) return (x / (c2 / c1)).Eval(); + } + } + + // Rules involving 2-operands. + TVM_TRY_REWRITE_IF((x * c1 + y) / c2, x * (c1 / c2) + y / c2, + c1.Eval()->value >= 0 && + c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF(min(x * c1, y) / c2, min(x * (c1 / c2), y / c2), + c1.Eval()->value >= 0 && + c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF(max(x * c1, y) / c2, max(x * (c1 / c2), y / c2), + c1.Eval()->value >= 0 && + c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF((y + x * c1) / c2, y / c2 + x * (c1 / c2), + c1.Eval()->value >= 0 && + c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF(min(y, x * c1) / c2, min(y / c2, x * (c1 / c2)), + c1.Eval()->value >= 0 && + c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF(max(y, x * c1) / c2, max(y / c2, x * (c1 / c2)), + c1.Eval()->value >= 0 && + c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual(y.Eval(), 0)); + + // Rules involving 3-operands. + TVM_TRY_REWRITE_IF((x * c1 + y + z) / c2, x * (c1 / c2) + (y + z)/ c2, + c1.Eval()->value >= 0 && + c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual((y + z).Eval(), 0)); + + TVM_TRY_REWRITE_IF((x * c1 - y + z) / c2, x * (c1 / c2) + (z - y)/ c2, + c1.Eval()->value >= 0 && + c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual((z - y).Eval(), 0)); + + TVM_TRY_REWRITE_IF((x * c1 + y - z) / c2, x * (c1 / c2) + (y - z)/ c2, + c1.Eval()->value >= 0 && + c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual((y - z).Eval(), 0)); + + TVM_TRY_REWRITE_IF((y + x * c1 + z) / c2, x * (c1 / c2) + (y + z) / c2, + c1.Eval()->value > 0 && + c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual((y + z).Eval(), 0)); + + TVM_TRY_REWRITE_IF((x + c1) / c2, x / c2 + c1 / c2, + c1.Eval()->value > 0 && + c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0)); + + TVM_TRY_REWRITE_IF((x + y) / x, y / x + 1, + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual(y.Eval(), 0)); + TVM_TRY_REWRITE_IF((y + x) / x, y / x + 1, + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF(((x + y) + z) / x, (y + z) / x + 1, + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual((y + z).Eval(), 0)); + TVM_TRY_REWRITE_IF(((y + x) + z) / x, (y + z) / x + 1, + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual((y + z).Eval(), 0)); + TVM_TRY_REWRITE_IF((y + (z + x)) / x, (y + z) / x + 1, + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual((y + z).Eval(), 0)); + TVM_TRY_REWRITE_IF((y + (x + z)) / x, (y + z) / x + 1, + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual((y + z).Eval(), 0)); + + TVM_TRY_REWRITE_IF((x * y) / y, x, + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual(y.Eval(), 0)); + TVM_TRY_REWRITE_IF((y * x) / y, x, + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF((x * z + y) / z, x + y / z, + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual(y.Eval(), 0) && + CanProveGreaterEqual(z.Eval(), 0)); + TVM_TRY_REWRITE_IF((z * x + y) / z, x + y / z, + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual(y.Eval(), 0) && + CanProveGreaterEqual(z.Eval(), 0)); + TVM_TRY_REWRITE_IF((y + x * z) / z, y / z + x, + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual(y.Eval(), 0) && + CanProveGreaterEqual(z.Eval(), 0)); + TVM_TRY_REWRITE_IF((y + z * x) / z, y / z + x, + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual(y.Eval(), 0) && + CanProveGreaterEqual(z.Eval(), 0)); + } + return ret; +} + + +Expr RewriteSimplifier::Impl:: +Mutate_(const Mod* op, const Expr& self) { + Expr ret = IRMutator::Mutate_(op, self); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); + if (const_res.defined()) return const_res; + + // Pattern var to match any expression + PVar x, y, z, b1; + // Pattern var match IntImm + PVar c1, c2, c3; + // Pattern var for lanes in broadcast and ramp + PVar lanes; + + // Vector rules + if (op->type.lanes() != 1) { + TVM_TRY_REWRITE(broadcast(x, lanes) % broadcast(y, lanes), + broadcast(x % y, lanes)); + + // ramp % bcast + if ((ramp(b1, c1, lanes) % broadcast(c2, lanes)).Match(ret)) { + int64_t c1val = c1.Eval()->value; + int64_t c2val = c2.Eval()->value; + if (c1val % c2val == 0) { + return broadcast(b1 % c2, lanes).Eval(); + } + // If all possible indices in ramp are the same. + if (CanProveGreaterEqual(b1.Eval(), 0)) { + ModularSet bmod = parent_->modular_set(b1.Eval()); + int64_t ramp_min = bmod->base / c2val; + int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val; + if (bmod->coeff % c2val == 0) { + if (ramp_min == ramp_max) { + return ramp(bmod->base % c2, c1, lanes).Eval(); + } else { + return (ramp(bmod->base % c2, c1, lanes) % broadcast(c2, lanes)).Eval(); + } + } + } + } + } + + if (IsIndexType(op->type)) { + // Be-aware of the division rules: + // We adopt the default C division uses truncation instead of floordiv. + // This means most rules need to check non-negativeness of the operands. + TVM_TRY_REWRITE_IF((x * c1) % c2, ZeroWithTypeLike(x), + c2.Eval()->value != 0 && + c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF((x * c1 + y) % c2, y % c2, + c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF((x + c1) % c2, x % c2, + c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0)); + + TVM_TRY_REWRITE_IF((x + y * c1) % c2, x % c2, + c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual(y.Eval(), 0)); + + // try modular analysis + if ((x % c1).Match(ret)) { + ModularSet mod = parent_->modular_set(x.Eval()); + int64_t c1val = c1.Eval()->value; + if (mod->coeff % c1val == 0 && + CanProveGreaterEqual(x.Eval(), 0)) { + return (mod->base % c1).Eval(); + } + } + } + return ret; +} + + +Expr RewriteSimplifier::operator()(const Expr& expr) { + return impl_->PostOrderSimplify(expr); +} + +void RewriteSimplifier::Update(const Var& var, + const Expr& info, + bool override) { + impl_->Update(var, info, override); +} + + +RewriteSimplifier::RewriteSimplifier(Analyzer* parent) + : impl_(new Impl(parent)) { +} + +RewriteSimplifier::~RewriteSimplifier() { + delete impl_; +} + +} // namespace arith +} // namespace tvm diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index 1945339a259c..ea1a8427e61a 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -117,6 +117,7 @@ TEST(Pattern, Integer) { // special case container of Expr CHECK((v * c).Match(tx * 3)); CHECK_EQ(c.Eval()->value, 3); + CHECK((v * 3).Match(tx * 3)); } // cannot match c to ty CHECK(!(v * c).Match(tx * ty)); diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py new file mode 100644 index 000000000000..bbfddddd41da --- /dev/null +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -0,0 +1,252 @@ +import tvm + +class RewriteChecker: + def __init__(self): + self.analyzer = tvm.arith.Analyzer() + + def verify(self, data, expected): + res = self.analyzer.rewrite_simplify(data) + assert tvm.ir_pass.Equal(res, expected), "data={}, res={}, expected={}".format( + data, res, expected) + + +def test_vector_simplify(): + ck = RewriteChecker() + x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") + # Add rules + ck.verify(tvm.expr.Ramp(x, 1, 4) + tvm.expr.Ramp(y, 2, 4), + tvm.expr.Ramp(x + y, 3, 4)) + ck.verify(tvm.expr.Ramp(x, 1, 2) + y, + tvm.expr.Ramp(x + y, 1, 2)) + ck.verify(y + tvm.expr.Ramp(x, 1, 2) , + tvm.expr.Ramp(y + x, 1, 2)) + ck.verify(y.astype("int32x2") + x.astype("int32x2"), + (y + x).astype("int32x2")) + # Sub rules + ck.verify(tvm.expr.Ramp(x, 4, 4) - tvm.expr.Ramp(y, 2, 4), + tvm.expr.Ramp(x - y, 2, 4)) + ck.verify(tvm.expr.Ramp(x, 1, 2) - y, + tvm.expr.Ramp(x - y, 1, 2)) + ck.verify(y - tvm.expr.Ramp(x, 1, 2) , + tvm.expr.Ramp(y - x, -1, 2)) + ck.verify(y.astype("int32x2") - x.astype("int32x2"), + (y - x).astype("int32x2")) + + # Mul rules + ck.verify(y.astype("int32x2") * x.astype("int32x2"), + (y * x).astype("int32x2")) + ck.verify(tvm.expr.Ramp(x, 4, 4) * 2, + tvm.expr.Ramp(x * 2, 8, 4)) + ck.verify(2 * tvm.expr.Ramp(x, 4, 4), + tvm.expr.Ramp(x * 2, 8, 4)) + + ## Div rules + ck.verify(y.astype("int32x2") / x.astype("int32x2"), + (y / x).astype("int32x2")) + ck.verify(tvm.expr.Ramp(x, 4, 4) / 2, + tvm.expr.Ramp(x/ 2, 2, 4)) + ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) + ck.verify(tvm.expr.Ramp(x * 8 + 1, 1, 4) / 8, + (x).astype("int32x4")) + ck.verify(tvm.expr.Ramp(x * 8 + 15, 1, 4) / 8, + tvm.expr.Ramp(x * 8 + 15, 1, 4) / 8) + + ## Mod rules + ck.verify(y.astype("int32x2") % x.astype("int32x2"), + (y % x).astype("int32x2")) + ck.verify(tvm.expr.Ramp(x, 4, 4) % 2, + tvm.expr.Broadcast(x % 2, 4)) + ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) + ck.verify(tvm.expr.Ramp(x * 8 + 1, 1, 4) % 8, + tvm.expr.Ramp(1, 1, 4)) + ck.verify(tvm.expr.Ramp(x * 8 + 1, 15, 4) % 8, + tvm.expr.Ramp(1, 15, 4) % 8) + + + +def test_select_simplify(): + ck = RewriteChecker() + x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") + # Add rules + ck.verify(tvm.expr.Select(x > 0, y, 0) + tvm.expr.Select(x > 0, 1, z), + tvm.expr.Select(x > 0, y + 1, z)) + ck.verify(tvm.expr.Select(x > 0, y, 1) - tvm.expr.Select(x > 0, 1, z), + tvm.expr.Select(x > 0, y + (-1), 1 - z)) + ck.verify(tvm.expr.Select(x > 0, y, z) - y, + tvm.expr.Select(x > 0, 0, z - y)) + ck.verify(tvm.expr.Select(x > 0, y, z) - z, + tvm.expr.Select(x > 0, y - z, 0)) + + +def test_add_index_simplify(): + ck = RewriteChecker() + x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") + + ck.verify(x + (y - x), y) + ck.verify(x - (y + 1) + (y + 1), x) + ck.verify((x - 10) + (10 - z), x - z) + ck.verify((x - y) + (z - x), z - y) + + ck.verify(tvm.min(x, y - z) + z, tvm.min(x + z, y)) + ck.verify(tvm.min(x - z, y) + z, tvm.min(x, y + z)) + ck.verify(tvm.max(x, y - 10) + 10, tvm.max(x + 10, y)) + ck.verify(tvm.max(x - 11, y) + 11, tvm.max(x, y + 11)) + + ck.verify(tvm.max(x, y * 2) + tvm.min(x, y * 2), x + y * 2); + ck.verify(tvm.min(x, y * 2) + tvm.max(x, y * 2), x + y * 2); + + ck.verify(tvm.max(x, y + 2) + (-2), tvm.max(x + (-2), y)); + ck.verify(tvm.min(x, y + 2) + (-2), tvm.min(x + (-2), y)); + ck.verify(tvm.min(x + 2, y + 3) + (-2), tvm.min(x, y + 1)); + + ck.verify(x * y + x * 10, x * (y + 10)) + ck.verify(y * x + x * 10, x * (y + 10)) + ck.verify(y * x + 10 * x, x * (y + 10)) + ck.verify(x * y + 10 * x, x * (y + 10)) + + ck.verify(y * (x % 8) + 10 * (x % 8), (x % 8) * (y + 10)) + ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) + ck.verify((x / 8) * 8 + x % 8, x) + + # canonicalization + ck.verify(x + 2 + 3 + 4 + x, x * 2 + 9); + ck.verify(x + 2 + 3 + 4 + x * 3, x * 4 + 9); + + # conservative bound + try: + ck.analyzer.update(x, tvm.arith.ConstIntBound(-1, 1000), override=True) + ck.verify((x / 8) * 8 + x % 8, x) + raise RuntimeError("bad") + except AssertionError: + pass + + +def test_sub_index_simplify(): + ck = RewriteChecker() + x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") + + ck.verify(x + y - y, x) + ck.verify(x + y - x, y) + ck.verify(x - (y + x), 0 - y) + ck.verify(x - (x + y), 0 - y) + + ck.verify(tvm.min(x, y) - x, tvm.min(0, y - x)) + ck.verify(tvm.min(x, y) - y, tvm.min(x - y, 0)) + ck.verify(tvm.max(x, y) - x, tvm.max(0, y - x)) + ck.verify(tvm.max(x, y) - y, tvm.max(x - y, 0)) + + ck.verify(x - tvm.min(x, y), tvm.max(0, x - y)) + ck.verify(y - tvm.min(x, y), tvm.max(y - x, 0)) + ck.verify(x - tvm.max(x, y), tvm.min(0, x - y)) + ck.verify(y - tvm.max(x, y), tvm.min(y - x, 0)) + + # mul co-efficient foldng + ck.verify(x - x, 0) + ck.verify(x * y - x, x * (y + (-1))) + ck.verify(x * y - 10 * x, x * (y + (-10))) + ck.verify(y * x - x * z, x * (y - z)) + ck.verify(y * x - z * x, x * (y - z)) + + ck.verify(x + 10 - 20, x + (-10)) + + # 4-operands pattern + ck.verify((x + y) - (x + z), y - z) + ck.verify((y + x) - (x + z), y - z) + ck.verify((x + y) - (z + x), y - z) + ck.verify((y + x) - (z + x), y - z) + + ck.verify(tvm.min(x + y, z) - x, tvm.min(y, z - x)) + ck.verify(tvm.min(y + x, z) - x, tvm.min(y, z - x)) + ck.verify(tvm.min(z, x + y) - x, tvm.min(z - x, y)) + ck.verify(tvm.min(z, y + x) - x, tvm.min(z - x, y)) + + ck.verify(x - tvm.min(x + y, z), tvm.max(0 - y, x - z)) + ck.verify(x - tvm.min(y + x, z), tvm.max(0 - y, x - z)) + ck.verify(x - tvm.min(z, x + y), tvm.max(x - z, 0 - y)) + ck.verify(x - tvm.min(z, y + x), tvm.max(x - z, 0 - y)) + + ck.verify(tvm.min(x, y) - tvm.min(y, x), 0) + ck.verify(tvm.max(x, y) - tvm.max(y, x), 0) + ck.verify(tvm.min(x, y) - tvm.min(x + 10, y + 10), -10) + ck.verify(tvm.min(x + 10, y + 1) - tvm.min(x, y - 9), 10) + + # div pattern + ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) + ck.verify(x - (x / 3) * 3, x % 3) + ck.verify((x + 5) / 3 - x / 3, (((x + 2) % 3) + 5)/ 3) + + +def test_mul_index_simplify(): + ck = RewriteChecker() + x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") + ck.verify((x + 2) * 3, x * 3 + 6) + ck.verify((x * 2) * 3, x * 6) + ck.verify(tvm.min(x, y) * tvm.max(x, y), x * y) + ck.verify(tvm.max(x, y) * tvm.min(x, y), x * y) + ck.verify((x - y) * (-2), (y - x) * 2) + + +def test_div_index_simplify(): + ck = RewriteChecker() + x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") + ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) + ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True) + ck.analyzer.update(z, tvm.arith.ConstIntBound(0, 1000), override=True) + + ck.verify(x / 2 / 3, x / 6) + ck.verify((x / 2 + 1) / 3, (x + 2) / 6) + ck.verify(x * 2 / 4, x / 2) + ck.verify(x * 4 / 2, x * 2) + + ck.verify((x * 4 + y) / 2, x * 2 + y / 2) + ck.verify(tvm.min(x * 6, y) / 2, tvm.min(x * 3, y / 2)) + ck.verify(tvm.max(x * 6, y) / 2, tvm.max(x * 3, y / 2)) + + ck.verify((y + x * 4) / 2, y / 2 + x * 2) + ck.verify(tvm.min(y, x * 6) / 2, tvm.min(y / 2, x * 3)) + ck.verify(tvm.max(y, x * 6) / 2, tvm.max(y / 2, x * 3)) + + # 3-operands + ck.verify((x * 6 + y + z) / 2, x * 3 + (y + z) / 2) + ck.verify((x * 6 - y + (y + 3)) / 2, x * 3 + 1) + ck.verify((x * 6 + (y + 3) - y) / 2, x * 3 + 1) + ck.verify((y + x * 6 + z) / 2, x * 3 + (y + z) / 2) + ck.verify((x + 4) / 2, x / 2 + 2) + + ck.verify((x + y) / x, y / x + 1) + ck.verify((y + x) / x, y / x + 1) + ck.verify(((x + y) + z) / x, (y + z) / x + 1) + ck.verify(((y + x) + z) / x, (y + z) / x + 1) + ck.verify((y + (x + z)) / x, (y + z) / x + 1) + ck.verify((y + (z + x)) / x, (y + z) / x + 1) + + ck.verify((x * y) / y, x) + ck.verify((y * x) / y, x) + + ck.verify((x * z + y) / z, x + y / z) + ck.verify((z * x + y) / z, x + y / z) + ck.verify((y + x * z) / z, y / z + x) + ck.verify((y + z * x) / z, y / z + x) + + +def test_mod_index_simplify(): + ck = RewriteChecker() + x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") + ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) + ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True) + + ck.verify(x * 10 % 2, 0) + ck.verify((x * 10 + y) % 2, y % 2) + ck.verify((x + 10) % 2, x % 2) + ck.verify((x + y * 10) % 2, x % 2) + ck.verify((x* 10 + 1 + y * 2 + 2) % 2, 1) + + +if __name__ == "__main__": + test_mod_index_simplify() + test_vector_simplify() + test_add_index_simplify() + test_sub_index_simplify() + test_mul_index_simplify() + test_div_index_simplify() + test_select_simplify()