From fe964505ffae2458c7bf041cf2d04a302ec7caf4 Mon Sep 17 00:00:00 2001 From: Mark Rousskov Date: Mon, 14 Oct 2024 13:41:55 -0400 Subject: [PATCH] feat(s2n-quic): Add the certififcate chain to TlsSession (#2349) This allows users to query the certificate chain when in the TlsSession event callbacks (for dc or for on_tls_exporter_ready). The API added here allocates a Vec> which is plausibly quite wasteful, so the API is for now doc(hidden) to avoid stabilizing it. --- quic/s2n-quic-core/src/crypto/tls.rs | 18 +++ quic/s2n-quic-core/src/crypto/tls/testing.rs | 4 + quic/s2n-quic-core/src/event.rs | 9 ++ quic/s2n-quic-rustls/src/session.rs | 11 ++ quic/s2n-quic-tls/src/session.rs | 10 ++ quic/s2n-quic/src/tests.rs | 2 + quic/s2n-quic/src/tests/chain.rs | 128 +++++++++++++++++++ 7 files changed, 182 insertions(+) create mode 100644 quic/s2n-quic/src/tests/chain.rs diff --git a/quic/s2n-quic-core/src/crypto/tls.rs b/quic/s2n-quic-core/src/crypto/tls.rs index 4ecd875241..3da543aa9e 100644 --- a/quic/s2n-quic-core/src/crypto/tls.rs +++ b/quic/s2n-quic-core/src/crypto/tls.rs @@ -1,6 +1,8 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 +#[cfg(feature = "alloc")] +use alloc::vec::Vec; #[cfg(feature = "alloc")] pub use bytes::{Bytes, BytesMut}; use core::fmt::Debug; @@ -35,6 +37,19 @@ impl TlsExportError { } } +#[derive(Debug)] +#[non_exhaustive] +pub enum ChainError { + #[non_exhaustive] + Failure, +} + +impl ChainError { + pub fn failure() -> Self { + ChainError::Failure + } +} + pub trait TlsSession: Send { /// See and . fn tls_exporter( @@ -45,6 +60,9 @@ pub trait TlsSession: Send { ) -> Result<(), TlsExportError>; fn cipher_suite(&self) -> CipherSuite; + + #[cfg(feature = "alloc")] + fn peer_cert_chain_der(&self) -> Result>, ChainError>; } //= https://www.rfc-editor.org/rfc/rfc9000#section-4 diff --git a/quic/s2n-quic-core/src/crypto/tls/testing.rs b/quic/s2n-quic-core/src/crypto/tls/testing.rs index 63238d0bcf..9f09d32ddf 100644 --- a/quic/s2n-quic-core/src/crypto/tls/testing.rs +++ b/quic/s2n-quic-core/src/crypto/tls/testing.rs @@ -123,6 +123,10 @@ impl TlsSession for Session { fn cipher_suite(&self) -> CipherSuite { CipherSuite::TLS_AES_128_GCM_SHA256 } + + fn peer_cert_chain_der(&self) -> Result>, tls::ChainError> { + Err(tls::ChainError::failure()) + } } #[derive(Debug)] diff --git a/quic/s2n-quic-core/src/event.rs b/quic/s2n-quic-core/src/event.rs index 0d193c4e5f..c53706c4a9 100644 --- a/quic/s2n-quic-core/src/event.rs +++ b/quic/s2n-quic-core/src/event.rs @@ -2,6 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 use crate::{connection, endpoint}; +#[cfg(feature = "alloc")] +use alloc::vec::Vec; use core::{ops::RangeInclusive, time::Duration}; mod generated; @@ -149,6 +151,13 @@ impl<'a> TlsSession<'a> { self.session.tls_exporter(label, context, output) } + // Currently intended only for unstable usage + #[doc(hidden)] + #[cfg(feature = "alloc")] + pub fn peer_cert_chain_der(&self) -> Result>, crate::crypto::tls::ChainError> { + self.session.peer_cert_chain_der() + } + pub fn cipher_suite(&self) -> crate::event::api::CipherSuite { self.session.cipher_suite().into_event() } diff --git a/quic/s2n-quic-rustls/src/session.rs b/quic/s2n-quic-rustls/src/session.rs index 2ba3a75eba..39605c384c 100644 --- a/quic/s2n-quic-rustls/src/session.rs +++ b/quic/s2n-quic-rustls/src/session.rs @@ -58,6 +58,17 @@ impl tls::TlsSession for Session { CipherSuite::Unknown } } + + fn peer_cert_chain_der(&self) -> Result>, tls::ChainError> { + let err = tls::ChainError::failure(); + Ok(self + .connection + .peer_certificates() + .ok_or(err)? + .iter() + .map(|v| v.to_vec()) + .collect()) + } } impl fmt::Debug for Session { diff --git a/quic/s2n-quic-tls/src/session.rs b/quic/s2n-quic-tls/src/session.rs index 6c688bac3d..89e2c939e9 100644 --- a/quic/s2n-quic-tls/src/session.rs +++ b/quic/s2n-quic-tls/src/session.rs @@ -107,6 +107,16 @@ impl tls::TlsSession for Session { fn cipher_suite(&self) -> CipherSuite { self.state.cipher_suite() } + + fn peer_cert_chain_der(&self) -> Result>, tls::ChainError> { + self.connection + .peer_cert_chain() + .map_err(|_| tls::ChainError::failure())? + .iter() + .map(|v| Ok(v?.der()?.to_vec())) + .collect::>, s2n_tls::error::Error>>() + .map_err(|_| tls::ChainError::failure()) + } } impl tls::Session for Session { diff --git a/quic/s2n-quic/src/tests.rs b/quic/s2n-quic/src/tests.rs index 0bb0b95e99..c9823b86f1 100644 --- a/quic/s2n-quic/src/tests.rs +++ b/quic/s2n-quic/src/tests.rs @@ -45,6 +45,8 @@ mod skip_packets; // build options than s2n-tls. We should build the rustls provider with // mTLS enabled and remove the `cfg(target_os("windows"))`. #[cfg(not(target_os = "windows"))] +mod chain; +#[cfg(not(target_os = "windows"))] mod client_handshake_confirm; #[cfg(not(target_os = "windows"))] mod dc; diff --git a/quic/s2n-quic/src/tests/chain.rs b/quic/s2n-quic/src/tests/chain.rs new file mode 100644 index 0000000000..a5bc648954 --- /dev/null +++ b/quic/s2n-quic/src/tests/chain.rs @@ -0,0 +1,128 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! This module shows an example of an event provider that accesses certificate chains +//! from QUIC connections on both client and server. + +use super::*; +use crate::provider::event::events::{self, ConnectionInfo, ConnectionMeta, Subscriber}; + +struct Chain; + +#[derive(Default)] +struct ChainContext { + chain: Option>>, + sender: Option>>>, +} + +impl Subscriber for Chain { + type ConnectionContext = ChainContext; + + #[inline] + fn create_connection_context( + &mut self, + _: &ConnectionMeta, + _info: &ConnectionInfo, + ) -> Self::ConnectionContext { + ChainContext::default() + } + + fn on_tls_exporter_ready( + &mut self, + context: &mut Self::ConnectionContext, + _meta: &ConnectionMeta, + event: &events::TlsExporterReady, + ) { + if let Some(sender) = context.sender.take() { + sender + .blocking_send(event.session.peer_cert_chain_der().unwrap()) + .unwrap(); + } else { + context.chain = Some(event.session.peer_cert_chain_der().unwrap()); + } + } +} + +fn start_server( + mut server: Server, + server_chain: tokio::sync::mpsc::Sender>>, +) -> crate::provider::io::testing::Result { + let server_addr = server.local_addr()?; + + // accept connections and echo back + spawn(async move { + while let Some(mut connection) = server.accept().await { + let chain = connection + .query_event_context_mut(|ctx: &mut ChainContext| { + if let Some(chain) = ctx.chain.take() { + Some(chain) + } else { + ctx.sender = Some(server_chain.clone()); + None + } + }) + .unwrap(); + if let Some(chain) = chain { + server_chain.send(chain).await.unwrap(); + } + } + }); + + Ok(server_addr) +} + +fn tls_test(f: fn(crate::Connection, Vec>) -> C) +where + C: 'static + core::future::Future + Send, +{ + let model = Model::default(); + model.set_delay(Duration::from_millis(50)); + + test(model, |handle| { + let server = Server::builder() + .with_io(handle.builder().build()?)? + .with_tls(build_server_mtls_provider(certificates::MTLS_CA_CERT)?)? + .with_event((Chain, tracing_events()))? + .start()?; + let (send, server_chain) = tokio::sync::mpsc::channel(1); + let server_chain = Arc::new(tokio::sync::Mutex::new(server_chain)); + + let addr = start_server(server, send)?; + + let client = Client::builder() + .with_io(handle.builder().build().unwrap())? + .with_tls(build_client_mtls_provider(certificates::MTLS_CA_CERT)?)? + .with_event((Chain, tracing_events()))? + .start()?; + + // show it working for several connections + for _ in 0..10 { + let client = client.clone(); + let server_chain = server_chain.clone(); + primary::spawn(async move { + let connect = Connect::new(addr).with_server_name("localhost"); + let conn = client.connect(connect).await.unwrap(); + delay(Duration::from_millis(100)).await; + let server_chain = server_chain.lock().await.recv().await.unwrap(); + f(conn, server_chain).await; + }); + } + + Ok(addr) + }) + .unwrap(); +} + +#[test] +fn happy_case() { + tls_test(|mut conn, server_chain| async move { + let client_chain = conn + .query_event_context_mut(|ctx: &mut ChainContext| ctx.chain.take().unwrap()) + .unwrap(); + // these are DER-encoded and we lack nice conversion functions, so just assert some simple + // properties. + assert!(server_chain.len() > 1); + assert!(client_chain.len() > 1); + assert_ne!(server_chain, client_chain); + }); +}