diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index 2c72db169a2d..d95332c245b7 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -10,7 +10,7 @@ #include "base.h" #include "expr.h" -#include "ir_operator.h" +#include "expr_operator.h" #include "tvm/node/container.h" namespace tvm { diff --git a/include/tvm/data_layout.h b/include/tvm/data_layout.h index 99aebc3a1c31..3f5cb9a29546 100644 --- a/include/tvm/data_layout.h +++ b/include/tvm/data_layout.h @@ -16,7 +16,7 @@ #include #include -#include "ir_operator.h" +#include "expr_operator.h" namespace tvm { diff --git a/include/tvm/ir_operator.h b/include/tvm/expr_operator.h similarity index 99% rename from include/tvm/ir_operator.h rename to include/tvm/expr_operator.h index c2cdc5e7a923..c4d2d555f3a3 100644 --- a/include/tvm/ir_operator.h +++ b/include/tvm/expr_operator.h @@ -1,13 +1,13 @@ /*! * Copyright (c) 2018 by Contributors - * \file tvm/ir_operator.h + * \file tvm/expr_operator.h * \brief Common operators defined for Expr. * * \note Most of the operator defined here perform simple constant folding * when the type is int32 or int64 for simplifying the index expressions. */ -#ifndef TVM_IR_OPERATOR_H_ -#define TVM_IR_OPERATOR_H_ +#ifndef TVM_EXPR_OPERATOR_H_ +#define TVM_EXPR_OPERATOR_H_ #include #include @@ -617,4 +617,4 @@ TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&); TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||); } // namespace tvm -#endif // TVM_IR_OPERATOR_H_ +#endif // TVM_EXPR_OPERATOR_H_ diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 5e1f1fc73917..eafce72375cf 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -10,7 +10,7 @@ #include #include #include "expr.h" -#include "ir_operator.h" +#include "expr_operator.h" #include "tensor.h" #include "schedule.h" #include "arithmetic.h" diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h index 16f7363a9e73..87ced8b3cb2a 100644 --- a/include/tvm/tensor.h +++ b/include/tvm/tensor.h @@ -14,7 +14,7 @@ #include "base.h" #include "expr.h" -#include "ir_operator.h" +#include "expr_operator.h" #include "arithmetic.h" namespace tvm { diff --git a/include/tvm/tvm.h b/include/tvm/tvm.h index 645c68357f13..5f81cb52fa31 100644 --- a/include/tvm/tvm.h +++ b/include/tvm/tvm.h @@ -8,7 +8,7 @@ #include "base.h" #include "expr.h" -#include "ir_operator.h" +#include "expr_operator.h" #include "tensor.h" #include "operation.h" #include "packed_func_ext.h" diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index fa2d52e9fe85..a4c7842ffe90 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -5,9 +5,8 @@ */ #include #include -#include #include -#include +#include namespace tvm { namespace ir { diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h new file mode 100644 index 000000000000..91613867115b --- /dev/null +++ b/src/arithmetic/const_fold.h @@ -0,0 +1,289 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file const_fold.h + * \brief Centralized location for constant folding. + */ +#ifndef TVM_ARITHMETIC_CONST_FOLD_H_ +#define TVM_ARITHMETIC_CONST_FOLD_H_ + +#include +#include + +namespace tvm { +namespace arith { + +/*! + * \brief Try to run binary compute with constant folding. + * + * \param a The left operand. + * \param b The right operand. + * \tparam Op The operator type. + * + * \note a and b Must already matched data types with each other. + * \return nullptr if constant fold fails, otherwise return folded result. + */ +template +inline Expr TryConstFold(Expr a, Expr b); + +/*! + * \brief Try to run unary compute with constant folding. + * + * \param a The left operand. + * \tparam Op The operator type. + * + * \note a and b Must already matched data types with each other. + * \return nullptr if constant fold fails, otherwise return folded result. + */ +template +inline Expr TryConstFold(Expr a); + +/*! + * \brief Check whether type is used to represent index. + * + * Index types are frequently used in shape computation + * and need to be aggressively constant-folded. + * + * \param type The type to represent index. + * \return the checked result. + */ +inline bool IsIndexType(const Type& type) { + return type.is_int() && type.lanes() == 1 && + (type.bits() == 32 || type.bits() == 64); +} + + +#define TVM_ARITH_CONST_PROPAGATION(BODY) \ + using ir::IntImm; \ + using ir::UIntImm; \ + using ir::FloatImm; \ + const IntImm* pa = a.as(); \ + const IntImm* pb = b.as(); \ + const FloatImm* fa = a.as(); \ + const FloatImm* fb = b.as(); \ + BODY; + + +#define TVM_INDEX_CONST_PROPAGATION(BODY) \ + using ir::IntImm; \ + using ir::UIntImm; \ + const IntImm* pa = a.as(); \ + const IntImm* pb = b.as(); \ + const Type& ta = a.type(); \ + const Type& tb = b.type(); \ + if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \ + BODY; \ + } \ + + +// specialization of constant folders. +template<> +inline Expr TryConstFold(Expr a, Expr b) { + TVM_ARITH_CONST_PROPAGATION({ + const Type& rtype = a.type(); + if (pa && pb) return IntImm::make(rtype, pa->value + pb->value); + if (pa && pa->value == 0) return b; + if (pb && pb->value == 0) return a; + if (fa && fb) return FloatImm::make(rtype, fa->value + fb->value); + if (fa && fa->value == 0) return b; + if (fb && fb->value == 0) return a; + }); + return Expr(); +} + +template<> +inline Expr TryConstFold(Expr a, Expr b) { + TVM_ARITH_CONST_PROPAGATION({ + const Type& rtype = a.type(); + if (pa && pb) return IntImm::make(rtype, pa->value - pb->value); + if (pb && pb->value == 0) return a; + if (fa && fb) return FloatImm::make(rtype, fa->value - fb->value); + if (fb && fb->value == 0) return a; + }); + return Expr(); +} + +template<> +inline Expr TryConstFold(Expr a, Expr b) { + TVM_ARITH_CONST_PROPAGATION({ + const Type& rtype = a.type(); + if (pa && pb) return IntImm::make(rtype, pa->value * pb->value); + if (pa) { + if (pa->value == 1) return b; + if (pa->value == 0) return a; + } + if (pb) { + if (pb->value == 1) return a; + if (pb->value == 0) return b; + } + if (fa && fb) return FloatImm::make(rtype, fa->value * fb->value); + if (fa) { + if (fa->value == 1) return b; + if (fa->value == 0) return a; + } + if (fb) { + if (fb->value == 1) return a; + if (fb->value == 0) return b; + } + }); + return Expr(); +} + +template<> +inline Expr TryConstFold(Expr a, Expr b) { + TVM_ARITH_CONST_PROPAGATION({ + const Type& rtype = a.type(); + // due to division and mod can have different modes + // only constant fold positive number where rule is fixed. + if (pa && pb && pa->value >= 0 && pb->value > 0) { + return IntImm::make(rtype, pa->value / pb->value); + } + if (pa) { + if (pa->value == 0) return a; + } + if (pb) { + if (pb->value == 1) return a; + CHECK_NE(pb->value, 0) << "Divide by zero"; + } + if (fa && fb && fb->value != 0) { + return FloatImm::make(rtype, fa->value / fb->value); + } + if (fa && fa->value == 0) return a; + if (fb) { + if (fb->value == 1) return a; + CHECK_NE(fb->value, 0) << "Divide by zero"; + } + }); + return Expr(); +} + +template<> +inline Expr TryConstFold(Expr a, Expr b) { + TVM_INDEX_CONST_PROPAGATION({ + const Type& rtype = a.type(); + // due to division and mod can have different modes + // only constant fold positive number where rule is fixed. + if (pa && pb && pa->value >= 0 && pb->value > 0) { + return IntImm::make(rtype, pa->value % pb->value); + } + if (pa) { + if (pa->value == 0) return a; + } + if (pb) { + if (pb->value == 1) return make_zero(rtype); + CHECK_NE(pb->value, 0) << "Divide by zero"; + } + }); + return Expr(); +} + +template<> +inline Expr TryConstFold(Expr a, Expr b) { + TVM_ARITH_CONST_PROPAGATION({ + const Type& rtype = a.type(); + if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value)); + if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value)); + }); + return Expr(); +} + +template<> +inline Expr TryConstFold(Expr a, Expr b) { + TVM_ARITH_CONST_PROPAGATION({ + const Type& rtype = a.type(); + if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value)); + if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value)); + }); + return Expr(); +} + +template<> +inline Expr TryConstFold(Expr a, Expr b) { + TVM_ARITH_CONST_PROPAGATION({ + if (pa && pb) return UIntImm::make(UInt(1), pa->value > pb->value); + if (fa && fb) return UIntImm::make(UInt(1), fa->value > fb->value); + }); + return Expr(); +} + +template<> +inline Expr TryConstFold(Expr a, Expr b) { + TVM_ARITH_CONST_PROPAGATION({ + if (pa && pb) return UIntImm::make(UInt(1), pa->value >= pb->value); + if (fa && fb) return UIntImm::make(UInt(1), fa->value >= fb->value); + }); + return Expr(); +} + +template<> +inline Expr TryConstFold(Expr a, Expr b) { + TVM_ARITH_CONST_PROPAGATION({ + if (pa && pb) return UIntImm::make(UInt(1), pa->value < pb->value); + if (fa && fb) return UIntImm::make(UInt(1), fa->value < fb->value); + }); + return Expr(); +} + +template<> +inline Expr TryConstFold(Expr a, Expr b) { + TVM_ARITH_CONST_PROPAGATION({ + if (pa && pb) return UIntImm::make(UInt(1), pa->value <= pb->value); + if (fa && fb) return UIntImm::make(UInt(1), fa->value <= fb->value); + }); + return Expr(); +} + +template<> +inline Expr TryConstFold(Expr a, Expr b) { + TVM_ARITH_CONST_PROPAGATION({ + if (pa && pb) return UIntImm::make(UInt(1), pa->value == pb->value); + if (fa && fb) return UIntImm::make(UInt(1), fa->value == fb->value); + }); + return Expr(); +} + +template<> +inline Expr TryConstFold(Expr a, Expr b) { + TVM_ARITH_CONST_PROPAGATION({ + if (pa && pb) return UIntImm::make(UInt(1), pa->value != pb->value); + if (fa && fb) return UIntImm::make(UInt(1), fa->value != fb->value); + }); + return Expr(); +} + +template<> +inline Expr TryConstFold(Expr a, Expr b) { + using ir::UIntImm; + const UIntImm* pa = a.as(); + const UIntImm* pb = b.as(); + if (pa && pa->value) return b; + if (pa && !pa->value) return a; + if (pb && pb->value) return a; + if (pb && !pb->value) return b; + return Expr(); +} + +template<> +inline Expr TryConstFold(Expr a, Expr b) { + using ir::UIntImm; + const UIntImm* pa = a.as(); + const UIntImm* pb = b.as(); + if (pa && pa->value) return a; + if (pa && !pa->value) return b; + if (pb && pb->value) return b; + if (pb && !pb->value) return a; + return Expr(); +} + +template<> +inline Expr TryConstFold(Expr a) { + using ir::UIntImm; + const UIntImm* pa = a.as(); + if (pa) { + return UIntImm::make(UInt(1), !(pa->value)); + } + return Expr(); +} + +} // namespace arith +} // namespace tvm +#endif // TVM_ARITHMETIC_CONST_FOLD_H_ diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 8da6e91fc7fa..8112beef7551 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -4,7 +4,7 @@ * \brief Modular set analysis */ #include -#include +#include #include #include #include "pattern_match.h" diff --git a/src/lang/expr.cc b/src/lang/expr.cc index 7ac0e372371c..3bf8fc9191fb 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -5,7 +5,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/lang/ir_operator.cc b/src/lang/expr_operator.cc similarity index 58% rename from src/lang/ir_operator.cc rename to src/lang/expr_operator.cc index beceb094c620..edbe0be3d5c5 100644 --- a/src/lang/ir_operator.cc +++ b/src/lang/expr_operator.cc @@ -1,28 +1,16 @@ /*! * Copyright (c) 2017 by Contributors - * \file ir_operator.cc + * \file expr_operator.cc */ #include #include -#include +#include #include +// Centralized header for constant folders. +#include "../arithmetic/const_fold.h" namespace tvm { -/*! - * \brief Check whether type is used to represent index. - * - * Index types are frequently used in shape computation - * and need to be aggressively constant-folded. - * - * \param type The type to represent index. - * \return the checked result. - */ -inline bool IsIndexType(const Type& type) { - return type.is_int() && type.lanes() == 1 && - (type.bits() == 32 || type.bits() == 64); -} - // simple cast that only checks if type matches and cast inline Expr SimpleCast(const Type& t, Expr value) { if (value.type() == t) return value; @@ -135,45 +123,14 @@ Expr reinterpret(const Type& t, Expr value) { return ir::Call::make(t, ir::Call::reinterpret, { value }, ir::Call::PureIntrinsic); } -#define TVM_INDEX_CONST_PROPAGATION(BODY) \ - using ir::IntImm; \ - using ir::UIntImm; \ - const IntImm* pa = a.as(); \ - const IntImm* pb = b.as(); \ - const Type& ta = a.type(); \ - const Type& tb = b.type(); \ - if (IsIndexType(ta) && IsIndexType(tb)) { \ - BODY; \ - } \ - BinaryOpMatchTypes(a, b); - -#define TVM_ARITH_CONST_PROPAGATION(BODY) \ - using ir::IntImm; \ - using ir::UIntImm; \ - using ir::FloatImm; \ - BinaryOpMatchTypes(a, b); \ - const IntImm* pa = a.as(); \ - const IntImm* pb = b.as(); \ - const FloatImm* fa = a.as(); \ - const FloatImm* fb = b.as(); \ - BODY; - - Expr operator+(Expr a, Expr b) { - TVM_ARITH_CONST_PROPAGATION({ - const Type& ta = a.type(); - const Type& tb = b.type(); - Type rtype = ta.bits() >= tb.bits() ? ta : tb; - if (pa && pb) return IntImm::make(rtype, pa->value + pb->value); - if (pa && pa->value == 0) return SimpleCast(rtype, b); - if (pb && pb->value == 0) return SimpleCast(rtype, a); - if (fa && fb) return FloatImm::make(rtype, fa->value + fb->value); - if (fa && fa->value == 0) return SimpleCast(rtype, b); - if (fb && fb->value == 0) return SimpleCast(rtype, a); - }); + BinaryOpMatchTypes(a, b); + Expr ret = arith::TryConstFold(a, b); + if (ret.defined()) return ret; return ir::Add::make(a, b); } +// negation Expr operator-(Expr a) { using ir::IntImm; using ir::FloatImm; @@ -185,114 +142,44 @@ Expr operator-(Expr a) { } Expr operator-(Expr a, Expr b) { - TVM_ARITH_CONST_PROPAGATION({ - const Type& ta = a.type(); - const Type& tb = b.type(); - Type rtype = ta.bits() >= tb.bits() ? ta : tb; - if (pa && pb) return IntImm::make(rtype, pa->value - pb->value); - if (pb && pb->value == 0) return SimpleCast(rtype, a); - if (fa && fb) return FloatImm::make(rtype, fa->value - fb->value); - if (fb && fb->value == 0) return SimpleCast(rtype, a); - }); + BinaryOpMatchTypes(a, b); + Expr ret = arith::TryConstFold(a, b); + if (ret.defined()) return ret; return ir::Sub::make(a, b); } Expr operator*(Expr a, Expr b) { - TVM_ARITH_CONST_PROPAGATION({ - const Type& ta = a.type(); - const Type& tb = b.type(); - Type rtype = ta.bits() >= tb.bits() ? ta : tb; - if (pa && pb) return IntImm::make(rtype, pa->value * pb->value); - if (pa) { - if (pa->value == 1) return SimpleCast(rtype, b); - if (pa->value == 0) return SimpleCast(rtype, a); - } - if (pb) { - if (pb->value == 1) return SimpleCast(rtype, a); - if (pb->value == 0) return SimpleCast(rtype, b); - } - if (fa && fb) return FloatImm::make(rtype, fa->value * fb->value); - if (fa) { - if (fa->value == 1) return SimpleCast(rtype, b); - if (fa->value == 0) return SimpleCast(rtype, a); - } - if (fb) { - if (fb->value == 1) return SimpleCast(rtype, a); - if (fb->value == 0) return SimpleCast(rtype, b); - } - }); + BinaryOpMatchTypes(a, b); + Expr ret = arith::TryConstFold(a, b); + if (ret.defined()) return ret; return ir::Mul::make(a, b); } Expr operator/(Expr a, Expr b) { - TVM_ARITH_CONST_PROPAGATION({ - const Type& ta = a.type(); - const Type& tb = b.type(); - Type rtype = ta.bits() >= tb.bits() ? ta : tb; - // due to division and mod can have different modes - // only constant fold positive number where rule is fixed. - if (pa && pb && pa->value >= 0 && pb->value > 0) { - return IntImm::make(rtype, pa->value / pb->value); - } - if (pa) { - if (pa->value == 0) return SimpleCast(rtype, a); - } - if (pb) { - if (pb->value == 1) return SimpleCast(rtype, a); - CHECK_NE(pb->value, 0) << "Divide by zero"; - } - if (fa && fb && fb->value != 0) { - return FloatImm::make(rtype, fa->value / fb->value); - } - if (fa && fa->value == 0) { - return SimpleCast(rtype, a); - } - if (fb) { - if (fb->value == 1) return SimpleCast(rtype, a); - CHECK_NE(fb->value, 0) << "Divide by zero"; - } - }); + BinaryOpMatchTypes(a, b); + Expr ret = arith::TryConstFold(a, b); + if (ret.defined()) return ret; return ir::Div::make(a, b); } Expr operator%(Expr a, Expr b) { - TVM_INDEX_CONST_PROPAGATION({ - Type rtype = ta.bits() >= tb.bits() ? ta : tb; - // due to division and mod can have different modes - // only constant fold positive number where rule is fixed. - if (pa && pb && pa->value >= 0 && pb->value > 0) { - return IntImm::make(rtype, pa->value % pb->value); - } - if (pa) { - if (pa->value == 0) return SimpleCast(rtype, a); - } - if (pb) { - if (pb->value == 1) return make_zero(rtype); - CHECK_NE(pb->value, 0) << "Divide by zero"; - } - }); + BinaryOpMatchTypes(a, b); + Expr ret = arith::TryConstFold(a, b); + if (ret.defined()) return ret; return ir::Mod::make(a, b); } Expr min(Expr a, Expr b) { - TVM_ARITH_CONST_PROPAGATION({ - const Type& ta = a.type(); - const Type& tb = b.type(); - Type rtype = ta.bits() >= tb.bits() ? ta : tb; - if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value)); - if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value)); - }); + BinaryOpMatchTypes(a, b); + Expr ret = arith::TryConstFold(a, b); + if (ret.defined()) return ret; return ir::Min::make(a, b); } Expr max(Expr a, Expr b) { - TVM_ARITH_CONST_PROPAGATION({ - const Type& ta = a.type(); - const Type& tb = b.type(); - Type rtype = ta.bits() >= tb.bits() ? ta : tb; - if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value)); - if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value)); - }); + BinaryOpMatchTypes(a, b); + Expr ret = arith::TryConstFold(a, b); + if (ret.defined()) return ret; return ir::Max::make(a, b); } @@ -328,129 +215,116 @@ Expr likely(Expr cond) { } Expr operator>(Expr a, Expr b) { - TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImm::make(UInt(1), pa->value > pb->value); - if (fa && fb) return UIntImm::make(UInt(1), fa->value > fb->value); - }); + BinaryOpMatchTypes(a, b); + Expr ret = arith::TryConstFold(a, b); + if (ret.defined()) return ret; return ir::GT::make(a, b); } Expr operator>=(Expr a, Expr b) { - TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImm::make(UInt(1), pa->value >= pb->value); - if (fa && fb) return UIntImm::make(UInt(1), fa->value >= fb->value); - }); + BinaryOpMatchTypes(a, b); + Expr ret = arith::TryConstFold(a, b); + if (ret.defined()) return ret; return ir::GE::make(a, b); } Expr operator<(Expr a, Expr b) { - TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImm::make(UInt(1), pa->value < pb->value); - if (fa && fb) return UIntImm::make(UInt(1), fa->value < fb->value); - }); + BinaryOpMatchTypes(a, b); + Expr ret = arith::TryConstFold(a, b); + if (ret.defined()) return ret; return ir::LT::make(a, b); } Expr operator<=(Expr a, Expr b) { - TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImm::make(UInt(1), pa->value <= pb->value); - if (fa && fb) return UIntImm::make(UInt(1), fa->value <= fb->value); - }); + BinaryOpMatchTypes(a, b); + Expr ret = arith::TryConstFold(a, b); + if (ret.defined()) return ret; return ir::LE::make(a, b); } Expr operator==(Expr a, Expr b) { - TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImm::make(UInt(1), pa->value == pb->value); - if (fa && fb) return UIntImm::make(UInt(1), fa->value == fb->value); - }); + BinaryOpMatchTypes(a, b); + Expr ret = arith::TryConstFold(a, b); + if (ret.defined()) return ret; return ir::EQ::make(a, b); } Expr operator!=(Expr a, Expr b) { - TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImm::make(UInt(1), pa->value != pb->value); - if (fa && fb) return UIntImm::make(UInt(1), fa->value != fb->value); - }); + BinaryOpMatchTypes(a, b); + Expr ret = arith::TryConstFold(a, b); + if (ret.defined()) return ret; return ir::NE::make(a, b); } Expr operator&&(Expr a, Expr b) { - using ir::UIntImm; - if (a.type().is_bool() && b.type().is_bool()) { - const UIntImm* pa = a.as(); - const UIntImm* pb = b.as(); - if (pa && pa->value) return b; - if (pa && !pa->value) return a; - if (pb && pb->value) return a; - if (pb && !pb->value) return b; - } + CHECK(a.type().is_bool()); + CHECK(b.type().is_bool()); + Expr ret = arith::TryConstFold(a, b); + if (ret.defined()) return ret; return ir::And::make(a, b); } Expr operator||(Expr a, Expr b) { - using ir::UIntImm; - if (a.type().is_bool() && b.type().is_bool()) { - const UIntImm* pa = a.as(); - const UIntImm* pb = b.as(); - if (pa && pa->value) return a; - if (pa && !pa->value) return b; - if (pb && pb->value) return b; - if (pb && !pb->value) return a; - } + CHECK(a.type().is_bool()); + CHECK(b.type().is_bool()); + Expr ret = arith::TryConstFold(a, b); + if (ret.defined()) return ret; return ir::Or::make(a, b); } Expr operator!(Expr a) { - using ir::UIntImm; - const UIntImm* pa = a.as(); - if (pa) { - return UIntImm::make(UInt(1), !(pa->value)); - } + CHECK(a.type().is_bool()); + Expr ret = arith::TryConstFold(a); + if (ret.defined()) return ret; return ir::Not::make(a); } Expr operator>>(Expr a, Expr b) { + BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - Type rtype = ta.bits() >= tb.bits() ? ta : tb; + const Type& rtype = a.type(); if (pa && pb) return IntImm::make(rtype, (pa->value >> pb->value)); if (pb) { - if (pb->value == 0) return SimpleCast(rtype, a); + if (pb->value == 0) return a; } }); return ir::Call::make(a.type(), ir::Call::shift_right, { a, b }, ir::Call::PureIntrinsic); } Expr operator<<(Expr a, Expr b) { + BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - Type rtype = ta.bits() >= tb.bits() ? ta : tb; + const Type& rtype = a.type(); if (pa && pb) return IntImm::make(rtype, (pa->value << pb->value)); if (pb) { - if (pb->value == 0) return SimpleCast(rtype, a); + if (pb->value == 0) return a; } }); return ir::Call::make(a.type(), ir::Call::shift_left, { a, b }, ir::Call::PureIntrinsic); } Expr operator&(Expr a, Expr b) { + BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - Type rtype = ta.bits() >= tb.bits() ? ta : tb; + const Type& rtype = a.type(); if (pa && pb) return IntImm::make(rtype, (pa->value & pb->value)); }); return ir::Call::make(a.type(), ir::Call::bitwise_and, { a, b }, ir::Call::PureIntrinsic); } Expr operator|(Expr a, Expr b) { + BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - Type rtype = ta.bits() >= tb.bits() ? ta : tb; + const Type& rtype = a.type(); if (pa && pb) return IntImm::make(rtype, (pa->value | pb->value)); }); return ir::Call::make(a.type(), ir::Call::bitwise_or, { a, b }, ir::Call::PureIntrinsic); } Expr operator^(Expr a, Expr b) { + BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - Type rtype = ta.bits() >= tb.bits() ? ta : tb; + const Type& rtype = a.type(); if (pa && pb) return IntImm::make(rtype, (pa->value ^ pb->value)); }); return ir::Call::make(a.type(), ir::Call::bitwise_xor, { a, b }, ir::Call::PureIntrinsic); diff --git a/src/op/hybrid_op.cc b/src/op/hybrid_op.cc index 0268498c7db2..31c45258abc8 100644 --- a/src/op/hybrid_op.cc +++ b/src/op/hybrid_op.cc @@ -7,8 +7,8 @@ #include #include #include -#include #include +#include #include #include #include diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h index 3cef4486ee1b..6af8421398de 100644 --- a/src/pass/ir_util.h +++ b/src/pass/ir_util.h @@ -7,7 +7,7 @@ #define TVM_PASS_IR_UTIL_H_ #include -#include +#include #include #include diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index 488d44544c31..12913dde95af 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -8,7 +8,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index c24203cebdb3..5bab6399151a 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -4,7 +4,7 @@ * \brief Implementation of operator pad */ #include -#include +#include #include #include #include diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index df23b22512e3..55892e5c73a1 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -5,7 +5,7 @@ */ #include #include -#include +#include #include #include #include diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 11a376b2b657..11f96c48a311 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -6,7 +6,7 @@ * \brief This is a backend-aware optimization pass. * Fuse necessary ops into a single one. */ -#include +#include #include #include #include diff --git a/tests/cpp/ir_mutator_test.cc b/tests/cpp/ir_mutator_test.cc index 0802d405bbe4..eecced8d90ab 100644 --- a/tests/cpp/ir_mutator_test.cc +++ b/tests/cpp/ir_mutator_test.cc @@ -1,7 +1,7 @@ #include #include #include -#include +#include namespace { using namespace tvm::ir;