diff --git a/dc/s2n-quic-dc/src/crypto.rs b/dc/s2n-quic-dc/src/crypto.rs index ffc2ee17ed..e807c3dbbe 100644 --- a/dc/s2n-quic-dc/src/crypto.rs +++ b/dc/s2n-quic-dc/src/crypto.rs @@ -1,44 +1,55 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::credentials::Credentials; pub use bytes::buf::UninitSlice; use core::fmt; pub use s2n_quic_core::packet::KeyPhase; pub mod awslc; -#[cfg(any(test, feature = "testing"))] -pub mod testing; -pub mod encrypt { +pub mod seal { use super::*; - pub trait Key { - fn credentials(&self) -> &Credentials; - + pub trait Application { fn key_phase(&self) -> KeyPhase; fn tag_len(&self) -> usize; /// Encrypt a payload - fn encrypt( + fn encrypt( &self, - nonce: N, + packet_number: u64, header: &[u8], extra_payload: Option<&[u8]>, payload_and_tag: &mut [u8], ); + } - fn retransmission_tag( - &self, - original_packet_number: u64, - retransmission_packet_number: u64, - tag_out: &mut [u8], - ); + pub trait Control { + fn tag_len(&self) -> usize; + + fn sign(&self, header: &[u8], tag: &mut [u8]); + } + + pub mod control { + use super::*; + + /// Marker trait for keys to be used with stream control packets + pub trait Stream: Control { + fn retransmission_tag( + &self, + original_packet_number: u64, + retransmission_packet_number: u64, + tag_out: &mut [u8], + ); + } + + /// Marker trait for keys to be used with secret control packets + pub trait Secret: Control {} } } -pub mod decrypt { +pub mod open { use super::*; #[derive(PartialEq, Eq, Clone, Copy, Debug)] @@ -47,6 +58,10 @@ pub mod decrypt { ReplayPotentiallyDetected { gap: Option }, ReplayDefinitelyDetected, InvalidTag, + SingleUseKey, + UnsupportedOperation, + MacOnly, + RotationNotSupported, } impl fmt::Display for Error { @@ -61,6 +76,12 @@ pub mod decrypt { write!(f, "key replay potentially detected: unknown gap") } Self::InvalidTag => "invalid tag".fmt(f), + Self::SingleUseKey => "this key can only be used once".fmt(f), + Self::UnsupportedOperation => { + "this key cannot be used with the given operation".fmt(f) + } + Self::MacOnly => "this key is only capable of generating MACs".fmt(f), + Self::RotationNotSupported => "this key does not support key rotation".fmt(f), } } } @@ -69,16 +90,14 @@ pub mod decrypt { pub type Result = core::result::Result; - pub trait Key { - fn credentials(&self) -> &Credentials; - + pub trait Application { fn tag_len(&self) -> usize; /// Decrypt a payload - fn decrypt( + fn decrypt( &self, key_phase: KeyPhase, - nonce: N, + packet_number: u64, header: &[u8], payload_in: &[u8], tag: &[u8], @@ -86,21 +105,68 @@ pub mod decrypt { ) -> Result; /// Decrypt a payload - fn decrypt_in_place( + fn decrypt_in_place( &self, key_phase: KeyPhase, - nonce: N, + packet_number: u64, header: &[u8], payload_and_tag: &mut [u8], ) -> Result; + } - fn retransmission_tag( - &self, - key_phase: KeyPhase, - original_packet_number: u64, - retransmission_packet_number: u64, - tag_out: &mut [u8], - ); + pub trait Control { + fn tag_len(&self) -> usize; + + fn verify(&self, header: &[u8], tag: &[u8]) -> Result; + } + + pub mod control { + use super::*; + + /// Marker trait for keys to be used with stream control packets + pub trait Stream: Control { + fn retransmission_tag( + &self, + original_packet_number: u64, + retransmission_packet_number: u64, + tag_out: &mut [u8], + ) -> Result; + } + + pub mod stream { + /// A no-op implementation for reliable transports + #[derive(Clone, Default)] + pub struct Reliable(()); + + impl super::Control for Reliable { + #[inline] + fn tag_len(&self) -> usize { + 16 + } + + #[inline] + fn verify(&self, _header: &[u8], _tag: &[u8]) -> super::Result { + // this method should not be used on reliable transports + Err(super::Error::UnsupportedOperation) + } + } + + impl super::Stream for Reliable { + #[inline] + fn retransmission_tag( + &self, + _original_packet_number: u64, + _retransmission_packet_number: u64, + _tag_out: &mut [u8], + ) -> super::Result { + // this method should not be used on reliable transports + Err(super::Error::UnsupportedOperation) + } + } + } + + /// Marker trait for keys to be used with secret control packets + pub trait Secret: Control {} } } diff --git a/dc/s2n-quic-dc/src/crypto/awslc.rs b/dc/s2n-quic-dc/src/crypto/awslc.rs index c792806178..92b40bd743 100644 --- a/dc/s2n-quic-dc/src/crypto/awslc.rs +++ b/dc/s2n-quic-dc/src/crypto/awslc.rs @@ -2,7 +2,6 @@ // SPDX-License-Identifier: Apache-2.0 use super::IntoNonce; -use crate::credentials::Credentials; use aws_lc_rs::aead::{Aad, Algorithm, LessSafeKey, Nonce, UnboundKey, NONCE_LEN}; use s2n_quic_core::{assume, packet::KeyPhase}; @@ -10,216 +9,341 @@ pub use aws_lc_rs::aead::{AES_128_GCM, AES_256_GCM}; const TAG_LEN: usize = 16; -#[derive(Debug)] -pub struct EncryptKey { - credentials: Credentials, - key: LessSafeKey, - iv: Iv, -} +pub mod seal { + use super::*; + use crate::crypto::seal; -impl EncryptKey { - #[inline] - pub fn new( - credentials: Credentials, - key: &[u8], - iv: [u8; NONCE_LEN], - algorithm: &'static Algorithm, - ) -> Self { - let key = UnboundKey::new(algorithm, key).unwrap(); - let key = LessSafeKey::new(key); - Self { - credentials, - key, - iv: Iv(iv), - } + #[derive(Debug)] + pub struct Application { + key: LessSafeKey, + iv: Iv, } -} -impl super::encrypt::Key for EncryptKey { - #[inline] - fn credentials(&self) -> &Credentials { - &self.credentials + impl Application { + #[inline] + pub fn new(key: &[u8], iv: [u8; NONCE_LEN], algorithm: &'static Algorithm) -> Self { + let key = UnboundKey::new(algorithm, key).unwrap(); + let key = LessSafeKey::new(key); + Self { key, iv: Iv(iv) } + } + + #[inline] + pub fn algorithm(&self) -> &'static Algorithm { + self.key.algorithm() + } } - #[inline] - fn key_phase(&self) -> KeyPhase { - KeyPhase::Zero + impl seal::Application for Application { + #[inline] + fn key_phase(&self) -> KeyPhase { + KeyPhase::Zero + } + + #[inline(always)] + fn tag_len(&self) -> usize { + debug_assert_eq!(TAG_LEN, self.key.algorithm().tag_len()); + TAG_LEN + } + + #[inline] + fn encrypt( + &self, + packet_number: u64, + header: &[u8], + extra_payload: Option<&[u8]>, + payload_and_tag: &mut [u8], + ) { + let nonce = self.iv.nonce(packet_number); + let aad = Aad::from(header); + + let extra_in = extra_payload.unwrap_or(&[][..]); + + unsafe { + assume!(payload_and_tag.len() >= self.tag_len() + extra_in.len()); + } + + let inline_len = payload_and_tag.len() - self.tag_len() - extra_in.len(); + + unsafe { + assume!(payload_and_tag.len() >= inline_len); + } + let (in_out, extra_out_and_tag) = payload_and_tag.split_at_mut(inline_len); + + let result = + self.key + .seal_in_place_scatter(nonce, aad, in_out, extra_in, extra_out_and_tag); + + unsafe { + assume!(result.is_ok()); + } + } } - #[inline(always)] - fn tag_len(&self) -> usize { - debug_assert_eq!(TAG_LEN, self.key.algorithm().tag_len()); - TAG_LEN + pub mod control { + use super::{super::control::*, seal}; + + macro_rules! impl_control { + ($name:ident, $tag_len:expr) => { + #[derive(Debug)] + pub struct $name(Key); + + impl $name { + #[inline] + pub fn new(key: &[u8], algorithm: &'static Algorithm) -> Self { + let key = Key::new(*algorithm, key); + Self(key) + } + } + + impl seal::Control for $name { + #[inline] + fn tag_len(&self) -> usize { + $tag_len + } + + #[inline] + fn sign(&self, header: &[u8], tag: &mut [u8]) { + sign(&self.0, $tag_len, header, tag) + } + } + }; + } + + impl_control!(Stream, STREAM_TAG_LEN); + + impl seal::control::Stream for Stream { + #[inline] + fn retransmission_tag( + &self, + original_packet_number: u64, + retransmission_packet_number: u64, + tag_out: &mut [u8], + ) { + retransmission_tag( + &self.0, + original_packet_number, + retransmission_packet_number, + tag_out, + ) + } + } + + impl_control!(Secret, SECRET_TAG_LEN); + + impl seal::control::Secret for Secret {} } +} - #[inline] - fn encrypt( - &self, - nonce: N, - header: &[u8], - extra_payload: Option<&[u8]>, - payload_and_tag: &mut [u8], - ) { - let nonce = self.iv.nonce(nonce); - let aad = Aad::from(header); +pub mod open { + use super::*; + use crate::crypto::{ + open::{self, *}, + UninitSlice, + }; + use s2n_quic_core::ensure; + + #[derive(Debug)] + pub struct Application { + key: LessSafeKey, + iv: Iv, + } - let extra_in = extra_payload.unwrap_or(&[][..]); + impl Application { + #[inline] + pub fn new(key: &[u8], iv: [u8; NONCE_LEN], algorithm: &'static Algorithm) -> Self { + let key = UnboundKey::new(algorithm, key).unwrap(); + let key = LessSafeKey::new(key); + Self { key, iv: Iv(iv) } + } + } - unsafe { - assume!(payload_and_tag.len() >= self.tag_len() + extra_in.len()); + impl open::Application for Application { + #[inline] + fn tag_len(&self) -> usize { + debug_assert_eq!(TAG_LEN, self.key.algorithm().tag_len()); + TAG_LEN } - let inline_len = payload_and_tag.len() - self.tag_len() - extra_in.len(); + #[inline] + fn decrypt( + &self, + key_phase: KeyPhase, + packet_number: u64, + header: &[u8], + payload_in: &[u8], + tag: &[u8], + payload_out: &mut UninitSlice, + ) -> Result { + ensure!( + key_phase == KeyPhase::Zero, + Err(Error::RotationNotSupported) + ); + debug_assert_eq!(payload_in.len(), payload_out.len()); + + let nonce = self.iv.nonce(packet_number); + let aad = Aad::from(header); + + let payload_out = unsafe { + // SAFETY: the payload is not read by aws-lc, only written to + let ptr = payload_out.as_mut_ptr(); + let len = payload_out.len(); + core::slice::from_raw_parts_mut(ptr, len) + }; - unsafe { - assume!(payload_and_tag.len() >= inline_len); + self.key + .open_separate_gather(nonce, aad, payload_in, tag, payload_out) + .map_err(|_| Error::InvalidTag) } - let (in_out, extra_out_and_tag) = payload_and_tag.split_at_mut(inline_len); - let result = + #[inline] + fn decrypt_in_place( + &self, + key_phase: KeyPhase, + packet_number: u64, + header: &[u8], + payload_and_tag: &mut [u8], + ) -> Result { + ensure!( + key_phase == KeyPhase::Zero, + Err(Error::RotationNotSupported) + ); + let nonce = self.iv.nonce(packet_number); + let aad = Aad::from(header); + self.key - .seal_in_place_scatter(nonce, aad, in_out, extra_in, extra_out_and_tag); + .open_in_place(nonce, aad, payload_and_tag) + .map_err(|_| Error::InvalidTag)?; - unsafe { - assume!(result.is_ok()); + Ok(()) } } - #[inline] - fn retransmission_tag( - &self, - original_packet_number: u64, - retransmission_packet_number: u64, - tag_out: &mut [u8], - ) { - retransmission_tag( - &self.key, - &self.iv, - original_packet_number, - retransmission_packet_number, - tag_out, - ) - } -} - -#[derive(Debug)] -pub struct DecryptKey { - credentials: Credentials, - key: LessSafeKey, - iv: Iv, -} + pub mod control { + use super::{super::control::*, open}; + + macro_rules! impl_control { + ($name:ident, $tag_len:expr) => { + #[derive(Debug)] + pub struct $name(Key); + + impl $name { + #[inline] + pub fn new(key: &[u8], algorithm: &'static Algorithm) -> Self { + let key = Key::new(*algorithm, key); + Self(key) + } + } + + impl open::Control for $name { + #[inline] + fn tag_len(&self) -> usize { + $tag_len + } + + #[inline] + fn verify(&self, header: &[u8], tag: &[u8]) -> open::Result { + verify(&self.0, $tag_len, header, tag) + } + } + }; + } -impl DecryptKey { - #[inline] - pub fn new( - credentials: Credentials, - key: &[u8], - iv: [u8; NONCE_LEN], - algorithm: &'static Algorithm, - ) -> Self { - let key = UnboundKey::new(algorithm, key).unwrap(); - let key = LessSafeKey::new(key); - Self { - credentials, - key, - iv: Iv(iv), + impl_control!(Stream, STREAM_TAG_LEN); + + impl open::control::Stream for Stream { + #[inline] + fn retransmission_tag( + &self, + original_packet_number: u64, + retransmission_packet_number: u64, + tag_out: &mut [u8], + ) -> open::Result { + retransmission_tag( + &self.0, + original_packet_number, + retransmission_packet_number, + tag_out, + ); + Ok(()) + } } + + impl_control!(Secret, SECRET_TAG_LEN); + + impl open::control::Secret for Secret {} } } -impl super::decrypt::Key for DecryptKey { - #[inline] - fn credentials(&self) -> &Credentials { - &self.credentials - } +mod control { + use crate::crypto::open; + use aws_lc_rs::hmac; + + pub use hmac::{Algorithm, Key}; + + //= https://datatracker.ietf.org/doc/html/rfc2104#section-5 + //# A well-known practice with message authentication codes is to + //# truncate the output of the MAC and output only part of the bits + //# (e.g., [MM, ANSI]). Preneel and van Oorschot [PV] show some + //# analytical advantages of truncating the output of hash-based MAC + //# functions. The results in this area are not absolute as for the + //# overall security advantages of truncation. It has advantages (less + //# information on the hash result available to an attacker) and + //# disadvantages (less bits to predict for the attacker). Applications + //# of HMAC can choose to truncate the output of HMAC by outputting the t + //# leftmost bits of the HMAC computation for some parameter t (namely, + //# the computation is carried in the normal way as defined in section 2 + //# above but the end result is truncated to t bits). We recommend that + //# the output length t be not less than half the length of the hash + //# output (to match the birthday attack bound) and not less than 80 bits + //# (a suitable lower bound on the number of bits that need to be + //# predicted by an attacker). + pub const STREAM_TAG_LEN: usize = 16; + pub const SECRET_TAG_LEN: usize = crate::packet::secret_control::TAG_LEN; #[inline] - fn tag_len(&self) -> usize { - debug_assert_eq!(TAG_LEN, self.key.algorithm().tag_len()); - TAG_LEN + pub fn sign(key: &Key, expected_tag_len: usize, header: &[u8], tag: &mut [u8]) { + debug_assert_eq!(tag.len(), expected_tag_len); + let out = hmac::sign(key, header); + let out = out.as_ref(); + let len = tag.len().min(out.len()); + tag[..len].copy_from_slice(&out[..len]); } #[inline] - fn decrypt( - &self, - _key_phase: KeyPhase, - nonce: N, + pub fn verify( + key: &Key, + expected_tag_len: usize, header: &[u8], - payload_in: &[u8], tag: &[u8], - payload_out: &mut super::UninitSlice, - ) -> super::decrypt::Result { - debug_assert_eq!(payload_in.len(), payload_out.len()); - - let nonce = self.iv.nonce(nonce); - let aad = Aad::from(header); - - let payload_out = unsafe { - // SAFETY: the payload is not read by aws-lc, only written to - let ptr = payload_out.as_mut_ptr(); - let len = payload_out.len(); - core::slice::from_raw_parts_mut(ptr, len) - }; - - self.key - .open_separate_gather(nonce, aad, payload_in, tag, payload_out) - .map_err(|_| super::decrypt::Error::InvalidTag) - } - - #[inline] - fn decrypt_in_place( - &self, - _key_phase: KeyPhase, - nonce: N, - header: &[u8], - payload_and_tag: &mut [u8], - ) -> super::decrypt::Result { - let nonce = self.iv.nonce(nonce); - let aad = Aad::from(header); + ) -> open::Result<()> { + if tag.len() != expected_tag_len { + return Err(open::Error::InvalidTag); + } - self.key - .open_in_place(nonce, aad, payload_and_tag) - .map_err(|_| super::decrypt::Error::InvalidTag)?; + // instead of using the `hmac::verify` function, we implement our own that controls the + // amount of truncation that happens to the tag. + let out = hmac::sign(key, header); + let out = out.as_ref(); + let len = tag.len().min(out.len()); - Ok(()) + aws_lc_rs::constant_time::verify_slices_are_equal(&tag[..len], &out[..len]) + .map_err(|_| open::Error::InvalidTag) } #[inline] - fn retransmission_tag( - &self, - _key_phase: KeyPhase, + pub fn retransmission_tag( + key: &Key, original_packet_number: u64, retransmission_packet_number: u64, tag_out: &mut [u8], ) { - retransmission_tag( - &self.key, - &self.iv, - original_packet_number, - retransmission_packet_number, - tag_out, - ) - } -} - -#[inline] -fn retransmission_tag( - key: &LessSafeKey, - iv: &Iv, - original_packet_number: u64, - retransmission_packet_number: u64, - tag_out: &mut [u8], -) { - debug_assert_eq!(tag_out.len(), TAG_LEN); - - let nonce = iv.nonce(retransmission_packet_number); - let aad = original_packet_number.to_be_bytes(); - let aad = Aad::from(&aad); - - let tag = key.seal_in_place_separate_tag(nonce, aad, &mut []).unwrap(); - - for (a, b) in tag_out.iter_mut().zip(tag.as_ref()) { - *a ^= b; + let mut v = [0; 16]; + v[..8].copy_from_slice(&original_packet_number.to_be_bytes()); + v[8..].copy_from_slice(&retransmission_packet_number.to_be_bytes()); + let tag = hmac::sign(key, &v); + for (a, b) in tag_out.iter_mut().zip(tag.as_ref()) { + *a ^= b; + } } } diff --git a/dc/s2n-quic-dc/src/crypto/testing.rs b/dc/s2n-quic-dc/src/crypto/testing.rs deleted file mode 100644 index f357bc5794..0000000000 --- a/dc/s2n-quic-dc/src/crypto/testing.rs +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -use super::IntoNonce; -use crate::credentials::Credentials; -use s2n_quic_core::{assume, packet::KeyPhase}; - -#[derive(Clone, Debug)] -pub struct Key { - credentials: Credentials, - tag_len: usize, -} - -impl Key { - #[inline] - pub fn new(credentials: Credentials) -> Self { - Self { - credentials, - tag_len: 16, - } - } -} - -impl super::encrypt::Key for Key { - #[inline] - fn credentials(&self) -> &Credentials { - &self.credentials - } - - #[inline] - fn key_phase(&self) -> KeyPhase { - KeyPhase::Zero - } - - #[inline] - fn tag_len(&self) -> usize { - self.tag_len - } - - #[inline] - fn encrypt( - &self, - _nonce: N, - _header: &[u8], - extra_payload: Option<&[u8]>, - payload_and_tag: &mut [u8], - ) { - if let Some(extra_payload) = extra_payload { - let offset = payload_and_tag.len() - self.tag_len() - extra_payload.len(); - let dest = &mut payload_and_tag[offset..]; - unsafe { - assume!(dest.len() == extra_payload.len() + self.tag_len); - } - let (dest, tag) = dest.split_at_mut(extra_payload.len()); - dest.copy_from_slice(extra_payload); - tag.fill(0); - } - } - - #[inline] - fn retransmission_tag( - &self, - _original_packet_number: u64, - _retransmission_packet_number: u64, - _tag_out: &mut [u8], - ) { - // no-op - } -} - -impl super::decrypt::Key for Key { - #[inline] - fn credentials(&self) -> &Credentials { - &self.credentials - } - - #[inline] - fn tag_len(&self) -> usize { - self.tag_len - } - - #[inline] - fn decrypt( - &self, - _key_phase: KeyPhase, - _nonce: N, - _header: &[u8], - payload_in: &[u8], - _tag: &[u8], - payload_out: &mut bytes::buf::UninitSlice, - ) -> Result<(), super::decrypt::Error> { - payload_out.copy_from_slice(payload_in); - Ok(()) - } - - #[inline] - fn decrypt_in_place( - &self, - _key_phase: KeyPhase, - _nonce: N, - _header: &[u8], - _payload_and_tag: &mut [u8], - ) -> Result<(), super::decrypt::Error> { - Ok(()) - } - - #[inline] - fn retransmission_tag( - &self, - _key_phase: KeyPhase, - _original_packet_number: u64, - _retransmission_packet_number: u64, - _tag_out: &mut [u8], - ) { - // no-op - } -} diff --git a/dc/s2n-quic-dc/src/datagram/tunneled/recv.rs b/dc/s2n-quic-dc/src/datagram/tunneled/recv.rs index dd029a1a3a..bbe3fce854 100644 --- a/dc/s2n-quic-dc/src/datagram/tunneled/recv.rs +++ b/dc/s2n-quic-dc/src/datagram/tunneled/recv.rs @@ -2,13 +2,13 @@ // SPDX-License-Identifier: Apache-2.0 use crate::{ - crypto::{decrypt, UninitSlice}, + crypto::{open, UninitSlice}, packet::datagram::{decoder, Tag}, }; use s2n_codec::{decoder_invariant, DecoderBufferMut, DecoderError}; use s2n_quic_core::packet::number::{PacketNumberSpace, SlidingWindow, SlidingWindowError}; -pub use crate::crypto::decrypt::Error; +pub use crate::crypto::open::Error; pub use decoder::Packet; #[derive(Default)] @@ -37,11 +37,11 @@ impl decoder::Validator for TagValidator { } } -pub struct Receiver { +pub struct Receiver { key: K, } -impl Receiver { +impl Receiver { pub fn new(key: K) -> Self { Self { key } } diff --git a/dc/s2n-quic-dc/src/datagram/tunneled/send.rs b/dc/s2n-quic-dc/src/datagram/tunneled/send.rs index 5524e76c7f..a13de83d2c 100644 --- a/dc/s2n-quic-dc/src/datagram/tunneled/send.rs +++ b/dc/s2n-quic-dc/src/datagram/tunneled/send.rs @@ -3,7 +3,8 @@ use crate::{ control, - crypto::encrypt, + credentials::Credentials, + crypto::seal, packet::{self, datagram::encoder}, }; use core::sync::atomic::{AtomicU64, Ordering}; @@ -19,17 +20,19 @@ pub enum Error { pub struct Sender { encrypt_key: E, + credentials: Credentials, packet_number: AtomicU64, } impl Sender where - E: encrypt::Key, + E: seal::Application, { #[inline] - pub fn new(encrypt_key: E) -> Self { + pub fn new(encrypt_key: E, credentials: Credentials) -> Self { Self { encrypt_key, + credentials, packet_number: AtomicU64::new(0), } } @@ -89,6 +92,7 @@ where payload_len, &mut cleartext_payload, &self.encrypt_key, + &self.credentials, ) }; diff --git a/dc/s2n-quic-dc/src/msg/cmsg.rs b/dc/s2n-quic-dc/src/msg/cmsg.rs index 5ccc8a2f18..e74e562c54 100644 --- a/dc/s2n-quic-dc/src/msg/cmsg.rs +++ b/dc/s2n-quic-dc/src/msg/cmsg.rs @@ -43,12 +43,6 @@ impl Receiver { match (cmsg.cmsg_level, cmsg.cmsg_type) { (level, ty) if features::tos::is_match(level, ty) => { if let Some(ecn) = features::tos::decode(value) { - // TODO remove this conversion once we consolidate the s2n-quic-core crates - // convert between the vendored s2n-quic-core types - let ecn = { - let ecn = ecn as u8; - ExplicitCongestionNotification::new(ecn) - }; self.ecn = ecn; } else { continue; diff --git a/dc/s2n-quic-dc/src/msg/send.rs b/dc/s2n-quic-dc/src/msg/send.rs index 116a2a16f0..c4e011e242 100644 --- a/dc/s2n-quic-dc/src/msg/send.rs +++ b/dc/s2n-quic-dc/src/msg/send.rs @@ -290,8 +290,7 @@ impl Message { let mut cmsg_storage = cmsg::Storage::<{ cmsg::ENCODER_LEN }>::default(); let mut cmsg = cmsg_storage.encoder(); if ecn != ExplicitCongestionNotification::NotEct { - // TODO enable this once we consolidate s2n-quic-core crates - // let _ = cmsg.encode_ecn(ecn, &addr); + let _ = cmsg.encode_ecn(ecn, &addr.get()); } if iov.len() > 1 { diff --git a/dc/s2n-quic-dc/src/packet.rs b/dc/s2n-quic-dc/src/packet.rs index d993e9af83..ed143fdcf9 100644 --- a/dc/s2n-quic-dc/src/packet.rs +++ b/dc/s2n-quic-dc/src/packet.rs @@ -76,13 +76,11 @@ impl<'a> s2n_codec::DecoderParameterizedValueMut<'a> for Packet<'a> { Ok((Self::Datagram(packet), decoder)) } Tag::StaleKey(_) => { - let (packet, decoder) = - secret_control::stale_key::Packet::decode(decoder, tag_len)?; + let (packet, decoder) = secret_control::stale_key::Packet::decode(decoder)?; Ok((Self::StaleKey(packet), decoder)) } Tag::ReplayDetected(_) => { - let (packet, decoder) = - secret_control::replay_detected::Packet::decode(decoder, tag_len)?; + let (packet, decoder) = secret_control::replay_detected::Packet::decode(decoder)?; Ok((Self::ReplayDetected(packet), decoder)) } Tag::UnknownPathSecret(_) => { diff --git a/dc/s2n-quic-dc/src/packet/control.rs b/dc/s2n-quic-dc/src/packet/control.rs index e8095a9f9e..2dac612ad3 100644 --- a/dc/s2n-quic-dc/src/packet/control.rs +++ b/dc/s2n-quic-dc/src/packet/control.rs @@ -3,11 +3,8 @@ use super::tag::Common; use core::fmt; -use s2n_quic_core::packet::KeyPhase; use zerocopy::{AsBytes, FromBytes, FromZeroes, Unaligned}; -const NONCE_MASK: u64 = 1 << 63; - pub mod decoder; pub mod encoder; @@ -29,7 +26,6 @@ impl fmt::Debug for Tag { f.debug_struct("control::Tag") .field("is_stream", &self.is_stream()) .field("has_application_header", &self.has_application_header()) - .field("key_phase", &self.key_phase()) .finish() } } @@ -37,7 +33,6 @@ impl fmt::Debug for Tag { impl Tag { pub const IS_STREAM_MASK: u8 = 0b0100; pub const HAS_APPLICATION_HEADER_MASK: u8 = 0b0010; - pub const KEY_PHASE_MASK: u8 = 0b0001; pub const MIN: u8 = 0b0101_0000; pub const MAX: u8 = 0b0101_1111; @@ -62,24 +57,6 @@ impl Tag { self.0.get(Self::HAS_APPLICATION_HEADER_MASK) } - #[inline] - pub fn set_key_phase(&mut self, key_phase: KeyPhase) { - let v = match key_phase { - KeyPhase::Zero => false, - KeyPhase::One => true, - }; - self.0.set(Self::KEY_PHASE_MASK, v) - } - - #[inline] - pub fn key_phase(&self) -> KeyPhase { - if self.0.get(Self::KEY_PHASE_MASK) { - KeyPhase::One - } else { - KeyPhase::Zero - } - } - #[inline] fn validate(&self) -> Result<(), s2n_codec::DecoderError> { let range = Self::MIN..=Self::MAX; diff --git a/dc/s2n-quic-dc/src/packet/control/decoder.rs b/dc/s2n-quic-dc/src/packet/control/decoder.rs index 7157b3b54c..aa4152de49 100644 --- a/dc/s2n-quic-dc/src/packet/control/decoder.rs +++ b/dc/s2n-quic-dc/src/packet/control/decoder.rs @@ -3,10 +3,7 @@ use crate::{ credentials::Credentials, - packet::{ - control::{self, Tag}, - stream, WireVersion, - }, + packet::{control::Tag, stream, WireVersion}, }; use s2n_codec::{ decoder_invariant, CheckedRange, DecoderBufferMut, DecoderBufferMutResult as R, DecoderError, @@ -87,11 +84,6 @@ impl<'a> Packet<'a> { self.stream_id.as_ref() } - #[inline] - pub fn crypto_nonce(&self) -> u64 { - self.packet_number.as_u64() | control::NONCE_MASK - } - #[inline] pub fn packet_number(&self) -> PacketNumber { self.packet_number diff --git a/dc/s2n-quic-dc/src/packet/control/encoder.rs b/dc/s2n-quic-dc/src/packet/control/encoder.rs index bd0b917b33..cd22ae8852 100644 --- a/dc/s2n-quic-dc/src/packet/control/encoder.rs +++ b/dc/s2n-quic-dc/src/packet/control/encoder.rs @@ -2,11 +2,9 @@ // SPDX-License-Identifier: Apache-2.0 use crate::{ - crypto::encrypt, - packet::{ - control::{Tag, NONCE_MASK}, - stream, WireVersion, - }, + credentials::Credentials, + crypto, + packet::{control::Tag, stream, WireVersion}, }; use s2n_codec::{Encoder, EncoderBuffer, EncoderValue}; use s2n_quic_core::{assume, buffer, varint::VarInt}; @@ -22,24 +20,22 @@ pub fn encode( control_data_len: VarInt, control_data: &CD, crypto: &C, + credentials: &Credentials, ) -> usize where H: buffer::reader::Storage, CD: EncoderValue, - C: encrypt::Key, + C: crypto::seal::control::Stream, { debug_assert_ne!(source_control_port, 0); let mut tag = Tag::default(); - tag.set_key_phase(crypto.key_phase()); tag.set_is_stream(stream_id.is_some()); tag.set_has_application_header(*header_len > 0); encoder.encode(&tag); - let nonce = *packet_number | NONCE_MASK; - // encode the credentials being used - encoder.encode(crypto.credentials()); + encoder.encode(credentials); // wire version - we only support `0` currently encoder.encode(&WireVersion::ZERO); @@ -76,12 +72,12 @@ where let slice = encoder.as_mut_slice(); { - let (header, payload_and_tag) = unsafe { + let (header, tag) = unsafe { assume!(slice.len() >= payload_offset); slice.split_at_mut(payload_offset) }; - crypto.encrypt(nonce, header, None, payload_and_tag); + crypto.sign(header, tag); } if cfg!(debug_assertions) { diff --git a/dc/s2n-quic-dc/src/packet/datagram/encoder.rs b/dc/s2n-quic-dc/src/packet/datagram/encoder.rs index 182a98a01b..026b9c5b90 100644 --- a/dc/s2n-quic-dc/src/packet/datagram/encoder.rs +++ b/dc/s2n-quic-dc/src/packet/datagram/encoder.rs @@ -3,7 +3,8 @@ use crate::{ credentials, - crypto::encrypt, + credentials::Credentials, + crypto::seal, packet::{datagram::Tag, WireVersion}, }; use s2n_codec::{Encoder, EncoderBuffer, EncoderValue}; @@ -65,12 +66,13 @@ pub fn encode( payload_len: super::PayloadLen, payload: &mut P, crypto: &C, + credentials: &Credentials, ) -> usize where H: buffer::reader::Storage, P: buffer::reader::Storage, CD: EncoderValue, - C: encrypt::Key, + C: seal::Application, { let mut tag = super::Tag::default(); tag.set_is_connected(packet_number.is_some()); @@ -84,7 +86,7 @@ where let nonce = *packet_number.unwrap_or(super::PacketNumber::ZERO); // encode the credentials being used - encoder.encode(crypto.credentials()); + encoder.encode(credentials); // wire version - we only support `0` currently encoder.encode(&WireVersion::ZERO); diff --git a/dc/s2n-quic-dc/src/packet/secret_control.rs b/dc/s2n-quic-dc/src/packet/secret_control.rs index 7f15b8eac7..62b4d40061 100644 --- a/dc/s2n-quic-dc/src/packet/secret_control.rs +++ b/dc/s2n-quic-dc/src/packet/secret_control.rs @@ -1,11 +1,7 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::{ - credentials, - crypto::{decrypt, encrypt}, - packet::WireVersion, -}; +use crate::{credentials, crypto::seal, packet::WireVersion}; use s2n_codec::{ decoder_invariant, decoder_value, DecoderBuffer, DecoderBufferMut, DecoderBufferMutResult as Rm, DecoderBufferResult as R, DecoderError, DecoderValue, Encoder, @@ -17,13 +13,13 @@ use zerocopy::{AsBytes, FromBytes, FromZeroes, Unaligned}; #[macro_use] mod decoder; mod encoder; -mod nonce; const UNKNOWN_PATH_SECRET: u8 = 0b0110_0000; const STALE_KEY: u8 = 0b0110_0001; const REPLAY_DETECTED: u8 = 0b0110_0010; -pub const MAX_PACKET_SIZE: usize = 50; +pub const MAX_PACKET_SIZE: usize = 64; +pub const TAG_LEN: usize = 16; macro_rules! impl_tag { ($tag:expr) => { @@ -79,16 +75,12 @@ macro_rules! impl_tests { ($ty:ident) => { #[test] fn round_trip_test() { - use crate::crypto::awslc::{DecryptKey, EncryptKey, AES_128_GCM}; + use crate::crypto::awslc::{open, seal}; + use aws_lc_rs::hmac::HMAC_SHA256; - let creds = crate::credentials::Credentials { - id: Default::default(), - key_id: Default::default(), - }; let key = &[0u8; 16]; - let iv = [0u8; 12]; - let encrypt = EncryptKey::new(creds, key, iv, &AES_128_GCM); - let decrypt = DecryptKey::new(creds, key, iv, &AES_128_GCM); + let sealer = seal::control::Secret::new(key, &HMAC_SHA256); + let opener = open::control::Secret::new(key, &HMAC_SHA256); bolero::check!() .with_type::<$ty>() @@ -98,27 +90,22 @@ macro_rules! impl_tests { let mut buffer = [0u8; MAX_PACKET_SIZE]; let len = { let encoder = s2n_codec::EncoderBuffer::new(&mut buffer); - value.encode(encoder, &encrypt) + value.encode(encoder, &sealer) }; { - use decrypt::Key as _; let buffer = s2n_codec::DecoderBufferMut::new(&mut buffer[..len]); - let (decoded, _) = Packet::decode(buffer, decrypt.tag_len()).unwrap(); - let decoded = decoded.authenticate(&decrypt).unwrap(); + let (decoded, _) = Packet::decode(buffer).unwrap(); + let decoded = decoded.authenticate(&opener).unwrap(); assert_eq!(value, decoded); } { - use decrypt::Key as _; let buffer = s2n_codec::DecoderBufferMut::new(&mut buffer[..len]); - let (decoded, _) = crate::packet::secret_control::Packet::decode( - buffer, - decrypt.tag_len(), - ) - .unwrap(); + let (decoded, _) = + crate::packet::secret_control::Packet::decode(buffer).unwrap(); if let crate::packet::secret_control::Packet::$ty(decoded) = decoded { - let decoded = decoded.authenticate(&decrypt).unwrap(); + let decoded = decoded.authenticate(&opener).unwrap(); assert_eq!(value, decoded); } else { panic!("decoded as the wrong packet type"); @@ -133,7 +120,6 @@ pub mod replay_detected; pub mod stale_key; pub mod unknown_path_secret; -pub use nonce::Nonce; pub use replay_detected::ReplayDetected; pub use stale_key::StaleKey; pub use unknown_path_secret::UnknownPathSecret; @@ -147,7 +133,7 @@ pub enum Packet<'a> { impl<'a> Packet<'a> { #[inline] - pub fn decode(buffer: DecoderBufferMut<'a>, crypto_tag_len: usize) -> Rm { + pub fn decode(buffer: DecoderBufferMut<'a>) -> Rm { let tag = buffer.peek_byte(0)?; Ok(match tag { @@ -156,11 +142,11 @@ impl<'a> Packet<'a> { (Self::UnknownPathSecret(packet), buffer) } STALE_KEY => { - let (packet, buffer) = stale_key::Packet::decode(buffer, crypto_tag_len)?; + let (packet, buffer) = stale_key::Packet::decode(buffer)?; (Self::StaleKey(packet), buffer) } REPLAY_DETECTED => { - let (packet, buffer) = replay_detected::Packet::decode(buffer, crypto_tag_len)?; + let (packet, buffer) = replay_detected::Packet::decode(buffer)?; (Self::ReplayDetected(packet), buffer) } _ => return Err(DecoderError::InvariantViolation("invalid tag")), diff --git a/dc/s2n-quic-dc/src/packet/secret_control/decoder.rs b/dc/s2n-quic-dc/src/packet/secret_control/decoder.rs index 58d9e31fe9..da8da8d4bd 100644 --- a/dc/s2n-quic-dc/src/packet/secret_control/decoder.rs +++ b/dc/s2n-quic-dc/src/packet/secret_control/decoder.rs @@ -16,10 +16,9 @@ macro_rules! impl_packet { impl<'a> Packet<'a> { #[inline] - pub fn decode(buffer: DecoderBufferMut<'a>, crypto_tag_len: usize) -> Rm { + pub fn decode(buffer: DecoderBufferMut<'a>) -> Rm { let header_len = decoder::header_len::<$name>(buffer.peek())?; - let ((header, value, crypto_tag), buffer) = - decoder::header(buffer, header_len, crypto_tag_len)?; + let ((header, value, crypto_tag), buffer) = decoder::header(buffer, header_len)?; let packet = Self { header, value, @@ -36,7 +35,7 @@ macro_rules! impl_packet { #[inline] pub fn authenticate(&self, crypto: &C) -> Option<&$name> where - C: decrypt::Key, + C: crate::crypto::open::control::Secret, { let Self { header, @@ -44,17 +43,7 @@ macro_rules! impl_packet { crypto_tag, } = self; - crypto - .decrypt( - // these don't rotate - s2n_quic_core::packet::KeyPhase::Zero, - value.nonce(), - header, - &[], - crypto_tag, - bytes::buf::UninitSlice::new(&mut []), - ) - .ok()?; + crypto.verify(header, crypto_tag).ok()?; Some(value) } @@ -73,11 +62,7 @@ where } #[inline] -pub fn header<'a, T>( - buffer: DecoderBufferMut<'a>, - header_len: usize, - crypto_tag_len: usize, -) -> Rm<'a, (&[u8], T, &[u8])> +pub fn header<'a, T>(buffer: DecoderBufferMut<'a>, header_len: usize) -> Rm<'a, (&[u8], T, &[u8])> where T: DecoderValue<'a>, { @@ -86,7 +71,7 @@ where let (value, _) = header.decode::()?; let header = header.into_less_safe_slice(); - let (crypto_tag, buffer) = buffer.decode_slice(crypto_tag_len)?; + let (crypto_tag, buffer) = buffer.decode_slice(super::TAG_LEN)?; let crypto_tag = crypto_tag.into_less_safe_slice(); Ok(((header, value, crypto_tag), buffer)) diff --git a/dc/s2n-quic-dc/src/packet/secret_control/encoder.rs b/dc/s2n-quic-dc/src/packet/secret_control/encoder.rs index 9d02f27b52..991426197a 100644 --- a/dc/s2n-quic-dc/src/packet/secret_control/encoder.rs +++ b/dc/s2n-quic-dc/src/packet/secret_control/encoder.rs @@ -1,15 +1,14 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use super::Nonce; -use crate::crypto::encrypt; +use crate::crypto::seal; use s2n_codec::{Encoder, EncoderBuffer}; use s2n_quic_core::assume; #[inline] -pub fn finish(mut encoder: EncoderBuffer, nonce: Nonce, crypto: &C) -> usize +pub fn finish(mut encoder: EncoderBuffer, crypto: &C) -> usize where - C: encrypt::Key, + C: seal::control::Secret, { let header_offset = encoder.len(); @@ -18,12 +17,12 @@ where let packet_len = encoder.len(); let slice = encoder.as_mut_slice(); - let (header, payload_and_tag) = unsafe { + let (header, tag) = unsafe { assume!(slice.len() >= header_offset); slice.split_at_mut(header_offset) }; - crypto.encrypt(nonce, header, None, payload_and_tag); + crypto.sign(header, tag); packet_len } diff --git a/dc/s2n-quic-dc/src/packet/secret_control/nonce.rs b/dc/s2n-quic-dc/src/packet/secret_control/nonce.rs deleted file mode 100644 index a71a6cddfb..0000000000 --- a/dc/s2n-quic-dc/src/packet/secret_control/nonce.rs +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -use super::{REPLAY_DETECTED, STALE_KEY, UNKNOWN_PATH_SECRET}; -use crate::crypto::IntoNonce; - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(test, derive(bolero_generator::TypeGenerator))] -pub enum Nonce { - UnknownPathSecret, - StaleKey { - // This is the minimum key ID the server will accept (at the time of sending). - // - // This is used for cases where the server intentionally drops state in a manner that cuts - // out a chunk of not-yet-used key ID space. - min_key_id: u64, - }, - ReplayDetected { - // This is the key ID we rejected. Should only be sent for *definitively* rejected keys: - // use StaleKey if the key ID's status is MaybeReplayed. - // - // The client should enqueue a handshake but it should keep in mind that this might be - // caused by an attacker replaying packets, so maybe impose rate limiting or ignore "really - // old" replay detected packets. - rejected_key_id: u64, - }, -} - -impl IntoNonce for Nonce { - #[inline] - fn into_nonce(self) -> [u8; 12] { - let mut nonce = [0; 12]; - match self { - Self::UnknownPathSecret => { - nonce[0] = UNKNOWN_PATH_SECRET; - } - Self::StaleKey { min_key_id } => { - nonce[0] = STALE_KEY; - nonce[1..9].copy_from_slice(&min_key_id.to_be_bytes()); - } - Self::ReplayDetected { rejected_key_id } => { - nonce[0] = REPLAY_DETECTED; - nonce[1..9].copy_from_slice(&rejected_key_id.to_be_bytes()); - } - } - nonce - } -} - -#[cfg(test)] -mod tests { - use super::*; - use bolero::check; - - /// ensures output nonces are only equal if the messages are equal - #[test] - #[cfg_attr(kani, kani::proof, kani::solver(cadical))] - fn nonce_uniqueness() { - check!().with_type::<(Nonce, Nonce)>().for_each(|(a, b)| { - if a == b { - assert_eq!(a.into_nonce(), b.into_nonce()); - } else { - assert_ne!(a.into_nonce(), b.into_nonce()); - } - }); - } -} diff --git a/dc/s2n-quic-dc/src/packet/secret_control/replay_detected.rs b/dc/s2n-quic-dc/src/packet/secret_control/replay_detected.rs index 7fb7770431..f511371588 100644 --- a/dc/s2n-quic-dc/src/packet/secret_control/replay_detected.rs +++ b/dc/s2n-quic-dc/src/packet/secret_control/replay_detected.rs @@ -18,21 +18,14 @@ impl ReplayDetected { #[inline] pub fn encode(&self, mut encoder: EncoderBuffer, crypto: &C) -> usize where - C: encrypt::Key, + C: seal::control::Secret, { encoder.encode(&Tag::default()); encoder.encode(&self.credential_id); encoder.encode(&self.wire_version); encoder.encode(&self.rejected_key_id); - encoder::finish(encoder, self.nonce(), crypto) - } - - #[inline] - pub fn nonce(&self) -> Nonce { - Nonce::ReplayDetected { - rejected_key_id: self.rejected_key_id.into(), - } + encoder::finish(encoder, crypto) } #[cfg(test)] diff --git a/dc/s2n-quic-dc/src/packet/secret_control/stale_key.rs b/dc/s2n-quic-dc/src/packet/secret_control/stale_key.rs index 354308a78d..bac9c3641d 100644 --- a/dc/s2n-quic-dc/src/packet/secret_control/stale_key.rs +++ b/dc/s2n-quic-dc/src/packet/secret_control/stale_key.rs @@ -18,21 +18,14 @@ impl StaleKey { #[inline] pub fn encode(&self, mut encoder: EncoderBuffer, crypto: &C) -> usize where - C: encrypt::Key, + C: seal::control::Secret, { encoder.encode(&Tag::default()); encoder.encode(&self.credential_id); encoder.encode(&self.wire_version); encoder.encode(&self.min_key_id); - encoder::finish(encoder, self.nonce(), crypto) - } - - #[inline] - pub fn nonce(&self) -> Nonce { - Nonce::StaleKey { - min_key_id: self.min_key_id.into(), - } + encoder::finish(encoder, crypto) } #[cfg(test)] diff --git a/dc/s2n-quic-dc/src/packet/secret_control/unknown_path_secret.rs b/dc/s2n-quic-dc/src/packet/secret_control/unknown_path_secret.rs index 93d21214a0..da8eb9c61a 100644 --- a/dc/s2n-quic-dc/src/packet/secret_control/unknown_path_secret.rs +++ b/dc/s2n-quic-dc/src/packet/secret_control/unknown_path_secret.rs @@ -6,8 +6,6 @@ use core::mem::size_of; impl_tag!(UNKNOWN_PATH_SECRET); -const STATELESS_RESET_LEN: usize = 16; - #[derive(Clone, Copy, Debug)] pub struct Packet<'a> { #[allow(dead_code)] @@ -17,10 +15,7 @@ pub struct Packet<'a> { } impl<'a> Packet<'a> { - pub fn new_for_test( - id: crate::credentials::Id, - stateless_reset: &[u8; STATELESS_RESET_LEN], - ) -> Packet<'_> { + pub fn new_for_test(id: crate::credentials::Id, stateless_reset: &[u8; TAG_LEN]) -> Packet<'_> { Packet { header: &[], value: UnknownPathSecret { @@ -34,8 +29,7 @@ impl<'a> Packet<'a> { #[inline] pub fn decode(buffer: DecoderBufferMut<'a>) -> Rm { let header_len = decoder::header_len::(buffer.peek())?; - let ((header, value, crypto_tag), buffer) = - decoder::header(buffer, header_len, STATELESS_RESET_LEN)?; + let ((header, value, crypto_tag), buffer) = decoder::header(buffer, header_len)?; let packet = Self { header, value, @@ -50,10 +44,7 @@ impl<'a> Packet<'a> { } #[inline] - pub fn authenticate( - &self, - stateless_reset: &[u8; STATELESS_RESET_LEN], - ) -> Option<&UnknownPathSecret> { + pub fn authenticate(&self, stateless_reset: &[u8; TAG_LEN]) -> Option<&UnknownPathSecret> { aws_lc_rs::constant_time::verify_slices_are_equal(self.crypto_tag, stateless_reset).ok()?; Some(&self.value) } @@ -68,14 +59,10 @@ pub struct UnknownPathSecret { impl UnknownPathSecret { pub const PACKET_SIZE: usize = - size_of::() + size_of::() + size_of::() + STATELESS_RESET_LEN; + size_of::() + size_of::() + size_of::() + TAG_LEN; #[inline] - pub fn encode( - &self, - mut encoder: EncoderBuffer, - stateless_reset_tag: &[u8; STATELESS_RESET_LEN], - ) -> usize { + pub fn encode(&self, mut encoder: EncoderBuffer, stateless_reset_tag: &[u8; TAG_LEN]) -> usize { let before = encoder.len(); encoder.encode(&Tag::default()); encoder.encode(&&self.credential_id[..]); @@ -84,11 +71,6 @@ impl UnknownPathSecret { let after = encoder.len(); after - before } - - #[inline] - pub fn nonce(&self) -> Nonce { - Nonce::UnknownPathSecret - } } impl<'a> DecoderValue<'a> for UnknownPathSecret { @@ -110,10 +92,15 @@ impl<'a> DecoderValue<'a> for UnknownPathSecret { mod tests { use super::*; + #[test] + fn stateless_reset_len() { + assert_eq!(s2n_quic_core::stateless_reset::token::LEN, TAG_LEN); + } + #[test] fn round_trip_test() { bolero::check!() - .with_type::<(UnknownPathSecret, [u8; 16])>() + .with_type::<(UnknownPathSecret, [u8; TAG_LEN])>() .for_each(|(value, stateless_reset)| { let mut buffer = [0u8; UnknownPathSecret::PACKET_SIZE]; let len = { diff --git a/dc/s2n-quic-dc/src/packet/stream.rs b/dc/s2n-quic-dc/src/packet/stream.rs index 32a8cef210..d9465c243d 100644 --- a/dc/s2n-quic-dc/src/packet/stream.rs +++ b/dc/s2n-quic-dc/src/packet/stream.rs @@ -3,7 +3,7 @@ use super::tag::Common; use core::fmt; -use s2n_quic_core::{packet::KeyPhase, probe, varint::VarInt}; +use s2n_quic_core::{packet::KeyPhase, probe}; use zerocopy::{AsBytes, FromBytes, FromZeroes, Unaligned}; pub mod decoder; @@ -24,17 +24,6 @@ pub enum PacketSpace { Recovery, } -impl PacketSpace { - #[inline] - pub fn packet_number_into_nonce(&self, packet_number: VarInt) -> u64 { - let mut nonce = packet_number.as_u64(); - if let Self::Recovery = self { - nonce |= 1 << 62; - } - nonce - } -} - impl probe::Arg for PacketSpace { #[inline] fn into_usdt(self) -> isize { diff --git a/dc/s2n-quic-dc/src/packet/stream/decoder.rs b/dc/s2n-quic-dc/src/packet/stream/decoder.rs index 33f41a26cb..ecb43a7cf5 100644 --- a/dc/s2n-quic-dc/src/packet/stream/decoder.rs +++ b/dc/s2n-quic-dc/src/packet/stream/decoder.rs @@ -13,7 +13,7 @@ use core::{fmt, mem::size_of}; use s2n_codec::{ decoder_invariant, CheckedRange, DecoderBufferMut, DecoderBufferMutResult as R, DecoderError, }; -use s2n_quic_core::{assume, varint::VarInt}; +use s2n_quic_core::{assume, ensure, varint::VarInt}; type PacketNumber = VarInt; @@ -232,72 +232,88 @@ impl<'a> Packet<'a> { } #[inline] - pub fn decrypt( + pub fn decrypt( &mut self, d: &D, + c: &C, payload_out: &mut crypto::UninitSlice, - ) -> Result<(), crypto::decrypt::Error> + ) -> Result<(), crypto::open::Error> where - D: crypto::decrypt::Key, + D: crypto::open::Application, + C: crypto::open::control::Stream, { let key_phase = self.tag.key_phase(); - let space = self.remove_retransmit(d); - - let nonce = space.packet_number_into_nonce(self.original_packet_number); + let space = self.remove_retransmit(c)?; + let nonce = self.original_packet_number.as_u64(); let header = &self.header; let payload = &self.payload; let auth_tag = &self.auth_tag; - d.decrypt(key_phase, nonce, header, payload, auth_tag, payload_out)?; + match space { + stream::PacketSpace::Stream => { + d.decrypt(key_phase, nonce, header, payload, auth_tag, payload_out)?; + } + stream::PacketSpace::Recovery => { + // recovery/probe packets cannot have payloads + ensure!(payload.is_empty(), Err(crypto::open::Error::MacOnly)); + c.verify(header, auth_tag)?; + } + } Ok(()) } #[inline] - pub fn decrypt_in_place(&mut self, d: &D) -> Result<(), crypto::decrypt::Error> + pub fn decrypt_in_place(&mut self, d: &D, c: &C) -> Result<(), crypto::open::Error> where - D: crypto::decrypt::Key, + D: crypto::open::Application, + C: crypto::open::control::Stream, { let key_phase = self.tag.key_phase(); - let space = self.remove_retransmit(d); + let space = self.remove_retransmit(c)?; - let nonce = space.packet_number_into_nonce(self.original_packet_number); + let nonce = self.original_packet_number.as_u64(); let header = &self.header; - let payload_len = self.payload.len(); - let payload_ptr = self.payload.as_mut_ptr(); - let tag_len = self.auth_tag.len(); - let tag_ptr = self.auth_tag.as_mut_ptr(); - let payload_and_tag = unsafe { - debug_assert_eq!(payload_ptr.add(payload_len), tag_ptr); - - core::slice::from_raw_parts_mut(payload_ptr, payload_len + tag_len) - }; - - d.decrypt_in_place(key_phase, nonce, header, payload_and_tag)?; + match space { + stream::PacketSpace::Stream => { + let payload_len = self.payload.len(); + let payload_ptr = self.payload.as_mut_ptr(); + let tag_len = self.auth_tag.len(); + let tag_ptr = self.auth_tag.as_mut_ptr(); + let payload_and_tag = unsafe { + debug_assert_eq!(payload_ptr.add(payload_len), tag_ptr); + + core::slice::from_raw_parts_mut(payload_ptr, payload_len + tag_len) + }; + d.decrypt_in_place(key_phase, nonce, header, payload_and_tag)?; + } + stream::PacketSpace::Recovery => { + // recovery/probe packets cannot have payloads + ensure!(self.payload.is_empty(), Err(crypto::open::Error::MacOnly)); + c.verify(header, self.auth_tag)?; + } + } Ok(()) } #[inline] - fn remove_retransmit(&mut self, d: &D) -> stream::PacketSpace + fn remove_retransmit(&mut self, c: &C) -> Result where - D: crypto::decrypt::Key, + C: crypto::open::control::Stream, { - let key_phase = self.tag.key_phase(); let space = self.tag.packet_space(); let original_packet_number = self.original_packet_number; let retransmission_packet_number = self.packet_number; if original_packet_number != retransmission_packet_number { - d.retransmission_tag( - key_phase, + c.retransmission_tag( original_packet_number.as_u64(), - stream::PacketSpace::Recovery - .packet_number_into_nonce(retransmission_packet_number), + retransmission_packet_number.as_u64(), self.auth_tag, - ); + )?; // clear the recovery packet bit, since this is a retransmission self.header[0] &= !super::Tag::IS_RECOVERY_PACKET; @@ -306,9 +322,9 @@ impl<'a> Packet<'a> { let range = offset..offset + size_of::(); self.header[range].copy_from_slice(&[0; size_of::()]); - stream::PacketSpace::Stream + Ok(stream::PacketSpace::Stream) } else { - space + Ok(space) } } @@ -321,7 +337,7 @@ impl<'a> Packet<'a> { key: &K, ) -> Result<(), DecoderError> where - K: crypto::encrypt::Key, + K: crypto::seal::control::Stream, { let buffer = buffer.into_less_safe_slice(); @@ -358,7 +374,7 @@ impl<'a> Packet<'a> { key: &K, ) -> Result<(), DecoderError> where - K: crypto::encrypt::Key, + K: crypto::seal::control::Stream, { Self::retransmit_impl(buffer, space, retransmission_packet_number, key) } @@ -379,7 +395,7 @@ impl<'a> Packet<'a> { key: &K, ) -> Result<(), DecoderError> where - K: crypto::encrypt::Key, + K: crypto::seal::control::Stream, { unsafe { assume!(key.tag_len() >= 16, "tag len needs to be at least 16 bytes"); @@ -404,11 +420,9 @@ impl<'a> Packet<'a> { tag }; - let (credentials, buffer) = buffer.decode::()?; + let (_credentials, buffer) = buffer.decode::()?; let (_wire_version, buffer) = buffer.decode::()?; - debug_assert_eq!(&credentials, key.credentials()); - let (_source_control_port, buffer) = buffer.decode::()?; let (_source_stream_port, buffer) = if tag.has_source_stream_port() { @@ -463,8 +477,7 @@ impl<'a> Packet<'a> { original_packet_number + VarInt::from_u32(prev_value); key.retransmission_tag( original_packet_number.as_u64(), - stream::PacketSpace::Recovery - .packet_number_into_nonce(retransmission_packet_number), + retransmission_packet_number.as_u64(), auth_tag, ); } @@ -473,7 +486,7 @@ impl<'a> Packet<'a> { key.retransmission_tag( original_packet_number.as_u64(), - stream::PacketSpace::Recovery.packet_number_into_nonce(retransmission_packet_number), + retransmission_packet_number.as_u64(), auth_tag, ); diff --git a/dc/s2n-quic-dc/src/packet/stream/encoder.rs b/dc/s2n-quic-dc/src/packet/stream/encoder.rs index 67914ed3cb..2bcba27b13 100644 --- a/dc/s2n-quic-dc/src/packet/stream/encoder.rs +++ b/dc/s2n-quic-dc/src/packet/stream/encoder.rs @@ -2,7 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 use crate::{ - crypto::encrypt, + credentials::Credentials, + crypto::{self, KeyPhase}, packet::{ stream::{self, RelativeRetransmissionOffset, Tag}, WireVersion, @@ -25,7 +26,6 @@ pub fn encode( source_control_port: u16, source_stream_port: Option, stream_id: stream::Id, - packet_space: stream::PacketSpace, packet_number: VarInt, next_expected_control_packet: VarInt, header_len: VarInt, @@ -34,12 +34,172 @@ pub fn encode( control_data: &CD, payload: &mut P, crypto: &C, + credentials: &Credentials, +) -> usize +where + H: buffer::reader::Storage, + P: buffer::Reader, + CD: EncoderValue, + C: crypto::seal::Application, +{ + let packet_space = stream::PacketSpace::Stream; + + let payload_len = encode_header( + &mut encoder, + packet_space, + crypto.key_phase(), + credentials, + source_control_port, + source_stream_port, + stream_id, + packet_number, + next_expected_control_packet, + header_len, + header, + control_data_len, + control_data, + payload, + crypto.tag_len(), + ); + + let nonce = packet_number.as_u64(); + + let payload_offset = encoder.len(); + + let mut last_chunk = Default::default(); + encoder.write_sized(payload_len, |mut dest| { + // the payload result is infallible + last_chunk = payload.infallible_partial_copy_into(&mut dest); + }); + + let last_chunk = if last_chunk.is_empty() { + None + } else { + Some(&*last_chunk) + }; + + encoder.advance_position(crypto.tag_len()); + + let packet_len = encoder.len(); + + let slice = encoder.as_mut_slice(); + + { + let (header, payload_and_tag) = unsafe { + assume!(slice.len() >= payload_offset); + slice.split_at_mut(payload_offset) + }; + + crypto.encrypt(nonce, header, last_chunk, payload_and_tag); + } + + if cfg!(debug_assertions) { + let decoder = s2n_codec::DecoderBufferMut::new(slice); + let (packet, remaining) = + super::decoder::Packet::decode(decoder, (), crypto.tag_len()).unwrap(); + assert!(remaining.is_empty()); + assert_eq!(packet.payload().len(), payload_len); + assert_eq!(packet.packet_number(), packet_number); + } + + packet_len +} + +#[inline(always)] +pub fn probe( + mut encoder: EncoderBuffer, + source_control_port: u16, + source_stream_port: Option, + stream_id: stream::Id, + packet_number: VarInt, + next_expected_control_packet: VarInt, + header_len: VarInt, + header: &mut H, + control_data_len: VarInt, + control_data: &CD, + payload: &mut P, + crypto: &C, + credentials: &Credentials, +) -> usize +where + H: buffer::reader::Storage, + P: buffer::Reader, + CD: EncoderValue, + C: crypto::seal::control::Stream, +{ + let packet_space = stream::PacketSpace::Recovery; + + let payload_len = encode_header( + &mut encoder, + packet_space, + KeyPhase::Zero, + credentials, + source_control_port, + source_stream_port, + stream_id, + packet_number, + next_expected_control_packet, + header_len, + header, + control_data_len, + control_data, + payload, + crypto.tag_len(), + ); + + debug_assert_eq!(payload_len, 0, "probes should not contain data"); + + let tag_offset = encoder.len(); + + encoder.advance_position(crypto.tag_len()); + + let packet_len = encoder.len(); + + let slice = encoder.as_mut_slice(); + + { + let (header, tag) = unsafe { + assume!(slice.len() >= tag_offset); + slice.split_at_mut(tag_offset) + }; + + crypto.sign(header, tag); + } + + if cfg!(debug_assertions) { + let decoder = s2n_codec::DecoderBufferMut::new(slice); + let (packet, remaining) = + super::decoder::Packet::decode(decoder, (), crypto.tag_len()).unwrap(); + assert!(remaining.is_empty()); + assert_eq!(packet.payload().len(), payload_len); + assert_eq!(packet.packet_number(), packet_number); + } + + packet_len +} + +#[inline(always)] +fn encode_header( + encoder: &mut EncoderBuffer, + packet_space: stream::PacketSpace, + key_phase: KeyPhase, + credentials: &Credentials, + source_control_port: u16, + source_stream_port: Option, + stream_id: stream::Id, + packet_number: VarInt, + next_expected_control_packet: VarInt, + header_len: VarInt, + header: &mut H, + control_data_len: VarInt, + control_data: &CD, + payload: &mut P, + tag_len: usize, ) -> usize where H: buffer::reader::Storage, P: buffer::Reader, CD: EncoderValue, - C: encrypt::Key, { let stream_offset = payload.current_offset(); let final_offset = payload.final_offset(); @@ -48,7 +208,7 @@ where debug_assert_ne!(source_stream_port, Some(0)); let mut tag = Tag::default(); - tag.set_key_phase(crypto.key_phase()); + tag.set_key_phase(key_phase); tag.set_has_control_data(*control_data_len > 0); tag.set_has_final_offset(final_offset.is_some()); tag.set_has_application_header(*header_len > 0); @@ -56,10 +216,8 @@ where tag.set_packet_space(packet_space); encoder.encode(&tag); - let nonce = packet_space.packet_number_into_nonce(packet_number); - // encode the credentials being used - encoder.encode(crypto.credentials()); + encoder.encode(credentials); // wire version - we only support `0` currently encoder.encode(&WireVersion::ZERO); @@ -99,7 +257,7 @@ where .saturating_sub(header_len.encoding_size()) .saturating_sub(*header_len as usize) .saturating_sub(*control_data_len as usize) - .saturating_sub(crypto.tag_len()); + .saturating_sub(tag_len); // TODO figure out encoding size for the capacity let remaining_payload_capacity = remaining_payload_capacity.saturating_sub(1); @@ -131,43 +289,5 @@ where encoder.encode(control_data); } - let payload_offset = encoder.len(); - - let mut last_chunk = Default::default(); - encoder.write_sized(*payload_len as usize, |mut dest| { - // the payload result is infallible - last_chunk = payload.infallible_partial_copy_into(&mut dest); - }); - - let last_chunk = if last_chunk.is_empty() { - None - } else { - Some(&*last_chunk) - }; - - encoder.advance_position(crypto.tag_len()); - - let packet_len = encoder.len(); - - let slice = encoder.as_mut_slice(); - - { - let (header, payload_and_tag) = unsafe { - assume!(slice.len() >= payload_offset); - slice.split_at_mut(payload_offset) - }; - - crypto.encrypt(nonce, header, last_chunk, payload_and_tag); - } - - if cfg!(debug_assertions) { - let decoder = s2n_codec::DecoderBufferMut::new(slice); - let (packet, remaining) = - super::decoder::Packet::decode(decoder, (), crypto.tag_len()).unwrap(); - assert!(remaining.is_empty()); - assert_eq!(packet.payload().len() as u64, payload_len.as_u64()); - assert_eq!(packet.packet_number(), packet_number); - } - - packet_len + *payload_len as usize } diff --git a/dc/s2n-quic-dc/src/packet/stream/id.rs b/dc/s2n-quic-dc/src/packet/stream/id.rs index 22f46ad14d..3acd7b1910 100644 --- a/dc/s2n-quic-dc/src/packet/stream/id.rs +++ b/dc/s2n-quic-dc/src/packet/stream/id.rs @@ -115,22 +115,6 @@ impl Id { } } -// FIXME: Try to remove this impl. It's probably misleading to allow conversion from a generic u64 -// and might get accidentally used for going from a VarInt key id which is probably not what you -// want. -impl TryFrom for Id { - type Error = s2n_quic_core::varint::VarIntError; - - #[inline] - fn try_from(value: u64) -> Result { - Ok(Self { - key_id: VarInt::new(value)?, - is_reliable: false, - is_bidirectional: false, - }) - } -} - pub const IS_RELIABLE_MASK: u64 = 0b10; pub const IS_BIDIRECTIONAL_MASK: u64 = 0b01; diff --git a/dc/s2n-quic-dc/src/path/secret.rs b/dc/s2n-quic-dc/src/path/secret.rs index fa1edcdb75..d5115b9697 100644 --- a/dc/s2n-quic-dc/src/path/secret.rs +++ b/dc/s2n-quic-dc/src/path/secret.rs @@ -10,5 +10,5 @@ pub mod schedule; mod sender; pub mod stateless_reset; -pub use key::{Opener, Sealer}; +pub use key::{open, seal}; pub use map::Map; diff --git a/dc/s2n-quic-dc/src/path/secret/key.rs b/dc/s2n-quic-dc/src/path/secret/key.rs index 02054db900..8f75c985d0 100644 --- a/dc/s2n-quic-dc/src/path/secret/key.rs +++ b/dc/s2n-quic-dc/src/path/secret/key.rs @@ -1,164 +1,383 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use super::map; -use crate::{ - credentials::Credentials, - crypto::{awslc, decrypt, encrypt, IntoNonce, UninitSlice}, -}; -use core::mem::MaybeUninit; +use super::{map, schedule}; +use core::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use s2n_quic_core::packet::KeyPhase; -use zeroize::Zeroize; -#[derive(Debug)] -pub struct Sealer { - pub(super) sealer: awslc::EncryptKey, -} +pub mod seal { + use super::*; + use crate::crypto::{awslc, seal}; + + pub use awslc::seal::control; -impl encrypt::Key for Sealer { - #[inline] - fn credentials(&self) -> &Credentials { - self.sealer.credentials() + #[derive(Debug)] + pub struct Application { + sealer: awslc::seal::Application, + ku: schedule::SealUpdate, + key_phase: KeyPhase, + encrypted_records: AtomicU64, } - #[inline] - fn key_phase(&self) -> KeyPhase { - KeyPhase::Zero + impl Application { + #[inline] + pub(crate) fn new(sealer: awslc::seal::Application, ku: schedule::SealUpdate) -> Self { + Self { + sealer, + ku, + key_phase: KeyPhase::Zero, + encrypted_records: AtomicU64::new(0), + } + } + + #[inline] + pub fn needs_update(&self) -> bool { + //= https://www.rfc-editor.org/rfc/rfc9001#section-6.6 + //# For AEAD_AES_128_GCM and AEAD_AES_256_GCM, the confidentiality limit + //# is 2^23 encrypted packets; see Appendix B.1. + const LIMIT: u64 = 2u64.pow(23); + + // enqueue key updates 2^16 packets before the limit is hit + const THRESHOLD: u64 = 2u64.pow(16); + + // in debug mode, rotate keys more often in order to surface any issues + const MAX_RECORDS: u64 = if cfg!(debug_assertions) { + 4096 + } else { + LIMIT - THRESHOLD + }; + + self.encrypted_records.load(Ordering::Relaxed) >= MAX_RECORDS + } + + #[inline] + pub fn update(&mut self) { + let (sealer, ku) = self.ku.next(); + self.sealer = sealer; + self.ku = ku; + self.encrypted_records = AtomicU64::new(0); + self.key_phase = self.key_phase.next_phase(); + tracing::debug!(sealer_updated = ?self.key_phase); + } } - #[inline] - fn tag_len(&self) -> usize { - self.sealer.tag_len() + impl seal::Application for Application { + #[inline] + fn key_phase(&self) -> KeyPhase { + self.key_phase + } + + #[inline] + fn tag_len(&self) -> usize { + self.sealer.tag_len() + } + + #[inline] + fn encrypt( + &self, + packet_number: u64, + header: &[u8], + extra_payload: Option<&[u8]>, + payload_and_tag: &mut [u8], + ) { + self.encrypted_records.fetch_add(1, Ordering::Relaxed); + self.sealer + .encrypt(packet_number, header, extra_payload, payload_and_tag) + } } - #[inline] - fn encrypt( - &self, - nonce: N, - header: &[u8], - extra_payload: Option<&[u8]>, - payload_and_tag: &mut [u8], - ) { - self.sealer - .encrypt(nonce, header, extra_payload, payload_and_tag) + #[derive(Debug)] + pub struct Once { + key: awslc::seal::Application, + sealed: AtomicBool, } - #[inline] - fn retransmission_tag( - &self, - original_packet_number: u64, - retransmission_packet_number: u64, - tag_out: &mut [u8], - ) { - self.sealer.retransmission_tag( - original_packet_number, - retransmission_packet_number, - tag_out, - ) + impl Once { + pub(crate) fn new(key: awslc::seal::Application) -> Self { + Self { + key, + sealed: AtomicBool::new(false), + } + } } -} -#[derive(Debug)] -pub struct Opener { - pub(super) opener: awslc::DecryptKey, - pub(super) dedup: map::Dedup, -} + impl seal::Application for Once { + #[inline] + fn key_phase(&self) -> KeyPhase { + KeyPhase::Zero + } -impl Opener { - /// Disables replay prevention allowing the decryption key to be reused. - /// - /// ## Safety - /// Disabling replay prevention is insecure because it makes it possible for - /// active network attackers to cause a peer to accept previously processed - /// data as new. For example, if a packet contains a mutating request such - /// as adding +1 to a value in a database, an attacker can keep replaying - /// packets to increment the value beyond what the original legitimate - /// sender of the packet intended. - pub unsafe fn disable_replay_prevention(&mut self) { - self.dedup.disable(); + #[inline] + fn tag_len(&self) -> usize { + self.key.tag_len() + } + + #[inline] + fn encrypt( + &self, + packet_number: u64, + header: &[u8], + extra_payload: Option<&[u8]>, + payload_and_tag: &mut [u8], + ) { + assert!(!self.sealed.swap(true, Ordering::Relaxed)); + self.key + .encrypt(packet_number, header, extra_payload, payload_and_tag) + } } +} - /// Ensures the key has not been used before - #[inline] - fn on_decrypt_success(&self, payload: &mut UninitSlice) -> decrypt::Result { - self.dedup.check(&self.opener).map_err(|e| { - let payload = unsafe { - let ptr = payload.as_mut_ptr() as *mut MaybeUninit; - let len = payload.len(); - core::slice::from_raw_parts_mut(ptr, len) - }; - payload.zeroize(); - e - })?; +pub mod open { + use super::*; + use crate::crypto::{awslc, open, UninitSlice}; + use core::mem::MaybeUninit; + use s2n_quic_core::ensure; + use zeroize::Zeroize; - Ok(()) - } + pub use awslc::open::control; + + macro_rules! with_dedup { + () => { + /// Disables replay prevention allowing the decryption key to be reused. + /// + /// ## Safety + /// Disabling replay prevention is insecure because it makes it possible for + /// active network attackers to cause a peer to accept previously processed + /// data as new. For example, if a packet contains a mutating request such + /// as adding +1 to a value in a database, an attacker can keep replaying + /// packets to increment the value beyond what the original legitimate + /// sender of the packet intended. + pub unsafe fn disable_replay_prevention(&mut self) { + self.dedup.disable(); + } + + /// Ensures the key has not been used before + #[inline] + fn on_decrypt_success(&self, payload: &mut UninitSlice) -> open::Result { + self.dedup.check().map_err(|e| { + let payload = unsafe { + let ptr = payload.as_mut_ptr() as *mut MaybeUninit; + let len = payload.len(); + core::slice::from_raw_parts_mut(ptr, len) + }; + payload.zeroize(); + e + })?; + + Ok(()) + } - #[doc(hidden)] - #[cfg(any(test, feature = "testing"))] - pub fn dedup_check(&self) -> decrypt::Result { - self.dedup.check(&self.opener) + #[doc(hidden)] + #[cfg(any(test, feature = "testing"))] + pub fn dedup_check(&self) -> open::Result { + self.dedup.check() + } + }; } -} -impl decrypt::Key for Opener { - #[inline] - fn credentials(&self) -> &Credentials { - self.opener.credentials() + #[derive(Debug)] + pub struct Application { + openers: [awslc::open::Application; 2], + ku: schedule::OpenUpdate, + // the current expected phase + key_phase: KeyPhase, + dedup: map::Dedup, + needs_update: AtomicBool, } - #[inline] - fn tag_len(&self) -> usize { - self.opener.tag_len() + impl Application { + pub(crate) fn new( + opener: awslc::open::Application, + ku: schedule::OpenUpdate, + dedup: map::Dedup, + ) -> Self { + let (opener2, ku) = ku.next(); + let openers = [opener, opener2]; + Self { + openers, + ku, + key_phase: KeyPhase::Zero, + dedup, + needs_update: AtomicBool::new(false), + } + } + + with_dedup!(); + + #[inline] + pub fn needs_update(&self) -> bool { + self.needs_update.load(Ordering::Relaxed) + } + + #[inline] + pub fn update(&mut self) { + let idx = match self.key_phase { + KeyPhase::Zero => 0, + KeyPhase::One => 1, + }; + let (opener, ku) = self.ku.next(); + self.openers[idx] = opener; + self.ku = ku; + self.key_phase = self.key_phase.next_phase(); + self.needs_update.store(false, Ordering::Relaxed); + tracing::debug!(opener_updated = ?self.key_phase); + } } - #[inline] - fn decrypt( - &self, - key_phase: KeyPhase, - nonce: N, - header: &[u8], - payload_in: &[u8], - tag: &[u8], - payload_out: &mut UninitSlice, - ) -> decrypt::Result { - self.opener - .decrypt(key_phase, nonce, header, payload_in, tag, payload_out)?; - - self.on_decrypt_success(payload_out)?; - - Ok(()) + impl open::Application for Application { + #[inline] + fn tag_len(&self) -> usize { + self.openers[0].tag_len() + } + + #[inline] + fn decrypt( + &self, + key_phase: KeyPhase, + packet_number: u64, + header: &[u8], + payload_in: &[u8], + tag: &[u8], + payload_out: &mut UninitSlice, + ) -> open::Result { + let opener = match key_phase { + KeyPhase::Zero => &self.openers[0], + KeyPhase::One => &self.openers[1], + }; + + opener.decrypt( + // the underlying key doesn't perform rotation + KeyPhase::Zero, + packet_number, + header, + payload_in, + tag, + payload_out, + )?; + + self.on_decrypt_success(payload_out)?; + + if key_phase != self.key_phase { + self.needs_update.store(true, Ordering::Relaxed); + } + + Ok(()) + } + + #[inline] + fn decrypt_in_place( + &self, + key_phase: KeyPhase, + packet_number: u64, + header: &[u8], + payload_and_tag: &mut [u8], + ) -> open::Result { + let opener = match key_phase { + KeyPhase::Zero => &self.openers[0], + KeyPhase::One => &self.openers[1], + }; + + opener.decrypt_in_place( + // the underlying key doesn't perform rotation + KeyPhase::Zero, + packet_number, + header, + payload_and_tag, + )?; + + self.on_decrypt_success(payload_and_tag.into())?; + + if key_phase != self.key_phase { + self.needs_update.store(true, Ordering::Relaxed); + } + + Ok(()) + } } - #[inline] - fn decrypt_in_place( - &self, - key_phase: KeyPhase, - nonce: N, - header: &[u8], - payload_and_tag: &mut [u8], - ) -> decrypt::Result { - self.opener - .decrypt_in_place(key_phase, nonce, header, payload_and_tag)?; + #[derive(Debug)] + pub struct Once { + key: awslc::open::Application, + dedup: map::Dedup, + opened: AtomicBool, + } - self.on_decrypt_success(UninitSlice::new(payload_and_tag))?; + impl Once { + pub(crate) fn new(key: awslc::open::Application, dedup: map::Dedup) -> Self { + Self { + key, + dedup, + opened: AtomicBool::new(false), + } + } - Ok(()) + with_dedup!(); } - #[inline] - fn retransmission_tag( - &self, - key_phase: KeyPhase, - original_packet_number: u64, - retransmission_packet_number: u64, - tag_out: &mut [u8], - ) { - self.opener.retransmission_tag( - key_phase, - original_packet_number, - retransmission_packet_number, - tag_out, - ) + impl open::Application for Once { + #[inline] + fn tag_len(&self) -> usize { + self.key.tag_len() + } + + #[inline] + fn decrypt( + &self, + key_phase: KeyPhase, + packet_number: u64, + header: &[u8], + payload_in: &[u8], + tag: &[u8], + payload_out: &mut UninitSlice, + ) -> open::Result { + ensure!( + key_phase == KeyPhase::Zero, + Err(open::Error::RotationNotSupported) + ); + + self.key.decrypt( + key_phase, + packet_number, + header, + payload_in, + tag, + payload_out, + )?; + + self.on_decrypt_success(payload_out)?; + + ensure!( + !self.opened.swap(true, Ordering::Relaxed), + Err(open::Error::SingleUseKey) + ); + + Ok(()) + } + + #[inline] + fn decrypt_in_place( + &self, + key_phase: KeyPhase, + packet_number: u64, + header: &[u8], + payload_and_tag: &mut [u8], + ) -> open::Result { + ensure!( + key_phase == KeyPhase::Zero, + Err(open::Error::RotationNotSupported) + ); + + self.key + .decrypt_in_place(key_phase, packet_number, header, payload_and_tag)?; + + self.on_decrypt_success(payload_and_tag.into())?; + + ensure!( + !self.opened.swap(true, Ordering::Relaxed), + Err(open::Error::SingleUseKey) + ); + + Ok(()) + } } } diff --git a/dc/s2n-quic-dc/src/path/secret/map.rs b/dc/s2n-quic-dc/src/path/secret/map.rs index 406f8c61da..4bea24e8bc 100644 --- a/dc/s2n-quic-dc/src/path/secret/map.rs +++ b/dc/s2n-quic-dc/src/path/secret/map.rs @@ -2,14 +2,15 @@ // SPDX-License-Identifier: Apache-2.0 use super::{ - receiver, + open, receiver, schedule::{self, Initiator}, - sender, stateless_reset, Opener, Sealer, + seal, sender, stateless_reset, }; use crate::{ credentials::{Credentials, Id}, crypto, packet::{secret_control as control, Packet, WireVersion}, + stream::TransportFeatures, }; use rand::Rng as _; use s2n_codec::EncoderBuffer; @@ -17,6 +18,7 @@ use s2n_quic_core::{ dc::{self, ApplicationParams, DatagramInfo}, ensure, event::api::EndpointType, + varint::VarInt, }; use std::{ fmt, @@ -268,6 +270,22 @@ impl Map { Self { state } } + /// The number of trusted secrets. + pub fn secrets_len(&self) -> usize { + self.state.ids.len() + } + + /// The number of trusted peers. + /// + /// This should be smaller than `secrets_len` (modulo momentary churn). + pub fn peers_len(&self) -> usize { + self.state.peers.len() + } + + pub fn secrets_capacity(&self) -> usize { + self.state.max_capacity + } + pub fn drop_state(&self) { self.state.peers.pin().clear(); self.state.ids.pin().clear(); @@ -278,42 +296,54 @@ impl Map { && !self.state.requested_handshakes.pin().contains(&peer) } - pub fn sealer(&self, peer: SocketAddr) -> Option<(Sealer, ApplicationParams)> { + pub fn seal_once( + &self, + peer: SocketAddr, + ) -> Option<(seal::Once, Credentials, ApplicationParams)> { let peers_guard = self.state.peers.guard(); let state = self.state.peers.get(&peer, &peers_guard)?; state.mark_live(self.state.cleaner.epoch()); - let sealer = state.uni_sealer(); - Some((sealer, state.parameters)) + let (sealer, credentials) = state.uni_sealer(); + Some((sealer, credentials, state.parameters)) } - pub fn opener(&self, credentials: &Credentials, control_out: &mut Vec) -> Option { + pub fn open_once( + &self, + credentials: &Credentials, + control_out: &mut Vec, + ) -> Option { let state = self.pre_authentication(credentials, control_out)?; let opener = state.uni_opener(self.clone(), credentials); Some(opener) } - pub fn pair_for_peer(&self, peer: SocketAddr) -> Option<(Sealer, Opener, ApplicationParams)> { + pub fn pair_for_peer( + &self, + peer: SocketAddr, + features: &TransportFeatures, + ) -> Option<(Bidirectional, ApplicationParams)> { let peers_guard = self.state.peers.guard(); let state = self.state.peers.get(&peer, &peers_guard)?; state.mark_live(self.state.cleaner.epoch()); - let (sealer, opener) = state.bidi_local(); + let keys = state.bidi_local(features); - Some((sealer, opener, state.parameters)) + Some((keys, state.parameters)) } pub fn pair_for_credentials( &self, credentials: &Credentials, + features: &TransportFeatures, control_out: &mut Vec, - ) -> Option<(Sealer, Opener, ApplicationParams)> { + ) -> Option<(Bidirectional, ApplicationParams)> { let state = self.pre_authentication(credentials, control_out)?; let params = state.parameters; - let (sealer, opener) = state.bidi_remote(self.clone(), credentials); + let keys = state.bidi_remote(self.clone(), credentials, features); - Some((sealer, opener, params)) + Some((keys, params)) } /// This can be called from anywhere to ask the map to handle a packet. @@ -481,7 +511,7 @@ impl Map { let provider = Self::new(stateless_reset::Signer::random()); let mut secret = [0; 32]; aws_lc_rs::rand::fill(&mut secret).unwrap(); - let mut stateless_reset = [0; 16]; + let mut stateless_reset = [0; control::TAG_LEN]; aws_lc_rs::rand::fill(&mut stateless_reset).unwrap(); let receiver_shared = receiver::Shared::new(); @@ -524,7 +554,7 @@ impl Map { s2n_quic_core::endpoint::Type::Client, &secret, ); - let sender = sender::State::new([0; 16]); + let sender = sender::State::new([0; control::TAG_LEN]); let receiver = self.state.receiver_shared.clone().new_receiver(); let entry = Entry::new( peer, @@ -633,9 +663,14 @@ impl Entry { secret: schedule::Secret, sender: sender::State, receiver: receiver::State, - parameters: ApplicationParams, + mut parameters: ApplicationParams, rehandshake_time: Duration, ) -> Self { + // clamp max datagram size to a well-known value + parameters.max_datagram_size = parameters + .max_datagram_size + .min(crate::stream::MAX_DATAGRAM_SIZE as _); + assert!(rehandshake_time.as_secs() <= u32::MAX as u64); Self { creation_time: Instant::now(), @@ -661,55 +696,127 @@ impl Entry { self.used_at.store(at_epoch, Ordering::Relaxed); } - fn uni_sealer(&self) -> Sealer { + fn uni_sealer(&self) -> (seal::Once, Credentials) { let key_id = self.sender.next_key_id(); + let credentials = Credentials { + id: *self.secret.id(), + key_id, + }; let sealer = self.secret.application_sealer(key_id); + let sealer = seal::Once::new(sealer); - Sealer { sealer } + (sealer, credentials) } - fn uni_opener(self: Arc, map: Map, credentials: &Credentials) -> Opener { - let opener = self.secret.application_opener(credentials.key_id); + fn uni_opener(self: Arc, map: Map, credentials: &Credentials) -> open::Once { + let key_id = credentials.key_id; + let opener = self.secret.application_opener(key_id); + let dedup = Dedup::new(self, key_id, map); + open::Once::new(opener, dedup) + } - let dedup = Dedup::new(self, map); + fn bidi_local(&self, features: &TransportFeatures) -> Bidirectional { + let key_id = self.sender.next_key_id(); + let initiator = Initiator::Local; + + let application = ApplicationPair::new( + &self.secret, + key_id, + initiator, + // we don't need to dedup locally-initiated openers + Dedup::disabled(), + ); - Opener { opener, dedup } + let control = if features.is_reliable() { + None + } else { + Some(ControlPair::new(&self.secret, key_id, initiator)) + }; + + Bidirectional { + credentials: Credentials { + id: *self.secret.id(), + key_id, + }, + application, + control, + } } - fn bidi_local(&self) -> (Sealer, Opener) { - let key_id = self.sender.next_key_id(); - let (sealer, opener) = self.secret.application_pair(key_id, Initiator::Local); - let sealer = Sealer { sealer }; + fn bidi_remote( + self: &Arc, + map: Map, + credentials: &Credentials, + features: &TransportFeatures, + ) -> Bidirectional { + let key_id = credentials.key_id; + let initiator = Initiator::Remote; + + let application = ApplicationPair::new( + &self.secret, + key_id, + initiator, + // Remote application keys need to be de-duplicated + Dedup::new(self.clone(), key_id, map), + ); - // we don't need to dedup locally-initiated openers - let dedup = Dedup::disabled(); + let control = if features.is_reliable() { + None + } else { + Some(ControlPair::new(&self.secret, key_id, initiator)) + }; - let opener = Opener { opener, dedup }; + Bidirectional { + credentials: *credentials, + application, + control, + } + } - (sealer, opener) + fn rehandshake_time(&self) -> Instant { + self.creation_time + Duration::from_secs(u64::from(self.rehandshake_delta_secs)) } +} - fn bidi_remote(self: Arc, map: Map, credentials: &Credentials) -> (Sealer, Opener) { - let (sealer, opener) = self - .secret - .application_pair(credentials.key_id, Initiator::Remote); - let sealer = Sealer { sealer }; +pub struct Bidirectional { + pub credentials: Credentials, + pub application: ApplicationPair, + pub control: Option, +} - let dedup = Dedup::new(self, map); +pub struct ApplicationPair { + pub sealer: seal::Application, + pub opener: open::Application, +} - let opener = Opener { opener, dedup }; +impl ApplicationPair { + fn new(secret: &schedule::Secret, key_id: VarInt, initiator: Initiator, dedup: Dedup) -> Self { + let (sealer, sealer_ku, opener, opener_ku) = secret.application_pair(key_id, initiator); - (sealer, opener) + let sealer = seal::Application::new(sealer, sealer_ku); + + let opener = open::Application::new(opener, opener_ku, dedup); + + Self { sealer, opener } } +} - fn rehandshake_time(&self) -> Instant { - self.creation_time + Duration::from_secs(u64::from(self.rehandshake_delta_secs)) +pub struct ControlPair { + pub sealer: seal::control::Stream, + pub opener: open::control::Stream, +} + +impl ControlPair { + fn new(secret: &schedule::Secret, key_id: VarInt, initiator: Initiator) -> Self { + let (sealer, opener) = secret.control_pair(key_id, initiator); + + Self { sealer, opener } } } pub struct Dedup { - cell: once_cell::sync::OnceCell, - init: core::cell::Cell, Map)>>, + cell: once_cell::sync::OnceCell, + init: core::cell::Cell, VarInt, Map)>>, } /// SAFETY: `init` cell is synchronized by `OnceCell` @@ -717,17 +824,17 @@ unsafe impl Sync for Dedup {} impl Dedup { #[inline] - fn new(entry: Arc, map: Map) -> Self { + fn new(entry: Arc, key_id: VarInt, map: Map) -> Self { // TODO potentially record a timestamp of when this was created to try and detect long // delays of processing the first packet. Self { cell: Default::default(), - init: core::cell::Cell::new(Some((entry, map))), + init: core::cell::Cell::new(Some((entry, key_id, map))), } } #[inline] - fn disabled() -> Self { + pub(crate) fn disabled() -> Self { Self { cell: once_cell::sync::OnceCell::with_value(Ok(())), init: core::cell::Cell::new(None), @@ -740,20 +847,23 @@ impl Dedup { } #[inline] - pub fn check(&self, c: &impl crypto::decrypt::Key) -> crypto::decrypt::Result { + pub fn check(&self) -> crypto::open::Result { *self.cell.get_or_init(|| { match self.init.take() { - Some((entry, map)) => { - let creds = c.credentials(); + Some((entry, key_id, map)) => { + let creds = &Credentials { + id: *entry.secret.id(), + key_id, + }; match entry.receiver.post_authentication(creds) { Ok(()) => Ok(()), Err(receiver::Error::AlreadyExists) => { map.send_control(&entry, creds, receiver::Error::AlreadyExists); - Err(crypto::decrypt::Error::ReplayDefinitelyDetected) + Err(crypto::open::Error::ReplayDefinitelyDetected) } Err(receiver::Error::Unknown) => { map.send_control(&entry, creds, receiver::Error::Unknown); - Err(crypto::decrypt::Error::ReplayPotentiallyDetected { + Err(crypto::open::Error::ReplayPotentiallyDetected { gap: Some( (*entry.receiver.minimum_unseen_key_id()) // This should never be negative, but saturate anyway to avoid @@ -766,7 +876,7 @@ impl Dedup { } None => { // Dedup has been poisoned! TODO log this - Err(crypto::decrypt::Error::ReplayPotentiallyDetected { gap: None }) + Err(crypto::open::Error::ReplayPotentiallyDetected { gap: None }) } } }) @@ -821,8 +931,7 @@ impl dc::Endpoint for Map { payload: &mut [u8], ) -> bool { let payload = s2n_codec::DecoderBufferMut::new(payload); - // TODO: Is 16 always right? - return match control::Packet::decode(payload, 16) { + match control::Packet::decode(payload) { Ok((packet, tail)) => { // Probably a bug somewhere? There shouldn't be anything trailing in the buffer // after we decode a secret control packet. @@ -834,7 +943,7 @@ impl dc::Endpoint for Map { true } Err(_) => false, - }; + } } } diff --git a/dc/s2n-quic-dc/src/path/secret/map/test.rs b/dc/s2n-quic-dc/src/path/secret/map/test.rs index 0b0e1c0270..1f5e47fa55 100644 --- a/dc/s2n-quic-dc/src/path/secret/map/test.rs +++ b/dc/s2n-quic-dc/src/path/secret/map/test.rs @@ -22,7 +22,7 @@ fn fake_entry(peer: u16) -> Arc { s2n_quic_core::endpoint::Type::Client, &secret, ), - sender::State::new([0; 16]), + sender::State::new([0; control::TAG_LEN]), receiver::State::without_shared(), dc::testing::TEST_APPLICATION_PARAMS, dc::testing::TEST_REHANDSHAKE_PERIOD, @@ -34,6 +34,10 @@ fn cleans_after_delay() { let signer = stateless_reset::Signer::new(b"secret"); let map = Map::new(signer); + // Stop background processing. We expect to manually invoke clean, and a background worker + // might interfere with our state. + map.state.cleaner.stop(); + let first = fake_entry(1); let second = fake_entry(1); let third = fake_entry(1); diff --git a/dc/s2n-quic-dc/src/path/secret/schedule.rs b/dc/s2n-quic-dc/src/path/secret/schedule.rs index bdc5ce484d..dd44989f72 100644 --- a/dc/s2n-quic-dc/src/path/secret/schedule.rs +++ b/dc/s2n-quic-dc/src/path/secret/schedule.rs @@ -2,23 +2,25 @@ // SPDX-License-Identifier: Apache-2.0 use crate::{ - credentials::{Credentials, Id}, - crypto::awslc::{DecryptKey, EncryptKey}, + credentials::Id, + crypto::awslc::{open, seal}, }; use aws_lc_rs::{ aead::{self, NONCE_LEN}, hkdf::{self, Prk}, + hmac, }; use s2n_quic_core::{dc, varint::VarInt}; pub use s2n_quic_core::endpoint; pub const MAX_KEY_LEN: usize = 32; +const MAX_HMAC_KEY_LEN: usize = 1024 / 8; #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(test, derive(bolero_generator::TypeGenerator))] #[allow(non_camel_case_types)] pub enum Ciphersuite { AES_GCM_128_SHA256, - #[allow(dead_code)] AES_GCM_256_SHA384, } @@ -38,6 +40,14 @@ impl Ciphersuite { Self::AES_GCM_256_SHA384 => hkdf::HKDF_SHA384, } } + + #[inline] + pub fn hmac(&self) -> &'static hmac::Algorithm { + match self { + Self::AES_GCM_128_SHA256 => &hmac::HMAC_SHA256, + Self::AES_GCM_256_SHA384 => &hmac::HMAC_SHA384, + } + } } impl hkdf::KeyType for Ciphersuite { @@ -117,7 +127,7 @@ impl Secret { }; let mut id = Id::default(); - v.expand(&[&[16], b" pid"], &mut *id); + v.prk.expand_into(&[&[16], b" pid"], &mut *id); v.id = id; v @@ -133,31 +143,33 @@ impl Secret { &self, key_id: VarInt, initiator: Initiator, - ) -> (EncryptKey, DecryptKey) { - let creds = Credentials { - id: self.id, - key_id, - }; - + ) -> (seal::Application, SealUpdate, open::Application, OpenUpdate) { let ciphersuite = &self.ciphersuite; - let mut out = [0u8; (NONCE_LEN + MAX_KEY_LEN) * 2]; + let mut out = [0u8; (NONCE_LEN + MAX_KEY_LEN) * 2 + MAX_KEY_LEN * 2]; let key_len = hkdf::KeyType::len(ciphersuite); - let out_len = (NONCE_LEN + key_len) * 2; + let out_len = (NONCE_LEN + key_len) * 2 + key_len * 2; + + debug_assert!(out_len <= u16::MAX as usize); + let (out, _) = out.split_at_mut(out_len); - self.expand( + self.prk.expand_into( &[ - &[out_len as u8], + &(out_len as u16).to_be_bytes(), b" bidi", initiator.label(self.endpoint), + b" app", &key_id.to_be_bytes(), ], out, ); + // if the hash is ever broken, it's better to put the "more secret" data at the beginning // - // here we derive + // here we derive: // - // (client_key, server_key, client_iv, server_iv) + // (client_ku, server_ku, client_key, server_key, client_iv, server_iv) + let (client_ku, out) = out.split_at(key_len); + let (server_ku, out) = out.split_at(key_len); let (client_key, out) = out.split_at(key_len); let (server_key, out) = out.split_at(key_len); let (client_iv, server_iv) = out.split_at(NONCE_LEN); @@ -165,41 +177,84 @@ impl Secret { let server_iv = server_iv.try_into().unwrap(); let aead = ciphersuite.aead(); - match self.endpoint { - endpoint::Type::Client => { - let sealer = EncryptKey::new(creds, client_key, client_iv, aead); - let opener = DecryptKey::new(creds, server_key, server_iv, aead); - (sealer, opener) - } - endpoint::Type::Server => { - let sealer = EncryptKey::new(creds, server_key, server_iv, aead); - let opener = DecryptKey::new(creds, client_key, client_iv, aead); - (sealer, opener) - } - } + let (sealer_ku, opener_ku, sealer_key, opener_key, sealer_iv, opener_iv) = + match self.endpoint { + endpoint::Type::Client => ( + client_ku, server_ku, client_key, server_key, client_iv, server_iv, + ), + endpoint::Type::Server => ( + server_ku, client_ku, server_key, client_key, server_iv, client_iv, + ), + }; + + let sealer = seal::Application::new(sealer_key, sealer_iv, aead); + let sealer_ku = SealUpdate::new(sealer_ku, ciphersuite); + let opener = open::Application::new(opener_key, opener_iv, aead); + let opener_ku = OpenUpdate::new(opener_ku, ciphersuite); + (sealer, sealer_ku, opener, opener_ku) } #[inline] - pub fn application_sealer(&self, key_id: VarInt) -> EncryptKey { - let creds = Credentials { - id: self.id, - key_id, + pub fn control_pair( + &self, + key_id: VarInt, + initiator: Initiator, + ) -> (seal::control::Stream, open::control::Stream) { + let ciphersuite = &self.ciphersuite; + let mut out = [0u8; MAX_HMAC_KEY_LEN * 2]; + let key_len = { + // Use the block length for the key, instead of output length for stronger security and to + // avoid padding. + + //= https://www.rfc-editor.org/rfc/rfc2104.html#section-2 + //# The authentication key K can be of any length up to B, the + //# block length of the hash function. Applications that use keys longer + //# than B bytes will first hash the key using H and then use the + //# resultant L byte string as the actual key to HMAC. In any case the + //# minimal recommended length for K is L bytes (as the hash output + //# length). + ciphersuite.hmac().digest_algorithm().block_len() }; + let out_len = key_len * 2; + + debug_assert!(out_len <= u16::MAX as usize); + let (out, _) = out.split_at_mut(out_len); + self.prk.expand_into( + &[ + &(out_len as u16).to_be_bytes(), + b" bidi", + initiator.label(self.endpoint), + b" ctl", + &key_id.to_be_bytes(), + ], + out, + ); + + let (client_key, server_key) = out.split_at(key_len); + let hmac = ciphersuite.hmac(); + + let (sealer_key, opener_key) = match self.endpoint { + endpoint::Type::Client => (client_key, server_key), + endpoint::Type::Server => (server_key, client_key), + }; + + let sealer = seal::control::Stream::new(sealer_key, hmac); + let opener = open::control::Stream::new(opener_key, hmac); + (sealer, opener) + } + + #[inline] + pub fn application_sealer(&self, key_id: VarInt) -> seal::Application { self.derive_application_key(Direction::Send, key_id, |alg, key, iv| { - EncryptKey::new(creds, key, iv, alg) + seal::Application::new(key, iv, alg) }) } #[inline] - pub fn application_opener(&self, key_id: VarInt) -> DecryptKey { - let creds = Credentials { - id: self.id, - key_id, - }; - + pub fn application_opener(&self, key_id: VarInt) -> open::Application { self.derive_application_key(Direction::Receive, key_id, |alg, key, iv| { - DecryptKey::new(creds, key, iv, alg) + open::Application::new(key, iv, alg) }) } @@ -211,10 +266,12 @@ impl Secret { let mut out = [0u8; NONCE_LEN + MAX_KEY_LEN]; let key_len = hkdf::KeyType::len(&self.ciphersuite); let out_len = NONCE_LEN + key_len; + debug_assert!(out_len <= u16::MAX as usize); + let (out, _) = out.split_at_mut(out_len); - self.expand( + self.prk.expand_into( &[ - &[out_len as u8], + &(out_len as u16).to_be_bytes(), b" uni", direction.label(self.endpoint), &key_id.to_be_bytes(), @@ -227,51 +284,58 @@ impl Secret { f(self.ciphersuite.aead(), key, iv) } - pub fn control_sealer(&self) -> EncryptKey { - let creds = Credentials { - id: *self.id(), - key_id: VarInt::ZERO, - }; - - self.derive_control_key(Direction::Send, |alg, key, iv| { - EncryptKey::new(creds, key, iv, alg) - }) + pub fn control_sealer(&self) -> seal::control::Secret { + self.derive_control_key(Direction::Send, seal::control::Secret::new) } - pub fn control_opener(&self) -> DecryptKey { - let creds = Credentials { - id: *self.id(), - key_id: VarInt::ZERO, - }; - - self.derive_control_key(Direction::Receive, |alg, key, iv| { - DecryptKey::new(creds, key, iv, alg) - }) + pub fn control_opener(&self) -> open::control::Secret { + self.derive_control_key(Direction::Receive, open::control::Secret::new) } #[inline] fn derive_control_key(&self, direction: Direction, f: F) -> R where - F: FnOnce(&'static aead::Algorithm, &[u8], [u8; NONCE_LEN]) -> R, + F: FnOnce(&[u8], &'static hmac::Algorithm) -> R, { - let mut out = [0u8; NONCE_LEN + MAX_KEY_LEN]; - let key_len = hkdf::KeyType::len(&self.ciphersuite); - let out_len = NONCE_LEN + key_len; + let mut out = [0u8; MAX_HMAC_KEY_LEN]; + let key_len = { + // Use the block length for the key, instead of output length for stronger security and to + // avoid padding. + + //= https://www.rfc-editor.org/rfc/rfc2104.html#section-2 + //# The authentication key K can be of any length up to B, the + //# block length of the hash function. Applications that use keys longer + //# than B bytes will first hash the key using H and then use the + //# resultant L byte string as the actual key to HMAC. In any case the + //# minimal recommended length for K is L bytes (as the hash output + //# length). + self.ciphersuite.hmac().digest_algorithm().block_len() + }; + + let out_len = key_len; + debug_assert!(out_len <= u16::MAX as usize); + let (out, _) = out.split_at_mut(out_len); - self.expand( - &[&[out_len as u8], b" ctl", direction.label(self.endpoint)], + self.prk.expand_into( + &[ + &(out_len as u16).to_be_bytes(), + b" ctl", + direction.label(self.endpoint), + ], out, ); - // if the hash is ever broken, it's better to put the "more secret" data at the beginning - let (key, iv) = out.split_at(key_len); - let iv = iv.try_into().unwrap(); - f(self.ciphersuite.aead(), key, iv) + f(out, self.ciphersuite.hmac()) } +} + +trait PrkExt { + fn expand_into(&self, label: &[&[u8]], out: &mut [u8]); +} +impl PrkExt for Prk { #[inline] - fn expand(&self, label: &[&[u8]], out: &mut [u8]) { - self.prk - .expand(label, OutLen(out.len())) + fn expand_into(&self, label: &[&[u8]], out: &mut [u8]) { + self.expand(label, OutLen(out.len())) .unwrap() .fill(out) .unwrap(); @@ -287,3 +351,247 @@ impl hkdf::KeyType for OutLen { self.0 } } + +#[derive(Debug)] +pub struct SealUpdate(Updater); + +impl SealUpdate { + #[inline] + pub fn new(secret: &[u8], ciphersuite: &Ciphersuite) -> Self { + Self(Updater::new(secret, ciphersuite)) + } + + #[inline] + pub fn next(&self) -> (seal::Application, SealUpdate) { + self.0.next(|key, iv, updater| { + let key = seal::Application::new(key, iv, updater.ciphersuite.aead()); + (key, Self(updater)) + }) + } +} + +#[derive(Debug)] +pub struct OpenUpdate(Updater); + +impl OpenUpdate { + #[inline] + pub fn new(secret: &[u8], ciphersuite: &Ciphersuite) -> Self { + Self(Updater::new(secret, ciphersuite)) + } + + #[inline] + pub fn next(&self) -> (open::Application, OpenUpdate) { + self.0.next(|key, iv, updater| { + let key = open::Application::new(key, iv, updater.ciphersuite.aead()); + (key, Self(updater)) + }) + } +} + +#[derive(Debug)] +struct Updater { + prk: Prk, + ciphersuite: Ciphersuite, +} + +impl Updater { + #[inline] + fn new(secret: &[u8], ciphersuite: &Ciphersuite) -> Self { + let prk = Prk::new_less_safe(ciphersuite.hkdf(), secret); + let ciphersuite = *ciphersuite; + Self { prk, ciphersuite } + } + + #[inline] + fn next(&self, f: F) -> R + where + F: FnOnce(&[u8], [u8; NONCE_LEN], Updater) -> R, + { + let ciphersuite = &self.ciphersuite; + + let mut out = [0u8; NONCE_LEN + MAX_KEY_LEN * 2]; + let key_len = hkdf::KeyType::len(ciphersuite); + let out_len = NONCE_LEN + key_len * 2; + let (out, _) = out.split_at_mut(out_len); + self.prk + .expand_into(&[&(out_len as u16).to_be_bytes(), b" ku"], out); + + // if the hash is ever broken, it's better to put the "more secret" data at the beginning + // + // here we derive: + // + // (key_update, key, iv) + let (ku, out) = out.split_at(key_len); + let (key, iv) = out.split_at(key_len); + let iv = iv.try_into().unwrap(); + + let ku = Self::new(ku, ciphersuite); + + f(key, iv, ku) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::path::secret::{ + map::Dedup, open::Application as Opener, seal::Application as Sealer, + }; + use bolero::*; + + #[derive(Clone, Copy, Debug, TypeGenerator)] + struct Pair { + ciphersuite: Ciphersuite, + key_id: VarInt, + initiator_is_client: bool, + } + + impl Pair { + fn initiator(&self) -> endpoint::Type { + if self.initiator_is_client { + endpoint::Type::Client + } else { + endpoint::Type::Server + } + } + + fn endpoints(&self) -> (Secret, Secret) { + let secret = &[42; 32]; + let client = Secret::new(self.ciphersuite, 0, endpoint::Type::Client, secret); + let server = Secret::new(self.ciphersuite, 0, endpoint::Type::Server, secret); + (client, server) + } + + fn check_app(self) { + let (client, server) = self.endpoints(); + let (client_i, server_i) = match self.initiator() { + endpoint::Type::Client => (Initiator::Local, Initiator::Remote), + endpoint::Type::Server => (Initiator::Remote, Initiator::Local), + }; + let mut client_app = Application::new(client.application_pair(self.key_id, client_i)); + let mut server_app = Application::new(server.application_pair(self.key_id, server_i)); + + for i in 0..8 { + client_app.send(&server_app).unwrap(); + server_app.send(&client_app).unwrap(); + + // invalid sender/recipient should fail + client_app.send(&client_app).unwrap_err(); + server_app.send(&server_app).unwrap_err(); + + let (sender, receiver) = if i % 2 == 0 { + dbg!("client ku"); + (&mut client_app, &mut server_app) + } else { + dbg!("server ku"); + (&mut server_app, &mut client_app) + }; + + sender.sealer.update(); + sender.send(receiver).unwrap(); + + assert!(receiver.opener.needs_update()); + receiver.opener.update(); + } + } + + fn check_control(self) { + let (client, server) = self.endpoints(); + let (client_i, server_i) = match self.initiator() { + endpoint::Type::Client => (Initiator::Local, Initiator::Remote), + endpoint::Type::Server => (Initiator::Remote, Initiator::Local), + }; + let client_app = Control::new(client.control_pair(self.key_id, client_i)); + let server_app = Control::new(server.control_pair(self.key_id, server_i)); + + client_app.send(&server_app).unwrap(); + server_app.send(&client_app).unwrap(); + + // invalid sender/recipient should fail + client_app.send(&client_app).unwrap_err(); + server_app.send(&server_app).unwrap_err(); + } + } + + struct Application { + sealer: Sealer, + opener: Opener, + } + + impl Application { + fn new( + (sealer, sealer_ku, opener, opener_ku): ( + seal::Application, + SealUpdate, + open::Application, + OpenUpdate, + ), + ) -> Self { + let sealer = Sealer::new(sealer, sealer_ku); + let opener = Opener::new(opener, opener_ku, Dedup::disabled()); + Self { sealer, opener } + } + + fn send(&self, other: &Self) -> crate::crypto::open::Result { + use crate::crypto::{open::Application as _, seal::Application as _}; + + let msg = b"hello"; + let mut buf = [0u8; 5 + 16]; + + let packet_number = 0u64; + let header = &[]; + + let key_phase = self.sealer.key_phase(); + self.sealer + .encrypt(packet_number, header, Some(msg), &mut buf); + + assert_ne!(msg, &buf[..5]); + + other + .opener + .decrypt_in_place(key_phase, packet_number, header, &mut buf)?; + + assert_eq!(msg, &buf[..5]); + + Ok(()) + } + } + + struct Control { + sealer: seal::control::Stream, + opener: open::control::Stream, + } + + impl Control { + fn new((sealer, opener): (seal::control::Stream, open::control::Stream)) -> Self { + Self { sealer, opener } + } + + fn send(&self, other: &Self) -> crate::crypto::open::Result { + use crate::crypto::{open::Control as _, seal::Control as _}; + + let msg = b"hello"; + let mut tag = [0u8; crate::packet::secret_control::TAG_LEN]; + + self.sealer.sign(msg, &mut tag); + + other.opener.verify(msg, &tag)?; + + Ok(()) + } + } + + #[test] + fn application_pair() { + bolero::check!() + .with_type::() + .for_each(|input| input.check_app()) + } + + #[test] + fn control_pair() { + bolero::check!() + .with_type::() + .for_each(|input| input.check_control()) + } +} diff --git a/dc/s2n-quic-dc/src/path/secret/sender.rs b/dc/s2n-quic-dc/src/path/secret/sender.rs index 3f50eb8617..54d6193dda 100644 --- a/dc/s2n-quic-dc/src/path/secret/sender.rs +++ b/dc/s2n-quic-dc/src/path/secret/sender.rs @@ -2,20 +2,22 @@ // SPDX-License-Identifier: Apache-2.0 use super::schedule; -use crate::crypto::awslc::DecryptKey; +use crate::{crypto::awslc::open, packet::secret_control}; use once_cell::sync::OnceCell; use s2n_quic_core::varint::VarInt; use std::sync::atomic::{AtomicU64, Ordering}; +type StatelessReset = [u8; secret_control::TAG_LEN]; + #[derive(Debug)] pub struct State { current_id: AtomicU64, - pub(super) stateless_reset: [u8; 16], - control_secret: OnceCell, + pub(super) stateless_reset: StatelessReset, + control_secret: OnceCell, } impl State { - pub fn new(stateless_reset: [u8; 16]) -> Self { + pub fn new(stateless_reset: StatelessReset) -> Self { Self { current_id: AtomicU64::new(0), stateless_reset, @@ -45,7 +47,7 @@ impl State { } #[inline] - pub fn control_secret(&self, secret: &schedule::Secret) -> &DecryptKey { + pub fn control_secret(&self, secret: &schedule::Secret) -> &open::control::Secret { self.control_secret.get_or_init(|| secret.control_opener()) } @@ -65,7 +67,7 @@ impl State { #[test] #[should_panic = "2^62 integer incremented"] fn sender_does_not_wrap() { - let state = State::new([0; 16]); + let state = State::new([0; secret_control::TAG_LEN]); assert_eq!(*state.next_key_id(), 0); state.current_id.store((1 << 62) - 3, Ordering::Relaxed); @@ -79,7 +81,7 @@ fn sender_does_not_wrap() { #[test] fn update_restarts_sequence() { - let state = State::new([0; 16]); + let state = State::new([0; secret_control::TAG_LEN]); assert_eq!(*state.next_key_id(), 0); state.update_for_stale_key(VarInt::new(3).unwrap()); diff --git a/dc/s2n-quic-dc/src/path/secret/stateless_reset.rs b/dc/s2n-quic-dc/src/path/secret/stateless_reset.rs index 26edbdc7a8..213805a6e8 100644 --- a/dc/s2n-quic-dc/src/path/secret/stateless_reset.rs +++ b/dc/s2n-quic-dc/src/path/secret/stateless_reset.rs @@ -1,20 +1,19 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use super::schedule; -use crate::credentials::Id; -use aws_lc_rs::hkdf::{Prk, Salt, HKDF_SHA384}; +use crate::{credentials::Id, packet::secret_control::TAG_LEN}; +use aws_lc_rs::hmac; #[derive(Debug)] pub struct Signer { - prk: Prk, + key: hmac::Key, } impl Signer { /// Creates a signer with the given secret pub fn new(secret: &[u8]) -> Self { - let prk = Salt::new(HKDF_SHA384, secret).extract(b"rst"); - Self { prk } + let key = hmac::Key::new(hmac::HMAC_SHA384, secret); + Self { key } } /// Returns a random `Signer` @@ -27,14 +26,11 @@ impl Signer { Self::new(&secret) } - pub fn sign(&self, id: &Id) -> [u8; 16] { - let mut stateless_reset = [0; 16]; + pub fn sign(&self, id: &Id) -> [u8; TAG_LEN] { + let mut stateless_reset = [0; TAG_LEN]; - self.prk - .expand(&[&[16], b"rst ", &**id], schedule::OutLen(16)) - .unwrap() - .fill(&mut stateless_reset) - .unwrap(); + let tag = hmac::sign(&self.key, &**id); + stateless_reset.copy_from_slice(&tag.as_ref()[..TAG_LEN]); stateless_reset } diff --git a/dc/s2n-quic-dc/src/stream.rs b/dc/s2n-quic-dc/src/stream.rs index 941a51e204..a41b639eeb 100644 --- a/dc/s2n-quic-dc/src/stream.rs +++ b/dc/s2n-quic-dc/src/stream.rs @@ -7,6 +7,8 @@ use core::time::Duration; pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(30); /// The maximum time a send stream will wait for ACKs from inflight packets pub const DEFAULT_INFLIGHT_TIMEOUT: Duration = Duration::from_secs(5); +/// The maximum length of a single packet written to a stream +pub const MAX_DATAGRAM_SIZE: usize = 1 << 15; // 32k pub mod application; pub mod crypto; diff --git a/dc/s2n-quic-dc/src/stream/crypto.rs b/dc/s2n-quic-dc/src/stream/crypto.rs index fa054b8446..a9693f476f 100644 --- a/dc/s2n-quic-dc/src/stream/crypto.rs +++ b/dc/s2n-quic-dc/src/stream/crypto.rs @@ -1,34 +1,50 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::path::secret::{Map, Opener, Sealer}; -use core::{fmt, sync::atomic::Ordering}; -use crossbeam_epoch::{pin, Atomic}; +use crate::{ + crypto::awslc::{open, seal}, + path::secret::{open::Application as Opener, seal::Application as Sealer, Map}, +}; +use core::fmt; +use std::sync::Mutex; -// TODO support key updates pub struct Crypto { - sealer: Atomic, - opener: Atomic, + app_sealer: Mutex, + app_opener: Mutex, + control_sealer: Option, + control_opener: Option, map: Map, } impl fmt::Debug for Crypto { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Crypto") - .field("sealer", &self.sealer) - .field("opener", &self.opener) + .field("sealer", &self.app_sealer) + .field("opener", &self.app_opener) .finish() } } impl Crypto { #[inline] - pub fn new(sealer: Sealer, opener: Opener, map: &Map) -> Self { - let sealer = Atomic::new(sealer); - let opener = Atomic::new(opener); + pub fn new( + app_sealer: Sealer, + app_opener: Opener, + control: Option<(seal::control::Stream, open::control::Stream)>, + map: &Map, + ) -> Self { + let app_sealer = Mutex::new(app_sealer); + let app_opener = Mutex::new(app_opener); + let (control_sealer, control_opener) = if let Some((s, o)) = control { + (Some(s), Some(o)) + } else { + (None, None) + }; Self { - sealer, - opener, + app_sealer, + app_opener, + control_sealer, + control_opener, map: map.clone(), } } @@ -44,50 +60,44 @@ impl Crypto { } #[inline] - pub fn seal_with(&self, seal: impl FnOnce(&Sealer) -> R) -> R { - let pin = pin(); - let sealer = self.seal_pin(&pin); - seal(sealer) - } + pub fn seal_with( + &self, + seal: impl FnOnce(&Sealer) -> R, + update: impl FnOnce(&mut Sealer), + ) -> R { + let lock = &self.app_sealer; + let mut guard = lock.lock().unwrap(); + let result = seal(&guard); - #[inline] - fn seal_pin<'a>(&self, pin: &'a crossbeam_epoch::Guard) -> &'a Sealer { - let sealer = self.sealer.load(Ordering::Acquire, pin); - unsafe { sealer.deref() } + // update the keys if needed + if guard.needs_update() { + update(&mut guard); + } + + result } #[inline] pub fn open_with(&self, open: impl FnOnce(&Opener) -> R) -> R { - let pin = pin(); - let opener = self.open_pin(&pin); - open(opener) + let lock = &self.app_opener; + let mut guard = lock.lock().unwrap(); + let result = open(&guard); + + // update the keys if needed + if guard.needs_update() { + guard.update(); + } + + result } #[inline] - fn open_pin<'a>(&self, pin: &'a crossbeam_epoch::Guard) -> &'a Opener { - let opener = self.opener.load(Ordering::Acquire, pin); - unsafe { opener.deref() } + pub fn control_sealer(&self) -> Option<&seal::control::Stream> { + self.control_sealer.as_ref() } -} -impl Drop for Crypto { #[inline] - fn drop(&mut self) { - use crossbeam_epoch::Shared; - let pin = pin(); - let sealer = self.sealer.swap(Shared::null(), Ordering::AcqRel, &pin); - let opener = self.opener.swap(Shared::null(), Ordering::AcqRel, &pin); - - // no need to drop either one - if sealer.is_null() && opener.is_null() { - return; - } - - unsafe { - pin.defer_unchecked(move || { - drop(sealer.try_into_owned()); - drop(opener.try_into_owned()); - }) - } + pub fn control_opener(&self) -> Option<&open::control::Stream> { + self.control_opener.as_ref() } } diff --git a/dc/s2n-quic-dc/src/stream/endpoint.rs b/dc/s2n-quic-dc/src/stream/endpoint.rs index 3c9b97ef8f..8d617b458c 100644 --- a/dc/s2n-quic-dc/src/stream/endpoint.rs +++ b/dc/s2n-quic-dc/src/stream/endpoint.rs @@ -2,9 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 use crate::{ - crypto::encrypt::Key as _, msg, packet, - path::secret::Map, + path::secret::{self, Map}, random::Random, stream::{ application, @@ -44,7 +43,8 @@ where P: Peer, { // derive secrets for the new stream - let Some((sealer, opener, mut parameters)) = map.pair_for_peer(handshake_addr.into()) else { + let Some((crypto, mut parameters)) = map.pair_for_peer(handshake_addr.into(), &peer.features()) + else { // the application didn't perform a handshake with the server before opening the stream return Err(io::Error::new( io::ErrorKind::NotFound, @@ -56,22 +56,20 @@ where parameters = o(parameters); } - // TODO get a flow ID. for now we'll use the sealer credentials - let key_id = sealer.credentials().key_id; + let key_id = crypto.credentials.key_id; let stream_id = packet::stream::Id { key_id, is_reliable: true, is_bidirectional: true, }; - let crypto = shared::Crypto::new(sealer, opener, map); - build_stream( env, peer, stream_id, None, crypto, + map, parameters, None, None, @@ -95,8 +93,8 @@ where { let credentials = &packet.credentials; let mut secret_control = vec![]; - let Some((sealer, opener, mut parameters)) = - map.pair_for_credentials(credentials, &mut secret_control) + let Some((crypto, mut parameters)) = + map.pair_for_credentials(credentials, &peer.features(), &mut secret_control) else { let error = io::Error::new( io::ErrorKind::NotFound, @@ -117,14 +115,13 @@ where // inform the value of what the source_control_port is peer.with_source_control_port(packet.source_control_port); - let crypto = shared::Crypto::new(sealer, opener, map); - let res = build_stream( env, peer, packet.stream_id, packet.source_stream_port, crypto, + map, parameters, handshake, buffer, @@ -150,7 +147,8 @@ fn build_stream( peer: P, stream_id: packet::stream::Id, remote_stream_port: Option, - crypto: shared::Crypto, + crypto: secret::map::Bidirectional, + map: &Map, parameters: dc::ApplicationParams, handshake: Option, recv_buffer: Option<&mut msg::recv::Message>, @@ -219,6 +217,7 @@ where remote_ip: UnsafeCell::new(sockets.remote_addr.ip()), source_control_port: UnsafeCell::new(sockets.source_control_port), application: UnsafeCell::new(application), + credentials: UnsafeCell::new(crypto.credentials), }; let remote_port = sockets.remote_addr.port(); @@ -235,6 +234,18 @@ where } }; + let crypto = { + let secret::map::Bidirectional { + application, + control, + credentials: _, + } = crypto; + + let control = control.map(|c| (c.sealer, c.opener)); + + shared::Crypto::new(application.sealer, application.opener, control, map) + }; + let shared = Arc::new(shared::Shared { receiver: reader, sender: writer.0, diff --git a/dc/s2n-quic-dc/src/stream/environment.rs b/dc/s2n-quic-dc/src/stream/environment.rs index b5e2bc25f8..d9bbbf5b8e 100644 --- a/dc/s2n-quic-dc/src/stream/environment.rs +++ b/dc/s2n-quic-dc/src/stream/environment.rs @@ -3,25 +3,12 @@ use crate::{ clock, - crypto::encrypt::Key as _, - msg, packet, - path::secret::Map, - random::Random, - stream::{ - application, recv, runtime, - send::{self, flow}, - server, shared, socket, TransportFeatures, - }, -}; -use core::{cell::UnsafeCell, future::Future}; -use s2n_quic_core::{ - dc, endpoint, - inet::{ExplicitCongestionNotification, SocketAddress}, - varint::VarInt, + stream::{runtime, socket, TransportFeatures}, }; +use core::future::Future; +use s2n_quic_core::inet::SocketAddress; use s2n_quic_platform::features; -use std::{io, sync::Arc}; -use tracing::{debug_span, Instrument as _}; +use std::io; type Result = core::result::Result; @@ -76,297 +63,4 @@ impl Builder { pub fn clock(&self) -> &E::Clock { self.env.clock() } - - #[inline] - pub fn open_stream

