Skip to content

Commit

Permalink
[ARITH] Add CombineInterval<Div> in IntSet (#48)
Browse files Browse the repository at this point in the history
* [FIX] add CombineInterval<Div>

* fix error message and add comment about rounding

* fix comment
  • Loading branch information
Ziheng Jiang authored and tqchen committed Feb 21, 2017
1 parent c8ec411 commit 3555769
Showing 1 changed file with 27 additions and 1 deletion.
28 changes: 27 additions & 1 deletion src/arithmetic/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ inline IntSet CombineInterval<Mul>(Interval a, Interval b) {
if (is_one(b.min)) return IntervalSet::make(a);
Expr e1 = a.has_lower_bound() ? ComputeExpr<Mul>(a.min, b.min) : a.min;
Expr e2 = a.has_upper_bound() ? ComputeExpr<Mul>(a.max, b.min) : a.max;
// This is relaxiation
// no relaxation is needed in here due to set is inclusive
// TODO(tqchen): consider convert to StrideSet.
if (is_positive_const(b.min)) {
return IntervalSet::make(e1, e2);
Expand All @@ -259,6 +259,32 @@ inline IntSet CombineInterval<Mul>(Interval a, Interval b) {
return IntSet::everything();
}

template<>
inline IntSet CombineInterval<Div>(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<Div>(a.min, b.min));
}
if (b.is_single_point()) {
if (is_zero(b.min)) {
LOG(FATAL) << "Divide by zero in CombineInterval Div";
}
if (is_one(b.min)) return IntervalSet::make(a);
Expr e1 = a.has_lower_bound() ? ComputeExpr<Div>(a.min, b.min) : a.min;
Expr e2 = a.has_upper_bound() ? ComputeExpr<Div>(a.max, b.min) : a.max;
// no relaxation is needed in here due to set is inclusive
if (is_positive_const(b.min)) {
return IntervalSet::make(e1, e2);
} else if (is_negative_const(b.min)) {
return IntervalSet::make(e2, e1);
} else if (a.is_bounded()) {
Expr cmp = b.min >= make_zero(b.min.type().element_of());
return IntervalSet::make(select(cmp, e1, e2), select(cmp, e2, e1));
}
}
LOG(WARNING) << "Return Everything in CombineInterval Div";
return IntSet::everything();
}

template<>
inline IntSet CombineInterval<Max>(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) {
Expand Down

0 comments on commit 3555769

Please sign in to comment.