diff --git a/examples/turmoil-provider/Cargo.toml b/examples/turmoil-provider/Cargo.toml new file mode 100644 index 0000000000..9b0ce09360 --- /dev/null +++ b/examples/turmoil-provider/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "turmoil-provider" +version = "0.1.0" +authors = ["AWS s2n"] +edition = "2021" + +[dependencies] +s2n-quic = { version = "1", path = "../../quic/s2n-quic", features = ["provider-event-tracing", "unstable-provider-io-turmoil"] } +tokio = { version = "1", features = ["full"] } +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +turmoil = { version = "0.5.2" } + +[workspace] +members = ["."] diff --git a/examples/turmoil-provider/rust-toolchain b/examples/turmoil-provider/rust-toolchain new file mode 100644 index 0000000000..166b660e8e --- /dev/null +++ b/examples/turmoil-provider/rust-toolchain @@ -0,0 +1,3 @@ +[toolchain] +channel = "1.66.0" +components = ["rustc", "clippy", "rustfmt"] diff --git a/examples/turmoil-provider/src/lib.rs b/examples/turmoil-provider/src/lib.rs new file mode 100644 index 0000000000..b5297aa7de --- /dev/null +++ b/examples/turmoil-provider/src/lib.rs @@ -0,0 +1,161 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +#![cfg(test)] + +use s2n_quic::{ + client::Connect, + provider::{event, io}, + Client, Server, +}; +use std::net::SocketAddr; +use turmoil::{lookup, Builder, Result}; + +/// NOTE: this certificate is to be used for demonstration purposes only! +pub static CERT_PEM: &str = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/../../quic/s2n-quic-core/certs/cert.pem" +)); +/// NOTE: this certificate is to be used for demonstration purposes only! +pub static KEY_PEM: &str = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/../../quic/s2n-quic-core/certs/key.pem" +)); + +#[test] +fn lossy_handshake() -> Result { + let mut sim = Builder::new() + .simulation_duration(core::time::Duration::from_secs(20)) + .build(); + + sim.host("server", || async move { + let io = io::turmoil::Builder::default() + .with_address(bind_to(443))? + .build()?; + + let mut server = Server::builder() + .with_io(io)? + .with_tls((CERT_PEM, KEY_PEM))? + .with_event(events())? + .start()?; + + while let Some(mut connection) = server.accept().await { + tokio::spawn(async move { + eprintln!("Connection accepted from {:?}", connection.remote_addr()); + + while let Ok(Some(mut stream)) = connection.accept_bidirectional_stream().await { + tokio::spawn(async move { + eprintln!("Stream opened from {:?}", stream.connection().remote_addr()); + + // echo any data back to the stream + while let Ok(Some(data)) = stream.receive().await { + stream.send(data).await.expect("stream should be open"); + } + }); + } + }); + } + + Ok(()) + }); + + sim.client("client", async move { + let io = io::turmoil::Builder::default() + .with_address(bind_to(1234))? + .build()?; + + let client = Client::builder() + .with_io(io)? + .with_tls(CERT_PEM)? + .with_event(events())? + .start()?; + + // drop packets for 1 second + drop_for(1); + + // even though we're dropping packets, the connection still goes through + let server_addr: SocketAddr = (lookup("server"), 443).into(); + let mut connection = client + .connect(Connect::new(server_addr).with_server_name("localhost")) + .await?; + + // drop packets for 5 seconds + drop_for(5); + + // even though we're dropping packets, the stream should still complete + let mut stream = connection.open_bidirectional_stream().await?; + stream.send(vec![1, 2, 3].into()).await?; + stream.finish()?; + + let response = stream.receive().await?.unwrap(); + assert_eq!(&response[..], &[1, 2, 3]); + + Ok(()) + }); + + sim.run()?; + + Ok(()) +} + +pub fn events() -> event::tracing::Provider { + use std::sync::Once; + + static TRACING: Once = Once::new(); + + // make sure this only gets initialized once + TRACING.call_once(|| { + use tokio::time::Instant; + + struct TokioUptime { + epoch: Instant, + } + + impl Default for TokioUptime { + fn default() -> Self { + Self { + epoch: Instant::now(), + } + } + } + + impl tracing_subscriber::fmt::time::FormatTime for TokioUptime { + fn format_time( + &self, + w: &mut tracing_subscriber::fmt::format::Writer, + ) -> std::fmt::Result { + write!(w, "{:?}", self.epoch.elapsed()) + } + } + + let format = tracing_subscriber::fmt::format() + .with_level(false) // don't include levels in formatted output + .with_timer(TokioUptime::default()) + .with_ansi(false) + .compact(); // Use a less verbose output format. + + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::new("trace")) + .event_format(format) + .with_test_writer() + .init(); + }); + + event::tracing::Provider::default() +} + +fn bind_to(port: u16) -> SocketAddr { + (std::net::Ipv4Addr::UNSPECIFIED, port).into() +} + +fn drop_for(secs: u64) { + turmoil::partition("client", "server"); + tokio::spawn(async move { + sleep_ms(secs * 1000).await; + turmoil::repair("client", "server"); + }); +} + +async fn sleep_ms(millis: u64) { + tokio::time::sleep(core::time::Duration::from_millis(millis)).await +} diff --git a/quic/s2n-quic-platform/Cargo.toml b/quic/s2n-quic-platform/Cargo.toml index b4a8a96bf1..978d3939a2 100644 --- a/quic/s2n-quic-platform/Cargo.toml +++ b/quic/s2n-quic-platform/Cargo.toml @@ -29,6 +29,7 @@ pin-project = { version = "1", optional = true } s2n-quic-core = { version = "=0.18.1", path = "../s2n-quic-core", default-features = false } socket2 = { version = "0.4", features = ["all"], optional = true } tokio = { version = "1", default-features = false, features = ["macros", "net", "rt", "time"], optional = true } +turmoil = { version = "0.5.2", optional = true } zeroize = { version = "1", default-features = false } [target.'cfg(unix)'.dependencies] diff --git a/quic/s2n-quic-platform/src/io.rs b/quic/s2n-quic-platform/src/io.rs index 3d2dbf5ff6..f6e42e3e56 100644 --- a/quic/s2n-quic-platform/src/io.rs +++ b/quic/s2n-quic-platform/src/io.rs @@ -8,3 +8,6 @@ pub mod tokio; #[cfg(any(test, feature = "io-testing"))] pub mod testing; + +#[cfg(feature = "turmoil")] +pub mod turmoil; diff --git a/quic/s2n-quic-platform/src/io/tokio.rs b/quic/s2n-quic-platform/src/io/tokio.rs index 7b38f48b37..eeb6208252 100644 --- a/quic/s2n-quic-platform/src/io/tokio.rs +++ b/quic/s2n-quic-platform/src/io/tokio.rs @@ -17,7 +17,7 @@ use tokio::{net::UdpSocket, runtime::Handle}; pub type PathHandle = socket::Handle; mod clock; -use clock::Clock; +pub(crate) use clock::Clock; impl crate::socket::std::Socket for UdpSocket { type Error = io::Error; diff --git a/quic/s2n-quic-platform/src/io/turmoil.rs b/quic/s2n-quic-platform/src/io/turmoil.rs new file mode 100644 index 0000000000..03f9dc315a --- /dev/null +++ b/quic/s2n-quic-platform/src/io/turmoil.rs @@ -0,0 +1,467 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::select::{self, Select}; +use crate::{buffer::default as buffer, io::tokio::Clock, socket::std as socket}; +use s2n_quic_core::{ + endpoint::Endpoint, + event::{self, EndpointPublisher as _}, + inet::SocketAddress, + path::MaxMtu, + time::Clock as ClockTrait, +}; +use std::{convert::TryInto, io, io::ErrorKind}; +use tokio::runtime::Handle; +use turmoil::net::UdpSocket; + +impl crate::socket::std::Socket for UdpSocket { + type Error = io::Error; + + fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, Option), Self::Error> { + let (len, addr) = self.try_recv_from(buf)?; + Ok((len, Some(addr.into()))) + } + + fn send_to(&self, buf: &[u8], addr: &SocketAddress) -> Result { + let addr: std::net::SocketAddr = (*addr).into(); + self.try_send_to(buf, addr) + } +} + +pub type PathHandle = socket::Handle; + +#[derive(Default)] +pub struct Io { + builder: Builder, +} + +impl Io { + pub fn builder() -> Builder { + Builder::default() + } + + pub fn new(addr: A) -> io::Result { + let builder = Builder::default().with_address(addr)?; + Ok(Self { builder }) + } + + async fn setup>( + self, + mut endpoint: E, + ) -> io::Result<(Instance, SocketAddress)> { + let Builder { + handle: _, + socket, + addr, + max_mtu, + } = self.builder; + + endpoint.set_max_mtu(max_mtu); + + let clock = Clock::default(); + + let socket = if let Some(socket) = socket { + socket + } else if let Some(addr) = addr { + UdpSocket::bind(&*addr).await? + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "missing bind address", + )); + }; + + let rx_buffer = buffer::Buffer::new_with_mtu(max_mtu.into()); + let mut rx = socket::Queue::new(rx_buffer); + + let tx_buffer = buffer::Buffer::new_with_mtu(max_mtu.into()); + let tx = socket::Queue::new(tx_buffer); + + let local_addr: SocketAddress = socket.local_addr()?.into(); + + // tell the queue the local address so it can fill it in on each message + rx.set_local_address(local_addr.into()); + + let instance = Instance { + clock, + socket, + rx, + tx, + endpoint, + }; + + Ok((instance, local_addr)) + } + + pub fn start>( + mut self, + endpoint: E, + ) -> io::Result<(tokio::task::JoinHandle<()>, SocketAddress)> { + let handle = if let Some(handle) = self.builder.handle.take() { + handle + } else { + Handle::try_current().map_err(|err| std::io::Error::new(io::ErrorKind::Other, err))? + }; + + let guard = handle.enter(); + + let task = handle.spawn(async move { + let (instance, _local_addr) = self.setup(endpoint).await.unwrap(); + + if let Err(err) = instance.event_loop().await { + let debug = format!("A fatal IO error occurred ({:?}): {err}", err.kind()); + if cfg!(test) { + panic!("{debug}"); + } else { + eprintln!("{debug}"); + } + } + }); + + drop(guard); + + // TODO this is a potentially async operation - can we get this here? + let local_addr = Default::default(); + + Ok((task, local_addr)) + } +} + +#[derive(Default)] +pub struct Builder { + handle: Option, + socket: Option, + addr: Option>, + max_mtu: MaxMtu, +} + +impl Builder { + #[must_use] + pub fn with_handle(mut self, handle: Handle) -> Self { + self.handle = Some(handle); + self + } + + /// Sets the local address for the runtime to listen on. + /// + /// NOTE: this method is mutually exclusive with `with_socket` + pub fn with_address( + mut self, + addr: A, + ) -> io::Result { + debug_assert!(self.socket.is_none(), "socket has already been set"); + self.addr = Some(Box::new(addr)); + Ok(self) + } + + /// Sets the socket used for sending and receiving for the runtime. + /// + /// NOTE: this method is mutually exclusive with `with_address` + pub fn with_socket(mut self, socket: UdpSocket) -> io::Result { + debug_assert!(self.addr.is_none(), "address has already been set"); + self.socket = Some(socket); + Ok(self) + } + + /// Sets the largest maximum transmission unit (MTU) that can be sent on a path + pub fn with_max_mtu(mut self, max_mtu: u16) -> io::Result { + self.max_mtu = max_mtu + .try_into() + .map_err(|err| io::Error::new(ErrorKind::InvalidInput, format!("{err}")))?; + Ok(self) + } + + pub fn build(self) -> io::Result { + Ok(Io { builder: self }) + } +} + +struct Instance { + clock: Clock, + socket: turmoil::net::UdpSocket, + rx: socket::Queue, + tx: socket::Queue, + endpoint: E, +} + +impl> Instance { + async fn event_loop(self) -> io::Result<()> { + let Self { + clock, + socket, + mut rx, + mut tx, + mut endpoint, + } = self; + + let mut timer = clock.timer(); + + loop { + // Poll for readability if we have free slots available + let rx_interest = rx.free_len() > 0; + let rx_task = async { + if rx_interest { + socket.readable().await + } else { + futures::future::pending().await + } + }; + + // Poll for writablity if we have occupied slots available + let tx_interest = tx.occupied_len() > 0; + let tx_task = async { + if tx_interest { + socket.writable().await + } else { + futures::future::pending().await + } + }; + + let wakeups = endpoint.wakeups(&clock); + // pin the wakeups future so we don't have to move it into the Select future. + tokio::pin!(wakeups); + + let select::Outcome { + rx_result, + tx_result, + timeout_expired, + application_wakeup, + } = if let Ok(res) = Select::new(rx_task, tx_task, &mut wakeups, &mut timer).await { + res + } else { + // The endpoint has shut down + return Ok(()); + }; + + let wakeup_timestamp = clock.get_time(); + let subscriber = endpoint.subscriber(); + let mut publisher = event::EndpointPublisherSubscriber::new( + event::builder::EndpointMeta { + endpoint_type: E::ENDPOINT_TYPE, + timestamp: wakeup_timestamp, + }, + None, + subscriber, + ); + + publisher.on_platform_event_loop_wakeup(event::builder::PlatformEventLoopWakeup { + timeout_expired, + rx_ready: rx_result.is_some(), + tx_ready: tx_result.is_some(), + application_wakeup, + }); + + if tx_result.is_some() { + tx.tx(&socket, &mut publisher)?; + } + + if rx_result.is_some() { + rx.rx(&socket, &mut publisher)?; + endpoint.receive(&mut rx.rx_queue(), &clock); + } + + endpoint.transmit(&mut tx.tx_queue(), &clock); + + let timeout = endpoint.timeout(); + + if let Some(timeout) = timeout { + timer.update(timeout); + } + + let timestamp = clock.get_time(); + let subscriber = endpoint.subscriber(); + let mut publisher = event::EndpointPublisherSubscriber::new( + event::builder::EndpointMeta { + endpoint_type: E::ENDPOINT_TYPE, + timestamp, + }, + None, + subscriber, + ); + + // notify the application that we're going to sleep + let timeout = timeout.map(|t| t.saturating_duration_since(timestamp)); + publisher.on_platform_event_loop_sleep(event::builder::PlatformEventLoopSleep { + timeout, + processing_duration: timestamp.saturating_duration_since(wakeup_timestamp), + }); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use core::{ + convert::TryInto, + task::{Context, Poll}, + }; + use s2n_quic_core::{ + endpoint::{self, CloseError}, + event, + inet::SocketAddress, + io::{ + rx::{self, Entry as _}, + tx, + }, + path::Handle as _, + time::{timer::Provider as _, Clock, Duration, Timer, Timestamp}, + }; + use std::collections::BTreeMap; + + struct TestEndpoint { + addr: SocketAddress, + tx_message_id: u32, + rx_messages: BTreeMap, + total_messages: u32, + subscriber: NoopSubscriber, + close_timer: Timer, + } + + impl TestEndpoint { + fn new(addr: SocketAddress) -> Self { + Self { + addr, + tx_message_id: 0, + rx_messages: BTreeMap::new(), + total_messages: 1000, + subscriber: Default::default(), + close_timer: Default::default(), + } + } + } + + #[derive(Debug, Default)] + struct NoopSubscriber; + + impl event::Subscriber for NoopSubscriber { + type ConnectionContext = (); + + fn create_connection_context( + &mut self, + _meta: &event::api::ConnectionMeta, + _info: &event::api::ConnectionInfo, + ) -> Self::ConnectionContext { + } + } + + impl Endpoint for TestEndpoint { + type PathHandle = PathHandle; + type Subscriber = NoopSubscriber; + + const ENDPOINT_TYPE: endpoint::Type = endpoint::Type::Server; + + fn transmit, C: Clock>( + &mut self, + queue: &mut Tx, + _clock: &C, + ) { + while self.tx_message_id < self.total_messages { + let payload = self.tx_message_id.to_be_bytes(); + let addr = PathHandle::from_remote_address(self.addr.into()); + let msg = (addr, payload); + if queue.push(msg).is_ok() { + self.tx_message_id += 1; + } else { + // no more capacity + return; + } + } + } + + fn receive, C: Clock>( + &mut self, + queue: &mut Rx, + clock: &C, + ) { + let now = clock.get_time(); + let local_address = queue.local_address(); + let entries = queue.as_slice_mut(); + let len = entries.len(); + for entry in entries { + if let Some((_header, payload)) = entry.read(&local_address) { + assert_eq!(payload.len(), 4, "invalid payload {:?}", payload); + + let id = (&*payload).try_into().unwrap(); + let id = u32::from_be_bytes(id); + self.rx_messages.insert(id, now); + } + } + queue.finish(len); + } + + fn poll_wakeups( + &mut self, + _cx: &mut Context<'_>, + clock: &C, + ) -> Poll> { + let now = clock.get_time(); + + if self.close_timer.poll_expiration(now).is_ready() { + assert!(self.rx_messages.len() as u32 * 4 > self.total_messages); + return Err(CloseError).into(); + } + + if !self.close_timer.is_armed() + && self.total_messages <= self.tx_message_id + && !self.rx_messages.is_empty() + { + self.close_timer.set(now + Duration::from_millis(100)); + } + + Poll::Pending + } + + fn timeout(&self) -> Option { + self.close_timer.next_expiration() + } + + fn set_max_mtu(&mut self, _max_mtu: MaxMtu) { + // noop + } + + fn subscriber(&mut self) -> &mut Self::Subscriber { + &mut self.subscriber + } + } + + fn bind(port: u16) -> std::net::SocketAddr { + use std::net::Ipv4Addr; + (Ipv4Addr::UNSPECIFIED, port).into() + } + + #[test] + fn sim_test() -> io::Result<()> { + use turmoil::lookup; + + let mut sim = turmoil::Builder::new().build(); + + sim.client("client", async move { + let io = Io::builder().with_address(bind(123))?.build()?; + + let endpoint = TestEndpoint::new((lookup("server"), 456).into()); + + let (task, _) = io.start(endpoint)?; + + task.await?; + + Ok(()) + }); + + sim.client("server", async move { + let io = Io::builder().with_address(bind(456))?.build()?; + + let endpoint = TestEndpoint::new((lookup("client"), 123).into()); + + let (task, _) = io.start(endpoint)?; + + task.await?; + + Ok(()) + }); + + sim.run().unwrap(); + + Ok(()) + } +} diff --git a/quic/s2n-quic/Cargo.toml b/quic/s2n-quic/Cargo.toml index 5a3e5dd8f9..ca04affa06 100644 --- a/quic/s2n-quic/Cargo.toml +++ b/quic/s2n-quic/Cargo.toml @@ -38,6 +38,8 @@ unstable_private_key = ["s2n-quic-tls/unstable_private_key"] unstable-provider-datagram = [] # This feature enables the testing IO provider unstable-provider-io-testing = ["s2n-quic-platform/io-testing"] +# This feature enables the turmoil IO provider +unstable-provider-io-turmoil = ["s2n-quic-platform/turmoil"] # This feature enables the packet interceptor provider, which is invoked on each cleartext packet unstable-provider-packet-interceptor = [] # This feature enables the random provider diff --git a/quic/s2n-quic/src/lib.rs b/quic/s2n-quic/src/lib.rs index 6185f7c6ff..034e51da6c 100644 --- a/quic/s2n-quic/src/lib.rs +++ b/quic/s2n-quic/src/lib.rs @@ -80,6 +80,7 @@ mod tests; feature = "unstable_client_hello", feature = "unstable-provider-datagram", feature = "unstable-provider-io-testing", + feature = "unstable-provider-io-turmoil", feature = "unstable-provider-packet-interceptor", feature = "unstable-provider-random", feature = "unstable-provider-congestion-controller", diff --git a/quic/s2n-quic/src/provider/io.rs b/quic/s2n-quic/src/provider/io.rs index bc200ec3a7..96d1e2c9d2 100644 --- a/quic/s2n-quic/src/provider/io.rs +++ b/quic/s2n-quic/src/provider/io.rs @@ -19,6 +19,9 @@ pub trait Provider: 'static { #[cfg(any(test, feature = "unstable-provider-io-testing"))] pub mod testing; +#[cfg(feature = "unstable-provider-io-turmoil")] +pub mod turmoil; + pub mod tokio; pub use self::tokio as default; diff --git a/quic/s2n-quic/src/provider/io/turmoil.rs b/quic/s2n-quic/src/provider/io/turmoil.rs new file mode 100644 index 0000000000..a107eb1868 --- /dev/null +++ b/quic/s2n-quic/src/provider/io/turmoil.rs @@ -0,0 +1,24 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Provides an implementation of the [`io::Provider`](crate::provider::io::Provider) +//! using the [`Turmoil network simulator`](https://docs.rs/turmoil). + +use s2n_quic_core::{endpoint::Endpoint, inet::SocketAddress}; +use s2n_quic_platform::io::turmoil; +use std::io; + +pub use self::turmoil::{Builder, Io as Provider}; + +impl super::Provider for Provider { + type PathHandle = turmoil::PathHandle; + type Error = io::Error; + + fn start>( + self, + endpoint: E, + ) -> Result { + let (_join_handle, local_addr) = Provider::start(self, endpoint)?; + Ok(local_addr) + } +}