diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 96526ccfcfb2..464ce6c143c5 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -152,11 +152,25 @@ TVM_DLL const Op& fma(); */ TVM_DLL const Op& call_extern(); +/*! + * \brief Call an pure extern C function with given name + * and signature from the types of args in the runtime environment. + * + * Type call_pure_extern(name, args...) { + * return dlsym(name)(args...); + * } + * + * \note This intrinsic does not provide any type checking, + * and is main used for backward compatibility reasons. + * Always consider use pre-registered and typed tvm::Op first. + */ +TVM_DLL const Op& call_pure_extern(); + /*! * \brief Call an LLVM intrinsic with a given intrinsic id * and signature from the types of args in the runtime environment. * - * Type call_llvm_intrin(intrin_id, args...) { + * Type call_llvm_pure_intrin(intrin_id, args...) { * return dlsym(name)(args...); * } * @@ -165,15 +179,27 @@ TVM_DLL const Op& call_extern(); TVM_DLL const Op& call_llvm_intrin(); /*! - * \brief Call an SPIRV GLSL450 intrinsic. + * \brief Call an LLVM pure intrinsic with a given intrinsic id + * and signature from the types of args in the runtime environment. + * + * Type call_llvm_pure_intrin(intrin_id, args...) { + * return dlsym(name)(args...); + * } + * + * \note This op does not provide any type checking. + */ +TVM_DLL const Op& call_llvm_pure_intrin(); + +/*! + * \brief Call an SPIRV pure GLSL450 intrinsic. * - * Type call_spirv_glsl450(intrin_id, args...) { + * Type call_spirv_pure_glsl450(intrin_id, args...) { * return dlsym(name)(args...); * } * * \note This op does not provide any type checking. */ -TVM_DLL const Op& call_spirv_glsl450(); +TVM_DLL const Op& call_spirv_pure_glsl450(); // TODO(tvm-team) revisit the builtins below // some of them can simply become ops with special codegen attr. diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index f0e6d898d7d8..100d163d8b12 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -875,19 +875,6 @@ class Let : public PrimExpr { */ class CallNode : public PrimExprNode { public: - /*! \brief Possible types of calls. */ - enum CallType : int { - /*! \brief Extern "C" function. */ - Extern = 0, - /*! \brief Extern CXX function. */ - ExternCPlusPlus = 1, - /*! \brief Extern "C" without side-effect. */ - PureExtern = 2, - /*! \brief Intrinsic functions. */ - Intrinsic = 4, - /*! \brief Intrinsic functions that are pure. */ - PureIntrinsic = 5 - }; /*! * \brief The operator(function) being invoked * @@ -898,31 +885,22 @@ class CallNode : public PrimExprNode { /*! \brief The arguments. */ Array args; - /*! \brief Type of calls. */ - CallType call_type; - void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); v->Visit("op", &op); v->Visit("args", &args); - v->Visit("call_type", &call_type); } bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args) && - equal(call_type, other->call_type); + return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); hash_reduce(op); hash_reduce(args); - hash_reduce(call_type); } - /*! \return Whether call node is pure. */ - bool is_pure() const { return (call_type == PureExtern || call_type == PureIntrinsic); } - static constexpr const char* _type_key = "tir.Call"; TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode); }; @@ -933,9 +911,7 @@ class CallNode : public PrimExprNode { */ class Call : public PrimExpr { public: - using CallType = CallNode::CallType; - - TVM_DLL Call(DataType dtype, RelayExpr op, Array args, CallType call_type); + TVM_DLL Call(DataType dtype, RelayExpr op, Array args); TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode); }; diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 09eb33c54b4c..34cb52f90211 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -553,10 +553,10 @@ TVM_DLL PrimExpr trunc(PrimExpr x); TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); // Intrinsic operators -#define TVM_DECLARE_INTRIN_UNARY(OpName) \ - inline PrimExpr OpName(PrimExpr x) { \ - static const Op& op = Op::Get("tir." #OpName); \ - return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); \ +#define TVM_DECLARE_INTRIN_UNARY(OpName) \ + inline PrimExpr OpName(PrimExpr x) { \ + static const Op& op = Op::Get("tir." #OpName); \ + return tir::Call(x.dtype(), op, {x}); \ } TVM_DECLARE_INTRIN_UNARY(exp); @@ -583,10 +583,10 @@ TVM_DECLARE_INTRIN_UNARY(acosh); TVM_DECLARE_INTRIN_UNARY(asinh); TVM_DECLARE_INTRIN_UNARY(atanh); -#define TVM_DECLARE_INTRIN_BINARY(OpName) \ - inline PrimExpr OpName(PrimExpr x, PrimExpr y) { \ - static const Op& op = Op::Get("tir." #OpName); \ - return tir::Call(x.dtype(), op, {x, y}, tir::CallNode::PureIntrinsic); \ +#define TVM_DECLARE_INTRIN_BINARY(OpName) \ + inline PrimExpr OpName(PrimExpr x, PrimExpr y) { \ + static const Op& op = Op::Get("tir." #OpName); \ + return tir::Call(x.dtype(), op, {x, y}); \ } TVM_DECLARE_INTRIN_BINARY(atan2); diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h index d7c13500d90e..ec7fc172cde8 100644 --- a/include/tvm/tir/op_attr_types.h +++ b/include/tvm/tir/op_attr_types.h @@ -43,6 +43,43 @@ using TGlobalSymbol = String; */ using TVectorizable = bool; +/*! + * \brief The effect type of the call. + */ +enum class CallEffectKind : int { + /*! \brief Function corresponds to an annotation(e.g. likely) and can translate to identity. */ + kExprAnnotation = 0, + /*! + * \brief Pure function that do not interacts + * with any external state. + */ + kPure = 1, + /*! + * \brief Function's that may read from states(e.g. RAM) + */ + kReadState = 2, + /*! + * \brief Function that may read/write from states(e.g. RAM). + */ + kUpdateState = 3, + /*! + * \brief Opaque function, cannot make any assumption + */ + kOpaque = kUpdateState, + /*! + * \brief Special intrinsic to annotate call arguments info + * only valid as a direct argument to a call. + */ + kSpecialCallArg = 4, + /*! + * \brief Embed opaque information in the Expr, cannot be codegen. + */ + kEmbedInfo = 5 +}; + +/*! \brief Use integer to record the kind. */ +using TCallEffectKind = Integer; + } // namespace tir } // namespace tvm #endif // TVM_TIR_OP_ATTR_TYPES_H_ diff --git a/python/tvm/te/hybrid/calls.py b/python/tvm/te/hybrid/calls.py index a119c20754f4..78ed1dce3a44 100644 --- a/python/tvm/te/hybrid/calls.py +++ b/python/tvm/te/hybrid/calls.py @@ -22,7 +22,7 @@ from tvm.ir.container import Array from tvm import target as _tgt from tvm.tir import expr as _expr -from tvm.tir import call_pure_intrin +from tvm.tir import call_intrin from tvm.tir.stmt import For from .util import _internal_assert @@ -148,7 +148,7 @@ def likely(func_id, args): _internal_assert(args.__len__() == 1, \ "Only one expression can be likely") _internal_assert(func_id == "likely", "This function cannot be directly invoked!") - return call_pure_intrin(args[0].dtype, 'tir.likely', *args) + return call_intrin(args[0].dtype, 'tir.likely', *args) def max_num_threads(func_id, args): diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 90ccde4f1647..9dbdc07b4a46 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -24,8 +24,8 @@ from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not -from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle, Call, Let -from .expr import IterVar, Any +from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle +from .expr import Call, CallEffectKind, Let, IterVar, Any from .stmt import Stmt, LetStmt, AssertStmt, For from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt @@ -34,8 +34,8 @@ from .function import PrimFunc -from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern -from .op import call_llvm_intrin, all, any, min_value, max_value, trace +from .op import call_packed, call_intrin, call_pure_extern, call_extern +from .op import call_llvm_intrin, call_llvm_pure_intrin, all, any, min_value, max_value, trace from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp from .op import sin, sinh, asin, asinh from .op import cos, cosh, acos, acosh diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 386badf3e8aa..c8f151ee2154 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -266,6 +266,23 @@ def asobject(self): return _ffi_api._OpNE(self.a, self.b) +class IntImmEnum(ObjectGeneric): + """Lazily evaluate an IntImm in case + the constructor is not available in runtime. + + Parameters + ---------- + value : int + The enum value + """ + def __init__(self, value): + self.value = value + + def asobject(self): + """Convert object.""" + return IntImm("int32", self.value) + + class PrimExprWithOp(ExprOp, PrimExpr): """Helper base class to inherit from PrimExpr.""" # In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__ @@ -959,6 +976,16 @@ def __init__(self, vectors, indices): _ffi_api.Shuffle, vectors, indices) +class CallEffectKind: + """Possible kinds of Call effects.""" + # only expose up to opaque + ExprAnnotation = IntImmEnum(0) + Pure = IntImmEnum(1) + ReadState = IntImmEnum(2) + UpdateState = IntImmEnum(3) + Opaque = UpdateState + + @tvm._ffi.register_object("tir.Call") class Call(PrimExprWithOp): """Call node. @@ -974,16 +1001,8 @@ class Call(PrimExprWithOp): args : list of Expr The input arguments to the call - - call_type : int - The type of the call """ - Extern = 0 - ExternCPlusPlus = 1 - PureExtern = 2 - Intrinsic = 4 - PureIntrinsic = 5 - def __init__(self, dtype, op, args, call_type): + def __init__(self, dtype, op, args): if isinstance(op, str): if not op.startswith("tir."): raise ValueError( @@ -992,7 +1011,7 @@ def __init__(self, dtype, op, args, call_type): "certain about the intrinsic name, pass in Op.get(name) instead") % op) op = Op.get(op) self.__init_handle_by_constructor__( - _ffi_api.Call, dtype, op, args, call_type) + _ffi_api.Call, dtype, op, args) @tvm._ffi.register_object("tir.Let") diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 089127c6f0ff..20180d1be45d 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -379,8 +379,7 @@ def likely(self, expr): expr : Expr The expression will likely tag. """ - return _expr.Call(expr.dtype, "tir.likely", [expr], - _expr.Call.PureIntrinsic) + return _expr.Call(expr.dtype, "tir.likely", [expr]) def get(self): """Return the builded IR. diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 6826241ac1a6..cbbd59fe4eaf 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -29,10 +29,8 @@ def _pack_buffer(buf): """Build intrinsics that packs the buffer. """ assert buf.shape - shape = Call("handle", "tir.tvm_stack_make_shape", buf.shape, - Call.Intrinsic) - strides = Call("handle", "tir.tvm_stack_make_shape", buf.strides, - Call.Intrinsic) if buf.strides else 0 + shape = Call("handle", "tir.tvm_stack_make_shape", buf.shape) + strides = Call("handle", "tir.tvm_stack_make_shape", buf.strides) if buf.strides else 0 pack_args = [buf.data, shape, strides, @@ -40,7 +38,7 @@ def _pack_buffer(buf): const(0, dtype=buf.dtype), buf.elem_offset] return Call("handle", Op.get("tir.tvm_stack_make_array"), - pack_args, Call.Intrinsic) + pack_args) def call_packed(*args): """Build expression by call an external packed function. @@ -68,11 +66,11 @@ def call_packed(*args): """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] return Call( - "int32", Op.get("tir.tvm_call_packed"), call_args, Call.Intrinsic) + "int32", Op.get("tir.tvm_call_packed"), call_args) -def call_pure_intrin(dtype, func_name, *args): - """Build expression by calling a pure intrinsic function. +def call_intrin(dtype, func_name, *args): + """Build expression by calling an intrinsic function. Intrinsics can be overloaded with multiple data types via the intrinsic translation rule. @@ -93,16 +91,12 @@ def call_pure_intrin(dtype, func_name, *args): call : PrimExpr The call expression. """ - args = convert(args) return Call( - dtype, func_name, convert(args), Call.PureIntrinsic) + dtype, func_name, convert(args)) -def call_intrin(dtype, func_name, *args): - """Build expression by calling an intrinsic function. - - Intrinsics can be overloaded with multiple data types via - the intrinsic translation rule. +def call_pure_extern(dtype, func_name, *args): + """Build expression by calling a pure extern function. Parameters ---------- @@ -110,7 +104,7 @@ def call_intrin(dtype, func_name, *args): The data type of the result. func_name: str - The intrinsic function name. + The extern function name. args : list Positional arguments. @@ -120,13 +114,12 @@ def call_intrin(dtype, func_name, *args): call : PrimExpr The call expression. """ - args = convert(args) return Call( - dtype, func_name, convert(args), Call.Intrinsic) + dtype, Op.get("tir.call_pure_extern"), convert((StringImm(func_name),) + args)) -def call_pure_extern(dtype, func_name, *args): - """Build expression by calling a pure extern function. +def call_extern(dtype, func_name, *args): + """Build expression by calling a extern function. Parameters ---------- @@ -145,34 +138,39 @@ def call_pure_extern(dtype, func_name, *args): The call expression. """ return Call( - dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), Call.PureExtern) + dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args)) -def call_extern(dtype, func_name, *args): - """Build expression by calling a extern function. +def call_llvm_intrin(dtype, name, *args): + """Build expression by calling a llvm intrinsic function Parameters ---------- dtype : str - The data type of the result. + The data type of the result. - func_name: str - The extern function name. + name : str + The name of the llvm intrinsic function. args : list - Positional arguments. + Poistional arguments. Returns ------- call : PrimExpr The call expression. """ - return Call( - dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), Call.Extern) + # pylint: disable=import-outside-toplevel + from tvm.target import codegen + llvm_id = codegen.llvm_lookup_intrinsic_id(name) + assert llvm_id != 0, "%s is not an LLVM intrinsic" % name + return call_intrin( + dtype, Op.get("tir.call_llvm_intrin"), + tvm.tir.const(llvm_id, 'uint32'), *args) -def call_llvm_intrin(dtype, name, *args): - """Build expression by calling an llvm intrinsic function +def call_llvm_pure_intrin(dtype, name, *args): + """Build expression by calling a pure llvm intrinsic function Parameters ---------- @@ -194,8 +192,9 @@ def call_llvm_intrin(dtype, name, *args): from tvm.target import codegen llvm_id = codegen.llvm_lookup_intrinsic_id(name) assert llvm_id != 0, "%s is not an LLVM intrinsic" % name - return call_pure_intrin(dtype, Op.get("tir.call_llvm_intrin"), - tvm.tir.const(llvm_id, 'uint32'), *args) + return call_intrin( + dtype, Op.get("tir.call_llvm_pure_intrin"), + tvm.tir.const(llvm_id, 'uint32'), *args) def any(*args): @@ -279,7 +278,7 @@ def trace(args, trace_action="tvm.default_trace_action"): call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] call_args.insert(0, trace_action) return tvm.tir.Call( - args[-1].dtype, Op.get("tir.tvm_call_trace_packed"), call_args, tvm.tir.Call.Intrinsic) + args[-1].dtype, Op.get("tir.tvm_call_trace_packed"), call_args) @@ -328,7 +327,7 @@ def exp(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.exp", x) + return call_intrin(x.dtype, "tir.exp", x) def exp2(x): @@ -344,7 +343,7 @@ def exp2(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.exp2", x) + return call_intrin(x.dtype, "tir.exp2", x) def exp10(x): @@ -360,7 +359,7 @@ def exp10(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.exp10", x) + return call_intrin(x.dtype, "tir.exp10", x) def erf(x): @@ -376,7 +375,7 @@ def erf(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.erf", x) + return call_intrin(x.dtype, "tir.erf", x) def tanh(x): @@ -392,7 +391,7 @@ def tanh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.tanh", x) + return call_intrin(x.dtype, "tir.tanh", x) def sigmoid(x): @@ -408,7 +407,7 @@ def sigmoid(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.sigmoid", x) + return call_intrin(x.dtype, "tir.sigmoid", x) def log(x): @@ -424,7 +423,7 @@ def log(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.log", x) + return call_intrin(x.dtype, "tir.log", x) def log2(x): @@ -440,7 +439,7 @@ def log2(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.log2", x) + return call_intrin(x.dtype, "tir.log2", x) def log10(x): @@ -456,7 +455,7 @@ def log10(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.log10", x) + return call_intrin(x.dtype, "tir.log10", x) def log1p(x): @@ -472,7 +471,7 @@ def log1p(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.log1p", x) + return call_intrin(x.dtype, "tir.log1p", x) def tan(x): @@ -488,7 +487,7 @@ def tan(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.tan", x) + return call_intrin(x.dtype, "tir.tan", x) def cos(x): @@ -504,7 +503,7 @@ def cos(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.cos", x) + return call_intrin(x.dtype, "tir.cos", x) def cosh(x): @@ -520,7 +519,7 @@ def cosh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.cosh", x) + return call_intrin(x.dtype, "tir.cosh", x) def acos(x): @@ -536,7 +535,7 @@ def acos(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.acos", x) + return call_intrin(x.dtype, "tir.acos", x) def acosh(x): @@ -552,7 +551,7 @@ def acosh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.acosh", x) + return call_intrin(x.dtype, "tir.acosh", x) def sin(x): @@ -568,7 +567,7 @@ def sin(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.sin", x) + return call_intrin(x.dtype, "tir.sin", x) def sinh(x): @@ -584,7 +583,7 @@ def sinh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.sinh", x) + return call_intrin(x.dtype, "tir.sinh", x) def asin(x): @@ -600,7 +599,7 @@ def asin(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.asin", x) + return call_intrin(x.dtype, "tir.asin", x) def asinh(x): @@ -616,7 +615,7 @@ def asinh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.asinh", x) + return call_intrin(x.dtype, "tir.asinh", x) def atan(x): @@ -632,7 +631,7 @@ def atan(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.atan", x) + return call_intrin(x.dtype, "tir.atan", x) def atanh(x): @@ -648,7 +647,7 @@ def atanh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.atanh", x) + return call_intrin(x.dtype, "tir.atanh", x) def atan2(x1, x2): @@ -667,7 +666,7 @@ def atan2(x1, x2): y : PrimExpr The result. """ - return call_pure_intrin(x1.dtype, "tir.atan2", x1, x2) + return call_intrin(x1.dtype, "tir.atan2", x1, x2) def sqrt(x): @@ -683,7 +682,7 @@ def sqrt(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.sqrt", x) + return call_intrin(x.dtype, "tir.sqrt", x) def rsqrt(x): @@ -699,7 +698,7 @@ def rsqrt(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.rsqrt", x) + return call_intrin(x.dtype, "tir.rsqrt", x) def floor(x): @@ -824,7 +823,7 @@ def nextafter(x1, x2): y : PrimExpr The result. """ - return call_pure_intrin(x1.dtype, "tir.nextafter", x1, x2) + return call_intrin(x1.dtype, "tir.nextafter", x1, x2) def hypot(x1, x2): @@ -843,7 +842,7 @@ def hypot(x1, x2): y : PrimExpr The result. """ - return call_pure_intrin(x1.dtype, "tir.hypot", x1, x2) + return call_intrin(x1.dtype, "tir.hypot", x1, x2) def copysign(x1, x2): @@ -862,7 +861,7 @@ def copysign(x1, x2): y : PrimExpr The result. """ - return call_pure_intrin(x1.dtype, "tir.copysign", x1, x2) + return call_intrin(x1.dtype, "tir.copysign", x1, x2) def ldexp(x1, x2): @@ -881,7 +880,7 @@ def ldexp(x1, x2): y : PrimExpr The result. """ - return call_pure_intrin(x1.dtype, "tir.ldexp", x1, x2) + return call_intrin(x1.dtype, "tir.ldexp", x1, x2) def isnan(x): @@ -964,7 +963,7 @@ def popcount(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.popcount", x) + return call_intrin(x.dtype, "tir.popcount", x) def fmod(x, y): """Return the remainder of x divided by y with the same sign as x. @@ -981,7 +980,7 @@ def fmod(x, y): z : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.fmod", x, y) + return call_intrin(x.dtype, "tir.fmod", x, y) def if_then_else(cond, t, f): diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index c367d0c9f9d8..259fcd90c63d 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -146,7 +146,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { false_value.same_as(op->args[2])) { return GetRef(op); } else { - return Call(op->dtype, op->op, {cond, true_value, false_value}, op->call_type); + return Call(op->dtype, op->op, {cond, true_value, false_value}); } } return StmtExprMutator::VisitExpr_(op); diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index de8425146bbf..81a4d616d432 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -679,7 +679,7 @@ class PCallExpr : public Pattern> { #define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinOpName) \ struct OpName { \ static PrimExpr Eval(Array args) { \ - return tir::Call(args[0].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic); \ + return tir::Call(args[0].dtype(), GetOp(), args); \ } \ static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \ }; \ @@ -695,25 +695,23 @@ TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, bitwise_or); TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, bitwise_xor); // unary intrinsics -#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName) \ - struct OpName { \ - static PrimExpr Eval(Array args) { \ - return tir::Call(args[0].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic); \ - } \ - static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \ - }; \ - template \ - inline PCallExpr FuncName(const Pattern& a) { \ - return PCallExpr(a.derived()); \ +#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName) \ + struct OpName { \ + static PrimExpr Eval(Array args) { \ + return tir::Call(args[0].dtype(), GetOp(), args); \ + } \ + static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \ + }; \ + template \ + inline PCallExpr FuncName(const Pattern& a) { \ + return PCallExpr(a.derived()); \ } TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, bitwise_not); // if_then_else struct PIfThenElseOp { - static PrimExpr Eval(Array args) { - return tir::Call(args[1].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic); - } + static PrimExpr Eval(Array args) { return tir::Call(args[1].dtype(), GetOp(), args); } static const Op& GetOp() { return tir::builtin::if_then_else(); } }; diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 0d5d654c3f6e..b65ae91c6393 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -238,7 +238,8 @@ void CodeGenHybrid::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLIN PrintExpr(op->args[0], os); os << " else "; PrintExpr(op->args[2], os); - } else if (op->op.same_as(builtin::call_extern())) { + } else if (op->op.same_as(builtin::call_pure_extern()) || + op->op.same_as(builtin::call_extern())) { StringImm fname = Downcast(op->args[0]); os << fname << "("; for (size_t i = 1; i < op->args.size(); i++) { diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index a11de012d391..7ab26fae785f 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -326,23 +326,6 @@ Doc TIRTextPrinter::VisitExpr_(const LetNode* op) { return doc; } -inline const char* CallType2String(CallNode::CallType t) { - switch (t) { - case CallNode::Extern: - return "extern"; - case CallNode::ExternCPlusPlus: - return "extern_cpp"; - case CallNode::PureExtern: - return "pure_extern"; - case CallNode::Intrinsic: - return "intrin"; - case CallNode::PureIntrinsic: - return "pure_intrin"; - } - LOG(FATAL) << "Unknown CallType"; - return "Unknown"; -} - Doc TIRTextPrinter::VisitExpr_(const CallNode* op) { Doc doc; if (auto* ptr_op = op->op.as()) { @@ -357,8 +340,7 @@ Doc TIRTextPrinter::VisitExpr_(const CallNode* op) { for (const auto& arg : op->args) { args.push_back(Print(arg)); } - doc << PrintSep(args, Doc::Text(", ")) << ", dtype=" << PrintDType(op->dtype) - << ", type=" << Doc::StrLiteral(CallType2String(op->call_type)) << ")"; + doc << PrintSep(args, Doc::Text(", ")) << ", dtype=" << PrintDType(op->dtype) << ")"; return doc; } diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 37855fb39179..31fadf1ce5ac 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -29,53 +29,53 @@ namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log2").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log2").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log10").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log10").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atanh").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atanh").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acos").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acos").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acosh").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acosh").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asin").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asin").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asinh").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asinh").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.nextafter").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.nextafter").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.copysign").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.copysign").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.ldexp").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.ldexp").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt").set_body(DispatchPureExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt") .set_body([](const TVMArgs& args, TVMRetValue* rv) { @@ -87,7 +87,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt") *rv = one / sqrt(call->args[0]); }); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow").set_body(DispatchPureExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid") .set_body([](const TVMArgs& args, TVMRetValue* rv) { diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h index 36e553900d00..359c5b9580b5 100644 --- a/src/target/intrin_rule.h +++ b/src/target/intrin_rule.h @@ -55,7 +55,7 @@ struct Direct { // Call pure extern function. template -inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) { +inline void DispatchPureExtern(const TVMArgs& args, TVMRetValue* rv) { PrimExpr e = args[0]; const CallNode* call = e.as(); CHECK(call != nullptr); @@ -72,7 +72,7 @@ inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) { for (auto arg : call->args) { new_args.push_back(arg); } - *rv = Call(call->dtype, tir::builtin::call_extern(), new_args, CallNode::PureExtern); + *rv = Call(call->dtype, tir::builtin::call_pure_extern(), new_args); } else { *rv = e; } diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index 13ce59d54b82..5e5a94b50064 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -46,7 +46,7 @@ class CodeGenARM final : public CodeGenCPU { }; llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { - if (op->op.same_as(builtin_call_llvm_intrin_)) { + if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) { llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); if (id == ::llvm::Intrinsic::ctpop) { PrimExpr e = ARMPopcount(op); @@ -70,7 +70,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt_args.push_back(e); - return tir::Call(call->dtype, builtin_call_llvm_intrin_, vcnt_args, CallNode::PureIntrinsic); + return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt_args); } // Popcount lowering rule: @@ -94,16 +94,14 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt8_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt8_args.push_back(input8); - PrimExpr vcnt8 = - tir::Call(uint8_type, builtin_call_llvm_intrin_, vcnt8_args, CallNode::PureIntrinsic); + PrimExpr vcnt8 = tir::Call(uint8_type, builtin_call_llvm_pure_intrin_, vcnt8_args); // Accumulation 8->16bit Array vcnt16_args; vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt16_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt16_args.push_back(vcnt8); - PrimExpr vcnt16 = - tir::Call(uint16_type, builtin_call_llvm_intrin_, vcnt16_args, CallNode::PureIntrinsic); + PrimExpr vcnt16 = tir::Call(uint16_type, builtin_call_llvm_pure_intrin_, vcnt16_args); if (call->dtype.bits() == 16) { return vcnt16; } @@ -113,8 +111,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt32_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt32_args.push_back(vcnt16); - PrimExpr vcnt32 = - tir::Call(uint32_type, builtin_call_llvm_intrin_, vcnt32_args, CallNode::PureIntrinsic); + PrimExpr vcnt32 = tir::Call(uint32_type, builtin_call_llvm_pure_intrin_, vcnt32_args); if (call->dtype.bits() == 32) { return vcnt32; } @@ -124,7 +121,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt64_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt64_args.push_back(vcnt32); - return tir::Call(call->dtype, builtin_call_llvm_intrin_, vcnt64_args, CallNode::PureIntrinsic); + return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt64_args); } TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 49f14c31d07f..99a23c64d402 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -738,7 +738,7 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type } llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { - if (op->op.same_as(builtin_call_llvm_intrin_)) { + if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) { CHECK_GE(op->args.size(), 2U); llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); int64_t num_signature = Downcast(op->args[1])->value; @@ -1077,7 +1077,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { if (auto* ptr_op = op->op.as()) { auto call_op = GetRef(ptr_op); - if (op->op.same_as(builtin_call_extern_)) { + if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { // call extern intrinsic CHECK_GE(op->args.size(), 1U); auto global_symbol = Downcast(op->args[0]); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 2bfe047038b0..9e7b56a6d7ae 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -326,7 +326,10 @@ class CodeGenLLVM : public ExprFunctor, // global symbol table. OpAttrMap op_attr_global_symbol_ = Op::GetAttrMap("TGlobalSymbol"); const Op& builtin_call_extern_ = builtin::call_extern(); + const Op& builtin_call_pure_extern_ = builtin::call_pure_extern(); const Op& builtin_call_llvm_intrin_ = builtin::call_llvm_intrin(); + const Op& builtin_call_llvm_pure_intrin_ = builtin::call_llvm_pure_intrin(); + /*! \brief Helper struct for debug infos. */ struct DebugInfo { std::unique_ptr di_builder_; diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index 5d269fa4d513..6f3d4f7cb25e 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -90,7 +90,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { DTypeToLLVMType(DataType::Float(32, from.lanes())), { MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::builtin::reinterpret(), - {op->value}, tir::CallNode::PureIntrinsic)), + {op->value})), MakeValue(tir::Broadcast(FloatImm(DataType::Float(32), 0), from.lanes())), /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)), @@ -102,11 +102,10 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { const auto has_f16c = TargetHasFeature(*target_machine_, "f16c"); if (from.lanes() >= 8 && has_f16c) { - return CallVectorIntrin( - ::llvm::Intrinsic::x86_vcvtph2ps_256, 8, - DTypeToLLVMType(DataType::Float(32, from.lanes())), - {MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::builtin::reinterpret(), - {op->value}, tir::CallNode::PureIntrinsic))}); + return CallVectorIntrin(::llvm::Intrinsic::x86_vcvtph2ps_256, 8, + DTypeToLLVMType(DataType::Float(32, from.lanes())), + {MakeValue(tir::Call(DataType::Int(16, from.lanes()), + tir::builtin::reinterpret(), {op->value}))}); } #endif } diff --git a/src/target/llvm/intrin_rule_llvm.h b/src/target/llvm/intrin_rule_llvm.h index cc9437d25b7e..1a6775e92e12 100644 --- a/src/target/llvm/intrin_rule_llvm.h +++ b/src/target/llvm/intrin_rule_llvm.h @@ -50,8 +50,7 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = - tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs, tir::CallNode::PureIntrinsic); + *rv = tir::Call(call->dtype, tir::builtin::call_llvm_pure_intrin(), cargs); } template @@ -66,7 +65,7 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs, tir::CallNode::Intrinsic); + *rv = tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs); } } // namespace codegen diff --git a/src/target/llvm/intrin_rule_nvptx.cc b/src/target/llvm/intrin_rule_nvptx.cc index a0ffe11da27a..0e332940339c 100644 --- a/src/target/llvm/intrin_rule_nvptx.cc +++ b/src/target/llvm/intrin_rule_nvptx.cc @@ -32,7 +32,7 @@ namespace tvm { namespace codegen { -inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) { +inline void DispatchPureExternLibDevice(const TVMArgs& args, TVMRetValue* rv) { PrimExpr e = args[0]; using namespace tir; const CallNode* call = e.as(); @@ -52,54 +52,54 @@ inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) { for (auto arg : call->args) { new_args.push_back(arg); } - *rv = Call(call->dtype, builtin::call_extern(), new_args, CallNode::PureExtern); + *rv = Call(call->dtype, builtin::call_pure_extern(), new_args); } namespace llvm { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.floor").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.floor").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.ceil").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.ceil").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp2").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp2").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp10").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp10").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log2").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log2").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log10").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log10").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cosh").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cosh").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sinh").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sinh").set_body(DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan").set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan").set_body(DispatchPureExternLibDevice); } // namespace llvm } // namespace codegen diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index 07520ae08cc8..22ebf9b192aa 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -32,7 +32,7 @@ namespace tvm { namespace codegen { -inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) { +inline void DispatchPureExternOCML(const TVMArgs& args, TVMRetValue* rv) { PrimExpr e = args[0]; using namespace tir; const CallNode* call = e.as(); @@ -51,7 +51,7 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) { new_args.push_back(arg); } - *rv = Call(call->dtype, builtin::call_extern(), new_args, CallNode::PureExtern); + *rv = Call(call->dtype, builtin::call_pure_extern(), new_args); } inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) { @@ -66,10 +66,10 @@ inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) { // get own lane in self (__lane_id) PrimExpr minus_one = tir::make_const(DataType::Int(32), -1); PrimExpr zero = tir::make_zero(DataType::Int(32)); - PrimExpr lo = Call(DataType::Int(32), builtin::call_extern(), - {StringImm("llvm.amdgcn.mbcnt.lo"), minus_one, zero}, CallNode::PureExtern); - PrimExpr self = Call(DataType::Int(32), builtin::call_extern(), - {StringImm("llvm.amdgcn.mbcnt.hi"), minus_one, lo}, CallNode::PureExtern); + PrimExpr lo = Call(DataType::Int(32), builtin::call_pure_extern(), + {StringImm("llvm.amdgcn.mbcnt.lo"), minus_one, zero}); + PrimExpr self = Call(DataType::Int(32), builtin::call_pure_extern(), + {StringImm("llvm.amdgcn.mbcnt.hi"), minus_one, lo}); // compute lane to get from PrimExpr width = call->args[3]; @@ -87,9 +87,8 @@ inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) { index = self + delta; index = Select((self & (width - 1)) + delta >= width, self, index); } - PrimExpr res = - Call(var.dtype(), builtin::call_extern(), - {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var}, CallNode::PureExtern); + PrimExpr res = Call(var.dtype(), builtin::call_pure_extern(), + {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var}); *rv = res; } @@ -108,49 +107,49 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_shuffle_up").set_body(Dispatc TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_shuffle_down").set_body(DispatchShuffle); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp2").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp2").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp10").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp10").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log2").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log2").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log10").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log10").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cosh").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cosh").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sinh").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sinh").set_body(DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan").set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan").set_body(DispatchPureExternOCML); } // namespace llvm } // namespace codegen diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index ffeaba06d701..05582fb07d6a 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -575,7 +575,7 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) if (auto* ptr_op = op->op.as()) { auto call_op = GetRef(ptr_op); - if (op->op.same_as(builtin_call_extern_)) { + if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { CHECK_GE(op->args.size(), 1U); auto func = Downcast(op->args[0]); this->PrintCallExtern(GetType(GetRef(op)), func->value, op->args, true, os); diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 9346f87cb3bb..87a4a2944130 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -262,6 +262,7 @@ class CodeGenC : public ExprFunctor, OpAttrMap op_attr_global_symbol_ = Op::GetAttrMap("TGlobalSymbol"); // cache commonly used ops const Op& builtin_call_extern_ = builtin::call_extern(); + const Op& builtin_call_pure_extern_ = builtin::call_pure_extern(); private: /*! \brief whether to print in SSA form */ diff --git a/src/target/source/intrin_rule_aocl.cc b/src/target/source/intrin_rule_aocl.cc index 0cafd0255a86..69279a041413 100644 --- a/src/target/source/intrin_rule_aocl.cc +++ b/src/target/source/intrin_rule_aocl.cc @@ -27,49 +27,49 @@ namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.floor").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.floor").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.ceil").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.ceil").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.trunc").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.trunc").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.fabs").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.fabs").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.round").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.round").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.exp").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.exp").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.log").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.log").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.tanh").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.tanh").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.sqrt").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.sqrt").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.pow").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.pow").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.popcount").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.popcount").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.floor").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.floor").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.ceil").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.ceil").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.trunc").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.trunc").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.fabs").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.fabs").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.round").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.round").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.exp").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.exp").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.log").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.log").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.tanh").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.tanh").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.sqrt").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.sqrt").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.pow").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.pow").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.popcount").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.popcount").set_body(DispatchPureExtern); } // namespace intrin } // namespace codegen diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 53a2799e2725..9ffceb68e278 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -110,7 +110,7 @@ struct CUDAWarpIntrinsic { static void DispatchCUDAWarpActiveMask(const TVMArgs& args, TVMRetValue* rv) { Call call = args[0]; - *rv = Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args, CallNode::PureExtern); + *rv = Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args); } template @@ -121,53 +121,52 @@ static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) { CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size Array cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}}; - *rv = - Call(call->dtype, T()(call->dtype, Downcast(call->op)), cuda_args, CallNode::PureExtern); + *rv = Call(call->dtype, T()(call->dtype, Downcast(call->op)), cuda_args); } -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fabs").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fabs").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp2").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp2").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp10").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp10").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log2").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log2").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log10").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log10").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cosh").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cosh").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sinh").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sinh").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount").set_body(DispatchPureExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle") .set_body(DispatchCUDAShuffle); @@ -181,28 +180,32 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_down") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_activemask") .set_body(DispatchCUDAWarpActiveMask); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod").set_body(DispatchPureExtern); // Register low-level builtin ops. // TODO(tvm-team): consider make CUDA its own subfolder and create a file for low-level builtins. TVM_REGISTER_OP("tir.cuda.__shfl_sync") .set_num_inputs(4) .set_attr("TGlobalSymbol", "__shfl_sync") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); TVM_REGISTER_OP("tir.cuda.__shfl_up_sync") .set_num_inputs(4) .set_attr("TGlobalSymbol", "__shfl_up_sync") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); TVM_REGISTER_OP("tir.cuda.__shfl_down_sync") .set_num_inputs(4) .set_attr("TGlobalSymbol", "__shfl_down_sync") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); TVM_REGISTER_OP("tir.cuda.__activemask") .set_num_inputs(0) .set_attr("TGlobalSymbol", "__activemask") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_attr("cuda.need_warp_shuffle", true); } // namespace intrin diff --git a/src/target/source/intrin_rule_metal.cc b/src/target/source/intrin_rule_metal.cc index 00fb9f9a95de..80a10312c011 100644 --- a/src/target/source/intrin_rule_metal.cc +++ b/src/target/source/intrin_rule_metal.cc @@ -27,45 +27,45 @@ namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fabs").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fabs").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp2").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp2").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp10").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp10").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log2").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log2").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log10").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log10").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sin").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sin").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sinh").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sinh").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cos").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cos").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cosh").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cosh").set_body(DispatchPureExtern); } // namespace intrin } // namespace codegen diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index 82eabdd96dfe..7f81e335ec8d 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -29,45 +29,45 @@ namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fabs").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fabs").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp2").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp2").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp10").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp10").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log2").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log2").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log10").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log10").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sin").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sin").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sinh").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sinh").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cos").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cos").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh").set_body(DispatchPureExtern); // There is no warp shuffle instruction in standard OpenCL // When shuffle is used, we assume it is intel's shuffle extension @@ -80,7 +80,7 @@ static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) { CHECK(analyzer.CanProve(call->args[3] == call->args[4])) << "Intel warp shuffle dose not support width != warp_size"; Array opencl_args{{StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}}; - *rv = Call(call->dtype, builtin::call_extern(), opencl_args, CallNode::PureExtern); + *rv = Call(call->dtype, builtin::call_pure_extern(), opencl_args); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle").set_body(DispatchIntelShuffle); diff --git a/src/target/source/intrin_rule_vhls.cc b/src/target/source/intrin_rule_vhls.cc index fb01d6566dab..da9bc79452ed 100644 --- a/src/target/source/intrin_rule_vhls.cc +++ b/src/target/source/intrin_rule_vhls.cc @@ -27,43 +27,43 @@ namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.floor").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.floor").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.ceil").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.ceil").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.trunc").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.trunc").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.fabs").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.fabs").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp2").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp2").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp10").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp10").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log2").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log2").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log10").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log10").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.tanh").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.tanh").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sqrt").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sqrt").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.pow").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.pow").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.popcount").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.popcount").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sin").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sin").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sinh").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sinh").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cos").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cos").set_body(DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cosh").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cosh").set_body(DispatchPureExtern); } // namespace intrin } // namespace codegen diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 6c12343c81ec..ff3bc7d5f227 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -237,7 +237,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) { } spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { - if (op->op.same_as(builtin::call_spirv_glsl450())) { + if (op->op.same_as(builtin::call_spirv_pure_glsl450())) { CHECK_GE(op->args.size(), 2U); uint32_t inst_id = static_cast(op->args[0].as()->value); std::vector values; @@ -317,13 +317,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { return builder_->MakeValue(spv::OpBitCount, builder_->GetSType(op->dtype), MakeValue(op->args[0])); } else { - if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { - LOG(FATAL) << "Unresolved intrinsic " << op->op << " with return type " << op->dtype; - } else if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) { - LOG(FATAL) << "Unresolved extern " << op->op << " with return type " << op->dtype; - } else { - LOG(FATAL) << "Unresolved call type " << op->call_type; - } + LOG(FATAL) << "Unresolved call " << op->op; return spirv::Value(); } } diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index 1b9d2e4e410d..ea575ca83866 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -44,8 +44,7 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::Call(call->dtype, tir::builtin::call_spirv_glsl450(), cargs, - tir::CallNode::PureIntrinsic); + *rv = tir::Call(call->dtype, tir::builtin::call_spirv_pure_glsl450(), cargs); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor") diff --git a/src/te/autodiff/jacobian.cc b/src/te/autodiff/jacobian.cc index f6254121b7cb..e2479d8f133e 100644 --- a/src/te/autodiff/jacobian.cc +++ b/src/te/autodiff/jacobian.cc @@ -95,34 +95,32 @@ class JacobianMutator : public ExprMutator { PrimExpr VisitExpr_(const CallNode* op) { PrimExpr expr = GetRef(op); - if (op->call_type == CallNode::CallType::PureIntrinsic) { - if (op->op.same_as(op_exp_)) { - return Mul(Mutate(op->args[0]), expr); - } else if (op->op.same_as(op_log_)) { - return Div(Mutate(op->args[0]), op->args[0]); - } else if (op->op.same_as(op_sigmoid_)) { - return Mul(Mutate(op->args[0]), Mul(expr, Sub(FloatImm(expr.dtype(), 1.0), expr))); - } else if (op->op.same_as(op_sqrt_)) { - return Div(Mutate(op->args[0]), Mul(expr, FloatImm(expr.dtype(), 2.0))); - } else if (op->op.same_as(op_tanh_)) { - return Mul(Mutate(op->args[0]), Sub(FloatImm(expr.dtype(), 1.0), Mul(expr, expr))); - } else if (op->op.same_as(op_pow_)) { - auto x = op->args[0], y = op->args[1]; - return expr * (Mutate(y) * log(x) + Mutate(x) * y / x); - } else if (op->op.same_as(op_fabs_)) { - auto type = op->args[0].dtype(); - return Mul(Mutate(op->args[0]), Select(GE(op->args[0], make_zero(type)), - FloatImm(type, 1.0), FloatImm(type, -1.0))); - } else if (op->op.same_as(op_if_then_else_)) { - Array new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])}; - return Call(op->dtype, op->op, new_args, op->call_type); - } else if (piecewise_const.count(op->op)) { - return FloatImm(expr.dtype(), 0.0); - } else { - LOG(FATAL) << "Derivative of this intrinsic is not implemented: " << op->op; - } + if (op->op.same_as(op_exp_)) { + return Mul(Mutate(op->args[0]), expr); + } else if (op->op.same_as(op_log_)) { + return Div(Mutate(op->args[0]), op->args[0]); + } else if (op->op.same_as(op_sigmoid_)) { + return Mul(Mutate(op->args[0]), Mul(expr, Sub(FloatImm(expr.dtype(), 1.0), expr))); + } else if (op->op.same_as(op_sqrt_)) { + return Div(Mutate(op->args[0]), Mul(expr, FloatImm(expr.dtype(), 2.0))); + } else if (op->op.same_as(op_tanh_)) { + return Mul(Mutate(op->args[0]), Sub(FloatImm(expr.dtype(), 1.0), Mul(expr, expr))); + } else if (op->op.same_as(op_pow_)) { + auto x = op->args[0], y = op->args[1]; + return expr * (Mutate(y) * log(x) + Mutate(x) * y / x); + } else if (op->op.same_as(op_fabs_)) { + auto type = op->args[0].dtype(); + return Mul(Mutate(op->args[0]), Select(GE(op->args[0], make_zero(type)), FloatImm(type, 1.0), + FloatImm(type, -1.0))); + } else if (op->op.same_as(op_if_then_else_)) { + Array new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])}; + return Call(op->dtype, op->op, new_args); + } else if (piecewise_const.count(op->op)) { + return FloatImm(expr.dtype(), 0.0); + } else { + LOG(FATAL) << "Derivative of this intrinsic is not implemented: " << op->op; + return PrimExpr(); } - NOT_IMPLEMENTED; } PrimExpr VisitExpr_(const AddNode* op) { return Add(Mutate(op->a), Mutate(op->b)); } diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index b4725c571782..21343ecfe1b1 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -277,10 +277,9 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, if (attr->dim_align_factor != 0) { Array tuple = {static_cast(i), attr->dim_align_factor, attr->dim_align_offset}; - realize = tir::AttrStmt( - t, tir::attr::buffer_dim_align, - Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), - realize); + realize = + tir::AttrStmt(t, tir::attr::buffer_dim_align, + Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), realize); } } } diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index eeaab301ad03..427be320e844 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -196,8 +196,8 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, // Apply the existing input predicate if any. output_preds.push_back(input_pred); - Stmt reduce_body = Evaluate(Call(DataType::Handle(), tir::builtin::tvm_thread_allreduce(), - freduce_args, CallNode::Intrinsic)); + Stmt reduce_body = + Evaluate(Call(DataType::Handle(), tir::builtin::tvm_thread_allreduce(), freduce_args)); reduce_body = AttrStmt(reduces[0]->combiner, tir::attr::reduce_scope, make_zero(DataType::Handle()), reduce_body); diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 01019e43e61c..d789938a51b1 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -153,7 +153,7 @@ Stmt ExternOpNode::BuildProvide(const Stage& stage, tuple.push_back(buffer->shape[k]); } ret = AttrStmt(bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), builtin::tvm_tuple(), tuple, CallNode::Intrinsic), ret); + Call(DataType::Handle(), builtin::tvm_tuple(), tuple), ret); }; for (size_t i = output_placeholders.size(); i != 0; --i) { f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1)); diff --git a/src/te/operation/tensor_compute_op.cc b/src/te/operation/tensor_compute_op.cc index 714e8859229d..f6f00584aa76 100644 --- a/src/te/operation/tensor_compute_op.cc +++ b/src/te/operation/tensor_compute_op.cc @@ -152,9 +152,9 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage, tuple.push_back(region[i]->min); tuple.push_back(region[i]->extent); } - input_bind_nest.emplace_back(AttrStmt( - bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop)); + input_bind_nest.emplace_back( + AttrStmt(bind_spec, tir::attr::buffer_bind_scope, + Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop)); } // output binding @@ -176,9 +176,9 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage, } } - output_bind_nest.emplace_back(AttrStmt( - bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop)); + output_bind_nest.emplace_back( + AttrStmt(bind_spec, tir::attr::buffer_bind_scope, + Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop)); } // Check variable remap diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index dd978a430e4b..d48bf78f10d8 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -368,9 +368,9 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, tuple.push_back(r->min); tuple.push_back(r->extent); } - input_bind_nest.emplace_back(AttrStmt( - bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop)); + input_bind_nest.emplace_back( + AttrStmt(bind_spec, tir::attr::buffer_bind_scope, + Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop)); } // output binding const ComputeOpNode* intrin_compute = intrin->op.as(); @@ -388,9 +388,9 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, Tensor tensor = stage->op.output(i - intrin->inputs.size()); Buffer buffer = intrin->buffers[i]; Array bind_spec{buffer, tensor}; - output_bind_nest.emplace_back(AttrStmt( - bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop)); + output_bind_nest.emplace_back( + AttrStmt(bind_spec, tir::attr::buffer_bind_scope, + Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop)); } // Check variable remap std::unordered_map vmap; diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index 67121b881a33..be1bdd99f901 100644 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -850,14 +850,12 @@ class TensorCoreIRMutator : public StmtExprMutator { return Evaluate( Call(DataType::Handle(), builtin::tvm_bmma_sync(), {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, - buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, - CallNode::Intrinsic)); + buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset})); } else { return Evaluate( Call(DataType::Handle(), builtin::tvm_mma_sync(), {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, - buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, - CallNode::Intrinsic)); + buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset})); } }; @@ -881,8 +879,7 @@ class TensorCoreIRMutator : public StmtExprMutator { auto fill_fragment_call = [this, &op](const Buffer& buffer) { return Evaluate(Call(DataType::Handle(), builtin::tvm_fill_fragment(), {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, op->value}, - CallNode::Intrinsic)); + buffer->elem_offset, op->value})); }; ObjectPtr buffer_node = make_object(); @@ -903,8 +900,7 @@ class TensorCoreIRMutator : public StmtExprMutator { ThreadIdxMutator thread_idx_mutator(warp_y); PrimExpr mutated_value = thread_idx_mutator(op->value); // TODO(tvm-team) The extern function name seems to be a hack. - PrimExpr src = Call(value->dtype, builtin::call_extern(), {StringImm("&"), mutated_value}, - CallNode::Extern); + PrimExpr src = Call(value->dtype, builtin::call_extern(), {StringImm("&"), mutated_value}); auto pload = dst.as(); PrimExpr matrix_major; @@ -922,8 +918,7 @@ class TensorCoreIRMutator : public StmtExprMutator { auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) { return Evaluate(Call(DataType::Handle(), builtin::tvm_load_matrix_sync(), {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, src, stride, matrix_major}, - CallNode::Intrinsic)); + buffer->elem_offset, src, stride, matrix_major})); }; ObjectPtr buffer_node = make_object(); @@ -943,16 +938,14 @@ class TensorCoreIRMutator : public StmtExprMutator { PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); dst = thread_idx_mutator(dst); - dst = - Call(DataType::Handle(), builtin::call_extern(), {StringImm("&"), dst}, CallNode::Extern); + dst = Call(DataType::Handle(), builtin::call_extern(), {StringImm("&"), dst}); auto pload = op->value.as(); auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) { return Evaluate(Call(DataType::Handle(), builtin::tvm_store_matrix_sync(), {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, dst, stride, StringImm("col_major")}, - CallNode::Intrinsic)); + buffer->elem_offset, dst, stride, StringImm("col_major")})); }; ObjectPtr buffer_node = make_object(); @@ -1067,7 +1060,7 @@ class TensorCoreIRMutator : public StmtExprMutator { args.push_back(pload->indices[i]); args.push_back(shape[i]); } - auto tuple = Call(DataType::Handle(), builtin::tvm_tuple(), args, CallNode::Intrinsic); + auto tuple = Call(DataType::Handle(), builtin::tvm_tuple(), args); Array node = {buffer, tensor}; return AttrStmt(node, "buffer_bind_scope", tuple, call_back(buffer)); } diff --git a/src/tir/analysis/side_effect.cc b/src/tir/analysis/side_effect.cc index b5fb328bf2b9..923cda3e41ea 100644 --- a/src/tir/analysis/side_effect.cc +++ b/src/tir/analysis/side_effect.cc @@ -21,9 +21,11 @@ * \file side_effect.cc * \brief side effect analysis */ +#include #include #include #include +#include namespace tvm { namespace tir { @@ -36,11 +38,19 @@ class ExprSideEffect : public ExprVisitor { } void VisitExpr_(const CallNode* op) final { - if (!op->is_pure()) { + static auto op_call_effect = Op::GetAttrMap("TCallEffectKind"); + + if (auto* ptr_op = op->op.as()) { + auto effect_kind = op_call_effect[GetRef(ptr_op)]; + if (effect_kind != CallEffectKind::kPure && effect_kind != CallEffectKind::kExprAnnotation) { + has_side_effect_ = true; + return; + } else { + ExprVisitor::VisitExpr_(op); + } + } else { has_side_effect_ = true; return; - } else { - ExprVisitor::VisitExpr_(op); } } diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 6cccfa0fcebf..e9f65eed166a 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -377,7 +377,7 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane } Array acc_args{e_dtype, self->data, elem_offset, extent, make_const(DataType::Int(32), access_mask)}; - return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args, tir::CallNode::Intrinsic); + return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args); } Buffer::Buffer(Var data, DataType dtype, Array shape, Array strides, diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 4b20351e2053..b4bb98441618 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -698,7 +698,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Call -Call::Call(DataType dtype, RelayExpr op, Array args, CallType call_type) { +Call::Call(DataType dtype, RelayExpr op, Array args) { for (size_t i = 0; i < args.size(); ++i) { CHECK(args[i].defined()); } @@ -707,12 +707,11 @@ Call::Call(DataType dtype, RelayExpr op, Array args, CallType call_typ node->dtype = dtype; node->op = std::move(op); node->args = std::move(args); - node->call_type = call_type; data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.Call") - .set_body_typed([](DataType type, RelayExpr op, Array args, int call_type) { + .set_body_typed([](DataType type, RelayExpr op, Array args) { Array prim_expr_args; for (const auto& it : args) { CHECK(it->IsInstance() || it->IsInstance()); @@ -722,7 +721,7 @@ TVM_REGISTER_GLOBAL("tir.Call") prim_expr_args.push_back(Downcast(it)); } } - return Call(type, op, prim_expr_args, static_cast(call_type)); + return Call(type, op, prim_expr_args); }); TVM_REGISTER_NODE_TYPE(CallNode); diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 98b9fd02c09c..afc128bbb73f 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -166,7 +166,7 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { if (args.same_as(op->args)) { return GetRef(op); } else { - return Call(op->dtype, op->op, args, op->call_type); + return Call(op->dtype, op->op, args); } } diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 7b4ac7e2732c..296f49207cce 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include namespace tvm { @@ -566,10 +567,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) PrimExpr TypeAnnotation(DataType dtype) { static auto op = Op::Get("tir.type_annotation"); - return tir::Call(dtype, op, {}, tir::CallNode::PureIntrinsic); + return tir::Call(dtype, op, {}); } -TVM_REGISTER_OP("tir.type_annotation"); +TVM_REGISTER_OP("tir.type_annotation") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); } // namespace tir } // namespace tvm diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 8efcf3ff4925..d23662c78d37 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -38,117 +38,191 @@ namespace builtin { } \ TVM_REGISTER_OP("tir." #OpName) -TIR_DEFINE_BUILTIN_FUNC(reinterpret).set_num_inputs(1); +TIR_DEFINE_BUILTIN_FUNC(reinterpret) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_num_inputs(1); -TIR_DEFINE_BUILTIN_FUNC(likely).set_num_inputs(1).set_attr("TVectorizable", true); +TIR_DEFINE_BUILTIN_FUNC(likely) + .set_num_inputs(1) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kExprAnnotation)) + .set_attr("TVectorizable", true); TIR_DEFINE_BUILTIN_FUNC(bitwise_and) .set_num_inputs(2) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_attr("TVectorizable", true); TIR_DEFINE_BUILTIN_FUNC(bitwise_or) .set_num_inputs(2) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_attr("TVectorizable", true); TIR_DEFINE_BUILTIN_FUNC(bitwise_xor) .set_num_inputs(2) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_attr("TVectorizable", true); TIR_DEFINE_BUILTIN_FUNC(bitwise_not) .set_num_inputs(1) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_attr("TVectorizable", true); TIR_DEFINE_BUILTIN_FUNC(shift_left) .set_num_inputs(2) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_attr("TVectorizable", true); TIR_DEFINE_BUILTIN_FUNC(shift_right) .set_num_inputs(2) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_attr("TVectorizable", true); -TIR_DEFINE_BUILTIN_FUNC(large_uint_imm).set_num_inputs(2); +TIR_DEFINE_BUILTIN_FUNC(large_uint_imm) + .set_num_inputs(2) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); + +TIR_DEFINE_BUILTIN_FUNC(address_of) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_num_inputs(1); + +TIR_DEFINE_BUILTIN_FUNC(if_then_else) + .set_num_inputs(3) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_BUILTIN_FUNC(address_of).set_num_inputs(1); +TIR_DEFINE_BUILTIN_FUNC(isnullptr).set_num_inputs(1).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_BUILTIN_FUNC(if_then_else).set_num_inputs(3); +TIR_DEFINE_BUILTIN_FUNC(isnan).set_num_inputs(1).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_BUILTIN_FUNC(isnullptr).set_num_inputs(1); +TIR_DEFINE_BUILTIN_FUNC(popcount) + .set_num_inputs(1) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TVectorizable", true); -TIR_DEFINE_BUILTIN_FUNC(isnan).set_num_inputs(1); +TIR_DEFINE_BUILTIN_FUNC(fma) + .set_num_inputs(3) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TVectorizable", true); -TIR_DEFINE_BUILTIN_FUNC(popcount).set_num_inputs(1); +TIR_DEFINE_BUILTIN_FUNC(call_extern) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(fma).set_num_inputs(3).set_attr("TVectorizable", true); +TIR_DEFINE_BUILTIN_FUNC(call_pure_extern) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_BUILTIN_FUNC(call_extern); +TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin); +TIR_DEFINE_BUILTIN_FUNC(call_llvm_pure_intrin) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_BUILTIN_FUNC(call_spirv_glsl450); +TIR_DEFINE_BUILTIN_FUNC(call_spirv_pure_glsl450) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_BUILTIN_FUNC(prefetch); +TIR_DEFINE_BUILTIN_FUNC(prefetch).set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(tvm_access_ptr).set_num_inputs(5); +TIR_DEFINE_BUILTIN_FUNC(tvm_access_ptr) + .set_num_inputs(5) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kSpecialCallArg)); -TIR_DEFINE_BUILTIN_FUNC(tvm_static_handle).set_num_inputs(0); +TIR_DEFINE_BUILTIN_FUNC(tvm_static_handle) + .set_num_inputs(0) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kSpecialCallArg)); -TIR_DEFINE_BUILTIN_FUNC(tvm_context_id).set_num_inputs(0); +TIR_DEFINE_BUILTIN_FUNC(tvm_context_id) + .set_num_inputs(0) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kReadState)); -TIR_DEFINE_BUILTIN_FUNC(tvm_tuple); +TIR_DEFINE_BUILTIN_FUNC(tvm_tuple).set_attr("TCallEffectKind", + Integer(CallEffectKind::kEmbedInfo)); -TIR_DEFINE_BUILTIN_FUNC(tvm_struct_get).set_num_inputs(3); +TIR_DEFINE_BUILTIN_FUNC(tvm_struct_get) + .set_num_inputs(3) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kReadState)); -TIR_DEFINE_BUILTIN_FUNC(tvm_struct_set).set_num_inputs(4); +TIR_DEFINE_BUILTIN_FUNC(tvm_struct_set) + .set_num_inputs(4) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kUpdateState)); -TIR_DEFINE_BUILTIN_FUNC(tvm_throw_last_error).set_num_inputs(0); +TIR_DEFINE_BUILTIN_FUNC(tvm_throw_last_error) + .set_num_inputs(0) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(tvm_stack_alloca).set_num_inputs(2); +TIR_DEFINE_BUILTIN_FUNC(tvm_stack_alloca) + .set_num_inputs(2) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_shape); +TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_shape) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_array).set_num_inputs(6); +TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_array) + .set_num_inputs(6) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); // When num_inputs are not set, the function is assumed to be variable length. -TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed); +TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed); +TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(tvm_thread_context).set_num_inputs(1); +TIR_DEFINE_BUILTIN_FUNC(tvm_thread_context) + .set_num_inputs(1) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed_lowered); +TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed_lowered) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed_lowered); +TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed_lowered) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); // TODO(tvm-team) revisit storage sync once we have a good memory hierachy structure. -TIR_DEFINE_BUILTIN_FUNC(tvm_storage_sync); +TIR_DEFINE_BUILTIN_FUNC(tvm_storage_sync) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle); +TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_up); +TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_up) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_down); +TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_down) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(tvm_warp_activemask); +TIR_DEFINE_BUILTIN_FUNC(tvm_warp_activemask) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(tvm_global_barrier_kinit); +TIR_DEFINE_BUILTIN_FUNC(tvm_global_barrier_kinit) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(tvm_thread_allreduce); +TIR_DEFINE_BUILTIN_FUNC(tvm_thread_allreduce) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(tvm_load_matrix_sync); +TIR_DEFINE_BUILTIN_FUNC(tvm_load_matrix_sync) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kReadState)); -TIR_DEFINE_BUILTIN_FUNC(tvm_mma_sync); +TIR_DEFINE_BUILTIN_FUNC(tvm_mma_sync) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(tvm_bmma_sync); +TIR_DEFINE_BUILTIN_FUNC(tvm_bmma_sync) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(tvm_fill_fragment); +TIR_DEFINE_BUILTIN_FUNC(tvm_fill_fragment) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync); +TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(vectorhigh); +TIR_DEFINE_BUILTIN_FUNC(vectorhigh) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_BUILTIN_FUNC(vectorlow); +TIR_DEFINE_BUILTIN_FUNC(vectorlow).set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); -TIR_DEFINE_BUILTIN_FUNC(vectorcombine); +TIR_DEFINE_BUILTIN_FUNC(vectorcombine) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); } // namespace builtin } // namespace tir diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index f8049eace356..0f67126be3e2 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -38,10 +38,14 @@ namespace tvm { using namespace tir; // macro to register an unary op -#define TIR_REGISTER_PURE_UNARY_OP(OpName) TVM_REGISTER_OP(OpName).set_num_inputs(1) +#define TIR_REGISTER_PURE_UNARY_OP(OpName) \ + TVM_REGISTER_OP(OpName).set_num_inputs(1).set_attr( \ + "TCallEffectKind", Integer(CallEffectKind::kPure)) // macro to register an binary op -#define TIR_REGISTER_PURE_BINARY_OP(OpName) TVM_REGISTER_OP(OpName).set_num_inputs(2) +#define TIR_REGISTER_PURE_BINARY_OP(OpName) \ + TVM_REGISTER_OP(OpName).set_num_inputs(2).set_attr( \ + "TCallEffectKind", Integer(CallEffectKind::kPure)) runtime::DataType GetRuntimeDataType(const Type& type) { if (auto* n = type.as()) { @@ -83,8 +87,7 @@ inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) { // LargeUIntImm PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { return tir::Call(t, tir::builtin::large_uint_imm(), - {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)}, - tir::CallNode::PureIntrinsic); + {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)}); } // The public function with a quick checking path. @@ -262,7 +265,7 @@ PrimExpr cast(const DataType& t, PrimExpr value) { // reinterpret PrimExpr reinterpret(const DataType& t, PrimExpr value) { if (value.dtype() == t) return value; - return tir::Call(t, tir::builtin::reinterpret(), {value}, tir::CallNode::PureIntrinsic); + return tir::Call(t, tir::builtin::reinterpret(), {value}); } // operator+ @@ -387,17 +390,15 @@ PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) } return tir::Call(true_value.dtype(), tir::builtin::if_then_else(), - {cond, true_value, false_value}, tir::CallNode::PureIntrinsic); + {cond, true_value, false_value}); } // likely PrimExpr likely(PrimExpr cond) { if (is_const(cond)) return cond; - return tir::Call(cond.dtype(), tir::builtin::likely(), {cond}, tir::CallNode::PureIntrinsic); + return tir::Call(cond.dtype(), tir::builtin::likely(), {cond}); } -TVM_REGISTER_OP("tir.likely").set_num_inputs(1); - // operator> PrimExpr operator>(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); @@ -481,7 +482,7 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) { } }); - return tir::Call(a.dtype(), tir::builtin::shift_right(), {a, b}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::builtin::shift_right(), {a, b}); } // shift left @@ -500,7 +501,7 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) { if (pb->value == 0) return a; } }); - return tir::Call(a.dtype(), tir::builtin::shift_left(), {a, b}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::builtin::shift_left(), {a, b}); } // bitwise and @@ -512,7 +513,7 @@ PrimExpr operator&(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value & pb->value)); }); - return tir::Call(a.dtype(), tir::builtin::bitwise_and(), {a, b}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::builtin::bitwise_and(), {a, b}); } // bitwise_or @@ -524,7 +525,7 @@ PrimExpr operator|(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value | pb->value)); }); - return tir::Call(a.dtype(), tir::builtin::bitwise_or(), {a, b}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::builtin::bitwise_or(), {a, b}); } // bitwise_xor @@ -536,17 +537,15 @@ PrimExpr operator^(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value)); }); - return tir::Call(a.dtype(), tir::builtin::bitwise_xor(), {a, b}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::builtin::bitwise_xor(), {a, b}); } // bitwie_not PrimExpr operator~(PrimExpr a) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); - return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}); } -TVM_REGISTER_OP("tir.bitwise_not"); - TVM_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a) { return ~a; }); // pow @@ -554,10 +553,10 @@ PrimExpr pow(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "power only applies to float"; static auto op = Op::Get("tir.pow"); - return tir::Call(x.dtype(), op, {x, y}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), op, {x, y}); } -TVM_REGISTER_OP("tir.pow").set_num_inputs(2).set_attr("TVectorizable", true); +TIR_REGISTER_PURE_BINARY_OP("tir.pow").set_attr("TVectorizable", true); // abs PrimExpr abs(PrimExpr x) { @@ -575,7 +574,7 @@ PrimExpr abs(PrimExpr x) { return FloatImm(x.dtype(), std::fabs(fx->value)); } static auto op = Op::Get("tir.fabs"); - return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), op, {x}); } else if (x.dtype().is_uint()) { return x; } else { @@ -600,10 +599,9 @@ PrimExpr isnan(PrimExpr x) { } static auto op = Op::Get("tir.isnan"); if (x.dtype().bits() == 16) { - return tir::Call(t, op, {cast(DataType::Float(32, t.lanes()), std::move(x))}, - tir::CallNode::PureIntrinsic); + return tir::Call(t, op, {cast(DataType::Float(32, t.lanes()), std::move(x))}); } else { - return tir::Call(t, op, {x}, tir::CallNode::PureIntrinsic); + return tir::Call(t, op, {x}); } } else { LOG(FATAL) << "Data type " << x.dtype() << " not supported for isnan op. Skipping isnan op..."; @@ -611,8 +609,6 @@ PrimExpr isnan(PrimExpr x) { } } -TIR_REGISTER_PURE_UNARY_OP("tir.isnan"); - // isinf PrimExpr isinf(PrimExpr x) { DataType t = DataType::Bool(x.dtype().lanes()); @@ -685,7 +681,7 @@ PrimExpr fmod(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "fmod only applies to float"; static auto op = Op::Get("tir.fmod"); - return tir::Call(x.dtype(), op, {x, y}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), op, {x, y}); } TIR_REGISTER_PURE_UNARY_OP("tir.fmod"); @@ -699,7 +695,7 @@ PrimExpr floor(PrimExpr x) { const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::floor(fx->value)); static auto op = Op::Get("tir.floor"); - return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), op, {x}); } TIR_REGISTER_PURE_UNARY_OP("tir.floor").set_attr("TVectorizable", true); @@ -713,7 +709,7 @@ PrimExpr ceil(PrimExpr x) { const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::ceil(fx->value)); static auto op = Op::Get("tir.ceil"); - return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), op, {x}); } TIR_REGISTER_PURE_UNARY_OP("tir.ceil").set_attr("TVectorizable", true); @@ -727,7 +723,7 @@ PrimExpr round(PrimExpr x) { const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); static auto op = Op::Get("tir.round"); - return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), op, {x}); } TIR_REGISTER_PURE_UNARY_OP("tir.round").set_attr("TVectorizable", true); @@ -741,7 +737,7 @@ PrimExpr nearbyint(PrimExpr x) { const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); static auto op = Op::Get("tir.nearbyint"); - return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), op, {x}); } TIR_REGISTER_PURE_UNARY_OP("tir.nearbyint"); @@ -757,7 +753,7 @@ PrimExpr trunc(PrimExpr x) { return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value))); } static auto op = Op::Get("tir.trunc"); - return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), op, {x}); } TIR_REGISTER_PURE_UNARY_OP("tir.trunc").set_attr("TVectorizable", true); @@ -787,8 +783,6 @@ TIR_REGISTER_PURE_UNARY_OP("tir.log1p"); TIR_REGISTER_PURE_UNARY_OP("tir.log10").set_attr("TVectorizable", true); -TIR_REGISTER_PURE_UNARY_OP("tir.popcount").set_attr("TVectorizable", true); - TIR_REGISTER_PURE_UNARY_OP("tir.tan").set_attr("TVectorizable", true); TIR_REGISTER_PURE_UNARY_OP("tir.cos").set_attr("TVectorizable", true); diff --git a/src/tir/op/runtime.cc b/src/tir/op/runtime.cc index 1c540e3a650a..adabae9e75f7 100644 --- a/src/tir/op/runtime.cc +++ b/src/tir/op/runtime.cc @@ -29,11 +29,13 @@ namespace tir { TVM_REGISTER_OP("tir.TVMBackendAllocWorkspace") .set_num_inputs(5) - .set_attr("TGlobalSymbol", "TVMBackendAllocWorkspace"); + .set_attr("TGlobalSymbol", "TVMBackendAllocWorkspace") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); TVM_REGISTER_OP("tir.TVMBackendFreeWorkspace") .set_num_inputs(3) - .set_attr("TGlobalSymbol", "TVMBackendFreeWorkspace"); + .set_attr("TGlobalSymbol", "TVMBackendFreeWorkspace") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index 80c526827ad5..b88d2980b770 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -204,8 +204,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type)); init_nest_.emplace_back( LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop)); - PrimExpr is_null = - Call(DataType::Bool(1), builtin::isnullptr(), {v_strides}, CallNode::PureIntrinsic); + PrimExpr is_null = Call(DataType::Bool(1), builtin::isnullptr(), {v_strides}); if (buffer->strides.size() == 0) { // Assert the buffer is compact DataType stype = buffer->DefaultIndexType(); diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 9722d1100a7e..4a44b85684b2 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -189,13 +189,13 @@ class BF16LowerRewriter : StmtExprMutator { auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); auto uint32_v = Cast(uint32_dtype, op_val); // to be endian invariant. - return Call(op->dtype, builtin::reinterpret(), {uint32_v << 16}, CallNode::PureIntrinsic); + return Call(op->dtype, builtin::reinterpret(), {uint32_v << 16}); } else if (op->dtype.is_bfloat16()) { // if is cast_to_bf16, check if op->value is fp32 CHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32); auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); - auto uint32_v = Call(uint32_dtype, builtin::reinterpret(), {op_val}, CallNode::PureIntrinsic); + auto uint32_v = Call(uint32_dtype, builtin::reinterpret(), {op_val}); auto uint16_dtype = DataType(kDLUInt, 16, op_val->dtype.lanes()); /* the following TIR is equivalent to the C++ code below: uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc index 092a7cdeca98..eb9ef32cb4f6 100644 --- a/src/tir/transforms/coproc_sync.cc +++ b/src/tir/transforms/coproc_sync.cc @@ -196,8 +196,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor { } std::vector GetSync(std::string sync_name) { - return { - Evaluate(Call(DataType::Int(32), Op::Get("tir." + sync_name), {}, CallNode::Intrinsic))}; + return {Evaluate(Call(DataType::Int(32), Op::Get("tir." + sync_name), {}))}; } const std::unordered_set& touched_; @@ -334,8 +333,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor { PrimExpr min = r->min; PrimExpr extent = r->extent; return Evaluate(Call(DataType::Int(32), Op::Get(func), - {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, - CallNode::Intrinsic)); + {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent})); } // Write barrier name bool read_barrier_{false}; @@ -558,13 +556,11 @@ class CoProcInstDepDetector : public StmtVisitor { Stmt MakePush(int from, int to) { return Evaluate(Call(DataType::Int(32), sync_push_op_, - {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, - CallNode::Intrinsic)); + {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)})); } Stmt MakePop(int from, int to) { return Evaluate(Call(DataType::Int(32), sync_pop_op_, - {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, - CallNode::Intrinsic)); + {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)})); } // sync states. SyncState first_state_, last_state_, curr_state_; diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 7180dd29d903..d5405790a15a 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -231,8 +231,7 @@ class VTInjector : public StmtExprMutator { PrimExpr extent = this->VisitExpr(op->args[3]); PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes()); offset = stride * var_ + offset; - return Call(op->dtype, op->op, {op->args[0], op->args[1], offset, extent, op->args[4]}, - op->call_type); + return Call(op->dtype, op->op, {op->args[0], op->args[1], offset, extent, op->args[4]}); } else if (op->op.same_as(builtin::tvm_context_id())) { return allow_share_ ? GetRef(op) : var_; } else { diff --git a/src/tir/transforms/ir_util.h b/src/tir/transforms/ir_util.h index 758923b15af9..2f9d70659f4d 100644 --- a/src/tir/transforms/ir_util.h +++ b/src/tir/transforms/ir_util.h @@ -87,7 +87,7 @@ inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index, builtin::TVMStructFieldKind kind) { Array args = {handle, make_const(DataType::Int(32), index), make_const(DataType::Int(32), static_cast(kind))}; - return Call(dtype, builtin::tvm_struct_get(), args, CallNode::PureIntrinsic); + return Call(dtype, builtin::tvm_struct_get(), args); } /*! @@ -99,8 +99,7 @@ inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index, inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) { return Call(DataType::Handle(), builtin::address_of(), {Load(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()), - const_true(dtype.lanes()))}, - CallNode::PureIntrinsic); + const_true(dtype.lanes()))}); } /*! @@ -115,7 +114,7 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { offset = Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes()); } return Call(DataType::Handle(), builtin::address_of(), - {Load(dtype, handle, offset, const_true(dtype.lanes()))}, CallNode::PureIntrinsic); + {Load(dtype, handle, offset, const_true(dtype.lanes()))}); } /*! @@ -129,7 +128,7 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { inline Stmt TVMStructSet(Var handle, int index, builtin::TVMStructFieldKind kind, PrimExpr value) { Array args = {handle, make_const(DataType::Int(32), index), make_const(DataType::Int(32), static_cast(kind)), value}; - return Evaluate(Call(DataType::Int(32), builtin::tvm_struct_set(), args, CallNode::Intrinsic)); + return Evaluate(Call(DataType::Int(32), builtin::tvm_struct_set(), args)); } /*! diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index d38cb7b36042..5ec4fe303052 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -51,17 +51,13 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } PrimExpr VisitExpr_(const CallNode* op) final { - // NOTE: call_type will eventually be deprecated and the information - // will be folded into Op's attr - if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { - if (auto* ptr_op = op->op.as()) { - // Still use legacy string based rewriting - // TODO(tvm-team): migrate the pattern application from global function look up - // to an OpAttrMap - std::string name = ptr_op->name; - PrimExpr r = ApplyPattern(name, GetRef(op)); - if (r.defined()) return r; - } + if (auto* ptr_op = op->op.as()) { + // Still use legacy string based rewriting + // TODO(tvm-team): migrate the pattern application from global function look up + // to an OpAttrMap + std::string name = ptr_op->name; + PrimExpr r = ApplyPattern(name, GetRef(op)); + if (r.defined()) return r; } return IRMutatorWithAnalyzer::VisitExpr_(op); } @@ -238,7 +234,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr rhs = SwapBroadcastCast(b); if (fma_ != nullptr && op->dtype.is_float()) { - PrimExpr r = (*fma_)(Call(op->dtype, builtin::fma(), {lhs, rhs, c}, CallNode::PureIntrinsic)); + PrimExpr r = (*fma_)(Call(op->dtype, builtin::fma(), {lhs, rhs, c})); if (r.defined()) return this->VisitExpr(r); } else { if (!lhs.same_as(a) || !rhs.same_as(b)) { diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index dab8d5a78d02..04b89534e818 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -242,8 +242,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Var mask_var("mask", DataType::UInt(32)); { PrimExpr pred = const_true(1); - PrimExpr mask = - Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}, CallNode::Intrinsic); + PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}); seq.emplace_back(Store(mask_var, mask, index, pred)); // Push allocation with an empty body. Later this will be fixed // when the entire body is ready. @@ -464,8 +463,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // sync thread op. static Stmt SyncThread(const std::string& sync) { - return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), {StringImm(sync)}, - CallNode::Intrinsic)); + return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), {StringImm(sync)})); } // Emit warp shuffle calls. @@ -475,7 +473,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { PrimExpr mask = Load(DataType::UInt(32), mask_var, index, pred); PrimExpr width = IntImm(DataType::Int(32), warp_size_); Array args{mask, val, IntImm(DataType::Int(32), delta_or_lane), width, width}; - return Call(val.dtype(), op, args, CallNode::Intrinsic); + return Call(val.dtype(), op, args); } // Check if this is a reduction on threadIdx.x and its extent matches diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index e6182301a335..f07170489298 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -41,7 +41,7 @@ inline PrimExpr ConstInt32(size_t index) { inline PrimExpr StackAlloca(std::string type, size_t num) { Array args = {StringImm(type), ConstInt32(num)}; - return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args, CallNode::Intrinsic); + return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args); } // Calculate the statistics of packed function. @@ -103,11 +103,9 @@ class BuiltinLower : public StmtExprMutator { } CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR"; - Stmt throw_last_error = - Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}, CallNode::Intrinsic)); + Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {})); - Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}, - CallNode::PureIntrinsic), + Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}), throw_last_error), op->body}); Stmt alloca = LetStmt( @@ -115,14 +113,12 @@ class BuiltinLower : public StmtExprMutator { Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"), {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), op->dtype.code()), - IntImm(DataType::Int(32), op->dtype.bits())}, - CallNode::Extern), + IntImm(DataType::Int(32), op->dtype.bits())}), body); PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"), {cast(DataType::Int(32), device_type_), - cast(DataType::Int(32), device_id_), op->buffer_var}, - CallNode::Extern); + cast(DataType::Int(32), device_id_), op->buffer_var}); Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); body = SeqStmt({alloca, free_stmt}); body = AttrStmt(op->buffer_var, attr::storage_alignment, @@ -245,8 +241,7 @@ class BuiltinLower : public StmtExprMutator { Array packed_args = {op->args[0], stack_value_, stack_tcode_, ConstInt32(arg_stack_begin), ConstInt32(arg_stack_begin + op->args.size() - 1)}; - return Call(DataType::Int(32), builtin::tvm_call_packed_lowered(), packed_args, - CallNode::Intrinsic); + return Call(DataType::Int(32), builtin::tvm_call_packed_lowered(), packed_args); } PrimExpr MakeCallTracePacked(const CallNode* op) { @@ -287,8 +282,7 @@ class BuiltinLower : public StmtExprMutator { ConstInt32(arg_stack_begin + op->args.size() - 1), // Pass traced value. op->args[args_size - 1]}; - return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args, - CallNode::Intrinsic); + return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args); } private: diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 3e7d13b2ff6e..72423e0cdfb5 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -250,10 +250,9 @@ class WarpAccessRewriter : protected StmtExprMutator { << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index << " local_index=" << local_index; PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate); - PrimExpr mask = - Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}, CallNode::Intrinsic); + PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}); return Call(load_value.dtype(), builtin::tvm_warp_shuffle(), - {mask, load_value, group, width_, warp_size_}, CallNode::Intrinsic); + {mask, load_value, group, width_, warp_size_}); } else { return StmtExprMutator::VisitExpr_(op); } diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 9bb5fc6b5971..bfcf0b7f11ae 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -86,7 +86,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { IntImm(DataType::Int(32), builtin::kTVMValueContent)}; // load 64 bit version DataType api_type = APIType(t); - PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args, CallNode::PureIntrinsic); + PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args); // cast to the target version. if (api_type != t) { res = Cast(t, res); @@ -191,8 +191,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { if (runtime::DeviceAPI::NeedSetDeviceContext(target_device_type)) { Stmt set_device = Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), - {StringImm(runtime::symbol::tvm_set_device), device_type, device_id}, - CallNode::Intrinsic)); + {StringImm(runtime::symbol::tvm_set_device), device_type, device_id})); body = SeqStmt({set_device, body}); } } diff --git a/src/tir/transforms/rewrite_unsafe_select.cc b/src/tir/transforms/rewrite_unsafe_select.cc index e5535369c39e..f1286d773c2d 100644 --- a/src/tir/transforms/rewrite_unsafe_select.cc +++ b/src/tir/transforms/rewrite_unsafe_select.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -43,11 +44,16 @@ class UnsafeExprDetector : public ExprFunctor { } else if (op->op.same_as(builtin::address_of())) { const LoadNode* l = op->args[0].as(); return this->VisitExpr(l->index); - } else if (op->is_pure()) { - for (PrimExpr e : op->args) { - if (VisitExpr(e)) return true; + } else if (auto* ptr_op = op->op.as()) { + auto effect_kind = op_call_effect_[GetRef(ptr_op)]; + if (effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation) { + for (PrimExpr e : op->args) { + if (VisitExpr(e)) return true; + } + return false; + } else { + return true; } - return false; } else { return true; } @@ -94,6 +100,8 @@ class UnsafeExprDetector : public ExprFunctor { bool BinaryOp(const T* op) { return VisitExpr(op->a) || VisitExpr(op->b); } + + OpAttrMap op_call_effect_ = Op::GetAttrMap("TCallEffectKind"); }; class UnsafeSelectRewriter : public StmtExprMutator { @@ -106,7 +114,7 @@ class UnsafeSelectRewriter : public StmtExprMutator { if ((unsafe.VisitExpr(op->true_value) || unsafe.VisitExpr(op->false_value)) && cond_is_scalar_bool) { return Call(op->dtype, builtin::if_then_else(), - {op->condition, op->true_value, op->false_value}, CallNode::Intrinsic); + {op->condition, op->true_value, op->false_value}); } else { return expr; } diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index c35caf54db4a..f339c565959a 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -238,8 +238,7 @@ class HostDeviceSplitter : public StmtMutator { for (PrimExpr ext : m.thread_extent_) { call_args.push_back(ext); } - return Evaluate( - Call(DataType::Int(32), builtin::tvm_call_packed(), call_args, CallNode::Intrinsic)); + return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), call_args)); } // target ir module diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 30805508144d..8eb43f8ebc84 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -321,10 +321,8 @@ class StorageFlattener : public StmtExprMutator { stmt = For(vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt); } else { PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype); - PrimExpr address = - Call(DataType::Handle(), builtin::address_of(), {load}, CallNode::PureIntrinsic); - PrimExpr prefetch = - Call(op->buffer->dtype, builtin::prefetch(), {address, 0, 3, 1}, CallNode::Intrinsic); + PrimExpr address = Call(DataType::Handle(), builtin::address_of(), {load}); + PrimExpr prefetch = Call(op->buffer->dtype, builtin::prefetch(), {address, 0, 3, 1}); stmt = Evaluate(prefetch); PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1; stmt = For(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt); diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index d7a258cffe30..09d96510f7f0 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -404,8 +404,7 @@ class StoragePlanRewriter : public StmtExprMutator { if (se->bits_offset != 0) { offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset; } - return Call(op->dtype, op->op, {op->args[0], se->alloc_var, offset, extent, op->args[4]}, - op->call_type); + return Call(op->dtype, op->op, {op->args[0], se->alloc_var, offset, extent, op->args[4]}); } else { return StmtExprMutator::VisitExpr_(op); } diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index cdd9377e00d6..a38be3c8c6fe 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -211,7 +211,7 @@ class ThreadSyncInserter : public StmtExprMutator { barrier = MakeGlobalBarrier(); } else { barrier = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), - {StringImm(sync_scope_.to_string())}, CallNode::Intrinsic)); + {StringImm(sync_scope_.to_string())})); } // Mutate after query, to avoid stmt change. auto ret = StmtExprMutator::VisitStmt(stmt); @@ -299,8 +299,7 @@ class ThreadSyncInserter : public StmtExprMutator { Stmt InitGlobalBarrier(const AttrStmtNode* op) { CHECK(op != nullptr); Array pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)}; - Stmt prep = - Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs, CallNode::Intrinsic)); + Stmt prep = Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs)); Stmt body = op->body; for (const auto& kv : rw_stats_) { const auto& e = kv.second; @@ -309,8 +308,7 @@ class ThreadSyncInserter : public StmtExprMutator { } } rw_stats_.clear(); - Stmt kinit = Evaluate( - Call(DataType::Int(32), builtin::tvm_global_barrier_kinit(), {}, CallNode::Intrinsic)); + Stmt kinit = Evaluate(Call(DataType::Int(32), builtin::tvm_global_barrier_kinit(), {})); body = SeqStmt({kinit, body}); body = AttrStmt(op->node, op->attr_key, op->value, body); return SeqStmt({prep, body}); @@ -334,8 +332,7 @@ class ThreadSyncInserter : public StmtExprMutator { CHECK_EQ(num_work_dim_, thread_extents_.size()); } return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), - {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_}, - CallNode::Intrinsic)); + {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_})); } // data structure. StorageScope sync_scope_; diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 1a2ec502f605..e015990847e5 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -214,7 +214,7 @@ class Vectorizer : public StmtExprMutator { int lanes = std::max(t.dtype().lanes(), f.dtype().lanes()); t = BroadcastTo(t, lanes); f = BroadcastTo(f, lanes); - return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}, op->call_type); + return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}); } } // Call @@ -239,7 +239,7 @@ class Vectorizer : public StmtExprMutator { if (op->args.same_as(new_args)) { return GetRef(op); } else { - return Call(op->dtype, op->op, new_args, op->call_type); + return Call(op->dtype, op->op, new_args); } } else { int lane = 0; @@ -248,7 +248,7 @@ class Vectorizer : public StmtExprMutator { if (op->args.same_as(new_args)) { return GetRef(op); } else { - return Call(op->dtype.with_lanes(lane), op->op, new_args, op->call_type); + return Call(op->dtype.with_lanes(lane), op->op, new_args); } } } diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index ce50ed0c45f7..de06a0e7189f 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -193,8 +193,8 @@ TEST(IRF, StmtMutator) { } { - auto body = Evaluate(Call(DataType::Int(32), builtin::call_extern(), {StringImm("xyz"), x + 1}, - CallNode::Extern)); + auto body = + Evaluate(Call(DataType::Int(32), builtin::call_extern(), {StringImm("xyz"), x + 1})); auto res = v(std::move(body)); CHECK(res.as()->value.as()->args[1].same_as(x)); } diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 9882a3b854af..e12f970b2724 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -204,7 +204,7 @@ def test_reduce_combiner_simplify(): # Test that components with side effects are not removed dummy = tvm.ir.GlobalVar("dummy") - side_effect = lambda *xs: tvm.tir.Call("int32", dummy, xs, tvm.tir.Call.Intrinsic) + side_effect = lambda *xs: tvm.tir.Call("int32", dummy, xs) ck.verify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0], sum_and_prod((A[k], side_effect(A[10-k])), k)[0]) ck.verify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0], diff --git a/tests/python/unittest/test_target_codegen_c_host.py b/tests/python/unittest/test_target_codegen_c_host.py index 18a98eed0673..698dd74b0786 100644 --- a/tests/python/unittest/test_target_codegen_c_host.py +++ b/tests/python/unittest/test_target_codegen_c_host.py @@ -98,7 +98,7 @@ def test_reinterpret(): nn = 1024 n = tvm.runtime.convert(nn) A = te.placeholder((n,), name='A', dtype="int32") - B = te.compute(A.shape, lambda *i: tvm.tir.call_pure_intrin("float32", "tir.reinterpret", A(*i)), name='B') + B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin("float32", "tir.reinterpret", A(*i)), name='B') s = te.create_schedule(B.op) def check_c(): diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index a6a231564033..911ffb44f353 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -29,12 +29,12 @@ def test_llvm_intrin(): n = tvm.runtime.convert(4) A = ib.pointer("float32", name="A") args = [ - tvm.tir.call_pure_intrin("handle", "tir.address_of", A[0]), + tvm.tir.call_intrin("handle", "tir.address_of", A[0]), 0, 3, 1 ] ib.emit(tvm.tir.Evaluate( tvm.tir.Call( - "int32", "tir.prefetch", args, tvm.tir.Call.Intrinsic))) + "int32", "tir.prefetch", args))) body = ib.get() mod = tvm.IRModule.from_expr( @@ -65,7 +65,7 @@ def test_llvm_overloaded_intrin(): def use_llvm_intrinsic(A, C): ib = tvm.tir.ir_builder.create() L = A.vload((0,0)) - I = tvm.tir.call_llvm_intrin('int32', 'llvm.ctlz', + I = tvm.tir.call_llvm_pure_intrin('int32', 'llvm.ctlz', tvm.tir.const(2, 'uint32'), L, tvm.tir.const(0, 'int1')) S = C.vstore((0,0), I) ib.emit(S) @@ -124,7 +124,7 @@ def test_llvm_lookup_intrin(): ib = tvm.tir.ir_builder.create() A = ib.pointer("uint8x8", name="A") z = tvm.tir.const(0, 'int32') - x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'), A[z]) + x = tvm.tir.call_llvm_pure_intrin("uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'), A[z]) ib.emit(x) body = ib.get() mod = tvm.IRModule.from_expr( diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py index d2c504badd4a..578e32f92859 100644 --- a/tests/python/unittest/test_tir_constructor.py +++ b/tests/python/unittest/test_tir_constructor.py @@ -112,12 +112,11 @@ def test_expr_constructor(): assert x.vectors[0] == a assert x.indices[0].value == 0 - x = tvm.tir.Call("float32", "tir.call_extern", [tvm.tir.StringImm("xyz"), a], tvm.tir.Call.Extern) + x = tvm.tir.Call("float32", "tir.call_extern", [tvm.tir.StringImm("xyz"), a]) assert isinstance(x, tvm.tir.Call) assert x.dtype == "float32" assert x.op.name == "tir.call_extern" assert x.args[1] == a - assert x.call_type == tvm.tir.Call.Extern v = te.var("aa") x = tvm.tir.Let(v, 1, v) diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 39acb3aecafa..ab730cd63d1e 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -171,19 +171,19 @@ def test_all(): def test_bitwise(): x = te.var('x') y = te.var('y') - assert str(x << y) == '@tir.shift_left(x: int32, y: int32, dtype=int32, type="pure_intrin")' - assert str(x >> y) == '@tir.shift_right(x: int32, y: int32, dtype=int32, type="pure_intrin")' - assert str(x & y) == '@tir.bitwise_and(x: int32, y: int32, dtype=int32, type="pure_intrin")' - assert str(x | y) == '@tir.bitwise_or(x: int32, y: int32, dtype=int32, type="pure_intrin")' - assert str(x ^ y) == '@tir.bitwise_xor(x: int32, y: int32, dtype=int32, type="pure_intrin")' - assert str(10 & x) == '@tir.bitwise_and(10, x: int32, dtype=int32, type="pure_intrin")' - assert str(10 | x) == '@tir.bitwise_or(10, x: int32, dtype=int32, type="pure_intrin")' - assert str(10 ^ x) == '@tir.bitwise_xor(10, x: int32, dtype=int32, type="pure_intrin")' - assert str(10 >> x) == '@tir.shift_right(10, x: int32, dtype=int32, type="pure_intrin")' - assert str(10 << x) == '@tir.shift_left(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(x << y) == '@tir.shift_left(x: int32, y: int32, dtype=int32)' + assert str(x >> y) == '@tir.shift_right(x: int32, y: int32, dtype=int32)' + assert str(x & y) == '@tir.bitwise_and(x: int32, y: int32, dtype=int32)' + assert str(x | y) == '@tir.bitwise_or(x: int32, y: int32, dtype=int32)' + assert str(x ^ y) == '@tir.bitwise_xor(x: int32, y: int32, dtype=int32)' + assert str(10 & x) == '@tir.bitwise_and(10, x: int32, dtype=int32)' + assert str(10 | x) == '@tir.bitwise_or(10, x: int32, dtype=int32)' + assert str(10 ^ x) == '@tir.bitwise_xor(10, x: int32, dtype=int32)' + assert str(10 >> x) == '@tir.shift_right(10, x: int32, dtype=int32)' + assert str(10 << x) == '@tir.shift_left(10, x: int32, dtype=int32)' assert str(10 % x) == 'floormod(10, x: int32)' - assert str(~x) == '@tir.bitwise_not(x: int32, dtype=int32, type="pure_intrin")' + assert str(~x) == '@tir.bitwise_not(x: int32, dtype=int32)' assert(tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2" assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2" assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2" @@ -240,10 +240,10 @@ def test_divide_by_zero(): def test_isnan(): x = te.var('x', 'float32') - assert str(tvm.tir.isnan(x)) == '@tir.isnan(x: float32, dtype=bool, type="pure_intrin")' + assert str(tvm.tir.isnan(x)) == '@tir.isnan(x: float32, dtype=bool)' assert str(tvm.tir.isnan(x).dtype) == 'bool' y = te.var('y', 'float16') - assert str(tvm.tir.isnan(y)) == '@tir.isnan(cast(float32, y: float16), dtype=bool, type="pure_intrin")' + assert str(tvm.tir.isnan(y)) == '@tir.isnan(cast(float32, y: float16), dtype=bool)' z = te.var('z', 'int32') assert str(tvm.tir.isnan(z)) == 'False' k = te.var('k', 'int8x2') diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py b/tests/python/unittest/test_tir_transform_bf16_legalize.py index 55a6819aeced..599ddba41015 100644 --- a/tests/python/unittest/test_tir_transform_bf16_legalize.py +++ b/tests/python/unittest/test_tir_transform_bf16_legalize.py @@ -115,19 +115,19 @@ def get_target(): def test_legalize(): def to32(v): uint32_v = topi.cast(v, "uint32") - uint32_v = tvm.tir.call_pure_intrin( + uint32_v = tvm.tir.call_intrin( "uint32", "tir.shift_left", uint32_v, tvm.tir.const(16, "uint32")) - return tvm.tir.call_pure_intrin("float32", "tir.reinterpret", uint32_v) + return tvm.tir.call_intrin("float32", "tir.reinterpret", uint32_v) def to16(v): - uint32_v = tvm.tir.call_pure_intrin("uint32", "tir.reinterpret", v) - rounding_bias = tvm.tir.call_pure_intrin( + uint32_v = tvm.tir.call_intrin("uint32", "tir.reinterpret", v) + rounding_bias = tvm.tir.call_intrin( "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32")) - rounding_bias = tvm.tir.call_pure_intrin( + rounding_bias = tvm.tir.call_intrin( "uint32", "tir.bitwise_and", rounding_bias, tvm.tir.const(1, "uint32")) rounding_bias = rounding_bias + tvm.tir.const(0x7FFF, "uint16") uint32_v = uint32_v + rounding_bias - uint32_v = tvm.tir.call_pure_intrin( + uint32_v = tvm.tir.call_intrin( "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32")) return topi.cast(uint32_v, 'uint16') diff --git a/tests/python/unittest/test_tir_transform_combine_context_call.py b/tests/python/unittest/test_tir_transform_combine_context_call.py index d7a25ca0156e..288695891952 100644 --- a/tests/python/unittest/test_tir_transform_combine_context_call.py +++ b/tests/python/unittest/test_tir_transform_combine_context_call.py @@ -22,7 +22,7 @@ def test_for(): def device_context(dev_id): ctx = tvm.tir.call_extern("handle", "device_context", dev_type, dev_id) return tvm.tir.Call( - "handle", "tir.tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic) + "handle", "tir.tvm_thread_context", [ctx]) ib = tvm.tir.ir_builder.create() n = te.var("n") diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index 4964039a4c14..be725d60ad9e 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -36,7 +36,7 @@ def get_vthread(name): bbuffer = tvm.tir.decl_buffer((m,), dtype=B.dtype, data=B.asobject()) ib.emit(tvm.tir.call_extern("int32", "Run", bbuffer.access_ptr("r"), - tvm.tir.call_pure_intrin("int32", "tir.tvm_context_id"))) + tvm.tir.call_intrin("int32", "tir.tvm_context_id"))) C[i * nthread + tx] = B[i] + 1 return ib.get() diff --git a/topi/include/topi/detail/extern.h b/topi/include/topi/detail/extern.h index 7068b95bec6c..5349818a2790 100644 --- a/topi/include/topi/detail/extern.h +++ b/topi/include/topi/detail/extern.h @@ -112,12 +112,12 @@ inline Array make_extern(const Array >& out_shapes, */ inline PrimExpr pack_buffer(Buffer buf) { CHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element"; - auto shape = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), - buf->shape, tvm::tir::CallNode::CallType::Intrinsic); + auto shape = + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), buf->shape); PrimExpr strides; if (buf->strides.size() > 0) { - strides = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), - buf->shape, tvm::tir::CallNode::CallType::Intrinsic); + strides = + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), buf->shape); } else { strides = 0; } @@ -127,8 +127,7 @@ inline PrimExpr pack_buffer(Buffer buf) { make_const(DataType::Int(32), static_cast(buf->shape.size())), make_const(buf->dtype, 0), buf->elem_offset}; - return tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_array(), pack_args, - tvm::tir::CallNode::CallType::Intrinsic); + return tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_array(), pack_args); } /*! @@ -141,8 +140,7 @@ inline PrimExpr pack_buffer(Buffer buf) { * \return An expression representing the invocation */ inline PrimExpr call_packed(Array args) { - return tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args, - tvm::tir::CallNode::CallType::Intrinsic); + return tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args); } } // namespace detail diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index 0ec7e4d212bf..9b418d03900c 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -310,8 +310,7 @@ inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "te return compute( x->shape, [&](const Array& i) { - return tvm::tir::Call(type, tvm::tir::builtin::reinterpret(), {x(i)}, - tvm::tir::CallNode::PureIntrinsic); + return tvm::tir::Call(type, tvm::tir::builtin::reinterpret(), {x(i)}); }, name, tag); } diff --git a/topi/python/topi/arm_cpu/bitserial_conv2d.py b/topi/python/topi/arm_cpu/bitserial_conv2d.py index f035251a8c29..e76b374f32e3 100644 --- a/topi/python/topi/arm_cpu/bitserial_conv2d.py +++ b/topi/python/topi/arm_cpu/bitserial_conv2d.py @@ -231,21 +231,21 @@ def _instr(index): cnts = tvm.tir.popcount(w_ & x_) - tvm.tir.popcount(~w_ & x_) else: cnts = tvm.tir.popcount(w_ & x_) - upper_half = tvm.tir.call_pure_intrin( + upper_half = tvm.tir.call_intrin( half_dtype, 'tir.vectorhigh', cnts) - lower_half = tvm.tir.call_pure_intrin( + lower_half = tvm.tir.call_intrin( half_dtype, 'tir.vectorlow', cnts) cnts8[i] = upper_half + lower_half for i in range(m//2): - cnts4[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, - args_2, cnts8[i*2], cnts8[i*2+1]) + cnts4[i] = tvm.tir.call_llvm_pure_intrin( + half_dtype, vpadd, args_2, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): - cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, - args_2, cnts4[i*2], cnts4[i*2+1]) - cnts = tvm.tir.call_pure_intrin( + cnts2[i] = tvm.tir.call_llvm_pure_intrin( + half_dtype, vpadd, args_2, cnts4[i*2], cnts4[i*2+1]) + cnts = tvm.tir.call_intrin( full_dtype, 'tir.vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype) - out = tvm.tir.call_llvm_intrin( + out = tvm.tir.call_llvm_pure_intrin( return_dtype, vpadalu, args_2, zz.vload(0, return_dtype), shifted_cnts) else: # ki == 8 @@ -257,15 +257,15 @@ def _instr(index): else: cnts8[i] = tvm.tir.popcount(w_ & x_) for i in range(m//2): - cnts4[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, - args_2, cnts8[i*2], cnts8[i*2+1]) + cnts4[i] = tvm.tir.call_llvm_pure_intrin( + half_dtype, vpadd, args_2, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): - cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, - args_2, cnts4[i*2], cnts4[i*2+1]) - cnts = tvm.tir.call_pure_intrin( + cnts2[i] = tvm.tir.call_llvm_pure_intrin( + half_dtype, vpadd, args_2, cnts4[i*2], cnts4[i*2+1]) + cnts = tvm.tir.call_intrin( full_dtype, 'tir.vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype) - out = tvm.tir.call_llvm_intrin( + out = tvm.tir.call_llvm_pure_intrin( return_dtype, vpadalu, args_2, zz.vload(0, return_dtype), shifted_cnts) irb.emit(zz.vstore(0, out)) diff --git a/topi/python/topi/arm_cpu/tensor_intrin.py b/topi/python/topi/arm_cpu/tensor_intrin.py index 6ef2548af1be..dfa2f05e7960 100644 --- a/topi/python/topi/arm_cpu/tensor_intrin.py +++ b/topi/python/topi/arm_cpu/tensor_intrin.py @@ -425,21 +425,22 @@ def _instr(index): dtype_c = '%s32x%d' % (dtype, int32_lanes) a_int8 = ins[0].vload([0], dtype_a) - re_int32 = tvm.tir.call_pure_intrin('%s32' % dtype, 'tir.reinterpret', a_int8) + re_int32 = tvm.tir.call_intrin('%s32' % dtype, 'tir.reinterpret', a_int8) # broadcast a vec_ai32 = re_int32.astype(dtype_c) - vec_a = tvm.tir.call_pure_intrin(dtype_b, 'tir.reinterpret', vec_ai32) + vec_a = tvm.tir.call_intrin(dtype_b, 'tir.reinterpret', vec_ai32) vec_b = ins[1].vload([0, 0], dtype_b) vec_c = outs[0].vload([0], dtype_c) inst = 'udot' if dtype == 'uint' else 'sdot' inst = 'llvm.aarch64.neon.%s.v%di32.v%di8' % ( inst, int32_lanes, int32_lanes * num_int8_elements) - vdot = tvm.tir.call_llvm_intrin(dtype_c, - inst, - tvm.tir.const(2, 'uint32'), - vec_c, vec_a, vec_b) + vdot = tvm.tir.call_llvm_pure_intrin( + dtype_c, + inst, + tvm.tir.const(2, 'uint32'), + vec_c, vec_a, vec_b) ib.emit(outs[0].vstore(0, vdot)) return ib.get() diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index c98d7e99d3ee..9e3200a0c418 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -38,10 +38,10 @@ def cuda_atomic_add_rule(op): tvm.target.intrin.register_intrin_rule( "cuda", "atomic_add", cuda_atomic_add_rule, override=True) -tvm.ir.register_op_attr("tir.atomic_add", "TVectorizable", False) +tvm.ir.register_op_attr("tir.atomic_add", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque) def atomic_add(x, y): - return tvm.tir.call_pure_intrin(y.dtype, "tir.atomic_add", x, y) + return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y) def get_valid_counts_ir(data, valid_count, out, out_indices, @@ -114,7 +114,7 @@ def get_valid_counts_ir(data, valid_count, out, out_indices, with ib.if_scope( tvm.tir.all(data[tid * elem_length + score_index] > score_threshold, tvm.tir.any(id_index < 0, data[tid * elem_length + id_index] >= 0))): - atomic_add_return[0] = atomic_add(tvm.tir.call_pure_intrin("handle", "tir.address_of", + atomic_add_return[0] = atomic_add(tvm.tir.call_intrin("handle", "tir.address_of", valid_count[i]), one_count) with ib.for_range(0, elem_length) as k: out[tid * elem_length + k] = data[tid * elem_length + k] diff --git a/topi/python/topi/cuda/rcnn/proposal.py b/topi/python/topi/cuda/rcnn/proposal.py index 5b7e0905de63..14143845d98f 100644 --- a/topi/python/topi/cuda/rcnn/proposal.py +++ b/topi/python/topi/cuda/rcnn/proposal.py @@ -186,8 +186,7 @@ def argsort_ir(data_buf, out_index_buf): index_out[offset] = index_out[offset + 1] index_out[offset + 1] = temp_index[0] ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync', - tvm.runtime.convert(['shared']), - tvm.tir.Call.Intrinsic)) + tvm.runtime.convert(['shared']))) return ib.get() @@ -247,8 +246,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.if_scope(iou > nms_threshold): p_out[base_idx + i] = True ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync', - tvm.runtime.convert(['shared']), - tvm.tir.Call.Intrinsic)) + tvm.runtime.convert(['shared']))) return ib.get() diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py index 7181d5721684..a8d1572acbde 100644 --- a/topi/python/topi/cuda/sort.py +++ b/topi/python/topi/cuda/sort.py @@ -116,8 +116,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): indices_out[base_idx + tid * axis_mul_after] = \ tvm.tir.generic.cast(tid, indices_out.dtype) ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync', - tvm.runtime.convert(['shared']), - tvm.tir.Call.Intrinsic)) + tvm.runtime.convert(['shared']))) idxd = tvm.tir.indexdiv idxm = tvm.tir.indexmod @@ -144,8 +143,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): indices_out[offset] = indices_out[offset + axis_mul_after] indices_out[offset + axis_mul_after] = temp_index[0] ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync', - tvm.runtime.convert(['shared']), - tvm.tir.Call.Intrinsic)) + tvm.runtime.convert(['shared']))) return ib.get() @@ -236,8 +234,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): output[offset] = output[offset + axis_mul_after] output[offset + axis_mul_after] = temp_index[0] ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync', - tvm.runtime.convert(['shared']), - tvm.tir.Call.Intrinsic)) + tvm.runtime.convert(['shared']))) return ib.get() diff --git a/topi/python/topi/x86/tensor_intrin.py b/topi/python/topi/x86/tensor_intrin.py index 31de70e92f18..17c0b36b70cf 100644 --- a/topi/python/topi/x86/tensor_intrin.py +++ b/topi/python/topi/x86/tensor_intrin.py @@ -88,19 +88,21 @@ def _instr(index): return ib.get() a_int8 = ins[0].vload([0], "uint8x4") - re_int32 = tvm.tir.call_pure_intrin('int32', 'tir.reinterpret', a_int8) + re_int32 = tvm.tir.call_intrin('int32', 'tir.reinterpret', a_int8) vec_ai32 = re_int32.astype('int32x16') - vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai32) + vec_a = tvm.tir.call_intrin('int8x64', 'tir.reinterpret', vec_ai32) vec_b = ins[1].vload([0, 0], "int8x64") vec_one = tvm.tir.const(1, "int16x32") - pair_reduction = tvm.tir.call_llvm_intrin('int16x32', - 'llvm.x86.avx512.pmaddubs.w.512', - tvm.tir.const(0, 'uint32'), - vec_a, vec_b) - quad_reduction = tvm.tir.call_llvm_intrin('int32x16', - 'llvm.x86.avx512.pmaddw.d.512', - tvm.tir.const(0, 'uint32'), - pair_reduction, vec_one) + pair_reduction = tvm.tir.call_llvm_pure_intrin( + 'int16x32', + 'llvm.x86.avx512.pmaddubs.w.512', + tvm.tir.const(0, 'uint32'), + vec_a, vec_b) + quad_reduction = tvm.tir.call_llvm_pure_intrin( + 'int32x16', + 'llvm.x86.avx512.pmaddw.d.512', + tvm.tir.const(0, 'uint32'), + pair_reduction, vec_one) if index == 0: ib.emit(outs[0].vstore(0, quad_reduction)) else: @@ -174,16 +176,17 @@ def _instr(index): return ib.get() a_int8 = ins[0].vload([0], "uint8x2") - re_int16 = tvm.tir.call_pure_intrin('int16', 'tir.reinterpret', a_int8) + re_int16 = tvm.tir.call_intrin('int16', 'tir.reinterpret', a_int8) vec_ai16 = re_int16.astype('int16x32') - vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai16) + vec_a = tvm.tir.call_intrin('int8x64', 'tir.reinterpret', vec_ai16) for i in range(4): vec_b = ins[1].vload([i*32, 0], "int8x64") - pair_reduction = tvm.tir.call_llvm_intrin('int16x32', - 'llvm.x86.avx512.pmaddubs.w.512', - tvm.tir.const(0, 'uint32'), - vec_a, vec_b) + pair_reduction = tvm.tir.call_llvm_pure_intrin( + 'int16x32', + 'llvm.x86.avx512.pmaddubs.w.512', + tvm.tir.const(0, 'uint32'), + vec_a, vec_b) if index == 0: ib.emit(outs[0].vstore([i*32], pair_reduction)) else: @@ -254,7 +257,7 @@ def _instr(index): return ib.get() a_int8 = ins[0].vload([0], "uint8x4") - re_int32 = tvm.tir.call_pure_intrin('int32', 'tir.reinterpret', a_int8) + re_int32 = tvm.tir.call_intrin('int32', 'tir.reinterpret', a_int8) vec_ai32 = re_int32.astype('int32x16') vec_b = ins[1].vload([0, 0], "int8x64") @@ -262,24 +265,27 @@ def _instr(index): llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id(vnni_inst_name) if llvm_id != 0: # VNNI is available for current LLVM version - vec_bi32 = tvm.tir.call_pure_intrin('int32x16', 'tir.reinterpret', vec_b) + vec_bi32 = tvm.tir.call_intrin('int32x16', 'tir.reinterpret', vec_b) vec_zero = tvm.tir.const(0, "int32x16") - quad_reduction = tvm.tir.call_llvm_intrin('int32x16', - 'llvm.x86.avx512.vpdpbusd.512', - tvm.tir.const(0, 'uint32'), - vec_zero, - vec_ai32, vec_bi32) + quad_reduction = tvm.tir.call_llvm_pure_intrin( + 'int32x16', + 'llvm.x86.avx512.vpdpbusd.512', + tvm.tir.const(0, 'uint32'), + vec_zero, + vec_ai32, vec_bi32) else: # Fall back to the normal AVX512 - vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai32) + vec_a = tvm.tir.call_intrin('int8x64', 'tir.reinterpret', vec_ai32) vec_one = tvm.tir.const(1, "int16x32") - pair_reduction = tvm.tir.call_llvm_intrin('int16x32', - 'llvm.x86.avx512.pmaddubs.w.512', - tvm.tir.const(0, 'uint32'), - vec_a, vec_b) - quad_reduction = tvm.tir.call_llvm_intrin('int32x16', - 'llvm.x86.avx512.pmaddw.d.512', - tvm.tir.const(0, 'uint32'), - pair_reduction, vec_one) + pair_reduction = tvm.tir.call_llvm_pure_intrin( + 'int16x32', + 'llvm.x86.avx512.pmaddubs.w.512', + tvm.tir.const(0, 'uint32'), + vec_a, vec_b) + quad_reduction = tvm.tir.call_llvm_pure_intrin( + 'int32x16', + 'llvm.x86.avx512.pmaddw.d.512', + tvm.tir.const(0, 'uint32'), + pair_reduction, vec_one) if index == 0: ib.emit(outs[0].vstore(0, quad_reduction)) diff --git a/tutorials/language/intrin_math.py b/tutorials/language/intrin_math.py index 65bfd4c38681..4a4ff96e2d31 100644 --- a/tutorials/language/intrin_math.py +++ b/tutorials/language/intrin_math.py @@ -135,7 +135,7 @@ def my_cuda_math_rule(op): def mylog(x): """customized log intrinsic function""" - return tvm.tir.call_pure_intrin(x.dtype, "tir.mylog", x) + return tvm.tir.call_intrin(x.dtype, "tir.mylog", x) def my_cuda_mylog_rule(op): @@ -148,7 +148,7 @@ def my_cuda_mylog_rule(op): return op # new op registration is triggered by registering an attribute of the op -tvm.ir.register_op_attr("tir.mylog", "TVectorizable", True) +tvm.ir.register_op_attr("tir.mylog", "TCallEffectKind", tvm.tir.CallEffectKind.Pure) tvm.target.register_intrin_rule("cuda", "mylog", my_cuda_mylog_rule, override=True) n = te.var("n") diff --git a/vta/python/vta/environment.py b/vta/python/vta/environment.py index 947c583ed55f..3e6a0c538b12 100644 --- a/vta/python/vta/environment.py +++ b/vta/python/vta/environment.py @@ -79,8 +79,7 @@ def __init__(self, env): self.vta_push_uop = tvm.tir.StringImm("VTAPushGEMMOp") ctx = tvm.tir.call_intrin("handle", "tir.vta.command_handle") self.command_handle = tvm.tir.Call( - "handle", "tir.tvm_thread_context", [ctx], - tvm.tir.Call.Intrinsic) + "handle", "tir.tvm_thread_context", [ctx]) self.DEBUG_NO_SYNC = False env._dev_ctx = self self.gemm = intrin.gemm(env, env.mock_mode) @@ -316,12 +315,15 @@ def coproc_dep_pop(op): # register a dummy into to trigger registration of the ops # change the info to lowering rule later. -tvm.ir.register_op_attr("tir.vta.coproc_sync", "TVectorizable", False) -tvm.ir.register_op_attr("tir.vta.coproc_dep_push", "TVectorizable", False) -tvm.ir.register_op_attr("tir.vta.coproc_dep_pop", "TVectorizable", False) +tvm.ir.register_op_attr("tir.vta.coproc_sync", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque) +tvm.ir.register_op_attr("tir.vta.coproc_dep_push", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque) +tvm.ir.register_op_attr("tir.vta.coproc_dep_pop", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque) +tvm.ir.register_op_attr("tir.vta.uop_push", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque) tvm.ir.register_op_attr("tir.vta.uop_push", "TGlobalSymbol", "VTAUopPush") + tvm.ir.register_op_attr("tir.vta.command_handle", "TGlobalSymbol", "VTATLSCommandHandle") +tvm.ir.register_op_attr("tir.vta.command_handle", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque) def _init_env(): diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index e92b178a5be6..d9f47f1f71ec 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -298,7 +298,7 @@ def _do_fold(stmt): if _match_pragma(stmt, "coproc_sync"): success[0] = True sync = tvm.tir.Call( - "int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic) + "int32", "vta.coproc_sync", []) return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)]) if _match_pragma(stmt, "trim_loop"): op = stmt.body