Skip to content

Commit

Permalink
[ARITH] More recursive rewrite rule, cleanup simplify tests (apache#3502
Browse files Browse the repository at this point in the history
)
  • Loading branch information
tqchen authored Jul 7, 2019
1 parent eadc4e3 commit 2a7aebe
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 235 deletions.
30 changes: 0 additions & 30 deletions python/tvm/make.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
"""
from __future__ import absolute_import as _abs
from ._ffi.function import _init_api
from ._ffi.runtime_ctypes import TVMType


def range_by_min_extent(min_value, extent):
Expand All @@ -48,35 +47,6 @@ def range_by_min_extent(min_value, extent):
return _range_by_min_extent(min_value, extent)


def static_cast(dtype, expr):
"""Cast expr to dtype.
If expr is scalar and dtype is a corresponding vector
type, a Broadcast is generated. Otherwise it is a Cast.
Parameters
----------
dtype : str
The target data type.
expr : Expr
The expression to be casted.
Returns
-------
casted : Expr
The casted expression.
"""
target_type = TVMType(dtype)
src_type = TVMType(expr.dtype)
if target_type.type_code == src_type.type_code and src_type.bits == target_type.bits:
if src_type.lanes == target_type.lanes:
return expr
if src_type.lanes == 1 and target_type.lanes > 1:
return Broadcast(expr, target_type.lanes)
return Cast(dtype, expr)


def node(type_key, **kwargs):
"""Make a new DSL node by its type key and fields
Expand Down
4 changes: 2 additions & 2 deletions src/arithmetic/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1194,9 +1194,9 @@ Mutate_(const LT* op, const Expr& self) {
TVM_TRY_RECURSIVE_REWRITE(c1 - y < x, c1 < x + y);
TVM_TRY_RECURSIVE_REWRITE(c1 + y < x, c1 < x - y);


TVM_TRY_RECURSIVE_REWRITE(x + c1 < c2, x < c2 - c1);
TVM_TRY_RECURSIVE_REWRITE(x - c1 < c2, x < c2 + c1);
TVM_TRY_REWRITE(x - c1 < 0, x < c1);
TVM_TRY_REWRITE(x + c1 < c2, x < c2 - c1);
}
return ret;
}
Expand Down
35 changes: 16 additions & 19 deletions src/arithmetic/stmt_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,24 @@

namespace tvm {
namespace arith {
// statement simplifier

using namespace ir;

class StmtSimplifier : public IRMutator {
public:
using IRMutator::Mutate;

Expr Mutate(Expr expr) final {
return analyzer_.Simplify(expr);
}

Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
for (auto kv : vrange) {
analyzer_.Bind(kv.first, kv.second);
}
return Mutate(stmt);
}

Stmt Mutate_(const For* op, const Stmt& s) final {
Var loop_var(op->loop_var.node_);
analyzer_.Bind(loop_var, Range::make_by_min_extent(op->min, op->extent));
Expand Down Expand Up @@ -124,28 +137,12 @@ class StmtSimplifier : public IRMutator {
std::unordered_map<const Variable*, Range> var_dom_;
};


class CanonicalStmtSimplifier : public StmtSimplifier {
public:
using StmtSimplifier::Mutate;
Expr Mutate(Expr expr) final {
return analyzer_.canonical_simplify(expr);
}

Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
for (auto kv : vrange) {
analyzer_.Bind(kv.first, kv.second);
}
return Mutate(stmt);
}
};

} // namespace arith

namespace ir {

Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
return arith::CanonicalStmtSimplifier().CanonicalSimplify(
return arith::StmtSimplifier().Simplify(
stmt, vrange);
}

Expand All @@ -167,7 +164,7 @@ Expr Simplify(Expr expr, Map<Var, Range> vrange) {
}

Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
return arith::CanonicalStmtSimplifier().CanonicalSimplify(
return arith::StmtSimplifier().Simplify(
stmt, vrange);
}
} // namespace ir
Expand Down
6 changes: 6 additions & 0 deletions tests/python/unittest/test_arith_canonical_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def test_canonical_mixed():
z = tvm.const(3, "int32")
ck.verify(x / (z*z) - x / (z*z), 0)
ck.verify(x / (z+z) - x / (z+z), 0)
ck.verify(x - 2 < 3, x < 5)
ck.verify(tvm.max(x, 1) - tvm.max(x, 1), 0)
ck.verify(tvm.min(x, 1) - tvm.min(x, 1), 0)
ck.verify(x * x - x * x, 0)


def test_reduce_combiner_simplify():
Expand Down Expand Up @@ -211,6 +215,8 @@ def test_complex_cases():
ck.verify(res3, ((((x*1024) + y)/256) - (y/256)) - (x*4))




if __name__ == "__main__":
test_simplify_if_then_else()
test_div_simplify()
Expand Down
125 changes: 0 additions & 125 deletions tests/python/unittest/test_arith_simplify.py

This file was deleted.

58 changes: 0 additions & 58 deletions tests/python/unittest/test_arith_stmt_simplify.py

This file was deleted.

18 changes: 18 additions & 0 deletions tests/python/unittest/test_lang_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,25 @@ def check_throws(f):
assert tvm.any(x, true).same_as(true)
assert tvm.any(true, x).same_as(true)


def test_const_fold4():
x1 = tvm.const(4, "int32")
x2 = x1 + 5
assert isinstance(x2, tvm.expr.IntImm) and x2.value == 9
x3 = x2 / 3
assert isinstance(x3, tvm.expr.IntImm) and x3.value == 3
x4 = x3 + 0.55
assert isinstance(x4, tvm.expr.FloatImm) and abs(x4.value - 3.55) < 1e-6
x5 = tvm.ceil(x4)
assert isinstance(x5, tvm.expr.FloatImm) and x5.value == 4
x6 = x5.astype('int')
assert isinstance(x6, tvm.expr.IntImm) and x6.value == 4, "x6={}".format(x6)
y = (tvm.round((tvm.const(6.5, 'float32') - 1) / 1.5) + 2).astype('int')
assert isinstance(y, tvm.expr.IntImm) and y.value == 6


if __name__ == "__main__":
test_const_fold()
test_const_fold2()
test_const_fold3()
test_const_fold4()
2 changes: 1 addition & 1 deletion topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,4 +342,4 @@ def cast(x, dtype):
if isinstance(x, tvm.tensor.Tensor):
return tvm.compute(
x.shape, lambda *i: x(*i).astype(dtype), tag=tag.ELEMWISE)
return tvm.make.static_cast(dtype, x)
return tvm.make._cast(dtype, x)

0 comments on commit 2a7aebe

Please sign in to comment.