diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index cc9e5374b888..44b00b5d89fa 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -9,14 +9,282 @@ #include #include #include +#include #include "expr.h" namespace tvm { - +// forward delcare Tensor class Tensor; - /*! \brief namespace of arithmetic */ namespace arith { +//------------------------------------------------------- +// Base integer analysis API. +// +// We have multiple type of analyzers to do relaxed +// integer set analysis(bound analysis, modulo) and +// equivalence checking and simplification. +// +// Importantly, each analyzer may need result from +// another analyzer. +//------------------------------------------------------- + +// Forward declare Analyzer +class Analyzer; +/*! + * \brief reference class to ConstIntBoundNode + * \sa ConstIntBoundNode + */ +class ConstIntBound; +/*! + * \brief Constant integer up and lower bound(inclusive). + * Useful for value bound analysis. + * + * set = [min_value, max_value] + */ +class ConstIntBoundNode : public Node { + public: + int64_t min_value; + int64_t max_value; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("min_value", &min_value); + v->Visit("max_value", &max_value); + } + + TVM_DLL static ConstIntBound make(int64_t min_value, int64_t max_value); + + /*! \brief Number to represent +inf */ + static const constexpr int64_t kPosInf = std::numeric_limits::max(); + /*! + * \brief Number to represent -inf + * \note We can make use the of fact that -kPosInf == kNegInf in the project. + */ + static const constexpr int64_t kNegInf = -kPosInf; + + static constexpr const char* _type_key = "arith.ConstIntBound"; + TVM_DECLARE_NODE_TYPE_INFO(ConstIntBoundNode, Node); +}; + +TVM_DEFINE_NODE_REF(ConstIntBound, ConstIntBoundNode); + +/*! + * \brief Analyzer to get constant integer bound over expression. + */ +class ConstIntBoundAnalyzer { + public: + /*! + * \brief analyze the expr + * \param expr The expression of interest. + * \return the result of the analysis. + */ + ConstIntBound operator()(const Expr& expr); + + /*! + * \brief Update constant int bound information of var. + * + * \param var The variable of interest. + * \param info The bound information. + * \param override Whether do we allow override of existing information. + */ + void Update(const Var& var, + const ConstIntBound& info, + bool override = false); + /*! + * \brief Bind variable to a range. + * + * \param var The variable. + * \param range The range we bind to. + */ + void Bind(const Var& var, const Range& range); + + private: + friend class Analyzer; + friend class ConstraintContext; + explicit ConstIntBoundAnalyzer(Analyzer* parent); + ~ConstIntBoundAnalyzer(); + /*! + * \brief Update the internal state to enter constraint. + * \param constraint A constraint expression. + * + * \return an exit function that must be called to cleanup the constraint can be nullptr. + */ + std::function EnterConstraint(const Expr& constraint); + struct Entry; + class Impl; + /*! \brief Internal impl */ + Impl* impl_; +}; + +/*! + * \brief reference of ModularSetNode + * \sa ModularSetNode + */ +class ModularSet; +/*! + * \brief Range of a linear integer function. + * Use to do specify the possible index values. + * + * set = { coeff * x + base | x in Z } + * + * When coeff != 0, it can also be written as + * set = { n | n % coeff == base } + * + * This is useful to decide if the index is dividable by certain value. + * For example, if index = 0 + 4 x, then we know it can be divided by 4. + */ +class ModularSetNode : public Node { + public: + /*! \brief linear co-efficient */ + int64_t coeff; + /*! \brief The base */ + int64_t base; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("coeff", &coeff); + v->Visit("base", &base); + } + + TVM_DLL static ModularSet make(int64_t coeff, int64_t base); + + static constexpr const char* _type_key = "arith.ModularSet"; + TVM_DECLARE_NODE_TYPE_INFO(ModularSetNode, Node); +}; + +TVM_DEFINE_NODE_REF(ModularSet, ModularSetNode); + +/*! + * \brief Analyzer to get modular information over expression. + */ +class ModularSetAnalyzer { + public: + /*! + * \brief analyze the expr + * \param expr The expression of interest. + * \return the result of the analysis. + */ + ModularSet operator()(const Expr& expr); + /*! + * \brief Update constant int bound information of var. + * + * \param var The variable of interest. + * \param info The bound information. + * \param override Whether do we allow override of existing information. + */ + void Update(const Var& var, + const ModularSet& info, + bool override = false); + + private: + friend class Analyzer; + friend class ConstraintContext; + explicit ModularSetAnalyzer(Analyzer* parent); + ~ModularSetAnalyzer(); + /*! + * \brief Update the internal state to enter constraint. + * \param constraint A constraint expression. + * + * \return an exit function that must be called to cleanup the constraint can be nullptr. + */ + std::function EnterConstraint(const Expr& constraint); + struct Entry; + class Impl; + /*! \brief Internal impl */ + Impl* impl_; +}; + +/*! + * \brief A RAII constraint context. + * + * \code + * + * Var("x"); + * arith::Analyzer analyzer; + * { + * arith::ConstraintContext cctx(&analyzer, x % 3 == 0); + * CHECK_EQ(analyzer.modular_set(x)->coeff, 3); + * } + * // constraint no longer in effect. + * CHECK_NE(analyzer.modular_set(x)->coeff, 3); + * + * \endcode + */ +class ConstraintContext { + public: + /*! + * \brief Construct a constraint context. + * \param analyzer The analyzer. + * \param constraint The constraint to be applied. + */ + ConstraintContext(Analyzer* analyzer, const Expr& constraint) DMLC_THROW_EXCEPTION; + /*! \brief destructor */ + ~ConstraintContext() DMLC_THROW_EXCEPTION { + exit_(); + } + + private: + /*! \brief function to be called in recovery */ + std::function exit_; +}; + +/*! + * \brief Analyzer that contains bunch of sub-analyzers. + * + * Each sub-analyzer can make use of another sub-analyzer + * by weak reference of this. + * + * NOTE for sub-analyzer developers: + * If the analyzer uses memoization, we need to clear the internal + * cache when information about a Var has been overrideen. + */ +class Analyzer { + public: + /*! \brief sub-analyzer: const integer bound */ + ConstIntBoundAnalyzer const_int_bound; + /*! \brief sub-analyzer: modular set */ + ModularSetAnalyzer modular_set; + /*! \brief constructor */ + Analyzer(); + /*! + * \brief Notify all the sub-analyzers that var + * is created and binded to expr. + * + * Each var can only be binded once. + * + * \param var The variable. + * \param expr The expression we bind to. + */ + void Bind(const VarExpr& var, const Expr& expr); + /*! + * \brief Notify all the sub-analyzers that var + * is created and binded to a range. + * + * Each var can only be binded once. + * + * \param var The variable. + * \param range The range we bind to. + */ + void Bind(const VarExpr& var, const Range& range); + /*! + * \brief Whether can we proof expr >= val. + + * Non-negative proof is very useful in integer analysis + * to lower divisions and mods given difference in trunc and ceil mode. + * + * \param expr The expression. + * \param lower_bound The lower bound. + * \return Whether we can proof it. + * + * \note Analyzer will call into sub-analyzers to get the result. + */ + bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound); +}; + +//----------------------------------------------- +// Integer set abstraction API. +// +// This is a API build on top of the base +// integer analysis API to provide set analysis. +//------------------------------------------------ /*! * \brief Sign of an expression or set. */ @@ -118,42 +386,6 @@ class IntSet : public NodeRef { static IntSet interval(Expr min, Expr max); }; -/*! - * \brief Range of a linear integer function. - * Use to do specify the possible index values. - * - * set = { coeff * x + base | x in Z } - * - * When coeff != 0, it can also be written as - * set = { n | n % coeff == base } - * - * This is useful to decide if the index is dividable by certain value. - * For example, if index = 0 + 4 x, then we know it can be divided by 4. - */ -struct ModularEntry { - /*! \brief linear co-efficient */ - int coeff{1}; - /*! \brief The base */ - int base{0}; - - /*! \return entry represent everything */ - static ModularEntry everything() { - // always safe to set 0 + x, so it can be everything. - ModularEntry e; - e.coeff = 1; - e.base = 0; - return e; - } - /*! - * \brief Add two modular entries together to get a new modular entry. - * \param a The left operand. - * \param b The right operand. - * \return The combined modular entry. - */ - static ModularEntry Add(const ModularEntry& a, - const ModularEntry& b); -}; - /*! * \brief Base class of all IntSet containers. */ @@ -300,24 +532,6 @@ IntSet DeduceBound(Expr v, Expr cond, */ Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides); -/*! - * \brief Evaluate the expression with modular analysis - * \param e The expression to be evaluated. - * \param mod_map Map of modular statistics of known variables. - * \return The ModularEntry covering all possible value of e. - */ -ModularEntry EvalModular( - const Expr& e, - const std::unordered_map& mod_map); - -/*! - * \brief Same as EvalModular, used by front-end. - * \param e The expression to be evaluated. - * \param mod_map Map of modular statistics of known variables. - * \return A ModularSet covering all possible value of e. - */ -IntSet EvalModular(const Expr& e, - const Map& mod_map); // implementation inline const IntSetNode* IntSet::operator->() const { return static_cast(node_.get()); diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 3ef955e834d0..0f05c98e0722 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -17,6 +17,7 @@ namespace tvm { namespace ir { +using HalideIR::Internal::BaseExprNode; using HalideIR::Internal::ExprNode; using HalideIR::Internal::StmtNode; using HalideIR::Internal::IRNodeType; diff --git a/python/tvm/arith.py b/python/tvm/arith.py index 778d761c659e..92aaa36aa10f 100644 --- a/python/tvm/arith.py +++ b/python/tvm/arith.py @@ -33,9 +33,162 @@ class StrideSet(IntSet): """Represent set of strided integers""" -@register_node -class ModularSet(IntSet): +@register_node("arith.ModularSet") +class ModularSet(NodeBase): """Represent range of (coeff * x + base) for x in Z """ + def __init__(self, coeff, base): + self.__init_handle_by_constructor__( + _make_ModularSet, coeff, base) + + +@register_node("arith.ConstIntBound") +class ConstIntBound(NodeBase): + """Represent constant integer bound + + Parameters + ---------- + min_value : int + The minimum value of the bound. + + max_value : int + The maximum value of the bound. + """ + POS_INF = (1 << 63) - 1 + NEG_INF = -POS_INF + + def __init__(self, min_value, max_value): + self.__init_handle_by_constructor__( + _make_ConstIntBound, min_value, max_value) + + +class ConstraintScope: + """Constraint scope. + + Parameters + ---------- + fenter : function + A function that will be called to create an enter context. + + Note + ---- + Do not create object directly, use Analyzer.constraint_scope + """ + def __init__(self, fenter): + self._fenter = fenter + self._fexit = None + + def __enter__(self): + self._fexit = self._fenter() + + def __exit__(self, ptype, value, trace): + self._fexit() + + +class Analyzer: + """Integer arithmetic analyzer + + This is a stateful analyzer class that can + be used to perform various symbolic integer analysis. + """ + def __init__(self): + _mod = _CreateAnalyzer() + self._const_int_bound = _mod("const_int_bound") + self._const_int_bound_update = _mod("const_int_bound_update") + self._bind = _mod("bind") + self._modular_set = _mod("modular_set") + self._enter_constraint_context = _mod("enter_constraint_context") + + def const_int_bound(self, expr): + """Find constant integer bound for expr. + + Parameters + ---------- + expr : tvm.Expr + The expression. + + Returns + ------- + bound : ConstIntBound + The result bound + """ + return self._const_int_bound(expr) + + def modular_set(self, expr): + """Find a modular set that expr belongs to. + + Parameters + ---------- + expr : tvm.Expr + The expression. + + Returns + ------- + result : ModularSet + The result. + """ + return self._modular_set(expr) + + def bind(self, var, expr): + """Bind a variable to the expression. + + Parameters + ---------- + var : tvm.Var + The variable. + + expr : tvm.Expr + The expression. + """ + return self._bind(var, expr) + + def constraint_scope(self, constraint): + """Create a constraint scope. + + Parameters + ---------- + constraint : tvm.Expr + The constraint expression. + + returns + ------- + scope : ConstraintScope + The constraint scope + + Examples + -------- + .. code-block:: python + + x = tvm.var("x") + analyzer = tvm.arith.Analyzer() + with analzyer.constraint_scope(x % 3 == 0): + # constraint in effect + assert analyzer.modular_set(x).coeff == 3 + # constraint no longer in effect + assert analyzer.modular_set(x).coeff != 3 + """ + def _fenter(): + return self._enter_constraint_context(constraint) + return ConstraintScope(_fenter) + + def update(self, var, info, override=False): + """Update infomation about var + + Parameters + ---------- + var : tvm.Var + The variable. + + info : tvm.NodeBase + Related information. + + override : bool + Whether allow override. + """ + if isinstance(info, ConstIntBound): + self._const_int_bound_update(var, info, override) + else: + raise TypeError( + "Do not know how to handle type {}".format(type(info))) _init_api("tvm.arith") diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index 31ff5ccb3a15..cba70370f5b6 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -26,11 +26,6 @@ TVM_REGISTER_API("arith.intset_interval") *ret = IntSet::interval(args[0], args[1]); }); -TVM_REGISTER_API("arith.EvalModular") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = EvalModular(args[0], Map()); - }); - TVM_REGISTER_API("arith.DetectLinearEquation") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = DetectLinearEquation(args[0], args[1]); @@ -75,5 +70,56 @@ TVM_REGISTER_API("_IntSetIsEverything") *ret = args[0].operator IntSet().is_everything(); }); +TVM_REGISTER_API("arith._make_ConstIntBound") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ConstIntBoundNode::make(args[0], args[1]); + }); + +TVM_REGISTER_API("arith._make_ModularSet") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ModularSetNode::make(args[0], args[1]); + }); + +TVM_REGISTER_API("arith._CreateAnalyzer") +.set_body([](TVMArgs args, TVMRetValue* ret) { + using runtime::PackedFunc; + using runtime::TypedPackedFunc; + auto self = std::make_shared(); + auto f = [self](std::string name) -> PackedFunc { + if (name == "const_int_bound") { + return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { + *ret = self->const_int_bound(args[0]); + }); + } else if (name == "modular_set") { + return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { + *ret = self->modular_set(args[0]); + }); + } else if (name == "const_int_bound_update") { + return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { + self->const_int_bound.Update(args[0], args[1], args[2]); + }); + } else if (name == "bind") { + return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { + auto& sptr = args[1].node_sptr(); + if (sptr->is_type()) { + self->Bind(args[0], args[1].operator Range()); + } else { + self->Bind(args[0], args[1].operator Expr()); + } + }); + } else if (name == "enter_constraint_context") { + return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { + auto ctx = std::make_shared(self.get(), args[0]); + auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { + ctx.reset(); + }; + *ret = PackedFunc(fexit); + }); + } + return PackedFunc(); + }; + *ret = TypedPackedFunc(f); +}); + } // namespace arith } // namespace tvm diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc new file mode 100644 index 000000000000..236a21ba71f5 --- /dev/null +++ b/src/arithmetic/analyzer.cc @@ -0,0 +1,44 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file tvm/arithmetic/analyzer.cc + */ +#include + +namespace tvm { +namespace arith { + +Analyzer::Analyzer() + : const_int_bound(this), + modular_set(this) { +} + +void Analyzer::Bind(const VarExpr& v, const Expr& expr) { + Var var(v.node_); + this->const_int_bound.Update(var, this->const_int_bound(expr)); + this->modular_set.Update(var, this->modular_set(expr)); +} + +void Analyzer::Bind(const VarExpr& v, const Range& range) { + Var var(v.node_); + this->const_int_bound.Bind(var, range); + // skip modular_set +} + +ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) { + // entering the scope. + auto f0 = analyzer->const_int_bound.EnterConstraint(constraint); + auto f1 = analyzer->modular_set.EnterConstraint(constraint); + // recovery function. + exit_ = [f0, f1]() { + if (f1 != nullptr) f1(); + if (f0 != nullptr) f0(); + }; +} + +bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) { + auto bd = this->const_int_bound(expr); + if (bd->min_value >= lower_bound) return true; + return false; +} +} // namespace arith +} // namespace tvm diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc new file mode 100644 index 000000000000..c83be8933b55 --- /dev/null +++ b/src/arithmetic/const_int_bound.cc @@ -0,0 +1,393 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file tvm/arithmetic/const_int_bound.cc + */ +#include +#include +#include +#include "int_op_overflow.h" + +namespace tvm { +namespace arith { + +using namespace ir; + +TVM_REGISTER_NODE_TYPE(ConstIntBoundNode); + +ConstIntBound ConstIntBoundNode::make( + int64_t min_value, int64_t max_value) { + auto node = make_node(); + node->min_value = min_value; + node->max_value = max_value; + return ConstIntBound(node); +} + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const ConstIntBoundNode *op, IRPrinter *p) { + p->stream << "ConstIntBound" + << "[" << op->min_value << ", " + << op->max_value << ']'; + }); + +// internal entry for const int bound +struct ConstIntBoundAnalyzer::Entry { + int64_t min_value; + int64_t max_value; + + bool is_const(int64_t value) const { + return min_value == max_value && min_value == value; + } +}; + +class ConstIntBoundAnalyzer::Impl : + public ExprFunctor { + public: + void Bind(const Var& var, const Range& range) { + Entry a = VisitExpr(range->min); + Entry b = VisitExpr(range->extent); + Entry ret; + ret.min_value = a.min_value; + ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1)); + Update(var, ret, false); + } + + void Update(const Var& var, + const Entry& info, + bool override) { + if (!override) { + CHECK(!var_map_.count(var)); + } + var_map_[var] = info; + } + + void Update(const Var& var, + const ConstIntBound& info, + bool override) { + Update(var, MakeBound(info->min_value, info->max_value), override); + } + + // Override visitor behaviors + Entry VisitExprDefault_(const Node* op) final { + return Everything( + static_cast(op)->type); + } + + Entry VisitExpr_(const Cast* op) final { + Entry a = VisitExpr(op->value); + Entry b = Everything(op->type); + return Intersect(a, b); + } + + Entry VisitExpr_(const IntImm* op) final { + return MakeBound(op->value, op->value); + } + + Entry VisitExpr_(const UIntImm* op) final { + if (op->value <= static_cast(kPosInf)) { + return MakeBound(op->value, op->value); + } else { + return Everything(op->type); + } + } + + Entry VisitExpr_(const Add* op) final { + Entry a = VisitExpr(op->a); + Entry b = VisitExpr(op->b); + Entry ret; + ret.min_value = InfAwareAdd(a.min_value, b.min_value); + ret.max_value = InfAwareAdd(a.max_value, b.max_value); + return ret; + } + + Entry VisitExpr_(const Sub* op) final { + Entry a = VisitExpr(op->a); + Entry b = VisitExpr(op->b); + Entry ret; + ret.min_value = InfAwareAdd(a.min_value, -b.max_value); + ret.max_value = InfAwareAdd(a.max_value, -b.min_value); + return ret; + } + + Entry VisitExpr_(const Mul* op) final { + Entry a = VisitExpr(op->a); + Entry b = VisitExpr(op->b); + return BinaryOpBoundry(a, b, InfAwareMul); + } + + Entry VisitExpr_(const Div* op) final { + Entry a = VisitExpr(op->a); + Entry b = VisitExpr(op->b); + CHECK(!b.is_const(0)) << "divide by zero"; + // assume no division by 0 + if (b.min_value == 0) b.min_value = 1; + if (b.max_value == 0) b.max_value = -1; + return BinaryOpBoundry(a, b, InfAwareDiv); + } + + Entry VisitExpr_(const Mod* op) final { + Entry a = VisitExpr(op->a); + Entry b = VisitExpr(op->b); + if (b.min_value > 0) { + int64_t b_max_cap = InfAwareAdd(b.max_value, -1); + if (a.min_value >= 0) { + // 0 <= [a_min, a_max] < b_min + if (a.max_value < b.min_value) return a; + // other case, we can get close to 0 + return MakeBound(0, + std::min(a.max_value, b_max_cap)); + } else { + return MakeBound(std::max(a.min_value, -b_max_cap), + std::min(a.max_value, b_max_cap)); + } + } else { + CHECK(!b.is_const(0)) << "mod by zero"; + // mod by negative value is rare, + // and we just use the simpliest rule. + return Everything(op->type); + } + } + + Entry VisitExpr_(const Min* op) final { + Entry a = VisitExpr(op->a); + Entry b = VisitExpr(op->b); + Entry ret; + ret.min_value = std::min(a.min_value, b.min_value); + ret.max_value = std::min(a.max_value, b.max_value); + return ret; + } + + Entry VisitExpr_(const Max* op) final { + Entry a = VisitExpr(op->a); + Entry b = VisitExpr(op->b); + Entry ret; + ret.min_value = std::max(a.min_value, b.min_value); + ret.max_value = std::max(a.max_value, b.max_value); + return ret; + } + + Entry VisitExpr_(const Select* op) final { + Entry a = VisitExpr(op->true_value); + Entry b = VisitExpr(op->false_value); + return Union(a, b); + } + + Entry VisitExpr_(const Call* op) final { + // only special handle >> and & which can be + // used for index calculation. + if (op->is_intrinsic(Call::shift_right)) { + return VisitRightShift(op); + } else if (op->is_intrinsic(Call::bitwise_and)) { + return VisitBitwiseAnd(op); + } else { + return Everything(op->type); + } + } + + Entry VisitExpr_(const Variable* op) final { + Var v = GetRef(op); + auto it = var_map_.find(v); + if (it != var_map_.end()) { + return it->second; + } else { + return Everything(op->type); + } + } + + Entry VisitRightShift(const Call* op) { + Entry a = VisitExpr(op->args[0]); + Entry b = VisitExpr(op->args[1]); + return BinaryOpBoundry(a, b, InfAwareRightShift); + } + + Entry VisitBitwiseAnd(const Call* op) { + Entry a = VisitExpr(op->args[0]); + Entry b = VisitExpr(op->args[1]); + // handle positive index case. + if (a.min_value >= 0 && b.min_value >= 0) { + return MakeBound(0, std::min(a.max_value, b.max_value)); + } else { + if (b.min_value >= 0) { + return MakeBound(0, b.max_value); + } + if (a.min_value >= 0) { + return MakeBound(0, a.max_value); + } + return Everything(op->type); + } + } + + private: + // internal variable map + std::unordered_map var_map_; + // constants: the limit value means umlimited + // NOTE: kNegInf/kPosInf are used to represent infinity. + static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf; + static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf; + static_assert(-kNegInf == kPosInf, "invariant of inf"); + // internal helper functions + /*! + * \brief Get boundary of binary op who are monotonic wrt to one argument. + * \param param a The entry of the left operand. + * \param param a The entry of the right operand. + * \param op The operator. + * \tparam F the operator function type. + * \return The result. + */ + template + static Entry BinaryOpBoundry(Entry a, Entry b, const F& op) { + Entry ret; + // The boundary point must be shihft of the original boundary. + int64_t v1 = op(a.min_value, b.min_value); + int64_t v2 = op(a.max_value, b.max_value); + int64_t v3 = op(a.min_value, b.max_value); + int64_t v4 = op(a.max_value, b.min_value); + ret.min_value = std::min(std::min(std::min(v1, v2), v3), v4); + ret.max_value = std::max(std::max(std::max(v1, v2), v3), v4); + return ret; + } + /*! + * \brief Compute x + y, aware of inf. + * \param x The left operand. + * \param y The right operand. + * \return the result. + */ + static int64_t InfAwareAdd(int64_t x, int64_t y) { + if (x == kPosInf) { + CHECK(y != kNegInf); + return kPosInf; + } + if (x == kNegInf) { + CHECK(y != kPosInf); + return kNegInf; + } + if (y == kPosInf || y == kNegInf) return y; + if (WillOverflow(x, y, kNegInf, kPosInf)) { + if (x > 0) return kPosInf; + return kNegInf; + } + return x + y; + } + /*! + * \brief Compute x * y, aware of inf. + * \param x The left operand. + * \param y The right operand. + * \return the result. + */ + static int64_t InfAwareMul(int64_t x, int64_t y) { + if (!WillOverflow(x, y, kNegInf, kPosInf)) return x * y; + if ((x > 0 && y > 0) || (x < 0 && y < 0)) return kPosInf; + return kNegInf; + } + /*! + * \brief Compute x / y, aware of inf. + * \param x The left operand. + * \param y The right operand. + * \return the result. + */ + static int64_t InfAwareDiv(int64_t x, int64_t y) { + CHECK_NE(y, 0); + if (x == kPosInf || x == kNegInf) { + if (y > 0) return x; + return -x; + } + return x / y; + } + /*! + * \brief Compute x / y, aware of inf. + * \param x The left operand. + * \param y The right operand. + * \return the result. + */ + static int64_t InfAwareRightShift(int64_t x, int64_t y) { + if (x == kPosInf || x == kNegInf) return x; + return x >> y; + } + /*! + * \brief Make a new bound entry. + */ + static Entry MakeBound(int64_t min_value, int64_t max_value) { + Entry e; + e.min_value = min_value; + e.max_value = max_value; + return e; + } + /*! + * \brief Create union of two sets. + * \param a The left operand. + * \param b the right operand. + */ + static Entry Union(Entry a, Entry b) { + Entry ret; + ret.min_value = std::min(a.min_value, b.min_value); + ret.max_value = std::max(a.max_value, b.max_value); + return ret; + } + /*! + * \brief Create intersect of two sets. + * \param a The left operand. + * \param b the right operand. + */ + static Entry Intersect(Entry a, Entry b) { + Entry ret; + ret.min_value = std::max(a.min_value, b.min_value); + ret.max_value = std::min(a.max_value, b.max_value); + return ret; + } + /*! + * \brief return everything dtype can represent. + * \param dtype The data type. + * \return Bound that represent everything dtype can represent. + */ + static Entry Everything(Type dtype) { + if (!dtype.is_int() && !dtype.is_uint()) { + return MakeBound(kNegInf, kPosInf); + } + Entry ret; + int64_t vbits = dtype.bits() - static_cast(dtype.is_int()); + if (dtype.is_uint()) { + ret.min_value = 0; + } else { + if (vbits >= 63) { + ret.min_value = kNegInf; + } else { + ret.min_value = -(static_cast(1) << vbits); + } + } + if (vbits >= 63) { + ret.max_value = kPosInf; + } else { + ret.max_value = (static_cast(1) << vbits) - 1; + } + return ret; + } +}; + +ConstIntBound ConstIntBoundAnalyzer::operator()(const Expr& expr) { + Entry ret = impl_->VisitExpr(expr); + return ConstIntBoundNode::make(ret.min_value, ret.max_value); +} + +void ConstIntBoundAnalyzer::Update(const Var& var, + const ConstIntBound& info, + bool override) { + impl_->Update(var, info, override); +} + +void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) { + impl_->Bind(var, range); +} + +std::function ConstIntBoundAnalyzer::EnterConstraint(const Expr& constraint) { + return nullptr; +} + +ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) + : impl_(new Impl()) { +} + +ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { + delete impl_; +} + +} // namespace arith +} // namespace tvm diff --git a/src/arithmetic/int_op_overflow.h b/src/arithmetic/int_op_overflow.h new file mode 100644 index 000000000000..ef637b4b9521 --- /dev/null +++ b/src/arithmetic/int_op_overflow.h @@ -0,0 +1,78 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file int_op_overflow.h + * \brief Utility functions to detect if an integer op will overflow. + */ +#ifndef TVM_ARITHMETIC_INT_OP_OVERFLOW_H_ +#define TVM_ARITHMETIC_INT_OP_OVERFLOW_H_ + +#include + +namespace tvm { +namespace arith { + +/*! + * \brief Check if an integer op with operand x, y will overflow. + * \param x The left operand. + * \param y The left operand. + * \param min_value The minimum value of the domain. + * \param max_value The maximum value of the domain. + * \return Whether overflow can happen. + * \tparam Op The integer operator. + */ +template +inline bool WillOverflow(int64_t x, + int64_t y, + int64_t min_value, + int64_t max_value) { + return false; +} + +template<> +bool WillOverflow(int64_t x, + int64_t y, + int64_t min_value, + int64_t max_value) { + if ((y > 0) && (x > max_value - y)) return true; + if ((y < 0) && (x < min_value - y)) return true; + return false; +} + +template<> +bool WillOverflow(int64_t x, + int64_t y, + int64_t min_value, + int64_t max_value) { + if ((y > 0) && (x < min_value + y)) return true; + if ((y < 0) && (x > max_value + y)) return true; + return false; +} + +template<> +bool WillOverflow(int64_t x, + int64_t y, + int64_t min_value, + int64_t max_value) { + if (y == 0) return false; + if (y > 0) { + if (x < min_value / y) return true; + if (x > max_value / y) return true; + } else { + if (y == -1 && x == std::numeric_limits::min()) return true; + if (x > min_value / y) return true; + if (x < max_value / y) return true; + } + return false; +} + +template<> +bool WillOverflow(int64_t x, + int64_t y, + int64_t min_value, + int64_t max_value) { + return y == 0; +} + +} // namespace arith +} // namespace tvm +#endif // TVM_ARITHMETIC_INT_OP_OVERFLOW_H_ diff --git a/src/arithmetic/int_set_internal.h b/src/arithmetic/int_set_internal.h index e28fe2a9d958..cc2a4c307997 100644 --- a/src/arithmetic/int_set_internal.h +++ b/src/arithmetic/int_set_internal.h @@ -54,23 +54,6 @@ struct StrideSet : public IntSetNode { TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode); }; -/*! - * \brief Set represented by range of ModularEntry. - * Used for front-end modular analysis. - */ -struct ModularSet : public IntSetNode { - /*! \brief Internal modular entry */ - ModularEntry e; - - void VisitAttrs(AttrVisitor* v) final { - v->Visit("base", &(e.base)); - v->Visit("coeff", &(e.coeff)); - } - static constexpr const char* _type_key = "ModularSet"; - TVM_DECLARE_NODE_TYPE_INFO(ModularSet, IntSetNode); -}; - - } // namespace arith } // namespace tvm diff --git a/src/arithmetic/modular.cc b/src/arithmetic/modular.cc deleted file mode 100644 index d79300eb7782..000000000000 --- a/src/arithmetic/modular.cc +++ /dev/null @@ -1,168 +0,0 @@ -/*! - * Copyright (c) 2017 by Contributors - * \file modular.cc - * \brief Modular analysis - */ -#include -#include -#include -#include -#include -#include "int_set_internal.h" - -namespace tvm { -namespace arith { - -using namespace ir; - -class ModularEvaluator - : public ExprFunctor { - public: - explicit ModularEvaluator( - const std::unordered_map< - const Variable*, ModularEntry>& mod_map) - : mod_map_(mod_map) { - } - ModularEntry Eval(const Expr& e) { - return VisitExpr(e); - } - // default - ModularEntry VisitExprDefault_(const Node*) final { - return ModularEntry::everything(); - } - // override combination rules. - ModularEntry VisitExpr_(const IntImm* op) final { - if (op->value < std::numeric_limits::max()) { - ModularEntry ret; - ret.base = static_cast(op->value); - ret.coeff = 0; - return ret; - } else { - return ModularEntry::everything(); - } - } - ModularEntry VisitExpr_(const UIntImm* op) final { - if (op->value < static_cast( - std::numeric_limits::max())) { - ModularEntry ret; - ret.base = static_cast(op->value); - ret.coeff = 0; - return ret; - } else { - return ModularEntry::everything(); - } - } - ModularEntry VisitExpr_(const Variable* op) final { - auto it = mod_map_.find(op); - if (it != mod_map_.end()) { - return it->second; - } else { - return ModularEntry::everything(); - } - } - ModularEntry VisitExpr_(const Add* op) final { - ModularEntry a = Eval(op->a); - ModularEntry b = Eval(op->b); - ModularEntry ret; - ret.coeff = ZeroAwareGCD(a.coeff, b.coeff); - ret.base = BaseSimplify(a.base + b.base, ret.coeff); - return ret; - } - ModularEntry VisitExpr_(const Sub* op) final { - ModularEntry a = Eval(op->a); - ModularEntry b = Eval(op->b); - ModularEntry ret; - ret.coeff = ZeroAwareGCD(a.coeff, b.coeff); - ret.base = BaseSimplify(a.base - b.base, ret.coeff); - return ret; - } - ModularEntry VisitExpr_(const Mul* op) final { - ModularEntry a = Eval(op->a); - ModularEntry b = Eval(op->b); - // Simplification rule, x, y, z are in Z - // (p x + n) (q y + m) - // -> pq xy + pm x + qn y + mn - // -> pq z + pm x + qn y + mn - int pq = a.coeff * b.coeff; - int pm = a.coeff * b.base; - int qn = a.base * b.coeff; - ModularEntry ret; - ret.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn)); - ret.base = BaseSimplify(a.base * b.base, ret.coeff); - return ret; - } - ModularEntry VisitExpr_(const Div* op) final { - // a c x / c -> a x - // We cannot do cases where offset is non-zero - // because of different integer rounding in pos/neg - ModularEntry a = Eval(op->a); - ModularEntry b = Eval(op->b); - if (b.coeff == 0 && - a.base == 0) { - CHECK_NE(b.base, 0); - if (a.coeff % b.base == 0) { - ModularEntry ret; - ret.coeff = a.coeff / b.base; - ret.base = 0; - return ret; - } - } - return ModularEntry::everything(); - } - - private: - const std::unordered_map< - const Variable*, ModularEntry>& mod_map_; - friend struct ModularEntry; - // simplify the base by putting it in range. - static int BaseSimplify(int base, int coeff) { - if (coeff == 0) return base; - base = base % coeff; - if (base < 0) base += coeff; - return base; - } - static int ZeroAwareGCD(int a, int b) { - CHECK_GE(a, 0); - CHECK_GE(b, 0); - if (a < b) std::swap(a, b); - if (b == 0) return a; - // perform GCD (greatest common divisor) - // ax + by = gcd(a, b) z if a != 0, b != 0 - while (a % b != 0) { - a = a % b; - std::swap(a, b); - } - return b; - } -}; - -ModularEntry ModularEntry::Add(const ModularEntry& a, - const ModularEntry& b) { - ModularEntry ret; - ret.coeff = ModularEvaluator::ZeroAwareGCD(a.coeff, b.coeff); - ret.base = ModularEvaluator::BaseSimplify(a.base + b.base, ret.coeff); - return ret; -} - - -ModularEntry EvalModular( - const Expr& e, - const std::unordered_map& mod_map) { - return ModularEvaluator(mod_map)(e); -} - -IntSet EvalModular(const Expr& e, - const Map& mod_map) { - std::unordered_map mmap; - for (auto& kv : mod_map) { - const ModularSet* m = kv.second.as(); - CHECK(m) << "Need to pass ModularSet for Modular Analysis"; - mmap[kv.first.get()] = m->e; - } - NodePtr n = make_node(); - n->e = ModularEvaluator(mmap)(e); - return IntSet(n); -} - -} // namespace arith -} // namespace tvm diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc new file mode 100644 index 000000000000..8da6e91fc7fa --- /dev/null +++ b/src/arithmetic/modular_set.cc @@ -0,0 +1,344 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file modular_set.cc + * \brief Modular set analysis + */ +#include +#include +#include +#include +#include "pattern_match.h" + +namespace tvm { +namespace arith { + +using namespace ir; + +TVM_REGISTER_NODE_TYPE(ModularSetNode); + +ModularSet ModularSetNode::make(int64_t coeff, int64_t base) { + auto node = make_node(); + node->coeff = coeff; + node->base = base; + return ModularSet(node); +} + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const ModularSetNode *op, IRPrinter *p) { + p->stream << "ModularSet(" + << "coeff=" << op->coeff << ", base=" + << op->base << ')'; + }); + + +// internal entry for const int bound +struct ModularSetAnalyzer::Entry { + int64_t coeff{1}; + int64_t base{0}; + + bool is_const() const { + return coeff == 0; + } +}; + +class ModularSetAnalyzer::Impl : + public ExprFunctor { + public: + explicit Impl(Analyzer* parent) + : parent_(parent) {} + + void Update(const Var& var, + const ModularSet& info, + bool override) { + if (!override) { + CHECK(!var_map_.count(var)); + } + Entry e; + e.coeff = info->coeff; + e.base = info->base; + var_map_[var] = e; + } + + // Detect useful constraints and use them in the analysis scope. + std::function EnterConstraint(const Expr& constraint) { + PVar var; + PVar coeff, base; + // pattern match interesting constraints + if (((var % coeff) == base).Match(constraint)) { + Entry entry; + entry.coeff = coeff.Eval()->value; + entry.base = base.Eval()->value; + return UpdateByIntersect(var.Eval(), entry); + } + return nullptr; + } + + // Override visitor behaviors + Entry VisitExprDefault_(const Node* op) final { + return Everything(); + } + + Entry VisitExpr_(const Cast* op) final { + return VisitExpr(op->value); + } + + Entry VisitExpr_(const IntImm* op) final { + Entry ret; + ret.base = op->value; + ret.coeff = 0; + return ret; + } + + Entry VisitExpr_(const UIntImm* op) final { + if (op->value < std::numeric_limits::max()) { + Entry ret; + ret.base = static_cast(op->value); + ret.coeff = 0; + return ret; + } else { + return Everything(); + } + } + + Entry VisitExpr_(const Add* op) final { + Entry a = VisitExpr(op->a); + Entry b = VisitExpr(op->b); + Entry ret; + ret.coeff = ZeroAwareGCD(a.coeff, b.coeff); + ret.base = BaseSimplify(a.base + b.base, ret.coeff); + return ret; + } + + Entry VisitExpr_(const Sub* op) final { + Entry a = VisitExpr(op->a); + Entry b = VisitExpr(op->b); + Entry ret; + ret.coeff = ZeroAwareGCD(a.coeff, b.coeff); + ret.base = BaseSimplify(a.base - b.base, ret.coeff); + return ret; + } + + Entry VisitExpr_(const Mul* op) final { + Entry a = VisitExpr(op->a); + Entry b = VisitExpr(op->b); + // Simplification rule, x, y, z are in Z + // (p x + n) (q y + m) + // -> pq xy + pm x + qn y + mn + // -> pq z + pm x + qn y + mn + int64_t pq = a.coeff * b.coeff; + int64_t pm = a.coeff * b.base; + int64_t qn = a.base * b.coeff; + Entry ret; + ret.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn)); + ret.base = BaseSimplify(a.base * b.base, ret.coeff); + return ret; + } + + Entry DivByConst(const Expr& lhs, + int64_t val, + bool round_down) { + Entry a = VisitExpr(lhs); + CHECK_NE(val, 0); + if (a.coeff % val == 0) { + Entry ret; + if (a.base == 0) { + // a c x / c -> a x + ret.coeff = std::abs(a.coeff / val); + ret.base = 0; + return ret; + } + // positive division have a clear rounding mode. + // Only handle case where we clearly know we need to round down. + if (a.base > 0 && val > 0 && + (round_down || parent_->CanProveGreaterEqual(lhs, 0))) { + ret.coeff = a.coeff / val; + ret.base = a.base / val; + return ret; + } + } + return Everything(); + } + + Entry VisitExpr_(const Div* op) final { + Entry b = VisitExpr(op->b); + if (b.is_const()) { + return DivByConst(op->a, b.base, false); + } + return Everything(); + } + + Entry VisitExpr_(const Min* op) final { + Entry a = VisitExpr(op->a); + Entry b = VisitExpr(op->b); + return Union(a, b); + } + + Entry VisitExpr_(const Max* op) final { + Entry a = VisitExpr(op->a); + Entry b = VisitExpr(op->b); + return Union(a, b); + } + + Entry VisitExpr_(const Select* op) final { + Entry a = VisitExpr(op->true_value); + Entry b = VisitExpr(op->false_value); + return Union(a, b); + } + + Entry VisitExpr_(const Call* op) final { + // only special handle >> which can be + // used for index calculation. + if (op->is_intrinsic(Call::shift_right)) { + return VisitRightShift(op); + } else { + return Everything(); + } + } + + Entry VisitExpr_(const Variable* op) final { + Var v = GetRef(op); + auto it = var_map_.find(v); + if (it != var_map_.end()) { + return it->second; + } else { + return Everything(); + } + } + + Entry VisitRightShift(const Call* op) { + Entry b = VisitExpr(op->args[1]); + // a c x / c -> a x + if (b.is_const()) { + return DivByConst(op->args[0], 1 << b.base, true); + } + return Everything(); + } + + private: + /*! \brief pointer to parent. */ + Analyzer* parent_{nullptr}; + // internal variable map + std::unordered_map var_map_; + /*! + * \brief Update var by intersecting entry with var's current set. + * \param var The variable. + * \param entry The entry to be updated. + * \return The recovery function of the scope. + */ + std::function UpdateByIntersect(const Var& var, Entry entry) { + Entry old = Everything(); + auto it = var_map_.find(var); + if (it != var_map_.end()) { + old = it->second; + } + var_map_[var] = Intersect(old, entry); + // reover function. + return [this, old, var]() { + var_map_[var] = old; + }; + } + /*! + * \brief Create union of two sets. + * \param a The left operand. + * \param b the right operand. + */ + static Entry Union(Entry a, Entry b) { + // {ax + y} \cup {bz + h} => {gcd(a, b) x + {y or h}} + int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff); + if (coeff == 0) { + if (a.base == b.base) return a; + return Everything(); + } + int64_t base0 = a.base % coeff; + int64_t base1 = b.base % coeff; + Entry ret; + if (base0 == base1) { + ret.coeff = coeff; + ret.base = base0; + return ret; + } else { + ret.coeff = ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff); + ret.base = 0; + return ret; + } + } + /*! + * \brief Create interect of two sets. + * \param a The left operand. + * \param b the right operand. + */ + static Entry Intersect(Entry a, Entry b) { + // simple rule for now: pick higher constraints. + // TODO(team-team): Use extended euclidean algorithm. + if (a.coeff == 0) return a; + if (b.coeff == 0) return b; + if (a.coeff >= b.coeff) return a; + return b; + } + /*! + * \brief Simplify base so that it is in [0, coeff) when coeff != 0. + * \param base The base value. + * \param coeff The coeff value. + * \return The simplified base. + */ + static int64_t BaseSimplify(int64_t base, int64_t coeff) { + if (coeff == 0) return base; + base = base % coeff; + if (base < 0) base += coeff; + return base; + } + /*! + * \brief Take GCD of a and b. + * \param a The first operand. + * \param b The second operand. + * \return The result. + */ + static int64_t ZeroAwareGCD(int64_t a, int64_t b) { + if (a < 0) a = -a; + if (b < 0) b = -b; + if (a < b) std::swap(a, b); + if (b == 0) return a; + // perform GCD (greatest common divisor) + // ax + by = gcd(a, b) z if a != 0, b != 0 + while (a % b != 0) { + a = a % b; + std::swap(a, b); + } + return b; + } + /*! + * \brief return everything dtype can represent. + * \return Bound that represent everything dtype can represent. + */ + static Entry Everything() { + Entry ret; + ret.coeff = 1; ret.base = 0; + return ret; + } +}; + +ModularSet ModularSetAnalyzer::operator()(const Expr& expr) { + Entry ret = impl_->VisitExpr(expr); + return ModularSetNode::make(ret.coeff, ret.base); +} + +void ModularSetAnalyzer::Update(const Var& var, + const ModularSet& info, + bool override) { + impl_->Update(var, info, override); +} + +std::function ModularSetAnalyzer::EnterConstraint(const Expr& constraint) { + return impl_->EnterConstraint(constraint); +} + +ModularSetAnalyzer::ModularSetAnalyzer(Analyzer* parent) + : impl_(new Impl(parent)) { +} + +ModularSetAnalyzer::~ModularSetAnalyzer() { + delete impl_; +} + +} // namespace arith +} // namespace tvm diff --git a/src/arithmetic/pattern_match.h b/src/arithmetic/pattern_match.h index b4140d959759..50f2300dd4b7 100644 --- a/src/arithmetic/pattern_match.h +++ b/src/arithmetic/pattern_match.h @@ -25,6 +25,17 @@ * // The filled value is valid until the next call to Match. * return (max(x, y) + z).Eval(); * } + * + * tvm::Var tx, ty; + * arith::PVar c; + * arith::PVar v; + * // We can match integer and Var, both of which are + * // special case container of Expr + * CHECK((v * c).Match(tx * 3)); + * CHECK_EQ(c.Eval()->value, 3); + * // cannot match c to ty + * CHECK(!(v * c).Match(tx * ty)); + * * \endcode * * \note The pattern matcher is not threadsafe, @@ -109,6 +120,22 @@ class PEqualChecker { } }; +template<> +class PEqualChecker { + public: + bool operator()(const Integer& lhs, const Integer& rhs) const { + return lhs->value == rhs->value; + } +}; + +template<> +class PEqualChecker { + public: + bool operator()(const Var& lhs, const Var& rhs) const { + return lhs.same_as(rhs); + } +}; + /*! * \brief Pattern variable container. * @@ -123,7 +150,7 @@ template class PVar : public Pattern > { public: // Store PVars by reference in the expression. - using Nested = const PVar&; + using Nested = const PVar&; void InitMatch_() const { filled_ = false; @@ -139,12 +166,23 @@ class PVar : public Pattern > { } } + template::value>::type> + bool Match_(const NodeRefType& value) const { + if (const auto* ptr = value.template as()) { + return Match_(GetRef(ptr)); + } else { + return false; + } + } + T Eval() const { CHECK(filled_); return value_; } - private: + protected: /*! \brief The matched value */ mutable T value_; /*! \brief whether the variable has been filled */ @@ -171,6 +209,7 @@ class PConst : public Pattern > { T Eval() const { return value_; } + private: const T value_; }; diff --git a/src/codegen/codegen_common.h b/src/codegen/codegen_common.h deleted file mode 100644 index 5e76af12e583..000000000000 --- a/src/codegen/codegen_common.h +++ /dev/null @@ -1,59 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file codegen_common.h - * \brief Common utility for codegen. - */ -#ifndef TVM_CODEGEN_CODEGEN_COMMON_H_ -#define TVM_CODEGEN_CODEGEN_COMMON_H_ - -#include -#include "../arithmetic/compute_expr.h" - -namespace tvm { -namespace codegen { - -/*! - * \brief Visit AssertStmt recursively, update align_map from condition. - * \param op The AssertStmt - * \param align_map The alignmap - * \param fvisit The recursive visitor - * \tparam FVisit the recursive visitor - */ -template -inline void VisitAssert( - const ir::AssertStmt* op, - std::unordered_map* align_map, - FVisit fvisit) { - using namespace ir; - auto& align_map_ = *align_map; - // Detect useful invariant pattern and use them to visit child. - // Pattern: Var % const == 0 - // TODO(tqchen) merge these pattern to a generic scope info visitor. - if (const EQ* eq = op->condition.as()) { - const Mod* mod = eq->a.as(); - int64_t factor = 0, offset = 0; - if (mod && arith::GetConst(eq->b, &offset)) { - const Variable *var = mod->a.as(); - if (var && arith::GetConst(mod->b, &factor)) { - arith::ModularEntry old = align_map_[var]; - if (factor > old.coeff) { - arith::ModularEntry e; - e.coeff = static_cast(factor); - e.base = static_cast(offset); - // new alignment info, - align_map_[var] = e; - fvisit(op->body); - // restore old info - align_map_[var] = old; - return; - } - } - } - } - fvisit(op->body); -} - -} // namespace codegen -} // namespace tvm - -#endif // TVM_CODEGEN_CODEGEN_COMMON_H_ diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index f80bd9e8d436..6b69f97a66fe 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -9,7 +9,6 @@ #include #include "codegen_llvm.h" #include "codegen_cpu.h" -#include "../codegen_common.h" #include "../../pass/ir_util.h" #include "../../arithmetic/compute_expr.h" @@ -84,9 +83,9 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) { void CodeGenLLVM::InitFuncState() { var_map_.clear(); alias_var_set_.clear(); - align_map_.clear(); alloc_storage_info_.clear(); volatile_buf_.clear(); + analyzer_.reset(new arith::Analyzer()); } void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) { @@ -381,14 +380,16 @@ void CodeGenLLVM::GetAlignment(Type t, *p_native_bits = native_vector_bits_; } - arith::ModularEntry me = arith::EvalModular(index, align_map_); + arith::ModularSet me = analyzer_->modular_set(index); + int64_t base = me->base; + int64_t coeff = me->coeff; int align_bits = t.bits(); while (align_bits < max_align_bits && - me.base % 2 == 0 && - me.coeff % 2 == 0) { - me.base = me.base / 2; - me.coeff = me.coeff / 2; + base % 2 == 0 && + coeff % 2 == 0) { + base = base / 2; + coeff = coeff / 2; align_bits *= 2; } if (align_bits < 8) { @@ -874,7 +875,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Select* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) { CHECK(!var_map_.count(op->var.get())); var_map_[op->var.get()] = MakeValue(op->value); - align_map_[op->var.get()] = EvalModular(op->value, align_map_); + analyzer_->Bind(op->var, op->value); return MakeValue(op->body); } @@ -998,6 +999,7 @@ void CodeGenLLVM::VisitStmt_(const Store* op) { void CodeGenLLVM::VisitStmt_(const For* op) { CHECK(is_zero(op->min)); + analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); if (op->for_type == ForType::Unrolled) { LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, " << " consider set unroll_explicit=True"; @@ -1078,6 +1080,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { if (iv->thread_tag.length() != 0) { if (!var_map_.count(iv->var.get())) { var_map_[iv->var.get()] = GetThreadIndex(iv); + analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value)); } } } else if (op->attr_key == ir::attr::storage_scope) { @@ -1099,21 +1102,19 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { } void CodeGenLLVM::VisitStmt_(const AssertStmt* op) { - VisitAssert(op, &align_map_, [this](const Stmt& body) { - this->VisitStmt(body); - }); + arith::ConstraintContext cctx(analyzer_.get(), op->condition); + this->VisitStmt(op->body); } void CodeGenLLVM::VisitStmt_(const LetStmt* op) { CHECK(!var_map_.count(op->var.get())); - CHECK(!align_map_.count(op->var.get())); if (op->var.type().is_handle()) { if (!is_restricted_) { alias_var_set_.insert(op->var.get()); } } var_map_[op->var.get()] = MakeValue(op->value); - align_map_[op->var.get()] = EvalModular(op->value, align_map_); + analyzer_->Bind(op->var, op->value); this->VisitStmt(op->body); } diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index 080306310370..ead1af883166 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -23,7 +23,6 @@ namespace codegen { using namespace ir; - /*! * \brief A base class to generate a LLVM. */ @@ -267,8 +266,8 @@ class CodeGenLLVM : std::unordered_map str_map_; // Whether current function is restricted bool is_restricted_{true}; - // The alignment information - std::unordered_map align_map_; + // The analyzer information + std::unique_ptr analyzer_; // set of var that are not restricted(can alias) std::unordered_set alias_var_set_; // set of volatile buffer. diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index 812fee4a114e..8b1cabd9e386 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -6,7 +6,7 @@ #include #include #include -#include "../codegen_common.h" +#include "../../arithmetic/compute_expr.h" #include "codegen_spirv.h" namespace tvm { @@ -66,7 +66,7 @@ void CodeGenSPIRV::InitFuncState() { std::fill(workgroup_size_, workgroup_size_ + 3, 1); var_map_.clear(); storage_info_.clear(); - align_map_.clear(); + analyzer_.reset(new arith::Analyzer()); builder_.reset(new spirv::IRBuilder()); builder_->InitHeader(); } @@ -217,7 +217,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Select* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const Let* op) { CHECK(!var_map_.count(op->var.get())); var_map_[op->var.get()] = MakeValue(op->value); - align_map_[op->var.get()] = EvalModular(op->value, align_map_); + analyzer_->Bind(op->var, op->value); return MakeValue(op->body); } @@ -378,9 +378,9 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Load* op) { if (const Ramp* ramp = op->index.as()) { if (is_one(ramp->stride)) { CHECK_EQ(ramp->lanes, op->type.lanes()); - arith::ModularEntry me = arith::EvalModular(ramp->base, align_map_); - CHECK((me.coeff % ramp->lanes) == 0 && - (me.base % ramp->lanes) == 0) + arith::ModularSet me = analyzer_->modular_set(ramp->base); + CHECK((me->coeff % ramp->lanes) == 0 && + (me->base % ramp->lanes) == 0) << "Only aligned vector access is allowed in SPIRV"; Expr vec_index = ir::Simplify( ramp->base / make_const(ramp->base.type(), ramp->lanes)); @@ -458,9 +458,9 @@ void CodeGenSPIRV::VisitStmt_(const Store* op) { if (const Ramp* ramp = op->index.as()) { if (is_one(ramp->stride)) { CHECK_EQ(ramp->lanes, op->value.type().lanes()); - arith::ModularEntry me = arith::EvalModular(ramp->base, align_map_); - CHECK((me.coeff % ramp->lanes) == 0 && - (me.base % ramp->lanes) == 0) + arith::ModularSet me = analyzer_->modular_set(ramp->base); + CHECK((me->coeff % ramp->lanes) == 0 && + (me->base % ramp->lanes) == 0) << "Only aligned vector access is allowed in SPIRV"; Expr vec_index = ir::Simplify( ramp->base / make_const(ramp->base.type(), ramp->lanes)); @@ -477,6 +477,7 @@ void CodeGenSPIRV::VisitStmt_(const Store* op) { void CodeGenSPIRV::VisitStmt_(const For* op) { CHECK(is_zero(op->min)); + analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); spirv::Value init_value = MakeValue(op->min); spirv::Value extent_value = MakeValue(op->extent); // Must get init label after making value(to make sure they are correct) @@ -589,6 +590,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) { if (iv->thread_tag.length() != 0) { if (!var_map_.count(iv->var.get())) { var_map_[iv->var.get()] = GetThreadIndex(iv, op->value); + analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value)); } } } else if (op->attr_key == ir::attr::storage_scope) { @@ -605,17 +607,15 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) { } void CodeGenSPIRV::VisitStmt_(const AssertStmt* op) { - VisitAssert(op, &align_map_, [this](const Stmt& body) { - this->VisitStmt(body); - }); + arith::ConstraintContext cctx(analyzer_.get(), op->condition); + this->VisitStmt(op->body); } void CodeGenSPIRV::VisitStmt_(const LetStmt* op) { CHECK(!var_map_.count(op->var.get())); - CHECK(!align_map_.count(op->var.get())); CHECK(!op->var.type().is_handle()); var_map_[op->var.get()] = MakeValue(op->value); - align_map_[op->var.get()] = EvalModular(op->value, align_map_); + analyzer_->Bind(op->var, op->value); this->VisitStmt(op->body); } diff --git a/src/codegen/spirv/codegen_spirv.h b/src/codegen/spirv/codegen_spirv.h index 6a43182f7f2e..94cf761b9f84 100644 --- a/src/codegen/spirv/codegen_spirv.h +++ b/src/codegen/spirv/codegen_spirv.h @@ -122,8 +122,8 @@ class CodeGenSPIRV: std::unordered_map storage_info_; // The definition of local variable. std::unordered_map var_map_; - // The alignment information - std::unordered_map align_map_; + // The analyzer. + std::unique_ptr analyzer_; }; } // namespace codegen diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 9ba9dcde63c9..3f7fd9512eb2 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -936,10 +936,8 @@ class VectorAllocRewriter : public IRMutator { tvec[0].lanes() != op->type.lanes()) { int factor = tvec[0].lanes() / op->type.lanes(); Array extents = op->extents; - arith::ModularEntry me = EvalModular( - extents[extents.size() - 1], - std::unordered_map()); - if (me.base % factor == 0 && me.coeff % factor == 0) { + arith::ModularSet me = analyzer_.modular_set(extents[extents.size() - 1]); + if (me->base % factor == 0 && me->coeff % factor == 0) { extents.Set(extents.size() - 1, extents[extents.size() - 1] / make_const(extents[0].type(), factor)); return Allocate::make( @@ -959,6 +957,8 @@ class VectorAllocRewriter : public IRMutator { // Internal access map std::unordered_map > acc_map_; + // internal analyzer + arith::Analyzer analyzer_; }; diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index cb746e65660b..1945339a259c 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -107,6 +107,23 @@ TEST(Pattern, Basic) { } } +TEST(Pattern, Integer) { + using namespace tvm; + tvm::Var tx, ty; + arith::PVar c; + arith::PVar v; + { + // We can match integer and Var, both of which are + // special case container of Expr + CHECK((v * c).Match(tx * 3)); + CHECK_EQ(c.Eval()->value, 3); + } + // cannot match c to ty + CHECK(!(v * c).Match(tx * ty)); + // cannot match tx + 1 to v + CHECK(!(v * c).Match((tx + 1) * 3)); +} + int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/cpp/unittest.mk b/tests/cpp/unittest.mk deleted file mode 100644 index b810d63ee4b1..000000000000 --- a/tests/cpp/unittest.mk +++ /dev/null @@ -1,12 +0,0 @@ -GTEST_LIB=$(GTEST_PATH)/lib/ -GTEST_INC=$(GTEST_PATH)/include/ - -TEST_SRC = $(wildcard tests/cpp/*_test.cc) -TEST = $(patsubst tests/cpp/%_test.cc, tests/cpp/%_test, $(TEST_SRC)) - -tests/cpp/%_test: tests/cpp/%_test.cc lib/libtvm.so - $(CXX) -std=c++11 $(CFLAGS) -MM -MT tests/cpp/$* $< >tests/cpp/$*.d - $(CXX) -std=c++11 $(CFLAGS) -I$(GTEST_INC) -o $@ $(filter %.cc %.a, $^) \ - -L$(GTEST_LIB) $(LDFLAGS) -lgtest -Llib -ltvm - --include tests/cpp/*.d diff --git a/tests/python/unittest/test_arith_const_int_bound.py b/tests/python/unittest/test_arith_const_int_bound.py new file mode 100644 index 000000000000..968692208f5d --- /dev/null +++ b/tests/python/unittest/test_arith_const_int_bound.py @@ -0,0 +1,219 @@ +import tvm + +def test_dtype_bound(): + analyzer = tvm.arith.Analyzer() + + x = tvm.var("x", dtype="int64") + bd = analyzer.const_int_bound(x) + assert bd.min_value == bd.NEG_INF + assert bd.max_value == bd.POS_INF + + x = tvm.var("x", dtype="int8") + bd = analyzer.const_int_bound(x) + assert bd.min_value == -128 + assert bd.max_value == 127 + + x = tvm.var("x", dtype="uint8") + bd = analyzer.const_int_bound(x) + assert bd.min_value == 0 + assert bd.max_value == 255 + + +def test_cast_bound(): + analyzer = tvm.arith.Analyzer() + x = tvm.var("x", dtype="int8") + bd = analyzer.const_int_bound((x % 3).astype("uint32")) + assert bd.min_value == 0 + assert bd.max_value == 2 + + bd = analyzer.const_int_bound( + (x % 3).astype("float32").astype("int32")) + assert bd.min_value == -2 + assert bd.max_value == 2 + + +def test_add_sub_bound(): + analyzer = tvm.arith.Analyzer() + x, y = tvm.var("x", "int64"), tvm.var("y", "int64") + bd = analyzer.const_int_bound(x + y) + assert bd.min_value == bd.NEG_INF + assert bd.max_value == bd.POS_INF + + analyzer.update(x, tvm.arith.ConstIntBound(0, 4)) + analyzer.update(y, tvm.arith.ConstIntBound(1, 10)) + bd = analyzer.const_int_bound(x + y) + assert bd.min_value == 1 + assert bd.max_value == 14 + + bd = analyzer.const_int_bound(x - y) + assert bd.min_value == -10 + assert bd.max_value == 3 + + analyzer.update(x, tvm.arith.ConstIntBound(0, bd.POS_INF), override=True) + bd = analyzer.const_int_bound(x - y) + assert bd.min_value == -10 + assert bd.max_value == bd.POS_INF + + bd = analyzer.const_int_bound(1 - x) + assert bd.min_value == bd.NEG_INF + assert bd.max_value == 1 + + +def test_mul_bound(): + analyzer = tvm.arith.Analyzer() + x, y = tvm.var("x"), tvm.var("y") + + analyzer.update(x, tvm.arith.ConstIntBound(-2, 4)) + analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) + bd = analyzer.const_int_bound(x * y + 20) + assert bd.min_value == 0 + assert bd.max_value == 60 + + analyzer.update(x, tvm.arith.ConstIntBound(-3, 4), override=True) + analyzer.update(y, tvm.arith.ConstIntBound(-8, 2), override=True) + bd = analyzer.const_int_bound(x * y) + assert bd.min_value == -32 + assert bd.max_value == 24 + + analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True) + analyzer.update(y, tvm.arith.ConstIntBound(-8, 2), override=True) + bd = analyzer.const_int_bound(x * y) + assert bd.min_value == bd.NEG_INF + assert bd.max_value == bd.POS_INF + + +def test_div_bound(): + analyzer = tvm.arith.Analyzer() + x, y = tvm.var("x"), tvm.var("y") + + analyzer.update(x, tvm.arith.ConstIntBound(-9, 4)) + analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) + bd = analyzer.const_int_bound(x / y) + assert bd.min_value == -2 + + analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True) + analyzer.update(y, tvm.arith.ConstIntBound(-2, 0), override=True) + bd = analyzer.const_int_bound(x / y) + assert bd.min_value == -4 + assert bd.max_value == 9 + + analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True) + analyzer.update(y, tvm.arith.ConstIntBound(-2, 1), override=True) + bd = analyzer.const_int_bound(x / y) + assert bd.min_value == bd.NEG_INF + assert bd.max_value == bd.POS_INF + + +def test_mod_bound(): + analyzer = tvm.arith.Analyzer() + x, y = tvm.var("x"), tvm.var("y") + + analyzer.update(x, tvm.arith.ConstIntBound(-9, 4)) + analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) + bd = analyzer.const_int_bound(x % y) + assert bd.min_value == -9 + assert bd.max_value == 4 + + analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True) + analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) + bd = analyzer.const_int_bound(x % y) + assert bd.min_value == -9 + assert bd.max_value == 9 + + analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True) + analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) + bd = analyzer.const_int_bound(x % y) + assert bd.min_value == 0 + assert bd.max_value == 9 + + +def test_min_max_bound(): + analyzer = tvm.arith.Analyzer() + x, y = tvm.var("x"), tvm.var("y") + + analyzer.update(x, tvm.arith.ConstIntBound(-9, 11)) + analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) + bd = analyzer.const_int_bound(tvm.min(x, y)) + assert bd.min_value == -9 + assert bd.max_value == 10 + + analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True) + analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) + bd = analyzer.const_int_bound(tvm.min(x, y)) + assert bd.min_value == bd.NEG_INF + assert bd.max_value == 10 + + bd = analyzer.const_int_bound(tvm.max(x, y)) + assert bd.min_value == 4 + assert bd.max_value == bd.POS_INF + + analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True) + analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) + bd = analyzer.const_int_bound(tvm.max(x, y)) + assert bd.min_value == 4 + assert bd.max_value == bd.POS_INF + + +def test_select_bound(): + analyzer = tvm.arith.Analyzer() + x, y = tvm.var("x"), tvm.var("y") + + analyzer.update(x, tvm.arith.ConstIntBound(-9, 11)) + analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) + + bd = analyzer.const_int_bound( + tvm.expr.Select(x > 1, (y < 0).astype("int32"), y + 1)) + assert bd.min_value == 0 + assert bd.max_value == 11 + + +def test_shift_and_bound(): + analyzer = tvm.arith.Analyzer() + x, y = tvm.var("x"), tvm.var("y") + + analyzer.update(x, tvm.arith.ConstIntBound(-9, 11)) + analyzer.update(y, tvm.arith.ConstIntBound(2, 10)) + + bd = analyzer.const_int_bound(x >> y) + assert bd.min_value == -3 + assert bd.max_value == 2 + + bd = analyzer.const_int_bound(x & y) + assert bd.min_value == 0 + assert bd.max_value == 10 + + analyzer.update(x, tvm.arith.ConstIntBound(10, 11), override=True) + bd = analyzer.const_int_bound(x & y) + assert bd.min_value == 0 + assert bd.max_value == 10 + + +def test_mix_index_bound(): + analyzer = tvm.arith.Analyzer() + x, y = tvm.var("x"), tvm.var("y") + analyzer.update(x, tvm.arith.ConstIntBound(0, 24 - 1)) + analyzer.update(y, tvm.arith.ConstIntBound(0, 3 - 1)) + bd = analyzer.const_int_bound((x % 8) + (x / 8) * 8) + assert bd.min_value == 0 + assert bd.max_value == 24 - 1 + + bd = analyzer.const_int_bound(y + x * 3) + assert bd.min_value == 0 + assert bd.max_value == 24 * 3 - 1 + + bd = analyzer.const_int_bound((x % 7) + (x / 7) * 7) + assert bd.min_value == 0 + assert bd.max_value == (23 // 7) * 7 + 6 + + +if __name__ == "__main__": + test_dtype_bound() + test_cast_bound() + test_add_sub_bound() + test_mul_bound() + test_div_bound() + test_mod_bound() + test_min_max_bound() + test_select_bound() + test_shift_and_bound() + test_mix_index_bound() diff --git a/tests/python/unittest/test_arith_modular.py b/tests/python/unittest/test_arith_modular.py deleted file mode 100644 index 58b5d3115d5e..000000000000 --- a/tests/python/unittest/test_arith_modular.py +++ /dev/null @@ -1,32 +0,0 @@ -import tvm - -def test_basic(): - a = tvm.var() - b = tvm.var() - m = tvm.arith.EvalModular(a * 4 + b * 6 + 7) - assert m.coeff == 2 - assert m.base == 1 - - m = tvm.arith.EvalModular((a * 4 + 1) * (b * 8 + 3)) - assert m.coeff == 4 - assert m.base == 3 - - m = tvm.arith.EvalModular((a * 4 + 1) / (b * 8 + 3)) - assert m.coeff == 1 - assert m.base == 0 - - m = tvm.arith.EvalModular((a * 4 + 1) * (b * 8 / 4)) - assert m.coeff == 2 - assert m.base == 0 - - m = tvm.arith.EvalModular((a * 12 + 1) - (b * 3 * 7 + 2)) - assert m.coeff == 3 - assert m.base == 2 - - - m = tvm.arith.EvalModular(a * 12 + tvm.min(b * 3 * 7, 2)) - assert m.coeff == 1 - assert m.base == 0 - -if __name__ == "__main__": - test_basic() diff --git a/tests/python/unittest/test_arith_modular_set.py b/tests/python/unittest/test_arith_modular_set.py new file mode 100644 index 000000000000..06ae5197b974 --- /dev/null +++ b/tests/python/unittest/test_arith_modular_set.py @@ -0,0 +1,128 @@ +import tvm + + +def test_cast(): + analyzer = tvm.arith.Analyzer() + x = tvm.var("x", dtype="int8") + m = analyzer.modular_set((x * 3).astype("uint32")) + assert m.coeff == 3 + assert m.base == 0 + m = analyzer.modular_set( + (x * 3 + 1).astype("float32").astype("int32")) + assert m.coeff == 3 + assert m.base == 1 + + +def test_add_sub(): + analyzer = tvm.arith.Analyzer() + x, y = tvm.var("x", "int64"), tvm.var("y", "int64") + m = analyzer.modular_set(x * 6 + y * 4) + assert m.coeff == 2 + assert m.base == 0 + + analyzer.bind(y, x * 4 + 1) + m = analyzer.modular_set(1 - y) + assert m.coeff == 4 + assert m.base == 0 + + +def test_mul(): + analyzer = tvm.arith.Analyzer() + x, y = tvm.var("x"), tvm.var("y") + m = analyzer.modular_set((x * 4 + 2) * (y * 6 + 1)) + assert m.coeff == 4 + assert m.base == 2 + + +def test_div_shift(): + analyzer = tvm.arith.Analyzer() + x, y = tvm.var("x"), tvm.var("y") + # not sure if x is non-negative + m = analyzer.modular_set((x * 4 + 2) / 2) + assert m.coeff == 1 + assert m.base == 0 + # right shift always round down so it is fine + m = analyzer.modular_set((x * 4 + 2) >> 1) + assert m.coeff == 2 + assert m.base == 1 + # x is non-negative + analyzer.update(x, tvm.arith.ConstIntBound(0, 100)) + m = analyzer.modular_set((x * 4 + 2) / 2) + assert m.coeff == 2 + assert m.base == 1 + + +def test_min_max_select(): + analyzer = tvm.arith.Analyzer() + x, y = tvm.var("x"), tvm.var("y") + m = analyzer.modular_set(tvm.min(x * 3, y * 9)) + assert m.coeff == 3 + assert m.base == 0 + + m = analyzer.modular_set(tvm.max(x * 3 + 1, y * 9 + 4)) + assert m.coeff == 3 + assert m.base == 1 + + m = analyzer.modular_set(tvm.expr.Select(x > 0, x * 3 + 1, y * 9 + 2)) + assert m.coeff == 1 + assert m.base == 0 + + +def test_mix_index(): + a = tvm.var("a") + b = tvm.var("b") + analyzer = tvm.arith.Analyzer() + m = analyzer.modular_set(a * 4 + b * 6 + 7) + assert m.coeff == 2 + assert m.base == 1 + + m = analyzer.modular_set((a * 4 + 1) * (b * 8 + 3)) + assert m.coeff == 4 + assert m.base == 3 + + m = analyzer.modular_set((a * 4 + 1) / (b * 8 + 3)) + assert m.coeff == 1 + assert m.base == 0 + + m = analyzer.modular_set((a * 4 + 1) * (b * 8 / 4)) + assert m.coeff == 2 + assert m.base == 0 + + m = analyzer.modular_set((a * 12 + 1) - (b * 3 * 7 + 2)) + assert m.coeff == 3 + assert m.base == 2 + + m = analyzer.modular_set(a * 12 + tvm.min(b * 3 * 7, 2)) + assert m.coeff == 1 + assert m.base == 0 + + +def test_constraint_scope(): + a = tvm.var("a") + b = tvm.var("b") + analyzer = tvm.arith.Analyzer() + with analyzer.constraint_scope(b % 4 == 2): + m = analyzer.modular_set(b + 1) + assert m.coeff == 4 + assert m.base == 3 + with analyzer.constraint_scope(a % 2 == 1): + m = analyzer.modular_set(b + a * 2) + assert m.coeff == 4 + assert m.base == 0 + m = analyzer.modular_set(b + a * 2) + assert m.coeff == 2 + assert m.base == 0 + + m = analyzer.modular_set(b + 1) + assert m.coeff == 1 + assert m.base == 0 + + +if __name__ == "__main__": + test_cast() + test_add_sub() + test_mul() + test_div_shift() + test_min_max_select() + test_mix_index() + test_constraint_scope()