Skip to content

Commit

Permalink
[TIR][OP][API-CHANGE] Remove CallNode.call_type in favor of attribute.
Browse files Browse the repository at this point in the history
This is a followup refactor for tir::Call.
Now that we have switched call->name to call->op, the function effect property
can be registered through the op itself, so we no longer need the call_type in the CallNode.

- Introduce CallEffectKind to provide a more fine grained categorization of calls.
- Introduce call_pure_extern and call_llvm_pure_intrin to
  allow us to indicate pure calls in those cases.
- Migrate existing usecases to the new API.
  • Loading branch information
tqchen committed Jun 27, 2020
1 parent 75f2539 commit 1892e0e
Show file tree
Hide file tree
Showing 81 changed files with 784 additions and 705 deletions.
34 changes: 30 additions & 4 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,25 @@ TVM_DLL const Op& fma();
*/
TVM_DLL const Op& call_extern();

/*!
* \brief Call an pure extern C function with given name
* and signature from the types of args in the runtime environment.
*
* Type call_pure_extern(name, args...) {
* return dlsym(name)(args...);
* }
*
* \note This intrinsic does not provide any type checking,
* and is main used for backward compatibility reasons.
* Always consider use pre-registered and typed tvm::Op first.
*/
TVM_DLL const Op& call_pure_extern();

/*!
* \brief Call an LLVM intrinsic with a given intrinsic id
* and signature from the types of args in the runtime environment.
*
* Type call_llvm_intrin(intrin_id, args...) {
* Type call_llvm_pure_intrin(intrin_id, args...) {
* return dlsym(name)(args...);
* }
*
Expand All @@ -165,15 +179,27 @@ TVM_DLL const Op& call_extern();
TVM_DLL const Op& call_llvm_intrin();

/*!
* \brief Call an SPIRV GLSL450 intrinsic.
* \brief Call an LLVM pure intrinsic with a given intrinsic id
* and signature from the types of args in the runtime environment.
*
* Type call_llvm_pure_intrin(intrin_id, args...) {
* return dlsym(name)(args...);
* }
*
* \note This op does not provide any type checking.
*/
TVM_DLL const Op& call_llvm_pure_intrin();

/*!
* \brief Call an SPIRV pure GLSL450 intrinsic.
*
* Type call_spirv_glsl450(intrin_id, args...) {
* Type call_spirv_pure_glsl450(intrin_id, args...) {
* return dlsym(name)(args...);
* }
*
* \note This op does not provide any type checking.
*/
TVM_DLL const Op& call_spirv_glsl450();
TVM_DLL const Op& call_spirv_pure_glsl450();

// TODO(tvm-team) revisit the builtins below
// some of them can simply become ops with special codegen attr.
Expand Down
28 changes: 2 additions & 26 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -875,19 +875,6 @@ class Let : public PrimExpr {
*/
class CallNode : public PrimExprNode {
public:
/*! \brief Possible types of calls. */
enum CallType : int {
/*! \brief Extern "C" function. */
Extern = 0,
/*! \brief Extern CXX function. */
ExternCPlusPlus = 1,
/*! \brief Extern "C" without side-effect. */
PureExtern = 2,
/*! \brief Intrinsic functions. */
Intrinsic = 4,
/*! \brief Intrinsic functions that are pure. */
PureIntrinsic = 5
};
/*!
* \brief The operator(function) being invoked
*
Expand All @@ -898,31 +885,22 @@ class CallNode : public PrimExprNode {

/*! \brief The arguments. */
Array<PrimExpr> args;
/*! \brief Type of calls. */
CallType call_type;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("op", &op);
v->Visit("args", &args);
v->Visit("call_type", &call_type);
}

bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args) &&
equal(call_type, other->call_type);
return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(op);
hash_reduce(args);
hash_reduce(call_type);
}

/*! \return Whether call node is pure. */
bool is_pure() const { return (call_type == PureExtern || call_type == PureIntrinsic); }

static constexpr const char* _type_key = "tir.Call";
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode);
};
Expand All @@ -933,9 +911,7 @@ class CallNode : public PrimExprNode {
*/
class Call : public PrimExpr {
public:
using CallType = CallNode::CallType;

TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, CallType call_type);
TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args);
TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode);
};

Expand Down
16 changes: 8 additions & 8 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -553,10 +553,10 @@ TVM_DLL PrimExpr trunc(PrimExpr x);
TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high);

// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x) { \
static const Op& op = Op::Get("tir." #OpName); \
return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); \
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x) { \
static const Op& op = Op::Get("tir." #OpName); \
return tir::Call(x.dtype(), op, {x}); \
}

TVM_DECLARE_INTRIN_UNARY(exp);
Expand All @@ -583,10 +583,10 @@ TVM_DECLARE_INTRIN_UNARY(acosh);
TVM_DECLARE_INTRIN_UNARY(asinh);
TVM_DECLARE_INTRIN_UNARY(atanh);

#define TVM_DECLARE_INTRIN_BINARY(OpName) \
inline PrimExpr OpName(PrimExpr x, PrimExpr y) { \
static const Op& op = Op::Get("tir." #OpName); \
return tir::Call(x.dtype(), op, {x, y}, tir::CallNode::PureIntrinsic); \
#define TVM_DECLARE_INTRIN_BINARY(OpName) \
inline PrimExpr OpName(PrimExpr x, PrimExpr y) { \
static const Op& op = Op::Get("tir." #OpName); \
return tir::Call(x.dtype(), op, {x, y}); \
}

TVM_DECLARE_INTRIN_BINARY(atan2);
Expand Down
37 changes: 37 additions & 0 deletions include/tvm/tir/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,43 @@ using TGlobalSymbol = String;
*/
using TVectorizable = bool;

/*!
* \brief The effect type of the call.
*/
enum class CallEffectKind : int {
/*! \brief Function corresponds to an annotation(e.g. likely) and can translate to identity. */
kExprAnnotation = 0,
/*!
* \brief Pure function that do not interacts
* with any external state.
*/
kPure = 1,
/*!
* \brief Function's that may read from states(e.g. RAM)
*/
kReadState = 2,
/*!
* \brief Function that may read/write from states(e.g. RAM).
*/
kUpdateState = 3,
/*!
* \brief Opaque function, cannot make any assumption
*/
kOpaque = kUpdateState,
/*!
* \brief Special intrinsic to annotate call arguments info
* only valid as a direct argument to a call.
*/
kSpecialCallArg = 4,
/*!
* \brief Embed opaque information in the Expr, cannot be codegen.
*/
kEmbedInfo = 5
};

/*! \brief Use integer to record the kind. */
using TCallEffectKind = Integer;

} // namespace tir
} // namespace tvm
#endif // TVM_TIR_OP_ATTR_TYPES_H_
4 changes: 2 additions & 2 deletions python/tvm/te/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm.ir.container import Array
from tvm import target as _tgt
from tvm.tir import expr as _expr
from tvm.tir import call_pure_intrin
from tvm.tir import call_intrin
from tvm.tir.stmt import For

from .util import _internal_assert
Expand Down Expand Up @@ -148,7 +148,7 @@ def likely(func_id, args):
_internal_assert(args.__len__() == 1, \
"Only one expression can be likely")
_internal_assert(func_id == "likely", "This function cannot be directly invoked!")
return call_pure_intrin(args[0].dtype, 'tir.likely', *args)
return call_intrin(args[0].dtype, 'tir.likely', *args)


def max_num_threads(func_id, args):
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle, Call, Let
from .expr import IterVar, Any
from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle
from .expr import Call, CallEffectKind, Let, IterVar, Any

from .stmt import Stmt, LetStmt, AssertStmt, For
from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt
Expand All @@ -34,8 +34,8 @@

from .function import PrimFunc

from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, all, any, min_value, max_value, trace
from .op import call_packed, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, call_llvm_pure_intrin, all, any, min_value, max_value, trace
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp
from .op import sin, sinh, asin, asinh
from .op import cos, cosh, acos, acosh
Expand Down
39 changes: 29 additions & 10 deletions python/tvm/tir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,23 @@ def asobject(self):
return _ffi_api._OpNE(self.a, self.b)


class IntImmEnum(ObjectGeneric):
"""Lazily evaluate an IntImm in case
the constructor is not available in runtime.
Parameters
----------
value : int
The enum value
"""
def __init__(self, value):
self.value = value

def asobject(self):
"""Convert object."""
return IntImm("int32", self.value)


class PrimExprWithOp(ExprOp, PrimExpr):
"""Helper base class to inherit from PrimExpr."""
# In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__
Expand Down Expand Up @@ -959,6 +976,16 @@ def __init__(self, vectors, indices):
_ffi_api.Shuffle, vectors, indices)


class CallEffectKind:
"""Possible kinds of Call effects."""
# only expose up to opaque
ExprAnnotation = IntImmEnum(0)
Pure = IntImmEnum(1)
ReadState = IntImmEnum(2)
UpdateState = IntImmEnum(3)
Opaque = UpdateState


@tvm._ffi.register_object("tir.Call")
class Call(PrimExprWithOp):
"""Call node.
Expand All @@ -974,16 +1001,8 @@ class Call(PrimExprWithOp):
args : list of Expr
The input arguments to the call
call_type : int
The type of the call
"""
Extern = 0
ExternCPlusPlus = 1
PureExtern = 2
Intrinsic = 4
PureIntrinsic = 5
def __init__(self, dtype, op, args, call_type):
def __init__(self, dtype, op, args):
if isinstance(op, str):
if not op.startswith("tir."):
raise ValueError(
Expand All @@ -992,7 +1011,7 @@ def __init__(self, dtype, op, args, call_type):
"certain about the intrinsic name, pass in Op.get(name) instead") % op)
op = Op.get(op)
self.__init_handle_by_constructor__(
_ffi_api.Call, dtype, op, args, call_type)
_ffi_api.Call, dtype, op, args)


@tvm._ffi.register_object("tir.Let")
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,7 @@ def likely(self, expr):
expr : Expr
The expression will likely tag.
"""
return _expr.Call(expr.dtype, "tir.likely", [expr],
_expr.Call.PureIntrinsic)
return _expr.Call(expr.dtype, "tir.likely", [expr])

def get(self):
"""Return the builded IR.
Expand Down
Loading

0 comments on commit 1892e0e

Please sign in to comment.