Skip to content

Commit

Permalink
[Arith][Refactor] Return Optional<PrimExpr> from TryConstFold (#12784)
Browse files Browse the repository at this point in the history
Prior to this commit, the templated `TryConstFold` utility returned an
undefined `PrimExpr` to represent a failure to perform constant
folding.  This commit makes this explicit by returning
`Optional<PrimExpr>` instead.
  • Loading branch information
Lunderberg authored Sep 15, 2022
1 parent e5adb83 commit 397cf87
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 140 deletions.
21 changes: 7 additions & 14 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -716,8 +716,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const AddNode* op) {
PrimExpr b = this->CanonicalMutate(op->b);

// const folding
PrimExpr const_res = TryConstFold<Add>(a, b);
if (const_res.defined()) return const_res;
if (auto const_res = TryConstFold<Add>(a, b)) return const_res.value();

// canonical form simplification.
SumExpr ret = ToSumExpr(std::move(a));
Expand All @@ -741,8 +740,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const SubNode* op) {
PrimExpr b = this->CanonicalMutate(op->b);

// const folding
PrimExpr const_res = TryConstFold<Sub>(a, b);
if (const_res.defined()) return const_res;
if (auto const_res = TryConstFold<Sub>(a, b)) return const_res.value();

// canonical form simplification.
SumExpr ret = ToSumExpr(std::move(a));
Expand All @@ -766,8 +764,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) {
PrimExpr b = this->CanonicalMutate(op->b);

// const folding
PrimExpr const_res = TryConstFold<Mul>(a, b);
if (const_res.defined()) return const_res;
if (auto const_res = TryConstFold<Mul>(a, b)) return const_res.value();

// x * c
if (a.as<IntImmNode>()) {
Expand Down Expand Up @@ -870,8 +867,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) {
PrimExpr b = this->CanonicalMutate(op->b);

// const folding
PrimExpr const_res = TryConstFold<Div>(a, b);
if (const_res.defined()) return const_res;
if (auto const_res = TryConstFold<Div>(a, b)) return const_res.value();
PVar<IntImm> c1;
// x / c1
if (c1.Match(b) && c1.Eval()->value > 0) {
Expand Down Expand Up @@ -928,8 +924,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
PrimExpr b = this->CanonicalMutate(op->b);

// const folding
PrimExpr const_res = TryConstFold<FloorDiv>(a, b);
if (const_res.defined()) return const_res;
if (auto const_res = TryConstFold<FloorDiv>(a, b)) return const_res.value();
PVar<IntImm> c1;
// x / c1
if (c1.Match(b) && c1.Eval()->value > 0) {
Expand Down Expand Up @@ -1037,8 +1032,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) {
PrimExpr b = this->CanonicalMutate(op->b);

// const folding
PrimExpr const_res = TryConstFold<Mod>(a, b);
if (const_res.defined()) return const_res;
if (auto const_res = TryConstFold<Mod>(a, b)) return const_res.value();

PVar<IntImm> c1;
// x % c1
Expand Down Expand Up @@ -1105,8 +1099,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
PrimExpr b = this->CanonicalMutate(op->b);

// const folding
PrimExpr const_res = TryConstFold<FloorMod>(a, b);
if (const_res.defined()) return const_res;
if (auto const_res = TryConstFold<FloorMod>(a, b)) return const_res.value();

PVar<IntImm> c1;
// x % c1
Expand Down
91 changes: 46 additions & 45 deletions src/arith/const_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#ifndef TVM_ARITH_CONST_FOLD_H_
#define TVM_ARITH_CONST_FOLD_H_

#include <tvm/runtime/container/optional.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>

Expand All @@ -44,10 +45,10 @@ namespace arith {
* \tparam Op The operator type.
*
* \note a and b Must already matched data types with each other.
* \return nullptr if constant fold fails, otherwise return folded result.
* \return NullOpt if constant fold fails, otherwise return folded result.
*/
template <typename Op>
inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b);
inline Optional<PrimExpr> TryConstFold(PrimExpr a, PrimExpr b);

/*!
* \brief Try to run unary compute with constant folding.
Expand All @@ -56,10 +57,10 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b);
* \tparam Op The operator type.
*
* \note a and b Must already matched data types with each other.
* \return nullptr if constant fold fails, otherwise return folded result.
* \return NullOpt if constant fold fails, otherwise return folded result.
*/
template <typename Op>
inline PrimExpr TryConstFold(PrimExpr a);
inline Optional<PrimExpr> TryConstFold(PrimExpr a);

/*!
* \brief Check whether type is used to represent index.
Expand Down Expand Up @@ -126,7 +127,7 @@ inline double GetFoldResultDoubleRepr(float x) {

// specialization of constant folders.
template <>
inline PrimExpr TryConstFold<tir::Add>(PrimExpr a, PrimExpr b) {
inline Optional<PrimExpr> TryConstFold<tir::Add>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
Expand All @@ -142,17 +143,17 @@ inline PrimExpr TryConstFold<tir::Add>(PrimExpr a, PrimExpr b) {
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value + fb->value);
} else {
return PrimExpr();
return NullOpt;
}
}
if (fa && fa->value == 0) return b;
if (fb && fb->value == 0) return a;
});
return PrimExpr();
return NullOpt;
}

template <>
inline PrimExpr TryConstFold<tir::Sub>(PrimExpr a, PrimExpr b) {
inline Optional<PrimExpr> TryConstFold<tir::Sub>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
ICHECK(!((pa && pa->dtype.is_uint() && pa->value == 0U) &&
(pb && pb->dtype.is_uint() && pb->value > 0U)))
Expand All @@ -171,16 +172,16 @@ inline PrimExpr TryConstFold<tir::Sub>(PrimExpr a, PrimExpr b) {
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value - fb->value);
} else {
return PrimExpr();
return NullOpt;
}
}
if (fb && fb->value == 0) return a;
});
return PrimExpr();
return NullOpt;
}

template <>
inline PrimExpr TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) {
inline Optional<PrimExpr> TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
Expand All @@ -202,7 +203,7 @@ inline PrimExpr TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) {
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value * fb->value);
} else {
return PrimExpr();
return NullOpt;
}
}
if (fa) {
Expand All @@ -214,11 +215,11 @@ inline PrimExpr TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) {
if (fb->value == 0) return b;
}
});
return PrimExpr();
return NullOpt;
}

template <>
inline PrimExpr TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) {
inline Optional<PrimExpr> TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
Expand All @@ -242,7 +243,7 @@ inline PrimExpr TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) {
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value / fb->value);
} else {
return PrimExpr();
return NullOpt;
}
}
if (fa && fa->value == 0) return a;
Expand All @@ -251,11 +252,11 @@ inline PrimExpr TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) {
ICHECK_NE(fb->value, 0) << "Divide by zero";
}
});
return PrimExpr();
return NullOpt;
}

template <>
inline PrimExpr TryConstFold<tir::Mod>(PrimExpr a, PrimExpr b) {
inline Optional<PrimExpr> TryConstFold<tir::Mod>(PrimExpr a, PrimExpr b) {
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
Expand All @@ -271,11 +272,11 @@ inline PrimExpr TryConstFold<tir::Mod>(PrimExpr a, PrimExpr b) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
}
});
return PrimExpr();
return NullOpt;
}

template <>
inline PrimExpr TryConstFold<tir::FloorDiv>(PrimExpr a, PrimExpr b) {
inline Optional<PrimExpr> TryConstFold<tir::FloorDiv>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
Expand All @@ -297,7 +298,7 @@ inline PrimExpr TryConstFold<tir::FloorDiv>(PrimExpr a, PrimExpr b) {
} else if (rtype.bits() == 64) {
return FloatImm(rtype, std::floor(fa->value / fb->value));
} else {
return PrimExpr();
return NullOpt;
}
}
if (fa && fa->value == 0) return a;
Expand All @@ -306,11 +307,11 @@ inline PrimExpr TryConstFold<tir::FloorDiv>(PrimExpr a, PrimExpr b) {
ICHECK_NE(fb->value, 0) << "Divide by zero";
}
});
return PrimExpr();
return NullOpt;
}

template <>
inline PrimExpr TryConstFold<tir::FloorMod>(PrimExpr a, PrimExpr b) {
inline Optional<PrimExpr> TryConstFold<tir::FloorMod>(PrimExpr a, PrimExpr b) {
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
Expand All @@ -326,114 +327,114 @@ inline PrimExpr TryConstFold<tir::FloorMod>(PrimExpr a, PrimExpr b) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
}
});
return PrimExpr();
return NullOpt;
}

template <>
inline PrimExpr TryConstFold<tir::Min>(PrimExpr a, PrimExpr b) {
inline Optional<PrimExpr> TryConstFold<tir::Min>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value));
if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value));
});
if (a.same_as(b)) return a;
return PrimExpr();
return NullOpt;
}

template <>
inline PrimExpr TryConstFold<tir::Max>(PrimExpr a, PrimExpr b) {
inline Optional<PrimExpr> TryConstFold<tir::Max>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value));
if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value));
});
if (a.same_as(b)) return a;
return PrimExpr();
return NullOpt;
}

template <>
inline PrimExpr TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) {
inline Optional<PrimExpr> TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value);
});
return PrimExpr();
return NullOpt;
}

template <>
inline PrimExpr TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) {
inline Optional<PrimExpr> TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value);
});
return PrimExpr();
return NullOpt;
}

template <>
inline PrimExpr TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) {
inline Optional<PrimExpr> TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value);
});
return PrimExpr();
return NullOpt;
}

template <>
inline PrimExpr TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) {
inline Optional<PrimExpr> TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value);
});
return PrimExpr();
return NullOpt;
}

template <>
inline PrimExpr TryConstFold<tir::EQ>(PrimExpr a, PrimExpr b) {
inline Optional<PrimExpr> TryConstFold<tir::EQ>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value);
});
return PrimExpr();
return NullOpt;
}

template <>
inline PrimExpr TryConstFold<tir::NE>(PrimExpr a, PrimExpr b) {
inline Optional<PrimExpr> TryConstFold<tir::NE>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value);
});
return PrimExpr();
return NullOpt;
}

template <>
inline PrimExpr TryConstFold<tir::And>(PrimExpr a, PrimExpr b) {
inline Optional<PrimExpr> TryConstFold<tir::And>(PrimExpr a, PrimExpr b) {
const IntImmNode* pa = a.as<IntImmNode>();
const IntImmNode* pb = b.as<IntImmNode>();
if (pa && pa->value) return b;
if (pa && !pa->value) return a;
if (pb && pb->value) return a;
if (pb && !pb->value) return b;
return PrimExpr();
return NullOpt;
}

template <>
inline PrimExpr TryConstFold<tir::Or>(PrimExpr a, PrimExpr b) {
inline Optional<PrimExpr> TryConstFold<tir::Or>(PrimExpr a, PrimExpr b) {
const IntImmNode* pa = a.as<IntImmNode>();
const IntImmNode* pb = b.as<IntImmNode>();
if (pa && pa->value) return a;
if (pa && !pa->value) return b;
if (pb && pb->value) return b;
if (pb && !pb->value) return a;
return PrimExpr();
return NullOpt;
}

template <>
inline PrimExpr TryConstFold<tir::Not>(PrimExpr a) {
inline Optional<PrimExpr> TryConstFold<tir::Not>(PrimExpr a) {
const IntImmNode* pa = a.as<IntImmNode>();
if (pa) {
return IntImm(DataType::UInt(1), !(pa->value));
}
return PrimExpr();
return NullOpt;
}

/*! \brief Helper namespace for symbolic value limits */
Expand Down
10 changes: 7 additions & 3 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,13 @@ TVM_DECLARE_LOGICAL_OP(Not);
template <typename Op>
inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, DataType dtype) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
PrimExpr res = TryConstFold<Op>(a->min_value, b->min_value);
if (!res.defined()) res = Op(a->min_value, b->min_value);
return IntervalSet::SinglePoint(res);
PrimExpr expr;
if (auto res = TryConstFold<Op>(a->min_value, b->min_value)) {
expr = res.value();
} else {
expr = Op(a->min_value, b->min_value);
}
return IntervalSet::SinglePoint(expr);
}
if (is_logical_op<Op>::value) {
return IntervalSet(make_const(dtype, 0), make_const(dtype, 1));
Expand Down
Loading

0 comments on commit 397cf87

Please sign in to comment.