Skip to content

Commit

Permalink
Fix ConfirmComplete with handshake de-duplication
Browse files Browse the repository at this point in the history
Previously, we only stored a single Waker, which meant that when the
underlying s2n-quic connection/handle was returned to applications
multiple times it ended up only waking one of the application tasks.

This modifies the ConfirmComplete logic such that we now store as many
Wakers as needed (using Tokio's watch channel) and wake all of them on
changes.
  • Loading branch information
Mark-Simulacrum committed Aug 27, 2024
1 parent b4ba62d commit ec8b7c0
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 90 deletions.
2 changes: 1 addition & 1 deletion quic/s2n-quic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ s2n-quic-rustls = { version = "=0.44.1", path = "../s2n-quic-rustls", optional =
s2n-quic-tls = { version = "=0.44.1", path = "../s2n-quic-tls", optional = true }
s2n-quic-tls-default = { version = "=0.44.1", path = "../s2n-quic-tls-default", optional = true }
s2n-quic-transport = { version = "=0.44.1", path = "../s2n-quic-transport" }
tokio = { version = "1", default-features = false }
tokio = { version = "1", default-features = false, features = ["sync"] }
zerocopy = { version = "0.7", optional = true, features = ["derive"] }
zeroize = { version = "1", optional = true, default-features = false }

Expand Down
90 changes: 43 additions & 47 deletions quic/s2n-quic/src/provider/dc/confirm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// SPDX-License-Identifier: Apache-2.0

use crate::Connection;
use core::task::{Context, Poll, Waker};
use s2n_quic_core::{
connection,
connection::Error,
Expand All @@ -13,6 +12,7 @@ use s2n_quic_core::{
},
};
use std::io;
use tokio::sync::watch;

/// `event::Subscriber` used for ensuring an s2n-quic client or server negotiating dc
/// waits for the dc handshake to complete
Expand All @@ -21,58 +21,54 @@ impl ConfirmComplete {
/// Blocks the task until the provided connection has either completed the dc handshake or closed
/// with an error
pub async fn wait_ready(conn: &mut Connection) -> io::Result<()> {
core::future::poll_fn(|cx| {
conn.query_event_context_mut(|context: &mut ConfirmContext| context.poll_ready(cx))
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?
})
.await
let mut receiver = conn
.query_event_context_mut(|context: &mut ConfirmContext| context.sender.subscribe())
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;

loop {
match &*receiver.borrow_and_update() {
// if we're ready or have errored then let the application know
State::Ready => return Ok(()),
State::Failed(error) => return Err((*error).into()),
State::Waiting(_) => {}
}

if receiver.changed().await.is_err() {
return Err(io::Error::new(
io::ErrorKind::Other,
"never reached terminal state",
));
}
}
}
}

#[derive(Default)]
pub struct ConfirmContext {
waker: Option<Waker>,
state: State,
sender: watch::Sender<State>,
}

impl Default for ConfirmContext {
fn default() -> Self {
let (sender, _receiver) = watch::channel(State::default());
Self { sender }
}
}

impl ConfirmContext {
/// Updates the state on the context
fn update(&mut self, state: State) {
self.state = state;

// notify the application that the state was updated
self.wake();
}

/// Polls the context for handshake completion
fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
match self.state {
// if we're ready or have errored then let the application know
State::Ready => Poll::Ready(Ok(())),
State::Failed(error) => Poll::Ready(Err(error.into())),
State::Waiting(_) => {
// store the waker so we can notify the application of state updates
self.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}

/// notify the application of a state update
fn wake(&mut self) {
if let Some(waker) = self.waker.take() {
waker.wake();
}
self.sender.send_replace(state);
}
}

impl Drop for ConfirmContext {
// make sure the application is notified that we're closing the connection
fn drop(&mut self) {
if matches!(self.state, State::Waiting(_)) {
self.state = State::Failed(connection::Error::unspecified());
}
self.wake();
self.sender.send_modify(|state| {
if matches!(state, State::Waiting(_)) {
*state = State::Failed(connection::Error::unspecified());
}
});
}
}

Expand Down Expand Up @@ -107,14 +103,14 @@ impl Subscriber for ConfirmComplete {
meta: &ConnectionMeta,
event: &events::ConnectionClosed,
) {
ensure!(matches!(context.state, State::Waiting(_)));

match (&meta.endpoint_type, event.error, &context.state) {
(
EndpointType::Server { .. },
Error::Closed { .. },
State::Waiting(Some(DcState::PathSecretsReady { .. })),
) => {
ensure!(matches!(*context.sender.borrow(), State::Waiting(_)));
let is_ready = matches!(
*context.sender.borrow(),
State::Waiting(Some(DcState::PathSecretsReady { .. }))
);

match (&meta.endpoint_type, event.error, is_ready) {
(EndpointType::Server { .. }, Error::Closed { .. }, true) => {
// The client may close the connection immediately after the dc handshake completes,
// before it sends acknowledgement of the server's DC_STATELESS_RESET_TOKENS.
// Since the server has already moved into the PathSecretsReady state, this can be considered
Expand All @@ -132,7 +128,7 @@ impl Subscriber for ConfirmComplete {
_meta: &ConnectionMeta,
event: &events::DcStateChanged,
) {
ensure!(matches!(context.state, State::Waiting(_)));
ensure!(matches!(*context.sender.borrow(), State::Waiting(_)));

match event.state {
DcState::NoVersionNegotiated { .. } => context.update(State::Failed(
Expand Down
96 changes: 54 additions & 42 deletions quic/s2n-quic/src/tests/dc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ fn dc_handshake_self_test() -> Result<()> {
.with_tls(certificates::CERT_PEM)?
.with_dc(MockDcEndpoint::new(&CLIENT_TOKENS))?;

self_test(server, client, None, None)?;
self_test(server, client, true, None, None)?;

Ok(())
}
Expand Down Expand Up @@ -114,7 +114,7 @@ fn dc_mtls_handshake_self_test() -> Result<()> {
.with_tls(client_tls)?
.with_dc(MockDcEndpoint::new(&SERVER_TOKENS))?;

self_test(server, client, None, None)?;
self_test(server, client, true, None, None)?;

Ok(())
}
Expand Down Expand Up @@ -143,7 +143,7 @@ fn dc_mtls_handshake_auth_failure_self_test() -> Result<()> {
}
.into();

self_test(server, client, Some(expected_client_error), None)?;
self_test(server, client, true, Some(expected_client_error), None)?;

Ok(())
}
Expand Down Expand Up @@ -181,6 +181,7 @@ fn dc_mtls_handshake_server_not_supported_self_test() -> Result<()> {
self_test(
server,
client,
true,
Some(connection::Error::invalid_configuration(
"peer does not support specified dc versions",
)),
Expand Down Expand Up @@ -228,6 +229,7 @@ fn dc_mtls_handshake_client_not_supported_self_test() -> Result<()> {
self_test(
server,
client,
false,
Some(expected_client_error),
Some(connection::Error::invalid_configuration(
"peer does not support specified dc versions",
Expand Down Expand Up @@ -266,7 +268,7 @@ fn dc_possible_secret_control_packet(
.with_dc(dc_endpoint)?
.with_packet_interceptor(RandomShort::default())?;

let (client_events, _server_events) = self_test(server, client, None, None)?;
let (client_events, _server_events) = self_test(server, client, true, None, None)?;

assert_eq!(
1,
Expand Down Expand Up @@ -297,6 +299,7 @@ fn dc_possible_secret_control_packet(
fn self_test<S: ServerProviders, C: ClientProviders>(
server: server::Builder<S>,
client: client::Builder<C>,
client_has_dc: bool,
expected_client_error: Option<connection::Error>,
expected_server_error: Option<connection::Error>,
) -> Result<(DcRecorder, DcRecorder)> {
Expand All @@ -318,18 +321,21 @@ fn self_test<S: ServerProviders, C: ClientProviders>(

let addr = server.local_addr()?;

let expected_count = 1 + client_has_dc as usize;
spawn(async move {
if let Some(mut conn) = server.accept().await {
let result = dc::ConfirmComplete::wait_ready(&mut conn).await;

if let Some(error) = expected_server_error {
assert_eq!(error, convert_io_result(result).unwrap());

if expected_client_error.is_some() {
conn.close(SERVER_CLOSE_ERROR_CODE.into());
for _ in 0..expected_count {
if let Some(mut conn) = server.accept().await {
let result = dc::ConfirmComplete::wait_ready(&mut conn).await;

if let Some(error) = expected_server_error {
assert_eq!(error, convert_io_result(result).unwrap());

if expected_client_error.is_some() {
conn.close(SERVER_CLOSE_ERROR_CODE.into());
}
} else {
assert!(result.is_ok());
}
} else {
assert!(result.is_ok());
}
}
});
Expand All @@ -340,35 +346,41 @@ fn self_test<S: ServerProviders, C: ClientProviders>(
.with_random(Random::with_seed(456))?
.start()?;

let client_events = client_events.clone();

primary::spawn(async move {
let connect = Connect::new(addr).with_server_name("localhost");
let mut conn = client.connect(connect).await.unwrap();
let result = dc::ConfirmComplete::wait_ready(&mut conn).await;

if let Some(error) = expected_client_error {
assert_eq!(error, convert_io_result(result).unwrap());

if expected_server_error.is_some() {
conn.close(CLIENT_CLOSE_ERROR_CODE.into());
// wait for the server to assert the expected error before dropping
delay(Duration::from_millis(100)).await;
for _ in 0..expected_count {
primary::spawn({
let client = client.clone();
let client_events = client_events.clone();
async move {
let connect = Connect::new(addr)
.with_server_name("localhost")
.with_deduplicate(client_has_dc);
let mut conn = client.connect(connect).await.unwrap();
let result = dc::ConfirmComplete::wait_ready(&mut conn).await;

if let Some(error) = expected_client_error {
assert_eq!(error, convert_io_result(result).unwrap());

if expected_server_error.is_some() {
conn.close(CLIENT_CLOSE_ERROR_CODE.into());
// wait for the server to assert the expected error before dropping
delay(Duration::from_millis(100)).await;
}
} else {
assert!(result.is_ok());
let client_events = client_events
.dc_state_changed_events()
.lock()
.unwrap()
.clone();
assert_dc_complete(&client_events);
// wait briefly so the ack for the `DC_STATELESS_RESET_TOKENS` frame from the server is sent
// before the client closes the connection. This is only necessary to confirm the `dc::State`
// on the server moves to `DcState::Complete`
delay(Duration::from_millis(100)).await;
}
}
} else {
assert!(result.is_ok());
let client_events = client_events
.dc_state_changed_events()
.lock()
.unwrap()
.clone();
assert_dc_complete(&client_events);
// wait briefly so the ack for the `DC_STATELESS_RESET_TOKENS` frame from the server is sent
// before the client closes the connection. This is only necessary to confirm the `dc::State`
// on the server moves to `DcState::Complete`
delay(Duration::from_millis(100)).await;
}
});
});
}

Ok(addr)
})
Expand Down

0 comments on commit ec8b7c0

Please sign in to comment.