Skip to content

Commit

Permalink
client ssl filter: fix connection handling (#235)
Browse files Browse the repository at this point in the history
Now that we handle SSL client auth on a new connection (correct) we
need to wait until the handshake is complete before we do the auth
checks.
  • Loading branch information
mattklein123 authored Nov 19, 2016
1 parent 7e57daa commit 4090ba2
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 14 deletions.
17 changes: 14 additions & 3 deletions source/common/filter/auth/client_ssl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,22 +108,33 @@ Network::FilterStatus Instance::onNewConnection() {
if (!read_callbacks_->connection().ssl()) {
config_->stats().auth_no_ssl_.inc();
return Network::FilterStatus::Continue;
} else {
// Otherwise we need to wait for handshake to be complete before proceeding.
return Network::FilterStatus::StopIteration;
}
}

void Instance::onEvent(uint32_t events) {
if (!(events & Network::ConnectionEvent::Connected)) {
return;
}

ASSERT(read_callbacks_->connection().ssl());
if (config_->ipWhiteList().contains(read_callbacks_->connection().remoteAddress())) {
config_->stats().auth_ip_white_list_.inc();
return Network::FilterStatus::Continue;
read_callbacks_->continueReading();
return;
}

if (!config_->allowedPrincipals().allowed(
read_callbacks_->connection().ssl()->sha256PeerCertificateDigest())) {
config_->stats().auth_digest_no_match_.inc();
read_callbacks_->connection().close(Network::ConnectionCloseType::NoFlush);
return Network::FilterStatus::StopIteration;
return;
}

config_->stats().auth_digest_match_.inc();
return Network::FilterStatus::Continue;
read_callbacks_->continueReading();
}

} // Client Ssl
Expand Down
7 changes: 6 additions & 1 deletion source/common/filter/auth/client_ssl.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ typedef std::shared_ptr<Config> ConfigPtr;
/**
* A client SSL auth filter instance. One per connection.
*/
class Instance : public Network::ReadFilter {
class Instance : public Network::ReadFilter, public Network::ConnectionCallbacks {
public:
Instance(ConfigPtr config) : config_(config) {}

Expand All @@ -106,8 +106,13 @@ class Instance : public Network::ReadFilter {
Network::FilterStatus onNewConnection() override;
void initializeReadFilterCallbacks(Network::ReadFilterCallbacks& callbacks) override {
read_callbacks_ = &callbacks;
read_callbacks_->connection().addConnectionCallbacks(*this);
}

// Network::ConnectionCallbacks
void onBufferChange(Network::ConnectionBufferType, uint64_t, int64_t) override {}
void onEvent(uint32_t events) override;

private:
ConfigPtr config_;
Network::ReadFilterCallbacks* read_callbacks_{};
Expand Down
4 changes: 3 additions & 1 deletion source/common/ssl/connection_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ Network::ConnectionImpl::PostIoAction ConnectionImpl::doHandshake() {

handshake_complete_ = true;
raiseEvents(Network::ConnectionEvent::Connected);
return PostIoAction::KeepOpen;

// It's possible that we closed during the handshake callback.
return state() == State::Open ? PostIoAction::KeepOpen : PostIoAction::Close;
} else {
int err = SSL_get_error(ssl_.get(), rc);
conn_log_debug("handshake error: {}", *this, err);
Expand Down
33 changes: 25 additions & 8 deletions test/common/filter/auth/client_ssl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "test/test_common/utility.h"

using testing::_;
using testing::InSequence;
using testing::Invoke;
using testing::Return;
using testing::ReturnNew;
Expand Down Expand Up @@ -49,6 +50,7 @@ class ClientSslAuthFilterTest : public testing::Test {
}

void createAuthFilter() {
filter_callbacks_.connection_.callbacks_.clear();
instance_.reset(new Instance(config_));
instance_->initializeReadFilterCallbacks(filter_callbacks_);
}
Expand Down Expand Up @@ -91,7 +93,7 @@ TEST_F(ClientSslAuthFilterTest, NoCluster) {
EXPECT_THROW(new Config(loader, tls_, cm_, dispatcher_, stats_store_, runtime_), EnvoyException);
}

TEST_F(ClientSslAuthFilterTest, Basic) {
TEST_F(ClientSslAuthFilterTest, NoSsl) {
setup();
Buffer::OwnedImpl dummy("hello");

Expand All @@ -100,15 +102,27 @@ TEST_F(ClientSslAuthFilterTest, Basic) {
EXPECT_EQ(Network::FilterStatus::Continue, instance_->onNewConnection());
EXPECT_EQ(Network::FilterStatus::Continue, instance_->onData(dummy));
EXPECT_EQ(Network::FilterStatus::Continue, instance_->onData(dummy));
filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::RemoteClose);

EXPECT_EQ(1U, stats_store_.counter("auth.clientssl.vpn.auth_no_ssl").value());
}

TEST_F(ClientSslAuthFilterTest, Ssl) {
InSequence s;

setup();
Buffer::OwnedImpl dummy("hello");

// Create a new filter for an SSL connection, with no backing auth data yet.
createAuthFilter();
EXPECT_CALL(filter_callbacks_.connection_, ssl()).Times(2).WillRepeatedly(Return(&ssl_));
ON_CALL(filter_callbacks_.connection_, ssl()).WillByDefault(Return(&ssl_));
EXPECT_CALL(filter_callbacks_.connection_, remoteAddress())
.WillOnce(ReturnRefOfCopy(std::string("192.168.1.1")));
EXPECT_CALL(ssl_, sha256PeerCertificateDigest()).WillOnce(Return("digest"));
EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::NoFlush));
EXPECT_EQ(Network::FilterStatus::StopIteration, instance_->onNewConnection());
filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::Connected);
filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::RemoteClose);

// Respond.
EXPECT_CALL(*interval_timer_, enableTimer(_));
Expand All @@ -121,26 +135,29 @@ TEST_F(ClientSslAuthFilterTest, Basic) {

// Create a new filter for an SSL connection with an authorized cert.
createAuthFilter();
EXPECT_CALL(filter_callbacks_.connection_, ssl()).Times(2).WillRepeatedly(Return(&ssl_));
EXPECT_CALL(filter_callbacks_.connection_, remoteAddress())
.WillOnce(ReturnRefOfCopy(std::string("192.168.1.1")));
EXPECT_CALL(ssl_, sha256PeerCertificateDigest())
.WillOnce(Return("1b7d42ef0025ad89c1c911d6c10d7e86a4cb7c5863b2980abcbad1895f8b5314"));
EXPECT_EQ(Network::FilterStatus::Continue, instance_->onNewConnection());
EXPECT_EQ(Network::FilterStatus::StopIteration, instance_->onNewConnection());
EXPECT_CALL(filter_callbacks_, continueReading());
filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::Connected);
EXPECT_EQ(Network::FilterStatus::Continue, instance_->onData(dummy));
EXPECT_EQ(Network::FilterStatus::Continue, instance_->onData(dummy));
filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::RemoteClose);

// White list case.
createAuthFilter();
EXPECT_CALL(filter_callbacks_.connection_, ssl()).WillOnce(Return(&ssl_));
EXPECT_CALL(filter_callbacks_.connection_, remoteAddress())
.WillOnce(ReturnRefOfCopy(std::string("1.2.3.4")));
EXPECT_EQ(Network::FilterStatus::Continue, instance_->onNewConnection());
EXPECT_EQ(Network::FilterStatus::StopIteration, instance_->onNewConnection());
EXPECT_CALL(filter_callbacks_, continueReading());
filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::Connected);
EXPECT_EQ(Network::FilterStatus::Continue, instance_->onData(dummy));
EXPECT_EQ(Network::FilterStatus::Continue, instance_->onData(dummy));
filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::RemoteClose);

