diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 2528fe80c4b0..6f7b4d78da05 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -150,43 +150,26 @@ class BoundDeducer: public IRVisitor { // always use relax bound bool divided = analyzer_.CanProve(floormod(result_, operand) == 0); - // TODO(tvm-team): use floordiv, which could give better bound. - result_ = truncdiv(result_, operand); + result_ = floordiv(result_, operand); // rounding down here if (!divided) { - // Handle non-divisible case - // NOTE: this accounts for trunc div behavior. - bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative(); - if (comp_op == kGreater) { + // System will round down in all the cases, so add one for result_ for kGreater + // (x >= 3/2 --> x >= 2) + // (x >= -3/2 --> x >= -1) + // (x >= 3/-2 --> x >= -1) + // (x >= -3/-2 --> x >= 2) result_ += 1; } else if (comp_op == kEqual) { - // condition unsatisfiable as with trunc div, it will change the expression + // condition unsatisfiable as with floor div, it will change the expression success_ = false; return; } else { - // NOTE: this is a bit sutble hack. - // - // condition: - // - x * operand <= result - // - operand > 0 - // - x >= 0 - // - // Then it is fine to deduce that x <= result / operand. - // - if result > 0, this division round down - // - if result < 0, (result / operand) rounds up and may violate the constraint - // however, given that x is always non-negative, - // it is fine to have this relaxed bound, given that the user of deduce bound - // will respect the bound of x - // - // TODO(tvm-team): think about a better API to incorporate constraint of x. - // e.g. specify an interval of x and return a bound - // that is in the interval and satisfies the condition. - if (target_is_non_neg && sign_operand == kPositive) { - // do nothing - } else { - result_ -= 1; - } + // System rounds down in all cases, do nothing for kLess. + // ( x <= 3/2 --> x <= 1) + // ( x <= -3/2 --> x <= -2) + // ( x <= 3/-2 --> x <= -2) + // ( x <= -3/-2 --> x <= 1) } } Visit(left ? op->a : op->b); diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py index 235c93538225..33e31c766950 100644 --- a/tests/python/unittest/test_arith_deduce_bound.py +++ b/tests/python/unittest/test_arith_deduce_bound.py @@ -35,11 +35,11 @@ def test_deduce(): d_s = tvm.arith.IntervalSet(-3, -1) zero = tvm.const(0, "int32") - tdiv = tvm.truncdiv + fdiv = tvm.floordiv e0 = (-b)*a+c-d res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) - ans0 = (tdiv(d - c, b*-1) + (-1)) + ans0 = fdiv(d - c, b*-1) assert_expr_equal(res0.max_value, ans0) # expression containing variable a is on rhs @@ -48,7 +48,7 @@ def test_deduce(): e0 = d*a+c-d res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) - ans0 = (tdiv(d-c,d) - 1) + ans0 = fdiv(d-c, d) assert_expr_equal(res0.max_value, ans0) # expression containing variable a is on rhs @@ -58,7 +58,7 @@ def test_deduce(): e1 = (a*4+b < c) res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) - ans1 = (tdiv((c - b) + -1,4) -1) + ans1 = fdiv(c-1-b, 4) assert_expr_equal(res1.max_value, ans1) @@ -81,7 +81,7 @@ def test_deduce(): e3 = (-b)+a*c-d res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) - ans3 = tdiv(2,c)+1 + ans3 = fdiv(2,c)+1 assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3) res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})