Skip to content
This repository has been archived by the owner on Jun 23, 2022. It is now read-only.

feat(security): client_negotiation handle mechanism selected response #612

Merged
merged 22 commits into from
Sep 4, 2020
Merged
1 change: 1 addition & 0 deletions include/dsn/utility/error_code.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,5 @@ DEFINE_ERR_CODE(ERR_UNAUTHENTICATED)
DEFINE_ERR_CODE(ERR_KRB5_INTERNAL)

DEFINE_ERR_CODE(ERR_SASL_INTERNAL)
DEFINE_ERR_CODE(ERR_NOT_COMPLEMENTED)
levy5307 marked this conversation as resolved.
Show resolved Hide resolved
} // namespace dsn
43 changes: 37 additions & 6 deletions src/runtime/security/client_negotiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void client_negotiation::handle_response(error_code err, const negotiation_respo
on_recv_mechanisms(response);
break;
case negotiation_status::type::SASL_SELECT_MECHANISMS:
// TBD(zlw)
on_mechanism_selected(response);
break;
case negotiation_status::type::SASL_INITIATE:
case negotiation_status::type::SASL_CHALLENGE_RESP:
Expand All @@ -79,11 +79,7 @@ void client_negotiation::handle_response(error_code err, const negotiation_respo

void client_negotiation::on_recv_mechanisms(const negotiation_response &resp)
{
if (resp.status != negotiation_status::type::SASL_LIST_MECHANISMS_RESP) {
dwarn_f("{}: get message({}) while expect({})",
_name,
enum_to_string(resp.status),
enum_to_string(negotiation_status::type::SASL_LIST_MECHANISMS_RESP));
if (!check_status(resp.status, negotiation_status::type::SASL_LIST_MECHANISMS_RESP)) {
fail_negotiation();
return;
}
Expand Down Expand Up @@ -111,6 +107,41 @@ void client_negotiation::on_recv_mechanisms(const negotiation_response &resp)
select_mechanism(match_mechanism);
}

void client_negotiation::on_mechanism_selected(const negotiation_response &resp)
{
if (!check_status(resp.status, negotiation_status::type::SASL_SELECT_MECHANISMS_RESP)) {
fail_negotiation();
return;
}

// init client sasl
auto err_s = _sasl->init();
if (!err_s.is_ok()) {
dwarn_f("{}: initiaze sasl client failed, error = {}, reason = {}",
levy5307 marked this conversation as resolved.
Show resolved Hide resolved
_name,
err_s.code().to_string(),
err_s.description());
fail_negotiation();
return;
}

// start client sasl, and send `SASL_INITIATE` to `server_negotiation` if everything is ok
std::string start_output;
err_s = _sasl->start(_selected_mechanism, "", start_output);
if (err_s.is_ok() || ERR_NOT_COMPLEMENTED == err_s.code()) {
auto req = dsn::make_unique<negotiation_request>();
_status = req->status = negotiation_status::type::SASL_INITIATE;
req->msg = start_output;
send(std::move(req));
} else {
dwarn_f("{}: start sasl client failed, error = {}, reason = {}",
_name,
err_s.code().to_string(),
err_s.description());
fail_negotiation();
}
}

void client_negotiation::select_mechanism(const std::string &mechanism)
{
_selected_mechanism = mechanism;
Expand Down
1 change: 1 addition & 0 deletions src/runtime/security/client_negotiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class client_negotiation : public negotiation
private:
void handle_response(error_code err, const negotiation_response &&response);
void on_recv_mechanisms(const negotiation_response &resp);
void on_mechanism_selected(const negotiation_response &resp);

void list_mechanisms();
void select_mechanism(const std::string &mechanism);
Expand Down
15 changes: 15 additions & 0 deletions src/runtime/security/negotiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
#include "negotiation.h"
#include "client_negotiation.h"
#include "server_negotiation.h"
#include "negotiation_utils.h"

#include <dsn/utility/flags.h>
#include <dsn/utility/smart_pointers.h>
#include <dsn/dist/fmt_logging.h>

namespace dsn {
namespace security {
Expand Down Expand Up @@ -48,5 +50,18 @@ void negotiation::fail_negotiation()
_session->on_failure(true);
}

bool negotiation::check_status(negotiation_status::type status,
negotiation_status::type expect_status)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expected_status may be better.

{
if (status != expect_status) {
dwarn_f("{}: get message({}) while expect({})",
_name,
enum_to_string(status),
enum_to_string(expect_status));
return false;
}

return true;
}
} // namespace security
} // namespace dsn
5 changes: 5 additions & 0 deletions src/runtime/security/negotiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class negotiation
virtual void start() = 0;
bool negotiation_succeed() const { return _status == negotiation_status::type::SASL_SUCC; }
void fail_negotiation();
// check whether the status is equal to expect_status
// ret value:
// true: status == expect_status
// false: status != expect_status
bool check_status(negotiation_status::type status, negotiation_status::type expect_status);

protected:
// The ownership of the negotiation instance is held by rpc_session.
Expand Down
27 changes: 23 additions & 4 deletions src/runtime/security/sasl_client_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

#include "sasl_client_wrapper.h"

#include <sasl/sasl.h>
#include <dsn/utility/flags.h>
#include <dsn/utility/fail_point.h>

namespace dsn {
namespace security {
Expand All @@ -26,16 +28,33 @@ DSN_DECLARE_string(service_name);

error_s sasl_client_wrapper::init()
{
// TBD(zlw)
return error_s::make(ERR_OK);
FAIL_POINT_INJECT_F("sasl_client_wrapper_init", [](dsn::string_view str) {
error_code err = error_code::try_get(str.data(), ERR_UNKNOWN);
return error_s::make(err);
});

int sasl_err = sasl_client_new(
FLAGS_service_name, FLAGS_service_fqdn, nullptr, nullptr, nullptr, 0, &_conn);
return wrap_error(sasl_err);
}

error_s sasl_client_wrapper::start(const std::string &mechanism,
const std::string &input,
std::string &output)
{
// TBD(zlw)
return error_s::make(ERR_OK);
FAIL_POINT_INJECT_F("sasl_client_wrapper_start", [](dsn::string_view str) {
error_code err = error_code::try_get(str.data(), ERR_UNKNOWN);
return error_s::make(err);
});

const char *msg = nullptr;
unsigned msg_len = 0;
const char *client_mech = nullptr;
int sasl_err =
sasl_client_start(_conn, mechanism.c_str(), nullptr, &msg, &msg_len, &client_mech);

output.assign(msg, msg_len);
return wrap_error(sasl_err);
}

error_s sasl_client_wrapper::step(const std::string &input, std::string &output)
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/security/sasl_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ error_s sasl_wrapper::wrap_error(int sasl_err)
case SASL_OK:
return error_s::make(ERR_OK);
case SASL_CONTINUE:
return error_s::make(ERR_NOT_IMPLEMENTED);
return error_s::make(ERR_NOT_COMPLEMENTED);
case SASL_FAIL: // Generic failure (encompasses missing krb5 credentials).
case SASL_BADAUTH: // Authentication failure.
case SASL_BADMAC: // Decode failure.
Expand Down
58 changes: 25 additions & 33 deletions src/runtime/security/server_negotiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,52 +60,44 @@ void server_negotiation::handle_request(negotiation_rpc rpc)

void server_negotiation::on_list_mechanisms(negotiation_rpc rpc)
{
if (rpc.request().status == negotiation_status::type::SASL_LIST_MECHANISMS) {
std::string mech_list = boost::join(supported_mechanisms, ",");
negotiation_response &response = rpc.response();
_status = response.status = negotiation_status::type::SASL_LIST_MECHANISMS_RESP;
response.msg = std::move(mech_list);
} else {
ddebug_f("{}: got message({}) while expect({})",
_name,
enum_to_string(rpc.request().status),
enum_to_string(negotiation_status::type::SASL_LIST_MECHANISMS));
if (!check_status(rpc.request().status, negotiation_status::type::SASL_LIST_MECHANISMS)) {
fail_negotiation();
return;
}
return;

std::string mech_list = boost::join(supported_mechanisms, ",");
negotiation_response &response = rpc.response();
_status = response.status = negotiation_status::type::SASL_LIST_MECHANISMS_RESP;
response.msg = std::move(mech_list);
}

void server_negotiation::on_select_mechanism(negotiation_rpc rpc)
{
const negotiation_request &request = rpc.request();
if (request.status == negotiation_status::type::SASL_SELECT_MECHANISMS) {
_selected_mechanism = request.msg;
if (supported_mechanisms.find(_selected_mechanism) == supported_mechanisms.end()) {
dwarn_f("the mechanism of {} is not supported", _selected_mechanism);
fail_negotiation();
return;
}
if (!check_status(rpc.request().status, negotiation_status::type::SASL_SELECT_MECHANISMS)) {
fail_negotiation();
return;
}

error_s err_s = _sasl->init();
if (!err_s.is_ok()) {
dwarn_f("{}: server initialize sasl failed, error = {}, msg = {}",
_name,
err_s.code().to_string(),
err_s.description());
fail_negotiation();
return;
}
_selected_mechanism = request.msg;
if (supported_mechanisms.find(_selected_mechanism) == supported_mechanisms.end()) {
dwarn_f("the mechanism of {} is not supported", _selected_mechanism);
fail_negotiation();
return;
}

negotiation_response &response = rpc.response();
_status = response.status = negotiation_status::type::SASL_SELECT_MECHANISMS_RESP;
} else {
dwarn_f("{}: got message({}) while expect({})",
error_s err_s = _sasl->init();
if (!err_s.is_ok()) {
dwarn_f("{}: server initialize sasl failed, error = {}, msg = {}",
_name,
enum_to_string(request.status),
negotiation_status::type::SASL_SELECT_MECHANISMS);
err_s.code().to_string(),
err_s.description());
fail_negotiation();
return;
}

negotiation_response &response = rpc.response();
_status = response.status = negotiation_status::type::SASL_SELECT_MECHANISMS_RESP;
Comment on lines +77 to +100
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good refactor!

}
} // namespace security
} // namespace dsn
52 changes: 52 additions & 0 deletions src/runtime/test/client_negotiation_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <gtest/gtest.h>
#include <dsn/utility/flags.h>
#include <dsn/utility/fail_point.h>

namespace dsn {
namespace security {
Expand All @@ -45,6 +46,11 @@ class client_negotiation_test : public testing::Test
_client_negotiation->handle_response(err, std::move(resp));
}

void on_mechanism_selected(const negotiation_response &resp)
{
_client_negotiation->on_mechanism_selected(resp);
}

const std::string &get_selected_mechanism() { return _client_negotiation->_selected_mechanism; }

negotiation_status::type get_negotiation_status() { return _client_negotiation->_status; }
Expand Down Expand Up @@ -112,5 +118,51 @@ TEST_F(client_negotiation_test, handle_response)
ASSERT_EQ(get_negotiation_status(), test.neg_status);
}
}

TEST_F(client_negotiation_test, on_mechanism_selected)
{
struct
{
std::string sasl_init_return_value;
std::string sasl_start_return_value;
negotiation_status::type resp_status;
negotiation_status::type neg_status;
} tests[] = {{"ERR_OK",
"ERR_OK",
negotiation_status::type::SASL_SELECT_MECHANISMS_RESP,
negotiation_status::type::SASL_INITIATE},
{"ERR_OK",
"ERR_NOT_COMPLEMENTED",
negotiation_status::type::SASL_SELECT_MECHANISMS_RESP,
negotiation_status::type::SASL_INITIATE},
{"ERR_OK",
"ERR_TIMEOUT",
negotiation_status::type::SASL_SELECT_MECHANISMS_RESP,
negotiation_status::type::SASL_AUTH_FAIL},
{"ERR_TIMEOUT",
"ERR_OK",
negotiation_status::type::SASL_SELECT_MECHANISMS_RESP,
negotiation_status::type::SASL_AUTH_FAIL},
{"ERR_OK",
"ERR_OK",
negotiation_status::type::SASL_SELECT_MECHANISMS,
negotiation_status::type::SASL_AUTH_FAIL}};

RPC_MOCKING(negotiation_rpc)
{
for (const auto &test : tests) {
fail::setup();
fail::cfg("sasl_client_wrapper_init", "return(" + test.sasl_init_return_value + ")");
fail::cfg("sasl_client_wrapper_start", "return(" + test.sasl_start_return_value + ")");

negotiation_response resp;
resp.status = test.resp_status;
on_mechanism_selected(resp);
ASSERT_EQ(get_negotiation_status(), test.neg_status);

fail::teardown();
}
}
}
} // namespace security
} // namespace dsn