Skip to content

Commit

Permalink
[Arith] linear system and equation solver (apache#5171)
Browse files Browse the repository at this point in the history
* [arith] linear system and equation solver

Co-authored-by: Sergei Grechanik <[email protected]>

* avoid constructing analyzer every time

* generate random test cases and address comments

Co-authored-by: Sergei Grechanik <[email protected]>

* rename linear_system to int_constraints

* add comments and use random seed

* message for reporting failure with seed

* add SEqualReduce to IntConstraints; allow variables & ranges to be None

Co-authored-by: Sergei Grechanik <[email protected]>
Co-authored-by: Sergei Grechanik <[email protected]>
  • Loading branch information
3 people authored and dpankratz committed Apr 24, 2020
1 parent f1c2d55 commit 1ee5937
Show file tree
Hide file tree
Showing 10 changed files with 1,230 additions and 0 deletions.
6 changes: 6 additions & 0 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,12 @@ class Analyzer {
* \param range The range we bind to.
*/
void Bind(const Var& var, const Range& range);
/*!
* \brief Bind all the vars in the Map
*
* \param variables The {variable -> range} map.
*/
void Bind(const Map<Var, Range>& variables);
/*!
* \brief Whether can we prove expr >= val.
Expand Down
208 changes: 208 additions & 0 deletions include/tvm/arith/int_solver.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/arith/int_solver.h
* \brief integer constraints data structures and solvers
*/
#ifndef TVM_ARITH_INT_SOLVER_H_
#define TVM_ARITH_INT_SOLVER_H_

#include <tvm/ir/expr.h>
#include <tvm/tir/expr.h>
#include <unordered_map>
#include <vector>

namespace tvm {
namespace arith {

using tir::Var;
using tir::VarNode;
using tir::IterVar;

/*!
* \brief Represent integer constrains including (integer) variables, their ranges and
* the relations between them (either equations or inequalities).
* \sa LinearSystem
*/
class IntConstraintsNode : public Object {
public:
// e.g., \alpha, \beta, must be integers
Array<Var> variables;
// e.g., 1 <= \alpha <= N, etc.
// it is absolutely ok to include ranges for parameters
// (variables that are not in this->variables) in this map
Map<Var, Range> ranges;
// linear equalities or inequalities
// e.g., A \alpha = \beta or A \alpha <= \beta
Array<PrimExpr> relations;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("variables", &variables);
v->Visit("ranges", &ranges);
v->Visit("relations", &relations);
}

bool SEqualReduce(const IntConstraintsNode* other, SEqualReducer equal) const {
return
equal(variables, other->variables) &&
equal(ranges, other->ranges) &&
equal(relations, other->relations);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(variables);
hash_reduce(ranges);
hash_reduce(relations);
}

static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "arith.IntConstraints";
TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object);
};

/*!
* \brief Managed reference to IntConstraintsNode.
* \sa IntConstraintsNode
*/
class IntConstraints : public ObjectRef {
public:
/*!
* \brief Constructor by fields
* \param variables The variables in the constraints, must be integers.
* \param ranges The ranges of the variables.
* \param relations The linear relations between the variables
* (either equations or inequalities)
*/
TVM_DLL IntConstraints(Array<Var> variables,
Map<Var, Range> ranges,
Array<PrimExpr> relations);

TVM_DEFINE_OBJECT_REF_METHODS(IntConstraints, ObjectRef, IntConstraintsNode);
};

/*!
* \brief We can have different set of variables to represent the same constraints.
* For example, the following two systems are equivalent,
* {a + b = 0 | a >= 0, b >= 0} and
* {m - n = 0 | m >= 0, n <= 0}
* This data structure represents the transformation
* between two equivalent linear systems.
* In the above example,
* src : {a + b = 0 | a >= 0, b >= 0}
* dst : {m - n = 0 | m >= 0, n <= 0}
* src_to_dst : {a -> m, b -> -n}
* dst_to_src : {m -> a, n -> -b}
* \sa IntConstraintsTransform
*/
class IntConstraintsTransformNode : public Object {
public:
IntConstraints src;
IntConstraints dst;
Map<Var, PrimExpr> src_to_dst;
Map<Var, PrimExpr> dst_to_src;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("src", &src);
v->Visit("dst", &dst);
v->Visit("src_to_dst", &src_to_dst);
v->Visit("dst_to_src", &dst_to_src);
}

bool SEqualReduce(const IntConstraintsTransformNode* other, SEqualReducer equal) const {
return
equal(src, other->src) &&
equal(dst, other->dst) &&
equal(src_to_dst, other->src_to_dst) &&
equal(dst_to_src, other->dst_to_src);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(src);
hash_reduce(dst);
hash_reduce(src_to_dst);
hash_reduce(dst_to_src);
}

static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "arith.IntConstraintsTransform";
TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object);
};

