Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

listener-filter: refactor tls_inspector so its logic could be shared with custom filters #4331

Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 64 additions & 30 deletions source/extensions/filters/listener/tls_inspector/tls_inspector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ namespace Extensions {
namespace ListenerFilters {
namespace TlsInspector {

Config::Config(Stats::Scope& scope, uint32_t max_client_hello_size)
: stats_{ALL_TLS_INSPECTOR_STATS(POOL_COUNTER_PREFIX(scope, "tls_inspector."))},
Config::Config(Stats::Scope& scope, uint32_t max_client_hello_size, const std::string& stat_prefix)
: stats_{TLS_STATS(POOL_COUNTER_PREFIX(scope, stat_prefix))},
ssl_ctx_(SSL_CTX_new(TLS_with_buffers_method())),
max_client_hello_size_(max_client_hello_size) {

Expand All @@ -42,14 +42,14 @@ Config::Config(Stats::Scope& scope, uint32_t max_client_hello_size)
size_t len;
if (SSL_early_callback_ctx_extension_get(
client_hello, TLSEXT_TYPE_application_layer_protocol_negotiation, &data, &len)) {
Filter* filter = static_cast<Filter*>(SSL_get_app_data(client_hello->ssl));
TlsFilterBase* filter = static_cast<TlsFilterBase*>(SSL_get_app_data(client_hello->ssl));
filter->onALPN(data, len);
}
return ssl_select_cert_success;
});
SSL_CTX_set_tlsext_servername_callback(
ssl_ctx_.get(), [](SSL* ssl, int* out_alert, void*) -> int {
Filter* filter = static_cast<Filter*>(SSL_get_app_data(ssl));
TlsFilterBase* filter = static_cast<TlsFilterBase*>(SSL_get_app_data(ssl));
filter->onServername(SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name));

// Return an error to stop the handshake; we have what we wanted already.
Expand All @@ -63,10 +63,16 @@ bssl::UniquePtr<SSL> Config::newSsl() { return bssl::UniquePtr<SSL>{SSL_new(ssl_
thread_local uint8_t Filter::buf_[Config::TLS_MAX_CLIENT_HELLO];

Filter::Filter(const ConfigSharedPtr config) : config_(config), ssl_(config_->newSsl()) {
RELEASE_ASSERT(sizeof(buf_) >= config_->maxClientHelloSize(), "");
initializeSsl(config->maxClientHelloSize(), sizeof(buf_), ssl_,
static_cast<TlsFilterBase*>(this));
}

void Filter::initializeSsl(uint32_t maxClientHelloSize, size_t bufSize,
const bssl::UniquePtr<SSL>& ssl, void* appData) {
RELEASE_ASSERT(bufSize >= maxClientHelloSize, "");

SSL_set_app_data(ssl_.get(), this);
SSL_set_accept_state(ssl_.get());
SSL_set_app_data(ssl.get(), appData);
SSL_set_accept_state(ssl.get());
}

Network::FilterStatus Filter::onAccept(Network::ListenerFilterCallbacks& cb) {
Expand Down Expand Up @@ -100,6 +106,16 @@ Network::FilterStatus Filter::onAccept(Network::ListenerFilterCallbacks& cb) {
}

void Filter::onALPN(const unsigned char* data, unsigned int len) {
doOnALPN(data, len,
[&](std::vector<absl::string_view> protocols) {
cb_->socket().setRequestedApplicationProtocols(protocols);
},
alpn_found_);
}

void Filter::doOnALPN(const unsigned char* data, unsigned int len,
std::function<void(std::vector<absl::string_view> protocols)> onAlpnCb,
bool& alpn_found) {
CBS wire, list;
CBS_init(&wire, reinterpret_cast<const uint8_t*>(data), static_cast<size_t>(len));
if (!CBS_get_u16_length_prefixed(&wire, &list) || CBS_len(&wire) != 0 || CBS_len(&list) < 2) {
Expand All @@ -115,19 +131,28 @@ void Filter::onALPN(const unsigned char* data, unsigned int len) {
}
protocols.emplace_back(reinterpret_cast<const char*>(CBS_data(&name)), CBS_len(&name));
}
cb_->socket().setRequestedApplicationProtocols(protocols);
alpn_found_ = true;
onAlpnCb(protocols);
alpn_found = true;
}

void Filter::onServername(absl::string_view servername) {
ENVOY_LOG(debug, "tls:onServerName(), requestedServerName: {}", servername);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we keep this similar debug log in tcp_proxy

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it can be useful for debug to track that the SNI in both filters

doOnServername(
servername, config_->stats(),
[&](absl::string_view name) -> void { cb_->socket().setRequestedServerName(name); },
clienthello_success_);
}

void Filter::onServername(absl::string_view name) {
void Filter::doOnServername(absl::string_view name, const TlsStats& stats,
std::function<void(absl::string_view name)> onServernameCb,
bool& clienthello_success) {
if (!name.empty()) {
config_->stats().sni_found_.inc();
cb_->socket().setRequestedServerName(name);
ENVOY_LOG(debug, "tls:onServerName(), requestedServerName: {}", name);
stats.sni_found_.inc();
onServernameCb(name);
} else {
config_->stats().sni_not_found_.inc();
stats.sni_not_found_.inc();
}
clienthello_success_ = true;
clienthello_success = true;
}

void Filter::onRead() {
Expand Down Expand Up @@ -162,7 +187,13 @@ void Filter::onRead() {
const uint8_t* data = buf_ + read_;
const size_t len = result.rc_ - read_;
read_ = result.rc_;
parseClientHello(data, len);
parseClientHello(data, len, ssl_, read_, config_->maxClientHelloSize(), config_->stats(),
[&](bool success) -> void { done(success); }, alpn_found_,
clienthello_success_,
[&]() -> void {
cb_->socket().setDetectedTransportProtocol(
TransportSockets::TransportSocketNames::get().Tls);
});
}
}

Expand All @@ -179,41 +210,44 @@ void Filter::done(bool success) {
cb_->continueFilterChain(success);
}

void Filter::parseClientHello(const void* data, size_t len) {
// Ownership is passed to ssl_ in SSL_set_bio()
void Filter::parseClientHello(const void* data, size_t len, bssl::UniquePtr<SSL>& ssl,
uint64_t read, uint32_t maxClientHelloSize, const TlsStats& stats,
std::function<void(bool)> done, bool& alpn_found,
bool& clienthello_success, std::function<void()> onSuccess) {
// Ownership is passed to ssl in SSL_set_bio()
bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(data, len));

// Make the mem-BIO return that there is more data
// available beyond it's end
BIO_set_mem_eof_return(bio.get(), -1);

SSL_set_bio(ssl_.get(), bio.get(), bio.get());
SSL_set_bio(ssl.get(), bio.get(), bio.get());
bio.release();

int ret = SSL_do_handshake(ssl_.get());
int ret = SSL_do_handshake(ssl.get());

// This should never succeed because an error is always returned from the SNI callback.
ASSERT(ret <= 0);
switch (SSL_get_error(ssl_.get(), ret)) {
switch (SSL_get_error(ssl.get(), ret)) {
case SSL_ERROR_WANT_READ:
if (read_ == config_->maxClientHelloSize()) {
if (read == maxClientHelloSize) {
// We've hit the specified size limit. This is an unreasonably large ClientHello;
// indicate failure.
config_->stats().client_hello_too_large_.inc();
stats.client_hello_too_large_.inc();
done(false);
}
break;
case SSL_ERROR_SSL:
if (clienthello_success_) {
config_->stats().tls_found_.inc();
if (alpn_found_) {
config_->stats().alpn_found_.inc();
if (clienthello_success) {
stats.tls_found_.inc();
if (alpn_found) {
stats.alpn_found_.inc();
} else {
config_->stats().alpn_not_found_.inc();
stats.alpn_not_found_.inc();
}
cb_->socket().setDetectedTransportProtocol(TransportSockets::TransportSocketNames::get().Tls);
onSuccess();
} else {
config_->stats().tls_not_found_.inc();
stats.tls_not_found_.inc();
}
done(true);
break;
Expand Down
52 changes: 38 additions & 14 deletions source/extensions/filters/listener/tls_inspector/tls_inspector.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace TlsInspector {
/**
* All stats for the TLS inspector. @see stats_macros.h
*/
#define ALL_TLS_INSPECTOR_STATS(COUNTER) \
#define TLS_STATS(COUNTER) \
COUNTER(connection_closed) \
COUNTER(client_hello_too_large) \
COUNTER(read_error) \
Expand All @@ -32,50 +32,77 @@ namespace TlsInspector {
COUNTER(sni_not_found)

/**
* Definition of all stats for the TLS inspector. @see stats_macros.h
* Definition of stats for the TLS. @see stats_macros.h
*/
struct TlsInspectorStats {
ALL_TLS_INSPECTOR_STATS(GENERATE_COUNTER_STRUCT)
struct TlsStats {
TLS_STATS(GENERATE_COUNTER_STRUCT)
};

/**
* Global configuration for TLS inspector.
*/
class Config {
public:
Config(Stats::Scope& scope, uint32_t max_client_hello_size = TLS_MAX_CLIENT_HELLO);
Config(Stats::Scope& scope, uint32_t max_client_hello_size = TLS_MAX_CLIENT_HELLO,
const std::string& stat_prefix = "tls_inspector.");

const TlsInspectorStats& stats() const { return stats_; }
const TlsStats& stats() const { return stats_; }
bssl::UniquePtr<SSL> newSsl();
uint32_t maxClientHelloSize() const { return max_client_hello_size_; }

static constexpr size_t TLS_MAX_CLIENT_HELLO = 64 * 1024;

private:
TlsInspectorStats stats_;
TlsStats stats_;
bssl::UniquePtr<SSL_CTX> ssl_ctx_;
const uint32_t max_client_hello_size_;
};

typedef std::shared_ptr<Config> ConfigSharedPtr;

class TlsFilterBase {
public:
virtual ~TlsFilterBase() {}

private:
virtual void onALPN(const unsigned char* data, unsigned int len) PURE;
virtual void onServername(absl::string_view name) PURE;

// Allows callbacks on the SSL_CTX to set fields in this class.
friend class Config;
};

/**
* TLS inspector listener filter.
*/
class Filter : public Network::ListenerFilter, Logger::Loggable<Logger::Id::filter> {
class Filter : public Network::ListenerFilter,
public TlsFilterBase,
Logger::Loggable<Logger::Id::filter> {
public:
Filter(const ConfigSharedPtr config);

// Network::ListenerFilter
Network::FilterStatus onAccept(Network::ListenerFilterCallbacks& cb) override;
static void initializeSsl(uint32_t maxClientHelloSize, size_t bufSize,
const bssl::UniquePtr<SSL>& ssl, void* appData);
static void parseClientHello(const void* data, size_t len, bssl::UniquePtr<SSL>& ssl,
uint64_t read, uint32_t maxClientHelloSize, const TlsStats& stats,
std::function<void(bool)> done, bool& alpn_found,
bool& clienthello_success, std::function<void()> onSuccess);
static void doOnServername(absl::string_view name, const TlsStats& stats,
std::function<void(absl::string_view name)> onServernameCb,
bool& clienthello_success_);
static void doOnALPN(const unsigned char* data, unsigned int len,
std::function<void(std::vector<absl::string_view> protocols)> onAlpnCb,
bool& alpn_found);

private:
void parseClientHello(const void* data, size_t len);
void onRead();
void onTimeout();
void done(bool success);
void onALPN(const unsigned char* data, unsigned int len);
void onServername(absl::string_view name);
// Extensions::ListenerFilters::TlsInspector::TlsFilterBase
void onALPN(const unsigned char* data, unsigned int len) override;
void onServername(absl::string_view name) override;

ConfigSharedPtr config_;
Network::ListenerFilterCallbacks* cb_;
Expand All @@ -88,9 +115,6 @@ class Filter : public Network::ListenerFilter, Logger::Loggable<Logger::Id::filt
bool clienthello_success_{false};

static thread_local uint8_t buf_[Config::TLS_MAX_CLIENT_HELLO];

// Allows callbacks on the SSL_CTX to set fields in this class.
friend class Config;
};

} // namespace TlsInspector
Expand Down