From 8703d9fb263be42b01d482c678fbb32e99e76adc Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 17 Jun 2019 21:51:33 -0700 Subject: [PATCH] [ARITH] Bugfix min/max const canonicalize rule (#3386) --- 3rdparty/dmlc-core | 2 +- src/arithmetic/rewrite_simplify.cc | 7 +++++-- tests/python/unittest/test_arith_rewrite_simplify.py | 2 ++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index fbe142b267a8..3943914eed66 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit fbe142b267a8edd1f1188fa2140d88f7ae308661 +Subproject commit 3943914eed66470bd010df581e29e4dca4f7df6f diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index ee3265618876..ea6530631880 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -813,7 +813,9 @@ Mutate_(const Min* op, const Expr& self) { // canonicalization TVM_TRY_RECURSIVE_REWRITE(min(min(x, c1), y), min(min(x, y), c1)); - TVM_TRY_RECURSIVE_REWRITE(min(c1 - x, c2), c1 - max(x, c2 - c1)); + TVM_TRY_RECURSIVE_REWRITE_IF( + min(c1 - x, c2), c1 - max(x, c1 - c2), + c2.Eval()->value != 0); } // condition rules. @@ -961,7 +963,8 @@ Mutate_(const Max* op, const Expr& self) { // canonicalization TVM_TRY_RECURSIVE_REWRITE(max(max(x, c1), y), max(max(x, y), c1)); - TVM_TRY_RECURSIVE_REWRITE(max(c1 - x, c2), c1 - min(x, c2 - c1)); + TVM_TRY_RECURSIVE_REWRITE_IF( + max(c1 - x, c2), c1 - min(x, c1 - c2), c2.Eval()->value != 0); } // condition rules. diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 596e54d338b5..07d460eee7fe 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -392,6 +392,7 @@ def test_min_index_simplify(): ck.verify(tvm.min(x / 10, y / 10), tvm.min(x, y) / 10) ck.verify(tvm.min(x / (-10), y / (-10)), tvm.max(x, y) / (-10)) ck.verify(tvm.min(x * 3, 9), tvm.min(x, 3) * 3) + ck.verify(tvm.min(3 - x, 2), 3 - tvm.max(x, 1)) def test_max_index_simplify(): @@ -448,6 +449,7 @@ def test_max_index_simplify(): ck.verify(tvm.max(x / 10, y / 10), tvm.max(x, y) / 10) ck.verify(tvm.max(x / (-10), y / (-10)), tvm.min(x, y) / (-10)) ck.verify(tvm.max(x * 3, 9), tvm.max(x, 3) * 3) + ck.verify(tvm.max(3 - x, 1), 3 - tvm.min(x, 2)) def test_cmp_simplify():