Skip to content
This repository has been archived by the owner on Jun 23, 2022. It is now read-only.

Commit

Permalink
feat(security): add unit tests for client_negotiation (#608)
Browse files Browse the repository at this point in the history
  • Loading branch information
levy5307 authored Sep 3, 2020
1 parent 8ca22de commit b322067
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 13 deletions.
4 changes: 2 additions & 2 deletions include/dsn/tool-api/network.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ class rpc_session : public ref_counter
bool unlink_message_for_send();
virtual void send(uint64_t signature) = 0;
void on_send_completed(uint64_t signature = 0);
void on_failure(bool is_write = false);
void on_success();
virtual void on_failure(bool is_write = false);
virtual void on_success();

protected:
///
Expand Down
8 changes: 8 additions & 0 deletions src/runtime/rpc/network.sim.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class sim_client_session : public rpc_session
virtual void do_read(int sz) override {}

virtual void close() override {}

virtual void on_failure(bool is_write = false) override {}

virtual void on_success() override {}
};

class sim_server_session : public rpc_session
Expand All @@ -73,6 +77,10 @@ class sim_server_session : public rpc_session

virtual void close() override {}

virtual void on_failure(bool is_write = false) override {}

virtual void on_success() override {}

private:
rpc_session_ptr _client;
};
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/security/client_negotiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ void client_negotiation::on_recv_mechanisms(const negotiation_response &resp)

if (match_mechanism.empty()) {
dwarn_f("server only support mechanisms of ({}), can't find expected ({})",
resp_string,
boost::join(supported_mechanisms, ","));
boost::join(supported_mechanisms, ","),
resp_string);
fail_negotiation();
return;
}
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/security/client_negotiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class client_negotiation : public negotiation
void select_mechanism(const std::string &mechanism);
void send(std::unique_ptr<negotiation_request> request);
void succ_negotiation();

friend class client_negotiation_test;
};

} // namespace security
Expand Down
4 changes: 4 additions & 0 deletions src/runtime/security/sasl_server_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <sasl/sasl.h>
#include <dsn/utility/flags.h>
#include <dsn/utility/fail_point.h>

