From a34b8ccc2b6bffe147e9e07bdb4a11ab262a62d4 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 6 May 2020 07:50:41 -0700 Subject: [PATCH] [RPC] Fix the multihop cpu case (#5522) --- src/runtime/rpc/rpc_endpoint.cc | 27 ++++++++++++++--------- src/runtime/rpc/rpc_local_session.h | 4 ++++ src/runtime/rpc/rpc_session.h | 12 ++++++++++ tests/python/unittest/test_runtime_rpc.py | 9 ++++++-- 4 files changed, 39 insertions(+), 13 deletions(-) diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index 916ecaee8a78..8a7f11cbdea6 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -390,20 +390,18 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8; char* data_ptr; + auto* sess = GetServingSession(); - if (ctx.device_type == kDLCPU) { + // When session is local, we can directly treat handle + // as the cpu pointer without allocating a temp space. + if (ctx.device_type == kDLCPU && + sess->IsLocalSession() && + DMLC_IO_NO_ENDIAN_SWAP) { data_ptr = reinterpret_cast(handle) + offset; - // endian aware handling - if (!DMLC_IO_NO_ENDIAN_SWAP) { - char* temp = this->ArenaAlloc(num_bytes); - std::memcpy(temp, data_ptr, num_bytes); - dmlc::ByteSwap(temp, elem_bytes, num_bytes / elem_bytes); - data_ptr = temp; - } } else { try { data_ptr = this->ArenaAlloc(num_bytes); - GetServingSession()->CopyFromRemote( + sess->CopyFromRemote( reinterpret_cast(handle), offset, data_ptr, 0, num_bytes, ctx, type_hint); @@ -440,8 +438,11 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { this->Read(&type_hint); size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8; + auto* sess = GetServingSession(); - if (ctx.device_type == kDLCPU) { + // When session is local, we can directly treat handle + // as the cpu pointer without allocating a temp space. + if (ctx.device_type == kDLCPU && sess->IsLocalSession()) { char* dptr = reinterpret_cast(handle) + offset; this->ReadArray(dptr, num_bytes); @@ -457,7 +458,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { } try { - GetServingSession()->CopyToRemote( + sess->CopyToRemote( temp_data, 0, reinterpret_cast(handle), offset, num_bytes, ctx, type_hint); @@ -1046,6 +1047,10 @@ class RPCClientSession : public RPCSession, return this; } + bool IsLocalSession() const final { + return false; + } + private: std::shared_ptr endpoint_; }; diff --git a/src/runtime/rpc/rpc_local_session.h b/src/runtime/rpc/rpc_local_session.h index ebb3ea11c50e..3b6e7d8ea6f0 100644 --- a/src/runtime/rpc/rpc_local_session.h +++ b/src/runtime/rpc/rpc_local_session.h @@ -68,6 +68,10 @@ class LocalSession : public RPCSession { DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) final; + bool IsLocalSession() const final { + return true; + } + protected: /*! * \brief Internal implementation of GetFunction. diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index a715e7b0b20c..e7e4433b1867 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -177,6 +177,18 @@ class RPCSession { */ virtual DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) = 0; + /*! + * \brief Whether the session is a local session and we can directly + * the data handle returned by the session and treat it as pointer + * to the local memory. + * + * This information is useful for RPC server to directly copy into the + * local memory without creating a temporary buffer. + * + * \return Whether it is a local session. + */ + virtual bool IsLocalSession() const = 0; + /*! * \return The session table index of the session. */ diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index 4e7921b0ae7a..58cd8e2ea873 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -167,8 +167,13 @@ def test_rpc_remote_module(): B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') s = te.create_schedule(B.op) - server = rpc.Server("localhost") - client = rpc.connect(server.host, server.port) + server0 = rpc.Server("localhost", key="x0") + server1 = rpc.Server("localhost", key="x1") + + client = rpc.connect( + server0.host, server0.port, key="x0", + session_constructor_args=[ + "rpc.Connect", server1.host, server1.port, "x1"]) def check_remote(remote): if not tvm.runtime.enabled("llvm"):