From 9cc9b5d4f940ace68d1745833382ef91b1e4e458 Mon Sep 17 00:00:00 2001 From: zhao liwei Date: Wed, 27 Jan 2021 14:38:35 +0800 Subject: [PATCH] refactor: remove the dependency of tls memory for some member functions in message_ex (#739) --- include/dsn/tool-api/rpc_message.h | 1 + include/dsn/utility/transient_memory.h | 11 +------ src/runtime/rpc/rpc_message.cpp | 20 +++++------- src/runtime/test/rpc_message.cpp | 27 ++++++++++++++++ src/utils/transient_memory.cpp | 44 +------------------------- 5 files changed, 38 insertions(+), 65 deletions(-) diff --git a/include/dsn/tool-api/rpc_message.h b/include/dsn/tool-api/rpc_message.h index 6faea89b64..a375aad282 100644 --- a/include/dsn/tool-api/rpc_message.h +++ b/include/dsn/tool-api/rpc_message.h @@ -176,6 +176,7 @@ class message_ex : public ref_counter, public extensible_object /// The returned message: /// - msg->buffers[0] = message_header /// - msg->buffers[1] = data + /// NOTE: the reference counter of returned message_ex is not added in this function DSN_API static message_ex *create_receive_message_with_standalone_header(const blob &data); /// copy message without client information, it will not reply diff --git a/include/dsn/utility/transient_memory.h b/include/dsn/utility/transient_memory.h index 196ecdbb8a..6796ba0d43 100644 --- a/include/dsn/utility/transient_memory.h +++ b/include/dsn/utility/transient_memory.h @@ -80,13 +80,4 @@ void tls_trans_mem_init(size_t default_per_block_bytes); // "tls_trans_mem_next" should be used together with "tls_trans_mem_commit" void tls_trans_mem_next(void **ptr, size_t *sz, size_t min_size); void tls_trans_mem_commit(size_t use_size); - -// allocate a blob, the size is "sz" -blob tls_trans_mem_alloc_blob(size_t sz); - -// allocate memory -void *tls_trans_malloc(size_t sz); - -// free memory, ptr shouldn't be null -void tls_trans_free(void *ptr); -} +} // namespace dsn diff --git a/src/runtime/rpc/rpc_message.cpp b/src/runtime/rpc/rpc_message.cpp index 9784e7de6b..49b7b763fc 100644 --- a/src/runtime/rpc/rpc_message.cpp +++ b/src/runtime/rpc/rpc_message.cpp @@ -148,13 +148,11 @@ message_ex *message_ex::create_received_request(dsn::task_code code, message_ex *message_ex::create_receive_message_with_standalone_header(const blob &data) { message_ex *msg = new message_ex(); - std::shared_ptr header_holder( - static_cast(dsn::tls_trans_malloc(sizeof(message_header))), - [](char *c) { dsn::tls_trans_free(c); }); - msg->header = reinterpret_cast(header_holder.get()); - memset(static_cast(msg->header), 0, sizeof(message_header)); + size_t header_size = sizeof(message_header); + std::string str(header_size, '\0'); + msg->header = reinterpret_cast(const_cast(str.data())); - msg->buffers.emplace_back(blob(std::move(header_holder), sizeof(message_header))); + msg->buffers.emplace_back(blob::create_from_bytes(std::move(str))); msg->buffers.push_back(data); msg->header->body_length = data.length(); @@ -168,13 +166,11 @@ message_ex *message_ex::create_receive_message_with_standalone_header(const blob message_ex *message_ex::copy_message_no_reply(const message_ex &old_msg) { message_ex *msg = new message_ex(); - std::shared_ptr header_holder( - static_cast(dsn::tls_trans_malloc(sizeof(message_header))), - [](char *c) { dsn::tls_trans_free(c); }); - msg->header = reinterpret_cast(header_holder.get()); - memset(static_cast(msg->header), 0, sizeof(message_header)); - msg->buffers.emplace_back(blob(std::move(header_holder), sizeof(message_header))); + size_t header_size = sizeof(message_header); + std::string str(header_size, '\0'); + msg->header = reinterpret_cast(const_cast(str.data())); + msg->buffers.emplace_back(blob::create_from_bytes(std::move(str))); if (old_msg.buffers.size() == 1) { // if old_msg only has header, consider its header as data msg->buffers.emplace_back(old_msg.buffers[0]); diff --git a/src/runtime/test/rpc_message.cpp b/src/runtime/test/rpc_message.cpp index 9dc28b4b34..5da64f4df9 100644 --- a/src/runtime/test/rpc_message.cpp +++ b/src/runtime/test/rpc_message.cpp @@ -195,3 +195,30 @@ TEST(rpc_message, restore_read) msg->restore_read(); } } + +TEST(rpc_message, create_receive_message_with_standalone_header) +{ + auto data = blob::create_from_bytes("10086"); + + message_ptr msg = message_ex::create_receive_message_with_standalone_header(data); + ASSERT_EQ(msg->buffers.size(), 2); + ASSERT_EQ(0, strcmp(msg->buffers[1].data(), data.data())); + ASSERT_EQ(msg->header->body_length, data.length()); +} + +TEST(rpc_message, copy_message_no_reply) +{ + auto data = blob::create_from_bytes("10086"); + message_ptr old_msg = message_ex::create_receive_message_with_standalone_header(data); + old_msg->local_rpc_code = RPC_CODE_FOR_TEST; + + auto msg = message_ex::copy_message_no_reply(*old_msg); + ASSERT_EQ(msg->buffers.size(), old_msg->buffers.size()); + ASSERT_EQ(0, strcmp(msg->buffers[1].data(), old_msg->buffers[1].data())); + ASSERT_EQ(msg->header->body_length, old_msg->header->body_length); + ASSERT_EQ(msg->local_rpc_code, old_msg->local_rpc_code); + + // add_ref was called in message_ex::copy_message_no_reply for msg + // so we only need to call release_ref here. + msg->release_ref(); +} diff --git a/src/utils/transient_memory.cpp b/src/utils/transient_memory.cpp index 6137bd47b5..dbee2f6021 100644 --- a/src/utils/transient_memory.cpp +++ b/src/utils/transient_memory.cpp @@ -98,46 +98,4 @@ void tls_trans_mem_commit(size_t use_size) tls_trans_memory.remain_bytes -= use_size; tls_trans_memory.committed = true; } - -blob tls_trans_mem_alloc_blob(size_t sz) -{ - void *ptr; - size_t sz2; - tls_trans_mem_next(&ptr, &sz2, sz); - - ::dsn::blob buffer((*::dsn::tls_trans_memory.block), - (int)((char *)(ptr) - ::dsn::tls_trans_memory.block->get()), - (int)sz); - - tls_trans_mem_commit(sz); - return buffer; -} - -void *tls_trans_malloc(size_t sz) -{ - sz += sizeof(std::shared_ptr) + sizeof(uint32_t); - void *ptr; - size_t sz2; - tls_trans_mem_next(&ptr, &sz2, sz); - - // add ref - new (ptr) std::shared_ptr(*::dsn::tls_trans_memory.block); - - // add magic - *(uint32_t *)((char *)(ptr) + sizeof(std::shared_ptr)) = 0xdeadbeef; - - tls_trans_mem_commit(sz); - - return (void *)((char *)(ptr) + sizeof(std::shared_ptr) + sizeof(uint32_t)); -} - -void tls_trans_free(void *ptr) -{ - ptr = (void *)((char *)ptr - sizeof(uint32_t)); - // invalid transient memory block - assert(*(uint32_t *)(ptr) == 0xdeadbeef); - - ptr = (void *)((char *)ptr - sizeof(std::shared_ptr)); - ((std::shared_ptr *)(ptr))->~shared_ptr(); -} -} +} // namespace dsn