diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index be830d3892094..fbb52a9ebe7a9 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -205,17 +205,14 @@ class ConstIntBoundAnalyzer::Impl Entry VisitExpr_(const MulNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); - return BinaryOpBoundry(a, b, InfAwareMul); + return BinaryOpBoundary(a, b, InfAwareMul); } Entry VisitExpr_(const DivNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); CHECK(!b.is_const(0)) << "divide by zero"; - // assume no division by 0 - if (b.min_value == 0) b.min_value = 1; - if (b.max_value == 0) b.max_value = -1; - return BinaryOpBoundry(a, b, InfAwareDiv); + return HandleDivision(a, b, op->dtype, InfAwareDiv); } Entry VisitExpr_(const ModNode* op) final { @@ -244,10 +241,7 @@ class ConstIntBoundAnalyzer::Impl Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); CHECK(!b.is_const(0)) << "floordiv by zero"; - // assume no division by 0 - if (b.min_value == 0) b.min_value = 1; - if (b.max_value == 0) b.max_value = -1; - return BinaryOpBoundry(a, b, InfAwareFloorDiv); + return HandleDivision(a, b, op->dtype, InfAwareFloorDiv); } Entry VisitExpr_(const FloorModNode* op) final { @@ -331,7 +325,7 @@ class ConstIntBoundAnalyzer::Impl Entry VisitRightShift(const CallNode* op) { Entry a = VisitExpr(op->args[0]); Entry b = VisitExpr(op->args[1]); - return BinaryOpBoundry(a, b, InfAwareRightShift); + return BinaryOpBoundary(a, b, InfAwareRightShift); } Entry VisitBitwiseAnd(const CallNode* op) { @@ -380,14 +374,14 @@ class ConstIntBoundAnalyzer::Impl // internal helper functions /*! * \brief Get boundary of binary op who are monotonic wrt to one argument. - * \param param a The entry of the left operand. - * \param param a The entry of the right operand. + * \param a The entry of the left operand. + * \param b The entry of the right operand. * \param op The operator. * \tparam F the operator function type. * \return The result. */ template - static Entry BinaryOpBoundry(Entry a, Entry b, const F& op) { + static Entry BinaryOpBoundary(Entry a, Entry b, const F& op) { Entry ret; // The boundary point must be shihft of the original boundary. int64_t v1 = op(a.min_value, b.min_value); @@ -398,6 +392,38 @@ class ConstIntBoundAnalyzer::Impl ret.max_value = std::max(std::max(std::max(v1, v2), v3), v4); return ret; } + /*! + * \brief Get value boundaries of division (e.g. Div or FloorDiv). + * \param a The entry of the left operand. + * \param b The entry of the right operand. + * \param dt The data type of the division operator. + * \param op The division operator. + * \tparam F the operator function type. + * \return The result. + */ + template + static Entry HandleDivision(Entry a, Entry b, DataType dt, const F& op) { + // Here we have a / b. + // The largest value of the division will be for the smallest (with + // respect to the absolute value) value of b. If the range of b starts + // at a negative value and ends at a positive one, narrow it down to + // be closer to 0, because BinaryOpBoundary only checks end-points of + // the domain ranges. + + // If the range of b contains 0, then some infinity will be involved + if (b.min_value <= 0 && 0 <= b.max_value) { + Entry b_neg = b.min_value < 0 ? MakeBound(b.min_value, -1) : Everything(dt); + Entry b_pos = b.max_value > 0 ? MakeBound(1, b.max_value) : Everything(dt); + + Entry e_neg = BinaryOpBoundary(a, b_neg, op); + Entry e_pos = BinaryOpBoundary(a, b_pos, op); + + return MakeBound(std::min(e_neg.min_value, e_pos.min_value), + std::max(e_neg.max_value, e_pos.max_value)); + } + // If the range of b does not have 0, use BinaryOpBoundary. + return BinaryOpBoundary(a, b, op); + } /*! * \brief Compute x + y, aware of inf. * \param x The left operand. diff --git a/tests/python/unittest/test_arith_const_int_bound.py b/tests/python/unittest/test_arith_const_int_bound.py index c5794cd126ef2..9ead0d488408e 100644 --- a/tests/python/unittest/test_arith_const_int_bound.py +++ b/tests/python/unittest/test_arith_const_int_bound.py @@ -122,6 +122,12 @@ def test_truncdiv_bound(): assert bd.min_value == bd.NEG_INF assert bd.max_value == bd.POS_INF + analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True) + analyzer.update(y, tvm.arith.ConstIntBound(-4, 12), override=True) + bd = analyzer.const_int_bound(tdiv(x, y)) + assert bd.min_value == -9 + assert bd.max_value == 9 + def test_truncmod_bound(): analyzer = tvm.arith.Analyzer() @@ -169,6 +175,12 @@ def test_floordiv_bound(): assert bd.min_value == bd.NEG_INF assert bd.max_value == bd.POS_INF + analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True) + analyzer.update(y, tvm.arith.ConstIntBound(-4, 12), override=True) + bd = analyzer.const_int_bound(fld(x, y)) + assert bd.min_value == -9 + assert bd.max_value == 9 + def test_floormod_bound(): analyzer = tvm.arith.Analyzer()