From d9b7e7224217c1fd21339a51d2a071728ccd74ad Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 6 May 2020 11:50:09 -0700 Subject: [PATCH] [RUNTIME] Improve PackedFunc robustness (#5517) * [RUNTIME] Improve PackedFunc robustness - Add static assert to warn about unsupported type deduction. - Always inline template expansions for PackedFunc calls. * Fix style issue --- include/tvm/runtime/packed_func.h | 101 ++++++++++++++++++------------ src/ir/op.cc | 5 +- src/relay/quantize/quantize.cc | 4 +- 3 files changed, 69 insertions(+), 41 deletions(-) diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index cf6d5fab0e19..0726292234fd 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -45,6 +45,14 @@ #define TVM_RUNTIME_HEADER_ONLY 0 #endif +// Always inline macro only use in template +// expansion cases where we know inline is important. +#ifdef _MSC_VER +#define TVM_ALWAYS_INLINE __forceinline inline +#else +#define TVM_ALWAYS_INLINE inline __attribute__((always_inline)) +#endif + namespace tvm { namespace runtime { @@ -273,7 +281,7 @@ class TypedPackedFunc { * \param args The arguments * \returns The return value. */ - inline R operator()(Args ...args) const; + TVM_ALWAYS_INLINE R operator()(Args ...args) const; /*! * \brief convert to PackedFunc * \return the internal PackedFunc @@ -1076,11 +1084,15 @@ struct func_signature_helper { template struct func_signature_helper { using FType = R(Args...); + static_assert(!std::is_reference::value, + "TypedPackedFunc return reference"); }; template struct func_signature_helper { using FType = R(Args...); + static_assert(!std::is_reference::value, + "TypedPackedFunc return reference"); }; /*! @@ -1096,12 +1108,16 @@ struct function_signature { template struct function_signature { using FType = R(Args...); + static_assert(!std::is_reference::value, + "TypedPackedFunc return reference"); }; // handle case of function ptr. template struct function_signature { using FType = R(Args...); + static_assert(!std::is_reference::value, + "TypedPackedFunc return reference"); }; } // namespace detail @@ -1114,66 +1130,66 @@ class TVMArgsSetter { template::value>::type> - void operator()(size_t i, T value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, T value) const { values_[i].v_int64 = static_cast(value); type_codes_[i] = kDLInt; } - void operator()(size_t i, uint64_t value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const { values_[i].v_int64 = static_cast(value); CHECK_LE(value, static_cast(std::numeric_limits::max())); type_codes_[i] = kDLInt; } - void operator()(size_t i, double value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, double value) const { values_[i].v_float64 = value; type_codes_[i] = kDLFloat; } - void operator()(size_t i, std::nullptr_t value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, std::nullptr_t value) const { values_[i].v_handle = value; type_codes_[i] = kTVMNullptr; } - void operator()(size_t i, const TVMArgValue& value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const TVMArgValue& value) const { values_[i] = value.value_; type_codes_[i] = value.type_code_; } - void operator()(size_t i, void* value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, void* value) const { values_[i].v_handle = value; type_codes_[i] = kTVMOpaqueHandle; } - void operator()(size_t i, DLTensor* value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, DLTensor* value) const { values_[i].v_handle = value; type_codes_[i] = kTVMDLTensorHandle; } - void operator()(size_t i, TVMContext value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, TVMContext value) const { values_[i].v_ctx = value; type_codes_[i] = kTVMContext; } - void operator()(size_t i, DLDataType value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, DLDataType value) const { values_[i].v_type = value; type_codes_[i] = kTVMDataType; } - void operator()(size_t i, DataType dtype) const { + TVM_ALWAYS_INLINE void operator()(size_t i, DataType dtype) const { operator()(i, dtype.operator DLDataType()); } - void operator()(size_t i, const char* value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const char* value) const { values_[i].v_str = value; type_codes_[i] = kTVMStr; } // setters for container types - void operator()(size_t i, const std::string& value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const std::string& value) const { values_[i].v_str = value.c_str(); type_codes_[i] = kTVMStr; } - void operator()(size_t i, const TVMByteArray& value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const TVMByteArray& value) const { values_[i].v_handle = const_cast(&value); type_codes_[i] = kTVMBytes; } - void operator()(size_t i, const PackedFunc& value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const PackedFunc& value) const { values_[i].v_handle = const_cast(&value); type_codes_[i] = kTVMPackedFuncHandle; } template - void operator()(size_t i, const TypedPackedFunc& value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc& value) const { operator()(i, value.packed()); } void operator()(size_t i, const TVMRetValue& value) const { @@ -1191,7 +1207,7 @@ class TVMArgsSetter { typename = typename std::enable_if< std::is_base_of::value> ::type> - void operator()(size_t i, const TObjectRef& value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const TObjectRef& value) const { this->SetObject(i, value); } @@ -1200,7 +1216,7 @@ class TVMArgsSetter { std::is_base_of::type>::value> ::type> - void operator()(size_t i, TObjectRef&& value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, TObjectRef&& value) const { this->SetObject(i, std::forward(value)); } @@ -1230,10 +1246,10 @@ namespace detail { template struct unpack_call_dispatcher { template - static void run(const F& f, - const TVMArgs& args_pack, - TVMRetValue* rv, - Args&&... unpacked_args) { + TVM_ALWAYS_INLINE static void run(const F& f, + const TVMArgs& args_pack, + TVMRetValue* rv, + Args&&... unpacked_args) { // construct a movable argument value // which allows potential move of argument to the input of F. unpack_call_dispatcher @@ -1247,27 +1263,33 @@ struct unpack_call_dispatcher { template struct unpack_call_dispatcher { template - static void run(const F& f, - const TVMArgs& args_pack, - TVMRetValue* rv, - Args&&... unpacked_args) { - *rv = R(f(std::forward(unpacked_args)...)); + TVM_ALWAYS_INLINE static void run(const F& f, + const TVMArgs& args_pack, + TVMRetValue* rv, + Args&&... unpacked_args) { + using RetType = decltype(f(std::forward(unpacked_args)...)); + if (std::is_same::value) { + *rv = f(std::forward(unpacked_args)...); + } else { + *rv = R(f(std::forward(unpacked_args)...)); + } } }; template struct unpack_call_dispatcher { template - static void run(const F& f, - const TVMArgs& args_pack, - TVMRetValue* rv, - Args&&... unpacked_args) { + TVM_ALWAYS_INLINE static void run(const F& f, + const TVMArgs& args_pack, + TVMRetValue* rv, + Args&&... unpacked_args) { f(std::forward(unpacked_args)...); } }; template -inline void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) { +TVM_ALWAYS_INLINE void unpack_call( + const F& f, const TVMArgs& args, TVMRetValue* rv) { CHECK_EQ(nargs, args.size()) << "Expect " << nargs << " arguments but get " << args.size(); unpack_call_dispatcher::run(f, args, rv); @@ -1280,22 +1302,23 @@ struct unpack_call_by_signature { template struct unpack_call_by_signature { template - static void run(const F& f, - const TVMArgs& args, - TVMRetValue* rv) { + TVM_ALWAYS_INLINE static void run( + const F& f, + const TVMArgs& args, + TVMRetValue* rv) { unpack_call(f, args, rv); } }; template -inline R call_packed(const PackedFunc& pf, Args&& ...args) { +TVM_ALWAYS_INLINE R call_packed(const PackedFunc& pf, Args&& ...args) { return R(pf(std::forward(args)...)); } template struct typed_packed_call_dispatcher { template - static inline R run(const PackedFunc& pf, Args&& ...args) { + TVM_ALWAYS_INLINE static R run(const PackedFunc& pf, Args&& ...args) { return pf(std::forward(args)...); } }; @@ -1303,7 +1326,7 @@ struct typed_packed_call_dispatcher { template<> struct typed_packed_call_dispatcher { template - static inline void run(const PackedFunc& pf, Args&& ...args) { + TVM_ALWAYS_INLINE static void run(const PackedFunc& pf, Args&& ...args) { pf(std::forward(args)...); } }; @@ -1334,7 +1357,7 @@ inline void TypedPackedFunc::AssignTypedLambda(FType flambda) { } template -inline R TypedPackedFunc::operator()(Args... args) const { +TVM_ALWAYS_INLINE R TypedPackedFunc::operator()(Args... args) const { return detail::typed_packed_call_dispatcher ::run(packed_, std::forward(args)...); } diff --git a/src/ir/op.cc b/src/ir/op.cc index b024165c1a4c..bd8a6e22f70e 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -148,7 +148,10 @@ TVM_REGISTER_GLOBAL("relay.op._ListOpNames") return ret; }); -TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed(Op::Get); +TVM_REGISTER_GLOBAL("relay.op._GetOp") +.set_body_typed([](std::string name) -> Op { + return Op::Get(name); +}); TVM_REGISTER_GLOBAL("relay.op._OpGetAttr") .set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/src/relay/quantize/quantize.cc b/src/relay/quantize/quantize.cc index 631d8c0fdf58..431e18b95356 100644 --- a/src/relay/quantize/quantize.cc +++ b/src/relay/quantize/quantize.cc @@ -135,7 +135,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); TVM_REGISTER_GLOBAL("relay._quantize._GetCurrentQConfig") -.set_body_typed(QConfig::Current); +.set_body_typed([]() -> QConfig { + return QConfig::Current(); +}); TVM_REGISTER_GLOBAL("relay._quantize._EnterQConfigScope") .set_body_typed(QConfig::EnterQConfigScope);