Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith] Inequalities solver #5618

Merged
merged 39 commits into from
Jul 6, 2020
Merged
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
8207da5
[arith] inequalities solver
yzhliu Apr 30, 2020
39a93f7
introduce IntGroupedBounds
yzhliu May 16, 2020
af6708d
add no deskewed solution
yzhliu May 16, 2020
7f551e1
keep IntConstraints def
yzhliu May 18, 2020
e56d310
add test case and fix for equations
yzhliu Jun 5, 2020
1b216e9
add random consistency test cases
yzhliu Jun 6, 2020
059ea2e
improve test cases
yzhliu Jun 6, 2020
9c5becc
add doc & comments
yzhliu Jun 7, 2020
0fd69f7
test case refactored
yzhliu Jun 7, 2020
9af9d24
test file rename
yzhliu Jun 7, 2020
8a44d61
merge from upstream
yzhliu Jun 7, 2020
6349c33
fix lint
yzhliu Jun 7, 2020
8caa406
apply clang-format
yzhliu Jun 7, 2020
bc0587f
apply cl-format-10
yzhliu Jun 7, 2020
e84467a
fix cpplint
yzhliu Jun 7, 2020
45c6a4c
add more docs
yzhliu Jun 17, 2020
bb22dd2
Add Co-author.
yzhliu Jun 17, 2020
0e0b5f0
add check_solution
yzhliu Jun 17, 2020
ff05fe0
add support for unsolvable inequalities
yzhliu Jun 19, 2020
e6845bc
fix for non-divisible equation
yzhliu Jun 20, 2020
7166020
fix non-divisible case again and add test case for no solution
yzhliu Jun 20, 2020
eaca5c9
revise testing
yzhliu Jun 20, 2020
7fd0c57
Merge remote-tracking branch 'upstream/master' into inequality_solver
yzhliu Jun 20, 2020
12b5fcd
fix merging
yzhliu Jun 20, 2020
182a5d5
fix lint
yzhliu Jun 20, 2020
c8b2370
add comments
yzhliu Jun 20, 2020
4c6cec4
fix order of as_condition
yzhliu Jun 22, 2020
a4dbba4
fix lint
yzhliu Jun 22, 2020
3a29517
Merge remote-tracking branch 'upstream/master' into inequality_solver
yzhliu Jun 22, 2020
666d569
Merge remote-tracking branch 'upstream/master' into inequality_solver
yzhliu Jun 26, 2020
0605343
remove special dealing with equations
yzhliu Jun 26, 2020
485cfad
address several comments; add steps to python side analyzer.simplify
yzhliu Jun 26, 2020
98cc3ac
fix a dumb compilation failure
yzhliu Jun 26, 2020
d8f4ee5
Merge remote-tracking branch 'upstream/master' into inequality_solver
yzhliu Jun 27, 2020
14588cb
merge from upstream
yzhliu Jul 1, 2020
51eb652
IntGrpBounds -> IntGroupBounds
yzhliu Jul 1, 2020
aaa53f3
move if check to the root
yzhliu Jul 1, 2020
d42f2ca
fix lint
yzhliu Jul 1, 2020
b21c13f
fix clang format check
yzhliu Jul 1, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
test case refactored
yzhliu committed Jun 7, 2020
commit 0fd69f769b29b929614904a064fa70d89868cfe0
24 changes: 24 additions & 0 deletions include/tvm/arith/int_solver.h
Original file line number Diff line number Diff line change
@@ -290,6 +290,30 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t>> *S,
*/
IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve);

/*!
* \brief Solve linear inequalities.
* \param system_to_solve the variables to solve, their ranges, and a list of inequalities.
* \return A map of variables and their solved bounds,
* and constrains that cannot be solved to bounds.
*/
PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_to_solve);

/*!
* \brief Solve linear inequalities.
* \param system_to_solve the variables to solve, their ranges, and a list of inequalities.
* \return The result ranges for each variables.
* Constrains that cannot be transformed to Range will be stored in relations.
*/
IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities);

/*!
* \brief Solve linear inequalities.
* \param system_to_solve the variables to solve, their ranges, and a list of inequalities.
* \return Solved ranges are deskewed to be started from zero.
* New variables and the mapping are created accordingly.
*/
IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequalities);

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_INT_SOLVER_H_
43 changes: 43 additions & 0 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
import numpy as np
import tvm
import tvm._ffi
from tvm import te, tir


