From 0faf7310d94ffaad903c65a32516f102d291e435 Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Mon, 3 Jun 2019 18:52:31 +0300 Subject: [PATCH] [ARITH] Bugfix: check arg positiveness for mod rules (#3279) --- src/arithmetic/rewrite_simplify.cc | 4 +++- .../unittest/test_arith_rewrite_simplify.py | 17 ++++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 00198d9b140a..ee3265618876 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -634,10 +634,12 @@ Mutate_(const Mod* op, const Expr& self) { TVM_TRY_REWRITE_IF((x * c1 + y) % c2, y % c2, c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual((x * c1).Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); TVM_TRY_REWRITE_IF((x + c1) % c2, x % c2, c2.Eval()->value > 0 && + c1.Eval()->value >= 0 && c1.Eval()->value % c2.Eval()->value == 0 && CanProveGreaterEqual(x.Eval(), 0)); @@ -645,7 +647,7 @@ Mutate_(const Mod* op, const Expr& self) { c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); + CanProveGreaterEqual((y * c1).Eval(), 0)); // canonicalization: x % c == x % (-c) for truncated division // NOTE: trunc div required diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 1b03253c9a0f..ee113e101cce 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -302,9 +302,11 @@ def test_div_index_simplify(): def test_mod_index_simplify(): ck = RewriteChecker() - x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") + x, y, nx, ny, z = tvm.var("x"), tvm.var("y"), tvm.var("nx"), tvm.var("ny"), tvm.var("z") ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True) + ck.analyzer.update(nx, tvm.arith.ConstIntBound(-1000, 0), override=True) + ck.analyzer.update(ny, tvm.arith.ConstIntBound(-1000, 0), override=True) ck.verify(x * 10 % 2, 0) ck.verify((x * 10 + y) % 2, y % 2) @@ -317,6 +319,19 @@ def test_mod_index_simplify(): ck.verify((x + y * 10) % -2, x % 2) ck.verify((x* 10 + 1 + y * 2 + 2) % -2, 1) + ck.verify(x * (-10) % 2, 0) + ck.verify((x * (-10) + y) % 2, (x * (-10) + y) % 2) + ck.verify((x + (-10)) % 2, (x + (-10)) % 2) + ck.verify((x + y * (-10)) % 2, (x + y * (-10)) % 2) + ck.verify(x * (-10) % -2, 0) + + ck.verify(nx * 10 % 2, 0) + ck.verify((nx * (-10) + y) % 2, y % 2) + ck.verify((x + ny * (-10)) % 2, x % 2) + ck.verify((nx * (-10) + 1 + ny * (-2) + 2) % 2, 1) + ck.verify(nx * 10 % -2, 0) + ck.verify((nx * (-10) + y) % -2, y % 2) + ck.verify((x + ny * (-10)) % -2, x % 2) def test_min_index_simplify(): ck = RewriteChecker()