diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 81a4d616d4323..78ae446d0321d 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -69,6 +69,7 @@ #include #include +#include #include #include "const_fold.h" @@ -145,6 +146,14 @@ class PEqualChecker { bool operator()(const IntImm& lhs, const IntImm& rhs) const { return lhs->value == rhs->value; } }; +template <> +class PEqualChecker { + public: + bool operator()(const FloatImm& lhs, const FloatImm& rhs) const { + return std::fabs(lhs->value - rhs->value) < 1e-20; + } +}; + template <> class PEqualChecker { public: diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index e9d640ad660f8..6c8498fa1912a 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -125,6 +125,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm PVar c1, c2, c3; + // Pattern var match FloatImm + PVar c4; // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules @@ -133,6 +135,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes), ramp(b1 + x, s1, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes), ramp(x + b1, s1, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) + broadcast(y, lanes), broadcast(x + y, lanes)); + TVM_TRY_REWRITE_IF(x + broadcast(c4, lanes), x, c4.Eval()->value == 0.0f); } if (IsIndexType(op->dtype)) { @@ -416,6 +419,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm PVar c1, c2; + // Pattern var match FloatImm + PVar c3; // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules @@ -423,6 +428,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), broadcast(x * y, lanes)); TVM_TRY_REWRITE(ramp(b1, s1, lanes) * broadcast(x, lanes), ramp(b1 * x, s1 * x, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) * ramp(b1, s1, lanes), ramp(b1 * x, s1 * x, lanes)); + TVM_TRY_REWRITE_IF(broadcast(c3, lanes) * x, broadcast(c3, lanes), c3.Eval()->value == 0.0f); } if (IsIndexType(op->dtype)) { diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index c01898635488d..ae7b432a9f338 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -40,6 +40,8 @@ def test_vector_simplify(): (y + x).astype("int32x2")) ck.verify(tvm.tir.Broadcast(0, 4) + y, tvm.tir.Broadcast(y, 4)) + ck.verify(tvm.tir.Ramp(x, 1, 4).astype('float32x4') + tvm.tir.Broadcast(0.0, 4), + tvm.tir.Ramp(x, 1, 4).astype('float32x4')) # Sub rules ck.verify(tvm.tir.Ramp(x, 4, 4) - tvm.tir.Ramp(y, 2, 4), tvm.tir.Ramp(x - y, 2, 4)) @@ -59,6 +61,8 @@ def test_vector_simplify(): tvm.tir.Ramp(x * 2, 8, 4)) ck.verify(tvm.tir.Broadcast(0, 4) * x, tvm.tir.Broadcast(0, 4)) + ck.verify(tvm.tir.Broadcast(0.0, 4) * x, + tvm.tir.Broadcast(0.0, 4)) ## DivMod rules tdiv = tvm.tir.truncdiv