From ec8b7c09d6194b95b08064b9836e1b9c5811fe7a Mon Sep 17 00:00:00 2001 From: Mark Rousskov Date: Mon, 26 Aug 2024 23:12:30 +0000 Subject: [PATCH] Fix ConfirmComplete with handshake de-duplication Previously, we only stored a single Waker, which meant that when the underlying s2n-quic connection/handle was returned to applications multiple times it ended up only waking one of the application tasks. This modifies the ConfirmComplete logic such that we now store as many Wakers as needed (using Tokio's watch channel) and wake all of them on changes. --- quic/s2n-quic/Cargo.toml | 2 +- quic/s2n-quic/src/provider/dc/confirm.rs | 90 +++++++++++----------- quic/s2n-quic/src/tests/dc.rs | 96 +++++++++++++----------- 3 files changed, 98 insertions(+), 90 deletions(-) diff --git a/quic/s2n-quic/Cargo.toml b/quic/s2n-quic/Cargo.toml index de89c9f6ef..98ac649f27 100644 --- a/quic/s2n-quic/Cargo.toml +++ b/quic/s2n-quic/Cargo.toml @@ -77,7 +77,7 @@ s2n-quic-rustls = { version = "=0.44.1", path = "../s2n-quic-rustls", optional = s2n-quic-tls = { version = "=0.44.1", path = "../s2n-quic-tls", optional = true } s2n-quic-tls-default = { version = "=0.44.1", path = "../s2n-quic-tls-default", optional = true } s2n-quic-transport = { version = "=0.44.1", path = "../s2n-quic-transport" } -tokio = { version = "1", default-features = false } +tokio = { version = "1", default-features = false, features = ["sync"] } zerocopy = { version = "0.7", optional = true, features = ["derive"] } zeroize = { version = "1", optional = true, default-features = false } diff --git a/quic/s2n-quic/src/provider/dc/confirm.rs b/quic/s2n-quic/src/provider/dc/confirm.rs index c76ab3bd67..6a3c9e6333 100644 --- a/quic/s2n-quic/src/provider/dc/confirm.rs +++ b/quic/s2n-quic/src/provider/dc/confirm.rs @@ -2,7 +2,6 @@ // SPDX-License-Identifier: Apache-2.0 use crate::Connection; -use core::task::{Context, Poll, Waker}; use s2n_quic_core::{ connection, connection::Error, @@ -13,6 +12,7 @@ use s2n_quic_core::{ }, }; use std::io; +use tokio::sync::watch; /// `event::Subscriber` used for ensuring an s2n-quic client or server negotiating dc /// waits for the dc handshake to complete @@ -21,58 +21,54 @@ impl ConfirmComplete { /// Blocks the task until the provided connection has either completed the dc handshake or closed /// with an error pub async fn wait_ready(conn: &mut Connection) -> io::Result<()> { - core::future::poll_fn(|cx| { - conn.query_event_context_mut(|context: &mut ConfirmContext| context.poll_ready(cx)) - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))? - }) - .await + let mut receiver = conn + .query_event_context_mut(|context: &mut ConfirmContext| context.sender.subscribe()) + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + + loop { + match &*receiver.borrow_and_update() { + // if we're ready or have errored then let the application know + State::Ready => return Ok(()), + State::Failed(error) => return Err((*error).into()), + State::Waiting(_) => {} + } + + if receiver.changed().await.is_err() { + return Err(io::Error::new( + io::ErrorKind::Other, + "never reached terminal state", + )); + } + } } } -#[derive(Default)] pub struct ConfirmContext { - waker: Option, - state: State, + sender: watch::Sender, +} + +impl Default for ConfirmContext { + fn default() -> Self { + let (sender, _receiver) = watch::channel(State::default()); + Self { sender } + } } impl ConfirmContext { /// Updates the state on the context fn update(&mut self, state: State) { - self.state = state; - - // notify the application that the state was updated - self.wake(); - } - - /// Polls the context for handshake completion - fn poll_ready(&mut self, cx: &mut Context) -> Poll> { - match self.state { - // if we're ready or have errored then let the application know - State::Ready => Poll::Ready(Ok(())), - State::Failed(error) => Poll::Ready(Err(error.into())), - State::Waiting(_) => { - // store the waker so we can notify the application of state updates - self.waker = Some(cx.waker().clone()); - Poll::Pending - } - } - } - - /// notify the application of a state update - fn wake(&mut self) { - if let Some(waker) = self.waker.take() { - waker.wake(); - } + self.sender.send_replace(state); } } impl Drop for ConfirmContext { // make sure the application is notified that we're closing the connection fn drop(&mut self) { - if matches!(self.state, State::Waiting(_)) { - self.state = State::Failed(connection::Error::unspecified()); - } - self.wake(); + self.sender.send_modify(|state| { + if matches!(state, State::Waiting(_)) { + *state = State::Failed(connection::Error::unspecified()); + } + }); } } @@ -107,14 +103,14 @@ impl Subscriber for ConfirmComplete { meta: &ConnectionMeta, event: &events::ConnectionClosed, ) { - ensure!(matches!(context.state, State::Waiting(_))); - - match (&meta.endpoint_type, event.error, &context.state) { - ( - EndpointType::Server { .. }, - Error::Closed { .. }, - State::Waiting(Some(DcState::PathSecretsReady { .. })), - ) => { + ensure!(matches!(*context.sender.borrow(), State::Waiting(_))); + let is_ready = matches!( + *context.sender.borrow(), + State::Waiting(Some(DcState::PathSecretsReady { .. })) + ); + + match (&meta.endpoint_type, event.error, is_ready) { + (EndpointType::Server { .. }, Error::Closed { .. }, true) => { // The client may close the connection immediately after the dc handshake completes, // before it sends acknowledgement of the server's DC_STATELESS_RESET_TOKENS. // Since the server has already moved into the PathSecretsReady state, this can be considered @@ -132,7 +128,7 @@ impl Subscriber for ConfirmComplete { _meta: &ConnectionMeta, event: &events::DcStateChanged, ) { - ensure!(matches!(context.state, State::Waiting(_))); + ensure!(matches!(*context.sender.borrow(), State::Waiting(_))); match event.state { DcState::NoVersionNegotiated { .. } => context.update(State::Failed( diff --git a/quic/s2n-quic/src/tests/dc.rs b/quic/s2n-quic/src/tests/dc.rs index a49846030c..e2f450feb4 100644 --- a/quic/s2n-quic/src/tests/dc.rs +++ b/quic/s2n-quic/src/tests/dc.rs @@ -69,7 +69,7 @@ fn dc_handshake_self_test() -> Result<()> { .with_tls(certificates::CERT_PEM)? .with_dc(MockDcEndpoint::new(&CLIENT_TOKENS))?; - self_test(server, client, None, None)?; + self_test(server, client, true, None, None)?; Ok(()) } @@ -114,7 +114,7 @@ fn dc_mtls_handshake_self_test() -> Result<()> { .with_tls(client_tls)? .with_dc(MockDcEndpoint::new(&SERVER_TOKENS))?; - self_test(server, client, None, None)?; + self_test(server, client, true, None, None)?; Ok(()) } @@ -143,7 +143,7 @@ fn dc_mtls_handshake_auth_failure_self_test() -> Result<()> { } .into(); - self_test(server, client, Some(expected_client_error), None)?; + self_test(server, client, true, Some(expected_client_error), None)?; Ok(()) } @@ -181,6 +181,7 @@ fn dc_mtls_handshake_server_not_supported_self_test() -> Result<()> { self_test( server, client, + true, Some(connection::Error::invalid_configuration( "peer does not support specified dc versions", )), @@ -228,6 +229,7 @@ fn dc_mtls_handshake_client_not_supported_self_test() -> Result<()> { self_test( server, client, + false, Some(expected_client_error), Some(connection::Error::invalid_configuration( "peer does not support specified dc versions", @@ -266,7 +268,7 @@ fn dc_possible_secret_control_packet( .with_dc(dc_endpoint)? .with_packet_interceptor(RandomShort::default())?; - let (client_events, _server_events) = self_test(server, client, None, None)?; + let (client_events, _server_events) = self_test(server, client, true, None, None)?; assert_eq!( 1, @@ -297,6 +299,7 @@ fn dc_possible_secret_control_packet( fn self_test( server: server::Builder, client: client::Builder, + client_has_dc: bool, expected_client_error: Option, expected_server_error: Option, ) -> Result<(DcRecorder, DcRecorder)> { @@ -318,18 +321,21 @@ fn self_test( let addr = server.local_addr()?; + let expected_count = 1 + client_has_dc as usize; spawn(async move { - if let Some(mut conn) = server.accept().await { - let result = dc::ConfirmComplete::wait_ready(&mut conn).await; - - if let Some(error) = expected_server_error { - assert_eq!(error, convert_io_result(result).unwrap()); - - if expected_client_error.is_some() { - conn.close(SERVER_CLOSE_ERROR_CODE.into()); + for _ in 0..expected_count { + if let Some(mut conn) = server.accept().await { + let result = dc::ConfirmComplete::wait_ready(&mut conn).await; + + if let Some(error) = expected_server_error { + assert_eq!(error, convert_io_result(result).unwrap()); + + if expected_client_error.is_some() { + conn.close(SERVER_CLOSE_ERROR_CODE.into()); + } + } else { + assert!(result.is_ok()); } - } else { - assert!(result.is_ok()); } } }); @@ -340,35 +346,41 @@ fn self_test( .with_random(Random::with_seed(456))? .start()?; - let client_events = client_events.clone(); - - primary::spawn(async move { - let connect = Connect::new(addr).with_server_name("localhost"); - let mut conn = client.connect(connect).await.unwrap(); - let result = dc::ConfirmComplete::wait_ready(&mut conn).await; - - if let Some(error) = expected_client_error { - assert_eq!(error, convert_io_result(result).unwrap()); - - if expected_server_error.is_some() { - conn.close(CLIENT_CLOSE_ERROR_CODE.into()); - // wait for the server to assert the expected error before dropping - delay(Duration::from_millis(100)).await; + for _ in 0..expected_count { + primary::spawn({ + let client = client.clone(); + let client_events = client_events.clone(); + async move { + let connect = Connect::new(addr) + .with_server_name("localhost") + .with_deduplicate(client_has_dc); + let mut conn = client.connect(connect).await.unwrap(); + let result = dc::ConfirmComplete::wait_ready(&mut conn).await; + + if let Some(error) = expected_client_error { + assert_eq!(error, convert_io_result(result).unwrap()); + + if expected_server_error.is_some() { + conn.close(CLIENT_CLOSE_ERROR_CODE.into()); + // wait for the server to assert the expected error before dropping + delay(Duration::from_millis(100)).await; + } + } else { + assert!(result.is_ok()); + let client_events = client_events + .dc_state_changed_events() + .lock() + .unwrap() + .clone(); + assert_dc_complete(&client_events); + // wait briefly so the ack for the `DC_STATELESS_RESET_TOKENS` frame from the server is sent + // before the client closes the connection. This is only necessary to confirm the `dc::State` + // on the server moves to `DcState::Complete` + delay(Duration::from_millis(100)).await; + } } - } else { - assert!(result.is_ok()); - let client_events = client_events - .dc_state_changed_events() - .lock() - .unwrap() - .clone(); - assert_dc_complete(&client_events); - // wait briefly so the ack for the `DC_STATELESS_RESET_TOKENS` frame from the server is sent - // before the client closes the connection. This is only necessary to confirm the `dc::State` - // on the server moves to `DcState::Complete` - delay(Duration::from_millis(100)).await; - } - }); + }); + } Ok(addr) })