diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index 8a7f11cbdea6..26f24c91f728 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -57,11 +57,13 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { EventHandler(support::RingBuffer* reader, support::RingBuffer* writer, std::string name, - std::string* remote_key) + std::string* remote_key, + std::function flush_writer) : reader_(reader), writer_(writer), name_(name), - remote_key_(remote_key) { + remote_key_(remote_key), + flush_writer_(flush_writer) { this->Clear(); if (*remote_key == "%toinit") { @@ -109,13 +111,21 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { /*! * \brief Enter the io loop until the next event. * \param client_mode Whether we are in the client. + * \param async_server_mode Whether we are in the async server mode. * \param setreturn The function to set the return value encoding. * \return The function to set return values when there is a return event. */ - RPCCode HandleNextEvent(bool client_mode, RPCSession::FEncodeReturn setreturn) { + RPCCode HandleNextEvent(bool client_mode, + bool async_server_mode, + RPCSession::FEncodeReturn setreturn) { std::swap(client_mode_, client_mode); + std::swap(async_server_mode_, async_server_mode); - while (this->Ready()) { + RPCCode status = RPCCode::kNone; + + while (status == RPCCode::kNone && + state_ != kWaitForAsyncCallback && + this->Ready()) { switch (state_) { case kInitHeader: HandleInitHeader(); break; case kRecvPacketNumBytes: { @@ -133,23 +143,27 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { this->HandleProcessPacket(setreturn); break; } + case kWaitForAsyncCallback: { + break; + } case kReturnReceived: { this->SwitchToState(kRecvPacketNumBytes); - std::swap(client_mode_, client_mode); - return RPCCode::kReturn; + status = RPCCode::kReturn; + break; } case kCopyAckReceived: { - std::swap(client_mode_, client_mode); - return RPCCode::kCopyAck; + status = RPCCode::kCopyAck; + break; } case kShutdownReceived: { - std::swap(client_mode_, client_mode); - return RPCCode::kShutdown; + status = RPCCode::kShutdown; } } } + + std::swap(async_server_mode_, async_server_mode); std::swap(client_mode_, client_mode); - return RPCCode::kNone; + return status; } /*! \brief Clear all the states in the Handler.*/ @@ -229,6 +243,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { kInitHeader, kRecvPacketNumBytes, kProcessPacket, + kWaitForAsyncCallback, kReturnReceived, kCopyAckReceived, kShutdownReceived @@ -239,6 +254,8 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { bool init_header_step_{0}; // Whether current handler is client or server mode. bool client_mode_{false}; + // Whether current handler is in the async server mode. + bool async_server_mode_{false}; // Internal arena support::Arena arena_; @@ -249,6 +266,11 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { CHECK_EQ(pending_request_bytes_, 0U) << "state=" << state; } + // need to actively flush the writer + // so the data get pushed out. + if (state_ == kWaitForAsyncCallback) { + flush_writer_(); + } state_ = state; CHECK(state != kInitHeader) << "cannot switch to init header"; @@ -389,41 +411,50 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { this->Read(&type_hint); size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8; - char* data_ptr; auto* sess = GetServingSession(); + // Return Copy Ack with the given data + auto fcopyack = [this](char* data_ptr, size_t num_bytes) { + RPCCode code = RPCCode::kCopyAck; + uint64_t packet_nbytes = sizeof(code) + num_bytes; + + this->Write(packet_nbytes); + this->Write(code); + this->WriteArray(data_ptr, num_bytes); + this->SwitchToState(kRecvPacketNumBytes); + }; + // When session is local, we can directly treat handle // as the cpu pointer without allocating a temp space. if (ctx.device_type == kDLCPU && sess->IsLocalSession() && DMLC_IO_NO_ENDIAN_SWAP) { - data_ptr = reinterpret_cast(handle) + offset; + char* data_ptr = reinterpret_cast(handle) + offset; + fcopyack(data_ptr, num_bytes); } else { - try { - data_ptr = this->ArenaAlloc(num_bytes); - sess->CopyFromRemote( - reinterpret_cast(handle), offset, - data_ptr, 0, - num_bytes, ctx, type_hint); - // endian aware handling - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(data_ptr, elem_bytes, num_bytes / elem_bytes); + char* data_ptr = this->ArenaAlloc(num_bytes); + + auto on_copy_complete = [this, elem_bytes, num_bytes, data_ptr, fcopyack]( + RPCCode status, TVMArgs args) { + if (status == RPCCode::kException) { + this->ReturnException(args.values[0].v_str); + this->SwitchToState(kRecvPacketNumBytes); + } else { + // endian aware handling + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(data_ptr, elem_bytes, num_bytes / elem_bytes); + } + fcopyack(data_ptr, num_bytes); } - } catch (const std::runtime_error &e) { - this->ReturnException(e.what()); - this->SwitchToState(kRecvPacketNumBytes); - return; - } + }; + + this->SwitchToState(kWaitForAsyncCallback); + sess->AsyncCopyFromRemote( + reinterpret_cast(handle), offset, + data_ptr, 0, + num_bytes, ctx, type_hint, + on_copy_complete); } - RPCCode code = RPCCode::kCopyAck; - uint64_t packet_nbytes = sizeof(code) + num_bytes; - - // Return Copy Ack - this->Write(packet_nbytes); - this->Write(code); - this->WriteArray(data_ptr, num_bytes); - - this->SwitchToState(kRecvPacketNumBytes); } void HandleCopyToRemote() { @@ -446,9 +477,11 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { char* dptr = reinterpret_cast(handle) + offset; this->ReadArray(dptr, num_bytes); - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(dptr, elem_bytes, num_bytes / elem_bytes); - } + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(dptr, elem_bytes, num_bytes / elem_bytes); + } + this->ReturnVoid(); + this->SwitchToState(kRecvPacketNumBytes); } else { char* temp_data = this->ArenaAlloc(num_bytes); this->ReadArray(temp_data, num_bytes); @@ -457,20 +490,23 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { dmlc::ByteSwap(temp_data, elem_bytes, num_bytes / elem_bytes); } - try { - sess->CopyToRemote( + auto on_copy_complete = [this](RPCCode status, TVMArgs args) { + if (status == RPCCode::kException) { + this->ReturnException(args.values[0].v_str); + this->SwitchToState(kRecvPacketNumBytes); + } else { + this->ReturnVoid(); + this->SwitchToState(kRecvPacketNumBytes); + } + }; + + this->SwitchToState(kWaitForAsyncCallback); + sess->AsyncCopyToRemote( temp_data, 0, reinterpret_cast(handle), offset, - num_bytes, ctx, type_hint); - } catch (const std::runtime_error &e) { - this->ReturnException(e.what()); - this->SwitchToState(kRecvPacketNumBytes); - return; - } + num_bytes, ctx, type_hint, + on_copy_complete); } - - this->ReturnVoid(); - this->SwitchToState(kRecvPacketNumBytes); } // Handle for packed call. @@ -480,16 +516,18 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { this->Read(&call_handle); TVMArgs args = RecvPackedSeq(); - try { - GetServingSession()->CallFunc( - reinterpret_cast(call_handle), - args.values, args.type_codes, args.size(), - [this](TVMArgs ret) { this->ReturnPackedSeq(ret); }); - } catch (const std::runtime_error& e) { - this->ReturnException(e.what()); - } - - this->SwitchToState(kRecvPacketNumBytes); + this->SwitchToState(kWaitForAsyncCallback); + GetServingSession()->AsyncCallFunc( + reinterpret_cast(call_handle), + args.values, args.type_codes, args.size(), + [this](RPCCode status, TVMArgs args) { + if (status == RPCCode::kException) { + this->ReturnException(args.values[0].v_str); + } else { + this->ReturnPackedSeq(args); + } + this->SwitchToState(kRecvPacketNumBytes); + }); } void HandleInitServer() { @@ -512,35 +550,39 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { << " server protocol=" << server_protocol_ver << ", client protocol=" << client_protocol_ver; + std::string constructor_name; + TVMArgs constructor_args = TVMArgs(nullptr, nullptr, 0); + if (args.size() == 0) { + constructor_name = "rpc.LocalSession"; serving_session_ = std::make_shared(); } else { - std::string constructor_name = args[0]; - auto* fconstructor = Registry::Get(constructor_name); - CHECK(fconstructor != nullptr) - << " Cannot find session constructor " << constructor_name; - TVMRetValue con_ret; - - try { - fconstructor->CallPacked( - TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1), &con_ret); - } catch (const dmlc::Error& e) { - LOG(FATAL) << "Server[" << name_ << "]:" - << " Error caught from session constructor " << constructor_name - << ":\n" << e.what(); - } + constructor_name = args[0].operator std::string(); + constructor_args = TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1); + } + + auto* fconstructor = Registry::Get(constructor_name); + CHECK(fconstructor != nullptr) + << " Cannot find session constructor " << constructor_name; + TVMRetValue con_ret; - CHECK_EQ(con_ret.type_code(), kTVMModuleHandle) - << "Server[" << name_ << "]:" - << " Constructor " << constructor_name - << " need to return an RPCModule"; - Module mod = con_ret; - std::string tkey = mod->type_key(); - CHECK_EQ(tkey, "rpc") - << "Constructor " << constructor_name << " to return an RPCModule"; - serving_session_ = RPCModuleGetSession(mod); + try { + fconstructor->CallPacked(constructor_args, &con_ret); + } catch (const dmlc::Error& e) { + LOG(FATAL) << "Server[" << name_ << "]:" + << " Error caught from session constructor " << constructor_name + << ":\n" << e.what(); } + CHECK_EQ(con_ret.type_code(), kTVMModuleHandle) + << "Server[" << name_ << "]:" + << " Constructor " << constructor_name + << " need to return an RPCModule"; + Module mod = con_ret; + std::string tkey = mod->type_key(); + CHECK_EQ(tkey, "rpc") + << "Constructor " << constructor_name << " to return an RPCModule"; + serving_session_ = RPCModuleGetSession(mod); this->ReturnVoid(); } catch (const std::runtime_error &e) { this->ReturnException(e.what()); @@ -549,6 +591,28 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { this->SwitchToState(kRecvPacketNumBytes); } + void HandleSyscallStreamSync() { + TVMArgs args = RecvPackedSeq(); + try { + TVMContext ctx = args[0]; + TVMStreamHandle handle = args[1]; + + this->SwitchToState(kWaitForAsyncCallback); + GetServingSession()->AsyncStreamWait( + ctx, handle, [this](RPCCode status, TVMArgs args) { + if (status == RPCCode::kException) { + this->ReturnException(args.values[0].v_str); + } else { + this->ReturnVoid(); + } + this->SwitchToState(kRecvPacketNumBytes); + }); + } catch (const std::runtime_error& e) { + this->ReturnException(e.what()); + this->SwitchToState(kRecvPacketNumBytes); + } + } + // Handler for special syscalls that have a specific RPCCode. template void SysCallHandler(F f) { @@ -572,6 +636,9 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { RPCSession* GetServingSession() const { CHECK(serving_session_ != nullptr) << "Need to call InitRemoteSession first before any further actions"; + CHECK(!serving_session_->IsAsync() || async_server_mode_) + << "Cannot host an async session in a non-Event driven server"; + return serving_session_.get(); } // Utility functions @@ -598,10 +665,13 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { std::string name_; // remote key std::string* remote_key_; + // function to flush the writer. + std::function flush_writer_; }; RPCCode RPCEndpoint::HandleUntilReturnEvent( - bool client_mode, RPCSession::FEncodeReturn setreturn) { + bool client_mode, + RPCSession::FEncodeReturn setreturn) { RPCCode code = RPCCode::kCallFunc; while (code != RPCCode::kReturn && code != RPCCode::kShutdown && @@ -624,15 +694,26 @@ RPCCode RPCEndpoint::HandleUntilReturnEvent( } } } - code = handler_->HandleNextEvent(client_mode, setreturn); + code = handler_->HandleNextEvent(client_mode, false, setreturn); } return code; } void RPCEndpoint::Init() { + // callback to flush the writer. + auto flush_writer = [this]() { + while (writer_.bytes_available() != 0) { + size_t n = writer_.ReadWithCallback([this](const void *data, size_t size) { + return channel_->Send(data, size); + }, writer_.bytes_available()); + if (n == 0) break; + } + }; + // Event handler handler_ = std::make_shared( - &reader_, &writer_, name_, &remote_key_); + &reader_, &writer_, name_, &remote_key_, flush_writer); + // Quick function to for syscall remote. syscall_remote_ = PackedFunc([this](TVMArgs all_args, TVMRetValue* rv) { std::lock_guard lock(mutex_); @@ -711,7 +792,7 @@ int RPCEndpoint::ServerAsyncIOEventHandler(const std::string& in_bytes, int even RPCCode code = RPCCode::kNone; if (in_bytes.length() != 0) { reader_.Write(in_bytes.c_str(), in_bytes.length()); - code = handler_->HandleNextEvent(false, [](TVMArgs) {}); + code = handler_->HandleNextEvent(false, true, [](TVMArgs) {}); } if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) { writer_.ReadWithCallback([this](const void *data, size_t size) { @@ -894,12 +975,6 @@ void RPCDevFreeData(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { handler->GetDeviceAPI(ctx)->FreeDataSpace(ctx, ptr); } -void RPCDevStreamSync(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - TVMStreamHandle handle = args[1]; - handler->GetDeviceAPI(ctx)->StreamSync(ctx, handle); -} - void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { void* from = args[0]; uint64_t from_offset = args[1]; @@ -935,12 +1010,14 @@ void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) { case RPCCode::kDevGetAttr: SysCallHandler(RPCDevGetAttr); break; case RPCCode::kDevAllocData: SysCallHandler(RPCDevAllocData); break; case RPCCode::kDevFreeData: SysCallHandler(RPCDevFreeData); break; - case RPCCode::kDevStreamSync: SysCallHandler(RPCDevStreamSync); break; + case RPCCode::kDevStreamSync: this->HandleSyscallStreamSync(); break; case RPCCode::kCopyAmongRemote: SysCallHandler(RPCCopyAmongRemote); break; default: LOG(FATAL) << "Unknown event " << static_cast(code); } - CHECK_EQ(state_, kRecvPacketNumBytes); + if (state_ != kWaitForAsyncCallback) { + CHECK_EQ(state_, kRecvPacketNumBytes); + } } /*! diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc index 351a9896d899..9d1fb7246524 100644 --- a/src/runtime/rpc/rpc_local_session.cc +++ b/src/runtime/rpc/rpc_local_session.cc @@ -31,21 +31,15 @@ namespace runtime { RPCSession::PackedFuncHandle LocalSession::GetFunction(const std::string& name) { - PackedFunc pf = this->GetFunctionInternal(name); - // return raw handl because the remote need to explicitly manage it. - if (pf != nullptr) return new PackedFunc(pf); - return nullptr; + if (auto* fp = tvm::runtime::Registry::Get(name)) { + // return raw handle because the remote need to explicitly manage it. + return new PackedFunc(*fp); + } else { + return nullptr; + } } -void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, - const TVMValue* arg_values, - const int* arg_type_codes, - int num_args, - const FEncodeReturn& encode_return) { - auto* pf = static_cast(func); - TVMRetValue rv; - - pf->CallPacked(TVMArgs(arg_values, arg_type_codes, num_args), &rv); +void LocalSession::EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_return) { int rv_tcode = rv.type_code(); // return value encoding. @@ -84,6 +78,17 @@ void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, } } +void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + const FEncodeReturn& encode_return) { + auto* pf = static_cast(func); + TVMRetValue rv; + pf->CallPacked(TVMArgs(arg_values, arg_type_codes, num_args), &rv); + this->EncodeReturn(std::move(rv), encode_return); +} + void LocalSession::CopyToRemote(void* from, size_t from_offset, void* to, @@ -134,15 +139,6 @@ DeviceAPI* LocalSession::GetDeviceAPI(TVMContext ctx, bool allow_missing) { return DeviceAPI::Get(ctx, allow_missing); } -PackedFunc LocalSession::GetFunctionInternal(const std::string& name) { - auto* fp = tvm::runtime::Registry::Get(name); - if (fp != nullptr) { - return *fp; - } else { - return nullptr; - } -} - TVM_REGISTER_GLOBAL("rpc.LocalSession") .set_body_typed([]() { return CreateRPCSessionModule(std::make_shared()); diff --git a/src/runtime/rpc/rpc_local_session.h b/src/runtime/rpc/rpc_local_session.h index 3b6e7d8ea6f0..ff0caa4ead43 100644 --- a/src/runtime/rpc/rpc_local_session.h +++ b/src/runtime/rpc/rpc_local_session.h @@ -28,6 +28,7 @@ #include #include #include +#include #include "rpc_session.h" namespace tvm { @@ -40,13 +41,13 @@ namespace runtime { class LocalSession : public RPCSession { public: // function overrides - PackedFuncHandle GetFunction(const std::string& name) final; + PackedFuncHandle GetFunction(const std::string& name) override; void CallFunc(PackedFuncHandle func, const TVMValue* arg_values, const int* arg_type_codes, int num_args, - const FEncodeReturn& fencode_return) final; + const FEncodeReturn& fencode_return) override; void CopyToRemote(void* from, size_t from_offset, @@ -54,7 +55,7 @@ class LocalSession : public RPCSession { size_t to_offset, size_t nbytes, TVMContext ctx_to, - DLDataType type_hint) final; + DLDataType type_hint) override; void CopyFromRemote(void* from, size_t from_offset, @@ -62,23 +63,23 @@ class LocalSession : public RPCSession { size_t to_offset, size_t nbytes, TVMContext ctx_from, - DLDataType type_hint) final; + DLDataType type_hint) override; - void FreeHandle(void* handle, int type_code) final; + void FreeHandle(void* handle, int type_code) override; - DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) final; + DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) override; - bool IsLocalSession() const final { + bool IsLocalSession() const override { return true; } protected: /*! - * \brief Internal implementation of GetFunction. - * \param name The name of the function. - * \return The corresponding PackedFunc. + * \brief internal encode return fucntion. + * \param rv The return value. + * \param encode_return The encoding function. */ - virtual PackedFunc GetFunctionInternal(const std::string& name); + void EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_return); }; } // namespace runtime diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index dd0afa0145d2..d07aa740692a 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -30,6 +30,93 @@ namespace tvm { namespace runtime { +bool RPCSession::IsAsync() const { + return false; +} + +void RPCSession::SendException(FAsyncCallback callback, const char* msg) { + TVMValue value; + value.v_str = msg; + int32_t tcode = kTVMStr; + callback(RPCCode::kException, TVMArgs(&value, &tcode, 1)); +} + +void RPCSession::AsyncCallFunc(PackedFuncHandle func, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + FAsyncCallback callback) { + try { + this->CallFunc(func, arg_values, arg_type_codes, num_args, + [&callback](TVMArgs args) { + callback(RPCCode::kReturn, args); + }); + } catch (const std::runtime_error& e) { + this->SendException(callback, e.what()); + } +} + + +void RPCSession::AsyncCopyToRemote(void* local_from, + size_t local_from_offset, + void* remote_to, + size_t remote_to_offset, + size_t nbytes, + TVMContext remote_ctx_to, + DLDataType type_hint, + RPCSession::FAsyncCallback callback) { + TVMValue value; + int32_t tcode = kTVMNullptr; + value.v_handle = nullptr; + + try { + this->CopyToRemote(local_from, local_from_offset, + remote_to, remote_to_offset, + nbytes, remote_ctx_to, type_hint); + callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); + } catch (const std::runtime_error& e) { + this->SendException(callback, e.what()); + } +} + +void RPCSession::AsyncCopyFromRemote(void* remote_from, + size_t remote_from_offset, + void* local_to, + size_t local_to_offset, + size_t nbytes, + TVMContext remote_ctx_from, + DLDataType type_hint, + RPCSession::FAsyncCallback callback) { + TVMValue value; + int32_t tcode = kTVMNullptr; + value.v_handle = nullptr; + + try { + this->CopyFromRemote(remote_from, remote_from_offset, + local_to, local_to_offset, + nbytes, remote_ctx_from, type_hint); + callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); + } catch (const std::runtime_error& e) { + this->SendException(callback, e.what()); + } +} + +void RPCSession::AsyncStreamWait(TVMContext ctx, + TVMStreamHandle stream, + RPCSession::FAsyncCallback callback) { + TVMValue value; + int32_t tcode = kTVMNullptr; + value.v_handle = nullptr; + + try { + this->GetDeviceAPI(ctx)->StreamSync(ctx, stream); + callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); + } catch (const std::runtime_error& e) { + this->SendException(callback, e.what()); + } +} + + class RPCSessTable { public: static constexpr int kMaxRPCSession = 32; diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index e7e4433b1867..7ea1eb9003ac 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -30,6 +30,7 @@ #include #include #include +#include "rpc_protocol.h" namespace tvm { namespace runtime { @@ -58,7 +59,6 @@ class RPCSession { * \brief Callback to send an encoded return values via encode_args. * * \param encode_args The arguments that we can encode the return values into. - * \param ret_tcode The actual remote type code of the return value. * * Encoding convention (as list of arguments): * - str/float/int/byte: [tcode: int, value: TVMValue] value follows PackedFunc convention. @@ -69,6 +69,14 @@ class RPCSession { */ using FEncodeReturn = std::function; + /*! + * \brief Callback to send an encoded return values via encode_args. + * + * \param status The return status, can be RPCCode::kReturn or RPCCode::kException. + * \param encode_args The arguments that we can encode the return values into. + */ + using FAsyncCallback = std::function; + /*! \brief Destructor.*/ virtual ~RPCSession() {} @@ -189,6 +197,98 @@ class RPCSession { */ virtual bool IsLocalSession() const = 0; + // Asynchrous variant of API + // These APIs are used by the RPC server to allow sessions that + // have special implementations for the async functions. + // + // In the async APIs, an exception is returned by the passing + // async_error=true, encode_args=[error_msg]. + + /*! + * \brief Whether the session is async. + * + * If the session is not async, its Aync implementations + * simply calls into the their synchronize counterparts, + * and the callback is guaranteed to be called before the async function finishes. + * + * \return the async state. + * + * \note We can only use async session in an Event driven RPC server. + */ + virtual bool IsAsync() const; + + /*! + * \brief Asynchrously call func. + * \param func The function handle. + * \param arg_values The argument values. + * \param arg_type_codes the type codes of the argument. + * \param num_args Number of arguments. + * + * \param callback The callback to pass the return value or exception. + */ + virtual void AsyncCallFunc(PackedFuncHandle func, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + FAsyncCallback callback); + + /*! + * \brief Asynchrous version of CopyToRemote. + * + * \param local_from The source host data. + * \param local_from_offset The byte offeset in the from. + * \param remote_to The target array. + * \param remote_to_offset The byte offset in the to. + * \param nbytes The size of the memory in bytes. + * \param remote_ctx_to The target context. + * \param type_hint Hint of content data type. + * + * \param on_complete The callback to signal copy complete. + * \note All the allocated memory in local_from, and remote_to + * must stay alive until on_compelete is called. + */ + virtual void AsyncCopyToRemote(void* local_from, + size_t local_from_offset, + void* remote_to, + size_t remote_to_offset, + size_t nbytes, + TVMContext remote_ctx_to, + DLDataType type_hint, + FAsyncCallback on_complete); + + /*! + * \brief Asynchrous version of CopyFromRemote. + * + * \param remote_from The source host data. + * \param remote_from_offset The byte offeset in the from. + * \param to The target array. + * \param to_offset The byte offset in the to. + * \param nbytes The size of the memory in bytes. + * \param remote_ctx_from The source context in the remote. + * \param type_hint Hint of content data type. + * + * \param on_complete The callback to signal copy complete. + * \note All the allocated memory in remote_from, and local_to + * must stay alive until on_compelete is called. + */ + virtual void AsyncCopyFromRemote(void* remote_from, + size_t remote_from_offset, + void* local_to, + size_t local_to_offset, + size_t nbytes, + TVMContext remote_ctx_from, + DLDataType type_hint, + FAsyncCallback on_complete); + /*! + * \brief Asynchrously wait for all events in ctx, stream compeletes. + * \param ctx The device context. + * \param stream The stream to wait on. + * \param on_complete The callback to signal copy complete. + */ + virtual void AsyncStreamWait(TVMContext ctx, + TVMStreamHandle stream, + FAsyncCallback on_compelte); + /*! * \return The session table index of the session. */ @@ -203,6 +303,13 @@ class RPCSession { */ static std::shared_ptr Get(int table_index); + protected: + /*! + * \brief Send an exception to the callback. + * \param msg The exception message. + */ + void SendException(FAsyncCallback callback, const char* msg); + private: /*! \brief index of this session in RPC session table */ int table_index_{0};