Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][OP][API-CHANGE] Remove CallNode.call_type in favor of attribute. #5937

Merged
merged 1 commit into from
Jun 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,25 @@ TVM_DLL const Op& fma();
*/
TVM_DLL const Op& call_extern();

/*!
* \brief Call an pure extern C function with given name
* and signature from the types of args in the runtime environment.
*
* Type call_pure_extern(name, args...) {
* return dlsym(name)(args...);
* }
*
* \note This intrinsic does not provide any type checking,
* and is main used for backward compatibility reasons.
* Always consider use pre-registered and typed tvm::Op first.
*/
TVM_DLL const Op& call_pure_extern();

/*!
* \brief Call an LLVM intrinsic with a given intrinsic id
* and signature from the types of args in the runtime environment.
*
* Type call_llvm_intrin(intrin_id, args...) {
* Type call_llvm_pure_intrin(intrin_id, args...) {
* return dlsym(name)(args...);
* }
*
Expand All @@ -165,15 +179,27 @@ TVM_DLL const Op& call_extern();
TVM_DLL const Op& call_llvm_intrin();

/*!
* \brief Call an SPIRV GLSL450 intrinsic.
* \brief Call an LLVM pure intrinsic with a given intrinsic id
* and signature from the types of args in the runtime environment.
*
* Type call_llvm_pure_intrin(intrin_id, args...) {
* return dlsym(name)(args...);
* }
*
* \note This op does not provide any type checking.
*/
TVM_DLL const Op& call_llvm_pure_intrin();

/*!
* \brief Call an SPIRV pure GLSL450 intrinsic.
*
* Type call_spirv_glsl450(intrin_id, args...) {
* Type call_spirv_pure_glsl450(intrin_id, args...) {
* return dlsym(name)(args...);
* }
*
* \note This op does not provide any type checking.
*/
TVM_DLL const Op& call_spirv_glsl450();
TVM_DLL const Op& call_spirv_pure_glsl450();

// TODO(tvm-team) revisit the builtins below
// some of them can simply become ops with special codegen attr.
Expand Down
28 changes: 2 additions & 26 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -875,19 +875,6 @@ class Let : public PrimExpr {
*/
class CallNode : public PrimExprNode {
public:
/*! \brief Possible types of calls. */
enum CallType : int {
/*! \brief Extern "C" function. */
Extern = 0,
/*! \brief Extern CXX function. */
ExternCPlusPlus = 1,
/*! \brief Extern "C" without side-effect. */
PureExtern = 2,
/*! \brief Intrinsic functions. */
Intrinsic = 4,
/*! \brief Intrinsic functions that are pure. */
PureIntrinsic = 5
};
/*!
* \brief The operator(function) being invoked
*
Expand All @@ -898,31 +885,22 @@ class CallNode : public PrimExprNode {

/*! \brief The arguments. */
Array<PrimExpr> args;
/*! \brief Type of calls. */
CallType call_type;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("op", &op);
v->Visit("args", &args);
v->Visit("call_type", &call_type);
}

bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args) &&
equal(call_type, other->call_type);
return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(op);
hash_reduce(args);
hash_reduce(call_type);
}

/*! \return Whether call node is pure. */
bool is_pure() const { return (call_type == PureExtern || call_type == PureIntrinsic); }

static constexpr const char* _type_key = "tir.Call";
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode);
};
Expand All @@ -933,9 +911,7 @@ class CallNode : public PrimExprNode {
*/
class Call : public PrimExpr {
public:
using CallType = CallNode::CallType;

TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, CallType call_type);
TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args);
TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode);
};

Expand Down
16 changes: 8 additions & 8 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -553,10 +553,10 @@ TVM_DLL PrimExpr trunc(PrimExpr x);
TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high);

// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x) { \
static const Op& op = Op::Get("tir." #OpName); \
return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); \
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x) { \
static const Op& op = Op::Get("tir." #OpName); \
return tir::Call(x.dtype(), op, {x}); \
}

TVM_DECLARE_INTRIN_UNARY(exp);
Expand All @@ -583,10 +583,10 @@ TVM_DECLARE_INTRIN_UNARY(acosh);
TVM_DECLARE_INTRIN_UNARY(asinh);
TVM_DECLARE_INTRIN_UNARY(atanh);

#define TVM_DECLARE_INTRIN_BINARY(OpName) \
inline PrimExpr OpName(PrimExpr x, PrimExpr y) { \
static const Op& op = Op::Get("tir." #OpName); \
return tir::Call(x.dtype(), op, {x, y}, tir::CallNode::PureIntrinsic); \
#define TVM_DECLARE_INTRIN_BINARY(OpName) \
inline PrimExpr OpName(PrimExpr x, PrimExpr y) { \
static const Op& op = Op::Get("tir." #OpName); \
return tir::Call(x.dtype(), op, {x, y}); \
}

TVM_DECLARE_INTRIN_BINARY(atan2);
Expand Down
37 changes: 37 additions & 0 deletions include/tvm/tir/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,43 @@ using TGlobalSymbol = String;
*/
using TVectorizable = bool;

/*!
* \brief The effect type of the call.
*/
enum class CallEffectKind : int {
/*! \brief Function corresponds to an annotation(e.g. likely) and can translate to identity. */
kExprAnnotation = 0,
/*!
* \brief Pure function that do not interacts
* with any external state.
*/
kPure = 1,
/*!
* \brief Function's that may read from states(e.g. RAM)
*/
kReadState = 2,
/*!
* \brief Function that may read/write from states(e.g. RAM).
*/
kUpdateState = 3,
/*!
* \brief Opaque function, cannot make any assumption
*/
kOpaque = kUpdateState,
/*!
* \brief Special intrinsic to annotate call arguments info
* only valid as a direct argument to a call.
*/
kSpecialCallArg = 4,
/*!
* \brief Embed opaque information in the Expr, cannot be codegen.
*/
kEmbedInfo = 5
};

/*! \brief Use integer to record the kind. */
using TCallEffectKind = Integer;

} // namespace tir
} // namespace tvm
#endif // TVM_TIR_OP_ATTR_TYPES_H_
4 changes: 2 additions & 2 deletions python/tvm/te/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm.ir.container import Array
from tvm import target as _tgt
from tvm.tir import expr as _expr
from tvm.tir import call_pure_intrin
from tvm.tir import call_intrin
from tvm.tir.stmt import For

from .util import _internal_assert
Expand Down Expand Up @@ -148,7 +148,7 @@ def likely(func_id, args):
_internal_assert(args.__len__() == 1, \
"Only one expression can be likely")
_internal_assert(func_id == "likely", "This function cannot be directly invoked!")
return call_pure_intrin(args[0].dtype, 'tir.likely', *args)
return call_intrin(args[0].dtype, 'tir.likely', *args)


def max_num_threads(func_id, args):
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle, Call, Let
from .expr import IterVar, Any
from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle
from .expr import Call, CallEffectKind, Let, IterVar, Any

from .stmt import Stmt, LetStmt, AssertStmt, For
from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt
Expand All @@ -34,8 +34,8 @@

from .function import PrimFunc

from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, all, any, min_value, max_value, trace
from .op import call_packed, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, call_llvm_pure_intrin, all, any, min_value, max_value, trace
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp
from .op import sin, sinh, asin, asinh
from .op import cos, cosh, acos, acosh
Expand Down
39 changes: 29 additions & 10 deletions python/tvm/tir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,23 @@ def asobject(self):
return _ffi_api._OpNE(self.a, self.b)


class IntImmEnum(ObjectGeneric):
"""Lazily evaluate an IntImm in case
the constructor is not available in runtime.

Parameters
----------
value : int
The enum value
"""
def __init__(self, value):
self.value = value

def asobject(self):
"""Convert object."""
return IntImm("int32", self.value)


class PrimExprWithOp(ExprOp, PrimExpr):
"""Helper base class to inherit from PrimExpr."""
# In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__
Expand Down Expand Up @@ -959,6 +976,16 @@ def __init__(self, vectors, indices):
_ffi_api.Shuffle, vectors, indices)


class CallEffectKind:
"""Possible kinds of Call effects."""
# only expose up to opaque
ExprAnnotation = IntImmEnum(0)
Pure = IntImmEnum(1)
ReadState = IntImmEnum(2)
UpdateState = IntImmEnum(3)
Opaque = UpdateState


@tvm._ffi.register_object("tir.Call")
class Call(PrimExprWithOp):
"""Call node.
Expand All @@ -974,16 +1001,8 @@ class Call(PrimExprWithOp):

args : list of Expr
The input arguments to the call

call_type : int
The type of the call
"""
Extern = 0
ExternCPlusPlus = 1
PureExtern = 2
Intrinsic = 4
PureIntrinsic = 5
def __init__(self, dtype, op, args, call_type):
def __init__(self, dtype, op, args):
if isinstance(op, str):
if not op.startswith("tir."):
raise ValueError(
Expand All @@ -992,7 +1011,7 @@ def __init__(self, dtype, op, args, call_type):
"certain about the intrinsic name, pass in Op.get(name) instead") % op)
op = Op.get(op)
self.__init_handle_by_constructor__(
_ffi_api.Call, dtype, op, args, call_type)
_ffi_api.Call, dtype, op, args)


@tvm._ffi.register_object("tir.Let")
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,7 @@ def likely(self, expr):
expr : Expr
The expression will likely tag.
"""
return _expr.Call(expr.dtype, "tir.likely", [expr],
_expr.Call.PureIntrinsic)
return _expr.Call(expr.dtype, "tir.likely", [expr])

def get(self):
"""Return the builded IR.
Expand Down
Loading