Skip to content

Commit

Permalink
[ARITH] Bugfix div subtract rewrite rule (apache#3504)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Jul 7, 2019
1 parent f978887 commit eadc4e3
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
7 changes: 5 additions & 2 deletions src/arithmetic/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -342,13 +342,16 @@ Mutate_(const Sub* op, const Expr& self) {
c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);

// Proof in the case of floordiv, need positive condition.
// let x = a * c3 + r
// (x + c1) / c3 - x / c3 => (r + c1) / c3
TVM_TRY_REWRITE_IF((x + c1) / c3 - (x + c2) / c3,
((x + (c1 % c3)) % c3 + (c1 - c2)) / c3,
((x + ((c2 % c3) + c3) % c3) % c3 + (c1 - c2)) / c3,
CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) &&
c1.Eval()->value >= c2.Eval()->value &&
c3.Eval()->value > 0);
TVM_TRY_REWRITE_IF((x + c1) / c3 - x / c3,
((x + (c1 % c3)) % c3 + c1) / c3,
(x % c3 + c1) / c3,
CanProveGreaterEqual(x.Eval(), 0) &&
c1.Eval()->value >= 0 &&
c3.Eval()->value > 0);
Expand Down
5 changes: 4 additions & 1 deletion tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ def test_sub_index_simplify():
# div pattern
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify(x - (x / 3) * 3, x % 3)
ck.verify((x + 5) / 3 - x / 3, (((x + 2) % 3) + 5)/ 3)

ck.verify((x + 5) / 3 - x / 3, ((x % 3) + 5)/ 3)
ck.verify((x + 5) / 3 - (x + 1) / 3, (((x + 1) % 3) + 4)/ 3)

ck.verify(y - (y / (-5)) * (-5), y % 5)
ck.verify((y / 3) * 3 - y, 0 - y % 3)
Expand All @@ -258,6 +260,7 @@ def test_sub_index_simplify():
ck.verify(6 * ((y + z) / 3) - y * 2, (z - (y + z) % 3) * 2)
ck.verify(((y - z) / 3) * 6 - 2 * y, (0 - (y - z) % 3 - z) * 2)


def test_mul_index_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
Expand Down

0 comments on commit eadc4e3

Please sign in to comment.