Skip to content

Commit

Permalink
[RUNTIME] Improve PackedFunc robustness (apache#5517)
Browse files Browse the repository at this point in the history
* [RUNTIME] Improve PackedFunc robustness

- Add static assert to warn about unsupported type deduction.
- Always inline template expansions for PackedFunc calls.

* Fix style issue
  • Loading branch information
tqchen authored and trevor-m committed Jun 18, 2020
1 parent 1131ba7 commit d9b7e72
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 41 deletions.
101 changes: 62 additions & 39 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@
#define TVM_RUNTIME_HEADER_ONLY 0
#endif

// Always inline macro only use in template
// expansion cases where we know inline is important.
#ifdef _MSC_VER
#define TVM_ALWAYS_INLINE __forceinline inline
#else
#define TVM_ALWAYS_INLINE inline __attribute__((always_inline))
#endif

namespace tvm {
namespace runtime {

Expand Down Expand Up @@ -273,7 +281,7 @@ class TypedPackedFunc<R(Args...)> {
* \param args The arguments
* \returns The return value.
*/
inline R operator()(Args ...args) const;
TVM_ALWAYS_INLINE R operator()(Args ...args) const;
/*!
* \brief convert to PackedFunc
* \return the internal PackedFunc
Expand Down Expand Up @@ -1076,11 +1084,15 @@ struct func_signature_helper {
template<typename T, typename R, typename ...Args>
struct func_signature_helper<R (T::*)(Args...)> {
using FType = R(Args...);
static_assert(!std::is_reference<R>::value,
"TypedPackedFunc return reference");
};

template<typename T, typename R, typename ...Args>
struct func_signature_helper<R (T::*)(Args...) const> {
using FType = R(Args...);
static_assert(!std::is_reference<R>::value,
"TypedPackedFunc return reference");
};

/*!
Expand All @@ -1096,12 +1108,16 @@ struct function_signature {
template<typename R, typename ...Args>
struct function_signature<R(Args...)> {
using FType = R(Args...);
static_assert(!std::is_reference<R>::value,
"TypedPackedFunc return reference");
};

// handle case of function ptr.
template<typename R, typename ...Args>
struct function_signature<R (*)(Args...)> {
using FType = R(Args...);
static_assert(!std::is_reference<R>::value,
"TypedPackedFunc return reference");
};
} // namespace detail

Expand All @@ -1114,66 +1130,66 @@ class TVMArgsSetter {
template<typename T,
typename = typename std::enable_if<
std::is_integral<T>::value>::type>
void operator()(size_t i, T value) const {
TVM_ALWAYS_INLINE void operator()(size_t i, T value) const {
values_[i].v_int64 = static_cast<int64_t>(value);
type_codes_[i] = kDLInt;
}
void operator()(size_t i, uint64_t value) const {
TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const {
values_[i].v_int64 = static_cast<int64_t>(value);
CHECK_LE(value,
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
type_codes_[i] = kDLInt;
}
void operator()(size_t i, double value) const {
TVM_ALWAYS_INLINE void operator()(size_t i, double value) const {
values_[i].v_float64 = value;
type_codes_[i] = kDLFloat;
}
void operator()(size_t i, std::nullptr_t value) const {
TVM_ALWAYS_INLINE void operator()(size_t i, std::nullptr_t value) const {
values_[i].v_handle = value;
type_codes_[i] = kTVMNullptr;
}
void operator()(size_t i, const TVMArgValue& value) const {
TVM_ALWAYS_INLINE void operator()(size_t i, const TVMArgValue& value) const {
values_[i] = value.value_;
type_codes_[i] = value.type_code_;
}
void operator()(size_t i, void* value) const {
TVM_ALWAYS_INLINE void operator()(size_t i, void* value) const {
values_[i].v_handle = value;
type_codes_[i] = kTVMOpaqueHandle;
}
void operator()(size_t i, DLTensor* value) const {
TVM_ALWAYS_INLINE void operator()(size_t i, DLTensor* value) const {
values_[i].v_handle = value;
type_codes_[i] = kTVMDLTensorHandle;
}
void operator()(size_t i, TVMContext value) const {
TVM_ALWAYS_INLINE void operator()(size_t i, TVMContext value) const {
values_[i].v_ctx = value;
type_codes_[i] = kTVMContext;
}
void operator()(size_t i, DLDataType value) const {
TVM_ALWAYS_INLINE void operator()(size_t i, DLDataType value) const {
values_[i].v_type = value;
type_codes_[i] = kTVMDataType;
}
void operator()(size_t i, DataType dtype) const {
TVM_ALWAYS_INLINE void operator()(size_t i, DataType dtype) const {
operator()(i, dtype.operator DLDataType());
}
void operator()(size_t i, const char* value) const {
TVM_ALWAYS_INLINE void operator()(size_t i, const char* value) const {
values_[i].v_str = value;
type_codes_[i] = kTVMStr;
}
// setters for container types
void operator()(size_t i, const std::string& value) const {
TVM_ALWAYS_INLINE void operator()(size_t i, const std::string& value) const {
values_[i].v_str = value.c_str();
type_codes_[i] = kTVMStr;
}
void operator()(size_t i, const TVMByteArray& value) const {
TVM_ALWAYS_INLINE void operator()(size_t i, const TVMByteArray& value) const {
values_[i].v_handle = const_cast<TVMByteArray*>(&value);
type_codes_[i] = kTVMBytes;
}
void operator()(size_t i, const PackedFunc& value) const {
TVM_ALWAYS_INLINE void operator()(size_t i, const PackedFunc& value) const {
values_[i].v_handle = const_cast<PackedFunc*>(&value);
type_codes_[i] = kTVMPackedFuncHandle;
}
template<typename FType>
void operator()(size_t i, const TypedPackedFunc<FType>& value) const {
TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc<FType>& value) const {
operator()(i, value.packed());
}
void operator()(size_t i, const TVMRetValue& value) const {
Expand All @@ -1191,7 +1207,7 @@ class TVMArgsSetter {
typename = typename std::enable_if<
std::is_base_of<ObjectRef, TObjectRef>::value>
::type>
void operator()(size_t i, const TObjectRef& value) const {
TVM_ALWAYS_INLINE void operator()(size_t i, const TObjectRef& value) const {
this->SetObject(i, value);
}

Expand All @@ -1200,7 +1216,7 @@ class TVMArgsSetter {
std::is_base_of<ObjectRef,
typename std::remove_reference<TObjectRef>::type>::value>
::type>
void operator()(size_t i, TObjectRef&& value) const {
TVM_ALWAYS_INLINE void operator()(size_t i, TObjectRef&& value) const {
this->SetObject(i, std::forward<TObjectRef>(value));
}

Expand Down Expand Up @@ -1230,10 +1246,10 @@ namespace detail {
template<typename R, int nleft, int index, typename F>
struct unpack_call_dispatcher {
template<typename ...Args>
static void run(const F& f,
const TVMArgs& args_pack,
TVMRetValue* rv,
Args&&... unpacked_args) {
TVM_ALWAYS_INLINE static void run(const F& f,
const TVMArgs& args_pack,
TVMRetValue* rv,
Args&&... unpacked_args) {
// construct a movable argument value
// which allows potential move of argument to the input of F.
unpack_call_dispatcher<R, nleft - 1, index + 1, F>
Expand All @@ -1247,27 +1263,33 @@ struct unpack_call_dispatcher {
template<typename R, int index, typename F>
struct unpack_call_dispatcher<R, 0, index, F> {
template<typename ...Args>
static void run(const F& f,
const TVMArgs& args_pack,
TVMRetValue* rv,
Args&&... unpacked_args) {
*rv = R(f(std::forward<Args>(unpacked_args)...));
TVM_ALWAYS_INLINE static void run(const F& f,
const TVMArgs& args_pack,
TVMRetValue* rv,
Args&&... unpacked_args) {
using RetType = decltype(f(std::forward<Args>(unpacked_args)...));
if (std::is_same<RetType, R>::value) {
*rv = f(std::forward<Args>(unpacked_args)...);
} else {
*rv = R(f(std::forward<Args>(unpacked_args)...));
}
}
};

template<int index, typename F>
struct unpack_call_dispatcher<void, 0, index, F> {
template<typename ...Args>
static void run(const F& f,
const TVMArgs& args_pack,
TVMRetValue* rv,
Args&&... unpacked_args) {
TVM_ALWAYS_INLINE static void run(const F& f,
const TVMArgs& args_pack,
TVMRetValue* rv,
Args&&... unpacked_args) {
f(std::forward<Args>(unpacked_args)...);
}
};

template<typename R, int nargs, typename F>
inline void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) {
TVM_ALWAYS_INLINE void unpack_call(
const F& f, const TVMArgs& args, TVMRetValue* rv) {
CHECK_EQ(nargs, args.size())
<< "Expect " << nargs << " arguments but get " << args.size();
unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv);
Expand All @@ -1280,30 +1302,31 @@ struct unpack_call_by_signature {
template<typename R, typename ...Args>
struct unpack_call_by_signature<R(Args...)> {
template<typename F>
static void run(const F& f,
const TVMArgs& args,
TVMRetValue* rv) {
TVM_ALWAYS_INLINE static void run(
const F& f,
const TVMArgs& args,
TVMRetValue* rv) {
unpack_call<R, sizeof...(Args)>(f, args, rv);
}
};

template<typename R, typename ...Args>
inline R call_packed(const PackedFunc& pf, Args&& ...args) {
TVM_ALWAYS_INLINE R call_packed(const PackedFunc& pf, Args&& ...args) {
return R(pf(std::forward<Args>(args)...));
}

template<typename R>
struct typed_packed_call_dispatcher {
template<typename ...Args>
static inline R run(const PackedFunc& pf, Args&& ...args) {
TVM_ALWAYS_INLINE static R run(const PackedFunc& pf, Args&& ...args) {
return pf(std::forward<Args>(args)...);
}
};

template<>
struct typed_packed_call_dispatcher<void> {
template<typename ...Args>
static inline void run(const PackedFunc& pf, Args&& ...args) {
TVM_ALWAYS_INLINE static void run(const PackedFunc& pf, Args&& ...args) {
pf(std::forward<Args>(args)...);
}
};
Expand Down Expand Up @@ -1334,7 +1357,7 @@ inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
}

template<typename R, typename ...Args>
inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
TVM_ALWAYS_INLINE R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
return detail::typed_packed_call_dispatcher<R>
::run(packed_, std::forward<Args>(args)...);
}
Expand Down
5 changes: 4 additions & 1 deletion src/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,10 @@ TVM_REGISTER_GLOBAL("relay.op._ListOpNames")
return ret;
});

TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed(Op::Get);
TVM_REGISTER_GLOBAL("relay.op._GetOp")
.set_body_typed([](std::string name) -> Op {
return Op::Get(name);
});

TVM_REGISTER_GLOBAL("relay.op._OpGetAttr")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Expand Down
4 changes: 3 additions & 1 deletion src/relay/quantize/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});

TVM_REGISTER_GLOBAL("relay._quantize._GetCurrentQConfig")
.set_body_typed(QConfig::Current);
.set_body_typed([]() -> QConfig {
return QConfig::Current();
});

TVM_REGISTER_GLOBAL("relay._quantize._EnterQConfigScope")
.set_body_typed(QConfig::EnterQConfigScope);
Expand Down

0 comments on commit d9b7e72

Please sign in to comment.