Skip to content

Commit

Permalink
[Arith] Inequalities solver (#5618)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu authored Jul 6, 2020
1 parent a64feed commit 151f3f5
Show file tree
Hide file tree
Showing 12 changed files with 1,363 additions and 87 deletions.
7 changes: 6 additions & 1 deletion include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,16 @@ class TVM_DLL Analyzer {
* \brief Simplify expr.
*
* \param expr The expression to be simplified.
* \param steps The simplification runs in the order of
* rewrite_simplify (step 1) -> canonical_simplify (step 2) ->
* rewrite_simplify (step 3) -> canonical_simplify (step 4) -> ...
* param steps controls how many steps to run.
* Default is 2, i.e., rewrite_simplify + canonical_simplify.
* \return The result.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
PrimExpr Simplify(const PrimExpr& expr);
PrimExpr Simplify(const PrimExpr& expr, int steps = 2);
};

} // namespace arith
Expand Down
145 changes: 145 additions & 0 deletions include/tvm/arith/int_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,110 @@

#include <tvm/ir/expr.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>

#include <unordered_map>
#include <utility>
#include <vector>

#include "analyzer.h"

namespace tvm {
namespace arith {

using tir::IterVar;
using tir::Var;
using tir::VarNode;

/*!
* \brief Represent integer grouped bounds which are classified into
* lower bounds (inclusive), upper bounds (inclusive) and equalities.
* It also contains coefficient as a multiplier for the bounds, i.e.,
* coef * var >= lower
* coef * var == equal
* coef * var <= upper
* \sa IntGroupBounds
*/
class IntGroupBoundsNode : public Object {
public:
PrimExpr coef;
Array<PrimExpr> lower;
Array<PrimExpr> equal;
Array<PrimExpr> upper;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("coef", &coef);
v->Visit("lower", &lower);
v->Visit("equal", &equal);
v->Visit("upper", &upper);
}

bool SEqualReduce(const IntGroupBoundsNode* other, SEqualReducer eq) const {
return eq(coef, other->coef) && eq(lower, other->lower) && eq(equal, other->equal) &&
eq(upper, other->upper);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(coef);
hash_reduce(lower);
hash_reduce(equal);
hash_reduce(upper);
}

static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "arith.IntGroupBounds";
TVM_DECLARE_FINAL_OBJECT_INFO(IntGroupBoundsNode, Object);
};

/*!
* \brief Managed reference to IntGroupBoundsNode.
* \sa IntGroupBoundsNode
*/
class IntGroupBounds : public ObjectRef {
public:
/*!
* \brief Constructor by fields
* \param coef The coefficient. Must be integer.
* coef * var >= lower
* coef * var == equal
* coef * var >= upper
* \param lower the lower bounds (include)
* \param equal equalities
* \param upper the upper bounds (include)
*/
TVM_DLL IntGroupBounds(PrimExpr coef, Array<PrimExpr> lower, Array<PrimExpr> equal,
Array<PrimExpr> upper);

/*!
* \brief Construct bounds from a range.
* \param r The range
* \return constructed bounds.
*/
static IntGroupBounds FromRange(const Range& r);

/*!
* \brief Perform substitution on all components of the struct.
*/
IntGroupBounds Substitute(const Map<Var, PrimExpr>& subst) const;

/*!
* \brief Find the best range from the grouped bounds.
* \param vranges_addl additional variable ranges that help infer the best range.
* \return The best range (has the least difference between the lower bound and upper bound).
* undefined if (-inf, +inf).
*/
Range FindBestRange(const Map<Var, Range>& vranges_addl = {}) const;

/*!
* \brief Combine the bounds with another range.
* \param r range to be combined.
* \return combined bounds.
*/
IntGroupBounds operator+(const Range& r);

TVM_DEFINE_OBJECT_REF_METHODS(IntGroupBounds, ObjectRef, IntGroupBoundsNode);
};

/*!
* \brief Represent integer constrains including (integer) variables, their ranges and
* the relations between them (either equations or inequalities).
Expand Down Expand Up @@ -161,6 +254,8 @@ class IntConstraintsTransform : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode);
};

typedef std::pair<Map<Var, IntGroupBounds>, Array<PrimExpr>> PartialSolvedInequalities;

/*!
* \brief Obtain Smith Normal Form of linear equation A x = y.
* Smith Normal Form of matrix A_{mxn} is S_{mxn} = U_{mxm} A_{mxn} V_{nxn},
Expand Down Expand Up @@ -191,6 +286,56 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t>>* S, std::vector<std::
*/
IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_solve);

/*!
* \brief Solve linear inequalities.
* \param system_to_solve the variables to solve, their ranges, and a list of inequalities.
* The inequalities are rewritten using Fourier-Motzkin elimination.
* This function takes an array of (in)equalities and an array of variables, and essentially
* rewrites the (in)equalities into an array of (in)equalities of the following form,
*
* x0 >= f0(x1, x2, ..., xn)
* x0 <= g0(x1, x2, ..., xn)
* x1 >= f1(x2, ..., xn)
* x1 <= g1(x2, ..., xn)
* ...
* xn >= fn() // just a constant
* xn <= gn() // just a constant
*
* \return A map of variables and their solved bounds,
* and constrains that cannot be solved to bounds.
*/
PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_to_solve);

