Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith] Inequalities solver #5618

Merged
merged 39 commits into from
Jul 6, 2020
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
8207da5
[arith] inequalities solver
yzhliu Apr 30, 2020
39a93f7
introduce IntGroupedBounds
yzhliu May 16, 2020
af6708d
add no deskewed solution
yzhliu May 16, 2020
7f551e1
keep IntConstraints def
yzhliu May 18, 2020
e56d310
add test case and fix for equations
yzhliu Jun 5, 2020
1b216e9
add random consistency test cases
yzhliu Jun 6, 2020
059ea2e
improve test cases
yzhliu Jun 6, 2020
9c5becc
add doc & comments
yzhliu Jun 7, 2020
0fd69f7
test case refactored
yzhliu Jun 7, 2020
9af9d24
test file rename
yzhliu Jun 7, 2020
8a44d61
merge from upstream
yzhliu Jun 7, 2020
6349c33
fix lint
yzhliu Jun 7, 2020
8caa406
apply clang-format
yzhliu Jun 7, 2020
bc0587f
apply cl-format-10
yzhliu Jun 7, 2020
e84467a
fix cpplint
yzhliu Jun 7, 2020
45c6a4c
add more docs
yzhliu Jun 17, 2020
bb22dd2
Add Co-author.
yzhliu Jun 17, 2020
0e0b5f0
add check_solution
yzhliu Jun 17, 2020
ff05fe0
add support for unsolvable inequalities
yzhliu Jun 19, 2020
e6845bc
fix for non-divisible equation
yzhliu Jun 20, 2020
7166020
fix non-divisible case again and add test case for no solution
yzhliu Jun 20, 2020
eaca5c9
revise testing
yzhliu Jun 20, 2020
7fd0c57
Merge remote-tracking branch 'upstream/master' into inequality_solver
yzhliu Jun 20, 2020
12b5fcd
fix merging
yzhliu Jun 20, 2020
182a5d5
fix lint
yzhliu Jun 20, 2020
c8b2370
add comments
yzhliu Jun 20, 2020
4c6cec4
fix order of as_condition
yzhliu Jun 22, 2020
a4dbba4
fix lint
yzhliu Jun 22, 2020
3a29517
Merge remote-tracking branch 'upstream/master' into inequality_solver
yzhliu Jun 22, 2020
666d569
Merge remote-tracking branch 'upstream/master' into inequality_solver
yzhliu Jun 26, 2020
0605343
remove special dealing with equations
yzhliu Jun 26, 2020
485cfad
address several comments; add steps to python side analyzer.simplify
yzhliu Jun 26, 2020
98cc3ac
fix a dumb compilation failure
yzhliu Jun 26, 2020
d8f4ee5
Merge remote-tracking branch 'upstream/master' into inequality_solver
yzhliu Jun 27, 2020
14588cb
merge from upstream
yzhliu Jul 1, 2020
51eb652
IntGrpBounds -> IntGroupBounds
yzhliu Jul 1, 2020
aaa53f3
move if check to the root
yzhliu Jul 1, 2020
d42f2ca
fix lint
yzhliu Jul 1, 2020
b21c13f
fix clang format check
yzhliu Jul 1, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ tvm_t.*
.python_history
.pytest_cache
.local
cmake-build-debug

# Visual Studio Code
.vscode
Expand Down
7 changes: 6 additions & 1 deletion include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,16 @@ class TVM_DLL Analyzer {
* \brief Simplify expr.
*
* \param expr The expression to be simplified.
* \param steps The simplification runs in the order of
* rewrite_simplify (step 1) -> canonical_simplify (step 2) ->
* rewrite_simplify (step 3) -> canonical_simplify (step 4) -> ...
* param steps controls how many steps to run.
* Default is 2, i.e., rewrite_simplify + canonical_simplify.
* \return The result.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
PrimExpr Simplify(const PrimExpr& expr);
PrimExpr Simplify(const PrimExpr& expr, size_t steps = 2);
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
};

} // namespace arith
Expand Down
108 changes: 108 additions & 0 deletions include/tvm/arith/int_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,99 @@

#include <tvm/ir/expr.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>

#include <unordered_map>
#include <utility>
#include <vector>

#include "analyzer.h"

