From 8207da5e1a791ce8164492ec4cdc299e55ad9114 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 29 Apr 2020 19:45:09 -0700 Subject: [PATCH 01/33] [arith] inequalities solver --- include/tvm/arith/util.h | 16 + python/tvm/arith/__init__.py | 2 +- python/tvm/arith/int_solver.py | 6 + src/arith/detect_linear_equation.cc | 2 +- src/arith/solve_linear_inequality.cc | 608 ++++++++++++++++++ src/arith/util.cc | 14 + .../test_arith_solve_linear_inequality.py | 105 +++ 7 files changed, 751 insertions(+), 2 deletions(-) create mode 100644 src/arith/solve_linear_inequality.cc create mode 100644 tests/python/unittest/test_arith_solve_linear_inequality.py diff --git a/include/tvm/arith/util.h b/include/tvm/arith/util.h index adfcefcd2e21..1818eefa8789 100644 --- a/include/tvm/arith/util.h +++ b/include/tvm/arith/util.h @@ -31,6 +31,14 @@ namespace tvm { /*! \brief namespace of arithmetic analysis. */ namespace arith { +/*! + * \brief Calculate the greatest common divisor for two values. + * \param a an integer number + * \param b an integer number + * \return the greatest common divisor. + */ +int gcd(int a, int b); + /*! * \brief Calculate the extended greatest common divisor for two values. * See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm. @@ -40,6 +48,14 @@ namespace arith { */ std::tuple xgcd(int64_t a, int64_t b); +/*! + * \brief Calculate the least common multiple for two values. + * \param a an integer number + * \param b an integer number + * \return the least common multiple. + */ +int lcm(int a, int b); + } // namespace arith } // namespace tvm #endif // TVM_ARITH_UTIL_H_ diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index 017934a03b34..e5af52938f5c 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -20,4 +20,4 @@ from .analyzer import ModularSet, ConstIntBound, Analyzer from .bound import deduce_bound from .pattern import detect_linear_equation, detect_clip_bound -from .int_solver import solve_linear_equations +from .int_solver import solve_linear_equations, solve_linear_inequalities diff --git a/python/tvm/arith/int_solver.py b/python/tvm/arith/int_solver.py index e35435c1da03..c27e39f52917 100644 --- a/python/tvm/arith/int_solver.py +++ b/python/tvm/arith/int_solver.py @@ -97,3 +97,9 @@ def solve_linear_equations(equations, variables=None, ranges=None): if isinstance(equations, IntConstraints): return _ffi_api.SolveLinearEquations(equations) return _ffi_api.SolveLinearEquations(variables, ranges, equations) + + +def solve_linear_inequalities(equations, variables=None, ranges=None): + if isinstance(equations, IntConstraints): + return _ffi_api.SolveLinearInequalities(equations) + return _ffi_api.SolveLinearInequalities(variables, ranges, equations) diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index cc9c745a24b8..587b2d8f04ae 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -142,7 +142,7 @@ class LinearEqDetector }; Array DetectLinearEquation(const PrimExpr& e, - const Array& vars) { + const Array& vars) { PrimExpr base = e; Array coeff; diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc new file mode 100644 index 000000000000..71f434b277c4 --- /dev/null +++ b/src/arith/solve_linear_inequality.cc @@ -0,0 +1,608 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/arith/solve_linear_inequality.cc + * \brief Solve linear inequalities. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// TODO +#include "../tir/pass/ir_util.h" +// TODO: testing +// https://github.com/sergei-grechanik/tvm/blob/8cad2d1e62272b3e192bfe08b896e07bc9550e94/tests/python/unittest/test_pass_zero_elimination.py#L367 + +namespace tvm { +namespace arith { + +using namespace tvm::runtime; +using namespace tvm::te; + +struct VarBounds { + PrimExpr coef; + Array lower; + Array equal; + Array upper; + + /*! + * \brief Perform substitution on all components of the struct. + */ + VarBounds substitute(const Map& subst) const { + auto apply_fun = [&subst](const PrimExpr& e) { return Substitute(e, subst); }; + return {Substitute(coef, subst), + tir::UpdateArray(lower, apply_fun), + tir::UpdateArray(equal, apply_fun), + tir::UpdateArray(upper, apply_fun)}; + } +}; + +// TODO +struct SolveSystemOfInequalitiesResult { + Array variables; + std::unordered_map bounds; + Array other_conditions; + + /*! + * \brief Combine the information into an array of (in)equalities. + */ + Array as_conditions() const { + Array res; + for (const Var& v : variables) { + auto it = bounds.find(v.get()); + CHECK(it != bounds.end()); + const VarBounds& bnds = it->second; + PrimExpr lhs = bnds.coef * v; + for (const PrimExpr& rhs : bnds.equal) { + res.push_back(EQNode::make(lhs, rhs)); + } + for (const PrimExpr& rhs : bnds.lower) { + res.push_back(GENode::make(lhs, rhs)); + } + for (const PrimExpr& rhs : bnds.upper) { + res.push_back(LENode::make(lhs, rhs)); + } + } + for (const PrimExpr& e : other_conditions) { + res.push_back(e); + } + return res; + } +}; + +//////////////////////////////////////////////////// + +/* +struct ExprLess { + bool operator()(const PrimExpr& l, const PrimExpr& r) const { + // FIXME: + // After https://github.com/apache/incubator-tvm/pull/5206 + // we no longer have ExprLess, + // it was comparing VarNode* raw pointers + return Compare(l, r) < 0; + } +}; +*/ + +void DebugPrint(std::unordered_set& current_ineq_set, + std::unordered_set& next_ineq_set, + std::vector& rest, + std::vector >& coef_pos, + std::vector >& coef_neg) { + std::cout << "Current ineq set:\n["; + for (auto& ineq : current_ineq_set) { + std::cout << ineq << ", "; + } + std::cout << "]\n"; + + std::cout << "Next ineq set:\n["; + for (auto& ineq : next_ineq_set) { + std::cout << ineq << ", "; + } + std::cout << "]\n"; + + std::cout << "coef_pos:\n["; + for (auto& coef : coef_pos) { + std::cout << "(" << coef.first << ", " << coef.second << "), "; + } + std::cout << "]\n"; + + std::cout << "coef_neg:\n["; + for (auto& coef : coef_neg) { + std::cout << "(" << coef.first << ", " << coef.second << "), "; + } + std::cout << "]\n"; +} + +// normalize to the form `expr <= 0` +class NormalizeComparisons : public ExprMutator { + public: + PrimExpr VisitExpr_(const EQNode* op) override { return Make(op->a, op->b); } + PrimExpr VisitExpr_(const NENode* op) override { return Make(op->a, op->b); } + PrimExpr VisitExpr_(const LTNode* op) override { return Make(op->a, op->b); } + PrimExpr VisitExpr_(const LENode* op) override { return Make(op->a, op->b); } + PrimExpr VisitExpr_(const GTNode* op) override { return Make(op->b, op->a); } + PrimExpr VisitExpr_(const GENode* op) override { return Make(op->b, op->a); } + + private: + template + PrimExpr Make(const PrimExpr& a, const PrimExpr& b) { + LOG(INFO) << "a = " << a << " b = " << b; + // rewrite LT to LE for ints + if (std::is_same::value && (a.dtype().is_int() || a.dtype().is_uint())) { + return LENode::make(analyzer_.Simplify(a - b + 1), make_zero(a.dtype())); + } + return TNode::make(analyzer_.Simplify(a - b), make_zero(a.dtype())); + } + arith::Analyzer analyzer_; +}; + +void AddInequality(std::unordered_set& inequality_set, + const PrimExpr& new_ineq, + Analyzer& analyzer) { + LOG(INFO) << "insert ineq " << new_ineq; + if (analyzer.CanProve(new_ineq) || inequality_set.find(new_ineq) != inequality_set.end()) { + // redundant: follows from the vranges + // or has already been added + return; + } + for (auto iter = inequality_set.begin(); iter != inequality_set.end();) { + if (const LENode* new_le = new_ineq.as()) { + const LENode* le = iter->as(); + if (le && analyzer.CanProve(new_le->a - le->a <= 0)) { + return; + } else if (le && analyzer.CanProve(le->a - new_le->a <= 0)) { + iter = inequality_set.erase(iter); + } else { + ++iter; + } + } else { + ++iter; + } + } + + inequality_set.insert(new_ineq); +} + +void ClassifyPolarity(const Var& var, + std::unordered_set& current_ineq_set, + std::unordered_set& next_ineq_set, + std::vector& rest, + std::vector >& coef_pos, + std::vector >& coef_neg, + Analyzer &analyzer) { + // Take formulas from current_ineq_set and classify them according to polarity wrt var + // and store to coef_pos and coef_neg respectively. + for (const PrimExpr& ineq : current_ineq_set) { + if (const LENode* le = ineq.as()) { + Array coef = arith::DetectLinearEquation(le->a, {var}); + if (!coef.empty() && is_const(coef[0])) { + int64_t coef0 = *as_const_int(coef[0]); + if (coef0 == 0) { + // zero polarity, straight to next_ineq_set + AddInequality(next_ineq_set, ineq, analyzer); + } else if (coef0 > 0) { + coef_pos.push_back({coef0, coef[1]}); + } else if (coef0 < 0) { + coef_neg.push_back({coef0, coef[1]}); + } + continue; + } + } else if (const EQNode* eq = ineq.as()) { + Array coef = arith::DetectLinearEquation(eq->a, {var}); + if (!coef.empty() && is_const(coef[0])) { + int64_t coef0 = *as_const_int(coef[0]); + if (coef0 == 0) { + // zero polarity, straight to new_current + AddInequality(next_ineq_set, ineq, analyzer); + } else if (coef0 > 0) { + // Equalities may be considered as pairs of two inequalities + coef_pos.push_back({coef0, coef[1]}); + coef_neg.push_back({-coef0, -coef[1]}); + } else if (coef0 < 0) { + coef_pos.push_back({-coef0, -coef[1]}); + coef_neg.push_back({coef0, coef[1]}); + } + continue; + } + } + + // if nothing worked, put it in rest + rest.push_back(ineq); + } +} + +void MoveEquality(std::unordered_set& upper_bounds, + std::unordered_set& lower_bounds, + std::unordered_set& equalities) { + // those exist in both upper & lower bounds will be moved to equalities + for (auto ub = upper_bounds.begin(); ub != upper_bounds.end();) { + auto lb = lower_bounds.find(*ub); + if (lb != lower_bounds.end()) { + equalities.insert(*lb); + lower_bounds.erase(lb); + ub = upper_bounds.erase(ub); + } else { + ++ub; + } + } +} + +SolveSystemOfInequalitiesResult SolveLinearInequalities(const IntConstraints& system_to_solve) { + LOG(INFO) << "solving inequalities " << system_to_solve; + + arith::Analyzer analyzer; + for (auto kv : system_to_solve->ranges) { + analyzer.Bind(kv.first, kv.second); + } + + SolveSystemOfInequalitiesResult res; + res.variables = system_to_solve->variables; + + // The algorithm consists in doing the following things for each variable v + // - Take formulas from `current_ineq_set_to_solve` and classify them according to polarity wrt v + // - Combine each formula of positive polarity (wrt v) with each formula of negative polarity + // - Put the resulting combinations into `next_ineq_set_to_solve` along with unclassifiable formulas + // - Replace `current` with `next_ineq_set_to_solve` and move to the next variable + + // normalized inequality + // current and next_ineq_set_to_solve are sorted to enable some heuristics + std::unordered_set current_ineq_set_to_solve; + std::unordered_set next_ineq_set_to_solve; + // A vector of pairs (c, e), c > 0, representing formulas of the form c*v + e <= 0 + std::vector> coef_pos; + // A vector of pairs (c, e), c < 0, representing formulas of the form c*v + e <= 0 + std::vector> coef_neg; + + // formulas we don't know what to do with + std::vector rest; + + // Simplify each inequality into the form `expr <= 0` and add to current formulas + for (const PrimExpr& ineq : system_to_solve->relations) { + // TODO: SuperSimplify(ineq, vranges) + AddInequality(current_ineq_set_to_solve, NormalizeComparisons()(analyzer.Simplify(ineq)), analyzer); + } + + DebugPrint(current_ineq_set_to_solve, + next_ineq_set_to_solve, + rest, + coef_pos, + coef_neg); + + for (const Var& v : system_to_solve->variables) { + std::cout << "Working on " << v << "\n"; + // TODO: + CHECK(!res.bounds.count(v.get())) << + "Variable " << v << " appears more than one time in the `variables` which might be a bug"; + + next_ineq_set_to_solve.clear(); + coef_pos.clear(); + coef_neg.clear(); + + // Add bounds from vranges + if (system_to_solve->ranges.count(v)) { + const Range& range = (system_to_solve->ranges)[v]; + PrimExpr range_lbound = analyzer.Simplify(range->min); + PrimExpr range_ubound = analyzer.Simplify(range->min + range->extent - 1); + coef_neg.push_back({-1, range_lbound}); + coef_pos.push_back({1, -range_ubound}); + } + + ClassifyPolarity(v, current_ineq_set_to_solve, next_ineq_set_to_solve, rest, coef_pos, coef_neg, analyzer); + + DebugPrint(current_ineq_set_to_solve, + next_ineq_set_to_solve, + rest, + coef_pos, + coef_neg); + + // Combine each positive inequality with each negative one (by adding them together) + for (const auto& pos : coef_pos) { + for (const auto& neg : coef_neg) { + auto first_gcd = gcd(pos.first, -neg.first); + PrimExpr c_pos = make_const(v.dtype(), neg.first/first_gcd); + PrimExpr c_neg = make_const(v.dtype(), pos.first/first_gcd); + // eliminate the current variable + PrimExpr new_lhs = c_neg*neg.second - c_pos*pos.second; + PrimExpr new_ineq = LENode::make(new_lhs, make_zero(pos.second.dtype())); + LOG(INFO) << "get new ineq " << new_ineq; + // it helps to simplify (((y + 10) - (-1*(y - 20))) <= 0) => y - 5 <= 0 + // otherwise it's (y*2) - 10 <= 0 + new_ineq = NormalizeComparisons()(analyzer.rewrite_simplify(analyzer.Simplify(new_ineq))); + AddInequality(next_ineq_set_to_solve, new_ineq, analyzer); + } + } + + // Now we have to generate resulting (in)equalities for the variable v + + // Find the common denominator in a sense + // We will generate formulas of the form coef_lcm*v <= bound + int64_t coef_lcm = 1; + for (const auto& pos : coef_pos) { + coef_lcm = lcm(coef_lcm, pos.first); + } + for (const auto& neg : coef_neg) { + coef_lcm = lcm(coef_lcm, -neg.first); + } + + // The resulting lower and upper bounds stored in sorted vectors + std::unordered_set upper_bounds; + std::unordered_set lower_bounds; + upper_bounds.reserve(coef_pos.size()); + lower_bounds.reserve(coef_neg.size()); + + for (const auto& pos : coef_pos) { + PrimExpr bound = make_const(v.dtype(), -coef_lcm/pos.first)*pos.second; + bound = analyzer.Simplify(bound); + // Don't add if any of the existing bounds is better + if (std::any_of(upper_bounds.begin(), upper_bounds.end(), + [&bound, &analyzer](const PrimExpr& o) + { return analyzer.CanProve(o - bound <= 0); })) { + continue; + } + // Erase all worse bounds + for (auto iter = upper_bounds.begin(); iter != upper_bounds.end();) { + if (analyzer.CanProve(*iter - bound >= 0)) { + iter = upper_bounds.erase(iter); + } else { + ++iter; + } + } + /* + upper_bounds.erase( + std::remove_if(upper_bounds.begin(), upper_bounds.end(), + [&bound, &analyzer](const PrimExpr& o) + { return analyzer.CanProve(o - bound >= 0); }), + upper_bounds.end()); + */ + // Add the upper bound + upper_bounds.insert(bound); + } + for (const auto& neg : coef_neg) { + PrimExpr bound = make_const(v.dtype(), -coef_lcm/neg.first)*neg.second; + bound = analyzer.Simplify(bound); + // Don't add if any of the existing bounds is better + if (std::any_of(lower_bounds.begin(), lower_bounds.end(), + [&bound, &analyzer](const PrimExpr& o) + { return analyzer.CanProve(o - bound >= 0); })) { + continue; + } + // Erase all worse bounds + for (auto iter = lower_bounds.begin(); iter != lower_bounds.end();) { + if (analyzer.CanProve(*iter - bound <= 0)) { + iter = lower_bounds.erase(iter); + } else { + ++iter; + } + } +// lower_bounds.erase( +// std::remove_if(lower_bounds.begin(), lower_bounds.end(), +// [&bound, &analyzer](const PrimExpr& o) +// { return analyzer.CanProve(o - bound <= 0); }), +// lower_bounds.end()); + // Add the lower bound + lower_bounds.insert(bound); + } + + std::unordered_set equal; + equal.reserve(std::min(upper_bounds.size(), lower_bounds.size())); + MoveEquality(upper_bounds, lower_bounds, equal); + + // Write it to the result. + auto& bnds = res.bounds[v.get()]; + bnds.coef = make_const(v.dtype(), coef_lcm); + bnds.equal = Array(equal.begin(), equal.end()); + bnds.lower = Array(lower_bounds.begin(), lower_bounds.end()); + bnds.upper = Array(upper_bounds.begin(), upper_bounds.end()); + LOG(INFO) << "Bound of " << v << " coef = " << bnds.coef + << " EQUAL: " << bnds.equal + << " LOWER: " << bnds.lower + << " UPPER: " << bnds.upper; + + std::swap(current_ineq_set_to_solve, next_ineq_set_to_solve); + } + + // Everything that is left goes to res.other_conditions + for (const PrimExpr& e : current_ineq_set_to_solve) { + PrimExpr e_simp = analyzer.Simplify(e); + if (is_const_int(e_simp, 0)) { + // contradiction detected + res.other_conditions = {const_false()}; + return res; + } else if (is_const_int(e_simp, 1)) { + continue; + } else { + res.other_conditions.push_back(e_simp); + } + } + + for (const PrimExpr& e : rest) { + res.other_conditions.push_back(e); + } + + return res; +} + +// Deskew the given domain +IntConstraintsTransform DeskewDomain(const IntConstraints& inequalities) { + // Resulting ranges will contain ranges for the new variables and for the variables that are + // not in the inequalities->variables but are in inequalities->ranges (jac_xxx) + Map res_ranges; + // we get a set of equality, lower, upper bound of each variable. + auto solved_system = SolveLinearInequalities(inequalities); + LOG(INFO) << "solved system = " << solved_system.as_conditions(); + + arith::Analyzer analyzer; + + Map res_old_to_new; + Map res_new_to_old; + Array res_variables; + Array res_relations; + std::unordered_map new_var_intsets; + + // this keeps being updated during determining the range of each variable. + // TODO: vrange must be non-empty ? + Map vranges = inequalities->ranges; + // Initialize new_var_intsets with the old var intsets + for (const auto& pair : inequalities->ranges) { + new_var_intsets[pair.first.get()] = IntSet::range(pair.second); + analyzer.Bind(pair.first, pair.second); + } + + // We process variables in the reverse direction to start with the most independent one. + // This order is needed to compute new ranges. + for (auto it = inequalities->variables.rbegin(); it != inequalities->variables.rend(); ++it) { + const Var& var = *it; + auto& bnd = solved_system.bounds[var.get()]; + // Note that we replace old vars with new ones + bnd = bnd.substitute(res_old_to_new); + + if (is_one(bnd.coef) && !bnd.equal.empty()) { + // There is an equation of the form `v == expr`, so this variable can be completely removed. + // Note that we use the 0-th expression because they are ordered by complexity, so it must be + // the simplest one. + // TODO + res_old_to_new.Set(var, bnd.equal[0]); + } else { + std::vector lowers(bnd.equal.begin(), bnd.equal.end()); + std::vector uppers(bnd.equal.begin(), bnd.equal.end()); + for (const auto& expr : bnd.lower) lowers.push_back(expr); + for (const auto& expr : bnd.upper) uppers.push_back(expr); + + // TODO: remove following +// std::sort(lowers.begin(), lowers.end(), ExprLess()); +// std::sort(uppers.begin(), uppers.end(), ExprLess()); + + // Here we will try all pairs of lower and upper bounds and find the best pair, that is, the + // pair with the minimal difference between the upper and the lower. + // Note that the bounds are for v, not for v*coef, because we will need bounds for v anyway + + // The lower bound of the best pair so far + PrimExpr best_lower; + if (vranges.count(var) > 0) { + best_lower = vranges[var]->min; + } + // The difference between the upper and the lower of the best pair, maybe overapproximation + PrimExpr best_diff_over; + if (vranges.count(var) > 0) { + best_diff_over = vranges[var]->extent - 1; + } + + for (const PrimExpr& low : lowers) { + for (const PrimExpr& upp : uppers) { + PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, bnd.coef)); + // Since diff may depend on some other variables, we compute its overapproximation + PrimExpr diff_over_1 = analyzer.Simplify(EvalSet(diff_1, new_var_intsets).max()); + + // low is the lower bound for v*coef, but we need the lower bound for v. + // We use rounding-up division to compute it. Since we want to use a single formula + PrimExpr low_divided = analyzer.Simplify(floordiv(low + bnd.coef - 1, bnd.coef)); + + // Compute another difference which may be more precise (or not). + PrimExpr diff_2 = analyzer.Simplify(floordiv(upp, bnd.coef) - low_divided); + PrimExpr diff_over_2 = analyzer.Simplify(EvalSet(diff_2, new_var_intsets).max()); + + PrimExpr diff_over = analyzer.CanProve(diff_over_2 - diff_over_1 < 0) + ? diff_over_2 : diff_over_1; + + // If it is provable that the new one is strictly better than the current best one, + // then replace it. Note that we are biased towards earlier pairs which should be simpler. + if (!best_diff_over.defined() || analyzer.CanProve(diff_over - best_diff_over < 0)) { + best_lower = low_divided; + best_diff_over = diff_over_1; + } + } + } + LOG(INFO) << var << " DONE for each low - upp best lower = " << best_lower; + if (!best_lower.defined()) { + continue; + } + + // TODO: +// std::string suffix = ExprDeepEqual()(best_lower, vranges[var]->min) ? "" : ".shifted"; + std::string suffix = ".shifted"; + Var new_var = var.copy_with_suffix(suffix); + + PrimExpr diff = analyzer.Simplify(best_diff_over); + + if (is_const_int(diff, 0)) { + // Don't create an itervar, just replace it everywhere with its min + res_old_to_new.Set(var, best_lower); + } else { + // created new_var starts from 0 + res_old_to_new.Set(var, new_var + best_lower); + // Note that we are substituting old with new, so best_lower contains new var, + // that is we have to substitute new with old in best_lower here + res_new_to_old.Set(new_var, + analyzer.Simplify(var - Substitute(best_lower, res_new_to_old))); + + new_var_intsets[new_var.get()] = IntSet::interval(make_zero(new_var.dtype()), diff); + + // Add the new var to the resulting axis + auto range = Range(make_zero(new_var.dtype()), analyzer.Simplify(diff + 1)); + res_variables.push_back(new_var); + res_ranges.Set(new_var, range); + + vranges.Set(new_var, range); + analyzer.Bind(new_var, range); + } + } + } + + // Add the original conditions (with variables substituted) to the resulting conditions + for (const PrimExpr& old_cond : solved_system.as_conditions()) { + PrimExpr new_cond = analyzer.Simplify(Substitute(old_cond, res_old_to_new)); + if (!is_const_int(new_cond, 1)) { + // those not represented in vranges (res_ranges) + res_relations.push_back(new_cond); + } + } + + // Reverse the axis so that it matches the order of the original variables + res_variables = Array(res_variables.rbegin(), res_variables.rend()); + + IntConstraints new_inequalities(res_variables, res_ranges, res_relations); + IntConstraintsTransform transform(inequalities, new_inequalities, res_old_to_new, res_new_to_old); + + return transform; +} + +TVM_REGISTER_GLOBAL("arith.SolveLinearInequalities") +.set_body([](TVMArgs args, TVMRetValue *ret) { + if (args.size() == 1) { + *ret = DeskewDomain(args[0]); + } else if (args.size() == 3) { + IntConstraints problem(args[0], args[1], args[2]); + *ret = DeskewDomain(problem); + } else { + LOG(FATAL) << "arith.SolveLinearInequalities expects 1 or 3 arguments, gets " << args.size(); + } + }); + +} // namespace arith +} // namespace tvm diff --git a/src/arith/util.cc b/src/arith/util.cc index 058c3e959528..e0e115cdda74 100644 --- a/src/arith/util.cc +++ b/src/arith/util.cc @@ -27,6 +27,16 @@ namespace tvm { namespace arith { +int gcd(int a, int b) { + if (a < b) std::swap(a, b); + while (b != 0) { + int64_t tmp = b; + b = a % b; + a = tmp; + } + return a; +} + std::tuple xgcd(int64_t a, int64_t b) { int64_t s = 0, old_s = 1; int64_t t = 1, old_t = 0; @@ -49,5 +59,9 @@ std::tuple xgcd(int64_t a, int64_t b) { return std::make_tuple(old_r, old_s, old_t); } +int lcm(int a, int b) { + return (a*b)/gcd(a, b); +} + } // namespace arith } // namespace tvm diff --git a/tests/python/unittest/test_arith_solve_linear_inequality.py b/tests/python/unittest/test_arith_solve_linear_inequality.py new file mode 100644 index 000000000000..c12243a41a07 --- /dev/null +++ b/tests/python/unittest/test_arith_solve_linear_inequality.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import random +import numpy as np +import sys +import pytest +import tvm +from tvm import te, arith, ir, tir + + +def test_solve_system_of_inequalities(): + random.seed(0) + + def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)): + vs = [te.var("x" + str(i)) for i in range(variables)] + + fs = [] + for i in range(formulas): + s1 = sum([v*random.randint(coef[0], coef[1]) for v in vs]) + s1 += random.randint(coef[0], coef[1]) + s2 = sum([v*random.randint(coef[0], coef[1]) for v in vs]) + s2 += random.randint(coef[0], coef[1]) + op = random.choice([tir.expr.EQ, tir.expr.LE, tir.expr.LT, tir.expr.GE, tir.expr.GT]) + fs.append(op(s1, s2)) + + vranges = {v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in vs} + + print("--- before ---") + print(fs) + after = arith.solve_linear_inequalities(fs, vs, vranges) + print("--- after ---") + print(after) + print() + + # check_bruteforce(before == after, vranges) + + _check(2, 2) + + # for i in range(3): + # _check(1, 1) + # for i in range(3): + # _check(1, 2) + # + # for i in range(3): + # _check(2, 1) + # for i in range(3): + # _check(2, 2) + # for i in range(3): + # _check(2, 3) + # + # # Somewhere here coefficients in the results become too large, leading to overflow, + # # so we use smaller initial coefficients + # + # for i in range(5): + # _check(3, 3, coef=(-2,2)) + # for i in range(5): + # _check(3, 4, coef=(-2,2)) + # + # for i in range(5): + # _check(4, 3, coef=(-1,1)) + # + # for i in range(5): + # _check(10, 2, coef=(-1,1), bounds=(0, 4)) + # for i in range(5): + # _check(10, 3, coef=(0,1), bounds=(0, 4)) + + +def test_simple(): + x, y = te.var("x"), te.var("y") + # TODO: following will hang forever + ranges = { + x: tvm.ir.Range(-100, 0), + y: tvm.ir.Range(0, 100), + } + + ranges = { + x: tvm.ir.Range(-100, 100), + y: tvm.ir.Range(0, 10), + } + + solution = arith.solve_linear_inequalities([ + tvm.tir.LE(x + y, 20), + tvm.tir.GE(x - y, 10), + ], [x, y], ranges) + + print(solution) + + +if __name__ == "__main__": + # test_solve_system_of_inequalities() + test_simple() From 39a93f70528b7f6c4eac57ae329e4f47d7dfdc5c Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 15 May 2020 23:03:36 -0700 Subject: [PATCH 02/33] introduce IntGroupedBounds --- include/tvm/arith/int_solver.h | 100 +++++++- python/tvm/arith/int_solver.py | 44 +++- src/arith/int_constraints.cc | 184 +++++++++++++- src/arith/solve_linear_equation.cc | 13 +- src/arith/solve_linear_inequality.cc | 239 +++++------------- .../test_arith_solve_linear_inequality.py | 51 +++- 6 files changed, 440 insertions(+), 191 deletions(-) diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index 57f3af4bb67b..1e4248c77a71 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -26,8 +26,10 @@ #include #include +#include #include #include +#include "analyzer.h" namespace tvm { namespace arith { @@ -36,6 +38,73 @@ using tir::Var; using tir::VarNode; using tir::IterVar; +class IntGroupedBoundsNode : public Object { + public: + PrimExpr coef; + Array lower; + Array equal; + Array upper; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("coef", &coef); + v->Visit("lower", &lower); + v->Visit("equal", &equal); + v->Visit("upper", &upper); + } + + bool SEqualReduce(const IntGroupedBoundsNode* other, SEqualReducer eq) const { + return + eq(coef, other->coef) && + eq(lower, other->lower) && + eq(equal, other->equal) && + eq(upper, other->upper); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(coef); + hash_reduce(lower); + hash_reduce(equal); + hash_reduce(upper); + } + + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const char* _type_key = "arith.IntGroupedBounds"; + TVM_DECLARE_FINAL_OBJECT_INFO(IntGroupedBoundsNode, Object); +}; + +/*! + * \brief Managed reference to IntGroupedBoundsNode. + * \sa IntGroupedBoundsNode + */ +class IntGroupedBounds : public ObjectRef { + public: + /*! TODO: comments + * \brief Constructor by fields + */ + TVM_DLL IntGroupedBounds(PrimExpr coef, + Array lower, + Array equal, + Array upper); + + /*! + * \brief Construct bounds from a range. + * \param r The range + * \return constructed bounds. + */ + static IntGroupedBounds range(const Range& r); + + /*! + * \brief Perform substitution on all components of the struct. + */ + IntGroupedBounds Substitute(const Map& subst) const; + + Range FindBestRange(const Map& vranges_addl = {}) const; + + IntGroupedBounds operator+(const Range& r); + + TVM_DEFINE_OBJECT_REF_METHODS(IntGroupedBounds, ObjectRef, IntGroupedBoundsNode); +}; + /*! * \brief Represent integer constrains including (integer) variables, their ranges and * the relations between them (either equations or inequalities). @@ -48,7 +117,7 @@ class IntConstraintsNode : public Object { // e.g., 1 <= \alpha <= N, etc. // it is absolutely ok to include ranges for parameters // (variables that are not in this->variables) in this map - Map ranges; + Map ranges; // linear equalities or inequalities // e.g., A \alpha = \beta or A \alpha <= \beta Array relations; @@ -91,9 +160,34 @@ class IntConstraints : public ObjectRef { * (either equations or inequalities) */ TVM_DLL IntConstraints(Array variables, - Map ranges, + Map ranges, Array relations); + /*! + * \brief Combine the information into an array of (in)equalities. + */ + Array as_conditions() const { + Array res; + for (const Var& v : operator->()->variables) { + CHECK(operator->()->ranges.count(v) > 0); + const auto& bnds = operator->()->ranges[v]; + PrimExpr lhs = bnds->coef * v; + for (const PrimExpr& rhs : bnds->equal) { + res.push_back(tir::EQNode::make(lhs, rhs)); + } + for (const PrimExpr& rhs : bnds->lower) { + res.push_back(tir::GENode::make(lhs, rhs)); + } + for (const PrimExpr& rhs : bnds->upper) { + res.push_back(tir::LENode::make(lhs, rhs)); + } + } + for (const PrimExpr& e : operator->()->relations) { + res.push_back(e); + } + return res; + } + TVM_DEFINE_OBJECT_REF_METHODS(IntConstraints, ObjectRef, IntConstraintsNode); }; @@ -169,6 +263,8 @@ class IntConstraintsTransform : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode); }; +Map ConvertGroupedBoundToRange(Map bounds); + /*! * \brief Obtain Smith Normal Form of linear equation A x = y. * Smith Normal Form of matrix A_{mxn} is S_{mxn} = U_{mxm} A_{mxn} V_{nxn}, diff --git a/python/tvm/arith/int_solver.py b/python/tvm/arith/int_solver.py index c27e39f52917..3db27d61d693 100644 --- a/python/tvm/arith/int_solver.py +++ b/python/tvm/arith/int_solver.py @@ -20,6 +20,38 @@ from . import _ffi_api +@tvm._ffi.register_object("arith.IntGroupedBounds") +class IntGroupedBounds(Object): + # TODO: doc + def __init__(self, coef, lower, equal, upper): + self.__init_handle_by_constructor__( + _ffi_api.IntGroupedBounds, coef, lower, equal, upper) + + @staticmethod + def make_by_range(r): + """Construct a IntGroupedBounds by Range. + + Parameters TODO + ---------- + min_value : PrimExpr + The minimum value of the range. + + extent : PrimExpr + The extent of the range. + + Returns + ------- + rng : Range + The constructed range. + """ + return _ffi_api.int_grouped_bounds_by_range(r) + + def find_best_range(self): + """Return the best range from the grouped bounds. + """ + return _ffi_api.IntGroupedBounds_FindBestRange(self) + + @tvm._ffi.register_object("arith.IntConstraints") class IntConstraints(Object): """Represent a set of integer constraints including variables, their ranges and @@ -96,10 +128,18 @@ def solve_linear_equations(equations, variables=None, ranges=None): """ if isinstance(equations, IntConstraints): return _ffi_api.SolveLinearEquations(equations) + if ranges is not None: + assert isinstance(ranges, dict) + ranges = {v: r if isinstance(r, IntGroupedBounds) else IntGroupedBounds.make_by_range(r) + for (v, r) in ranges.items()} return _ffi_api.SolveLinearEquations(variables, ranges, equations) def solve_linear_inequalities(equations, variables=None, ranges=None): if isinstance(equations, IntConstraints): - return _ffi_api.SolveLinearInequalities(equations) - return _ffi_api.SolveLinearInequalities(variables, ranges, equations) + return _ffi_api.DeskewRange(equations) + if ranges is not None: + assert isinstance(ranges, dict) + ranges = {v: r if isinstance(r, IntGroupedBounds) else IntGroupedBounds.make_by_range(r) + for (v, r) in ranges.items()} + return _ffi_api.DeskewRange(variables, ranges, equations) diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 34efa986e985..4ce5a8000588 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -24,24 +24,189 @@ #include #include #include +#include #include #include #include #include +#include +#include + +#include "../tir/pass/ir_util.h" namespace tvm { namespace arith { +Map ConvertGroupedBoundToRange(Map bounds) { + Map vranges; + for (const auto& kv : bounds) { + vranges.Set(kv.first, kv.second.FindBestRange()); + } + return vranges; +} + +IntGroupedBounds::IntGroupedBounds(PrimExpr coef, + Array lower, + Array equal, + Array upper) { + ObjectPtr node = make_object(); + node->coef = std::move(coef); + node->lower = std::move(lower); + node->equal = std::move(equal); + node->upper = std::move(upper); + data_ = std::move(node); +} + +IntGroupedBounds IntGroupedBounds::range(const Range& r) { + Analyzer analyzer; + PrimExpr coef = tir::make_const(r->min.dtype(), 1); + Array equal; + Array lower; + Array upper; + if (te::is_one(r->extent)) { + equal.push_back(r->min); + } else { + lower.push_back(r->min); + upper.push_back(analyzer.Simplify(r->min + r->extent - 1)); + } + return IntGroupedBounds(coef, lower, equal, upper); +} + +IntGroupedBounds IntGroupedBounds::operator+(const Range& r) { + Analyzer analyzer; + Array equal; + Array lower; + Array upper; + if (te::is_one(r->extent)) { + equal.push_back(analyzer.Simplify(r->min * operator->()->coef)); + } else { + lower.push_back(analyzer.Simplify(r->min * operator->()->coef)); + upper.push_back(analyzer.Simplify((r->min + r->extent - 1) * operator->()->coef)); + } + for (const auto& eq : operator->()->equal) equal.push_back(eq); + for (const auto& lb : operator->()->lower) lower.push_back(lb); + for (const auto& ub : operator->()->upper) upper.push_back(ub); + return IntGroupedBounds(operator->()->coef, lower, equal, upper); +} + +IntGroupedBounds IntGroupedBounds::Substitute(const Map& subst) const { + auto apply_fun = [&subst](const PrimExpr& e) { return tir::Substitute(e, subst); }; + return IntGroupedBounds(tir::Substitute(operator->()->coef, subst), + tir::UpdateArray(operator->()->lower, apply_fun), + tir::UpdateArray(operator->()->equal, apply_fun), + tir::UpdateArray(operator->()->upper, apply_fun)); +} + +Range IntGroupedBounds::FindBestRange(const Map& vranges_addl) const { + Analyzer analyzer; + analyzer.Bind(vranges_addl); + + std::unordered_map var_intsets; + for (auto kv : vranges_addl) { + var_intsets[kv.first.get()] = IntSet::range(kv.second); + } + + std::vector lowers(operator->()->equal.begin(), operator->()->equal.end()); + std::vector uppers(operator->()->equal.begin(), operator->()->equal.end()); + for (const auto& expr : operator->()->lower) { + lowers.push_back(expr); + } + for (const auto& expr : operator->()->upper) { + uppers.push_back(expr); + } + + if (lowers.size() == 1 && uppers.size() == 1 && te::is_one(operator->()->coef)) { + return Range(analyzer.Simplify(lowers[0]), analyzer.Simplify(uppers[0] + 1)); + } + + // Here we will try all pairs of lower and upper bounds and find the best pair, that is, the + // pair with the minimal difference between the upper and the lower. + // Note that the bounds are for v, not for v*coef + + // The lower bound of the best pair so far + PrimExpr best_lower; + // The difference between the upper and the lower of the best pair, maybe overapproximation + PrimExpr best_diff_over; + + for (const PrimExpr& low : lowers) { + for (const PrimExpr& upp : uppers) { + PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, operator->()->coef)); + // Since diff may depend on some other variables, we compute its overapproximation + PrimExpr diff_over_1 = analyzer.Simplify(EvalSet(diff_1, var_intsets).max()); + + // low is the lower bound for v*coef, but we need the lower bound for v. + // We use rounding-up division to compute it. Since we want to use a single formula + PrimExpr low_divided = analyzer.Simplify(floordiv(low + operator->()->coef - 1, operator->()->coef)); + + // Compute another difference which may be more precise (or not). + PrimExpr diff_2 = analyzer.Simplify(floordiv(upp, operator->()->coef) - low_divided); + PrimExpr diff_over_2 = analyzer.Simplify(EvalSet(diff_2, var_intsets).max()); + + PrimExpr diff_over = analyzer.CanProve(diff_over_2 - diff_over_1 < 0) + ? diff_over_2 : diff_over_1; + + // If it is provable that the new one is strictly better than the current best one, + // then replace it. Note that we are biased towards earlier pairs which should be simpler. + if (!best_diff_over.defined() || analyzer.CanProve(diff_over - best_diff_over < 0)) { + best_lower = low_divided; + best_diff_over = diff_over; + } + } + } + + if (!best_lower.defined()) { + CHECK(!best_diff_over.defined()); + return Range(); + } + return Range::make_by_min_extent(best_lower, analyzer.Simplify(best_diff_over + 1)); +} + +TVM_REGISTER_NODE_TYPE(IntGroupedBoundsNode); + +TVM_REGISTER_GLOBAL("arith.IntGroupedBounds") +.set_body_typed([](PrimExpr coef, + Array lower, + Array equal, + Array upper) { + return IntGroupedBounds(coef, lower, equal, upper); +}); + +TVM_REGISTER_GLOBAL("arith.int_grouped_bounds_by_range") +.set_body_typed(IntGroupedBounds::range); + +TVM_REGISTER_GLOBAL("arith.IntGroupedBounds_FindBestRange") +.set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK(args.size() == 1 || args.size() == 2); + IntGroupedBounds bounds = args[0]; + if (args.size() == 1) { + *ret = bounds.FindBestRange(); + } else if (args.size() == 2) { + *ret = bounds.FindBestRange(args[1]); + } +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IntGroupedBounds(coef=" + << op->coef + << ", lower=" << op->lower + << ", equal=" << op->equal + << ", upper=" << op->upper + << ")"; + }); + + IntConstraints::IntConstraints(Array variables, - Map ranges, + Map ranges, Array relations) { ObjectPtr node = make_object(); if (!variables.defined()) { variables = Array(); } if (!ranges.defined()) { - ranges = Map(); + ranges = Map(); } CHECK(relations.defined()); for (const auto& var : variables) { @@ -56,6 +221,13 @@ IntConstraints::IntConstraints(Array variables, TVM_REGISTER_NODE_TYPE(IntConstraintsNode); +TVM_REGISTER_GLOBAL("arith.IntConstraints") +.set_body_typed([](Array variables, + Map ranges, + Array relations) { + return IntConstraints(variables, ranges, relations); +}); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -81,6 +253,14 @@ IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode); +TVM_REGISTER_GLOBAL("arith.IntConstraintsTransform") +.set_body_typed([](IntConstraints src, + IntConstraints dst, + Map src_to_dst, + Map dst_to_src) { + return IntConstraintsTransform(src, dst, src_to_dst, dst_to_src); +}); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index 8142a03155c8..a5601061ef23 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -281,8 +281,9 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol // Conditions we don't know what to do with std::vector rest; + Map vranges = ConvertGroupedBoundToRange(system_to_solve->ranges); Analyzer analyzer_problem; - analyzer_problem.Bind(system_to_solve->ranges); + analyzer_problem.Bind(vranges); size_t num_vars = system_to_solve->variables.size(); @@ -428,13 +429,13 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol // The resulting ranges Map new_ranges = InferRange( - new_to_old_map, system_to_solve->variables, system_to_solve->ranges); + new_to_old_map, system_to_solve->variables, vranges); Analyzer analyzer_solution; analyzer_solution.Bind(new_ranges); // We have to transform ranges of the old variables into relations over new variables because // new ranges are not enough usually. - for (const auto& p : system_to_solve->ranges) { + for (const auto& p : vranges) { const Var& old_var = p.first; const Range& old_range = p.second; if (old_to_new_map.count(old_var)) { @@ -457,7 +458,11 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol new_relations.push_back(Substitute(cond, old_to_new_map)); } - IntConstraints solution(new_vars, new_ranges, new_relations); + Map new_bounds; + for (const auto& kv : new_ranges) { + new_bounds.Set(kv.first, IntGroupedBounds::range(kv.second)); + } + IntConstraints solution(new_vars, new_bounds, new_relations); IntConstraintsTransform transform( system_to_solve, solution, old_to_new_map, new_to_old_map); diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 71f434b277c4..cc44563a7de8 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -44,59 +44,6 @@ namespace arith { using namespace tvm::runtime; using namespace tvm::te; -struct VarBounds { - PrimExpr coef; - Array lower; - Array equal; - Array upper; - - /*! - * \brief Perform substitution on all components of the struct. - */ - VarBounds substitute(const Map& subst) const { - auto apply_fun = [&subst](const PrimExpr& e) { return Substitute(e, subst); }; - return {Substitute(coef, subst), - tir::UpdateArray(lower, apply_fun), - tir::UpdateArray(equal, apply_fun), - tir::UpdateArray(upper, apply_fun)}; - } -}; - -// TODO -struct SolveSystemOfInequalitiesResult { - Array variables; - std::unordered_map bounds; - Array other_conditions; - - /*! - * \brief Combine the information into an array of (in)equalities. - */ - Array as_conditions() const { - Array res; - for (const Var& v : variables) { - auto it = bounds.find(v.get()); - CHECK(it != bounds.end()); - const VarBounds& bnds = it->second; - PrimExpr lhs = bnds.coef * v; - for (const PrimExpr& rhs : bnds.equal) { - res.push_back(EQNode::make(lhs, rhs)); - } - for (const PrimExpr& rhs : bnds.lower) { - res.push_back(GENode::make(lhs, rhs)); - } - for (const PrimExpr& rhs : bnds.upper) { - res.push_back(LENode::make(lhs, rhs)); - } - } - for (const PrimExpr& e : other_conditions) { - res.push_back(e); - } - return res; - } -}; - -//////////////////////////////////////////////////// - /* struct ExprLess { bool operator()(const PrimExpr& l, const PrimExpr& r) const { @@ -189,13 +136,13 @@ void AddInequality(std::unordered_set inequality_set.insert(new_ineq); } -void ClassifyPolarity(const Var& var, - std::unordered_set& current_ineq_set, - std::unordered_set& next_ineq_set, - std::vector& rest, - std::vector >& coef_pos, - std::vector >& coef_neg, - Analyzer &analyzer) { +void ClassifyByPolarity(const Var &var, + std::unordered_set ¤t_ineq_set, + std::unordered_set &next_ineq_set, + std::vector &rest, + std::vector > &coef_pos, + std::vector > &coef_neg, + Analyzer &analyzer) { // Take formulas from current_ineq_set and classify them according to polarity wrt var // and store to coef_pos and coef_neg respectively. for (const PrimExpr& ineq : current_ineq_set) { @@ -253,16 +200,13 @@ void MoveEquality(std::unordered_set& } } -SolveSystemOfInequalitiesResult SolveLinearInequalities(const IntConstraints& system_to_solve) { +IntConstraints SolveLinearInequalities(const IntConstraints& system_to_solve) { LOG(INFO) << "solving inequalities " << system_to_solve; - arith::Analyzer analyzer; - for (auto kv : system_to_solve->ranges) { - analyzer.Bind(kv.first, kv.second); - } + Map vranges = ConvertGroupedBoundToRange(system_to_solve->ranges); - SolveSystemOfInequalitiesResult res; - res.variables = system_to_solve->variables; + arith::Analyzer analyzer; + analyzer.Bind(vranges); // The algorithm consists in doing the following things for each variable v // - Take formulas from `current_ineq_set_to_solve` and classify them according to polarity wrt v @@ -294,10 +238,9 @@ SolveSystemOfInequalitiesResult SolveLinearInequalities(const IntConstraints& sy coef_pos, coef_neg); + Map res_bounds; for (const Var& v : system_to_solve->variables) { - std::cout << "Working on " << v << "\n"; - // TODO: - CHECK(!res.bounds.count(v.get())) << + CHECK(!res_bounds.count(v)) << "Variable " << v << " appears more than one time in the `variables` which might be a bug"; next_ineq_set_to_solve.clear(); @@ -306,14 +249,14 @@ SolveSystemOfInequalitiesResult SolveLinearInequalities(const IntConstraints& sy // Add bounds from vranges if (system_to_solve->ranges.count(v)) { - const Range& range = (system_to_solve->ranges)[v]; + const Range& range = vranges[v]; PrimExpr range_lbound = analyzer.Simplify(range->min); PrimExpr range_ubound = analyzer.Simplify(range->min + range->extent - 1); coef_neg.push_back({-1, range_lbound}); coef_pos.push_back({1, -range_ubound}); } - ClassifyPolarity(v, current_ineq_set_to_solve, next_ineq_set_to_solve, rest, coef_pos, coef_neg, analyzer); + ClassifyByPolarity(v, current_ineq_set_to_solve, next_ineq_set_to_solve, rest, coef_pos, coef_neg, analyzer); DebugPrint(current_ineq_set_to_solve, next_ineq_set_to_solve, @@ -330,7 +273,6 @@ SolveSystemOfInequalitiesResult SolveLinearInequalities(const IntConstraints& sy // eliminate the current variable PrimExpr new_lhs = c_neg*neg.second - c_pos*pos.second; PrimExpr new_ineq = LENode::make(new_lhs, make_zero(pos.second.dtype())); - LOG(INFO) << "get new ineq " << new_ineq; // it helps to simplify (((y + 10) - (-1*(y - 20))) <= 0) => y - 5 <= 0 // otherwise it's (y*2) - 10 <= 0 new_ineq = NormalizeComparisons()(analyzer.rewrite_simplify(analyzer.Simplify(new_ineq))); @@ -373,13 +315,6 @@ SolveSystemOfInequalitiesResult SolveLinearInequalities(const IntConstraints& sy ++iter; } } - /* - upper_bounds.erase( - std::remove_if(upper_bounds.begin(), upper_bounds.end(), - [&bound, &analyzer](const PrimExpr& o) - { return analyzer.CanProve(o - bound >= 0); }), - upper_bounds.end()); - */ // Add the upper bound upper_bounds.insert(bound); } @@ -400,11 +335,6 @@ SolveSystemOfInequalitiesResult SolveLinearInequalities(const IntConstraints& sy ++iter; } } -// lower_bounds.erase( -// std::remove_if(lower_bounds.begin(), lower_bounds.end(), -// [&bound, &analyzer](const PrimExpr& o) -// { return analyzer.CanProve(o - bound <= 0); }), -// lower_bounds.end()); // Add the lower bound lower_bounds.insert(bound); } @@ -414,48 +344,49 @@ SolveSystemOfInequalitiesResult SolveLinearInequalities(const IntConstraints& sy MoveEquality(upper_bounds, lower_bounds, equal); // Write it to the result. - auto& bnds = res.bounds[v.get()]; - bnds.coef = make_const(v.dtype(), coef_lcm); - bnds.equal = Array(equal.begin(), equal.end()); - bnds.lower = Array(lower_bounds.begin(), lower_bounds.end()); - bnds.upper = Array(upper_bounds.begin(), upper_bounds.end()); - LOG(INFO) << "Bound of " << v << " coef = " << bnds.coef - << " EQUAL: " << bnds.equal - << " LOWER: " << bnds.lower - << " UPPER: " << bnds.upper; + IntGroupedBounds bnds(make_const(v.dtype(), coef_lcm), + Array(lower_bounds.begin(), lower_bounds.end()), + Array(equal.begin(), equal.end()), + Array(upper_bounds.begin(), upper_bounds.end()) + ); + res_bounds.Set(v, bnds); + LOG(INFO) << "Bound of " << v << bnds; std::swap(current_ineq_set_to_solve, next_ineq_set_to_solve); } - // Everything that is left goes to res.other_conditions + // Everything that is left goes to res.relations + Array other_conditions; for (const PrimExpr& e : current_ineq_set_to_solve) { PrimExpr e_simp = analyzer.Simplify(e); if (is_const_int(e_simp, 0)) { // contradiction detected - res.other_conditions = {const_false()}; - return res; + other_conditions = {const_false()}; + break; } else if (is_const_int(e_simp, 1)) { continue; } else { - res.other_conditions.push_back(e_simp); + other_conditions.push_back(e_simp); } } for (const PrimExpr& e : rest) { - res.other_conditions.push_back(e); + other_conditions.push_back(e); } + IntConstraints res(system_to_solve->variables, res_bounds, other_conditions); + return res; } // Deskew the given domain -IntConstraintsTransform DeskewDomain(const IntConstraints& inequalities) { +IntConstraintsTransform DeskewRange(const IntConstraints& inequalities) { // Resulting ranges will contain ranges for the new variables and for the variables that are // not in the inequalities->variables but are in inequalities->ranges (jac_xxx) - Map res_ranges; + Map res_ranges; // we get a set of equality, lower, upper bound of each variable. auto solved_system = SolveLinearInequalities(inequalities); - LOG(INFO) << "solved system = " << solved_system.as_conditions(); + LOG(INFO) << "solved system = " << solved_system; arith::Analyzer analyzer; @@ -463,110 +394,52 @@ IntConstraintsTransform DeskewDomain(const IntConstraints& inequalities) { Map res_new_to_old; Array res_variables; Array res_relations; - std::unordered_map new_var_intsets; // this keeps being updated during determining the range of each variable. // TODO: vrange must be non-empty ? - Map vranges = inequalities->ranges; - // Initialize new_var_intsets with the old var intsets - for (const auto& pair : inequalities->ranges) { - new_var_intsets[pair.first.get()] = IntSet::range(pair.second); - analyzer.Bind(pair.first, pair.second); - } + Map vranges = ConvertGroupedBoundToRange(inequalities->ranges); + analyzer.Bind(vranges); // We process variables in the reverse direction to start with the most independent one. // This order is needed to compute new ranges. for (auto it = inequalities->variables.rbegin(); it != inequalities->variables.rend(); ++it) { const Var& var = *it; - auto& bnd = solved_system.bounds[var.get()]; + auto bnd = solved_system->ranges[var]; // Note that we replace old vars with new ones - bnd = bnd.substitute(res_old_to_new); + bnd = bnd.Substitute(res_old_to_new); - if (is_one(bnd.coef) && !bnd.equal.empty()) { + if (is_one(bnd->coef) && !bnd->equal.empty()) { // There is an equation of the form `v == expr`, so this variable can be completely removed. // Note that we use the 0-th expression because they are ordered by complexity, so it must be // the simplest one. // TODO - res_old_to_new.Set(var, bnd.equal[0]); + res_old_to_new.Set(var, bnd->equal[0]); } else { - std::vector lowers(bnd.equal.begin(), bnd.equal.end()); - std::vector uppers(bnd.equal.begin(), bnd.equal.end()); - for (const auto& expr : bnd.lower) lowers.push_back(expr); - for (const auto& expr : bnd.upper) uppers.push_back(expr); - - // TODO: remove following -// std::sort(lowers.begin(), lowers.end(), ExprLess()); -// std::sort(uppers.begin(), uppers.end(), ExprLess()); - - // Here we will try all pairs of lower and upper bounds and find the best pair, that is, the - // pair with the minimal difference between the upper and the lower. - // Note that the bounds are for v, not for v*coef, because we will need bounds for v anyway - - // The lower bound of the best pair so far - PrimExpr best_lower; - if (vranges.count(var) > 0) { - best_lower = vranges[var]->min; - } - // The difference between the upper and the lower of the best pair, maybe overapproximation - PrimExpr best_diff_over; if (vranges.count(var) > 0) { - best_diff_over = vranges[var]->extent - 1; + bnd = bnd + vranges[var]; } - for (const PrimExpr& low : lowers) { - for (const PrimExpr& upp : uppers) { - PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, bnd.coef)); - // Since diff may depend on some other variables, we compute its overapproximation - PrimExpr diff_over_1 = analyzer.Simplify(EvalSet(diff_1, new_var_intsets).max()); - - // low is the lower bound for v*coef, but we need the lower bound for v. - // We use rounding-up division to compute it. Since we want to use a single formula - PrimExpr low_divided = analyzer.Simplify(floordiv(low + bnd.coef - 1, bnd.coef)); - - // Compute another difference which may be more precise (or not). - PrimExpr diff_2 = analyzer.Simplify(floordiv(upp, bnd.coef) - low_divided); - PrimExpr diff_over_2 = analyzer.Simplify(EvalSet(diff_2, new_var_intsets).max()); - - PrimExpr diff_over = analyzer.CanProve(diff_over_2 - diff_over_1 < 0) - ? diff_over_2 : diff_over_1; - - // If it is provable that the new one is strictly better than the current best one, - // then replace it. Note that we are biased towards earlier pairs which should be simpler. - if (!best_diff_over.defined() || analyzer.CanProve(diff_over - best_diff_over < 0)) { - best_lower = low_divided; - best_diff_over = diff_over_1; - } - } - } - LOG(INFO) << var << " DONE for each low - upp best lower = " << best_lower; - if (!best_lower.defined()) { - continue; - } + auto best_range = bnd.FindBestRange(vranges); + LOG(INFO) << "best range for " << var << " = " << best_range; - // TODO: -// std::string suffix = ExprDeepEqual()(best_lower, vranges[var]->min) ? "" : ".shifted"; std::string suffix = ".shifted"; Var new_var = var.copy_with_suffix(suffix); - PrimExpr diff = analyzer.Simplify(best_diff_over); - - if (is_const_int(diff, 0)) { + if (is_const_int(best_range->extent, 1)) { // Don't create an itervar, just replace it everywhere with its min - res_old_to_new.Set(var, best_lower); + res_old_to_new.Set(var, best_range->min); } else { // created new_var starts from 0 - res_old_to_new.Set(var, new_var + best_lower); + res_old_to_new.Set(var, new_var + best_range->min); // Note that we are substituting old with new, so best_lower contains new var, // that is we have to substitute new with old in best_lower here res_new_to_old.Set(new_var, - analyzer.Simplify(var - Substitute(best_lower, res_new_to_old))); - - new_var_intsets[new_var.get()] = IntSet::interval(make_zero(new_var.dtype()), diff); + analyzer.Simplify(var - Substitute(best_range->min, res_new_to_old))); // Add the new var to the resulting axis - auto range = Range(make_zero(new_var.dtype()), analyzer.Simplify(diff + 1)); + auto range = Range(make_zero(new_var.dtype()), best_range->extent); res_variables.push_back(new_var); - res_ranges.Set(new_var, range); + res_ranges.Set(new_var, IntGroupedBounds::range(range)); vranges.Set(new_var, range); analyzer.Bind(new_var, range); @@ -595,14 +468,26 @@ IntConstraintsTransform DeskewDomain(const IntConstraints& inequalities) { TVM_REGISTER_GLOBAL("arith.SolveLinearInequalities") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args.size() == 1) { - *ret = DeskewDomain(args[0]); + *ret = SolveLinearInequalities(args[0]); } else if (args.size() == 3) { IntConstraints problem(args[0], args[1], args[2]); - *ret = DeskewDomain(problem); + *ret = SolveLinearInequalities(problem); } else { LOG(FATAL) << "arith.SolveLinearInequalities expects 1 or 3 arguments, gets " << args.size(); } }); +TVM_REGISTER_GLOBAL("arith.DeskewRange") +.set_body([](TVMArgs args, TVMRetValue *ret) { + if (args.size() == 1) { + *ret = DeskewRange(args[0]); + } else if (args.size() == 3) { + IntConstraints problem(args[0], args[1], args[2]); + *ret = DeskewRange(problem); + } else { + LOG(FATAL) << "arith.DeskewRange expects 1 or 3 arguments, gets " << args.size(); + } + }); + } // namespace arith } // namespace tvm diff --git a/tests/python/unittest/test_arith_solve_linear_inequality.py b/tests/python/unittest/test_arith_solve_linear_inequality.py index c12243a41a07..4a57a4b5d398 100644 --- a/tests/python/unittest/test_arith_solve_linear_inequality.py +++ b/tests/python/unittest/test_arith_solve_linear_inequality.py @@ -82,10 +82,10 @@ def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)): def test_simple(): x, y = te.var("x"), te.var("y") # TODO: following will hang forever - ranges = { - x: tvm.ir.Range(-100, 0), - y: tvm.ir.Range(0, 100), - } + # ranges = { + # x: tvm.ir.Range(-100, 0), + # y: tvm.ir.Range(0, 100), + # } ranges = { x: tvm.ir.Range(-100, 100), @@ -99,7 +99,50 @@ def test_simple(): print(solution) + [x_new, y_new] = solution.dst.variables + [rel] = solution.dst.relations + + assert ir.structural_equal(rel, (y_new*2) + x_new <= 10) + + assert ir.structural_equal(solution.dst.ranges[x_new].find_best_range().min, 0) + assert ir.structural_equal(solution.dst.ranges[x_new].find_best_range().extent, 11) + + assert ir.structural_equal(solution.dst.ranges[y_new].find_best_range().min, 0) + assert ir.structural_equal(solution.dst.ranges[y_new].find_best_range().extent, 6) + + assert ir.structural_equal(solution.src_to_dst[x], x_new + (y_new + 10)) + assert ir.structural_equal(solution.src_to_dst[y], y_new) + assert ir.structural_equal(solution.dst_to_src[x_new], x - y - 10) + assert ir.structural_equal(solution.dst_to_src[y_new], y) + + +def test_equal(): + x, y = te.var("x"), te.var("y") + + solution = arith.solve_linear_inequalities([ + tvm.tir.GE(x + y, 10), + tvm.tir.GE(x - y, 2), + tvm.tir.LE(x, 6), + ], [x, y]) + + print(solution) + + +def test_multi_equal(): + x, y = te.var("x"), te.var("y") + + solution = arith.solve_linear_inequalities([ + tvm.tir.LE(x, 6), + tvm.tir.GE(x, 6), + tvm.tir.GE(x - 2 * y, 0), + tvm.tir.LE(x - 2 * y, 0), + ], [x, y]) + + print(solution) + if __name__ == "__main__": # test_solve_system_of_inequalities() test_simple() + test_equal() + test_multi_equal() From af6708dc8b128687ba8ca4fabcb90e267752dd1f Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 16 May 2020 16:14:42 -0700 Subject: [PATCH 03/33] add no deskewed solution --- include/tvm/arith/int_solver.h | 2 + python/tvm/arith/__init__.py | 2 +- python/tvm/arith/int_solver.py | 10 +++ src/arith/solve_linear_inequality.cc | 72 ++++++++++++++++++- .../test_arith_solve_linear_inequality.py | 15 +++- 5 files changed, 98 insertions(+), 3 deletions(-) diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index 1e4248c77a71..51ac6d422a50 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -265,6 +265,8 @@ class IntConstraintsTransform : public ObjectRef { Map ConvertGroupedBoundToRange(Map bounds); +typedef std::pair > PartialSolvedInequalities; + /*! * \brief Obtain Smith Normal Form of linear equation A x = y. * Smith Normal Form of matrix A_{mxn} is S_{mxn} = U_{mxm} A_{mxn} V_{nxn}, diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index e5af52938f5c..e4797b6108b1 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -20,4 +20,4 @@ from .analyzer import ModularSet, ConstIntBound, Analyzer from .bound import deduce_bound from .pattern import detect_linear_equation, detect_clip_bound -from .int_solver import solve_linear_equations, solve_linear_inequalities +from .int_solver import solve_linear_equations, solve_linear_inequalities, solve_linear_inequalities2 diff --git a/python/tvm/arith/int_solver.py b/python/tvm/arith/int_solver.py index 3db27d61d693..ce8e346a7a99 100644 --- a/python/tvm/arith/int_solver.py +++ b/python/tvm/arith/int_solver.py @@ -143,3 +143,13 @@ def solve_linear_inequalities(equations, variables=None, ranges=None): ranges = {v: r if isinstance(r, IntGroupedBounds) else IntGroupedBounds.make_by_range(r) for (v, r) in ranges.items()} return _ffi_api.DeskewRange(variables, ranges, equations) + + +def solve_linear_inequalities2(equations, variables=None, ranges=None): + if isinstance(equations, IntConstraints): + return _ffi_api.SolveInequalitiesRange(equations) + if ranges is not None: + assert isinstance(ranges, dict) + ranges = {v: r if isinstance(r, IntGroupedBounds) else IntGroupedBounds.make_by_range(r) + for (v, r) in ranges.items()} + return _ffi_api.SolveInequalitiesRange(variables, ranges, equations) diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index cc44563a7de8..729db10b8b6d 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -379,6 +379,65 @@ IntConstraints SolveLinearInequalities(const IntConstraints& system_to_solve) { return res; } +IntConstraints SolveInequalitiesRange(const IntConstraints& inequalities) { + // Resulting ranges will contain ranges for the new variables and for the variables that are + // not in the inequalities->variables but are in inequalities->ranges (jac_xxx) + Map res_ranges; + // we get a set of equality, lower, upper bound of each variable. + auto solved_system = SolveLinearInequalities(inequalities); + LOG(INFO) << "solved system = " << solved_system; + + Array res_variables; + Array res_relations; + + // this keeps being updated during determining the range of each variable. + Map vranges = ConvertGroupedBoundToRange(inequalities->ranges); + + // We process variables in the reverse direction to start with the most independent one. + // This order is needed to compute new ranges. + for (auto it = inequalities->variables.rbegin(); it != inequalities->variables.rend(); ++it) { + arith::Analyzer analyzer; + analyzer.Bind(vranges); + + const Var& var = *it; + auto bnd = solved_system->ranges[var]; + if (is_one(bnd->coef) && !bnd->equal.empty()) { + // There is an equation of the form `v == expr`, so this variable can be completely removed. + // Note that we use the 0-th expression because they are ordered by complexity, so it must be + // the simplest one. + Range best_range(bnd->equal[0], analyzer.Simplify(bnd->equal[0] + 1)); + res_ranges.Set(var, IntGroupedBounds::range(best_range)); + vranges.Set(var, best_range); + } else { + if (vranges.count(var) > 0) { + bnd = bnd + vranges[var]; + } + LOG(INFO) << "bnd = " << bnd; + LOG(INFO) << "vranges = " << vranges; + + auto best_range = bnd.FindBestRange(vranges); + LOG(INFO) << "best range for " << var << " = " << best_range; + + res_ranges.Set(var, IntGroupedBounds::range(best_range)); + vranges.Set(var, best_range); + } + } + + // Add the original conditions (with variables substituted) to the resulting conditions + arith::Analyzer analyzer; + analyzer.Bind(vranges); + for (const PrimExpr& old_cond : solved_system.as_conditions()) { + if (!analyzer.CanProve(old_cond)) { + // those not represented in vranges (res_ranges) + res_relations.push_back(old_cond); + } + } + + IntConstraints system(inequalities->variables, res_ranges, res_relations); + + return system; +} + // Deskew the given domain IntConstraintsTransform DeskewRange(const IntConstraints& inequalities) { // Resulting ranges will contain ranges for the new variables and for the variables that are @@ -396,7 +455,6 @@ IntConstraintsTransform DeskewRange(const IntConstraints& inequalities) { Array res_relations; // this keeps being updated during determining the range of each variable. - // TODO: vrange must be non-empty ? Map vranges = ConvertGroupedBoundToRange(inequalities->ranges); analyzer.Bind(vranges); @@ -477,6 +535,18 @@ TVM_REGISTER_GLOBAL("arith.SolveLinearInequalities") } }); +TVM_REGISTER_GLOBAL("arith.SolveInequalitiesRange") +.set_body([](TVMArgs args, TVMRetValue *ret) { + if (args.size() == 1) { + *ret = SolveInequalitiesRange(args[0]); + } else if (args.size() == 3) { + IntConstraints problem(args[0], args[1], args[2]); + *ret = SolveInequalitiesRange(problem); + } else { + LOG(FATAL) << "arith.SolveInequalitiesRange expects 1 or 3 arguments, gets " << args.size(); + } + }); + TVM_REGISTER_GLOBAL("arith.DeskewRange") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args.size() == 1) { diff --git a/tests/python/unittest/test_arith_solve_linear_inequality.py b/tests/python/unittest/test_arith_solve_linear_inequality.py index 4a57a4b5d398..bd4358dd6183 100644 --- a/tests/python/unittest/test_arith_solve_linear_inequality.py +++ b/tests/python/unittest/test_arith_solve_linear_inequality.py @@ -115,6 +115,12 @@ def test_simple(): assert ir.structural_equal(solution.dst_to_src[x_new], x - y - 10) assert ir.structural_equal(solution.dst_to_src[y_new], y) + sol = arith.solve_linear_inequalities2([ + tvm.tir.LE(x + y, 20), + tvm.tir.GE(x - y, 10), + ], [x, y], ranges) + print(sol) + def test_equal(): x, y = te.var("x"), te.var("y") @@ -127,6 +133,13 @@ def test_equal(): print(solution) + sol = arith.solve_linear_inequalities2([ + tvm.tir.GE(x + y, 10), + tvm.tir.GE(x - y, 2), + tvm.tir.LE(x, 6), + ], [x, y]) + print(sol) + def test_multi_equal(): x, y = te.var("x"), te.var("y") @@ -145,4 +158,4 @@ def test_multi_equal(): # test_solve_system_of_inequalities() test_simple() test_equal() - test_multi_equal() + # test_multi_equal() From 7f551e1a94a25ab27cdda4348eb6e6094af17f3b Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Mon, 18 May 2020 12:19:56 -0700 Subject: [PATCH 04/33] keep IntConstraints def --- include/tvm/arith/int_solver.h | 31 +------ python/tvm/arith/int_solver.py | 12 --- src/arith/int_constraints.cc | 6 +- src/arith/solve_linear_equation.cc | 13 +-- src/arith/solve_linear_inequality.cc | 86 +++++++++++-------- .../test_arith_solve_linear_inequality.py | 8 +- 6 files changed, 66 insertions(+), 90 deletions(-) diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index 51ac6d422a50..627e4bbe83fe 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -117,7 +117,7 @@ class IntConstraintsNode : public Object { // e.g., 1 <= \alpha <= N, etc. // it is absolutely ok to include ranges for parameters // (variables that are not in this->variables) in this map - Map ranges; + Map ranges; // linear equalities or inequalities // e.g., A \alpha = \beta or A \alpha <= \beta Array relations; @@ -160,34 +160,9 @@ class IntConstraints : public ObjectRef { * (either equations or inequalities) */ TVM_DLL IntConstraints(Array variables, - Map ranges, + Map ranges, Array relations); - /*! - * \brief Combine the information into an array of (in)equalities. - */ - Array as_conditions() const { - Array res; - for (const Var& v : operator->()->variables) { - CHECK(operator->()->ranges.count(v) > 0); - const auto& bnds = operator->()->ranges[v]; - PrimExpr lhs = bnds->coef * v; - for (const PrimExpr& rhs : bnds->equal) { - res.push_back(tir::EQNode::make(lhs, rhs)); - } - for (const PrimExpr& rhs : bnds->lower) { - res.push_back(tir::GENode::make(lhs, rhs)); - } - for (const PrimExpr& rhs : bnds->upper) { - res.push_back(tir::LENode::make(lhs, rhs)); - } - } - for (const PrimExpr& e : operator->()->relations) { - res.push_back(e); - } - return res; - } - TVM_DEFINE_OBJECT_REF_METHODS(IntConstraints, ObjectRef, IntConstraintsNode); }; @@ -265,7 +240,7 @@ class IntConstraintsTransform : public ObjectRef { Map ConvertGroupedBoundToRange(Map bounds); -typedef std::pair > PartialSolvedInequalities; +typedef std::pair, Array > PartialSolvedInequalities; /*! * \brief Obtain Smith Normal Form of linear equation A x = y. diff --git a/python/tvm/arith/int_solver.py b/python/tvm/arith/int_solver.py index ce8e346a7a99..dba67fde78b6 100644 --- a/python/tvm/arith/int_solver.py +++ b/python/tvm/arith/int_solver.py @@ -128,28 +128,16 @@ def solve_linear_equations(equations, variables=None, ranges=None): """ if isinstance(equations, IntConstraints): return _ffi_api.SolveLinearEquations(equations) - if ranges is not None: - assert isinstance(ranges, dict) - ranges = {v: r if isinstance(r, IntGroupedBounds) else IntGroupedBounds.make_by_range(r) - for (v, r) in ranges.items()} return _ffi_api.SolveLinearEquations(variables, ranges, equations) def solve_linear_inequalities(equations, variables=None, ranges=None): if isinstance(equations, IntConstraints): return _ffi_api.DeskewRange(equations) - if ranges is not None: - assert isinstance(ranges, dict) - ranges = {v: r if isinstance(r, IntGroupedBounds) else IntGroupedBounds.make_by_range(r) - for (v, r) in ranges.items()} return _ffi_api.DeskewRange(variables, ranges, equations) def solve_linear_inequalities2(equations, variables=None, ranges=None): if isinstance(equations, IntConstraints): return _ffi_api.SolveInequalitiesRange(equations) - if ranges is not None: - assert isinstance(ranges, dict) - ranges = {v: r if isinstance(r, IntGroupedBounds) else IntGroupedBounds.make_by_range(r) - for (v, r) in ranges.items()} return _ffi_api.SolveInequalitiesRange(variables, ranges, equations) diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 4ce5a8000588..f0b74cc1ab12 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -199,14 +199,14 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) IntConstraints::IntConstraints(Array variables, - Map ranges, + Map ranges, Array relations) { ObjectPtr node = make_object(); if (!variables.defined()) { variables = Array(); } if (!ranges.defined()) { - ranges = Map(); + ranges = Map(); } CHECK(relations.defined()); for (const auto& var : variables) { @@ -223,7 +223,7 @@ TVM_REGISTER_NODE_TYPE(IntConstraintsNode); TVM_REGISTER_GLOBAL("arith.IntConstraints") .set_body_typed([](Array variables, - Map ranges, + Map ranges, Array relations) { return IntConstraints(variables, ranges, relations); }); diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index a5601061ef23..8142a03155c8 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -281,9 +281,8 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol // Conditions we don't know what to do with std::vector rest; - Map vranges = ConvertGroupedBoundToRange(system_to_solve->ranges); Analyzer analyzer_problem; - analyzer_problem.Bind(vranges); + analyzer_problem.Bind(system_to_solve->ranges); size_t num_vars = system_to_solve->variables.size(); @@ -429,13 +428,13 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol // The resulting ranges Map new_ranges = InferRange( - new_to_old_map, system_to_solve->variables, vranges); + new_to_old_map, system_to_solve->variables, system_to_solve->ranges); Analyzer analyzer_solution; analyzer_solution.Bind(new_ranges); // We have to transform ranges of the old variables into relations over new variables because // new ranges are not enough usually. - for (const auto& p : vranges) { + for (const auto& p : system_to_solve->ranges) { const Var& old_var = p.first; const Range& old_range = p.second; if (old_to_new_map.count(old_var)) { @@ -458,11 +457,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol new_relations.push_back(Substitute(cond, old_to_new_map)); } - Map new_bounds; - for (const auto& kv : new_ranges) { - new_bounds.Set(kv.first, IntGroupedBounds::range(kv.second)); - } - IntConstraints solution(new_vars, new_bounds, new_relations); + IntConstraints solution(new_vars, new_ranges, new_relations); IntConstraintsTransform transform( system_to_solve, solution, old_to_new_map, new_to_old_map); diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 729db10b8b6d..983c41e0198e 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -56,6 +56,31 @@ struct ExprLess { }; */ +/*! + * \brief Combine the information into an array of (in)equalities. + */ +Array as_conditions(const Map& bounds, const Array& relations) { + Array res; + for (const auto iter : bounds) { + const Var& v = iter.first; + const auto& bnds = iter.second; + PrimExpr lhs = bnds->coef * v; + for (const PrimExpr& rhs : bnds->equal) { + res.push_back(tir::EQNode::make(lhs, rhs)); + } + for (const PrimExpr& rhs : bnds->lower) { + res.push_back(tir::GENode::make(lhs, rhs)); + } + for (const PrimExpr& rhs : bnds->upper) { + res.push_back(tir::LENode::make(lhs, rhs)); + } + } + for (const PrimExpr& e : relations) { + res.push_back(e); + } + return res; +} + void DebugPrint(std::unordered_set& current_ineq_set, std::unordered_set& next_ineq_set, std::vector& rest, @@ -200,13 +225,10 @@ void MoveEquality(std::unordered_set& } } -IntConstraints SolveLinearInequalities(const IntConstraints& system_to_solve) { +PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_to_solve) { LOG(INFO) << "solving inequalities " << system_to_solve; - - Map vranges = ConvertGroupedBoundToRange(system_to_solve->ranges); - arith::Analyzer analyzer; - analyzer.Bind(vranges); + analyzer.Bind(system_to_solve->ranges); // The algorithm consists in doing the following things for each variable v // - Take formulas from `current_ineq_set_to_solve` and classify them according to polarity wrt v @@ -249,7 +271,7 @@ IntConstraints SolveLinearInequalities(const IntConstraints& system_to_solve) { // Add bounds from vranges if (system_to_solve->ranges.count(v)) { - const Range& range = vranges[v]; + const Range& range = system_to_solve->ranges[v]; PrimExpr range_lbound = analyzer.Simplify(range->min); PrimExpr range_ubound = analyzer.Simplify(range->min + range->extent - 1); coef_neg.push_back({-1, range_lbound}); @@ -374,24 +396,27 @@ IntConstraints SolveLinearInequalities(const IntConstraints& system_to_solve) { other_conditions.push_back(e); } - IntConstraints res(system_to_solve->variables, res_bounds, other_conditions); - - return res; + return {res_bounds, other_conditions}; } IntConstraints SolveInequalitiesRange(const IntConstraints& inequalities) { // Resulting ranges will contain ranges for the new variables and for the variables that are // not in the inequalities->variables but are in inequalities->ranges (jac_xxx) - Map res_ranges; + Map res_ranges; // we get a set of equality, lower, upper bound of each variable. auto solved_system = SolveLinearInequalities(inequalities); - LOG(INFO) << "solved system = " << solved_system; + + Map solved_bounds = solved_system.first; + Array solved_other_relations = solved_system.second; Array res_variables; Array res_relations; // this keeps being updated during determining the range of each variable. - Map vranges = ConvertGroupedBoundToRange(inequalities->ranges); + Map vranges; + for (std::pair vr : inequalities->ranges) { + vranges.Set(vr.first, vr.second); + } // We process variables in the reverse direction to start with the most independent one. // This order is needed to compute new ranges. @@ -400,13 +425,14 @@ IntConstraints SolveInequalitiesRange(const IntConstraints& inequalities) { analyzer.Bind(vranges); const Var& var = *it; - auto bnd = solved_system->ranges[var]; + CHECK(solved_bounds.count(var)); + auto bnd = solved_bounds[var]; if (is_one(bnd->coef) && !bnd->equal.empty()) { // There is an equation of the form `v == expr`, so this variable can be completely removed. // Note that we use the 0-th expression because they are ordered by complexity, so it must be // the simplest one. Range best_range(bnd->equal[0], analyzer.Simplify(bnd->equal[0] + 1)); - res_ranges.Set(var, IntGroupedBounds::range(best_range)); + res_ranges.Set(var, best_range); vranges.Set(var, best_range); } else { if (vranges.count(var) > 0) { @@ -418,7 +444,7 @@ IntConstraints SolveInequalitiesRange(const IntConstraints& inequalities) { auto best_range = bnd.FindBestRange(vranges); LOG(INFO) << "best range for " << var << " = " << best_range; - res_ranges.Set(var, IntGroupedBounds::range(best_range)); + res_ranges.Set(var, best_range); vranges.Set(var, best_range); } } @@ -426,7 +452,7 @@ IntConstraints SolveInequalitiesRange(const IntConstraints& inequalities) { // Add the original conditions (with variables substituted) to the resulting conditions arith::Analyzer analyzer; analyzer.Bind(vranges); - for (const PrimExpr& old_cond : solved_system.as_conditions()) { + for (const PrimExpr& old_cond : as_conditions(solved_bounds, solved_other_relations)) { if (!analyzer.CanProve(old_cond)) { // those not represented in vranges (res_ranges) res_relations.push_back(old_cond); @@ -442,10 +468,11 @@ IntConstraints SolveInequalitiesRange(const IntConstraints& inequalities) { IntConstraintsTransform DeskewRange(const IntConstraints& inequalities) { // Resulting ranges will contain ranges for the new variables and for the variables that are // not in the inequalities->variables but are in inequalities->ranges (jac_xxx) - Map res_ranges; + Map res_ranges; // we get a set of equality, lower, upper bound of each variable. auto solved_system = SolveLinearInequalities(inequalities); - LOG(INFO) << "solved system = " << solved_system; + Map solved_bounds = solved_system.first; + Array solved_other_relations = solved_system.second; arith::Analyzer analyzer; @@ -455,14 +482,17 @@ IntConstraintsTransform DeskewRange(const IntConstraints& inequalities) { Array res_relations; // this keeps being updated during determining the range of each variable. - Map vranges = ConvertGroupedBoundToRange(inequalities->ranges); + Map vranges; + for (std::pair vr : inequalities->ranges) { + vranges.Set(vr.first, vr.second); + } analyzer.Bind(vranges); // We process variables in the reverse direction to start with the most independent one. // This order is needed to compute new ranges. for (auto it = inequalities->variables.rbegin(); it != inequalities->variables.rend(); ++it) { const Var& var = *it; - auto bnd = solved_system->ranges[var]; + auto bnd = solved_bounds[var]; // Note that we replace old vars with new ones bnd = bnd.Substitute(res_old_to_new); @@ -497,7 +527,7 @@ IntConstraintsTransform DeskewRange(const IntConstraints& inequalities) { // Add the new var to the resulting axis auto range = Range(make_zero(new_var.dtype()), best_range->extent); res_variables.push_back(new_var); - res_ranges.Set(new_var, IntGroupedBounds::range(range)); + res_ranges.Set(new_var, range); vranges.Set(new_var, range); analyzer.Bind(new_var, range); @@ -506,7 +536,7 @@ IntConstraintsTransform DeskewRange(const IntConstraints& inequalities) { } // Add the original conditions (with variables substituted) to the resulting conditions - for (const PrimExpr& old_cond : solved_system.as_conditions()) { + for (const PrimExpr& old_cond : as_conditions(solved_bounds, solved_other_relations)) { PrimExpr new_cond = analyzer.Simplify(Substitute(old_cond, res_old_to_new)); if (!is_const_int(new_cond, 1)) { // those not represented in vranges (res_ranges) @@ -523,18 +553,6 @@ IntConstraintsTransform DeskewRange(const IntConstraints& inequalities) { return transform; } -TVM_REGISTER_GLOBAL("arith.SolveLinearInequalities") -.set_body([](TVMArgs args, TVMRetValue *ret) { - if (args.size() == 1) { - *ret = SolveLinearInequalities(args[0]); - } else if (args.size() == 3) { - IntConstraints problem(args[0], args[1], args[2]); - *ret = SolveLinearInequalities(problem); - } else { - LOG(FATAL) << "arith.SolveLinearInequalities expects 1 or 3 arguments, gets " << args.size(); - } - }); - TVM_REGISTER_GLOBAL("arith.SolveInequalitiesRange") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args.size() == 1) { diff --git a/tests/python/unittest/test_arith_solve_linear_inequality.py b/tests/python/unittest/test_arith_solve_linear_inequality.py index bd4358dd6183..daee3f63e922 100644 --- a/tests/python/unittest/test_arith_solve_linear_inequality.py +++ b/tests/python/unittest/test_arith_solve_linear_inequality.py @@ -104,11 +104,11 @@ def test_simple(): assert ir.structural_equal(rel, (y_new*2) + x_new <= 10) - assert ir.structural_equal(solution.dst.ranges[x_new].find_best_range().min, 0) - assert ir.structural_equal(solution.dst.ranges[x_new].find_best_range().extent, 11) + assert ir.structural_equal(solution.dst.ranges[x_new].min, 0) + assert ir.structural_equal(solution.dst.ranges[x_new].extent, 11) - assert ir.structural_equal(solution.dst.ranges[y_new].find_best_range().min, 0) - assert ir.structural_equal(solution.dst.ranges[y_new].find_best_range().extent, 6) + assert ir.structural_equal(solution.dst.ranges[y_new].min, 0) + assert ir.structural_equal(solution.dst.ranges[y_new].extent, 6) assert ir.structural_equal(solution.src_to_dst[x], x_new + (y_new + 10)) assert ir.structural_equal(solution.src_to_dst[y], y_new) From e56d31059913c2f4409d94a772760c9be8c98aff Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Thu, 4 Jun 2020 22:08:11 -0700 Subject: [PATCH 05/33] add test case and fix for equations --- .gitignore | 1 + python/tvm/arith/__init__.py | 2 +- python/tvm/arith/int_solver.py | 13 +-- src/arith/solve_linear_inequality.cc | 94 +++++++++++++++---- .../test_arith_solve_linear_inequality.py | 29 ++++-- 5 files changed, 100 insertions(+), 39 deletions(-) diff --git a/.gitignore b/.gitignore index 068cb87484a0..9892a2ab1ec3 100644 --- a/.gitignore +++ b/.gitignore @@ -195,6 +195,7 @@ tvm_t.* .python_history .pytest_cache .local +cmake-build-debug # Visual Studio Code .vscode diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index e4797b6108b1..e5af52938f5c 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -20,4 +20,4 @@ from .analyzer import ModularSet, ConstIntBound, Analyzer from .bound import deduce_bound from .pattern import detect_linear_equation, detect_clip_bound -from .int_solver import solve_linear_equations, solve_linear_inequalities, solve_linear_inequalities2 +from .int_solver import solve_linear_equations, solve_linear_inequalities diff --git a/python/tvm/arith/int_solver.py b/python/tvm/arith/int_solver.py index dba67fde78b6..4519d49a2156 100644 --- a/python/tvm/arith/int_solver.py +++ b/python/tvm/arith/int_solver.py @@ -131,13 +131,8 @@ def solve_linear_equations(equations, variables=None, ranges=None): return _ffi_api.SolveLinearEquations(variables, ranges, equations) -def solve_linear_inequalities(equations, variables=None, ranges=None): +def solve_linear_inequalities(equations, variables=None, ranges=None, deskew_range=False): + solver = _ffi_api.DeskewRange if deskew_range else _ffi_api.SolveInequalitiesRange if isinstance(equations, IntConstraints): - return _ffi_api.DeskewRange(equations) - return _ffi_api.DeskewRange(variables, ranges, equations) - - -def solve_linear_inequalities2(equations, variables=None, ranges=None): - if isinstance(equations, IntConstraints): - return _ffi_api.SolveInequalitiesRange(equations) - return _ffi_api.SolveInequalitiesRange(variables, ranges, equations) + return solver(equations) + return solver(variables, ranges, equations) diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 983c41e0198e..6bc278b13fd6 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -44,17 +44,59 @@ namespace arith { using namespace tvm::runtime; using namespace tvm::te; -/* +#define PLUS_ONE(OP) \ + void VisitExpr_(const OP* op) final { \ + num_symbol++; \ + } + +#define PLUS_ONE_BINARY(OP) \ + void VisitExpr_(const OP* op) final { \ + num_symbol++; \ + VisitExpr(op->a); \ + VisitExpr(op->b); \ + } + +class ExprComplexity : public ExprVisitor { + public: + size_t Eval(const PrimExpr& expr) { + VisitExpr(expr); + return num_symbol; + } + + PLUS_ONE_BINARY(AddNode) + PLUS_ONE_BINARY(SubNode) + PLUS_ONE_BINARY(MulNode) + PLUS_ONE_BINARY(DivNode) + PLUS_ONE_BINARY(ModNode) + PLUS_ONE_BINARY(FloorDivNode) + PLUS_ONE_BINARY(FloorModNode) + PLUS_ONE_BINARY(MinNode) + PLUS_ONE_BINARY(MaxNode) + PLUS_ONE_BINARY(EQNode) + PLUS_ONE_BINARY(NENode) + PLUS_ONE_BINARY(LTNode) + PLUS_ONE_BINARY(LENode) + PLUS_ONE_BINARY(GTNode) + PLUS_ONE_BINARY(GENode) + PLUS_ONE_BINARY(AndNode) + PLUS_ONE_BINARY(OrNode) + PLUS_ONE(VarNode) + PLUS_ONE(FloatImmNode) + PLUS_ONE(IntImmNode) + void VisitExpr_(const NotNode* op) final { + num_symbol++; + VisitExpr(op->a); + } + + private: + size_t num_symbol{0}; +}; + struct ExprLess { bool operator()(const PrimExpr& l, const PrimExpr& r) const { - // FIXME: - // After https://github.com/apache/incubator-tvm/pull/5206 - // we no longer have ExprLess, - // it was comparing VarNode* raw pointers - return Compare(l, r) < 0; + return ExprComplexity().Eval(l) < ExprComplexity().Eval(r); } }; -*/ /*! * \brief Combine the information into an array of (in)equalities. @@ -161,12 +203,12 @@ void AddInequality(std::unordered_set inequality_set.insert(new_ineq); } -void ClassifyByPolarity(const Var &var, - std::unordered_set ¤t_ineq_set, - std::unordered_set &next_ineq_set, - std::vector &rest, - std::vector > &coef_pos, - std::vector > &coef_neg, +void ClassifyByPolarity(const Var& var, + std::unordered_set& current_ineq_set, + std::unordered_set& next_ineq_set, + std::vector& rest, + std::vector >& coef_pos, + std::vector >& coef_neg, Analyzer &analyzer) { // Take formulas from current_ineq_set and classify them according to polarity wrt var // and store to coef_pos and coef_neg respectively. @@ -278,7 +320,13 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t coef_pos.push_back({1, -range_ubound}); } - ClassifyByPolarity(v, current_ineq_set_to_solve, next_ineq_set_to_solve, rest, coef_pos, coef_neg, analyzer); + ClassifyByPolarity(v, + current_ineq_set_to_solve, + next_ineq_set_to_solve, + rest, + coef_pos, + coef_neg, + analyzer); DebugPrint(current_ineq_set_to_solve, next_ineq_set_to_solve, @@ -364,11 +412,13 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t std::unordered_set equal; equal.reserve(std::min(upper_bounds.size(), lower_bounds.size())); MoveEquality(upper_bounds, lower_bounds, equal); + std::vector equal_list(equal.begin(), equal.end()); + std::sort(equal_list.begin(), equal_list.end(), ExprLess()); // Write it to the result. IntGroupedBounds bnds(make_const(v.dtype(), coef_lcm), Array(lower_bounds.begin(), lower_bounds.end()), - Array(equal.begin(), equal.end()), + Array(equal_list.begin(), equal_list.end()), Array(upper_bounds.begin(), upper_bounds.end()) ); res_bounds.Set(v, bnds); @@ -444,8 +494,10 @@ IntConstraints SolveInequalitiesRange(const IntConstraints& inequalities) { auto best_range = bnd.FindBestRange(vranges); LOG(INFO) << "best range for " << var << " = " << best_range; - res_ranges.Set(var, best_range); - vranges.Set(var, best_range); + if (best_range.defined()) { + res_ranges.Set(var, best_range); + vranges.Set(var, best_range); + } } } @@ -500,7 +552,6 @@ IntConstraintsTransform DeskewRange(const IntConstraints& inequalities) { // There is an equation of the form `v == expr`, so this variable can be completely removed. // Note that we use the 0-th expression because they are ordered by complexity, so it must be // the simplest one. - // TODO res_old_to_new.Set(var, bnd->equal[0]); } else { if (vranges.count(var) > 0) { @@ -512,8 +563,11 @@ IntConstraintsTransform DeskewRange(const IntConstraints& inequalities) { std::string suffix = ".shifted"; Var new_var = var.copy_with_suffix(suffix); - - if (is_const_int(best_range->extent, 1)) { + if (!best_range.defined()) { + res_old_to_new.Set(var, var); + res_new_to_old.Set(var, var); + res_variables.push_back(var); + } else if (is_const_int(best_range->extent, 1)) { // Don't create an itervar, just replace it everywhere with its min res_old_to_new.Set(var, best_range->min); } else { diff --git a/tests/python/unittest/test_arith_solve_linear_inequality.py b/tests/python/unittest/test_arith_solve_linear_inequality.py index daee3f63e922..71db7767c599 100644 --- a/tests/python/unittest/test_arith_solve_linear_inequality.py +++ b/tests/python/unittest/test_arith_solve_linear_inequality.py @@ -95,7 +95,7 @@ def test_simple(): solution = arith.solve_linear_inequalities([ tvm.tir.LE(x + y, 20), tvm.tir.GE(x - y, 10), - ], [x, y], ranges) + ], [x, y], ranges, deskew_range=True) print(solution) @@ -115,7 +115,7 @@ def test_simple(): assert ir.structural_equal(solution.dst_to_src[x_new], x - y - 10) assert ir.structural_equal(solution.dst_to_src[y_new], y) - sol = arith.solve_linear_inequalities2([ + sol = arith.solve_linear_inequalities([ tvm.tir.LE(x + y, 20), tvm.tir.GE(x - y, 10), ], [x, y], ranges) @@ -133,29 +133,40 @@ def test_equal(): print(solution) - sol = arith.solve_linear_inequalities2([ + sol = arith.solve_linear_inequalities([ tvm.tir.GE(x + y, 10), tvm.tir.GE(x - y, 2), tvm.tir.LE(x, 6), - ], [x, y]) + ], [x, y], deskew_range=True) print(sol) def test_multi_equal(): - x, y = te.var("x"), te.var("y") + x, y, z = te.var("x"), te.var("y"), te.var("z") solution = arith.solve_linear_inequalities([ tvm.tir.LE(x, 6), tvm.tir.GE(x, 6), - tvm.tir.GE(x - 2 * y, 0), - tvm.tir.LE(x - 2 * y, 0), - ], [x, y]) + tvm.tir.GE(x - z * y, 0), + tvm.tir.LE(x - z * y, 0), + ], [x, y, z], deskew_range=True) print(solution) + # TODO: should it be y & z ? + assert ir.structural_equal(solution.src_to_dst[y], y) + assert ir.structural_equal(solution.src_to_dst[z], z) + assert ir.structural_equal(solution.src_to_dst[x], 6) + + print(arith.solve_linear_inequalities([ + tvm.tir.LE(x, 6), + tvm.tir.GE(x, 6), + tvm.tir.GE(x - z * y, 0), + tvm.tir.LE(x - z * y, 0), + ], [x, y, z])) if __name__ == "__main__": # test_solve_system_of_inequalities() test_simple() test_equal() - # test_multi_equal() + test_multi_equal() From 1b216e9cf91e445b98971bdfd8395044de8c623d Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 5 Jun 2020 17:35:37 -0700 Subject: [PATCH 06/33] add random consistency test cases --- src/arith/int_constraints.cc | 4 + src/arith/solve_linear_inequality.cc | 19 ++- .../test_arith_solve_linear_inequality.py | 121 +++++++++++------- 3 files changed, 97 insertions(+), 47 deletions(-) diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index f0b74cc1ab12..f29130bc5ad0 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -143,8 +143,12 @@ Range IntGroupedBounds::FindBestRange(const Map& vranges_addl) const PrimExpr diff_2 = analyzer.Simplify(floordiv(upp, operator->()->coef) - low_divided); PrimExpr diff_over_2 = analyzer.Simplify(EvalSet(diff_2, var_intsets).max()); + LOG(INFO) << "upp = " << upp << " low = " << low; + LOG(INFO) << "diff_1 = " << diff_1 << " diff_over_1 = " << diff_over_1; + LOG(INFO) << "diff_2 = " << diff_2 << " diff_over_2 = " << diff_over_2; PrimExpr diff_over = analyzer.CanProve(diff_over_2 - diff_over_1 < 0) ? diff_over_2 : diff_over_1; + LOG(INFO) << "diff_over = " << diff_over; // If it is provable that the new one is strictly better than the current best one, // then replace it. Note that we are biased towards earlier pairs which should be simpler. diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 6bc278b13fd6..3dbcfe94dbdd 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -33,10 +33,7 @@ #include #include -// TODO -#include "../tir/pass/ir_util.h" -// TODO: testing -// https://github.com/sergei-grechanik/tvm/blob/8cad2d1e62272b3e192bfe08b896e07bc9550e94/tests/python/unittest/test_pass_zero_elimination.py#L367 +// TODO: supersimplify namespace tvm { namespace arith { @@ -607,6 +604,20 @@ IntConstraintsTransform DeskewRange(const IntConstraints& inequalities) { return transform; } +TVM_REGISTER_GLOBAL("arith.SolveInequalitiesAsCondition") +.set_body([](TVMArgs args, TVMRetValue *ret) { + PartialSolvedInequalities ret_ineq; + if (args.size() == 1) { + ret_ineq = SolveLinearInequalities(args[0]); + } else if (args.size() == 3) { + IntConstraints problem(args[0], args[1], args[2]); + ret_ineq = SolveLinearInequalities(problem); + } else { + LOG(FATAL) << "arith.SolveInequalitiesAsCondition expects 1 or 3 arguments, gets " << args.size(); + } + *ret = as_conditions(ret_ineq.first, ret_ineq.second); +}); + TVM_REGISTER_GLOBAL("arith.SolveInequalitiesRange") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args.size() == 1) { diff --git a/tests/python/unittest/test_arith_solve_linear_inequality.py b/tests/python/unittest/test_arith_solve_linear_inequality.py index 71db7767c599..9e3c2a6e9471 100644 --- a/tests/python/unittest/test_arith_solve_linear_inequality.py +++ b/tests/python/unittest/test_arith_solve_linear_inequality.py @@ -22,6 +22,42 @@ from tvm import te, arith, ir, tir +def run_expr(expr, vranges): + """ Evaluate expr for every value of free variables + given by vranges and return the tensor of results. + TODO(yzhliu): move to utils + """ + def _compute_body(*us): + vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)} + return tir.ir_pass.Substitute(expr, vmap) + + A = te.compute([r.extent.value for v, r in vranges.items()], _compute_body) + args = [tvm.nd.empty(A.shape, A.dtype)] + sch = te.create_schedule(A.op) + mod = tvm.build(sch, [A]) + mod(*args) + return args[0].asnumpy() + + +def check_bruteforce(bool_expr, vranges, cond=None): + """ Check that bool_expr holds given the condition cond + for every value of free variables from vranges. + TODO(yzhliu): move to utils + """ + if cond is not None: + bool_expr = te.any(tir.Not(cond), bool_expr) + + res = run_expr(bool_expr, vranges) + if not np.all(res): + indices = list(np.argwhere(res == 0)[0]) + counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)] + counterex = sorted(counterex, key=lambda x: x[0]) + counterex = ", ".join([v + " = " + str(i) for v, i in counterex]) + raise AssertionError("Expression {}\nis not true on {}\n" + "Counterexample: {}" + .format(tir.ir_pass.CanonicalSimplify(bool_expr), vranges, counterex)) + + def test_solve_system_of_inequalities(): random.seed(0) @@ -38,54 +74,48 @@ def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)): fs.append(op(s1, s2)) vranges = {v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in vs} + before = te.all(tir.const(1, 'bool'), *fs) print("--- before ---") print(fs) - after = arith.solve_linear_inequalities(fs, vs, vranges) + after = arith._ffi_api.SolveInequalitiesAsCondition(vs, vranges, fs) + after = te.all(tir.const(1, 'bool'), *after) print("--- after ---") print(after) print() - # check_bruteforce(before == after, vranges) - - _check(2, 2) - - # for i in range(3): - # _check(1, 1) - # for i in range(3): - # _check(1, 2) - # - # for i in range(3): - # _check(2, 1) - # for i in range(3): - # _check(2, 2) - # for i in range(3): - # _check(2, 3) - # - # # Somewhere here coefficients in the results become too large, leading to overflow, - # # so we use smaller initial coefficients - # - # for i in range(5): - # _check(3, 3, coef=(-2,2)) - # for i in range(5): - # _check(3, 4, coef=(-2,2)) - # - # for i in range(5): - # _check(4, 3, coef=(-1,1)) - # - # for i in range(5): - # _check(10, 2, coef=(-1,1), bounds=(0, 4)) - # for i in range(5): - # _check(10, 3, coef=(0,1), bounds=(0, 4)) + check_bruteforce(before == after, vranges) + + for i in range(3): + _check(1, 1) + for i in range(3): + _check(1, 2) + + for i in range(3): + _check(2, 1) + for i in range(3): + _check(2, 2) + for i in range(3): + _check(2, 3) + + # Somewhere here coefficients in the results become too large, leading to overflow, + # so we use smaller initial coefficients + for i in range(5): + _check(3, 3, coef=(-2, 2)) + for i in range(5): + _check(3, 4, coef=(-2, 2)) + + for i in range(5): + _check(4, 3, coef=(-1, 1)) + + for i in range(5): + _check(10, 2, coef=(-1, 1), bounds=(0, 4)) + for i in range(5): + _check(10, 3, coef=(0, 1), bounds=(0, 4)) def test_simple(): x, y = te.var("x"), te.var("y") - # TODO: following will hang forever - # ranges = { - # x: tvm.ir.Range(-100, 0), - # y: tvm.ir.Range(0, 100), - # } ranges = { x: tvm.ir.Range(-100, 100), @@ -120,6 +150,12 @@ def test_simple(): tvm.tir.GE(x - y, 10), ], [x, y], ranges) print(sol) + # 0 <= y <=5 + assert sol.ranges[y].min == 0 + assert sol.ranges[y].extent == 6 + # y + 10 <= x <= 20 - y + assert ir.structural_equal(sol.ranges[x].min, y + 10) + assert sol.ranges[x].extent == 11 # max(10 - 2y) def test_equal(): @@ -152,10 +188,9 @@ def test_multi_equal(): ], [x, y, z], deskew_range=True) print(solution) - # TODO: should it be y & z ? - assert ir.structural_equal(solution.src_to_dst[y], y) - assert ir.structural_equal(solution.src_to_dst[z], z) - assert ir.structural_equal(solution.src_to_dst[x], 6) + assert solution.src_to_dst[y] == y + assert solution.src_to_dst[z] == z + assert solution.src_to_dst[x] == 6 print(arith.solve_linear_inequalities([ tvm.tir.LE(x, 6), @@ -168,5 +203,5 @@ def test_multi_equal(): if __name__ == "__main__": # test_solve_system_of_inequalities() test_simple() - test_equal() - test_multi_equal() + # test_equal() + # test_multi_equal() From 059ea2e966f0c2ba9f6fc238e16102e58fc46c3d Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 5 Jun 2020 23:44:49 -0700 Subject: [PATCH 07/33] improve test cases --- include/tvm/arith/analyzer.h | 2 +- src/arith/analyzer.cc | 12 ++- src/arith/solve_linear_inequality.cc | 14 +-- .../test_arith_solve_linear_inequality.py | 97 ++++++++++--------- 4 files changed, 67 insertions(+), 58 deletions(-) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 3a71e5eb5fbf..e0ab48920de0 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -459,7 +459,7 @@ class Analyzer { * * \note Analyzer will call into sub-analyzers to get the result. */ - PrimExpr Simplify(const PrimExpr& expr); + PrimExpr Simplify(const PrimExpr& expr, size_t repeat=1); }; } // namespace arith diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 83dfc64009cf..fe73dd7cf43b 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -107,11 +107,15 @@ bool Analyzer::CanProve(const PrimExpr& expr) { return false; } -PrimExpr Analyzer::Simplify(const PrimExpr& expr) { +PrimExpr Analyzer::Simplify(const PrimExpr& expr, size_t repeat) { if (tir::is_const(expr)) return expr; - auto res = this->rewrite_simplify(expr); - if (tir::is_const(res)) return res; - res = this->canonical_simplify(res); + PrimExpr res = expr; + for (size_t i = 0; i < repeat; ++i) { + res = this->rewrite_simplify(res); + if (tir::is_const(res)) return res; + res = this->canonical_simplify(res); + if (tir::is_const(res)) return res; + } return res; } diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 3dbcfe94dbdd..03d81f526c25 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -42,13 +42,12 @@ using namespace tvm::runtime; using namespace tvm::te; #define PLUS_ONE(OP) \ - void VisitExpr_(const OP* op) final { \ - num_symbol++; \ + void VisitExpr_(const OP* op) final { num_symbols++; \ } #define PLUS_ONE_BINARY(OP) \ void VisitExpr_(const OP* op) final { \ - num_symbol++; \ + num_symbols++; \ VisitExpr(op->a); \ VisitExpr(op->b); \ } @@ -57,7 +56,7 @@ class ExprComplexity : public ExprVisitor { public: size_t Eval(const PrimExpr& expr) { VisitExpr(expr); - return num_symbol; + return num_symbols; } PLUS_ONE_BINARY(AddNode) @@ -81,12 +80,12 @@ class ExprComplexity : public ExprVisitor { PLUS_ONE(FloatImmNode) PLUS_ONE(IntImmNode) void VisitExpr_(const NotNode* op) final { - num_symbol++; + num_symbols++; VisitExpr(op->a); } private: - size_t num_symbol{0}; + size_t num_symbols{0}; }; struct ExprLess { @@ -340,9 +339,10 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // eliminate the current variable PrimExpr new_lhs = c_neg*neg.second - c_pos*pos.second; PrimExpr new_ineq = LENode::make(new_lhs, make_zero(pos.second.dtype())); + // we need to do analyzer.rewrite_simplify(analyzer.Simplify(new_ineq)) // it helps to simplify (((y + 10) - (-1*(y - 20))) <= 0) => y - 5 <= 0 // otherwise it's (y*2) - 10 <= 0 - new_ineq = NormalizeComparisons()(analyzer.rewrite_simplify(analyzer.Simplify(new_ineq))); + new_ineq = NormalizeComparisons()(analyzer.Simplify(new_ineq, 2)); AddInequality(next_ineq_set_to_solve, new_ineq, analyzer); } } diff --git a/tests/python/unittest/test_arith_solve_linear_inequality.py b/tests/python/unittest/test_arith_solve_linear_inequality.py index 9e3c2a6e9471..3022647cb9ec 100644 --- a/tests/python/unittest/test_arith_solve_linear_inequality.py +++ b/tests/python/unittest/test_arith_solve_linear_inequality.py @@ -114,94 +114,99 @@ def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)): _check(10, 3, coef=(0, 1), bounds=(0, 4)) -def test_simple(): +def test_dual_variable(): x, y = te.var("x"), te.var("y") + variables = [x, y] ranges = { x: tvm.ir.Range(-100, 100), y: tvm.ir.Range(0, 10), } + problem = [ + tvm.tir.LE(x + y, 20), + tvm.tir.GE(x - y, 10), + ] + + # solution as conditions + solution = arith._ffi_api.SolveInequalitiesAsCondition(variables, ranges, problem) + assert len(solution) == 4 + assert ir.structural_equal(solution[0], y >= 0) + assert ir.structural_equal(solution[1], y <= 5) + assert ir.structural_equal(solution[2], x >= (y + 10)) + assert ir.structural_equal(solution[3], x <= (20 - y)) + # solve and get the ranges solution = arith.solve_linear_inequalities([ tvm.tir.LE(x + y, 20), tvm.tir.GE(x - y, 10), - ], [x, y], ranges, deskew_range=True) - + ], [x, y], ranges) print(solution) + # 0 <= y <=5 + assert solution.ranges[y].min == 0 + assert solution.ranges[y].extent == 6 + # y + 10 <= x <= 20 - y + assert ir.structural_equal(solution.ranges[x].min, y + 10) + assert solution.ranges[x].extent == 11 # max(10 - 2y) + # deskew the solved ranges to be starting from zero + solution = arith.solve_linear_inequalities(problem, variables, ranges, deskew_range=True) + print(solution) [x_new, y_new] = solution.dst.variables [rel] = solution.dst.relations - assert ir.structural_equal(rel, (y_new*2) + x_new <= 10) - assert ir.structural_equal(solution.dst.ranges[x_new].min, 0) assert ir.structural_equal(solution.dst.ranges[x_new].extent, 11) - assert ir.structural_equal(solution.dst.ranges[y_new].min, 0) assert ir.structural_equal(solution.dst.ranges[y_new].extent, 6) - assert ir.structural_equal(solution.src_to_dst[x], x_new + (y_new + 10)) assert ir.structural_equal(solution.src_to_dst[y], y_new) assert ir.structural_equal(solution.dst_to_src[x_new], x - y - 10) assert ir.structural_equal(solution.dst_to_src[y_new], y) - sol = arith.solve_linear_inequalities([ - tvm.tir.LE(x + y, 20), - tvm.tir.GE(x - y, 10), - ], [x, y], ranges) - print(sol) - # 0 <= y <=5 - assert sol.ranges[y].min == 0 - assert sol.ranges[y].extent == 6 - # y + 10 <= x <= 20 - y - assert ir.structural_equal(sol.ranges[x].min, y + 10) - assert sol.ranges[x].extent == 11 # max(10 - 2y) - def test_equal(): x, y = te.var("x"), te.var("y") - - solution = arith.solve_linear_inequalities([ + problem = [ tvm.tir.GE(x + y, 10), tvm.tir.GE(x - y, 2), tvm.tir.LE(x, 6), - ], [x, y]) - - print(solution) - - sol = arith.solve_linear_inequalities([ - tvm.tir.GE(x + y, 10), - tvm.tir.GE(x - y, 2), - tvm.tir.LE(x, 6), - ], [x, y], deskew_range=True) - print(sol) + ] + + solution = arith.solve_linear_inequalities(problem, [x, y]) + assert solution.ranges[x].min == 6 + assert solution.ranges[x].extent == 1 + assert solution.ranges[y].min == 4 + assert solution.ranges[y].extent == 1 + + solution = arith.solve_linear_inequalities(problem, [x, y], deskew_range=True) + assert len(solution.dst.variables) == 0 + assert len(solution.dst.ranges) == 0 + assert len(solution.dst.relations) == 0 + assert solution.src_to_dst[x] == 6 + assert solution.src_to_dst[y] == 4 def test_multi_equal(): x, y, z = te.var("x"), te.var("y"), te.var("z") - - solution = arith.solve_linear_inequalities([ + problem = [ tvm.tir.LE(x, 6), tvm.tir.GE(x, 6), tvm.tir.GE(x - z * y, 0), tvm.tir.LE(x - z * y, 0), - ], [x, y, z], deskew_range=True) + ] - print(solution) + solution = arith.solve_linear_inequalities(problem, [x, y, z]) + assert solution.ranges[x].min == 6 + assert solution.ranges[x].extent == 1 + + solution = arith.solve_linear_inequalities(problem, [x, y, z], deskew_range=True) assert solution.src_to_dst[y] == y assert solution.src_to_dst[z] == z assert solution.src_to_dst[x] == 6 - print(arith.solve_linear_inequalities([ - tvm.tir.LE(x, 6), - tvm.tir.GE(x, 6), - tvm.tir.GE(x - z * y, 0), - tvm.tir.LE(x - z * y, 0), - ], [x, y, z])) - if __name__ == "__main__": - # test_solve_system_of_inequalities() - test_simple() - # test_equal() - # test_multi_equal() + test_solve_system_of_inequalities() + test_dual_variable() + test_equal() + test_multi_equal() From 9c5becc60dcaf2a68fb222e8799c733a1c49d939 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 6 Jun 2020 22:20:49 -0700 Subject: [PATCH 08/33] add doc & comments --- include/tvm/arith/analyzer.h | 7 +- include/tvm/arith/int_solver.h | 52 ++++++--- python/tvm/arith/int_solver.py | 58 ++++++++-- src/arith/analyzer.cc | 6 +- src/arith/int_constraints.cc | 71 ++++++------ src/arith/solve_linear_inequality.cc | 165 ++++++++++++++------------- 6 files changed, 205 insertions(+), 154 deletions(-) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index e0ab48920de0..630cd8a11232 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -455,11 +455,16 @@ class Analyzer { * \brief Simplify expr. * * \param expr The expression to be simplified. + * \param steps The simplification runs in the order of + * rewrite_simplify (step 1) -> canonical_simplify (step 2) -> + * rewrite_simplify (step 3) -> canonical_simplify (step 4) -> ... + * param steps controls how many steps to run. + * Default is 2, i.e., rewrite_simplify + canonical_simplify. * \return The result. * * \note Analyzer will call into sub-analyzers to get the result. */ - PrimExpr Simplify(const PrimExpr& expr, size_t repeat=1); + PrimExpr Simplify(const PrimExpr& expr, size_t steps=2); }; } // namespace arith diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index 627e4bbe83fe..4b3727f3fd38 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -38,7 +38,16 @@ using tir::Var; using tir::VarNode; using tir::IterVar; -class IntGroupedBoundsNode : public Object { +/*! + * \brief Represent integer grouped bounds which are classified into + * lower bounds (include), upper bounds (include) and equalities. + * It also contains coefficient as a multiplier for the bounds, i.e., + * coef * var >= lower + * coef * var == equal + * coef * var <= upper + * \sa IntGrpBounds + */ +class IntGrpBoundsNode : public Object { public: PrimExpr coef; Array lower; @@ -52,7 +61,7 @@ class IntGroupedBoundsNode : public Object { v->Visit("upper", &upper); } - bool SEqualReduce(const IntGroupedBoundsNode* other, SEqualReducer eq) const { + bool SEqualReduce(const IntGrpBoundsNode* other, SEqualReducer eq) const { return eq(coef, other->coef) && eq(lower, other->lower) && @@ -68,41 +77,48 @@ class IntGroupedBoundsNode : public Object { } static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const char* _type_key = "arith.IntGroupedBounds"; - TVM_DECLARE_FINAL_OBJECT_INFO(IntGroupedBoundsNode, Object); + static constexpr const char* _type_key = "arith.IntGrpBounds"; + TVM_DECLARE_FINAL_OBJECT_INFO(IntGrpBoundsNode, Object); }; /*! - * \brief Managed reference to IntGroupedBoundsNode. - * \sa IntGroupedBoundsNode + * \brief Managed reference to IntGrpBoundsNode. + * \sa IntGrpBoundsNode */ -class IntGroupedBounds : public ObjectRef { +class IntGrpBounds : public ObjectRef { public: - /*! TODO: comments + /*! * \brief Constructor by fields + * \param coef The coefficient. Must be integer. + * coef * var >= lower + * coef * var == equal + * coef * var >= upper + * \param lower the lower bounds (include) + * \param equal equalities + * \param upper the upper bounds (include) */ - TVM_DLL IntGroupedBounds(PrimExpr coef, - Array lower, - Array equal, - Array upper); + TVM_DLL IntGrpBounds(PrimExpr coef, + Array lower, + Array equal, + Array upper); /*! * \brief Construct bounds from a range. * \param r The range * \return constructed bounds. */ - static IntGroupedBounds range(const Range& r); + static IntGrpBounds range(const Range& r); /*! * \brief Perform substitution on all components of the struct. */ - IntGroupedBounds Substitute(const Map& subst) const; + IntGrpBounds Substitute(const Map& subst) const; Range FindBestRange(const Map& vranges_addl = {}) const; - IntGroupedBounds operator+(const Range& r); + IntGrpBounds operator+(const Range& r); - TVM_DEFINE_OBJECT_REF_METHODS(IntGroupedBounds, ObjectRef, IntGroupedBoundsNode); + TVM_DEFINE_OBJECT_REF_METHODS(IntGrpBounds, ObjectRef, IntGrpBoundsNode); }; /*! @@ -238,9 +254,7 @@ class IntConstraintsTransform : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode); }; -Map ConvertGroupedBoundToRange(Map bounds); - -typedef std::pair, Array > PartialSolvedInequalities; +typedef std::pair, Array > PartialSolvedInequalities; /*! * \brief Obtain Smith Normal Form of linear equation A x = y. diff --git a/python/tvm/arith/int_solver.py b/python/tvm/arith/int_solver.py index 4519d49a2156..ad3a49e2966a 100644 --- a/python/tvm/arith/int_solver.py +++ b/python/tvm/arith/int_solver.py @@ -20,24 +20,37 @@ from . import _ffi_api -@tvm._ffi.register_object("arith.IntGroupedBounds") -class IntGroupedBounds(Object): - # TODO: doc +@tvm._ffi.register_object("arith.IntGrpBounds") +class IntGrpBounds(Object): + """Represent integer grouped bounds which are classified into + lower bounds (include), upper bounds (include) and equalities. + + Parameters + ---------- + coef : tvm.ir.PrimExpr + The coefficient. Must be integer. + coef * var >= lower + coef * var == equal + coef * var >= upper + lower : List[tvm.ir.PrimExpr] + the lower bounds (include) + equal : List[tvm.ir.PrimExpr] + equalities + upper : List[tvm.ir.PrimExpr] + the upper bounds (include) + """ def __init__(self, coef, lower, equal, upper): self.__init_handle_by_constructor__( - _ffi_api.IntGroupedBounds, coef, lower, equal, upper) + _ffi_api.IntGrpBounds, coef, lower, equal, upper) @staticmethod def make_by_range(r): """Construct a IntGroupedBounds by Range. - Parameters TODO + Parameters ---------- - min_value : PrimExpr - The minimum value of the range. + r : tvm.ir.Range - extent : PrimExpr - The extent of the range. Returns ------- @@ -48,8 +61,9 @@ def make_by_range(r): def find_best_range(self): """Return the best range from the grouped bounds. + None if (-inf, +inf). """ - return _ffi_api.IntGroupedBounds_FindBestRange(self) + return _ffi_api.IntGrpBounds_FindBestRange(self) @tvm._ffi.register_object("arith.IntConstraints") @@ -132,7 +146,29 @@ def solve_linear_equations(equations, variables=None, ranges=None): def solve_linear_inequalities(equations, variables=None, ranges=None, deskew_range=False): - solver = _ffi_api.DeskewRange if deskew_range else _ffi_api.SolveInequalitiesRange + """Solve linear inequalities. + + Parameters + ---------- + equations : List[tvm.ir.PrimExpr] or IntConstraints + The inequalities of the variables + variables : Optional[List[tvm.tir.Var]] + The variables in the system. + ranges : Optional[Map[tvm.tir.Var, tvm.ir.Range]] + The ranges of the variables. + deskew_range: Optional[bool] + Whether deskew the result ranges to be started from zero. + Default false. + + Returns + ------- + ret_ranges: IntConstraints or IntConstraintsTransform + The result ranges for each variables. + Constrains that cannot be transformed to Range will be stored in IntConstraints.relations. + If deskew_range is set (=True), the result ranges will be deskewed to be started from zero. + New variables are created accordingly therefore IntConstraintsTransform is returned. + """ + solver = _ffi_api.SolveInequalitiesDeskewRange if deskew_range else _ffi_api.SolveInequalitiesToRange if isinstance(equations, IntConstraints): return solver(equations) return solver(variables, ranges, equations) diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index fe73dd7cf43b..fe2708121c6b 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -107,12 +107,12 @@ bool Analyzer::CanProve(const PrimExpr& expr) { return false; } -PrimExpr Analyzer::Simplify(const PrimExpr& expr, size_t repeat) { +PrimExpr Analyzer::Simplify(const PrimExpr& expr, size_t steps) { if (tir::is_const(expr)) return expr; PrimExpr res = expr; - for (size_t i = 0; i < repeat; ++i) { + for (size_t i = 0; i < steps; ++i) { res = this->rewrite_simplify(res); - if (tir::is_const(res)) return res; + if (tir::is_const(res) || ++i == steps) return res; res = this->canonical_simplify(res); if (tir::is_const(res)) return res; } diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index f29130bc5ad0..a4f8cb77bda4 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -38,19 +38,13 @@ namespace tvm { namespace arith { -Map ConvertGroupedBoundToRange(Map bounds) { - Map vranges; - for (const auto& kv : bounds) { - vranges.Set(kv.first, kv.second.FindBestRange()); - } - return vranges; -} - -IntGroupedBounds::IntGroupedBounds(PrimExpr coef, - Array lower, - Array equal, - Array upper) { - ObjectPtr node = make_object(); +IntGrpBounds::IntGrpBounds(PrimExpr coef, + Array lower, + Array equal, + Array upper) { + CHECK(coef.dtype().is_int() || coef.dtype().is_uint()) + << "Coefficient in IntGrpBounds must be integers"; + ObjectPtr node = make_object(); node->coef = std::move(coef); node->lower = std::move(lower); node->equal = std::move(equal); @@ -58,7 +52,7 @@ IntGroupedBounds::IntGroupedBounds(PrimExpr coef, data_ = std::move(node); } -IntGroupedBounds IntGroupedBounds::range(const Range& r) { +IntGrpBounds IntGrpBounds::range(const Range& r) { Analyzer analyzer; PrimExpr coef = tir::make_const(r->min.dtype(), 1); Array equal; @@ -70,10 +64,10 @@ IntGroupedBounds IntGroupedBounds::range(const Range& r) { lower.push_back(r->min); upper.push_back(analyzer.Simplify(r->min + r->extent - 1)); } - return IntGroupedBounds(coef, lower, equal, upper); + return IntGrpBounds(coef, lower, equal, upper); } -IntGroupedBounds IntGroupedBounds::operator+(const Range& r) { +IntGrpBounds IntGrpBounds::operator+(const Range& r) { Analyzer analyzer; Array equal; Array lower; @@ -87,18 +81,18 @@ IntGroupedBounds IntGroupedBounds::operator+(const Range& r) { for (const auto& eq : operator->()->equal) equal.push_back(eq); for (const auto& lb : operator->()->lower) lower.push_back(lb); for (const auto& ub : operator->()->upper) upper.push_back(ub); - return IntGroupedBounds(operator->()->coef, lower, equal, upper); + return IntGrpBounds(operator->()->coef, lower, equal, upper); } -IntGroupedBounds IntGroupedBounds::Substitute(const Map& subst) const { +IntGrpBounds IntGrpBounds::Substitute(const Map& subst) const { auto apply_fun = [&subst](const PrimExpr& e) { return tir::Substitute(e, subst); }; - return IntGroupedBounds(tir::Substitute(operator->()->coef, subst), + return IntGrpBounds(tir::Substitute(operator->()->coef, subst), tir::UpdateArray(operator->()->lower, apply_fun), tir::UpdateArray(operator->()->equal, apply_fun), tir::UpdateArray(operator->()->upper, apply_fun)); } -Range IntGroupedBounds::FindBestRange(const Map& vranges_addl) const { +Range IntGrpBounds::FindBestRange(const Map& vranges_addl) const { Analyzer analyzer; analyzer.Bind(vranges_addl); @@ -117,7 +111,8 @@ Range IntGroupedBounds::FindBestRange(const Map& vranges_addl) const } if (lowers.size() == 1 && uppers.size() == 1 && te::is_one(operator->()->coef)) { - return Range(analyzer.Simplify(lowers[0]), analyzer.Simplify(uppers[0] + 1)); + return Range(analyzer.Simplify(lowers[0]), + analyzer.Simplify(uppers[0] + 1)); } // Here we will try all pairs of lower and upper bounds and find the best pair, that is, the @@ -131,24 +126,22 @@ Range IntGroupedBounds::FindBestRange(const Map& vranges_addl) const for (const PrimExpr& low : lowers) { for (const PrimExpr& upp : uppers) { - PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, operator->()->coef)); + PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, operator->()->coef), 3); // Since diff may depend on some other variables, we compute its overapproximation - PrimExpr diff_over_1 = analyzer.Simplify(EvalSet(diff_1, var_intsets).max()); + PrimExpr diff_over_1 = analyzer.Simplify(EvalSet(diff_1, var_intsets).max(), 3); // low is the lower bound for v*coef, but we need the lower bound for v. // We use rounding-up division to compute it. Since we want to use a single formula - PrimExpr low_divided = analyzer.Simplify(floordiv(low + operator->()->coef - 1, operator->()->coef)); + PrimExpr low_divided = analyzer.Simplify( + floordiv(low + operator->()->coef - 1, operator->()->coef), 3); // Compute another difference which may be more precise (or not). - PrimExpr diff_2 = analyzer.Simplify(floordiv(upp, operator->()->coef) - low_divided); - PrimExpr diff_over_2 = analyzer.Simplify(EvalSet(diff_2, var_intsets).max()); + PrimExpr diff_2 = analyzer.Simplify( + floordiv(upp, operator->()->coef) - low_divided, 3); + PrimExpr diff_over_2 = analyzer.Simplify(EvalSet(diff_2, var_intsets).max(), 3); - LOG(INFO) << "upp = " << upp << " low = " << low; - LOG(INFO) << "diff_1 = " << diff_1 << " diff_over_1 = " << diff_over_1; - LOG(INFO) << "diff_2 = " << diff_2 << " diff_over_2 = " << diff_over_2; PrimExpr diff_over = analyzer.CanProve(diff_over_2 - diff_over_1 < 0) ? diff_over_2 : diff_over_1; - LOG(INFO) << "diff_over = " << diff_over; // If it is provable that the new one is strictly better than the current best one, // then replace it. Note that we are biased towards earlier pairs which should be simpler. @@ -166,23 +159,23 @@ Range IntGroupedBounds::FindBestRange(const Map& vranges_addl) const return Range::make_by_min_extent(best_lower, analyzer.Simplify(best_diff_over + 1)); } -TVM_REGISTER_NODE_TYPE(IntGroupedBoundsNode); +TVM_REGISTER_NODE_TYPE(IntGrpBoundsNode); -TVM_REGISTER_GLOBAL("arith.IntGroupedBounds") +TVM_REGISTER_GLOBAL("arith.IntGrpBounds") .set_body_typed([](PrimExpr coef, Array lower, Array equal, Array upper) { - return IntGroupedBounds(coef, lower, equal, upper); + return IntGrpBounds(coef, lower, equal, upper); }); TVM_REGISTER_GLOBAL("arith.int_grouped_bounds_by_range") -.set_body_typed(IntGroupedBounds::range); +.set_body_typed(IntGrpBounds::range); -TVM_REGISTER_GLOBAL("arith.IntGroupedBounds_FindBestRange") +TVM_REGISTER_GLOBAL("arith.IntGrpBounds_FindBestRange") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK(args.size() == 1 || args.size() == 2); - IntGroupedBounds bounds = args[0]; + IntGrpBounds bounds = args[0]; if (args.size() == 1) { *ret = bounds.FindBestRange(); } else if (args.size() == 2) { @@ -191,9 +184,9 @@ TVM_REGISTER_GLOBAL("arith.IntGroupedBounds_FindBestRange") }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "IntGroupedBounds(coef=" +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IntGrpBounds(coef=" << op->coef << ", lower=" << op->lower << ", equal=" << op->equal diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 03d81f526c25..71ae74becd65 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -33,30 +33,29 @@ #include #include -// TODO: supersimplify - namespace tvm { namespace arith { using namespace tvm::runtime; using namespace tvm::te; -#define PLUS_ONE(OP) \ - void VisitExpr_(const OP* op) final { num_symbols++; \ - } +#define PLUS_ONE(OP) void VisitExpr_(const OP* op) final { num_symbols_++; } #define PLUS_ONE_BINARY(OP) \ void VisitExpr_(const OP* op) final { \ - num_symbols++; \ + num_symbols_++; \ VisitExpr(op->a); \ VisitExpr(op->b); \ } +/*! + * \brief Calculate the expresion complexity based on number of symbols it contains. + */ class ExprComplexity : public ExprVisitor { public: size_t Eval(const PrimExpr& expr) { VisitExpr(expr); - return num_symbols; + return num_symbols_; } PLUS_ONE_BINARY(AddNode) @@ -80,12 +79,12 @@ class ExprComplexity : public ExprVisitor { PLUS_ONE(FloatImmNode) PLUS_ONE(IntImmNode) void VisitExpr_(const NotNode* op) final { - num_symbols++; + num_symbols_++; VisitExpr(op->a); } private: - size_t num_symbols{0}; + size_t num_symbols_{0}; }; struct ExprLess { @@ -97,7 +96,8 @@ struct ExprLess { /*! * \brief Combine the information into an array of (in)equalities. */ -Array as_conditions(const Map& bounds, const Array& relations) { +Array as_conditions(const Map& bounds, + const Array& relations) { Array res; for (const auto iter : bounds) { const Var& v = iter.first; @@ -149,7 +149,9 @@ void DebugPrint(std::unordered_set& c std::cout << "]\n"; } -// normalize to the form `expr <= 0` +/*! + * \brief normalize to the form `expr <= 0` + */ class NormalizeComparisons : public ExprMutator { public: PrimExpr VisitExpr_(const EQNode* op) override { return Make(op->a, op->b); } @@ -162,7 +164,6 @@ class NormalizeComparisons : public ExprMutator { private: template PrimExpr Make(const PrimExpr& a, const PrimExpr& b) { - LOG(INFO) << "a = " << a << " b = " << b; // rewrite LT to LE for ints if (std::is_same::value && (a.dtype().is_int() || a.dtype().is_uint())) { return LENode::make(analyzer_.Simplify(a - b + 1), make_zero(a.dtype())); @@ -175,7 +176,6 @@ class NormalizeComparisons : public ExprMutator { void AddInequality(std::unordered_set& inequality_set, const PrimExpr& new_ineq, Analyzer& analyzer) { - LOG(INFO) << "insert ineq " << new_ineq; if (analyzer.CanProve(new_ineq) || inequality_set.find(new_ineq) != inequality_set.end()) { // redundant: follows from the vranges // or has already been added @@ -199,13 +199,14 @@ void AddInequality(std::unordered_set inequality_set.insert(new_ineq); } -void ClassifyByPolarity(const Var& var, - std::unordered_set& current_ineq_set, - std::unordered_set& next_ineq_set, - std::vector& rest, - std::vector >& coef_pos, - std::vector >& coef_neg, - Analyzer &analyzer) { +void ClassifyByPolarity( + const Var& var, + std::unordered_set& current_ineq_set, + std::unordered_set& next_ineq_set, + std::vector& rest, + std::vector >& coef_pos, + std::vector >& coef_neg, + Analyzer &analyzer) { // Take formulas from current_ineq_set and classify them according to polarity wrt var // and store to coef_pos and coef_neg respectively. for (const PrimExpr& ineq : current_ineq_set) { @@ -228,7 +229,7 @@ void ClassifyByPolarity(const Var& var, if (!coef.empty() && is_const(coef[0])) { int64_t coef0 = *as_const_int(coef[0]); if (coef0 == 0) { - // zero polarity, straight to new_current + // zero polarity, straight to next_ineq_set AddInequality(next_ineq_set, ineq, analyzer); } else if (coef0 > 0) { // Equalities may be considered as pairs of two inequalities @@ -269,13 +270,16 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t analyzer.Bind(system_to_solve->ranges); // The algorithm consists in doing the following things for each variable v - // - Take formulas from `current_ineq_set_to_solve` and classify them according to polarity wrt v - // - Combine each formula of positive polarity (wrt v) with each formula of negative polarity - // - Put the resulting combinations into `next_ineq_set_to_solve` along with unclassifiable formulas - // - Replace `current` with `next_ineq_set_to_solve` and move to the next variable + // - Take formulas from `current_ineq_set_to_solve` and + // classify them according to polarity wrt v. + // - Combine each formula of positive polarity (wrt v) + // with each formula of negative polarity. + // - Put the resulting combinations into `next_ineq_set_to_solve` + // along with unclassifiable formulas. + // - Replace `current_ineq_set_to_solve` with `next_ineq_set_to_solve` + // and move to the next variable. // normalized inequality - // current and next_ineq_set_to_solve are sorted to enable some heuristics std::unordered_set current_ineq_set_to_solve; std::unordered_set next_ineq_set_to_solve; // A vector of pairs (c, e), c > 0, representing formulas of the form c*v + e <= 0 @@ -288,8 +292,8 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // Simplify each inequality into the form `expr <= 0` and add to current formulas for (const PrimExpr& ineq : system_to_solve->relations) { - // TODO: SuperSimplify(ineq, vranges) - AddInequality(current_ineq_set_to_solve, NormalizeComparisons()(analyzer.Simplify(ineq)), analyzer); + AddInequality(current_ineq_set_to_solve, + NormalizeComparisons()(analyzer.Simplify(ineq, 3)), analyzer); } DebugPrint(current_ineq_set_to_solve, @@ -298,7 +302,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t coef_pos, coef_neg); - Map res_bounds; + Map res_bounds; for (const Var& v : system_to_solve->variables) { CHECK(!res_bounds.count(v)) << "Variable " << v << " appears more than one time in the `variables` which might be a bug"; @@ -310,8 +314,8 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // Add bounds from vranges if (system_to_solve->ranges.count(v)) { const Range& range = system_to_solve->ranges[v]; - PrimExpr range_lbound = analyzer.Simplify(range->min); - PrimExpr range_ubound = analyzer.Simplify(range->min + range->extent - 1); + PrimExpr range_lbound = analyzer.Simplify(range->min, 3); + PrimExpr range_ubound = analyzer.Simplify(range->min + range->extent - 1, 3); coef_neg.push_back({-1, range_lbound}); coef_pos.push_back({1, -range_ubound}); } @@ -339,10 +343,10 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // eliminate the current variable PrimExpr new_lhs = c_neg*neg.second - c_pos*pos.second; PrimExpr new_ineq = LENode::make(new_lhs, make_zero(pos.second.dtype())); - // we need to do analyzer.rewrite_simplify(analyzer.Simplify(new_ineq)) - // it helps to simplify (((y + 10) - (-1*(y - 20))) <= 0) => y - 5 <= 0 - // otherwise it's (y*2) - 10 <= 0 - new_ineq = NormalizeComparisons()(analyzer.Simplify(new_ineq, 2)); + // we need rewrite_simplify -> canonical_simplify -> rewrite_simplify + // to help simplify things like (((y + 10) - (-1*(y - 20))) <= 0) => y - 5 <= 0 + // with steps = 2 it's (y*2) - 10 <= 0 + new_ineq = NormalizeComparisons()(analyzer.Simplify(new_ineq, 3)); AddInequality(next_ineq_set_to_solve, new_ineq, analyzer); } } @@ -367,7 +371,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t for (const auto& pos : coef_pos) { PrimExpr bound = make_const(v.dtype(), -coef_lcm/pos.first)*pos.second; - bound = analyzer.Simplify(bound); + bound = analyzer.Simplify(bound, 3); // Don't add if any of the existing bounds is better if (std::any_of(upper_bounds.begin(), upper_bounds.end(), [&bound, &analyzer](const PrimExpr& o) @@ -387,7 +391,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t } for (const auto& neg : coef_neg) { PrimExpr bound = make_const(v.dtype(), -coef_lcm/neg.first)*neg.second; - bound = analyzer.Simplify(bound); + bound = analyzer.Simplify(bound, 3); // Don't add if any of the existing bounds is better if (std::any_of(lower_bounds.begin(), lower_bounds.end(), [&bound, &analyzer](const PrimExpr& o) @@ -413,13 +417,12 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t std::sort(equal_list.begin(), equal_list.end(), ExprLess()); // Write it to the result. - IntGroupedBounds bnds(make_const(v.dtype(), coef_lcm), + IntGrpBounds bnds(make_const(v.dtype(), coef_lcm), Array(lower_bounds.begin(), lower_bounds.end()), Array(equal_list.begin(), equal_list.end()), Array(upper_bounds.begin(), upper_bounds.end()) ); res_bounds.Set(v, bnds); - LOG(INFO) << "Bound of " << v << bnds; std::swap(current_ineq_set_to_solve, next_ineq_set_to_solve); } @@ -427,7 +430,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // Everything that is left goes to res.relations Array other_conditions; for (const PrimExpr& e : current_ineq_set_to_solve) { - PrimExpr e_simp = analyzer.Simplify(e); + PrimExpr e_simp = analyzer.Simplify(e, 3); if (is_const_int(e_simp, 0)) { // contradiction detected other_conditions = {const_false()}; @@ -446,14 +449,15 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t return {res_bounds, other_conditions}; } -IntConstraints SolveInequalitiesRange(const IntConstraints& inequalities) { +IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { // Resulting ranges will contain ranges for the new variables and for the variables that are - // not in the inequalities->variables but are in inequalities->ranges (jac_xxx) + // not in the inequalities->variables but are in inequalities->ranges + // It will be useful when solving Jacobian axes jac_xxx) Map res_ranges; // we get a set of equality, lower, upper bound of each variable. auto solved_system = SolveLinearInequalities(inequalities); - Map solved_bounds = solved_system.first; + Map solved_bounds = solved_system.first; Array solved_other_relations = solved_system.second; Array res_variables; @@ -476,20 +480,17 @@ IntConstraints SolveInequalitiesRange(const IntConstraints& inequalities) { auto bnd = solved_bounds[var]; if (is_one(bnd->coef) && !bnd->equal.empty()) { // There is an equation of the form `v == expr`, so this variable can be completely removed. - // Note that we use the 0-th expression because they are ordered by complexity, so it must be - // the simplest one. - Range best_range(bnd->equal[0], analyzer.Simplify(bnd->equal[0] + 1)); + // Note that we use the 0-th expression because they are ordered by complexity, + // so it must be the simplest one. + Range best_range(bnd->equal[0], analyzer.Simplify(bnd->equal[0] + 1, 3)); res_ranges.Set(var, best_range); vranges.Set(var, best_range); } else { if (vranges.count(var) > 0) { bnd = bnd + vranges[var]; } - LOG(INFO) << "bnd = " << bnd; - LOG(INFO) << "vranges = " << vranges; auto best_range = bnd.FindBestRange(vranges); - LOG(INFO) << "best range for " << var << " = " << best_range; if (best_range.defined()) { res_ranges.Set(var, best_range); @@ -513,20 +514,19 @@ IntConstraints SolveInequalitiesRange(const IntConstraints& inequalities) { return system; } -// Deskew the given domain -IntConstraintsTransform DeskewRange(const IntConstraints& inequalities) { +IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequalities) { // Resulting ranges will contain ranges for the new variables and for the variables that are // not in the inequalities->variables but are in inequalities->ranges (jac_xxx) Map res_ranges; // we get a set of equality, lower, upper bound of each variable. auto solved_system = SolveLinearInequalities(inequalities); - Map solved_bounds = solved_system.first; + Map solved_bounds = solved_system.first; Array solved_other_relations = solved_system.second; arith::Analyzer analyzer; - Map res_old_to_new; - Map res_new_to_old; + Map res_src_to_dst; + Map res_dst_to_src; Array res_variables; Array res_relations; @@ -543,37 +543,37 @@ IntConstraintsTransform DeskewRange(const IntConstraints& inequalities) { const Var& var = *it; auto bnd = solved_bounds[var]; // Note that we replace old vars with new ones - bnd = bnd.Substitute(res_old_to_new); + bnd = bnd.Substitute(res_src_to_dst); if (is_one(bnd->coef) && !bnd->equal.empty()) { - // There is an equation of the form `v == expr`, so this variable can be completely removed. - // Note that we use the 0-th expression because they are ordered by complexity, so it must be - // the simplest one. - res_old_to_new.Set(var, bnd->equal[0]); + // There is an equation of the form `v == expr`, + // so this variable can be completely removed. + // Note that we use the 0-th expression because they are ordered by complexity, + // so it must be the simplest one. + res_src_to_dst.Set(var, bnd->equal[0]); } else { if (vranges.count(var) > 0) { bnd = bnd + vranges[var]; } auto best_range = bnd.FindBestRange(vranges); - LOG(INFO) << "best range for " << var << " = " << best_range; - std::string suffix = ".shifted"; - Var new_var = var.copy_with_suffix(suffix); + Var new_var = var.copy_with_suffix(".shifted"); if (!best_range.defined()) { - res_old_to_new.Set(var, var); - res_new_to_old.Set(var, var); + res_src_to_dst.Set(var, var); + res_dst_to_src.Set(var, var); res_variables.push_back(var); } else if (is_const_int(best_range->extent, 1)) { // Don't create an itervar, just replace it everywhere with its min - res_old_to_new.Set(var, best_range->min); + res_src_to_dst.Set(var, best_range->min); } else { // created new_var starts from 0 - res_old_to_new.Set(var, new_var + best_range->min); - // Note that we are substituting old with new, so best_lower contains new var, - // that is we have to substitute new with old in best_lower here - res_new_to_old.Set(new_var, - analyzer.Simplify(var - Substitute(best_range->min, res_new_to_old))); + res_src_to_dst.Set(var, new_var + best_range->min); + // Note that we are substituting old with new, so best_range contains new var, + // that is we have to substitute new with old in best_range here + res_dst_to_src.Set(new_var, + analyzer.Simplify( + var - Substitute(best_range->min, res_dst_to_src))); // Add the new var to the resulting axis auto range = Range(make_zero(new_var.dtype()), best_range->extent); @@ -588,7 +588,7 @@ IntConstraintsTransform DeskewRange(const IntConstraints& inequalities) { // Add the original conditions (with variables substituted) to the resulting conditions for (const PrimExpr& old_cond : as_conditions(solved_bounds, solved_other_relations)) { - PrimExpr new_cond = analyzer.Simplify(Substitute(old_cond, res_old_to_new)); + PrimExpr new_cond = analyzer.Simplify(Substitute(old_cond, res_src_to_dst)); if (!is_const_int(new_cond, 1)) { // those not represented in vranges (res_ranges) res_relations.push_back(new_cond); @@ -599,7 +599,7 @@ IntConstraintsTransform DeskewRange(const IntConstraints& inequalities) { res_variables = Array(res_variables.rbegin(), res_variables.rend()); IntConstraints new_inequalities(res_variables, res_ranges, res_relations); - IntConstraintsTransform transform(inequalities, new_inequalities, res_old_to_new, res_new_to_old); + IntConstraintsTransform transform(inequalities, new_inequalities, res_src_to_dst, res_dst_to_src); return transform; } @@ -613,32 +613,35 @@ TVM_REGISTER_GLOBAL("arith.SolveInequalitiesAsCondition") IntConstraints problem(args[0], args[1], args[2]); ret_ineq = SolveLinearInequalities(problem); } else { - LOG(FATAL) << "arith.SolveInequalitiesAsCondition expects 1 or 3 arguments, gets " << args.size(); + LOG(FATAL) << "arith.SolveInequalitiesAsCondition expects 1 or 3 arguments, gets " + << args.size(); } *ret = as_conditions(ret_ineq.first, ret_ineq.second); }); -TVM_REGISTER_GLOBAL("arith.SolveInequalitiesRange") +TVM_REGISTER_GLOBAL("arith.SolveInequalitiesToRange") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args.size() == 1) { - *ret = SolveInequalitiesRange(args[0]); + *ret = SolveInequalitiesToRange(args[0]); } else if (args.size() == 3) { IntConstraints problem(args[0], args[1], args[2]); - *ret = SolveInequalitiesRange(problem); + *ret = SolveInequalitiesToRange(problem); } else { - LOG(FATAL) << "arith.SolveInequalitiesRange expects 1 or 3 arguments, gets " << args.size(); + LOG(FATAL) << "arith.SolveInequalitiesToRange expects 1 or 3 arguments, gets " + << args.size(); } }); -TVM_REGISTER_GLOBAL("arith.DeskewRange") +TVM_REGISTER_GLOBAL("arith.SolveInequalitiesDeskewRange") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args.size() == 1) { - *ret = DeskewRange(args[0]); + *ret = SolveInequalitiesDeskewRange(args[0]); } else if (args.size() == 3) { IntConstraints problem(args[0], args[1], args[2]); - *ret = DeskewRange(problem); + *ret = SolveInequalitiesDeskewRange(problem); } else { - LOG(FATAL) << "arith.DeskewRange expects 1 or 3 arguments, gets " << args.size(); + LOG(FATAL) << "arith.SolveInequalitiesDeskewRange expects 1 or 3 arguments, gets " + << args.size(); } }); From 0fd69f769b29b929614904a064fa70d89868cfe0 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 6 Jun 2020 22:58:46 -0700 Subject: [PATCH 09/33] test case refactored --- include/tvm/arith/int_solver.h | 24 ++++++++ python/tvm/testing.py | 43 +++++++++++++ src/arith/solve_linear_inequality.cc | 13 ---- .../test_arith_solve_linear_inequality.py | 60 +++---------------- .../test_arith_solve_linear_system.py | 43 ++----------- 5 files changed, 78 insertions(+), 105 deletions(-) diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index 4b3727f3fd38..3cdabddca566 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -290,6 +290,30 @@ void SmithNormalFormDiag(std::vector> *S, */ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve); +/*! + * \brief Solve linear inequalities. + * \param system_to_solve the variables to solve, their ranges, and a list of inequalities. + * \return A map of variables and their solved bounds, + * and constrains that cannot be solved to bounds. + */ +PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_to_solve); + +/*! + * \brief Solve linear inequalities. + * \param system_to_solve the variables to solve, their ranges, and a list of inequalities. + * \return The result ranges for each variables. + * Constrains that cannot be transformed to Range will be stored in relations. + */ +IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities); + +/*! + * \brief Solve linear inequalities. + * \param system_to_solve the variables to solve, their ranges, and a list of inequalities. + * \return Solved ranges are deskewed to be started from zero. + * New variables and the mapping are created accordingly. + */ +IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequalities); + } // namespace arith } // namespace tvm #endif // TVM_ARITH_INT_SOLVER_H_ diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 0f50636d68d8..d9207ba16f9a 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -21,6 +21,7 @@ import numpy as np import tvm import tvm._ffi +from tvm import te, tir def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): @@ -168,4 +169,46 @@ def compare_derivative(j, n_der, grad): x_name, grad.shape, dist, max_diff, avg_diff) +def check_bool_expr_is_true(bool_expr, vranges, cond=None): + """ Check that bool_expr holds given the condition cond + for every value of free variables from vranges. + + Parameters + ---------- + bool_expr : tvm.ir.expr.PrimExpr + Boolean expression to check + vranges: Dict[tvm.tir.expr.Var, tvm.ir.Range] + Free variables and their ranges + cond: tvm.ir.expr.PrimExpr + extra conditions needs to be satisfied. + """ + if cond is not None: + bool_expr = te.any(tir.Not(cond), bool_expr) + + def _run_expr(expr, vranges): + """ Evaluate expr for every value of free variables + given by vranges and return the tensor of results. + """ + def _compute_body(*us): + vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)} + return tir.ir_pass.Substitute(expr, vmap) + + A = te.compute([r.extent.value for v, r in vranges.items()], _compute_body) + args = [tvm.nd.empty(A.shape, A.dtype)] + sch = te.create_schedule(A.op) + mod = tvm.build(sch, [A]) + mod(*args) + return args[0].asnumpy() + + res = _run_expr(bool_expr, vranges) + if not np.all(res): + indices = list(np.argwhere(res == 0)[0]) + counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)] + counterex = sorted(counterex, key=lambda x: x[0]) + counterex = ", ".join([v + " = " + str(i) for v, i in counterex]) + raise AssertionError("Expression {}\nis not true on {}\n" + "Counterexample: {}" + .format(tir.ir_pass.CanonicalSimplify(bool_expr), vranges, counterex)) + + tvm._ffi._init_api("testing", __name__) diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 71ae74becd65..49a8739186c3 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -265,7 +265,6 @@ void MoveEquality(std::unordered_set& } PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_to_solve) { - LOG(INFO) << "solving inequalities " << system_to_solve; arith::Analyzer analyzer; analyzer.Bind(system_to_solve->ranges); @@ -296,12 +295,6 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t NormalizeComparisons()(analyzer.Simplify(ineq, 3)), analyzer); } - DebugPrint(current_ineq_set_to_solve, - next_ineq_set_to_solve, - rest, - coef_pos, - coef_neg); - Map res_bounds; for (const Var& v : system_to_solve->variables) { CHECK(!res_bounds.count(v)) << @@ -328,12 +321,6 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t coef_neg, analyzer); - DebugPrint(current_ineq_set_to_solve, - next_ineq_set_to_solve, - rest, - coef_pos, - coef_neg); - // Combine each positive inequality with each negative one (by adding them together) for (const auto& pos : coef_pos) { for (const auto& neg : coef_neg) { diff --git a/tests/python/unittest/test_arith_solve_linear_inequality.py b/tests/python/unittest/test_arith_solve_linear_inequality.py index 3022647cb9ec..f2a5428b33a5 100644 --- a/tests/python/unittest/test_arith_solve_linear_inequality.py +++ b/tests/python/unittest/test_arith_solve_linear_inequality.py @@ -15,51 +15,17 @@ # specific language governing permissions and limitations # under the License. import random -import numpy as np import sys import pytest import tvm -from tvm import te, arith, ir, tir - - -def run_expr(expr, vranges): - """ Evaluate expr for every value of free variables - given by vranges and return the tensor of results. - TODO(yzhliu): move to utils - """ - def _compute_body(*us): - vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)} - return tir.ir_pass.Substitute(expr, vmap) - - A = te.compute([r.extent.value for v, r in vranges.items()], _compute_body) - args = [tvm.nd.empty(A.shape, A.dtype)] - sch = te.create_schedule(A.op) - mod = tvm.build(sch, [A]) - mod(*args) - return args[0].asnumpy() - - -def check_bruteforce(bool_expr, vranges, cond=None): - """ Check that bool_expr holds given the condition cond - for every value of free variables from vranges. - TODO(yzhliu): move to utils - """ - if cond is not None: - bool_expr = te.any(tir.Not(cond), bool_expr) - - res = run_expr(bool_expr, vranges) - if not np.all(res): - indices = list(np.argwhere(res == 0)[0]) - counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)] - counterex = sorted(counterex, key=lambda x: x[0]) - counterex = ", ".join([v + " = " + str(i) for v, i in counterex]) - raise AssertionError("Expression {}\nis not true on {}\n" - "Counterexample: {}" - .format(tir.ir_pass.CanonicalSimplify(bool_expr), vranges, counterex)) +from tvm import te, arith, ir, tir, testing def test_solve_system_of_inequalities(): - random.seed(0) + seed = random.randrange(sys.maxsize) + print("\nThis test is intentionally non-deterministic, " + "if it fails please report it in github issue together with this seed {}\n".format(seed)) + random.seed(seed) def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)): vs = [te.var("x" + str(i)) for i in range(variables)] @@ -75,16 +41,9 @@ def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)): vranges = {v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in vs} before = te.all(tir.const(1, 'bool'), *fs) - - print("--- before ---") - print(fs) after = arith._ffi_api.SolveInequalitiesAsCondition(vs, vranges, fs) after = te.all(tir.const(1, 'bool'), *after) - print("--- after ---") - print(after) - print() - - check_bruteforce(before == after, vranges) + testing.check_bool_expr_is_true(before == after, vranges) for i in range(3): _check(1, 1) @@ -140,7 +99,6 @@ def test_dual_variable(): tvm.tir.LE(x + y, 20), tvm.tir.GE(x - y, 10), ], [x, y], ranges) - print(solution) # 0 <= y <=5 assert solution.ranges[y].min == 0 assert solution.ranges[y].extent == 6 @@ -150,7 +108,6 @@ def test_dual_variable(): # deskew the solved ranges to be starting from zero solution = arith.solve_linear_inequalities(problem, variables, ranges, deskew_range=True) - print(solution) [x_new, y_new] = solution.dst.variables [rel] = solution.dst.relations assert ir.structural_equal(rel, (y_new*2) + x_new <= 10) @@ -206,7 +163,4 @@ def test_multi_equal(): if __name__ == "__main__": - test_solve_system_of_inequalities() - test_dual_variable() - test_equal() - test_multi_equal() + pytest.main([__file__]) diff --git a/tests/python/unittest/test_arith_solve_linear_system.py b/tests/python/unittest/test_arith_solve_linear_system.py index 45f8fc10aaf0..b3df8c2ace4d 100644 --- a/tests/python/unittest/test_arith_solve_linear_system.py +++ b/tests/python/unittest/test_arith_solve_linear_system.py @@ -19,43 +19,7 @@ import sys import pytest import tvm -from tvm import te, arith, ir, tir - - -def run_expr(expr, vranges): - """ Evaluate expr for every value of free variables - given by vranges and return the tensor of results. - TODO(yzhliu): move to utils - """ - def _compute_body(*us): - vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)} - return tir.ir_pass.Substitute(expr, vmap) - - A = te.compute([r.extent.value for v, r in vranges.items()], _compute_body) - args = [tvm.nd.empty(A.shape, A.dtype)] - sch = te.create_schedule(A.op) - mod = tvm.build(sch, [A]) - mod(*args) - return args[0].asnumpy() - - -def check_bruteforce(bool_expr, vranges, cond=None): - """ Check that bool_expr holds given the condition cond - for every value of free variables from vranges. - TODO(yzhliu): move to utils - """ - if cond is not None: - bool_expr = te.any(tir.Not(cond), bool_expr) - - res = run_expr(bool_expr, vranges) - if not np.all(res): - indices = list(np.argwhere(res == 0)[0]) - counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)] - counterex = sorted(counterex, key=lambda x: x[0]) - counterex = ", ".join([v + " = " + str(i) for v, i in counterex]) - raise AssertionError("Expression {}\nis not true on {}\n" - "Counterexample: {}" - .format(tir.ir_pass.CanonicalSimplify(bool_expr), vranges, counterex)) +from tvm import te, arith, ir, tir, testing def check_solution(solution, vranges={}): @@ -81,8 +45,9 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap): range_cond = tir.ir_pass.Substitute(range_cond, backvarmap) cond_subst = te.all(cond_subst, range_cond) cond_subst = tir.ir_pass.Simplify(cond_subst) - check_bruteforce(te.all(cond_subst, cond_on_vars), all_vranges, - cond=te.all(tir.const(1, 'bool'), *constraints1.relations)) + testing.check_bool_expr_is_true( + te.all(cond_subst, cond_on_vars), all_vranges, + cond=te.all(tir.const(1, 'bool'), *constraints1.relations)) rels = solution.dst.relations if len(rels) == 1 and ir.structural_equal(rels[0], False): From 9af9d245d6d947a69cffa05c5a8e98daad9d14a6 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 6 Jun 2020 23:00:34 -0700 Subject: [PATCH 10/33] test file rename --- ...lve_linear_system.py => test_arith_solve_linear_equations.py} | 1 - 1 file changed, 1 deletion(-) rename tests/python/unittest/{test_arith_solve_linear_system.py => test_arith_solve_linear_equations.py} (99%) diff --git a/tests/python/unittest/test_arith_solve_linear_system.py b/tests/python/unittest/test_arith_solve_linear_equations.py similarity index 99% rename from tests/python/unittest/test_arith_solve_linear_system.py rename to tests/python/unittest/test_arith_solve_linear_equations.py index b3df8c2ace4d..d2cf9d283f83 100644 --- a/tests/python/unittest/test_arith_solve_linear_system.py +++ b/tests/python/unittest/test_arith_solve_linear_equations.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import random -import numpy as np import sys import pytest import tvm From 6349c33d1403a3c1d57dd2daa7bee5bfd2aaa38d Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sun, 7 Jun 2020 00:48:06 -0700 Subject: [PATCH 11/33] fix lint --- include/tvm/arith/analyzer.h | 2 +- include/tvm/arith/int_solver.h | 1 + python/tvm/arith/int_solver.py | 11 ++-- src/arith/solve_linear_inequality.cc | 92 ++++++++++++++-------------- 4 files changed, 54 insertions(+), 52 deletions(-) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 8fe19a96e6f5..6f9ba8f16ea5 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -473,7 +473,7 @@ class TVM_DLL Analyzer { * * \note Analyzer will call into sub-analyzers to get the result. */ - PrimExpr Simplify(const PrimExpr& expr, size_t steps=2); + PrimExpr Simplify(const PrimExpr& expr, size_t steps = 2); }; } // namespace arith diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index 757709859ae6..3e7d7abc4e65 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -29,6 +29,7 @@ #include #include #include +#include #include "analyzer.h" namespace tvm { diff --git a/python/tvm/arith/int_solver.py b/python/tvm/arith/int_solver.py index ad3a49e2966a..beae4cf5ce88 100644 --- a/python/tvm/arith/int_solver.py +++ b/python/tvm/arith/int_solver.py @@ -44,20 +44,20 @@ def __init__(self, coef, lower, equal, upper): _ffi_api.IntGrpBounds, coef, lower, equal, upper) @staticmethod - def make_by_range(r): + def make_by_range(rng): """Construct a IntGroupedBounds by Range. Parameters ---------- - r : tvm.ir.Range + rng : tvm.ir.Range Returns ------- - rng : Range + ret : Range The constructed range. """ - return _ffi_api.int_grouped_bounds_by_range(r) + return _ffi_api.int_grouped_bounds_by_range(rng) def find_best_range(self): """Return the best range from the grouped bounds. @@ -168,7 +168,8 @@ def solve_linear_inequalities(equations, variables=None, ranges=None, deskew_ran If deskew_range is set (=True), the result ranges will be deskewed to be started from zero. New variables are created accordingly therefore IntConstraintsTransform is returned. """ - solver = _ffi_api.SolveInequalitiesDeskewRange if deskew_range else _ffi_api.SolveInequalitiesToRange + solver = _ffi_api.SolveInequalitiesDeskewRange \ + if deskew_range else _ffi_api.SolveInequalitiesToRange if isinstance(equations, IntConstraints): return solver(equations) return solver(variables, ranges, equations) diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 4cfefd8e5647..29b481b14074 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -119,11 +119,12 @@ Array as_conditions(const Map& bounds, return res; } -void DebugPrint(std::unordered_set& current_ineq_set, - std::unordered_set& next_ineq_set, - std::vector& rest, - std::vector >& coef_pos, - std::vector >& coef_neg) { +void DebugPrint( + const std::unordered_set& current_ineq_set, + const std::unordered_set& next_ineq_set, + const std::vector& rest, + const std::vector >& coef_pos, + const std::vector >& coef_neg) { std::cout << "Current ineq set:\n["; for (auto& ineq : current_ineq_set) { std::cout << ineq << ", "; @@ -173,21 +174,21 @@ class NormalizeComparisons : public ExprMutator { arith::Analyzer analyzer_; }; -void AddInequality(std::unordered_set& inequality_set, +void AddInequality(std::unordered_set* inequality_set, const PrimExpr& new_ineq, - Analyzer& analyzer) { - if (analyzer.CanProve(new_ineq) || inequality_set.find(new_ineq) != inequality_set.end()) { + Analyzer* analyzer) { + if (analyzer->CanProve(new_ineq) || inequality_set->find(new_ineq) != inequality_set->end()) { // redundant: follows from the vranges // or has already been added return; } - for (auto iter = inequality_set.begin(); iter != inequality_set.end();) { + for (auto iter = inequality_set->begin(); iter != inequality_set->end();) { if (const LENode* new_le = new_ineq.as()) { const LENode* le = iter->as(); - if (le && analyzer.CanProve(new_le->a - le->a <= 0)) { + if (le && analyzer->CanProve(new_le->a - le->a <= 0)) { return; - } else if (le && analyzer.CanProve(le->a - new_le->a <= 0)) { - iter = inequality_set.erase(iter); + } else if (le && analyzer->CanProve(le->a - new_le->a <= 0)) { + iter = inequality_set->erase(iter); } else { ++iter; } @@ -196,17 +197,17 @@ void AddInequality(std::unordered_set } } - inequality_set.insert(new_ineq); + inequality_set->insert(new_ineq); } void ClassifyByPolarity( const Var& var, - std::unordered_set& current_ineq_set, - std::unordered_set& next_ineq_set, - std::vector& rest, - std::vector >& coef_pos, - std::vector >& coef_neg, - Analyzer &analyzer) { + const std::unordered_set& current_ineq_set, + std::unordered_set* next_ineq_set, + std::vector* rest, + std::vector >* coef_pos, + std::vector >* coef_neg, + Analyzer* analyzer) { // Take formulas from current_ineq_set and classify them according to polarity wrt var // and store to coef_pos and coef_neg respectively. for (const PrimExpr& ineq : current_ineq_set) { @@ -218,9 +219,9 @@ void ClassifyByPolarity( // zero polarity, straight to next_ineq_set AddInequality(next_ineq_set, ineq, analyzer); } else if (coef0 > 0) { - coef_pos.push_back({coef0, coef[1]}); + coef_pos->push_back({coef0, coef[1]}); } else if (coef0 < 0) { - coef_neg.push_back({coef0, coef[1]}); + coef_neg->push_back({coef0, coef[1]}); } continue; } @@ -233,31 +234,31 @@ void ClassifyByPolarity( AddInequality(next_ineq_set, ineq, analyzer); } else if (coef0 > 0) { // Equalities may be considered as pairs of two inequalities - coef_pos.push_back({coef0, coef[1]}); - coef_neg.push_back({-coef0, -coef[1]}); + coef_pos->push_back({coef0, coef[1]}); + coef_neg->push_back({-coef0, -coef[1]}); } else if (coef0 < 0) { - coef_pos.push_back({-coef0, -coef[1]}); - coef_neg.push_back({coef0, coef[1]}); + coef_pos->push_back({-coef0, -coef[1]}); + coef_neg->push_back({coef0, coef[1]}); } continue; } } // if nothing worked, put it in rest - rest.push_back(ineq); + rest->push_back(ineq); } } -void MoveEquality(std::unordered_set& upper_bounds, - std::unordered_set& lower_bounds, - std::unordered_set& equalities) { +void MoveEquality(std::unordered_set* upper_bounds, + std::unordered_set* lower_bounds, + std::unordered_set* equalities) { // those exist in both upper & lower bounds will be moved to equalities - for (auto ub = upper_bounds.begin(); ub != upper_bounds.end();) { - auto lb = lower_bounds.find(*ub); - if (lb != lower_bounds.end()) { - equalities.insert(*lb); - lower_bounds.erase(lb); - ub = upper_bounds.erase(ub); + for (auto ub = upper_bounds->begin(); ub != upper_bounds->end();) { + auto lb = lower_bounds->find(*ub); + if (lb != lower_bounds->end()) { + equalities->insert(*lb); + lower_bounds->erase(lb); + ub = upper_bounds->erase(ub); } else { ++ub; } @@ -291,8 +292,8 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // Simplify each inequality into the form `expr <= 0` and add to current formulas for (const PrimExpr& ineq : system_to_solve->relations) { - AddInequality(current_ineq_set_to_solve, - NormalizeComparisons()(analyzer.Simplify(ineq, 3)), analyzer); + AddInequality(¤t_ineq_set_to_solve, + NormalizeComparisons()(analyzer.Simplify(ineq, 3)), &analyzer); } Map res_bounds; @@ -315,11 +316,11 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t ClassifyByPolarity(v, current_ineq_set_to_solve, - next_ineq_set_to_solve, - rest, - coef_pos, - coef_neg, - analyzer); + &next_ineq_set_to_solve, + &rest, + &coef_pos, + &coef_neg, + &analyzer); // Combine each positive inequality with each negative one (by adding them together) int64_t gcd_x, gcd_y; @@ -335,7 +336,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // to help simplify things like (((y + 10) - (-1*(y - 20))) <= 0) => y - 5 <= 0 // with steps = 2 it's (y*2) - 10 <= 0 new_ineq = NormalizeComparisons()(analyzer.Simplify(new_ineq, 3)); - AddInequality(next_ineq_set_to_solve, new_ineq, analyzer); + AddInequality(&next_ineq_set_to_solve, new_ineq, &analyzer); } } @@ -400,7 +401,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t std::unordered_set equal; equal.reserve(std::min(upper_bounds.size(), lower_bounds.size())); - MoveEquality(upper_bounds, lower_bounds, equal); + MoveEquality(&upper_bounds, &lower_bounds, &equal); std::vector equal_list(equal.begin(), equal.end()); std::sort(equal_list.begin(), equal_list.end(), ExprLess()); @@ -408,8 +409,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t IntGrpBounds bnds(make_const(v.dtype(), coef_lcm), Array(lower_bounds.begin(), lower_bounds.end()), Array(equal_list.begin(), equal_list.end()), - Array(upper_bounds.begin(), upper_bounds.end()) - ); + Array(upper_bounds.begin(), upper_bounds.end())); res_bounds.Set(v, bnds); std::swap(current_ineq_set_to_solve, next_ineq_set_to_solve); From 8caa4062c820e44923c0f34b0e40e83259d81231 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sun, 7 Jun 2020 10:57:34 -0700 Subject: [PATCH 12/33] apply clang-format --- include/tvm/arith/int_solver.h | 17 ++-- src/arith/solve_linear_inequality.cc | 134 +++++++++++++-------------- 2 files changed, 70 insertions(+), 81 deletions(-) diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index 3e7d7abc4e65..d8c15a742caf 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -27,9 +27,11 @@ #include #include #include + #include -#include #include +#include + #include "analyzer.h" namespace tvm { @@ -63,11 +65,8 @@ class IntGrpBoundsNode : public Object { } bool SEqualReduce(const IntGrpBoundsNode* other, SEqualReducer eq) const { - return - eq(coef, other->coef) && - eq(lower, other->lower) && - eq(equal, other->equal) && - eq(upper, other->upper); + return eq(coef, other->coef) && eq(lower, other->lower) && eq(equal, other->equal) && + eq(upper, other->upper); } void SHashReduce(SHashReducer hash_reduce) const { @@ -98,9 +97,7 @@ class IntGrpBounds : public ObjectRef { * \param equal equalities * \param upper the upper bounds (include) */ - TVM_DLL IntGrpBounds(PrimExpr coef, - Array lower, - Array equal, + TVM_DLL IntGrpBounds(PrimExpr coef, Array lower, Array equal, Array upper); /*! @@ -246,7 +243,7 @@ class IntConstraintsTransform : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode); }; -typedef std::pair, Array > PartialSolvedInequalities; +typedef std::pair, Array> PartialSolvedInequalities; /*! * \brief Obtain Smith Normal Form of linear equation A x = y. diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 29b481b14074..8af8699ccb8d 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -21,15 +21,15 @@ * \file tvm/arith/solve_linear_inequality.cc * \brief Solve linear inequalities. */ -#include -#include -#include #include #include -#include #include -#include #include +#include +#include +#include +#include +#include #include "int_operator.h" @@ -39,7 +39,8 @@ namespace arith { using namespace tvm::runtime; using namespace tvm::tir; -#define PLUS_ONE(OP) void VisitExpr_(const OP* op) final { num_symbols_++; } +#define PLUS_ONE(OP) \ + void VisitExpr_(const OP* op) final { num_symbols_++; } #define PLUS_ONE_BINARY(OP) \ void VisitExpr_(const OP* op) final { \ @@ -122,9 +123,8 @@ Array as_conditions(const Map& bounds, void DebugPrint( const std::unordered_set& current_ineq_set, const std::unordered_set& next_ineq_set, - const std::vector& rest, - const std::vector >& coef_pos, - const std::vector >& coef_neg) { + const std::vector& rest, const std::vector>& coef_pos, + const std::vector>& coef_neg) { std::cout << "Current ineq set:\n["; for (auto& ineq : current_ineq_set) { std::cout << ineq << ", "; @@ -175,8 +175,7 @@ class NormalizeComparisons : public ExprMutator { }; void AddInequality(std::unordered_set* inequality_set, - const PrimExpr& new_ineq, - Analyzer* analyzer) { + const PrimExpr& new_ineq, Analyzer* analyzer) { if (analyzer->CanProve(new_ineq) || inequality_set->find(new_ineq) != inequality_set->end()) { // redundant: follows from the vranges // or has already been added @@ -204,10 +203,8 @@ void ClassifyByPolarity( const Var& var, const std::unordered_set& current_ineq_set, std::unordered_set* next_ineq_set, - std::vector* rest, - std::vector >* coef_pos, - std::vector >* coef_neg, - Analyzer* analyzer) { + std::vector* rest, std::vector>* coef_pos, + std::vector>* coef_neg, Analyzer* analyzer) { // Take formulas from current_ineq_set and classify them according to polarity wrt var // and store to coef_pos and coef_neg respectively. for (const PrimExpr& ineq : current_ineq_set) { @@ -292,14 +289,15 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // Simplify each inequality into the form `expr <= 0` and add to current formulas for (const PrimExpr& ineq : system_to_solve->relations) { - AddInequality(¤t_ineq_set_to_solve, - NormalizeComparisons()(analyzer.Simplify(ineq, 3)), &analyzer); + AddInequality(¤t_ineq_set_to_solve, NormalizeComparisons()(analyzer.Simplify(ineq, 3)), + &analyzer); } Map res_bounds; for (const Var& v : system_to_solve->variables) { - CHECK(!res_bounds.count(v)) << - "Variable " << v << " appears more than one time in the `variables` which might be a bug"; + CHECK(!res_bounds.count(v)) + << "Variable " << v + << " appears more than one time in the `variables` which might be a bug"; next_ineq_set_to_solve.clear(); coef_pos.clear(); @@ -314,23 +312,18 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t coef_pos.push_back({1, -range_ubound}); } - ClassifyByPolarity(v, - current_ineq_set_to_solve, - &next_ineq_set_to_solve, - &rest, - &coef_pos, - &coef_neg, - &analyzer); + ClassifyByPolarity(v, current_ineq_set_to_solve, &next_ineq_set_to_solve, &rest, &coef_pos, + &coef_neg, &analyzer); // Combine each positive inequality with each negative one (by adding them together) int64_t gcd_x, gcd_y; for (const auto& pos : coef_pos) { for (const auto& neg : coef_neg) { auto first_gcd = ExtendedEuclidean(pos.first, -neg.first, &gcd_x, &gcd_y); - PrimExpr c_pos = make_const(v.dtype(), neg.first/first_gcd); - PrimExpr c_neg = make_const(v.dtype(), pos.first/first_gcd); + PrimExpr c_pos = make_const(v.dtype(), neg.first / first_gcd); + PrimExpr c_neg = make_const(v.dtype(), pos.first / first_gcd); // eliminate the current variable - PrimExpr new_lhs = c_neg*neg.second - c_pos*pos.second; + PrimExpr new_lhs = c_neg * neg.second - c_pos * pos.second; PrimExpr new_ineq = LENode::make(new_lhs, make_zero(pos.second.dtype())); // we need rewrite_simplify -> canonical_simplify -> rewrite_simplify // to help simplify things like (((y + 10) - (-1*(y - 20))) <= 0) => y - 5 <= 0 @@ -359,12 +352,13 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t lower_bounds.reserve(coef_neg.size()); for (const auto& pos : coef_pos) { - PrimExpr bound = make_const(v.dtype(), -coef_lcm/pos.first)*pos.second; + PrimExpr bound = make_const(v.dtype(), -coef_lcm / pos.first) * pos.second; bound = analyzer.Simplify(bound, 3); // Don't add if any of the existing bounds is better if (std::any_of(upper_bounds.begin(), upper_bounds.end(), - [&bound, &analyzer](const PrimExpr& o) - { return analyzer.CanProve(o - bound <= 0); })) { + [&bound, &analyzer](const PrimExpr& o) { + return analyzer.CanProve(o - bound <= 0); + })) { continue; } // Erase all worse bounds @@ -379,12 +373,13 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t upper_bounds.insert(bound); } for (const auto& neg : coef_neg) { - PrimExpr bound = make_const(v.dtype(), -coef_lcm/neg.first)*neg.second; + PrimExpr bound = make_const(v.dtype(), -coef_lcm / neg.first) * neg.second; bound = analyzer.Simplify(bound, 3); // Don't add if any of the existing bounds is better if (std::any_of(lower_bounds.begin(), lower_bounds.end(), - [&bound, &analyzer](const PrimExpr& o) - { return analyzer.CanProve(o - bound >= 0); })) { + [&bound, &analyzer](const PrimExpr& o) { + return analyzer.CanProve(o - bound >= 0); + })) { continue; } // Erase all worse bounds @@ -407,9 +402,9 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // Write it to the result. IntGrpBounds bnds(make_const(v.dtype(), coef_lcm), - Array(lower_bounds.begin(), lower_bounds.end()), - Array(equal_list.begin(), equal_list.end()), - Array(upper_bounds.begin(), upper_bounds.end())); + Array(lower_bounds.begin(), lower_bounds.end()), + Array(equal_list.begin(), equal_list.end()), + Array(upper_bounds.begin(), upper_bounds.end())); res_bounds.Set(v, bnds); std::swap(current_ineq_set_to_solve, next_ineq_set_to_solve); @@ -560,8 +555,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ // Note that we are substituting old with new, so best_range contains new var, // that is we have to substitute new with old in best_range here res_dst_to_src.Set(new_var, - analyzer.Simplify( - var - Substitute(best_range->min, res_dst_to_src))); + analyzer.Simplify(var - Substitute(best_range->min, res_dst_to_src))); // Add the new var to the resulting axis auto range = Range(make_zero(new_var.dtype()), best_range->extent); @@ -593,45 +587,43 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ } TVM_REGISTER_GLOBAL("arith.SolveInequalitiesAsCondition") -.set_body([](TVMArgs args, TVMRetValue *ret) { - PartialSolvedInequalities ret_ineq; + .set_body([](TVMArgs args, TVMRetValue* ret) { + PartialSolvedInequalities ret_ineq; + if (args.size() == 1) { + ret_ineq = SolveLinearInequalities(args[0]); + } else if (args.size() == 3) { + IntConstraints problem(args[0], args[1], args[2]); + ret_ineq = SolveLinearInequalities(problem); + } else { + LOG(FATAL) << "arith.SolveInequalitiesAsCondition expects 1 or 3 arguments, gets " + << args.size(); + } + *ret = as_conditions(ret_ineq.first, ret_ineq.second); + }); + +TVM_REGISTER_GLOBAL("arith.SolveInequalitiesToRange").set_body([](TVMArgs args, TVMRetValue* ret) { if (args.size() == 1) { - ret_ineq = SolveLinearInequalities(args[0]); + *ret = SolveInequalitiesToRange(args[0]); } else if (args.size() == 3) { IntConstraints problem(args[0], args[1], args[2]); - ret_ineq = SolveLinearInequalities(problem); + *ret = SolveInequalitiesToRange(problem); } else { - LOG(FATAL) << "arith.SolveInequalitiesAsCondition expects 1 or 3 arguments, gets " - << args.size(); + LOG(FATAL) << "arith.SolveInequalitiesToRange expects 1 or 3 arguments, gets " << args.size(); } - *ret = as_conditions(ret_ineq.first, ret_ineq.second); }); -TVM_REGISTER_GLOBAL("arith.SolveInequalitiesToRange") -.set_body([](TVMArgs args, TVMRetValue *ret) { - if (args.size() == 1) { - *ret = SolveInequalitiesToRange(args[0]); - } else if (args.size() == 3) { - IntConstraints problem(args[0], args[1], args[2]); - *ret = SolveInequalitiesToRange(problem); - } else { - LOG(FATAL) << "arith.SolveInequalitiesToRange expects 1 or 3 arguments, gets " - << args.size(); - } - }); - TVM_REGISTER_GLOBAL("arith.SolveInequalitiesDeskewRange") -.set_body([](TVMArgs args, TVMRetValue *ret) { - if (args.size() == 1) { - *ret = SolveInequalitiesDeskewRange(args[0]); - } else if (args.size() == 3) { - IntConstraints problem(args[0], args[1], args[2]); - *ret = SolveInequalitiesDeskewRange(problem); - } else { - LOG(FATAL) << "arith.SolveInequalitiesDeskewRange expects 1 or 3 arguments, gets " - << args.size(); - } - }); + .set_body([](TVMArgs args, TVMRetValue* ret) { + if (args.size() == 1) { + *ret = SolveInequalitiesDeskewRange(args[0]); + } else if (args.size() == 3) { + IntConstraints problem(args[0], args[1], args[2]); + *ret = SolveInequalitiesDeskewRange(problem); + } else { + LOG(FATAL) << "arith.SolveInequalitiesDeskewRange expects 1 or 3 arguments, gets " + << args.size(); + } + }); } // namespace arith } // namespace tvm From bc0587ff45a0292b12a8fa1f2b4f4e6310599146 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sun, 7 Jun 2020 11:03:05 -0700 Subject: [PATCH 13/33] apply cl-format-10 --- src/arith/int_constraints.cc | 98 ++++++++++++++---------------------- src/arith/int_operator.h | 2 +- 2 files changed, 40 insertions(+), 60 deletions(-) diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 35d62d6e55fc..0ba5745b1020 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -21,13 +21,12 @@ * \file int_constraints.cc * \brief The integer constraints data structures. */ +#include #include #include #include #include -#include #include -#include #include #include @@ -39,9 +38,7 @@ namespace tvm { namespace arith { -IntGrpBounds::IntGrpBounds(PrimExpr coef, - Array lower, - Array equal, +IntGrpBounds::IntGrpBounds(PrimExpr coef, Array lower, Array equal, Array upper) { CHECK(coef.dtype().is_int() || coef.dtype().is_uint()) << "Coefficient in IntGrpBounds must be integers"; @@ -86,11 +83,11 @@ IntGrpBounds IntGrpBounds::operator+(const Range& r) { } IntGrpBounds IntGrpBounds::Substitute(const Map& subst) const { - auto apply_fun = [&subst](const PrimExpr& e) { return tir::Substitute(e, subst); }; - return IntGrpBounds(tir::Substitute(operator->()->coef, subst), - tir::UpdateArray(operator->()->lower, apply_fun), - tir::UpdateArray(operator->()->equal, apply_fun), - tir::UpdateArray(operator->()->upper, apply_fun)); + auto apply_fun = [&subst](const PrimExpr& e) { return tir::Substitute(e, subst); }; + return IntGrpBounds(tir::Substitute(operator->()->coef, subst), + tir::UpdateArray(operator->()->lower, apply_fun), + tir::UpdateArray(operator->()->equal, apply_fun), + tir::UpdateArray(operator->()->upper, apply_fun)); } Range IntGrpBounds::FindBestRange(const Map& vranges_addl) const { @@ -112,8 +109,7 @@ Range IntGrpBounds::FindBestRange(const Map& vranges_addl) const { } if (lowers.size() == 1 && uppers.size() == 1 && tir::is_one(operator->()->coef)) { - return Range(analyzer.Simplify(lowers[0]), - analyzer.Simplify(uppers[0] + 1)); + return Range(analyzer.Simplify(lowers[0]), analyzer.Simplify(uppers[0] + 1)); } // Here we will try all pairs of lower and upper bounds and find the best pair, that is, the @@ -133,16 +129,15 @@ Range IntGrpBounds::FindBestRange(const Map& vranges_addl) const { // low is the lower bound for v*coef, but we need the lower bound for v. // We use rounding-up division to compute it. Since we want to use a single formula - PrimExpr low_divided = analyzer.Simplify( - floordiv(low + operator->()->coef - 1, operator->()->coef), 3); + PrimExpr low_divided = + analyzer.Simplify(floordiv(low + operator->()->coef - 1, operator->()->coef), 3); // Compute another difference which may be more precise (or not). - PrimExpr diff_2 = analyzer.Simplify( - floordiv(upp, operator->()->coef) - low_divided, 3); + PrimExpr diff_2 = analyzer.Simplify(floordiv(upp, operator->()->coef) - low_divided, 3); PrimExpr diff_over_2 = analyzer.Simplify(EvalSet(diff_2, var_intsets).max(), 3); - PrimExpr diff_over = analyzer.CanProve(diff_over_2 - diff_over_1 < 0) - ? diff_over_2 : diff_over_1; + PrimExpr diff_over = + analyzer.CanProve(diff_over_2 - diff_over_1 < 0) ? diff_over_2 : diff_over_1; // If it is provable that the new one is strictly better than the current best one, // then replace it. Note that we are biased towards earlier pairs which should be simpler. @@ -163,41 +158,30 @@ Range IntGrpBounds::FindBestRange(const Map& vranges_addl) const { TVM_REGISTER_NODE_TYPE(IntGrpBoundsNode); TVM_REGISTER_GLOBAL("arith.IntGrpBounds") -.set_body_typed([](PrimExpr coef, - Array lower, - Array equal, - Array upper) { - return IntGrpBounds(coef, lower, equal, upper); -}); + .set_body_typed([](PrimExpr coef, Array lower, Array equal, + Array upper) { return IntGrpBounds(coef, lower, equal, upper); }); -TVM_REGISTER_GLOBAL("arith.int_grouped_bounds_by_range") -.set_body_typed(IntGrpBounds::range); +TVM_REGISTER_GLOBAL("arith.int_grouped_bounds_by_range").set_body_typed(IntGrpBounds::range); TVM_REGISTER_GLOBAL("arith.IntGrpBounds_FindBestRange") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK(args.size() == 1 || args.size() == 2); - IntGrpBounds bounds = args[0]; - if (args.size() == 1) { - *ret = bounds.FindBestRange(); - } else if (args.size() == 2) { - *ret = bounds.FindBestRange(args[1]); - } -}); + .set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK(args.size() == 1 || args.size() == 2); + IntGrpBounds bounds = args[0]; + if (args.size() == 1) { + *ret = bounds.FindBestRange(); + } else if (args.size() == 2) { + *ret = bounds.FindBestRange(args[1]); + } + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "IntGrpBounds(coef=" - << op->coef - << ", lower=" << op->lower - << ", equal=" << op->equal - << ", upper=" << op->upper - << ")"; - }); - - -IntConstraints::IntConstraints(Array variables, - Map ranges, + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IntGrpBounds(coef=" << op->coef << ", lower=" << op->lower + << ", equal=" << op->equal << ", upper=" << op->upper << ")"; + }); + +IntConstraints::IntConstraints(Array variables, Map ranges, Array relations) { ObjectPtr node = make_object(); if (!variables.defined()) { @@ -220,11 +204,9 @@ IntConstraints::IntConstraints(Array variables, TVM_REGISTER_NODE_TYPE(IntConstraintsNode); TVM_REGISTER_GLOBAL("arith.IntConstraints") -.set_body_typed([](Array variables, - Map ranges, - Array relations) { - return IntConstraints(variables, ranges, relations); -}); + .set_body_typed([](Array variables, Map ranges, Array relations) { + return IntConstraints(variables, ranges, relations); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -247,12 +229,10 @@ IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, IntConstrai TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode); TVM_REGISTER_GLOBAL("arith.IntConstraintsTransform") -.set_body_typed([](IntConstraints src, - IntConstraints dst, - Map src_to_dst, - Map dst_to_src) { - return IntConstraintsTransform(src, dst, src_to_dst, dst_to_src); -}); + .set_body_typed([](IntConstraints src, IntConstraints dst, Map src_to_dst, + Map dst_to_src) { + return IntConstraintsTransform(src, dst, src_to_dst, dst_to_src); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { diff --git a/src/arith/int_operator.h b/src/arith/int_operator.h index a77d2d8bc23d..eff52308f389 100644 --- a/src/arith/int_operator.h +++ b/src/arith/int_operator.h @@ -190,7 +190,7 @@ inline int64_t ZeroAwareGCD(int64_t a, int64_t b) { */ inline int64_t LeastCommonMultiple(int64_t a, int64_t b) { int64_t x, y; - return (a*b) / ExtendedEuclidean(a, b, &x, &y); + return (a * b) / ExtendedEuclidean(a, b, &x, &y); } } // namespace arith From e84467a295045f667af7c4428ab39a56f934ed8d Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sun, 7 Jun 2020 11:08:15 -0700 Subject: [PATCH 14/33] fix cpplint --- include/tvm/arith/int_solver.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index d8c15a742caf..d61f03951aa4 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -289,7 +289,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t * \return The result ranges for each variables. * Constrains that cannot be transformed to Range will be stored in relations. */ -IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities); +IntConstraints SolveInequalitiesToRange(const IntConstraints& system_to_solve); /*! * \brief Solve linear inequalities. @@ -297,7 +297,7 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities); * \return Solved ranges are deskewed to be started from zero. * New variables and the mapping are created accordingly. */ -IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequalities); +IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& system_to_solve); } // namespace arith } // namespace tvm From 45c6a4cea9bbe7b0df44774efb4fc5e5ee3153b3 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 16 Jun 2020 23:56:13 -0700 Subject: [PATCH 15/33] add more docs --- include/tvm/arith/int_solver.h | 37 +++++++++++++++++++++++----- python/tvm/arith/int_solver.py | 4 ++- src/arith/int_constraints.cc | 24 ++++++++++-------- src/arith/solve_linear_inequality.cc | 4 +-- 4 files changed, 50 insertions(+), 19 deletions(-) diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index d61f03951aa4..1f4a86032e29 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -43,7 +43,7 @@ using tir::VarNode; /*! * \brief Represent integer grouped bounds which are classified into - * lower bounds (include), upper bounds (include) and equalities. + * lower bounds (inclusive), upper bounds (inclusive) and equalities. * It also contains coefficient as a multiplier for the bounds, i.e., * coef * var >= lower * coef * var == equal @@ -112,8 +112,19 @@ class IntGrpBounds : public ObjectRef { */ IntGrpBounds Substitute(const Map& subst) const; + /*! + * \brief Find the best range from the grouped bounds. + * \param vranges_addl additional variable ranges that help infer the best range. + * \return The best range (has the least difference between the lower bound and upper bound). + * undefined if (-inf, +inf). + */ Range FindBestRange(const Map& vranges_addl = {}) const; + /*! + * \brief Combine the bounds with another range. + * \param range another range to be combined. + * \return combined bounds. + */ IntGrpBounds operator+(const Range& r); TVM_DEFINE_OBJECT_REF_METHODS(IntGrpBounds, ObjectRef, IntGrpBoundsNode); @@ -284,18 +295,32 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_to_solve); /*! - * \brief Solve linear inequalities. + * \brief Solve linear inequalities and infer the range of each variable. * \param system_to_solve the variables to solve, their ranges, and a list of inequalities. * \return The result ranges for each variables. - * Constrains that cannot be transformed to Range will be stored in relations. + * The returned IntConstraints(variables, ranges, relations) contains, + * 1. variables - the variables that have been solved. + * 2. ranges - the best range of each variable. + * 3. relations - constraints that cannot be transformed to + * Range will be stored in relations. */ IntConstraints SolveInequalitiesToRange(const IntConstraints& system_to_solve); /*! - * \brief Solve linear inequalities. + * \brief Solve linear inequalities and deskew the ranges towards zero. * \param system_to_solve the variables to solve, their ranges, and a list of inequalities. - * \return Solved ranges are deskewed to be started from zero. - * New variables and the mapping are created accordingly. + * \return A transform (src IntConstraints -> dst IntConstraints) + * from original variables to a set of new variables. + * The ranges of new variables always start from zero, + * their extents are solved from \p system_to_solve. + * src IntConstraints is the same as \p system_to_solve. + * dst IntConstraints(variables, ranges, relations) contains, + * 1. variables - the variables that have been solved. + * 2. ranges - the best range (start from zero) of each variable. + * 3. relations - constraints that cannot be transformed to + * Range will be stored in relations. + * Variable mapping can be obtained from + * IntConstraintsTransform.src_to_dst and IntConstraintsTransform.dst_to_src. */ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& system_to_solve); diff --git a/python/tvm/arith/int_solver.py b/python/tvm/arith/int_solver.py index beae4cf5ce88..7f28ed406239 100644 --- a/python/tvm/arith/int_solver.py +++ b/python/tvm/arith/int_solver.py @@ -28,7 +28,7 @@ class IntGrpBounds(Object): Parameters ---------- coef : tvm.ir.PrimExpr - The coefficient. Must be integer. + The coefficient. Must be integer type. coef * var >= lower coef * var == equal coef * var >= upper @@ -171,5 +171,7 @@ def solve_linear_inequalities(equations, variables=None, ranges=None, deskew_ran solver = _ffi_api.SolveInequalitiesDeskewRange \ if deskew_range else _ffi_api.SolveInequalitiesToRange if isinstance(equations, IntConstraints): + assert variables is None + assert ranges is None return solver(equations) return solver(variables, ranges, equations) diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 0ba5745b1020..a84acebe612a 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -70,16 +70,17 @@ IntGrpBounds IntGrpBounds::operator+(const Range& r) { Array equal; Array lower; Array upper; + const PrimExpr& coef = operator->()->coef; if (tir::is_one(r->extent)) { - equal.push_back(analyzer.Simplify(r->min * operator->()->coef)); + equal.push_back(analyzer.Simplify(r->min * coef)); } else { - lower.push_back(analyzer.Simplify(r->min * operator->()->coef)); - upper.push_back(analyzer.Simplify((r->min + r->extent - 1) * operator->()->coef)); + lower.push_back(analyzer.Simplify(r->min * coef)); + upper.push_back(analyzer.Simplify((r->min + r->extent - 1) * coef)); } for (const auto& eq : operator->()->equal) equal.push_back(eq); for (const auto& lb : operator->()->lower) lower.push_back(lb); for (const auto& ub : operator->()->upper) upper.push_back(ub); - return IntGrpBounds(operator->()->coef, lower, equal, upper); + return IntGrpBounds(coef, lower, equal, upper); } IntGrpBounds IntGrpBounds::Substitute(const Map& subst) const { @@ -99,8 +100,11 @@ Range IntGrpBounds::FindBestRange(const Map& vranges_addl) const { var_intsets[kv.first.get()] = IntSet::range(kv.second); } - std::vector lowers(operator->()->equal.begin(), operator->()->equal.end()); - std::vector uppers(operator->()->equal.begin(), operator->()->equal.end()); + const Array& equal = operator->()->equal; + const PrimExpr& coef = operator->()->coef; + + std::vector lowers(equal.begin(), equal.end()); + std::vector uppers(equal.begin(), equal.end()); for (const auto& expr : operator->()->lower) { lowers.push_back(expr); } @@ -108,7 +112,7 @@ Range IntGrpBounds::FindBestRange(const Map& vranges_addl) const { uppers.push_back(expr); } - if (lowers.size() == 1 && uppers.size() == 1 && tir::is_one(operator->()->coef)) { + if (lowers.size() == 1 && uppers.size() == 1 && tir::is_one(coef)) { return Range(analyzer.Simplify(lowers[0]), analyzer.Simplify(uppers[0] + 1)); } @@ -123,17 +127,17 @@ Range IntGrpBounds::FindBestRange(const Map& vranges_addl) const { for (const PrimExpr& low : lowers) { for (const PrimExpr& upp : uppers) { - PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, operator->()->coef), 3); + PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, coef), 3); // Since diff may depend on some other variables, we compute its overapproximation PrimExpr diff_over_1 = analyzer.Simplify(EvalSet(diff_1, var_intsets).max(), 3); // low is the lower bound for v*coef, but we need the lower bound for v. // We use rounding-up division to compute it. Since we want to use a single formula PrimExpr low_divided = - analyzer.Simplify(floordiv(low + operator->()->coef - 1, operator->()->coef), 3); + analyzer.Simplify(floordiv(low + coef - 1, coef), 3); // Compute another difference which may be more precise (or not). - PrimExpr diff_2 = analyzer.Simplify(floordiv(upp, operator->()->coef) - low_divided, 3); + PrimExpr diff_2 = analyzer.Simplify(floordiv(upp, coef) - low_divided, 3); PrimExpr diff_over_2 = analyzer.Simplify(EvalSet(diff_2, var_intsets).max(), 3); PrimExpr diff_over = diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 8af8699ccb8d..848c385b32ac 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -345,7 +345,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t coef_lcm = LeastCommonMultiple(coef_lcm, -neg.first); } - // The resulting lower and upper bounds stored in sorted vectors + // The resulting lower and upper bounds std::unordered_set upper_bounds; std::unordered_set lower_bounds; upper_bounds.reserve(coef_pos.size()); @@ -482,7 +482,7 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { } } - // Add the original conditions (with variables substituted) to the resulting conditions + // Add the original conditions to the resulting conditions arith::Analyzer analyzer; analyzer.Bind(vranges); for (const PrimExpr& old_cond : as_conditions(solved_bounds, solved_other_relations)) { From bb22dd2c42bff8611db8a70dad706138cccfe89f Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 17 Jun 2020 00:01:19 -0700 Subject: [PATCH 16/33] Add Co-author. Co-authored-by: Sergei Grechanik --- src/arith/int_constraints.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index a84acebe612a..718cae6ae4bf 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -133,8 +133,7 @@ Range IntGrpBounds::FindBestRange(const Map& vranges_addl) const { // low is the lower bound for v*coef, but we need the lower bound for v. // We use rounding-up division to compute it. Since we want to use a single formula - PrimExpr low_divided = - analyzer.Simplify(floordiv(low + coef - 1, coef), 3); + PrimExpr low_divided = analyzer.Simplify(floordiv(low + coef - 1, coef), 3); // Compute another difference which may be more precise (or not). PrimExpr diff_2 = analyzer.Simplify(floordiv(upp, coef) - low_divided, 3); From 0e0b5f085625d642113c09c8491221e7a5e96c69 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 17 Jun 2020 14:04:40 -0700 Subject: [PATCH 17/33] add check_solution --- include/tvm/arith/int_solver.h | 2 +- python/tvm/testing.py | 4 +- .../test_arith_solve_linear_inequality.py | 49 ++++++++++++++++++- 3 files changed, 51 insertions(+), 4 deletions(-) diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index 1f4a86032e29..18272c1b0462 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -122,7 +122,7 @@ class IntGrpBounds : public ObjectRef { /*! * \brief Combine the bounds with another range. - * \param range another range to be combined. + * \param r range to be combined. * \return combined bounds. */ IntGrpBounds operator+(const Range& r); diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 9c23ce045b2e..a6b00e88c4a9 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -226,10 +226,10 @@ def _compute_body(*us): counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)] counterex = sorted(counterex, key=lambda x: x[0]) counterex = ", ".join([v + " = " + str(i) for v, i in counterex]) + ana = tvm.arith.Analyzer() raise AssertionError("Expression {}\nis not true on {}\n" "Counterexample: {}" - .format(tvm.tir.ir_pass.CanonicalSimplify(bool_expr), - vranges, counterex)) + .format(ana.simplify(bool_expr), vranges, counterex)) tvm._ffi._init_api("testing", __name__) diff --git a/tests/python/unittest/test_arith_solve_linear_inequality.py b/tests/python/unittest/test_arith_solve_linear_inequality.py index f2a5428b33a5..8fb19ff8c3e9 100644 --- a/tests/python/unittest/test_arith_solve_linear_inequality.py +++ b/tests/python/unittest/test_arith_solve_linear_inequality.py @@ -21,6 +21,44 @@ from tvm import te, arith, ir, tir, testing +def check_solution(solution, vranges={}): + """Check that solution is a bijective transformation""" + def _check_forward(constraints1, constraints2, varmap, backvarmap): + ana = tvm.arith.Analyzer() + all_vranges = vranges.copy() + all_vranges.update({v: r for v, r in constraints1.ranges.items()}) + + # Check that the transformation is injective + cond_on_vars = tir.const(1, 'bool') + for v in constraints1.variables: + # variable mapping is consistent + v_back = ana.simplify(tir.stmt_functor.substitute(varmap[v], backvarmap)) + cond_on_vars = te.all(cond_on_vars, v == v_back) + # Also we have to check that the new relations are true when old relations are true + cond_subst = tir.stmt_functor.substitute( + te.all(tir.const(1, 'bool'), *constraints2.relations), backvarmap) + # We have to include relations from vranges too + for v in constraints2.variables: + if v in constraints2.ranges: + r = constraints2.ranges[v] + range_cond = te.all(v >= r.min, v < r.min + r.extent) + range_cond = tir.stmt_functor.substitute(range_cond, backvarmap) + cond_subst = te.all(cond_subst, range_cond) + cond_subst = ana.simplify(cond_subst) + testing.check_bool_expr_is_true( + te.all(cond_subst, cond_on_vars), all_vranges, + cond=te.all(tir.const(1, 'bool'), *constraints1.relations)) + + rels = solution.dst.relations + if len(rels) == 1 and ir.structural_equal(rels[0], False): + # not solvable, skip + return + _check_forward(solution.src, solution.dst, + solution.src_to_dst, solution.dst_to_src) + _check_forward(solution.dst, solution.src, + solution.dst_to_src, solution.src_to_dst) + + def test_solve_system_of_inequalities(): seed = random.randrange(sys.maxsize) print("\nThis test is intentionally non-deterministic, " @@ -45,6 +83,14 @@ def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)): after = te.all(tir.const(1, 'bool'), *after) testing.check_bool_expr_is_true(before == after, vranges) + print("-------------") + print(fs) + print(vs) + print(vranges) + solution = arith.solve_linear_inequalities(fs, vs, vranges, deskew_range=True) + print(solution) + check_solution(solution) + for i in range(3): _check(1, 1) for i in range(3): @@ -163,4 +209,5 @@ def test_multi_equal(): if __name__ == "__main__": - pytest.main([__file__]) + test_solve_system_of_inequalities() + # pytest.main([__file__]) From ff05fe0733589df570d6ddd3797fed5dd551862b Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Thu, 18 Jun 2020 23:30:58 -0700 Subject: [PATCH 18/33] add support for unsolvable inequalities --- src/arith/solve_linear_inequality.cc | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 848c385b32ac..eb66cf859ef6 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -443,7 +443,6 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { Map solved_bounds = solved_system.first; Array solved_other_relations = solved_system.second; - Array res_variables; Array res_relations; // this keeps being updated during determining the range of each variable. @@ -476,6 +475,11 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { auto best_range = bnd.FindBestRange(vranges); if (best_range.defined()) { + if (analyzer.CanProveGreaterEqual(-best_range->extent, 0)) { + // range.extent <= 0 implies the input inequality system is unsolvable + return IntConstraints(/*variables=*/{}, /*ranges=*/{}, + /*relations=*/{tir::make_zero(DataType::Bool())}); + } res_ranges.Set(var, best_range); vranges.Set(var, best_range); } @@ -549,6 +553,14 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ } else if (is_const_int(best_range->extent, 1)) { // Don't create an itervar, just replace it everywhere with its min res_src_to_dst.Set(var, best_range->min); + } else if (analyzer.CanProveGreaterEqual(-best_range->extent, 0)) { + // range.extent <= 0 implies the input inequality system is unsolvable + return IntConstraintsTransform(inequalities, + IntConstraints( + /*variables=*/{}, + /*ranges=*/{}, + /*relations=*/{tir::make_zero(DataType::Bool())}), + {}, {}); } else { // created new_var starts from 0 res_src_to_dst.Set(var, new_var + best_range->min); From e6845bccc797271e97f07830819a8b5f338ba370 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 19 Jun 2020 18:40:02 -0700 Subject: [PATCH 19/33] fix for non-divisible equation --- src/arith/int_constraints.cc | 12 +++++++++++- .../unittest/test_arith_solve_linear_inequality.py | 8 +------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 718cae6ae4bf..56ddadb6086b 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -145,7 +145,17 @@ Range IntGrpBounds::FindBestRange(const Map& vranges_addl) const { // If it is provable that the new one is strictly better than the current best one, // then replace it. Note that we are biased towards earlier pairs which should be simpler. if (!best_diff_over.defined() || analyzer.CanProve(diff_over - best_diff_over < 0)) { - best_lower = low_divided; + if (tir::is_const_int(diff_over, 0)) { + // we need to be very careful with equations, + // as the division cannot be rounded in such case. + if (tir::is_const(best_lower) && !analyzer.CanProve(floormod(low, coef) == 0)) { + // we don't support non-integer case so far. + return Range::make_by_min_extent(best_lower, 0); + } + best_lower = analyzer.Simplify(low / coef); + } else { + best_lower = low_divided; + } best_diff_over = diff_over; } } diff --git a/tests/python/unittest/test_arith_solve_linear_inequality.py b/tests/python/unittest/test_arith_solve_linear_inequality.py index 8fb19ff8c3e9..fbd39dea0d47 100644 --- a/tests/python/unittest/test_arith_solve_linear_inequality.py +++ b/tests/python/unittest/test_arith_solve_linear_inequality.py @@ -83,12 +83,7 @@ def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)): after = te.all(tir.const(1, 'bool'), *after) testing.check_bool_expr_is_true(before == after, vranges) - print("-------------") - print(fs) - print(vs) - print(vranges) solution = arith.solve_linear_inequalities(fs, vs, vranges, deskew_range=True) - print(solution) check_solution(solution) for i in range(3): @@ -209,5 +204,4 @@ def test_multi_equal(): if __name__ == "__main__": - test_solve_system_of_inequalities() - # pytest.main([__file__]) + pytest.main([__file__]) From 7166020586c7ac30c6e342298cdbdd72f28486fe Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 20 Jun 2020 12:24:40 -0700 Subject: [PATCH 20/33] fix non-divisible case again and add test case for no solution --- src/arith/int_constraints.cc | 17 +++++++---------- .../test_arith_solve_linear_inequality.py | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 56ddadb6086b..3501f5fc489d 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -145,18 +145,15 @@ Range IntGrpBounds::FindBestRange(const Map& vranges_addl) const { // If it is provable that the new one is strictly better than the current best one, // then replace it. Note that we are biased towards earlier pairs which should be simpler. if (!best_diff_over.defined() || analyzer.CanProve(diff_over - best_diff_over < 0)) { - if (tir::is_const_int(diff_over, 0)) { - // we need to be very careful with equations, - // as the division cannot be rounded in such case. - if (tir::is_const(best_lower) && !analyzer.CanProve(floormod(low, coef) == 0)) { - // we don't support non-integer case so far. - return Range::make_by_min_extent(best_lower, 0); - } - best_lower = analyzer.Simplify(low / coef); - } else { + if (analyzer.CanProve(floormod(low, coef) == 0)) { + // we need to be very careful with rounding + // as it could be wrong when we have equations. + // equations can come from + // 1. when it is a single point, i.e., extent == 1. + // 2. when var is substituted by another var in deskew range. best_lower = low_divided; + best_diff_over = diff_over; } - best_diff_over = diff_over; } } } diff --git a/tests/python/unittest/test_arith_solve_linear_inequality.py b/tests/python/unittest/test_arith_solve_linear_inequality.py index fbd39dea0d47..67f681be2f66 100644 --- a/tests/python/unittest/test_arith_solve_linear_inequality.py +++ b/tests/python/unittest/test_arith_solve_linear_inequality.py @@ -203,5 +203,24 @@ def test_multi_equal(): assert solution.src_to_dst[x] == 6 +def test_no_solution(): + x = te.var("x0") + vranges = {x: tvm.ir.Range.make_by_min_extent(-20, 41)} + problem = [-x - 4 <= -5*x + 2, x*4 + 5 <= x*5] + + solution = arith.solve_linear_inequalities(problem, [x], vranges, deskew_range=True) + assert list(solution.dst.variables) == [] + [rel] = solution.dst.relations + assert ir.structural_equal(rel, False) + assert len(solution.src_to_dst) == 0 + assert len(solution.dst_to_src) == 0 + + solution = arith.solve_linear_inequalities(problem, [x], vranges) + assert len(solution.variables) == 0 + assert len(solution.ranges) == 0 + [rel] = solution.relations + assert not rel + + if __name__ == "__main__": pytest.main([__file__]) From eaca5c92602ba3b6270b48ebe87b3adaa9ef255b Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 20 Jun 2020 12:56:34 -0700 Subject: [PATCH 21/33] revise testing --- python/tvm/testing.py | 46 +++++++++++++++++++ .../test_arith_solve_linear_equations.py | 42 +---------------- .../test_arith_solve_linear_inequality.py | 42 +---------------- 3 files changed, 50 insertions(+), 80 deletions(-) diff --git a/python/tvm/testing.py b/python/tvm/testing.py index a6b00e88c4a9..1a860c901a09 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -232,4 +232,50 @@ def _compute_body(*us): .format(ana.simplify(bool_expr), vranges, counterex)) +def check_int_constraints_trans_consistency(constraints_trans, vranges={}): + """ Check IntConstraintsTransform is a bijective transformation. + + Parameters + ---------- + constraints_trans : arith.IntConstraintsTransform + Integer constraints transformation + vranges: Dict[tvm.tir.expr.Var, tvm.ir.Range] + Free variables and their ranges + """ + def _check_forward(constraints1, constraints2, varmap, backvarmap): + ana = tvm.arith.Analyzer() + all_vranges = vranges.copy() + all_vranges.update({v: r for v, r in constraints1.ranges.items()}) + + # Check that the transformation is injective + cond_on_vars = tvm.tir.const(1, 'bool') + for v in constraints1.variables: + # variable mapping is consistent + v_back = ana.simplify(tvm.tir.stmt_functor.substitute(varmap[v], backvarmap)) + cond_on_vars = tvm.te.all(cond_on_vars, v == v_back) + # Also we have to check that the new relations are true when old relations are true + cond_subst = tvm.tir.stmt_functor.substitute( + tvm.te.all(tvm.tir.const(1, 'bool'), *constraints2.relations), backvarmap) + # We have to include relations from vranges too + for v in constraints2.variables: + if v in constraints2.ranges: + r = constraints2.ranges[v] + range_cond = tvm.te.all(v >= r.min, v < r.min + r.extent) + range_cond = tvm.tir.stmt_functor.substitute(range_cond, backvarmap) + cond_subst = tvm.te.all(cond_subst, range_cond) + cond_subst = ana.simplify(cond_subst) + check_bool_expr_is_true( + tvm.te.all(cond_subst, cond_on_vars), all_vranges, + cond=tvm.te.all(tvm.tir.const(1, 'bool'), *constraints1.relations)) + + rels = constraints_trans.dst.relations + if len(rels) == 1 and tvm.ir.structural_equal(rels[0], False): + # not solvable, skip + return + _check_forward(constraints_trans.src, constraints_trans.dst, + constraints_trans.src_to_dst, constraints_trans.dst_to_src) + _check_forward(constraints_trans.dst, constraints_trans.src, + constraints_trans.dst_to_src, constraints_trans.src_to_dst) + + tvm._ffi._init_api("testing", __name__) diff --git a/tests/python/unittest/test_arith_solve_linear_equations.py b/tests/python/unittest/test_arith_solve_linear_equations.py index 798a580574ec..c01834f79457 100644 --- a/tests/python/unittest/test_arith_solve_linear_equations.py +++ b/tests/python/unittest/test_arith_solve_linear_equations.py @@ -21,44 +21,6 @@ from tvm import te, arith, ir, tir, testing -def check_solution(solution, vranges={}): - """Check that solution is a bijective transformation""" - def _check_forward(constraints1, constraints2, varmap, backvarmap): - ana = tvm.arith.Analyzer() - all_vranges = vranges.copy() - all_vranges.update({v: r for v, r in constraints1.ranges.items()}) - - # Check that the transformation is injective - cond_on_vars = tir.const(1, 'bool') - for v in constraints1.variables: - # variable mapping is consistent - v_back = ana.simplify(tir.stmt_functor.substitute(varmap[v], backvarmap)) - cond_on_vars = te.all(cond_on_vars, v == v_back) - # Also we have to check that the new relations are true when old relations are true - cond_subst = tir.stmt_functor.substitute( - te.all(tir.const(1, 'bool'), *constraints2.relations), backvarmap) - # We have to include relations from vranges too - for v in constraints2.variables: - if v in constraints2.ranges: - r = constraints2.ranges[v] - range_cond = te.all(v >= r.min, v < r.min + r.extent) - range_cond = tir.stmt_functor.substitute(range_cond, backvarmap) - cond_subst = te.all(cond_subst, range_cond) - cond_subst = ana.simplify(cond_subst) - testing.check_bool_expr_is_true( - te.all(cond_subst, cond_on_vars), all_vranges, - cond=te.all(tir.const(1, 'bool'), *constraints1.relations)) - - rels = solution.dst.relations - if len(rels) == 1 and ir.structural_equal(rels[0], False): - # not solvable, skip - return - _check_forward(solution.src, solution.dst, - solution.src_to_dst, solution.dst_to_src) - _check_forward(solution.dst, solution.src, - solution.dst_to_src, solution.src_to_dst) - - def test_solution_consistency(): seed = random.randrange(sys.maxsize) print("\nThis test is intentionally non-deterministic, " @@ -84,14 +46,14 @@ def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)): vranges = {v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in variables} solution = arith.solve_linear_equations(relations, variables, vranges) - check_solution(solution) + testing.check_int_constraints_trans_consistency(solution) # leaving some variables as parameters should also be ok for k in [1, 2]: if len(variables) > k: solution = arith.solve_linear_equations(relations, variables[:-k], vranges) param_ranges = {v: vranges[v] for v in variables[-k:]} - check_solution(solution, param_ranges) + testing.check_int_constraints_trans_consistency(solution, param_ranges) for i in range(2): _check(num_vars=1, num_formulas=1) diff --git a/tests/python/unittest/test_arith_solve_linear_inequality.py b/tests/python/unittest/test_arith_solve_linear_inequality.py index 67f681be2f66..d6092cc51e8c 100644 --- a/tests/python/unittest/test_arith_solve_linear_inequality.py +++ b/tests/python/unittest/test_arith_solve_linear_inequality.py @@ -21,45 +21,7 @@ from tvm import te, arith, ir, tir, testing -def check_solution(solution, vranges={}): - """Check that solution is a bijective transformation""" - def _check_forward(constraints1, constraints2, varmap, backvarmap): - ana = tvm.arith.Analyzer() - all_vranges = vranges.copy() - all_vranges.update({v: r for v, r in constraints1.ranges.items()}) - - # Check that the transformation is injective - cond_on_vars = tir.const(1, 'bool') - for v in constraints1.variables: - # variable mapping is consistent - v_back = ana.simplify(tir.stmt_functor.substitute(varmap[v], backvarmap)) - cond_on_vars = te.all(cond_on_vars, v == v_back) - # Also we have to check that the new relations are true when old relations are true - cond_subst = tir.stmt_functor.substitute( - te.all(tir.const(1, 'bool'), *constraints2.relations), backvarmap) - # We have to include relations from vranges too - for v in constraints2.variables: - if v in constraints2.ranges: - r = constraints2.ranges[v] - range_cond = te.all(v >= r.min, v < r.min + r.extent) - range_cond = tir.stmt_functor.substitute(range_cond, backvarmap) - cond_subst = te.all(cond_subst, range_cond) - cond_subst = ana.simplify(cond_subst) - testing.check_bool_expr_is_true( - te.all(cond_subst, cond_on_vars), all_vranges, - cond=te.all(tir.const(1, 'bool'), *constraints1.relations)) - - rels = solution.dst.relations - if len(rels) == 1 and ir.structural_equal(rels[0], False): - # not solvable, skip - return - _check_forward(solution.src, solution.dst, - solution.src_to_dst, solution.dst_to_src) - _check_forward(solution.dst, solution.src, - solution.dst_to_src, solution.src_to_dst) - - -def test_solve_system_of_inequalities(): +def test_solution_consistency(): seed = random.randrange(sys.maxsize) print("\nThis test is intentionally non-deterministic, " "if it fails please report it in github issue together with this seed {}\n".format(seed)) @@ -84,7 +46,7 @@ def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)): testing.check_bool_expr_is_true(before == after, vranges) solution = arith.solve_linear_inequalities(fs, vs, vranges, deskew_range=True) - check_solution(solution) + testing.check_int_constraints_trans_consistency(solution) for i in range(3): _check(1, 1) From 12b5fcd54558924be8db6dc76ad5f6e2811c0853 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 20 Jun 2020 14:21:21 -0700 Subject: [PATCH 22/33] fix merging --- src/arith/solve_linear_inequality.cc | 28 +++++++++---------- .../test_arith_solve_linear_inequality.py | 9 +++--- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index eb66cf859ef6..9eafa85226f4 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -105,13 +105,13 @@ Array as_conditions(const Map& bounds, const auto& bnds = iter.second; PrimExpr lhs = bnds->coef * v; for (const PrimExpr& rhs : bnds->equal) { - res.push_back(tir::EQNode::make(lhs, rhs)); + res.push_back(tir::EQ(lhs, rhs)); } for (const PrimExpr& rhs : bnds->lower) { - res.push_back(tir::GENode::make(lhs, rhs)); + res.push_back(tir::GE(lhs, rhs)); } for (const PrimExpr& rhs : bnds->upper) { - res.push_back(tir::LENode::make(lhs, rhs)); + res.push_back(tir::LE(lhs, rhs)); } } for (const PrimExpr& e : relations) { @@ -155,21 +155,21 @@ void DebugPrint( */ class NormalizeComparisons : public ExprMutator { public: - PrimExpr VisitExpr_(const EQNode* op) override { return Make(op->a, op->b); } - PrimExpr VisitExpr_(const NENode* op) override { return Make(op->a, op->b); } - PrimExpr VisitExpr_(const LTNode* op) override { return Make(op->a, op->b); } - PrimExpr VisitExpr_(const LENode* op) override { return Make(op->a, op->b); } - PrimExpr VisitExpr_(const GTNode* op) override { return Make(op->b, op->a); } - PrimExpr VisitExpr_(const GENode* op) override { return Make(op->b, op->a); } + PrimExpr VisitExpr_(const EQNode* op) override { return Make(op->a, op->b); } + PrimExpr VisitExpr_(const NENode* op) override { return Make(op->a, op->b); } + PrimExpr VisitExpr_(const LTNode* op) override { return Make(op->a, op->b); } + PrimExpr VisitExpr_(const LENode* op) override { return Make(op->a, op->b); } + PrimExpr VisitExpr_(const GTNode* op) override { return Make(op->b, op->a); } + PrimExpr VisitExpr_(const GENode* op) override { return Make(op->b, op->a); } private: - template + template PrimExpr Make(const PrimExpr& a, const PrimExpr& b) { // rewrite LT to LE for ints - if (std::is_same::value && (a.dtype().is_int() || a.dtype().is_uint())) { - return LENode::make(analyzer_.Simplify(a - b + 1), make_zero(a.dtype())); + if (std::is_same::value && (a.dtype().is_int() || a.dtype().is_uint())) { + return LE(analyzer_.Simplify(a - b + 1), make_zero(a.dtype())); } - return TNode::make(analyzer_.Simplify(a - b), make_zero(a.dtype())); + return T(analyzer_.Simplify(a - b), make_zero(a.dtype())); } arith::Analyzer analyzer_; }; @@ -324,7 +324,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t PrimExpr c_neg = make_const(v.dtype(), pos.first / first_gcd); // eliminate the current variable PrimExpr new_lhs = c_neg * neg.second - c_pos * pos.second; - PrimExpr new_ineq = LENode::make(new_lhs, make_zero(pos.second.dtype())); + PrimExpr new_ineq = LE(new_lhs, make_zero(pos.second.dtype())); // we need rewrite_simplify -> canonical_simplify -> rewrite_simplify // to help simplify things like (((y + 10) - (-1*(y - 20))) <= 0) => y - 5 <= 0 // with steps = 2 it's (y*2) - 10 <= 0 diff --git a/tests/python/unittest/test_arith_solve_linear_inequality.py b/tests/python/unittest/test_arith_solve_linear_inequality.py index d6092cc51e8c..c778aa555b71 100644 --- a/tests/python/unittest/test_arith_solve_linear_inequality.py +++ b/tests/python/unittest/test_arith_solve_linear_inequality.py @@ -91,11 +91,10 @@ def test_dual_variable(): # solution as conditions solution = arith._ffi_api.SolveInequalitiesAsCondition(variables, ranges, problem) - assert len(solution) == 4 - assert ir.structural_equal(solution[0], y >= 0) - assert ir.structural_equal(solution[1], y <= 5) - assert ir.structural_equal(solution[2], x >= (y + 10)) - assert ir.structural_equal(solution[3], x <= (20 - y)) + assert ir.structural_equal(solution[0], x >= (y + 10)) + assert ir.structural_equal(solution[1], x <= (20 - y)) + assert ir.structural_equal(solution[2], y >= 0) + assert ir.structural_equal(solution[3], y <= 5) # solve and get the ranges solution = arith.solve_linear_inequalities([ From 182a5d51782bac4622b031c7bf2242c4a0bd6ca0 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 20 Jun 2020 15:15:41 -0700 Subject: [PATCH 23/33] fix lint --- python/tvm/testing.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 1a860c901a09..d597dc4bbaed 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name +# pylint: disable=invalid-name,unnecessary-comprehension """ TVM testing utilities """ import logging import numpy as np @@ -232,7 +232,7 @@ def _compute_body(*us): .format(ana.simplify(bool_expr), vranges, counterex)) -def check_int_constraints_trans_consistency(constraints_trans, vranges={}): +def check_int_constraints_trans_consistency(constraints_trans, vranges=None): """ Check IntConstraintsTransform is a bijective transformation. Parameters @@ -242,6 +242,9 @@ def check_int_constraints_trans_consistency(constraints_trans, vranges={}): vranges: Dict[tvm.tir.expr.Var, tvm.ir.Range] Free variables and their ranges """ + if vranges is None: + vranges = {} + def _check_forward(constraints1, constraints2, varmap, backvarmap): ana = tvm.arith.Analyzer() all_vranges = vranges.copy() From c8b2370f5dbb6c277162dde8c0796e785c921203 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 20 Jun 2020 16:14:59 -0700 Subject: [PATCH 24/33] add comments --- include/tvm/arith/int_solver.h | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index 18272c1b0462..839a8de0b851 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -289,6 +289,18 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol /*! * \brief Solve linear inequalities. * \param system_to_solve the variables to solve, their ranges, and a list of inequalities. + * The inequalities are rewritten using Fourier-Motzkin elimination. + * This function takes an array of (in)equalities and an array of variables, and essentially + * rewrites the (in)equalities into an array of (in)equalities of the following form, + * + * x0 >= f0(x1, x2, ..., xn) + * x0 <= g0(x1, x2, ..., xn) + * x1 >= f1(x2, ..., xn) + * x1 <= g1(x2, ..., xn) + * ... + * xn >= fn() // just a constant + * xn <= gn() // just a constant + * * \return A map of variables and their solved bounds, * and constrains that cannot be solved to bounds. */ From 4c6cec4bae98351118f5268a1ef508eb02d2a77d Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sun, 21 Jun 2020 17:23:12 -0700 Subject: [PATCH 25/33] fix order of as_condition --- src/arith/solve_linear_inequality.cc | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 9eafa85226f4..3d2c7fa006f5 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -97,12 +97,16 @@ struct ExprLess { /*! * \brief Combine the information into an array of (in)equalities. */ -Array as_conditions(const Map& bounds, +Array as_conditions(const Array& variables, + const Map& bounds, const Array& relations) { Array res; - for (const auto iter : bounds) { - const Var& v = iter.first; - const auto& bnds = iter.second; + // use variables to keep the order of iteration + // so as to get rid of any non-determinism. + CHECK_EQ(variables.size(), bounds.size()); + for (const auto v : variables) { + CHECK(bounds.count(v)); + const auto& bnds = bounds[v]; PrimExpr lhs = bnds->coef * v; for (const PrimExpr& rhs : bnds->equal) { res.push_back(tir::EQ(lhs, rhs)); @@ -489,7 +493,8 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { // Add the original conditions to the resulting conditions arith::Analyzer analyzer; analyzer.Bind(vranges); - for (const PrimExpr& old_cond : as_conditions(solved_bounds, solved_other_relations)) { + for (const PrimExpr& old_cond : as_conditions( + inequalities->variables, solved_bounds, solved_other_relations)) { if (!analyzer.CanProve(old_cond)) { // those not represented in vranges (res_ranges) res_relations.push_back(old_cond); @@ -581,7 +586,8 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ } // Add the original conditions (with variables substituted) to the resulting conditions - for (const PrimExpr& old_cond : as_conditions(solved_bounds, solved_other_relations)) { + for (const PrimExpr& old_cond : as_conditions( + inequalities->variables, solved_bounds, solved_other_relations)) { PrimExpr new_cond = analyzer.Simplify(Substitute(old_cond, res_src_to_dst)); if (!is_const_int(new_cond, 1)) { // those not represented in vranges (res_ranges) @@ -600,17 +606,19 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ TVM_REGISTER_GLOBAL("arith.SolveInequalitiesAsCondition") .set_body([](TVMArgs args, TVMRetValue* ret) { + IntConstraints problem; PartialSolvedInequalities ret_ineq; if (args.size() == 1) { - ret_ineq = SolveLinearInequalities(args[0]); + problem = args[0]; + ret_ineq = SolveLinearInequalities(problem); } else if (args.size() == 3) { - IntConstraints problem(args[0], args[1], args[2]); + problem = IntConstraints(args[0], args[1], args[2]); ret_ineq = SolveLinearInequalities(problem); } else { LOG(FATAL) << "arith.SolveInequalitiesAsCondition expects 1 or 3 arguments, gets " << args.size(); } - *ret = as_conditions(ret_ineq.first, ret_ineq.second); + *ret = as_conditions(problem->variables, ret_ineq.first, ret_ineq.second); }); TVM_REGISTER_GLOBAL("arith.SolveInequalitiesToRange").set_body([](TVMArgs args, TVMRetValue* ret) { From a4dbba465ac01c7c18fbb543494a33e19a3c65c7 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sun, 21 Jun 2020 17:27:44 -0700 Subject: [PATCH 26/33] fix lint --- src/arith/solve_linear_inequality.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 3d2c7fa006f5..871b5973a60c 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -97,8 +97,7 @@ struct ExprLess { /*! * \brief Combine the information into an array of (in)equalities. */ -Array as_conditions(const Array& variables, - const Map& bounds, +Array as_conditions(const Array& variables, const Map& bounds, const Array& relations) { Array res; // use variables to keep the order of iteration @@ -493,8 +492,8 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { // Add the original conditions to the resulting conditions arith::Analyzer analyzer; analyzer.Bind(vranges); - for (const PrimExpr& old_cond : as_conditions( - inequalities->variables, solved_bounds, solved_other_relations)) { + for (const PrimExpr& old_cond : + as_conditions(inequalities->variables, solved_bounds, solved_other_relations)) { if (!analyzer.CanProve(old_cond)) { // those not represented in vranges (res_ranges) res_relations.push_back(old_cond); @@ -586,8 +585,8 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ } // Add the original conditions (with variables substituted) to the resulting conditions - for (const PrimExpr& old_cond : as_conditions( - inequalities->variables, solved_bounds, solved_other_relations)) { + for (const PrimExpr& old_cond : + as_conditions(inequalities->variables, solved_bounds, solved_other_relations)) { PrimExpr new_cond = analyzer.Simplify(Substitute(old_cond, res_src_to_dst)); if (!is_const_int(new_cond, 1)) { // those not represented in vranges (res_ranges) From 06053433db4a62c4306b1eb30c180a4719446dd3 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 26 Jun 2020 00:00:42 -0700 Subject: [PATCH 27/33] remove special dealing with equations --- src/arith/int_constraints.cc | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 3501f5fc489d..718cae6ae4bf 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -145,15 +145,8 @@ Range IntGrpBounds::FindBestRange(const Map& vranges_addl) const { // If it is provable that the new one is strictly better than the current best one, // then replace it. Note that we are biased towards earlier pairs which should be simpler. if (!best_diff_over.defined() || analyzer.CanProve(diff_over - best_diff_over < 0)) { - if (analyzer.CanProve(floormod(low, coef) == 0)) { - // we need to be very careful with rounding - // as it could be wrong when we have equations. - // equations can come from - // 1. when it is a single point, i.e., extent == 1. - // 2. when var is substituted by another var in deskew range. - best_lower = low_divided; - best_diff_over = diff_over; - } + best_lower = low_divided; + best_diff_over = diff_over; } } } From 485cfad2875f2ccc96132552071d6daaa386d9df Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 26 Jun 2020 12:14:38 -0700 Subject: [PATCH 28/33] address several comments; add steps to python side analyzer.simplify --- include/tvm/arith/analyzer.h | 2 +- python/tvm/arith/analyzer.py | 9 +++++-- python/tvm/testing.py | 24 +++++++++++-------- src/arith/analyzer.cc | 12 ++++++++-- .../test_arith_solve_linear_inequality.py | 9 +++---- 5 files changed, 37 insertions(+), 19 deletions(-) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 6f9ba8f16ea5..aa1a205b661e 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -473,7 +473,7 @@ class TVM_DLL Analyzer { * * \note Analyzer will call into sub-analyzers to get the result. */ - PrimExpr Simplify(const PrimExpr& expr, size_t steps = 2); + PrimExpr Simplify(const PrimExpr& expr, int steps = 2); }; } // namespace arith diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index 5a420ad81755..e841de906d5e 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -119,20 +119,25 @@ def modular_set(self, expr): """ return self._modular_set(expr) - def simplify(self, expr): + def simplify(self, expr, steps=2): """Simplify expression via both rewrite and canonicalization. Parameters ---------- expr : PrimExpr The expression. + steps : The simplification runs in the order of + rewrite_simplify (step 1) -> canonical_simplify (step 2) -> + rewrite_simplify (step 3) -> canonical_simplify (step 4) -> ... + param steps controls how many steps to run. + Default is 2, i.e., rewrite_simplify + canonical_simplify. Returns ------- result : Expr The result. """ - return self._simplify(expr) + return self._simplify(expr, steps) def rewrite_simplify(self, expr): """Simplify expression via rewriting rules. diff --git a/python/tvm/testing.py b/python/tvm/testing.py index d597dc4bbaed..7483a9fb4cf8 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -193,13 +193,20 @@ def check_bool_expr_is_true(bool_expr, vranges, cond=None): """ Check that bool_expr holds given the condition cond for every value of free variables from vranges. + for example, 2x > 4y solves to x > 2y given x in (0, 10) and y in (0, 10) + here bool_expr is x > 2y, vranges is {x: (0, 10), y: (0, 10)}, cond is 2x > 4y + We creates iterations to check, + for x in range(10): + for y in range(10): + assert !(2x > 4y) || (x > 2y) + Parameters ---------- - bool_expr : tvm.ir.expr.PrimExpr + bool_expr : tvm.ir.PrimExpr Boolean expression to check vranges: Dict[tvm.tir.expr.Var, tvm.ir.Range] Free variables and their ranges - cond: tvm.ir.expr.PrimExpr + cond: tvm.ir.PrimExpr extra conditions needs to be satisfied. """ if cond is not None: @@ -239,7 +246,7 @@ def check_int_constraints_trans_consistency(constraints_trans, vranges=None): ---------- constraints_trans : arith.IntConstraintsTransform Integer constraints transformation - vranges: Dict[tvm.tir.expr.Var, tvm.ir.Range] + vranges: Dict[tvm.tir.Var, tvm.ir.Range] Free variables and their ranges """ if vranges is None: @@ -253,9 +260,10 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap): # Check that the transformation is injective cond_on_vars = tvm.tir.const(1, 'bool') for v in constraints1.variables: - # variable mapping is consistent - v_back = ana.simplify(tvm.tir.stmt_functor.substitute(varmap[v], backvarmap)) - cond_on_vars = tvm.te.all(cond_on_vars, v == v_back) + if v in varmap: + # variable mapping is consistent + v_back = ana.simplify(tvm.tir.stmt_functor.substitute(varmap[v], backvarmap)) + cond_on_vars = tvm.te.all(cond_on_vars, v == v_back) # Also we have to check that the new relations are true when old relations are true cond_subst = tvm.tir.stmt_functor.substitute( tvm.te.all(tvm.tir.const(1, 'bool'), *constraints2.relations), backvarmap) @@ -271,10 +279,6 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap): tvm.te.all(cond_subst, cond_on_vars), all_vranges, cond=tvm.te.all(tvm.tir.const(1, 'bool'), *constraints1.relations)) - rels = constraints_trans.dst.relations - if len(rels) == 1 and tvm.ir.structural_equal(rels[0], False): - # not solvable, skip - return _check_forward(constraints_trans.src, constraints_trans.dst, constraints_trans.src_to_dst, constraints_trans.dst_to_src) _check_forward(constraints_trans.dst, constraints_trans.src, diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index a3e63aa5d7e7..3f31df455ccd 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -115,7 +115,7 @@ bool Analyzer::CanProve(const PrimExpr& expr) { return false; } -PrimExpr Analyzer::Simplify(const PrimExpr& expr, size_t steps) { +PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { if (tir::is_const(expr)) return expr; PrimExpr res = expr; for (size_t i = 0; i < steps; ++i) { @@ -143,7 +143,15 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu self->const_int_bound.Update(args[0], args[1], args[2]); }); } else if (name == "Simplify") { - return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { *ret = self->Simplify(args[0]); }); + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + if (args.size() == 1) { + *ret = self->Simplify(args[0]); + } else if (args.size() == 2) { + *ret = self->Simplify(args[0], args[1]); + } else { + LOG(FATAL) << "Invalid size of argument (" << args.size() << ")"; + } + }); } else if (name == "rewrite_simplify") { return PackedFunc( [self](TVMArgs args, TVMRetValue* ret) { *ret = self->rewrite_simplify(args[0]); }); diff --git a/tests/python/unittest/test_arith_solve_linear_inequality.py b/tests/python/unittest/test_arith_solve_linear_inequality.py index c778aa555b71..51e5a7f85f71 100644 --- a/tests/python/unittest/test_arith_solve_linear_inequality.py +++ b/tests/python/unittest/test_arith_solve_linear_inequality.py @@ -97,10 +97,7 @@ def test_dual_variable(): assert ir.structural_equal(solution[3], y <= 5) # solve and get the ranges - solution = arith.solve_linear_inequalities([ - tvm.tir.LE(x + y, 20), - tvm.tir.GE(x - y, 10), - ], [x, y], ranges) + solution = arith.solve_linear_inequalities(problem, variables, ranges) # 0 <= y <=5 assert solution.ranges[y].min == 0 assert solution.ranges[y].extent == 6 @@ -157,6 +154,10 @@ def test_multi_equal(): solution = arith.solve_linear_inequalities(problem, [x, y, z]) assert solution.ranges[x].min == 6 assert solution.ranges[x].extent == 1 + assert len(solution.relations) == 3 + assert ir.structural_equal(solution.relations[0], x == z * y) + assert ir.structural_equal(solution.relations[1], z*y - 6 <= 0) + assert ir.structural_equal(solution.relations[2], 6 - z*y <= 0) solution = arith.solve_linear_inequalities(problem, [x, y, z], deskew_range=True) assert solution.src_to_dst[y] == y From 98cc3acfe80354685604253175950b528f82fcef Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 26 Jun 2020 12:54:40 -0700 Subject: [PATCH 29/33] fix a dumb compilation failure --- src/arith/analyzer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 3f31df455ccd..73020e9d24bf 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -118,7 +118,7 @@ bool Analyzer::CanProve(const PrimExpr& expr) { PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { if (tir::is_const(expr)) return expr; PrimExpr res = expr; - for (size_t i = 0; i < steps; ++i) { + for (int i = 0; i < steps; ++i) { res = this->rewrite_simplify(res); if (tir::is_const(res) || ++i == steps) return res; res = this->canonical_simplify(res); From 51eb65214e85ba844c65a90d06f06cef0552698e Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 1 Jul 2020 14:29:31 -0700 Subject: [PATCH 30/33] IntGrpBounds -> IntGroupBounds --- include/tvm/arith/int_solver.h | 30 +++++++++++----------- python/tvm/arith/int_solver.py | 12 ++++----- src/arith/int_constraints.cc | 38 ++++++++++++++-------------- src/arith/solve_linear_inequality.cc | 10 ++++---- 4 files changed, 45 insertions(+), 45 deletions(-) diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index 839a8de0b851..ad044b288941 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -48,9 +48,9 @@ using tir::VarNode; * coef * var >= lower * coef * var == equal * coef * var <= upper - * \sa IntGrpBounds + * \sa IntGroupBounds */ -class IntGrpBoundsNode : public Object { +class IntGroupBoundsNode : public Object { public: PrimExpr coef; Array lower; @@ -64,7 +64,7 @@ class IntGrpBoundsNode : public Object { v->Visit("upper", &upper); } - bool SEqualReduce(const IntGrpBoundsNode* other, SEqualReducer eq) const { + bool SEqualReduce(const IntGroupBoundsNode* other, SEqualReducer eq) const { return eq(coef, other->coef) && eq(lower, other->lower) && eq(equal, other->equal) && eq(upper, other->upper); } @@ -77,15 +77,15 @@ class IntGrpBoundsNode : public Object { } static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const char* _type_key = "arith.IntGrpBounds"; - TVM_DECLARE_FINAL_OBJECT_INFO(IntGrpBoundsNode, Object); + static constexpr const char* _type_key = "arith.IntGroupBounds"; + TVM_DECLARE_FINAL_OBJECT_INFO(IntGroupBoundsNode, Object); }; /*! - * \brief Managed reference to IntGrpBoundsNode. - * \sa IntGrpBoundsNode + * \brief Managed reference to IntGroupBoundsNode. + * \sa IntGroupBoundsNode */ -class IntGrpBounds : public ObjectRef { +class IntGroupBounds : public ObjectRef { public: /*! * \brief Constructor by fields @@ -97,20 +97,20 @@ class IntGrpBounds : public ObjectRef { * \param equal equalities * \param upper the upper bounds (include) */ - TVM_DLL IntGrpBounds(PrimExpr coef, Array lower, Array equal, - Array upper); + TVM_DLL IntGroupBounds(PrimExpr coef, Array lower, Array equal, + Array upper); /*! * \brief Construct bounds from a range. * \param r The range * \return constructed bounds. */ - static IntGrpBounds range(const Range& r); + static IntGroupBounds FromRange(const Range& r); /*! * \brief Perform substitution on all components of the struct. */ - IntGrpBounds Substitute(const Map& subst) const; + IntGroupBounds Substitute(const Map& subst) const; /*! * \brief Find the best range from the grouped bounds. @@ -125,9 +125,9 @@ class IntGrpBounds : public ObjectRef { * \param r range to be combined. * \return combined bounds. */ - IntGrpBounds operator+(const Range& r); + IntGroupBounds operator+(const Range& r); - TVM_DEFINE_OBJECT_REF_METHODS(IntGrpBounds, ObjectRef, IntGrpBoundsNode); + TVM_DEFINE_OBJECT_REF_METHODS(IntGroupBounds, ObjectRef, IntGroupBoundsNode); }; /*! @@ -254,7 +254,7 @@ class IntConstraintsTransform : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode); }; -typedef std::pair, Array> PartialSolvedInequalities; +typedef std::pair, Array> PartialSolvedInequalities; /*! * \brief Obtain Smith Normal Form of linear equation A x = y. diff --git a/python/tvm/arith/int_solver.py b/python/tvm/arith/int_solver.py index 7f28ed406239..91fa459b8aa8 100644 --- a/python/tvm/arith/int_solver.py +++ b/python/tvm/arith/int_solver.py @@ -20,8 +20,8 @@ from . import _ffi_api -@tvm._ffi.register_object("arith.IntGrpBounds") -class IntGrpBounds(Object): +@tvm._ffi.register_object("arith.IntGroupBounds") +class IntGroupBounds(Object): """Represent integer grouped bounds which are classified into lower bounds (include), upper bounds (include) and equalities. @@ -41,10 +41,10 @@ class IntGrpBounds(Object): """ def __init__(self, coef, lower, equal, upper): self.__init_handle_by_constructor__( - _ffi_api.IntGrpBounds, coef, lower, equal, upper) + _ffi_api.IntGroupBounds, coef, lower, equal, upper) @staticmethod - def make_by_range(rng): + def from_range(rng): """Construct a IntGroupedBounds by Range. Parameters @@ -57,13 +57,13 @@ def make_by_range(rng): ret : Range The constructed range. """ - return _ffi_api.int_grouped_bounds_by_range(rng) + return _ffi_api.IntGroupBounds_from_range(rng) def find_best_range(self): """Return the best range from the grouped bounds. None if (-inf, +inf). """ - return _ffi_api.IntGrpBounds_FindBestRange(self) + return _ffi_api.IntGroupBounds_FindBestRange(self) @tvm._ffi.register_object("arith.IntConstraints") diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index e0227e99473d..6e68f07bd82f 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -38,11 +38,11 @@ namespace tvm { namespace arith { -IntGrpBounds::IntGrpBounds(PrimExpr coef, Array lower, Array equal, +IntGroupBounds::IntGroupBounds(PrimExpr coef, Array lower, Array equal, Array upper) { CHECK(coef.dtype().is_int() || coef.dtype().is_uint()) - << "Coefficient in IntGrpBounds must be integers"; - ObjectPtr node = make_object(); + << "Coefficient in IntGroupBounds must be integers"; + ObjectPtr node = make_object(); node->coef = std::move(coef); node->lower = std::move(lower); node->equal = std::move(equal); @@ -50,7 +50,7 @@ IntGrpBounds::IntGrpBounds(PrimExpr coef, Array lower, Array data_ = std::move(node); } -IntGrpBounds IntGrpBounds::range(const Range& r) { +IntGroupBounds IntGroupBounds::FromRange(const Range& r) { Analyzer analyzer; PrimExpr coef = tir::make_const(r->min.dtype(), 1); Array equal; @@ -62,10 +62,10 @@ IntGrpBounds IntGrpBounds::range(const Range& r) { lower.push_back(r->min); upper.push_back(analyzer.Simplify(r->min + r->extent - 1)); } - return IntGrpBounds(coef, lower, equal, upper); + return IntGroupBounds(coef, lower, equal, upper); } -IntGrpBounds IntGrpBounds::operator+(const Range& r) { +IntGroupBounds IntGroupBounds::operator+(const Range& r) { Analyzer analyzer; Array equal; Array lower; @@ -80,18 +80,18 @@ IntGrpBounds IntGrpBounds::operator+(const Range& r) { for (const auto& eq : operator->()->equal) equal.push_back(eq); for (const auto& lb : operator->()->lower) lower.push_back(lb); for (const auto& ub : operator->()->upper) upper.push_back(ub); - return IntGrpBounds(coef, lower, equal, upper); + return IntGroupBounds(coef, lower, equal, upper); } -IntGrpBounds IntGrpBounds::Substitute(const Map& subst) const { +IntGroupBounds IntGroupBounds::Substitute(const Map& subst) const { auto apply_fun = [&subst](const PrimExpr& e) { return tir::Substitute(e, subst); }; - return IntGrpBounds(tir::Substitute(operator->()->coef, subst), + return IntGroupBounds(tir::Substitute(operator->()->coef, subst), tir::UpdateArray(operator->()->lower, apply_fun), tir::UpdateArray(operator->()->equal, apply_fun), tir::UpdateArray(operator->()->upper, apply_fun)); } -Range IntGrpBounds::FindBestRange(const Map& vranges_addl) const { +Range IntGroupBounds::FindBestRange(const Map& vranges_addl) const { Analyzer analyzer; analyzer.Bind(vranges_addl); @@ -158,18 +158,18 @@ Range IntGrpBounds::FindBestRange(const Map& vranges_addl) const { return Range::FromMinExtent(best_lower, analyzer.Simplify(best_diff_over + 1)); } -TVM_REGISTER_NODE_TYPE(IntGrpBoundsNode); +TVM_REGISTER_NODE_TYPE(IntGroupBoundsNode); -TVM_REGISTER_GLOBAL("arith.IntGrpBounds") +TVM_REGISTER_GLOBAL("arith.IntGroupBounds") .set_body_typed([](PrimExpr coef, Array lower, Array equal, - Array upper) { return IntGrpBounds(coef, lower, equal, upper); }); + Array upper) { return IntGroupBounds(coef, lower, equal, upper); }); -TVM_REGISTER_GLOBAL("arith.int_grouped_bounds_by_range").set_body_typed(IntGrpBounds::range); +TVM_REGISTER_GLOBAL("arith.IntGroupBounds_from_range").set_body_typed(IntGroupBounds::FromRange); -TVM_REGISTER_GLOBAL("arith.IntGrpBounds_FindBestRange") +TVM_REGISTER_GLOBAL("arith.IntGroupBounds_FindBestRange") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK(args.size() == 1 || args.size() == 2); - IntGrpBounds bounds = args[0]; + IntGroupBounds bounds = args[0]; if (args.size() == 1) { *ret = bounds.FindBestRange(); } else if (args.size() == 2) { @@ -178,9 +178,9 @@ TVM_REGISTER_GLOBAL("arith.IntGrpBounds_FindBestRange") }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "IntGrpBounds(coef=" << op->coef << ", lower=" << op->lower + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IntGroupBounds(coef=" << op->coef << ", lower=" << op->lower << ", equal=" << op->equal << ", upper=" << op->upper << ")"; }); diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 93cf00dedfaf..fa6998c380c7 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -97,7 +97,7 @@ struct ExprLess { /*! * \brief Combine the information into an array of (in)equalities. */ -Array as_conditions(const Array& variables, const Map& bounds, +Array as_conditions(const Array& variables, const Map& bounds, const Array& relations) { Array res; // use variables to keep the order of iteration @@ -296,7 +296,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t &analyzer); } - Map res_bounds; + Map res_bounds; for (const Var& v : system_to_solve->variables) { CHECK(!res_bounds.count(v)) << "Variable " << v @@ -404,7 +404,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t std::sort(equal_list.begin(), equal_list.end(), ExprLess()); // Write it to the result. - IntGrpBounds bnds(make_const(v.dtype(), coef_lcm), + IntGroupBounds bnds(make_const(v.dtype(), coef_lcm), Array(lower_bounds.begin(), lower_bounds.end()), Array(equal_list.begin(), equal_list.end()), Array(upper_bounds.begin(), upper_bounds.end())); @@ -443,7 +443,7 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { // we get a set of equality, lower, upper bound of each variable. auto solved_system = SolveLinearInequalities(inequalities); - Map solved_bounds = solved_system.first; + Map solved_bounds = solved_system.first; Array solved_other_relations = solved_system.second; Array res_relations; @@ -511,7 +511,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ Map res_ranges; // we get a set of equality, lower, upper bound of each variable. auto solved_system = SolveLinearInequalities(inequalities); - Map solved_bounds = solved_system.first; + Map solved_bounds = solved_system.first; Array solved_other_relations = solved_system.second; arith::Analyzer analyzer; From aaa53f37a01a70ea402575fabd95493d44258cd2 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 1 Jul 2020 14:39:47 -0700 Subject: [PATCH 31/33] move if check to the root --- src/arith/solve_linear_inequality.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index fa6998c380c7..d004f6cd3141 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -184,8 +184,8 @@ void AddInequality(std::unordered_set // or has already been added return; } - for (auto iter = inequality_set->begin(); iter != inequality_set->end();) { - if (const LENode* new_le = new_ineq.as()) { + if (const LENode* new_le = new_ineq.as()) { + for (auto iter = inequality_set->begin(); iter != inequality_set->end();) { const LENode* le = iter->as(); if (le && analyzer->CanProve(new_le->a - le->a <= 0)) { return; @@ -194,8 +194,6 @@ void AddInequality(std::unordered_set } else { ++iter; } - } else { - ++iter; } } From d42f2cae521ddab99e1731a9decc584f39c16293 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 1 Jul 2020 14:50:35 -0700 Subject: [PATCH 32/33] fix lint --- src/arith/int_constraints.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 6e68f07bd82f..e4e4d5ac20e3 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -162,7 +162,9 @@ TVM_REGISTER_NODE_TYPE(IntGroupBoundsNode); TVM_REGISTER_GLOBAL("arith.IntGroupBounds") .set_body_typed([](PrimExpr coef, Array lower, Array equal, - Array upper) { return IntGroupBounds(coef, lower, equal, upper); }); + Array upper) { + return IntGroupBounds(coef, lower, equal, upper); + }); TVM_REGISTER_GLOBAL("arith.IntGroupBounds_from_range").set_body_typed(IntGroupBounds::FromRange); From b21c13f641fc54d2dcc81ae823c7a1c0334ef67d Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 1 Jul 2020 14:54:23 -0700 Subject: [PATCH 33/33] fix clang format check --- src/arith/int_constraints.cc | 8 ++++---- src/arith/solve_linear_inequality.cc | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index e4e4d5ac20e3..c95f7f855ceb 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -39,7 +39,7 @@ namespace tvm { namespace arith { IntGroupBounds::IntGroupBounds(PrimExpr coef, Array lower, Array equal, - Array upper) { + Array upper) { CHECK(coef.dtype().is_int() || coef.dtype().is_uint()) << "Coefficient in IntGroupBounds must be integers"; ObjectPtr node = make_object(); @@ -86,9 +86,9 @@ IntGroupBounds IntGroupBounds::operator+(const Range& r) { IntGroupBounds IntGroupBounds::Substitute(const Map& subst) const { auto apply_fun = [&subst](const PrimExpr& e) { return tir::Substitute(e, subst); }; return IntGroupBounds(tir::Substitute(operator->()->coef, subst), - tir::UpdateArray(operator->()->lower, apply_fun), - tir::UpdateArray(operator->()->equal, apply_fun), - tir::UpdateArray(operator->()->upper, apply_fun)); + tir::UpdateArray(operator->()->lower, apply_fun), + tir::UpdateArray(operator->()->equal, apply_fun), + tir::UpdateArray(operator->()->upper, apply_fun)); } Range IntGroupBounds::FindBestRange(const Map& vranges_addl) const { diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index d004f6cd3141..f489d046835d 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -403,9 +403,9 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // Write it to the result. IntGroupBounds bnds(make_const(v.dtype(), coef_lcm), - Array(lower_bounds.begin(), lower_bounds.end()), - Array(equal_list.begin(), equal_list.end()), - Array(upper_bounds.begin(), upper_bounds.end())); + Array(lower_bounds.begin(), lower_bounds.end()), + Array(equal_list.begin(), equal_list.end()), + Array(upper_bounds.begin(), upper_bounds.end())); res_bounds.Set(v, bnds); std::swap(current_ineq_set_to_solve, next_ineq_set_to_solve);