diff --git a/quic/s2n-quic-core/src/connection/limits.rs b/quic/s2n-quic-core/src/connection/limits.rs index e1bef82805..20c1417ea2 100644 --- a/quic/s2n-quic-core/src/connection/limits.rs +++ b/quic/s2n-quic-core/src/connection/limits.rs @@ -304,3 +304,25 @@ impl Limiter for Limits { *self } } + +#[cfg(test)] +mod tests { + use super::*; + + // Local max data limits should be <= u32::MAX + #[test] + fn limit_validation() { + let mut data = u32::MAX as u64 + 1; + let limits = Limits::default(); + assert!(limits.with_data_window(data).is_err()); + assert!(limits.with_bidirectional_local_data_window(data).is_err()); + assert!(limits.with_bidirectional_remote_data_window(data).is_err()); + assert!(limits.with_unidirectional_data_window(data).is_err()); + + data = u32::MAX as u64; + assert!(limits.with_data_window(data).is_ok()); + assert!(limits.with_bidirectional_local_data_window(data).is_ok()); + assert!(limits.with_bidirectional_remote_data_window(data).is_ok()); + assert!(limits.with_unidirectional_data_window(data).is_ok()); + } +} diff --git a/quic/s2n-quic-core/src/path/mod.rs b/quic/s2n-quic-core/src/path/mod.rs index 9392ec6d2d..135119d3ee 100644 --- a/quic/s2n-quic-core/src/path/mod.rs +++ b/quic/s2n-quic-core/src/path/mod.rs @@ -100,6 +100,18 @@ pub trait Handle: 'static + Copy + Send + fmt::Debug { /// Returns `true` if the two handles are strictly equal to each other, i.e. /// byte-for-byte. fn strict_eq(&self, other: &Self) -> bool; + + /// Depending on the current value of `self`, fields from `other` may be copied to increase the + /// fidelity of the value. + /// + /// This is especially useful for clients that initiate a connection only based on the remote + /// IP and port. They likely wouldn't know the IP address of the local socket. Once a response + /// is received from the server, the IP information will be known at this point and the handle + /// can be updated with the new information. + /// + /// Implementations should try to limit the cost of updating by checking the current value to + /// see if it needs updating. + fn maybe_update(&mut self, other: &Self); } macro_rules! impl_addr { @@ -174,6 +186,11 @@ impl Handle for RemoteAddress { fn strict_eq(&self, other: &Self) -> bool { PartialEq::eq(self, other) } + + #[inline] + fn maybe_update(&mut self, _other: &Self) { + // nothing to update + } } #[derive(Clone, Copy, Debug, Eq)] @@ -221,6 +238,14 @@ impl Handle for Tuple { fn strict_eq(&self, other: &Self) -> bool { PartialEq::eq(self, other) } + + #[inline] + fn maybe_update(&mut self, other: &Self) { + // once we discover our path, update the address local address + if self.local_address.port() == 0 { + self.local_address = other.local_address; + } + } } #[derive(Clone, Copy, Debug, PartialEq)] diff --git a/quic/s2n-quic-core/src/xdp/path.rs b/quic/s2n-quic-core/src/xdp/path.rs index 4b25c81725..fdd4f45395 100644 --- a/quic/s2n-quic-core/src/xdp/path.rs +++ b/quic/s2n-quic-core/src/xdp/path.rs @@ -107,4 +107,12 @@ impl Handle for Tuple { fn strict_eq(&self, other: &Self) -> bool { PartialEq::eq(self, other) } + + #[inline] + fn maybe_update(&mut self, other: &Self) { + // once we discover our path, update the address full address + if self.local_address.port == 0 { + *self = *other; + } + } } diff --git a/quic/s2n-quic-platform/src/message/msg.rs b/quic/s2n-quic-platform/src/message/msg.rs index 273116d0db..6d69e28903 100644 --- a/quic/s2n-quic-platform/src/message/msg.rs +++ b/quic/s2n-quic-platform/src/message/msg.rs @@ -114,6 +114,14 @@ impl path::Handle for Handle { fn strict_eq(&self, other: &Self) -> bool { PartialEq::eq(self, other) } + + #[inline] + fn maybe_update(&mut self, other: &Self) { + // once we discover our path, update the address local address + if self.local_address.port() == 0 { + self.local_address = other.local_address; + } + } } impl_message_delegate!(Message, 0, msghdr); diff --git a/quic/s2n-quic-transport/src/path/manager.rs b/quic/s2n-quic-transport/src/path/manager.rs index 16945a1f2e..7e307f83a2 100644 --- a/quic/s2n-quic-transport/src/path/manager.rs +++ b/quic/s2n-quic-transport/src/path/manager.rs @@ -258,6 +258,9 @@ impl Manager { return Err(DatagramDropReason::InvalidSourceConnectionId); } + // update the address if it was resolved + path.handle.maybe_update(path_handle); + let unblocked = path.on_bytes_received(datagram.payload_len); return Ok((id, unblocked)); } diff --git a/quic/s2n-quic-transport/src/path/mod.rs b/quic/s2n-quic-transport/src/path/mod.rs index 511009fe17..b7351272bd 100644 --- a/quic/s2n-quic-transport/src/path/mod.rs +++ b/quic/s2n-quic-transport/src/path/mod.rs @@ -551,9 +551,8 @@ impl Path { // a default un-specified value; therefore only the remote_address is used // to compare Paths. fn eq_by_handle(&self, handle: &Config::PathHandle) -> bool { - if Config::ENDPOINT_TYPE.is_client() { - // TODO: https://github.com/aws/s2n-quic/issues/954 - // Possibly research a strategy to populate the local_address for Client endpoint + if self.handle.local_address().port() == 0 { + // Only compare the remote address if we haven't updated the local yet s2n_quic_core::path::Handle::eq(&self.handle.remote_address(), &handle.remote_address()) } else { self.handle.eq(handle) diff --git a/quic/s2n-quic/src/tests.rs b/quic/s2n-quic/src/tests.rs index 9037e92f7a..36a5ba0a6a 100644 --- a/quic/s2n-quic/src/tests.rs +++ b/quic/s2n-quic/src/tests.rs @@ -5,30 +5,27 @@ use crate::{ client::Connect, provider::{ self, - event::{events::PacketSent, ConnectionInfo, ConnectionMeta, Subscriber}, + event::{ + events::{MtuUpdated, MtuUpdatedCause, PacketSent, RecoveryMetrics}, + ConnectionInfo, ConnectionMeta, Subscriber, + }, io::testing::{rand, spawn, test, time::delay, Model}, - limits::Limits, packet_interceptor::Loss, }, Client, Server, }; +use bytes::Bytes; +use s2n_quic_core::{crypto::tls::testing::certificates, stream::testing::Data}; +use s2n_quic_platform::io::testing::{network::Packet, primary, TxRecorder}; use std::{ net::SocketAddr, sync::{Arc, Mutex}, time::Duration, }; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; mod setup; -use bytes::Bytes; -use s2n_quic_core::{ - crypto::tls::testing::certificates, - event::api::{MtuUpdated, MtuUpdatedCause}, - stream::testing::Data, -}; -use s2n_quic_platform::io::testing::{network::Packet, primary, TxRecorder}; - use setup::*; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; #[test] fn client_server_test() { @@ -304,38 +301,36 @@ fn local_stream_open_notify_test() { } macro_rules! event_recorder { - ($sub:ident, $con:ident, $event:ty, $method:ident) => { + ($sub:ident, $event:ty, $method:ident) => { + event_recorder!($sub, $event, $method, $event, { + |event: &$event, storage: &mut Vec<$event>| storage.push(event.clone()) + }); + }; + ($sub:ident, $event:ty, $method:ident, $storage:ty, $store:expr) => { + #[derive(Clone, Default)] struct $sub { - events: Arc>>, + events: Arc>>, } impl $sub { fn new() -> Self { - $sub { - events: Arc::new(Mutex::new(Vec::new())), - } + Self::default() } - fn events(&self) -> Arc>> { + fn events(&self) -> Arc>> { self.events.clone() } } - struct $con { - events: Arc>>, - } - impl Subscriber for $sub { - type ConnectionContext = $con; + type ConnectionContext = $sub; fn create_connection_context( &mut self, _meta: &ConnectionMeta, _info: &ConnectionInfo, ) -> Self::ConnectionContext { - $con { - events: self.events.clone(), - } + self.clone() } fn $method( @@ -344,24 +339,27 @@ macro_rules! event_recorder { _meta: &ConnectionMeta, event: &$event, ) { + let store = $store; let mut buffer = context.events.lock().unwrap(); - buffer.push(event.clone()); + store(event, &mut buffer); } } }; } +event_recorder!(PacketSentRecorder, PacketSent, on_packet_sent); +event_recorder!(MtuUpdatedRecorder, MtuUpdated, on_mtu_updated); event_recorder!( - PacketSentRecorder, - PacketSentRecorderContext, - PacketSent, - on_packet_sent -); -event_recorder!( - MtuUpdatedRecorder, - MtuUpdatedRecorderContext, - MtuUpdated, - on_mtu_updated + PathUpdatedRecorder, + RecoveryMetrics, + on_recovery_metrics, + SocketAddr, + |event: &RecoveryMetrics, storage: &mut Vec| { + let addr: SocketAddr = event.path.local_addr.to_string().parse().unwrap(); + if storage.last().map_or(true, |prev| *prev != addr) { + storage.push(addr); + } + } ); #[test] @@ -573,19 +571,39 @@ fn mtu_blackhole() { assert_eq!(1200, events.lock().unwrap().last().unwrap().mtu); } -// Local max data limits should be <= u32::MAX +/// Ensures that the client's local path handle is updated after it receives a packet from the +/// server +/// +/// See https://github.com/aws/s2n-quic/issues/954 #[test] -fn limit_validation() { - let mut data = u32::MAX as u64 + 1; - let limits = Limits::default(); - assert!(limits.with_data_window(data).is_err()); - assert!(limits.with_bidirectional_local_data_window(data).is_err()); - assert!(limits.with_bidirectional_remote_data_window(data).is_err()); - assert!(limits.with_unidirectional_data_window(data).is_err()); - - data = u32::MAX as u64; - assert!(limits.with_data_window(data).is_ok()); - assert!(limits.with_bidirectional_local_data_window(data).is_ok()); - assert!(limits.with_bidirectional_remote_data_window(data).is_ok()); - assert!(limits.with_unidirectional_data_window(data).is_ok()); +fn client_path_handle_update() { + let model = Model::default(); + + let subscriber = PathUpdatedRecorder::new(); + let events = subscriber.events(); + + test(model, |handle| { + let server = Server::builder() + .with_io(handle.builder().build()?)? + .with_tls(SERVER_CERTS)? + .start()?; + let client = Client::builder() + .with_io(handle.builder().build().unwrap())? + .with_tls(certificates::CERT_PEM)? + .with_event(subscriber)? + .start()?; + let addr = start_server(server)?; + start_client(client, addr, Data::new(1000))?; + Ok(addr) + }) + .unwrap(); + + let events_handle = events.lock().unwrap(); + + // initially, the client address should be unknown + assert_eq!(events_handle[0], "0.0.0.0:0".parse().unwrap()); + // after receiving a packet, the client port should be the first available ephemeral port + assert_eq!(events_handle[1], "1.0.0.1:49153".parse().unwrap()); + // there should only be a single update to the path handle + assert_eq!(events_handle.len(), 2); }