Skip to content

Commit

Permalink
[RPC] Enhance RPC Protocol to support TVM Object (#15631)
Browse files Browse the repository at this point in the history
This PR introduces object support in TVM RPC protocol by introducing three
new interfaces in `rpc_reference.h`:
- `uint64_t GetObjectBytes(Object* obj)`, which is a required
  implementation that returns the length of the object during serialization;
- `void WriteObject(Object* obj)` used to serialize an object to a
  writable channel;
- `void ReadObject(int* type_code, TVMValue* value)`, which deserializes
  a TVM Object from a channel.

To serialize an object, a recommended paradigm is to write its
`type_index` first, and then its content. For example, `ShapeTuple` can
be serialized as:

```C++
// pseudocode
void WriteObject(Object* obj) {
  if (obj is ShapeTuple) {
    this->Write<uint32_t>(type_index of ShapeTuple);
    this->Write<int32_t>(obj->ndim);
    this->WriteArray<int64_t>(obj->shape);
  } else {
    throw Unsupported;
  }
}

uint64_t GetObjectBytes(Object* obj) {
  uint64_t result = 0;
  if (obj is ShapeTuple) {
    result += sizeof(uint32_t); # for `type_index`
    result += sizeof(int32_t);  # for `ndim`
    result += sizeof(int64_t) * obj->ndim; # for content of the shape
  } else {
    throw Unsupported;
  }
  return result;
}
```

To deserialize an object, similar to serialization, the recommended
approach paradigm is to read `type_index` and disptch based on it.

Caveat on deserialization: RPC Reference itself does not own or allocate
any memory to store objects, meaning extra logic is usually required in
`ReadObject` to keep their liveness.
  • Loading branch information
junrushao committed Aug 27, 2023
1 parent d3f9d3d commit c57da13
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 0 deletions.
2 changes: 2 additions & 0 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ struct TypeIndex {
kRuntimeShapeTuple = 6,
/*! \brief runtime::PackedFunc. */
kRuntimePackedFunc = 7,
/*! \brief runtime::DRef */
kRuntimeDiscoDRef = 8,
// static assignments that may subject to change.
kRuntimeClosure,
kRuntimeADT,
Expand Down
10 changes: 10 additions & 0 deletions src/runtime/minrpc/minrpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ class MinRPCReturns : public MinRPCReturnInterface {
io_->Exit(static_cast<int>(code));
}

void WriteObject(void* obj) { this->ThrowError(RPCServerStatus::kUnknownTypeCode); }
uint64_t GetObjectBytes(void* obj) {
this->ThrowError(RPCServerStatus::kUnknownTypeCode);
return 0;
}

template <typename T>
void Write(const T& data) {
static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value,
Expand Down Expand Up @@ -748,6 +754,10 @@ class MinRPCServer {
return ReadRawBytes(data, sizeof(T) * count);
}

void ReadObject(int* tcode, TVMValue* value) {
this->ThrowError(RPCServerStatus::kUnknownTypeCode);
}

private:
void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args) {
RPCReference::RecvPackedSeq(out_values, out_tcodes, out_num_args, this);
Expand Down
4 changes: 4 additions & 0 deletions src/runtime/minrpc/minrpc_server_logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ class MinRPCSniffer {
return ReadRawBytes(data, sizeof(T) * count);
}

void ReadObject(int* tcode, TVMValue* value) {
this->ThrowError(RPCServerStatus::kUnknownTypeCode);
}

private:
bool ReadRawBytes(void* data, size_t size) {
uint8_t* buf = reinterpret_cast<uint8_t*>(data);
Expand Down
13 changes: 13 additions & 0 deletions src/runtime/minrpc/rpc_reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
namespace tvm {
namespace runtime {

// Forward declare TVM Object to use `Object*` in RPC protocol.
class Object;

/*! \brief The current RPC procotol version. */
constexpr const char* kRPCProtocolVer = "0.8.0";

Expand Down Expand Up @@ -194,6 +197,8 @@ struct RPCReference {
num_bytes_ += sizeof(T) * num;
}

void WriteObject(Object* obj) { num_bytes_ += channel_->GetObjectBytes(obj); }

void ThrowError(RPCServerStatus status) { channel_->ThrowError(status); }

uint64_t num_bytes() const { return num_bytes_; }
Expand Down Expand Up @@ -364,6 +369,10 @@ struct RPCReference {
channel->WriteArray(bytes->data, len);
break;
}
case kTVMObjectHandle: {
channel->WriteObject(static_cast<Object*>(value.v_handle));
break;
}
default: {
channel->ThrowError(RPCServerStatus::kUnknownTypeCode);
break;
Expand Down Expand Up @@ -461,6 +470,10 @@ struct RPCReference {
value.v_handle = ReceiveDLTensor(channel);
break;
}
case kTVMObjectHandle: {
channel->ReadObject(&tcodes[i], &value);
break;
}
default: {
channel->ThrowError(RPCServerStatus::kUnknownTypeCode);
break;
Expand Down
10 changes: 10 additions & 0 deletions src/runtime/rpc/rpc_endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,16 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
this->Write(cdata);
}

void WriteObject(void* obj) { this->ThrowError(RPCServerStatus::kUnknownTypeCode); }
uint64_t GetObjectBytes(void* obj) {
this->ThrowError(RPCServerStatus::kUnknownTypeCode);
return 0;
}

void ReadObject(int* tcode, TVMValue* value) {
this->ThrowError(RPCServerStatus::kUnknownTypeCode);
}

void MessageDone() {
// Unused here, implemented for microTVM framing layer.
}
Expand Down

0 comments on commit c57da13

Please sign in to comment.