diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index c4ee7b5b6279..cc64294c92ca 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -464,11 +464,16 @@ class TVM_DLL Analyzer { * \brief Simplify expr. * * \param expr The expression to be simplified. + * \param steps The simplification runs in the order of + * rewrite_simplify (step 1) -> canonical_simplify (step 2) -> + * rewrite_simplify (step 3) -> canonical_simplify (step 4) -> ... + * param steps controls how many steps to run. + * Default is 2, i.e., rewrite_simplify + canonical_simplify. * \return The result. * * \note Analyzer will call into sub-analyzers to get the result. */ - PrimExpr Simplify(const PrimExpr& expr); + PrimExpr Simplify(const PrimExpr& expr, int steps = 2); }; } // namespace arith diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index ae18cab0a9fa..ad044b288941 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -26,10 +26,14 @@ #include #include +#include #include +#include #include +#include "analyzer.h" + namespace tvm { namespace arith { @@ -37,6 +41,95 @@ using tir::IterVar; using tir::Var; using tir::VarNode; +/*! + * \brief Represent integer grouped bounds which are classified into + * lower bounds (inclusive), upper bounds (inclusive) and equalities. + * It also contains coefficient as a multiplier for the bounds, i.e., + * coef * var >= lower + * coef * var == equal + * coef * var <= upper + * \sa IntGroupBounds + */ +class IntGroupBoundsNode : public Object { + public: + PrimExpr coef; + Array lower; + Array equal; + Array upper; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("coef", &coef); + v->Visit("lower", &lower); + v->Visit("equal", &equal); + v->Visit("upper", &upper); + } + + bool SEqualReduce(const IntGroupBoundsNode* other, SEqualReducer eq) const { + return eq(coef, other->coef) && eq(lower, other->lower) && eq(equal, other->equal) && + eq(upper, other->upper); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(coef); + hash_reduce(lower); + hash_reduce(equal); + hash_reduce(upper); + } + + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const char* _type_key = "arith.IntGroupBounds"; + TVM_DECLARE_FINAL_OBJECT_INFO(IntGroupBoundsNode, Object); +}; + +/*! + * \brief Managed reference to IntGroupBoundsNode. + * \sa IntGroupBoundsNode + */ +class IntGroupBounds : public ObjectRef { + public: + /*! + * \brief Constructor by fields + * \param coef The coefficient. Must be integer. + * coef * var >= lower + * coef * var == equal + * coef * var >= upper + * \param lower the lower bounds (include) + * \param equal equalities + * \param upper the upper bounds (include) + */ + TVM_DLL IntGroupBounds(PrimExpr coef, Array lower, Array equal, + Array upper); + + /*! + * \brief Construct bounds from a range. + * \param r The range + * \return constructed bounds. + */ + static IntGroupBounds FromRange(const Range& r); + + /*! + * \brief Perform substitution on all components of the struct. + */ + IntGroupBounds Substitute(const Map& subst) const; + + /*! + * \brief Find the best range from the grouped bounds. + * \param vranges_addl additional variable ranges that help infer the best range. + * \return The best range (has the least difference between the lower bound and upper bound). + * undefined if (-inf, +inf). + */ + Range FindBestRange(const Map& vranges_addl = {}) const; + + /*! + * \brief Combine the bounds with another range. + * \param r range to be combined. + * \return combined bounds. + */ + IntGroupBounds operator+(const Range& r); + + TVM_DEFINE_OBJECT_REF_METHODS(IntGroupBounds, ObjectRef, IntGroupBoundsNode); +}; + /*! * \brief Represent integer constrains including (integer) variables, their ranges and * the relations between them (either equations or inequalities). @@ -161,6 +254,8 @@ class IntConstraintsTransform : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode); }; +typedef std::pair, Array> PartialSolvedInequalities; + /*! * \brief Obtain Smith Normal Form of linear equation A x = y. * Smith Normal Form of matrix A_{mxn} is S_{mxn} = U_{mxm} A_{mxn} V_{nxn}, @@ -191,6 +286,56 @@ void SmithNormalFormDiag(std::vector>* S, std::vector= f0(x1, x2, ..., xn) + * x0 <= g0(x1, x2, ..., xn) + * x1 >= f1(x2, ..., xn) + * x1 <= g1(x2, ..., xn) + * ... + * xn >= fn() // just a constant + * xn <= gn() // just a constant + * + * \return A map of variables and their solved bounds, + * and constrains that cannot be solved to bounds. + */ +PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_to_solve); + +/*! + * \brief Solve linear inequalities and infer the range of each variable. + * \param system_to_solve the variables to solve, their ranges, and a list of inequalities. + * \return The result ranges for each variables. + * The returned IntConstraints(variables, ranges, relations) contains, + * 1. variables - the variables that have been solved. + * 2. ranges - the best range of each variable. + * 3. relations - constraints that cannot be transformed to + * Range will be stored in relations. + */ +IntConstraints SolveInequalitiesToRange(const IntConstraints& system_to_solve); + +/*! + * \brief Solve linear inequalities and deskew the ranges towards zero. + * \param system_to_solve the variables to solve, their ranges, and a list of inequalities. + * \return A transform (src IntConstraints -> dst IntConstraints) + * from original variables to a set of new variables. + * The ranges of new variables always start from zero, + * their extents are solved from \p system_to_solve. + * src IntConstraints is the same as \p system_to_solve. + * dst IntConstraints(variables, ranges, relations) contains, + * 1. variables - the variables that have been solved. + * 2. ranges - the best range (start from zero) of each variable. + * 3. relations - constraints that cannot be transformed to + * Range will be stored in relations. + * Variable mapping can be obtained from + * IntConstraintsTransform.src_to_dst and IntConstraintsTransform.dst_to_src. + */ +IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& system_to_solve); + } // namespace arith } // namespace tvm #endif // TVM_ARITH_INT_SOLVER_H_ diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index 017934a03b34..e5af52938f5c 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -20,4 +20,4 @@ from .analyzer import ModularSet, ConstIntBound, Analyzer from .bound import deduce_bound from .pattern import detect_linear_equation, detect_clip_bound -from .int_solver import solve_linear_equations +from .int_solver import solve_linear_equations, solve_linear_inequalities diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index 5a420ad81755..e841de906d5e 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -119,20 +119,25 @@ def modular_set(self, expr): """ return self._modular_set(expr) - def simplify(self, expr): + def simplify(self, expr, steps=2): """Simplify expression via both rewrite and canonicalization. Parameters ---------- expr : PrimExpr The expression. + steps : The simplification runs in the order of + rewrite_simplify (step 1) -> canonical_simplify (step 2) -> + rewrite_simplify (step 3) -> canonical_simplify (step 4) -> ... + param steps controls how many steps to run. + Default is 2, i.e., rewrite_simplify + canonical_simplify. Returns ------- result : Expr The result. """ - return self._simplify(expr) + return self._simplify(expr, steps) def rewrite_simplify(self, expr): """Simplify expression via rewriting rules. diff --git a/python/tvm/arith/int_solver.py b/python/tvm/arith/int_solver.py index e35435c1da03..91fa459b8aa8 100644 --- a/python/tvm/arith/int_solver.py +++ b/python/tvm/arith/int_solver.py @@ -20,6 +20,52 @@ from . import _ffi_api +@tvm._ffi.register_object("arith.IntGroupBounds") +class IntGroupBounds(Object): + """Represent integer grouped bounds which are classified into + lower bounds (include), upper bounds (include) and equalities. + + Parameters + ---------- + coef : tvm.ir.PrimExpr + The coefficient. Must be integer type. + coef * var >= lower + coef * var == equal + coef * var >= upper + lower : List[tvm.ir.PrimExpr] + the lower bounds (include) + equal : List[tvm.ir.PrimExpr] + equalities + upper : List[tvm.ir.PrimExpr] + the upper bounds (include) + """ + def __init__(self, coef, lower, equal, upper): + self.__init_handle_by_constructor__( + _ffi_api.IntGroupBounds, coef, lower, equal, upper) + + @staticmethod + def from_range(rng): + """Construct a IntGroupedBounds by Range. + + Parameters + ---------- + rng : tvm.ir.Range + + + Returns + ------- + ret : Range + The constructed range. + """ + return _ffi_api.IntGroupBounds_from_range(rng) + + def find_best_range(self): + """Return the best range from the grouped bounds. + None if (-inf, +inf). + """ + return _ffi_api.IntGroupBounds_FindBestRange(self) + + @tvm._ffi.register_object("arith.IntConstraints") class IntConstraints(Object): """Represent a set of integer constraints including variables, their ranges and @@ -97,3 +143,35 @@ def solve_linear_equations(equations, variables=None, ranges=None): if isinstance(equations, IntConstraints): return _ffi_api.SolveLinearEquations(equations) return _ffi_api.SolveLinearEquations(variables, ranges, equations) + + +def solve_linear_inequalities(equations, variables=None, ranges=None, deskew_range=False): + """Solve linear inequalities. + + Parameters + ---------- + equations : List[tvm.ir.PrimExpr] or IntConstraints + The inequalities of the variables + variables : Optional[List[tvm.tir.Var]] + The variables in the system. + ranges : Optional[Map[tvm.tir.Var, tvm.ir.Range]] + The ranges of the variables. + deskew_range: Optional[bool] + Whether deskew the result ranges to be started from zero. + Default false. + + Returns + ------- + ret_ranges: IntConstraints or IntConstraintsTransform + The result ranges for each variables. + Constrains that cannot be transformed to Range will be stored in IntConstraints.relations. + If deskew_range is set (=True), the result ranges will be deskewed to be started from zero. + New variables are created accordingly therefore IntConstraintsTransform is returned. + """ + solver = _ffi_api.SolveInequalitiesDeskewRange \ + if deskew_range else _ffi_api.SolveInequalitiesToRange + if isinstance(equations, IntConstraints): + assert variables is None + assert ranges is None + return solver(equations) + return solver(variables, ranges, equations) diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 5a3d394c098f..7483a9fb4cf8 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -15,13 +15,14 @@ # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name +# pylint: disable=invalid-name,unnecessary-comprehension """ TVM testing utilities """ import logging import numpy as np import tvm import tvm.arith import tvm.tir +import tvm.te import tvm._ffi @@ -188,5 +189,100 @@ def assert_prim_expr_equal(lhs, rhs): raise ValueError("{} and {} are not equal".format(lhs, rhs)) +def check_bool_expr_is_true(bool_expr, vranges, cond=None): + """ Check that bool_expr holds given the condition cond + for every value of free variables from vranges. + + for example, 2x > 4y solves to x > 2y given x in (0, 10) and y in (0, 10) + here bool_expr is x > 2y, vranges is {x: (0, 10), y: (0, 10)}, cond is 2x > 4y + We creates iterations to check, + for x in range(10): + for y in range(10): + assert !(2x > 4y) || (x > 2y) + + Parameters + ---------- + bool_expr : tvm.ir.PrimExpr + Boolean expression to check + vranges: Dict[tvm.tir.expr.Var, tvm.ir.Range] + Free variables and their ranges + cond: tvm.ir.PrimExpr + extra conditions needs to be satisfied. + """ + if cond is not None: + bool_expr = tvm.te.any(tvm.tir.Not(cond), bool_expr) + + def _run_expr(expr, vranges): + """ Evaluate expr for every value of free variables + given by vranges and return the tensor of results. + """ + def _compute_body(*us): + vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)} + return tvm.tir.stmt_functor.substitute(expr, vmap) + + A = tvm.te.compute([r.extent.value for v, r in vranges.items()], _compute_body) + args = [tvm.nd.empty(A.shape, A.dtype)] + sch = tvm.te.create_schedule(A.op) + mod = tvm.build(sch, [A]) + mod(*args) + return args[0].asnumpy() + + res = _run_expr(bool_expr, vranges) + if not np.all(res): + indices = list(np.argwhere(res == 0)[0]) + counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)] + counterex = sorted(counterex, key=lambda x: x[0]) + counterex = ", ".join([v + " = " + str(i) for v, i in counterex]) + ana = tvm.arith.Analyzer() + raise AssertionError("Expression {}\nis not true on {}\n" + "Counterexample: {}" + .format(ana.simplify(bool_expr), vranges, counterex)) + + +def check_int_constraints_trans_consistency(constraints_trans, vranges=None): + """ Check IntConstraintsTransform is a bijective transformation. + + Parameters + ---------- + constraints_trans : arith.IntConstraintsTransform + Integer constraints transformation + vranges: Dict[tvm.tir.Var, tvm.ir.Range] + Free variables and their ranges + """ + if vranges is None: + vranges = {} + + def _check_forward(constraints1, constraints2, varmap, backvarmap): + ana = tvm.arith.Analyzer() + all_vranges = vranges.copy() + all_vranges.update({v: r for v, r in constraints1.ranges.items()}) + + # Check that the transformation is injective + cond_on_vars = tvm.tir.const(1, 'bool') + for v in constraints1.variables: + if v in varmap: + # variable mapping is consistent + v_back = ana.simplify(tvm.tir.stmt_functor.substitute(varmap[v], backvarmap)) + cond_on_vars = tvm.te.all(cond_on_vars, v == v_back) + # Also we have to check that the new relations are true when old relations are true + cond_subst = tvm.tir.stmt_functor.substitute( + tvm.te.all(tvm.tir.const(1, 'bool'), *constraints2.relations), backvarmap) + # We have to include relations from vranges too + for v in constraints2.variables: + if v in constraints2.ranges: + r = constraints2.ranges[v] + range_cond = tvm.te.all(v >= r.min, v < r.min + r.extent) + range_cond = tvm.tir.stmt_functor.substitute(range_cond, backvarmap) + cond_subst = tvm.te.all(cond_subst, range_cond) + cond_subst = ana.simplify(cond_subst) + check_bool_expr_is_true( + tvm.te.all(cond_subst, cond_on_vars), all_vranges, + cond=tvm.te.all(tvm.tir.const(1, 'bool'), *constraints1.relations)) + + _check_forward(constraints_trans.src, constraints_trans.dst, + constraints_trans.src_to_dst, constraints_trans.dst_to_src) + _check_forward(constraints_trans.dst, constraints_trans.src, + constraints_trans.dst_to_src, constraints_trans.src_to_dst) + tvm._ffi._init_api("testing", __name__) diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index c7a8365b9fda..daf61441b466 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -115,11 +115,15 @@ bool Analyzer::CanProve(const PrimExpr& expr) { return false; } -PrimExpr Analyzer::Simplify(const PrimExpr& expr) { +PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { if (tir::is_const_int(expr)) return expr; - auto res = this->rewrite_simplify(expr); - if (tir::is_const_int(res)) return res; - res = this->canonical_simplify(res); + PrimExpr res = expr; + for (int i = 0; i < steps; ++i) { + res = this->rewrite_simplify(res); + if (tir::is_const_int(res) || ++i == steps) return res; + res = this->canonical_simplify(res); + if (tir::is_const_int(res)) return res; + } return res; } @@ -139,7 +143,15 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu self->const_int_bound.Update(args[0], args[1], args[2]); }); } else if (name == "Simplify") { - return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { *ret = self->Simplify(args[0]); }); + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + if (args.size() == 1) { + *ret = self->Simplify(args[0]); + } else if (args.size() == 2) { + *ret = self->Simplify(args[0], args[1]); + } else { + LOG(FATAL) << "Invalid size of argument (" << args.size() << ")"; + } + }); } else if (name == "rewrite_simplify") { return PackedFunc( [self](TVMArgs args, TVMRetValue* ret) { *ret = self->rewrite_simplify(args[0]); }); diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 62858d2dc9e2..c95f7f855ceb 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -21,18 +21,171 @@ * \file int_constraints.cc * \brief The integer constraints data structures. */ +#include #include #include #include #include +#include +#include #include #include #include +#include "../tir/transforms/ir_util.h" + namespace tvm { namespace arith { +IntGroupBounds::IntGroupBounds(PrimExpr coef, Array lower, Array equal, + Array upper) { + CHECK(coef.dtype().is_int() || coef.dtype().is_uint()) + << "Coefficient in IntGroupBounds must be integers"; + ObjectPtr node = make_object(); + node->coef = std::move(coef); + node->lower = std::move(lower); + node->equal = std::move(equal); + node->upper = std::move(upper); + data_ = std::move(node); +} + +IntGroupBounds IntGroupBounds::FromRange(const Range& r) { + Analyzer analyzer; + PrimExpr coef = tir::make_const(r->min.dtype(), 1); + Array equal; + Array lower; + Array upper; + if (tir::is_one(r->extent)) { + equal.push_back(r->min); + } else { + lower.push_back(r->min); + upper.push_back(analyzer.Simplify(r->min + r->extent - 1)); + } + return IntGroupBounds(coef, lower, equal, upper); +} + +IntGroupBounds IntGroupBounds::operator+(const Range& r) { + Analyzer analyzer; + Array equal; + Array lower; + Array upper; + const PrimExpr& coef = operator->()->coef; + if (tir::is_one(r->extent)) { + equal.push_back(analyzer.Simplify(r->min * coef)); + } else { + lower.push_back(analyzer.Simplify(r->min * coef)); + upper.push_back(analyzer.Simplify((r->min + r->extent - 1) * coef)); + } + for (const auto& eq : operator->()->equal) equal.push_back(eq); + for (const auto& lb : operator->()->lower) lower.push_back(lb); + for (const auto& ub : operator->()->upper) upper.push_back(ub); + return IntGroupBounds(coef, lower, equal, upper); +} + +IntGroupBounds IntGroupBounds::Substitute(const Map& subst) const { + auto apply_fun = [&subst](const PrimExpr& e) { return tir::Substitute(e, subst); }; + return IntGroupBounds(tir::Substitute(operator->()->coef, subst), + tir::UpdateArray(operator->()->lower, apply_fun), + tir::UpdateArray(operator->()->equal, apply_fun), + tir::UpdateArray(operator->()->upper, apply_fun)); +} + +Range IntGroupBounds::FindBestRange(const Map& vranges_addl) const { + Analyzer analyzer; + analyzer.Bind(vranges_addl); + + std::unordered_map var_intsets; + for (auto kv : vranges_addl) { + var_intsets[kv.first.get()] = IntSet::FromRange(kv.second); + } + + const Array& equal = operator->()->equal; + const PrimExpr& coef = operator->()->coef; + + std::vector lowers(equal.begin(), equal.end()); + std::vector uppers(equal.begin(), equal.end()); + for (const auto& expr : operator->()->lower) { + lowers.push_back(expr); + } + for (const auto& expr : operator->()->upper) { + uppers.push_back(expr); + } + + if (lowers.size() == 1 && uppers.size() == 1 && tir::is_one(coef)) { + return Range(analyzer.Simplify(lowers[0]), analyzer.Simplify(uppers[0] + 1)); + } + + // Here we will try all pairs of lower and upper bounds and find the best pair, that is, the + // pair with the minimal difference between the upper and the lower. + // Note that the bounds are for v, not for v*coef + + // The lower bound of the best pair so far + PrimExpr best_lower; + // The difference between the upper and the lower of the best pair, maybe overapproximation + PrimExpr best_diff_over; + + for (const PrimExpr& low : lowers) { + for (const PrimExpr& upp : uppers) { + PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, coef), 3); + // Since diff may depend on some other variables, we compute its overapproximation + PrimExpr diff_over_1 = analyzer.Simplify(EvalSet(diff_1, var_intsets).max(), 3); + + // low is the lower bound for v*coef, but we need the lower bound for v. + // We use rounding-up division to compute it. Since we want to use a single formula + PrimExpr low_divided = analyzer.Simplify(floordiv(low + coef - 1, coef), 3); + + // Compute another difference which may be more precise (or not). + PrimExpr diff_2 = analyzer.Simplify(floordiv(upp, coef) - low_divided, 3); + PrimExpr diff_over_2 = analyzer.Simplify(EvalSet(diff_2, var_intsets).max(), 3); + + PrimExpr diff_over = + analyzer.CanProve(diff_over_2 - diff_over_1 < 0) ? diff_over_2 : diff_over_1; + + // If it is provable that the new one is strictly better than the current best one, + // then replace it. Note that we are biased towards earlier pairs which should be simpler. + if (!best_diff_over.defined() || analyzer.CanProve(diff_over - best_diff_over < 0)) { + best_lower = low_divided; + best_diff_over = diff_over; + } + } + } + + if (!best_lower.defined()) { + CHECK(!best_diff_over.defined()); + return Range(); + } + return Range::FromMinExtent(best_lower, analyzer.Simplify(best_diff_over + 1)); +} + +TVM_REGISTER_NODE_TYPE(IntGroupBoundsNode); + +TVM_REGISTER_GLOBAL("arith.IntGroupBounds") + .set_body_typed([](PrimExpr coef, Array lower, Array equal, + Array upper) { + return IntGroupBounds(coef, lower, equal, upper); + }); + +TVM_REGISTER_GLOBAL("arith.IntGroupBounds_from_range").set_body_typed(IntGroupBounds::FromRange); + +TVM_REGISTER_GLOBAL("arith.IntGroupBounds_FindBestRange") + .set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK(args.size() == 1 || args.size() == 2); + IntGroupBounds bounds = args[0]; + if (args.size() == 1) { + *ret = bounds.FindBestRange(); + } else if (args.size() == 2) { + *ret = bounds.FindBestRange(args[1]); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IntGroupBounds(coef=" << op->coef << ", lower=" << op->lower + << ", equal=" << op->equal << ", upper=" << op->upper << ")"; + }); + IntConstraints::IntConstraints(Array variables, Map ranges, Array relations) { ObjectPtr node = make_object(); @@ -55,6 +208,11 @@ IntConstraints::IntConstraints(Array variables, Map ranges, TVM_REGISTER_NODE_TYPE(IntConstraintsNode); +TVM_REGISTER_GLOBAL("arith.IntConstraints") + .set_body_typed([](Array variables, Map ranges, Array relations) { + return IntConstraints(variables, ranges, relations); + }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -75,6 +233,12 @@ IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, IntConstrai TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode); +TVM_REGISTER_GLOBAL("arith.IntConstraintsTransform") + .set_body_typed([](IntConstraints src, IntConstraints dst, Map src_to_dst, + Map dst_to_src) { + return IntConstraintsTransform(src, dst, src_to_dst, dst_to_src); + }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); diff --git a/src/arith/int_operator.h b/src/arith/int_operator.h index b69ce4fe5858..eff52308f389 100644 --- a/src/arith/int_operator.h +++ b/src/arith/int_operator.h @@ -182,6 +182,17 @@ inline int64_t ZeroAwareGCD(int64_t a, int64_t b) { return b; } +/*! + * \brief Calculate the least common multiple for two values. + * \param a an integer number + * \param b an integer number + * \return the least common multiple. + */ +inline int64_t LeastCommonMultiple(int64_t a, int64_t b) { + int64_t x, y; + return (a * b) / ExtendedEuclidean(a, b, &x, &y); +} + } // namespace arith } // namespace tvm #endif // TVM_ARITH_INT_OPERATOR_H_ diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc new file mode 100644 index 000000000000..f489d046835d --- /dev/null +++ b/src/arith/solve_linear_inequality.cc @@ -0,0 +1,646 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/arith/solve_linear_inequality.cc + * \brief Solve linear inequalities. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "int_operator.h" + +namespace tvm { +namespace arith { + +using namespace tvm::runtime; +using namespace tvm::tir; + +#define PLUS_ONE(OP) \ + void VisitExpr_(const OP* op) final { num_symbols_++; } + +#define PLUS_ONE_BINARY(OP) \ + void VisitExpr_(const OP* op) final { \ + num_symbols_++; \ + VisitExpr(op->a); \ + VisitExpr(op->b); \ + } + +/*! + * \brief Calculate the expresion complexity based on number of symbols it contains. + */ +class ExprComplexity : public ExprVisitor { + public: + size_t Eval(const PrimExpr& expr) { + VisitExpr(expr); + return num_symbols_; + } + + PLUS_ONE_BINARY(AddNode) + PLUS_ONE_BINARY(SubNode) + PLUS_ONE_BINARY(MulNode) + PLUS_ONE_BINARY(DivNode) + PLUS_ONE_BINARY(ModNode) + PLUS_ONE_BINARY(FloorDivNode) + PLUS_ONE_BINARY(FloorModNode) + PLUS_ONE_BINARY(MinNode) + PLUS_ONE_BINARY(MaxNode) + PLUS_ONE_BINARY(EQNode) + PLUS_ONE_BINARY(NENode) + PLUS_ONE_BINARY(LTNode) + PLUS_ONE_BINARY(LENode) + PLUS_ONE_BINARY(GTNode) + PLUS_ONE_BINARY(GENode) + PLUS_ONE_BINARY(AndNode) + PLUS_ONE_BINARY(OrNode) + PLUS_ONE(VarNode) + PLUS_ONE(FloatImmNode) + PLUS_ONE(IntImmNode) + void VisitExpr_(const NotNode* op) final { + num_symbols_++; + VisitExpr(op->a); + } + + private: + size_t num_symbols_{0}; +}; + +struct ExprLess { + bool operator()(const PrimExpr& l, const PrimExpr& r) const { + return ExprComplexity().Eval(l) < ExprComplexity().Eval(r); + } +}; + +/*! + * \brief Combine the information into an array of (in)equalities. + */ +Array as_conditions(const Array& variables, const Map& bounds, + const Array& relations) { + Array res; + // use variables to keep the order of iteration + // so as to get rid of any non-determinism. + CHECK_EQ(variables.size(), bounds.size()); + for (const auto v : variables) { + CHECK(bounds.count(v)); + const auto& bnds = bounds[v]; + PrimExpr lhs = bnds->coef * v; + for (const PrimExpr& rhs : bnds->equal) { + res.push_back(tir::EQ(lhs, rhs)); + } + for (const PrimExpr& rhs : bnds->lower) { + res.push_back(tir::GE(lhs, rhs)); + } + for (const PrimExpr& rhs : bnds->upper) { + res.push_back(tir::LE(lhs, rhs)); + } + } + for (const PrimExpr& e : relations) { + res.push_back(e); + } + return res; +} + +void DebugPrint( + const std::unordered_set& current_ineq_set, + const std::unordered_set& next_ineq_set, + const std::vector& rest, const std::vector>& coef_pos, + const std::vector>& coef_neg) { + std::cout << "Current ineq set:\n["; + for (auto& ineq : current_ineq_set) { + std::cout << ineq << ", "; + } + std::cout << "]\n"; + + std::cout << "Next ineq set:\n["; + for (auto& ineq : next_ineq_set) { + std::cout << ineq << ", "; + } + std::cout << "]\n"; + + std::cout << "coef_pos:\n["; + for (auto& coef : coef_pos) { + std::cout << "(" << coef.first << ", " << coef.second << "), "; + } + std::cout << "]\n"; + + std::cout << "coef_neg:\n["; + for (auto& coef : coef_neg) { + std::cout << "(" << coef.first << ", " << coef.second << "), "; + } + std::cout << "]\n"; +} + +/*! + * \brief normalize to the form `expr <= 0` + */ +class NormalizeComparisons : public ExprMutator { + public: + PrimExpr VisitExpr_(const EQNode* op) override { return Make(op->a, op->b); } + PrimExpr VisitExpr_(const NENode* op) override { return Make(op->a, op->b); } + PrimExpr VisitExpr_(const LTNode* op) override { return Make(op->a, op->b); } + PrimExpr VisitExpr_(const LENode* op) override { return Make(op->a, op->b); } + PrimExpr VisitExpr_(const GTNode* op) override { return Make(op->b, op->a); } + PrimExpr VisitExpr_(const GENode* op) override { return Make(op->b, op->a); } + + private: + template + PrimExpr Make(const PrimExpr& a, const PrimExpr& b) { + // rewrite LT to LE for ints + if (std::is_same::value && (a.dtype().is_int() || a.dtype().is_uint())) { + return LE(analyzer_.Simplify(a - b + 1), make_zero(a.dtype())); + } + return T(analyzer_.Simplify(a - b), make_zero(a.dtype())); + } + arith::Analyzer analyzer_; +}; + +void AddInequality(std::unordered_set* inequality_set, + const PrimExpr& new_ineq, Analyzer* analyzer) { + if (analyzer->CanProve(new_ineq) || inequality_set->find(new_ineq) != inequality_set->end()) { + // redundant: follows from the vranges + // or has already been added + return; + } + if (const LENode* new_le = new_ineq.as()) { + for (auto iter = inequality_set->begin(); iter != inequality_set->end();) { + const LENode* le = iter->as(); + if (le && analyzer->CanProve(new_le->a - le->a <= 0)) { + return; + } else if (le && analyzer->CanProve(le->a - new_le->a <= 0)) { + iter = inequality_set->erase(iter); + } else { + ++iter; + } + } + } + + inequality_set->insert(new_ineq); +} + +void ClassifyByPolarity( + const Var& var, + const std::unordered_set& current_ineq_set, + std::unordered_set* next_ineq_set, + std::vector* rest, std::vector>* coef_pos, + std::vector>* coef_neg, Analyzer* analyzer) { + // Take formulas from current_ineq_set and classify them according to polarity wrt var + // and store to coef_pos and coef_neg respectively. + for (const PrimExpr& ineq : current_ineq_set) { + if (const LENode* le = ineq.as()) { + Array coef = arith::DetectLinearEquation(le->a, {var}); + if (!coef.empty() && is_const_int(coef[0])) { + int64_t coef0 = *as_const_int(coef[0]); + if (coef0 == 0) { + // zero polarity, straight to next_ineq_set + AddInequality(next_ineq_set, ineq, analyzer); + } else if (coef0 > 0) { + coef_pos->push_back({coef0, coef[1]}); + } else if (coef0 < 0) { + coef_neg->push_back({coef0, coef[1]}); + } + continue; + } + } else if (const EQNode* eq = ineq.as()) { + Array coef = arith::DetectLinearEquation(eq->a, {var}); + if (!coef.empty() && is_const_int(coef[0])) { + int64_t coef0 = *as_const_int(coef[0]); + if (coef0 == 0) { + // zero polarity, straight to next_ineq_set + AddInequality(next_ineq_set, ineq, analyzer); + } else if (coef0 > 0) { + // Equalities may be considered as pairs of two inequalities + coef_pos->push_back({coef0, coef[1]}); + coef_neg->push_back({-coef0, -coef[1]}); + } else if (coef0 < 0) { + coef_pos->push_back({-coef0, -coef[1]}); + coef_neg->push_back({coef0, coef[1]}); + } + continue; + } + } + + // if nothing worked, put it in rest + rest->push_back(ineq); + } +} + +void MoveEquality(std::unordered_set* upper_bounds, + std::unordered_set* lower_bounds, + std::unordered_set* equalities) { + // those exist in both upper & lower bounds will be moved to equalities + for (auto ub = upper_bounds->begin(); ub != upper_bounds->end();) { + auto lb = lower_bounds->find(*ub); + if (lb != lower_bounds->end()) { + equalities->insert(*lb); + lower_bounds->erase(lb); + ub = upper_bounds->erase(ub); + } else { + ++ub; + } + } +} + +PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_to_solve) { + arith::Analyzer analyzer; + analyzer.Bind(system_to_solve->ranges); + + // The algorithm consists in doing the following things for each variable v + // - Take formulas from `current_ineq_set_to_solve` and + // classify them according to polarity wrt v. + // - Combine each formula of positive polarity (wrt v) + // with each formula of negative polarity. + // - Put the resulting combinations into `next_ineq_set_to_solve` + // along with unclassifiable formulas. + // - Replace `current_ineq_set_to_solve` with `next_ineq_set_to_solve` + // and move to the next variable. + + // normalized inequality + std::unordered_set current_ineq_set_to_solve; + std::unordered_set next_ineq_set_to_solve; + // A vector of pairs (c, e), c > 0, representing formulas of the form c*v + e <= 0 + std::vector> coef_pos; + // A vector of pairs (c, e), c < 0, representing formulas of the form c*v + e <= 0 + std::vector> coef_neg; + + // formulas we don't know what to do with + std::vector rest; + + // Simplify each inequality into the form `expr <= 0` and add to current formulas + for (const PrimExpr& ineq : system_to_solve->relations) { + AddInequality(¤t_ineq_set_to_solve, NormalizeComparisons()(analyzer.Simplify(ineq, 3)), + &analyzer); + } + + Map res_bounds; + for (const Var& v : system_to_solve->variables) { + CHECK(!res_bounds.count(v)) + << "Variable " << v + << " appears more than one time in the `variables` which might be a bug"; + + next_ineq_set_to_solve.clear(); + coef_pos.clear(); + coef_neg.clear(); + + // Add bounds from vranges + if (system_to_solve->ranges.count(v)) { + const Range& range = system_to_solve->ranges[v]; + PrimExpr range_lbound = analyzer.Simplify(range->min, 3); + PrimExpr range_ubound = analyzer.Simplify(range->min + range->extent - 1, 3); + coef_neg.push_back({-1, range_lbound}); + coef_pos.push_back({1, -range_ubound}); + } + + ClassifyByPolarity(v, current_ineq_set_to_solve, &next_ineq_set_to_solve, &rest, &coef_pos, + &coef_neg, &analyzer); + + // Combine each positive inequality with each negative one (by adding them together) + int64_t gcd_x, gcd_y; + for (const auto& pos : coef_pos) { + for (const auto& neg : coef_neg) { + auto first_gcd = ExtendedEuclidean(pos.first, -neg.first, &gcd_x, &gcd_y); + PrimExpr c_pos = make_const(v.dtype(), neg.first / first_gcd); + PrimExpr c_neg = make_const(v.dtype(), pos.first / first_gcd); + // eliminate the current variable + PrimExpr new_lhs = c_neg * neg.second - c_pos * pos.second; + PrimExpr new_ineq = LE(new_lhs, make_zero(pos.second.dtype())); + // we need rewrite_simplify -> canonical_simplify -> rewrite_simplify + // to help simplify things like (((y + 10) - (-1*(y - 20))) <= 0) => y - 5 <= 0 + // with steps = 2 it's (y*2) - 10 <= 0 + new_ineq = NormalizeComparisons()(analyzer.Simplify(new_ineq, 3)); + AddInequality(&next_ineq_set_to_solve, new_ineq, &analyzer); + } + } + + // Now we have to generate resulting (in)equalities for the variable v + + // Find the common denominator in a sense + // We will generate formulas of the form coef_lcm*v <= bound + int64_t coef_lcm = 1; + for (const auto& pos : coef_pos) { + coef_lcm = LeastCommonMultiple(coef_lcm, pos.first); + } + for (const auto& neg : coef_neg) { + coef_lcm = LeastCommonMultiple(coef_lcm, -neg.first); + } + + // The resulting lower and upper bounds + std::unordered_set upper_bounds; + std::unordered_set lower_bounds; + upper_bounds.reserve(coef_pos.size()); + lower_bounds.reserve(coef_neg.size()); + + for (const auto& pos : coef_pos) { + PrimExpr bound = make_const(v.dtype(), -coef_lcm / pos.first) * pos.second; + bound = analyzer.Simplify(bound, 3); + // Don't add if any of the existing bounds is better + if (std::any_of(upper_bounds.begin(), upper_bounds.end(), + [&bound, &analyzer](const PrimExpr& o) { + return analyzer.CanProve(o - bound <= 0); + })) { + continue; + } + // Erase all worse bounds + for (auto iter = upper_bounds.begin(); iter != upper_bounds.end();) { + if (analyzer.CanProve(*iter - bound >= 0)) { + iter = upper_bounds.erase(iter); + } else { + ++iter; + } + } + // Add the upper bound + upper_bounds.insert(bound); + } + for (const auto& neg : coef_neg) { + PrimExpr bound = make_const(v.dtype(), -coef_lcm / neg.first) * neg.second; + bound = analyzer.Simplify(bound, 3); + // Don't add if any of the existing bounds is better + if (std::any_of(lower_bounds.begin(), lower_bounds.end(), + [&bound, &analyzer](const PrimExpr& o) { + return analyzer.CanProve(o - bound >= 0); + })) { + continue; + } + // Erase all worse bounds + for (auto iter = lower_bounds.begin(); iter != lower_bounds.end();) { + if (analyzer.CanProve(*iter - bound <= 0)) { + iter = lower_bounds.erase(iter); + } else { + ++iter; + } + } + // Add the lower bound + lower_bounds.insert(bound); + } + + std::unordered_set equal; + equal.reserve(std::min(upper_bounds.size(), lower_bounds.size())); + MoveEquality(&upper_bounds, &lower_bounds, &equal); + std::vector equal_list(equal.begin(), equal.end()); + std::sort(equal_list.begin(), equal_list.end(), ExprLess()); + + // Write it to the result. + IntGroupBounds bnds(make_const(v.dtype(), coef_lcm), + Array(lower_bounds.begin(), lower_bounds.end()), + Array(equal_list.begin(), equal_list.end()), + Array(upper_bounds.begin(), upper_bounds.end())); + res_bounds.Set(v, bnds); + + std::swap(current_ineq_set_to_solve, next_ineq_set_to_solve); + } + + // Everything that is left goes to res.relations + Array other_conditions; + for (const PrimExpr& e : current_ineq_set_to_solve) { + PrimExpr e_simp = analyzer.Simplify(e, 3); + if (is_const_int(e_simp, 0)) { + // contradiction detected + other_conditions = {const_false()}; + break; + } else if (is_const_int(e_simp, 1)) { + continue; + } else { + other_conditions.push_back(e_simp); + } + } + + for (const PrimExpr& e : rest) { + other_conditions.push_back(e); + } + + return {res_bounds, other_conditions}; +} + +IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { + // Resulting ranges will contain ranges for the new variables and for the variables that are + // not in the inequalities->variables but are in inequalities->ranges + // It will be useful when solving Jacobian axes jac_xxx) + Map res_ranges; + // we get a set of equality, lower, upper bound of each variable. + auto solved_system = SolveLinearInequalities(inequalities); + + Map solved_bounds = solved_system.first; + Array solved_other_relations = solved_system.second; + + Array res_relations; + + // this keeps being updated during determining the range of each variable. + Map vranges; + for (std::pair vr : inequalities->ranges) { + vranges.Set(vr.first, vr.second); + } + + // We process variables in the reverse direction to start with the most independent one. + // This order is needed to compute new ranges. + for (auto it = inequalities->variables.rbegin(); it != inequalities->variables.rend(); ++it) { + arith::Analyzer analyzer; + analyzer.Bind(vranges); + + const Var& var = *it; + CHECK(solved_bounds.count(var)); + auto bnd = solved_bounds[var]; + if (is_one(bnd->coef) && !bnd->equal.empty()) { + // There is an equation of the form `v == expr`, so this variable can be completely removed. + // Note that we use the 0-th expression because they are ordered by complexity, + // so it must be the simplest one. + Range best_range(bnd->equal[0], analyzer.Simplify(bnd->equal[0] + 1, 3)); + res_ranges.Set(var, best_range); + vranges.Set(var, best_range); + } else { + if (vranges.count(var) > 0) { + bnd = bnd + vranges[var]; + } + + auto best_range = bnd.FindBestRange(vranges); + + if (best_range.defined()) { + if (analyzer.CanProveGreaterEqual(-best_range->extent, 0)) { + // range.extent <= 0 implies the input inequality system is unsolvable + return IntConstraints(/*variables=*/{}, /*ranges=*/{}, + /*relations=*/{tir::make_zero(DataType::Bool())}); + } + res_ranges.Set(var, best_range); + vranges.Set(var, best_range); + } + } + } + + // Add the original conditions to the resulting conditions + arith::Analyzer analyzer; + analyzer.Bind(vranges); + for (const PrimExpr& old_cond : + as_conditions(inequalities->variables, solved_bounds, solved_other_relations)) { + if (!analyzer.CanProve(old_cond)) { + // those not represented in vranges (res_ranges) + res_relations.push_back(old_cond); + } + } + + IntConstraints system(inequalities->variables, res_ranges, res_relations); + + return system; +} + +IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequalities) { + // Resulting ranges will contain ranges for the new variables and for the variables that are + // not in the inequalities->variables but are in inequalities->ranges (jac_xxx) + Map res_ranges; + // we get a set of equality, lower, upper bound of each variable. + auto solved_system = SolveLinearInequalities(inequalities); + Map solved_bounds = solved_system.first; + Array solved_other_relations = solved_system.second; + + arith::Analyzer analyzer; + + Map res_src_to_dst; + Map res_dst_to_src; + Array res_variables; + Array res_relations; + + // this keeps being updated during determining the range of each variable. + Map vranges; + for (std::pair vr : inequalities->ranges) { + vranges.Set(vr.first, vr.second); + } + analyzer.Bind(vranges); + + // We process variables in the reverse direction to start with the most independent one. + // This order is needed to compute new ranges. + for (auto it = inequalities->variables.rbegin(); it != inequalities->variables.rend(); ++it) { + const Var& var = *it; + auto bnd = solved_bounds[var]; + // Note that we replace old vars with new ones + bnd = bnd.Substitute(res_src_to_dst); + + if (is_one(bnd->coef) && !bnd->equal.empty()) { + // There is an equation of the form `v == expr`, + // so this variable can be completely removed. + // Note that we use the 0-th expression because they are ordered by complexity, + // so it must be the simplest one. + res_src_to_dst.Set(var, bnd->equal[0]); + } else { + if (vranges.count(var) > 0) { + bnd = bnd + vranges[var]; + } + + auto best_range = bnd.FindBestRange(vranges); + + Var new_var = var.copy_with_suffix(".shifted"); + if (!best_range.defined()) { + res_src_to_dst.Set(var, var); + res_dst_to_src.Set(var, var); + res_variables.push_back(var); + } else if (is_const_int(best_range->extent, 1)) { + // Don't create an itervar, just replace it everywhere with its min + res_src_to_dst.Set(var, best_range->min); + } else if (analyzer.CanProveGreaterEqual(-best_range->extent, 0)) { + // range.extent <= 0 implies the input inequality system is unsolvable + return IntConstraintsTransform(inequalities, + IntConstraints( + /*variables=*/{}, + /*ranges=*/{}, + /*relations=*/{tir::make_zero(DataType::Bool())}), + {}, {}); + } else { + // created new_var starts from 0 + res_src_to_dst.Set(var, new_var + best_range->min); + // Note that we are substituting old with new, so best_range contains new var, + // that is we have to substitute new with old in best_range here + res_dst_to_src.Set(new_var, + analyzer.Simplify(var - Substitute(best_range->min, res_dst_to_src))); + + // Add the new var to the resulting axis + auto range = Range(make_zero(new_var.dtype()), best_range->extent); + res_variables.push_back(new_var); + res_ranges.Set(new_var, range); + + vranges.Set(new_var, range); + analyzer.Bind(new_var, range); + } + } + } + + // Add the original conditions (with variables substituted) to the resulting conditions + for (const PrimExpr& old_cond : + as_conditions(inequalities->variables, solved_bounds, solved_other_relations)) { + PrimExpr new_cond = analyzer.Simplify(Substitute(old_cond, res_src_to_dst)); + if (!is_const_int(new_cond, 1)) { + // those not represented in vranges (res_ranges) + res_relations.push_back(new_cond); + } + } + + // Reverse the axis so that it matches the order of the original variables + res_variables = Array(res_variables.rbegin(), res_variables.rend()); + + IntConstraints new_inequalities(res_variables, res_ranges, res_relations); + IntConstraintsTransform transform(inequalities, new_inequalities, res_src_to_dst, res_dst_to_src); + + return transform; +} + +TVM_REGISTER_GLOBAL("arith.SolveInequalitiesAsCondition") + .set_body([](TVMArgs args, TVMRetValue* ret) { + IntConstraints problem; + PartialSolvedInequalities ret_ineq; + if (args.size() == 1) { + problem = args[0]; + ret_ineq = SolveLinearInequalities(problem); + } else if (args.size() == 3) { + problem = IntConstraints(args[0], args[1], args[2]); + ret_ineq = SolveLinearInequalities(problem); + } else { + LOG(FATAL) << "arith.SolveInequalitiesAsCondition expects 1 or 3 arguments, gets " + << args.size(); + } + *ret = as_conditions(problem->variables, ret_ineq.first, ret_ineq.second); + }); + +TVM_REGISTER_GLOBAL("arith.SolveInequalitiesToRange").set_body([](TVMArgs args, TVMRetValue* ret) { + if (args.size() == 1) { + *ret = SolveInequalitiesToRange(args[0]); + } else if (args.size() == 3) { + IntConstraints problem(args[0], args[1], args[2]); + *ret = SolveInequalitiesToRange(problem); + } else { + LOG(FATAL) << "arith.SolveInequalitiesToRange expects 1 or 3 arguments, gets " << args.size(); + } +}); + +TVM_REGISTER_GLOBAL("arith.SolveInequalitiesDeskewRange") + .set_body([](TVMArgs args, TVMRetValue* ret) { + if (args.size() == 1) { + *ret = SolveInequalitiesDeskewRange(args[0]); + } else if (args.size() == 3) { + IntConstraints problem(args[0], args[1], args[2]); + *ret = SolveInequalitiesDeskewRange(problem); + } else { + LOG(FATAL) << "arith.SolveInequalitiesDeskewRange expects 1 or 3 arguments, gets " + << args.size(); + } + }); + +} // namespace arith +} // namespace tvm diff --git a/tests/python/unittest/test_arith_solve_linear_system.py b/tests/python/unittest/test_arith_solve_linear_equations.py similarity index 62% rename from tests/python/unittest/test_arith_solve_linear_system.py rename to tests/python/unittest/test_arith_solve_linear_equations.py index 550dfef995c6..968e40b5d5f7 100644 --- a/tests/python/unittest/test_arith_solve_linear_system.py +++ b/tests/python/unittest/test_arith_solve_linear_equations.py @@ -15,84 +15,10 @@ # specific language governing permissions and limitations # under the License. import random -import numpy as np import sys import pytest import tvm -from tvm import te, arith, ir, tir - - -def run_expr(expr, vranges): - """ Evaluate expr for every value of free variables - given by vranges and return the tensor of results. - TODO(yzhliu): move to utils - """ - def _compute_body(*us): - vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)} - return tir.stmt_functor.substitute(expr, vmap) - - A = te.compute([r.extent.value for v, r in vranges.items()], _compute_body) - args = [tvm.nd.empty(A.shape, A.dtype)] - sch = te.create_schedule(A.op) - mod = tvm.build(sch, [A]) - mod(*args) - return args[0].asnumpy() - - -def check_bruteforce(bool_expr, vranges, cond=None): - """ Check that bool_expr holds given the condition cond - for every value of free variables from vranges. - TODO(yzhliu): move to utils - """ - if cond is not None: - bool_expr = te.any(tir.Not(cond), bool_expr) - - res = run_expr(bool_expr, vranges) - if not np.all(res): - indices = list(np.argwhere(res == 0)[0]) - counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)] - counterex = sorted(counterex, key=lambda x: x[0]) - counterex = ", ".join([v + " = " + str(i) for v, i in counterex]) - raise AssertionError("Expression {}\nis not true on {}\n" - "Counterexample: {}" - .format(tir.arith.Analyzer().simplify(bool_expr), vranges, counterex)) - - -def check_solution(solution, vranges={}): - """Check that solution is a bijective transformation""" - def _check_forward(constraints1, constraints2, varmap, backvarmap): - ana = tvm.arith.Analyzer() - all_vranges = vranges.copy() - all_vranges.update({v: r for v, r in constraints1.ranges.items()}) - - # Check that the transformation is injective - cond_on_vars = tir.const(1, 'bool') - for v in constraints1.variables: - # variable mapping is consistent - v_back = ana.simplify(tir.stmt_functor.substitute(varmap[v], backvarmap)) - cond_on_vars = te.all(cond_on_vars, v == v_back) - # Also we have to check that the new relations are true when old relations are true - cond_subst = tir.stmt_functor.substitute( - te.all(tir.const(1, 'bool'), *constraints2.relations), backvarmap) - # We have to include relations from vranges too - for v in constraints2.variables: - if v in constraints2.ranges: - r = constraints2.ranges[v] - range_cond = te.all(v >= r.min, v < r.min + r.extent) - range_cond = tir.stmt_functor.substitute(range_cond, backvarmap) - cond_subst = te.all(cond_subst, range_cond) - cond_subst = ana.simplify(cond_subst) - check_bruteforce(te.all(cond_subst, cond_on_vars), all_vranges, - cond=te.all(tir.const(1, 'bool'), *constraints1.relations)) - - rels = solution.dst.relations - if len(rels) == 1 and ir.structural_equal(rels[0], False): - # not solvable, skip - return - _check_forward(solution.src, solution.dst, - solution.src_to_dst, solution.dst_to_src) - _check_forward(solution.dst, solution.src, - solution.dst_to_src, solution.src_to_dst) +from tvm import te, arith, ir, tir, testing def test_solution_consistency(): @@ -120,14 +46,14 @@ def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)): vranges = {v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in variables} solution = arith.solve_linear_equations(relations, variables, vranges) - check_solution(solution) + testing.check_int_constraints_trans_consistency(solution) # leaving some variables as parameters should also be ok for k in [1, 2]: if len(variables) > k: solution = arith.solve_linear_equations(relations, variables[:-k], vranges) param_ranges = {v: vranges[v] for v in variables[-k:]} - check_solution(solution, param_ranges) + testing.check_int_constraints_trans_consistency(solution, param_ranges) for i in range(2): _check(num_vars=1, num_formulas=1) diff --git a/tests/python/unittest/test_arith_solve_linear_inequality.py b/tests/python/unittest/test_arith_solve_linear_inequality.py new file mode 100644 index 000000000000..acdabecc89f4 --- /dev/null +++ b/tests/python/unittest/test_arith_solve_linear_inequality.py @@ -0,0 +1,188 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import random +import sys +import pytest +import tvm +from tvm import te, arith, ir, tir, testing + + +def test_solution_consistency(): + seed = random.randrange(sys.maxsize) + print("\nThis test is intentionally non-deterministic, " + "if it fails please report it in github issue together with this seed {}\n".format(seed)) + random.seed(seed) + + def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)): + vs = [te.var("x" + str(i)) for i in range(variables)] + + fs = [] + for i in range(formulas): + s1 = sum([v*random.randint(coef[0], coef[1]) for v in vs]) + s1 += random.randint(coef[0], coef[1]) + s2 = sum([v*random.randint(coef[0], coef[1]) for v in vs]) + s2 += random.randint(coef[0], coef[1]) + op = random.choice([tir.expr.EQ, tir.expr.LE, tir.expr.LT, tir.expr.GE, tir.expr.GT]) + fs.append(op(s1, s2)) + + vranges = {v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in vs} + before = te.all(tir.const(1, 'bool'), *fs) + after = arith._ffi_api.SolveInequalitiesAsCondition(vs, vranges, fs) + after = te.all(tir.const(1, 'bool'), *after) + testing.check_bool_expr_is_true(before == after, vranges) + + solution = arith.solve_linear_inequalities(fs, vs, vranges, deskew_range=True) + testing.check_int_constraints_trans_consistency(solution) + + for i in range(3): + _check(1, 1) + for i in range(3): + _check(1, 2) + + for i in range(3): + _check(2, 1) + for i in range(3): + _check(2, 2) + for i in range(3): + _check(2, 3) + + # Somewhere here coefficients in the results become too large, leading to overflow, + # so we use smaller initial coefficients + for i in range(5): + _check(3, 3, coef=(-2, 2)) + for i in range(5): + _check(3, 4, coef=(-2, 2)) + + for i in range(5): + _check(4, 3, coef=(-1, 1)) + + for i in range(5): + _check(10, 2, coef=(-1, 1), bounds=(0, 4)) + for i in range(5): + _check(10, 3, coef=(0, 1), bounds=(0, 4)) + + +def test_dual_variable(): + x, y = te.var("x"), te.var("y") + + variables = [x, y] + ranges = { + x: tvm.ir.Range(-100, 100), + y: tvm.ir.Range(0, 10), + } + problem = [ + tvm.tir.LE(x + y, 20), + tvm.tir.GE(x - y, 10), + ] + + # solution as conditions + solution = arith._ffi_api.SolveInequalitiesAsCondition(variables, ranges, problem) + assert ir.structural_equal(solution[0], x >= (y + 10)) + assert ir.structural_equal(solution[1], x <= (20 - y)) + assert ir.structural_equal(solution[2], y >= 0) + assert ir.structural_equal(solution[3], y <= 5) + + # solve and get the ranges + solution = arith.solve_linear_inequalities(problem, variables, ranges) + # 0 <= y <=5 + assert solution.ranges[y].min == 0 + assert solution.ranges[y].extent == 6 + # y + 10 <= x <= 20 - y + assert ir.structural_equal(solution.ranges[x].min, y + 10) + assert solution.ranges[x].extent == 11 # max(10 - 2y) + + # deskew the solved ranges to be starting from zero + solution = arith.solve_linear_inequalities(problem, variables, ranges, deskew_range=True) + [x_new, y_new] = solution.dst.variables + [rel] = solution.dst.relations + assert ir.structural_equal(rel, (y_new*2) + x_new <= 10) + assert ir.structural_equal(solution.dst.ranges[x_new].min, 0) + assert ir.structural_equal(solution.dst.ranges[x_new].extent, 11) + assert ir.structural_equal(solution.dst.ranges[y_new].min, 0) + assert ir.structural_equal(solution.dst.ranges[y_new].extent, 6) + assert ir.structural_equal(solution.src_to_dst[x], x_new + (y_new + 10)) + assert ir.structural_equal(solution.src_to_dst[y], y_new) + assert ir.structural_equal(solution.dst_to_src[x_new], x - y - 10) + assert ir.structural_equal(solution.dst_to_src[y_new], y) + + +def test_equal(): + x, y = te.var("x"), te.var("y") + problem = [ + tvm.tir.GE(x + y, 10), + tvm.tir.GE(x - y, 2), + tvm.tir.LE(x, 6), + ] + + solution = arith.solve_linear_inequalities(problem, [x, y]) + assert solution.ranges[x].min == 6 + assert solution.ranges[x].extent == 1 + assert solution.ranges[y].min == 4 + assert solution.ranges[y].extent == 1 + + solution = arith.solve_linear_inequalities(problem, [x, y], deskew_range=True) + assert len(solution.dst.variables) == 0 + assert len(solution.dst.ranges) == 0 + assert len(solution.dst.relations) == 0 + assert solution.src_to_dst[x] == 6 + assert solution.src_to_dst[y] == 4 + + +def test_multi_equal(): + x, y, z = te.var("x"), te.var("y"), te.var("z") + problem = [ + tvm.tir.LE(x, 6), + tvm.tir.GE(x, 6), + tvm.tir.GE(x - z * y, 0), + tvm.tir.LE(x - z * y, 0), + ] + + solution = arith.solve_linear_inequalities(problem, [x, y, z]) + assert solution.ranges[x].min == 6 + assert solution.ranges[x].extent == 1 + assert len(solution.relations) == 3 + assert ir.structural_equal(solution.relations[0], x == z * y) + assert ir.structural_equal(solution.relations[1], z*y - 6 <= 0) + assert ir.structural_equal(solution.relations[2], 6 - z*y <= 0) + + solution = arith.solve_linear_inequalities(problem, [x, y, z], deskew_range=True) + assert solution.src_to_dst[y] == y + assert solution.src_to_dst[z] == z + assert solution.src_to_dst[x] == 6 + + +def test_no_solution(): + x = te.var("x0") + vranges = {x: tvm.ir.Range.from_min_extent(-20, 41)} + problem = [-x - 4 <= -5*x + 2, x*4 + 5 <= x*5] + + solution = arith.solve_linear_inequalities(problem, [x], vranges, deskew_range=True) + assert list(solution.dst.variables) == [] + [rel] = solution.dst.relations + assert ir.structural_equal(rel, False) + assert len(solution.src_to_dst) == 0 + assert len(solution.dst_to_src) == 0 + + solution = arith.solve_linear_inequalities(problem, [x], vranges) + assert len(solution.variables) == 0 + assert len(solution.ranges) == 0 + [rel] = solution.relations + assert not rel + + +if __name__ == "__main__": + pytest.main([__file__])