/*!
* \brief Managed reference to IntConstraintsTransformNode.
* \sa IntConstraintsTransformNode
*/
class IntConstraintsTransform : public ObjectRef {
public:
/*!
* \brief Constructor by fields
* \param src source integer constraints, e.g., {a + b = 0 | a >= 0, b >= 0}
* \param dst integer constraints equivalent to the source,
* e.g., {m - n = 0 | m >= 0, n <= 0}
* \param src_to_dst mapping from variables in the \p src to the variables in the \p dst,
* e.g., {a -> m, b -> -n}
* \param dst_to_src mapping from variables in the \p dst to the variables in the \p src,
* e.g., {m -> a, n -> -b}
*/
TVM_DLL IntConstraintsTransform(IntConstraints src,
IntConstraints dst,
Map<Var, PrimExpr> src_to_dst,
Map<Var, PrimExpr> dst_to_src);

TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode);
};

/*!
* \brief Obtain Smith Normal Form of linear equation A x = y.
* Smith Normal Form of matrix A_{mxn} is S_{mxn} = U_{mxm} A_{mxn} V_{nxn},
* in which S_{mxn} is diag(s1, s2, ..., sr, 0, ..., 0) and r is the rank of A.
* NOTE: Although in standard Smith Normal Form the diagonal elements satisfy
* s_i | s_{i+1} (| means divides), the implement here does not guarantee it.
* TODO(yzhliu): From sergei-grechanik:
* computing the proper Smith normal form may improve stability of automatic differentiation
* (generating the same gradient code for slightly different but equivalent input code
* U_{mxm} and V_{nxn} are invertible matrices.
* This function modifies \p S to be S_{mxn}, \p V to be V_{nxn},
* \p y to be U_{mxm} y_{mx1} and \p x to be V^{-1} x.
* \param S the original A_{mxn}, it will be modified to S_{mxn}
* \param V an identity matrix, it will be modified to V_{nxn}
* \param x the x in A x = y. it will be modified to V^{-1}_{nxn} x_{nx1}
* \param y the y in A x = y. it will be modified to U_{mxm} y_{mx1}
*/
void SmithNormalFormDiag(std::vector<std::vector<int64_t>> *S,
std::vector<std::vector<int64_t>> *V,
std::vector<PrimExpr>* x,
std::vector<PrimExpr> *y);

/*!
* \brief Solve linear equations.
* \param system_to_solve the variables to solve, their ranges, and a list of equations.
* \return A new linear system, with less variables (if \p system_to_solve is NOT of full rank),
* or no variable (if \p system_to_solve is of full rank),
* or an empty linear system (if \p system_to_solve is unsolvable).
* It also provides the ranges of the variables in the new system,
* as well as inequalities inferred from the \p system_to_solve.
* You can get the mapping from the original variables to the solution via ret->src_to_dst.
*/
IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve);

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_INT_SOLVER_H_
45 changes: 45 additions & 0 deletions include/tvm/arith/util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/arith/util.h
* \brief Utils for arithmetic analysis.
*/
#ifndef TVM_ARITH_UTIL_H_
#define TVM_ARITH_UTIL_H_

#include <cstdint>
#include <tuple>