namespace dsn {
namespace security {
Expand All @@ -27,6 +28,9 @@ DSN_DECLARE_string(service_name);

error_s sasl_server_wrapper::init()
{
FAIL_POINT_INJECT_F("sasl_server_wrapper_init",
[](dsn::string_view) { return error_s::make(ERR_OK); });

int sasl_err = sasl_server_new(
FLAGS_service_name, FLAGS_service_fqdn, nullptr, nullptr, nullptr, nullptr, 0, &_conn);
return wrap_error(sasl_err);
Expand Down
4 changes: 1 addition & 3 deletions src/runtime/security/server_negotiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ void server_negotiation::on_select_mechanism(negotiation_rpc rpc)
if (request.status == negotiation_status::type::SASL_SELECT_MECHANISMS) {
_selected_mechanism = request.msg;
if (supported_mechanisms.find(_selected_mechanism) == supported_mechanisms.end()) {
std::string error_msg =
fmt::format("the mechanism of {} is not supported", _selected_mechanism);
dwarn_f("{}", error_msg);
dwarn_f("the mechanism of {} is not supported", _selected_mechanism);
fail_negotiation();
return;
}
Expand Down
116 changes: 116 additions & 0 deletions src/runtime/test/client_negotiation_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "runtime/security/negotiation_utils.h"
#include "runtime/security/client_negotiation.h"
#include "runtime/rpc/network.sim.h"

#include <gtest/gtest.h>
#include <dsn/utility/flags.h>

namespace dsn {
namespace security {
class client_negotiation_test : public testing::Test
{
public:
client_negotiation_test()
{
std::unique_ptr<tools::sim_network_provider> sim_net(
new tools::sim_network_provider(nullptr, nullptr));
_sim_session = sim_net->create_client_session(rpc_address("localhost", 10086));
_client_negotiation = make_unique<client_negotiation>(_sim_session);
}

void on_recv_mechanism(const negotiation_response &resp)
{
_client_negotiation->on_recv_mechanisms(resp);
}

void handle_response(error_code err, const negotiation_response &resp)
{
_client_negotiation->handle_response(err, std::move(resp));
}

const std::string &get_selected_mechanism() { return _client_negotiation->_selected_mechanism; }

negotiation_status::type get_negotiation_status() { return _client_negotiation->_status; }

// _sim_session is used for holding the sim_rpc_session which is created in ctor,
// in case it is released. Because negotiation keeps only a raw pointer.
rpc_session_ptr _sim_session;
std::unique_ptr<client_negotiation> _client_negotiation;
};

TEST_F(client_negotiation_test, on_recv_mechanisms)
{
struct
{
negotiation_status::type resp_status;
std::string resp_msg;
std::string selected_mechanism;
} tests[] = {{negotiation_status::type::SASL_SELECT_MECHANISMS, "GSSAPI", ""},
{negotiation_status::type::SASL_LIST_MECHANISMS_RESP, "TEST1", ""},
{negotiation_status::type::SASL_LIST_MECHANISMS_RESP, "TEST1, TEST2", ""},
{negotiation_status::type::SASL_LIST_MECHANISMS_RESP, "TEST1, GSSAPI", "GSSAPI"},
{negotiation_status::type::SASL_LIST_MECHANISMS_RESP, "GSSAPI", "GSSAPI"}};

RPC_MOCKING(negotiation_rpc)
{
for (const auto &test : tests) {
negotiation_response resp;
resp.status = test.resp_status;
resp.msg = test.resp_msg;
on_recv_mechanism(resp);

ASSERT_EQ(get_selected_mechanism(), test.selected_mechanism);
}
}
}

TEST_F(client_negotiation_test, handle_response)
{
struct
{
error_code resp_err;
negotiation_status::type resp_status;
bool mandatory_auth;
negotiation_status::type neg_status;
} tests[] = {{ERR_TIMEOUT,
negotiation_status::type::SASL_SELECT_MECHANISMS,
false,
negotiation_status::type::SASL_AUTH_FAIL},
{ERR_OK,
negotiation_status::type::SASL_AUTH_DISABLE,
true,
negotiation_status::type::SASL_AUTH_FAIL},
{ERR_OK,
negotiation_status::type::SASL_AUTH_DISABLE,
false,
negotiation_status::type::SASL_SUCC}};

DSN_DECLARE_bool(mandatory_auth);
for (const auto &test : tests) {
negotiation_response resp;
resp.status = test.resp_status;
FLAGS_mandatory_auth = test.mandatory_auth;
handle_response(test.resp_err, resp);

ASSERT_EQ(get_negotiation_status(), test.neg_status);
}
}
} // namespace security
} // namespace dsn
21 changes: 15 additions & 6 deletions src/runtime/test/server_negotiation_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

#include "runtime/security/server_negotiation.h"
#include "runtime/security/negotiation_utils.h"
#include "runtime/rpc/network.sim.h"

#include <gtest/gtest.h>
#include <dsn/utility/fail_point.h>
#include <runtime/rpc/network.sim.h>

namespace dsn {
namespace security {
Expand All @@ -31,8 +31,9 @@ class server_negotiation_test : public testing::Test
{
std::unique_ptr<tools::sim_network_provider> sim_net(
new tools::sim_network_provider(nullptr, nullptr));
auto sim_session = sim_net->create_client_session(rpc_address("localhost", 10086));
_srv_negotiation = new server_negotiation(sim_session);
_sim_session =
sim_net->create_server_session(rpc_address("localhost", 10086), rpc_session_ptr());
_srv_negotiation = make_unique<server_negotiation>(_sim_session);
}

negotiation_rpc create_negotiation_rpc(negotiation_status::type status, const std::string &msg)
Expand All @@ -47,7 +48,12 @@ class server_negotiation_test : public testing::Test

void on_select_mechanism(negotiation_rpc rpc) { _srv_negotiation->on_select_mechanism(rpc); }

server_negotiation *_srv_negotiation;
negotiation_status::type get_negotiation_status() { return _srv_negotiation->_status; }

// _sim_session is used for holding the sim_rpc_session which is created in ctor,
// in case it is released. Because negotiation keeps only a raw pointer.
rpc_session_ptr _sim_session;
std::unique_ptr<server_negotiation> _srv_negotiation;
};

TEST_F(server_negotiation_test, on_list_mechanisms)
Expand Down Expand Up @@ -75,6 +81,7 @@ TEST_F(server_negotiation_test, on_list_mechanisms)

ASSERT_EQ(rpc.response().status, test.resp_status);
ASSERT_EQ(rpc.response().msg, test.resp_msg);
ASSERT_EQ(get_negotiation_status(), test.nego_status);
}
}
}
Expand All @@ -95,21 +102,23 @@ TEST_F(server_negotiation_test, on_select_mechanism)
},
{negotiation_status::type::SASL_SELECT_MECHANISMS,
"TEST",
negotiation_status::type::INVALID},
negotiation_status::type::INVALID,
negotiation_status::type::SASL_AUTH_FAIL},
{negotiation_status::type::SASL_INITIATE,
"GSSAPI",
negotiation_status::type::INVALID,
negotiation_status::type::SASL_AUTH_FAIL}};

fail::setup();
fail::cfg("server_negotiation_sasl_server_init", "return()");
fail::cfg("sasl_server_wrapper_init", "return()");
RPC_MOCKING(negotiation_rpc)
{
for (const auto &test : tests) {
auto rpc = create_negotiation_rpc(test.req_status, test.req_msg);
on_select_mechanism(rpc);

ASSERT_EQ(rpc.response().status, test.resp_status);
ASSERT_EQ(get_negotiation_status(), test.nego_status);
}
}
fail::teardown();
Expand Down

0 comments on commit b322067

Please sign in to comment.