Skip to content

Commit

Permalink
[ARITH] Bugfix: check arg positiveness for mod rules (#3279)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrechanik-h authored and tqchen committed Jun 3, 2019
1 parent fc2b2a0 commit 0faf731
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/arithmetic/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -634,18 +634,20 @@ Mutate_(const Mod* op, const Expr& self) {
TVM_TRY_REWRITE_IF((x * c1 + y) % c2, y % c2,
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual((x * c1).Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0));

TVM_TRY_REWRITE_IF((x + c1) % c2, x % c2,
c2.Eval()->value > 0 &&
c1.Eval()->value >= 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0));

TVM_TRY_REWRITE_IF((x + y * c1) % c2, x % c2,
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0));
CanProveGreaterEqual((y * c1).Eval(), 0));

// canonicalization: x % c == x % (-c) for truncated division
// NOTE: trunc div required
Expand Down
17 changes: 16 additions & 1 deletion tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,11 @@ def test_div_index_simplify():

def test_mod_index_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
x, y, nx, ny, z = tvm.var("x"), tvm.var("y"), tvm.var("nx"), tvm.var("ny"), tvm.var("z")
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.analyzer.update(nx, tvm.arith.ConstIntBound(-1000, 0), override=True)
ck.analyzer.update(ny, tvm.arith.ConstIntBound(-1000, 0), override=True)

ck.verify(x * 10 % 2, 0)
ck.verify((x * 10 + y) % 2, y % 2)
Expand All @@ -317,6 +319,19 @@ def test_mod_index_simplify():
ck.verify((x + y * 10) % -2, x % 2)
ck.verify((x* 10 + 1 + y * 2 + 2) % -2, 1)

ck.verify(x * (-10) % 2, 0)
ck.verify((x * (-10) + y) % 2, (x * (-10) + y) % 2)
ck.verify((x + (-10)) % 2, (x + (-10)) % 2)
ck.verify((x + y * (-10)) % 2, (x + y * (-10)) % 2)
ck.verify(x * (-10) % -2, 0)

ck.verify(nx * 10 % 2, 0)
ck.verify((nx * (-10) + y) % 2, y % 2)
ck.verify((x + ny * (-10)) % 2, x % 2)
ck.verify((nx * (-10) + 1 + ny * (-2) + 2) % 2, 1)
ck.verify(nx * 10 % -2, 0)
ck.verify((nx * (-10) + y) % -2, y % 2)
ck.verify((x + ny * (-10)) % -2, x % 2)

def test_min_index_simplify():
ck = RewriteChecker()
Expand Down

0 comments on commit 0faf731

Please sign in to comment.