Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith] Inequalities solver #5618

Merged
merged 39 commits into from
Jul 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
8207da5
[arith] inequalities solver
yzhliu Apr 30, 2020
39a93f7
introduce IntGroupedBounds
yzhliu May 16, 2020
af6708d
add no deskewed solution
yzhliu May 16, 2020
7f551e1
keep IntConstraints def
yzhliu May 18, 2020
e56d310
add test case and fix for equations
yzhliu Jun 5, 2020
1b216e9
add random consistency test cases
yzhliu Jun 6, 2020
059ea2e
improve test cases
yzhliu Jun 6, 2020
9c5becc
add doc & comments
yzhliu Jun 7, 2020
0fd69f7
test case refactored
yzhliu Jun 7, 2020
9af9d24
test file rename
yzhliu Jun 7, 2020
8a44d61
merge from upstream
yzhliu Jun 7, 2020
6349c33
fix lint
yzhliu Jun 7, 2020
8caa406
apply clang-format
yzhliu Jun 7, 2020
bc0587f
apply cl-format-10
yzhliu Jun 7, 2020
e84467a
fix cpplint
yzhliu Jun 7, 2020
45c6a4c
add more docs
yzhliu Jun 17, 2020
bb22dd2
Add Co-author.
yzhliu Jun 17, 2020
0e0b5f0
add check_solution
yzhliu Jun 17, 2020
ff05fe0
add support for unsolvable inequalities
yzhliu Jun 19, 2020
e6845bc
fix for non-divisible equation
yzhliu Jun 20, 2020
7166020
fix non-divisible case again and add test case for no solution
yzhliu Jun 20, 2020
eaca5c9
revise testing
yzhliu Jun 20, 2020
7fd0c57
Merge remote-tracking branch 'upstream/master' into inequality_solver
yzhliu Jun 20, 2020
12b5fcd
fix merging
yzhliu Jun 20, 2020
182a5d5
fix lint
yzhliu Jun 20, 2020
c8b2370
add comments
yzhliu Jun 20, 2020
4c6cec4
fix order of as_condition
yzhliu Jun 22, 2020
a4dbba4
fix lint
yzhliu Jun 22, 2020
3a29517
Merge remote-tracking branch 'upstream/master' into inequality_solver
yzhliu Jun 22, 2020
666d569
Merge remote-tracking branch 'upstream/master' into inequality_solver
yzhliu Jun 26, 2020
0605343
remove special dealing with equations
yzhliu Jun 26, 2020
485cfad
address several comments; add steps to python side analyzer.simplify
yzhliu Jun 26, 2020
98cc3ac
fix a dumb compilation failure
yzhliu Jun 26, 2020
d8f4ee5
Merge remote-tracking branch 'upstream/master' into inequality_solver
yzhliu Jun 27, 2020
14588cb
merge from upstream
yzhliu Jul 1, 2020
51eb652
IntGrpBounds -> IntGroupBounds
yzhliu Jul 1, 2020
aaa53f3
move if check to the root
yzhliu Jul 1, 2020
d42f2ca
fix lint
yzhliu Jul 1, 2020
b21c13f
fix clang format check
yzhliu Jul 1, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,16 @@ class TVM_DLL Analyzer {
* \brief Simplify expr.
*
* \param expr The expression to be simplified.
* \param steps The simplification runs in the order of
* rewrite_simplify (step 1) -> canonical_simplify (step 2) ->
* rewrite_simplify (step 3) -> canonical_simplify (step 4) -> ...
* param steps controls how many steps to run.
* Default is 2, i.e., rewrite_simplify + canonical_simplify.
* \return The result.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
PrimExpr Simplify(const PrimExpr& expr);
PrimExpr Simplify(const PrimExpr& expr, int steps = 2);
};

} // namespace arith
Expand Down
145 changes: 145 additions & 0 deletions include/tvm/arith/int_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,110 @@

#include <tvm/ir/expr.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>

#include <unordered_map>
#include <utility>
#include <vector>

#include "analyzer.h"

