From b3220675df17b116075f0719cce684735b0a0847 Mon Sep 17 00:00:00 2001 From: zhao liwei Date: Thu, 3 Sep 2020 14:33:43 +0800 Subject: [PATCH] feat(security): add unit tests for client_negotiation (#608) --- include/dsn/tool-api/network.h | 4 +- src/runtime/rpc/network.sim.h | 8 ++ src/runtime/security/client_negotiation.cpp | 4 +- src/runtime/security/client_negotiation.h | 2 + src/runtime/security/sasl_server_wrapper.cpp | 4 + src/runtime/security/server_negotiation.cpp | 4 +- src/runtime/test/client_negotiation_test.cpp | 116 +++++++++++++++++++ src/runtime/test/server_negotiation_test.cpp | 21 +++- 8 files changed, 150 insertions(+), 13 deletions(-) create mode 100644 src/runtime/test/client_negotiation_test.cpp diff --git a/include/dsn/tool-api/network.h b/include/dsn/tool-api/network.h index 645afc8c76..2e5f6f7c07 100644 --- a/include/dsn/tool-api/network.h +++ b/include/dsn/tool-api/network.h @@ -257,8 +257,8 @@ class rpc_session : public ref_counter bool unlink_message_for_send(); virtual void send(uint64_t signature) = 0; void on_send_completed(uint64_t signature = 0); - void on_failure(bool is_write = false); - void on_success(); + virtual void on_failure(bool is_write = false); + virtual void on_success(); protected: /// diff --git a/src/runtime/rpc/network.sim.h b/src/runtime/rpc/network.sim.h index 4573f3c54b..fde275eb72 100644 --- a/src/runtime/rpc/network.sim.h +++ b/src/runtime/rpc/network.sim.h @@ -55,6 +55,10 @@ class sim_client_session : public rpc_session virtual void do_read(int sz) override {} virtual void close() override {} + + virtual void on_failure(bool is_write = false) override {} + + virtual void on_success() override {} }; class sim_server_session : public rpc_session @@ -73,6 +77,10 @@ class sim_server_session : public rpc_session virtual void close() override {} + virtual void on_failure(bool is_write = false) override {} + + virtual void on_success() override {} + private: rpc_session_ptr _client; }; diff --git a/src/runtime/security/client_negotiation.cpp b/src/runtime/security/client_negotiation.cpp index f9f64d85de..2b22bc0ef5 100644 --- a/src/runtime/security/client_negotiation.cpp +++ b/src/runtime/security/client_negotiation.cpp @@ -102,8 +102,8 @@ void client_negotiation::on_recv_mechanisms(const negotiation_response &resp) if (match_mechanism.empty()) { dwarn_f("server only support mechanisms of ({}), can't find expected ({})", - resp_string, - boost::join(supported_mechanisms, ",")); + boost::join(supported_mechanisms, ","), + resp_string); fail_negotiation(); return; } diff --git a/src/runtime/security/client_negotiation.h b/src/runtime/security/client_negotiation.h index a45b4084d3..158c127311 100644 --- a/src/runtime/security/client_negotiation.h +++ b/src/runtime/security/client_negotiation.h @@ -37,6 +37,8 @@ class client_negotiation : public negotiation void select_mechanism(const std::string &mechanism); void send(std::unique_ptr request); void succ_negotiation(); + + friend class client_negotiation_test; }; } // namespace security diff --git a/src/runtime/security/sasl_server_wrapper.cpp b/src/runtime/security/sasl_server_wrapper.cpp index ce327bbd7f..6f4f070c87 100644 --- a/src/runtime/security/sasl_server_wrapper.cpp +++ b/src/runtime/security/sasl_server_wrapper.cpp @@ -19,6 +19,7 @@ #include #include +#include namespace dsn { namespace security { @@ -27,6 +28,9 @@ DSN_DECLARE_string(service_name); error_s sasl_server_wrapper::init() { + FAIL_POINT_INJECT_F("sasl_server_wrapper_init", + [](dsn::string_view) { return error_s::make(ERR_OK); }); + int sasl_err = sasl_server_new( FLAGS_service_name, FLAGS_service_fqdn, nullptr, nullptr, nullptr, nullptr, 0, &_conn); return wrap_error(sasl_err); diff --git a/src/runtime/security/server_negotiation.cpp b/src/runtime/security/server_negotiation.cpp index 6fe70c7b7b..130a538d53 100644 --- a/src/runtime/security/server_negotiation.cpp +++ b/src/runtime/security/server_negotiation.cpp @@ -81,9 +81,7 @@ void server_negotiation::on_select_mechanism(negotiation_rpc rpc) if (request.status == negotiation_status::type::SASL_SELECT_MECHANISMS) { _selected_mechanism = request.msg; if (supported_mechanisms.find(_selected_mechanism) == supported_mechanisms.end()) { - std::string error_msg = - fmt::format("the mechanism of {} is not supported", _selected_mechanism); - dwarn_f("{}", error_msg); + dwarn_f("the mechanism of {} is not supported", _selected_mechanism); fail_negotiation(); return; } diff --git a/src/runtime/test/client_negotiation_test.cpp b/src/runtime/test/client_negotiation_test.cpp new file mode 100644 index 0000000000..5158c45a75 --- /dev/null +++ b/src/runtime/test/client_negotiation_test.cpp @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "runtime/security/negotiation_utils.h" +#include "runtime/security/client_negotiation.h" +#include "runtime/rpc/network.sim.h" + +#include +#include + +namespace dsn { +namespace security { +class client_negotiation_test : public testing::Test +{ +public: + client_negotiation_test() + { + std::unique_ptr sim_net( + new tools::sim_network_provider(nullptr, nullptr)); + _sim_session = sim_net->create_client_session(rpc_address("localhost", 10086)); + _client_negotiation = make_unique(_sim_session); + } + + void on_recv_mechanism(const negotiation_response &resp) + { + _client_negotiation->on_recv_mechanisms(resp); + } + + void handle_response(error_code err, const negotiation_response &resp) + { + _client_negotiation->handle_response(err, std::move(resp)); + } + + const std::string &get_selected_mechanism() { return _client_negotiation->_selected_mechanism; } + + negotiation_status::type get_negotiation_status() { return _client_negotiation->_status; } + + // _sim_session is used for holding the sim_rpc_session which is created in ctor, + // in case it is released. Because negotiation keeps only a raw pointer. + rpc_session_ptr _sim_session; + std::unique_ptr _client_negotiation; +}; + +TEST_F(client_negotiation_test, on_recv_mechanisms) +{ + struct + { + negotiation_status::type resp_status; + std::string resp_msg; + std::string selected_mechanism; + } tests[] = {{negotiation_status::type::SASL_SELECT_MECHANISMS, "GSSAPI", ""}, + {negotiation_status::type::SASL_LIST_MECHANISMS_RESP, "TEST1", ""}, + {negotiation_status::type::SASL_LIST_MECHANISMS_RESP, "TEST1, TEST2", ""}, + {negotiation_status::type::SASL_LIST_MECHANISMS_RESP, "TEST1, GSSAPI", "GSSAPI"}, + {negotiation_status::type::SASL_LIST_MECHANISMS_RESP, "GSSAPI", "GSSAPI"}}; + + RPC_MOCKING(negotiation_rpc) + { + for (const auto &test : tests) { + negotiation_response resp; + resp.status = test.resp_status; + resp.msg = test.resp_msg; + on_recv_mechanism(resp); + + ASSERT_EQ(get_selected_mechanism(), test.selected_mechanism); + } + } +} + +TEST_F(client_negotiation_test, handle_response) +{ + struct + { + error_code resp_err; + negotiation_status::type resp_status; + bool mandatory_auth; + negotiation_status::type neg_status; + } tests[] = {{ERR_TIMEOUT, + negotiation_status::type::SASL_SELECT_MECHANISMS, + false, + negotiation_status::type::SASL_AUTH_FAIL}, + {ERR_OK, + negotiation_status::type::SASL_AUTH_DISABLE, + true, + negotiation_status::type::SASL_AUTH_FAIL}, + {ERR_OK, + negotiation_status::type::SASL_AUTH_DISABLE, + false, + negotiation_status::type::SASL_SUCC}}; + + DSN_DECLARE_bool(mandatory_auth); + for (const auto &test : tests) { + negotiation_response resp; + resp.status = test.resp_status; + FLAGS_mandatory_auth = test.mandatory_auth; + handle_response(test.resp_err, resp); + + ASSERT_EQ(get_negotiation_status(), test.neg_status); + } +} +} // namespace security +} // namespace dsn diff --git a/src/runtime/test/server_negotiation_test.cpp b/src/runtime/test/server_negotiation_test.cpp index b70286bf9e..fbf2847725 100644 --- a/src/runtime/test/server_negotiation_test.cpp +++ b/src/runtime/test/server_negotiation_test.cpp @@ -17,10 +17,10 @@ #include "runtime/security/server_negotiation.h" #include "runtime/security/negotiation_utils.h" +#include "runtime/rpc/network.sim.h" #include #include -#include namespace dsn { namespace security { @@ -31,8 +31,9 @@ class server_negotiation_test : public testing::Test { std::unique_ptr sim_net( new tools::sim_network_provider(nullptr, nullptr)); - auto sim_session = sim_net->create_client_session(rpc_address("localhost", 10086)); - _srv_negotiation = new server_negotiation(sim_session); + _sim_session = + sim_net->create_server_session(rpc_address("localhost", 10086), rpc_session_ptr()); + _srv_negotiation = make_unique(_sim_session); } negotiation_rpc create_negotiation_rpc(negotiation_status::type status, const std::string &msg) @@ -47,7 +48,12 @@ class server_negotiation_test : public testing::Test void on_select_mechanism(negotiation_rpc rpc) { _srv_negotiation->on_select_mechanism(rpc); } - server_negotiation *_srv_negotiation; + negotiation_status::type get_negotiation_status() { return _srv_negotiation->_status; } + + // _sim_session is used for holding the sim_rpc_session which is created in ctor, + // in case it is released. Because negotiation keeps only a raw pointer. + rpc_session_ptr _sim_session; + std::unique_ptr _srv_negotiation; }; TEST_F(server_negotiation_test, on_list_mechanisms) @@ -75,6 +81,7 @@ TEST_F(server_negotiation_test, on_list_mechanisms) ASSERT_EQ(rpc.response().status, test.resp_status); ASSERT_EQ(rpc.response().msg, test.resp_msg); + ASSERT_EQ(get_negotiation_status(), test.nego_status); } } } @@ -95,14 +102,15 @@ TEST_F(server_negotiation_test, on_select_mechanism) }, {negotiation_status::type::SASL_SELECT_MECHANISMS, "TEST", - negotiation_status::type::INVALID}, + negotiation_status::type::INVALID, + negotiation_status::type::SASL_AUTH_FAIL}, {negotiation_status::type::SASL_INITIATE, "GSSAPI", negotiation_status::type::INVALID, negotiation_status::type::SASL_AUTH_FAIL}}; fail::setup(); - fail::cfg("server_negotiation_sasl_server_init", "return()"); + fail::cfg("sasl_server_wrapper_init", "return()"); RPC_MOCKING(negotiation_rpc) { for (const auto &test : tests) { @@ -110,6 +118,7 @@ TEST_F(server_negotiation_test, on_select_mechanism) on_select_mechanism(rpc); ASSERT_EQ(rpc.response().status, test.resp_status); + ASSERT_EQ(get_negotiation_status(), test.nego_status); } } fail::teardown();