From 397cf8781eba7a2bcc35e832130801c1d1419c43 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 15 Sep 2022 06:39:20 -0500 Subject: [PATCH] [Arith][Refactor] Return Optional from TryConstFold (#12784) Prior to this commit, the templated `TryConstFold` utility returned an undefined `PrimExpr` to represent a failure to perform constant folding. This commit makes this explicit by returning `Optional` instead. --- src/arith/canonical_simplify.cc | 21 +++----- src/arith/const_fold.h | 91 +++++++++++++++++---------------- src/arith/int_set.cc | 10 ++-- src/arith/iter_affine_map.cc | 15 ++---- src/arith/pattern_match.h | 3 +- src/arith/rewrite_simplify.cc | 42 +++++---------- src/tir/op/op.cc | 57 +++++++-------------- 7 files changed, 99 insertions(+), 140 deletions(-) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 9f45317cba11..f5d2667aa64e 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -716,8 +716,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const AddNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(a, b)) return const_res.value(); // canonical form simplification. SumExpr ret = ToSumExpr(std::move(a)); @@ -741,8 +740,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const SubNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(a, b)) return const_res.value(); // canonical form simplification. SumExpr ret = ToSumExpr(std::move(a)); @@ -766,8 +764,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(a, b)) return const_res.value(); // x * c if (a.as()) { @@ -870,8 +867,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold
(a, b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold
(a, b)) return const_res.value(); PVar c1; // x / c1 if (c1.Match(b) && c1.Eval()->value > 0) { @@ -928,8 +924,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(a, b)) return const_res.value(); PVar c1; // x / c1 if (c1.Match(b) && c1.Eval()->value > 0) { @@ -1037,8 +1032,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(a, b)) return const_res.value(); PVar c1; // x % c1 @@ -1105,8 +1099,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(a, b)) return const_res.value(); PVar c1; // x % c1 diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index d0e09a1a7429..a7466cf38c85 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -24,6 +24,7 @@ #ifndef TVM_ARITH_CONST_FOLD_H_ #define TVM_ARITH_CONST_FOLD_H_ +#include #include #include @@ -44,10 +45,10 @@ namespace arith { * \tparam Op The operator type. * * \note a and b Must already matched data types with each other. - * \return nullptr if constant fold fails, otherwise return folded result. + * \return NullOpt if constant fold fails, otherwise return folded result. */ template -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b); +inline Optional TryConstFold(PrimExpr a, PrimExpr b); /*! * \brief Try to run unary compute with constant folding. @@ -56,10 +57,10 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b); * \tparam Op The operator type. * * \note a and b Must already matched data types with each other. - * \return nullptr if constant fold fails, otherwise return folded result. + * \return NullOpt if constant fold fails, otherwise return folded result. */ template -inline PrimExpr TryConstFold(PrimExpr a); +inline Optional TryConstFold(PrimExpr a); /*! * \brief Check whether type is used to represent index. @@ -126,7 +127,7 @@ inline double GetFoldResultDoubleRepr(float x) { // specialization of constant folders. template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -142,17 +143,17 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } else if (rtype.bits() == 64) { return FloatImm(rtype, fa->value + fb->value); } else { - return PrimExpr(); + return NullOpt; } } if (fa && fa->value == 0) return b; if (fb && fb->value == 0) return a; }); - return PrimExpr(); + return NullOpt; } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ ICHECK(!((pa && pa->dtype.is_uint() && pa->value == 0U) && (pb && pb->dtype.is_uint() && pb->value > 0U))) @@ -171,16 +172,16 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } else if (rtype.bits() == 64) { return FloatImm(rtype, fa->value - fb->value); } else { - return PrimExpr(); + return NullOpt; } } if (fb && fb->value == 0) return a; }); - return PrimExpr(); + return NullOpt; } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -202,7 +203,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } else if (rtype.bits() == 64) { return FloatImm(rtype, fa->value * fb->value); } else { - return PrimExpr(); + return NullOpt; } } if (fa) { @@ -214,11 +215,11 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { if (fb->value == 0) return b; } }); - return PrimExpr(); + return NullOpt; } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -242,7 +243,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } else if (rtype.bits() == 64) { return FloatImm(rtype, fa->value / fb->value); } else { - return PrimExpr(); + return NullOpt; } } if (fa && fa->value == 0) return a; @@ -251,11 +252,11 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { ICHECK_NE(fb->value, 0) << "Divide by zero"; } }); - return PrimExpr(); + return NullOpt; } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -271,11 +272,11 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { ICHECK_NE(pb->value, 0) << "Divide by zero"; } }); - return PrimExpr(); + return NullOpt; } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -297,7 +298,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } else if (rtype.bits() == 64) { return FloatImm(rtype, std::floor(fa->value / fb->value)); } else { - return PrimExpr(); + return NullOpt; } } if (fa && fa->value == 0) return a; @@ -306,11 +307,11 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { ICHECK_NE(fb->value, 0) << "Divide by zero"; } }); - return PrimExpr(); + return NullOpt; } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -326,114 +327,114 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { ICHECK_NE(pb->value, 0) << "Divide by zero"; } }); - return PrimExpr(); + return NullOpt; } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value)); }); if (a.same_as(b)) return a; - return PrimExpr(); + return NullOpt; } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value)); }); if (a.same_as(b)) return a; - return PrimExpr(); + return NullOpt; } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value); }); - return PrimExpr(); + return NullOpt; } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value); }); - return PrimExpr(); + return NullOpt; } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value); }); - return PrimExpr(); + return NullOpt; } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value); }); - return PrimExpr(); + return NullOpt; } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value); }); - return PrimExpr(); + return NullOpt; } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value); }); - return PrimExpr(); + return NullOpt; } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline Optional TryConstFold(PrimExpr a, PrimExpr b) { const IntImmNode* pa = a.as(); const IntImmNode* pb = b.as(); if (pa && pa->value) return b; if (pa && !pa->value) return a; if (pb && pb->value) return a; if (pb && !pb->value) return b; - return PrimExpr(); + return NullOpt; } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline Optional TryConstFold(PrimExpr a, PrimExpr b) { const IntImmNode* pa = a.as(); const IntImmNode* pb = b.as(); if (pa && pa->value) return a; if (pa && !pa->value) return b; if (pb && pb->value) return b; if (pb && !pb->value) return a; - return PrimExpr(); + return NullOpt; } template <> -inline PrimExpr TryConstFold(PrimExpr a) { +inline Optional TryConstFold(PrimExpr a) { const IntImmNode* pa = a.as(); if (pa) { return IntImm(DataType::UInt(1), !(pa->value)); } - return PrimExpr(); + return NullOpt; } /*! \brief Helper namespace for symbolic value limits */ diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index e8e223ceca09..35b12bb35238 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -108,9 +108,13 @@ TVM_DECLARE_LOGICAL_OP(Not); template inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, DataType dtype) { if (a->IsSinglePoint() && b->IsSinglePoint()) { - PrimExpr res = TryConstFold(a->min_value, b->min_value); - if (!res.defined()) res = Op(a->min_value, b->min_value); - return IntervalSet::SinglePoint(res); + PrimExpr expr; + if (auto res = TryConstFold(a->min_value, b->min_value)) { + expr = res.value(); + } else { + expr = Op(a->min_value, b->min_value); + } + return IntervalSet::SinglePoint(expr); } if (is_logical_op::value) { return IntervalSet(make_const(dtype, 0), make_const(dtype, 1)); diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 83e2821c9800..182eada24d96 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1205,8 +1205,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) { PrimExpr b = this->DirectMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(a, b)) return const_res.value(); // does not contain iter map. if (!a->IsInstance() && !b->IsInstance()) { if (op->a.same_as(a) && op->b.same_as(b)) { @@ -1240,8 +1239,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const SubNode* op) { PrimExpr b = this->DirectMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(a, b)) return const_res.value(); // does not contain iter map. if (!a->IsInstance() && !b->IsInstance()) { @@ -1276,8 +1274,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { PrimExpr b = this->DirectMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(a, b)) return const_res.value(); // does not contain iter map. if (!a->IsInstance() && !b->IsInstance()) { @@ -1572,8 +1569,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { PrimExpr b = this->DirectMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(a, b)) return const_res.value(); // does not contain iter map. if (!a->IsInstance() && !b->IsInstance()) { @@ -1657,8 +1653,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { PrimExpr b = this->DirectMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(a, b)) return const_res.value(); // does not contain iter map. if (!a->IsInstance() && !b->IsInstance()) { diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 6abcc728fc8d..69f064e11931 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -330,8 +330,7 @@ class PBinaryExpr : public Pattern> { PrimExpr Eval() const { PrimExpr lhs = a_.Eval(); PrimExpr rhs = b_.Eval(); - PrimExpr ret = TryConstFold(lhs, rhs); - if (ret.defined()) return ret; + if (auto ret = TryConstFold(lhs, rhs)) return ret.value(); return OpType(lhs, rhs); } diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index d7866fc1307b..e3e9db62d0bd 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -124,8 +124,7 @@ void RewriteSimplifier::Impl::Update(const Var& var, const PrimExpr& info, bool PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm @@ -258,8 +257,7 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm @@ -450,8 +448,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm @@ -490,8 +487,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold
(op->a, op->b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold
(op->a, op->b)) return const_res.value(); // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm @@ -666,8 +662,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); // Pattern var to match any expression PVar x, y, z, b1; @@ -748,8 +743,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm @@ -895,8 +889,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); // Pattern var to match any expression PVar x, y, z, b1; @@ -977,8 +970,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); // Pattern var to match any expression PVar x, y, z, s1, s2; @@ -1149,8 +1141,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); // Pattern var to match any expression PVar x, y, z, s1, s2; @@ -1327,8 +1318,7 @@ Optional RewriteSimplifier::Impl::TryMatchLiteralConstraint(const Prim PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); // Pattern var to match any expression @@ -1376,8 +1366,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GENode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); // Pattern var to match any expression @@ -1508,8 +1497,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(op->a)) return const_res.value(); if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); // Pattern var to match any expression @@ -1534,8 +1522,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); // Pattern var to match any expression @@ -1574,8 +1561,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); - if (const_res.defined()) return const_res; + if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); // Pattern var to match any expression diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index b9e0c3c37068..509badbebb92 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -327,8 +327,7 @@ PrimExpr operator+(PrimExpr a, PrimExpr b) { return add(a, b); } PrimExpr add(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - PrimExpr ret = arith::TryConstFold(a, b); - if (ret.defined()) return ret; + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); return tir::Add(a, b, span); } @@ -349,23 +348,20 @@ PrimExpr operator-(PrimExpr a, PrimExpr b) { return sub(a, b); } PrimExpr sub(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - PrimExpr ret = arith::TryConstFold(a, b); - if (ret.defined()) return ret; + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); return tir::Sub(a, b, span); } PrimExpr operator*(PrimExpr a, PrimExpr b) { return mul(a, b); } PrimExpr mul(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - PrimExpr ret = arith::TryConstFold(a, b); - if (ret.defined()) return ret; + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); return tir::Mul(a, b, span); } PrimExpr div(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - PrimExpr ret = arith::TryConstFold(a, b); - if (ret.defined()) return ret; + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); return tir::Div(a, b, span); } @@ -377,8 +373,7 @@ PrimExpr truncdiv(PrimExpr a, PrimExpr b, Span span) { PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - PrimExpr ret = arith::TryConstFold(a, b); - if (ret.defined()) return ret; + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); return tir::Mod(a, b, span); } @@ -397,8 +392,7 @@ PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span) { ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b, span); - PrimExpr ret = arith::TryConstFold(a, b); - if (ret.defined()) return ret; + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); return tir::FloorDiv(a, b, span); } @@ -406,8 +400,7 @@ PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span) { ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b, span); - PrimExpr ret = arith::TryConstFold(a + b - 1, b); - if (ret.defined()) return ret; + if (auto ret = arith::TryConstFold(a + b - 1, b)) return ret.value(); return tir::FloorDiv(a + b - 1, b, span); } @@ -415,8 +408,7 @@ PrimExpr floormod(PrimExpr a, PrimExpr b, Span span) { ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b, span); - PrimExpr ret = arith::TryConstFold(a, b); - if (ret.defined()) return ret; + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); return tir::FloorMod(a, b, span); } @@ -429,8 +421,7 @@ PrimExpr min(PrimExpr a, PrimExpr b, Span span) { if (is_pos_inf(b)) return a; if (is_neg_inf(b)) return b; BinaryOpMatchTypes(a, b, span); - PrimExpr ret = arith::TryConstFold(a, b); - if (ret.defined()) return ret; + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); return tir::Min(a, b, span); } @@ -443,8 +434,7 @@ PrimExpr max(PrimExpr a, PrimExpr b, Span span) { if (is_pos_inf(b)) return b; if (is_neg_inf(b)) return a; BinaryOpMatchTypes(a, b, span); - PrimExpr ret = arith::TryConstFold(a, b); - if (ret.defined()) return ret; + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); return tir::Max(a, b, span); } @@ -475,48 +465,42 @@ PrimExpr likely(PrimExpr cond, Span span) { PrimExpr operator>(PrimExpr a, PrimExpr b) { return greater(a, b); } PrimExpr greater(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - PrimExpr ret = arith::TryConstFold(a, b); - if (ret.defined()) return ret; + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); return tir::GT(a, b, span); } PrimExpr operator>=(PrimExpr a, PrimExpr b) { return greater_equal(a, b); } PrimExpr greater_equal(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - PrimExpr ret = arith::TryConstFold(a, b); - if (ret.defined()) return ret; + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); return tir::GE(a, b, span); } PrimExpr operator<(PrimExpr a, PrimExpr b) { return less(a, b); } PrimExpr less(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - PrimExpr ret = arith::TryConstFold(a, b); - if (ret.defined()) return ret; + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); return tir::LT(a, b, span); } PrimExpr operator<=(PrimExpr a, PrimExpr b) { return less_equal(a, b); } PrimExpr less_equal(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - PrimExpr ret = arith::TryConstFold(a, b); - if (ret.defined()) return ret; + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); return tir::LE(a, b, span); } PrimExpr operator==(PrimExpr a, PrimExpr b) { return equal(a, b); } PrimExpr equal(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - PrimExpr ret = arith::TryConstFold(a, b); - if (ret.defined()) return ret; + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); return tir::EQ(a, b, span); } PrimExpr operator!=(PrimExpr a, PrimExpr b) { return not_equal(a, b); } PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - PrimExpr ret = arith::TryConstFold(a, b); - if (ret.defined()) return ret; + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); return tir::NE(a, b, span); } @@ -551,24 +535,21 @@ void type_check_integer_args(const PrimExpr& lhs, const PrimExpr& rhs, const cha PrimExpr operator&&(PrimExpr a, PrimExpr b) { return logical_and(a, b); } PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span) { type_check_boolean_args(a, b, "&& operator (logical AND)"); - PrimExpr ret = arith::TryConstFold(a, b); - if (ret.defined()) return ret; + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); return tir::And(a, b, span); } PrimExpr operator||(PrimExpr a, PrimExpr b) { return logical_or(a, b); } PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span) { type_check_boolean_args(a, b, "|| operator (logical OR)"); - PrimExpr ret = arith::TryConstFold(a, b); - if (ret.defined()) return ret; + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); return tir::Or(a, b, span); } PrimExpr operator!(PrimExpr a) { return logical_not(a); } PrimExpr logical_not(PrimExpr a, Span span) { type_check_boolean_args(a, "! operator (logical NOT)"); - PrimExpr ret = arith::TryConstFold(a); - if (ret.defined()) return ret; + if (auto ret = arith::TryConstFold(a)) return ret.value(); return tir::Not(a, span); }