Skip to content

Commit

Permalink
[RUNTIME] Improved Packed FFI for optional. (apache#5478)
Browse files Browse the repository at this point in the history
Allows Optional<NDArray> and module to be passed with the right type code.
  • Loading branch information
tqchen authored and trevor-m committed Jun 18, 2020
1 parent b289f3d commit 1e29992
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 17 deletions.
38 changes: 21 additions & 17 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -1346,16 +1346,16 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
// We use type traits to eliminate un-necessary checks.
template<typename T>
inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
using TObjectRef = typename std::remove_reference<T>::type;
using ContainerType = typename std::remove_reference<T>::type::ContainerType;
if (value.defined()) {
Object* ptr = value.data_.data_;
if (std::is_base_of<NDArray, TObjectRef>::value ||
(std::is_base_of<TObjectRef, NDArray>::value &&
if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
(std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
ptr->IsInstance<NDArray::ContainerType>())) {
values_[i].v_handle = NDArray::FFIGetHandle(value);
type_codes_[i] = kTVMNDArrayHandle;
} else if (std::is_base_of<Module, TObjectRef>::value ||
(std::is_base_of<TObjectRef, Module>::value &&
} else if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
(std::is_base_of<ContainerType, Module::ContainerType>::value &&
ptr->IsInstance<Module::ContainerType>())) {
values_[i].v_handle = ptr;
type_codes_[i] = kTVMModuleHandle;
Expand All @@ -1375,12 +1375,12 @@ template<typename TObjectRef, typename>
inline bool TVMPODValue_::IsObjectRef() const {
using ContainerType = typename TObjectRef::ContainerType;
// NOTE: the following code can be optimized by constant folding.
if (std::is_base_of<NDArray, TObjectRef>::value) {
if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
return type_code_ == kTVMNDArrayHandle &&
TVMArrayHandleToObjectHandle(
static_cast<TVMArrayHandle>(value_.v_handle))->IsInstance<ContainerType>();
}
if (std::is_base_of<Module, TObjectRef>::value) {
if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
return type_code_ == kTVMModuleHandle &&
static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
}
Expand All @@ -1390,8 +1390,10 @@ inline bool TVMPODValue_::IsObjectRef() const {
*static_cast<Object**>(value_.v_handle));
}
return
(std::is_base_of<TObjectRef, NDArray>::value && type_code_ == kTVMNDArrayHandle) ||
(std::is_base_of<TObjectRef, Module>::value && type_code_ == kTVMModuleHandle) ||
(std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
type_code_ == kTVMNDArrayHandle) ||
(std::is_base_of<ContainerType, Module::ContainerType>::value &&
type_code_ == kTVMModuleHandle) ||
(type_code_ == kTVMObjectHandle &&
ObjectTypeChecker<TObjectRef>::Check(static_cast<Object*>(value_.v_handle)));
}
Expand All @@ -1402,13 +1404,14 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
using ContainerType = typename TObjectRef::ContainerType;

if (type_code_ == kTVMNullptr) {
CHECK(TObjectRef::_type_is_nullable)
<< "Expect a not null value of " << ContainerType::_type_key;
return TObjectRef(ObjectPtr<Object>(nullptr));
}
// NOTE: the following code can be optimized by constant folding.
if (std::is_base_of<NDArray, TObjectRef>::value) {
if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
// Casting to a sub-class of NDArray
TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle);
ObjectPtr<Object> data = NDArray::FFIDataFromHandle(
Expand All @@ -1417,7 +1420,7 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
<< "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey();
return TObjectRef(data);
}
if (std::is_base_of<Module, TObjectRef>::value) {
if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
// Casting to a sub-class of Module
TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle);
ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle));
Expand All @@ -1438,13 +1441,13 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
<< "Expect " << ObjectTypeChecker<TObjectRef>::TypeName()
<< " but get " << ptr->GetTypeKey();
return TObjectRef(GetObjectPtr<Object>(ptr));
} else if (std::is_base_of<TObjectRef, NDArray>::value &&
} else if (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
type_code_ == kTVMNDArrayHandle) {
// Casting to a base class that NDArray can sub-class
ObjectPtr<Object> data = NDArray::FFIDataFromHandle(
static_cast<TVMArrayHandle>(value_.v_handle));
return TObjectRef(data);
} else if (std::is_base_of<TObjectRef, Module>::value &&
} else if (std::is_base_of<ContainerType, Module::ContainerType>::value &&
type_code_ == kTVMModuleHandle) {
// Casting to a base class that Module can sub-class
return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
Expand All @@ -1456,15 +1459,16 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {

template<typename TObjectRef, typename>
inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) {
using ContainerType = typename TObjectRef::ContainerType;
const Object* ptr = other.get();
if (ptr != nullptr) {
if (std::is_base_of<NDArray, TObjectRef>::value ||
(std::is_base_of<TObjectRef, NDArray>::value &&
if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
(std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
ptr->IsInstance<NDArray::ContainerType>())) {
return operator=(NDArray(std::move(other.data_)));
}
if (std::is_base_of<Module, TObjectRef>::value ||
(std::is_base_of<TObjectRef, Module>::value &&
if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
(std::is_base_of<ContainerType, Module::ContainerType>::value &&
ptr->IsInstance<Module::ContainerType>())) {
return operator=(Module(std::move(other.data_)));
}
Expand Down
10 changes: 10 additions & 0 deletions tests/cpp/build_module_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,16 @@ TEST(BuildModule, Heterogeneous) {
runtime::Module mod = (*graph_runtime)(
json, module, cpu_dev_ty, cpu_dev_id, gpu_dev_ty, gpu_dev_id);

// test FFI for module.
auto test_ffi = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
int tcode = args[1];
CHECK_EQ(args[0].type_code(), tcode);
});

test_ffi(runtime::Module(mod), static_cast<int>(kTVMModuleHandle));
test_ffi(Optional<runtime::Module>(mod), static_cast<int>(kTVMModuleHandle));


PackedFunc set_input = mod.GetFunction("set_input", false);
PackedFunc run = mod.GetFunction("run", false);
PackedFunc get_output = mod.GetFunction("get_output", false);
Expand Down
12 changes: 12 additions & 0 deletions tests/cpp/container_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,18 @@ TEST(Optional, PackedCall) {
CHECK(packedfunc("xyz", false).operator String() == "xyz");
CHECK(packedfunc("xyz", false).operator Optional<String>() == "xyz");
CHECK(packedfunc(nullptr, true).operator Optional<String>() == nullptr);

// test FFI convention.
auto test_ffi = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
int tcode = args[1];
CHECK_EQ(args[0].type_code(), tcode);
});
String s = "xyz";
auto nd = NDArray::Empty({0, 1}, DataType::Float(32), DLContext{kDLCPU, 0});
test_ffi(Optional<NDArray>(nd), static_cast<int>(kTVMNDArrayHandle));
test_ffi(Optional<String>(s), static_cast<int>(kTVMObjectRValueRefArg));
test_ffi(s, static_cast<int>(kTVMObjectHandle));
test_ffi(String(s), static_cast<int>(kTVMObjectRValueRefArg));
}

int main(int argc, char** argv) {
Expand Down

0 comments on commit 1e29992

Please sign in to comment.