diff --git a/source/extensions/filters/listener/tls_inspector/tls_inspector.cc b/source/extensions/filters/listener/tls_inspector/tls_inspector.cc index 6e56e2012ca1..f3928b650090 100644 --- a/source/extensions/filters/listener/tls_inspector/tls_inspector.cc +++ b/source/extensions/filters/listener/tls_inspector/tls_inspector.cc @@ -24,8 +24,8 @@ namespace Extensions { namespace ListenerFilters { namespace TlsInspector { -Config::Config(Stats::Scope& scope, uint32_t max_client_hello_size) - : stats_{ALL_TLS_INSPECTOR_STATS(POOL_COUNTER_PREFIX(scope, "tls_inspector."))}, +Config::Config(Stats::Scope& scope, uint32_t max_client_hello_size, const std::string& stat_prefix) + : stats_{TLS_STATS(POOL_COUNTER_PREFIX(scope, stat_prefix))}, ssl_ctx_(SSL_CTX_new(TLS_with_buffers_method())), max_client_hello_size_(max_client_hello_size) { @@ -42,14 +42,14 @@ Config::Config(Stats::Scope& scope, uint32_t max_client_hello_size) size_t len; if (SSL_early_callback_ctx_extension_get( client_hello, TLSEXT_TYPE_application_layer_protocol_negotiation, &data, &len)) { - Filter* filter = static_cast(SSL_get_app_data(client_hello->ssl)); + TlsFilterBase* filter = static_cast(SSL_get_app_data(client_hello->ssl)); filter->onALPN(data, len); } return ssl_select_cert_success; }); SSL_CTX_set_tlsext_servername_callback( ssl_ctx_.get(), [](SSL* ssl, int* out_alert, void*) -> int { - Filter* filter = static_cast(SSL_get_app_data(ssl)); + TlsFilterBase* filter = static_cast(SSL_get_app_data(ssl)); filter->onServername(SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name)); // Return an error to stop the handshake; we have what we wanted already. @@ -63,10 +63,16 @@ bssl::UniquePtr Config::newSsl() { return bssl::UniquePtr{SSL_new(ssl_ thread_local uint8_t Filter::buf_[Config::TLS_MAX_CLIENT_HELLO]; Filter::Filter(const ConfigSharedPtr config) : config_(config), ssl_(config_->newSsl()) { - RELEASE_ASSERT(sizeof(buf_) >= config_->maxClientHelloSize(), ""); + initializeSsl(config->maxClientHelloSize(), sizeof(buf_), ssl_, + static_cast(this)); +} + +void Filter::initializeSsl(uint32_t maxClientHelloSize, size_t bufSize, + const bssl::UniquePtr& ssl, void* appData) { + RELEASE_ASSERT(bufSize >= maxClientHelloSize, ""); - SSL_set_app_data(ssl_.get(), this); - SSL_set_accept_state(ssl_.get()); + SSL_set_app_data(ssl.get(), appData); + SSL_set_accept_state(ssl.get()); } Network::FilterStatus Filter::onAccept(Network::ListenerFilterCallbacks& cb) { @@ -100,6 +106,16 @@ Network::FilterStatus Filter::onAccept(Network::ListenerFilterCallbacks& cb) { } void Filter::onALPN(const unsigned char* data, unsigned int len) { + doOnALPN(data, len, + [&](std::vector protocols) { + cb_->socket().setRequestedApplicationProtocols(protocols); + }, + alpn_found_); +} + +void Filter::doOnALPN(const unsigned char* data, unsigned int len, + std::function protocols)> onAlpnCb, + bool& alpn_found) { CBS wire, list; CBS_init(&wire, reinterpret_cast(data), static_cast(len)); if (!CBS_get_u16_length_prefixed(&wire, &list) || CBS_len(&wire) != 0 || CBS_len(&list) < 2) { @@ -115,19 +131,28 @@ void Filter::onALPN(const unsigned char* data, unsigned int len) { } protocols.emplace_back(reinterpret_cast(CBS_data(&name)), CBS_len(&name)); } - cb_->socket().setRequestedApplicationProtocols(protocols); - alpn_found_ = true; + onAlpnCb(protocols); + alpn_found = true; +} + +void Filter::onServername(absl::string_view servername) { + ENVOY_LOG(debug, "tls:onServerName(), requestedServerName: {}", servername); + doOnServername( + servername, config_->stats(), + [&](absl::string_view name) -> void { cb_->socket().setRequestedServerName(name); }, + clienthello_success_); } -void Filter::onServername(absl::string_view name) { +void Filter::doOnServername(absl::string_view name, const TlsStats& stats, + std::function onServernameCb, + bool& clienthello_success) { if (!name.empty()) { - config_->stats().sni_found_.inc(); - cb_->socket().setRequestedServerName(name); - ENVOY_LOG(debug, "tls:onServerName(), requestedServerName: {}", name); + stats.sni_found_.inc(); + onServernameCb(name); } else { - config_->stats().sni_not_found_.inc(); + stats.sni_not_found_.inc(); } - clienthello_success_ = true; + clienthello_success = true; } void Filter::onRead() { @@ -162,7 +187,13 @@ void Filter::onRead() { const uint8_t* data = buf_ + read_; const size_t len = result.rc_ - read_; read_ = result.rc_; - parseClientHello(data, len); + parseClientHello(data, len, ssl_, read_, config_->maxClientHelloSize(), config_->stats(), + [&](bool success) -> void { done(success); }, alpn_found_, + clienthello_success_, + [&]() -> void { + cb_->socket().setDetectedTransportProtocol( + TransportSockets::TransportSocketNames::get().Tls); + }); } } @@ -179,41 +210,44 @@ void Filter::done(bool success) { cb_->continueFilterChain(success); } -void Filter::parseClientHello(const void* data, size_t len) { - // Ownership is passed to ssl_ in SSL_set_bio() +void Filter::parseClientHello(const void* data, size_t len, bssl::UniquePtr& ssl, + uint64_t read, uint32_t maxClientHelloSize, const TlsStats& stats, + std::function done, bool& alpn_found, + bool& clienthello_success, std::function onSuccess) { + // Ownership is passed to ssl in SSL_set_bio() bssl::UniquePtr bio(BIO_new_mem_buf(data, len)); // Make the mem-BIO return that there is more data // available beyond it's end BIO_set_mem_eof_return(bio.get(), -1); - SSL_set_bio(ssl_.get(), bio.get(), bio.get()); + SSL_set_bio(ssl.get(), bio.get(), bio.get()); bio.release(); - int ret = SSL_do_handshake(ssl_.get()); + int ret = SSL_do_handshake(ssl.get()); // This should never succeed because an error is always returned from the SNI callback. ASSERT(ret <= 0); - switch (SSL_get_error(ssl_.get(), ret)) { + switch (SSL_get_error(ssl.get(), ret)) { case SSL_ERROR_WANT_READ: - if (read_ == config_->maxClientHelloSize()) { + if (read == maxClientHelloSize) { // We've hit the specified size limit. This is an unreasonably large ClientHello; // indicate failure. - config_->stats().client_hello_too_large_.inc(); + stats.client_hello_too_large_.inc(); done(false); } break; case SSL_ERROR_SSL: - if (clienthello_success_) { - config_->stats().tls_found_.inc(); - if (alpn_found_) { - config_->stats().alpn_found_.inc(); + if (clienthello_success) { + stats.tls_found_.inc(); + if (alpn_found) { + stats.alpn_found_.inc(); } else { - config_->stats().alpn_not_found_.inc(); + stats.alpn_not_found_.inc(); } - cb_->socket().setDetectedTransportProtocol(TransportSockets::TransportSocketNames::get().Tls); + onSuccess(); } else { - config_->stats().tls_not_found_.inc(); + stats.tls_not_found_.inc(); } done(true); break; diff --git a/source/extensions/filters/listener/tls_inspector/tls_inspector.h b/source/extensions/filters/listener/tls_inspector/tls_inspector.h index c927c56d1327..6fce0c14a00d 100644 --- a/source/extensions/filters/listener/tls_inspector/tls_inspector.h +++ b/source/extensions/filters/listener/tls_inspector/tls_inspector.h @@ -19,7 +19,7 @@ namespace TlsInspector { /** * All stats for the TLS inspector. @see stats_macros.h */ -#define ALL_TLS_INSPECTOR_STATS(COUNTER) \ +#define TLS_STATS(COUNTER) \ COUNTER(connection_closed) \ COUNTER(client_hello_too_large) \ COUNTER(read_error) \ @@ -32,10 +32,10 @@ namespace TlsInspector { COUNTER(sni_not_found) /** - * Definition of all stats for the TLS inspector. @see stats_macros.h + * Definition of stats for the TLS. @see stats_macros.h */ -struct TlsInspectorStats { - ALL_TLS_INSPECTOR_STATS(GENERATE_COUNTER_STRUCT) +struct TlsStats { + TLS_STATS(GENERATE_COUNTER_STRUCT) }; /** @@ -43,39 +43,66 @@ struct TlsInspectorStats { */ class Config { public: - Config(Stats::Scope& scope, uint32_t max_client_hello_size = TLS_MAX_CLIENT_HELLO); + Config(Stats::Scope& scope, uint32_t max_client_hello_size = TLS_MAX_CLIENT_HELLO, + const std::string& stat_prefix = "tls_inspector."); - const TlsInspectorStats& stats() const { return stats_; } + const TlsStats& stats() const { return stats_; } bssl::UniquePtr newSsl(); uint32_t maxClientHelloSize() const { return max_client_hello_size_; } static constexpr size_t TLS_MAX_CLIENT_HELLO = 64 * 1024; private: - TlsInspectorStats stats_; + TlsStats stats_; bssl::UniquePtr ssl_ctx_; const uint32_t max_client_hello_size_; }; typedef std::shared_ptr ConfigSharedPtr; +class TlsFilterBase { +public: + virtual ~TlsFilterBase() {} + +private: + virtual void onALPN(const unsigned char* data, unsigned int len) PURE; + virtual void onServername(absl::string_view name) PURE; + + // Allows callbacks on the SSL_CTX to set fields in this class. + friend class Config; +}; + /** * TLS inspector listener filter. */ -class Filter : public Network::ListenerFilter, Logger::Loggable { +class Filter : public Network::ListenerFilter, + public TlsFilterBase, + Logger::Loggable { public: Filter(const ConfigSharedPtr config); // Network::ListenerFilter Network::FilterStatus onAccept(Network::ListenerFilterCallbacks& cb) override; + static void initializeSsl(uint32_t maxClientHelloSize, size_t bufSize, + const bssl::UniquePtr& ssl, void* appData); + static void parseClientHello(const void* data, size_t len, bssl::UniquePtr& ssl, + uint64_t read, uint32_t maxClientHelloSize, const TlsStats& stats, + std::function done, bool& alpn_found, + bool& clienthello_success, std::function onSuccess); + static void doOnServername(absl::string_view name, const TlsStats& stats, + std::function onServernameCb, + bool& clienthello_success_); + static void doOnALPN(const unsigned char* data, unsigned int len, + std::function protocols)> onAlpnCb, + bool& alpn_found); private: - void parseClientHello(const void* data, size_t len); void onRead(); void onTimeout(); void done(bool success); - void onALPN(const unsigned char* data, unsigned int len); - void onServername(absl::string_view name); + // Extensions::ListenerFilters::TlsInspector::TlsFilterBase + void onALPN(const unsigned char* data, unsigned int len) override; + void onServername(absl::string_view name) override; ConfigSharedPtr config_; Network::ListenerFilterCallbacks* cb_; @@ -88,9 +115,6 @@ class Filter : public Network::ListenerFilter, Logger::Loggable