diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 2528fe80c4b0c..5768fb19fdcdd 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -150,43 +150,26 @@ class BoundDeducer: public IRVisitor { // always use relax bound bool divided = analyzer_.CanProve(floormod(result_, operand) == 0); - // TODO(tvm-team): use floordiv, which could give better bound. - result_ = truncdiv(result_, operand); + result_ = floordiv(result_, operand); // rounding down here if (!divided) { - // Handle non-divisible case - // NOTE: this accounts for trunc div behavior. - bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative(); - if (comp_op == kGreater) { + // System will round down in all the cases, so add one for result_ for kGreater + // (x >= 3/2 --> x >= 2) + // (x >= -3/2 --> x >= -1) + // (x >= 3/-2 --> x >= -1) + // (x >= -3/-2 --> x >= 2) result_ += 1; } else if (comp_op == kEqual) { // condition unsatisfiable as with trunc div, it will change the expression success_ = false; return; } else { - // NOTE: this is a bit sutble hack. - // - // condition: - // - x * operand <= result - // - operand > 0 - // - x >= 0 - // - // Then it is fine to deduce that x <= result / operand. - // - if result > 0, this division round down - // - if result < 0, (result / operand) rounds up and may violate the constraint - // however, given that x is always non-negative, - // it is fine to have this relaxed bound, given that the user of deduce bound - // will respect the bound of x - // - // TODO(tvm-team): think about a better API to incorporate constraint of x. - // e.g. specify an interval of x and return a bound - // that is in the interval and satisfies the condition. - if (target_is_non_neg && sign_operand == kPositive) { - // do nothing - } else { - result_ -= 1; - } + // System rounds down in all cases, do nothing for kLess. + // ( x <= 3/2 --> x <= 1) + // ( x <= -3/2 --> x <= -2) + // ( x <= 3/-2 --> x <= -2) + // ( x <= -3/-2 --> x <= 1) } } Visit(left ? op->a : op->b);