( - &self, - handshake_addr: SocketAddress, - peer: P, - map: &Map, - parameter_override: Option<&dyn Fn(dc::ApplicationParams) -> dc::ApplicationParams>, - ) -> Result - where - P: Peer, - { - // derive secrets for the new stream - let Some((sealer, opener, mut parameters)) = map.pair_for_peer(handshake_addr.into()) - else { - // the application didn't perform a handshake with the server before opening the stream - return Err(io::Error::new( - io::ErrorKind::NotFound, - format!("missing credentials for server: {handshake_addr}"), - )); - }; - - if let Some(o) = parameter_override { - parameters = o(parameters); - } - - // TODO get a flow ID. for now we'll use the sealer credentials - let key_id = sealer.credentials().key_id; - let stream_id = packet::stream::Id { - key_id, - is_reliable: true, - is_bidirectional: true, - }; - - let crypto = shared::Crypto::new(sealer, opener, map); - - self.build_stream( - peer, - stream_id, - None, - crypto, - parameters, - None, - None, - endpoint::Type::Client, - ) - } - - #[inline] - pub fn accept_stream

( - &self, - mut peer: P, - packet: &server::InitialPacket, - handshake: Option, - buffer: Option<&mut msg::recv::Message>, - map: &Map, - parameter_override: Option<&dyn Fn(dc::ApplicationParams) -> dc::ApplicationParams>, - ) -> Result> - where - P: Peer, - { - let credentials = &packet.credentials; - let mut secret_control = vec![]; - let Some((sealer, opener, mut parameters)) = - map.pair_for_credentials(credentials, &mut secret_control) - else { - let error = io::Error::new( - io::ErrorKind::NotFound, - format!("missing credentials for client: {credentials:?}"), - ); - let error = AcceptError { - secret_control, - peer: Some(peer), - error, - }; - return Err(error); - }; - - if let Some(o) = parameter_override { - parameters = o(parameters); - } - - // inform the value of what the source_control_port is - peer.with_source_control_port(packet.source_control_port); - - let crypto = shared::Crypto::new(sealer, opener, map); - - let res = self.build_stream( - peer, - packet.stream_id, - packet.source_stream_port, - crypto, - parameters, - handshake, - buffer, - endpoint::Type::Server, - ); - - match res { - Ok(stream) => Ok(stream), - Err(error) => { - let error = AcceptError { - secret_control, - peer: None, - error, - }; - Err(error) - } - } - } - - #[inline] - fn build_stream

( - &self, - peer: P, - stream_id: packet::stream::Id, - remote_stream_port: Option, - crypto: shared::Crypto, - parameters: dc::ApplicationParams, - handshake: Option, - recv_buffer: Option<&mut msg::recv::Message>, - endpoint_type: endpoint::Type, - ) -> Result - where - P: Peer, - { - let features = peer.features(); - - let sockets = peer.setup(&self.env)?; - - // construct shared reader state - let reader = - recv::shared::State::new(stream_id, ¶meters, handshake, features, recv_buffer); - - let writer = { - let worker = sockets - .write_worker - .map(|socket| (send::state::State::new(stream_id, ¶meters), socket)); - - let (flow_offset, send_quantum, bandwidth) = - if let Some((worker, _socket)) = worker.as_ref() { - let flow_offset = worker.flow_offset(); - let send_quantum = worker.send_quantum_packets(); - let bandwidth = Some(worker.cca.bandwidth()); - - (flow_offset, send_quantum, bandwidth) - } else { - debug_assert!( - features.is_flow_controlled(), - "transports without flow control need background workers" - ); - - let flow_offset = VarInt::MAX; - let send_quantum = 10; - let bandwidth = None; - - (flow_offset, send_quantum, bandwidth) - }; - - let flow = flow::non_blocking::State::new(flow_offset); - - let path = send::path::Info { - max_datagram_size: parameters.max_datagram_size, - send_quantum, - ecn: ExplicitCongestionNotification::Ect0, - next_expected_control_packet: VarInt::ZERO, - }; - - // construct shared writer state - let state = send::shared::State::new(flow, path, bandwidth); - - (state, worker) - }; - - // construct shared common state between readers/writers - let common = { - let application = send::application::state::State { - stream_id, - source_control_port: sockets.source_control_port, - source_stream_port: sockets.source_stream_port, - }; - - let fixed = shared::FixedValues { - remote_ip: UnsafeCell::new(sockets.remote_addr.ip()), - source_control_port: UnsafeCell::new(sockets.source_control_port), - application: UnsafeCell::new(application), - }; - - let remote_port = sockets.remote_addr.port(); - let write_remote_port = remote_stream_port.unwrap_or(remote_port); - - shared::Common { - clock: self.env.clock().clone(), - gso: self.env.gso(), - read_remote_port: remote_port.into(), - write_remote_port: write_remote_port.into(), - last_peer_activity: Default::default(), - fixed, - closed_halves: 0u8.into(), - } - }; - - let shared = Arc::new(shared::Shared { - receiver: reader, - sender: writer.0, - common, - crypto, - }); - - // spawn the read worker - if let Some(socket) = sockets.read_worker { - let shared = shared.clone(); - - let task = async move { - let mut reader = recv::worker::Worker::new(socket, shared, endpoint_type); - - let mut prev_waker: Option = None; - core::future::poll_fn(|cx| { - // update the waker if needed - if prev_waker - .as_ref() - .map_or(true, |prev| !prev.will_wake(cx.waker())) - { - prev_waker = Some(cx.waker().clone()); - reader.update_waker(cx); - } - - // drive the reader to completion - reader.poll(cx) - }) - .await; - }; - - let span = debug_span!("worker::read"); - - if span.is_disabled() { - self.env.spawn_reader(task); - } else { - self.env.spawn_reader(task.instrument(span)); - } - } - - // spawn the write worker - if let Some((worker, socket)) = writer.1 { - let shared = shared.clone(); - - let task = async move { - let mut writer = send::worker::Worker::new( - socket, - Random::default(), - shared, - worker, - endpoint_type, - ); - - let mut prev_waker: Option = None; - core::future::poll_fn(|cx| { - // update the waker if needed - if prev_waker - .as_ref() - .map_or(true, |prev| !prev.will_wake(cx.waker())) - { - prev_waker = Some(cx.waker().clone()); - writer.update_waker(cx); - } - - // drive the writer to completion - writer.poll(cx) - }) - .await; - }; - - let span = debug_span!("worker::write"); - - if span.is_disabled() { - self.env.spawn_writer(task); - } else { - self.env.spawn_writer(task.instrument(span)); - } - } - - let read = recv::application::Builder::new(endpoint_type, self.env.reader_rt()); - let write = send::application::Builder::new(self.env.writer_rt()); - - let stream = application::Builder { - read, - write, - shared, - sockets: sockets.application, - }; - - Ok(stream) - } } diff --git a/dc/s2n-quic-dc/src/stream/processing.rs b/dc/s2n-quic-dc/src/stream/processing.rs index 0dccd2dc04..841840603c 100644 --- a/dc/s2n-quic-dc/src/stream/processing.rs +++ b/dc/s2n-quic-dc/src/stream/processing.rs @@ -1,12 +1,12 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::crypto::decrypt; +use crate::crypto::open; #[derive(Clone, Copy, Debug, thiserror::Error)] pub enum Error { - #[error("packet could not be decrypted")] - Decrypt, + #[error("packet could not be decrypted: {0}")] + Crypto(open::Error), #[error("packet has already been processed")] Duplicate, #[error("the crypto key has been replayed and is invalid")] @@ -15,14 +15,14 @@ pub enum Error { KeyReplayPotentiallyPrevented { gap: Option }, } -impl From for Error { - fn from(value: decrypt::Error) -> Self { +impl From for Error { + fn from(value: open::Error) -> Self { match value { - decrypt::Error::ReplayDefinitelyDetected => Self::KeyReplayPrevented, - decrypt::Error::ReplayPotentiallyDetected { gap } => { + open::Error::ReplayDefinitelyDetected => Self::KeyReplayPrevented, + open::Error::ReplayPotentiallyDetected { gap } => { Self::KeyReplayPotentiallyPrevented { gap } } - decrypt::Error::InvalidTag => Self::Decrypt, + error => Self::Crypto(error), } } } diff --git a/dc/s2n-quic-dc/src/stream/recv/error.rs b/dc/s2n-quic-dc/src/stream/recv/error.rs index f95b2a42b2..8bb22bc59b 100644 --- a/dc/s2n-quic-dc/src/stream/recv/error.rs +++ b/dc/s2n-quic-dc/src/stream/recv/error.rs @@ -2,20 +2,86 @@ // SPDX-License-Identifier: Apache-2.0 use crate::{ - crypto::decrypt, + credentials, + crypto::open, packet::{self, stream}, stream::TransportFeatures, }; +use core::{fmt, panic::Location}; use s2n_quic_core::{buffer, frame}; +#[derive(Clone, Copy)] +pub struct Error { + kind: Kind, + location: &'static Location<'static>, +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Error") + .field("kind", &self.kind) + .field("crate", &"s2n-quic-dc") + .field("file", &self.file()) + .field("line", &self.location.line()) + .finish() + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let Self { kind, location } = self; + let file = self.file(); + let line = location.line(); + write!(f, "[s2n-quic-dc::{file}:{line}]: {kind}") + } +} + +impl std::error::Error for Error {} + +impl Error { + #[track_caller] + #[inline] + pub fn new(kind: Kind) -> Self { + Self { + kind, + location: Location::caller(), + } + } + + #[inline] + pub fn kind(&self) -> &Kind { + &self.kind + } + + #[inline] + fn file(&self) -> &'static str { + self.location + .file() + .trim_start_matches(concat!(env!("CARGO_MANIFEST_DIR"), "/src/")) + } +} + +impl From for Error { + #[track_caller] + #[inline] + fn from(kind: Kind) -> Self { + Self::new(kind) + } +} + #[derive(Clone, Copy, Debug, thiserror::Error)] -pub enum Error { +pub enum Kind { #[error("could not decode packet")] Decode, - #[error("could not decrypt packet")] - Decrypt, + #[error("could not decrypt packet: {0}")] + Crypto(open::Error), #[error("packet has already been processed")] Duplicate, + #[error("the packet was for another credential ({actual:?}) but was handled by {expected:?}")] + CredentialMismatch { + expected: credentials::Id, + actual: credentials::Id, + }, #[error("the packet was for another stream ({actual}) but was handled by {expected}")] StreamMismatch { expected: stream::Id, @@ -47,15 +113,23 @@ pub enum Error { UnexpectedPacket { packet: packet::Kind }, } -impl From for Error { - fn from(value: decrypt::Error) -> Self { +impl Kind { + #[inline] + #[track_caller] + pub(crate) fn err(self) -> Error { + Error::new(self) + } +} + +impl From for Error { + #[track_caller] + fn from(value: open::Error) -> Self { match value { - decrypt::Error::ReplayDefinitelyDetected => Self::KeyReplayPrevented, - decrypt::Error::ReplayPotentiallyDetected { gap } => { - Self::KeyReplayMaybePrevented { gap } - } - decrypt::Error::InvalidTag => Self::Decrypt, + open::Error::ReplayDefinitelyDetected => Kind::KeyReplayPrevented, + open::Error::ReplayPotentiallyDetected { gap } => Kind::KeyReplayMaybePrevented { gap }, + error => Kind::Crypto(error), } + .err() } } @@ -69,45 +143,51 @@ impl Error { } !matches!( - self, - Self::Decode | Self::Decrypt | Self::Duplicate | Self::StreamMismatch { .. } + self.kind(), + Kind::Decode + | Kind::Crypto(_) + | Kind::Duplicate + | Kind::CredentialMismatch { .. } + | Kind::StreamMismatch { .. } ) } #[inline] pub(super) fn connection_close(&self) -> Option> { use s2n_quic_core::transport; - match self { - Error::Decode - | Error::Decrypt - | Error::Duplicate - | Error::StreamMismatch { .. } - | Error::UnexpectedPacket { .. } - | Error::UnexpectedRetransmission => { + match self.kind() { + Kind::Decode + | Kind::Crypto(_) + | Kind::Duplicate + | Kind::CredentialMismatch { .. } + | Kind::StreamMismatch { .. } + | Kind::UnexpectedPacket { .. } + | Kind::UnexpectedRetransmission => { // return protocol violation for the errors that are only fatal for reliable // transports Some(transport::Error::PROTOCOL_VIOLATION.into()) } - Error::IdleTimeout => None, - Error::MaxDataExceeded => Some(transport::Error::FLOW_CONTROL_ERROR.into()), - Error::InvalidFin | Error::TruncatedTransport => { + Kind::IdleTimeout => None, + Kind::MaxDataExceeded => Some(transport::Error::FLOW_CONTROL_ERROR.into()), + Kind::InvalidFin | Kind::TruncatedTransport => { Some(transport::Error::FINAL_SIZE_ERROR.into()) } - Error::OutOfOrder { .. } => Some(transport::Error::STREAM_STATE_ERROR.into()), - Error::OutOfRange => Some(transport::Error::STREAM_LIMIT_ERROR.into()), + Kind::OutOfOrder { .. } => Some(transport::Error::STREAM_STATE_ERROR.into()), + Kind::OutOfRange => Some(transport::Error::STREAM_LIMIT_ERROR.into()), // we don't have working crypto keys so we can't respond - Error::KeyReplayPrevented | Error::KeyReplayMaybePrevented { .. } => None, - Error::ApplicationError { error } => Some((*error).into()), + Kind::KeyReplayPrevented | Kind::KeyReplayMaybePrevented { .. } => None, + Kind::ApplicationError { error } => Some((*error).into()), } } } impl From> for Error { #[inline] + #[track_caller] fn from(value: buffer::Error) -> Self { match value { - buffer::Error::OutOfRange => Self::OutOfRange, - buffer::Error::InvalidFin => Self::InvalidFin, + buffer::Error::OutOfRange => Kind::OutOfRange.err(), + buffer::Error::InvalidFin => Kind::InvalidFin.err(), buffer::Error::ReaderError(error) => error, } } @@ -116,36 +196,36 @@ impl From> for Error { impl From for std::io::Error { #[inline] fn from(error: Error) -> Self { - Self::new(error.into(), error) + Self::new(error.kind.into(), error) } } -impl From for std::io::ErrorKind { +impl From for std::io::ErrorKind { #[inline] - fn from(error: Error) -> Self { + fn from(kind: Kind) -> Self { use std::io::ErrorKind; - match error { - Error::Decode => ErrorKind::InvalidData, - Error::Decrypt => ErrorKind::InvalidData, - Error::Duplicate => ErrorKind::InvalidData, - Error::StreamMismatch { .. } => ErrorKind::InvalidData, - Error::MaxDataExceeded => ErrorKind::ConnectionAborted, - Error::InvalidFin => ErrorKind::InvalidData, - Error::TruncatedTransport => ErrorKind::UnexpectedEof, - Error::OutOfRange => ErrorKind::ConnectionAborted, - Error::OutOfOrder { .. } => ErrorKind::InvalidData, - Error::UnexpectedRetransmission { .. } => ErrorKind::InvalidData, - Error::IdleTimeout => ErrorKind::TimedOut, - Error::KeyReplayPrevented => ErrorKind::PermissionDenied, - Error::KeyReplayMaybePrevented { .. } => ErrorKind::PermissionDenied, - Error::ApplicationError { .. } => ErrorKind::ConnectionReset, - Error::UnexpectedPacket { + match kind { + Kind::Decode => ErrorKind::InvalidData, + Kind::Crypto(_) => ErrorKind::InvalidData, + Kind::Duplicate => ErrorKind::InvalidData, + Kind::CredentialMismatch { .. } | Kind::StreamMismatch { .. } => ErrorKind::InvalidData, + Kind::MaxDataExceeded => ErrorKind::ConnectionAborted, + Kind::InvalidFin => ErrorKind::InvalidData, + Kind::TruncatedTransport => ErrorKind::UnexpectedEof, + Kind::OutOfRange => ErrorKind::ConnectionAborted, + Kind::OutOfOrder { .. } => ErrorKind::InvalidData, + Kind::UnexpectedRetransmission { .. } => ErrorKind::InvalidData, + Kind::IdleTimeout => ErrorKind::TimedOut, + Kind::KeyReplayPrevented => ErrorKind::PermissionDenied, + Kind::KeyReplayMaybePrevented { .. } => ErrorKind::PermissionDenied, + Kind::ApplicationError { .. } => ErrorKind::ConnectionReset, + Kind::UnexpectedPacket { packet: packet::Kind::UnknownPathSecret | packet::Kind::StaleKey | packet::Kind::ReplayDetected, } => ErrorKind::ConnectionRefused, - Error::UnexpectedPacket { + Kind::UnexpectedPacket { packet: packet::Kind::Stream | packet::Kind::Control | packet::Kind::Datagram, } => ErrorKind::InvalidData, } diff --git a/dc/s2n-quic-dc/src/stream/recv/packet.rs b/dc/s2n-quic-dc/src/stream/recv/packet.rs index 66b44a490e..5a5c0d99e7 100644 --- a/dc/s2n-quic-dc/src/stream/recv/packet.rs +++ b/dc/s2n-quic-dc/src/stream/recv/packet.rs @@ -2,7 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 use crate::{ - crypto::decrypt, + credentials::Credentials, + crypto, packet::stream, stream::recv::{state::State as Receiver, Error}, }; @@ -13,17 +14,29 @@ use s2n_quic_core::{ varint::VarInt, }; -pub struct Packet<'a, 'p, D: decrypt::Key, C: Clock + ?Sized> { +pub struct Packet<'a, 'p, D, K, C> +where + D: crypto::open::Application, + K: crypto::open::control::Stream, + C: Clock + ?Sized, +{ pub packet: &'a mut stream::decoder::Packet<'p>, pub payload_cursor: usize, pub is_decrypted_in_place: bool, pub ecn: ExplicitCongestionNotification, pub clock: &'a C, pub opener: &'a D, + pub control: &'a K, + pub credentials: &'a Credentials, pub receiver: &'a mut Receiver, } -impl<'a, 'p, D: decrypt::Key, C: Clock + ?Sized> reader::Storage for Packet<'a, 'p, D, C> { +impl<'a, 'p, D, K, C: Clock> reader::Storage for Packet<'a, 'p, D, K, C> +where + D: crypto::open::Application, + K: crypto::open::control::Stream, + C: Clock + ?Sized, +{ type Error = Error; #[inline] @@ -36,6 +49,8 @@ impl<'a, 'p, D: decrypt::Key, C: Clock + ?Sized> reader::Storage for Packet<'a, if !self.is_decrypted_in_place { self.receiver.on_stream_packet_in_place( self.opener, + self.control, + self.credentials, self.packet, self.ecn, self.clock, @@ -79,6 +94,8 @@ impl<'a, 'p, D: decrypt::Key, C: Clock + ?Sized> reader::Storage for Packet<'a, let did_write = dest.put_uninit_slice(self.packet.payload().len(), |dest| { self.receiver.on_stream_packet_copy( self.opener, + self.control, + self.credentials, self.packet, self.ecn, dest, @@ -97,7 +114,12 @@ impl<'a, 'p, D: decrypt::Key, C: Clock + ?Sized> reader::Storage for Packet<'a, } } -impl<'a, 'p, D: decrypt::Key, C: Clock + ?Sized> Reader for Packet<'a, 'p, D, C> { +impl<'a, 'p, D, K, C: Clock> Reader for Packet<'a, 'p, D, K, C> +where + D: crypto::open::Application, + K: crypto::open::control::Stream, + C: Clock + ?Sized, +{ #[inline] fn current_offset(&self) -> VarInt { self.packet.stream_offset() + self.payload_cursor diff --git a/dc/s2n-quic-dc/src/stream/recv/shared.rs b/dc/s2n-quic-dc/src/stream/recv/shared.rs index 6f7195eb09..b42a7018df 100644 --- a/dc/s2n-quic-dc/src/stream/recv/shared.rs +++ b/dc/s2n-quic-dc/src/stream/recv/shared.rs @@ -273,10 +273,16 @@ impl Inner { ) { let source_control_port = shared.source_control_port(); - shared.crypto.seal_with(|sealer| { - self.receiver - .on_transmit(sealer, source_control_port, send_buffer, &shared.clock) - }); + self.receiver.on_transmit( + shared + .crypto + .control_sealer() + .expect("control sealer should be available with recv transmissions"), + shared.credentials(), + source_control_port, + send_buffer, + &shared.clock, + ); ensure!(!send_buffer.is_empty()); @@ -373,6 +379,10 @@ impl Inner { let mut out_buf = buffer::duplex::Interposer::new(out_buf, &mut self.reassembler); + // this opener should never actually be used anywhere. any packets that try to use control + // authentication will result in stream closure. + let control_opener = &crate::crypto::open::control::stream::Reliable::default(); + loop { // consume the previous packet if let Some(packet_len) = prev_packet_len.take() { @@ -401,17 +411,14 @@ impl Inner { // otherwise, we'll need to receive more bytes from the stream to correctly // parse a packet - // if we have pending data greater than the max record size then it's never going to parse - let max_datagram_size = - shared.sender.path.load().max_datagram_size as usize; - - if msg.payload_len() > max_datagram_size { + // if we have pending data greater than the max datagram size then it's never going to parse + if msg.payload_len() > crate::stream::MAX_DATAGRAM_SIZE { tracing::error!( unconsumed = msg.payload_len(), remaining_capacity = msg.remaining_capacity() ); msg.clear(); - self.receiver.on_error(recv::Error::Decode); + self.receiver.on_error(recv::error::Kind::Decode); return; } @@ -433,7 +440,7 @@ impl Inner { // any other decoder errors mean the stream has been corrupted so // it's time to shut down the connection msg.clear(); - self.receiver.on_error(recv::Error::Decode); + self.receiver.on_error(recv::error::Kind::Decode); return; } }; @@ -445,7 +452,11 @@ impl Inner { debug_assert_eq!(Some(packet.total_len()), prev_packet_len); // make sure the packet looks OK before deriving openers from it - if self.receiver.precheck_stream_packet(packet).is_err() { + if self + .receiver + .precheck_stream_packet(shared.credentials(), packet) + .is_err() + { // check if the receiver returned an error if self.receiver.check_error().is_err() { msg.clear(); @@ -457,8 +468,15 @@ impl Inner { } let _ = shared.crypto.open_with(|opener| { - self.receiver - .on_stream_packet(opener, packet, ecn, clock, &mut out_buf)?; + self.receiver.on_stream_packet( + opener, + control_opener, + shared.credentials(), + packet, + ecn, + clock, + &mut out_buf, + )?; any_valid_packets = true; did_complete_handshake |= @@ -479,7 +497,7 @@ impl Inner { // if we get a packet we don't expect then it's fatal for streams msg.clear(); self.receiver - .on_error(recv::Error::UnexpectedPacket { packet: kind }); + .on_error(recv::error::Kind::UnexpectedPacket { packet: kind }); return; } } @@ -511,6 +529,10 @@ impl Inner { let mut did_complete_handshake = false; let mut out_buf = buffer::duplex::Interposer::new(out_buf, &mut self.reassembler); + let control_opener = shared + .crypto + .control_opener() + .expect("control opener should be available on unreliable transports"); for segment in msg.segments() { let segment_len = segment.len(); @@ -539,13 +561,17 @@ impl Inner { Packet::Stream(mut packet) => { // make sure the packet looks OK before deriving openers from it ensure!( - self.receiver.precheck_stream_packet(&packet).is_ok(), + self.receiver + .precheck_stream_packet(shared.credentials(), &packet) + .is_ok(), continue ); let _ = shared.crypto.open_with(|opener| { self.receiver.on_stream_packet( opener, + control_opener, + shared.credentials(), &mut packet, ecn, clock, diff --git a/dc/s2n-quic-dc/src/stream/recv/state.rs b/dc/s2n-quic-dc/src/stream/recv/state.rs index 5f432e6bb0..a0b0e35bd0 100644 --- a/dc/s2n-quic-dc/src/stream/recv/state.rs +++ b/dc/s2n-quic-dc/src/stream/recv/state.rs @@ -4,10 +4,15 @@ use crate::{ allocator::Allocator, clock, - crypto::{decrypt, encrypt, UninitSlice}, + credentials::Credentials, + crypto::{self, UninitSlice}, packet::{control, stream}, stream::{ - recv::{ack, packet, probes, Error}, + recv::{ + ack, + error::{self, Error}, + packet, probes, + }, TransportFeatures, DEFAULT_IDLE_TIMEOUT, }, }; @@ -107,7 +112,7 @@ impl State { pub fn stop_sending(&mut self, error: s2n_quic_core::application::Error) { // if we've already received everything then no need to notify the peer to stop ensure!(matches!(self.state, Receiver::Recv | Receiver::SizeKnown)); - self.on_error(Error::ApplicationError { error }); + self.on_error(error::Kind::ApplicationError { error }); } #[inline] @@ -151,16 +156,17 @@ impl State { #[inline] pub fn precheck_stream_packet( &mut self, + credentials: &Credentials, packet: &stream::decoder::Packet, ) -> Result<(), Error> { - match self.precheck_stream_packet_impl(packet) { + match self.precheck_stream_packet_impl(credentials, packet) { Ok(()) => Ok(()), Err(err) => { if err.is_fatal(&self.features) { tracing::error!(fatal_error = %err, ?packet); self.on_error(err); } else { - tracing::debug!(non_fatal_error = %err); + tracing::debug!(non_fatal_error = %err, ?packet); } Err(err) } @@ -170,15 +176,26 @@ impl State { #[inline] fn precheck_stream_packet_impl( &mut self, + credentials: &Credentials, packet: &stream::decoder::Packet, ) -> Result<(), Error> { + ensure!( + packet.credentials().id == credentials.id, + Err(error::Kind::CredentialMismatch { + expected: credentials.id, + actual: packet.credentials().id, + } + .err()) + ); + // make sure we're getting packets for the correct stream ensure!( *packet.stream_id() == self.stream_id, - Err(Error::StreamMismatch { + Err(error::Kind::StreamMismatch { expected: self.stream_id, actual: *packet.stream_id(), - }) + } + .err()) ); if self.features.is_stream() { @@ -191,10 +208,11 @@ impl State { let actual_pn = packet.packet_number().as_u64(); ensure!( expected_pn == actual_pn, - Err(Error::OutOfOrder { + Err(error::Kind::OutOfOrder { expected: expected_pn, actual: actual_pn, - }) + } + .err()) ); } @@ -202,7 +220,7 @@ impl State { // if the transport is reliable then we don't expect retransmissions ensure!( !packet.is_retransmission(), - Err(Error::UnexpectedRetransmission) + Err(error::Kind::UnexpectedRetransmission.err()) ); } @@ -210,21 +228,24 @@ impl State { } #[inline] - pub fn on_stream_packet( + pub fn on_stream_packet( &mut self, opener: &D, + control: &C, + credentials: &Credentials, packet: &mut stream::decoder::Packet, ecn: ExplicitCongestionNotification, clock: &Clk, out_buf: &mut B, ) -> Result<(), Error> where - D: decrypt::Key, + D: crypto::open::Application, + C: crypto::open::control::Stream, Clk: Clock + ?Sized, B: buffer::Duplex, { probes::on_stream_packet( - opener.credentials().id, + credentials.id, self.stream_id, packet.tag().packet_space(), packet.packet_number(), @@ -234,14 +255,15 @@ impl State { packet.is_retransmission(), ); - match self.on_stream_packet_impl(opener, packet, ecn, clock, out_buf) { + match self.on_stream_packet_impl(opener, control, credentials, packet, ecn, clock, out_buf) + { Ok(()) => Ok(()), Err(err) => { if err.is_fatal(&self.features) { tracing::error!(fatal_error = %err, ?packet); self.on_error(err); } else { - tracing::debug!(non_fatal_error = %err); + tracing::debug!(non_fatal_error = %err, ?packet); } Err(err) } @@ -249,22 +271,25 @@ impl State { } #[inline] - fn on_stream_packet_impl( + fn on_stream_packet_impl( &mut self, opener: &D, + control: &C, + credentials: &Credentials, packet: &mut stream::decoder::Packet, ecn: ExplicitCongestionNotification, clock: &Clk, out_buf: &mut B, ) -> Result<(), Error> where - D: decrypt::Key, + D: crypto::open::Application, + C: crypto::open::control::Stream, Clk: Clock + ?Sized, B: buffer::Duplex, { use buffer::reader::Storage as _; - self.precheck_stream_packet_impl(packet)?; + self.precheck_stream_packet_impl(credentials, packet)?; let is_max_data_ok = self.ensure_max_data(packet); @@ -276,6 +301,8 @@ impl State { ecn, clock, opener, + control, + credentials, receiver: self, }; @@ -293,7 +320,7 @@ impl State { .saturating_add(packet.packet.payload().len() as u64), ); - let error = Error::MaxDataExceeded; + let error = error::Kind::MaxDataExceeded.err(); self.on_error(error); return Err(error); } @@ -308,22 +335,25 @@ impl State { } #[inline] - pub(super) fn on_stream_packet_in_place( + pub(super) fn on_stream_packet_in_place( &mut self, crypto: &D, + control: &C, + credentials: &Credentials, packet: &mut stream::decoder::Packet, ecn: ExplicitCongestionNotification, clock: &Clk, ) -> Result<(), Error> where - D: decrypt::Key, + D: crypto::open::Application, + C: crypto::open::control::Stream, Clk: Clock + ?Sized, { // ensure the packet is authentic before processing it - let res = packet.decrypt_in_place(crypto); + let res = packet.decrypt_in_place(crypto, control); probes::on_stream_packet_decrypted( - crypto.credentials().id, + credentials.id, self.stream_id, packet.tag().packet_space(), packet.packet_number(), @@ -340,23 +370,26 @@ impl State { } #[inline] - pub(super) fn on_stream_packet_copy( + pub(super) fn on_stream_packet_copy( &mut self, crypto: &D, + control: &C, + credentials: &Credentials, packet: &mut stream::decoder::Packet, ecn: ExplicitCongestionNotification, payload_out: &mut UninitSlice, clock: &Clk, ) -> Result<(), Error> where - D: decrypt::Key, + D: crypto::open::Application, + C: crypto::open::control::Stream, Clk: Clock + ?Sized, { // ensure the packet is authentic before processing it - let res = packet.decrypt(crypto, payload_out); + let res = packet.decrypt(crypto, control, payload_out); probes::on_stream_packet_decrypted( - crypto.credentials().id, + credentials.id, self.stream_id, packet.tag().packet_space(), packet.packet_number(), @@ -407,7 +440,7 @@ impl State { }; ensure!( space.filter.on_packet(packet).is_ok(), - Err(Error::Duplicate) + Err(error::Kind::Duplicate.err()) ); let packet_number = PacketNumberSpace::Initial.new_packet_number(packet.packet_number()); @@ -454,7 +487,7 @@ impl State { // only error out if we're still expecting more data ensure!(matches!(self.state, Receiver::Recv | Receiver::SizeKnown)); - self.on_error(Error::TruncatedTransport); + self.on_error(error::Kind::TruncatedTransport); } #[inline] @@ -519,7 +552,12 @@ impl State { } #[inline] - pub fn on_error(&mut self, error: Error) { + #[track_caller] + pub fn on_error(&mut self, error: E) + where + Error: From, + { + let error = Error::from(error); debug_assert!(error.is_fatal(&self.features)); let _ = self.state.on_reset(); self.stream_ack.clear(); @@ -564,7 +602,7 @@ impl State { did_transition |= self.state.on_reset().is_ok(); did_transition |= self.state.on_app_read_reset().is_ok(); if did_transition { - self.on_error(Error::IdleTimeout); + self.on_error(error::Kind::IdleTimeout); // override the transmission since we're just timing out self._should_transmit = false; } @@ -610,14 +648,15 @@ impl State { } #[inline] - pub fn on_transmit( + pub fn on_transmit( &mut self, - encrypt_key: &E, + key: &K, + credentials: &Credentials, source_control_port: u16, output: &mut A, clock: &Clk, ) where - E: encrypt::Key, + K: crypto::seal::control::Stream, A: Allocator, Clk: Clock + ?Sized, { @@ -627,7 +666,8 @@ impl State { Self::on_transmit_error })( self, - encrypt_key, + key, + credentials, source_control_port, output, // avoid querying the clock for every transmitted packet @@ -636,14 +676,15 @@ impl State { } #[inline] - fn on_transmit_ack( + fn on_transmit_ack( &mut self, - encrypt_key: &E, + key: &K, + credentials: &Credentials, source_control_port: u16, output: &mut A, _clock: &Clk, ) where - E: encrypt::Key, + K: crypto::seal::control::Stream, A: Allocator, Clk: Clock + ?Sized, { @@ -695,7 +736,8 @@ impl State { &mut &[][..], encoding_size, &frame, - encrypt_key, + key, + credentials, ); match result { @@ -721,7 +763,7 @@ impl State { ); if let (Some(min), Some(max), Some(gaps)) = metrics { probes::on_transmit_control( - encrypt_key.credentials().id, + credentials.id, self.stream_id, space, packet_number, @@ -743,14 +785,15 @@ impl State { } #[inline] - fn on_transmit_error( + fn on_transmit_error( &mut self, - encrypt_key: &E, + control_key: &K, + credentials: &Credentials, source_control_port: u16, output: &mut A, _clock: &Clk, ) where - E: encrypt::Key, + K: crypto::seal::control::Stream, A: Allocator, Clk: Clock + ?Sized, { @@ -786,7 +829,8 @@ impl State { &mut &[][..], encoding_size, &frame, - encrypt_key, + control_key, + credentials, ); match result { @@ -807,7 +851,7 @@ impl State { self.recovery_ack.clear(); probes::on_transmit_close( - encrypt_key.credentials().id, + credentials.id, self.stream_id, packet_number, frame.error_code, diff --git a/dc/s2n-quic-dc/src/stream/send/application.rs b/dc/s2n-quic-dc/src/stream/send/application.rs index f5710e7df1..e25a6f88e5 100644 --- a/dc/s2n-quic-dc/src/stream/send/application.rs +++ b/dc/s2n-quic-dc/src/stream/send/application.rs @@ -154,7 +154,9 @@ impl Inner { trace!(?credits); - let mut batch = if self.sockets.write_application().features().is_reliable() { + let features = self.sockets.write_application().features(); + + let mut batch = if features.is_reliable() { // the protocol does recovery for us so no need to track the transmissions None } else { @@ -175,18 +177,29 @@ impl Inner { max_segments, &self.shared.sender.segment_alloc, |message, buf| { - self.shared.crypto.seal_with(|sealer| { - // push packets for transmission into our queue - app.transmit( - credits, - &path, - buf, - &self.shared.sender.packet_number, - sealer, - &clock::Cached::new(&self.shared.clock), - message, - ) - }) + self.shared.crypto.seal_with( + |sealer| { + // push packets for transmission into our queue + app.transmit( + credits, + &path, + buf, + &self.shared.sender.packet_number, + sealer, + self.shared.credentials(), + &clock::Cached::new(&self.shared.clock), + message, + ) + }, + |sealer| { + if features.is_reliable() { + sealer.update(); + } else { + // TODO enqueue a full flush of any pending transmissions before + // updating the key. + } + }, + ) }, )?; diff --git a/dc/s2n-quic-dc/src/stream/send/application/state.rs b/dc/s2n-quic-dc/src/stream/send/application/state.rs index e0c7e199e1..99ee74c7e0 100644 --- a/dc/s2n-quic-dc/src/stream/send/application/state.rs +++ b/dc/s2n-quic-dc/src/stream/send/application/state.rs @@ -2,7 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 use crate::{ - crypto::encrypt, + credentials::Credentials, + crypto::seal, packet::stream::{self, encoder}, stream::{ packet_number, @@ -43,11 +44,12 @@ impl State { storage: &mut I, packet_number: &packet_number::Counter, encrypt_key: &E, + credentials: &Credentials, clock: &Clk, message: &mut M, ) -> Result<(), Error> where - E: encrypt::Key, + E: seal::Application, I: buffer::reader::Storage, Clk: Clock, M: Message, @@ -94,7 +96,6 @@ impl State { self.source_control_port, self.source_stream_port, stream_id, - stream::PacketSpace::Stream, packet_number, path.next_expected_control_packet, VarInt::ZERO, @@ -103,6 +104,7 @@ impl State { &(), &mut reader, encrypt_key, + credentials, ); // buffer is clamped to u16::MAX so this is safe to cast without loss @@ -119,7 +121,7 @@ impl State { let time_sent = clock.get_time(); probes::on_transmit_stream( - encrypt_key.credentials().id, + credentials.id, stream_id, stream::PacketSpace::Stream, s2n_quic_core::packet::number::PacketNumberSpace::Initial diff --git a/dc/s2n-quic-dc/src/stream/send/error.rs b/dc/s2n-quic-dc/src/stream/send/error.rs index 2cddf70f6a..31c3e0a13d 100644 --- a/dc/s2n-quic-dc/src/stream/send/error.rs +++ b/dc/s2n-quic-dc/src/stream/send/error.rs @@ -2,10 +2,70 @@ // SPDX-License-Identifier: Apache-2.0 use crate::stream::packet_number; +use core::{fmt, panic::Location}; use s2n_quic_core::{buffer, varint::VarInt}; +#[derive(Clone, Copy)] +pub struct Error { + kind: Kind, + location: &'static Location<'static>, +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Error") + .field("kind", &self.kind) + .field("crate", &"s2n-quic-dc") + .field("file", &self.file()) + .field("line", &self.location.line()) + .finish() + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let Self { kind, location } = self; + let file = self.file(); + let line = location.line(); + write!(f, "[s2n-quic-dc::{file}:{line}]: {kind}") + } +} + +impl std::error::Error for Error {} + +impl Error { + #[track_caller] + #[inline] + pub fn new(kind: Kind) -> Self { + Self { + kind, + location: Location::caller(), + } + } + + #[inline] + pub fn kind(&self) -> &Kind { + &self.kind + } + + #[inline] + fn file(&self) -> &'static str { + self.location + .file() + .trim_start_matches(concat!(env!("CARGO_MANIFEST_DIR"), "/src/")) + } +} + +impl From for Error { + #[track_caller] + #[inline] + fn from(kind: Kind) -> Self { + Self::new(kind) + } +} + #[derive(Clone, Copy, Debug, thiserror::Error)] -pub enum Error { +pub enum Kind { #[error("payload provided is too large and exceeded the maximum offset")] PayloadTooLarge, #[error("the provided packet buffer is too small for the minimum packet size")] @@ -32,46 +92,57 @@ pub enum Error { FatalError, } +impl Kind { + #[inline] + #[track_caller] + pub(crate) fn err(self) -> Error { + Error::new(self) + } +} + impl From for std::io::Error { #[inline] + #[track_caller] fn from(error: Error) -> Self { - Self::new(error.into(), error) + Self::new(error.kind.into(), error) } } -impl From for std::io::ErrorKind { +impl From for std::io::ErrorKind { #[inline] - fn from(error: Error) -> Self { + fn from(kind: Kind) -> Self { use std::io::ErrorKind; - match error { - Error::PayloadTooLarge => ErrorKind::BrokenPipe, - Error::PacketBufferTooSmall => ErrorKind::InvalidInput, - Error::PacketNumberExhaustion => ErrorKind::BrokenPipe, - Error::RetransmissionFailure => ErrorKind::BrokenPipe, - Error::StreamFinished => ErrorKind::UnexpectedEof, - Error::FinalSizeChanged => ErrorKind::InvalidInput, - Error::IdleTimeout => ErrorKind::TimedOut, - Error::ApplicationError { .. } => ErrorKind::ConnectionReset, - Error::TransportError { .. } => ErrorKind::ConnectionAborted, - Error::FrameError { .. } => ErrorKind::InvalidData, - Error::FatalError => ErrorKind::BrokenPipe, + match kind { + Kind::PayloadTooLarge => ErrorKind::BrokenPipe, + Kind::PacketBufferTooSmall => ErrorKind::InvalidInput, + Kind::PacketNumberExhaustion => ErrorKind::BrokenPipe, + Kind::RetransmissionFailure => ErrorKind::BrokenPipe, + Kind::StreamFinished => ErrorKind::UnexpectedEof, + Kind::FinalSizeChanged => ErrorKind::InvalidInput, + Kind::IdleTimeout => ErrorKind::TimedOut, + Kind::ApplicationError { .. } => ErrorKind::ConnectionReset, + Kind::TransportError { .. } => ErrorKind::ConnectionAborted, + Kind::FrameError { .. } => ErrorKind::InvalidData, + Kind::FatalError => ErrorKind::BrokenPipe, } } } impl From for Error { #[inline] + #[track_caller] fn from(_error: packet_number::ExhaustionError) -> Self { - Self::PacketNumberExhaustion + Kind::PacketNumberExhaustion.err() } } impl From> for Error { #[inline] + #[track_caller] fn from(error: buffer::Error) -> Self { match error { - buffer::Error::OutOfRange => Self::PayloadTooLarge, - buffer::Error::InvalidFin => Self::FinalSizeChanged, + buffer::Error::OutOfRange => Kind::PayloadTooLarge.err(), + buffer::Error::InvalidFin => Kind::FinalSizeChanged.err(), buffer::Error::ReaderError(_) => unreachable!(), } } diff --git a/dc/s2n-quic-dc/src/stream/send/flow/blocking.rs b/dc/s2n-quic-dc/src/stream/send/flow/blocking.rs index c6bddf205e..f0bb1a4b21 100644 --- a/dc/s2n-quic-dc/src/stream/send/flow/blocking.rs +++ b/dc/s2n-quic-dc/src/stream/send/flow/blocking.rs @@ -2,7 +2,10 @@ // SPDX-License-Identifier: Apache-2.0 use super::Credits; -use crate::stream::send::{error::Error, flow}; +use crate::stream::send::{ + error::{self, Error}, + flow, +}; use s2n_quic_core::{ensure, varint::VarInt}; use std::sync::{Condvar, Mutex}; @@ -40,7 +43,10 @@ impl State { /// Callers MUST ensure the provided offset is monotonic. #[inline] pub fn release(&self, flow_offset: VarInt) -> Result<(), Error> { - let mut guard = self.state.lock().map_err(|_| Error::FatalError)?; + let mut guard = self + .state + .lock() + .map_err(|_| error::Kind::FatalError.err())?; // only notify subscribers if we actually increment the offset debug_assert!( @@ -60,10 +66,13 @@ impl State { /// Called by the application to acquire flow credits #[inline] pub fn acquire(&self, mut request: flow::Request) -> Result { - let mut guard = self.state.lock().map_err(|_| Error::FatalError)?; + let mut guard = self + .state + .lock() + .map_err(|_| error::Kind::FatalError.err())?; loop { - ensure!(!guard.is_finished, Err(Error::FinalSizeChanged)); + ensure!(!guard.is_finished, Err(error::Kind::FinalSizeChanged.err())); // TODO check for an error @@ -88,7 +97,10 @@ impl State { } }) else { - guard = self.notify.wait(guard).map_err(|_| Error::FatalError)?; + guard = self + .notify + .wait(guard) + .map_err(|_| error::Kind::FatalError.err())?; continue; }; @@ -98,7 +110,7 @@ impl State { // update the stream offset with the given request guard.stream_offset = current_offset .checked_add_usize(request.len) - .ok_or(Error::PayloadTooLarge)?; + .ok_or_else(|| error::Kind::PayloadTooLarge.err())?; // update the finished status guard.is_finished |= request.is_fin; diff --git a/dc/s2n-quic-dc/src/stream/send/flow/non_blocking.rs b/dc/s2n-quic-dc/src/stream/send/flow/non_blocking.rs index 32d048c6d0..fb2c7ae5b8 100644 --- a/dc/s2n-quic-dc/src/stream/send/flow/non_blocking.rs +++ b/dc/s2n-quic-dc/src/stream/send/flow/non_blocking.rs @@ -2,7 +2,10 @@ // SPDX-License-Identifier: Apache-2.0 use super::Credits; -use crate::stream::send::{error::Error, flow}; +use crate::stream::send::{ + error::{self, Error}, + flow, +}; use atomic_waker::AtomicWaker; use core::{ fmt, @@ -141,7 +144,7 @@ impl State { let mut new_offset = (current_offset & OFFSET_MASK) .checked_add(request.len as u64) .filter(|v| *v <= VarInt::MAX.as_u64()) - .ok_or(Error::PayloadTooLarge)?; + .ok_or_else(|| error::Kind::PayloadTooLarge.err())?; // record that we've sent the final offset if request.is_fin || current_offset & FINISHED_MASK == FINISHED_MASK { @@ -187,12 +190,12 @@ impl State { .stream_error .get() .copied() - .unwrap_or(Error::FatalError); + .unwrap_or_else(|| error::Kind::FatalError.err()); return Err(error); } if offset & FINISHED_MASK == FINISHED_MASK { - ensure!(request.len == 0, Err(Error::FinalSizeChanged)); + ensure!(request.len == 0, Err(error::Kind::FinalSizeChanged.err())); } Ok(offset) diff --git a/dc/s2n-quic-dc/src/stream/send/state.rs b/dc/s2n-quic-dc/src/stream/send/state.rs index 4a2f5e8c52..166a10af19 100644 --- a/dc/s2n-quic-dc/src/stream/send/state.rs +++ b/dc/s2n-quic-dc/src/stream/send/state.rs @@ -3,7 +3,8 @@ use crate::{ congestion, - crypto::{decrypt, encrypt, UninitSlice}, + credentials::Credentials, + crypto, packet::{ self, stream::{self, decoder, encoder}, @@ -12,7 +13,10 @@ use crate::{ stream::{ processing, send::{ - application, buffer, error::Error, filter::Filter, probes, + application, buffer, + error::{self, Error}, + filter::Filter, + probes, transmission::Type as TransmissionType, }, DEFAULT_IDLE_TIMEOUT, @@ -202,9 +206,10 @@ impl State { /// Called by the worker when it receives a control packet from the peer #[inline] - pub fn on_control_packet( + pub fn on_control_packet( &mut self, - decrypt_key: &D, + control_key: &C, + credentials: &Credentials, ecn: ExplicitCongestionNotification, packet: &mut packet::control::decoder::Packet, random: &mut dyn random::Generator, @@ -213,11 +218,12 @@ impl State { segment_alloc: &buffer::Allocator, ) -> Result<(), processing::Error> where - D: decrypt::Key, + C: crypto::open::control::Stream, Clk: Clock, { match self.on_control_packet_impl( - decrypt_key, + control_key, + credentials, ecn, packet, random, @@ -238,9 +244,10 @@ impl State { } #[inline(always)] - fn on_control_packet_impl( + fn on_control_packet_impl( &mut self, - decrypt_key: &D, + control_key: &C, + credentials: &Credentials, _ecn: ExplicitCongestionNotification, packet: &mut packet::control::decoder::Packet, random: &mut dyn random::Generator, @@ -249,30 +256,21 @@ impl State { segment_alloc: &buffer::Allocator, ) -> Result, Error> where - D: decrypt::Key, + C: crypto::open::control::Stream, Clk: Clock, { probes::on_control_packet( - decrypt_key.credentials().id, + credentials.id, self.stream_id, packet.packet_number(), packet.control_data().len(), ); - let key_phase = packet.tag().key_phase(); - // only process the packet after we know it's authentic - let res = decrypt_key.decrypt( - key_phase, - packet.crypto_nonce(), - packet.header(), - &[], - packet.auth_tag(), - UninitSlice::new(&mut []), - ); + let res = control_key.verify(packet.header(), packet.auth_tag()); probes::on_control_packet_decrypted( - decrypt_key.credentials().id, + credentials.id, self.stream_id, packet.packet_number(), packet.control_data().len(), @@ -289,7 +287,7 @@ impl State { self.control_filter.on_packet(packet).is_ok(), return { probes::on_control_packet_duplicate( - decrypt_key.credentials().id, + credentials.id, self.stream_id, packet.packet_number(), packet.control_data().len(), @@ -319,7 +317,7 @@ impl State { while !decoder.is_empty() { let (frame, remaining) = decoder .decode::() - .map_err(|decoder| Error::FrameError { decoder })?; + .map_err(|decoder| error::Kind::FrameError { decoder }.err())?; decoder = remaining; trace!(?frame); @@ -338,8 +336,8 @@ impl State { } if ack.ecn_counts.is_some() { - self.on_frame_ack::<_, _, _, true>( - decrypt_key, + self.on_frame_ack::<_, _, true>( + credentials, &ack, random, &recv_time, @@ -349,8 +347,8 @@ impl State { segment_alloc, )?; } else { - self.on_frame_ack::<_, _, _, false>( - decrypt_key, + self.on_frame_ack::<_, _, false>( + credentials, &ack, random, &recv_time, @@ -370,7 +368,7 @@ impl State { debug!(connection_close = ?close, state = ?self.state); probes::on_close( - decrypt_key.credentials().id, + credentials.id, self.stream_id, packet_number, close.error_code, @@ -394,15 +392,15 @@ impl State { let _ = self.state.on_send_reset(); let _ = self.state.on_recv_reset_ack(); let error = if close.frame_type.is_some() { - Error::TransportError { + error::Kind::TransportError { code: close.error_code, } } else { - Error::ApplicationError { + error::Kind::ApplicationError { error: close.error_code.into(), } }; - return Err(error); + return Err(error.err()); } _ => continue, } @@ -414,7 +412,7 @@ impl State { (stream::PacketSpace::Recovery, max_acked_recovery), ] { if let Some(pn) = pn { - self.detect_lost_packets(decrypt_key, random, &recv_time, space, pn)?; + self.detect_lost_packets(credentials, random, &recv_time, space, pn)?; } } @@ -436,9 +434,9 @@ impl State { } #[inline] - fn on_frame_ack( + fn on_frame_ack( &mut self, - decrypt_key: &D, + credentials: &Credentials, ack: &frame::Ack, random: &mut dyn random::Generator, clock: &Clk, @@ -448,7 +446,6 @@ impl State { segment_alloc: &buffer::Allocator, ) -> Result<(), Error> where - D: decrypt::Key, Ack: frame::ack::AckRanges, Clk: Clock, { @@ -492,7 +489,7 @@ impl State { } probes::on_packet_ack( - decrypt_key.credentials().id, + credentials.id, self.stream_id, stream::PacketSpace::$space, num.as_u64(), @@ -562,16 +559,15 @@ impl State { } #[inline] - fn detect_lost_packets( + fn detect_lost_packets( &mut self, - decrypt_key: &D, + credentials: &Credentials, random: &mut dyn random::Generator, clock: &Clk, packet_space: stream::PacketSpace, max: VarInt, ) -> Result<(), Error> where - D: decrypt::Key, Clk: Clock, { let Some(loss_threshold) = max.checked_sub(VarInt::from_u8(2)) else { @@ -597,7 +593,7 @@ impl State { ); probes::on_packet_lost( - decrypt_key.credentials().id, + credentials.id, self.stream_id, packet_space, num.as_u64(), @@ -647,7 +643,10 @@ impl State { } } - ensure!(!is_unrecoverable, Err(Error::RetransmissionFailure)); + ensure!( + !is_unrecoverable, + Err(error::Kind::RetransmissionFailure.err()) + ); self.invariants(); @@ -718,7 +717,7 @@ impl State { Ld: FnOnce() -> Timestamp, { if self.poll_idle_timer(clock, load_last_activity).is_ready() { - self.on_error(Error::IdleTimeout); + self.on_error(error::Kind::IdleTimeout); // we don't actually want to send any packets on idle timeout let _ = self.state.on_send_reset(); let _ = self.state.on_recv_reset_ack(); @@ -730,7 +729,7 @@ impl State { .poll_expiration(clock.get_time()) .is_ready() { - self.on_error(Error::IdleTimeout); + self.on_error(error::Kind::IdleTimeout); return; } @@ -923,17 +922,20 @@ impl State { } #[inline] - pub fn fill_transmit_queue( + pub fn fill_transmit_queue( &mut self, - encrypt_key: &E, + control_key: &C, + credentials: &Credentials, source_control_port: u16, clock: &Clk, ) -> Result<(), Error> where - E: encrypt::Key, + C: crypto::seal::control::Stream, Clk: Clock, { - if let Err(error) = self.fill_transmit_queue_impl(encrypt_key, source_control_port, clock) { + if let Err(error) = + self.fill_transmit_queue_impl(control_key, credentials, source_control_port, clock) + { self.on_error(error); return Err(error); } @@ -942,14 +944,15 @@ impl State { } #[inline] - fn fill_transmit_queue_impl( + fn fill_transmit_queue_impl( &mut self, - encrypt_key: &E, + control_key: &C, + credentials: &Credentials, source_control_port: u16, clock: &Clk, ) -> Result<(), Error> where - E: encrypt::Key, + C: crypto::seal::control::Stream, Clk: Clock, { // skip a packet number if we're probing @@ -957,20 +960,21 @@ impl State { self.recovery_packet_number += 1; } - self.try_transmit_retransmissions(encrypt_key, clock)?; - self.try_transmit_probe(encrypt_key, source_control_port, clock)?; + self.try_transmit_retransmissions(control_key, credentials, clock)?; + self.try_transmit_probe(control_key, credentials, source_control_port, clock)?; Ok(()) } #[inline] - fn try_transmit_retransmissions( + fn try_transmit_retransmissions( &mut self, - encrypt_key: &E, + control_key: &C, + credentials: &Credentials, clock: &Clk, ) -> Result<(), Error> where - E: encrypt::Key, + C: crypto::seal::control::Stream, Clk: Clock, { // We'll only have retransmissions if we're reliable @@ -1001,13 +1005,13 @@ impl State { buffer, stream::PacketSpace::Recovery, packet_number, - encrypt_key, + control_key, ) { Ok(info) => info, Err(err) => { // this shouldn't ever happen debug_assert!(false, "{err:?}"); - return Err(Error::RetransmissionFailure); + return Err(error::Kind::RetransmissionFailure.err()); } } }; @@ -1044,7 +1048,7 @@ impl State { }; probes::on_transmit_stream( - encrypt_key.credentials().id, + credentials.id, self.stream_id, stream::PacketSpace::Recovery, PacketNumberSpace::Initial.new_packet_number(packet_number), @@ -1067,14 +1071,15 @@ impl State { } #[inline] - pub fn try_transmit_probe( + pub fn try_transmit_probe( &mut self, - encrypt_key: &E, + control_key: &C, + credentials: &Credentials, source_control_port: u16, clock: &Clk, ) -> Result<(), Error> where - E: encrypt::Key, + C: crypto::seal::control::Stream, Clk: Clock, { while self.pto.transmissions() > 0 { @@ -1113,12 +1118,11 @@ impl State { }; let encoder = EncoderBuffer::new(&mut buffer); - let packet_len = encoder::encode( + let packet_len = encoder::probe( encoder, source_control_port, None, self.stream_id, - stream::PacketSpace::Recovery, packet_number, self.next_expected_control_packet, VarInt::ZERO, @@ -1126,7 +1130,8 @@ impl State { VarInt::ZERO, &(), &mut payload, - encrypt_key, + control_key, + credentials, ); let payload_len = 0; @@ -1211,9 +1216,13 @@ impl State { } #[inline] - pub fn on_error(&mut self, error: Error) { + #[track_caller] + pub fn on_error(&mut self, error: E) + where + Error: From, + { ensure!(self.error.is_none()); - self.error = Some(error); + self.error = Some(Error::from(error)); let _ = self.state.on_queue_reset(); self.clean_up(); diff --git a/dc/s2n-quic-dc/src/stream/send/worker.rs b/dc/s2n-quic-dc/src/stream/send/worker.rs index 0eab1fea93..ec37643727 100644 --- a/dc/s2n-quic-dc/src/stream/send/worker.rs +++ b/dc/s2n-quic-dc/src/stream/send/worker.rs @@ -7,8 +7,13 @@ use crate::{ msg::addr, packet::Packet, stream::{ - pacer, processing, - send::{error::Error, queue::Queue, shared::Event, state::State}, + pacer, + send::{ + error::{self, Error}, + queue::Queue, + shared::Event, + state::State, + }, shared::{self, Half}, socket::Socket, }, @@ -241,7 +246,7 @@ where } => { // if the application is panicking then we notify the peer if is_panicking { - let error = Error::ApplicationError { error: 1u8.into() }; + let error = error::Kind::ApplicationError { error: 1u8.into() }; self.sender.on_error(error); continue; } @@ -282,6 +287,11 @@ where let random = &mut self.random; let mut any_valid_packets = false; let clock = &clock::Cached::new(&self.shared.clock); + let opener = self + .shared + .crypto + .control_opener() + .expect("control crypto should be available"); for segment in self.recv_buffer.segments() { let segment_len = segment.len(); @@ -310,21 +320,20 @@ where continue ); - let _ = self.shared.crypto.open_with(|opener| { - self.sender.on_control_packet( - opener, - ecn, - &mut packet, - random, - clock, - &self.shared.sender.application_transmission_queue, - &self.shared.sender.segment_alloc, - )?; + let res = self.sender.on_control_packet( + opener, + self.shared.credentials(), + ecn, + &mut packet, + random, + clock, + &self.shared.sender.application_transmission_queue, + &self.shared.sender.segment_alloc, + ); + if res.is_ok() { any_valid_packets = true; - - >::Ok(()) - }); + } } other => self.shared.crypto.map().handle_unexpected_packet(&other), } @@ -354,15 +363,20 @@ where loop { ready!(self.poll_transmit_flush(cx)); + let control_sealer = self + .shared + .crypto + .control_sealer() + .expect("control crypto should be available"); + match self.state { waiting::State::Acking => { - let _ = self.shared.crypto.seal_with(|sealer| { - self.sender.fill_transmit_queue( - sealer, - self.socket.local_addr().unwrap().port(), - &self.shared.clock, - ) - }); + let _ = self.sender.fill_transmit_queue( + control_sealer, + self.shared.credentials(), + self.socket.local_addr().unwrap().port(), + &self.shared.clock, + ); } waiting::State::Detached => { // flush the remaining application queue @@ -392,13 +406,12 @@ where continue; } waiting::State::ShuttingDown => { - let _ = self.shared.crypto.seal_with(|sealer| { - self.sender.fill_transmit_queue( - sealer, - self.socket.local_addr().unwrap().port(), - &self.shared.clock, - ) - }); + let _ = self.sender.fill_transmit_queue( + control_sealer, + self.shared.credentials(), + self.socket.local_addr().unwrap().port(), + &self.shared.clock, + ); if self.sender.state.is_terminal() { let _ = self.state.on_finished(); diff --git a/dc/s2n-quic-dc/src/stream/shared.rs b/dc/s2n-quic-dc/src/stream/shared.rs index 2ba44e2a2b..b6600fb305 100644 --- a/dc/s2n-quic-dc/src/stream/shared.rs +++ b/dc/s2n-quic-dc/src/stream/shared.rs @@ -3,6 +3,7 @@ use crate::{ clock::Clock, + credentials::Credentials, stream::{ recv::shared as recv, send::{application, shared as send}, @@ -121,6 +122,14 @@ impl Shared { *self.common.fixed.source_control_port.get() } } + + #[inline] + pub fn credentials(&self) -> &Credentials { + unsafe { + // SAFETY: the fixed information doesn't change for the lifetime of the stream + &*self.common.fixed.credentials.get() + } + } } impl ops::Deref for Shared { @@ -167,6 +176,7 @@ pub struct FixedValues { pub remote_ip: UnsafeCell, pub source_control_port: UnsafeCell, pub application: UnsafeCell, + pub credentials: UnsafeCell, } unsafe impl Send for FixedValues {} diff --git a/dc/s2n-quic-dc/src/stream/socket/fd/udp.rs b/dc/s2n-quic-dc/src/stream/socket/fd/udp.rs index a47704eaa4..5ca159a59c 100644 --- a/dc/s2n-quic-dc/src/stream/socket/fd/udp.rs +++ b/dc/s2n-quic-dc/src/stream/socket/fd/udp.rs @@ -97,8 +97,7 @@ fn send_msghdr( let mut cmsg_storage = cmsg::Storage::<{ cmsg::ENCODER_LEN }>::default(); let mut cmsg = cmsg_storage.encoder(); if ecn != ExplicitCongestionNotification::NotEct { - // TODO enable this once we consolidate s2n-quic-core crates - // let _ = cmsg.encode_ecn(ecn, &addr); + let _ = cmsg.encode_ecn(ecn, &addr.get()); } if segments.len() > 1 { diff --git a/dc/wireshark/src/field.rs b/dc/wireshark/src/field.rs index 947ade8e1b..eb79ce9be7 100644 --- a/dc/wireshark/src/field.rs +++ b/dc/wireshark/src/field.rs @@ -572,5 +572,5 @@ mod masks { } pub const HAS_APPLICATION_HEADER: u64 = common_tag!(HAS_APPLICATION_HEADER_MASK); - pub const KEY_PHASE: u64 = common_tag!(KEY_PHASE_MASK); + pub const KEY_PHASE: u64 = stream::Tag::KEY_PHASE_MASK as _; } diff --git a/dc/wireshark/src/test.rs b/dc/wireshark/src/test.rs index 94ff8136c3..8fb717b4bc 100644 --- a/dc/wireshark/src/test.rs +++ b/dc/wireshark/src/test.rs @@ -5,11 +5,12 @@ use crate::{buffer::Buffer, dissect, value::Parsed}; use s2n_codec::EncoderBuffer; use s2n_quic_core::{ buffer::{reader::Storage, Reader}, + packet::KeyPhase, stream::testing::Data, varint::VarInt, }; use s2n_quic_dc::{ - credentials::{self, Credentials}, + credentials, packet::{self, stream, WireVersion}, }; use std::{collections::HashMap, num::NonZeroU16, ptr, time::Duration}; @@ -25,6 +26,7 @@ struct StreamPacket { next_expected_control_packet: VarInt, application_header: Data, payload: Data, + key_phase: KeyPhase, } #[test] @@ -36,7 +38,7 @@ fn check_stream_parse() { .with_type() .for_each(|packet: &StreamPacket| { let mut packet = packet.clone(); - let key = s2n_quic_dc::crypto::testing::Key::new(packet.credentials); + let key = TestKey(packet.key_phase); let sent_payload = packet.payload; let sent_app_header_len = packet.application_header.buffered_len(); let mut buffer = vec![ @@ -51,7 +53,6 @@ fn check_stream_parse() { NonZeroU16::get(packet.source_control_port), packet.source_stream_port.map(NonZeroU16::get), packet.stream_id, - packet.packet_space, packet.packet_number, packet.next_expected_control_packet, VarInt::new(packet.application_header.buffered_len() as u64).unwrap(), @@ -61,6 +62,7 @@ fn check_stream_parse() { &(), &mut packet.payload, &key, + &packet.credentials, ); let fields = crate::field::get(); @@ -194,6 +196,7 @@ struct DatagramPacket { next_expected_control_packet: Option, application_header: Data, payload: Data, + key_phase: KeyPhase, } #[test] @@ -208,7 +211,7 @@ fn check_datagram_parse() { if packet.next_expected_control_packet.is_some() && packet.packet_number.is_none() { packet.packet_number = Some(Default::default()); } - let key = s2n_quic_dc::crypto::testing::Key::new(packet.credentials); + let key = TestKey(packet.key_phase); let sent_payload = packet.payload; let sent_app_header_len = packet.application_header.buffered_len(); let mut buffer = vec![ @@ -230,6 +233,7 @@ fn check_datagram_parse() { VarInt::new(packet.payload.buffered_len() as u64).unwrap(), &mut packet.payload, &key, + &packet.credentials, ); let fields = crate::field::get(); @@ -324,6 +328,7 @@ struct ControlPacket { next_expected_control_packet: Option, application_header: Data, control_data: Data, + key_phase: KeyPhase, } #[test] @@ -335,7 +340,7 @@ fn check_control_parse() { .with_type() .for_each(|packet: &ControlPacket| { let mut packet = packet.clone(); - let key = s2n_quic_dc::crypto::testing::Key::new(packet.credentials); + let key = TestKey(packet.key_phase); let sent_app_header_len = packet.application_header.buffered_len(); let mut buffer = vec![ 0; @@ -355,6 +360,7 @@ fn check_control_parse() { // FIXME: Encode *real* control data, not random garbage. &&packet.control_data.read_chunk(usize::MAX).unwrap()[..], &key, + &packet.credentials, ); let fields = crate::field::get(); @@ -435,10 +441,7 @@ fn check_secret_control_parse() { .with_type() .for_each(|packet: &SecretControlPacket| { // Use a fixed key, we don't change key IDs per control packet anyway. - let key = s2n_quic_dc::crypto::testing::Key::new(Credentials { - id: [0; 16].into(), - key_id: VarInt::ZERO, - }); + let key = TestKey(KeyPhase::Zero); let mut buffer = vec![0; s2n_quic_dc::packet::secret_control::MAX_PACKET_SIZE]; let length = match packet { SecretControlPacket::UnknownPathSecret { id, auth_tag } => { @@ -707,3 +710,55 @@ impl crate::wireshark::Item for () { // no-op } } + +struct TestKey(KeyPhase); + +impl s2n_quic_dc::crypto::seal::Application for TestKey { + fn key_phase(&self) -> KeyPhase { + self.0 + } + + fn tag_len(&self) -> usize { + 16 + } + + fn encrypt( + &self, + _packet_number: u64, + _header: &[u8], + extra_payload: Option<&[u8]>, + payload_and_tag: &mut [u8], + ) { + if let Some(extra_payload) = extra_payload { + let offset = payload_and_tag.len() - self.tag_len() - extra_payload.len(); + let dest = &mut payload_and_tag[offset..]; + assert!(dest.len() == extra_payload.len() + self.tag_len()); + let (dest, tag) = dest.split_at_mut(extra_payload.len()); + dest.copy_from_slice(extra_payload); + tag.fill(0); + } + } +} + +impl s2n_quic_dc::crypto::seal::Control for TestKey { + fn tag_len(&self) -> usize { + 16 + } + + fn sign(&self, _header: &[u8], tag: &mut [u8]) { + tag.fill(0) + } +} + +impl s2n_quic_dc::crypto::seal::control::Stream for TestKey { + fn retransmission_tag( + &self, + _original_packet_number: u64, + _retransmission_packet_number: u64, + tag_out: &mut [u8], + ) { + tag_out.fill(0) + } +} + +impl s2n_quic_dc::crypto::seal::control::Secret for TestKey {} diff --git a/quic/s2n-quic-core/src/packet/key_phase.rs b/quic/s2n-quic-core/src/packet/key_phase.rs index cef35085c2..64dde34cf5 100644 --- a/quic/s2n-quic-core/src/packet/key_phase.rs +++ b/quic/s2n-quic-core/src/packet/key_phase.rs @@ -14,6 +14,10 @@ const KEY_PHASE_MASK: u8 = 0x04; pub struct ProtectedKeyPhase; #[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr( + any(test, feature = "bolero-generator"), + derive(bolero_generator::TypeGenerator) +)] pub enum KeyPhase { Zero, One,