From 80ec68218f6e8a7ecc24fa8eaf72ec3d942bfa00 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 6 Jul 2019 22:48:43 -0700 Subject: [PATCH] [ARITH] More recursive rewrite rule, cleanup simplify tests (#3502) --- python/tvm/make.py | 30 ----- src/arithmetic/rewrite_simplify.cc | 4 +- src/arithmetic/stmt_simplify.cc | 35 +++-- .../unittest/test_arith_canonical_simplify.py | 6 + tests/python/unittest/test_arith_simplify.py | 125 ------------------ .../unittest/test_arith_stmt_simplify.py | 58 -------- tests/python/unittest/test_lang_operator.py | 18 +++ topi/python/topi/math.py | 2 +- 8 files changed, 43 insertions(+), 235 deletions(-) delete mode 100644 tests/python/unittest/test_arith_simplify.py delete mode 100644 tests/python/unittest/test_arith_stmt_simplify.py diff --git a/python/tvm/make.py b/python/tvm/make.py index 7439952ad7ad..241edd6b0948 100644 --- a/python/tvm/make.py +++ b/python/tvm/make.py @@ -24,7 +24,6 @@ """ from __future__ import absolute_import as _abs from ._ffi.function import _init_api -from ._ffi.runtime_ctypes import TVMType def range_by_min_extent(min_value, extent): @@ -48,35 +47,6 @@ def range_by_min_extent(min_value, extent): return _range_by_min_extent(min_value, extent) -def static_cast(dtype, expr): - """Cast expr to dtype. - - If expr is scalar and dtype is a corresponding vector - type, a Broadcast is generated. Otherwise it is a Cast. - - Parameters - ---------- - dtype : str - The target data type. - - expr : Expr - The expression to be casted. - - Returns - ------- - casted : Expr - The casted expression. - """ - target_type = TVMType(dtype) - src_type = TVMType(expr.dtype) - if target_type.type_code == src_type.type_code and src_type.bits == target_type.bits: - if src_type.lanes == target_type.lanes: - return expr - if src_type.lanes == 1 and target_type.lanes > 1: - return Broadcast(expr, target_type.lanes) - return Cast(dtype, expr) - - def node(type_key, **kwargs): """Make a new DSL node by its type key and fields diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 06a28b519492..773f6c3a85c4 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -1194,9 +1194,9 @@ Mutate_(const LT* op, const Expr& self) { TVM_TRY_RECURSIVE_REWRITE(c1 - y < x, c1 < x + y); TVM_TRY_RECURSIVE_REWRITE(c1 + y < x, c1 < x - y); - + TVM_TRY_RECURSIVE_REWRITE(x + c1 < c2, x < c2 - c1); + TVM_TRY_RECURSIVE_REWRITE(x - c1 < c2, x < c2 + c1); TVM_TRY_REWRITE(x - c1 < 0, x < c1); - TVM_TRY_REWRITE(x + c1 < c2, x < c2 - c1); } return ret; } diff --git a/src/arithmetic/stmt_simplify.cc b/src/arithmetic/stmt_simplify.cc index 162cb1e5fd16..fc6b92a87ce1 100644 --- a/src/arithmetic/stmt_simplify.cc +++ b/src/arithmetic/stmt_simplify.cc @@ -31,11 +31,24 @@ namespace tvm { namespace arith { -// statement simplifier + using namespace ir; class StmtSimplifier : public IRMutator { public: + using IRMutator::Mutate; + + Expr Mutate(Expr expr) final { + return analyzer_.Simplify(expr); + } + + Stmt Simplify(Stmt stmt, Map vrange) { + for (auto kv : vrange) { + analyzer_.Bind(kv.first, kv.second); + } + return Mutate(stmt); + } + Stmt Mutate_(const For* op, const Stmt& s) final { Var loop_var(op->loop_var.node_); analyzer_.Bind(loop_var, Range::make_by_min_extent(op->min, op->extent)); @@ -124,28 +137,12 @@ class StmtSimplifier : public IRMutator { std::unordered_map var_dom_; }; - -class CanonicalStmtSimplifier : public StmtSimplifier { - public: - using StmtSimplifier::Mutate; - Expr Mutate(Expr expr) final { - return analyzer_.canonical_simplify(expr); - } - - Stmt CanonicalSimplify(Stmt stmt, Map vrange) { - for (auto kv : vrange) { - analyzer_.Bind(kv.first, kv.second); - } - return Mutate(stmt); - } -}; - } // namespace arith namespace ir { Stmt CanonicalSimplify(Stmt stmt, Map vrange) { - return arith::CanonicalStmtSimplifier().CanonicalSimplify( + return arith::StmtSimplifier().Simplify( stmt, vrange); } @@ -167,7 +164,7 @@ Expr Simplify(Expr expr, Map vrange) { } Stmt Simplify(Stmt stmt, Map vrange) { - return arith::CanonicalStmtSimplifier().CanonicalSimplify( + return arith::StmtSimplifier().Simplify( stmt, vrange); } } // namespace ir diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 56d2bb1e67e0..d38dfac77a1f 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -81,6 +81,10 @@ def test_canonical_mixed(): z = tvm.const(3, "int32") ck.verify(x / (z*z) - x / (z*z), 0) ck.verify(x / (z+z) - x / (z+z), 0) + ck.verify(x - 2 < 3, x < 5) + ck.verify(tvm.max(x, 1) - tvm.max(x, 1), 0) + ck.verify(tvm.min(x, 1) - tvm.min(x, 1), 0) + ck.verify(x * x - x * x, 0) def test_reduce_combiner_simplify(): @@ -211,6 +215,8 @@ def test_complex_cases(): ck.verify(res3, ((((x*1024) + y)/256) - (y/256)) - (x*4)) + + if __name__ == "__main__": test_simplify_if_then_else() test_div_simplify() diff --git a/tests/python/unittest/test_arith_simplify.py b/tests/python/unittest/test_arith_simplify.py deleted file mode 100644 index b8c5a3a16d5c..000000000000 --- a/tests/python/unittest/test_arith_simplify.py +++ /dev/null @@ -1,125 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm - -def csimplify(z): - return tvm.ir_pass.CanonicalSimplify( - tvm.make.Evaluate(z)).value - -def test_simplify(): - x = tvm.var('n') - z = x * 4 - x * 2 - zz = csimplify(z) - assert zz.b.value == 2 - - z = (x / 4) * 2 - (x / 4) - zz = csimplify(z) - assert zz.a == x and zz.b.value == 4 - - z = (x % 4) * 3 + (x % 4) - zz = csimplify(z) - assert zz.b.value == 4 - zz = zz.a - assert zz.a == x and zz.b.value == 4 - - n = tvm.var('n') - assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n % 1), tvm.const(0, "int32")) - assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / 1), n) - tvm.ir_pass.CanonicalSimplify(n / (-1)) - # This is not true in the current implementation - # assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / (-1)), - # tvm.ir_pass.CanonicalSimplify(-n)) - -def test_simplify_mod(): - ib = tvm.ir_builder.create() - n = tvm.var('n') - A = ib.pointer("float32", name="A") - with ib.for_range(0, 10, name="j") as j: - with ib.for_range(0, 16, name="i") as i: - A[i] = A[(j * 32 + i+1) % 16] - body = ib.get() - stmt = tvm.ir_pass.CanonicalSimplify(body) - diff = tvm.ir_pass.CanonicalSimplify(stmt.body.body.value.index - (1 + i) % 16) - assert diff.value == 0 - # if we can't prove that j is non-negative, we can't prove that (j+16) % 16 is j%16 - index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16) - assert index != j - index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16, {j: tvm.Range(0, 6)}) - assert index == j - # if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16 - index = tvm.ir_pass.CanonicalSimplify( - (j + n * 32) % 16, {j: tvm.Range(0, 6)}) - assert index != j - index = tvm.ir_pass.CanonicalSimplify( - (j + n * 32) % 16, {j: tvm.Range(0, 6), n: tvm.Range(0, 10)}) - assert index == j - -def test_simplify_minmax(): - x = tvm.var('x') - e1 = tvm.max(x, 1) - tvm.max(x, 1) - e1s = tvm.ir_pass.CanonicalSimplify(e1) - assert e1s.value == 0 - - e2 = tvm.min(x, 1) - tvm.min(x, 1) - e2s = tvm.ir_pass.CanonicalSimplify(e2) - assert e2s.value == 0 - -def test_mul(): - x = tvm.var('x') - e = x * x - x * x - es = tvm.ir_pass.CanonicalSimplify(e) - assert es.value == 0 - -def test_modular(): - rx = tvm.var("rx") - ry = tvm.var("ry") - y = tvm.var("y") - x = tvm.var("x") - i32_const = lambda x: tvm.const(x, "int32") - vmap = {rx: tvm.Range(i32_const(0), i32_const(3)), - ry: tvm.Range(i32_const(0), i32_const(3)), - y: tvm.Range(i32_const(0), i32_const(2)), - x: tvm.Range(i32_const(0), i32_const(14))} - idx = ry * 16 + rx + y * 16 + x - z2 = tvm.ir_pass.CanonicalSimplify(idx % 16, vmap) - z1 = tvm.ir_pass.CanonicalSimplify(idx // 16, vmap) - assert tvm.ir_pass.CanonicalSimplify(z1 - (ry + y)).value == 0 - assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0 - -def test_const_propagation(): - x1 = tvm.const(4, "int32") - x2 = x1 + 5 - assert isinstance(x2, tvm.expr.IntImm) and x2.value == 9 - x3 = x2 / 3 - assert isinstance(x3, tvm.expr.IntImm) and x3.value == 3 - x4 = x3 + 0.5 - assert isinstance(x4, tvm.expr.FloatImm) and x4.value == 3.5 - x5 = tvm.ceil(x4) - assert isinstance(x5, tvm.expr.FloatImm) and x5.value == 4 - x6 = x5.astype('int') - assert isinstance(x6, tvm.expr.IntImm) and x6.value == 4 - y = (tvm.round((tvm.const(6.5, 'float32') - 1) / 1.5) + 2).astype('int') - assert isinstance(y, tvm.expr.IntImm) and y.value == 6 - - -if __name__ == "__main__": - test_modular() - test_simplify() - test_mul() - test_simplify_minmax() - test_const_propagation() - test_simplify_mod() diff --git a/tests/python/unittest/test_arith_stmt_simplify.py b/tests/python/unittest/test_arith_stmt_simplify.py deleted file mode 100644 index 44c301ab7ef6..000000000000 --- a/tests/python/unittest/test_arith_stmt_simplify.py +++ /dev/null @@ -1,58 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -import numpy -from tvm import comm_reducer -from tvm.ir_pass import Simplify, CanonicalSimplify, Equal - -def test_simplify(): - """Not yet working, mock design""" - dtype = 'int64' - n = tvm.var('n') - Ab = tvm.decl_buffer((n, ), dtype) - i = tvm.var('i') - j = tvm.var('j') - # for i in 0 to n-1: - stmt = tvm.make.For( - i, 2, n, 0, 0, - tvm.make.For(j, 0, n, 0, 0, - tvm.make.IfThenElse( - tvm.make.LT(i + 2, n), - tvm.make.Store(Ab.data, - tvm.make.Load(dtype, Ab.data, i + 4) + 1, - (j + 1) * 4 - 4 * j + i), - None))) - stmt = tvm.ir_pass.CanonicalSimplify(stmt) - - -def test_basic(): - m = tvm.var('m') - ret = tvm.ir_pass.CanonicalSimplify(tvm.make.Evaluate(m-1)) - assert str(ret.value) == "(m - 1)" - - -def test_bound(): - m = tvm.var('m') - vrange = tvm.convert({m: tvm.Range(tvm.const(0, "int32"), tvm.const(10, "int32"))}) - ret = tvm.ir_pass.Simplify(m % 10, vrange) - assert ret == m - - -if __name__ == "__main__": - test_bound() - test_basic() - test_simplify() diff --git a/tests/python/unittest/test_lang_operator.py b/tests/python/unittest/test_lang_operator.py index da309815b9f4..8e7dcba3230c 100644 --- a/tests/python/unittest/test_lang_operator.py +++ b/tests/python/unittest/test_lang_operator.py @@ -83,7 +83,25 @@ def check_throws(f): assert tvm.any(x, true).same_as(true) assert tvm.any(true, x).same_as(true) + +def test_const_fold4(): + x1 = tvm.const(4, "int32") + x2 = x1 + 5 + assert isinstance(x2, tvm.expr.IntImm) and x2.value == 9 + x3 = x2 / 3 + assert isinstance(x3, tvm.expr.IntImm) and x3.value == 3 + x4 = x3 + 0.55 + assert isinstance(x4, tvm.expr.FloatImm) and abs(x4.value - 3.55) < 1e-6 + x5 = tvm.ceil(x4) + assert isinstance(x5, tvm.expr.FloatImm) and x5.value == 4 + x6 = x5.astype('int') + assert isinstance(x6, tvm.expr.IntImm) and x6.value == 4, "x6={}".format(x6) + y = (tvm.round((tvm.const(6.5, 'float32') - 1) / 1.5) + 2).astype('int') + assert isinstance(y, tvm.expr.IntImm) and y.value == 6 + + if __name__ == "__main__": test_const_fold() test_const_fold2() test_const_fold3() + test_const_fold4() diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index 5a1742b12c56..406d48969682 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -342,4 +342,4 @@ def cast(x, dtype): if isinstance(x, tvm.tensor.Tensor): return tvm.compute( x.shape, lambda *i: x(*i).astype(dtype), tag=tag.ELEMWISE) - return tvm.make.static_cast(dtype, x) + return tvm.make._cast(dtype, x)