namespace tvm {
namespace arith {

using tir::IterVar;
using tir::Var;
using tir::VarNode;

/*!
* \brief Represent integer grouped bounds which are classified into
* lower bounds (inclusive), upper bounds (inclusive) and equalities.
* It also contains coefficient as a multiplier for the bounds, i.e.,
* coef * var >= lower
* coef * var == equal
* coef * var <= upper
* \sa IntGroupBounds
*/
class IntGroupBoundsNode : public Object {
public:
PrimExpr coef;
Array<PrimExpr> lower;
Array<PrimExpr> equal;
Array<PrimExpr> upper;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("coef", &coef);
v->Visit("lower", &lower);
v->Visit("equal", &equal);
v->Visit("upper", &upper);
}

bool SEqualReduce(const IntGroupBoundsNode* other, SEqualReducer eq) const {
return eq(coef, other->coef) && eq(lower, other->lower) && eq(equal, other->equal) &&
eq(upper, other->upper);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(coef);
hash_reduce(lower);
hash_reduce(equal);
hash_reduce(upper);
}

static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "arith.IntGroupBounds";
TVM_DECLARE_FINAL_OBJECT_INFO(IntGroupBoundsNode, Object);
};

/*!
* \brief Managed reference to IntGroupBoundsNode.
* \sa IntGroupBoundsNode
*/
class IntGroupBounds : public ObjectRef {
public:
/*!
* \brief Constructor by fields
* \param coef The coefficient. Must be integer.
* coef * var >= lower
* coef * var == equal
* coef * var >= upper
* \param lower the lower bounds (include)
* \param equal equalities
* \param upper the upper bounds (include)
*/
TVM_DLL IntGroupBounds(PrimExpr coef, Array<PrimExpr> lower, Array<PrimExpr> equal,
Array<PrimExpr> upper);

/*!
* \brief Construct bounds from a range.
* \param r The range
* \return constructed bounds.
*/
static IntGroupBounds FromRange(const Range& r);

/*!
* \brief Perform substitution on all components of the struct.
*/
IntGroupBounds Substitute(const Map<Var, PrimExpr>& subst) const;

/*!
* \brief Find the best range from the grouped bounds.
* \param vranges_addl additional variable ranges that help infer the best range.
* \return The best range (has the least difference between the lower bound and upper bound).
* undefined if (-inf, +inf).
*/
Range FindBestRange(const Map<Var, Range>& vranges_addl = {}) const;

/*!
* \brief Combine the bounds with another range.
* \param r range to be combined.
* \return combined bounds.
*/
IntGroupBounds operator+(const Range& r);

TVM_DEFINE_OBJECT_REF_METHODS(IntGroupBounds, ObjectRef, IntGroupBoundsNode);
};

/*!
* \brief Represent integer constrains including (integer) variables, their ranges and
* the relations between them (either equations or inequalities).
Expand Down Expand Up @@ -161,6 +254,8 @@ class IntConstraintsTransform : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode);
};

typedef std::pair<Map<Var, IntGroupBounds>, Array<PrimExpr>> PartialSolvedInequalities;

/*!
* \brief Obtain Smith Normal Form of linear equation A x = y.
* Smith Normal Form of matrix A_{mxn} is S_{mxn} = U_{mxm} A_{mxn} V_{nxn},
Expand Down Expand Up @@ -191,6 +286,56 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t>>* S, std::vector<std::
*/
IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_solve);

/*!
* \brief Solve linear inequalities.
* \param system_to_solve the variables to solve, their ranges, and a list of inequalities.
* The inequalities are rewritten using Fourier-Motzkin elimination.
* This function takes an array of (in)equalities and an array of variables, and essentially
* rewrites the (in)equalities into an array of (in)equalities of the following form,
*
* x0 >= f0(x1, x2, ..., xn)
* x0 <= g0(x1, x2, ..., xn)
* x1 >= f1(x2, ..., xn)
* x1 <= g1(x2, ..., xn)
* ...
* xn >= fn() // just a constant
* xn <= gn() // just a constant
*
* \return A map of variables and their solved bounds,
* and constrains that cannot be solved to bounds.
*/
PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_to_solve);