def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):
@@ -168,4 +169,46 @@ def compare_derivative(j, n_der, grad):
x_name, grad.shape, dist, max_diff, avg_diff)


def check_bool_expr_is_true(bool_expr, vranges, cond=None):
""" Check that bool_expr holds given the condition cond
for every value of free variables from vranges.

Parameters
----------
bool_expr : tvm.ir.expr.PrimExpr
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
Boolean expression to check
vranges: Dict[tvm.tir.expr.Var, tvm.ir.Range]
Free variables and their ranges
cond: tvm.ir.expr.PrimExpr
extra conditions needs to be satisfied.
"""
if cond is not None:
bool_expr = te.any(tir.Not(cond), bool_expr)

def _run_expr(expr, vranges):
""" Evaluate expr for every value of free variables
given by vranges and return the tensor of results.
"""
def _compute_body(*us):
vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)}
return tir.ir_pass.Substitute(expr, vmap)

A = te.compute([r.extent.value for v, r in vranges.items()], _compute_body)
args = [tvm.nd.empty(A.shape, A.dtype)]
sch = te.create_schedule(A.op)
mod = tvm.build(sch, [A])
mod(*args)
return args[0].asnumpy()

res = _run_expr(bool_expr, vranges)
if not np.all(res):
indices = list(np.argwhere(res == 0)[0])
counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)]
counterex = sorted(counterex, key=lambda x: x[0])
counterex = ", ".join([v + " = " + str(i) for v, i in counterex])
raise AssertionError("Expression {}\nis not true on {}\n"
"Counterexample: {}"
.format(tir.ir_pass.CanonicalSimplify(bool_expr), vranges, counterex))


tvm._ffi._init_api("testing", __name__)
13 changes: 0 additions & 13 deletions src/arith/solve_linear_inequality.cc
Original file line number Diff line number Diff line change
@@ -265,7 +265,6 @@ void MoveEquality(std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>&
}

PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_to_solve) {
LOG(INFO) << "solving inequalities " << system_to_solve;
arith::Analyzer analyzer;
analyzer.Bind(system_to_solve->ranges);

@@ -296,12 +295,6 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
NormalizeComparisons()(analyzer.Simplify(ineq, 3)), analyzer);
}

DebugPrint(current_ineq_set_to_solve,
next_ineq_set_to_solve,
rest,
coef_pos,
coef_neg);

Map<Var, IntGrpBounds> res_bounds;
for (const Var& v : system_to_solve->variables) {
CHECK(!res_bounds.count(v)) <<
@@ -328,12 +321,6 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
coef_neg,
analyzer);

DebugPrint(current_ineq_set_to_solve,
next_ineq_set_to_solve,
rest,
coef_pos,
coef_neg);

// Combine each positive inequality with each negative one (by adding them together)
for (const auto& pos : coef_pos) {
for (const auto& neg : coef_neg) {
60 changes: 7 additions & 53 deletions tests/python/unittest/test_arith_solve_linear_inequality.py
Original file line number Diff line number Diff line change
@@ -15,51 +15,17 @@
# specific language governing permissions and limitations
# under the License.
import random
import numpy as np
import sys
import pytest
import tvm
from tvm import te, arith, ir, tir


def run_expr(expr, vranges):
""" Evaluate expr for every value of free variables
given by vranges and return the tensor of results.
TODO(yzhliu): move to utils
"""
def _compute_body(*us):
vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)}
return tir.ir_pass.Substitute(expr, vmap)

A = te.compute([r.extent.value for v, r in vranges.items()], _compute_body)
args = [tvm.nd.empty(A.shape, A.dtype)]
sch = te.create_schedule(A.op)
mod = tvm.build(sch, [A])
mod(*args)
return args[0].asnumpy()


def check_bruteforce(bool_expr, vranges, cond=None):
""" Check that bool_expr holds given the condition cond
for every value of free variables from vranges.
TODO(yzhliu): move to utils
"""
if cond is not None:
bool_expr = te.any(tir.Not(cond), bool_expr)

res = run_expr(bool_expr, vranges)
if not np.all(res):
indices = list(np.argwhere(res == 0)[0])
counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)]
counterex = sorted(counterex, key=lambda x: x[0])
counterex = ", ".join([v + " = " + str(i) for v, i in counterex])
raise AssertionError("Expression {}\nis not true on {}\n"
"Counterexample: {}"
.format(tir.ir_pass.CanonicalSimplify(bool_expr), vranges, counterex))
from tvm import te, arith, ir, tir, testing


