diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index c56600730eda..cf6d5fab0e19 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -1346,16 +1346,16 @@ inline R TypedPackedFunc::operator()(Args... args) const { // We use type traits to eliminate un-necessary checks. template inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { - using TObjectRef = typename std::remove_reference::type; + using ContainerType = typename std::remove_reference::type::ContainerType; if (value.defined()) { Object* ptr = value.data_.data_; - if (std::is_base_of::value || - (std::is_base_of::value && + if (std::is_base_of::value || + (std::is_base_of::value && ptr->IsInstance())) { values_[i].v_handle = NDArray::FFIGetHandle(value); type_codes_[i] = kTVMNDArrayHandle; - } else if (std::is_base_of::value || - (std::is_base_of::value && + } else if (std::is_base_of::value || + (std::is_base_of::value && ptr->IsInstance())) { values_[i].v_handle = ptr; type_codes_[i] = kTVMModuleHandle; @@ -1375,12 +1375,12 @@ template inline bool TVMPODValue_::IsObjectRef() const { using ContainerType = typename TObjectRef::ContainerType; // NOTE: the following code can be optimized by constant folding. - if (std::is_base_of::value) { + if (std::is_base_of::value) { return type_code_ == kTVMNDArrayHandle && TVMArrayHandleToObjectHandle( static_cast(value_.v_handle))->IsInstance(); } - if (std::is_base_of::value) { + if (std::is_base_of::value) { return type_code_ == kTVMModuleHandle && static_cast(value_.v_handle)->IsInstance(); } @@ -1390,8 +1390,10 @@ inline bool TVMPODValue_::IsObjectRef() const { *static_cast(value_.v_handle)); } return - (std::is_base_of::value && type_code_ == kTVMNDArrayHandle) || - (std::is_base_of::value && type_code_ == kTVMModuleHandle) || + (std::is_base_of::value && + type_code_ == kTVMNDArrayHandle) || + (std::is_base_of::value && + type_code_ == kTVMModuleHandle) || (type_code_ == kTVMObjectHandle && ObjectTypeChecker::Check(static_cast(value_.v_handle))); } @@ -1402,13 +1404,14 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { std::is_base_of::value, "Conversion only works for ObjectRef"); using ContainerType = typename TObjectRef::ContainerType; + if (type_code_ == kTVMNullptr) { CHECK(TObjectRef::_type_is_nullable) << "Expect a not null value of " << ContainerType::_type_key; return TObjectRef(ObjectPtr(nullptr)); } // NOTE: the following code can be optimized by constant folding. - if (std::is_base_of::value) { + if (std::is_base_of::value) { // Casting to a sub-class of NDArray TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle); ObjectPtr data = NDArray::FFIDataFromHandle( @@ -1417,7 +1420,7 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey(); return TObjectRef(data); } - if (std::is_base_of::value) { + if (std::is_base_of::value) { // Casting to a sub-class of Module TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); ObjectPtr data = GetObjectPtr(static_cast(value_.v_handle)); @@ -1438,13 +1441,13 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expect " << ObjectTypeChecker::TypeName() << " but get " << ptr->GetTypeKey(); return TObjectRef(GetObjectPtr(ptr)); - } else if (std::is_base_of::value && + } else if (std::is_base_of::value && type_code_ == kTVMNDArrayHandle) { // Casting to a base class that NDArray can sub-class ObjectPtr data = NDArray::FFIDataFromHandle( static_cast(value_.v_handle)); return TObjectRef(data); - } else if (std::is_base_of::value && + } else if (std::is_base_of::value && type_code_ == kTVMModuleHandle) { // Casting to a base class that Module can sub-class return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); @@ -1456,15 +1459,16 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { template inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { + using ContainerType = typename TObjectRef::ContainerType; const Object* ptr = other.get(); if (ptr != nullptr) { - if (std::is_base_of::value || - (std::is_base_of::value && + if (std::is_base_of::value || + (std::is_base_of::value && ptr->IsInstance())) { return operator=(NDArray(std::move(other.data_))); } - if (std::is_base_of::value || - (std::is_base_of::value && + if (std::is_base_of::value || + (std::is_base_of::value && ptr->IsInstance())) { return operator=(Module(std::move(other.data_))); } diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index 4913731f1bd3..6ea0b212e6a3 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -177,6 +177,16 @@ TEST(BuildModule, Heterogeneous) { runtime::Module mod = (*graph_runtime)( json, module, cpu_dev_ty, cpu_dev_id, gpu_dev_ty, gpu_dev_id); + // test FFI for module. + auto test_ffi = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + int tcode = args[1]; + CHECK_EQ(args[0].type_code(), tcode); + }); + + test_ffi(runtime::Module(mod), static_cast(kTVMModuleHandle)); + test_ffi(Optional(mod), static_cast(kTVMModuleHandle)); + + PackedFunc set_input = mod.GetFunction("set_input", false); PackedFunc run = mod.GetFunction("run", false); PackedFunc get_output = mod.GetFunction("get_output", false); diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index c67df63e6e7e..c89f8153ee37 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -468,6 +468,18 @@ TEST(Optional, PackedCall) { CHECK(packedfunc("xyz", false).operator String() == "xyz"); CHECK(packedfunc("xyz", false).operator Optional() == "xyz"); CHECK(packedfunc(nullptr, true).operator Optional() == nullptr); + + // test FFI convention. + auto test_ffi = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + int tcode = args[1]; + CHECK_EQ(args[0].type_code(), tcode); + }); + String s = "xyz"; + auto nd = NDArray::Empty({0, 1}, DataType::Float(32), DLContext{kDLCPU, 0}); + test_ffi(Optional(nd), static_cast(kTVMNDArrayHandle)); + test_ffi(Optional(s), static_cast(kTVMObjectRValueRefArg)); + test_ffi(s, static_cast(kTVMObjectHandle)); + test_ffi(String(s), static_cast(kTVMObjectRValueRefArg)); } int main(int argc, char** argv) {