EXPECT_EQ(1U, stats_store_.counter("auth.clientssl.vpn.update_success").value());
EXPECT_EQ(1U, stats_store_.counter("auth.clientssl.vpn.auth_no_ssl").value());
EXPECT_EQ(1U, stats_store_.counter("auth.clientssl.vpn.auth_ip_white_list").value());
EXPECT_EQ(1U, stats_store_.counter("auth.clientssl.vpn.auth_digest_match").value());
EXPECT_EQ(1U, stats_store_.counter("auth.clientssl.vpn.auth_digest_no_match").value());
Expand Down Expand Up @@ -175,7 +192,6 @@ TEST_F(ClientSslAuthFilterTest, Basic) {
callbacks_->onFailure(Http::AsyncClient::FailureReason::Reset);

// Interval timer fires, cannot obtain async client.
EXPECT_CALL(*interval_timer_, enableTimer(_));
EXPECT_CALL(cm_, httpAsyncClientForCluster("vpn")).WillOnce(ReturnRef(cm_.async_client_));
EXPECT_CALL(cm_.async_client_, send_(_, _, _))
.WillOnce(
Expand All @@ -185,6 +201,7 @@ TEST_F(ClientSslAuthFilterTest, Basic) {
Http::HeaderMapPtr{new Http::TestHeaderMapImpl{{":status", "503"}}})});
return nullptr;
}));
EXPECT_CALL(*interval_timer_, enableTimer(_));
interval_timer_->callback_();

EXPECT_EQ(4U, stats_store_.counter("auth.clientssl.vpn.update_failure").value());
Expand Down
1 change: 1 addition & 0 deletions test/common/network/connection_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ TEST(ConnectionImplTest, BufferCallbacks) {
EXPECT_CALL(server_callbacks, onBufferChange(ConnectionBufferType::Read, 4, -4)).InSequence(s2);
EXPECT_CALL(server_callbacks, onEvent(ConnectionEvent::LocalClose)).InSequence(s2);

EXPECT_CALL(*read_filter, onNewConnection());
EXPECT_CALL(*read_filter, onData(_))
.WillOnce(Invoke([&](Buffer::Instance& data) -> FilterStatus {
data.drain(data.length());
Expand Down
3 changes: 2 additions & 1 deletion test/common/network/proxy_protocol_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ TEST_F(ProxyProtocolTest, Basic) {
}));

read_filter_.reset(new MockReadFilter());
EXPECT_CALL(*read_filter_.get(), onData(BufferStringEqual("more data")));
EXPECT_CALL(*read_filter_, onNewConnection());
EXPECT_CALL(*read_filter_, onData(BufferStringEqual("more data")));

dispatcher_.run(Event::Dispatcher::RunType::NonBlock);
accepted_connection->close(ConnectionCloseType::NoFlush);
Expand Down

0 comments on commit 4090ba2

Please sign in to comment.