diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index ab9a711d28d8..9d9f98e79695 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -490,8 +490,6 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) { CHECK(expr.defined()); if (const ir::IntImmNode* op = expr.as()) { *ptr = static_cast(op->value); - } else if (const ir::UIntImmNode* op = expr.as()) { - *ptr = static_cast(op->value); } else { LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey(); } @@ -523,8 +521,6 @@ inline void SetValue(double* ptr, const TVMArgValue& val) { *ptr = static_cast(op->value); } else if (const ir::IntImmNode* op = expr.as()) { *ptr = static_cast(op->value); - } else if (const ir::UIntImmNode* op = expr.as()) { - *ptr = static_cast(op->value); } else { LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey(); } diff --git a/include/tvm/expr.h b/include/tvm/expr.h index faae303d95dd..62806c667e61 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -115,56 +115,38 @@ class Var : public PrimExpr { using ContainerType = VarNode; }; -class Integer; -/*! \brief ExprNode: constant integer. */ -class IntImmNode : public PrimExprNode { - public: - /*! \brief the Internal value. */ - int64_t value; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("value", &value); - } - - TVM_DLL static Integer make(DataType t, int64_t value); - - static constexpr const char* _type_key = "IntImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); -}; - /*! - * \brief Container of constant integer (IntImm). + * \brief Container of constant int that adds more constructors. * * This is used to store and automate type check * attributes that must be constant integer. + * + * \sa IntImm */ -class Integer : public PrimExpr { +class Integer : public IntImm { public: - Integer() : PrimExpr() {} + Integer() {} /*! * \brief constructor from node. */ - explicit Integer(ObjectPtr node) : PrimExpr(node) {} + explicit Integer(ObjectPtr node) : IntImm(node) {} /*! * \brief Construct integer from int value. */ - Integer(int value) : PrimExpr(value) {} // NOLINT(*) + Integer(int value) : IntImm(DataType::Int(32), value) {} // NOLINT(*) + /*! + * \brief Construct integer from int imm. + * \param other The other value. + */ + Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*) /*! * \brief Assign an expression to integer. * \param other another expression. */ - Integer& operator=(const Integer& other) { - data_ = other.data_; + Integer& operator=(const IntImm& other) { + data_ = ObjectRef::GetDataPtr(other); return *this; } - /*! - * \brief Get pointer to the internal value. - * \return the content of the integer. - */ - const IntImmNode* operator->() const { - return static_cast(get()); - } /*! * \brief convert to int64_t */ @@ -173,8 +155,6 @@ class Integer : public PrimExpr { << " Trying to reference a null Integer"; return (*this)->value; } - /*! \brief type indicate the container type */ - using ContainerType = IntImmNode; }; /*! \brief range over one dimension */ diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index 2d8f37855856..ff3b340bf1fa 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -30,6 +30,7 @@ #include #include +#include #include "expr.h" #include "ir.h" @@ -82,21 +83,6 @@ inline const int64_t* as_const_int(const PrimExpr& x) { } } -/*! - * \brief Get x as constant uint expression. - * \param x The expression - * \return the address to the int expression, - * return nullptr, if x is not UIntImm. - */ -inline const uint64_t* as_const_uint(const PrimExpr& x) { - if (!x.defined()) return nullptr; - if (const ir::UIntImmNode* op = x.as()) { - return &(op->value); - } else { - return nullptr; - } -} - /*! * \brief Check whether x is a constant integer expression. * \param x The input argument @@ -597,6 +583,15 @@ TVM_DLL PrimExpr nearbyint(PrimExpr x); */ TVM_DLL PrimExpr trunc(PrimExpr x); +/*! + * \brief Construct a large uint constant by its low 32 bits and high 32bits. + * \param dtype The final data type. + * \param low The lower 32 bits. + * \param high The higher 32 bits. + * \return The constructed expression. + */ +TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); + // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ inline PrimExpr OpName(PrimExpr x) { \ @@ -617,11 +612,11 @@ TVM_DECLARE_INTRIN_UNARY(atan); // Implementation details after this inline bool is_const(const PrimExpr& x) { - if (x.as() || x.as()) { + if (x.as()) { return true; } else if (const auto* op = x.as()) { const PrimExpr& val = op->value; - if (val.as() || val.as()) { + if (val.as()) { return true; } } @@ -631,8 +626,6 @@ inline bool is_const(const PrimExpr& x) { inline bool is_positive_const(const PrimExpr& a) { if (const ir::IntImmNode* op = a.as()) { return op->value > 0; - } else if (const ir::UIntImmNode* op = a.as()) { - return op->value > 0; } else { return false; } @@ -649,14 +642,10 @@ inline bool is_negative_const(const PrimExpr& a) { inline bool is_const_int(const PrimExpr& x, int64_t value) { if (const auto* op = x.as()) { return op->value == value; - } else if (const auto* op = x.as()) { - return op->value == static_cast(value); } else if (const auto* op = x.as()) { const PrimExpr& val = op->value; if (const auto* opv = val.as()) { return opv->value == value; - } else if (const auto* opv = val.as()) { - return opv->value == static_cast(value); } } return false; @@ -675,15 +664,27 @@ inline bool is_no_op(const Stmt& stmt) { template inline PrimExpr MakeConstScalar(DataType t, ValueType value) { - if (t.is_int()) return ir::IntImmNode::make(t, static_cast(value)); - if (t.is_uint()) return ir::UIntImmNode::make(t, static_cast(value)); + if (t.is_int()) return IntImm(t, static_cast(value)); + if (t.is_uint()) { + // Use IntImm if it is a small integer + uint64_t uval = static_cast(value); + if (uval <= static_cast(std::numeric_limits::max())) { + return IntImm(t, static_cast(value)); + } else { + uint64_t mask = (static_cast(1) << 32U) - 1U; + uint64_t low = uval & mask; + uint64_t high = uval >> 32U; + return LargeUIntImm(t, static_cast(low), static_cast(high)); + } + } if (t.is_float()) return ir::FloatImmNode::make(t, static_cast(value)); // For now, we store const scalar values of custom datatypes within doubles; later, during the // datatypes lowering pass, we will lower the value to its true representation in the format // specified by the datatype. // TODO(gus) when do we need to start worrying about doubles not being precise enough? - if (static_cast(t.code()) >= static_cast(kCustomBegin)) + if (static_cast(t.code()) >= static_cast(kCustomBegin)) { return ir::FloatImmNode::make(t, static_cast(value)); + } LOG(FATAL) << "cannot make const for type " << t; return PrimExpr(); } diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 84039485ae69..9c14a31be2fe 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -39,23 +39,6 @@ namespace ir { using IntImmNode = tvm::IntImmNode; using VarNode = tvm::VarNode; -/*! \brief constant unsigned integer. */ -class UIntImmNode : public PrimExprNode { - public: - /*! \brief The constant value content. */ - uint64_t value; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("value", &value); - } - - TVM_DLL static PrimExpr make(DataType t, uint64_t value); - - static constexpr const char* _type_key = "UIntImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(UIntImmNode, PrimExprNode); -}; - /*! \brief Floating point constants. */ class FloatImmNode : public PrimExprNode { public: @@ -1422,6 +1405,16 @@ inline bool IsPragmaKey(const std::string& attr_key) { /*! \brief namespace of TVM Intrinsic functions */ namespace intrinsic { +/*! + * \brief See pesudo code + * + * Construct a big uint that may not be representable by int64 + * + * Expr tvm_large_uint_imm(uint32_t v0, uin32_t v1) { + * return (v1 << 32) | v0; + * } + */ +constexpr const char* tvm_large_uint_imm = "tvm_large_uint_imm"; /*! * \brief See pesudo code * diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 12b34dd26398..12e505ed9ff6 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -131,6 +131,56 @@ class PrimExpr : public BaseExpr { using ContainerType = PrimExprNode; }; +/*! + * \brief Constant integer literals in the program. + * \sa IntImm + */ +class IntImmNode : public PrimExprNode { + public: + /*! \brief the Internal value. */ + int64_t value; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &dtype); + v->Visit("value", &value); + } + + static constexpr const char* _type_key = "IntImm"; + TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); +}; + +/*! + * \brief Managed reference class to IntImmNode. + * + * \sa IntImmNode + */ +class IntImm : public PrimExpr { + public: + /*! + * \brief Constructor + */ + IntImm() {} + /*! + * \brief constructor from node. + */ + explicit IntImm(ObjectPtr node) : PrimExpr(node) {} + /*! + * \brief Constructor. + * \param dtype The data type of the value. + * \param value The internal value. + */ + TVM_DLL IntImm(DataType dtype, int64_t value); + /*! + * \brief Get pointer to the internal value. + * \return the content of the integer. + */ + const IntImmNode* operator->() const { + return static_cast(get()); + } + /*! \brief type indicate the container type */ + using ContainerType = IntImmNode; +}; + /*! * \brief Base node of all non-primitive expressions. * diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index 7d57564fd3df..37a1fe4bffb2 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -161,7 +161,6 @@ class ExprFunctor { virtual R VisitExpr_(const BroadcastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ShuffleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const UIntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Object* op, Args ...) { @@ -203,7 +202,6 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(ShuffleNode); IR_EXPR_FUNCTOR_DISPATCH(BroadcastNode); IR_EXPR_FUNCTOR_DISPATCH(IntImmNode); - IR_EXPR_FUNCTOR_DISPATCH(UIntImmNode); IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode); IR_EXPR_FUNCTOR_DISPATCH(StringImmNode); return vtable; @@ -327,7 +325,6 @@ class TVM_DLL ExprVisitor : void VisitExpr_(const BroadcastNode* op) override; void VisitExpr_(const ShuffleNode* op) override; void VisitExpr_(const IntImmNode* op) override; - void VisitExpr_(const UIntImmNode* op) override; void VisitExpr_(const FloatImmNode* op) override; void VisitExpr_(const StringImmNode* op) override; }; @@ -372,7 +369,6 @@ class TVM_DLL ExprMutator : PrimExpr VisitExpr_(const BroadcastNode* op) override; PrimExpr VisitExpr_(const ShuffleNode* op) override; PrimExpr VisitExpr_(const IntImmNode* op) override; - PrimExpr VisitExpr_(const UIntImmNode* op) override; PrimExpr VisitExpr_(const FloatImmNode* op) override; PrimExpr VisitExpr_(const StringImmNode* op) override; }; diff --git a/python/tvm/api.py b/python/tvm/api.py index 7395d3524709..4bfe794c14d3 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -92,6 +92,9 @@ def const(value, dtype=None): """ if dtype is None: dtype = _scalar_type_inference(value) + if dtype == "uint64" and value >= (1 << 63): + return _api_internal._LargeUIntImm( + dtype, value & ((1 << 32) - 1), value >> 32) return _api_internal._const(value, dtype) diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 7f36914eb0a6..5067277d32a8 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -221,7 +221,7 @@ def args_to_workload(x, topi_compute_func=None): workload = tuple([args_to_workload(a) for a in x]) elif isinstance(x, (str, int, float, np.int, np.float, expr.Var)): workload = x - elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)): + elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)): workload = x.value elif x is None: workload = 0 @@ -344,7 +344,7 @@ def _count_flop(exp): if len(source) != 1: raise FlopCalculationError("Found multiple output in the source of reduce op") return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0])) - if isinstance(exp, (expr.FloatImm, expr.IntImm, expr.UIntImm)): + if isinstance(exp, (expr.FloatImm, expr.IntImm)): return 0 if isinstance(exp, expr.Cast): return _count_flop(exp.value) diff --git a/python/tvm/autotvm/util.py b/python/tvm/autotvm/util.py index 3026914aed20..54001d3338ad 100644 --- a/python/tvm/autotvm/util.py +++ b/python/tvm/autotvm/util.py @@ -155,9 +155,9 @@ def get_const_int(exp): """ if isinstance(exp, int): return exp - if not isinstance(exp, (expr.IntImm, expr.UIntImm)): + if not isinstance(exp, (expr.IntImm,)): exp = ir_pass.Simplify(exp) - if not isinstance(exp, (expr.IntImm, expr.UIntImm)): + if not isinstance(exp, (expr.IntImm,)): raise ValueError("Expect value to be constant int") return exp.value @@ -179,9 +179,9 @@ def get_const_tuple(in_tuple): for elem in in_tuple: if isinstance(elem, expr.Var): ret.append(elem) - elif not isinstance(elem, (expr.IntImm, expr.UIntImm, int)): + elif not isinstance(elem, (expr.IntImm, int)): elem = ir_pass.Simplify(elem) - if not isinstance(elem, (expr.IntImm, expr.UIntImm)): + if not isinstance(elem, (expr.IntImm)): ret.append(elem) else: ret.append(get_const_int(elem)) diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 71c0aecd1f6a..2fd7b78d9d66 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -341,23 +341,6 @@ def __int__(self): return self.value -@register_object -class UIntImm(ConstExpr): - """UInt constant. - - Parameters - ---------- - dtype : str - The data type - - value : int - The constant value. - """ - def __init__(self, dtype, value): - self.__init_handle_by_constructor__( - _make.UIntImm, dtype, value) - - @register_object class StringImm(ConstExpr): """String constant. diff --git a/python/tvm/hybrid/calls.py b/python/tvm/hybrid/calls.py index 1d5612e67e80..7038f6144db3 100644 --- a/python/tvm/hybrid/calls.py +++ b/python/tvm/hybrid/calls.py @@ -156,6 +156,6 @@ def max_num_threads(func_id, args): if args.__len__() == 0: res = _tgt.current_target().max_num_threads else: - _internal_assert(isinstance(args[0], _expr.UIntImm), "In tvm bool should be uint") + _internal_assert(isinstance(args[0], _expr.IntImm), "In tvm bool should be uint") res = _tgt.current_target(args[0].value).max_num_threads return _api.convert(res) diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 06bcbcabe0c3..57d636328816 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -386,7 +386,7 @@ def visit_Subscript(self, node): if isinstance(i, numbers.Integral): arr = arr[i] else: - _internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \ + _internal_assert(isinstance(i, (_expr.IntImm,)), \ "All indices are supposed to be constants") arr = arr[i.value] return arr @@ -413,7 +413,7 @@ def visit_If(self, node): cond = _ir_pass.CanonicalSimplify(self.visit(node.test)) # Return no IfThenElse if proven - if isinstance(cond, _expr.UIntImm): + if isinstance(cond, _expr.IntImm): if cond.value: return visit_list_to_block(self.visit, node.body) if node.orelse: diff --git a/python/tvm/hybrid/util.py b/python/tvm/hybrid/util.py index 0dd1fa141329..a08a380dd767 100644 --- a/python/tvm/hybrid/util.py +++ b/python/tvm/hybrid/util.py @@ -33,7 +33,7 @@ #pylint: disable=invalid-name np_arg_types = tuple(list(numeric_types) + [numpy.ndarray]) tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr) -halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm) +halide_imm_types = (_expr.IntImm, _expr.FloatImm) def _internal_assert(cond, err): diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 7e22d72131ac..e7f4682e7eb2 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -931,7 +931,7 @@ def _shape(): def _impl(inputs, attr, params): is_symbolic_shape = False for axis in attr['_input_shapes'][inputs[0]]: - if not isinstance(axis, (int, tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(axis, (int, tvm.expr.IntImm)): is_symbolic_shape = True break diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index ca4823bc6b83..30ca51592c8f 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -130,8 +130,6 @@ TVM_REGISTER_GLOBAL("make.CommReducer") REGISTER_MAKE(Reduce); REGISTER_MAKE(AttrStmt); -REGISTER_MAKE(IntImm); -REGISTER_MAKE(UIntImm); REGISTER_MAKE(FloatImm); REGISTER_MAKE(StringImm); diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 6a8bc58ad7d0..fa7b59d36b88 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -53,6 +53,9 @@ TVM_REGISTER_GLOBAL("_const") } }); +TVM_REGISTER_GLOBAL("_LargeUIntImm") +.set_body_typed(LargeUIntImm); + TVM_REGISTER_GLOBAL("_str") .set_body_typed(ir::StringImmNode::make); diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index 7a3baa678352..e03e5e2387bf 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -87,15 +87,15 @@ bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { } bool Analyzer::CanProve(const PrimExpr& expr) { - if (const auto* ptr = expr.as()) { + if (const auto* ptr = expr.as()) { return ptr->value != 0; } auto res = this->rewrite_simplify(expr); - if (const auto* ptr = res.as()) { + if (const auto* ptr = res.as()) { return ptr->value != 0; } res = this->canonical_simplify(expr); - if (const auto* ptr = res.as()) { + if (const auto* ptr = res.as()) { return ptr->value != 0; } return false; diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index 5f721d7a1f94..90c6e48ded1e 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -737,7 +737,7 @@ VisitExpr_(const DivNode* op) { // const folding PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; - PVar c1; + PVar c1; // x / c1 if (c1.Match(b) && c1.Eval()->value > 0) { int64_t cval = c1.Eval()->value; @@ -797,7 +797,7 @@ VisitExpr_(const FloorDivNode* op) { // const folding PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; - PVar c1; + PVar c1; // x / c1 if (c1.Match(b) && c1.Eval()->value > 0) { int64_t cval = c1.Eval()->value; @@ -905,7 +905,7 @@ VisitExpr_(const ModNode* op) { PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; - PVar c1; + PVar c1; // x % c1 if (c1.Match(b) && c1.Eval()->value > 0) { int64_t cval = c1.Eval()->value; @@ -975,7 +975,7 @@ VisitExpr_(const FloorModNode* op) { PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; - PVar c1; + PVar c1; // x % c1 if (c1.Match(b) && c1.Eval()->value > 0) { int64_t cval = c1.Eval()->value; diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index 55c156d898f9..3b803ecd84a2 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -76,8 +76,6 @@ inline bool IsIndexType(const DataType& type) { #define TVM_ARITH_CONST_PROPAGATION(BODY) \ - using ir::IntImmNode; \ - using ir::UIntImmNode; \ using ir::FloatImmNode; \ const IntImmNode* pa = a.as(); \ const IntImmNode* pb = b.as(); \ @@ -87,8 +85,6 @@ inline bool IsIndexType(const DataType& type) { #define TVM_INDEX_CONST_PROPAGATION(BODY) \ - using ir::IntImmNode; \ - using ir::UIntImmNode; \ const IntImmNode* pa = a.as(); \ const IntImmNode* pb = b.as(); \ const DataType& ta = a.dtype(); \ @@ -103,7 +99,7 @@ template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, pa->value + pb->value); + if (pa && pb) return IntImm(rtype, pa->value + pb->value); if (pa && pa->value == 0) return b; if (pb && pb->value == 0) return a; if (fa && fb) return FloatImmNode::make(rtype, fa->value + fb->value); @@ -117,7 +113,7 @@ template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, pa->value - pb->value); + if (pa && pb) return IntImm(rtype, pa->value - pb->value); if (pb && pb->value == 0) return a; if (fa && fb) return FloatImmNode::make(rtype, fa->value - fb->value); if (fb && fb->value == 0) return a; @@ -129,7 +125,7 @@ template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, pa->value * pb->value); + if (pa && pb) return IntImm(rtype, pa->value * pb->value); if (pa) { if (pa->value == 1) return b; if (pa->value == 0) return a; @@ -159,7 +155,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { // due to division and mod can have different modes // NOTE: this will assumes truc div. CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImmNode::make(rtype, pa->value / pb->value); + return IntImm(rtype, pa->value / pb->value); } if (pa) { if (pa->value == 0) return a; @@ -185,7 +181,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { - return IntImmNode::make(rtype, pa->value % pb->value); + return IntImm(rtype, pa->value % pb->value); } if (pa) { if (pa->value == 0) return a; @@ -204,7 +200,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) { CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImmNode::make(rtype, arith::floordiv(pa->value, pb->value)); + return IntImm(rtype, arith::floordiv(pa->value, pb->value)); } if (pa) { if (pa->value == 0) return a; @@ -230,7 +226,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { - return IntImmNode::make(rtype, arith::floormod(pa->value, pb->value)); + return IntImm(rtype, arith::floormod(pa->value, pb->value)); } if (pa) { if (pa->value == 0) return a; @@ -247,7 +243,7 @@ template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, std::min(pa->value, pb->value)); + if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); if (fa && fb) return FloatImmNode::make(rtype, std::min(fa->value, fb->value)); }); if (a.same_as(b)) return a; @@ -258,7 +254,7 @@ template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, std::max(pa->value, pb->value)); + if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); if (fa && fb) return FloatImmNode::make(rtype, std::max(fa->value, fb->value)); }); if (a.same_as(b)) return a; @@ -268,8 +264,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value > pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value > fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value); }); return PrimExpr(); } @@ -277,8 +273,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value >= pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value >= fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value); }); return PrimExpr(); } @@ -286,8 +282,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value < pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value < fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value); }); return PrimExpr(); } @@ -295,8 +291,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value <= pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value <= fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value); }); return PrimExpr(); } @@ -304,8 +300,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value == pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value == fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value); }); return PrimExpr(); } @@ -313,17 +309,16 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value != pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value != fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value); }); return PrimExpr(); } template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { - using ir::UIntImmNode; - const UIntImmNode* pa = a.as(); - const UIntImmNode* pb = b.as(); + const IntImmNode* pa = a.as(); + const IntImmNode* pb = b.as(); if (pa && pa->value) return b; if (pa && !pa->value) return a; if (pb && pb->value) return a; @@ -333,9 +328,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { - using ir::UIntImmNode; - const UIntImmNode* pa = a.as(); - const UIntImmNode* pb = b.as(); + const IntImmNode* pa = a.as(); + const IntImmNode* pb = b.as(); if (pa && pa->value) return a; if (pa && !pa->value) return b; if (pb && pb->value) return b; @@ -345,10 +339,9 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a) { - using ir::UIntImmNode; - const UIntImmNode* pa = a.as(); + const IntImmNode* pa = a.as(); if (pa) { - return UIntImmNode::make(DataType::UInt(1), !(pa->value)); + return IntImm(DataType::UInt(1), !(pa->value)); } return PrimExpr(); } diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index a041e40abf46..25d88d3429b6 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -150,14 +150,6 @@ class ConstIntBoundAnalyzer::Impl : return MakeBound(op->value, op->value); } - Entry VisitExpr_(const UIntImmNode* op) final { - if (op->value <= static_cast(kPosInf)) { - return MakeBound(op->value, op->value); - } else { - return Everything(op->dtype); - } - } - Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); @@ -496,7 +488,7 @@ class ConstIntBoundAnalyzer::Impl : */ static std::vector DetectBoundInfo(const PrimExpr& cond) { PVar x, y; - PVar c; + PVar c; // NOTE: canonical form always use <= or < if ((c <= x).Match(cond)) { return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value, kPosInf))}; diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index ceaa976469e8..37d5e9eb5e57 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -384,10 +384,6 @@ class IntervalSetEvaluator : return IntervalSet::SinglePoint(GetRef(op)); } - IntervalSet VisitExpr_(const UIntImmNode* op) final { - return IntervalSet::SinglePoint(GetRef(op)); - } - IntervalSet VisitExpr_(const VarNode* op) final { Var var = GetRef(op); auto it = dom_map_.find(var); @@ -476,7 +472,7 @@ class IntervalSetEvaluator : IntervalSet VisitExpr_(const RampNode* op) final { CHECK(eval_vec_); IntervalSet base = Eval(op->base); - PVar stride; + PVar stride; if (stride.Match(op->stride)) { DataType t = op->base.dtype(); int64_t vstride = stride.Eval()->value; diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 01dd2e8e499e..c81842035c9f 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -109,7 +109,7 @@ class ModularSetAnalyzer::Impl : // Detect useful constraints and use them in the analysis scope. std::function EnterConstraint(const PrimExpr& constraint) { PVar var; - PVar coeff, base; + PVar coeff, base; // pattern match interesting constraints if ((truncmod(var, coeff) == base).Match(constraint) || (floormod(var, coeff) == base).Match(constraint)) { @@ -132,14 +132,6 @@ class ModularSetAnalyzer::Impl : return Entry(0, op->value); } - Entry VisitExpr_(const UIntImmNode* op) final { - if (op->value < std::numeric_limits::max()) { - return Entry(0, static_cast(op->value)); - } else { - return Everything(); - } - } - Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); diff --git a/src/arithmetic/pattern_match.h b/src/arithmetic/pattern_match.h index 733dcf41ce94..a236e65a8312 100644 --- a/src/arithmetic/pattern_match.h +++ b/src/arithmetic/pattern_match.h @@ -45,7 +45,7 @@ * } * * tvm::Var tx, ty; - * arith::PVar c; + * arith::PVar c; * arith::PVar v; * // We can match integer and Var, both of which are * // special case container of Expr @@ -140,9 +140,9 @@ class PEqualChecker { }; template<> -class PEqualChecker { +class PEqualChecker { public: - bool operator()(const Integer& lhs, const Integer& rhs) const { + bool operator()(const IntImm& lhs, const IntImm& rhs) const { return lhs->value == rhs->value; } }; diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 94d951da51db..e6e1524604ce 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -124,7 +124,7 @@ VisitExpr_(const AddNode* op) { // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm - PVar c1, c2, c3; + PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules @@ -239,7 +239,7 @@ VisitExpr_(const SubNode* op) { // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm - PVar c1, c2, c3; + PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules @@ -438,7 +438,7 @@ VisitExpr_(const MulNode* op) { // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules @@ -477,7 +477,7 @@ VisitExpr_(const DivNode* op) { // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2, c3; + PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -700,7 +700,7 @@ VisitExpr_(const ModNode* op) { // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -789,7 +789,7 @@ VisitExpr_(const FloorDivNode* op) { // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2, c3; + PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -934,7 +934,7 @@ VisitExpr_(const FloorModNode* op) { // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -1004,7 +1004,7 @@ VisitExpr_(const MinNode* op) { // Pattern var to match any expression PVar x, y, z, s1, s2; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; PVar lanes; // vector rule @@ -1189,7 +1189,7 @@ VisitExpr_(const MaxNode* op) { // Pattern var to match any expression PVar x, y, z, s1, s2; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; PVar lanes; // vector rule @@ -1362,7 +1362,7 @@ VisitExpr_(const EQNode* op) { // Pattern var to match any expression PVar x, y; // Pattern var match IntImm - PVar c1; + PVar c1; PVar lanes; // vector rule @@ -1416,7 +1416,7 @@ VisitExpr_(const LTNode* op) { // Pattern var to match any expression PVar x, y, z, s1, s2; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; PVar lanes; // vector rule @@ -1597,7 +1597,7 @@ VisitExpr_(const AndNode* op) { // Pattern var to match any expression PVar x, y; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; PVar lanes; if (op->dtype.lanes() != 1) { @@ -1646,7 +1646,7 @@ VisitExpr_(const OrNode* op) { // Pattern var to match any expression PVar x, y; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; PVar lanes; if (op->dtype.lanes() != 1) { diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc index cf138edd494e..55ed36ca9352 100644 --- a/src/autotvm/touch_extractor.cc +++ b/src/autotvm/touch_extractor.cc @@ -256,7 +256,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > Array attr{std::string("_attr_"), FloatImmNode::make(DataType::Float(32), trans(fea.length)), - IntImmNode::make(DataType::Int(32), fea.nest_level), + IntImm(DataType::Int(32), fea.nest_level), FloatImmNode::make(DataType::Float(32), trans(fea.topdown_product)), FloatImmNode::make(DataType::Float(32), trans(fea.bottomup_product)), }; diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 777ad6203008..d9b7f7f08d12 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -372,16 +372,17 @@ inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // } } -inline void PrintConst(const UIntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - if (op->dtype == DataType::UInt(32)) { + +inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os, CodeGenC* p) { // NOLINT(*) + if (dtype == DataType::UInt(32)) { std::ostringstream temp; - temp << op->value << "U"; + temp << val << "U"; p->MarkConst(temp.str()); os << temp.str(); } else { os << "("; - p->PrintType(op->dtype, os); - os << ")" << op->value; + p->PrintType(dtype, os); + os << ")" << val; } } @@ -408,9 +409,7 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { void CodeGenC::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } -void CodeGenC::VisitExpr_(const UIntImmNode* op, std::ostream& os) { // NOLINT(*) - PrintConst(op, os, this); -} + void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } @@ -528,6 +527,12 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) os << ")"; } else if (op->is_intrinsic(CallNode::bitwise_and)) { PrintBinaryIntrinsic(op, " & ", os, this); + } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { + CHECK_EQ(op->args.size(), 2U); + uint64_t low = static_cast(Downcast(op->args[0])->value); + uint64_t high = static_cast(Downcast(op->args[1])->value); + uint64_t val = (high << 32U) | low; + PrintUIntConst(op->dtype, val, os, this); } else if (op->is_intrinsic(CallNode::bitwise_xor)) { PrintBinaryIntrinsic(op, " ^ ", os, this); } else if (op->is_intrinsic(CallNode::bitwise_or)) { diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index cb092c566322..7e5dd4269c94 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -128,7 +128,6 @@ class CodeGenC : void VisitExpr_(const ShuffleNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const UIntImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) // statment diff --git a/src/codegen/codegen_opengl.cc b/src/codegen/codegen_opengl.cc index 7967c1847ac2..cea276d5cb1a 100644 --- a/src/codegen/codegen_opengl.cc +++ b/src/codegen/codegen_opengl.cc @@ -247,11 +247,6 @@ void CodeGenOpenGL::VisitExpr_(const IntImmNode* op, std::ostream& os) { CodeGenC::VisitExpr_(op, os); } -void CodeGenOpenGL::VisitExpr_(const UIntImmNode* op, std::ostream& os) { - CHECK_EQ(op->dtype, DataType::UInt(32)) << "GLSL 3.0 only supports 32-bit uints."; - CodeGenC::VisitExpr_(op, os); -} - void CodeGenOpenGL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { CHECK_EQ(op->dtype, DataType::Float(32)) << "GLSL 3.0 only supports 32-bit floats."; CodeGenC::VisitExpr_(op, os); diff --git a/src/codegen/codegen_opengl.h b/src/codegen/codegen_opengl.h index cd1ec83360c6..19ca2ee12c6c 100644 --- a/src/codegen/codegen_opengl.h +++ b/src/codegen/codegen_opengl.h @@ -50,7 +50,6 @@ class CodeGenOpenGL final : public CodeGenC { // Codegen for immediate values void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const UIntImmNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) final; // NOLINT(*) diff --git a/src/codegen/llvm/codegen_arm.cc b/src/codegen/llvm/codegen_arm.cc index 6879fd5f8542..44862cf7a97c 100644 --- a/src/codegen/llvm/codegen_arm.cc +++ b/src/codegen/llvm/codegen_arm.cc @@ -48,7 +48,7 @@ class CodeGenARM final : public CodeGenCPU { llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { if (op->is_intrinsic("llvm_intrin")) { llvm::Intrinsic::ID id = static_cast( - op->args[0].as()->value); + Downcast(op->args[0])->value); if (id == ::llvm::Intrinsic::ctpop) { PrimExpr e = ARMPopcount(op); return CodeGenCPU::CreateIntrinsic(e.as()); @@ -68,8 +68,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { if (!call->dtype.is_vector() || call->dtype.bits() == 8 || (total_size != 128 && total_size != 64)) { Array vcnt_args; - vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id)); - vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); + vcnt_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt_args.push_back(e); return ir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic); } @@ -93,16 +93,16 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { const CallNode* c0 = input8.as(); CHECK(c0 != nullptr); Array vcnt8_args; - vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id)); - vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); + vcnt8_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt8_args.push_back(input8); PrimExpr vcnt8 = ir::CallNode::make( uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic); // Accumulation 8->16bit Array vcnt16_args; - vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id)); - vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); + vcnt16_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt16_args.push_back(vcnt8); PrimExpr vcnt16 = ir::CallNode::make( uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic); @@ -112,8 +112,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { // Accumulation 16->32bit Array vcnt32_args; - vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id)); - vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); + vcnt32_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt32_args.push_back(vcnt16); PrimExpr vcnt32 = ir::CallNode::make( uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic); @@ -123,8 +123,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { // Accumulation 32->64bit Array vcnt64_args; - vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id)); - vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); + vcnt64_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt64_args.push_back(vcnt32); return ir::CallNode::make( call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic); diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index c04a023aefad..60d8146fc0e6 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -662,15 +662,13 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { if (op->is_intrinsic("llvm_intrin")) { CHECK_GE(op->args.size(), 2U); llvm::Intrinsic::ID id = static_cast( - op->args[0].as()->value); - const uint64_t *num_signature = as_const_uint(op->args[1]); - CHECK(num_signature) << "The second argument should be a uint represents number of arguments, " - << "but " << op->args[1] << " got!\n"; + Downcast(op->args[0])->value); + int64_t num_signature = Downcast(op->args[1])->value; std::vector arg_value; std::vector sig_type; for (size_t i = 2; i < op->args.size(); ++i) { arg_value.push_back(MakeValue(op->args[i])); - if (i - 2 < *num_signature) { + if (i - 2 < static_cast(num_signature)) { sig_type.push_back(arg_value.back()->getType()); } } @@ -722,6 +720,12 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return llvm::Constant::getNullValue(t_void_p_); } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { return builder_->CreateIsNull(MakeValue(op->args[0])); + } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { + CHECK_EQ(op->args.size(), 2U); + uint64_t low = static_cast(Downcast(op->args[0])->value); + uint64_t high = static_cast(Downcast(op->args[1])->value); + uint64_t val = (high << 32U) | low; + return llvm::ConstantInt::get(LLVMType(op->dtype), val); } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { CHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition"; @@ -804,10 +808,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) { return llvm::ConstantInt::getSigned(LLVMType(op->dtype), op->value); } -llvm::Value* CodeGenLLVM::VisitExpr_(const UIntImmNode* op) { - return llvm::ConstantInt::get(LLVMType(op->dtype), op->value); -} - llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) { return llvm::ConstantFP::get(LLVMType(op->dtype), op->value); } diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index 34c3ee723e18..b269f2423fc8 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -106,7 +106,6 @@ class CodeGenLLVM : llvm::Value* VisitExpr_(const VarNode* op) override; llvm::Value* VisitExpr_(const CastNode* op) override; llvm::Value* VisitExpr_(const IntImmNode* op) override; - llvm::Value* VisitExpr_(const UIntImmNode* op) override; llvm::Value* VisitExpr_(const FloatImmNode* op) override; llvm::Value* VisitExpr_(const StringImmNode* op) override; llvm::Value* VisitExpr_(const AddNode* op) override; diff --git a/src/codegen/llvm/codegen_x86_64.cc b/src/codegen/llvm/codegen_x86_64.cc index 03656cc70a46..11bda70fb8cf 100644 --- a/src/codegen/llvm/codegen_x86_64.cc +++ b/src/codegen/llvm/codegen_x86_64.cc @@ -96,8 +96,8 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { MakeValue( ir::BroadcastNode::make( ir::FloatImmNode::make(DataType::Float(32), 0), from.lanes())), - /*mask=*/MakeValue(ir::IntImmNode::make(DataType::Int(16), -1)), - /*rounding-mode=*/MakeValue(ir::IntImmNode::make(DataType::Int(32), 4)), + /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), + /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)), }); } diff --git a/src/codegen/llvm/intrin_rule_llvm.h b/src/codegen/llvm/intrin_rule_llvm.h index b3ab557ee215..1f839f362f40 100644 --- a/src/codegen/llvm/intrin_rule_llvm.h +++ b/src/codegen/llvm/intrin_rule_llvm.h @@ -43,8 +43,8 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { CHECK(call != nullptr); Array cargs; // intrin id. - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id)); - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature)); + cargs.push_back(IntImm(DataType::UInt(32), id)); + cargs.push_back(IntImm(DataType::UInt(32), num_signature)); for (PrimExpr arg : call->args) { cargs.push_back(arg); @@ -60,8 +60,8 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { CHECK(call != nullptr); Array cargs; // intrin id. - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id)); - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature)); + cargs.push_back(IntImm(DataType::UInt(32), id)); + cargs.push_back(IntImm(DataType::UInt(32), num_signature)); for (PrimExpr arg : call->args) { cargs.push_back(arg); } diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index a749424892e2..985f6816a640 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -136,10 +136,6 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const IntImmNode* op) { return builder_->IntImm(builder_->GetSType(op->dtype), op->value); } -spirv::Value CodeGenSPIRV::VisitExpr_(const UIntImmNode* op) { - return builder_->UIntImm(builder_->GetSType(op->dtype), op->value); -} - spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImmNode* op) { return builder_->FloatImm(builder_->GetSType(op->dtype), op->value); } @@ -242,7 +238,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { if (op->is_intrinsic("spirv_glsl450")) { CHECK_GE(op->args.size(), 2U); - uint32_t inst_id = op->args[0].as()->value; + uint32_t inst_id = static_cast( + op->args[0].as()->value); std::vector values; for (size_t i = 1; i < op->args.size(); ++i) { values.push_back(MakeValue(op->args[i])); @@ -285,6 +282,12 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { } else if (op->is_intrinsic(CallNode::reinterpret)) { return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype), MakeValue(op->args[0])); + } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { + CHECK_EQ(op->args.size(), 2U); + uint64_t low = static_cast(Downcast(op->args[0])->value); + uint64_t high = static_cast(Downcast(op->args[1])->value); + uint64_t val = (high << 32U) | low; + return builder_->UIntImm(builder_->GetSType(op->dtype), val); } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { return this->CreateStorageSync(op); } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { diff --git a/src/codegen/spirv/codegen_spirv.h b/src/codegen/spirv/codegen_spirv.h index 3804bda0f2e0..5aa7f9c49910 100644 --- a/src/codegen/spirv/codegen_spirv.h +++ b/src/codegen/spirv/codegen_spirv.h @@ -65,7 +65,6 @@ class CodeGenSPIRV: spirv::Value VisitExpr_(const VarNode* op) override; spirv::Value VisitExpr_(const CastNode* op) override; spirv::Value VisitExpr_(const IntImmNode* op) override; - spirv::Value VisitExpr_(const UIntImmNode* op) override; spirv::Value VisitExpr_(const FloatImmNode* op) override; spirv::Value VisitExpr_(const StringImmNode* op) override; spirv::Value VisitExpr_(const AddNode* op) override; diff --git a/src/codegen/spirv/intrin_rule_spirv.cc b/src/codegen/spirv/intrin_rule_spirv.cc index d41d96db5165..d96883ed02fd 100644 --- a/src/codegen/spirv/intrin_rule_spirv.cc +++ b/src/codegen/spirv/intrin_rule_spirv.cc @@ -39,7 +39,7 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { CHECK(call != nullptr); Array cargs; // intrin id. - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id)); + cargs.push_back(IntImm(DataType::UInt(32), id)); for (PrimExpr arg : call->args) { cargs.push_back(arg); diff --git a/src/codegen/spirv/ir_builder.cc b/src/codegen/spirv/ir_builder.cc index 6f8d96e148c1..bf43f11cce02 100644 --- a/src/codegen/spirv/ir_builder.cc +++ b/src/codegen/spirv/ir_builder.cc @@ -342,9 +342,9 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { if (dtype.type == DataType::UInt(1)) { // bool types. if (*pvalue) { - ib_.Begin(spv::OpConstantTrue).AddSeq(ret); + ib_.Begin(spv::OpConstantTrue).AddSeq(dtype, ret); } else { - ib_.Begin(spv::OpConstantFalse).AddSeq(ret); + ib_.Begin(spv::OpConstantFalse).AddSeq(dtype, ret); } } else { // Integral/floating-point types. diff --git a/src/codegen/spirv/ir_builder.h b/src/codegen/spirv/ir_builder.h index 3843cbb3c6a9..5d25e8634e84 100644 --- a/src/codegen/spirv/ir_builder.h +++ b/src/codegen/spirv/ir_builder.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/codegen/stackvm/codegen_stackvm.cc b/src/codegen/stackvm/codegen_stackvm.cc index eccff6c74c2e..01096ae1dd46 100644 --- a/src/codegen/stackvm/codegen_stackvm.cc +++ b/src/codegen/stackvm/codegen_stackvm.cc @@ -280,12 +280,6 @@ void CodeGenStackVM::VisitExpr_(const IntImmNode* op) { this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); } -void CodeGenStackVM::VisitExpr_(const UIntImmNode* op) { - CHECK(op->value <= std::numeric_limits::max()) - << "Int constant exceed bound"; - this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); -} - void CodeGenStackVM::VisitExpr_(const FloatImmNode* op) { LOG(FATAL) << "Float Imm is not supported"; } diff --git a/src/codegen/stackvm/codegen_stackvm.h b/src/codegen/stackvm/codegen_stackvm.h index 07989b2062e1..1360cc2d70f1 100644 --- a/src/codegen/stackvm/codegen_stackvm.h +++ b/src/codegen/stackvm/codegen_stackvm.h @@ -136,7 +136,6 @@ class CodeGenStackVM void VisitExpr_(const RampNode* op) final; void VisitExpr_(const BroadcastNode* op) final; void VisitExpr_(const IntImmNode* op) final; - void VisitExpr_(const UIntImmNode* op) final; void VisitExpr_(const FloatImmNode* op) final; void VisitExpr_(const StringImmNode* op) final; // statment diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 7e3d44f26aef..346ec3808919 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -79,10 +79,7 @@ void CodeGenHybrid::PrintType(DataType t, std::ostream &os) { void CodeGenHybrid::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) os << op->value; } -void CodeGenHybrid::VisitExpr_(const UIntImmNode* op, std::ostream& os) { // NOLINT(*) - PrintType(op->dtype, os); - os << "(" << op->value << ")"; -} + void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) PrintType(op->dtype, os); os << "(" << std::setprecision(20) << op->value << ")"; diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 89a1ece577f9..33bd0efae8a4 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -117,7 +117,6 @@ class CodeGenHybrid : void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const UIntImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) // statment diff --git a/src/ir/expr.cc b/src/ir/expr.cc index f698a5d1802e..6d89967416b0 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -26,6 +26,25 @@ namespace tvm { +IntImm::IntImm(DataType dtype, int64_t value) { + CHECK(dtype.is_scalar()) + << "ValueError: IntImm can only take scalar."; + CHECK(dtype.is_int() || dtype.is_uint()) + << "ValueError: IntImm can only take scalar."; + if (dtype.is_uint()) { + CHECK_GE(value, 0U); + } + ObjectPtr node = make_object(); + node->dtype = dtype; + node->value = value; + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("make.IntImm") +.set_body_typed([](DataType dtype, int64_t value) { + return IntImm(dtype, value); +}); + GlobalVar::GlobalVar(std::string name_hint) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h index 34ee4b3159a5..4fffc475a773 100644 --- a/src/lang/attr_functor.h +++ b/src/lang/attr_functor.h @@ -77,7 +77,6 @@ class AttrFunctor { virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::UIntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; // deep comparison of symbolic integer expressions. @@ -113,7 +112,6 @@ class AttrFunctor { ATTR_FUNCTOR_DISPATCH(StrMapNode); ATTR_FUNCTOR_DISPATCH(ArrayNode); ATTR_FUNCTOR_DISPATCH(IntImmNode); - ATTR_FUNCTOR_DISPATCH(UIntImmNode); ATTR_FUNCTOR_DISPATCH(FloatImmNode); ATTR_FUNCTOR_DISPATCH(StringImmNode); ATTR_FUNCTOR_DISPATCH(VarNode); @@ -157,7 +155,6 @@ class AttrsEqualHandler : bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::IntImmNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::UIntImmNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::FloatImmNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::StringImmNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::AddNode* lhs, const ObjectRef& other) final; @@ -198,7 +195,6 @@ class AttrsHashHandler : protected: size_t VisitAttrDefault_(const Object* lhs) final; size_t VisitAttr_(const ir::IntImmNode* lhs) final; - size_t VisitAttr_(const ir::UIntImmNode* lhs) final; size_t VisitAttr_(const ir::FloatImmNode* lhs) final; size_t VisitAttr_(const ir::StringImmNode* lhs) final; size_t VisitAttr_(const ArrayNode* lhs) final; diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index 1d3e767a5b71..a590f10e78e5 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -97,13 +97,6 @@ bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef& other return false; } -bool AttrsEqualHandler::VisitAttr_(const UIntImmNode* lhs, const ObjectRef& other) { - if (const auto* rhs = other.as()) { - return lhs->value == rhs->value; - } - return false; -} - bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; @@ -224,10 +217,6 @@ size_t AttrsHashHandler::VisitAttr_(const IntImmNode* op) { return std::hash()(op->value); } -size_t AttrsHashHandler::VisitAttr_(const UIntImmNode* op) { - return std::hash()(op->value); -} - size_t AttrsHashHandler::VisitAttr_(const FloatImmNode* op) { return std::hash()(op->value); } diff --git a/src/lang/expr.cc b/src/lang/expr.cc index a7289369bcd4..55dfb89342a8 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -30,7 +30,7 @@ namespace tvm { PrimExpr::PrimExpr(int32_t value) - : PrimExpr(IntImmNode::make(DataType::Int(32), value)) {} + : PrimExpr(IntImm(DataType::Int(32), value)) {} PrimExpr::PrimExpr(float value) : PrimExpr(ir::FloatImmNode::make(DataType::Float(32), value)) {} @@ -54,15 +54,6 @@ Range::Range(PrimExpr begin, PrimExpr end) is_zero(begin) ? end : (end - begin))) { } -Integer IntImmNode::make(DataType t, int64_t value) { - CHECK(t.is_int() && t.is_scalar()) - << "ValueError: IntImm can only take scalar."; - ObjectPtr node = make_object(); - node->dtype = t; - node->value = value; - return Integer(node); -} - Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) { return Range(make_object(min, extent)); } diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index d3875e28c887..bd43d89d89d0 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -35,6 +35,14 @@ inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) { return ir::CastNode::make(t, value); } +PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { + return ir::CallNode::make( + t, ir::intrinsic::tvm_large_uint_imm, + {make_const(DataType::UInt(32), low), + make_const(DataType::UInt(32), high)}, + ir::CallNode::PureIntrinsic); +} + // The public function with a quick checking path. void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*) if (lhs.dtype() == rhs.dtype()) return; @@ -78,26 +86,25 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*) } } - // maximum and min limits PrimExpr max_value(const DataType& dtype) { using namespace ir; CHECK_EQ(dtype.lanes(), 1); if (dtype.is_int()) { if (dtype.bits() == 64) { - return IntImmNode::make(dtype, std::numeric_limits::max()); + return IntImm(dtype, std::numeric_limits::max()); } else if (dtype.bits() < 64) { int64_t val = 1; val = (val << (dtype.bits() - 1)) - 1; - return IntImmNode::make(dtype, val); + return IntImm(dtype, val); } } else if (dtype.is_uint()) { if (dtype.bits() == 64) { - return UIntImmNode::make(dtype, std::numeric_limits::max()); + return make_const(dtype, std::numeric_limits::max()); } else if (dtype.bits() < 64) { uint64_t val = 1; val = (val << static_cast(dtype.bits())) - 1; - return UIntImmNode::make(dtype, val); + return IntImm(dtype, static_cast(val)); } } else if (dtype.is_float()) { if (dtype.bits() == 64) { @@ -117,14 +124,14 @@ PrimExpr min_value(const DataType& dtype) { CHECK_EQ(dtype.lanes(), 1); if (dtype.is_int()) { if (dtype.bits() == 64) { - return IntImmNode::make(dtype, std::numeric_limits::lowest()); + return IntImm(dtype, std::numeric_limits::lowest()); } else if (dtype.bits() < 64) { int64_t val = 1; val = -(val << (dtype.bits() - 1)); - return IntImmNode::make(dtype, val); + return IntImm(dtype, val); } } else if (dtype.is_uint()) { - return UIntImmNode::make(dtype, 0); + return IntImm(dtype, 0); } else if (dtype.is_float()) { if (dtype.bits() == 64) { return FloatImmNode::make(dtype, std::numeric_limits::lowest()); @@ -155,24 +162,18 @@ inline bool ConstPowerHelper(ValueType val, int *shift) { bool is_const_power_of_two_integer(const PrimExpr& x, int* shift) { if (const auto* op = x.as()) { return ConstPowerHelper(op->value, shift); - } else if (const auto* op = x.as()) { - return ConstPowerHelper(op->value, shift); } else { return false; } } PrimExpr cast(const DataType& t, PrimExpr value) { - using ir::IntImmNode; - using ir::UIntImmNode; using ir::FloatImmNode; if (value.dtype() == t) return value; // const fold IntImm as they are used in index computations if (t.lanes() == 1) { if (const IntImmNode* op = value.as()) { return make_const(t, op->value); - } else if (const UIntImmNode* op = value.as()) { - return make_const(t, op->value); } else if (const FloatImmNode* op = value.as()) { return make_const(t, op->value); } @@ -184,8 +185,6 @@ PrimExpr cast(const DataType& t, PrimExpr value) { if (value.dtype() != vtype) { if (const IntImmNode* op = value.as()) { value = make_const(vtype, op->value); - } else if (const UIntImmNode* op = value.as()) { - return make_const(t, op->value); } else if (const FloatImmNode* op = value.as()) { value = make_const(vtype, op->value); } else { @@ -219,7 +218,7 @@ PrimExpr operator-(PrimExpr a) { using ir::FloatImmNode; const IntImmNode* pa = a.as(); const FloatImmNode* fa = a.as(); - if (pa) return ir::IntImmNode::make(a.dtype(), -pa->value); + if (pa) return IntImm(a.dtype(), -pa->value); if (fa) return ir::FloatImmNode::make(a.dtype(), -fa->value); return make_zero(a.dtype()) - a; } @@ -322,18 +321,10 @@ PrimExpr max(PrimExpr a, PrimExpr b) { } PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { - using ir::IntImmNode; - using ir::UIntImmNode; CHECK(cond.dtype() == DataType::Bool(1)) << "if_then_else only accept the condition to be boolean type."; BinaryOpMatchTypes(true_value, false_value); - if (const UIntImmNode* op = cond.as()) { - if (op->value != 0) { - return true_value; - } else { - return false_value; - } - } else if (const IntImmNode* op = cond.as()) { + if (const IntImmNode* op = cond.as()) { if (op->value != 0) { return true_value; } else { @@ -424,7 +415,7 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, (pa->value >> pb->value)); + if (pa && pb) return IntImm(rtype, (pa->value >> pb->value)); if (pb) { if (pb->value == 0) return a; } @@ -437,7 +428,7 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, (pa->value << pb->value)); + if (pa && pb) return IntImm(rtype, (pa->value << pb->value)); if (pb) { if (pb->value == 0) return a; } @@ -450,7 +441,7 @@ PrimExpr operator&(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, (pa->value & pb->value)); + if (pa && pb) return IntImm(rtype, (pa->value & pb->value)); }); return ir::CallNode::make( a.dtype(), ir::CallNode::bitwise_and, { a, b }, ir::CallNode::PureIntrinsic); @@ -460,7 +451,7 @@ PrimExpr operator|(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, (pa->value | pb->value)); + if (pa && pb) return IntImm(rtype, (pa->value | pb->value)); }); return ir::CallNode::make( a.dtype(), ir::CallNode::bitwise_or, { a, b }, ir::CallNode::PureIntrinsic); @@ -470,7 +461,7 @@ PrimExpr operator^(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, (pa->value ^ pb->value)); + if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value)); }); return ir::CallNode::make( a.dtype(), ir::CallNode::bitwise_xor, { a, b }, ir::CallNode::PureIntrinsic); @@ -494,7 +485,7 @@ PrimExpr abs(PrimExpr x) { using ir::IntImmNode; const IntImmNode* px = x.as(); if (px) { - return ir::IntImmNode::make(x.dtype(), std::abs(px->value)); + return IntImm(x.dtype(), std::abs(px->value)); } return ir::SelectNode::make(x >= make_zero(x.dtype()), x, -x); } else if (x.dtype().is_float()) { diff --git a/src/lang/ir.cc b/src/lang/ir.cc index ad7f260226bd..f06a6be5e75a 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -31,14 +31,6 @@ namespace tvm { namespace ir { // constructors -PrimExpr UIntImmNode::make(DataType t, uint64_t value) { - CHECK(t.is_uint() && t.lanes() == 1) - << "ValueError: UIntImm can only take scalar"; - ObjectPtr node = make_object(); - node->dtype = t; - node->value = value; - return PrimExpr(node); -} PrimExpr FloatImmNode::make(DataType t, double value) { CHECK_EQ(t.lanes(), 1) @@ -248,7 +240,7 @@ PrimExpr ShuffleNode::make_concat(Array vectors) { int index = 0; for (const PrimExpr& e : vectors) { for (int i = 0; i < e.dtype().lanes(); ++i) { - indices.push_back(IntImmNode::make(DataType::Int(32), index++)); + indices.push_back(IntImm(DataType::Int(32), index++)); } } return make(vectors, indices); @@ -531,11 +523,6 @@ Stmt EvaluateNode::make(PrimExpr value) { } // Printers -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "(" << op->dtype << ")" << op->value; - }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) .set_dispatch([](const ObjectRef& node, NodePrinter* p) { @@ -1153,7 +1140,6 @@ TVM_REGISTER_NODE_TYPE(AnyNode); TVM_REGISTER_NODE_TYPE(AttrStmtNode); TVM_REGISTER_NODE_TYPE(FloatImmNode); TVM_REGISTER_NODE_TYPE(IntImmNode); -TVM_REGISTER_NODE_TYPE(UIntImmNode); TVM_REGISTER_NODE_TYPE(StringImmNode); TVM_REGISTER_NODE_TYPE(CastNode); TVM_REGISTER_NODE_TYPE(VarNode); diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index 2c04de3710fa..0f350d2d732e 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -179,11 +179,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, std::ostringstream type_err_msg; type_err_msg << arg_name << ".dtype is expected to be " << dtype; PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) == - UIntImmNode::make(DataType::UInt(8), dtype.code()) && + IntImm(DataType::UInt(8), dtype.code()) && TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) == - UIntImmNode::make(DataType::UInt(8), dtype.bits()) && + IntImm(DataType::UInt(8), dtype.bits()) && TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) == - UIntImmNode::make(DataType::UInt(16), dtype.lanes())); + IntImm(DataType::UInt(16), dtype.lanes())); asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop)); // data field if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData), @@ -193,7 +193,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, // mark alignment of external bufs init_nest_.emplace_back(AttrStmtNode::make( vptr, ir::attr::storage_alignment, - IntImmNode::make(DataType::Int(32), buffer->data_alignment), nop)); + IntImm(DataType::Int(32), buffer->data_alignment), nop)); } Var v_shape(arg_name + ".shape", DataType::Handle()); @@ -206,7 +206,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, Bind_(buffer->shape[k], cast(buffer->shape[k].dtype(), LoadNode::make(tvm_shape_type, v_shape, - IntImmNode::make(DataType::Int(32), k), const_true(1))), + IntImm(DataType::Int(32), k), const_true(1))), field_name.str(), true); } // strides field @@ -228,7 +228,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, PrimExpr svalue = cast( stype, LoadNode::make(tvm_shape_type, v_strides, - IntImmNode::make(DataType::Int(32), k), const_true(1))); + IntImm(DataType::Int(32), k), const_true(1))); conds.push_back(expect_stride == svalue); expect_stride = expect_stride * buffer->shape[k]; } @@ -251,7 +251,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, field_name << v_strides->name_hint << '[' << k << ']'; PrimExpr value = cast(buffer->shape[k].dtype(), LoadNode::make(tvm_shape_type, v_strides, - IntImmNode::make(DataType::Int(32), k), const_true(1))); + IntImm(DataType::Int(32), k), const_true(1))); value = tvm::if_then_else(is_null, stride, value); value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); Bind_(buffer->strides[k], value, field_name.str(), true); @@ -270,7 +270,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, Bind_(buffer->strides[k], cast(buffer->shape[k].dtype(), LoadNode::make(tvm_shape_type, v_strides, - IntImmNode::make(DataType::Int(32), k), const_true(1))), + IntImm(DataType::Int(32), k), const_true(1))), field_name.str(), true); } } diff --git a/src/pass/ir_deep_compare.cc b/src/pass/ir_deep_compare.cc index 6eacb145b29b..8c441510c51d 100644 --- a/src/pass/ir_deep_compare.cc +++ b/src/pass/ir_deep_compare.cc @@ -252,10 +252,6 @@ class IRDeepCompare : CompareValue(op->value, other.as()->value); } - void VisitExpr_(const UIntImmNode *op, const PrimExpr& other) final { - CompareValue(op->value, other.as()->value); - } - void VisitExpr_(const FloatImmNode *op, const PrimExpr& other) final { CompareValue(op->value, other.as()->value); } diff --git a/src/pass/ir_functor.cc b/src/pass/ir_functor.cc index 67acec674630..857206f8dd9f 100644 --- a/src/pass/ir_functor.cc +++ b/src/pass/ir_functor.cc @@ -260,7 +260,6 @@ DEFINE_BINOP_VISIT_(AndNode); DEFINE_BINOP_VISIT_(OrNode); void ExprVisitor::VisitExpr_(const IntImmNode* op) {} -void ExprVisitor::VisitExpr_(const UIntImmNode* op) {} void ExprVisitor::VisitExpr_(const FloatImmNode* op) {} void ExprVisitor::VisitExpr_(const StringImmNode* op) {} @@ -640,7 +639,6 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { } DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode) diff --git a/src/pass/lift_attr_scope.cc b/src/pass/lift_attr_scope.cc index 7b760fa4a672..5aba355b7003 100644 --- a/src/pass/lift_attr_scope.cc +++ b/src/pass/lift_attr_scope.cc @@ -180,9 +180,6 @@ class AttrScopeLifter : public StmtMutator { if (const IntImmNode* op = a.as()) { return op->value == b.as()->value; } - if (const UIntImmNode* op = a.as()) { - return op->value == b.as()->value; - } return false; } diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index ed8be8bb39fc..5684f4ef785f 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -173,7 +173,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const MaxNode* op) final { using namespace arith; PVar x, y; - PVar c; + PVar c; auto e = GetRef(op); if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 && diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index a0b07c293b05..d509169df0b1 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -120,7 +120,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const CommReducerNode *combiner = reduce_combiner_.back(); size_t size = combiner->result.size(); - const UIntImmNode *size_of_args = call->args[0].as(); + const IntImmNode *size_of_args = call->args[0].as(); CHECK(size_of_args) << call->args[0]->GetTypeKey(); CHECK_EQ(size, size_of_args->value); Array inits = combiner->identity_element; diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc index 8e7f1d86da74..01a97b7878be 100644 --- a/src/pass/lower_tvm_builtin.cc +++ b/src/pass/lower_tvm_builtin.cc @@ -129,8 +129,8 @@ class BuiltinLower : public StmtExprMutator { {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), cast(DataType::UInt(64), total_bytes), - IntImmNode::make(DataType::Int(32), op->dtype.code()), - IntImmNode::make(DataType::Int(32), op->dtype.bits())}, + IntImm(DataType::Int(32), op->dtype.code()), + IntImm(DataType::Int(32), op->dtype.bits())}, CallNode::Extern), body); diff --git a/src/pass/make_api.cc b/src/pass/make_api.cc index d5c73a2e8a75..5df36d0b2423 100644 --- a/src/pass/make_api.cc +++ b/src/pass/make_api.cc @@ -69,8 +69,8 @@ LoweredFunc MakeAPI(Stmt body, // load i-th argument as type t auto f_arg_value = [&](DataType t, int i) { Array call_args{v_packed_args, - IntImmNode::make(DataType::Int(32), i), - IntImmNode::make(DataType::Int(32), intrinsic::kTVMValueContent)}; + IntImm(DataType::Int(32), i), + IntImm(DataType::Int(32), intrinsic::kTVMValueContent)}; // load 64 bit version DataType api_type = APIType(t); PrimExpr res = CallNode::make( @@ -117,7 +117,7 @@ LoweredFunc MakeAPI(Stmt body, seq_init.emplace_back(LetStmtNode::make( tcode, LoadNode::make( DataType::Int(32), v_packed_arg_type_ids, - IntImmNode::make(DataType::Int(32), i), const_true(1)), + IntImm(DataType::Int(32), i), const_true(1)), nop)); DataType t = v_arg.dtype(); if (t.is_handle()) { diff --git a/src/pass/rewrite_unsafe_select.cc b/src/pass/rewrite_unsafe_select.cc index 224a81c12396..9fb19cc4b308 100644 --- a/src/pass/rewrite_unsafe_select.cc +++ b/src/pass/rewrite_unsafe_select.cc @@ -96,7 +96,6 @@ class UnsafeExprDetector : public ExprFunctor { return false; } bool VisitExpr_(const VarNode* op) final { return false; } - bool VisitExpr_(const UIntImmNode* op) final { return false; } bool VisitExpr_(const IntImmNode* op) final { return false; } bool VisitExpr_(const FloatImmNode* op) final { return false; } bool VisitExpr_(const StringImmNode* op) final { return false; } diff --git a/src/pass/tensor_core.cc b/src/pass/tensor_core.cc index bb57fe8c37d3..956f27c9319d 100644 --- a/src/pass/tensor_core.cc +++ b/src/pass/tensor_core.cc @@ -462,7 +462,7 @@ class BufferAnalyser : public StmtExprVisitor { strides = bi.strides; } else { for (size_t i = 1; i < bi.shape.size(); ++i) { - PrimExpr stride = IntImmNode::make(DataType::Int(32), 1); + PrimExpr stride = IntImm(DataType::Int(32), 1); for (size_t j = bi.shape.size() - 1; j >= i; --j) { stride = MulNode::make(stride, bi.shape[j]); } @@ -575,7 +575,7 @@ class BufferAnalyser : public StmtExprVisitor { strides = bi.strides; } else { for (size_t i = 1; i < bi.shape.size(); ++i) { - PrimExpr stride = IntImmNode::make(DataType::Int(32), 1); + PrimExpr stride = IntImm(DataType::Int(32), 1); for (size_t j = bi.shape.size() - 1; j >= i; --j) { stride = MulNode::make(stride, bi.shape[j]); } @@ -765,7 +765,7 @@ class ThreadIdxMutator : public StmtExprMutator { op = expr.as(); if (op != nullptr) { if (op->name_hint == "threadIdx.x") { - PrimExpr zero = IntImmNode::make(DataType::Int(32), 0); + PrimExpr zero = IntImm(DataType::Int(32), 0); return zero; } if (op->name_hint == "threadIdx.y") { @@ -934,7 +934,7 @@ class TensorCoreIRMutator : public StmtExprMutator { PrimExpr stride = strides[strides.size()-2]; // thread index unification inside a warp - PrimExpr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_); + PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); PrimExpr mutated_value = thread_idx_mutator(op->value); PrimExpr src = CallNode::make(value->dtype, @@ -984,7 +984,7 @@ class TensorCoreIRMutator : public StmtExprMutator { PrimExpr dst = it3->second; // thread index unification inside a warp - PrimExpr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_); + PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); dst = thread_idx_mutator(dst); dst = CallNode::make(DataType::Handle(), @@ -1089,7 +1089,7 @@ class TensorCoreIRMutator : public StmtExprMutator { Array strides; for (size_t i = 1; i < shape.size(); ++i) { - PrimExpr stride = IntImmNode::make(DataType::Int(32), 1); + PrimExpr stride = IntImm(DataType::Int(32), 1); for (size_t j = shape.size() - 1; j >= i; --j) { stride = MulNode::make(stride, shape[j]); } @@ -1097,7 +1097,7 @@ class TensorCoreIRMutator : public StmtExprMutator { } strides.push_back(make_const(DataType::Int(32), 1)); - PrimExpr elem_offset = IntImmNode::make(DataType::Int(32), 0); + PrimExpr elem_offset = IntImm(DataType::Int(32), 0); CHECK_EQ(call->args.size(), min_bound.size()); for (size_t i = 0; i < min_bound.size(); i++) { elem_offset = AddNode::make( diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc index b2c50f7a8bd2..26ad59189671 100644 --- a/src/pass/unroll_loop.cc +++ b/src/pass/unroll_loop.cc @@ -159,14 +159,10 @@ class LoopUnroller : public StmtExprMutator { // constant folding. PrimExpr extent = ir::Simplify(op->extent); const IntImmNode *v1 = extent.as(); - const UIntImmNode *v2 = extent.as(); int value = -1; if (v1 != nullptr) { value = static_cast(v1->value); } - if (v2 != nullptr) { - value = static_cast(v2->value); - } return value; } diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 00c40b2565bb..5ee4ce30c96d 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -88,7 +88,7 @@ Array GetShape(const Array& shape) { if (pval != nullptr) { CHECK_LE(pval[0], std::numeric_limits::max()); CHECK_GE(pval[0], std::numeric_limits::min()); - res.push_back(ir::IntImmNode::make(DataType::Int(32), *pval)); + res.push_back(IntImm(DataType::Int(32), *pval)); } else if (val->IsInstance()) { res.push_back(val.as()->ToVar()); } else { @@ -395,7 +395,7 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { // set inputs for (auto param : prim_func->params) { int state = param_states_[param]; - cache_node->shape_func_param_states.push_back(IntImmNode::make(DataType::Int(32), state)); + cache_node->shape_func_param_states.push_back(IntImm(DataType::Int(32), state)); if (state & kNeedInputData) { for (auto t : param_data_[param]) { cache_node->inputs.push_back(t); @@ -528,7 +528,7 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { auto ret_type = call_node->checked_type(); Array out_ndims; if (const auto* ttype = ret_type.as()) { - out_ndims.push_back(IntImmNode::make(DataType::Int(32), ttype->shape.size())); + out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); } else { auto rtype = ret_type.as(); // TODO(@icemelon): Allow recursive tuple @@ -536,7 +536,7 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { for (size_t i = 0; i < rtype->fields.size(); ++i) { auto ttype = rtype->fields[i].as(); CHECK(ttype); - out_ndims.push_back(IntImmNode::make(DataType::Int(32), ttype->shape.size())); + out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); } } // Call shape function diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index f66cce6b7b82..9966d9cc55ef 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -56,7 +56,7 @@ TensorType ConstantNode::tensor_type() const { CHECK_LE(data->shape[i], std::numeric_limits::max()); CHECK_GE(data->shape[i], std::numeric_limits::min()); shape.push_back( - tvm::ir::IntImmNode::make(DataType::Int(32), data->shape[i])); + tvm::IntImm(DataType::Int(32), data->shape[i])); } return TensorTypeNode::make(shape, dtype); diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 25650c7766cb..400a6bea22ed 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -857,10 +857,6 @@ class PrettyPrinter : return PrintConstScalar(op->dtype, &(op->value)); } - Doc VisitAttr_(const ir::UIntImmNode* op) final { - return PrintConstScalar(op->dtype, &(op->value)); - } - Doc VisitAttr_(const ir::FloatImmNode* op) final { return PrintConstScalar(op->dtype, &(op->value)); } diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 4d3a4b9589ee..b5383cd3339b 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -852,7 +852,7 @@ bool ArgWhereRel(const Array& types, const auto& input_rank = input_shape.size(); std::vector result_shape; result_shape.push_back(Any::make()); - result_shape.push_back(IntImmNode::make(DataType::Int(32), input_rank)); + result_shape.push_back(IntImm(DataType::Int(32), input_rank)); reporter->Assign(types[1], TensorTypeNode::make(result_shape, DataType::Int(32))); return true; } diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 594669343f62..30a9a5c80402 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -41,7 +41,7 @@ class TypeSolver::Reporter : public TypeReporterNode { } bool Assert(const IndexExpr& cond) final { - if (const uint64_t* pdiff = as_const_uint(cond)) { + if (const int64_t* pdiff = as_const_int(cond)) { return pdiff[0]; } return true; diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index 378a5e3728f4..2e332413c1f6 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -47,14 +47,10 @@ static inline Array get_shape(const Type& type) { static inline const int32_t GetQmin(const DataType& dtype) { CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; - if (dtype.is_int()) { + if (dtype.is_int() || dtype.is_uint()) { auto* min_value = as_const_int(tvm::min_value(dtype)); CHECK(min_value != nullptr); return static_cast(min_value[0]); - } else if (dtype.is_uint()) { - auto* min_value = as_const_uint(tvm::min_value(dtype)); - CHECK(min_value != nullptr); - return static_cast(min_value[0]); } else { LOG(FATAL) << "Type not supported " << dtype; return -1; // To hide the warning @@ -64,14 +60,10 @@ static inline const int32_t GetQmin(const DataType& dtype) { static inline const int32_t GetQmax(const DataType& dtype) { CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; - if (dtype.is_int()) { + if (dtype.is_int() || dtype.is_uint()) { auto* max_value = as_const_int(tvm::max_value(dtype)); CHECK(max_value != nullptr); return static_cast(max_value[0]); - } else if (dtype.is_uint()) { - auto* max_value = as_const_uint(tvm::max_value(dtype)); - CHECK(max_value != nullptr); - return static_cast(max_value[0]); } else { LOG(FATAL) << "Type not supported " << dtype; return -1; // To hide the warning diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index 5392eaeac1e8..193f2f206c06 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -127,10 +127,10 @@ TEST(Pattern, Basic) { } } -TEST(Pattern, Integer) { +TEST(Pattern, IntImm) { using namespace tvm; tvm::Var tx, ty; - arith::PVar c; + arith::PVar c; arith::PVar v; { // We can match integer and Var, both of which are diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index 45ecf9539337..5a10618fb269 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -18,6 +18,32 @@ from tvm.contrib import util import numpy as np +def test_large_uint_imm(): + value = (1 << 63) + 123 + other = tvm.const(3, "uint64") + n = 12 + num_thread = 2 + + A = tvm.compute((n,), lambda *i: tvm.const(value, "uint64") + other, name='A') + s = tvm.create_schedule(A.op) + xo, xi = s[A].split(A.op.axis[0], factor=num_thread) + s[A].bind(xi, tvm.thread_axis("threadIdx.x")) + s[A].bind(xo, tvm.thread_axis("blockIdx.x")) + + def check_target(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + return + f = tvm.build(s, [A], device) + # launch the kernel. + a = tvm.nd.empty((n, ), dtype=A.dtype, ctx=ctx) + f(a) + assert a.asnumpy()[0] == value + 3 + + check_target("cuda") + check_target("vulkan") + + def test_add_pipeline(): n = tvm.var('n') A = tvm.placeholder((n,), name='A') @@ -112,4 +138,5 @@ def check_module_save(device, host="stackvm"): if __name__ == "__main__": + test_large_uint_imm() test_add_pipeline() diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index 0e595cd79c97..4920206ee019 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -88,6 +88,25 @@ def test_llvm_lookup_intrin(): fcode = tvm.build(func, None, "llvm") +def test_llvm_large_uintimm(): + value = (1 << 63) + 123 + other = tvm.const(3, "uint64") + A = tvm.compute((), lambda : tvm.const(value, "uint64") + other, name='A') + s = tvm.create_schedule(A.op) + + def check_llvm(): + if not tvm.module.enabled("llvm"): + return + f = tvm.build(s, [A], "llvm") + ctx = tvm.cpu(0) + # launch the kernel. + a = tvm.nd.empty((), dtype=A.dtype, ctx=ctx) + f(a) + assert a.asnumpy() == value + 3 + + check_llvm() + + def test_llvm_add_pipeline(): nn = 1024 n = tvm.convert(nn) @@ -645,6 +664,7 @@ def vectorizer(op): tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32')) if __name__ == "__main__": + test_llvm_large_uintimm() test_llvm_import() test_alignment() test_rank_zero() diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index c3c40cf740ad..5f1facb2b45f 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -24,7 +24,7 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): def tvm_val_2_py_val(val): val = tvm.ir_pass.Substitute(val, var_dict) val = tvm.ir_pass.Simplify(val) - assert isinstance(val, (tvm.expr.IntImm, tvm.expr.UIntImm)) + assert isinstance(val, (tvm.expr.IntImm,)) return val.value ctx = tvm.context(target, 0) diff --git a/tests/python/unittest/test_lang_constructor.py b/tests/python/unittest/test_lang_constructor.py index fe329494e24e..c4187858a8a8 100644 --- a/tests/python/unittest/test_lang_constructor.py +++ b/tests/python/unittest/test_lang_constructor.py @@ -38,16 +38,11 @@ def test_expr_constructor(): assert x.value == 2 assert x.dtype == "int64" - x = tvm.expr.UIntImm("uint16", 2) - assert isinstance(x, tvm.expr.UIntImm) - assert x.value == 2 - assert x.dtype == "uint16" - x = tvm.expr.StringImm("xyza") assert isinstance(x, tvm.expr.StringImm) assert x.value == "xyza" - x = tvm.expr.Cast("float32", tvm.expr.IntImm("int32", 1)) + x = tvm.expr.Cast("float32", tvm.expr.IntImm("uint32", 1)) assert isinstance(x, tvm.expr.Cast) assert x.dtype == "float32" assert x.value.value == 1 diff --git a/tests/python/unittest/test_lang_operator.py b/tests/python/unittest/test_lang_operator.py index c57f4a1109ec..ac2ee6d88cc5 100644 --- a/tests/python/unittest/test_lang_operator.py +++ b/tests/python/unittest/test_lang_operator.py @@ -29,7 +29,7 @@ def test_const_fold(): def check(f, *args): x = f(*[tvm.const(x, "int32") for x in args]) y = f(*args) - if not isinstance(x, (tvm.expr.IntImm, tvm.expr.UIntImm)) or x.value != int(y): + if not isinstance(x, (tvm.expr.IntImm,)) or x.value != int(y): raise ValueError("check error: %s vs %s " % (x, y)) tmod = tvm.truncmod diff --git a/topi/include/topi/detail/constant_utils.h b/topi/include/topi/detail/constant_utils.h index 43ac3a29cd7c..e6de76f20881 100644 --- a/topi/include/topi/detail/constant_utils.h +++ b/topi/include/topi/detail/constant_utils.h @@ -43,8 +43,7 @@ using namespace tvm; */ inline bool IsConstInt(PrimExpr expr) { return - expr->IsInstance() || - expr->IsInstance(); + expr->IsInstance(); } /*! @@ -56,11 +55,8 @@ inline bool IsConstInt(PrimExpr expr) { * \return The integer value. */ inline int64_t GetConstInt(PrimExpr expr) { - if (expr->IsInstance()) { - return expr.as()->value; - } - if (expr->IsInstance()) { - return expr.as()->value; + if (expr->IsInstance()) { + return expr.as()->value; } LOG(ERROR) << "expr must be a constant integer"; return -1; diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index 8f32a297d719..02d082b8b342 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -92,9 +92,9 @@ def get_const_int(expr): """ if isinstance(expr, Integral): return expr - if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(expr, tvm.expr.IntImm): expr = tvm.ir_pass.Simplify(expr) - if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(expr, tvm.expr.IntImm): raise ValueError("Expect value to be constant int") return int(expr.value) @@ -136,9 +136,9 @@ def equal_const_int(expr, value): """ if isinstance(expr, Integral): return expr == value - if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(expr, tvm.expr.IntImm): expr = tvm.ir_pass.Simplify(expr) - if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(expr, tvm.expr.IntImm): return False return expr.value == value @@ -160,9 +160,9 @@ def get_const_tuple(in_tuple): for elem in in_tuple: if isinstance(elem, tvm.expr.Var): ret.append(elem) - elif not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm, int)): + elif not isinstance(elem, (tvm.expr.IntImm, int)): elem = tvm.ir_pass.Simplify(elem) - if not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(elem, tvm.expr.IntImm): ret.append(elem) else: ret.append(get_const_int(elem))