namespace tvm {
namespace arith {

using tir::IterVar;
using tir::Var;
using tir::VarNode;

/*!
* \brief Represent integer grouped bounds which are classified into
* lower bounds (include), upper bounds (include) and equalities.
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
* It also contains coefficient as a multiplier for the bounds, i.e.,
* coef * var >= lower
* coef * var == equal
* coef * var <= upper
* \sa IntGrpBounds
*/
class IntGrpBoundsNode : public Object {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IntGroupBounds

public:
PrimExpr coef;
Array<PrimExpr> lower;
Array<PrimExpr> equal;
Array<PrimExpr> upper;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("coef", &coef);
v->Visit("lower", &lower);
v->Visit("equal", &equal);
v->Visit("upper", &upper);
}

bool SEqualReduce(const IntGrpBoundsNode* other, SEqualReducer eq) const {
return eq(coef, other->coef) && eq(lower, other->lower) && eq(equal, other->equal) &&
eq(upper, other->upper);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(coef);
hash_reduce(lower);
hash_reduce(equal);
hash_reduce(upper);
}

static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "arith.IntGrpBounds";
TVM_DECLARE_FINAL_OBJECT_INFO(IntGrpBoundsNode, Object);
};

/*!
* \brief Managed reference to IntGrpBoundsNode.
* \sa IntGrpBoundsNode
*/
class IntGrpBounds : public ObjectRef {
public:
/*!
* \brief Constructor by fields
* \param coef The coefficient. Must be integer.
* coef * var >= lower
* coef * var == equal
* coef * var >= upper
* \param lower the lower bounds (include)
* \param equal equalities
* \param upper the upper bounds (include)
*/
TVM_DLL IntGrpBounds(PrimExpr coef, Array<PrimExpr> lower, Array<PrimExpr> equal,
Array<PrimExpr> upper);

/*!
* \brief Construct bounds from a range.
* \param r The range
* \return constructed bounds.
*/
static IntGrpBounds range(const Range& r);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe range --> CreateIntGrpBounds (Or something like that, as we are creating Bounds from Range)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I see now. I am also okay if it is general convention:)
@tqchen : Do you have any thoughts on the comment ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about FromRange


/*!
* \brief Perform substitution on all components of the struct.
*/
IntGrpBounds Substitute(const Map<Var, PrimExpr>& subst) const;

Range FindBestRange(const Map<Var, Range>& vranges_addl = {}) const;

IntGrpBounds operator+(const Range& r);
yzhliu marked this conversation as resolved.
Show resolved Hide resolved

TVM_DEFINE_OBJECT_REF_METHODS(IntGrpBounds, ObjectRef, IntGrpBoundsNode);
};

/*!
* \brief Represent integer constrains including (integer) variables, their ranges and
* the relations between them (either equations or inequalities).
Expand Down Expand Up @@ -161,6 +243,8 @@ class IntConstraintsTransform : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode);
};

typedef std::pair<Map<Var, IntGrpBounds>, Array<PrimExpr>> PartialSolvedInequalities;

/*!
* \brief Obtain Smith Normal Form of linear equation A x = y.
* Smith Normal Form of matrix A_{mxn} is S_{mxn} = U_{mxm} A_{mxn} V_{nxn},
Expand Down Expand Up @@ -191,6 +275,30 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t>>* S, std::vector<std::
*/
IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_solve);

/*!
* \brief Solve linear inequalities.
* \param system_to_solve the variables to solve, their ranges, and a list of inequalities.
* \return A map of variables and their solved bounds,
* and constrains that cannot be solved to bounds.
*/
PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_to_solve);

/*!
* \brief Solve linear inequalities.
* \param system_to_solve the variables to solve, their ranges, and a list of inequalities.
* \return The result ranges for each variables.
* Constrains that cannot be transformed to Range will be stored in relations.
*/
IntConstraints SolveInequalitiesToRange(const IntConstraints& system_to_solve);
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
yzhliu marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief Solve linear inequalities.
* \param system_to_solve the variables to solve, their ranges, and a list of inequalities.
* \return Solved ranges are deskewed to be started from zero.
* New variables and the mapping are created accordingly.
*/
IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& system_to_solve);
yzhliu marked this conversation as resolved.
Show resolved Hide resolved

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_INT_SOLVER_H_
2 changes: 1 addition & 1 deletion python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
from .analyzer import ModularSet, ConstIntBound, Analyzer
from .bound import deduce_bound
from .pattern import detect_linear_equation, detect_clip_bound
from .int_solver import solve_linear_equations
from .int_solver import solve_linear_equations, solve_linear_inequalities
76 changes: 76 additions & 0 deletions python/tvm/arith/int_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,52 @@
from . import _ffi_api


@tvm._ffi.register_object("arith.IntGrpBounds")
class IntGrpBounds(Object):
"""Represent integer grouped bounds which are classified into
lower bounds (include), upper bounds (include) and equalities.

Parameters
----------
coef : tvm.ir.PrimExpr
The coefficient. Must be integer.
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
coef * var >= lower
coef * var == equal
coef * var >= upper
lower : List[tvm.ir.PrimExpr]
the lower bounds (include)
equal : List[tvm.ir.PrimExpr]
equalities
upper : List[tvm.ir.PrimExpr]
the upper bounds (include)
"""
def __init__(self, coef, lower, equal, upper):
self.__init_handle_by_constructor__(
_ffi_api.IntGrpBounds, coef, lower, equal, upper)

