Skip to content

Commit

Permalink
[RUNTIME] Introduce RValue reference(move) support to TypedPackedFunc
Browse files Browse the repository at this point in the history
This PR introduces RValue reference support the PackedFunc calling convention to address the above issue.
Specifically, when an argument is a r-value reference, we will use a assign a different type code(`kObjectRValueRefArg`),
and pass `Object**`  (the address to the Object pointer) instead through the values array.
The callee can choose to move out this Object pointer and set the original Object pointer from the caller side to be nullptr.

We also add an experimental move support to the python side(marked as _move so to indicate the dev nature).
This enhancement will enable copy on write optimizations through out the TVM stack.
  • Loading branch information
tqchen committed Apr 8, 2020
1 parent 53a4ad3 commit b2d7e84
Show file tree
Hide file tree
Showing 25 changed files with 374 additions and 96 deletions.
34 changes: 18 additions & 16 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class PrimExpr : public BaseExpr {

private:
// Internal function for conversion.
friend class runtime::TVMPODValue_;
friend class runtime::PackedFuncValueConverter<PrimExpr>;
TVM_DLL static PrimExpr FromObject_(ObjectPtr<Object> ptr);
};

Expand Down Expand Up @@ -450,22 +450,24 @@ inline const TTypeNode* RelayExprNode::type_as() const {

namespace tvm {
namespace runtime {
// Additional implementattion overloads for PackedFunc.
inline TVMPODValue_::operator tvm::PrimExpr() const {
if (type_code_ == kTVMNullptr) return PrimExpr();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return PrimExpr(static_cast<int>(value_.v_int64));
template<>
struct PackedFuncValueConverter<PrimExpr> {
// common rule for both RetValue and ArgValue.
static PrimExpr From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return PrimExpr(ObjectPtr<Object>(nullptr));
}
if (val.type_code() == kDLInt) {
return PrimExpr(val.operator int());
}
if (val.type_code() == kDLFloat) {
return PrimExpr(static_cast<float>(val.operator double()));
}
TVM_CHECK_TYPE_CODE(val.type_code(), kTVMObjectHandle);
Object* ptr = val.ptr<Object>();
return PrimExpr::FromObject_(GetObjectPtr<Object>(ptr));
}
if (type_code_ == kDLFloat) {
return PrimExpr(static_cast<float>(value_.v_float64));
}

TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
return PrimExpr::FromObject_(ObjectPtr<Object>(ptr));
}
};
} // namespace runtime
} // namespace tvm
#endif // TVM_IR_EXPR_H_
3 changes: 2 additions & 1 deletion include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ typedef enum {
kTVMStr = 11U,
kTVMBytes = 12U,
kTVMNDArrayHandle = 13U,
kTVMObjectRValueRefArg = 14U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
Expand Down Expand Up @@ -289,7 +290,7 @@ TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret,
*
* \return 0 when success, -1 when failure happens.
*/
TVM_DLL int TVMCbArgToReturn(TVMValue* value, int code);
TVM_DLL int TVMCbArgToReturn(TVMValue* value, int* code);

/*!
* \brief C type of packed function.
Expand Down
20 changes: 20 additions & 0 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <dmlc/logging.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>

#include <cstring>
#include <initializer_list>
Expand Down Expand Up @@ -582,6 +583,25 @@ inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count,
}
}

template<>
struct PackedFuncValueConverter<::tvm::runtime::String> {
static String From(const TVMArgValue& val) {
if (val.IsObjectRef<tvm::runtime::String>()) {
return val.AsObjectRef<tvm::runtime::String>();
} else {
return tvm::runtime::String(val.operator std::string());
}
}

static String From(const TVMRetValue& val) {
if (val.IsObjectRef<tvm::runtime::String>()) {
return val.AsObjectRef<tvm::runtime::String>();
} else {
return tvm::runtime::String(val.operator std::string());
}
}
};

} // namespace runtime
} // namespace tvm

Expand Down
16 changes: 16 additions & 0 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,17 @@ class ObjectPtr {
data_->IncRef();
}
}
/*!
* \brief Move an ObjectPtr from an RValueRef argument.
* \param ref The rvalue reference.
* \return the moved result.
*/
static ObjectPtr<T> MoveFromRValueRefArg(Object** ref) {
ObjectPtr<T> ptr;
ptr.data_ = *ref;
*ref = nullptr;
return ptr;
}
// friend classes
friend class Object;
friend class ObjectRef;
Expand All @@ -489,6 +500,7 @@ class ObjectPtr {
friend class TVMArgsSetter;
friend class TVMRetValue;
friend class TVMArgValue;
friend class TVMMovableArgValue_;
template <typename RelayRefType, typename ObjType>
friend RelayRefType GetRef(const ObjType* ptr);
template <typename BaseType, typename ObjType>
Expand Down Expand Up @@ -550,6 +562,10 @@ class ObjectRef {
bool unique() const {
return data_.unique();
}
/*! \return The use count of the ptr, for debug purposes */
int use_count() const {
return data_.use_count();
}
/*!
* \brief Try to downcast the internal Object to a
* raw pointer of a corresponding type.
Expand Down
Loading

0 comments on commit b2d7e84

Please sign in to comment.