From 24df0bac7df49fd8d7b2f98820b0ab020257003b Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Fri, 19 Jun 2020 04:51:55 +0530 Subject: [PATCH] Additional canonicalization added for AddNode (#5846) --- python/tvm/tir/expr.py | 2 +- src/arith/rewrite_simplify.cc | 1 + tests/python/unittest/test_arith_rewrite_simplify.py | 10 ++++++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index f8cb05431a5b..3b580efe2b62 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -61,7 +61,7 @@ def __add__(self, other): return _generic.add(self, other) def __radd__(self, other): - return self.__add__(other) + return _generic.add(other, self) def __sub__(self, other): return _generic.subtract(self, other) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index ce3f2a6223f2..4887ef0ee47d 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -191,6 +191,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { // canonicalization rule // will try rewrite again after canonicalization. TVM_TRY_RECURSIVE_REWRITE(x + (c1 - y), (x - y) + c1); + TVM_TRY_RECURSIVE_REWRITE((c1 - y) + x, (x - y) + c1); TVM_TRY_RECURSIVE_REWRITE(x + c1 + y, (x + y) + c1); TVM_TRY_RECURSIVE_REWRITE(x + (c1 + y), (x + y) + c1); TVM_TRY_RECURSIVE_REWRITE(x + max(y, z), max(y, z) + x); diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 813e10a58707..53ba93dc65e7 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -184,6 +184,16 @@ def test_add_index_simplify(): ck.verify(y * x + 10 * x, x * (y + 10)) ck.verify(x * y + 10 * x, x * (y + 10)) + ck.verify((2 * z) + tvm.te.min(x, y - (2 * z)), tvm.te.min(x + (z * 2), y)) + ck.verify(y * x + x, x * (y + 1)) + ck.verify(x * y + x, x * (y + 1)) + ck.verify((x + 10) + 13, x + 23) + ck.verify((x + 10) + (13 + z), x + z + 23) + ck.verify(x * y + 10 * x, x * (y + 10)) + ck.verify(y * x + x * 3, x * (y + 3)) + ck.verify(x + 3 + y, x + y + 3) + ck.verify((3 - y) + x, x - y + 3) + # canonicalization ck.verify(x + 2 + 3 + 4 + x, x * 2 + 9);