From b5297697cabf2786c4b9c52de39fbd687a9641ac Mon Sep 17 00:00:00 2001 From: Salem Derisavi Date: Tue, 12 Mar 2019 14:26:02 -0400 Subject: [PATCH] target variable can now appear in either lhs or rhs of the expression to be analyzed --- src/arithmetic/bound_deducer.cc | 57 ++++++++++++++++------ tests/python/unittest/test_arith_intset.py | 36 ++++++++++++-- 2 files changed, 75 insertions(+), 18 deletions(-) diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index c9779bbbe24d..2aaf4fec6f88 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -188,24 +188,53 @@ void BoundDeducer::Init() { } void BoundDeducer::Transform() { + // We will ensure to set expr_ such that it contains target_ if (const LT* op = expr_.as()) { - is_greater = false; - expr_ = op->a; - // a < b -> a <= b - 1 - result = op->b - 1; + if (GetPath(target_, op->a).empty()) { + // a < b -> b >= a + 1 + is_greater = true; + expr_ = op->b; + result = op->a + 1; + } else { + // a < b -> a <= b - 1 + is_greater = false; + expr_ = op->a; + result = op->b - 1; + } } else if (const LE* op = expr_.as()) { - is_greater = false; - expr_ = op->a; - result = op->b; + if (GetPath(target_, op->a).empty()) { + // a <= b -> b >= a + is_greater = true; + expr_ = op->b; + result = op->a; + } else { + is_greater = false; + expr_ = op->a; + result = op->b; + } } else if (const GT* op = expr_.as()) { - is_greater = true; - expr_ = op->a; - // a > b -> a >= b + 1 - result = op->b + 1; + if (GetPath(target_, op->a).empty()) { + // a > b -> b <= a - 1 + is_greater = false; + expr_ = op->b; + result = op->a - 1; + } else { + // a > b -> a >= b + 1 + is_greater = true; + expr_ = op->a; + result = op->b + 1; + } } else if (const GE* op = expr_.as()) { - is_greater = true; - expr_ = op->a; - result = op->b; + if (GetPath(target_, op->a).empty()) { + // a >= b -> b <= a + is_greater = false; + expr_ = op->b; + result = op->a; + } else { + is_greater = true; + expr_ = op->a; + result = op->b; + } } else { success = false; } diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 9b869feddc9d..d8428afe514a 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -22,32 +22,56 @@ def test_deduce(): b_s = tvm.arith.intset_interval(2, 3) c_s = tvm.arith.intset_interval(10, 15) d_s = tvm.arith.intset_interval(-3, -1) + zero = tvm.const(0, "int32") e0 = (-b)*a+c-d res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) ans0 = ((d - c) /(b*-1)) assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) + # expression containing variable a is on rhs + res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) + assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) + e0 = d*a+c-d res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) ans0 = ((0-c)/d + 1) assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) + # expression containing variable a is on rhs + res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) + assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) + e1 = (a*4+b < c) res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) ans1 = (((c - b) + -1)/4) assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1) + # expression containing variable a is on rhs + e1 = (c > a*4+b) + res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) + assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1) + e2 = (tvm.max(5, a * 4) < 0) res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) assert str(res2.max()) == "neg_inf" assert str(res2.min()) == "pos_inf" + # expression containing variable a is on rhs + e2 = (zero < tvm.max(5, a * 4)) + res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) + assert str(res2.max()) == "neg_inf" + assert str(res2.min()) == "pos_inf" + + e3 = (-b)+a*c-d res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) ans3 = 2/c+1 assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3) + res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) + assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3) + def test_check(): a = tvm.var('a') b = tvm.var('b') @@ -81,11 +105,13 @@ def test_basic(a1, a2, coff): [x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()] assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1 - res1 = tvm.arith.DeduceBound(a, e0>17, {b: b_s}, {b: b_s}) + # expression containing variable a is on rhs + res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32") < e0, {b: b_s}, {b: b_s}) [x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()] assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1 - res1 = tvm.arith.DeduceBound(a, e0<=17, {b: b_s}, {b: b_s}) + # expression containing variable a is on rhs + res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32")>= e0, {b: b_s}, {b: b_s}) [x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()] assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1 @@ -111,7 +137,8 @@ def test_complex(a1, a2, coff): [t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()] assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1 - res1 = tvm.arith.DeduceBound(a, e0<=63, {b: b_s}, {b: b_s}) + # expression containing variable a is on rhs + res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32")>= e0, {b: b_s}, {b: b_s}) [t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()] assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1 @@ -119,7 +146,8 @@ def test_complex(a1, a2, coff): [t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()] assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1 - res1 = tvm.arith.DeduceBound(a, e0>=63, {b: b_s}, {b: b_s}) + # expression containing variable a is on rhs + res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32") <= e0, {b: b_s}, {b: b_s}) [t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()] assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1