diff --git a/src/runtime/security/server_negotiation.cpp b/src/runtime/security/server_negotiation.cpp index f7044d06d9..3268386bbf 100644 --- a/src/runtime/security/server_negotiation.cpp +++ b/src/runtime/security/server_negotiation.cpp @@ -38,9 +38,19 @@ void server_negotiation::start() void server_negotiation::handle_request(negotiation_rpc rpc) { - if (_status == negotiation_status::type::SASL_LIST_MECHANISMS) { + switch (_status) { + case negotiation_status::type::SASL_LIST_MECHANISMS: on_list_mechanisms(rpc); - return; + break; + case negotiation_status::type::SASL_LIST_MECHANISMS_RESP: + on_select_mechanism(rpc); + break; + case negotiation_status::type::SASL_SELECT_MECHANISMS_RESP: + case negotiation_status::type::SASL_CHALLENGE: + // TBD(zlw) + break; + default: + fail_negotiation(rpc, "wrong status"); } } @@ -61,6 +71,47 @@ void server_negotiation::on_list_mechanisms(negotiation_rpc rpc) return; } +void server_negotiation::on_select_mechanism(negotiation_rpc rpc) +{ + const negotiation_request &request = rpc.request(); + if (request.status == negotiation_status::type::SASL_SELECT_MECHANISMS) { + _selected_mechanism = request.msg; + if (supported_mechanisms.find(_selected_mechanism) == supported_mechanisms.end()) { + std::string error_msg = + fmt::format("the mechanism of {} is not supported", _selected_mechanism); + dwarn_f("{}", error_msg); + fail_negotiation(rpc, error_msg); + return; + } + + error_s err_s = do_sasl_server_init(); + if (!err_s.is_ok()) { + dwarn_f("{}: server initialize sasl failed, error = {}, msg = {}", + _name, + err_s.code().to_string(), + err_s.description()); + fail_negotiation(rpc, err_s.description()); + return; + } + + negotiation_response &response = rpc.response(); + _status = response.status = negotiation_status::type::SASL_SELECT_MECHANISMS_RESP; + } else { + dwarn_f("{}: got message({}) while expect({})", + _name, + enum_to_string(request.status), + negotiation_status::type::SASL_SELECT_MECHANISMS); + fail_negotiation(rpc, "invalid_client_message_status"); + return; + } +} + +error_s server_negotiation::do_sasl_server_init() +{ + // TBD(zlw) + return error_s::make(ERR_OK); +} + void server_negotiation::fail_negotiation(negotiation_rpc rpc, const std::string &reason) { negotiation_response &response = rpc.response(); diff --git a/src/runtime/security/server_negotiation.h b/src/runtime/security/server_negotiation.h index 9337efc28a..928bb22c34 100644 --- a/src/runtime/security/server_negotiation.h +++ b/src/runtime/security/server_negotiation.h @@ -19,6 +19,8 @@ #include "negotiation.h" +#include + namespace dsn { namespace security { extern const std::set supported_mechanisms; @@ -33,6 +35,8 @@ class server_negotiation : public negotiation private: void on_list_mechanisms(negotiation_rpc rpc); + void on_select_mechanism(negotiation_rpc rpc); + error_s do_sasl_server_init(); void fail_negotiation(negotiation_rpc rpc, const std::string &reason); };