From 2a43201024e2806267081128f5924b0ca0cacd49 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 14 Jan 2020 19:03:26 -0800 Subject: [PATCH] [REFACTOR][IR] Unify IntImm and UIntImm (#4706) * [REFACTOR][IR] Unify IntImm and UIntImm This PR unifies UIntImm and IntImm to simplify the codebase. Unsigned integer constants will also be stored as IntImm. For uint constant that does not fit into int64(rare case), we introduced an intrinsic tvm_big_uint_imm to construct such intgers by its lower and higher 32bits. * [REFACTOR][IR] Remove UIntImm to use IntImm * rename big->large --- include/tvm/attrs.h | 4 -- include/tvm/expr.h | 48 +++++---------- include/tvm/expr_operator.h | 53 ++++++++-------- include/tvm/ir.h | 27 +++----- include/tvm/ir/expr.h | 50 +++++++++++++++ include/tvm/ir_functor_ext.h | 4 -- python/tvm/api.py | 3 + python/tvm/autotvm/task/task.py | 4 +- python/tvm/autotvm/util.py | 8 +-- python/tvm/expr.py | 17 ------ python/tvm/hybrid/calls.py | 2 +- python/tvm/hybrid/parser.py | 4 +- python/tvm/hybrid/util.py | 2 +- python/tvm/relay/frontend/tensorflow.py | 2 +- src/api/api_ir.cc | 2 - src/api/api_lang.cc | 3 + src/arithmetic/analyzer.cc | 6 +- src/arithmetic/canonical_simplify.cc | 8 +-- src/arithmetic/const_fold.h | 61 ++++++++----------- src/arithmetic/const_int_bound.cc | 10 +-- src/arithmetic/int_set.cc | 6 +- src/arithmetic/modular_set.cc | 10 +-- src/arithmetic/pattern_match.h | 6 +- src/arithmetic/rewrite_simplify.cc | 26 ++++---- src/autotvm/touch_extractor.cc | 2 +- src/codegen/codegen_c.cc | 21 ++++--- src/codegen/codegen_c.h | 1 - src/codegen/codegen_opengl.cc | 5 -- src/codegen/codegen_opengl.h | 1 - src/codegen/llvm/codegen_arm.cc | 22 +++---- src/codegen/llvm/codegen_llvm.cc | 18 +++--- src/codegen/llvm/codegen_llvm.h | 1 - src/codegen/llvm/codegen_x86_64.cc | 4 +- src/codegen/llvm/intrin_rule_llvm.h | 8 +-- src/codegen/spirv/codegen_spirv.cc | 13 ++-- src/codegen/spirv/codegen_spirv.h | 1 - src/codegen/spirv/intrin_rule_spirv.cc | 2 +- src/codegen/spirv/ir_builder.cc | 4 +- src/codegen/spirv/ir_builder.h | 4 +- src/codegen/stackvm/codegen_stackvm.cc | 6 -- src/codegen/stackvm/codegen_stackvm.h | 1 - src/contrib/hybrid/codegen_hybrid.cc | 5 +- src/contrib/hybrid/codegen_hybrid.h | 1 - src/ir/expr.cc | 19 ++++++ src/lang/attr_functor.h | 4 -- src/lang/attrs.cc | 11 ---- src/lang/expr.cc | 11 +--- src/lang/expr_operator.cc | 55 +++++++---------- src/lang/ir.cc | 16 +---- src/pass/arg_binder.cc | 16 ++--- src/pass/ir_deep_compare.cc | 4 -- src/pass/ir_functor.cc | 2 - src/pass/lift_attr_scope.cc | 3 - src/pass/lower_intrin.cc | 2 +- src/pass/lower_thread_allreduce.cc | 2 +- src/pass/lower_tvm_builtin.cc | 4 +- src/pass/make_api.cc | 6 +- src/pass/rewrite_unsafe_select.cc | 1 - src/pass/tensor_core.cc | 14 ++--- src/pass/unroll_loop.cc | 4 -- src/relay/backend/compile_engine.cc | 8 +-- src/relay/ir/expr.cc | 2 +- src/relay/ir/pretty_printer.cc | 4 -- src/relay/op/tensor/transform.cc | 2 +- src/relay/pass/type_solver.cc | 2 +- src/relay/qnn/util.h | 12 +--- tests/cpp/pattern_match_test.cc | 4 +- tests/python/unittest/test_codegen_device.py | 27 ++++++++ tests/python/unittest/test_codegen_llvm.py | 20 ++++++ tests/python/unittest/test_hybrid_script.py | 2 +- .../python/unittest/test_lang_constructor.py | 7 +-- tests/python/unittest/test_lang_operator.py | 2 +- topi/include/topi/detail/constant_utils.h | 10 +-- topi/python/topi/util.py | 12 ++-- 74 files changed, 361 insertions(+), 413 deletions(-) diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index ab9a711d28d8..9d9f98e79695 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -490,8 +490,6 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) { CHECK(expr.defined()); if (const ir::IntImmNode* op = expr.as()) { *ptr = static_cast(op->value); - } else if (const ir::UIntImmNode* op = expr.as()) { - *ptr = static_cast(op->value); } else { LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey(); } @@ -523,8 +521,6 @@ inline void SetValue(double* ptr, const TVMArgValue& val) { *ptr = static_cast(op->value); } else if (const ir::IntImmNode* op = expr.as()) { *ptr = static_cast(op->value); - } else if (const ir::UIntImmNode* op = expr.as()) { - *ptr = static_cast(op->value); } else { LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey(); } diff --git a/include/tvm/expr.h b/include/tvm/expr.h index faae303d95dd..62806c667e61 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -115,56 +115,38 @@ class Var : public PrimExpr { using ContainerType = VarNode; }; -class Integer; -/*! \brief ExprNode: constant integer. */ -class IntImmNode : public PrimExprNode { - public: - /*! \brief the Internal value. */ - int64_t value; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("value", &value); - } - - TVM_DLL static Integer make(DataType t, int64_t value); - - static constexpr const char* _type_key = "IntImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); -}; - /*! - * \brief Container of constant integer (IntImm). + * \brief Container of constant int that adds more constructors. * * This is used to store and automate type check * attributes that must be constant integer. + * + * \sa IntImm */ -class Integer : public PrimExpr { +class Integer : public IntImm { public: - Integer() : PrimExpr() {} + Integer() {} /*! * \brief constructor from node. */ - explicit Integer(ObjectPtr node) : PrimExpr(node) {} + explicit Integer(ObjectPtr node) : IntImm(node) {} /*! * \brief Construct integer from int value. */ - Integer(int value) : PrimExpr(value) {} // NOLINT(*) + Integer(int value) : IntImm(DataType::Int(32), value) {} // NOLINT(*) + /*! + * \brief Construct integer from int imm. + * \param other The other value. + */ + Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*) /*! * \brief Assign an expression to integer. * \param other another expression. */ - Integer& operator=(const Integer& other) { - data_ = other.data_; + Integer& operator=(const IntImm& other) { + data_ = ObjectRef::GetDataPtr(other); return *this; } - /*! - * \brief Get pointer to the internal value. - * \return the content of the integer. - */ - const IntImmNode* operator->() const { - return static_cast(get()); - } /*! * \brief convert to int64_t */ @@ -173,8 +155,6 @@ class Integer : public PrimExpr { << " Trying to reference a null Integer"; return (*this)->value; } - /*! \brief type indicate the container type */ - using ContainerType = IntImmNode; }; /*! \brief range over one dimension */ diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index 2d8f37855856..ff3b340bf1fa 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -30,6 +30,7 @@ #include #include +#include #include "expr.h" #include "ir.h" @@ -82,21 +83,6 @@ inline const int64_t* as_const_int(const PrimExpr& x) { } } -/*! - * \brief Get x as constant uint expression. - * \param x The expression - * \return the address to the int expression, - * return nullptr, if x is not UIntImm. - */ -inline const uint64_t* as_const_uint(const PrimExpr& x) { - if (!x.defined()) return nullptr; - if (const ir::UIntImmNode* op = x.as()) { - return &(op->value); - } else { - return nullptr; - } -} - /*! * \brief Check whether x is a constant integer expression. * \param x The input argument @@ -597,6 +583,15 @@ TVM_DLL PrimExpr nearbyint(PrimExpr x); */ TVM_DLL PrimExpr trunc(PrimExpr x); +/*! + * \brief Construct a large uint constant by its low 32 bits and high 32bits. + * \param dtype The final data type. + * \param low The lower 32 bits. + * \param high The higher 32 bits. + * \return The constructed expression. + */ +TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); + // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ inline PrimExpr OpName(PrimExpr x) { \ @@ -617,11 +612,11 @@ TVM_DECLARE_INTRIN_UNARY(atan); // Implementation details after this inline bool is_const(const PrimExpr& x) { - if (x.as() || x.as()) { + if (x.as()) { return true; } else if (const auto* op = x.as()) { const PrimExpr& val = op->value; - if (val.as() || val.as()) { + if (val.as()) { return true; } } @@ -631,8 +626,6 @@ inline bool is_const(const PrimExpr& x) { inline bool is_positive_const(const PrimExpr& a) { if (const ir::IntImmNode* op = a.as()) { return op->value > 0; - } else if (const ir::UIntImmNode* op = a.as()) { - return op->value > 0; } else { return false; } @@ -649,14 +642,10 @@ inline bool is_negative_const(const PrimExpr& a) { inline bool is_const_int(const PrimExpr& x, int64_t value) { if (const auto* op = x.as()) { return op->value == value; - } else if (const auto* op = x.as()) { - return op->value == static_cast(value); } else if (const auto* op = x.as()) { const PrimExpr& val = op->value; if (const auto* opv = val.as()) { return opv->value == value; - } else if (const auto* opv = val.as()) { - return opv->value == static_cast(value); } } return false; @@ -675,15 +664,27 @@ inline bool is_no_op(const Stmt& stmt) { template inline PrimExpr MakeConstScalar(DataType t, ValueType value) { - if (t.is_int()) return ir::IntImmNode::make(t, static_cast(value)); - if (t.is_uint()) return ir::UIntImmNode::make(t, static_cast(value)); + if (t.is_int()) return IntImm(t, static_cast(value)); + if (t.is_uint()) { + // Use IntImm if it is a small integer + uint64_t uval = static_cast(value); + if (uval <= static_cast(std::numeric_limits::max())) { + return IntImm(t, static_cast(value)); + } else { + uint64_t mask = (static_cast(1) << 32U) - 1U; + uint64_t low = uval & mask; + uint64_t high = uval >> 32U; + return LargeUIntImm(t, static_cast(low), static_cast(high)); + } + } if (t.is_float()) return ir::FloatImmNode::make(t, static_cast(value)); // For now, we store const scalar values of custom datatypes within doubles; later, during the // datatypes lowering pass, we will lower the value to its true representation in the format // specified by the datatype. // TODO(gus) when do we need to start worrying about doubles not being precise enough? - if (static_cast(t.code()) >= static_cast(kCustomBegin)) + if (static_cast(t.code()) >= static_cast(kCustomBegin)) { return ir::FloatImmNode::make(t, static_cast(value)); + } LOG(FATAL) << "cannot make const for type " << t; return PrimExpr(); } diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 84039485ae69..9c14a31be2fe 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -39,23 +39,6 @@ namespace ir { using IntImmNode = tvm::IntImmNode; using VarNode = tvm::VarNode; -/*! \brief constant unsigned integer. */ -class UIntImmNode : public PrimExprNode { - public: - /*! \brief The constant value content. */ - uint64_t value; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("value", &value); - } - - TVM_DLL static PrimExpr make(DataType t, uint64_t value); - - static constexpr const char* _type_key = "UIntImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(UIntImmNode, PrimExprNode); -}; - /*! \brief Floating point constants. */ class FloatImmNode : public PrimExprNode { public: @@ -1422,6 +1405,16 @@ inline bool IsPragmaKey(const std::string& attr_key) { /*! \brief namespace of TVM Intrinsic functions */ namespace intrinsic { +/*! + * \brief See pesudo code + * + * Construct a big uint that may not be representable by int64 + * + * Expr tvm_large_uint_imm(uint32_t v0, uin32_t v1) { + * return (v1 << 32) | v0; + * } + */ +constexpr const char* tvm_large_uint_imm = "tvm_large_uint_imm"; /*! * \brief See pesudo code * diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 12b34dd26398..12e505ed9ff6 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -131,6 +131,56 @@ class PrimExpr : public BaseExpr { using ContainerType = PrimExprNode; }; +/*! + * \brief Constant integer literals in the program. + * \sa IntImm + */ +class IntImmNode : public PrimExprNode { + public: + /*! \brief the Internal value. */ + int64_t value; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &dtype); + v->Visit("value", &value); + } + + static constexpr const char* _type_key = "IntImm"; + TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); +}; + +/*! + * \brief Managed reference class to IntImmNode. + * + * \sa IntImmNode + */ +class IntImm : public PrimExpr { + public: + /*! + * \brief Constructor + */ + IntImm() {} + /*! + * \brief constructor from node. + */ + explicit IntImm(ObjectPtr node) : PrimExpr(node) {} + /*! + * \brief Constructor. + * \param dtype The data type of the value. + * \param value The internal value. + */ + TVM_DLL IntImm(DataType dtype, int64_t value); + /*! + * \brief Get pointer to the internal value. + * \return the content of the integer. + */ + const IntImmNode* operator->() const { + return static_cast(get()); + } + /*! \brief type indicate the container type */ + using ContainerType = IntImmNode; +}; + /*! * \brief Base node of all non-primitive expressions. * diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index 7d57564fd3df..37a1fe4bffb2 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -161,7 +161,6 @@ class ExprFunctor { virtual R VisitExpr_(const BroadcastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ShuffleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const UIntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Object* op, Args ...) { @@ -203,7 +202,6 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(ShuffleNode); IR_EXPR_FUNCTOR_DISPATCH(BroadcastNode); IR_EXPR_FUNCTOR_DISPATCH(IntImmNode); - IR_EXPR_FUNCTOR_DISPATCH(UIntImmNode); IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode); IR_EXPR_FUNCTOR_DISPATCH(StringImmNode); return vtable; @@ -327,7 +325,6 @@ class TVM_DLL ExprVisitor : void VisitExpr_(const BroadcastNode* op) override; void VisitExpr_(const ShuffleNode* op) override; void VisitExpr_(const IntImmNode* op) override; - void VisitExpr_(const UIntImmNode* op) override; void VisitExpr_(const FloatImmNode* op) override; void VisitExpr_(const StringImmNode* op) override; }; @@ -372,7 +369,6 @@ class TVM_DLL ExprMutator : PrimExpr VisitExpr_(const BroadcastNode* op) override; PrimExpr VisitExpr_(const ShuffleNode* op) override; PrimExpr VisitExpr_(const IntImmNode* op) override; - PrimExpr VisitExpr_(const UIntImmNode* op) override; PrimExpr VisitExpr_(const FloatImmNode* op) override; PrimExpr VisitExpr_(const StringImmNode* op) override; }; diff --git a/python/tvm/api.py b/python/tvm/api.py index 7395d3524709..4bfe794c14d3 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -92,6 +92,9 @@ def const(value, dtype=None): """ if dtype is None: dtype = _scalar_type_inference(value) + if dtype == "uint64" and value >= (1 << 63): + return _api_internal._LargeUIntImm( + dtype, value & ((1 << 32) - 1), value >> 32) return _api_internal._const(value, dtype) diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 7f36914eb0a6..5067277d32a8 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -221,7 +221,7 @@ def args_to_workload(x, topi_compute_func=None): workload = tuple([args_to_workload(a) for a in x]) elif isinstance(x, (str, int, float, np.int, np.float, expr.Var)): workload = x - elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)): + elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)): workload = x.value elif x is None: workload = 0 @@ -344,7 +344,7 @@ def _count_flop(exp): if len(source) != 1: raise FlopCalculationError("Found multiple output in the source of reduce op") return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0])) - if isinstance(exp, (expr.FloatImm, expr.IntImm, expr.UIntImm)): + if isinstance(exp, (expr.FloatImm, expr.IntImm)): return 0 if isinstance(exp, expr.Cast): return _count_flop(exp.value) diff --git a/python/tvm/autotvm/util.py b/python/tvm/autotvm/util.py index 3026914aed20..54001d3338ad 100644 --- a/python/tvm/autotvm/util.py +++ b/python/tvm/autotvm/util.py @@ -155,9 +155,9 @@ def get_const_int(exp): """ if isinstance(exp, int): return exp - if not isinstance(exp, (expr.IntImm, expr.UIntImm)): + if not isinstance(exp, (expr.IntImm,)): exp = ir_pass.Simplify(exp) - if not isinstance(exp, (expr.IntImm, expr.UIntImm)): + if not isinstance(exp, (expr.IntImm,)): raise ValueError("Expect value to be constant int") return exp.value @@ -179,9 +179,9 @@ def get_const_tuple(in_tuple): for elem in in_tuple: if isinstance(elem, expr.Var): ret.append(elem) - elif not isinstance(elem, (expr.IntImm, expr.UIntImm, int)): + elif not isinstance(elem, (expr.IntImm, int)): elem = ir_pass.Simplify(elem) - if not isinstance(elem, (expr.IntImm, expr.UIntImm)): + if not isinstance(elem, (expr.IntImm)): ret.append(elem) else: ret.append(get_const_int(elem)) diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 71c0aecd1f6a..2fd7b78d9d66 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -341,23 +341,6 @@ def __int__(self): return self.value -@register_object -class UIntImm(ConstExpr): - """UInt constant. - - Parameters - ---------- - dtype : str - The data type - - value : int - The constant value. - """ - def __init__(self, dtype, value): - self.__init_handle_by_constructor__( - _make.UIntImm, dtype, value) - - @register_object class StringImm(ConstExpr): """String constant. diff --git a/python/tvm/hybrid/calls.py b/python/tvm/hybrid/calls.py index 1d5612e67e80..7038f6144db3 100644 --- a/python/tvm/hybrid/calls.py +++ b/python/tvm/hybrid/calls.py @@ -156,6 +156,6 @@ def max_num_threads(func_id, args): if args.__len__() == 0: res = _tgt.current_target().max_num_threads else: - _internal_assert(isinstance(args[0], _expr.UIntImm), "In tvm bool should be uint") + _internal_assert(isinstance(args[0], _expr.IntImm), "In tvm bool should be uint") res = _tgt.current_target(args[0].value).max_num_threads return _api.convert(res) diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 06bcbcabe0c3..57d636328816 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -386,7 +386,7 @@ def visit_Subscript(self, node): if isinstance(i, numbers.Integral): arr = arr[i] else: - _internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \ + _internal_assert(isinstance(i, (_expr.IntImm,)), \ "All indices are supposed to be constants") arr = arr[i.value] return arr @@ -413,7 +413,7 @@ def visit_If(self, node): cond = _ir_pass.CanonicalSimplify(self.visit(node.test)) # Return no IfThenElse if proven - if isinstance(cond, _expr.UIntImm): + if isinstance(cond, _expr.IntImm): if cond.value: return visit_list_to_block(self.visit, node.body) if node.orelse: diff --git a/python/tvm/hybrid/util.py b/python/tvm/hybrid/util.py index 0dd1fa141329..a08a380dd767 100644 --- a/python/tvm/hybrid/util.py +++ b/python/tvm/hybrid/util.py @@ -33,7 +33,7 @@ #pylint: disable=invalid-name np_arg_types = tuple(list(numeric_types) + [numpy.ndarray]) tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr) -halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm) +halide_imm_types = (_expr.IntImm, _expr.FloatImm) def _internal_assert(cond, err): diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 7e22d72131ac..e7f4682e7eb2 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -931,7 +931,7 @@ def _shape(): def _impl(inputs, attr, params): is_symbolic_shape = False for axis in attr['_input_shapes'][inputs[0]]: - if not isinstance(axis, (int, tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(axis, (int, tvm.expr.IntImm)): is_symbolic_shape = True break diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index ca4823bc6b83..30ca51592c8f 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -130,8 +130,6 @@ TVM_REGISTER_GLOBAL("make.CommReducer") REGISTER_MAKE(Reduce); REGISTER_MAKE(AttrStmt); -REGISTER_MAKE(IntImm); -REGISTER_MAKE(UIntImm); REGISTER_MAKE(FloatImm); REGISTER_MAKE(StringImm); diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 6a8bc58ad7d0..fa7b59d36b88 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -53,6 +53,9 @@ TVM_REGISTER_GLOBAL("_const") } }); +TVM_REGISTER_GLOBAL("_LargeUIntImm") +.set_body_typed(LargeUIntImm); + TVM_REGISTER_GLOBAL("_str") .set_body_typed(ir::StringImmNode::make); diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index 7a3baa678352..e03e5e2387bf 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -87,15 +87,15 @@ bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { } bool Analyzer::CanProve(const PrimExpr& expr) { - if (const auto* ptr = expr.as()) { + if (const auto* ptr = expr.as()) { return ptr->value != 0; } auto res = this->rewrite_simplify(expr); - if (const auto* ptr = res.as()) { + if (const auto* ptr = res.as()) { return ptr->value != 0; } res = this->canonical_simplify(expr); - if (const auto* ptr = res.as()) { + if (const auto* ptr = res.as()) { return ptr->value != 0; } return false; diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index 5f721d7a1f94..90c6e48ded1e 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -737,7 +737,7 @@ VisitExpr_(const DivNode* op) { // const folding PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; - PVar c1; + PVar c1; // x / c1 if (c1.Match(b) && c1.Eval()->value > 0) { int64_t cval = c1.Eval()->value; @@ -797,7 +797,7 @@ VisitExpr_(const FloorDivNode* op) { // const folding PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; - PVar c1; + PVar c1; // x / c1 if (c1.Match(b) && c1.Eval()->value > 0) { int64_t cval = c1.Eval()->value; @@ -905,7 +905,7 @@ VisitExpr_(const ModNode* op) { PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; - PVar c1; + PVar c1; // x % c1 if (c1.Match(b) && c1.Eval()->value > 0) { int64_t cval = c1.Eval()->value; @@ -975,7 +975,7 @@ VisitExpr_(const FloorModNode* op) { PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; - PVar c1; + PVar c1; // x % c1 if (c1.Match(b) && c1.Eval()->value > 0) { int64_t cval = c1.Eval()->value; diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index 55c156d898f9..3b803ecd84a2 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -76,8 +76,6 @@ inline bool IsIndexType(const DataType& type) { #define TVM_ARITH_CONST_PROPAGATION(BODY) \ - using ir::IntImmNode; \ - using ir::UIntImmNode; \ using ir::FloatImmNode; \ const IntImmNode* pa = a.as(); \ const IntImmNode* pb = b.as(); \ @@ -87,8 +85,6 @@ inline bool IsIndexType(const DataType& type) { #define TVM_INDEX_CONST_PROPAGATION(BODY) \ - using ir::IntImmNode; \ - using ir::UIntImmNode; \ const IntImmNode* pa = a.as(); \ const IntImmNode* pb = b.as(); \ const DataType& ta = a.dtype(); \ @@ -103,7 +99,7 @@ template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, pa->value + pb->value); + if (pa && pb) return IntImm(rtype, pa->value + pb->value); if (pa && pa->value == 0) return b; if (pb && pb->value == 0) return a; if (fa && fb) return FloatImmNode::make(rtype, fa->value + fb->value); @@ -117,7 +113,7 @@ template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, pa->value - pb->value); + if (pa && pb) return IntImm(rtype, pa->value - pb->value); if (pb && pb->value == 0) return a; if (fa && fb) return FloatImmNode::make(rtype, fa->value - fb->value); if (fb && fb->value == 0) return a; @@ -129,7 +125,7 @@ template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, pa->value * pb->value); + if (pa && pb) return IntImm(rtype, pa->value * pb->value); if (pa) { if (pa->value == 1) return b; if (pa->value == 0) return a; @@ -159,7 +155,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { // due to division and mod can have different modes // NOTE: this will assumes truc div. CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImmNode::make(rtype, pa->value / pb->value); + return IntImm(rtype, pa->value / pb->value); } if (pa) { if (pa->value == 0) return a; @@ -185,7 +181,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { - return IntImmNode::make(rtype, pa->value % pb->value); + return IntImm(rtype, pa->value % pb->value); } if (pa) { if (pa->value == 0) return a; @@ -204,7 +200,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) { CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImmNode::make(rtype, arith::floordiv(pa->value, pb->value)); + return IntImm(rtype, arith::floordiv(pa->value, pb->value)); } if (pa) { if (pa->value == 0) return a; @@ -230,7 +226,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { - return IntImmNode::make(rtype, arith::floormod(pa->value, pb->value)); + return IntImm(rtype, arith::floormod(pa->value, pb->value)); } if (pa) { if (pa->value == 0) return a; @@ -247,7 +243,7 @@ template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, std::min(pa->value, pb->value)); + if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); if (fa && fb) return FloatImmNode::make(rtype, std::min(fa->value, fb->value)); }); if (a.same_as(b)) return a; @@ -258,7 +254,7 @@ template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, std::max(pa->value, pb->value)); + if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); if (fa && fb) return FloatImmNode::make(rtype, std::max(fa->value, fb->value)); }); if (a.same_as(b)) return a; @@ -268,8 +264,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value > pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value > fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value); }); return PrimExpr(); } @@ -277,8 +273,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value >= pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value >= fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value); }); return PrimExpr(); } @@ -286,8 +282,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value < pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value < fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value); }); return PrimExpr(); } @@ -295,8 +291,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value <= pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value <= fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value); }); return PrimExpr(); } @@ -304,8 +300,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value == pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value == fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value); }); return PrimExpr(); } @@ -313,17 +309,16 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value != pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value != fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value); }); return PrimExpr(); } template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { - using ir::UIntImmNode; - const UIntImmNode* pa = a.as(); - const UIntImmNode* pb = b.as(); + const IntImmNode* pa = a.as(); + const IntImmNode* pb = b.as(); if (pa && pa->value) return b; if (pa && !pa->value) return a; if (pb && pb->value) return a; @@ -333,9 +328,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { - using ir::UIntImmNode; - const UIntImmNode* pa = a.as(); - const UIntImmNode* pb = b.as(); + const IntImmNode* pa = a.as(); + const IntImmNode* pb = b.as(); if (pa && pa->value) return a; if (pa && !pa->value) return b; if (pb && pb->value) return b; @@ -345,10 +339,9 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a) { - using ir::UIntImmNode; - const UIntImmNode* pa = a.as(); + const IntImmNode* pa = a.as(); if (pa) { - return UIntImmNode::make(DataType::UInt(1), !(pa->value)); + return IntImm(DataType::UInt(1), !(pa->value)); } return PrimExpr(); } diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index a041e40abf46..25d88d3429b6 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -150,14 +150,6 @@ class ConstIntBoundAnalyzer::Impl : return MakeBound(op->value, op->value); } - Entry VisitExpr_(const UIntImmNode* op) final { - if (op->value <= static_cast(kPosInf)) { - return MakeBound(op->value, op->value); - } else { - return Everything(op->dtype); - } - } - Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); @@ -496,7 +488,7 @@ class ConstIntBoundAnalyzer::Impl : */ static std::vector DetectBoundInfo(const PrimExpr& cond) { PVar x, y; - PVar c; + PVar c; // NOTE: canonical form always use <= or < if ((c <= x).Match(cond)) { return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value, kPosInf))}; diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index ceaa976469e8..37d5e9eb5e57 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -384,10 +384,6 @@ class IntervalSetEvaluator : return IntervalSet::SinglePoint(GetRef(op)); } - IntervalSet VisitExpr_(const UIntImmNode* op) final { - return IntervalSet::SinglePoint(GetRef(op)); - } - IntervalSet VisitExpr_(const VarNode* op) final { Var var = GetRef(op); auto it = dom_map_.find(var); @@ -476,7 +472,7 @@ class IntervalSetEvaluator : IntervalSet VisitExpr_(const RampNode* op) final { CHECK(eval_vec_); IntervalSet base = Eval(op->base); - PVar stride; + PVar stride; if (stride.Match(op->stride)) { DataType t = op->base.dtype(); int64_t vstride = stride.Eval()->value; diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 01dd2e8e499e..c81842035c9f 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -109,7 +109,7 @@ class ModularSetAnalyzer::Impl : // Detect useful constraints and use them in the analysis scope. std::function EnterConstraint(const PrimExpr& constraint) { PVar var; - PVar coeff, base; + PVar coeff, base; // pattern match interesting constraints if ((truncmod(var, coeff) == base).Match(constraint) || (floormod(var, coeff) == base).Match(constraint)) { @@ -132,14 +132,6 @@ class ModularSetAnalyzer::Impl : return Entry(0, op->value); } - Entry VisitExpr_(const UIntImmNode* op) final { - if (op->value < std::numeric_limits::max()) { - return Entry(0, static_cast(op->value)); - } else { - return Everything(); - } - } - Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); diff --git a/src/arithmetic/pattern_match.h b/src/arithmetic/pattern_match.h index 733dcf41ce94..a236e65a8312 100644 --- a/src/arithmetic/pattern_match.h +++ b/src/arithmetic/pattern_match.h @@ -45,7 +45,7 @@ * } * * tvm::Var tx, ty; - * arith::PVar c; + * arith::PVar c; * arith::PVar v; * // We can match integer and Var, both of which are * // special case container of Expr @@ -140,9 +140,9 @@ class PEqualChecker { }; template<> -class PEqualChecker { +class PEqualChecker { public: - bool operator()(const Integer& lhs, const Integer& rhs) const { + bool operator()(const IntImm& lhs, const IntImm& rhs) const { return lhs->value == rhs->value; } }; diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 94d951da51db..e6e1524604ce 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -124,7 +124,7 @@ VisitExpr_(const AddNode* op) { // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm - PVar c1, c2, c3; + PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules @@ -239,7 +239,7 @@ VisitExpr_(const SubNode* op) { // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm - PVar c1, c2, c3; + PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules @@ -438,7 +438,7 @@ VisitExpr_(const MulNode* op) { // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules @@ -477,7 +477,7 @@ VisitExpr_(const DivNode* op) { // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2, c3; + PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -700,7 +700,7 @@ VisitExpr_(const ModNode* op) { // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -789,7 +789,7 @@ VisitExpr_(const FloorDivNode* op) { // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2, c3; + PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -934,7 +934,7 @@ VisitExpr_(const FloorModNode* op) { // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -1004,7 +1004,7 @@ VisitExpr_(const MinNode* op) { // Pattern var to match any expression PVar x, y, z, s1, s2; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; PVar lanes; // vector rule @@ -1189,7 +1189,7 @@ VisitExpr_(const MaxNode* op) { // Pattern var to match any expression PVar x, y, z, s1, s2; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; PVar lanes; // vector rule @@ -1362,7 +1362,7 @@ VisitExpr_(const EQNode* op) { // Pattern var to match any expression PVar x, y; // Pattern var match IntImm - PVar c1; + PVar c1; PVar lanes; // vector rule @@ -1416,7 +1416,7 @@ VisitExpr_(const LTNode* op) { // Pattern var to match any expression PVar x, y, z, s1, s2; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; PVar lanes; // vector rule @@ -1597,7 +1597,7 @@ VisitExpr_(const AndNode* op) { // Pattern var to match any expression PVar x, y; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; PVar lanes; if (op->dtype.lanes() != 1) { @@ -1646,7 +1646,7 @@ VisitExpr_(const OrNode* op) { // Pattern var to match any expression PVar x, y; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; PVar lanes; if (op->dtype.lanes() != 1) { diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc index cf138edd494e..55ed36ca9352 100644 --- a/src/autotvm/touch_extractor.cc +++ b/src/autotvm/touch_extractor.cc @@ -256,7 +256,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > Array attr{std::string("_attr_"), FloatImmNode::make(DataType::Float(32), trans(fea.length)), - IntImmNode::make(DataType::Int(32), fea.nest_level), + IntImm(DataType::Int(32), fea.nest_level), FloatImmNode::make(DataType::Float(32), trans(fea.topdown_product)), FloatImmNode::make(DataType::Float(32), trans(fea.bottomup_product)), }; diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 777ad6203008..d9b7f7f08d12 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -372,16 +372,17 @@ inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // } } -inline void PrintConst(const UIntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - if (op->dtype == DataType::UInt(32)) { + +inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os, CodeGenC* p) { // NOLINT(*) + if (dtype == DataType::UInt(32)) { std::ostringstream temp; - temp << op->value << "U"; + temp << val << "U"; p->MarkConst(temp.str()); os << temp.str(); } else { os << "("; - p->PrintType(op->dtype, os); - os << ")" << op->value; + p->PrintType(dtype, os); + os << ")" << val; } } @@ -408,9 +409,7 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { void CodeGenC::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } -void CodeGenC::VisitExpr_(const UIntImmNode* op, std::ostream& os) { // NOLINT(*) - PrintConst(op, os, this); -} + void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } @@ -528,6 +527,12 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) os << ")"; } else if (op->is_intrinsic(CallNode::bitwise_and)) { PrintBinaryIntrinsic(op, " & ", os, this); + } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { + CHECK_EQ(op->args.size(), 2U); + uint64_t low = static_cast(Downcast(op->args[0])->value); + uint64_t high = static_cast(Downcast(op->args[1])->value); + uint64_t val = (high << 32U) | low; + PrintUIntConst(op->dtype, val, os, this); } else if (op->is_intrinsic(CallNode::bitwise_xor)) { PrintBinaryIntrinsic(op, " ^ ", os, this); } else if (op->is_intrinsic(CallNode::bitwise_or)) { diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index cb092c566322..7e5dd4269c94 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -128,7 +128,6 @@ class CodeGenC : void VisitExpr_(const ShuffleNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const UIntImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) // statment diff --git a/src/codegen/codegen_opengl.cc b/src/codegen/codegen_opengl.cc index 7967c1847ac2..cea276d5cb1a 100644 --- a/src/codegen/codegen_opengl.cc +++ b/src/codegen/codegen_opengl.cc @@ -247,11 +247,6 @@ void CodeGenOpenGL::VisitExpr_(const IntImmNode* op, std::ostream& os) { CodeGenC::VisitExpr_(op, os); } -void CodeGenOpenGL::VisitExpr_(const UIntImmNode* op, std::ostream& os) { - CHECK_EQ(op->dtype, DataType::UInt(32)) << "GLSL 3.0 only supports 32-bit uints."; - CodeGenC::VisitExpr_(op, os); -} - void CodeGenOpenGL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { CHECK_EQ(op->dtype, DataType::Float(32)) << "GLSL 3.0 only supports 32-bit floats."; CodeGenC::VisitExpr_(op, os); diff --git a/src/codegen/codegen_opengl.h b/src/codegen/codegen_opengl.h index cd1ec83360c6..19ca2ee12c6c 100644 --- a/src/codegen/codegen_opengl.h +++ b/src/codegen/codegen_opengl.h @@ -50,7 +50,6 @@ class CodeGenOpenGL final : public CodeGenC { // Codegen for immediate values void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const UIntImmNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) final; // NOLINT(*) diff --git a/src/codegen/llvm/codegen_arm.cc b/src/codegen/llvm/codegen_arm.cc index 6879fd5f8542..44862cf7a97c 100644 --- a/src/codegen/llvm/codegen_arm.cc +++ b/src/codegen/llvm/codegen_arm.cc @@ -48,7 +48,7 @@ class CodeGenARM final : public CodeGenCPU { llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { if (op->is_intrinsic("llvm_intrin")) { llvm::Intrinsic::ID id = static_cast( - op->args[0].as()->value); + Downcast(op->args[0])->value); if (id == ::llvm::Intrinsic::ctpop) { PrimExpr e = ARMPopcount(op); return CodeGenCPU::CreateIntrinsic(e.as()); @@ -68,8 +68,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { if (!call->dtype.is_vector() || call->dtype.bits() == 8 || (total_size != 128 && total_size != 64)) { Array vcnt_args; - vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id)); - vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); + vcnt_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt_args.push_back(e); return ir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic); } @@ -93,16 +93,16 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { const CallNode* c0 = input8.as(); CHECK(c0 != nullptr); Array vcnt8_args; - vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id)); - vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); + vcnt8_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt8_args.push_back(input8); PrimExpr vcnt8 = ir::CallNode::make( uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic); // Accumulation 8->16bit Array vcnt16_args; - vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id)); - vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); + vcnt16_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt16_args.push_back(vcnt8); PrimExpr vcnt16 = ir::CallNode::make( uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic); @@ -112,8 +112,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { // Accumulation 16->32bit Array vcnt32_args; - vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id)); - vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); + vcnt32_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt32_args.push_back(vcnt16); PrimExpr vcnt32 = ir::CallNode::make( uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic); @@ -123,8 +123,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { // Accumulation 32->64bit Array vcnt64_args; - vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id)); - vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); + vcnt64_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt64_args.push_back(vcnt32); return ir::CallNode::make( call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic); diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index c04a023aefad..60d8146fc0e6 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -662,15 +662,13 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { if (op->is_intrinsic("llvm_intrin")) { CHECK_GE(op->args.size(), 2U); llvm::Intrinsic::ID id = static_cast( - op->args[0].as()->value); - const uint64_t *num_signature = as_const_uint(op->args[1]); - CHECK(num_signature) << "The second argument should be a uint represents number of arguments, " - << "but " << op->args[1] << " got!\n"; + Downcast(op->args[0])->value); + int64_t num_signature = Downcast(op->args[1])->value; std::vector arg_value; std::vector sig_type; for (size_t i = 2; i < op->args.size(); ++i) { arg_value.push_back(MakeValue(op->args[i])); - if (i - 2 < *num_signature) { + if (i - 2 < static_cast(num_signature)) { sig_type.push_back(arg_value.back()->getType()); } } @@ -722,6 +720,12 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return llvm::Constant::getNullValue(t_void_p_); } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { return builder_->CreateIsNull(MakeValue(op->args[0])); + } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { + CHECK_EQ(op->args.size(), 2U); + uint64_t low = static_cast(Downcast(op->args[0])->value); + uint64_t high = static_cast(Downcast(op->args[1])->value); + uint64_t val = (high << 32U) | low; + return llvm::ConstantInt::get(LLVMType(op->dtype), val); } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { CHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition"; @@ -804,10 +808,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) { return llvm::ConstantInt::getSigned(LLVMType(op->dtype), op->value); } -llvm::Value* CodeGenLLVM::VisitExpr_(const UIntImmNode* op) { - return llvm::ConstantInt::get(LLVMType(op->dtype), op->value); -} - llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) { return llvm::ConstantFP::get(LLVMType(op->dtype), op->value); } diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index 34c3ee723e18..b269f2423fc8 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -106,7 +106,6 @@ class CodeGenLLVM : llvm::Value* VisitExpr_(const VarNode* op) override; llvm::Value* VisitExpr_(const CastNode* op) override; llvm::Value* VisitExpr_(const IntImmNode* op) override; - llvm::Value* VisitExpr_(const UIntImmNode* op) override; llvm::Value* VisitExpr_(const FloatImmNode* op) override; llvm::Value* VisitExpr_(const StringImmNode* op) override; llvm::Value* VisitExpr_(const AddNode* op) override; diff --git a/src/codegen/llvm/codegen_x86_64.cc b/src/codegen/llvm/codegen_x86_64.cc index 03656cc70a46..11bda70fb8cf 100644 --- a/src/codegen/llvm/codegen_x86_64.cc +++ b/src/codegen/llvm/codegen_x86_64.cc @@ -96,8 +96,8 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { MakeValue( ir::BroadcastNode::make( ir::FloatImmNode::make(DataType::Float(32), 0), from.lanes())), - /*mask=*/MakeValue(ir::IntImmNode::make(DataType::Int(16), -1)), - /*rounding-mode=*/MakeValue(ir::IntImmNode::make(DataType::Int(32), 4)), + /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), + /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)), }); } diff --git a/src/codegen/llvm/intrin_rule_llvm.h b/src/codegen/llvm/intrin_rule_llvm.h index b3ab557ee215..1f839f362f40 100644 --- a/src/codegen/llvm/intrin_rule_llvm.h +++ b/src/codegen/llvm/intrin_rule_llvm.h @@ -43,8 +43,8 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { CHECK(call != nullptr); Array cargs; // intrin id. - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id)); - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature)); + cargs.push_back(IntImm(DataType::UInt(32), id)); + cargs.push_back(IntImm(DataType::UInt(32), num_signature)); for (PrimExpr arg : call->args) { cargs.push_back(arg); @@ -60,8 +60,8 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { CHECK(call != nullptr); Array cargs; // intrin id. - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id)); - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature)); + cargs.push_back(IntImm(DataType::UInt(32), id)); + cargs.push_back(IntImm(DataType::UInt(32), num_signature)); for (PrimExpr arg : call->args) { cargs.push_back(arg); } diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index a749424892e2..985f6816a640 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -136,10 +136,6 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const IntImmNode* op) { return builder_->IntImm(builder_->GetSType(op->dtype), op->value); } -spirv::Value CodeGenSPIRV::VisitExpr_(const UIntImmNode* op) { - return builder_->UIntImm(builder_->GetSType(op->dtype), op->value); -} - spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImmNode* op) { return builder_->FloatImm(builder_->GetSType(op->dtype), op->value); } @@ -242,7 +238,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { if (op->is_intrinsic("spirv_glsl450")) { CHECK_GE(op->args.size(), 2U); - uint32_t inst_id = op->args[0].as()->value; + uint32_t inst_id = static_cast( + op->args[0].as()->value); std::vector values; for (size_t i = 1; i < op->args.size(); ++i) { values.push_back(MakeValue(op->args[i])); @@ -285,6 +282,12 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { } else if (op->is_intrinsic(CallNode::reinterpret)) { return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype), MakeValue(op->args[0])); + } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { + CHECK_EQ(op->args.size(), 2U); + uint64_t low = static_cast(Downcast(op->args[0])->value); + uint64_t high = static_cast(Downcast(op->args[1])->value); + uint64_t val = (high << 32U) | low; + return builder_->UIntImm(builder_->GetSType(op->dtype), val); } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { return this->CreateStorageSync(op); } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { diff --git a/src/codegen/spirv/codegen_spirv.h b/src/codegen/spirv/codegen_spirv.h index 3804bda0f2e0..5aa7f9c49910 100644 --- a/src/codegen/spirv/codegen_spirv.h +++ b/src/codegen/spirv/codegen_spirv.h @@ -65,7 +65,6 @@ class CodeGenSPIRV: spirv::Value VisitExpr_(const VarNode* op) override; spirv::Value VisitExpr_(const CastNode* op) override; spirv::Value VisitExpr_(const IntImmNode* op) override; - spirv::Value VisitExpr_(const UIntImmNode* op) override; spirv::Value VisitExpr_(const FloatImmNode* op) override; spirv::Value VisitExpr_(const StringImmNode* op) override; spirv::Value VisitExpr_(const AddNode* op) override; diff --git a/src/codegen/spirv/intrin_rule_spirv.cc b/src/codegen/spirv/intrin_rule_spirv.cc index d41d96db5165..d96883ed02fd 100644 --- a/src/codegen/spirv/intrin_rule_spirv.cc +++ b/src/codegen/spirv/intrin_rule_spirv.cc @@ -39,7 +39,7 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { CHECK(call != nullptr); Array cargs; // intrin id. - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id)); + cargs.push_back(IntImm(DataType::UInt(32), id)); for (PrimExpr arg : call->args) { cargs.push_back(arg); diff --git a/src/codegen/spirv/ir_builder.cc b/src/codegen/spirv/ir_builder.cc index 6f8d96e148c1..bf43f11cce02 100644 --- a/src/codegen/spirv/ir_builder.cc +++ b/src/codegen/spirv/ir_builder.cc @@ -342,9 +342,9 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { if (dtype.type == DataType::UInt(1)) { // bool types. if (*pvalue) { - ib_.Begin(spv::OpConstantTrue).AddSeq(ret); + ib_.Begin(spv::OpConstantTrue).AddSeq(dtype, ret); } else { - ib_.Begin(spv::OpConstantFalse).AddSeq(ret); + ib_.Begin(spv::OpConstantFalse).AddSeq(dtype, ret); } } else { // Integral/floating-point types. diff --git a/src/codegen/spirv/ir_builder.h b/src/codegen/spirv/ir_builder.h index 3843cbb3c6a9..5d25e8634e84 100644 --- a/src/codegen/spirv/ir_builder.h +++ b/src/codegen/spirv/ir_builder.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/codegen/stackvm/codegen_stackvm.cc b/src/codegen/stackvm/codegen_stackvm.cc index eccff6c74c2e..01096ae1dd46 100644 --- a/src/codegen/stackvm/codegen_stackvm.cc +++ b/src/codegen/stackvm/codegen_stackvm.cc @@ -280,12 +280,6 @@ void CodeGenStackVM::VisitExpr_(const IntImmNode* op) { this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); } -void CodeGenStackVM::VisitExpr_(const UIntImmNode* op) { - CHECK(op->value <= std::numeric_limits::max()) - << "Int constant exceed bound"; - this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); -} - void CodeGenStackVM::VisitExpr_(const FloatImmNode* op) { LOG(FATAL) << "Float Imm is not supported"; } diff --git a/src/codegen/stackvm/codegen_stackvm.h b/src/codegen/stackvm/codegen_stackvm.h index 07989b2062e1..1360cc2d70f1 100644 --- a/src/codegen/stackvm/codegen_stackvm.h +++ b/src/codegen/stackvm/codegen_stackvm.h @@ -136,7 +136,6 @@ class CodeGenStackVM void VisitExpr_(const RampNode* op) final; void VisitExpr_(const BroadcastNode* op) final; void VisitExpr_(const IntImmNode* op) final; - void VisitExpr_(const UIntImmNode* op) final; void VisitExpr_(const FloatImmNode* op) final; void VisitExpr_(const StringImmNode* op) final; // statment diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 7e3d44f26aef..346ec3808919 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -79,10 +79,7 @@ void CodeGenHybrid::PrintType(DataType t, std::ostream &os) { void CodeGenHybrid::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) os << op->value; } -void CodeGenHybrid::VisitExpr_(const UIntImmNode* op, std::ostream& os) { // NOLINT(*) - PrintType(op->dtype, os); - os << "(" << op->value << ")"; -} + void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) PrintType(op->dtype, os); os << "(" << std::setprecision(20) << op->value << ")"; diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 89a1ece577f9..33bd0efae8a4 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -117,7 +117,6 @@ class CodeGenHybrid : void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const UIntImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) // statment diff --git a/src/ir/expr.cc b/src/ir/expr.cc index f698a5d1802e..6d89967416b0 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -26,6 +26,25 @@ namespace tvm { +IntImm::IntImm(DataType dtype, int64_t value) { + CHECK(dtype.is_scalar()) + << "ValueError: IntImm can only take scalar."; + CHECK(dtype.is_int() || dtype.is_uint()) + << "ValueError: IntImm can only take scalar."; + if (dtype.is_uint()) { + CHECK_GE(value, 0U); + } + ObjectPtr node = make_object(); + node->dtype = dtype; + node->value = value; + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("make.IntImm") +.set_body_typed([](DataType dtype, int64_t value) { + return IntImm(dtype, value); +}); + GlobalVar::GlobalVar(std::string name_hint) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h index 34ee4b3159a5..4fffc475a773 100644 --- a/src/lang/attr_functor.h +++ b/src/lang/attr_functor.h @@ -77,7 +77,6 @@ class AttrFunctor { virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::UIntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; // deep comparison of symbolic integer expressions. @@ -113,7 +112,6 @@ class AttrFunctor { ATTR_FUNCTOR_DISPATCH(StrMapNode); ATTR_FUNCTOR_DISPATCH(ArrayNode); ATTR_FUNCTOR_DISPATCH(IntImmNode); - ATTR_FUNCTOR_DISPATCH(UIntImmNode); ATTR_FUNCTOR_DISPATCH(FloatImmNode); ATTR_FUNCTOR_DISPATCH(StringImmNode); ATTR_FUNCTOR_DISPATCH(VarNode); @@ -157,7 +155,6 @@ class AttrsEqualHandler : bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::IntImmNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::UIntImmNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::FloatImmNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::StringImmNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::AddNode* lhs, const ObjectRef& other) final; @@ -198,7 +195,6 @@ class AttrsHashHandler : protected: size_t VisitAttrDefault_(const Object* lhs) final; size_t VisitAttr_(const ir::IntImmNode* lhs) final; - size_t VisitAttr_(const ir::UIntImmNode* lhs) final; size_t VisitAttr_(const ir::FloatImmNode* lhs) final; size_t VisitAttr_(const ir::StringImmNode* lhs) final; size_t VisitAttr_(const ArrayNode* lhs) final; diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index 1d3e767a5b71..a590f10e78e5 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -97,13 +97,6 @@ bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef& other return false; } -bool AttrsEqualHandler::VisitAttr_(const UIntImmNode* lhs, const ObjectRef& other) { - if (const auto* rhs = other.as()) { - return lhs->value == rhs->value; - } - return false; -} - bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; @@ -224,10 +217,6 @@ size_t AttrsHashHandler::VisitAttr_(const IntImmNode* op) { return std::hash()(op->value); } -size_t AttrsHashHandler::VisitAttr_(const UIntImmNode* op) { - return std::hash()(op->value); -} - size_t AttrsHashHandler::VisitAttr_(const FloatImmNode* op) { return std::hash()(op->value); } diff --git a/src/lang/expr.cc b/src/lang/expr.cc index a7289369bcd4..55dfb89342a8 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -30,7 +30,7 @@ namespace tvm { PrimExpr::PrimExpr(int32_t value) - : PrimExpr(IntImmNode::make(DataType::Int(32), value)) {} + : PrimExpr(IntImm(DataType::Int(32), value)) {} PrimExpr::PrimExpr(float value) : PrimExpr(ir::FloatImmNode::make(DataType::Float(32), value)) {} @@ -54,15 +54,6 @@ Range::Range(PrimExpr begin, PrimExpr end) is_zero(begin) ? end : (end - begin))) { } -Integer IntImmNode::make(DataType t, int64_t value) { - CHECK(t.is_int() && t.is_scalar()) - << "ValueError: IntImm can only take scalar."; - ObjectPtr node = make_object(); - node->dtype = t; - node->value = value; - return Integer(node); -} - Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) { return Range(make_object(min, extent)); } diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index d3875e28c887..bd43d89d89d0 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -35,6 +35,14 @@ inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) { return ir::CastNode::make(t, value); } +PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { + return ir::CallNode::make( + t, ir::intrinsic::tvm_large_uint_imm, + {make_const(DataType::UInt(32), low), + make_const(DataType::UInt(32), high)}, + ir::CallNode::PureIntrinsic); +} + // The public function with a quick checking path. void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*) if (lhs.dtype() == rhs.dtype()) return; @@ -78,26 +86,25 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*) } } - // maximum and min limits PrimExpr max_value(const DataType& dtype) { using namespace ir; CHECK_EQ(dtype.lanes(), 1); if (dtype.is_int()) { if (dtype.bits() == 64) { - return IntImmNode::make(dtype, std::numeric_limits::max()); + return IntImm(dtype, std::numeric_limits::max()); } else if (dtype.bits() < 64) { int64_t val = 1; val = (val << (dtype.bits() - 1)) - 1; - return IntImmNode::make(dtype, val); + return IntImm(dtype, val); } } else if (dtype.is_uint()) { if (dtype.bits() == 64) { - return UIntImmNode::make(dtype, std::numeric_limits::max()); + return make_const(dtype, std::numeric_limits::max()); } else if (dtype.bits() < 64) { uint64_t val = 1; val = (val << static_cast(dtype.bits())) - 1; - return UIntImmNode::make(dtype, val); + return IntImm(dtype, static_cast(val)); } } else if (dtype.is_float()) { if (dtype.bits() == 64) { @@ -117,14 +124,14 @@ PrimExpr min_value(const DataType& dtype) { CHECK_EQ(dtype.lanes(), 1); if (dtype.is_int()) { if (dtype.bits() == 64) { - return IntImmNode::make(dtype, std::numeric_limits::lowest()); + return IntImm(dtype, std::numeric_limits::lowest()); } else if (dtype.bits() < 64) { int64_t val = 1; val = -(val << (dtype.bits() - 1)); - return IntImmNode::make(dtype, val); + return IntImm(dtype, val); } } else if (dtype.is_uint()) { - return UIntImmNode::make(dtype, 0); + return IntImm(dtype, 0); } else if (dtype.is_float()) { if (dtype.bits() == 64) { return FloatImmNode::make(dtype, std::numeric_limits::lowest()); @@ -155,24 +162,18 @@ inline bool ConstPowerHelper(ValueType val, int *shift) { bool is_const_power_of_two_integer(const PrimExpr& x, int* shift) { if (const auto* op = x.as()) { return ConstPowerHelper(op->value, shift); - } else if (const auto* op = x.as()) { - return ConstPowerHelper(op->value, shift); } else { return false; } } PrimExpr cast(const DataType& t, PrimExpr value) { - using ir::IntImmNode; - using ir::UIntImmNode; using ir::FloatImmNode; if (value.dtype() == t) return value; // const fold IntImm as they are used in index computations if (t.lanes() == 1) { if (const IntImmNode* op = value.as()) { return make_const(t, op->value); - } else if (const UIntImmNode* op = value.as()) { - return make_const(t, op->value); } else if (const FloatImmNode* op = value.as()) { return make_const(t, op->value); } @@ -184,8 +185,6 @@ PrimExpr cast(const DataType& t, PrimExpr value) { if (value.dtype() != vtype) { if (const IntImmNode* op = value.as()) { value = make_const(vtype, op->value); - } else if (const UIntImmNode* op = value.as()) { - return make_const(t, op->value); } else if (const FloatImmNode* op = value.as()) { value = make_const(vtype, op->value); } else { @@ -219,7 +218,7 @@ PrimExpr operator-(PrimExpr a) { using ir::FloatImmNode; const IntImmNode* pa = a.as(); const FloatImmNode* fa = a.as(); - if (pa) return ir::IntImmNode::make(a.dtype(), -pa->value); + if (pa) return IntImm(a.dtype(), -pa->value); if (fa) return ir::FloatImmNode::make(a.dtype(), -fa->value); return make_zero(a.dtype()) - a; } @@ -322,18 +321,10 @@ PrimExpr max(PrimExpr a, PrimExpr b) { } PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { - using ir::IntImmNode; - using ir::UIntImmNode; CHECK(cond.dtype() == DataType::Bool(1)) << "if_then_else only accept the condition to be boolean type."; BinaryOpMatchTypes(true_value, false_value); - if (const UIntImmNode* op = cond.as()) { - if (op->value != 0) { - return true_value; - } else { - return false_value; - } - } else if (const IntImmNode* op = cond.as()) { + if (const IntImmNode* op = cond.as()) { if (op->value != 0) { return true_value; } else { @@ -424,7 +415,7 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, (pa->value >> pb->value)); + if (pa && pb) return IntImm(rtype, (pa->value >> pb->value)); if (pb) { if (pb->value == 0) return a; } @@ -437,7 +428,7 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, (pa->value << pb->value)); + if (pa && pb) return IntImm(rtype, (pa->value << pb->value)); if (pb) { if (pb->value == 0) return a; } @@ -450,7 +441,7 @@ PrimExpr operator&(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, (pa->value & pb->value)); + if (pa && pb) return IntImm(rtype, (pa->value & pb->value)); }); return ir::CallNode::make( a.dtype(), ir::CallNode::bitwise_and, { a, b }, ir::CallNode::PureIntrinsic); @@ -460,7 +451,7 @@ PrimExpr operator|(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, (pa->value | pb->value)); + if (pa && pb) return IntImm(rtype, (pa->value | pb->value)); }); return ir::CallNode::make( a.dtype(), ir::CallNode::bitwise_or, { a, b }, ir::CallNode::PureIntrinsic); @@ -470,7 +461,7 @@ PrimExpr operator^(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, (pa->value ^ pb->value)); + if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value)); }); return ir::CallNode::make( a.dtype(), ir::CallNode::bitwise_xor, { a, b }, ir::CallNode::PureIntrinsic); @@ -494,7 +485,7 @@ PrimExpr abs(PrimExpr x) { using ir::IntImmNode; const IntImmNode* px = x.as(); if (px) { - return ir::IntImmNode::make(x.dtype(), std::abs(px->value)); + return IntImm(x.dtype(), std::abs(px->value)); } return ir::SelectNode::make(x >= make_zero(x.dtype()), x, -x); } else if (x.dtype().is_float()) { diff --git a/src/lang/ir.cc b/src/lang/ir.cc index ad7f260226bd..f06a6be5e75a 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -31,14 +31,6 @@ namespace tvm { namespace ir { // constructors -PrimExpr UIntImmNode::make(DataType t, uint64_t value) { - CHECK(t.is_uint() && t.lanes() == 1) - << "ValueError: UIntImm can only take scalar"; - ObjectPtr node = make_object(); - node->dtype = t; - node->value = value; - return PrimExpr(node); -} PrimExpr FloatImmNode::make(DataType t, double value) { CHECK_EQ(t.lanes(), 1) @@ -248,7 +240,7 @@ PrimExpr ShuffleNode::make_concat(Array vectors) { int index = 0; for (const PrimExpr& e : vectors) { for (int i = 0; i < e.dtype().lanes(); ++i) { - indices.push_back(IntImmNode::make(DataType::Int(32), index++)); + indices.push_back(IntImm(DataType::Int(32), index++)); } } return make(vectors, indices); @@ -531,11 +523,6 @@ Stmt EvaluateNode::make(PrimExpr value) { } // Printers -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "(" << op->dtype << ")" << op->value; - }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) .set_dispatch([](const ObjectRef& node, NodePrinter* p) { @@ -1153,7 +1140,6 @@ TVM_REGISTER_NODE_TYPE(AnyNode); TVM_REGISTER_NODE_TYPE(AttrStmtNode); TVM_REGISTER_NODE_TYPE(FloatImmNode); TVM_REGISTER_NODE_TYPE(IntImmNode); -TVM_REGISTER_NODE_TYPE(UIntImmNode); TVM_REGISTER_NODE_TYPE(StringImmNode); TVM_REGISTER_NODE_TYPE(CastNode); TVM_REGISTER_NODE_TYPE(VarNode); diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index 2c04de3710fa..0f350d2d732e 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -179,11 +179,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, std::ostringstream type_err_msg; type_err_msg << arg_name << ".dtype is expected to be " << dtype; PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) == - UIntImmNode::make(DataType::UInt(8), dtype.code()) && + IntImm(DataType::UInt(8), dtype.code()) && TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) == - UIntImmNode::make(DataType::UInt(8), dtype.bits()) && + IntImm(DataType::UInt(8), dtype.bits()) && TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) == - UIntImmNode::make(DataType::UInt(16), dtype.lanes())); + IntImm(DataType::UInt(16), dtype.lanes())); asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop)); // data field if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData), @@ -193,7 +193,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, // mark alignment of external bufs init_nest_.emplace_back(AttrStmtNode::make( vptr, ir::attr::storage_alignment, - IntImmNode::make(DataType::Int(32), buffer->data_alignment), nop)); + IntImm(DataType::Int(32), buffer->data_alignment), nop)); } Var v_shape(arg_name + ".shape", DataType::Handle()); @@ -206,7 +206,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, Bind_(buffer->shape[k], cast(buffer->shape[k].dtype(), LoadNode::make(tvm_shape_type, v_shape, - IntImmNode::make(DataType::Int(32), k), const_true(1))), + IntImm(DataType::Int(32), k), const_true(1))), field_name.str(), true); } // strides field @@ -228,7 +228,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, PrimExpr svalue = cast( stype, LoadNode::make(tvm_shape_type, v_strides, - IntImmNode::make(DataType::Int(32), k), const_true(1))); + IntImm(DataType::Int(32), k), const_true(1))); conds.push_back(expect_stride == svalue); expect_stride = expect_stride * buffer->shape[k]; } @@ -251,7 +251,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, field_name << v_strides->name_hint << '[' << k << ']'; PrimExpr value = cast(buffer->shape[k].dtype(), LoadNode::make(tvm_shape_type, v_strides, - IntImmNode::make(DataType::Int(32), k), const_true(1))); + IntImm(DataType::Int(32), k), const_true(1))); value = tvm::if_then_else(is_null, stride, value); value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); Bind_(buffer->strides[k], value, field_name.str(), true); @@ -270,7 +270,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, Bind_(buffer->strides[k], cast(buffer->shape[k].dtype(), LoadNode::make(tvm_shape_type, v_strides, - IntImmNode::make(DataType::Int(32), k), const_true(1))), + IntImm(DataType::Int(32), k), const_true(1))), field_name.str(), true); } } diff --git a/src/pass/ir_deep_compare.cc b/src/pass/ir_deep_compare.cc index 6eacb145b29b..8c441510c51d 100644 --- a/src/pass/ir_deep_compare.cc +++ b/src/pass/ir_deep_compare.cc @@ -252,10 +252,6 @@ class IRDeepCompare : CompareValue(op->value, other.as()->value); } - void VisitExpr_(const UIntImmNode *op, const PrimExpr& other) final { - CompareValue(op->value, other.as()->value); - } - void VisitExpr_(const FloatImmNode *op, const PrimExpr& other) final { CompareValue(op->value, other.as()->value); } diff --git a/src/pass/ir_functor.cc b/src/pass/ir_functor.cc index 67acec674630..857206f8dd9f 100644 --- a/src/pass/ir_functor.cc +++ b/src/pass/ir_functor.cc @@ -260,7 +260,6 @@ DEFINE_BINOP_VISIT_(AndNode); DEFINE_BINOP_VISIT_(OrNode); void ExprVisitor::VisitExpr_(const IntImmNode* op) {} -void ExprVisitor::VisitExpr_(const UIntImmNode* op) {} void ExprVisitor::VisitExpr_(const FloatImmNode* op) {} void ExprVisitor::VisitExpr_(const StringImmNode* op) {} @@ -640,7 +639,6 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { } DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode) diff --git a/src/pass/lift_attr_scope.cc b/src/pass/lift_attr_scope.cc index 7b760fa4a672..5aba355b7003 100644 --- a/src/pass/lift_attr_scope.cc +++ b/src/pass/lift_attr_scope.cc @@ -180,9 +180,6 @@ class AttrScopeLifter : public StmtMutator { if (const IntImmNode* op = a.as()) { return op->value == b.as()->value; } - if (const UIntImmNode* op = a.as()) { - return op->value == b.as()->value; - } return false; } diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index ed8be8bb39fc..5684f4ef785f 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -173,7 +173,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const MaxNode* op) final { using namespace arith; PVar x, y; - PVar c; + PVar c; auto e = GetRef(op); if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 && diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index a0b07c293b05..d509169df0b1 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -120,7 +120,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const CommReducerNode *combiner = reduce_combiner_.back(); size_t size = combiner->result.size(); - const UIntImmNode *size_of_args = call->args[0].as(); + const IntImmNode *size_of_args = call->args[0].as(); CHECK(size_of_args) << call->args[0]->GetTypeKey(); CHECK_EQ(size, size_of_args->value); Array inits = combiner->identity_element; diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc index 8e7f1d86da74..01a97b7878be 100644 --- a/src/pass/lower_tvm_builtin.cc +++ b/src/pass/lower_tvm_builtin.cc @@ -129,8 +129,8 @@ class BuiltinLower : public StmtExprMutator { {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), cast(DataType::UInt(64), total_bytes), - IntImmNode::make(DataType::Int(32), op->dtype.code()), - IntImmNode::make(DataType::Int(32), op->dtype.bits())}, + IntImm(DataType::Int(32), op->dtype.code()), + IntImm(DataType::Int(32), op->dtype.bits())}, CallNode::Extern), body); diff --git a/src/pass/make_api.cc b/src/pass/make_api.cc index d5c73a2e8a75..5df36d0b2423 100644 --- a/src/pass/make_api.cc +++ b/src/pass/make_api.cc @@ -69,8 +69,8 @@ LoweredFunc MakeAPI(Stmt body, // load i-th argument as type t auto f_arg_value = [&](DataType t, int i) { Array call_args{v_packed_args, - IntImmNode::make(DataType::Int(32), i), - IntImmNode::make(DataType::Int(32), intrinsic::kTVMValueContent)}; + IntImm(DataType::Int(32), i), + IntImm(DataType::Int(32), intrinsic::kTVMValueContent)}; // load 64 bit version DataType api_type = APIType(t); PrimExpr res = CallNode::make( @@ -117,7 +117,7 @@ LoweredFunc MakeAPI(Stmt body, seq_init.emplace_back(LetStmtNode::make( tcode, LoadNode::make( DataType::Int(32), v_packed_arg_type_ids, - IntImmNode::make(DataType::Int(32), i), const_true(1)), + IntImm(DataType::Int(32), i), const_true(1)), nop)); DataType t = v_arg.dtype(); if (t.is_handle()) { diff --git a/src/pass/rewrite_unsafe_select.cc b/src/pass/rewrite_unsafe_select.cc index 224a81c12396..9fb19cc4b308 100644 --- a/src/pass/rewrite_unsafe_select.cc +++ b/src/pass/rewrite_unsafe_select.cc @@ -96,7 +96,6 @@ class UnsafeExprDetector : public ExprFunctor { return false; } bool VisitExpr_(const VarNode* op) final { return false; } - bool VisitExpr_(const UIntImmNode* op) final { return false; } bool VisitExpr_(const IntImmNode* op) final { return false; } bool VisitExpr_(const FloatImmNode* op) final { return false; } bool VisitExpr_(const StringImmNode* op) final { return false; } diff --git a/src/pass/tensor_core.cc b/src/pass/tensor_core.cc index bb57fe8c37d3..956f27c9319d 100644 --- a/src/pass/tensor_core.cc +++ b/src/pass/tensor_core.cc @@ -462,7 +462,7 @@ class BufferAnalyser : public StmtExprVisitor { strides = bi.strides; } else { for (size_t i = 1; i < bi.shape.size(); ++i) { - PrimExpr stride = IntImmNode::make(DataType::Int(32), 1); + PrimExpr stride = IntImm(DataType::Int(32), 1); for (size_t j = bi.shape.size() - 1; j >= i; --j) { stride = MulNode::make(stride, bi.shape[j]); } @@ -575,7 +575,7 @@ class BufferAnalyser : public StmtExprVisitor { strides = bi.strides; } else { for (size_t i = 1; i < bi.shape.size(); ++i) { - PrimExpr stride = IntImmNode::make(DataType::Int(32), 1); + PrimExpr stride = IntImm(DataType::Int(32), 1); for (size_t j = bi.shape.size() - 1; j >= i; --j) { stride = MulNode::make(stride, bi.shape[j]); } @@ -765,7 +765,7 @@ class ThreadIdxMutator : public StmtExprMutator { op = expr.as(); if (op != nullptr) { if (op->name_hint == "threadIdx.x") { - PrimExpr zero = IntImmNode::make(DataType::Int(32), 0); + PrimExpr zero = IntImm(DataType::Int(32), 0); return zero; } if (op->name_hint == "threadIdx.y") { @@ -934,7 +934,7 @@ class TensorCoreIRMutator : public StmtExprMutator { PrimExpr stride = strides[strides.size()-2]; // thread index unification inside a warp - PrimExpr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_); + PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); PrimExpr mutated_value = thread_idx_mutator(op->value); PrimExpr src = CallNode::make(value->dtype, @@ -984,7 +984,7 @@ class TensorCoreIRMutator : public StmtExprMutator { PrimExpr dst = it3->second; // thread index unification inside a warp - PrimExpr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_); + PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); dst = thread_idx_mutator(dst); dst = CallNode::make(DataType::Handle(), @@ -1089,7 +1089,7 @@ class TensorCoreIRMutator : public StmtExprMutator { Array strides; for (size_t i = 1; i < shape.size(); ++i) { - PrimExpr stride = IntImmNode::make(DataType::Int(32), 1); + PrimExpr stride = IntImm(DataType::Int(32), 1); for (size_t j = shape.size() - 1; j >= i; --j) { stride = MulNode::make(stride, shape[j]); } @@ -1097,7 +1097,7 @@ class TensorCoreIRMutator : public StmtExprMutator { } strides.push_back(make_const(DataType::Int(32), 1)); - PrimExpr elem_offset = IntImmNode::make(DataType::Int(32), 0); + PrimExpr elem_offset = IntImm(DataType::Int(32), 0); CHECK_EQ(call->args.size(), min_bound.size()); for (size_t i = 0; i < min_bound.size(); i++) { elem_offset = AddNode::make( diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc index b2c50f7a8bd2..26ad59189671 100644 --- a/src/pass/unroll_loop.cc +++ b/src/pass/unroll_loop.cc @@ -159,14 +159,10 @@ class LoopUnroller : public StmtExprMutator { // constant folding. PrimExpr extent = ir::Simplify(op->extent); const IntImmNode *v1 = extent.as(); - const UIntImmNode *v2 = extent.as(); int value = -1; if (v1 != nullptr) { value = static_cast(v1->value); } - if (v2 != nullptr) { - value = static_cast(v2->value); - } return value; } diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 00c40b2565bb..5ee4ce30c96d 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -88,7 +88,7 @@ Array GetShape(const Array& shape) { if (pval != nullptr) { CHECK_LE(pval[0], std::numeric_limits::max()); CHECK_GE(pval[0], std::numeric_limits::min()); - res.push_back(ir::IntImmNode::make(DataType::Int(32), *pval)); + res.push_back(IntImm(DataType::Int(32), *pval)); } else if (val->IsInstance()) { res.push_back(val.as()->ToVar()); } else { @@ -395,7 +395,7 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { // set inputs for (auto param : prim_func->params) { int state = param_states_[param]; - cache_node->shape_func_param_states.push_back(IntImmNode::make(DataType::Int(32), state)); + cache_node->shape_func_param_states.push_back(IntImm(DataType::Int(32), state)); if (state & kNeedInputData) { for (auto t : param_data_[param]) { cache_node->inputs.push_back(t); @@ -528,7 +528,7 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { auto ret_type = call_node->checked_type(); Array out_ndims; if (const auto* ttype = ret_type.as()) { - out_ndims.push_back(IntImmNode::make(DataType::Int(32), ttype->shape.size())); + out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); } else { auto rtype = ret_type.as(); // TODO(@icemelon): Allow recursive tuple @@ -536,7 +536,7 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { for (size_t i = 0; i < rtype->fields.size(); ++i) { auto ttype = rtype->fields[i].as(); CHECK(ttype); - out_ndims.push_back(IntImmNode::make(DataType::Int(32), ttype->shape.size())); + out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); } } // Call shape function diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index f66cce6b7b82..9966d9cc55ef 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -56,7 +56,7 @@ TensorType ConstantNode::tensor_type() const { CHECK_LE(data->shape[i], std::numeric_limits::max()); CHECK_GE(data->shape[i], std::numeric_limits::min()); shape.push_back( - tvm::ir::IntImmNode::make(DataType::Int(32), data->shape[i])); + tvm::IntImm(DataType::Int(32), data->shape[i])); } return TensorTypeNode::make(shape, dtype); diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 25650c7766cb..400a6bea22ed 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -857,10 +857,6 @@ class PrettyPrinter : return PrintConstScalar(op->dtype, &(op->value)); } - Doc VisitAttr_(const ir::UIntImmNode* op) final { - return PrintConstScalar(op->dtype, &(op->value)); - } - Doc VisitAttr_(const ir::FloatImmNode* op) final { return PrintConstScalar(op->dtype, &(op->value)); } diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 4d3a4b9589ee..b5383cd3339b 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -852,7 +852,7 @@ bool ArgWhereRel(const Array& types, const auto& input_rank = input_shape.size(); std::vector result_shape; result_shape.push_back(Any::make()); - result_shape.push_back(IntImmNode::make(DataType::Int(32), input_rank)); + result_shape.push_back(IntImm(DataType::Int(32), input_rank)); reporter->Assign(types[1], TensorTypeNode::make(result_shape, DataType::Int(32))); return true; } diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 594669343f62..30a9a5c80402 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -41,7 +41,7 @@ class TypeSolver::Reporter : public TypeReporterNode { } bool Assert(const IndexExpr& cond) final { - if (const uint64_t* pdiff = as_const_uint(cond)) { + if (const int64_t* pdiff = as_const_int(cond)) { return pdiff[0]; } return true; diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index 378a5e3728f4..2e332413c1f6 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -47,14 +47,10 @@ static inline Array get_shape(const Type& type) { static inline const int32_t GetQmin(const DataType& dtype) { CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; - if (dtype.is_int()) { + if (dtype.is_int() || dtype.is_uint()) { auto* min_value = as_const_int(tvm::min_value(dtype)); CHECK(min_value != nullptr); return static_cast(min_value[0]); - } else if (dtype.is_uint()) { - auto* min_value = as_const_uint(tvm::min_value(dtype)); - CHECK(min_value != nullptr); - return static_cast(min_value[0]); } else { LOG(FATAL) << "Type not supported " << dtype; return -1; // To hide the warning @@ -64,14 +60,10 @@ static inline const int32_t GetQmin(const DataType& dtype) { static inline const int32_t GetQmax(const DataType& dtype) { CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; - if (dtype.is_int()) { + if (dtype.is_int() || dtype.is_uint()) { auto* max_value = as_const_int(tvm::max_value(dtype)); CHECK(max_value != nullptr); return static_cast(max_value[0]); - } else if (dtype.is_uint()) { - auto* max_value = as_const_uint(tvm::max_value(dtype)); - CHECK(max_value != nullptr); - return static_cast(max_value[0]); } else { LOG(FATAL) << "Type not supported " << dtype; return -1; // To hide the warning diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index 5392eaeac1e8..193f2f206c06 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -127,10 +127,10 @@ TEST(Pattern, Basic) { } } -TEST(Pattern, Integer) { +TEST(Pattern, IntImm) { using namespace tvm; tvm::Var tx, ty; - arith::PVar c; + arith::PVar c; arith::PVar v; { // We can match integer and Var, both of which are diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index 45ecf9539337..5a10618fb269 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -18,6 +18,32 @@ from tvm.contrib import util import numpy as np +def test_large_uint_imm(): + value = (1 << 63) + 123 + other = tvm.const(3, "uint64") + n = 12 + num_thread = 2 + + A = tvm.compute((n,), lambda *i: tvm.const(value, "uint64") + other, name='A') + s = tvm.create_schedule(A.op) + xo, xi = s[A].split(A.op.axis[0], factor=num_thread) + s[A].bind(xi, tvm.thread_axis("threadIdx.x")) + s[A].bind(xo, tvm.thread_axis("blockIdx.x")) + + def check_target(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + return + f = tvm.build(s, [A], device) + # launch the kernel. + a = tvm.nd.empty((n, ), dtype=A.dtype, ctx=ctx) + f(a) + assert a.asnumpy()[0] == value + 3 + + check_target("cuda") + check_target("vulkan") + + def test_add_pipeline(): n = tvm.var('n') A = tvm.placeholder((n,), name='A') @@ -112,4 +138,5 @@ def check_module_save(device, host="stackvm"): if __name__ == "__main__": + test_large_uint_imm() test_add_pipeline() diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index 0e595cd79c97..4920206ee019 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -88,6 +88,25 @@ def test_llvm_lookup_intrin(): fcode = tvm.build(func, None, "llvm") +def test_llvm_large_uintimm(): + value = (1 << 63) + 123 + other = tvm.const(3, "uint64") + A = tvm.compute((), lambda : tvm.const(value, "uint64") + other, name='A') + s = tvm.create_schedule(A.op) + + def check_llvm(): + if not tvm.module.enabled("llvm"): + return + f = tvm.build(s, [A], "llvm") + ctx = tvm.cpu(0) + # launch the kernel. + a = tvm.nd.empty((), dtype=A.dtype, ctx=ctx) + f(a) + assert a.asnumpy() == value + 3 + + check_llvm() + + def test_llvm_add_pipeline(): nn = 1024 n = tvm.convert(nn) @@ -645,6 +664,7 @@ def vectorizer(op): tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32')) if __name__ == "__main__": + test_llvm_large_uintimm() test_llvm_import() test_alignment() test_rank_zero() diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index c3c40cf740ad..5f1facb2b45f 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -24,7 +24,7 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): def tvm_val_2_py_val(val): val = tvm.ir_pass.Substitute(val, var_dict) val = tvm.ir_pass.Simplify(val) - assert isinstance(val, (tvm.expr.IntImm, tvm.expr.UIntImm)) + assert isinstance(val, (tvm.expr.IntImm,)) return val.value ctx = tvm.context(target, 0) diff --git a/tests/python/unittest/test_lang_constructor.py b/tests/python/unittest/test_lang_constructor.py index fe329494e24e..c4187858a8a8 100644 --- a/tests/python/unittest/test_lang_constructor.py +++ b/tests/python/unittest/test_lang_constructor.py @@ -38,16 +38,11 @@ def test_expr_constructor(): assert x.value == 2 assert x.dtype == "int64" - x = tvm.expr.UIntImm("uint16", 2) - assert isinstance(x, tvm.expr.UIntImm) - assert x.value == 2 - assert x.dtype == "uint16" - x = tvm.expr.StringImm("xyza") assert isinstance(x, tvm.expr.StringImm) assert x.value == "xyza" - x = tvm.expr.Cast("float32", tvm.expr.IntImm("int32", 1)) + x = tvm.expr.Cast("float32", tvm.expr.IntImm("uint32", 1)) assert isinstance(x, tvm.expr.Cast) assert x.dtype == "float32" assert x.value.value == 1 diff --git a/tests/python/unittest/test_lang_operator.py b/tests/python/unittest/test_lang_operator.py index c57f4a1109ec..ac2ee6d88cc5 100644 --- a/tests/python/unittest/test_lang_operator.py +++ b/tests/python/unittest/test_lang_operator.py @@ -29,7 +29,7 @@ def test_const_fold(): def check(f, *args): x = f(*[tvm.const(x, "int32") for x in args]) y = f(*args) - if not isinstance(x, (tvm.expr.IntImm, tvm.expr.UIntImm)) or x.value != int(y): + if not isinstance(x, (tvm.expr.IntImm,)) or x.value != int(y): raise ValueError("check error: %s vs %s " % (x, y)) tmod = tvm.truncmod diff --git a/topi/include/topi/detail/constant_utils.h b/topi/include/topi/detail/constant_utils.h index 43ac3a29cd7c..e6de76f20881 100644 --- a/topi/include/topi/detail/constant_utils.h +++ b/topi/include/topi/detail/constant_utils.h @@ -43,8 +43,7 @@ using namespace tvm; */ inline bool IsConstInt(PrimExpr expr) { return - expr->IsInstance() || - expr->IsInstance(); + expr->IsInstance(); } /*! @@ -56,11 +55,8 @@ inline bool IsConstInt(PrimExpr expr) { * \return The integer value. */ inline int64_t GetConstInt(PrimExpr expr) { - if (expr->IsInstance()) { - return expr.as()->value; - } - if (expr->IsInstance()) { - return expr.as()->value; + if (expr->IsInstance()) { + return expr.as()->value; } LOG(ERROR) << "expr must be a constant integer"; return -1; diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index 8f32a297d719..02d082b8b342 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -92,9 +92,9 @@ def get_const_int(expr): """ if isinstance(expr, Integral): return expr - if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(expr, tvm.expr.IntImm): expr = tvm.ir_pass.Simplify(expr) - if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(expr, tvm.expr.IntImm): raise ValueError("Expect value to be constant int") return int(expr.value) @@ -136,9 +136,9 @@ def equal_const_int(expr, value): """ if isinstance(expr, Integral): return expr == value - if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(expr, tvm.expr.IntImm): expr = tvm.ir_pass.Simplify(expr) - if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(expr, tvm.expr.IntImm): return False return expr.value == value @@ -160,9 +160,9 @@ def get_const_tuple(in_tuple): for elem in in_tuple: if isinstance(elem, tvm.expr.Var): ret.append(elem) - elif not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm, int)): + elif not isinstance(elem, (tvm.expr.IntImm, int)): elem = tvm.ir_pass.Simplify(elem) - if not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(elem, tvm.expr.IntImm): ret.append(elem) else: ret.append(get_const_int(elem))