/*!
* \brief Solve linear inequalities and infer the range of each variable.
* \param system_to_solve the variables to solve, their ranges, and a list of inequalities.
* \return The result ranges for each variables.
* The returned IntConstraints(variables, ranges, relations) contains,
* 1. variables - the variables that have been solved.
* 2. ranges - the best range of each variable.
* 3. relations - constraints that cannot be transformed to
* Range will be stored in relations.
*/
IntConstraints SolveInequalitiesToRange(const IntConstraints& system_to_solve);

/*!
* \brief Solve linear inequalities and deskew the ranges towards zero.
* \param system_to_solve the variables to solve, their ranges, and a list of inequalities.
* \return A transform (src IntConstraints -> dst IntConstraints)
* from original variables to a set of new variables.
* The ranges of new variables always start from zero,
* their extents are solved from \p system_to_solve.
* src IntConstraints is the same as \p system_to_solve.
* dst IntConstraints(variables, ranges, relations) contains,
* 1. variables - the variables that have been solved.
* 2. ranges - the best range (start from zero) of each variable.
* 3. relations - constraints that cannot be transformed to
* Range will be stored in relations.
* Variable mapping can be obtained from
* IntConstraintsTransform.src_to_dst and IntConstraintsTransform.dst_to_src.
*/
IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& system_to_solve);

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_INT_SOLVER_H_
2 changes: 1 addition & 1 deletion python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
from .analyzer import ModularSet, ConstIntBound, Analyzer
from .bound import deduce_bound
from .pattern import detect_linear_equation, detect_clip_bound
from .int_solver import solve_linear_equations
from .int_solver import solve_linear_equations, solve_linear_inequalities
9 changes: 7 additions & 2 deletions python/tvm/arith/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,20 +119,25 @@ def modular_set(self, expr):
"""
return self._modular_set(expr)

def simplify(self, expr):
def simplify(self, expr, steps=2):
"""Simplify expression via both rewrite and canonicalization.
Parameters
----------
expr : PrimExpr
The expression.
steps : The simplification runs in the order of
rewrite_simplify (step 1) -> canonical_simplify (step 2) ->
rewrite_simplify (step 3) -> canonical_simplify (step 4) -> ...
param steps controls how many steps to run.
Default is 2, i.e., rewrite_simplify + canonical_simplify.
Returns
-------
result : Expr
The result.
"""
return self._simplify(expr)
return self._simplify(expr, steps)

def rewrite_simplify(self, expr):
"""Simplify expression via rewriting rules.
Expand Down
78 changes: 78 additions & 0 deletions python/tvm/arith/int_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,52 @@
from . import _ffi_api


@tvm._ffi.register_object("arith.IntGroupBounds")
class IntGroupBounds(Object):
"""Represent integer grouped bounds which are classified into
lower bounds (include), upper bounds (include) and equalities.
Parameters
----------
coef : tvm.ir.PrimExpr
The coefficient. Must be integer type.
coef * var >= lower
coef * var == equal
coef * var >= upper
lower : List[tvm.ir.PrimExpr]
the lower bounds (include)
equal : List[tvm.ir.PrimExpr]
equalities
upper : List[tvm.ir.PrimExpr]
the upper bounds (include)
"""
def __init__(self, coef, lower, equal, upper):
self.__init_handle_by_constructor__(
_ffi_api.IntGroupBounds, coef, lower, equal, upper)

@staticmethod
def from_range(rng):
"""Construct a IntGroupedBounds by Range.
Parameters
----------
rng : tvm.ir.Range
Returns
-------
ret : Range
The constructed range.
"""
return _ffi_api.IntGroupBounds_from_range(rng)

def find_best_range(self):
"""Return the best range from the grouped bounds.
None if (-inf, +inf).
"""
return _ffi_api.IntGroupBounds_FindBestRange(self)


@tvm._ffi.register_object("arith.IntConstraints")
class IntConstraints(Object):
"""Represent a set of integer constraints including variables, their ranges and
Expand Down Expand Up @@ -97,3 +143,35 @@ def solve_linear_equations(equations, variables=None, ranges=None):
if isinstance(equations, IntConstraints):
return _ffi_api.SolveLinearEquations(equations)
return _ffi_api.SolveLinearEquations(variables, ranges, equations)


def solve_linear_inequalities(equations, variables=None, ranges=None, deskew_range=False):
"""Solve linear inequalities.
Parameters
----------
equations : List[tvm.ir.PrimExpr] or IntConstraints
The inequalities of the variables
variables : Optional[List[tvm.tir.Var]]
The variables in the system.
ranges : Optional[Map[tvm.tir.Var, tvm.ir.Range]]
The ranges of the variables.
deskew_range: Optional[bool]
Whether deskew the result ranges to be started from zero.
Default false.
Returns
-------
ret_ranges: IntConstraints or IntConstraintsTransform
The result ranges for each variables.
Constrains that cannot be transformed to Range will be stored in IntConstraints.relations.
If deskew_range is set (=True), the result ranges will be deskewed to be started from zero.
New variables are created accordingly therefore IntConstraintsTransform is returned.
"""
solver = _ffi_api.SolveInequalitiesDeskewRange \
if deskew_range else _ffi_api.SolveInequalitiesToRange
if isinstance(equations, IntConstraints):
assert variables is None
assert ranges is None
return solver(equations)
return solver(variables, ranges, equations)
Loading

0 comments on commit 151f3f5

Please sign in to comment.