From 1ee5937ea2939ea989760181e0c1ab2729a3da8e Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 10 Apr 2020 08:11:21 -0700 Subject: [PATCH] [Arith] linear system and equation solver (#5171) * [arith] linear system and equation solver Co-authored-by: Sergei Grechanik * avoid constructing analyzer every time * generate random test cases and address comments Co-authored-by: Sergei Grechanik * rename linear_system to int_constraints * add comments and use random seed * message for reporting failure with seed * add SEqualReduce to IntConstraints; allow variables & ranges to be None Co-authored-by: Sergei Grechanik Co-authored-by: Sergei Grechanik --- include/tvm/arith/analyzer.h | 6 + include/tvm/arith/int_solver.h | 208 ++++++++ include/tvm/arith/util.h | 45 ++ python/tvm/arith/__init__.py | 1 + python/tvm/arith/int_solver.py | 99 ++++ src/arith/analyzer.cc | 5 + src/arith/int_constraints.cc | 96 ++++ src/arith/solve_linear_equation.cc | 480 ++++++++++++++++++ src/arith/util.cc | 53 ++ .../test_arith_solve_linear_system.py | 237 +++++++++ 10 files changed, 1230 insertions(+) create mode 100644 include/tvm/arith/int_solver.h create mode 100644 include/tvm/arith/util.h create mode 100644 python/tvm/arith/int_solver.py create mode 100644 src/arith/int_constraints.cc create mode 100644 src/arith/solve_linear_equation.cc create mode 100644 src/arith/util.cc create mode 100644 tests/python/unittest/test_arith_solve_linear_system.py diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 1889e16fef661..3a71e5eb5fbf8 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -423,6 +423,12 @@ class Analyzer { * \param range The range we bind to. */ void Bind(const Var& var, const Range& range); + /*! + * \brief Bind all the vars in the Map + * + * \param variables The {variable -> range} map. + */ + void Bind(const Map& variables); /*! * \brief Whether can we prove expr >= val. diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h new file mode 100644 index 0000000000000..57f3af4bb67b3 --- /dev/null +++ b/include/tvm/arith/int_solver.h @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/arith/int_solver.h + * \brief integer constraints data structures and solvers + */ +#ifndef TVM_ARITH_INT_SOLVER_H_ +#define TVM_ARITH_INT_SOLVER_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace arith { + +using tir::Var; +using tir::VarNode; +using tir::IterVar; + +/*! + * \brief Represent integer constrains including (integer) variables, their ranges and + * the relations between them (either equations or inequalities). + * \sa LinearSystem + */ +class IntConstraintsNode : public Object { + public: + // e.g., \alpha, \beta, must be integers + Array variables; + // e.g., 1 <= \alpha <= N, etc. + // it is absolutely ok to include ranges for parameters + // (variables that are not in this->variables) in this map + Map ranges; + // linear equalities or inequalities + // e.g., A \alpha = \beta or A \alpha <= \beta + Array relations; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("variables", &variables); + v->Visit("ranges", &ranges); + v->Visit("relations", &relations); + } + + bool SEqualReduce(const IntConstraintsNode* other, SEqualReducer equal) const { + return + equal(variables, other->variables) && + equal(ranges, other->ranges) && + equal(relations, other->relations); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(variables); + hash_reduce(ranges); + hash_reduce(relations); + } + + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const char* _type_key = "arith.IntConstraints"; + TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object); +}; + +/*! + * \brief Managed reference to IntConstraintsNode. + * \sa IntConstraintsNode + */ +class IntConstraints : public ObjectRef { + public: + /*! + * \brief Constructor by fields + * \param variables The variables in the constraints, must be integers. + * \param ranges The ranges of the variables. + * \param relations The linear relations between the variables + * (either equations or inequalities) + */ + TVM_DLL IntConstraints(Array variables, + Map ranges, + Array relations); + + TVM_DEFINE_OBJECT_REF_METHODS(IntConstraints, ObjectRef, IntConstraintsNode); +}; + +/*! + * \brief We can have different set of variables to represent the same constraints. + * For example, the following two systems are equivalent, + * {a + b = 0 | a >= 0, b >= 0} and + * {m - n = 0 | m >= 0, n <= 0} + * This data structure represents the transformation + * between two equivalent linear systems. + * In the above example, + * src : {a + b = 0 | a >= 0, b >= 0} + * dst : {m - n = 0 | m >= 0, n <= 0} + * src_to_dst : {a -> m, b -> -n} + * dst_to_src : {m -> a, n -> -b} + * \sa IntConstraintsTransform + */ +class IntConstraintsTransformNode : public Object { + public: + IntConstraints src; + IntConstraints dst; + Map src_to_dst; + Map dst_to_src; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("src", &src); + v->Visit("dst", &dst); + v->Visit("src_to_dst", &src_to_dst); + v->Visit("dst_to_src", &dst_to_src); + } + + bool SEqualReduce(const IntConstraintsTransformNode* other, SEqualReducer equal) const { + return + equal(src, other->src) && + equal(dst, other->dst) && + equal(src_to_dst, other->src_to_dst) && + equal(dst_to_src, other->dst_to_src); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(src); + hash_reduce(dst); + hash_reduce(src_to_dst); + hash_reduce(dst_to_src); + } + + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const char* _type_key = "arith.IntConstraintsTransform"; + TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object); +}; + +/*! + * \brief Managed reference to IntConstraintsTransformNode. + * \sa IntConstraintsTransformNode + */ +class IntConstraintsTransform : public ObjectRef { + public: + /*! + * \brief Constructor by fields + * \param src source integer constraints, e.g., {a + b = 0 | a >= 0, b >= 0} + * \param dst integer constraints equivalent to the source, + * e.g., {m - n = 0 | m >= 0, n <= 0} + * \param src_to_dst mapping from variables in the \p src to the variables in the \p dst, + * e.g., {a -> m, b -> -n} + * \param dst_to_src mapping from variables in the \p dst to the variables in the \p src, + * e.g., {m -> a, n -> -b} + */ + TVM_DLL IntConstraintsTransform(IntConstraints src, + IntConstraints dst, + Map src_to_dst, + Map dst_to_src); + + TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode); +}; + +/*! + * \brief Obtain Smith Normal Form of linear equation A x = y. + * Smith Normal Form of matrix A_{mxn} is S_{mxn} = U_{mxm} A_{mxn} V_{nxn}, + * in which S_{mxn} is diag(s1, s2, ..., sr, 0, ..., 0) and r is the rank of A. + * NOTE: Although in standard Smith Normal Form the diagonal elements satisfy + * s_i | s_{i+1} (| means divides), the implement here does not guarantee it. + * TODO(yzhliu): From sergei-grechanik: + * computing the proper Smith normal form may improve stability of automatic differentiation + * (generating the same gradient code for slightly different but equivalent input code + * U_{mxm} and V_{nxn} are invertible matrices. + * This function modifies \p S to be S_{mxn}, \p V to be V_{nxn}, + * \p y to be U_{mxm} y_{mx1} and \p x to be V^{-1} x. + * \param S the original A_{mxn}, it will be modified to S_{mxn} + * \param V an identity matrix, it will be modified to V_{nxn} + * \param x the x in A x = y. it will be modified to V^{-1}_{nxn} x_{nx1} + * \param y the y in A x = y. it will be modified to U_{mxm} y_{mx1} + */ +void SmithNormalFormDiag(std::vector> *S, + std::vector> *V, + std::vector* x, + std::vector *y); + +/*! + * \brief Solve linear equations. + * \param system_to_solve the variables to solve, their ranges, and a list of equations. + * \return A new linear system, with less variables (if \p system_to_solve is NOT of full rank), + * or no variable (if \p system_to_solve is of full rank), + * or an empty linear system (if \p system_to_solve is unsolvable). + * It also provides the ranges of the variables in the new system, + * as well as inequalities inferred from the \p system_to_solve. + * You can get the mapping from the original variables to the solution via ret->src_to_dst. + */ +IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve); + +} // namespace arith +} // namespace tvm +#endif // TVM_ARITH_INT_SOLVER_H_ diff --git a/include/tvm/arith/util.h b/include/tvm/arith/util.h new file mode 100644 index 0000000000000..adfcefcd2e21d --- /dev/null +++ b/include/tvm/arith/util.h @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/arith/util.h + * \brief Utils for arithmetic analysis. + */ +#ifndef TVM_ARITH_UTIL_H_ +#define TVM_ARITH_UTIL_H_ + +#include +#include + +namespace tvm { +/*! \brief namespace of arithmetic analysis. */ +namespace arith { + +/*! + * \brief Calculate the extended greatest common divisor for two values. + * See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm. + * \param a an integer number + * \param b an integer number + * \return 3 integers (div, m, n) where div = gcd(a, b) and a*m + b*n = div + */ +std::tuple xgcd(int64_t a, int64_t b); + +} // namespace arith +} // namespace tvm +#endif // TVM_ARITH_UTIL_H_ diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index 40e977e61d75b..017934a03b34f 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -20,3 +20,4 @@ from .analyzer import ModularSet, ConstIntBound, Analyzer from .bound import deduce_bound from .pattern import detect_linear_equation, detect_clip_bound +from .int_solver import solve_linear_equations diff --git a/python/tvm/arith/int_solver.py b/python/tvm/arith/int_solver.py new file mode 100644 index 0000000000000..e35435c1da03e --- /dev/null +++ b/python/tvm/arith/int_solver.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""integer constraints data structures and solvers""" +import tvm._ffi +from tvm.runtime import Object +from . import _ffi_api + + +@tvm._ffi.register_object("arith.IntConstraints") +class IntConstraints(Object): + """Represent a set of integer constraints including variables, their ranges and + the relations between them (either equations or inequalities) + + Parameters + ---------- + variables : List[tvm.tir.Var] + The variables in the constraints. Must be integers + ranges : Map[tvm.tir.Var, tvm.ir.Range] + The ranges of the variables. + relations : List[tvm.ir.PrimExpr] + The relations between the variables (either equations or inequalities) + """ + def __init__(self, variables, ranges, relations): + self.__init_handle_by_constructor__( + _ffi_api.IntConstraints, variables, ranges, relations) + + +@tvm._ffi.register_object("arith.IntConstraintsTransform") +class IntConstraintsTransform(Object): + """We can have different set of variables to represent the same integer constraints. + For example, the following two constrains are equivalent, + {a + b = 0 | a >= 0, b >= 0} and + {m - n = 0 | m >= 0, n <= 0} + This data structure represents the transformation + between two equivalent integer constraints. + In the above example, + src : {a + b = 0 | a >= 0, b >= 0} + dst : {m - n = 0 | m >= 0, n <= 0} + src_to_dst : {a -> m, b -> -n} + dst_to_src : {m -> a, n -> -b} + + Parameters + ---------- + src : arith.IntConstraints + source integer constraints, e.g., {a + b = 0 | a >= 0, b >= 0} + dst : arith.IntConstraints + integer constraints equivalent to the source, e.g., {m - n = 0 | m >= 0, n <= 0} + src_to_dst : Map[tvm.tir.Var, tvm.ir.PrimExpr] + mapping from variables in the src to the variables in the dst, + e.g., {a -> m, b -> -n} + dst_to_src : Map[tvm.tir.Var, tvm.ir.PrimExpr] + mapping from variables in the dst to the variables in the src, + e.g., {m -> a, n -> -b} + """ + def __init__(self, src, dst, src_to_dst, dst_to_src): + self.__init_handle_by_constructor__( + _ffi_api.IntConstraintsTransform, src, dst, src_to_dst, dst_to_src) + + +def solve_linear_equations(equations, variables=None, ranges=None): + """Solve linear equations. + + Parameters + ---------- + equations: List[tvm.ir.PrimExpr] or IntConstraints + The equations of the variables + variables : Optional[List[tvm.tir.Var]] + The variables in the system. + ranges : Optional[Map[tvm.tir.Var, tvm.ir.Range]] + The ranges of the variables. + + Returns + ------- + int_constraints_transform : IntConstraintsTransform + New integer constraints, with less variables (if the problem is NOT of full rank), + or no variable (if the problem is of full rank), + or an empty integer constraints (if the problem is unsolvable). + It also provides the ranges of the variables in the new system, + as well as inequalities inferred from the problem. + You can get the mapping from the original variables to the solution via + int_constraints_transform.src_to_dst. + """ + if isinstance(equations, IntConstraints): + return _ffi_api.SolveLinearEquations(equations) + return _ffi_api.SolveLinearEquations(variables, ranges, equations) diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 9df5aa2d246de..83dfc64009cf3 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -58,6 +58,11 @@ void Analyzer::Bind(const Var& var, const Range& range) { // skip rewrite simplify } +void Analyzer::Bind(const Map& variables) { + for (const auto& iter : variables) { + this->Bind(iter.first, iter.second); + } +} void ConstraintContext::EnterWithScope() { CHECK(exit_ == nullptr); diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc new file mode 100644 index 0000000000000..34efa986e9856 --- /dev/null +++ b/src/arith/int_constraints.cc @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file int_constraints.cc + * \brief The integer constraints data structures. + */ +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace arith { + +IntConstraints::IntConstraints(Array variables, + Map ranges, + Array relations) { + ObjectPtr node = make_object(); + if (!variables.defined()) { + variables = Array(); + } + if (!ranges.defined()) { + ranges = Map(); + } + CHECK(relations.defined()); + for (const auto& var : variables) { + CHECK(var.dtype().is_int() || var.dtype().is_uint()) + << "Variables in IntConstraints must be integers"; + } + node->variables = std::move(variables); + node->ranges = std::move(ranges); + node->relations = std::move(relations); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(IntConstraintsNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IntConstraints(" + << op->variables + << ", " << op->ranges + << ", " << op->relations + << ")"; + }); + + +IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, + IntConstraints dst, + Map src_to_dst, + Map dst_to_src) { + ObjectPtr node = make_object(); + node->src = std::move(src); + node->dst = std::move(dst); + node->src_to_dst = std::move(src_to_dst); + node->dst_to_src = std::move(dst_to_src); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IntConstraintsTransform(" + << "\n\t" << op->src + << "\n\t" << op->dst + << "\n\t" << op->src_to_dst + << "\n\t" << op->dst_to_src + << "\n)"; + }); + +} // namespace arith +} // namespace tvm diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc new file mode 100644 index 0000000000000..8142a03155c8c --- /dev/null +++ b/src/arith/solve_linear_equation.cc @@ -0,0 +1,480 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/arith/solve_linear_equation.cc + * \brief Solve linear equations. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace arith { + +using namespace tvm::runtime; + +void SmithNormalFormDiag(std::vector >* S, + std::vector >* V, + std::vector* x, + std::vector* y) { + if (S->empty() || V->empty()) return; + size_t m = S->size(); + size_t n = (*S)[0].size(); // n is # of variables + CHECK_EQ(V->size(), n); + CHECK_EQ((*V)[0].size(), n); + + for (size_t index = 0; index < std::min(m, n); ++index) { + // Here A is partially diagonalized, that is A[i, j] is zero for all i, j + // such that (i < index) or (j < index), unless (i == j). + // That is, now we are diagonalizing the submatrix with i >= index and j >= index + + // Find a row with a nonzero element in the index-th column + // (We also prefer rows where this element has minimal abs value) + size_t best_i = index; + for (size_t i = best_i; i < m; ++i) { + int64_t s_old = (*S)[best_i][index]; + int64_t s_new = (*S)[i][index]; + if (s_new != 0) { + if (s_old == 0 || std::abs(s_new) < std::abs(s_old)) { + best_i = i; + } + } + } + // Move the row we found to the index-th position + std::swap((*S)[index], (*S)[best_i]); + std::swap((*y)[index], (*y)[best_i]); + + // If the index-th diagonal element is still zero, try to find a column with nonzero index-th + // element and move it to the index-th position + if ((*S)[index][index] == 0) { + for (size_t j = index + 1; j < n; ++j) { + if ((*S)[index][j] != 0) { + for (size_t i = index; i < m; ++i) { + std::swap((*S)[i][index], (*S)[i][j]); + } + // swapping columns corresponds to swapping the corresponding x + std::swap((*x)[index], (*x)[j]); + for (size_t i = 0; i < n; ++i) { + std::swap((*V)[i][index], (*V)[i][j]); + } + break; + } + } + } + + // If the index-th diagonal element is still zero, then both the index-th row and the index-th + // column are completely zero, and we don't need to do anything; just go to the next index + if ((*S)[index][index] == 0) { + continue; + } + + // Now the index-th diagonal element is non-zero and we can zero all the index-th column + // below it by subtracting rows from each other + for (auto i = index + 1; i < m; ++i) { + if ((*S)[i][index] != 0) { + int64_t g, a, b; + // g = a*matrix[index][index] + b*matrix[i][index] + if ((*S)[i][index] % (*S)[index][index] != 0) { + std::tie(g, a, b) = xgcd((*S)[index][index], (*S)[i][index]); + } else { + // Explicitly avoid changing the index-th row. This is important to avoid infinite loop. + g = (*S)[index][index]; + a = 1; + b = 0; + } + + // Let m = S[index][index], n = S[i][index], then the following is true: + // + // [ a n/g ][ m/g n/g ] = [ 1 0 ] + // [ b -m/g ][ b -a ] = [ 0 1 ] + // + // Note that the two matrices are integer (since g = gcd(m, n)). + // We will essentially multiply our matrix on the left by a dilated and transposed version + // of the first of these two matrices. The second matrix is not needed here, however we will + // use it while zeroing the index-th row. + + int64_t m_g = (*S)[index][index] / g; + int64_t n_g = (*S)[i][index] / g; + + // Note that j is the index of the column, not the row + for (size_t j = index; j < (*S)[i].size(); ++j) { + // Multiply index-th row by a and add the i-th row multiplied by b + // This will make the index-th diagonal element equal to the gcd + int64_t new_index_j = a*(*S)[index][j] + b*(*S)[i][j]; + // This transformation performs zeroing of matrix[i][index] + int64_t new_i_j = n_g*(*S)[index][j] - m_g*(*S)[i][j]; + (*S)[index][j] = new_index_j; + (*S)[i][j] = new_i_j; + } + // We have to do the same with rhs + PrimExpr ea = te::make_const((*y)[index].dtype(), a); + PrimExpr eb = te::make_const((*y)[i].dtype(), b); + PrimExpr e_m_g = te::make_const((*y)[i].dtype(), m_g); + PrimExpr e_n_g = te::make_const((*y)[index].dtype(), n_g); + PrimExpr new_index_rhs = ea*(*y)[index] + eb*(*y)[i]; + PrimExpr new_i_rhs = e_n_g*(*y)[index] - e_m_g*(*y)[i]; + (*y)[index] = new_index_rhs; + (*y)[i] = new_i_rhs; + } + } + + bool changed = false; + + // Now we have to zero the elements of the index-th row by manipulating columns. + // This is more difficult because column manipulation corresponds to variable manipulation, + // but the algorithm is essentially the same as before. + for (size_t j = index + 1; j < n; ++j) { + if ((*S)[index][j] != 0) { + int64_t g, a, b; + // g = a*matrix[index][index] + b*matrix[index][j] + if ((*S)[index][j] % (*S)[index][index] != 0) { + std::tie(g, a, b) = xgcd((*S)[index][index], (*S)[index][j]); + // During this phase we may disrupt the zeroness of the index-th column, so we will + // have to take some action if this might have happened. + changed = true; + } else { + // Explicitly avoid changing the index-th column. This is important to avoid infinite + // loop. Note that here we don't have to set `changed` to true since we don't change the + // index-th column. + g = (*S)[index][index]; + a = 1; + b = 0; + } + + // Let m = S[index][index], n = S[index][j], then the following is true: + // + // [ a n/g ][ m/g n/g ] = [ 1 0 ] + // [ b -m/g ][ b -a ] = [ 0 1 ] + // + // Now we are going to multiply our matrix on the right (to manipulate columns instead of + // rows), we will also transform the old_to_new matrix the same way, and we will use the + // second matrix to transform new_to_old. + + int64_t m_g = (*S)[index][index] / g; + int64_t n_g = (*S)[index][j] / g; + + for (size_t i = index; i < m; ++i) { + int64_t new_i_index = a*(*S)[i][index] + b*(*S)[i][j]; + int64_t new_i_j = n_g*(*S)[i][index] - m_g*(*S)[i][j]; + (*S)[i][index] = new_i_index; + (*S)[i][j] = new_i_j; + } + // We do exactly the same transformations with V + for (size_t i = 0; i < n; ++i) { + int64_t new_i_index = a*(*V)[i][index] + b*(*V)[i][j]; + int64_t new_i_j = n_g*(*V)[i][index] - m_g*(*V)[i][j]; + (*V)[i][index] = new_i_index; + (*V)[i][j] = new_i_j; + } + // And apply reverse transformations to new_to_old. + PrimExpr ea = te::make_const((*x)[j].dtype(), a); + PrimExpr eb = te::make_const((*x)[index].dtype(), b); + PrimExpr e_m_g = te::make_const((*x)[index].dtype(), m_g); + PrimExpr e_n_g = te::make_const((*x)[j].dtype(), n_g); + PrimExpr new_index = e_m_g*(*x)[index] + e_n_g*(*x)[j]; + PrimExpr new_j = eb*(*x)[index] - ea*(*x)[j]; + (*x)[index] = new_index; + (*x)[j] = new_j; + } + } + + if (changed) { + // We might have changed the first column, so we have to zero it once more + // (or at least check if it's zero), so just perform this iteration once more. + index -= 1; + } + } +} + +Map InferRange(const Map& vars_to_infer, + const Array& ori_vars, + const Map& ori_ranges) { + // The resulting ranges + Map new_ranges; + + std::unordered_set ori_vset; + for (const Var& v : ori_vars) { + ori_vset.insert(v.get()); + } + + std::unordered_map var_intsets; + for (const auto& p : ori_ranges) { + if (!ori_vset.count(p.first.get())) { + // First of all, fill the new ranges with outer variable ranges + new_ranges.Set(p.first, p.second); + } + // Convert original ranges to IntSets + var_intsets[p.first.get()] = IntSet::range(p.second); + } + + // Infer ranges for the new variables and add them to the resulting ranges + for (const auto& p : vars_to_infer) { + const auto& var = p.first; + const auto& expr = p.second; + Range range = EvalSet(expr, var_intsets).cover_range(Range()); + if (range.defined()) { + new_ranges.Set(var, range); + } + } + return new_ranges; +} + +// pretty print matrix equation +void DebugPrint(const std::vector>& S, + const std::vector>& V, + const std::vector& V_inv_x, + const std::vector& rhs) { + std::cout << "S:\n"; + for (size_t i = 0; i < S.size(); ++i) { + for (auto e : S[i]) { + std::cout << e << "\t"; + } + std::cout << "\t->\t" << rhs[i]; + std::cout << "\n"; + } + std::cout << "V:\n"; + for (const auto& r : V) { + for (auto e : r) { + std::cout << e << "\t"; + } + std::cout << "\n"; + } + std::cout << "V_inv x:\n" << Array(V_inv_x); + std::cout << "\n" << std::endl; +} + +IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve) { + // m: # of equations + // n: # of variables + // we first construct A_{mxn} x_{nx1} = y_{mx1} + // then get Smith normal form of matrix A, + // S_{mxn} = U_{mxm} A_{mxn} V_{nxn} + // => U^{-1} S V^{-1} x = y + // S V^{-1} x = U y + std::vector Uy; // mx1 + std::vector> S; // mxn + std::vector> V; // nxn + std::vector V_inv_x; // V^{-1} x, nx1 + // Conditions we don't know what to do with + std::vector rest; + + Analyzer analyzer_problem; + analyzer_problem.Bind(system_to_solve->ranges); + + size_t num_vars = system_to_solve->variables.size(); + + // initialize V_{nxn} with identity matrix, + // initialize V^{-1} x as x + for (size_t i = 0; i < num_vars; ++i) { + V.emplace_back(num_vars); + V.back()[i] = 1; + V_inv_x.push_back(system_to_solve->variables[i]); + } + + // Transform formulas into rows of the matrix + // S_{mxn} V^{-1}_{nxn} x_{nx1} = U y, in which n is # of variables + // here we initialize S_{mxn} to be A, U to be identity matrix. + for (const PrimExpr& equation : system_to_solve->relations) { + if (const tir::EQNode* eq = equation.as()) { + // a-b = sum_{i=0}^{n-1} variables[i] * coeff[i] + coeff[n] + Array coeffs = arith::DetectLinearEquation( + analyzer_problem.Simplify(eq->a - eq->b), + system_to_solve->variables); + if (!coeffs.empty()) { + std::vector row; + for (size_t j = 0; j < coeffs.size() - 1; ++j) { + PrimExpr c = coeffs[j]; + if (const IntImmNode* ic = c.as()) { + row.push_back(ic->value); + } else { + // elements in matrix S V must be integers + // ignore equations that we cannot deal with. + LOG(WARNING) << "Cannot deal with non-integer coefficients, ignore equation " + << equation; + row.clear(); + break; + } + } + + if (!row.empty()) { + // S V^{-1} (a-b) = Uy + // V is identity for now + S.push_back(row); + Uy.push_back(-coeffs[coeffs.size() - 1]); + continue; + } + } + } + + // otherwise + rest.push_back(equation); + } + + // After diagonalizing, we have + // S_{mxn} is the Smith normal form (diagonal matrix) + // V_{nxn} is invertible + // V_inv_x is V^{-1} \times x + // Uy is U \times y + SmithNormalFormDiag(&S, &V, &V_inv_x, &Uy); + + Array new_vars; + Array new_relations; + Map new_to_old_map; + Map old_to_new_map; + + // Simplify right hand sides + for (PrimExpr r : Uy) { + r = analyzer_problem.Simplify(r); + } + + // Create the relations of the existence of a solution + for (size_t j = 0; j < S.size(); ++j) { + PrimExpr new_relation; + if (j >= num_vars || S[j][j] == 0) { + // The row of matrix is zero. A solution exists only if the Ub[j] is also zero + new_relation = (Uy[j] == 0); + } else { + // The diagonal element is non-zero. A solution exists only if the diagonal element + // is a divisor of the Ub[j] + new_relation = (floormod(Uy[j], std::abs(S[j][j])) == 0); + } + new_relation = analyzer_problem.Simplify(new_relation); + if (tir::is_const_int(new_relation, 0)) { + // unable to solve the system. + return IntConstraintsTransform( + system_to_solve, + IntConstraints( + /*variables=*/{}, + /*ranges=*/{}, + /*relations=*/{te::make_zero(DataType::Bool())}), + {}, {}); + } else if (!tir::is_const_int(new_relation, 1)) { + new_relations.push_back(new_relation); + } + } + + Array solution_for_V_inv_x; + // Now create new variables or directly solve the equations + // suppose the rank of A is r, aka r = # of non-zeros in S + // the solution of S_{mxn} V^{-1}_{nxn} x_{nx1} = U b + // is + // x = (pseudo-inverse of A) b + K_{(n)x(n-r)} z_{n-r} + // = V_{nxn} S^{-1}_{nxm} (Ub)_{mxn} + K_{(n)x(n-r)} z_{n-r} + // in which K is the right n-r columns of V, z is variable vector + // thus, + // V^{-1} x = S^{-1}_{nxm} (Ub)_{mxn} + + // [[0, ... 0]_{n-r}, ... [0, ..., 0], diag(1, ..., 1)_{(n-r)x(n-r)}] z_{n-r} + for (size_t j = 0; j < num_vars; ++j) { + if (j >= S.size() || S[j][j] == 0) { + // The j-th variable can take any integer value, create a tvm variable for it + PrimExpr to_old = analyzer_problem.Simplify(V_inv_x[j]); + std::string name_hint = "n" + std::to_string(new_vars.size()); + if (const VarNode* v_old = to_old.as()) { + name_hint += "_" + v_old->name_hint; + } + Var v = Var(name_hint, V_inv_x[j].dtype()); + solution_for_V_inv_x.push_back(v); + new_vars.push_back(v); + new_to_old_map.Set(v, to_old); + } else { + // The j-th variable is just a single value, don't create a tvm variable + // S^{-1}_{nxm} Uy_{mxn} + if (S[j][j] >= 0) { + PrimExpr a = te::make_const(Uy[j].dtype(), S[j][j]); + solution_for_V_inv_x.push_back( + analyzer_problem.Simplify(floordiv(Uy[j], a))); + } else { + // This is required because some simplifiers + // have problems with dividing by negative numbers + PrimExpr a = te::make_const(Uy[j].dtype(), -S[j][j]); + solution_for_V_inv_x.push_back( + analyzer_problem.Simplify(floordiv(-Uy[j], a))); + } + } + } + + // V V^{-1} x = x + for (size_t i = 0; i < num_vars; ++i) { + PrimExpr e = te::make_zero(system_to_solve->variables[i].dtype()); + for (size_t j = 0; j < num_vars; ++j) { + e = e + te::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j]; + } + e = analyzer_problem.Simplify(e); + old_to_new_map.Set(system_to_solve->variables[i], e); + } + + // The resulting ranges + Map new_ranges = InferRange( + new_to_old_map, system_to_solve->variables, system_to_solve->ranges); + Analyzer analyzer_solution; + analyzer_solution.Bind(new_ranges); + + // We have to transform ranges of the old variables into relations over new variables because + // new ranges are not enough usually. + for (const auto& p : system_to_solve->ranges) { + const Var& old_var = p.first; + const Range& old_range = p.second; + if (old_to_new_map.count(old_var)) { + PrimExpr express_by_new_vars = old_to_new_map[old_var]; + PrimExpr lower_cond = analyzer_solution.Simplify( + old_range->min <= express_by_new_vars); + PrimExpr upper_cond = analyzer_solution.Simplify( + express_by_new_vars < old_range->min + old_range->extent); + if (!tir::is_const_int(lower_cond, 1)) { + new_relations.push_back(lower_cond); + } + if (!tir::is_const_int(upper_cond, 1)) { + new_relations.push_back(upper_cond); + } + } + } + + // Add the rest conditions + for (const PrimExpr& cond : rest) { + new_relations.push_back(Substitute(cond, old_to_new_map)); + } + + IntConstraints solution(new_vars, new_ranges, new_relations); + IntConstraintsTransform transform( + system_to_solve, solution, old_to_new_map, new_to_old_map); + + return transform; +} + +TVM_REGISTER_GLOBAL("arith.SolveLinearEquations") +.set_body([](TVMArgs args, TVMRetValue *ret) { + if (args.size() == 1) { + *ret = SolveLinearEquations(args[0]); + } else if (args.size() == 3) { + IntConstraints problem(args[0], args[1], args[2]); + *ret = SolveLinearEquations(problem); + } else { + LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size(); + } + }); + +} // namespace arith +} // namespace tvm diff --git a/src/arith/util.cc b/src/arith/util.cc new file mode 100644 index 0000000000000..058c3e9595281 --- /dev/null +++ b/src/arith/util.cc @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file util.cc + * \brief The utils for arithmetic analysis. + */ +#include +#include + +namespace tvm { +namespace arith { + +std::tuple xgcd(int64_t a, int64_t b) { + int64_t s = 0, old_s = 1; + int64_t t = 1, old_t = 0; + int64_t r = b, old_r = a; + + while (r != 0) { + int64_t q = old_r / r; + std::swap(r, old_r); + r -= q * old_r; + std::swap(s, old_s); + s -= q * old_s; + std::swap(t, old_t); + t -= q * old_t; + } + + CHECK_EQ(a % old_r, 0); + CHECK_EQ(b % old_r, 0); + CHECK(old_r == old_s*a + old_t*b); + + return std::make_tuple(old_r, old_s, old_t); +} + +} // namespace arith +} // namespace tvm diff --git a/tests/python/unittest/test_arith_solve_linear_system.py b/tests/python/unittest/test_arith_solve_linear_system.py new file mode 100644 index 0000000000000..45f8fc10aaf06 --- /dev/null +++ b/tests/python/unittest/test_arith_solve_linear_system.py @@ -0,0 +1,237 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import random +import numpy as np +import sys +import pytest +import tvm +from tvm import te, arith, ir, tir + + +def run_expr(expr, vranges): + """ Evaluate expr for every value of free variables + given by vranges and return the tensor of results. + TODO(yzhliu): move to utils + """ + def _compute_body(*us): + vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)} + return tir.ir_pass.Substitute(expr, vmap) + + A = te.compute([r.extent.value for v, r in vranges.items()], _compute_body) + args = [tvm.nd.empty(A.shape, A.dtype)] + sch = te.create_schedule(A.op) + mod = tvm.build(sch, [A]) + mod(*args) + return args[0].asnumpy() + + +def check_bruteforce(bool_expr, vranges, cond=None): + """ Check that bool_expr holds given the condition cond + for every value of free variables from vranges. + TODO(yzhliu): move to utils + """ + if cond is not None: + bool_expr = te.any(tir.Not(cond), bool_expr) + + res = run_expr(bool_expr, vranges) + if not np.all(res): + indices = list(np.argwhere(res == 0)[0]) + counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)] + counterex = sorted(counterex, key=lambda x: x[0]) + counterex = ", ".join([v + " = " + str(i) for v, i in counterex]) + raise AssertionError("Expression {}\nis not true on {}\n" + "Counterexample: {}" + .format(tir.ir_pass.CanonicalSimplify(bool_expr), vranges, counterex)) + + +def check_solution(solution, vranges={}): + """Check that solution is a bijective transformation""" + def _check_forward(constraints1, constraints2, varmap, backvarmap): + all_vranges = vranges.copy() + all_vranges.update({v: r for v, r in constraints1.ranges.items()}) + + # Check that the transformation is injective + cond_on_vars = tir.const(1, 'bool') + for v in constraints1.variables: + # variable mapping is consistent + v_back = tir.ir_pass.Simplify(tir.ir_pass.Substitute(varmap[v], backvarmap)) + cond_on_vars = te.all(cond_on_vars, v == v_back) + # Also we have to check that the new relations are true when old relations are true + cond_subst = tir.ir_pass.Substitute( + te.all(tir.const(1, 'bool'), *constraints2.relations), backvarmap) + # We have to include relations from vranges too + for v in constraints2.variables: + if v in constraints2.ranges: + r = constraints2.ranges[v] + range_cond = te.all(v >= r.min, v < r.min + r.extent) + range_cond = tir.ir_pass.Substitute(range_cond, backvarmap) + cond_subst = te.all(cond_subst, range_cond) + cond_subst = tir.ir_pass.Simplify(cond_subst) + check_bruteforce(te.all(cond_subst, cond_on_vars), all_vranges, + cond=te.all(tir.const(1, 'bool'), *constraints1.relations)) + + rels = solution.dst.relations + if len(rels) == 1 and ir.structural_equal(rels[0], False): + # not solvable, skip + return + _check_forward(solution.src, solution.dst, + solution.src_to_dst, solution.dst_to_src) + _check_forward(solution.dst, solution.src, + solution.dst_to_src, solution.src_to_dst) + + +def test_solution_consistency(): + seed = random.randrange(sys.maxsize) + print("\nThis test is intentionally non-deterministic, " + "if it fails please report it in github issue together with this seed {}\n".format(seed)) + random.seed(seed) + + def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)): + variables = [te.var("x" + str(i)) for i in range(num_vars)] + + relations = [] + for i in range(num_formulas): + s1 = sum([v*random.randint(coef[0], coef[1]) for v in variables]) + s1 += random.randint(coef[0], coef[1]) + s2 = sum([v*random.randint(coef[0], coef[1]) for v in variables]) + s2 += random.randint(coef[0], coef[1]) + if random.random() < 0.7: + op = tvm.tir.EQ + else: + # we also make sure it can correctly handle inequalities + op = random.choice([tvm.tir.LE, tvm.tir.LT, tvm.tir.GE, tvm.tir.GT]) + relations.append(op(s1, s2)) + + vranges = {v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in variables} + solution = arith.solve_linear_equations(relations, variables, vranges) + + check_solution(solution) + + # leaving some variables as parameters should also be ok + for k in [1, 2]: + if len(variables) > k: + solution = arith.solve_linear_equations(relations, variables[:-k], vranges) + param_ranges = {v: vranges[v] for v in variables[-k:]} + check_solution(solution, param_ranges) + + for i in range(2): + _check(num_vars=1, num_formulas=1) + for i in range(2): + _check(num_vars=1, num_formulas=2) + + for i in range(2): + _check(num_vars=2, num_formulas=1) + for i in range(2): + _check(num_vars=2, num_formulas=2) + for i in range(2): + _check(num_vars=2, num_formulas=3) + + for i in range(3): + _check(num_vars=3, num_formulas=3, coef=(-2, 2)) + for i in range(3): + _check(num_vars=3, num_formulas=4, coef=(-2, 2)) + + for i in range(3): + _check(num_vars=4, num_formulas=3, coef=(-1, 1)) + + for i in range(3): + _check(num_vars=10, num_formulas=2, coef=(-1, 1), bounds=(0, 4)) + for i in range(3): + _check(num_vars=10, num_formulas=3, coef=(0, 1), bounds=(0, 4)) + + +def test_empty_var_to_solve(): + x, y = te.var("x"), te.var("y") + equations = [ + tvm.tir.EQ(x + y, 20), + tvm.tir.EQ(x - y, 10), + ] + solution = arith.solve_linear_equations(equations) + assert len(solution.src_to_dst) == 0 + assert len(solution.dst_to_src) == 0 + assert len(solution.src.variables) == 0 + assert len(solution.src.ranges) == 0 + assert ir.structural_equal(solution.src.relations, equations) + assert ir.structural_equal(solution.src, solution.dst) + + +def test_unique_solution(): + x, y = te.var("x"), te.var("y") + + solution = arith.solve_linear_equations([ + tvm.tir.EQ(x + y, 20), + tvm.tir.EQ(x - y, 10), + ], [x, y]) + assert list(solution.dst.variables) == [] + assert ir.structural_equal(solution.src_to_dst[x], 15) + assert ir.structural_equal(solution.src_to_dst[y], 5) + + +def test_low_rank(): + x, y, z = te.var("x"), te.var("y"), te.var("z") + ranges = {} + + solution = arith.solve_linear_equations([ + tvm.tir.EQ(x + y + z, 15), + tvm.tir.EQ(x + y, 10), + ], [x, y, z], ranges) + [n0] = solution.dst.variables + assert ir.structural_equal(solution.src_to_dst[x], n0 + 10) + assert ir.structural_equal(solution.src_to_dst[y], -n0) + assert ir.structural_equal(solution.src_to_dst[z], 5) + + +def test_infer_range(): + x, y = te.var("x"), te.var("y") + ranges = { + x: tvm.ir.Range.make_by_min_extent(-5, 10), + y: tvm.ir.Range.make_by_min_extent(0, 10), + } + + solution = arith.solve_linear_equations([ + tvm.tir.EQ(x + y, 0), + ], [x, y], ranges) + [n0] = solution.dst.variables + assert ir.structural_equal(solution.src_to_dst[x], n0) + assert ir.structural_equal(solution.src_to_dst[y], -n0) + # inferred from y's range + assert ir.structural_equal(solution.dst.ranges[n0].min, -9) + assert ir.structural_equal(solution.dst.ranges[n0].extent, 10) + # additional inequality is added into the system for x + [ineq] = solution.dst.relations + assert isinstance(ineq, tvm.tir.LE) + assert ir.structural_equal(ineq.a, -5) + assert ir.structural_equal(ineq.b, n0) + + +def test_ill_formed(): + x, y = te.var("x"), te.var("y") + + solution = arith.solve_linear_equations([ + tvm.tir.EQ(x + y, 0), + tvm.tir.EQ(x - y, 0), + tvm.tir.EQ(x, 5), + ], [x, y], {}) + assert list(solution.dst.variables) == [] + [rel] = solution.dst.relations + assert ir.structural_equal(rel, False) + assert len(solution.src_to_dst) == 0 + assert len(solution.dst_to_src) == 0 + + +if __name__ == "__main__": + pytest.main([__file__])