diff --git a/include/dsn/tool-api/network.h b/include/dsn/tool-api/network.h index 6f8962612a..c261b517ff 100644 --- a/include/dsn/tool-api/network.h +++ b/include/dsn/tool-api/network.h @@ -207,6 +207,7 @@ class rpc_session : public ref_counter static join_point on_rpc_session_connected; static join_point on_rpc_session_disconnected; static join_point on_rpc_recv_message; + static join_point on_rpc_send_message; /*@}*/ public: rpc_session(connection_oriented_network &net, @@ -232,6 +233,11 @@ class rpc_session : public ref_counter bool cancel(message_ex *request); bool delay_recv(int delay_ms); bool on_recv_message(message_ex *msg, int delay_ms); + /// ret value: + /// true - pend succeed + /// false - pend failed + bool try_pend_message(message_ex *msg); + void clear_pending_messages(); /// interfaces for security authentication, /// you can ignore them if you don't enable auth @@ -275,7 +281,10 @@ class rpc_session : public ref_counter volatile session_state _connect_state; bool negotiation_succeed = false; - // TODO(zlw): add send pending message + // when the negotiation of a session isn't succeed, + // all messages are queued in _pending_messages. + // after connected, all of them are moved to "_messages" + std::vector _pending_messages; // messages are sent in batch, firstly all messages are linked together // in a doubly-linked list "_messages". diff --git a/src/runtime/rpc/network.cpp b/src/runtime/rpc/network.cpp index ceb4da883f..53413d1e06 100644 --- a/src/runtime/rpc/network.cpp +++ b/src/runtime/rpc/network.cpp @@ -24,7 +24,6 @@ * THE SOFTWARE. */ -#include "runtime/security/negotiation_utils.h" #include "message_parser_manager.h" #include "runtime/rpc/rpc_engine.h" @@ -40,9 +39,12 @@ namespace dsn { rpc_session::on_rpc_session_disconnected("rpc.session.disconnected"); /*static*/ join_point rpc_session::on_rpc_recv_message("rpc.session.recv.message"); +/*static*/ join_point + rpc_session::on_rpc_send_message("rpc.session.send.message"); rpc_session::~rpc_session() { + clear_pending_messages(); clear_send_queue(false); { @@ -250,9 +252,14 @@ int rpc_session::prepare_parser() void rpc_session::send_message(message_ex *msg) { msg->add_ref(); // released in on_send_completed - msg->io_session = this; + // ignore msg if join point return false + if (dsn_unlikely(!on_rpc_send_message.execute(msg, true))) { + msg->release_ref(); + return; + } + dassert(_parser, "parser should not be null when send"); _parser->prepare_on_send(msg); @@ -262,11 +269,7 @@ void rpc_session::send_message(message_ex *msg) msg->dl.insert_before(&_messages); ++_message_count; - // Attention: here we only allow two cases to send message: - // case 1: session's state is SS_CONNECTED - // case 2: session is sending negotiation message - if ((SS_CONNECTED == _connect_state || security::is_negotiation_message(msg->rpc_code())) && - !_is_sending_next) { + if ((SS_CONNECTED == _connect_state) && !_is_sending_next) { _is_sending_next = true; sig = _message_sent + 1; unlink_message_for_send(); @@ -397,7 +400,7 @@ bool rpc_session::on_recv_message(message_ex *msg, int delay_ms) msg->io_session = this; // ignore msg if join point return false - if (!on_rpc_recv_message.execute(msg, true)) { + if (dsn_unlikely(!on_rpc_recv_message.execute(msg, true))) { delete msg; return false; } @@ -437,12 +440,45 @@ bool rpc_session::on_recv_message(message_ex *msg, int delay_ms) return true; } -void rpc_session::set_negotiation_succeed() +bool rpc_session::try_pend_message(message_ex *msg) +{ + // if negotiation is not succeed, we should pend msg, + // in order to resend it when the negotiation is succeed + if (dsn_unlikely(!negotiation_succeed)) { + utils::auto_lock l(_lock); + if (!negotiation_succeed) { + msg->add_ref(); + _pending_messages.push_back(msg); + return true; + } + } + return false; +} + +void rpc_session::clear_pending_messages() { utils::auto_lock l(_lock); - negotiation_succeed = true; + for (auto msg : _pending_messages) { + msg->release_ref(); + } + _pending_messages.clear(); +} + +void rpc_session::set_negotiation_succeed() +{ + std::vector swapped_pending_msgs; + { + utils::auto_lock l(_lock); + negotiation_succeed = true; + + _pending_messages.swap(swapped_pending_msgs); + } - // todo(zlw): resend pending messages when negotiation is succeed + // resend the pending messages + for (auto msg : swapped_pending_msgs) { + send_message(msg); + msg->release_ref(); + } } bool rpc_session::is_negotiation_succeed() const @@ -451,7 +487,7 @@ bool rpc_session::is_negotiation_succeed() const // Because negotiation_succeed only transfered from false to true. // So if it is true now, it will not change in the later. // But if it is false now, maybe it will change soon. So we should use lock to protect it. - if (negotiation_succeed) { + if (dsn_likely(negotiation_succeed)) { return negotiation_succeed; } else { utils::auto_lock l(_lock); diff --git a/src/runtime/security/negotiation_service.cpp b/src/runtime/security/negotiation_service.cpp index 388a2d0477..ca1568c5ff 100644 --- a/src/runtime/security/negotiation_service.cpp +++ b/src/runtime/security/negotiation_service.cpp @@ -27,13 +27,18 @@ namespace dsn { namespace security { DSN_DECLARE_bool(enable_auth); +inline bool is_negotiation_message(dsn::task_code code) +{ + return code == RPC_NEGOTIATION || code == RPC_NEGOTIATION_ACK; +} + inline bool in_white_list(task_code code) { return is_negotiation_message(code) || fd::is_failure_detector_message(code); } negotiation_map negotiation_service::_negotiations; -zrwlock_nr negotiation_service::_lock; +utils::rw_lock_nr negotiation_service::_lock; negotiation_service::negotiation_service() : serverlet("negotiation_service") {} @@ -56,7 +61,7 @@ void negotiation_service::on_negotiation_request(negotiation_rpc rpc) server_negotiation *srv_negotiation = nullptr; { - zauto_read_lock l(_lock); + utils::auto_read_lock l(_lock); srv_negotiation = static_cast(_negotiations[rpc.dsn_request()->io_session].get()); } @@ -72,7 +77,7 @@ void negotiation_service::on_rpc_connected(rpc_session *session) std::unique_ptr nego = security::create_negotiation(session->is_client(), session); nego->start(); { - zauto_write_lock l(_lock); + utils::auto_write_lock l(_lock); _negotiations[session] = std::move(nego); } } @@ -80,7 +85,7 @@ void negotiation_service::on_rpc_connected(rpc_session *session) void negotiation_service::on_rpc_disconnected(rpc_session *session) { { - zauto_write_lock l(_lock); + utils::auto_write_lock l(_lock); _negotiations.erase(session); } } @@ -90,6 +95,12 @@ bool negotiation_service::on_rpc_recv_msg(message_ex *msg) return in_white_list(msg->rpc_code()) || msg->io_session->is_negotiation_succeed(); } +bool negotiation_service::on_rpc_send_msg(message_ex *msg) +{ + // if try_pend_message return true, it means the msg is pended to the resend message queue + return in_white_list(msg->rpc_code()) || !msg->io_session->try_pend_message(msg); +} + void init_join_point() { rpc_session::on_rpc_session_connected.put_back(negotiation_service::on_rpc_connected, @@ -97,6 +108,7 @@ void init_join_point() rpc_session::on_rpc_session_disconnected.put_back(negotiation_service::on_rpc_disconnected, "security"); rpc_session::on_rpc_recv_message.put_native(negotiation_service::on_rpc_recv_msg); + rpc_session::on_rpc_send_message.put_native(negotiation_service::on_rpc_send_msg); } } // namespace security } // namespace dsn diff --git a/src/runtime/security/negotiation_service.h b/src/runtime/security/negotiation_service.h index 868aa4f343..25462c1eae 100644 --- a/src/runtime/security/negotiation_service.h +++ b/src/runtime/security/negotiation_service.h @@ -20,7 +20,6 @@ #include "server_negotiation.h" #include -#include namespace dsn { namespace security { @@ -33,6 +32,7 @@ class negotiation_service : public serverlet, static void on_rpc_connected(rpc_session *session); static void on_rpc_disconnected(rpc_session *session); static bool on_rpc_recv_msg(message_ex *msg); + static bool on_rpc_send_msg(message_ex *msg); void open_service(); @@ -42,7 +42,7 @@ class negotiation_service : public serverlet, friend class utils::singleton; friend class negotiation_service_test; - static zrwlock_nr _lock; // [ + static utils::rw_lock_nr _lock; // [ static negotiation_map _negotiations; //] }; diff --git a/src/runtime/security/negotiation_utils.h b/src/runtime/security/negotiation_utils.h index 2c0f557a3a..27b8e59bf3 100644 --- a/src/runtime/security/negotiation_utils.h +++ b/src/runtime/security/negotiation_utils.h @@ -52,11 +52,5 @@ inline const char *enum_to_string(negotiation_status::type s) } DEFINE_TASK_CODE_RPC(RPC_NEGOTIATION, TASK_PRIORITY_COMMON, dsn::THREAD_POOL_DEFAULT) - -inline bool is_negotiation_message(dsn::task_code code) -{ - return code == RPC_NEGOTIATION || code == RPC_NEGOTIATION_ACK; -} - } // namespace security } // namespace dsn diff --git a/src/runtime/test/negotiation_service_test.cpp b/src/runtime/test/negotiation_service_test.cpp index f6fb0edf63..4b3d13dc9a 100644 --- a/src/runtime/test/negotiation_service_test.cpp +++ b/src/runtime/test/negotiation_service_test.cpp @@ -58,6 +58,11 @@ class negotiation_service_test : public testing::Test { return negotiation_service::instance().on_rpc_recv_msg(msg); } + + bool on_rpc_send_msg(message_ex *msg) + { + return negotiation_service::instance().on_rpc_send_msg(msg); + } }; TEST_F(negotiation_service_test, disable_auth) @@ -98,5 +103,32 @@ TEST_F(negotiation_service_test, on_rpc_recv_msg) ASSERT_EQ(test.return_value, on_rpc_recv_msg(msg)); } } + +TEST_F(negotiation_service_test, on_rpc_send_msg) +{ + struct + { + task_code rpc_code; + bool negotiation_succeed; + bool return_value; + } tests[] = {{RPC_NEGOTIATION, true, true}, + {RPC_NEGOTIATION_ACK, true, true}, + {fd::RPC_FD_FAILURE_DETECTOR_PING, true, true}, + {fd::RPC_FD_FAILURE_DETECTOR_PING_ACK, true, true}, + {RPC_NEGOTIATION, false, true}, + {RPC_HTTP_SERVICE, true, true}, + {RPC_HTTP_SERVICE, false, false}}; + + for (const auto &test : tests) { + message_ptr msg = dsn::message_ex::create_request(test.rpc_code, 0, 0); + auto sim_session = create_fake_session(); + msg->io_session = sim_session; + if (test.negotiation_succeed) { + sim_session->set_negotiation_succeed(); + } + + ASSERT_EQ(test.return_value, on_rpc_send_msg(msg)); + } +} } // namespace security } // namespace dsn