@staticmethod
def make_by_range(rng):
"""Construct a IntGroupedBounds by Range.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from_range


Parameters
----------
rng : tvm.ir.Range


Returns
-------
ret : Range
The constructed range.
"""
return _ffi_api.int_grouped_bounds_by_range(rng)

def find_best_range(self):
"""Return the best range from the grouped bounds.
None if (-inf, +inf).
"""
return _ffi_api.IntGrpBounds_FindBestRange(self)


@tvm._ffi.register_object("arith.IntConstraints")
class IntConstraints(Object):
"""Represent a set of integer constraints including variables, their ranges and
Expand Down Expand Up @@ -97,3 +143,33 @@ def solve_linear_equations(equations, variables=None, ranges=None):
if isinstance(equations, IntConstraints):
return _ffi_api.SolveLinearEquations(equations)
return _ffi_api.SolveLinearEquations(variables, ranges, equations)


def solve_linear_inequalities(equations, variables=None, ranges=None, deskew_range=False):
"""Solve linear inequalities.

Parameters
----------
equations : List[tvm.ir.PrimExpr] or IntConstraints
The inequalities of the variables
variables : Optional[List[tvm.tir.Var]]
The variables in the system.
ranges : Optional[Map[tvm.tir.Var, tvm.ir.Range]]
The ranges of the variables.
deskew_range: Optional[bool]
Whether deskew the result ranges to be started from zero.
Default false.

Returns
-------
ret_ranges: IntConstraints or IntConstraintsTransform
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
The result ranges for each variables.
Constrains that cannot be transformed to Range will be stored in IntConstraints.relations.
If deskew_range is set (=True), the result ranges will be deskewed to be started from zero.
New variables are created accordingly therefore IntConstraintsTransform is returned.
"""
solver = _ffi_api.SolveInequalitiesDeskewRange \
if deskew_range else _ffi_api.SolveInequalitiesToRange
if isinstance(equations, IntConstraints):
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
return solver(equations)
return solver(variables, ranges, equations)
43 changes: 43 additions & 0 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tvm
import tvm.arith
import tvm.tir
import tvm.te
import tvm._ffi


Expand Down Expand Up @@ -188,5 +189,47 @@ def assert_prim_expr_equal(lhs, rhs):
raise ValueError("{} and {} are not equal".format(lhs, rhs))


def check_bool_expr_is_true(bool_expr, vranges, cond=None):
""" Check that bool_expr holds given the condition cond
for every value of free variables from vranges.

Parameters
----------
bool_expr : tvm.ir.expr.PrimExpr
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
Boolean expression to check
vranges: Dict[tvm.tir.expr.Var, tvm.ir.Range]
Free variables and their ranges
cond: tvm.ir.expr.PrimExpr
extra conditions needs to be satisfied.
"""
if cond is not None:
bool_expr = tvm.te.any(tvm.tir.Not(cond), bool_expr)
yzhliu marked this conversation as resolved.
Show resolved Hide resolved

def _run_expr(expr, vranges):
""" Evaluate expr for every value of free variables
given by vranges and return the tensor of results.
"""
def _compute_body(*us):
vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)}
return tvm.tir.stmt_functor.substitute(expr, vmap)

A = tvm.te.compute([r.extent.value for v, r in vranges.items()], _compute_body)
args = [tvm.nd.empty(A.shape, A.dtype)]
sch = tvm.te.create_schedule(A.op)
mod = tvm.build(sch, [A])
mod(*args)
return args[0].asnumpy()

res = _run_expr(bool_expr, vranges)
if not np.all(res):
indices = list(np.argwhere(res == 0)[0])
counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)]
counterex = sorted(counterex, key=lambda x: x[0])
counterex = ", ".join([v + " = " + str(i) for v, i in counterex])
raise AssertionError("Expression {}\nis not true on {}\n"
"Counterexample: {}"
.format(tvm.tir.ir_pass.CanonicalSimplify(bool_expr),
vranges, counterex))


tvm._ffi._init_api("testing", __name__)
12 changes: 8 additions & 4 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,15 @@ bool Analyzer::CanProve(const PrimExpr& expr) {
return false;
}

PrimExpr Analyzer::Simplify(const PrimExpr& expr) {
PrimExpr Analyzer::Simplify(const PrimExpr& expr, size_t steps) {
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
if (tir::is_const(expr)) return expr;
auto res = this->rewrite_simplify(expr);
if (tir::is_const(res)) return res;
res = this->canonical_simplify(res);
PrimExpr res = expr;
for (size_t i = 0; i < steps; ++i) {
res = this->rewrite_simplify(res);
if (tir::is_const(res) || ++i == steps) return res;
res = this->canonical_simplify(res);
if (tir::is_const(res)) return res;
}
return res;
}

Expand Down
Loading