diff --git a/src/runtime/security/client_negotiation.cpp b/src/runtime/security/client_negotiation.cpp index 306273569c..089b277663 100644 --- a/src/runtime/security/client_negotiation.cpp +++ b/src/runtime/security/client_negotiation.cpp @@ -70,7 +70,7 @@ void client_negotiation::handle_response(error_code err, const negotiation_respo break; case negotiation_status::type::SASL_INITIATE: case negotiation_status::type::SASL_CHALLENGE_RESP: - // TBD(zlw) + on_challenge(response); break; default: fail_negotiation(); @@ -142,6 +142,33 @@ void client_negotiation::on_mechanism_selected(const negotiation_response &resp) } } +void client_negotiation::on_challenge(const negotiation_response &challenge) +{ + if (challenge.status == negotiation_status::type::SASL_CHALLENGE) { + std::string response_msg; + auto err = _sasl->step(challenge.msg, response_msg); + if (!err.is_ok() && err.code() != ERR_SASL_INCOMPLETE) { + dwarn_f("{}: negotiation failed, reason = {}", _name, err.description()); + fail_negotiation(); + return; + } + + auto req = dsn::make_unique(); + _status = req->status = negotiation_status::type::SASL_CHALLENGE_RESP; + req->msg = response_msg; + send(std::move(req)); + return; + } + + if (challenge.status == negotiation_status::type::SASL_SUCC) { + succ_negotiation(); + return; + } + + dwarn_f("{}: recv wrong negotiation msg type: {}", _name, enum_to_string(challenge.status)); + fail_negotiation(); +} + void client_negotiation::select_mechanism(const std::string &mechanism) { _selected_mechanism = mechanism; diff --git a/src/runtime/security/client_negotiation.h b/src/runtime/security/client_negotiation.h index 8a4dd7f38e..0560d885c8 100644 --- a/src/runtime/security/client_negotiation.h +++ b/src/runtime/security/client_negotiation.h @@ -33,6 +33,7 @@ class client_negotiation : public negotiation void handle_response(error_code err, const negotiation_response &&response); void on_recv_mechanisms(const negotiation_response &resp); void on_mechanism_selected(const negotiation_response &resp); + void on_challenge(const negotiation_response &resp); void list_mechanisms(); void select_mechanism(const std::string &mechanism); diff --git a/src/runtime/security/sasl_client_wrapper.cpp b/src/runtime/security/sasl_client_wrapper.cpp index 07871c6828..4bcb9244d3 100644 --- a/src/runtime/security/sasl_client_wrapper.cpp +++ b/src/runtime/security/sasl_client_wrapper.cpp @@ -59,8 +59,17 @@ error_s sasl_client_wrapper::start(const std::string &mechanism, error_s sasl_client_wrapper::step(const std::string &input, std::string &output) { - // TBD(zlw) - return error_s::make(ERR_OK); + FAIL_POINT_INJECT_F("sasl_client_wrapper_step", [](dsn::string_view str) { + error_code err = error_code::try_get(str.data(), ERR_UNKNOWN); + return error_s::make(err); + }); + + const char *msg = nullptr; + unsigned msg_len = 0; + int sasl_err = sasl_client_step(_conn, input.c_str(), input.length(), nullptr, &msg, &msg_len); + + output.assign(msg, msg_len); + return wrap_error(sasl_err); } } // namespace security } // namespace dsn diff --git a/src/runtime/test/client_negotiation_test.cpp b/src/runtime/test/client_negotiation_test.cpp index 7df6334953..4ef695e54b 100644 --- a/src/runtime/test/client_negotiation_test.cpp +++ b/src/runtime/test/client_negotiation_test.cpp @@ -51,6 +51,8 @@ class client_negotiation_test : public testing::Test _client_negotiation->on_mechanism_selected(resp); } + void on_challenge(const negotiation_response &resp) { _client_negotiation->on_challenge(resp); } + const std::string &get_selected_mechanism() { return _client_negotiation->_selected_mechanism; } negotiation_status::type get_negotiation_status() { return _client_negotiation->_status; } @@ -164,5 +166,40 @@ TEST_F(client_negotiation_test, on_mechanism_selected) } } } + +TEST_F(client_negotiation_test, on_challenge) +{ + struct + { + std::string sasl_step_result; + negotiation_status::type resp_status; + negotiation_status::type neg_status; + } tests[] = { + {"ERR_OK", + negotiation_status::type::SASL_CHALLENGE, + negotiation_status::type::SASL_CHALLENGE_RESP}, + {"ERR_SASL_INCOMPLETE", + negotiation_status::type::SASL_CHALLENGE, + negotiation_status::type::SASL_CHALLENGE_RESP}, + {"ERR_TIMEOUT", + negotiation_status::type::SASL_CHALLENGE, + negotiation_status::type::SASL_AUTH_FAIL}, + {"ERR_OK", negotiation_status::type::SASL_SUCC, negotiation_status::type::SASL_SUCC}}; + + RPC_MOCKING(negotiation_rpc) + { + for (const auto &test : tests) { + fail::setup(); + fail::cfg("sasl_client_wrapper_step", "return(" + test.sasl_step_result + ")"); + + negotiation_response resp; + resp.status = test.resp_status; + on_challenge(resp); + ASSERT_EQ(get_negotiation_status(), test.neg_status); + + fail::teardown(); + } + } +} } // namespace security } // namespace dsn