Skip to content

Commit

Permalink
target variable can now appear in either lhs or rhs of the expression…
Browse files Browse the repository at this point in the history
… to be analyzed
  • Loading branch information
derisavi committed Apr 5, 2019
1 parent eb82e7b commit b529769
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 18 deletions.
57 changes: 43 additions & 14 deletions src/arithmetic/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,24 +188,53 @@ void BoundDeducer::Init() {
}

void BoundDeducer::Transform() {
// We will ensure to set expr_ such that it contains target_
if (const LT* op = expr_.as<LT>()) {
is_greater = false;
expr_ = op->a;
// a < b -> a <= b - 1
result = op->b - 1;
if (GetPath(target_, op->a).empty()) {
// a < b -> b >= a + 1
is_greater = true;
expr_ = op->b;
result = op->a + 1;
} else {
// a < b -> a <= b - 1
is_greater = false;
expr_ = op->a;
result = op->b - 1;
}
} else if (const LE* op = expr_.as<LE>()) {
is_greater = false;
expr_ = op->a;
result = op->b;
if (GetPath(target_, op->a).empty()) {
// a <= b -> b >= a
is_greater = true;
expr_ = op->b;
result = op->a;
} else {
is_greater = false;
expr_ = op->a;
result = op->b;
}
} else if (const GT* op = expr_.as<GT>()) {
is_greater = true;
expr_ = op->a;
// a > b -> a >= b + 1
result = op->b + 1;
if (GetPath(target_, op->a).empty()) {
// a > b -> b <= a - 1
is_greater = false;
expr_ = op->b;
result = op->a - 1;
} else {
// a > b -> a >= b + 1
is_greater = true;
expr_ = op->a;
result = op->b + 1;
}
} else if (const GE* op = expr_.as<GE>()) {
is_greater = true;
expr_ = op->a;
result = op->b;
if (GetPath(target_, op->a).empty()) {
// a >= b -> b <= a
is_greater = false;
expr_ = op->b;
result = op->a;
} else {
is_greater = true;
expr_ = op->a;
result = op->b;
}
} else {
success = false;
}
Expand Down
36 changes: 32 additions & 4 deletions tests/python/unittest/test_arith_intset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,56 @@ def test_deduce():
b_s = tvm.arith.intset_interval(2, 3)
c_s = tvm.arith.intset_interval(10, 15)
d_s = tvm.arith.intset_interval(-3, -1)
zero = tvm.const(0, "int32")

e0 = (-b)*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = ((d - c) /(b*-1))
assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)

# expression containing variable a is on rhs
res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)

e0 = d*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = ((0-c)/d + 1)
assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)

# expression containing variable a is on rhs
res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)

e1 = (a*4+b < c)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
ans1 = (((c - b) + -1)/4)
assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)

# expression containing variable a is on rhs
e1 = (c > a*4+b)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)

e2 = (tvm.max(5, a * 4) < 0)
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
assert str(res2.max()) == "neg_inf"
assert str(res2.min()) == "pos_inf"

# expression containing variable a is on rhs
e2 = (zero < tvm.max(5, a * 4))
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
assert str(res2.max()) == "neg_inf"
assert str(res2.min()) == "pos_inf"


e3 = (-b)+a*c-d
res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
ans3 = 2/c+1
assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3)

res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3)

def test_check():
a = tvm.var('a')
b = tvm.var('b')
Expand Down Expand Up @@ -81,11 +105,13 @@ def test_basic(a1, a2, coff):
[x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1

res1 = tvm.arith.DeduceBound(a, e0>17, {b: b_s}, {b: b_s})
# expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32") < e0, {b: b_s}, {b: b_s})
[x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1

res1 = tvm.arith.DeduceBound(a, e0<=17, {b: b_s}, {b: b_s})
# expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32")>= e0, {b: b_s}, {b: b_s})
[x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1

Expand All @@ -111,15 +137,17 @@ def test_complex(a1, a2, coff):
[t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1

res1 = tvm.arith.DeduceBound(a, e0<=63, {b: b_s}, {b: b_s})
# expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32")>= e0, {b: b_s}, {b: b_s})
[t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1

res1 = tvm.arith.DeduceBound(a, e0>63, {b: b_s}, {b: b_s})
[t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1

res1 = tvm.arith.DeduceBound(a, e0>=63, {b: b_s}, {b: b_s})
# expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32") <= e0, {b: b_s}, {b: b_s})
[t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1

Expand Down

0 comments on commit b529769

Please sign in to comment.