diff --git a/common/s2n-codec/Cargo.toml b/common/s2n-codec/Cargo.toml index f57e3cf767..e54493f149 100644 --- a/common/s2n-codec/Cargo.toml +++ b/common/s2n-codec/Cargo.toml @@ -12,7 +12,8 @@ exclude = ["corpus.tar.gz"] [features] default = ["std", "bytes"] -std = [] +alloc = [] +std = ["alloc"] testing = ["std", "generator"] checked_range_unsafe = [] generator = ["bolero-generator"] diff --git a/common/s2n-codec/src/encoder/value.rs b/common/s2n-codec/src/encoder/value.rs index 6ef447d69b..fb29c4fe64 100644 --- a/common/s2n-codec/src/encoder/value.rs +++ b/common/s2n-codec/src/encoder/value.rs @@ -46,6 +46,14 @@ pub trait EncoderValue: Sized { len.encode(encoder); self.encode(encoder); } + + #[cfg(feature = "alloc")] + fn encode_to_vec(&self) -> alloc::vec::Vec { + let len = self.encoding_size(); + let mut buffer = alloc::vec![0u8; len]; + self.encode(&mut crate::EncoderBuffer::new(&mut buffer)); + buffer + } } macro_rules! encoder_value_byte { diff --git a/common/s2n-codec/src/lib.rs b/common/s2n-codec/src/lib.rs index 33018d1c61..a00265e8e5 100644 --- a/common/s2n-codec/src/lib.rs +++ b/common/s2n-codec/src/lib.rs @@ -3,6 +3,9 @@ #![cfg_attr(not(any(test, feature = "std")), no_std)] +#[cfg(feature = "alloc")] +extern crate alloc; + #[cfg(any(feature = "testing", test))] #[macro_use] pub mod testing; diff --git a/quic/s2n-quic-core/Cargo.toml b/quic/s2n-quic-core/Cargo.toml index 946b32bf1d..14f38ec035 100644 --- a/quic/s2n-quic-core/Cargo.toml +++ b/quic/s2n-quic-core/Cargo.toml @@ -12,7 +12,7 @@ exclude = ["corpus.tar.gz"] [features] default = ["alloc", "std"] -alloc = ["atomic-waker", "bytes", "crossbeam-utils"] +alloc = ["atomic-waker", "bytes", "crossbeam-utils", "s2n-codec/alloc"] std = ["alloc", "once_cell"] testing = ["std", "generator", "s2n-codec/testing", "checked-counters", "insta", "futures-test"] generator = ["bolero-generator"] diff --git a/quic/s2n-quic-core/src/crypto/tls/null.rs b/quic/s2n-quic-core/src/crypto/tls/null.rs index b4ded762c9..7e511646b6 100644 --- a/quic/s2n-quic-core/src/crypto/tls/null.rs +++ b/quic/s2n-quic-core/src/crypto/tls/null.rs @@ -85,7 +85,7 @@ impl crypto::tls::Endpoint for Endpoint { &mut self, transport_parameters: &Params, ) -> Self::Session { - let params = encode_transport_parameters(transport_parameters); + let params = transport_parameters.encode_to_vec().into(); Session::Server(server::Session::Init { transport_parameters: params, }) @@ -99,7 +99,7 @@ impl crypto::tls::Endpoint for Endpoint { ) -> Self::Session { assert_eq!(server_name, LOCALHOST); - let params = encode_transport_parameters(transport_parameters); + let params = transport_parameters.encode_to_vec().into(); Session::Client(client::Session::Init { transport_parameters: params, }) @@ -156,14 +156,6 @@ impl tls::Session for Session { } } -/// Encodes transport parameters into a byte vec -fn encode_transport_parameters(params: &Params) -> Bytes { - let len = params.encoding_size(); - let mut buffer = vec![0; len]; - params.encode(&mut s2n_codec::EncoderBuffer::new(&mut buffer)); - buffer.into() -} - static FIN: Bytes = Bytes::from_static(b"FIN"); static NULL: Bytes = Bytes::from_static(b"NULL"); @@ -191,7 +183,7 @@ pub mod client { } => { context.send_initial(core::mem::take(transport_parameters)); - context.on_server_name(LOCALHOST.clone()).unwrap(); + context.on_server_name(LOCALHOST.clone())?; *self = Self::WaitingInitial {}; } @@ -201,9 +193,7 @@ pub mod client { None => return Poll::Pending, }; - context - .on_handshake_keys(key::NoCrypto, key::NoCrypto) - .unwrap(); + context.on_handshake_keys(key::NoCrypto, key::NoCrypto)?; // notify the server we're done context.send_handshake(FIN.clone()); @@ -215,19 +205,17 @@ pub mod client { return Poll::Pending; } - context.on_application_protocol(NULL.clone()).unwrap(); + context.on_application_protocol(NULL.clone())?; - context - .on_one_rtt_keys( - key::NoCrypto, - key::NoCrypto, - tls::ApplicationParameters { - transport_parameters: params, - }, - ) - .unwrap(); + context.on_one_rtt_keys( + key::NoCrypto, + key::NoCrypto, + tls::ApplicationParameters { + transport_parameters: params, + }, + )?; - context.on_handshake_complete().unwrap(); + context.on_handshake_complete()?; *self = Self::Complete; @@ -267,24 +255,20 @@ pub mod server { }; context.send_initial(core::mem::take(transport_parameters)); - context - .on_handshake_keys(key::NoCrypto, key::NoCrypto) - .unwrap(); + context.on_handshake_keys(key::NoCrypto, key::NoCrypto)?; context.send_handshake(FIN.clone()); - context.on_application_protocol(NULL.clone()).unwrap(); + context.on_application_protocol(NULL.clone())?; - context - .on_one_rtt_keys( - key::NoCrypto, - key::NoCrypto, - tls::ApplicationParameters { - transport_parameters: &client_params, - }, - ) - .unwrap(); + context.on_one_rtt_keys( + key::NoCrypto, + key::NoCrypto, + tls::ApplicationParameters { + transport_parameters: &client_params, + }, + )?; - context.on_server_name(LOCALHOST.clone()).unwrap(); + context.on_server_name(LOCALHOST.clone())?; *self = Self::WaitingComplete; } @@ -294,7 +278,7 @@ pub mod server { } *self = Self::Complete; - context.on_handshake_complete().unwrap(); + context.on_handshake_complete()?; return Ok(()).into(); } @@ -446,6 +430,9 @@ mod key { mod tests { use super::*; use crate::crypto::tls::testing::Pair; + use bolero::check; + use bytes::{BufMut, Bytes, BytesMut}; + use std::collections::VecDeque; #[test] fn null_test() { @@ -460,4 +447,46 @@ mod tests { pair.finish(); } + + #[test] + fn fuzz_test() { + let mut server = Endpoint::default(); + let mut client = Endpoint::default(); + + check!().for_each(|mut bytes| { + // replaces a single buffer with fuzz bytes + let mut replace_bytes = |chunks: &mut VecDeque| { + for chunk in chunks { + let len = chunk.len().min(bytes.len()); + let (data, remaining) = bytes.split_at(len); + bytes = remaining; + let mut replacement = BytesMut::with_capacity(chunk.len()); + replacement.put_slice(data); + replacement.put_bytes(0, chunk.len() - data.len()); + assert_eq!(chunk.len(), replacement.len()); + *chunk = replacement.freeze(); + } + }; + + let mut pair = Pair::new(&mut server, &mut client, LOCALHOST.clone()); + + while pair.is_handshaking() { + if pair.poll_start().is_err() { + break; + } + + // replace all of the buffers with fuzz bytes + replace_bytes(&mut pair.server.context.initial.rx); + replace_bytes(&mut pair.server.context.initial.tx); + replace_bytes(&mut pair.server.context.handshake.rx); + replace_bytes(&mut pair.server.context.handshake.tx); + replace_bytes(&mut pair.server.context.application.rx); + replace_bytes(&mut pair.server.context.application.tx); + + if pair.poll_finish(None).is_err() { + break; + } + } + }); + } } diff --git a/quic/s2n-quic-core/src/crypto/tls/testing.rs b/quic/s2n-quic-core/src/crypto/tls/testing.rs index 27f33ffca3..92baeaa8d5 100644 --- a/quic/s2n-quic-core/src/crypto/tls/testing.rs +++ b/quic/s2n-quic-core/src/crypto/tls/testing.rs @@ -10,16 +10,18 @@ use crate::{ tls, CryptoSuite, HeaderKey, Key, }, endpoint, transport, + transport::parameters::{ClientTransportParameters, ServerTransportParameters}, }; use alloc::sync::Arc; use bytes::Bytes; use core::{ fmt, + marker::PhantomData, sync::atomic::{AtomicBool, Ordering}, task::{Poll, Waker}, }; use futures_test::task::new_count_waker; -use s2n_codec::EncoderValue; +use s2n_codec::{DecoderBuffer, DecoderValue, EncoderValue}; use std::{collections::VecDeque, fmt::Debug}; pub mod certificates { @@ -105,13 +107,19 @@ impl CryptoSuite for Session { } #[derive(Debug)] -pub struct TlsEndpoint { +pub struct TlsEndpoint +where + for<'a> Params: DecoderValue<'a>, +{ pub session: S, - pub context: Context, + pub context: Context, } -impl TlsEndpoint { - fn new(session: S, context: Context) -> Self { +impl TlsEndpoint +where + for<'a> Params: DecoderValue<'a>, +{ + fn new(session: S, context: Context) -> Self { Self { session, context } } } @@ -119,13 +127,28 @@ impl TlsEndpoint { /// A pair of TLS sessions and contexts being driven to completion #[derive(Debug)] pub struct Pair { - pub server: TlsEndpoint, - pub client: TlsEndpoint, + pub server: TlsEndpoint, + pub client: TlsEndpoint, pub server_name: ServerName, } -const TEST_SERVER_TRANSPORT_PARAMS: &[u8] = &[1, 2, 3]; -const TEST_CLIENT_TRANSPORT_PARAMS: &[u8] = &[3, 2, 1]; +fn server_params() -> Bytes { + ServerTransportParameters { + initial_max_data: 123.try_into().unwrap(), + ..Default::default() + } + .encode_to_vec() + .into() +} + +fn client_params() -> Bytes { + ClientTransportParameters { + initial_max_data: 456.try_into().unwrap(), + ..Default::default() + } + .encode_to_vec() + .into() +} impl Pair { pub fn new( @@ -139,13 +162,12 @@ impl Pair { { use crate::crypto::InitialKey; - let server = server_endpoint.new_server_session(&TEST_SERVER_TRANSPORT_PARAMS); + let server = server_endpoint.new_server_session(&&server_params()[..]); let mut server_context = Context::new(endpoint::Type::Server, ServerState::WaitingClientHello); server_context.initial.crypto = Some(S::InitialKey::new_server(server_name.as_bytes())); - let client = - client_endpoint.new_client_session(&TEST_CLIENT_TRANSPORT_PARAMS, server_name.clone()); + let client = client_endpoint.new_client_session(&&client_params()[..], server_name.clone()); let mut client_context = Context::new(endpoint::Type::Client, ClientState::ClientHelloSent); client_context.initial.crypto = Some(C::InitialKey::new_client(server_name.as_bytes())); @@ -164,8 +186,14 @@ impl Pair { /// Continues progress of the handshake pub fn poll( &mut self, - client_hello_cb_done: Option>, + client_hello_cb_done: Option<&Arc>, ) -> Result<(), transport::Error> { + self.poll_start()?; + self.poll_finish(client_hello_cb_done)?; + Ok(()) + } + + pub fn poll_start(&mut self) -> Result<(), transport::Error> { match self.client.session.poll(&mut self.client.context) { Poll::Ready(res) => res?, Poll::Pending => (), @@ -174,9 +202,18 @@ impl Pair { Poll::Ready(res) => res?, Poll::Pending => (), } + Ok(()) + } + + pub fn poll_finish( + &mut self, + client_hello_cb_done: Option<&Arc>, + ) -> Result<(), transport::Error> { self.client.context.transfer(&mut self.server.context); + #[cfg(not(fuzzing))] eprintln!("1/2 RTT"); + if let Some(client_hello_cb_done) = client_hello_cb_done { // If the server is processing the async client hello callback, then return early // and poll it until it completes @@ -245,12 +282,12 @@ impl Pair { assert_eq!( self.client.context.transport_parameters.as_ref().unwrap(), - TEST_SERVER_TRANSPORT_PARAMS, + &server_params(), "client did not receive the server transport parameters" ); assert_eq!( self.server.context.transport_parameters.as_ref().unwrap(), - TEST_CLIENT_TRANSPORT_PARAMS, + &client_params(), "server did not receive the client transport parameters" ); assert_eq!( @@ -331,7 +368,10 @@ impl ServerState { } /// Harness to ensure a TLS implementation adheres to the session contract -pub struct Context { +pub struct Context +where + for<'a> Params: DecoderValue<'a>, +{ pub initial: Space, pub handshake: Space, pub application: Space, @@ -343,9 +383,13 @@ pub struct Context { endpoint: endpoint::Type, pub state: State, waker: Waker, + params: PhantomData, } -impl fmt::Debug for Context { +impl fmt::Debug for Context +where + for<'a> Params: DecoderValue<'a>, +{ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("Context") .field("initial", &self.initial) @@ -361,7 +405,10 @@ impl fmt::Debug for Context { } } -impl Context { +impl Context +where + for<'a> Params: DecoderValue<'a>, +{ fn new(endpoint: endpoint::Type, state: State) -> Self { let (waker, _wake_counter) = new_count_waker(); Self { @@ -376,18 +423,25 @@ impl Context { endpoint, state, waker, + params: PhantomData, } } /// Transfers incoming and outgoing buffers between two contexts - pub fn transfer(&mut self, other: &mut Context) { + pub fn transfer(&mut self, other: &mut Context) + where + for<'a> OP: DecoderValue<'a>, + { self.initial.transfer(&mut other.initial); self.handshake.transfer(&mut other.handshake); self.application.transfer(&mut other.application); } /// Finishes the test and asserts consistency - pub fn finish(&self, other: &Context) { + pub fn finish(&self, other: &Context) + where + for<'a> OP: DecoderValue<'a>, + { self.assert_done(); other.assert_done(); @@ -423,10 +477,21 @@ impl Context { assert!(self.transport_parameters.is_some()); } - fn on_application_params(&mut self, params: tls::ApplicationParameters) { + fn on_application_params( + &mut self, + params: tls::ApplicationParameters, + ) -> Result<(), transport::Error> { + // make sure the parameters parse correctly + let buffer = DecoderBuffer::new(params.transport_parameters); + let _ = buffer + .decode::() + .map_err(|_| transport::Error::FRAME_ENCODING_ERROR)?; + self.transport_parameters = Some(Bytes::copy_from_slice(params.transport_parameters)); + Ok(()) } + #[cfg(not(fuzzing))] fn log(&self, event: &str) { eprintln!( "{:?}: {:?}: {}: {}", @@ -436,6 +501,11 @@ impl Context { event, ); } + + #[cfg(fuzzing)] + fn log(&self, event: &str) { + let _ = event; + } } pub struct Space { @@ -558,7 +628,10 @@ fn protect_unprotect(protect: &P, unprotect: &U, tag ); } -impl tls::Context for Context { +impl tls::Context for Context +where + for<'a> Params: DecoderValue<'a>, +{ fn on_handshake_keys( &mut self, key: C::HandshakeKey, @@ -585,7 +658,7 @@ impl tls::Context for Context { ); self.log("0-rtt keys"); self.zero_rtt_crypto = Some((key, header_key)); - self.on_application_params(params); + self.on_application_params(params)?; Ok(()) } @@ -601,7 +674,7 @@ impl tls::Context for Context { ); self.log("1-rtt keys"); self.application.crypto = Some((key, header_key)); - self.on_application_params(params); + self.on_application_params(params)?; Ok(()) } diff --git a/quic/s2n-quic-rustls/Cargo.toml b/quic/s2n-quic-rustls/Cargo.toml index 21836a1607..077dff06b7 100644 --- a/quic/s2n-quic-rustls/Cargo.toml +++ b/quic/s2n-quic-rustls/Cargo.toml @@ -14,7 +14,7 @@ exclude = ["corpus.tar.gz"] bytes = { version = "1", default-features = false } rustls = { version = "0.20", features = ["quic"] } rustls-pemfile = "1" -s2n-codec = { version = "=0.6.1", path = "../../common/s2n-codec", default-features = false } +s2n-codec = { version = "=0.6.1", path = "../../common/s2n-codec", default-features = false, features = ["alloc"] } s2n-quic-core = { version = "=0.25.0", path = "../s2n-quic-core", default-features = false, features = ["alloc"] } s2n-quic-crypto = { version = "=0.25.0", path = "../s2n-quic-crypto", default-features = false } diff --git a/quic/s2n-quic-rustls/src/client.rs b/quic/s2n-quic-rustls/src/client.rs index 52d3a3e3ed..0bb922aa8e 100644 --- a/quic/s2n-quic-rustls/src/client.rs +++ b/quic/s2n-quic-rustls/src/client.rs @@ -1,7 +1,7 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::{certificate, encode_transport_parameters, session::Session}; +use crate::{certificate, session::Session}; use core::convert::TryFrom; use rustls::{quic, ClientConfig}; use s2n_codec::EncoderValue; @@ -65,7 +65,7 @@ impl tls::Endpoint for Client { //= https://www.rfc-editor.org/rfc/rfc9001#section-8.2 //# Endpoints MUST send the quic_transport_parameters extension; - let transport_parameters = encode_transport_parameters(transport_parameters); + let transport_parameters = transport_parameters.encode_to_vec(); let rustls_server_name = rustls::ServerName::try_from(server_name.as_ref()).expect("invalid server name"); diff --git a/quic/s2n-quic-rustls/src/lib.rs b/quic/s2n-quic-rustls/src/lib.rs index 2c2f4c6607..10930b7615 100644 --- a/quic/s2n-quic-rustls/src/lib.rs +++ b/quic/s2n-quic-rustls/src/lib.rs @@ -23,16 +23,6 @@ static PROTOCOL_VERSIONS: &[&rustls::SupportedProtocolVersion] = &[&rustls::vers /// The supported version of quic const QUIC_VERSION: rustls::quic::Version = rustls::quic::Version::V1; -/// Encodes transport parameters into a byte vec -pub(crate) fn encode_transport_parameters( - params: &Params, -) -> Vec { - let len = params.encoding_size(); - let mut buffer = vec![0; len]; - params.encode(&mut s2n_codec::EncoderBuffer::new(&mut buffer)); - buffer -} - #[test] fn client_server_test() { use s2n_quic_core::crypto::tls::{self, testing::certificates::*}; diff --git a/quic/s2n-quic-rustls/src/server.rs b/quic/s2n-quic-rustls/src/server.rs index afb5ee9993..3357a39aba 100644 --- a/quic/s2n-quic-rustls/src/server.rs +++ b/quic/s2n-quic-rustls/src/server.rs @@ -1,7 +1,7 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::{certificate, encode_transport_parameters, session::Session}; +use crate::{certificate, session::Session}; use rustls::{quic, ServerConfig}; use s2n_codec::EncoderValue; use s2n_quic_core::{application::ServerName, crypto::tls}; @@ -55,7 +55,7 @@ impl tls::Endpoint for Server { //= https://www.rfc-editor.org/rfc/rfc9001#section-8.2 //# Endpoints MUST send the quic_transport_parameters extension; - let transport_parameters = encode_transport_parameters(transport_parameters); + let transport_parameters = transport_parameters.encode_to_vec(); let session = rustls::ServerConnection::new_quic( self.config.clone(), diff --git a/quic/s2n-quic-tls/src/tests.rs b/quic/s2n-quic-tls/src/tests.rs index b042ec33eb..b900f04348 100644 --- a/quic/s2n-quic-tls/src/tests.rs +++ b/quic/s2n-quic-tls/src/tests.rs @@ -407,7 +407,7 @@ fn run_result( let mut pair = tls::testing::Pair::new(server, client, "localhost".into()); while pair.is_handshaking() { - pair.poll(client_hello_cb_done.clone())?; + pair.poll(client_hello_cb_done.as_ref())?; } pair.finish();