Skip to content
This repository has been archived by the owner on Jun 23, 2022. It is now read-only.

Commit

Permalink
fix(security): fix bug in negotiation_service::on_negotiation_request…
Browse files Browse the repository at this point in the history
… when rpc_session is closed (#652)
  • Loading branch information
levy5307 authored Nov 4, 2020
1 parent bf94cbd commit 1863cf9
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 54 deletions.
7 changes: 4 additions & 3 deletions src/runtime/security/client_negotiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "client_negotiation.h"
#include "negotiation_utils.h"
#include "negotiation_manager.h"

#include <boost/algorithm/string/join.hpp>
#include <dsn/dist/fmt_logging.h>
Expand All @@ -29,7 +30,7 @@ namespace security {
DSN_DECLARE_bool(mandatory_auth);
extern const std::set<std::string> supported_mechanisms;

client_negotiation::client_negotiation(rpc_session *session) : negotiation(session)
client_negotiation::client_negotiation(rpc_session_ptr session) : negotiation(session)
{
_name = fmt::format("CLIENT_NEGOTIATION(SERVER={})", _session->remote_address().to_string());
}
Expand Down Expand Up @@ -179,8 +180,8 @@ void client_negotiation::send(negotiation_status::type status, const blob &msg)
req->msg = msg;

negotiation_rpc rpc(std::move(req), RPC_NEGOTIATION);
rpc.call(_session->remote_address(), nullptr, [this, rpc](error_code err) mutable {
handle_response(err, std::move(rpc.response()));
rpc.call(_session->remote_address(), nullptr, [rpc](error_code err) mutable {
negotiation_manager::on_negotiation_response(err, rpc);
});
}

Expand Down
4 changes: 2 additions & 2 deletions src/runtime/security/client_negotiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ namespace security {
class client_negotiation : public negotiation
{
public:
client_negotiation(rpc_session *session);
client_negotiation(rpc_session_ptr session);

void start();
void handle_response(error_code err, const negotiation_response &&response);

private:
void handle_response(error_code err, const negotiation_response &&response);
void on_recv_mechanisms(const negotiation_response &resp);
void on_mechanism_selected(const negotiation_response &resp);
void on_challenge(const negotiation_response &resp);
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/security/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

#include "kinit_context.h"
#include "sasl_init.h"
#include "negotiation_service.h"
#include "negotiation_manager.h"

#include <dsn/dist/fmt_logging.h>
#include <dsn/utility/flags.h>
Expand Down
6 changes: 2 additions & 4 deletions src/runtime/security/negotiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ typedef rpc_holder<negotiation_request, negotiation_response> negotiation_rpc;
class negotiation
{
public:
negotiation(rpc_session *session)
negotiation(rpc_session_ptr session)
: _session(session), _status(negotiation_status::type::INVALID)
{
_sasl = create_sasl_wrapper(_session->is_client());
Expand All @@ -49,9 +49,7 @@ class negotiation
bool check_status(negotiation_status::type status, negotiation_status::type expected_status);

protected:
// The ownership of the negotiation instance is held by rpc_session.
// So negotiation keeps only a raw pointer.
rpc_session *_session;
rpc_session_ptr _session;
std::string _name;
negotiation_status::type _status;
std::string _selected_mechanism;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
// specific language governing permissions and limitations
// under the License.

#include "negotiation_service.h"
#include "negotiation_manager.h"
#include "negotiation_utils.h"
#include "server_negotiation.h"
#include "client_negotiation.h"

#include <dsn/utility/flags.h>
#include <dsn/tool-api/zlocks.h>
#include <dsn/dist/failure_detector/fd.code.definition.h>
#include <dsn/dist/fmt_logging.h>

namespace dsn {
namespace security {
Expand All @@ -38,80 +40,100 @@ inline bool in_white_list(task_code code)
return is_negotiation_message(code) || fd::is_failure_detector_message(code);
}

negotiation_map negotiation_service::_negotiations;
utils::rw_lock_nr negotiation_service::_lock;
negotiation_map negotiation_manager::_negotiations;
utils::rw_lock_nr negotiation_manager::_lock;

negotiation_service::negotiation_service() : serverlet("negotiation_service") {}
negotiation_manager::negotiation_manager() : serverlet("negotiation_manager") {}

void negotiation_service::open_service()
void negotiation_manager::open_service()
{
register_rpc_handler_with_rpc_holder(
RPC_NEGOTIATION, "Negotiation", &negotiation_service::on_negotiation_request);
RPC_NEGOTIATION, "Negotiation", &negotiation_manager::on_negotiation_request);
}

void negotiation_service::on_negotiation_request(negotiation_rpc rpc)
void negotiation_manager::on_negotiation_request(negotiation_rpc rpc)
{
dassert(!rpc.dsn_request()->io_session->is_client(),
"only server session receive negotiation request");
"only server session receives negotiation request");

// reply SASL_AUTH_DISABLE if auth is not enable
if (!security::FLAGS_enable_auth) {
rpc.response().status = negotiation_status::type::SASL_AUTH_DISABLE;
return;
}

server_negotiation *srv_negotiation = nullptr;
{
utils::auto_read_lock l(_lock);
srv_negotiation =
static_cast<server_negotiation *>(_negotiations[rpc.dsn_request()->io_session].get());
std::shared_ptr<negotiation> nego = get_negotiation(rpc);
if (nullptr != nego) {
server_negotiation *srv_negotiation = static_cast<server_negotiation *>(nego.get());
srv_negotiation->handle_request(rpc);
}
}

dassert(srv_negotiation != nullptr,
"negotiation is null for msg: {}",
rpc.dsn_request()->rpc_code().to_string());
srv_negotiation->handle_request(rpc);
void negotiation_manager::on_negotiation_response(error_code err, negotiation_rpc rpc)
{
dassert(rpc.dsn_request()->io_session->is_client(),
"only client session receives negotiation response");

std::shared_ptr<negotiation> nego = get_negotiation(rpc);
if (nullptr != nego) {
client_negotiation *cli_negotiation = static_cast<client_negotiation *>(nego.get());
cli_negotiation->handle_response(err, std::move(rpc.response()));
}
}

void negotiation_service::on_rpc_connected(rpc_session *session)
void negotiation_manager::on_rpc_connected(rpc_session *session)
{
std::unique_ptr<negotiation> nego = security::create_negotiation(session->is_client(), session);
std::shared_ptr<negotiation> nego = security::create_negotiation(session->is_client(), session);
nego->start();
{
utils::auto_write_lock l(_lock);
_negotiations[session] = std::move(nego);
}
}

void negotiation_service::on_rpc_disconnected(rpc_session *session)
void negotiation_manager::on_rpc_disconnected(rpc_session *session)
{
{
utils::auto_write_lock l(_lock);
_negotiations.erase(session);
}
}

bool negotiation_service::on_rpc_recv_msg(message_ex *msg)
bool negotiation_manager::on_rpc_recv_msg(message_ex *msg)
{
return !FLAGS_mandatory_auth || in_white_list(msg->rpc_code()) ||
msg->io_session->is_negotiation_succeed();
}

bool negotiation_service::on_rpc_send_msg(message_ex *msg)
bool negotiation_manager::on_rpc_send_msg(message_ex *msg)
{
// if try_pend_message return true, it means the msg is pended to the resend message queue
return !FLAGS_mandatory_auth || in_white_list(msg->rpc_code()) ||
!msg->io_session->try_pend_message(msg);
}

std::shared_ptr<negotiation> negotiation_manager::get_negotiation(negotiation_rpc rpc)
{
utils::auto_read_lock l(_lock);
auto it = _negotiations.find(rpc.dsn_request()->io_session);
if (it == _negotiations.end()) {
ddebug_f("negotiation was removed for msg: {}, {}",
rpc.dsn_request()->rpc_code().to_string(),
rpc.remote_address().to_string());
return nullptr;
}

return it->second;
}

void init_join_point()
{
rpc_session::on_rpc_session_connected.put_back(negotiation_service::on_rpc_connected,
rpc_session::on_rpc_session_connected.put_back(negotiation_manager::on_rpc_connected,
"security");
rpc_session::on_rpc_session_disconnected.put_back(negotiation_service::on_rpc_disconnected,
rpc_session::on_rpc_session_disconnected.put_back(negotiation_manager::on_rpc_disconnected,
"security");
rpc_session::on_rpc_recv_message.put_native(negotiation_service::on_rpc_recv_msg);
rpc_session::on_rpc_send_message.put_native(negotiation_service::on_rpc_send_msg);
rpc_session::on_rpc_recv_message.put_native(negotiation_manager::on_rpc_recv_msg);
rpc_session::on_rpc_send_message.put_native(negotiation_manager::on_rpc_send_msg);
}
} // namespace security
} // namespace dsn
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,27 @@

namespace dsn {
namespace security {
typedef std::unordered_map<rpc_session *, std::unique_ptr<negotiation>> negotiation_map;
typedef std::unordered_map<rpc_session *, std::shared_ptr<negotiation>> negotiation_map;

class negotiation_service : public serverlet<negotiation_service>,
public utils::singleton<negotiation_service>
class negotiation_manager : public serverlet<negotiation_manager>,
public utils::singleton<negotiation_manager>
{
public:
static void on_rpc_connected(rpc_session *session);
static void on_rpc_disconnected(rpc_session *session);
static bool on_rpc_recv_msg(message_ex *msg);
static bool on_rpc_send_msg(message_ex *msg);
static void on_negotiation_response(error_code err, negotiation_rpc rpc);

void open_service();

private:
negotiation_service();
negotiation_manager();
void on_negotiation_request(negotiation_rpc rpc);
friend class utils::singleton<negotiation_service>;
friend class negotiation_service_test;
static std::shared_ptr<negotiation> get_negotiation(negotiation_rpc rpc);

friend class utils::singleton<negotiation_manager>;
friend class negotiation_manager_test;

static utils::rw_lock_nr _lock; // [
static negotiation_map _negotiations;
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/security/server_negotiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace security {
DSN_DECLARE_string(service_fqdn);
DSN_DECLARE_string(service_name);

server_negotiation::server_negotiation(rpc_session *session) : negotiation(session)
server_negotiation::server_negotiation(rpc_session_ptr session) : negotiation(session)
{
_name = fmt::format("SERVER_NEGOTIATION(CLIENT={})", _session->remote_address().to_string());
}
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/security/server_negotiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ extern const std::set<std::string> supported_mechanisms;
class server_negotiation : public negotiation
{
public:
server_negotiation(rpc_session *session);
server_negotiation(rpc_session_ptr session);

void start();
void handle_request(negotiation_rpc rpc);
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/service_api_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
#include "runtime/rpc/rpc_engine.h"
#include "runtime/task/task_engine.h"
#include "utils/coredump.h"
#include "runtime/security/negotiation_service.h"
#include "runtime/security/negotiation_manager.h"

namespace dsn {
namespace security {
Expand Down Expand Up @@ -562,7 +562,7 @@ service_app *service_app::new_service_app(const std::string &type,
service_app::service_app(const dsn::service_app_info *info) : _info(info), _started(false)
{
security::negotiation_service::instance().open_service();
security::negotiation_manager::instance().open_service();
}
const service_app_info &service_app::info() const { return *_info; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

#include "runtime/security/negotiation_service.h"
#include "runtime/security/negotiation_manager.h"
#include "runtime/security/negotiation_utils.h"
#include "runtime/rpc/network.sim.h"

Expand All @@ -29,7 +29,7 @@ namespace security {
DSN_DECLARE_bool(enable_auth);
DSN_DECLARE_bool(mandatory_auth);

class negotiation_service_test : public testing::Test
class negotiation_manager_test : public testing::Test
{
public:
negotiation_rpc create_fake_rpc()
Expand All @@ -52,21 +52,21 @@ class negotiation_service_test : public testing::Test

void on_negotiation_request(negotiation_rpc rpc)
{
negotiation_service::instance().on_negotiation_request(rpc);
negotiation_manager::instance().on_negotiation_request(rpc);
}

bool on_rpc_recv_msg(message_ex *msg)
{
return negotiation_service::instance().on_rpc_recv_msg(msg);
return negotiation_manager::instance().on_rpc_recv_msg(msg);
}

bool on_rpc_send_msg(message_ex *msg)
{
return negotiation_service::instance().on_rpc_send_msg(msg);
return negotiation_manager::instance().on_rpc_send_msg(msg);
}
};

TEST_F(negotiation_service_test, disable_auth)
TEST_F(negotiation_manager_test, disable_auth)
{
RPC_MOCKING(negotiation_rpc)
{
Expand All @@ -78,7 +78,7 @@ TEST_F(negotiation_service_test, disable_auth)
}
}

TEST_F(negotiation_service_test, on_rpc_recv_msg)
TEST_F(negotiation_manager_test, on_rpc_recv_msg)
{
struct
{
Expand Down Expand Up @@ -107,7 +107,7 @@ TEST_F(negotiation_service_test, on_rpc_recv_msg)
}
}

TEST_F(negotiation_service_test, on_rpc_send_msg)
TEST_F(negotiation_manager_test, on_rpc_send_msg)
{
struct
{
Expand Down

0 comments on commit 1863cf9

Please sign in to comment.