/*!
* \brief Solve linear inequalities and infer the range of each variable.
* \param system_to_solve the variables to solve, their ranges, and a list of inequalities.
* \return The result ranges for each variables.
* The returned IntConstraints(variables, ranges, relations) contains,
* 1. variables - the variables that have been solved.
* 2. ranges - the best range of each variable.
* 3. relations - constraints that cannot be transformed to
* Range will be stored in relations.
*/
IntConstraints SolveInequalitiesToRange(const IntConstraints& system_to_solve);
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
yzhliu marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief Solve linear inequalities and deskew the ranges towards zero.
* \param system_to_solve the variables to solve, their ranges, and a list of inequalities.
* \return A transform (src IntConstraints -> dst IntConstraints)
* from original variables to a set of new variables.
* The ranges of new variables always start from zero,
* their extents are solved from \p system_to_solve.
* src IntConstraints is the same as \p system_to_solve.
* dst IntConstraints(variables, ranges, relations) contains,
* 1. variables - the variables that have been solved.
* 2. ranges - the best range (start from zero) of each variable.
* 3. relations - constraints that cannot be transformed to
* Range will be stored in relations.
* Variable mapping can be obtained from
* IntConstraintsTransform.src_to_dst and IntConstraintsTransform.dst_to_src.
*/
IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& system_to_solve);
yzhliu marked this conversation as resolved.
Show resolved Hide resolved

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_INT_SOLVER_H_
2 changes: 1 addition & 1 deletion python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
from .analyzer import ModularSet, ConstIntBound, Analyzer
from .bound import deduce_bound
from .pattern import detect_linear_equation, detect_clip_bound
from .int_solver import solve_linear_equations
from .int_solver import solve_linear_equations, solve_linear_inequalities
9 changes: 7 additions & 2 deletions python/tvm/arith/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,20 +119,25 @@ def modular_set(self, expr):
"""
return self._modular_set(expr)

def simplify(self, expr):
def simplify(self, expr, steps=2):
"""Simplify expression via both rewrite and canonicalization.

Parameters
----------
expr : PrimExpr
The expression.
steps : The simplification runs in the order of
rewrite_simplify (step 1) -> canonical_simplify (step 2) ->
rewrite_simplify (step 3) -> canonical_simplify (step 4) -> ...
param steps controls how many steps to run.
Default is 2, i.e., rewrite_simplify + canonical_simplify.

Returns
-------
result : Expr
The result.
"""
return self._simplify(expr)
return self._simplify(expr, steps)

def rewrite_simplify(self, expr):
"""Simplify expression via rewriting rules.
Expand Down
78 changes: 78 additions & 0 deletions python/tvm/arith/int_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,52 @@
from . import _ffi_api


@tvm._ffi.register_object("arith.IntGroupBounds")
class IntGroupBounds(Object):
"""Represent integer grouped bounds which are classified into
lower bounds (include), upper bounds (include) and equalities.

Parameters
----------
coef : tvm.ir.PrimExpr
The coefficient. Must be integer type.
coef * var >= lower
coef * var == equal
coef * var >= upper
lower : List[tvm.ir.PrimExpr]
the lower bounds (include)
equal : List[tvm.ir.PrimExpr]
equalities
upper : List[tvm.ir.PrimExpr]
the upper bounds (include)
"""
def __init__(self, coef, lower, equal, upper):
self.__init_handle_by_constructor__(
_ffi_api.IntGroupBounds, coef, lower, equal, upper)

@staticmethod
def from_range(rng):
"""Construct a IntGroupedBounds by Range.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from_range


Parameters
----------
rng : tvm.ir.Range


Returns
-------
ret : Range
The constructed range.
"""
return _ffi_api.IntGroupBounds_from_range(rng)

def find_best_range(self):
"""Return the best range from the grouped bounds.
None if (-inf, +inf).
"""
return _ffi_api.IntGroupBounds_FindBestRange(self)


@tvm._ffi.register_object("arith.IntConstraints")
class IntConstraints(Object):
"""Represent a set of integer constraints including variables, their ranges and
Expand Down Expand Up @@ -97,3 +143,35 @@ def solve_linear_equations(equations, variables=None, ranges=None):
if isinstance(equations, IntConstraints):
return _ffi_api.SolveLinearEquations(equations)
return _ffi_api.SolveLinearEquations(variables, ranges, equations)


def solve_linear_inequalities(equations, variables=None, ranges=None, deskew_range=False):
"""Solve linear inequalities.

Parameters
----------
equations : List[tvm.ir.PrimExpr] or IntConstraints
The inequalities of the variables
variables : Optional[List[tvm.tir.Var]]
The variables in the system.
ranges : Optional[Map[tvm.tir.Var, tvm.ir.Range]]
The ranges of the variables.
deskew_range: Optional[bool]
Whether deskew the result ranges to be started from zero.
Default false.

Returns
-------
ret_ranges: IntConstraints or IntConstraintsTransform
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
The result ranges for each variables.
Constrains that cannot be transformed to Range will be stored in IntConstraints.relations.
If deskew_range is set (=True), the result ranges will be deskewed to be started from zero.
New variables are created accordingly therefore IntConstraintsTransform is returned.
"""
solver = _ffi_api.SolveInequalitiesDeskewRange \
if deskew_range else _ffi_api.SolveInequalitiesToRange
if isinstance(equations, IntConstraints):
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
assert variables is None
assert ranges is None
return solver(equations)
return solver(variables, ranges, equations)
Loading