def test_solve_system_of_inequalities():
random.seed(0)
seed = random.randrange(sys.maxsize)
print("\nThis test is intentionally non-deterministic, "
"if it fails please report it in github issue together with this seed {}\n".format(seed))
random.seed(seed)

def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)):
vs = [te.var("x" + str(i)) for i in range(variables)]
@@ -75,16 +41,9 @@ def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)):

vranges = {v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in vs}
before = te.all(tir.const(1, 'bool'), *fs)

print("--- before ---")
print(fs)
after = arith._ffi_api.SolveInequalitiesAsCondition(vs, vranges, fs)
after = te.all(tir.const(1, 'bool'), *after)
print("--- after ---")
print(after)
print()

check_bruteforce(before == after, vranges)
testing.check_bool_expr_is_true(before == after, vranges)
yzhliu marked this conversation as resolved.
Show resolved Hide resolved

for i in range(3):
_check(1, 1)
@@ -140,7 +99,6 @@ def test_dual_variable():
tvm.tir.LE(x + y, 20),
tvm.tir.GE(x - y, 10),
], [x, y], ranges)
print(solution)
# 0 <= y <=5
assert solution.ranges[y].min == 0
assert solution.ranges[y].extent == 6
@@ -150,7 +108,6 @@ def test_dual_variable():

# deskew the solved ranges to be starting from zero
solution = arith.solve_linear_inequalities(problem, variables, ranges, deskew_range=True)
print(solution)
[x_new, y_new] = solution.dst.variables
[rel] = solution.dst.relations
assert ir.structural_equal(rel, (y_new*2) + x_new <= 10)
@@ -206,7 +163,4 @@ def test_multi_equal():


if __name__ == "__main__":
test_solve_system_of_inequalities()
test_dual_variable()
test_equal()
test_multi_equal()
pytest.main([__file__])
43 changes: 4 additions & 39 deletions tests/python/unittest/test_arith_solve_linear_system.py
Original file line number Diff line number Diff line change
@@ -19,43 +19,7 @@
import sys
import pytest
import tvm
from tvm import te, arith, ir, tir


def run_expr(expr, vranges):
""" Evaluate expr for every value of free variables
given by vranges and return the tensor of results.
TODO(yzhliu): move to utils
"""
def _compute_body(*us):
vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)}
return tir.ir_pass.Substitute(expr, vmap)

A = te.compute([r.extent.value for v, r in vranges.items()], _compute_body)
args = [tvm.nd.empty(A.shape, A.dtype)]
sch = te.create_schedule(A.op)
mod = tvm.build(sch, [A])
mod(*args)
return args[0].asnumpy()


def check_bruteforce(bool_expr, vranges, cond=None):
""" Check that bool_expr holds given the condition cond
for every value of free variables from vranges.
TODO(yzhliu): move to utils
"""
if cond is not None:
bool_expr = te.any(tir.Not(cond), bool_expr)

res = run_expr(bool_expr, vranges)
if not np.all(res):
indices = list(np.argwhere(res == 0)[0])
counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)]
counterex = sorted(counterex, key=lambda x: x[0])
counterex = ", ".join([v + " = " + str(i) for v, i in counterex])
raise AssertionError("Expression {}\nis not true on {}\n"
"Counterexample: {}"
.format(tir.ir_pass.CanonicalSimplify(bool_expr), vranges, counterex))
from tvm import te, arith, ir, tir, testing


def check_solution(solution, vranges={}):
@@ -81,8 +45,9 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap):
range_cond = tir.ir_pass.Substitute(range_cond, backvarmap)
cond_subst = te.all(cond_subst, range_cond)
cond_subst = tir.ir_pass.Simplify(cond_subst)
check_bruteforce(te.all(cond_subst, cond_on_vars), all_vranges,
cond=te.all(tir.const(1, 'bool'), *constraints1.relations))
testing.check_bool_expr_is_true(
te.all(cond_subst, cond_on_vars), all_vranges,
cond=te.all(tir.const(1, 'bool'), *constraints1.relations))

rels = solution.dst.relations
if len(rels) == 1 and ir.structural_equal(rels[0], False):