Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix division range estimation error in simplifier #6244

Merged
merged 1 commit into from
Aug 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 39 additions & 13 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,17 +205,14 @@ class ConstIntBoundAnalyzer::Impl
Entry VisitExpr_(const MulNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
return BinaryOpBoundry(a, b, InfAwareMul);
return BinaryOpBoundary(a, b, InfAwareMul);
}

Entry VisitExpr_(const DivNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
CHECK(!b.is_const(0)) << "divide by zero";
// assume no division by 0
if (b.min_value == 0) b.min_value = 1;
if (b.max_value == 0) b.max_value = -1;
return BinaryOpBoundry(a, b, InfAwareDiv);
return HandleDivision(a, b, op->dtype, InfAwareDiv);
}

Entry VisitExpr_(const ModNode* op) final {
Expand Down Expand Up @@ -244,10 +241,7 @@ class ConstIntBoundAnalyzer::Impl
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
CHECK(!b.is_const(0)) << "floordiv by zero";
// assume no division by 0
if (b.min_value == 0) b.min_value = 1;
if (b.max_value == 0) b.max_value = -1;
return BinaryOpBoundry(a, b, InfAwareFloorDiv);
return HandleDivision(a, b, op->dtype, InfAwareFloorDiv);
}

Entry VisitExpr_(const FloorModNode* op) final {
Expand Down Expand Up @@ -331,7 +325,7 @@ class ConstIntBoundAnalyzer::Impl
Entry VisitRightShift(const CallNode* op) {
Entry a = VisitExpr(op->args[0]);
Entry b = VisitExpr(op->args[1]);
return BinaryOpBoundry(a, b, InfAwareRightShift);
return BinaryOpBoundary(a, b, InfAwareRightShift);
}

Entry VisitBitwiseAnd(const CallNode* op) {
Expand Down Expand Up @@ -380,14 +374,14 @@ class ConstIntBoundAnalyzer::Impl
// internal helper functions
/*!
* \brief Get boundary of binary op who are monotonic wrt to one argument.
* \param param a The entry of the left operand.
* \param param a The entry of the right operand.
* \param a The entry of the left operand.
* \param b The entry of the right operand.
* \param op The operator.
* \tparam F the operator function type.
* \return The result.
*/
template <typename F>
static Entry BinaryOpBoundry(Entry a, Entry b, const F& op) {
static Entry BinaryOpBoundary(Entry a, Entry b, const F& op) {
Entry ret;
// The boundary point must be shihft of the original boundary.
int64_t v1 = op(a.min_value, b.min_value);
Expand All @@ -398,6 +392,38 @@ class ConstIntBoundAnalyzer::Impl
ret.max_value = std::max(std::max(std::max(v1, v2), v3), v4);
return ret;
}
/*!
* \brief Get value boundaries of division (e.g. Div or FloorDiv).
* \param a The entry of the left operand.
* \param b The entry of the right operand.
* \param dt The data type of the division operator.
* \param op The division operator.
* \tparam F the operator function type.
* \return The result.
*/
template <typename F>
static Entry HandleDivision(Entry a, Entry b, DataType dt, const F& op) {
// Here we have a / b.
// The largest value of the division will be for the smallest (with
// respect to the absolute value) value of b. If the range of b starts
// at a negative value and ends at a positive one, narrow it down to
// be closer to 0, because BinaryOpBoundary only checks end-points of
// the domain ranges.

// If the range of b contains 0, then some infinity will be involved
if (b.min_value <= 0 && 0 <= b.max_value) {
Entry b_neg = b.min_value < 0 ? MakeBound(b.min_value, -1) : Everything(dt);
Entry b_pos = b.max_value > 0 ? MakeBound(1, b.max_value) : Everything(dt);

Entry e_neg = BinaryOpBoundary(a, b_neg, op);
Entry e_pos = BinaryOpBoundary(a, b_pos, op);

return MakeBound(std::min(e_neg.min_value, e_pos.min_value),
std::max(e_neg.max_value, e_pos.max_value));
}
// If the range of b does not have 0, use BinaryOpBoundary.
return BinaryOpBoundary(a, b, op);
}
/*!
* \brief Compute x + y, aware of inf.
* \param x The left operand.
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_arith_const_int_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ def test_truncdiv_bound():
assert bd.min_value == bd.NEG_INF
assert bd.max_value == bd.POS_INF

analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(-4, 12), override=True)
bd = analyzer.const_int_bound(tdiv(x, y))
assert bd.min_value == -9
assert bd.max_value == 9


def test_truncmod_bound():
analyzer = tvm.arith.Analyzer()
Expand Down Expand Up @@ -169,6 +175,12 @@ def test_floordiv_bound():
assert bd.min_value == bd.NEG_INF
assert bd.max_value == bd.POS_INF

analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(-4, 12), override=True)
bd = analyzer.const_int_bound(fld(x, y))
assert bd.min_value == -9
assert bd.max_value == 9


def test_floormod_bound():
analyzer = tvm.arith.Analyzer()
Expand Down