namespace tvm {
/*! \brief namespace of arithmetic analysis. */
namespace arith {

/*!
* \brief Calculate the extended greatest common divisor for two values.
* See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm.
* \param a an integer number
* \param b an integer number
* \return 3 integers (div, m, n) where div = gcd(a, b) and a*m + b*n = div
*/
std::tuple<int64_t, int64_t, int64_t> xgcd(int64_t a, int64_t b);

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_UTIL_H_
1 change: 1 addition & 0 deletions python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from .analyzer import ModularSet, ConstIntBound, Analyzer
from .bound import deduce_bound
from .pattern import detect_linear_equation, detect_clip_bound
from .int_solver import solve_linear_equations
99 changes: 99 additions & 0 deletions python/tvm/arith/int_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""integer constraints data structures and solvers"""
import tvm._ffi
from tvm.runtime import Object
from . import _ffi_api


@tvm._ffi.register_object("arith.IntConstraints")
class IntConstraints(Object):
"""Represent a set of integer constraints including variables, their ranges and
the relations between them (either equations or inequalities)
Parameters
----------
variables : List[tvm.tir.Var]
The variables in the constraints. Must be integers
ranges : Map[tvm.tir.Var, tvm.ir.Range]
The ranges of the variables.
relations : List[tvm.ir.PrimExpr]
The relations between the variables (either equations or inequalities)
"""
def __init__(self, variables, ranges, relations):
self.__init_handle_by_constructor__(
_ffi_api.IntConstraints, variables, ranges, relations)


@tvm._ffi.register_object("arith.IntConstraintsTransform")
class IntConstraintsTransform(Object):
"""We can have different set of variables to represent the same integer constraints.
For example, the following two constrains are equivalent,
{a + b = 0 | a >= 0, b >= 0} and
{m - n = 0 | m >= 0, n <= 0}
This data structure represents the transformation
between two equivalent integer constraints.
In the above example,
src : {a + b = 0 | a >= 0, b >= 0}
dst : {m - n = 0 | m >= 0, n <= 0}
src_to_dst : {a -> m, b -> -n}
dst_to_src : {m -> a, n -> -b}
Parameters
----------
src : arith.IntConstraints
source integer constraints, e.g., {a + b = 0 | a >= 0, b >= 0}
dst : arith.IntConstraints
integer constraints equivalent to the source, e.g., {m - n = 0 | m >= 0, n <= 0}
src_to_dst : Map[tvm.tir.Var, tvm.ir.PrimExpr]
mapping from variables in the src to the variables in the dst,
e.g., {a -> m, b -> -n}
dst_to_src : Map[tvm.tir.Var, tvm.ir.PrimExpr]
mapping from variables in the dst to the variables in the src,
e.g., {m -> a, n -> -b}
"""
def __init__(self, src, dst, src_to_dst, dst_to_src):
self.__init_handle_by_constructor__(
_ffi_api.IntConstraintsTransform, src, dst, src_to_dst, dst_to_src)


def solve_linear_equations(equations, variables=None, ranges=None):
"""Solve linear equations.
Parameters
----------
equations: List[tvm.ir.PrimExpr] or IntConstraints
The equations of the variables
variables : Optional[List[tvm.tir.Var]]
The variables in the system.
ranges : Optional[Map[tvm.tir.Var, tvm.ir.Range]]
The ranges of the variables.
Returns
-------
int_constraints_transform : IntConstraintsTransform
New integer constraints, with less variables (if the problem is NOT of full rank),
or no variable (if the problem is of full rank),
or an empty integer constraints (if the problem is unsolvable).
It also provides the ranges of the variables in the new system,
as well as inequalities inferred from the problem.
You can get the mapping from the original variables to the solution via
int_constraints_transform.src_to_dst.
"""
if isinstance(equations, IntConstraints):
return _ffi_api.SolveLinearEquations(equations)
return _ffi_api.SolveLinearEquations(variables, ranges, equations)
5 changes: 5 additions & 0 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ void Analyzer::Bind(const Var& var, const Range& range) {
// skip rewrite simplify
}

void Analyzer::Bind(const Map<Var, Range>& variables) {
for (const auto& iter : variables) {
this->Bind(iter.first, iter.second);
}
}

void ConstraintContext::EnterWithScope() {
CHECK(exit_ == nullptr);
Expand Down
Loading

0 comments on commit 1ee5937

Please sign in to comment.