Skip to content

Commit

Permalink
Fix division range estimation error in simplifier (apache#6244)
Browse files Browse the repository at this point in the history
Division a/b assumes maximum values when b is close to 0. Account
for that when estimating the range for a/b when 0 belongs to the
estimated range for b.

Assume that a division by zero cannot happen in a valid program,
so in such cases treat the range for b as a union
  [b.min_value, -1] u [1, b.max_value]
  • Loading branch information
Krzysztof Parzyszek authored and Trevor Morris committed Aug 26, 2020
1 parent e8f52c6 commit 824d575
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 13 deletions.
52 changes: 39 additions & 13 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,17 +205,14 @@ class ConstIntBoundAnalyzer::Impl
Entry VisitExpr_(const MulNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
return BinaryOpBoundry(a, b, InfAwareMul);
return BinaryOpBoundary(a, b, InfAwareMul);
}

Entry VisitExpr_(const DivNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
CHECK(!b.is_const(0)) << "divide by zero";
// assume no division by 0
if (b.min_value == 0) b.min_value = 1;
if (b.max_value == 0) b.max_value = -1;
return BinaryOpBoundry(a, b, InfAwareDiv);
return HandleDivision(a, b, op->dtype, InfAwareDiv);
}

Entry VisitExpr_(const ModNode* op) final {
Expand Down Expand Up @@ -244,10 +241,7 @@ class ConstIntBoundAnalyzer::Impl
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
CHECK(!b.is_const(0)) << "floordiv by zero";
// assume no division by 0
if (b.min_value == 0) b.min_value = 1;
if (b.max_value == 0) b.max_value = -1;
return BinaryOpBoundry(a, b, InfAwareFloorDiv);
return HandleDivision(a, b, op->dtype, InfAwareFloorDiv);
}

Entry VisitExpr_(const FloorModNode* op) final {
Expand Down Expand Up @@ -331,7 +325,7 @@ class ConstIntBoundAnalyzer::Impl
Entry VisitRightShift(const CallNode* op) {
Entry a = VisitExpr(op->args[0]);
Entry b = VisitExpr(op->args[1]);
return BinaryOpBoundry(a, b, InfAwareRightShift);
return BinaryOpBoundary(a, b, InfAwareRightShift);
}

Entry VisitBitwiseAnd(const CallNode* op) {
Expand Down Expand Up @@ -380,14 +374,14 @@ class ConstIntBoundAnalyzer::Impl
// internal helper functions
/*!
* \brief Get boundary of binary op who are monotonic wrt to one argument.
* \param param a The entry of the left operand.
* \param param a The entry of the right operand.
* \param a The entry of the left operand.
* \param b The entry of the right operand.
* \param op The operator.
* \tparam F the operator function type.
* \return The result.
*/
template <typename F>
static Entry BinaryOpBoundry(Entry a, Entry b, const F& op) {
static Entry BinaryOpBoundary(Entry a, Entry b, const F& op) {
Entry ret;
// The boundary point must be shihft of the original boundary.
int64_t v1 = op(a.min_value, b.min_value);
Expand All @@ -398,6 +392,38 @@ class ConstIntBoundAnalyzer::Impl
ret.max_value = std::max(std::max(std::max(v1, v2), v3), v4);
return ret;
}
/*!
* \brief Get value boundaries of division (e.g. Div or FloorDiv).
* \param a The entry of the left operand.
* \param b The entry of the right operand.
* \param dt The data type of the division operator.
* \param op The division operator.
* \tparam F the operator function type.
* \return The result.
*/
template <typename F>
static Entry HandleDivision(Entry a, Entry b, DataType dt, const F& op) {
// Here we have a / b.
// The largest value of the division will be for the smallest (with
// respect to the absolute value) value of b. If the range of b starts
// at a negative value and ends at a positive one, narrow it down to
// be closer to 0, because BinaryOpBoundary only checks end-points of
// the domain ranges.

// If the range of b contains 0, then some infinity will be involved
if (b.min_value <= 0 && 0 <= b.max_value) {
Entry b_neg = b.min_value < 0 ? MakeBound(b.min_value, -1) : Everything(dt);
Entry b_pos = b.max_value > 0 ? MakeBound(1, b.max_value) : Everything(dt);

Entry e_neg = BinaryOpBoundary(a, b_neg, op);
Entry e_pos = BinaryOpBoundary(a, b_pos, op);

return MakeBound(std::min(e_neg.min_value, e_pos.min_value),
std::max(e_neg.max_value, e_pos.max_value));
}
// If the range of b does not have 0, use BinaryOpBoundary.
return BinaryOpBoundary(a, b, op);
}
/*!
* \brief Compute x + y, aware of inf.
* \param x The left operand.
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_arith_const_int_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ def test_truncdiv_bound():
assert bd.min_value == bd.NEG_INF
assert bd.max_value == bd.POS_INF

analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(-4, 12), override=True)
bd = analyzer.const_int_bound(tdiv(x, y))
assert bd.min_value == -9
assert bd.max_value == 9


def test_truncmod_bound():
analyzer = tvm.arith.Analyzer()
Expand Down Expand Up @@ -169,6 +175,12 @@ def test_floordiv_bound():
assert bd.min_value == bd.NEG_INF
assert bd.max_value == bd.POS_INF

analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(-4, 12), override=True)
bd = analyzer.const_int_bound(fld(x, y))
assert bd.min_value == -9
assert bd.max_value == 9


def test_floormod_bound():
analyzer = tvm.arith.Analyzer()
Expand Down

0 comments on commit 824d575

Please sign in to comment.