diff --git a/dc/s2n-quic-dc/Cargo.toml b/dc/s2n-quic-dc/Cargo.toml index 7d4e1a9dd7..ec5714ccb2 100644 --- a/dc/s2n-quic-dc/Cargo.toml +++ b/dc/s2n-quic-dc/Cargo.toml @@ -11,7 +11,9 @@ license = "Apache-2.0" exclude = ["corpus.tar.gz"] [features] +default = ["tokio"] testing = ["bolero-generator", "s2n-quic-core/testing"] +tokio = ["tokio/io-util", "tokio/net", "tokio/rt-multi-thread", "tokio/time"] [dependencies] arrayvec = "0.7" @@ -35,7 +37,7 @@ s2n-quic-core = { version = "=0.42.0", path = "../../quic/s2n-quic-core", defaul s2n-quic-platform = { version = "=0.42.0", path = "../../quic/s2n-quic-platform" } slotmap = "1" thiserror = "1" -tokio = { version = "1", features = ["sync"] } +tokio = { version = "1", default-features = false, features = ["sync"] } tracing = "0.1" zerocopy = { version = "0.7", features = ["derive"] } zeroize = "1" @@ -46,4 +48,4 @@ bolero-generator = "0.11" insta = "1" s2n-codec = { path = "../../common/s2n-codec", features = ["testing"] } s2n-quic-core = { path = "../../quic/s2n-quic-core", features = ["testing"] } -tokio = { version = "1", features = ["sync"] } +tokio = { version = "1", features = ["full"] } diff --git a/dc/s2n-quic-dc/src/clock.rs b/dc/s2n-quic-dc/src/clock.rs index a8bf3cdcbc..8a870a007a 100644 --- a/dc/s2n-quic-dc/src/clock.rs +++ b/dc/s2n-quic-dc/src/clock.rs @@ -5,6 +5,7 @@ use core::{fmt, pin::Pin, task::Poll, time::Duration}; use s2n_quic_core::{ensure, time}; use tracing::trace; +#[cfg(feature = "tokio")] pub mod tokio; pub use time::clock::Cached; diff --git a/dc/s2n-quic-dc/src/crypto.rs b/dc/s2n-quic-dc/src/crypto.rs index 762e2f13c9..ffc2ee17ed 100644 --- a/dc/s2n-quic-dc/src/crypto.rs +++ b/dc/s2n-quic-dc/src/crypto.rs @@ -16,6 +16,8 @@ pub mod encrypt { pub trait Key { fn credentials(&self) -> &Credentials; + fn key_phase(&self) -> KeyPhase; + fn tag_len(&self) -> usize; /// Encrypt a payload @@ -75,6 +77,7 @@ pub mod decrypt { /// Decrypt a payload fn decrypt( &self, + key_phase: KeyPhase, nonce: N, header: &[u8], payload_in: &[u8], @@ -85,6 +88,7 @@ pub mod decrypt { /// Decrypt a payload fn decrypt_in_place( &self, + key_phase: KeyPhase, nonce: N, header: &[u8], payload_and_tag: &mut [u8], @@ -92,6 +96,7 @@ pub mod decrypt { fn retransmission_tag( &self, + key_phase: KeyPhase, original_packet_number: u64, retransmission_packet_number: u64, tag_out: &mut [u8], diff --git a/dc/s2n-quic-dc/src/crypto/awslc.rs b/dc/s2n-quic-dc/src/crypto/awslc.rs index a71228a9e5..c792806178 100644 --- a/dc/s2n-quic-dc/src/crypto/awslc.rs +++ b/dc/s2n-quic-dc/src/crypto/awslc.rs @@ -4,7 +4,7 @@ use super::IntoNonce; use crate::credentials::Credentials; use aws_lc_rs::aead::{Aad, Algorithm, LessSafeKey, Nonce, UnboundKey, NONCE_LEN}; -use s2n_quic_core::assume; +use s2n_quic_core::{assume, packet::KeyPhase}; pub use aws_lc_rs::aead::{AES_128_GCM, AES_256_GCM}; @@ -41,6 +41,11 @@ impl super::encrypt::Key for EncryptKey { &self.credentials } + #[inline] + fn key_phase(&self) -> KeyPhase { + KeyPhase::Zero + } + #[inline(always)] fn tag_len(&self) -> usize { debug_assert_eq!(TAG_LEN, self.key.algorithm().tag_len()); @@ -137,6 +142,7 @@ impl super::decrypt::Key for DecryptKey { #[inline] fn decrypt( &self, + _key_phase: KeyPhase, nonce: N, header: &[u8], payload_in: &[u8], @@ -163,6 +169,7 @@ impl super::decrypt::Key for DecryptKey { #[inline] fn decrypt_in_place( &self, + _key_phase: KeyPhase, nonce: N, header: &[u8], payload_and_tag: &mut [u8], @@ -180,6 +187,7 @@ impl super::decrypt::Key for DecryptKey { #[inline] fn retransmission_tag( &self, + _key_phase: KeyPhase, original_packet_number: u64, retransmission_packet_number: u64, tag_out: &mut [u8], diff --git a/dc/s2n-quic-dc/src/crypto/testing.rs b/dc/s2n-quic-dc/src/crypto/testing.rs index ef8d6ea38b..f357bc5794 100644 --- a/dc/s2n-quic-dc/src/crypto/testing.rs +++ b/dc/s2n-quic-dc/src/crypto/testing.rs @@ -3,7 +3,7 @@ use super::IntoNonce; use crate::credentials::Credentials; -use s2n_quic_core::assume; +use s2n_quic_core::{assume, packet::KeyPhase}; #[derive(Clone, Debug)] pub struct Key { @@ -27,6 +27,11 @@ impl super::encrypt::Key for Key { &self.credentials } + #[inline] + fn key_phase(&self) -> KeyPhase { + KeyPhase::Zero + } + #[inline] fn tag_len(&self) -> usize { self.tag_len @@ -77,6 +82,7 @@ impl super::decrypt::Key for Key { #[inline] fn decrypt( &self, + _key_phase: KeyPhase, _nonce: N, _header: &[u8], payload_in: &[u8], @@ -90,6 +96,7 @@ impl super::decrypt::Key for Key { #[inline] fn decrypt_in_place( &self, + _key_phase: KeyPhase, _nonce: N, _header: &[u8], _payload_and_tag: &mut [u8], @@ -100,6 +107,7 @@ impl super::decrypt::Key for Key { #[inline] fn retransmission_tag( &self, + _key_phase: KeyPhase, _original_packet_number: u64, _retransmission_packet_number: u64, _tag_out: &mut [u8], diff --git a/dc/s2n-quic-dc/src/datagram/tunneled/recv.rs b/dc/s2n-quic-dc/src/datagram/tunneled/recv.rs index b57cc8adef..dd029a1a3a 100644 --- a/dc/s2n-quic-dc/src/datagram/tunneled/recv.rs +++ b/dc/s2n-quic-dc/src/datagram/tunneled/recv.rs @@ -54,6 +54,7 @@ impl Receiver { debug_assert_eq!(packet.payload().len(), payload_out.len()); self.key.decrypt( + packet.tag().key_phase(), packet.crypto_nonce(), packet.header(), packet.payload(), diff --git a/dc/s2n-quic-dc/src/lib.rs b/dc/s2n-quic-dc/src/lib.rs index 0ce8c3e3af..98b5e5ad95 100644 --- a/dc/s2n-quic-dc/src/lib.rs +++ b/dc/s2n-quic-dc/src/lib.rs @@ -18,4 +18,7 @@ pub mod socket; pub mod stream; pub mod task; +#[cfg(any(test, feature = "testing"))] +pub mod testing; + pub use s2n_quic_core::dc::{Version, SUPPORTED_VERSIONS}; diff --git a/dc/s2n-quic-dc/src/packet.rs b/dc/s2n-quic-dc/src/packet.rs index e4e6bdec92..d993e9af83 100644 --- a/dc/s2n-quic-dc/src/packet.rs +++ b/dc/s2n-quic-dc/src/packet.rs @@ -10,6 +10,7 @@ pub type PayloadLen = VarInt; #[macro_use] pub mod tag; +pub mod wire_version; pub mod control; pub mod datagram; @@ -17,6 +18,17 @@ pub mod secret_control; pub mod stream; pub use tag::Tag; +pub use wire_version::WireVersion; + +#[derive(Clone, Copy, Debug)] +pub enum Kind { + Stream, + Datagram, + Control, + StaleKey, + ReplayDetected, + UnknownPathSecret, +} #[derive(Debug)] pub enum Packet<'a> { @@ -28,6 +40,20 @@ pub enum Packet<'a> { UnknownPathSecret(secret_control::unknown_path_secret::Packet<'a>), } +impl<'a> Packet<'a> { + #[inline] + pub fn kind(&self) -> Kind { + match self { + Packet::Stream(_) => Kind::Stream, + Packet::Datagram(_) => Kind::Datagram, + Packet::Control(_) => Kind::Control, + Packet::StaleKey(_) => Kind::StaleKey, + Packet::ReplayDetected(_) => Kind::ReplayDetected, + Packet::UnknownPathSecret(_) => Kind::UnknownPathSecret, + } + } +} + impl<'a> s2n_codec::DecoderParameterizedValueMut<'a> for Packet<'a> { type Parameter = usize; diff --git a/dc/s2n-quic-dc/src/packet/control.rs b/dc/s2n-quic-dc/src/packet/control.rs index 8617c9ec21..e8095a9f9e 100644 --- a/dc/s2n-quic-dc/src/packet/control.rs +++ b/dc/s2n-quic-dc/src/packet/control.rs @@ -3,6 +3,7 @@ use super::tag::Common; use core::fmt; +use s2n_quic_core::packet::KeyPhase; use zerocopy::{AsBytes, FromBytes, FromZeroes, Unaligned}; const NONCE_MASK: u64 = 1 << 63; @@ -28,13 +29,15 @@ impl fmt::Debug for Tag { f.debug_struct("control::Tag") .field("is_stream", &self.is_stream()) .field("has_application_header", &self.has_application_header()) + .field("key_phase", &self.key_phase()) .finish() } } impl Tag { - pub const IS_STREAM_MASK: u8 = 0b0010; - pub const HAS_APPLICATION_HEADER_MASK: u8 = 0b00_0001; + pub const IS_STREAM_MASK: u8 = 0b0100; + pub const HAS_APPLICATION_HEADER_MASK: u8 = 0b0010; + pub const KEY_PHASE_MASK: u8 = 0b0001; pub const MIN: u8 = 0b0101_0000; pub const MAX: u8 = 0b0101_1111; @@ -59,6 +62,24 @@ impl Tag { self.0.get(Self::HAS_APPLICATION_HEADER_MASK) } + #[inline] + pub fn set_key_phase(&mut self, key_phase: KeyPhase) { + let v = match key_phase { + KeyPhase::Zero => false, + KeyPhase::One => true, + }; + self.0.set(Self::KEY_PHASE_MASK, v) + } + + #[inline] + pub fn key_phase(&self) -> KeyPhase { + if self.0.get(Self::KEY_PHASE_MASK) { + KeyPhase::One + } else { + KeyPhase::Zero + } + } + #[inline] fn validate(&self) -> Result<(), s2n_codec::DecoderError> { let range = Self::MIN..=Self::MAX; diff --git a/dc/s2n-quic-dc/src/packet/control/decoder.rs b/dc/s2n-quic-dc/src/packet/control/decoder.rs index b4ca6dd3e5..12a6273294 100644 --- a/dc/s2n-quic-dc/src/packet/control/decoder.rs +++ b/dc/s2n-quic-dc/src/packet/control/decoder.rs @@ -5,7 +5,7 @@ use crate::{ credentials::Credentials, packet::{ control::{self, Tag}, - stream, + stream, WireVersion, }, }; use s2n_codec::{ @@ -50,6 +50,7 @@ where #[derive(Debug)] pub struct Packet<'a> { tag: Tag, + wire_version: WireVersion, credentials: Credentials, source_control_port: u16, stream_id: Option, @@ -66,6 +67,11 @@ impl<'a> Packet<'a> { self.tag } + #[inline] + pub fn wire_version(&self) -> WireVersion { + self.wire_version + } + #[inline] pub fn credentials(&self) -> &Credentials { &self.credentials @@ -124,6 +130,7 @@ impl<'a> Packet<'a> { ) -> R { let ( tag, + wire_version, credentials, source_control_port, stream_id, @@ -147,6 +154,8 @@ impl<'a> Packet<'a> { let (tag, buffer) = buffer.decode()?; validator.validate_tag(tag)?; + let (wire_version, buffer) = buffer.decode()?; + let (credentials, buffer) = buffer.decode()?; let (source_control_port, buffer) = buffer.decode()?; @@ -181,6 +190,7 @@ impl<'a> Packet<'a> { ( tag, + wire_version, credentials, source_control_port, stream_id, @@ -222,6 +232,7 @@ impl<'a> Packet<'a> { let packet = Packet { tag, + wire_version, credentials, source_control_port, stream_id, diff --git a/dc/s2n-quic-dc/src/packet/control/encoder.rs b/dc/s2n-quic-dc/src/packet/control/encoder.rs index 14b6385dfc..8258108e09 100644 --- a/dc/s2n-quic-dc/src/packet/control/encoder.rs +++ b/dc/s2n-quic-dc/src/packet/control/encoder.rs @@ -5,7 +5,7 @@ use crate::{ crypto::encrypt, packet::{ control::{Tag, NONCE_MASK}, - stream, + stream, WireVersion, }, }; use s2n_codec::{Encoder, EncoderBuffer, EncoderValue}; @@ -31,19 +31,16 @@ where debug_assert_ne!(source_control_port, 0); let mut tag = Tag::default(); + tag.set_key_phase(crypto.key_phase()); + tag.set_is_stream(stream_id.is_some()); + tag.set_has_application_header(*header_len > 0); + encoder.encode(&tag); - if stream_id.is_some() { - tag.set_is_stream(true); - } - - if *header_len > 0 { - tag.set_has_application_header(true); - } + // wire version - we only support `0` currently + encoder.encode(&WireVersion::ZERO); let nonce = *packet_number | NONCE_MASK; - encoder.encode(&tag); - // encode the credentials being used encoder.encode(crypto.credentials()); encoder.encode(&source_control_port); diff --git a/dc/s2n-quic-dc/src/packet/datagram.rs b/dc/s2n-quic-dc/src/packet/datagram.rs index 1a585412e5..8066c212aa 100644 --- a/dc/s2n-quic-dc/src/packet/datagram.rs +++ b/dc/s2n-quic-dc/src/packet/datagram.rs @@ -3,6 +3,7 @@ use super::{tag::Common, HeaderLen, PacketNumber, PayloadLen}; use core::fmt; +use s2n_quic_core::packet::KeyPhase; use zerocopy::{AsBytes, FromBytes, FromZeroes, Unaligned}; pub mod decoder; @@ -27,6 +28,7 @@ impl fmt::Debug for Tag { .field("ack_eliciting", &self.ack_eliciting()) .field("is_connected", &self.is_connected()) .field("has_application_header", &self.has_application_header()) + .field("key_phase", &self.key_phase()) .finish() } } @@ -34,8 +36,8 @@ impl fmt::Debug for Tag { impl Tag { pub const ACK_ELICITING_MASK: u8 = 0b1000; pub const IS_CONNECTED_MASK: u8 = 0b0100; - const RESERVED: u8 = 0b0010; - pub const HAS_APPLICATION_HEADER_MASK: u8 = 0b0001; + pub const HAS_APPLICATION_HEADER_MASK: u8 = 0b0010; + pub const KEY_PHASE_MASK: u8 = 0b0001; pub const MIN: u8 = 0b0100_0000; pub const MAX: u8 = 0b0100_1111; @@ -70,14 +72,28 @@ impl Tag { self.0.get(Self::HAS_APPLICATION_HEADER_MASK) } + #[inline] + pub fn set_key_phase(&mut self, key_phase: KeyPhase) { + let v = match key_phase { + KeyPhase::Zero => false, + KeyPhase::One => true, + }; + self.0.set(Self::KEY_PHASE_MASK, v) + } + + #[inline] + pub fn key_phase(&self) -> KeyPhase { + if self.0.get(Self::KEY_PHASE_MASK) { + KeyPhase::One + } else { + KeyPhase::Zero + } + } + #[inline] fn validate(&self) -> Result<(), s2n_codec::DecoderError> { let range = Self::MIN..=Self::MAX; s2n_codec::decoder_invariant!(range.contains(&(self.0).0), "invalid datagram bit pattern"); - s2n_codec::decoder_invariant!( - self.0 .0 & Self::RESERVED == 0, - "invalid datagram bit pattern" - ); Ok(()) } } diff --git a/dc/s2n-quic-dc/src/packet/datagram/decoder.rs b/dc/s2n-quic-dc/src/packet/datagram/decoder.rs index baf391f9d5..be6ed09b8d 100644 --- a/dc/s2n-quic-dc/src/packet/datagram/decoder.rs +++ b/dc/s2n-quic-dc/src/packet/datagram/decoder.rs @@ -1,7 +1,10 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::{credentials::Credentials, packet::datagram::Tag}; +use crate::{ + credentials::Credentials, + packet::{datagram::Tag, WireVersion}, +}; use s2n_codec::{ decoder_invariant, CheckedRange, DecoderBufferMut, DecoderBufferMutResult as R, DecoderError, }; @@ -43,6 +46,7 @@ where pub struct Packet<'a> { tag: Tag, + wire_version: WireVersion, credentials: Credentials, source_control_port: u16, packet_number: PacketNumber, @@ -58,6 +62,7 @@ impl<'a> std::fmt::Debug for Packet<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Packet") .field("tag", &self.tag) + .field("wire_version", &self.wire_version) .field("credentials", &self.credentials) .field("source_control_port", &self.source_control_port) .field("packet_number", &self.packet_number) @@ -80,6 +85,11 @@ impl<'a> Packet<'a> { self.tag } + #[inline] + pub fn wire_version(&self) -> WireVersion { + self.wire_version + } + #[inline] pub fn credentials(&self) -> &Credentials { &self.credentials @@ -143,6 +153,7 @@ impl<'a> Packet<'a> { ) -> R { let ( tag, + wire_version, credentials, source_control_port, packet_number, @@ -167,6 +178,8 @@ impl<'a> Packet<'a> { let (tag, buffer) = buffer.decode()?; validator.validate_tag(tag)?; + let (wire_version, buffer) = buffer.decode()?; + let (credentials, buffer) = buffer.decode()?; let (source_control_port, buffer) = buffer.decode()?; @@ -205,6 +218,7 @@ impl<'a> Packet<'a> { ( tag, + wire_version, credentials, source_control_port, packet_number, @@ -250,6 +264,7 @@ impl<'a> Packet<'a> { let packet = Packet { tag, + wire_version, credentials, source_control_port, packet_number, diff --git a/dc/s2n-quic-dc/src/packet/datagram/encoder.rs b/dc/s2n-quic-dc/src/packet/datagram/encoder.rs index f5abd5e796..262852150e 100644 --- a/dc/s2n-quic-dc/src/packet/datagram/encoder.rs +++ b/dc/s2n-quic-dc/src/packet/datagram/encoder.rs @@ -1,7 +1,11 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::{credentials, crypto::encrypt, packet::datagram::Tag}; +use crate::{ + credentials, + crypto::encrypt, + packet::{datagram::Tag, WireVersion}, +}; use s2n_codec::{Encoder, EncoderBuffer, EncoderValue}; use s2n_quic_core::{assume, buffer, varint::VarInt}; @@ -19,6 +23,9 @@ pub fn estimate_len( let mut encoder = s2n_codec::EncoderLenEstimator::new(usize::MAX); encoder.encode(&Tag::default()); + // wire version + encoder.encode(&WireVersion::ZERO); + // credentials { encoder.write_zerocopy::(|_| {}); @@ -69,13 +76,16 @@ where tag.set_is_connected(packet_number.is_some()); tag.set_has_application_header(header_len != super::HeaderLen::ZERO); tag.set_ack_eliciting(next_expected_control_packet.is_some()); + tag.set_key_phase(crypto.key_phase()); + encoder.encode(&tag); + + // wire version - we only support `0` currently + encoder.encode(&WireVersion::ZERO); let header_len_usize = *header_len as usize; let payload_len_usize = *payload_len as usize; let nonce = *packet_number.unwrap_or(super::PacketNumber::ZERO); - encoder.encode(&tag); - // encode the credentials being used encoder.encode(crypto.credentials()); encoder.encode(&source_control_port); diff --git a/dc/s2n-quic-dc/src/packet/secret_control.rs b/dc/s2n-quic-dc/src/packet/secret_control.rs index ecd6cef26e..7f15b8eac7 100644 --- a/dc/s2n-quic-dc/src/packet/secret_control.rs +++ b/dc/s2n-quic-dc/src/packet/secret_control.rs @@ -4,6 +4,7 @@ use crate::{ credentials, crypto::{decrypt, encrypt}, + packet::WireVersion, }; use s2n_codec::{ decoder_invariant, decoder_value, DecoderBuffer, DecoderBufferMut, diff --git a/dc/s2n-quic-dc/src/packet/secret_control/decoder.rs b/dc/s2n-quic-dc/src/packet/secret_control/decoder.rs index 65c08345c6..58d9e31fe9 100644 --- a/dc/s2n-quic-dc/src/packet/secret_control/decoder.rs +++ b/dc/s2n-quic-dc/src/packet/secret_control/decoder.rs @@ -46,6 +46,8 @@ macro_rules! impl_packet { crypto .decrypt( + // these don't rotate + s2n_quic_core::packet::KeyPhase::Zero, value.nonce(), header, &[], diff --git a/dc/s2n-quic-dc/src/packet/secret_control/replay_detected.rs b/dc/s2n-quic-dc/src/packet/secret_control/replay_detected.rs index dd76ebda4e..4e312d97c3 100644 --- a/dc/s2n-quic-dc/src/packet/secret_control/replay_detected.rs +++ b/dc/s2n-quic-dc/src/packet/secret_control/replay_detected.rs @@ -9,6 +9,7 @@ impl_packet!(ReplayDetected); #[derive(Clone, Copy, Debug, PartialEq, Eq)] #[cfg_attr(test, derive(bolero_generator::TypeGenerator))] pub struct ReplayDetected { + pub wire_version: WireVersion, pub credential_id: credentials::Id, pub rejected_key_id: VarInt, } @@ -20,6 +21,7 @@ impl ReplayDetected { C: encrypt::Key, { encoder.encode(&Tag::default()); + encoder.encode(&self.wire_version); encoder.encode(&self.credential_id); encoder.encode(&self.rejected_key_id); @@ -44,9 +46,11 @@ impl<'a> DecoderValue<'a> for ReplayDetected { fn decode(buffer: DecoderBuffer<'a>) -> R<'a, Self> { let (tag, buffer) = buffer.decode::()?; decoder_invariant!(tag == Tag::default(), "invalid tag"); + let (wire_version, buffer) = buffer.decode()?; let (credential_id, buffer) = buffer.decode()?; let (rejected_key_id, buffer) = buffer.decode()?; let value = Self { + wire_version, credential_id, rejected_key_id, }; diff --git a/dc/s2n-quic-dc/src/packet/secret_control/stale_key.rs b/dc/s2n-quic-dc/src/packet/secret_control/stale_key.rs index d13c48d1d7..cffdef2513 100644 --- a/dc/s2n-quic-dc/src/packet/secret_control/stale_key.rs +++ b/dc/s2n-quic-dc/src/packet/secret_control/stale_key.rs @@ -9,6 +9,7 @@ impl_packet!(StaleKey); #[derive(Clone, Copy, Debug, PartialEq, Eq)] #[cfg_attr(test, derive(bolero_generator::TypeGenerator))] pub struct StaleKey { + pub wire_version: WireVersion, pub credential_id: credentials::Id, pub min_key_id: VarInt, } @@ -20,6 +21,7 @@ impl StaleKey { C: encrypt::Key, { encoder.encode(&Tag::default()); + encoder.encode(&self.wire_version); encoder.encode(&self.credential_id); encoder.encode(&self.min_key_id); @@ -44,9 +46,11 @@ impl<'a> DecoderValue<'a> for StaleKey { fn decode(buffer: DecoderBuffer<'a>) -> R<'a, Self> { let (tag, buffer) = buffer.decode::()?; decoder_invariant!(tag == Tag::default(), "invalid tag"); + let (wire_version, buffer) = buffer.decode()?; let (credential_id, buffer) = buffer.decode()?; let (min_key_id, buffer) = buffer.decode()?; let value = Self { + wire_version, credential_id, min_key_id, }; diff --git a/dc/s2n-quic-dc/src/packet/secret_control/unknown_path_secret.rs b/dc/s2n-quic-dc/src/packet/secret_control/unknown_path_secret.rs index eb9e29de9b..b7972d6157 100644 --- a/dc/s2n-quic-dc/src/packet/secret_control/unknown_path_secret.rs +++ b/dc/s2n-quic-dc/src/packet/secret_control/unknown_path_secret.rs @@ -23,7 +23,10 @@ impl<'a> Packet<'a> { ) -> Packet<'_> { Packet { header: &[], - value: UnknownPathSecret { credential_id: id }, + value: UnknownPathSecret { + wire_version: WireVersion::ZERO, + credential_id: id, + }, crypto_tag: &stateless_reset[..], } } @@ -59,6 +62,7 @@ impl<'a> Packet<'a> { #[derive(Clone, Copy, Debug, PartialEq, Eq)] #[cfg_attr(test, derive(bolero_generator::TypeGenerator))] pub struct UnknownPathSecret { + pub wire_version: WireVersion, pub credential_id: credentials::Id, } @@ -74,6 +78,7 @@ impl UnknownPathSecret { ) -> usize { let before = encoder.len(); encoder.encode(&Tag::default()); + encoder.encode(&self.wire_version); encoder.encode(&&self.credential_id[..]); encoder.encode(&&stateless_reset_tag[..]); let after = encoder.len(); @@ -91,8 +96,12 @@ impl<'a> DecoderValue<'a> for UnknownPathSecret { fn decode(buffer: DecoderBuffer<'a>) -> R<'a, Self> { let (tag, buffer) = buffer.decode::()?; decoder_invariant!(tag == Tag::default(), "invalid tag"); + let (wire_version, buffer) = buffer.decode()?; let (credential_id, buffer) = buffer.decode()?; - let value = Self { credential_id }; + let value = Self { + wire_version, + credential_id, + }; Ok((value, buffer)) } } diff --git a/dc/s2n-quic-dc/src/packet/stream.rs b/dc/s2n-quic-dc/src/packet/stream.rs index 1a15a70eb0..32a8cef210 100644 --- a/dc/s2n-quic-dc/src/packet/stream.rs +++ b/dc/s2n-quic-dc/src/packet/stream.rs @@ -3,7 +3,7 @@ use super::tag::Common; use core::fmt; -use s2n_quic_core::{probe, varint::VarInt}; +use s2n_quic_core::{packet::KeyPhase, probe, varint::VarInt}; use zerocopy::{AsBytes, FromBytes, FromZeroes, Unaligned}; pub mod decoder; @@ -66,16 +66,18 @@ impl fmt::Debug for Tag { .field("packet_space", &self.packet_space()) .field("has_final_offset", &self.has_final_offset()) .field("has_application_header", &self.has_application_header()) + .field("key_phase", &self.key_phase()) .finish() } } impl Tag { - pub const HAS_SOURCE_STREAM_PORT: u8 = 0b01_0000; - pub const IS_RECOVERY_PACKET: u8 = 0b00_1000; - pub const HAS_CONTROL_DATA_MASK: u8 = 0b00_0100; - pub const HAS_FINAL_OFFSET_MASK: u8 = 0b00_0010; - pub const HAS_APPLICATION_HEADER_MASK: u8 = 0b00_0001; + pub const HAS_SOURCE_STREAM_PORT: u8 = 0b10_0000; + pub const IS_RECOVERY_PACKET: u8 = 0b01_0000; + pub const HAS_CONTROL_DATA_MASK: u8 = 0b00_1000; + pub const HAS_FINAL_OFFSET_MASK: u8 = 0b00_0100; + pub const HAS_APPLICATION_HEADER_MASK: u8 = 0b00_0010; + pub const KEY_PHASE_MASK: u8 = 0b00_0001; pub const MIN: u8 = 0b0000_0000; pub const MAX: u8 = 0b0011_1111; @@ -135,6 +137,24 @@ impl Tag { self.0.get(Self::HAS_APPLICATION_HEADER_MASK) } + #[inline] + pub fn set_key_phase(&mut self, key_phase: KeyPhase) { + let v = match key_phase { + KeyPhase::Zero => false, + KeyPhase::One => true, + }; + self.0.set(Self::KEY_PHASE_MASK, v) + } + + #[inline] + pub fn key_phase(&self) -> KeyPhase { + if self.0.get(Self::KEY_PHASE_MASK) { + KeyPhase::One + } else { + KeyPhase::Zero + } + } + #[inline] fn validate(&self) -> Result<(), s2n_codec::DecoderError> { let range = Self::MIN..=Self::MAX; diff --git a/dc/s2n-quic-dc/src/packet/stream/decoder.rs b/dc/s2n-quic-dc/src/packet/stream/decoder.rs index 1fac8074ba..5c178d0dac 100644 --- a/dc/s2n-quic-dc/src/packet/stream/decoder.rs +++ b/dc/s2n-quic-dc/src/packet/stream/decoder.rs @@ -4,7 +4,10 @@ use crate::{ credentials::Credentials, crypto, - packet::stream::{self, RelativeRetransmissionOffset, Tag}, + packet::{ + stream::{self, RelativeRetransmissionOffset, Tag}, + WireVersion, + }, }; use core::{fmt, mem::size_of}; use s2n_codec::{ @@ -49,6 +52,7 @@ where #[derive(Clone, Debug, PartialEq, Eq)] pub struct Owned { pub tag: Tag, + pub wire_version: WireVersion, pub credentials: Credentials, pub source_control_port: u16, pub source_stream_port: Option, @@ -72,6 +76,7 @@ impl<'a> From> for Owned { Self { tag: packet.tag, + wire_version: packet.wire_version, credentials: packet.credentials, source_control_port: packet.source_control_port, source_stream_port: packet.source_stream_port, @@ -92,6 +97,7 @@ impl<'a> From> for Owned { pub struct Packet<'a> { tag: Tag, + wire_version: WireVersion, credentials: Credentials, source_control_port: u16, source_stream_port: Option, @@ -113,6 +119,7 @@ impl<'a> fmt::Debug for Packet<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("stream::Packet") .field("tag", &self.tag) + .field("wire_version", &self.wire_version) .field("credentials", &self.credentials) .field("source_control_port", &self.source_control_port) .field("source_stream_port", &self.source_stream_port) @@ -133,6 +140,11 @@ impl<'a> Packet<'a> { self.tag } + #[inline] + pub fn wire_version(&self) -> WireVersion { + self.wire_version + } + #[inline] pub fn credentials(&self) -> &Credentials { &self.credentials @@ -228,6 +240,7 @@ impl<'a> Packet<'a> { where D: crypto::decrypt::Key, { + let key_phase = self.tag.key_phase(); let space = self.remove_retransmit(d); let nonce = space.packet_number_into_nonce(self.original_packet_number); @@ -236,7 +249,7 @@ impl<'a> Packet<'a> { let payload = &self.payload; let auth_tag = &self.auth_tag; - d.decrypt(nonce, header, payload, auth_tag, payload_out)?; + d.decrypt(key_phase, nonce, header, payload, auth_tag, payload_out)?; Ok(()) } @@ -246,6 +259,7 @@ impl<'a> Packet<'a> { where D: crypto::decrypt::Key, { + let key_phase = self.tag.key_phase(); let space = self.remove_retransmit(d); let nonce = space.packet_number_into_nonce(self.original_packet_number); @@ -261,7 +275,7 @@ impl<'a> Packet<'a> { core::slice::from_raw_parts_mut(payload_ptr, payload_len + tag_len) }; - d.decrypt_in_place(nonce, header, payload_and_tag)?; + d.decrypt_in_place(key_phase, nonce, header, payload_and_tag)?; Ok(()) } @@ -271,12 +285,14 @@ impl<'a> Packet<'a> { where D: crypto::decrypt::Key, { + let key_phase = self.tag.key_phase(); let space = self.tag.packet_space(); let original_packet_number = self.original_packet_number; let retransmission_packet_number = self.packet_number; if original_packet_number != retransmission_packet_number { d.retransmission_tag( + key_phase, original_packet_number.as_u64(), stream::PacketSpace::Recovery .packet_number_into_nonce(retransmission_packet_number), @@ -388,6 +404,8 @@ impl<'a> Packet<'a> { tag }; + let (_wire_version, buffer) = buffer.decode::()?; + let (credentials, buffer) = buffer.decode::()?; debug_assert_eq!(&credentials, key.credentials()); @@ -471,6 +489,7 @@ impl<'a> Packet<'a> { ) -> R { let ( tag, + wire_version, credentials, source_control_port, source_stream_port, @@ -501,6 +520,8 @@ impl<'a> Packet<'a> { let (tag, buffer) = buffer.decode()?; validator.validate_tag(tag)?; + let (wire_version, buffer) = buffer.decode::()?; + let (credentials, buffer) = buffer.decode()?; let (source_control_port, buffer) = buffer.decode()?; @@ -564,6 +585,7 @@ impl<'a> Packet<'a> { ( tag, + wire_version, credentials, source_control_port, source_stream_port, @@ -615,6 +637,7 @@ impl<'a> Packet<'a> { let packet = Packet { tag, + wire_version, credentials, source_control_port, source_stream_port, diff --git a/dc/s2n-quic-dc/src/packet/stream/encoder.rs b/dc/s2n-quic-dc/src/packet/stream/encoder.rs index d07270f9b1..3e2a190191 100644 --- a/dc/s2n-quic-dc/src/packet/stream/encoder.rs +++ b/dc/s2n-quic-dc/src/packet/stream/encoder.rs @@ -3,7 +3,10 @@ use crate::{ crypto::encrypt, - packet::stream::{self, RelativeRetransmissionOffset, Tag}, + packet::{ + stream::{self, RelativeRetransmissionOffset, Tag}, + WireVersion, + }, }; use s2n_codec::{Encoder, EncoderBuffer, EncoderValue}; use s2n_quic_core::{ @@ -41,32 +44,23 @@ where let stream_offset = payload.current_offset(); let final_offset = payload.final_offset(); - let mut tag = Tag::default(); - debug_assert_ne!(source_control_port, 0); debug_assert_ne!(source_stream_port, Some(0)); - if *control_data_len > 0 { - tag.set_has_control_data(true); - } - - if final_offset.is_some() { - tag.set_has_final_offset(true); - } - - if *header_len > 0 { - tag.set_has_application_header(true); - } + let mut tag = Tag::default(); + tag.set_key_phase(crypto.key_phase()); + tag.set_has_control_data(*control_data_len > 0); + tag.set_has_final_offset(final_offset.is_some()); + tag.set_has_application_header(*header_len > 0); + tag.set_has_source_stream_port(source_stream_port.is_some()); + tag.set_packet_space(packet_space); + encoder.encode(&tag); - if source_stream_port.is_some() { - tag.set_has_source_stream_port(true); - } + // wire version - we only support `0` currently + encoder.encode(&WireVersion::ZERO); - tag.set_packet_space(packet_space); let nonce = packet_space.packet_number_into_nonce(packet_number); - encoder.encode(&tag); - // encode the credentials being used encoder.encode(crypto.credentials()); encoder.encode(&source_control_port); diff --git a/dc/s2n-quic-dc/src/packet/wire_version.rs b/dc/s2n-quic-dc/src/packet/wire_version.rs new file mode 100644 index 0000000000..5cf9497aa5 --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/wire_version.rs @@ -0,0 +1,32 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use s2n_codec::{decoder_invariant, decoder_value, Encoder, EncoderValue}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +#[cfg_attr(test, derive(bolero_generator::TypeGenerator))] +pub struct WireVersion(#[cfg_attr(test, generator(bolero_generator::constant(0)))] pub u32); + +impl WireVersion { + pub const ZERO: Self = Self(0); +} + +decoder_value!( + impl<'a> WireVersion { + fn decode(buffer: Buffer) -> Result { + let (version, buffer) = buffer.decode::()?; + decoder_invariant!(version == 0, "only wire version 0 is supported currently"); + let version = Self(version as _); + Ok((version, buffer)) + } + } +); + +impl EncoderValue for WireVersion { + #[inline] + fn encode(&self, encoder: &mut E) { + debug_assert!(self.0 <= u8::MAX as u32); + let v = self.0 as u8; + v.encode(encoder); + } +} diff --git a/dc/s2n-quic-dc/src/path/secret/key.rs b/dc/s2n-quic-dc/src/path/secret/key.rs index ef6af3b8b3..02054db900 100644 --- a/dc/s2n-quic-dc/src/path/secret/key.rs +++ b/dc/s2n-quic-dc/src/path/secret/key.rs @@ -7,6 +7,7 @@ use crate::{ crypto::{awslc, decrypt, encrypt, IntoNonce, UninitSlice}, }; use core::mem::MaybeUninit; +use s2n_quic_core::packet::KeyPhase; use zeroize::Zeroize; #[derive(Debug)] @@ -20,6 +21,11 @@ impl encrypt::Key for Sealer { self.sealer.credentials() } + #[inline] + fn key_phase(&self) -> KeyPhase { + KeyPhase::Zero + } + #[inline] fn tag_len(&self) -> usize { self.sealer.tag_len() @@ -109,6 +115,7 @@ impl decrypt::Key for Opener { #[inline] fn decrypt( &self, + key_phase: KeyPhase, nonce: N, header: &[u8], payload_in: &[u8], @@ -116,7 +123,7 @@ impl decrypt::Key for Opener { payload_out: &mut UninitSlice, ) -> decrypt::Result { self.opener - .decrypt(nonce, header, payload_in, tag, payload_out)?; + .decrypt(key_phase, nonce, header, payload_in, tag, payload_out)?; self.on_decrypt_success(payload_out)?; @@ -126,12 +133,13 @@ impl decrypt::Key for Opener { #[inline] fn decrypt_in_place( &self, + key_phase: KeyPhase, nonce: N, header: &[u8], payload_and_tag: &mut [u8], ) -> decrypt::Result { self.opener - .decrypt_in_place(nonce, header, payload_and_tag)?; + .decrypt_in_place(key_phase, nonce, header, payload_and_tag)?; self.on_decrypt_success(UninitSlice::new(payload_and_tag))?; @@ -141,11 +149,13 @@ impl decrypt::Key for Opener { #[inline] fn retransmission_tag( &self, + key_phase: KeyPhase, original_packet_number: u64, retransmission_packet_number: u64, tag_out: &mut [u8], ) { self.opener.retransmission_tag( + key_phase, original_packet_number, retransmission_packet_number, tag_out, diff --git a/dc/s2n-quic-dc/src/path/secret/map.rs b/dc/s2n-quic-dc/src/path/secret/map.rs index 02fdfb7fd1..8e39245fb6 100644 --- a/dc/s2n-quic-dc/src/path/secret/map.rs +++ b/dc/s2n-quic-dc/src/path/secret/map.rs @@ -9,7 +9,7 @@ use super::{ use crate::{ credentials::{Credentials, Id}, crypto, - packet::{secret_control as control, Packet}, + packet::{secret_control as control, Packet, WireVersion}, }; use rand::Rng as _; use s2n_codec::EncoderBuffer; @@ -412,6 +412,7 @@ impl Map { let ids_guard = self.state.ids.guard(); let Some(state) = self.state.ids.get(&identity.id, &ids_guard) else { let packet = control::UnknownPathSecret { + wire_version: WireVersion::ZERO, credential_id: identity.id, }; control_out.resize(control::UnknownPathSecret::PACKET_SIZE, 0); @@ -472,7 +473,7 @@ impl Map { pub fn for_test_with_peers( peers: Vec<(schedule::Ciphersuite, dc::Version, SocketAddr)>, ) -> (Self, Vec) { - let provider = Self::new(Default::default()); + let provider = Self::new(stateless_reset::Signer::random()); let mut secret = [0; 32]; aws_lc_rs::rand::fill(&mut secret).unwrap(); let mut stateless_reset = [0; 16]; @@ -568,11 +569,13 @@ impl receiver::Error { let encoder = EncoderBuffer::new(&mut buffer[..]); let length = match self { receiver::Error::AlreadyExists => control::ReplayDetected { + wire_version: WireVersion::ZERO, credential_id: credentials.id, rejected_key_id: credentials.key_id, } .encode(encoder, &entry.secret.control_sealer()), receiver::Error::Unknown => control::StaleKey { + wire_version: WireVersion::ZERO, credential_id: credentials.id, min_key_id: entry.receiver.minimum_unseen_key_id(), } diff --git a/dc/s2n-quic-dc/src/path/secret/stateless_reset.rs b/dc/s2n-quic-dc/src/path/secret/stateless_reset.rs index f8e4a9bc76..26edbdc7a8 100644 --- a/dc/s2n-quic-dc/src/path/secret/stateless_reset.rs +++ b/dc/s2n-quic-dc/src/path/secret/stateless_reset.rs @@ -10,20 +10,23 @@ pub struct Signer { prk: Prk, } -impl Default for Signer { - fn default() -> Self { - let mut secret = [0u8; 32]; - aws_lc_rs::rand::fill(&mut secret).unwrap(); - Self::new(&secret) - } -} - impl Signer { + /// Creates a signer with the given secret pub fn new(secret: &[u8]) -> Self { let prk = Salt::new(HKDF_SHA384, secret).extract(b"rst"); Self { prk } } + /// Returns a random `Signer` + /// + /// Note that this signer cannot be used across restarts and will result in an endpoint + /// producing invalid `UnknownPathSecret` packets. + pub fn random() -> Self { + let mut secret = [0u8; 32]; + aws_lc_rs::rand::fill(&mut secret).unwrap(); + Self::new(&secret) + } + pub fn sign(&self, id: &Id) -> [u8; 16] { let mut stateless_reset = [0; 16]; diff --git a/dc/s2n-quic-dc/src/stream.rs b/dc/s2n-quic-dc/src/stream.rs index ef3a82b922..941a51e204 100644 --- a/dc/s2n-quic-dc/src/stream.rs +++ b/dc/s2n-quic-dc/src/stream.rs @@ -8,14 +8,19 @@ pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(30); /// The maximum time a send stream will wait for ACKs from inflight packets pub const DEFAULT_INFLIGHT_TIMEOUT: Duration = Duration::from_secs(5); +pub mod application; pub mod crypto; +pub mod endpoint; +pub mod environment; pub mod pacer; pub mod packet_map; pub mod packet_number; pub mod processing; pub mod recv; +pub mod runtime; pub mod send; pub mod server; +pub mod shared; pub mod socket; bitflags::bitflags! { diff --git a/dc/s2n-quic-dc/src/stream/application.rs b/dc/s2n-quic-dc/src/stream/application.rs new file mode 100644 index 0000000000..976caf13f9 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/application.rs @@ -0,0 +1,151 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::stream::{ + recv::application::{self as recv, Reader}, + send::application::{self as send, Writer}, + shared::ArcShared, + socket, +}; +use core::fmt; +use s2n_quic_core::buffer; +use std::{io, net::SocketAddr}; + +pub struct Builder { + pub read: recv::Builder, + pub write: send::Builder, + pub shared: ArcShared, + pub sockets: Box, +} + +impl Builder { + #[inline] + pub fn build(self) -> io::Result { + let Self { + read, + write, + shared, + sockets, + } = self; + let sockets = sockets.build()?; + let read = read.build(shared.clone(), sockets.clone()); + let write = write.build(shared, sockets); + Ok(Stream { read, write }) + } +} + +pub struct Stream { + read: Reader, + write: Writer, +} + +impl fmt::Debug for Stream { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Stream") + .field("peer_addr", &self.peer_addr().unwrap()) + .field("local_addr", &self.local_addr().unwrap()) + .finish() + } +} + +impl Stream { + #[inline] + pub fn peer_addr(&self) -> io::Result { + self.read.peer_addr() + } + + #[inline] + pub fn local_addr(&self) -> io::Result { + self.read.local_addr() + } + + #[inline] + pub fn protocol(&self) -> socket::Protocol { + self.read.protocol() + } + + #[inline] + pub async fn write_from( + &mut self, + buf: &mut impl buffer::reader::storage::Infallible, + ) -> io::Result { + self.write.write_from(buf).await + } + + #[inline] + pub async fn read_into( + &mut self, + out_buf: &mut impl buffer::writer::Storage, + ) -> io::Result { + self.read.read_into(out_buf).await + } + + #[inline] + pub fn split(&mut self) -> (&mut Reader, &mut Writer) { + (&mut self.read, &mut self.write) + } + + #[inline] + pub fn into_split(self) -> (Reader, Writer) { + (self.read, self.write) + } +} + +#[cfg(feature = "tokio")] +mod tokio_impl { + use super::Stream; + use core::{ + pin::Pin, + task::{Context, Poll}, + }; + use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf}; + + impl AsyncRead for Stream { + #[inline] + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.read).poll_read(cx, buf) + } + } + + impl AsyncWrite for Stream { + #[inline] + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.write).poll_write(cx, buf) + } + + #[inline] + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context, + buf: &[std::io::IoSlice], + ) -> Poll> { + Pin::new(&mut self.write).poll_write_vectored(cx, buf) + } + + #[inline] + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.write).poll_flush(cx) + } + + #[inline] + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.write).poll_shutdown(cx) + } + + #[inline(always)] + fn is_write_vectored(&self) -> bool { + self.write.is_write_vectored() + } + } +} diff --git a/dc/s2n-quic-dc/src/stream/endpoint.rs b/dc/s2n-quic-dc/src/stream/endpoint.rs new file mode 100644 index 0000000000..3c9b97ef8f --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/endpoint.rs @@ -0,0 +1,323 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + crypto::encrypt::Key as _, + msg, packet, + path::secret::Map, + random::Random, + stream::{ + application, + environment::{Environment, Peer}, + recv, + send::{self, flow}, + server, shared, + }, +}; +use core::cell::UnsafeCell; +use s2n_quic_core::{ + dc, endpoint, + inet::{ExplicitCongestionNotification, SocketAddress}, + varint::VarInt, +}; +use std::{io, sync::Arc}; +use tracing::{debug_span, Instrument as _}; + +type Result = core::result::Result; + +pub struct AcceptError { + pub secret_control: Vec, + pub peer: Option, + pub error: io::Error, +} + +#[inline] +pub fn open_stream( + env: &Env, + handshake_addr: SocketAddress, + peer: P, + map: &Map, + parameter_override: Option<&dyn Fn(dc::ApplicationParams) -> dc::ApplicationParams>, +) -> Result +where + Env: Environment, + P: Peer, +{ + // derive secrets for the new stream + let Some((sealer, opener, mut parameters)) = map.pair_for_peer(handshake_addr.into()) else { + // the application didn't perform a handshake with the server before opening the stream + return Err(io::Error::new( + io::ErrorKind::NotFound, + format!("missing credentials for server: {handshake_addr}"), + )); + }; + + if let Some(o) = parameter_override { + parameters = o(parameters); + } + + // TODO get a flow ID. for now we'll use the sealer credentials + let key_id = sealer.credentials().key_id; + let stream_id = packet::stream::Id { + key_id, + is_reliable: true, + is_bidirectional: true, + }; + + let crypto = shared::Crypto::new(sealer, opener, map); + + build_stream( + env, + peer, + stream_id, + None, + crypto, + parameters, + None, + None, + endpoint::Type::Client, + ) +} + +#[inline] +pub fn accept_stream( + env: &Env, + mut peer: P, + packet: &server::InitialPacket, + handshake: Option, + buffer: Option<&mut msg::recv::Message>, + map: &Map, + parameter_override: Option<&dyn Fn(dc::ApplicationParams) -> dc::ApplicationParams>, +) -> Result> +where + Env: Environment, + P: Peer, +{ + let credentials = &packet.credentials; + let mut secret_control = vec![]; + let Some((sealer, opener, mut parameters)) = + map.pair_for_credentials(credentials, &mut secret_control) + else { + let error = io::Error::new( + io::ErrorKind::NotFound, + format!("missing credentials for client: {credentials:?}"), + ); + let error = AcceptError { + secret_control, + peer: Some(peer), + error, + }; + return Err(error); + }; + + if let Some(o) = parameter_override { + parameters = o(parameters); + } + + // inform the value of what the source_control_port is + peer.with_source_control_port(packet.source_control_port); + + let crypto = shared::Crypto::new(sealer, opener, map); + + let res = build_stream( + env, + peer, + packet.stream_id, + packet.source_stream_port, + crypto, + parameters, + handshake, + buffer, + endpoint::Type::Server, + ); + + match res { + Ok(stream) => Ok(stream), + Err(error) => { + let error = AcceptError { + secret_control, + peer: None, + error, + }; + Err(error) + } + } +} + +#[inline] +fn build_stream( + env: &Env, + peer: P, + stream_id: packet::stream::Id, + remote_stream_port: Option, + crypto: shared::Crypto, + parameters: dc::ApplicationParams, + handshake: Option, + recv_buffer: Option<&mut msg::recv::Message>, + endpoint_type: endpoint::Type, +) -> Result +where + Env: Environment, + P: Peer, +{ + let features = peer.features(); + + let sockets = peer.setup(env)?; + + // construct shared reader state + let reader = recv::shared::State::new(stream_id, ¶meters, handshake, features, recv_buffer); + + let writer = { + let worker = sockets + .write_worker + .map(|socket| (send::state::State::new(stream_id, ¶meters), socket)); + + let (flow_offset, send_quantum, bandwidth) = + if let Some((worker, _socket)) = worker.as_ref() { + let flow_offset = worker.flow_offset(); + let send_quantum = worker.send_quantum_packets(); + let bandwidth = Some(worker.cca.bandwidth()); + + (flow_offset, send_quantum, bandwidth) + } else { + debug_assert!( + features.is_flow_controlled(), + "transports without flow control need background workers" + ); + + let flow_offset = VarInt::MAX; + let send_quantum = 10; + let bandwidth = None; + + (flow_offset, send_quantum, bandwidth) + }; + + let flow = flow::non_blocking::State::new(flow_offset); + + let path = send::path::Info { + max_datagram_size: parameters.max_datagram_size, + send_quantum, + ecn: ExplicitCongestionNotification::Ect0, + next_expected_control_packet: VarInt::ZERO, + }; + + // construct shared writer state + let state = send::shared::State::new(flow, path, bandwidth); + + (state, worker) + }; + + // construct shared common state between readers/writers + let common = { + let application = send::application::state::State { + stream_id, + source_control_port: sockets.source_control_port, + source_stream_port: sockets.source_stream_port, + }; + + let fixed = shared::FixedValues { + remote_ip: UnsafeCell::new(sockets.remote_addr.ip()), + source_control_port: UnsafeCell::new(sockets.source_control_port), + application: UnsafeCell::new(application), + }; + + let remote_port = sockets.remote_addr.port(); + let write_remote_port = remote_stream_port.unwrap_or(remote_port); + + shared::Common { + clock: env.clock().clone(), + gso: env.gso(), + read_remote_port: remote_port.into(), + write_remote_port: write_remote_port.into(), + last_peer_activity: Default::default(), + fixed, + closed_halves: 0u8.into(), + } + }; + + let shared = Arc::new(shared::Shared { + receiver: reader, + sender: writer.0, + common, + crypto, + }); + + // spawn the read worker + if let Some(socket) = sockets.read_worker { + let shared = shared.clone(); + + let task = async move { + let mut reader = recv::worker::Worker::new(socket, shared, endpoint_type); + + let mut prev_waker: Option = None; + core::future::poll_fn(|cx| { + // update the waker if needed + if prev_waker + .as_ref() + .map_or(true, |prev| !prev.will_wake(cx.waker())) + { + prev_waker = Some(cx.waker().clone()); + reader.update_waker(cx); + } + + // drive the reader to completion + reader.poll(cx) + }) + .await; + }; + + let span = debug_span!("worker::read"); + + if span.is_disabled() { + env.spawn_reader(task); + } else { + env.spawn_reader(task.instrument(span)); + } + } + + // spawn the write worker + if let Some((worker, socket)) = writer.1 { + let shared = shared.clone(); + + let task = async move { + let mut writer = + send::worker::Worker::new(socket, Random::default(), shared, worker, endpoint_type); + + let mut prev_waker: Option = None; + core::future::poll_fn(|cx| { + // update the waker if needed + if prev_waker + .as_ref() + .map_or(true, |prev| !prev.will_wake(cx.waker())) + { + prev_waker = Some(cx.waker().clone()); + writer.update_waker(cx); + } + + // drive the writer to completion + writer.poll(cx) + }) + .await; + }; + + let span = debug_span!("worker::write"); + + if span.is_disabled() { + env.spawn_writer(task); + } else { + env.spawn_writer(task.instrument(span)); + } + } + + let read = recv::application::Builder::new(endpoint_type, env.reader_rt()); + let write = send::application::Builder::new(env.writer_rt()); + + let stream = application::Builder { + read, + write, + shared, + sockets: sockets.application, + }; + + Ok(stream) +} diff --git a/dc/s2n-quic-dc/src/stream/environment.rs b/dc/s2n-quic-dc/src/stream/environment.rs new file mode 100644 index 0000000000..b5e2bc25f8 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/environment.rs @@ -0,0 +1,372 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + clock, + crypto::encrypt::Key as _, + msg, packet, + path::secret::Map, + random::Random, + stream::{ + application, recv, runtime, + send::{self, flow}, + server, shared, socket, TransportFeatures, + }, +}; +use core::{cell::UnsafeCell, future::Future}; +use s2n_quic_core::{ + dc, endpoint, + inet::{ExplicitCongestionNotification, SocketAddress}, + varint::VarInt, +}; +use s2n_quic_platform::features; +use std::{io, sync::Arc}; +use tracing::{debug_span, Instrument as _}; + +type Result = core::result::Result; + +#[cfg(feature = "tokio")] +pub mod tokio; + +pub trait Environment { + type Clock: Clone + clock::Clock; + + fn clock(&self) -> &Self::Clock; + fn gso(&self) -> features::Gso; + fn reader_rt(&self) -> runtime::ArcHandle; + fn spawn_reader>(&self, f: F); + fn writer_rt(&self) -> runtime::ArcHandle; + fn spawn_writer>(&self, f: F); +} + +pub struct SocketSet { + pub application: Box, + pub read_worker: Option, + pub write_worker: Option, + pub remote_addr: SocketAddress, + pub source_control_port: u16, + pub source_stream_port: Option, +} + +pub trait Peer { + type WorkerSocket: socket::Socket; + + fn features(&self) -> TransportFeatures; + fn with_source_control_port(&mut self, port: u16); + fn setup(self, env: &E) -> Result>; +} + +pub struct AcceptError { + pub secret_control: Vec, + pub peer: Option, + pub error: io::Error, +} + +pub struct Builder { + env: E, +} + +impl Builder { + #[inline] + pub fn new(env: E) -> Self { + Self { env } + } + + #[inline] + pub fn clock(&self) -> &E::Clock { + self.env.clock() + } + + #[inline] + pub fn open_stream

( + &self, + handshake_addr: SocketAddress, + peer: P, + map: &Map, + parameter_override: Option<&dyn Fn(dc::ApplicationParams) -> dc::ApplicationParams>, + ) -> Result + where + P: Peer, + { + // derive secrets for the new stream + let Some((sealer, opener, mut parameters)) = map.pair_for_peer(handshake_addr.into()) + else { + // the application didn't perform a handshake with the server before opening the stream + return Err(io::Error::new( + io::ErrorKind::NotFound, + format!("missing credentials for server: {handshake_addr}"), + )); + }; + + if let Some(o) = parameter_override { + parameters = o(parameters); + } + + // TODO get a flow ID. for now we'll use the sealer credentials + let key_id = sealer.credentials().key_id; + let stream_id = packet::stream::Id { + key_id, + is_reliable: true, + is_bidirectional: true, + }; + + let crypto = shared::Crypto::new(sealer, opener, map); + + self.build_stream( + peer, + stream_id, + None, + crypto, + parameters, + None, + None, + endpoint::Type::Client, + ) + } + + #[inline] + pub fn accept_stream

( + &self, + mut peer: P, + packet: &server::InitialPacket, + handshake: Option, + buffer: Option<&mut msg::recv::Message>, + map: &Map, + parameter_override: Option<&dyn Fn(dc::ApplicationParams) -> dc::ApplicationParams>, + ) -> Result> + where + P: Peer, + { + let credentials = &packet.credentials; + let mut secret_control = vec![]; + let Some((sealer, opener, mut parameters)) = + map.pair_for_credentials(credentials, &mut secret_control) + else { + let error = io::Error::new( + io::ErrorKind::NotFound, + format!("missing credentials for client: {credentials:?}"), + ); + let error = AcceptError { + secret_control, + peer: Some(peer), + error, + }; + return Err(error); + }; + + if let Some(o) = parameter_override { + parameters = o(parameters); + } + + // inform the value of what the source_control_port is + peer.with_source_control_port(packet.source_control_port); + + let crypto = shared::Crypto::new(sealer, opener, map); + + let res = self.build_stream( + peer, + packet.stream_id, + packet.source_stream_port, + crypto, + parameters, + handshake, + buffer, + endpoint::Type::Server, + ); + + match res { + Ok(stream) => Ok(stream), + Err(error) => { + let error = AcceptError { + secret_control, + peer: None, + error, + }; + Err(error) + } + } + } + + #[inline] + fn build_stream

( + &self, + peer: P, + stream_id: packet::stream::Id, + remote_stream_port: Option, + crypto: shared::Crypto, + parameters: dc::ApplicationParams, + handshake: Option, + recv_buffer: Option<&mut msg::recv::Message>, + endpoint_type: endpoint::Type, + ) -> Result + where + P: Peer, + { + let features = peer.features(); + + let sockets = peer.setup(&self.env)?; + + // construct shared reader state + let reader = + recv::shared::State::new(stream_id, ¶meters, handshake, features, recv_buffer); + + let writer = { + let worker = sockets + .write_worker + .map(|socket| (send::state::State::new(stream_id, ¶meters), socket)); + + let (flow_offset, send_quantum, bandwidth) = + if let Some((worker, _socket)) = worker.as_ref() { + let flow_offset = worker.flow_offset(); + let send_quantum = worker.send_quantum_packets(); + let bandwidth = Some(worker.cca.bandwidth()); + + (flow_offset, send_quantum, bandwidth) + } else { + debug_assert!( + features.is_flow_controlled(), + "transports without flow control need background workers" + ); + + let flow_offset = VarInt::MAX; + let send_quantum = 10; + let bandwidth = None; + + (flow_offset, send_quantum, bandwidth) + }; + + let flow = flow::non_blocking::State::new(flow_offset); + + let path = send::path::Info { + max_datagram_size: parameters.max_datagram_size, + send_quantum, + ecn: ExplicitCongestionNotification::Ect0, + next_expected_control_packet: VarInt::ZERO, + }; + + // construct shared writer state + let state = send::shared::State::new(flow, path, bandwidth); + + (state, worker) + }; + + // construct shared common state between readers/writers + let common = { + let application = send::application::state::State { + stream_id, + source_control_port: sockets.source_control_port, + source_stream_port: sockets.source_stream_port, + }; + + let fixed = shared::FixedValues { + remote_ip: UnsafeCell::new(sockets.remote_addr.ip()), + source_control_port: UnsafeCell::new(sockets.source_control_port), + application: UnsafeCell::new(application), + }; + + let remote_port = sockets.remote_addr.port(); + let write_remote_port = remote_stream_port.unwrap_or(remote_port); + + shared::Common { + clock: self.env.clock().clone(), + gso: self.env.gso(), + read_remote_port: remote_port.into(), + write_remote_port: write_remote_port.into(), + last_peer_activity: Default::default(), + fixed, + closed_halves: 0u8.into(), + } + }; + + let shared = Arc::new(shared::Shared { + receiver: reader, + sender: writer.0, + common, + crypto, + }); + + // spawn the read worker + if let Some(socket) = sockets.read_worker { + let shared = shared.clone(); + + let task = async move { + let mut reader = recv::worker::Worker::new(socket, shared, endpoint_type); + + let mut prev_waker: Option = None; + core::future::poll_fn(|cx| { + // update the waker if needed + if prev_waker + .as_ref() + .map_or(true, |prev| !prev.will_wake(cx.waker())) + { + prev_waker = Some(cx.waker().clone()); + reader.update_waker(cx); + } + + // drive the reader to completion + reader.poll(cx) + }) + .await; + }; + + let span = debug_span!("worker::read"); + + if span.is_disabled() { + self.env.spawn_reader(task); + } else { + self.env.spawn_reader(task.instrument(span)); + } + } + + // spawn the write worker + if let Some((worker, socket)) = writer.1 { + let shared = shared.clone(); + + let task = async move { + let mut writer = send::worker::Worker::new( + socket, + Random::default(), + shared, + worker, + endpoint_type, + ); + + let mut prev_waker: Option = None; + core::future::poll_fn(|cx| { + // update the waker if needed + if prev_waker + .as_ref() + .map_or(true, |prev| !prev.will_wake(cx.waker())) + { + prev_waker = Some(cx.waker().clone()); + writer.update_waker(cx); + } + + // drive the writer to completion + writer.poll(cx) + }) + .await; + }; + + let span = debug_span!("worker::write"); + + if span.is_disabled() { + self.env.spawn_writer(task); + } else { + self.env.spawn_writer(task.instrument(span)); + } + } + + let read = recv::application::Builder::new(endpoint_type, self.env.reader_rt()); + let write = send::application::Builder::new(self.env.writer_rt()); + + let stream = application::Builder { + read, + write, + shared, + sockets: sockets.application, + }; + + Ok(stream) + } +} diff --git a/dc/s2n-quic-dc/src/stream/environment/tokio.rs b/dc/s2n-quic-dc/src/stream/environment/tokio.rs new file mode 100644 index 0000000000..8e67f310f0 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/environment/tokio.rs @@ -0,0 +1,277 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + clock::tokio::Clock, + stream::{ + runtime::{tokio as runtime, ArcHandle}, + socket::{self, Socket as _}, + TransportFeatures, + }, +}; +use s2n_quic_core::{ + ensure, + inet::{SocketAddress, Unspecified}, +}; +use s2n_quic_platform::features; +use std::{io, net::UdpSocket, sync::Arc}; +use tokio::{io::unix::AsyncFd, net::TcpStream}; + +#[derive(Clone, Default)] +pub struct Builder { + clock: Option, + gso: Option, + socket_options: Option, + reader_rt: Option, + writer_rt: Option, + thread_name_prefix: Option, +} + +impl Builder { + #[inline] + pub fn build(self) -> io::Result { + let clock = self.clock.unwrap_or_default(); + let gso = self.gso.unwrap_or_default(); + let socket_options = self.socket_options.unwrap_or_default(); + + let thread_name_prefix = self.thread_name_prefix.as_deref().unwrap_or("dc_quic"); + + let reader_rt = self.reader_rt.map(>::Ok).unwrap_or_else(|| { + Ok(tokio::runtime::Builder::new_multi_thread() + .enable_all() + .thread_name(format!("{thread_name_prefix}::reader")) + .build()? + .into()) + })?; + let writer_rt = self.writer_rt.map(>::Ok).unwrap_or_else(|| { + Ok(tokio::runtime::Builder::new_multi_thread() + .enable_all() + .thread_name(format!("{thread_name_prefix}::writer")) + .build()? + .into()) + })?; + + Ok(Environment { + clock, + gso, + socket_options, + reader_rt, + writer_rt, + }) + } +} + +#[derive(Clone)] +pub struct Environment { + clock: Clock, + gso: features::Gso, + socket_options: socket::Options, + reader_rt: runtime::Shared, + writer_rt: runtime::Shared, +} + +impl Default for Environment { + #[inline] + fn default() -> Self { + Self::builder().build().unwrap() + } +} + +impl Environment { + #[inline] + pub fn builder() -> Builder { + Default::default() + } +} + +impl super::Environment for Environment { + type Clock = Clock; + + #[inline] + fn clock(&self) -> &Self::Clock { + &self.clock + } + + #[inline] + fn gso(&self) -> features::Gso { + self.gso.clone() + } + + #[inline] + fn reader_rt(&self) -> ArcHandle { + self.reader_rt.handle() + } + + #[inline] + fn spawn_reader>(&self, f: F) { + self.reader_rt.spawn(f); + } + + #[inline] + fn writer_rt(&self) -> ArcHandle { + self.writer_rt.handle() + } + + #[inline] + fn spawn_writer>(&self, f: F) { + self.writer_rt.spawn(f); + } +} + +#[derive(Clone, Copy, Debug)] +pub struct UdpUnbound(pub SocketAddress); + +impl super::Peer for UdpUnbound { + type WorkerSocket = AsyncFd>; + + #[inline] + fn features(&self) -> TransportFeatures { + TransportFeatures::UDP + } + + #[inline] + fn with_source_control_port(&mut self, port: u16) { + self.0.set_port(port); + } + + #[inline] + fn setup(self, env: &Environment) -> super::Result> { + let mut options = env.socket_options.clone(); + let remote_addr = self.0; + + match remote_addr { + SocketAddress::IpV6(_) if options.addr.is_ipv4() => { + let addr: SocketAddress = options.addr.into(); + if addr.ip().is_unspecified() { + options.addr.set_ip(std::net::Ipv6Addr::UNSPECIFIED.into()); + } else { + let addr = addr.to_ipv6_mapped(); + options.addr = addr.into(); + } + } + SocketAddress::IpV4(_) if options.addr.is_ipv6() => { + let addr: SocketAddress = options.addr.into(); + if addr.ip().is_unspecified() { + options.addr.set_ip(std::net::Ipv4Addr::UNSPECIFIED.into()); + } else { + let addr = addr.unmap(); + // ensure the local IP maps to v4, otherwise it won't bind correctly + ensure!( + matches!(addr, SocketAddress::IpV4(_)), + Err(io::ErrorKind::Unsupported.into()) + ); + options.addr = addr.into(); + } + } + _ => {} + } + + let socket::Pair { writer, reader } = socket::Pair::open(options)?; + + let writer = Arc::new(writer); + let reader = Arc::new(reader); + + let read_worker = { + let _guard = env.reader_rt.enter(); + AsyncFd::new(reader.clone())? + }; + + let write_worker = { + let _guard = env.writer_rt.enter(); + AsyncFd::new(writer.clone())? + }; + + // if we're on a platform that requires two different ports then we need to create + // a socket for the writer as well + let multi_port = read_worker.local_port()? != write_worker.local_port()?; + + let source_control_port = write_worker.local_port()?; + + // if the reader port is different from the writer then tell the peer + let source_stream_port = if multi_port { + Some(read_worker.local_port()?) + } else { + None + }; + + let application: Box = if multi_port { + Box::new(socket::application::builder::UdpPair { reader, writer }) + } else { + Box::new(reader) + }; + + let read_worker = Some(read_worker); + let write_worker = Some(write_worker); + + Ok(super::SocketSet { + application, + read_worker, + write_worker, + remote_addr, + source_control_port, + source_stream_port, + }) + } +} + +/// A socket that is already registered with the application runtime +pub struct TcpRegistered(pub TcpStream); + +impl super::Peer for TcpRegistered { + type WorkerSocket = TcpStream; + + fn features(&self) -> TransportFeatures { + TransportFeatures::TCP + } + + #[inline] + fn with_source_control_port(&mut self, port: u16) { + let _ = port; + } + + #[inline] + fn setup(self, _env: &Environment) -> super::Result> { + let remote_addr = self.0.peer_addr()?.into(); + let source_control_port = self.0.local_addr()?.port(); + let application = Box::new(self.0); + Ok(super::SocketSet { + application, + read_worker: None, + write_worker: None, + remote_addr, + source_control_port, + source_stream_port: None, + }) + } +} + +/// A socket that should be reregistered with the application runtime +pub struct TcpReregistered(pub TcpStream); + +impl super::Peer for TcpReregistered { + type WorkerSocket = TcpStream; + + fn features(&self) -> TransportFeatures { + TransportFeatures::TCP + } + + #[inline] + fn with_source_control_port(&mut self, port: u16) { + let _ = port; + } + + #[inline] + fn setup(self, _env: &Environment) -> super::Result> { + let remote_addr = self.0.peer_addr()?.into(); + let source_control_port = self.0.local_addr()?.port(); + let application = Box::new(self.0.into_std()?); + Ok(super::SocketSet { + application, + read_worker: None, + write_worker: None, + remote_addr, + source_control_port, + source_stream_port: None, + }) + } +} diff --git a/dc/s2n-quic-dc/src/stream/recv.rs b/dc/s2n-quic-dc/src/stream/recv.rs index 8dfe9ca8ec..e1b7092e31 100644 --- a/dc/s2n-quic-dc/src/stream/recv.rs +++ b/dc/s2n-quic-dc/src/stream/recv.rs @@ -1,849 +1,13 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use super::{TransportFeatures, DEFAULT_IDLE_TIMEOUT}; -use crate::{ - allocator::Allocator, - clock, - crypto::{decrypt, encrypt, UninitSlice}, - packet::{control, stream}, -}; -use core::{task::Poll, time::Duration}; -use s2n_codec::{EncoderBuffer, EncoderValue}; -use s2n_quic_core::{ - buffer::{self, reader::storage::Infallible as _}, - dc::ApplicationParams, - ensure, - frame::{self, ack::EcnCounts}, - inet::ExplicitCongestionNotification, - packet::number::PacketNumberSpace, - ready, - stream::state::Receiver as State, - time::{ - timer::{self, Provider as _}, - Clock, Timer, Timestamp, - }, - varint::VarInt, -}; - mod ack; +pub mod application; mod error; mod packet; mod probes; +pub mod shared; +pub mod state; +pub mod worker; pub use error::Error; - -#[derive(Debug)] -pub struct Receiver { - stream_id: stream::Id, - ecn_counts: EcnCounts, - control_packet_number: u64, - stream_ack: ack::Space, - recovery_ack: ack::Space, - state: State, - idle_timer: Timer, - idle_timeout: Duration, - // maintains a stable tick timer to avoid timer churn in the platform timer - tick_timer: Timer, - _should_transmit: bool, - max_data: VarInt, - max_data_window: VarInt, - error: Option, - fin_ack_packet_number: Option, - features: TransportFeatures, -} - -impl Receiver { - #[inline] - pub fn new( - stream_id: stream::Id, - params: &ApplicationParams, - features: TransportFeatures, - ) -> Self { - let initial_max_data = params.local_recv_max_data; - Self { - stream_id, - ecn_counts: Default::default(), - control_packet_number: Default::default(), - stream_ack: Default::default(), - recovery_ack: Default::default(), - state: Default::default(), - idle_timer: Default::default(), - idle_timeout: params.max_idle_timeout.unwrap_or(DEFAULT_IDLE_TIMEOUT), - tick_timer: Default::default(), - _should_transmit: false, - max_data: initial_max_data, - max_data_window: initial_max_data, - error: None, - fin_ack_packet_number: None, - features, - } - } - - #[inline] - pub fn id(&self) -> stream::Id { - self.stream_id - } - - #[inline] - pub fn state(&self) -> &State { - &self.state - } - - #[inline] - pub fn timer(&self) -> Option { - self.next_expiration() - } - - #[inline] - pub fn is_open(&self) -> bool { - !self.state.is_terminal() - } - - #[inline] - pub fn is_finished(&self) -> bool { - ensure!(self.state.is_terminal(), false); - ensure!(self.timer().is_none(), false); - true - } - - #[inline] - pub fn stop_sending(&mut self, error: s2n_quic_core::application::Error) { - // if we've already received everything then no need to notify the peer to stop - ensure!(matches!(self.state, State::Recv | State::SizeKnown)); - self.on_error(Error::ApplicationError { error }); - } - - #[inline] - pub fn on_read_buffer(&mut self, out_buf: &mut B, chunk: &mut C, _clock: &Clk) - where - B: buffer::Duplex, - C: buffer::writer::Storage, - Clk: Clock + ?Sized, - { - // try copying the out_buf into the application chunk, if possible - if chunk.has_remaining_capacity() && !out_buf.buffer_is_empty() { - out_buf.infallible_copy_into(chunk); - } - - // record our new max data value - let new_max_data = out_buf - .current_offset() - .saturating_add(self.max_data_window); - - if new_max_data > self.max_data { - self.max_data = new_max_data; - self.needs_transmission("new_max_data"); - } - - // if we know the final offset then update the sate - if out_buf.final_offset().is_some() { - let _ = self.state.on_receive_fin(); - } - - // if we've received everything then update the state - if out_buf.has_buffered_fin() && self.state.on_receive_all_data().is_ok() { - self.needs_transmission("receive_all_data"); - } - - // if we've completely drained the out buffer try transitioning to the final state - if out_buf.is_consumed() && self.state.on_app_read_all_data().is_ok() { - self.needs_transmission("app_read_all_data"); - } - } - - #[inline] - pub fn precheck_stream_packet( - &mut self, - packet: &stream::decoder::Packet, - ) -> Result<(), Error> { - match self.precheck_stream_packet_impl(packet) { - Ok(()) => Ok(()), - Err(err) => { - if err.is_fatal(&self.features) { - tracing::error!(fatal_error = %err, ?packet); - self.on_error(err); - } else { - tracing::debug!(non_fatal_error = %err); - } - Err(err) - } - } - } - - #[inline] - fn precheck_stream_packet_impl( - &mut self, - packet: &stream::decoder::Packet, - ) -> Result<(), Error> { - // make sure we're getting packets for the correct stream - ensure!( - *packet.stream_id() == self.stream_id, - Err(Error::StreamMismatch { - expected: self.stream_id, - actual: *packet.stream_id(), - }) - ); - - if self.features.is_stream() { - // if the transport is streaming then we expect packet numbers in order - let expected_pn = self - .stream_ack - .packets - .max_value() - .map_or(0, |v| v.as_u64() + 1); - let actual_pn = packet.packet_number().as_u64(); - ensure!( - expected_pn == actual_pn, - Err(Error::OutOfOrder { - expected: expected_pn, - actual: actual_pn, - }) - ); - } - - if self.features.is_reliable() { - // if the transport is reliable then we don't expect retransmissions - ensure!( - !packet.is_retransmission(), - Err(Error::UnexpectedRetransmission) - ); - } - - Ok(()) - } - - #[inline] - pub fn on_stream_packet( - &mut self, - opener: &D, - packet: &mut stream::decoder::Packet, - ecn: ExplicitCongestionNotification, - clock: &Clk, - out_buf: &mut B, - ) -> Result<(), Error> - where - D: decrypt::Key, - Clk: Clock + ?Sized, - B: buffer::Duplex, - { - probes::on_stream_packet( - opener.credentials().id, - self.stream_id, - packet.tag().packet_space(), - packet.packet_number(), - packet.stream_offset(), - packet.payload().len(), - packet.is_fin(), - packet.is_retransmission(), - ); - - match self.on_stream_packet_impl(opener, packet, ecn, clock, out_buf) { - Ok(()) => Ok(()), - Err(err) => { - if err.is_fatal(&self.features) { - tracing::error!(fatal_error = %err, ?packet); - self.on_error(err); - } else { - tracing::debug!(non_fatal_error = %err); - } - Err(err) - } - } - } - - #[inline] - fn on_stream_packet_impl( - &mut self, - opener: &D, - packet: &mut stream::decoder::Packet, - ecn: ExplicitCongestionNotification, - clock: &Clk, - out_buf: &mut B, - ) -> Result<(), Error> - where - D: decrypt::Key, - Clk: Clock + ?Sized, - B: buffer::Duplex, - { - use buffer::reader::Storage as _; - - self.precheck_stream_packet_impl(packet)?; - - let is_max_data_ok = self.ensure_max_data(packet); - - // wrap the parsed packet in a reader - let mut packet = packet::Packet { - packet: &mut *packet, - payload_cursor: 0, - is_decrypted_in_place: false, - ecn, - clock, - opener, - receiver: self, - }; - - if !is_max_data_ok { - // ensure the packet is authentic before resetting the stream - let _ = packet.read_chunk(usize::MAX)?; - - tracing::error!( - message = "max data exceeded", - allowed = packet.receiver.max_data.as_u64(), - requested = packet - .packet - .stream_offset() - .as_u64() - .saturating_add(packet.packet.payload().len() as u64), - ); - - let error = Error::MaxDataExceeded; - self.on_error(error); - return Err(error); - } - - // decrypt and write the packet to the provided buffer - out_buf.read_from(&mut packet)?; - - let mut chunk = buffer::writer::storage::Empty; - self.on_read_buffer(out_buf, &mut chunk, clock); - - Ok(()) - } - - #[inline] - fn on_stream_packet_in_place( - &mut self, - crypto: &D, - packet: &mut stream::decoder::Packet, - ecn: ExplicitCongestionNotification, - clock: &Clk, - ) -> Result<(), Error> - where - D: decrypt::Key, - Clk: Clock + ?Sized, - { - // ensure the packet is authentic before processing it - let res = packet.decrypt_in_place(crypto); - - probes::on_stream_packet_decrypted( - crypto.credentials().id, - self.stream_id, - packet.tag().packet_space(), - packet.packet_number(), - packet.stream_offset(), - packet.payload().len(), - packet.is_fin(), - packet.is_retransmission(), - res.is_ok(), - ); - - res?; - - self.on_cleartext_stream_packet(packet, ecn, clock) - } - - #[inline] - fn on_stream_packet_copy( - &mut self, - crypto: &D, - packet: &mut stream::decoder::Packet, - ecn: ExplicitCongestionNotification, - payload_out: &mut UninitSlice, - clock: &Clk, - ) -> Result<(), Error> - where - D: decrypt::Key, - Clk: Clock + ?Sized, - { - // ensure the packet is authentic before processing it - let res = packet.decrypt(crypto, payload_out); - - probes::on_stream_packet_decrypted( - crypto.credentials().id, - self.stream_id, - packet.tag().packet_space(), - packet.packet_number(), - packet.stream_offset(), - packet.payload().len(), - packet.is_fin(), - packet.is_retransmission(), - res.is_ok(), - ); - - res?; - - self.on_cleartext_stream_packet(packet, ecn, clock) - } - - #[inline] - fn ensure_max_data(&self, packet: &stream::decoder::Packet) -> bool { - // we only need to enforce flow control for non-controlled transport - ensure!(!self.features.is_flow_controlled(), true); - - self.max_data - .as_u64() - .checked_sub(packet.payload().len() as u64) - .and_then(|v| v.checked_sub(packet.stream_offset().as_u64())) - .is_some() - } - - #[inline] - fn on_cleartext_stream_packet( - &mut self, - packet: &stream::decoder::Packet, - ecn: ExplicitCongestionNotification, - clock: &Clk, - ) -> Result<(), Error> - where - Clk: Clock + ?Sized, - { - tracing::trace!( - stream_id = %packet.stream_id(), - stream_offset = packet.stream_offset().as_u64(), - payload_len = packet.payload().len(), - final_offset = ?packet.final_offset().map(|v| v.as_u64()), - ); - - let space = match packet.tag().packet_space() { - stream::PacketSpace::Stream => &mut self.stream_ack, - stream::PacketSpace::Recovery => &mut self.recovery_ack, - }; - ensure!( - space.filter.on_packet(packet).is_ok(), - Err(Error::Duplicate) - ); - - let packet_number = PacketNumberSpace::Initial.new_packet_number(packet.packet_number()); - if let Err(err) = space.packets.insert_packet_number(packet_number) { - tracing::debug!("could not record packet number {packet_number} with error {err:?}"); - } - - // if we got a new packet then we'll need to transmit an ACK - self.needs_transmission("new_packet"); - - // update the idle timer since we received a valid packet - if matches!(self.state, State::Recv | State::SizeKnown) - || packet.stream_offset() == VarInt::ZERO - { - self.update_idle_timer(clock); - } - - // TODO process control data - let _ = packet.control_data(); - - self.ecn_counts.increment(ecn); - - if !self.stream_id.is_reliable { - // TODO should we perform loss detection on the receiver and reset the stream if we have a big - // enough gap? - } - - // clean up any ACK state that we can - self.on_next_expected_control_packet(packet.next_expected_control_packet()); - - Ok(()) - } - - #[inline] - pub fn should_transmit(&self) -> bool { - self._should_transmit - } - - #[inline] - pub fn on_transport_close(&mut self) { - // only stream transports can be closed - ensure!(self.features.is_stream()); - - // only error out if we're still expecting more data - ensure!(matches!(self.state, State::Recv | State::SizeKnown)); - - self.on_error(Error::TruncatedTransport); - } - - #[inline] - fn needs_transmission(&mut self, reason: &str) { - if self.error.is_none() { - // we only transmit errors for reliable + flow-controlled transports - if self.features.is_reliable() && self.features.is_flow_controlled() { - tracing::trace!(skipping_transmission = reason); - return; - } - } - - if !self._should_transmit { - tracing::trace!(needs_transmission = reason); - } - self._should_transmit = true; - } - - #[inline] - fn on_next_expected_control_packet(&mut self, next_expected_control_packet: VarInt) { - if let Some(largest_delivered_control_packet) = - next_expected_control_packet.checked_sub(VarInt::from_u8(1)) - { - self.stream_ack - .on_largest_delivered_packet(largest_delivered_control_packet); - self.recovery_ack - .on_largest_delivered_packet(largest_delivered_control_packet); - - if let Some(fin_ack_packet_number) = self.fin_ack_packet_number { - // if the sender received our ACK to the fin, then we can shut down immediately - if largest_delivered_control_packet >= fin_ack_packet_number { - self.silent_shutdown(); - } - } - } - } - - #[inline] - fn update_idle_timer(&mut self, clock: &Clk) { - let target = clock.get_time() + self.idle_timeout; - self.idle_timer.set(target); - - // if the tick timer isn't armed then rearmed it; otherwise keep it stable to avoid - // churn - if !self.tick_timer.is_armed() { - self.tick_timer.set(target); - } - } - - #[inline] - fn mtu(&self) -> u16 { - // TODO should we pull this from somewhere - - // we want to make sure ACKs get through so use the minimum packet length for QUIC - 1200 - } - - #[inline] - fn ecn(&self) -> ExplicitCongestionNotification { - // TODO how do we decide what to send on control messages - ExplicitCongestionNotification::Ect0 - } - - #[inline] - pub fn on_error(&mut self, error: Error) { - debug_assert!(error.is_fatal(&self.features)); - let _ = self.state.on_reset(); - self.stream_ack.clear(); - self.recovery_ack.clear(); - self.needs_transmission("on_error"); - - // make sure we haven't already set the error from something else - ensure!(self.error.is_none()); - self.error = Some(error); - } - - #[inline] - pub fn check_error(&self) -> Result<(), Error> { - // if we already received/read all of the data then filter out errors - ensure!( - !matches!(self.state, State::DataRead | State::DataRecvd), - Ok(()) - ); - - if let Some(err) = self.error { - Err(err) - } else { - Ok(()) - } - } - - #[inline] - pub fn on_timeout(&mut self, clock: &Clk, load_last_activity: Ld) - where - Clk: Clock + ?Sized, - Ld: FnOnce() -> Timestamp, - { - let now = clock.get_time(); - if self.poll_idle_timer(clock, load_last_activity).is_ready() { - self.silent_shutdown(); - - // only transition to an error state if we didn't receive everything - ensure!(matches!(self.state, State::Recv | State::SizeKnown)); - - // we don't want to transmit anything so enter a terminal state - let mut did_transition = false; - did_transition |= self.state.on_reset().is_ok(); - did_transition |= self.state.on_app_read_reset().is_ok(); - if did_transition { - self.on_error(Error::IdleTimeout); - // override the transmission since we're just timing out - self._should_transmit = false; - } - - return; - } - - // if the tick timer expired, then copy the current idle timeout target - if self.tick_timer.poll_expiration(now).is_ready() { - self.tick_timer = self.idle_timer.clone(); - } - } - - #[inline] - fn poll_idle_timer(&mut self, clock: &Clk, load_last_activity: Ld) -> Poll<()> - where - Clk: Clock + ?Sized, - Ld: FnOnce() -> Timestamp, - { - let now = clock.get_time(); - - // check the idle timer first - ready!(self.idle_timer.poll_expiration(now)); - - // if that expired then load the last activity from the peer and update the idle timer with - // the value - let last_peer_activity = load_last_activity(); - self.update_idle_timer(&last_peer_activity); - - // check the idle timer once more before returning - ready!(self.idle_timer.poll_expiration(now)); - - Poll::Ready(()) - } - - #[inline] - fn silent_shutdown(&mut self) { - self._should_transmit = false; - self.idle_timer.cancel(); - self.tick_timer.cancel(); - self.stream_ack.clear(); - self.recovery_ack.clear(); - } - - #[inline] - pub fn on_transmit( - &mut self, - encrypt_key: &E, - source_control_port: u16, - output: &mut A, - clock: &Clk, - ) where - E: encrypt::Key, - A: Allocator, - Clk: Clock + ?Sized, - { - (if self.error.is_none() { - Self::on_transmit_ack - } else { - Self::on_transmit_error - })( - self, - encrypt_key, - source_control_port, - output, - // avoid querying the clock for every transmitted packet - &clock::Cached::new(clock), - ) - } - - #[inline] - fn on_transmit_ack( - &mut self, - encrypt_key: &E, - source_control_port: u16, - output: &mut A, - _clock: &Clk, - ) where - E: encrypt::Key, - A: Allocator, - Clk: Clock + ?Sized, - { - ensure!(self.should_transmit()); - - let mtu = self.mtu(); - - output.set_ecn(self.ecn()); - - let packet_number = self.next_pn(); - - ensure!(let Some(segment) = output.alloc()); - - let buffer = output.get_mut(&segment); - buffer.resize(mtu as _, 0); - - let encoder = EncoderBuffer::new(buffer); - - // TODO compute this by storing the time that we received the max packet number - let ack_delay = VarInt::ZERO; - - let max_data = frame::MaxData { - maximum_data: self.max_data, - }; - let max_data_encoding_size: VarInt = max_data.encoding_size().try_into().unwrap(); - - let (stream_ack, max_data_encoding_size) = self.stream_ack.encoding( - max_data_encoding_size, - ack_delay, - Some(self.ecn_counts), - mtu, - ); - let (recovery_ack, max_data_encoding_size) = - self.recovery_ack - .encoding(max_data_encoding_size, ack_delay, None, mtu); - - let encoding_size = max_data_encoding_size; - - tracing::trace!(?stream_ack, ?recovery_ack, ?max_data); - - let frame = ((max_data, stream_ack), recovery_ack); - - let result = control::encoder::encode( - encoder, - source_control_port, - Some(self.stream_id), - packet_number, - VarInt::ZERO, - &mut &[][..], - encoding_size, - &frame, - encrypt_key, - ); - - match result { - 0 => { - output.free(segment); - return; - } - packet_len => { - buffer.truncate(packet_len); - // TODO duplicate the transmission in case we have a lot of gaps in packets - output.push(segment); - } - } - - for (ack, space) in [ - (&self.stream_ack, stream::PacketSpace::Stream), - (&self.recovery_ack, stream::PacketSpace::Recovery), - ] { - let metrics = ( - ack.packets.min_value(), - ack.packets.max_value(), - ack.packets.interval_len().checked_sub(1), - ); - if let (Some(min), Some(max), Some(gaps)) = metrics { - probes::on_transmit_control( - encrypt_key.credentials().id, - self.stream_id, - space, - packet_number, - min, - max, - gaps, - ); - }; - } - - // make sure we sent a packet - ensure!(!output.is_empty()); - - // record the max value we've seen for removing old packet numbers - self.stream_ack.on_transmit(packet_number); - self.recovery_ack.on_transmit(packet_number); - - self.on_packet_sent(packet_number); - } - - #[inline] - fn on_transmit_error( - &mut self, - encrypt_key: &E, - source_control_port: u16, - output: &mut A, - _clock: &Clk, - ) where - E: encrypt::Key, - A: Allocator, - Clk: Clock + ?Sized, - { - ensure!(self.should_transmit()); - - let mtu = self.mtu() as usize; - - output.set_ecn(self.ecn()); - - let packet_number = self.next_pn(); - - ensure!(let Some(segment) = output.alloc()); - - let buffer = output.get_mut(&segment); - buffer.resize(mtu, 0); - - let encoder = EncoderBuffer::new(buffer); - - let frame = self - .error - .as_ref() - .and_then(|err| err.connection_close()) - .unwrap_or_else(|| s2n_quic_core::transport::Error::NO_ERROR.into()); - - let encoding_size = frame.encoding_size().try_into().unwrap(); - - let result = control::encoder::encode( - encoder, - source_control_port, - Some(self.stream_id), - packet_number, - VarInt::ZERO, - &mut &[][..], - encoding_size, - &frame, - encrypt_key, - ); - - match result { - 0 => { - output.free(segment); - return; - } - packet_len => { - buffer.truncate(packet_len); - output.push(segment); - } - } - - tracing::debug!(connection_close = ?frame); - - // clean things up - self.stream_ack.clear(); - self.recovery_ack.clear(); - - probes::on_transmit_close( - encrypt_key.credentials().id, - self.stream_id, - packet_number, - frame.error_code, - ); - - self.on_packet_sent(packet_number); - } - - #[inline] - fn next_pn(&mut self) -> VarInt { - VarInt::new(self.control_packet_number).expect("2^62 is a lot of packets") - } - - #[inline] - fn on_packet_sent(&mut self, packet_number: VarInt) { - // record the fin_ack packet number so we can shutdown more quickly - if !matches!(self.state, State::Recv | State::SizeKnown) - && self.fin_ack_packet_number.is_none() - { - self.fin_ack_packet_number = Some(packet_number); - } - - self.control_packet_number += 1; - self._should_transmit = false; - } -} - -impl timer::Provider for Receiver { - #[inline] - fn timers(&self, query: &mut Q) -> timer::Result { - self.idle_timer.timers(query)?; - self.tick_timer.timers(query)?; - Ok(()) - } -} diff --git a/dc/s2n-quic-dc/src/stream/recv/application.rs b/dc/s2n-quic-dc/src/stream/recv/application.rs new file mode 100644 index 0000000000..699d6d9a01 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/recv/application.rs @@ -0,0 +1,368 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + clock::Timer, + msg, + stream::{recv, runtime, shared::ArcShared, socket}, +}; +use core::{ + fmt, + mem::ManuallyDrop, + pin::Pin, + task::{Context, Poll}, +}; +use s2n_quic_core::{ + buffer::{self, writer::Storage as _}, + ensure, ready, + stream::state, + task::waker, + time::{clock::Timer as _, timer::Provider as _}, +}; +use std::{io, net::SocketAddr}; + +mod builder; +pub use builder::Builder; + +pub use crate::stream::recv::shared::AckMode; + +/// Defines what strategy to use when writing to the provided buffer +#[derive(Clone, Copy, Debug, Default)] +pub enum ReadMode { + /// Will attempt to read packets from the socket until the application buffer is full + UntilFull, + /// Will only attempt to read packets once + #[default] + Once, + /// Will attempt to drain the socket, even if the buffer isn't capable of reading it right now + Drain, +} + +pub struct Reader(ManuallyDrop>); + +pub(crate) struct Inner { + shared: ArcShared, + sockets: socket::ArcApplication, + send_buffer: msg::send::Message, + read_mode: ReadMode, + ack_mode: AckMode, + timer: Option, + local_state: LocalState, + runtime: runtime::ArcHandle, +} + +impl fmt::Debug for Reader { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Reader") + .field("peer_addr", &self.peer_addr().unwrap()) + .field("local_addr", &self.local_addr().unwrap()) + .finish() + } +} + +#[derive(Clone, Debug, Default)] +enum LocalState { + #[default] + Ready, + Reading, + Drained, + Errored(recv::Error), +} + +impl LocalState { + #[inline] + fn check(&self) -> Option> { + match self { + Self::Ready | Self::Reading => None, + Self::Drained => Some(Ok(())), + Self::Errored(err) => Some(Err((*err).into())), + } + } + + #[inline] + fn on_read(&mut self) { + ensure!(matches!(self, Self::Ready)); + *self = Self::Reading; + } + + #[inline] + fn transition(&mut self, target: Self, shared: &ArcShared) { + ensure!(matches!(self, Self::Ready | Self::Reading)); + *self = target; + + shared + .common + .closed_halves + .fetch_add(1, core::sync::atomic::Ordering::Relaxed); + } +} + +impl Reader { + #[inline] + pub fn peer_addr(&self) -> io::Result { + self.0.shared.common.ensure_open()?; + Ok(self.0.shared.read_remote_addr().into()) + } + + #[inline] + pub fn local_addr(&self) -> io::Result { + self.0.sockets.read_application().local_addr() + } + + #[inline] + pub fn protocol(&self) -> socket::Protocol { + self.0.sockets.protocol() + } + + #[inline] + pub async fn read_into(&mut self, out_buf: &mut S) -> io::Result + where + S: buffer::writer::Storage, + { + core::future::poll_fn(|cx| self.poll_read_into(cx, out_buf)).await + } + + #[inline] + pub fn poll_read_into( + &mut self, + cx: &mut Context, + out_buf: &mut S, + ) -> Poll> + where + S: buffer::writer::Storage, + { + waker::debug_assert_contract(cx, |cx| { + let mut out_buf = out_buf.track_write(); + let res = self.0.poll_read_into(cx, &mut out_buf); + + if res.is_pending() { + debug_assert_eq!( + out_buf.written_len(), + 0, + "bytes should only be written on Ready(_)" + ); + } + + let res = ready!(res); + // record the first time we get `Poll::Ready` + self.0.local_state.on_read(); + res?; + + Ok(out_buf.written_len()).into() + }) + } +} + +impl Inner { + #[inline(always)] + fn poll_read_into( + &mut self, + cx: &mut Context, + out_buf: &mut buffer::writer::storage::Tracked, + ) -> Poll> + where + S: buffer::writer::Storage, + { + if let Some(res) = self.local_state.check() { + return res.into(); + } + + // force a read on the socket if the application gave us an empty buffer + let mut force_recv = !out_buf.has_remaining_capacity(); + + let shared = &self.shared; + let sockets = &self.sockets; + let transport_features = sockets.read_application().features(); + + let mut reader = shared.receiver.application_guard( + self.ack_mode, + &mut self.send_buffer, + shared, + sockets, + )?; + let reader = &mut *reader; + + loop { + // try to process any bytes we have in the recv buffer + reader.process_recv_buffer(out_buf, shared, transport_features); + + // if we still have remaining capacity in the `out_buf` make sure the reassembler is + // fully drained + if cfg!(debug_assertions) && out_buf.has_remaining_capacity() { + assert!(reader.reassembler.is_empty()); + } + + // make sure we don't have an error + if let Err(err) = reader.receiver.check_error() { + self.local_state + .transition(LocalState::Errored(err), &self.shared); + return Err(err.into()).into(); + } + + match reader.receiver.state() { + state::Receiver::Recv | state::Receiver::SizeKnown => { + // we haven't received everything so we still need to read from the socket + } + state::Receiver::DataRecvd => { + // make sure we have capacity in the buffer before looping back around + ensure!(out_buf.has_remaining_capacity(), Ok(()).into()); + + // if we've received everything from the sender then no need to poll + // the socket at this point + continue; + } + // if we've copied the entire buffer into the application then just return + state::Receiver::DataRead => { + self.local_state + .transition(LocalState::Drained, &self.shared); + break; + } + // we already checked for an error above + state::Receiver::ResetRecvd | state::Receiver::ResetRead => unreachable!(), + } + + match self.read_mode { + // ignore the mode if we have a forced receive + _ if force_recv => {} + // if we've completely filled the `out_buf` then we're done + ReadMode::UntilFull if !out_buf.has_remaining_capacity() => break, + // if we've read at least one byte then we're done + ReadMode::Once if out_buf.written_len() > 0 => break, + // otherwise keep going + _ => {} + } + + let before_len = reader.recv_buffer.payload_len(); + + let recv = reader.poll_fill_recv_buffer(cx, self.sockets.read_application()); + + match Self::handle_socket_result(cx, &mut reader.receiver, &mut self.timer, recv) { + Poll::Ready(res) => res?, + // if we've written at least one byte then return that amount + Poll::Pending if out_buf.written_len() > 0 => break, + Poll::Pending => return Poll::Pending, + } + + // clear the forced receive after performing it once + force_recv = false; + + let after_len = reader.recv_buffer.payload_len(); + + if before_len == after_len { + if transport_features.is_stream() { + // if we got a 0-length read then the stream was closed - notify the receiver + reader.receiver.on_transport_close(); + continue; + } else { + debug_assert!(false, "datagram recv buffers should never be empty"); + } + } + } + + Ok(()).into() + } + + #[inline] + fn handle_socket_result( + cx: &mut Context, + receiver: &mut recv::state::State, + timer: &mut Option, + res: Poll>, + ) -> Poll> { + if let Poll::Ready(res) = res { + return res.into(); + } + + // only check the timer if we have one + let Some(timer) = timer.as_mut() else { + return Poll::Pending; + }; + + // if we didn't get any packets then poll the timer + if let Some(target) = receiver.next_expiration() { + timer.update(target); + ready!(timer.poll_ready(cx)); + + // if the timer expired then keep going, even if the recv buffer is empty + Ok(()).into() + } else { + timer.cancel(); + Poll::Pending + } + } + + #[inline] + fn shutdown(mut self: Box) { + // If we haven't exited the `Ready` state then spawn a task to do it for the application + // + // This is important for processing any secret control packets that the server sends us + if let LocalState::Ready = self.local_state { + tracing::debug!("spawning task to read server's response"); + let runtime = self.runtime.clone(); + let handle = Shutdown(self); + runtime.spawn_recv_shutdown(handle); + return; + } + + // update the common closed state if needed + self.local_state + .transition(LocalState::Drained, &self.shared); + + // let the peer know if we shut down cleanly + let is_panicking = std::thread::panicking(); + + self.shared.receiver.shutdown(is_panicking); + } +} + +#[cfg(feature = "tokio")] +impl tokio::io::AsyncRead for Reader { + #[inline] + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + let mut buf = buffer::writer::storage::BufMut::new(buf); + ready!(self.poll_read_into(cx, &mut buf))?; + Ok(()).into() + } +} + +impl Drop for Reader { + #[inline] + fn drop(&mut self) { + let inner = unsafe { + // SAFETY: the inner type is only taken once + ManuallyDrop::take(&mut self.0) + }; + inner.shutdown(); + } +} + +pub struct Shutdown(Box); + +impl core::future::Future for Shutdown { + type Output = (); + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> { + let mut storage = buffer::writer::storage::Empty; + let _ = ready!(self.0.poll_read_into(cx, &mut storage.track_write())); + Poll::Ready(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[allow(dead_code)] + fn shutdown_traits_test(shutdown: &Shutdown) { + use crate::testing::*; + + assert_send(shutdown); + assert_sync(shutdown); + assert_static(shutdown); + } +} diff --git a/dc/s2n-quic-dc/src/stream/recv/application/builder.rs b/dc/s2n-quic-dc/src/stream/recv/application/builder.rs new file mode 100644 index 0000000000..5fa8745aa7 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/recv/application/builder.rs @@ -0,0 +1,67 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + clock::Timer, + msg, + stream::{ + recv::application::{Inner, LocalState, Reader}, + runtime, + shared::ArcShared, + socket, + }, +}; +use core::mem::ManuallyDrop; +use s2n_quic_core::endpoint; + +pub struct Builder { + endpoint: endpoint::Type, + runtime: runtime::ArcHandle, +} + +impl Builder { + #[inline] + pub fn new(endpoint: endpoint::Type, runtime: runtime::ArcHandle) -> Self { + Self { endpoint, runtime } + } +} + +impl Builder { + #[inline] + pub fn build(self, shared: ArcShared, sockets: socket::ArcApplication) -> Reader { + let Self { endpoint, runtime } = self; + + let remote_addr = shared.read_remote_addr(); + // we only need a timer for unreliable transports + let is_reliable = sockets.read_application().features().is_reliable(); + let timer = if is_reliable { + None + } else { + Some(Timer::new(&shared.clock)) + }; + let gso = shared.gso.clone(); + let send_buffer = msg::send::Message::new(remote_addr, gso); + let read_mode = Default::default(); + let ack_mode = Default::default(); + let local_state = match endpoint { + // reliable transports on the client need to read at least one packet in order to + // process secret control packets + endpoint::Type::Client if is_reliable => LocalState::Ready, + // unreliable transports use background workers to drive state + endpoint::Type::Client => LocalState::Reading, + // the server acceptor already read from the socket at least once + endpoint::Type::Server => LocalState::Reading, + }; + + Reader(ManuallyDrop::new(Box::new(Inner { + shared, + sockets, + send_buffer, + read_mode, + ack_mode, + timer, + local_state, + runtime, + }))) + } +} diff --git a/dc/s2n-quic-dc/src/stream/recv/error.rs b/dc/s2n-quic-dc/src/stream/recv/error.rs index 0132d78b16..f95b2a42b2 100644 --- a/dc/s2n-quic-dc/src/stream/recv/error.rs +++ b/dc/s2n-quic-dc/src/stream/recv/error.rs @@ -1,7 +1,11 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::{crypto::decrypt, packet::stream}; +use crate::{ + crypto::decrypt, + packet::{self, stream}, + stream::TransportFeatures, +}; use s2n_quic_core::{buffer, frame}; #[derive(Clone, Copy, Debug, thiserror::Error)] @@ -39,6 +43,8 @@ pub enum Error { ApplicationError { error: s2n_quic_core::application::Error, }, + #[error("unexpected packet: {packet:?}")] + UnexpectedPacket { packet: packet::Kind }, } impl From for Error { @@ -55,7 +61,7 @@ impl From for Error { impl Error { #[inline] - pub(super) fn is_fatal(&self, features: &super::TransportFeatures) -> bool { + pub(super) fn is_fatal(&self, features: &TransportFeatures) -> bool { // if the transport is a stream then any error we encounter is fatal, since the stream is // now likely corrupted if features.is_stream() { @@ -76,6 +82,7 @@ impl Error { | Error::Decrypt | Error::Duplicate | Error::StreamMismatch { .. } + | Error::UnexpectedPacket { .. } | Error::UnexpectedRetransmission => { // return protocol violation for the errors that are only fatal for reliable // transports @@ -132,6 +139,15 @@ impl From for std::io::ErrorKind { Error::KeyReplayPrevented => ErrorKind::PermissionDenied, Error::KeyReplayMaybePrevented { .. } => ErrorKind::PermissionDenied, Error::ApplicationError { .. } => ErrorKind::ConnectionReset, + Error::UnexpectedPacket { + packet: + packet::Kind::UnknownPathSecret + | packet::Kind::StaleKey + | packet::Kind::ReplayDetected, + } => ErrorKind::ConnectionRefused, + Error::UnexpectedPacket { + packet: packet::Kind::Stream | packet::Kind::Control | packet::Kind::Datagram, + } => ErrorKind::InvalidData, } } } diff --git a/dc/s2n-quic-dc/src/stream/recv/packet.rs b/dc/s2n-quic-dc/src/stream/recv/packet.rs index 16e4e90b20..66b44a490e 100644 --- a/dc/s2n-quic-dc/src/stream/recv/packet.rs +++ b/dc/s2n-quic-dc/src/stream/recv/packet.rs @@ -1,8 +1,17 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use super::*; -use s2n_quic_core::buffer::{reader, writer, Reader}; +use crate::{ + crypto::decrypt, + packet::stream, + stream::recv::{state::State as Receiver, Error}, +}; +use s2n_quic_core::{ + buffer::{reader, writer, Reader}, + inet::ExplicitCongestionNotification, + time::Clock, + varint::VarInt, +}; pub struct Packet<'a, 'p, D: decrypt::Key, C: Clock + ?Sized> { pub packet: &'a mut stream::decoder::Packet<'p>, diff --git a/dc/s2n-quic-dc/src/stream/recv/shared.rs b/dc/s2n-quic-dc/src/stream/recv/shared.rs new file mode 100644 index 0000000000..6f7195eb09 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/recv/shared.rs @@ -0,0 +1,575 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + allocator::Allocator, + clock, msg, + packet::{stream, Packet}, + stream::{ + recv, + server::handshake, + shared::{ArcShared, Half}, + socket::{self, Socket}, + TransportFeatures, + }, + task::waker::worker::Waker as WorkerWaker, +}; +use core::{ + mem::ManuallyDrop, + ops, + task::{Context, Poll}, +}; +use s2n_codec::{DecoderBufferMut, DecoderError}; +use s2n_quic_core::{buffer, dc, ensure, ready, stream::state, time::Clock}; +use std::{ + io, + sync::{ + atomic::{AtomicU64, AtomicU8, Ordering}, + Mutex, MutexGuard, + }, +}; + +/// Who will send ACKs? +#[derive(Clone, Copy, Debug, Default)] +pub enum AckMode { + /// The application task is sending ACKs + #[default] + Application, + /// The worker task is sending ACKs + Worker, +} + +pub enum ApplicationState { + Open, + Closed { is_panicking: bool }, +} + +impl ApplicationState { + const IS_CLOSED_MASK: u8 = 1; + const IS_PANICKING_MASK: u8 = 1 << 1; + + #[inline] + fn load(shared: &AtomicU8) -> Self { + let value = shared.load(Ordering::Acquire); + if value == 0 { + return Self::Open; + } + + let is_panicking = value & Self::IS_PANICKING_MASK != 0; + + Self::Closed { is_panicking } + } + + #[inline] + fn close(shared: &AtomicU8, is_panicking: bool) { + let mut value = Self::IS_CLOSED_MASK; + + if is_panicking { + value |= Self::IS_PANICKING_MASK; + } + + shared.store(value, Ordering::Release); + } +} + +#[derive(Debug)] +pub struct State { + inner: Mutex, + application_epoch: AtomicU64, + application_state: AtomicU8, + pub worker_waker: WorkerWaker, +} + +impl State { + #[inline] + pub fn new( + stream_id: stream::Id, + params: &dc::ApplicationParams, + handshake: Option, + features: TransportFeatures, + recv_buffer: Option<&mut msg::recv::Message>, + ) -> Self { + let recv_buffer = match recv_buffer { + Some(prev) => prev.take(), + None => msg::recv::Message::new(9000u16), + }; + let receiver = recv::state::State::new(stream_id, params, features); + let reassembler = Default::default(); + let inner = Inner { + receiver, + reassembler, + handshake, + recv_buffer, + }; + let inner = Mutex::new(inner); + Self { + inner, + application_epoch: AtomicU64::new(0), + application_state: AtomicU8::new(0), + worker_waker: Default::default(), + } + } + + #[inline] + pub fn application_state(&self) -> ApplicationState { + ApplicationState::load(&self.application_state) + } + + #[inline] + pub fn application_epoch(&self) -> u64 { + self.application_epoch.load(Ordering::Acquire) + } + + #[inline] + pub fn application_guard<'a>( + &'a self, + ack_mode: AckMode, + send_buffer: &'a mut msg::send::Message, + shared: &'a ArcShared, + sockets: &'a dyn socket::Application, + ) -> io::Result> { + // increment the epoch at which we acquired the guard + self.application_epoch.fetch_add(1, Ordering::AcqRel); + + let inner = self.inner.lock().map_err(|_| { + io::Error::new(io::ErrorKind::Other, "shared recv state has been poisoned") + })?; + + let initial_state = inner.receiver.state().clone(); + + let inner = ManuallyDrop::new(inner); + + Ok(AppGuard { + inner, + ack_mode, + send_buffer, + shared, + sockets, + initial_state, + }) + } + + #[inline] + pub fn shutdown(&self, is_panicking: bool) { + ApplicationState::close(&self.application_state, is_panicking); + self.worker_waker.wake(); + } + + #[inline] + pub fn worker_try_lock(&self) -> io::Result>> { + match self.inner.try_lock() { + Ok(lock) => Ok(Some(lock)), + Err(std::sync::TryLockError::WouldBlock) => Ok(None), + Err(_) => Err(io::Error::new( + io::ErrorKind::Other, + "shared recv state has been poisoned", + )), + } + } +} + +pub struct AppGuard<'a> { + inner: ManuallyDrop>, + ack_mode: AckMode, + send_buffer: &'a mut msg::send::Message, + shared: &'a ArcShared, + sockets: &'a dyn socket::Application, + initial_state: state::Receiver, +} + +impl<'a> AppGuard<'a> { + /// Returns `true` if the read worker should be woken + #[inline] + fn send_ack(&mut self) -> bool { + // we only send ACKs for unreliable protocols + ensure!( + !self.sockets.read_application().features().is_reliable(), + false + ); + + match self.ack_mode { + AckMode::Application => { + self.inner + .fill_transmit_queue(self.shared, self.send_buffer); + + ensure!(!self.send_buffer.is_empty(), false); + + let did_send = self + .sockets + .read_application() + .try_send_buffer(self.send_buffer) + .is_ok(); + + // clear out the sender buffer if we didn't already + let _ = self.send_buffer.drain(); + + // only wake the worker if we weren't able to transmit the ACK + !did_send + } + AckMode::Worker => { + // only wake the worker if the receiver says we should + self.inner.receiver.should_transmit() + } + } + } +} + +impl<'a> ops::Deref for AppGuard<'a> { + type Target = Inner; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl<'a> ops::DerefMut for AppGuard<'a> { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl<'a> Drop for AppGuard<'a> { + #[inline] + fn drop(&mut self) { + let wake_worker_for_ack = self.send_ack(); + + let current_state = self.inner.receiver.state().clone(); + + unsafe { + // SAFETY: inner is no longer used + ManuallyDrop::drop(&mut self.inner); + } + + if wake_worker_for_ack && !current_state.is_terminal() { + // TODO wake the worker + } + + // no need to look at anything if the state didn't change + ensure!(self.initial_state != current_state); + + // shut down the worker if we're in a terminal state + if current_state.is_terminal() { + self.shared.receiver.shutdown(false); + } + } +} + +#[derive(Debug)] +pub struct Inner { + pub receiver: recv::state::State, + pub reassembler: buffer::Reassembler, + pub recv_buffer: msg::recv::Message, + pub handshake: Option, +} + +impl Inner { + #[inline] + pub fn fill_transmit_queue( + &mut self, + shared: &ArcShared, + send_buffer: &mut msg::send::Message, + ) { + let source_control_port = shared.source_control_port(); + + shared.crypto.seal_with(|sealer| { + self.receiver + .on_transmit(sealer, source_control_port, send_buffer, &shared.clock) + }); + + ensure!(!send_buffer.is_empty()); + + // Update the remote address with the latest value + send_buffer.set_remote_address(shared.read_remote_addr()); + } + + #[inline] + pub fn poll_fill_recv_buffer(&mut self, cx: &mut Context, socket: &S) -> Poll> + where + S: ?Sized + Socket, + { + loop { + if let Some(chan) = self.handshake.as_mut() { + match chan.poll_recv(cx) { + Poll::Ready(Some(recv_buffer)) => { + debug_assert!(!recv_buffer.is_empty()); + // no point in doing anything with an empty buffer + ensure!(!recv_buffer.is_empty(), continue); + // we got a buffer from the handshake so return and process it + self.recv_buffer = recv_buffer; + return Ok(()).into(); + } + Poll::Ready(None) => { + // the channel was closed so drop it + self.handshake = None; + } + Poll::Pending => { + // keep going and read the socket + } + } + } + + ready!(socket.poll_recv_buffer(cx, &mut self.recv_buffer))?; + + return Ok(()).into(); + } + } + + #[inline] + pub fn process_recv_buffer( + &mut self, + out_buf: &mut impl buffer::writer::Storage, + shared: &ArcShared, + features: TransportFeatures, + ) -> bool { + let clock = clock::Cached::new(&shared.clock); + let clock = &clock; + + // try copying data out of the reassembler into the application buffer + self.receiver + .on_read_buffer(&mut self.reassembler, out_buf, clock); + + // check if we have any packets to process + if !self.recv_buffer.is_empty() { + if features.is_stream() { + self.dispatch_buffer_stream(out_buf, shared, clock, features) + } else { + self.dispatch_buffer_datagram(out_buf, shared, clock, features) + } + + // if we processed packets then we may have data to copy out + self.receiver + .on_read_buffer(&mut self.reassembler, out_buf, clock); + } + + // we only check for timeouts on unreliable transports + if !features.is_reliable() { + self.receiver + .on_timeout(clock, || shared.last_peer_activity()); + } + + // indicate to the caller if we need to transmit an ACK + self.receiver.should_transmit() + } + + #[inline] + fn dispatch_buffer_stream( + &mut self, + out_buf: &mut impl buffer::writer::Storage, + shared: &ArcShared, + clock: &C, + features: TransportFeatures, + ) { + let msg = &mut self.recv_buffer; + let remote_addr = msg.remote_address(); + let ecn = msg.ecn(); + let tag_len = shared.crypto.tag_len(); + + let mut any_valid_packets = false; + let mut did_complete_handshake = false; + + let mut prev_packet_len = None; + + let mut out_buf = buffer::duplex::Interposer::new(out_buf, &mut self.reassembler); + + loop { + // consume the previous packet + if let Some(packet_len) = prev_packet_len.take() { + msg.consume(packet_len); + } + + let segment = msg.peek(); + ensure!(!segment.is_empty(), break); + + let initial_len = segment.len(); + let decoder = DecoderBufferMut::new(segment); + + let mut packet = match decoder.decode_parameterized(tag_len) { + Ok((packet, remaining)) => { + prev_packet_len = Some(initial_len - remaining.len()); + packet + } + Err(decoder_error) => { + if let DecoderError::UnexpectedEof(len) = decoder_error { + // if making the buffer contiguous resulted in the slice increasing, then + // try to parse a packet again + if msg.make_contiguous().len() > initial_len { + continue; + } + + // otherwise, we'll need to receive more bytes from the stream to correctly + // parse a packet + + // if we have pending data greater than the max record size then it's never going to parse + let max_datagram_size = + shared.sender.path.load().max_datagram_size as usize; + + if msg.payload_len() > max_datagram_size { + tracing::error!( + unconsumed = msg.payload_len(), + remaining_capacity = msg.remaining_capacity() + ); + msg.clear(); + self.receiver.on_error(recv::Error::Decode); + return; + } + + tracing::trace!( + protocol_features = ?features, + unexpected_eof = len, + buffer_len = initial_len + ); + + break; + } + + tracing::error!( + protocol_features = ?features, + fatal_error = %decoder_error, + payload_len = msg.payload_len() + ); + + // any other decoder errors mean the stream has been corrupted so + // it's time to shut down the connection + msg.clear(); + self.receiver.on_error(recv::Error::Decode); + return; + } + }; + + tracing::trace!(?packet); + + match &mut packet { + Packet::Stream(packet) => { + debug_assert_eq!(Some(packet.total_len()), prev_packet_len); + + // make sure the packet looks OK before deriving openers from it + if self.receiver.precheck_stream_packet(packet).is_err() { + // check if the receiver returned an error + if self.receiver.check_error().is_err() { + msg.clear(); + return; + } else { + // move on to the next packet + continue; + } + } + + let _ = shared.crypto.open_with(|opener| { + self.receiver + .on_stream_packet(opener, packet, ecn, clock, &mut out_buf)?; + + any_valid_packets = true; + did_complete_handshake |= + packet.next_expected_control_packet().as_u64() > 0; + + >::Ok(()) + }); + + if self.receiver.check_error().is_err() { + msg.clear(); + return; + } + } + other => { + let kind = other.kind(); + shared.crypto.map().handle_unexpected_packet(other); + + // if we get a packet we don't expect then it's fatal for streams + msg.clear(); + self.receiver + .on_error(recv::Error::UnexpectedPacket { packet: kind }); + return; + } + } + } + + if let Some(len) = prev_packet_len.take() { + msg.consume(len); + } + + if any_valid_packets { + shared.on_valid_packet(&remote_addr, Half::Read, did_complete_handshake); + } + } + + #[inline] + fn dispatch_buffer_datagram( + &mut self, + out_buf: &mut impl buffer::writer::Storage, + shared: &ArcShared, + clock: &C, + features: TransportFeatures, + ) { + let msg = &mut self.recv_buffer; + let remote_addr = msg.remote_address(); + let ecn = msg.ecn(); + let tag_len = shared.crypto.tag_len(); + + let mut any_valid_packets = false; + let mut did_complete_handshake = false; + + let mut out_buf = buffer::duplex::Interposer::new(out_buf, &mut self.reassembler); + + for segment in msg.segments() { + let segment_len = segment.len(); + let mut decoder = DecoderBufferMut::new(segment); + + 'segment: while !decoder.is_empty() { + let packet = match decoder.decode_parameterized(tag_len) { + Ok((packet, remaining)) => { + decoder = remaining; + packet + } + Err(decoder_error) => { + // the packet was likely corrupted so log it and move on to the + // next segment + tracing::warn!( + protocol_features = ?features, + %decoder_error, + segment_len + ); + + break 'segment; + } + }; + + match packet { + Packet::Stream(mut packet) => { + // make sure the packet looks OK before deriving openers from it + ensure!( + self.receiver.precheck_stream_packet(&packet).is_ok(), + continue + ); + + let _ = shared.crypto.open_with(|opener| { + self.receiver.on_stream_packet( + opener, + &mut packet, + ecn, + clock, + &mut out_buf, + )?; + + any_valid_packets = true; + did_complete_handshake |= + packet.next_expected_control_packet().as_u64() > 0; + + >::Ok(()) + }); + } + other => { + shared.crypto.map().handle_unexpected_packet(&other); + + // TODO if the packet was authentic then close the receiver with an error + } + } + } + } + + if any_valid_packets { + shared.on_valid_packet(&remote_addr, Half::Read, did_complete_handshake); + } + } +} diff --git a/dc/s2n-quic-dc/src/stream/recv/snapshots/s2n_quic_dc__stream__recv__worker__waiting__dot_test.snap b/dc/s2n-quic-dc/src/stream/recv/snapshots/s2n_quic_dc__stream__recv__worker__waiting__dot_test.snap new file mode 100644 index 0000000000..f53070819c --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/recv/snapshots/s2n_quic_dc__stream__recv__worker__waiting__dot_test.snap @@ -0,0 +1,30 @@ +--- +source: dc/s2n-quic-dc/src/stream/recv/worker.rs +expression: "State::dot()" +--- +digraph { + label = "s2n_quic_dc::stream::recv::worker::waiting::State"; + Cooldown; + DataRecvd; + Detached; + EpochTimeout; + Finished; + PeekPacket; + PeekPacket -> EpochTimeout [label = "on_peek_packet"]; + Cooldown -> PeekPacket [label = "on_cooldown_elapsed"]; + EpochTimeout -> PeekPacket [label = "on_epoch_unchanged"]; + PeekPacket -> Cooldown [label = "on_application_progress"]; + EpochTimeout -> Cooldown [label = "on_application_progress"]; + Cooldown -> Cooldown [label = "on_application_progress"]; + PeekPacket -> Detached [label = "on_application_detach"]; + EpochTimeout -> Detached [label = "on_application_detach"]; + Cooldown -> Detached [label = "on_application_detach"]; + PeekPacket -> DataRecvd [label = "on_data_received"]; + EpochTimeout -> DataRecvd [label = "on_data_received"]; + Cooldown -> DataRecvd [label = "on_data_received"]; + PeekPacket -> Finished [label = "on_finished"]; + EpochTimeout -> Finished [label = "on_finished"]; + Cooldown -> Finished [label = "on_finished"]; + Detached -> Finished [label = "on_finished"]; + DataRecvd -> Finished [label = "on_finished"]; +} diff --git a/dc/s2n-quic-dc/src/stream/recv/state.rs b/dc/s2n-quic-dc/src/stream/recv/state.rs new file mode 100644 index 0000000000..5f432e6bb0 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/recv/state.rs @@ -0,0 +1,845 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + allocator::Allocator, + clock, + crypto::{decrypt, encrypt, UninitSlice}, + packet::{control, stream}, + stream::{ + recv::{ack, packet, probes, Error}, + TransportFeatures, DEFAULT_IDLE_TIMEOUT, + }, +}; +use core::{task::Poll, time::Duration}; +use s2n_codec::{EncoderBuffer, EncoderValue}; +use s2n_quic_core::{ + buffer::{self, reader::storage::Infallible as _}, + dc::ApplicationParams, + ensure, + frame::{self, ack::EcnCounts}, + inet::ExplicitCongestionNotification, + packet::number::PacketNumberSpace, + ready, + stream::state::Receiver, + time::{ + timer::{self, Provider as _}, + Clock, Timer, Timestamp, + }, + varint::VarInt, +}; + +#[derive(Debug)] +pub struct State { + stream_id: stream::Id, + ecn_counts: EcnCounts, + control_packet_number: u64, + stream_ack: ack::Space, + recovery_ack: ack::Space, + state: Receiver, + idle_timer: Timer, + idle_timeout: Duration, + // maintains a stable tick timer to avoid timer churn in the platform timer + tick_timer: Timer, + _should_transmit: bool, + max_data: VarInt, + max_data_window: VarInt, + error: Option, + fin_ack_packet_number: Option, + features: TransportFeatures, +} + +impl State { + #[inline] + pub fn new( + stream_id: stream::Id, + params: &ApplicationParams, + features: TransportFeatures, + ) -> Self { + let initial_max_data = params.local_recv_max_data; + Self { + stream_id, + ecn_counts: Default::default(), + control_packet_number: Default::default(), + stream_ack: Default::default(), + recovery_ack: Default::default(), + state: Default::default(), + idle_timer: Default::default(), + idle_timeout: params.max_idle_timeout.unwrap_or(DEFAULT_IDLE_TIMEOUT), + tick_timer: Default::default(), + _should_transmit: false, + max_data: initial_max_data, + max_data_window: initial_max_data, + error: None, + fin_ack_packet_number: None, + features, + } + } + + #[inline] + pub fn id(&self) -> stream::Id { + self.stream_id + } + + #[inline] + pub fn state(&self) -> &Receiver { + &self.state + } + + #[inline] + pub fn timer(&self) -> Option { + self.next_expiration() + } + + #[inline] + pub fn is_open(&self) -> bool { + !self.state.is_terminal() + } + + #[inline] + pub fn is_finished(&self) -> bool { + ensure!(self.state.is_terminal(), false); + ensure!(self.timer().is_none(), false); + true + } + + #[inline] + pub fn stop_sending(&mut self, error: s2n_quic_core::application::Error) { + // if we've already received everything then no need to notify the peer to stop + ensure!(matches!(self.state, Receiver::Recv | Receiver::SizeKnown)); + self.on_error(Error::ApplicationError { error }); + } + + #[inline] + pub fn on_read_buffer(&mut self, out_buf: &mut B, chunk: &mut C, _clock: &Clk) + where + B: buffer::Duplex, + C: buffer::writer::Storage, + Clk: Clock + ?Sized, + { + // try copying the out_buf into the application chunk, if possible + if chunk.has_remaining_capacity() && !out_buf.buffer_is_empty() { + out_buf.infallible_copy_into(chunk); + } + + // record our new max data value + let new_max_data = out_buf + .current_offset() + .saturating_add(self.max_data_window); + + if new_max_data > self.max_data { + self.max_data = new_max_data; + self.needs_transmission("new_max_data"); + } + + // if we know the final offset then update the sate + if out_buf.final_offset().is_some() { + let _ = self.state.on_receive_fin(); + } + + // if we've received everything then update the state + if out_buf.has_buffered_fin() && self.state.on_receive_all_data().is_ok() { + self.needs_transmission("receive_all_data"); + } + + // if we've completely drained the out buffer try transitioning to the final state + if out_buf.is_consumed() && self.state.on_app_read_all_data().is_ok() { + self.needs_transmission("app_read_all_data"); + } + } + + #[inline] + pub fn precheck_stream_packet( + &mut self, + packet: &stream::decoder::Packet, + ) -> Result<(), Error> { + match self.precheck_stream_packet_impl(packet) { + Ok(()) => Ok(()), + Err(err) => { + if err.is_fatal(&self.features) { + tracing::error!(fatal_error = %err, ?packet); + self.on_error(err); + } else { + tracing::debug!(non_fatal_error = %err); + } + Err(err) + } + } + } + + #[inline] + fn precheck_stream_packet_impl( + &mut self, + packet: &stream::decoder::Packet, + ) -> Result<(), Error> { + // make sure we're getting packets for the correct stream + ensure!( + *packet.stream_id() == self.stream_id, + Err(Error::StreamMismatch { + expected: self.stream_id, + actual: *packet.stream_id(), + }) + ); + + if self.features.is_stream() { + // if the transport is streaming then we expect packet numbers in order + let expected_pn = self + .stream_ack + .packets + .max_value() + .map_or(0, |v| v.as_u64() + 1); + let actual_pn = packet.packet_number().as_u64(); + ensure!( + expected_pn == actual_pn, + Err(Error::OutOfOrder { + expected: expected_pn, + actual: actual_pn, + }) + ); + } + + if self.features.is_reliable() { + // if the transport is reliable then we don't expect retransmissions + ensure!( + !packet.is_retransmission(), + Err(Error::UnexpectedRetransmission) + ); + } + + Ok(()) + } + + #[inline] + pub fn on_stream_packet( + &mut self, + opener: &D, + packet: &mut stream::decoder::Packet, + ecn: ExplicitCongestionNotification, + clock: &Clk, + out_buf: &mut B, + ) -> Result<(), Error> + where + D: decrypt::Key, + Clk: Clock + ?Sized, + B: buffer::Duplex, + { + probes::on_stream_packet( + opener.credentials().id, + self.stream_id, + packet.tag().packet_space(), + packet.packet_number(), + packet.stream_offset(), + packet.payload().len(), + packet.is_fin(), + packet.is_retransmission(), + ); + + match self.on_stream_packet_impl(opener, packet, ecn, clock, out_buf) { + Ok(()) => Ok(()), + Err(err) => { + if err.is_fatal(&self.features) { + tracing::error!(fatal_error = %err, ?packet); + self.on_error(err); + } else { + tracing::debug!(non_fatal_error = %err); + } + Err(err) + } + } + } + + #[inline] + fn on_stream_packet_impl( + &mut self, + opener: &D, + packet: &mut stream::decoder::Packet, + ecn: ExplicitCongestionNotification, + clock: &Clk, + out_buf: &mut B, + ) -> Result<(), Error> + where + D: decrypt::Key, + Clk: Clock + ?Sized, + B: buffer::Duplex, + { + use buffer::reader::Storage as _; + + self.precheck_stream_packet_impl(packet)?; + + let is_max_data_ok = self.ensure_max_data(packet); + + // wrap the parsed packet in a reader + let mut packet = packet::Packet { + packet: &mut *packet, + payload_cursor: 0, + is_decrypted_in_place: false, + ecn, + clock, + opener, + receiver: self, + }; + + if !is_max_data_ok { + // ensure the packet is authentic before resetting the stream + let _ = packet.read_chunk(usize::MAX)?; + + tracing::error!( + message = "max data exceeded", + allowed = packet.receiver.max_data.as_u64(), + requested = packet + .packet + .stream_offset() + .as_u64() + .saturating_add(packet.packet.payload().len() as u64), + ); + + let error = Error::MaxDataExceeded; + self.on_error(error); + return Err(error); + } + + // decrypt and write the packet to the provided buffer + out_buf.read_from(&mut packet)?; + + let mut chunk = buffer::writer::storage::Empty; + self.on_read_buffer(out_buf, &mut chunk, clock); + + Ok(()) + } + + #[inline] + pub(super) fn on_stream_packet_in_place( + &mut self, + crypto: &D, + packet: &mut stream::decoder::Packet, + ecn: ExplicitCongestionNotification, + clock: &Clk, + ) -> Result<(), Error> + where + D: decrypt::Key, + Clk: Clock + ?Sized, + { + // ensure the packet is authentic before processing it + let res = packet.decrypt_in_place(crypto); + + probes::on_stream_packet_decrypted( + crypto.credentials().id, + self.stream_id, + packet.tag().packet_space(), + packet.packet_number(), + packet.stream_offset(), + packet.payload().len(), + packet.is_fin(), + packet.is_retransmission(), + res.is_ok(), + ); + + res?; + + self.on_cleartext_stream_packet(packet, ecn, clock) + } + + #[inline] + pub(super) fn on_stream_packet_copy( + &mut self, + crypto: &D, + packet: &mut stream::decoder::Packet, + ecn: ExplicitCongestionNotification, + payload_out: &mut UninitSlice, + clock: &Clk, + ) -> Result<(), Error> + where + D: decrypt::Key, + Clk: Clock + ?Sized, + { + // ensure the packet is authentic before processing it + let res = packet.decrypt(crypto, payload_out); + + probes::on_stream_packet_decrypted( + crypto.credentials().id, + self.stream_id, + packet.tag().packet_space(), + packet.packet_number(), + packet.stream_offset(), + packet.payload().len(), + packet.is_fin(), + packet.is_retransmission(), + res.is_ok(), + ); + + res?; + + self.on_cleartext_stream_packet(packet, ecn, clock) + } + + #[inline] + fn ensure_max_data(&self, packet: &stream::decoder::Packet) -> bool { + // we only need to enforce flow control for non-controlled transport + ensure!(!self.features.is_flow_controlled(), true); + + self.max_data + .as_u64() + .checked_sub(packet.payload().len() as u64) + .and_then(|v| v.checked_sub(packet.stream_offset().as_u64())) + .is_some() + } + + #[inline] + fn on_cleartext_stream_packet( + &mut self, + packet: &stream::decoder::Packet, + ecn: ExplicitCongestionNotification, + clock: &Clk, + ) -> Result<(), Error> + where + Clk: Clock + ?Sized, + { + tracing::trace!( + stream_id = %packet.stream_id(), + stream_offset = packet.stream_offset().as_u64(), + payload_len = packet.payload().len(), + final_offset = ?packet.final_offset().map(|v| v.as_u64()), + ); + + let space = match packet.tag().packet_space() { + stream::PacketSpace::Stream => &mut self.stream_ack, + stream::PacketSpace::Recovery => &mut self.recovery_ack, + }; + ensure!( + space.filter.on_packet(packet).is_ok(), + Err(Error::Duplicate) + ); + + let packet_number = PacketNumberSpace::Initial.new_packet_number(packet.packet_number()); + if let Err(err) = space.packets.insert_packet_number(packet_number) { + tracing::debug!("could not record packet number {packet_number} with error {err:?}"); + } + + // if we got a new packet then we'll need to transmit an ACK + self.needs_transmission("new_packet"); + + // update the idle timer since we received a valid packet + if matches!(self.state, Receiver::Recv | Receiver::SizeKnown) + || packet.stream_offset() == VarInt::ZERO + { + self.update_idle_timer(clock); + } + + // TODO process control data + let _ = packet.control_data(); + + self.ecn_counts.increment(ecn); + + if !self.stream_id.is_reliable { + // TODO should we perform loss detection on the receiver and reset the stream if we have a big + // enough gap? + } + + // clean up any ACK state that we can + self.on_next_expected_control_packet(packet.next_expected_control_packet()); + + Ok(()) + } + + #[inline] + pub fn should_transmit(&self) -> bool { + self._should_transmit + } + + #[inline] + pub fn on_transport_close(&mut self) { + // only stream transports can be closed + ensure!(self.features.is_stream()); + + // only error out if we're still expecting more data + ensure!(matches!(self.state, Receiver::Recv | Receiver::SizeKnown)); + + self.on_error(Error::TruncatedTransport); + } + + #[inline] + fn needs_transmission(&mut self, reason: &str) { + if self.error.is_none() { + // we only transmit errors for reliable + flow-controlled transports + if self.features.is_reliable() && self.features.is_flow_controlled() { + tracing::trace!(skipping_transmission = reason); + return; + } + } + + if !self._should_transmit { + tracing::trace!(needs_transmission = reason); + } + self._should_transmit = true; + } + + #[inline] + fn on_next_expected_control_packet(&mut self, next_expected_control_packet: VarInt) { + if let Some(largest_delivered_control_packet) = + next_expected_control_packet.checked_sub(VarInt::from_u8(1)) + { + self.stream_ack + .on_largest_delivered_packet(largest_delivered_control_packet); + self.recovery_ack + .on_largest_delivered_packet(largest_delivered_control_packet); + + if let Some(fin_ack_packet_number) = self.fin_ack_packet_number { + // if the sender received our ACK to the fin, then we can shut down immediately + if largest_delivered_control_packet >= fin_ack_packet_number { + self.silent_shutdown(); + } + } + } + } + + #[inline] + fn update_idle_timer(&mut self, clock: &Clk) { + let target = clock.get_time() + self.idle_timeout; + self.idle_timer.set(target); + + // if the tick timer isn't armed then rearmed it; otherwise keep it stable to avoid + // churn + if !self.tick_timer.is_armed() { + self.tick_timer.set(target); + } + } + + #[inline] + fn mtu(&self) -> u16 { + // TODO should we pull this from somewhere + + // we want to make sure ACKs get through so use the minimum packet length for QUIC + 1200 + } + + #[inline] + fn ecn(&self) -> ExplicitCongestionNotification { + // TODO how do we decide what to send on control messages + ExplicitCongestionNotification::Ect0 + } + + #[inline] + pub fn on_error(&mut self, error: Error) { + debug_assert!(error.is_fatal(&self.features)); + let _ = self.state.on_reset(); + self.stream_ack.clear(); + self.recovery_ack.clear(); + self.needs_transmission("on_error"); + + // make sure we haven't already set the error from something else + ensure!(self.error.is_none()); + self.error = Some(error); + } + + #[inline] + pub fn check_error(&self) -> Result<(), Error> { + // if we already received/read all of the data then filter out errors + ensure!( + !matches!(self.state, Receiver::DataRead | Receiver::DataRecvd), + Ok(()) + ); + + if let Some(err) = self.error { + Err(err) + } else { + Ok(()) + } + } + + #[inline] + pub fn on_timeout(&mut self, clock: &Clk, load_last_activity: Ld) + where + Clk: Clock + ?Sized, + Ld: FnOnce() -> Timestamp, + { + let now = clock.get_time(); + if self.poll_idle_timer(clock, load_last_activity).is_ready() { + self.silent_shutdown(); + + // only transition to an error state if we didn't receive everything + ensure!(matches!(self.state, Receiver::Recv | Receiver::SizeKnown)); + + // we don't want to transmit anything so enter a terminal state + let mut did_transition = false; + did_transition |= self.state.on_reset().is_ok(); + did_transition |= self.state.on_app_read_reset().is_ok(); + if did_transition { + self.on_error(Error::IdleTimeout); + // override the transmission since we're just timing out + self._should_transmit = false; + } + + return; + } + + // if the tick timer expired, then copy the current idle timeout target + if self.tick_timer.poll_expiration(now).is_ready() { + self.tick_timer = self.idle_timer.clone(); + } + } + + #[inline] + fn poll_idle_timer(&mut self, clock: &Clk, load_last_activity: Ld) -> Poll<()> + where + Clk: Clock + ?Sized, + Ld: FnOnce() -> Timestamp, + { + let now = clock.get_time(); + + // check the idle timer first + ready!(self.idle_timer.poll_expiration(now)); + + // if that expired then load the last activity from the peer and update the idle timer with + // the value + let last_peer_activity = load_last_activity(); + self.update_idle_timer(&last_peer_activity); + + // check the idle timer once more before returning + ready!(self.idle_timer.poll_expiration(now)); + + Poll::Ready(()) + } + + #[inline] + fn silent_shutdown(&mut self) { + self._should_transmit = false; + self.idle_timer.cancel(); + self.tick_timer.cancel(); + self.stream_ack.clear(); + self.recovery_ack.clear(); + } + + #[inline] + pub fn on_transmit( + &mut self, + encrypt_key: &E, + source_control_port: u16, + output: &mut A, + clock: &Clk, + ) where + E: encrypt::Key, + A: Allocator, + Clk: Clock + ?Sized, + { + (if self.error.is_none() { + Self::on_transmit_ack + } else { + Self::on_transmit_error + })( + self, + encrypt_key, + source_control_port, + output, + // avoid querying the clock for every transmitted packet + &clock::Cached::new(clock), + ) + } + + #[inline] + fn on_transmit_ack( + &mut self, + encrypt_key: &E, + source_control_port: u16, + output: &mut A, + _clock: &Clk, + ) where + E: encrypt::Key, + A: Allocator, + Clk: Clock + ?Sized, + { + ensure!(self.should_transmit()); + + let mtu = self.mtu(); + + output.set_ecn(self.ecn()); + + let packet_number = self.next_pn(); + + ensure!(let Some(segment) = output.alloc()); + + let buffer = output.get_mut(&segment); + buffer.resize(mtu as _, 0); + + let encoder = EncoderBuffer::new(buffer); + + // TODO compute this by storing the time that we received the max packet number + let ack_delay = VarInt::ZERO; + + let max_data = frame::MaxData { + maximum_data: self.max_data, + }; + let max_data_encoding_size: VarInt = max_data.encoding_size().try_into().unwrap(); + + let (stream_ack, max_data_encoding_size) = self.stream_ack.encoding( + max_data_encoding_size, + ack_delay, + Some(self.ecn_counts), + mtu, + ); + let (recovery_ack, max_data_encoding_size) = + self.recovery_ack + .encoding(max_data_encoding_size, ack_delay, None, mtu); + + let encoding_size = max_data_encoding_size; + + tracing::trace!(?stream_ack, ?recovery_ack, ?max_data); + + let frame = ((max_data, stream_ack), recovery_ack); + + let result = control::encoder::encode( + encoder, + source_control_port, + Some(self.stream_id), + packet_number, + VarInt::ZERO, + &mut &[][..], + encoding_size, + &frame, + encrypt_key, + ); + + match result { + 0 => { + output.free(segment); + return; + } + packet_len => { + buffer.truncate(packet_len); + // TODO duplicate the transmission in case we have a lot of gaps in packets + output.push(segment); + } + } + + for (ack, space) in [ + (&self.stream_ack, stream::PacketSpace::Stream), + (&self.recovery_ack, stream::PacketSpace::Recovery), + ] { + let metrics = ( + ack.packets.min_value(), + ack.packets.max_value(), + ack.packets.interval_len().checked_sub(1), + ); + if let (Some(min), Some(max), Some(gaps)) = metrics { + probes::on_transmit_control( + encrypt_key.credentials().id, + self.stream_id, + space, + packet_number, + min, + max, + gaps, + ); + }; + } + + // make sure we sent a packet + ensure!(!output.is_empty()); + + // record the max value we've seen for removing old packet numbers + self.stream_ack.on_transmit(packet_number); + self.recovery_ack.on_transmit(packet_number); + + self.on_packet_sent(packet_number); + } + + #[inline] + fn on_transmit_error( + &mut self, + encrypt_key: &E, + source_control_port: u16, + output: &mut A, + _clock: &Clk, + ) where + E: encrypt::Key, + A: Allocator, + Clk: Clock + ?Sized, + { + ensure!(self.should_transmit()); + + let mtu = self.mtu() as usize; + + output.set_ecn(self.ecn()); + + let packet_number = self.next_pn(); + + ensure!(let Some(segment) = output.alloc()); + + let buffer = output.get_mut(&segment); + buffer.resize(mtu, 0); + + let encoder = EncoderBuffer::new(buffer); + + let frame = self + .error + .as_ref() + .and_then(|err| err.connection_close()) + .unwrap_or_else(|| s2n_quic_core::transport::Error::NO_ERROR.into()); + + let encoding_size = frame.encoding_size().try_into().unwrap(); + + let result = control::encoder::encode( + encoder, + source_control_port, + Some(self.stream_id), + packet_number, + VarInt::ZERO, + &mut &[][..], + encoding_size, + &frame, + encrypt_key, + ); + + match result { + 0 => { + output.free(segment); + return; + } + packet_len => { + buffer.truncate(packet_len); + output.push(segment); + } + } + + tracing::debug!(connection_close = ?frame); + + // clean things up + self.stream_ack.clear(); + self.recovery_ack.clear(); + + probes::on_transmit_close( + encrypt_key.credentials().id, + self.stream_id, + packet_number, + frame.error_code, + ); + + self.on_packet_sent(packet_number); + } + + #[inline] + fn next_pn(&mut self) -> VarInt { + VarInt::new(self.control_packet_number).expect("2^62 is a lot of packets") + } + + #[inline] + fn on_packet_sent(&mut self, packet_number: VarInt) { + // record the fin_ack packet number so we can shutdown more quickly + if !matches!(self.state, Receiver::Recv | Receiver::SizeKnown) + && self.fin_ack_packet_number.is_none() + { + self.fin_ack_packet_number = Some(packet_number); + } + + self.control_packet_number += 1; + self._should_transmit = false; + } +} + +impl timer::Provider for State { + #[inline] + fn timers(&self, query: &mut Q) -> timer::Result { + self.idle_timer.timers(query)?; + self.tick_timer.timers(query)?; + Ok(()) + } +} diff --git a/dc/s2n-quic-dc/src/stream/recv/worker.rs b/dc/s2n-quic-dc/src/stream/recv/worker.rs new file mode 100644 index 0000000000..7c6c947168 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/recv/worker.rs @@ -0,0 +1,341 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + allocator::Allocator, + clock::Timer, + msg, + stream::{shared::ArcShared, socket::Socket}, +}; +use core::task::{Context, Poll}; +use s2n_quic_core::{buffer, endpoint, ensure, ready, time::clock::Timer as _}; +use std::{io, time::Duration}; +use tracing::{debug, trace}; + +const INITIAL_TIMEOUT: Duration = Duration::from_millis(2); + +mod waiting { + use s2n_quic_core::state::{event, is}; + + #[derive(Clone, Debug, Default, PartialEq)] + pub enum State { + PeekPacket, + EpochTimeout, + #[default] + Cooldown, + DataRecvd, + Detached, + Finished, + } + + impl State { + is!(is_peek_packet, PeekPacket); + event! { + on_peek_packet(PeekPacket => EpochTimeout); + on_cooldown_elapsed(Cooldown => PeekPacket); + on_epoch_unchanged(EpochTimeout => PeekPacket); + on_application_progress(PeekPacket | EpochTimeout | Cooldown => Cooldown); + on_application_detach(PeekPacket | EpochTimeout | Cooldown => Detached); + on_data_received(PeekPacket | EpochTimeout | Cooldown => DataRecvd); + on_finished(PeekPacket | EpochTimeout | Cooldown | Detached | DataRecvd => Finished); + } + } + + #[test] + fn dot_test() { + insta::assert_snapshot!(State::dot()); + } +} + +#[repr(u8)] +pub(crate) enum ErrorCode { + /// The application dropped the stream without errors + None = 0, + /// General error code for application-level errors + Application = 1, +} + +pub struct Worker { + shared: ArcShared, + last_observed_epoch: u64, + send_buffer: msg::send::Message, + state: waiting::State, + timer: Timer, + backoff: u8, + socket: S, +} + +impl Worker { + #[inline] + pub fn new(socket: S, shared: ArcShared, endpoint: endpoint::Type) -> Self { + let send_buffer = msg::send::Message::new(shared.read_remote_addr(), shared.gso.clone()); + let timer = Timer::new_with_timeout(&shared.clock, INITIAL_TIMEOUT); + + let state = match endpoint { + // on the client we delay before reading from the socket + endpoint::Type::Client => waiting::State::Cooldown, + // on the server we need the application to read after accepting, otherwise the peer + // won't know what our port is + endpoint::Type::Server => waiting::State::EpochTimeout, + }; + + Self { + shared, + last_observed_epoch: 0, + send_buffer, + state, + timer, + backoff: 0, + socket, + } + } + + #[inline] + pub fn update_waker(&self, cx: &mut Context) { + self.shared.receiver.worker_waker.update(cx.waker()); + } + + #[inline] + pub fn poll(&mut self, cx: &mut Context) -> Poll<()> { + if let Poll::Ready(Err(err)) = self.poll_flush_socket(cx) { + tracing::error!(socket_error = ?err); + // TODO should we return? if we get a send error it's most likely fatal + return Poll::Ready(()); + } + + if let Poll::Ready(Err(err)) = self.poll_socket(cx) { + tracing::error!(socket_error = ?err); + // TODO should we return? if we get a recv error it's most likely fatal + return Poll::Ready(()); + } + + // go until we get into the finished state + if let waiting::State::Finished = &self.state { + Poll::Ready(()) + } else { + Poll::Pending + } + } + + #[inline] + fn poll_socket(&mut self, cx: &mut Context) -> Poll> { + loop { + match &self.state { + waiting::State::PeekPacket => { + // check to see if the application is progressing before peeking the socket + ensure!(!self.is_application_progressing(), continue); + + // peek the socket buffer size - we don't care how big, just that there is at + // least one packet there + let _len = ready!(self.socket.poll_peek_len(cx))?; + + self.arm_timer(); + self.state.on_peek_packet().unwrap(); + continue; + } + waiting::State::EpochTimeout => { + // check to see if the application is progressing before checking the timer + ensure!(!self.is_application_progressing(), continue); + + ready!(self.timer.poll_ready(cx)); + + // the application isn't making progress so emit the timer expired event + self.state.on_epoch_unchanged().unwrap(); + + // only log this message after the first observation + if self.last_observed_epoch > 0 { + debug!("application reading too slowly from socket"); + } + + // reset the backoff with the assumption that the application will go slow in + // the future + self.backoff = 0; + + // drain the socket if the application isn't going fast enough + return self.poll_drain_recv_socket(cx); + } + waiting::State::Cooldown => { + // check to see if the application is progressing before checking the timer + ensure!(!self.is_application_progressing(), continue); + + ready!(self.timer.poll_ready(cx)); + + // go back to waiting for a packet + self.state.on_cooldown_elapsed().unwrap(); + continue; + } + waiting::State::Detached | waiting::State::DataRecvd => { + ready!(self.poll_drain_recv_socket(cx))?; + } + waiting::State::Finished => { + // nothing left to do + return Ok(()).into(); + } + } + } + } + + #[inline] + fn is_application_progressing(&mut self) -> bool { + // check to see if the application shut down + if let super::shared::ApplicationState::Closed { is_panicking } = + self.shared.receiver.application_state() + { + if let Ok(Some(mut recv)) = self.shared.receiver.worker_try_lock() { + // check to see if we have anything in the reassembler as well + let is_buffer_empty = recv.recv_buffer.is_empty() && recv.reassembler.is_empty(); + + let error = if !is_buffer_empty || is_panicking { + // we still had data in our buffer so notify the sender + ErrorCode::Application as u8 + } else { + // no error - the application is just going away + ErrorCode::None as u8 + }; + + recv.receiver.stop_sending(error.into()); + + // TODO arm the timer so we can clean up when we're done + + if recv.receiver.is_finished() { + let _ = self.state.on_finished(); + } + } + + let _ = self.state.on_application_detach(); + + return true; + } + + let current_epoch = self.shared.receiver.application_epoch(); + + // make sure the epoch has changed since we last saw it before cooling down + ensure!(self.last_observed_epoch < current_epoch, false); + + // record the new observation + self.last_observed_epoch = current_epoch; + + // the application is making progress since the packet is different - loop back to cooldown + trace!("application is making progress"); + + // delay when we read from the socket again to avoid spinning + let _ = self.state.on_application_progress(); + self.arm_timer(); + + // after successful progress from the application we want to intervene less + self.backoff = (self.backoff + 1).min(10); + + true + } + + #[inline] + fn poll_drain_recv_socket(&mut self, cx: &mut Context) -> Poll> { + let mut should_transmit = false; + let mut received_packets = 0; + + let _res = self.process_packets(cx, &mut received_packets, &mut should_transmit); + + ensure!( + should_transmit, + if received_packets == 0 { + Poll::Pending + } else { + Ok(()).into() + } + ); + + // send an ACK if needed + if let Some(mut recv) = self.shared.receiver.worker_try_lock()? { + // use the latest value rather than trying to transmit an old one + if !self.send_buffer.is_empty() { + let _ = self.send_buffer.drain(); + } + + recv.fill_transmit_queue(&self.shared, &mut self.send_buffer); + + if recv.receiver.state().is_data_received() { + let _ = self.state.on_data_received(); + } + + if recv.receiver.is_finished() { + let _ = self.state.on_finished(); + } else { + // TODO update the timer so we get woken up on idle timeout + } + } + + ready!(self.poll_flush_socket(cx))?; + + Ok(()).into() + } + + #[inline] + fn process_packets( + &mut self, + cx: &mut Context, + received_packets: &mut usize, + should_transmit: &mut bool, + ) -> io::Result<()> { + // loop until we hit Pending from the socket + loop { + // try_lock the state before reading so we don't consume a packet the application is + // about to read + let Some(mut recv) = self.shared.receiver.worker_try_lock()? else { + // if the application is locking the state then we don't want to transmit, since it + // will do that for us + *should_transmit = false; + break; + }; + + // make sure to process any left over packets, if any + if !recv.recv_buffer.is_empty() { + *should_transmit |= recv.process_recv_buffer( + &mut buffer::writer::storage::Empty, + &self.shared, + self.socket.features(), + ); + } + + let res = recv.poll_fill_recv_buffer(cx, &self.socket); + + match res { + Poll::Pending => break, + Poll::Ready(res) => res?, + }; + + *received_packets += 1; + + // process the packet we just received + *should_transmit |= recv.process_recv_buffer( + &mut buffer::writer::storage::Empty, + &self.shared, + self.socket.features(), + ); + } + + Ok(()) + } + + #[inline] + fn poll_flush_socket(&mut self, cx: &mut Context) -> Poll> { + while !self.send_buffer.is_empty() { + ready!(self.socket.poll_send_buffer(cx, &mut self.send_buffer))?; + } + + Ok(()).into() + } + + #[inline] + fn arm_timer(&mut self) { + // TODO do we derive this from RTT? + let mut timeout = INITIAL_TIMEOUT; + // don't back off on packet peeks + if !self.state.is_peek_packet() { + timeout *= (self.backoff as u32) + 1; + } + let now = self.shared.clock.get_time(); + let target = now + timeout; + + self.timer.update(target); + } +} diff --git a/dc/s2n-quic-dc/src/stream/runtime.rs b/dc/s2n-quic-dc/src/stream/runtime.rs new file mode 100644 index 0000000000..2cb7580244 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/runtime.rs @@ -0,0 +1,15 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::stream::{recv, send}; +use std::sync::Arc; + +#[cfg(feature = "tokio")] +pub mod tokio; + +pub type ArcHandle = Arc; + +pub trait Handle: 'static + Send + Sync { + fn spawn_recv_shutdown(&self, shutdown: recv::application::Shutdown); + fn spawn_send_shutdown(&self, shutdown: send::application::Shutdown); +} diff --git a/dc/s2n-quic-dc/src/stream/runtime/tokio.rs b/dc/s2n-quic-dc/src/stream/runtime/tokio.rs new file mode 100644 index 0000000000..53bc682c56 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/runtime/tokio.rs @@ -0,0 +1,87 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::stream::{recv, send}; +use core::{mem::MaybeUninit, ops}; +use std::sync::Arc; + +impl super::Handle for tokio::runtime::Handle { + #[inline] + fn spawn_recv_shutdown(&self, shutdown: recv::application::Shutdown) { + self.spawn(async move { + // Note: Must be created inside spawn() since ambient runtime is otherwise not + // guaranteed and will cause a panic on the timeout future construction. + // make sure the task doesn't hang around indefinitely + tokio::time::timeout(core::time::Duration::from_secs(1), shutdown).await + }); + } + + #[inline] + fn spawn_send_shutdown(&self, shutdown: send::application::Shutdown) { + self.spawn(async move { + // Note: Must be created inside spawn() since ambient runtime is otherwise not + // guaranteed and will cause a panic on the timeout future construction. + // make sure the task doesn't hang around indefinitely + tokio::time::timeout(core::time::Duration::from_secs(1), shutdown).await + }); + } +} + +#[derive(Clone)] +pub struct Shared(Arc); + +impl Shared { + #[inline] + pub fn handle(&self) -> super::ArcHandle { + self.0.clone() + } +} + +impl ops::Deref for Shared { + type Target = tokio::runtime::Handle; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From for Shared { + fn from(rt: tokio::runtime::Runtime) -> Self { + Self(Arc::new(SharedInner(MaybeUninit::new(rt)))) + } +} + +struct SharedInner(MaybeUninit); + +impl ops::Deref for SharedInner { + type Target = tokio::runtime::Handle; + + #[inline] + fn deref(&self) -> &Self::Target { + unsafe { (self.0).assume_init_ref().handle() } + } +} + +impl super::Handle for SharedInner { + #[inline] + fn spawn_recv_shutdown(&self, shutdown: recv::application::Shutdown) { + (**self).spawn_recv_shutdown(shutdown) + } + + #[inline] + fn spawn_send_shutdown(&self, shutdown: send::application::Shutdown) { + (**self).spawn_send_shutdown(shutdown) + } +} + +impl Drop for SharedInner { + fn drop(&mut self) { + // drop the runtimes in a separate thread to avoid tokio complaining + let rt = unsafe { self.0.assume_init_read() }; + std::thread::spawn(move || { + // give enough time for all of the streams to shut down + rt.shutdown_timeout(core::time::Duration::from_secs(10)); + }); + } +} diff --git a/dc/s2n-quic-dc/src/stream/send.rs b/dc/s2n-quic-dc/src/stream/send.rs index dcdd1635e2..3ba3ffcb6b 100644 --- a/dc/s2n-quic-dc/src/stream/send.rs +++ b/dc/s2n-quic-dc/src/stream/send.rs @@ -10,6 +10,7 @@ pub mod path; pub mod probes; pub mod queue; pub mod shared; +pub mod state; pub mod transmission; pub mod worker; diff --git a/dc/s2n-quic-dc/src/stream/send/application.rs b/dc/s2n-quic-dc/src/stream/send/application.rs index a89fecc15a..f5710e7df1 100644 --- a/dc/s2n-quic-dc/src/stream/send/application.rs +++ b/dc/s2n-quic-dc/src/stream/send/application.rs @@ -2,175 +2,392 @@ // SPDX-License-Identifier: Apache-2.0 use crate::{ - crypto::encrypt, - packet::stream::{self, encoder}, + clock, msg, stream::{ - packet_number, - send::{error::Error, flow, path, probes}, + pacer, runtime, + send::{flow, queue}, + shared::ArcShared, + socket, }, }; -use bytes::buf::UninitSlice; -use s2n_codec::EncoderBuffer; -use s2n_quic_core::{ - buffer::{self, reader::Storage as _, Reader as _}, - ensure, - time::Clock, - varint::VarInt, +use core::{ + fmt, + pin::Pin, + sync::atomic::Ordering, + task::{Context, Poll}, }; +use s2n_quic_core::{buffer, ensure, ready, task::waker}; +use std::{io, net::SocketAddr}; +use tracing::trace; +mod builder; +pub mod state; pub mod transmission; -pub trait Message { - fn max_segments(&self) -> usize; - fn push transmission::Event<()>>( - &mut self, - buffer_len: usize, - p: P, - ); +pub use builder::Builder; + +pub struct Writer(Box); + +struct Inner { + shared: ArcShared, + sockets: socket::ArcApplication, + queue: queue::Queue, + pacer: pacer::Naive, + open: bool, + runtime: runtime::ArcHandle, } -#[derive(Clone, Copy, Debug)] -pub struct State { - pub stream_id: stream::Id, - pub source_control_port: u16, - pub source_stream_port: Option, +impl fmt::Debug for Writer { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Writer") + .field("peer_addr", &self.peer_addr().unwrap()) + .field("local_addr", &self.local_addr().unwrap()) + .finish() + } } -impl State { +impl Writer { #[inline] - pub fn transmit( - &self, - credits: flow::Credits, - path: &path::Info, - storage: &mut I, - packet_number: &packet_number::Counter, - encrypt_key: &E, - clock: &Clk, - message: &mut M, - ) -> Result<(), Error> + pub fn peer_addr(&self) -> io::Result { + self.0.shared.common.ensure_open()?; + Ok(self.0.shared.write_remote_addr().into()) + } + + #[inline] + pub fn local_addr(&self) -> io::Result { + self.0.sockets.write_application().local_addr() + } + + #[inline] + pub fn protocol(&self) -> socket::Protocol { + self.0.sockets.protocol() + } + + #[inline] + pub async fn write_from(&mut self, buf: &mut S) -> io::Result + where + S: buffer::reader::storage::Infallible, + { + core::future::poll_fn(|cx| self.poll_write_from(cx, buf, false)).await + } + + #[inline] + pub fn poll_write_from( + &mut self, + cx: &mut Context, + buf: &mut S, + is_fin: bool, + ) -> Poll> where - E: encrypt::Key, - I: buffer::reader::Storage, - Clk: Clock, - M: Message, + S: buffer::reader::storage::Infallible, { + waker::debug_assert_contract(cx, |cx| { + // if we've already shut down the stream then return early + if !self.0.open { + ensure!( + buf.buffer_is_empty() && is_fin, + Err(io::Error::from(io::ErrorKind::BrokenPipe)).into() + ); + return Ok(0).into(); + } + + let res = ready!(self.0.poll_write_from(cx, buf, is_fin)); + + // if we got an error then shut down the stream if needed + if res.is_err() { + // use the `Drop` type so we send a RST instead + let _ = self.0.shutdown(ShutdownType::Drop { + is_panicking: false, + }); + } + + res.into() + }) + } + + /// Shutdown the stream for writing. + pub fn shutdown(&mut self) -> io::Result<()> { + self.0.shutdown(ShutdownType::Explicit) + } +} + +impl Inner { + #[inline(always)] + fn poll_write_from( + &mut self, + cx: &mut Context, + buf: &mut S, + is_fin: bool, + ) -> Poll> + where + S: buffer::reader::storage::Infallible, + { + // Try to flush any pending packets + let flushed_len = ready!(self.poll_flush_buffer(cx, buf.buffered_len()))?; + + // if the flushed len is non-zero then return it to the application before accepting more + // bytes to buffer + ensure!(flushed_len == 0, Ok(flushed_len).into()); + + // make sure the queue is drained before continuing + ensure!(self.queue.is_empty(), Ok(flushed_len).into()); + + let app = self.shared.application(); + let max_header_len = app.max_header_len(); + let max_segments = self.shared.gso.max_segments(); + + // create a flow request from the provided application input + let initial_len = buf.buffered_len(); + let mut request = flow::Request { + len: initial_len, + initial_len, + is_fin, + }; + + let path = self.shared.sender.path.load(); + + // clamp the flow request based on the path state + request.clamp(path.max_flow_credits(max_header_len, max_segments)); + + // acquire flow credits from the worker + let credits = ready!(self.shared.sender.flow.poll_acquire(cx, request))?; + + trace!(?credits); + + let mut batch = if self.sockets.write_application().features().is_reliable() { + // the protocol does recovery for us so no need to track the transmissions + None + } else { + // if we are using unreliable sockets then we need to write transmissions to a batch for the + // worker to track for recovery + + let batch = self + .shared + .sender + .application_transmission_queue + .alloc_batch(msg::segment::MAX_COUNT); + Some(batch) + }; + + self.queue.push_buffer( + buf, + &mut batch, + max_segments, + &self.shared.sender.segment_alloc, + |message, buf| { + self.shared.crypto.seal_with(|sealer| { + // push packets for transmission into our queue + app.transmit( + credits, + &path, + buf, + &self.shared.sender.packet_number, + sealer, + &clock::Cached::new(&self.shared.clock), + message, + ) + }) + }, + )?; + + if let Some(batch) = batch { + // send the transmission information off to the worker before flushing to the socket so the + // worker is prepared to handle ACKs from the peer + self.shared.sender.push_to_worker(batch)?; + } + + // flush the queue of packets to the socket + self.poll_flush_buffer(cx, usize::MAX) + } + + #[inline] + fn poll_flush_buffer( + &mut self, + cx: &mut Context, + limit: usize, + ) -> Poll> { + // if we're actually writing to the socket then we need to pace + if !self.queue.is_empty() { + ready!(self.pacer.poll_pacing(cx, &self.shared.clock)); + } + + let len = ready!(self.queue.poll_flush( + cx, + limit, + self.sockets.write_application(), + &msg::addr::Addr::new(self.shared.write_remote_addr()), + &self.shared.sender.segment_alloc, + &self.shared.gso, + ))?; + + Ok(len).into() + } + + #[inline] + fn shutdown(&mut self, ty: ShutdownType) -> io::Result<()> { + // make sure we haven't already shut down ensure!( - credits.len > 0 || storage.buffer_is_empty() || credits.is_fin, - Ok(()) + self.open, + // macos returns an error after the stream has already shut down + if cfg!(target_os = "macos") { + Err(io::ErrorKind::NotConnected.into()) + } else { + Ok(()) + } ); - let mut reader = buffer::reader::Incremental::new(credits.offset); - let mut reader = reader.with_storage(storage, credits.is_fin)?; - debug_assert!( - reader.buffered_len() >= credits.len, - "attempted to acquire more credits than what is buffered" - ); - let mut reader = reader.with_read_limit(credits.len); - - let stream_id = *self.stream_id(); - let max_header_len = self.max_header_len(); - - let mut total_payload_len = 0; - - loop { - let packet_number = packet_number.next()?; - - let buffer_len = { - let estimated_len = reader.buffered_len() + max_header_len; - (path.max_datagram_size as usize).min(estimated_len) - }; - - message.push(buffer_len, |buffer| { - let stream_offset = reader.current_offset(); - let mut reader = reader.track_read(); - - let buffer = unsafe { - // SAFETY: `buffer` is a valid `UninitSlice` but `EncoderBuffer` expects to - // write into a `&mut [u8]`. Here we construct a `&mut [u8]` since - // `EncoderBuffer` never actually reads from the slice and only writes to it. - core::slice::from_raw_parts_mut(buffer.as_mut_ptr(), buffer.len()) - }; - let encoder = EncoderBuffer::new(buffer); - let packet_len = encoder::encode( - encoder, - self.source_control_port, - self.source_stream_port, - stream_id, - stream::PacketSpace::Stream, - packet_number, - path.next_expected_control_packet, - VarInt::ZERO, - &mut &[][..], - VarInt::ZERO, - &(), - &mut reader, - encrypt_key, - ); + // TODO what do we want to do when we are panicking? + if !matches!(ty, ShutdownType::Drop { is_panicking: true }) { + // don't block on this actually completing since we want to also notify the worker + // immediately + let waker = s2n_quic_core::task::waker::noop(); + let mut cx = core::task::Context::from_waker(&waker); + let _ = self.poll_write_from(&mut cx, &mut buffer::reader::storage::Empty, true); + } - // buffer is clamped to u16::MAX so this is safe to cast without loss - let _: u16 = path.max_datagram_size; - let packet_len = packet_len as u16; - let payload_len = reader.consumed_len() as u16; - total_payload_len += payload_len as usize; + self.open = false; + self.shared + .common + .closed_halves + .fetch_add(1, Ordering::Relaxed); - let has_more_app_data = credits.initial_len > total_payload_len; + let queue = core::mem::take(&mut self.queue); - let included_fin = reader.final_offset().map_or(false, |fin| { - stream_offset.as_u64() + payload_len as u64 == fin.as_u64() - }); + // if we've transmitted everything we need to then finished the writing half + if matches!(ty, ShutdownType::Explicit) && queue.is_empty() { + self.sockets.write_application().send_finish()?; + } - let time_sent = clock.get_time(); - probes::on_transmit_stream( - encrypt_key.credentials().id, - stream_id, - stream::PacketSpace::Stream, - s2n_quic_core::packet::number::PacketNumberSpace::Initial - .new_packet_number(packet_number), - stream_offset, - payload_len, - included_fin, - false, - ); + // pass things to the worker if we need to gracefully shut down + if !self.sockets.write_application().features().is_stream() { + let is_panicking = matches!(ty, ShutdownType::Drop { is_panicking: true }); + self.shared.sender.shutdown(queue, is_panicking); + return Ok(()); + } - let info = transmission::Info { - packet_len, - retransmission: if stream_id.is_reliable { - Some(()) - } else { - None - }, - stream_offset, - payload_len, - included_fin, - time_sent, - ecn: path.ecn, - }; - - transmission::Event { - packet_number, - info, - has_more_app_data, - } + // if we're using TCP and we get blocked from writing a final offset then spawn a task + // to do it for us + if !queue.is_empty() { + let shared = self.shared.clone(); + let sockets = self.sockets.clone(); + self.runtime.spawn_send_shutdown(Shutdown { + queue, + shared, + sockets, + ty, }); - - // bail if we've transmitted everything - ensure!(!reader.buffer_is_empty(), break); } Ok(()) } +} +#[cfg(feature = "tokio")] +impl tokio::io::AsyncWrite for Writer { #[inline] - fn stream_id(&self) -> &stream::Id { - &self.stream_id + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: &[u8], + ) -> Poll> { + self.poll_write_from(cx, &mut buf, false) } #[inline] - pub fn max_header_len(&self) -> usize { - if self.stream_id().is_reliable { - encoder::MAX_RETRANSMISSION_HEADER_LEN - } else { - encoder::MAX_HEADER_LEN + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context, + buf: &[std::io::IoSlice], + ) -> Poll> { + let mut buf = buffer::reader::storage::IoSlice::new(buf); + self.poll_write_from(cx, &mut buf, false) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + // no-op to match TCP semantics + // https://github.com/tokio-rs/tokio/blob/ee68c1a8c211300ee862cbdd34c48292fa47ac3b/tokio/src/net/tcp/stream.rs#L1358 + Poll::Ready(Ok(())) + } + + #[inline] + fn poll_shutdown( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + self.0.shutdown(ShutdownType::Explicit).into() + } + + #[inline(always)] + fn is_write_vectored(&self) -> bool { + true + } +} + +impl Drop for Writer { + #[inline] + fn drop(&mut self) { + let _ = self.0.shutdown(ShutdownType::Drop { + is_panicking: std::thread::panicking(), + }); + } +} + +#[derive(Clone, Copy, Debug)] +enum ShutdownType { + Explicit, + Drop { is_panicking: bool }, +} + +pub struct Shutdown { + queue: queue::Queue, + shared: ArcShared, + sockets: socket::ArcApplication, + ty: ShutdownType, +} + +impl core::future::Future for Shutdown { + type Output = (); + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> { + let Self { + queue, + sockets, + shared, + ty, + } = self.get_mut(); + + // flush the buffer + let _ = ready!(queue.poll_flush( + cx, + usize::MAX, + sockets.write_application(), + &msg::addr::Addr::new(shared.write_remote_addr()), + &shared.sender.segment_alloc, + &shared.gso, + )); + + // If the application is explicitly shutting down then do the same. Otherwise let + // the stream `close` and send a RST + if matches!(ty, ShutdownType::Explicit) { + let _ = sockets.write_application().send_finish(); } + + Poll::Ready(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[allow(dead_code)] + fn shutdown_traits_test(shutdown: &Shutdown) { + use crate::testing::*; + + assert_send(shutdown); + assert_sync(shutdown); + assert_static(shutdown); } } diff --git a/dc/s2n-quic-dc/src/stream/send/application/builder.rs b/dc/s2n-quic-dc/src/stream/send/application/builder.rs new file mode 100644 index 0000000000..38b1aa9e29 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/application/builder.rs @@ -0,0 +1,33 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::stream::{ + runtime, + send::application::{Inner, Writer}, + shared::ArcShared, + socket, +}; + +pub struct Builder { + runtime: runtime::ArcHandle, +} + +impl Builder { + #[inline] + pub fn new(runtime: runtime::ArcHandle) -> Self { + Self { runtime } + } + + #[inline] + pub fn build(self, shared: ArcShared, sockets: socket::ArcApplication) -> Writer { + let Self { runtime } = self; + Writer(Box::new(Inner { + shared, + sockets, + queue: Default::default(), + pacer: Default::default(), + open: true, + runtime, + })) + } +} diff --git a/dc/s2n-quic-dc/src/stream/send/application/state.rs b/dc/s2n-quic-dc/src/stream/send/application/state.rs new file mode 100644 index 0000000000..e0c7e199e1 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/application/state.rs @@ -0,0 +1,174 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + crypto::encrypt, + packet::stream::{self, encoder}, + stream::{ + packet_number, + send::{application::transmission, error::Error, flow, path, probes}, + }, +}; +use bytes::buf::UninitSlice; +use s2n_codec::EncoderBuffer; +use s2n_quic_core::{ + buffer::{self, reader::Storage as _, Reader as _}, + ensure, + time::Clock, + varint::VarInt, +}; + +pub trait Message { + fn max_segments(&self) -> usize; + fn push transmission::Event<()>>( + &mut self, + buffer_len: usize, + p: P, + ); +} + +#[derive(Clone, Copy, Debug)] +pub struct State { + pub stream_id: stream::Id, + pub source_control_port: u16, + pub source_stream_port: Option, +} + +impl State { + #[inline] + pub fn transmit( + &self, + credits: flow::Credits, + path: &path::Info, + storage: &mut I, + packet_number: &packet_number::Counter, + encrypt_key: &E, + clock: &Clk, + message: &mut M, + ) -> Result<(), Error> + where + E: encrypt::Key, + I: buffer::reader::Storage, + Clk: Clock, + M: Message, + { + ensure!( + credits.len > 0 || storage.buffer_is_empty() || credits.is_fin, + Ok(()) + ); + + let mut reader = buffer::reader::Incremental::new(credits.offset); + let mut reader = reader.with_storage(storage, credits.is_fin)?; + debug_assert!( + reader.buffered_len() >= credits.len, + "attempted to acquire more credits than what is buffered" + ); + let mut reader = reader.with_read_limit(credits.len); + + let stream_id = *self.stream_id(); + let max_header_len = self.max_header_len(); + + let mut total_payload_len = 0; + + loop { + let packet_number = packet_number.next()?; + + let buffer_len = { + let estimated_len = reader.buffered_len() + max_header_len; + (path.max_datagram_size as usize).min(estimated_len) + }; + + message.push(buffer_len, |buffer| { + let stream_offset = reader.current_offset(); + let mut reader = reader.track_read(); + + let buffer = unsafe { + // SAFETY: `buffer` is a valid `UninitSlice` but `EncoderBuffer` expects to + // write into a `&mut [u8]`. Here we construct a `&mut [u8]` since + // `EncoderBuffer` never actually reads from the slice and only writes to it. + core::slice::from_raw_parts_mut(buffer.as_mut_ptr(), buffer.len()) + }; + let encoder = EncoderBuffer::new(buffer); + let packet_len = encoder::encode( + encoder, + self.source_control_port, + self.source_stream_port, + stream_id, + stream::PacketSpace::Stream, + packet_number, + path.next_expected_control_packet, + VarInt::ZERO, + &mut &[][..], + VarInt::ZERO, + &(), + &mut reader, + encrypt_key, + ); + + // buffer is clamped to u16::MAX so this is safe to cast without loss + let _: u16 = path.max_datagram_size; + let packet_len = packet_len as u16; + let payload_len = reader.consumed_len() as u16; + total_payload_len += payload_len as usize; + + let has_more_app_data = credits.initial_len > total_payload_len; + + let included_fin = reader.final_offset().map_or(false, |fin| { + stream_offset.as_u64() + payload_len as u64 == fin.as_u64() + }); + + let time_sent = clock.get_time(); + probes::on_transmit_stream( + encrypt_key.credentials().id, + stream_id, + stream::PacketSpace::Stream, + s2n_quic_core::packet::number::PacketNumberSpace::Initial + .new_packet_number(packet_number), + stream_offset, + payload_len, + included_fin, + false, + ); + + let info = transmission::Info { + packet_len, + retransmission: if stream_id.is_reliable { + Some(()) + } else { + None + }, + stream_offset, + payload_len, + included_fin, + time_sent, + ecn: path.ecn, + }; + + transmission::Event { + packet_number, + info, + has_more_app_data, + } + }); + + // bail if we've transmitted everything + ensure!(!reader.buffer_is_empty(), break); + } + + Ok(()) + } + + #[inline] + fn stream_id(&self) -> &stream::Id { + &self.stream_id + } + + #[inline] + pub fn max_header_len(&self) -> usize { + if self.stream_id().is_reliable { + encoder::MAX_RETRANSMISSION_HEADER_LEN + } else { + encoder::MAX_HEADER_LEN + } + } +} diff --git a/dc/s2n-quic-dc/src/stream/send/application/transmission.rs b/dc/s2n-quic-dc/src/stream/send/application/transmission.rs index cb4fb45ae1..4549337c9c 100644 --- a/dc/s2n-quic-dc/src/stream/send/application/transmission.rs +++ b/dc/s2n-quic-dc/src/stream/send/application/transmission.rs @@ -1,7 +1,7 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::stream::send::worker::transmission; +use crate::stream::send::state::transmission; use crossbeam_queue::{ArrayQueue, SegQueue}; use s2n_quic_core::{ensure, varint::VarInt}; use std::collections::VecDeque; diff --git a/dc/s2n-quic-dc/src/stream/send/queue.rs b/dc/s2n-quic-dc/src/stream/send/queue.rs index d15bbd811d..aed2ddde55 100644 --- a/dc/s2n-quic-dc/src/stream/send/queue.rs +++ b/dc/s2n-quic-dc/src/stream/send/queue.rs @@ -6,7 +6,8 @@ use crate::{ stream::{ send::{ application::{self, transmission}, - buffer, worker, + buffer, + state::Transmission, }, socket::Socket, }, @@ -33,13 +34,13 @@ impl Segment { } pub struct Message<'a> { - batch: &'a mut Option>, + batch: &'a mut Option>, queue: &'a mut Queue, max_segments: usize, segment_alloc: &'a buffer::Allocator, } -impl<'a> application::Message for Message<'a> { +impl<'a> application::state::Message for Message<'a> { #[inline] fn max_segments(&self) -> usize { self.max_segments @@ -125,7 +126,7 @@ impl Queue { pub fn push_buffer( &mut self, buf: &mut B, - batch: &mut Option>, + batch: &mut Option>, max_segments: usize, segment_alloc: &buffer::Allocator, push: F, diff --git a/dc/s2n-quic-dc/src/stream/send/shared.rs b/dc/s2n-quic-dc/src/stream/send/shared.rs index 879d3c62bf..61cad7f3ad 100644 --- a/dc/s2n-quic-dc/src/stream/send/shared.rs +++ b/dc/s2n-quic-dc/src/stream/send/shared.rs @@ -6,7 +6,7 @@ use crate::{ packet_number, send::{ application::transmission, buffer, error::Error, flow, path, queue::Queue, - worker::Transmission, + state::Transmission, }, }, task::waker::worker::Waker as WorkerWaker, diff --git a/dc/s2n-quic-dc/src/stream/send/snapshots/s2n_quic_dc__stream__send__worker__waiting__dot_test.snap b/dc/s2n-quic-dc/src/stream/send/snapshots/s2n_quic_dc__stream__send__worker__waiting__dot_test.snap new file mode 100644 index 0000000000..7767cd4452 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/snapshots/s2n_quic_dc__stream__send__worker__waiting__dot_test.snap @@ -0,0 +1,15 @@ +--- +source: dc/s2n-quic-dc/src/stream/send/worker.rs +expression: "State::dot()" +--- +digraph { + label = "s2n_quic_dc::stream::send::worker::waiting::State"; + Acking; + Detached; + Finished; + ShuttingDown; + Acking -> Detached [label = "on_application_detach"]; + Acking -> ShuttingDown [label = "on_shutdown"]; + Detached -> ShuttingDown [label = "on_shutdown"]; + ShuttingDown -> Finished [label = "on_finished"]; +} diff --git a/dc/s2n-quic-dc/src/stream/send/state.rs b/dc/s2n-quic-dc/src/stream/send/state.rs new file mode 100644 index 0000000000..4a2f5e8c52 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/state.rs @@ -0,0 +1,1269 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + congestion, + crypto::{decrypt, encrypt, UninitSlice}, + packet::{ + self, + stream::{self, decoder, encoder}, + }, + recovery, + stream::{ + processing, + send::{ + application, buffer, error::Error, filter::Filter, probes, + transmission::Type as TransmissionType, + }, + DEFAULT_IDLE_TIMEOUT, + }, +}; +use core::{task::Poll, time::Duration}; +use s2n_codec::{DecoderBufferMut, EncoderBuffer}; +use s2n_quic_core::{ + dc::ApplicationParams, + ensure, + frame::{self, FrameMut}, + inet::ExplicitCongestionNotification, + interval_set::IntervalSet, + packet::number::PacketNumberSpace, + path::{ecn, INITIAL_PTO_BACKOFF}, + random, ready, + recovery::{Pto, RttEstimator}, + stream::state, + time::{ + timer::{self, Provider as _}, + Clock, Timer, Timestamp, + }, + varint::VarInt, +}; +use slotmap::SlotMap; +use std::collections::{BinaryHeap, VecDeque}; +use tracing::{debug, trace}; + +pub mod probe; +pub mod retransmission; +pub mod transmission; + +type PacketMap = s2n_quic_core::packet::number::Map; + +pub type Transmission = application::transmission::Event; + +slotmap::new_key_type! { + pub struct BufferIndex; +} + +#[derive(Clone, Copy, Debug)] +pub enum TransmitIndex { + Stream(BufferIndex), + Recovery(BufferIndex), +} + +#[derive(Debug)] +pub struct SentStreamPacket { + info: transmission::Info, + cc_info: congestion::PacketInfo, +} + +#[derive(Debug)] +pub struct SentRecoveryPacket { + info: transmission::Info, + cc_info: congestion::PacketInfo, + max_stream_packet_number: VarInt, +} + +#[derive(Debug)] +pub struct State { + pub stream_id: stream::Id, + pub rtt_estimator: RttEstimator, + pub sent_stream_packets: PacketMap, + pub stream_packet_buffers: SlotMap, + pub max_stream_packet_number: VarInt, + pub sent_recovery_packets: PacketMap, + pub recovery_packet_buffers: SlotMap>, + pub free_packet_buffers: Vec>, + pub recovery_packet_number: u64, + pub last_sent_recovery_packet: Option, + pub transmit_queue: VecDeque, + pub state: state::Sender, + pub control_filter: Filter, + pub retransmissions: BinaryHeap>, + pub next_expected_control_packet: VarInt, + pub cca: congestion::Controller, + pub ecn: ecn::Controller, + pub pto: Pto, + pub pto_backoff: u32, + pub inflight_timer: Timer, + pub idle_timer: Timer, + pub idle_timeout: Duration, + pub error: Option, + pub unacked_ranges: IntervalSet, + pub max_sent_offset: VarInt, + pub max_data: VarInt, + pub local_max_data_window: VarInt, + pub peer_activity: Option, + pub max_datagram_size: u16, + pub max_sent_segment_size: u16, +} + +#[derive(Clone, Copy, Debug)] +pub struct PeerActivity { + pub newly_acked_packets: bool, +} + +impl State { + #[inline] + pub fn new(stream_id: stream::Id, params: &ApplicationParams) -> Self { + let max_datagram_size = params.max_datagram_size; + let initial_max_data = params.remote_max_data; + let local_max_data = params.local_send_max_data; + + // initialize the pending data left to send + let mut unacked_ranges = IntervalSet::new(); + unacked_ranges.insert(VarInt::ZERO..=VarInt::MAX).unwrap(); + + let cca = congestion::Controller::new(max_datagram_size); + let max_sent_offset = VarInt::ZERO; + + Self { + stream_id, + next_expected_control_packet: VarInt::ZERO, + rtt_estimator: recovery::rtt_estimator(), + cca, + sent_stream_packets: Default::default(), + stream_packet_buffers: Default::default(), + max_stream_packet_number: VarInt::ZERO, + sent_recovery_packets: Default::default(), + recovery_packet_buffers: Default::default(), + recovery_packet_number: 0, + last_sent_recovery_packet: None, + free_packet_buffers: Default::default(), + transmit_queue: Default::default(), + control_filter: Default::default(), + ecn: ecn::Controller::default(), + state: Default::default(), + retransmissions: Default::default(), + pto: Pto::default(), + pto_backoff: INITIAL_PTO_BACKOFF, + inflight_timer: Default::default(), + idle_timer: Default::default(), + idle_timeout: params.max_idle_timeout.unwrap_or(DEFAULT_IDLE_TIMEOUT), + error: None, + unacked_ranges, + max_sent_offset, + max_data: initial_max_data, + local_max_data_window: local_max_data, + peer_activity: None, + max_datagram_size, + max_sent_segment_size: 0, + } + } + + /// Initializes the worker as a client + #[inline] + pub fn init_client(&mut self, clock: &impl Clock) { + debug_assert!(self.state.is_ready()); + // make sure a packet gets sent soon if the application doesn't + self.force_arm_pto_timer(clock); + } + + /// Returns the current flow offset + #[inline] + pub fn flow_offset(&self) -> VarInt { + let cca_offset = { + let extra_window = self + .cca + .congestion_window() + .saturating_sub(self.cca.bytes_in_flight()); + + self.max_sent_offset + extra_window as usize + }; + + let local_offset = { + let unacked_start = self.unacked_ranges.min_value().unwrap_or_default(); + let local_max_data_window = self.local_max_data_window; + + unacked_start.saturating_add(local_max_data_window) + }; + + let remote_offset = self.max_data; + + cca_offset.min(local_offset).min(remote_offset) + } + + #[inline] + pub fn send_quantum_packets(&self) -> u8 { + // TODO use div_ceil when we're on 1.73+ MSRV + // https://doc.rust-lang.org/std/primitive.u64.html#method.div_ceil + let send_quantum = (self.cca.send_quantum() as u64 + self.max_datagram_size as u64 - 1) + / self.max_datagram_size as u64; + send_quantum.try_into().unwrap_or(u8::MAX) + } + + /// Called by the worker when it receives a control packet from the peer + #[inline] + pub fn on_control_packet( + &mut self, + decrypt_key: &D, + ecn: ExplicitCongestionNotification, + packet: &mut packet::control::decoder::Packet, + random: &mut dyn random::Generator, + clock: &Clk, + transmission_queue: &application::transmission::Queue, + segment_alloc: &buffer::Allocator, + ) -> Result<(), processing::Error> + where + D: decrypt::Key, + Clk: Clock, + { + match self.on_control_packet_impl( + decrypt_key, + ecn, + packet, + random, + clock, + transmission_queue, + segment_alloc, + ) { + Ok(None) => {} + Ok(Some(error)) => return Err(error), + Err(error) => { + self.on_error(error); + } + } + + self.invariants(); + + Ok(()) + } + + #[inline(always)] + fn on_control_packet_impl( + &mut self, + decrypt_key: &D, + _ecn: ExplicitCongestionNotification, + packet: &mut packet::control::decoder::Packet, + random: &mut dyn random::Generator, + clock: &Clk, + transmission_queue: &application::transmission::Queue, + segment_alloc: &buffer::Allocator, + ) -> Result, Error> + where + D: decrypt::Key, + Clk: Clock, + { + probes::on_control_packet( + decrypt_key.credentials().id, + self.stream_id, + packet.packet_number(), + packet.control_data().len(), + ); + + let key_phase = packet.tag().key_phase(); + + // only process the packet after we know it's authentic + let res = decrypt_key.decrypt( + key_phase, + packet.crypto_nonce(), + packet.header(), + &[], + packet.auth_tag(), + UninitSlice::new(&mut []), + ); + + probes::on_control_packet_decrypted( + decrypt_key.credentials().id, + self.stream_id, + packet.packet_number(), + packet.control_data().len(), + res.is_ok(), + ); + + // drop the packet if it failed to authenticate + if let Err(err) = res { + return Ok(Some(err.into())); + } + + // check if we've already seen the packet + ensure!( + self.control_filter.on_packet(packet).is_ok(), + return { + probes::on_control_packet_duplicate( + decrypt_key.credentials().id, + self.stream_id, + packet.packet_number(), + packet.control_data().len(), + ); + // drop the packet if we've already seen it + Ok(Some(processing::Error::Duplicate)) + } + ); + + let packet_number = packet.packet_number(); + + // raise our next expected control packet + { + let pn = packet_number.saturating_add(VarInt::from_u8(1)); + let pn = self.next_expected_control_packet.max(pn); + self.next_expected_control_packet = pn; + } + + let recv_time = clock.get_time(); + let mut newly_acked = false; + let mut max_acked_stream = None; + let mut max_acked_recovery = None; + let mut loaded_transmit_queue = false; + + { + let mut decoder = DecoderBufferMut::new(packet.control_data_mut()); + while !decoder.is_empty() { + let (frame, remaining) = decoder + .decode::() + .map_err(|decoder| Error::FrameError { decoder })?; + decoder = remaining; + + trace!(?frame); + + match frame { + FrameMut::Padding(_) => { + continue; + } + FrameMut::Ping(_) => { + // no need to do anything special here + } + FrameMut::Ack(ack) => { + if !core::mem::replace(&mut loaded_transmit_queue, true) { + // make sure we have a current view of the application transmissions + self.load_transmission_queue(transmission_queue); + } + + if ack.ecn_counts.is_some() { + self.on_frame_ack::<_, _, _, true>( + decrypt_key, + &ack, + random, + &recv_time, + &mut newly_acked, + &mut max_acked_stream, + &mut max_acked_recovery, + segment_alloc, + )?; + } else { + self.on_frame_ack::<_, _, _, false>( + decrypt_key, + &ack, + random, + &recv_time, + &mut newly_acked, + &mut max_acked_stream, + &mut max_acked_recovery, + segment_alloc, + )?; + } + } + FrameMut::MaxData(frame) => { + if self.max_data < frame.maximum_data { + self.max_data = frame.maximum_data; + } + } + FrameMut::ConnectionClose(close) => { + debug!(connection_close = ?close, state = ?self.state); + + probes::on_close( + decrypt_key.credentials().id, + self.stream_id, + packet_number, + close.error_code, + ); + + // if there was no error and we transmitted everything then just shut the + // stream down + if close.error_code == VarInt::ZERO + && close.frame_type.is_some() + && self.state.on_recv_all_acks().is_ok() + { + self.clean_up(); + // transmit one more PTO packet so we can ACK the peer's + // CONNECTION_CLOSE frame and they can shutdown quickly. Otherwise, + // they'll need to hang around to respond to potential loss. + self.pto.force_transmit(); + return Ok(None); + } + + // no need to transmit a reset back to the peer - just close it + let _ = self.state.on_send_reset(); + let _ = self.state.on_recv_reset_ack(); + let error = if close.frame_type.is_some() { + Error::TransportError { + code: close.error_code, + } + } else { + Error::ApplicationError { + error: close.error_code.into(), + } + }; + return Err(error); + } + _ => continue, + } + } + } + + for (space, pn) in [ + (stream::PacketSpace::Stream, max_acked_stream), + (stream::PacketSpace::Recovery, max_acked_recovery), + ] { + if let Some(pn) = pn { + self.detect_lost_packets(decrypt_key, random, &recv_time, space, pn)?; + } + } + + self.on_peer_activity(newly_acked); + + // try to transition to the final state if we've sent all of the data + if self.unacked_ranges.is_empty() + && self.error.is_none() + && self.state.on_recv_all_acks().is_ok() + { + self.clean_up(); + // transmit one more PTO packet so we can ACK the peer's + // CONNECTION_CLOSE frame and they can shutdown quickly. Otherwise, + // they'll need to hang around to respond to potential loss. + self.pto.force_transmit(); + } + + Ok(None) + } + + #[inline] + fn on_frame_ack( + &mut self, + decrypt_key: &D, + ack: &frame::Ack, + random: &mut dyn random::Generator, + clock: &Clk, + newly_acked: &mut bool, + max_acked_stream: &mut Option, + max_acked_recovery: &mut Option, + segment_alloc: &buffer::Allocator, + ) -> Result<(), Error> + where + D: decrypt::Key, + Ack: frame::ack::AckRanges, + Clk: Clock, + { + let mut cca_args = None; + let mut bytes_acked = 0; + + macro_rules! impl_ack_processing { + ($space:ident, $sent_packets:ident, $on_packet_number:expr) => { + for range in ack.ack_ranges() { + let pmin = PacketNumberSpace::Initial.new_packet_number(*range.start()); + let pmax = PacketNumberSpace::Initial.new_packet_number(*range.end()); + let range = s2n_quic_core::packet::number::PacketNumberRange::new(pmin, pmax); + for (num, packet) in self.$sent_packets.remove_range(range) { + let num_varint = unsafe { VarInt::new_unchecked(num.as_u64()) }; + + #[allow(clippy::redundant_closure_call)] + ($on_packet_number)(num_varint, &packet); + + let _ = self.unacked_ranges.remove(packet.info.tracking_range()); + + self.ecn + .on_packet_ack(packet.info.time_sent, packet.info.ecn); + bytes_acked += packet.info.cca_len() as usize; + + // record the most recent packet + if cca_args + .as_ref() + .map_or(true, |prev: &(Timestamp, _)| prev.0 < packet.info.time_sent) + { + cca_args = Some((packet.info.time_sent, packet.cc_info)); + } + + // free the retransmission segment + if let Some(segment) = packet.info.retransmission { + if let Some(segment) = self.stream_packet_buffers.remove(segment) { + // push the segment so the application can reuse it + if segment.capacity() >= self.max_sent_segment_size as usize { + segment_alloc.free(segment); + } + } + } + + probes::on_packet_ack( + decrypt_key.credentials().id, + self.stream_id, + stream::PacketSpace::$space, + num.as_u64(), + packet.info.packet_len, + packet.info.stream_offset, + packet.info.payload_len, + clock + .get_time() + .saturating_duration_since(packet.info.time_sent), + ); + + *newly_acked = true; + } + } + }; + } + + if IS_STREAM { + impl_ack_processing!( + Stream, + sent_stream_packets, + |packet_number: VarInt, _packet: &SentStreamPacket| { + *max_acked_stream = (*max_acked_stream).max(Some(packet_number)); + } + ); + } else { + impl_ack_processing!( + Recovery, + sent_recovery_packets, + |packet_number: VarInt, sent_packet: &SentRecoveryPacket| { + *max_acked_recovery = (*max_acked_recovery).max(Some(packet_number)); + *max_acked_stream = + (*max_acked_stream).max(Some(sent_packet.max_stream_packet_number)); + + // increase the max stream packet if this was a probe + if sent_packet.info.retransmission.is_none() { + self.max_stream_packet_number = self + .max_stream_packet_number + .max(sent_packet.max_stream_packet_number + 1); + } + } + ); + }; + + if let Some((time_sent, cc_info)) = cca_args { + let rtt_sample = clock.get_time().saturating_duration_since(time_sent); + + self.rtt_estimator.update_rtt( + ack.ack_delay(), + rtt_sample, + clock.get_time(), + true, + PacketNumberSpace::ApplicationData, + ); + + self.cca.on_packet_ack( + cc_info.first_sent_time, + bytes_acked, + cc_info, + &self.rtt_estimator, + random, + clock.get_time(), + ); + } + + Ok(()) + } + + #[inline] + fn detect_lost_packets( + &mut self, + decrypt_key: &D, + random: &mut dyn random::Generator, + clock: &Clk, + packet_space: stream::PacketSpace, + max: VarInt, + ) -> Result<(), Error> + where + D: decrypt::Key, + Clk: Clock, + { + let Some(loss_threshold) = max.checked_sub(VarInt::from_u8(2)) else { + return Ok(()); + }; + + let mut is_unrecoverable = false; + + macro_rules! impl_loss_detection { + ($sent_packets:ident, $on_packet:expr) => {{ + let lost_min = PacketNumberSpace::Initial.new_packet_number(VarInt::ZERO); + let lost_max = PacketNumberSpace::Initial.new_packet_number(loss_threshold); + let range = s2n_quic_core::packet::number::PacketNumberRange::new(lost_min, lost_max); + for (num, packet) in self.$sent_packets.remove_range(range) { + // TODO create a path and publisher + // self.ecn.on_packet_loss(packet.time_sent, packet.ecn, now, path, publisher); + + self.cca.on_packet_lost( + packet.info.cca_len() as _, + packet.cc_info, + random, + clock.get_time(), + ); + + probes::on_packet_lost( + decrypt_key.credentials().id, + self.stream_id, + packet_space, + num.as_u64(), + packet.info.packet_len, + packet.info.stream_offset, + packet.info.payload_len, + clock + .get_time() + .saturating_duration_since(packet.info.time_sent), + packet.info.retransmission.is_some(), + ); + + #[allow(clippy::redundant_closure_call)] + ($on_packet)(&packet); + + if let Some(segment) = packet.info.retransmission { + // update our local packet number to be at least 1 more than the largest lost + // packet number + let min_recovery_packet_number = num.as_u64() + 1; + self.recovery_packet_number = + self.recovery_packet_number.max(min_recovery_packet_number); + + let retransmission = retransmission::Segment { + segment, + stream_offset: packet.info.stream_offset, + payload_len: packet.info.payload_len, + ty: TransmissionType::Stream, + included_fin: packet.info.included_fin, + }; + self.retransmissions.push(retransmission); + } else { + // we can only recover reliable streams + is_unrecoverable |= packet.info.payload_len > 0 && !self.stream_id.is_reliable; + } + }} + } + } + + match packet_space { + stream::PacketSpace::Stream => impl_loss_detection!(sent_stream_packets, |_| {}), + stream::PacketSpace::Recovery => { + impl_loss_detection!(sent_recovery_packets, |sent_packet: &SentRecoveryPacket| { + self.max_stream_packet_number = self + .max_stream_packet_number + .max(sent_packet.max_stream_packet_number + 1); + }) + } + } + + ensure!(!is_unrecoverable, Err(Error::RetransmissionFailure)); + + self.invariants(); + + Ok(()) + } + + #[inline] + fn on_peer_activity(&mut self, newly_acked_packets: bool) { + if let Some(prev) = self.peer_activity.as_mut() { + prev.newly_acked_packets |= newly_acked_packets; + } else { + self.peer_activity = Some(PeerActivity { + newly_acked_packets, + }); + } + } + + #[inline] + pub fn before_sleep(&mut self, clock: &Clk) { + self.process_peer_activity(); + + // make sure our timers are armed + self.update_idle_timer(clock); + self.update_inflight_timer(clock); + self.update_pto_timer(clock); + + trace!( + unacked_ranges = ?self.unacked_ranges, + retransmissions = self.retransmissions.len(), + stream_packets_in_flight = self.sent_stream_packets.iter().count(), + recovery_packets_in_flight = self.sent_recovery_packets.iter().count(), + pto_timer = ?self.pto.next_expiration(), + inflight_timer = ?self.inflight_timer.next_expiration(), + idle_timer = ?self.idle_timer.next_expiration(), + ); + } + + #[inline] + fn process_peer_activity(&mut self) { + let Some(PeerActivity { + newly_acked_packets, + }) = self.peer_activity.take() + else { + return; + }; + + if newly_acked_packets { + self.reset_pto_timer(); + } + + // force probing when we've sent all of the data but haven't got an ACK for the final + // offset + if self.state.is_data_sent() && self.stream_packet_buffers.is_empty() { + self.pto.force_transmit(); + } + + // re-arm the idle timer as long as we're not in terminal state + if !self.state.is_terminal() { + self.idle_timer.cancel(); + self.inflight_timer.cancel(); + } + } + + #[inline] + pub fn on_time_update(&mut self, clock: &Clk, load_last_activity: Ld) + where + Clk: Clock, + Ld: FnOnce() -> Timestamp, + { + if self.poll_idle_timer(clock, load_last_activity).is_ready() { + self.on_error(Error::IdleTimeout); + // we don't actually want to send any packets on idle timeout + let _ = self.state.on_send_reset(); + let _ = self.state.on_recv_reset_ack(); + return; + } + + if self + .inflight_timer + .poll_expiration(clock.get_time()) + .is_ready() + { + self.on_error(Error::IdleTimeout); + return; + } + + if self + .pto + .on_timeout(self.has_inflight_packets(), clock.get_time()) + .is_ready() + { + // TODO where does this come from + let max_pto_backoff = 1024; + self.pto_backoff = self.pto_backoff.saturating_mul(2).min(max_pto_backoff); + } + } + + #[inline] + fn poll_idle_timer(&mut self, clock: &Clk, load_last_activity: Ld) -> Poll<()> + where + Clk: Clock, + Ld: FnOnce() -> Timestamp, + { + let now = clock.get_time(); + + // check the idle timer first + ready!(self.idle_timer.poll_expiration(now)); + + // if that expired then load the last activity from the peer and update the idle timer with + // the value + let last_peer_activity = load_last_activity(); + self.update_idle_timer(&last_peer_activity); + + // check the idle timer once more before returning + ready!(self.idle_timer.poll_expiration(now)); + + Poll::Ready(()) + } + + #[inline] + fn has_inflight_packets(&self) -> bool { + !self.sent_stream_packets.is_empty() + || !self.sent_recovery_packets.is_empty() + || !self.retransmissions.is_empty() + || !self.transmit_queue.is_empty() + } + + #[inline] + fn update_idle_timer(&mut self, clock: &impl Clock) { + ensure!(!self.idle_timer.is_armed()); + + let now = clock.get_time(); + self.idle_timer.set(now + self.idle_timeout); + } + + #[inline] + fn update_inflight_timer(&mut self, clock: &impl Clock) { + // TODO make this configurable + let inflight_timeout = crate::stream::DEFAULT_INFLIGHT_TIMEOUT; + + if self.has_inflight_packets() { + if !self.inflight_timer.is_armed() { + self.inflight_timer.set(clock.get_time() + inflight_timeout); + } + } else { + self.inflight_timer.cancel(); + } + } + + #[inline] + fn update_pto_timer(&mut self, clock: &impl Clock) { + ensure!(!self.pto.is_armed()); + + let mut should_arm = self.has_inflight_packets(); + + // if we have stream packet buffers in flight then arm the PTO + should_arm |= !self.stream_packet_buffers.is_empty(); + + // if we've sent all of the data/reset and are waiting to clean things up + should_arm |= self.state.is_data_sent() || self.state.is_reset_sent(); + + ensure!(should_arm); + + self.force_arm_pto_timer(clock); + } + + #[inline] + fn force_arm_pto_timer(&mut self, clock: &impl Clock) { + let pto_period = self + .rtt_estimator + .pto_period(self.pto_backoff, PacketNumberSpace::Initial); + self.pto.update(clock.get_time(), pto_period); + } + + #[inline] + fn reset_pto_timer(&mut self) { + self.pto_backoff = INITIAL_PTO_BACKOFF; + self.pto.cancel(); + } + + /// Called by the worker thread when it becomes aware of the application having transmitted a + /// segment + #[inline] + pub fn load_transmission_queue( + &mut self, + queue: &application::transmission::Queue, + ) -> bool { + let mut did_transmit_stream = false; + + for Transmission { + packet_number, + info, + has_more_app_data, + } in queue.drain() + { + self.max_sent_segment_size = self.max_sent_segment_size.max(info.packet_len); + let info = info.map(|buffer| self.stream_packet_buffers.insert(buffer)); + self.on_transmit_segment( + stream::PacketSpace::Stream, + packet_number, + info, + has_more_app_data, + ); + did_transmit_stream = true; + } + + if did_transmit_stream { + // if we just sent some packets then we can use those as probes + self.reset_pto_timer(); + } + + self.invariants(); + + did_transmit_stream + } + + #[inline] + fn on_transmit_segment( + &mut self, + packet_space: stream::PacketSpace, + packet_number: VarInt, + info: transmission::Info, + has_more_app_data: bool, + ) { + // the BBR implementation requires monotonic time so track that + let mut cca_time_sent = info.time_sent; + + match packet_space { + stream::PacketSpace::Stream => { + if let Some(min) = self.last_sent_recovery_packet { + cca_time_sent = info.time_sent.max(min); + } + } + stream::PacketSpace::Recovery => { + self.last_sent_recovery_packet = Some(info.time_sent); + } + } + + let cc_info = self.cca.on_packet_sent( + cca_time_sent, + info.cca_len(), + has_more_app_data, + &self.rtt_estimator, + ); + + // update the max offset that we've transmitted + self.max_sent_offset = self.max_sent_offset.max(info.end_offset()); + + // try to transition to start sending + let _ = self.state.on_send_stream(); + if info.included_fin { + // if the transmission included the final offset, then transition to that state + let _ = self.state.on_send_fin(); + } + + if let stream::PacketSpace::Recovery = packet_space { + let packet_number = PacketNumberSpace::Initial.new_packet_number(packet_number); + let max_stream_packet_number = self.max_stream_packet_number; + self.sent_recovery_packets.insert( + packet_number, + SentRecoveryPacket { + info, + cc_info, + max_stream_packet_number, + }, + ); + } else { + self.max_stream_packet_number = self.max_stream_packet_number.max(packet_number); + let packet_number = PacketNumberSpace::Initial.new_packet_number(packet_number); + self.sent_stream_packets + .insert(packet_number, SentStreamPacket { info, cc_info }); + } + } + + #[inline] + pub fn fill_transmit_queue( + &mut self, + encrypt_key: &E, + source_control_port: u16, + clock: &Clk, + ) -> Result<(), Error> + where + E: encrypt::Key, + Clk: Clock, + { + if let Err(error) = self.fill_transmit_queue_impl(encrypt_key, source_control_port, clock) { + self.on_error(error); + return Err(error); + } + + Ok(()) + } + + #[inline] + fn fill_transmit_queue_impl( + &mut self, + encrypt_key: &E, + source_control_port: u16, + clock: &Clk, + ) -> Result<(), Error> + where + E: encrypt::Key, + Clk: Clock, + { + // skip a packet number if we're probing + if self.pto.transmissions() > 0 { + self.recovery_packet_number += 1; + } + + self.try_transmit_retransmissions(encrypt_key, clock)?; + self.try_transmit_probe(encrypt_key, source_control_port, clock)?; + + Ok(()) + } + + #[inline] + fn try_transmit_retransmissions( + &mut self, + encrypt_key: &E, + clock: &Clk, + ) -> Result<(), Error> + where + E: encrypt::Key, + Clk: Clock, + { + // We'll only have retransmissions if we're reliable + ensure!(self.stream_id.is_reliable, Ok(())); + + while let Some(retransmission) = self.retransmissions.peek() { + // make sure we fit in the current congestion window + let remaining_cca_window = self + .cca + .congestion_window() + .saturating_sub(self.cca.bytes_in_flight()); + ensure!( + retransmission.payload_len as u32 <= remaining_cca_window, + break + ); + + let buffer = self.stream_packet_buffers[retransmission.segment].make_mut(); + + debug_assert!(!buffer.is_empty(), "empty retransmission buffer submitted"); + + let packet_number = + VarInt::new(self.recovery_packet_number).expect("2^62 is a lot of packets"); + self.recovery_packet_number += 1; + + { + let buffer = DecoderBufferMut::new(buffer); + match decoder::Packet::retransmit( + buffer, + stream::PacketSpace::Recovery, + packet_number, + encrypt_key, + ) { + Ok(info) => info, + Err(err) => { + // this shouldn't ever happen + debug_assert!(false, "{err:?}"); + return Err(Error::RetransmissionFailure); + } + } + }; + + let time_sent = clock.get_time(); + let packet_len = buffer.len() as u16; + + { + let info = self + .retransmissions + .pop() + .expect("retransmission should be available"); + + // enqueue the transmission + self.transmit_queue + .push_back(TransmitIndex::Stream(info.segment)); + + let stream_offset = info.stream_offset; + let payload_len = info.payload_len; + let included_fin = info.included_fin; + let retransmission = Some(info.segment); + + // TODO store this as part of the packet queue + let ecn = ExplicitCongestionNotification::Ect0; + + let info = transmission::Info { + packet_len, + stream_offset, + payload_len, + included_fin, + retransmission, + time_sent, + ecn, + }; + + probes::on_transmit_stream( + encrypt_key.credentials().id, + self.stream_id, + stream::PacketSpace::Recovery, + PacketNumberSpace::Initial.new_packet_number(packet_number), + stream_offset, + payload_len, + included_fin, + true, + ); + + self.on_transmit_segment(stream::PacketSpace::Recovery, packet_number, info, false); + + // consider this transmission a probe if needed + if self.pto.transmissions() > 0 { + self.pto.on_transmit_once(); + } + } + } + + Ok(()) + } + + #[inline] + pub fn try_transmit_probe( + &mut self, + encrypt_key: &E, + source_control_port: u16, + clock: &Clk, + ) -> Result<(), Error> + where + E: encrypt::Key, + Clk: Clock, + { + while self.pto.transmissions() > 0 { + // probes are not congestion-controlled + + let packet_number = + VarInt::new(self.recovery_packet_number).expect("2^62 is a lot of packets"); + self.recovery_packet_number += 1; + + let mut buffer = self.free_packet_buffers.pop().unwrap_or_default(); + + // resize the buffer to what we need + { + let min_len = stream::encoder::MAX_RETRANSMISSION_HEADER_LEN + 128; + + if buffer.capacity() < min_len { + buffer.reserve(min_len - buffer.len()); + } + + unsafe { + debug_assert!(buffer.capacity() >= min_len); + buffer.set_len(min_len); + } + } + + let offset = self.max_sent_offset; + let final_offset = if self.state.is_data_sent() { + Some(offset) + } else { + None + }; + + let mut payload = probe::Probe { + offset, + final_offset, + }; + + let encoder = EncoderBuffer::new(&mut buffer); + let packet_len = encoder::encode( + encoder, + source_control_port, + None, + self.stream_id, + stream::PacketSpace::Recovery, + packet_number, + self.next_expected_control_packet, + VarInt::ZERO, + &mut &[][..], + VarInt::ZERO, + &(), + &mut payload, + encrypt_key, + ); + + let payload_len = 0; + let included_fin = final_offset.is_some(); + buffer.truncate(packet_len); + + debug_assert!( + packet_len < u16::MAX as usize, + "cannot write larger packets than 2^16" + ); + let packet_len = packet_len as u16; + + let time_sent = clock.get_time(); + + // TODO store this as part of the packet queue + let ecn = ExplicitCongestionNotification::Ect0; + + // enqueue the transmission + let buffer_index = self.recovery_packet_buffers.insert(buffer); + self.transmit_queue + .push_back(TransmitIndex::Recovery(buffer_index)); + + let info = transmission::Info { + packet_len, + stream_offset: offset, + payload_len, + included_fin, + retransmission: None, // PTO packets are not retransmitted + time_sent, + ecn, + }; + + self.on_transmit_segment(stream::PacketSpace::Recovery, packet_number, info, false); + + self.pto.on_transmit_once(); + } + + Ok(()) + } + + #[inline] + pub fn transmit_queue_iter( + &mut self, + clock: &Clk, + ) -> impl Iterator + '_ { + let ecn = self + .ecn + .ecn(s2n_quic_core::transmission::Mode::Normal, clock.get_time()); + let stream_packet_buffers = &self.stream_packet_buffers; + let recovery_packet_buffers = &self.recovery_packet_buffers; + + self.transmit_queue.iter().filter_map(move |index| { + let buf = match *index { + TransmitIndex::Stream(index) => stream_packet_buffers.get(index)?.as_slice(), + TransmitIndex::Recovery(index) => recovery_packet_buffers.get(index)?, + }; + + Some((ecn, buf)) + }) + } + + #[inline] + pub fn on_transmit_queue(&mut self, count: usize) { + for transmission in self.transmit_queue.drain(..count) { + match transmission { + TransmitIndex::Stream(index) => { + // make sure the packet wasn't freed between when we wanted to transmit and + // when we actually did + ensure!(self.stream_packet_buffers.get(index).is_some(), continue); + } + TransmitIndex::Recovery(index) => { + // make sure the packet wasn't freed between when we wanted to transmit and + // when we actually did + let Some(mut buffer) = self.recovery_packet_buffers.remove(index) else { + continue; + }; + buffer.clear(); + self.free_packet_buffers.push(buffer); + } + }; + } + } + + #[inline] + pub fn on_error(&mut self, error: Error) { + ensure!(self.error.is_none()); + self.error = Some(error); + let _ = self.state.on_queue_reset(); + + self.clean_up(); + } + + #[inline] + fn clean_up(&mut self) { + self.retransmissions.clear(); + let min = PacketNumberSpace::Initial.new_packet_number(VarInt::ZERO); + let max = PacketNumberSpace::Initial.new_packet_number(VarInt::MAX); + let range = s2n_quic_core::packet::number::PacketNumberRange::new(min, max); + let _ = self.sent_stream_packets.remove_range(range); + let _ = self.sent_recovery_packets.remove_range(range); + + self.idle_timer.cancel(); + self.inflight_timer.cancel(); + self.pto.cancel(); + self.unacked_ranges.clear(); + + self.transmit_queue.clear(); + for buffer in self.stream_packet_buffers.drain() { + // TODO push buffer into free segment queue + let _ = buffer; + } + for (_idx, mut buffer) in self.recovery_packet_buffers.drain() { + buffer.clear(); + self.free_packet_buffers.push(buffer); + } + + self.invariants(); + } + + #[cfg(debug_assertions)] + #[inline] + fn invariants(&self) { + // TODO + } + + #[cfg(not(debug_assertions))] + #[inline(always)] + fn invariants(&self) {} +} + +impl timer::Provider for State { + #[inline] + fn timers(&self, query: &mut Q) -> timer::Result { + // if we're in a terminal state then no timers are needed + ensure!(!self.state.is_terminal(), Ok(())); + self.pto.timers(query)?; + self.idle_timer.timers(query)?; + Ok(()) + } +} diff --git a/dc/s2n-quic-dc/src/stream/send/worker/probe.rs b/dc/s2n-quic-dc/src/stream/send/state/probe.rs similarity index 100% rename from dc/s2n-quic-dc/src/stream/send/worker/probe.rs rename to dc/s2n-quic-dc/src/stream/send/state/probe.rs diff --git a/dc/s2n-quic-dc/src/stream/send/worker/retransmission.rs b/dc/s2n-quic-dc/src/stream/send/state/retransmission.rs similarity index 100% rename from dc/s2n-quic-dc/src/stream/send/worker/retransmission.rs rename to dc/s2n-quic-dc/src/stream/send/state/retransmission.rs diff --git a/dc/s2n-quic-dc/src/stream/send/worker/transmission.rs b/dc/s2n-quic-dc/src/stream/send/state/transmission.rs similarity index 100% rename from dc/s2n-quic-dc/src/stream/send/worker/transmission.rs rename to dc/s2n-quic-dc/src/stream/send/state/transmission.rs diff --git a/dc/s2n-quic-dc/src/stream/send/worker.rs b/dc/s2n-quic-dc/src/stream/send/worker.rs index 7301b6a1f0..0eab1fea93 100644 --- a/dc/s2n-quic-dc/src/stream/send/worker.rs +++ b/dc/s2n-quic-dc/src/stream/send/worker.rs @@ -2,1265 +2,475 @@ // SPDX-License-Identifier: Apache-2.0 use crate::{ - congestion, - crypto::{decrypt, encrypt, UninitSlice}, - packet::{ - self, - stream::{self, decoder, encoder}, - }, - recovery, + clock::{Clock, Timer}, + msg, + msg::addr, + packet::Packet, stream::{ - processing, - send::{ - application, buffer, error::Error, filter::Filter, probes, - transmission::Type as TransmissionType, - }, - DEFAULT_IDLE_TIMEOUT, + pacer, processing, + send::{error::Error, queue::Queue, shared::Event, state::State}, + shared::{self, Half}, + socket::Socket, }, }; -use core::{task::Poll, time::Duration}; -use s2n_codec::{DecoderBufferMut, EncoderBuffer}; +use core::task::{Context, Poll}; +use s2n_codec::DecoderBufferMut; use s2n_quic_core::{ - dc::ApplicationParams, - ensure, - frame::{self, FrameMut}, + endpoint, ensure, inet::ExplicitCongestionNotification, - interval_set::IntervalSet, - packet::number::PacketNumberSpace, - path::{ecn, INITIAL_PTO_BACKOFF}, random, ready, - recovery::{Pto, RttEstimator}, - stream::state, + recovery::bandwidth::Bandwidth, time::{ - timer::{self, Provider as _}, - Clock, Timer, Timestamp, + clock::{self, Timer as _}, + timer::Provider as _, + Timestamp, }, varint::VarInt, }; -use slotmap::SlotMap; -use std::collections::{BinaryHeap, VecDeque}; -use tracing::{debug, trace}; - -pub mod probe; -pub mod retransmission; -pub mod transmission; - -type PacketMap = s2n_quic_core::packet::number::Map; - -pub type Transmission = application::transmission::Event; - -slotmap::new_key_type! { - pub struct BufferIndex; -} +use std::sync::Arc; + +mod waiting { + use s2n_quic_core::state::event; + + #[derive(Clone, Debug, Default, PartialEq)] + pub enum State { + #[default] + Acking, + Detached, + ShuttingDown, + Finished, + } -#[derive(Clone, Copy, Debug)] -pub enum TransmitIndex { - Stream(BufferIndex), - Recovery(BufferIndex), -} + impl State { + event! { + on_application_detach(Acking => Detached); + on_shutdown(Acking | Detached => ShuttingDown); + on_finished(ShuttingDown => Finished); + } + } -#[derive(Debug)] -pub struct SentStreamPacket { - info: transmission::Info, - cc_info: congestion::PacketInfo, + #[test] + fn dot_test() { + insta::assert_snapshot!(State::dot()); + } } -#[derive(Debug)] -pub struct SentRecoveryPacket { - info: transmission::Info, - cc_info: congestion::PacketInfo, - max_stream_packet_number: VarInt, +pub struct Worker { + shared: Arc>, + sender: State, + recv_buffer: msg::recv::Message, + random: R, + state: waiting::State, + timer: Timer, + application_queue: Queue, + pacer: pacer::Naive, + socket: S, } #[derive(Debug)] -pub struct Worker { - pub stream_id: stream::Id, - pub rtt_estimator: RttEstimator, - pub sent_stream_packets: PacketMap, - pub stream_packet_buffers: SlotMap, - pub max_stream_packet_number: VarInt, - pub sent_recovery_packets: PacketMap, - pub recovery_packet_buffers: SlotMap>, - pub free_packet_buffers: Vec>, - pub recovery_packet_number: u64, - pub last_sent_recovery_packet: Option, - pub transmit_queue: VecDeque, - pub state: state::Sender, - pub control_filter: Filter, - pub retransmissions: BinaryHeap>, - pub next_expected_control_packet: VarInt, - pub cca: congestion::Controller, - pub ecn: ecn::Controller, - pub pto: Pto, - pub pto_backoff: u32, - pub inflight_timer: Timer, - pub idle_timer: Timer, - pub idle_timeout: Duration, - pub error: Option, - pub unacked_ranges: IntervalSet, - pub max_sent_offset: VarInt, - pub max_data: VarInt, - pub local_max_data_window: VarInt, - pub peer_activity: Option, - pub max_datagram_size: u16, - pub max_sent_segment_size: u16, +struct Snapshot { + flow_offset: VarInt, + has_pending_retransmissions: bool, + send_quantum: usize, + max_datagram_size: u16, + ecn: ExplicitCongestionNotification, + next_expected_control_packet: VarInt, + timeout: Option, + bandwidth: Bandwidth, + error: Option, } -#[derive(Clone, Copy, Debug)] -pub struct PeerActivity { - pub newly_acked_packets: bool, -} - -impl Worker { +impl Snapshot { #[inline] - pub fn new(stream_id: stream::Id, params: &ApplicationParams) -> Self { - let max_datagram_size = params.max_datagram_size; - let initial_max_data = params.remote_max_data; - let local_max_data = params.local_send_max_data; + fn apply(&self, initial: &Self, shared: &shared::Shared) { + if initial.flow_offset < self.flow_offset { + shared.sender.flow.release(self.flow_offset); + } else if initial.has_pending_retransmissions && !self.has_pending_retransmissions { + // we were waiting to clear out our retransmission queue before giving the application + // more flow credits + shared.sender.flow.release_max(self.flow_offset); + } - // initialize the pending data left to send - let mut unacked_ranges = IntervalSet::new(); - unacked_ranges.insert(VarInt::ZERO..=VarInt::MAX).unwrap(); + if initial.send_quantum != self.send_quantum { + let send_quantum = (self.send_quantum as u64 + self.max_datagram_size as u64 - 1) + / self.max_datagram_size as u64; + let send_quantum = send_quantum.try_into().unwrap_or(u8::MAX); + shared + .sender + .path + .update_info(self.ecn, send_quantum, self.max_datagram_size); + } - let cca = congestion::Controller::new(max_datagram_size); - let max_sent_offset = VarInt::ZERO; + if initial.next_expected_control_packet < self.next_expected_control_packet { + shared + .sender + .path + .set_next_expected_control_packet(self.next_expected_control_packet); + } - Self { - stream_id, - next_expected_control_packet: VarInt::ZERO, - rtt_estimator: recovery::rtt_estimator(), - cca, - sent_stream_packets: Default::default(), - stream_packet_buffers: Default::default(), - max_stream_packet_number: VarInt::ZERO, - sent_recovery_packets: Default::default(), - recovery_packet_buffers: Default::default(), - recovery_packet_number: 0, - last_sent_recovery_packet: None, - free_packet_buffers: Default::default(), - transmit_queue: Default::default(), - control_filter: Default::default(), - ecn: ecn::Controller::default(), - state: Default::default(), - retransmissions: Default::default(), - pto: Pto::default(), - pto_backoff: INITIAL_PTO_BACKOFF, - inflight_timer: Default::default(), - idle_timer: Default::default(), - idle_timeout: params.max_idle_timeout.unwrap_or(DEFAULT_IDLE_TIMEOUT), - error: None, - unacked_ranges, - max_sent_offset, - max_data: initial_max_data, - local_max_data_window: local_max_data, - peer_activity: None, - max_datagram_size, - max_sent_segment_size: 0, + if initial.bandwidth != self.bandwidth { + shared.sender.set_bandwidth(self.bandwidth); } - } - /// Initializes the worker as a client - #[inline] - pub fn init_client(&mut self, clock: &impl Clock) { - debug_assert!(self.state.is_ready()); - // make sure a packet gets sent soon if the application doesn't - self.force_arm_pto_timer(clock); + if let Some(error) = self.error { + if initial.error.is_none() { + shared.sender.flow.set_error(error); + } + } } +} - /// Returns the current flow offset +impl Worker +where + S: Socket, + R: random::Generator, + C: Clock, +{ #[inline] - pub fn flow_offset(&self) -> VarInt { - let cca_offset = { - let extra_window = self - .cca - .congestion_window() - .saturating_sub(self.cca.bytes_in_flight()); - - self.max_sent_offset + extra_window as usize - }; - - let local_offset = { - let unacked_start = self.unacked_ranges.min_value().unwrap_or_default(); - let local_max_data_window = self.local_max_data_window; - - unacked_start.saturating_add(local_max_data_window) - }; - - let remote_offset = self.max_data; + pub fn new( + socket: S, + random: R, + shared: Arc>, + mut sender: State, + endpoint: endpoint::Type, + ) -> Self { + let timer = Timer::new(&shared.clock); + let recv_buffer = msg::recv::Message::new(u16::MAX); + let state = Default::default(); + + // if this is a client then set up the sender + if endpoint.is_client() { + sender.init_client(&shared.clock); + } - cca_offset.min(local_offset).min(remote_offset) + Self { + shared, + sender, + recv_buffer, + random, + state, + timer, + application_queue: Default::default(), + pacer: Default::default(), + socket, + } } #[inline] - pub fn send_quantum_packets(&self) -> u8 { - // TODO use div_ceil when we're on 1.73+ MSRV - // https://doc.rust-lang.org/std/primitive.u64.html#method.div_ceil - let send_quantum = (self.cca.send_quantum() as u64 + self.max_datagram_size as u64 - 1) - / self.max_datagram_size as u64; - send_quantum.try_into().unwrap_or(u8::MAX) + pub fn update_waker(&self, cx: &mut Context) { + self.shared.sender.worker_waker.update(cx.waker()); } - /// Called by the worker when it receives a control packet from the peer #[inline] - pub fn on_control_packet( - &mut self, - decrypt_key: &D, - ecn: ExplicitCongestionNotification, - packet: &mut packet::control::decoder::Packet, - random: &mut dyn random::Generator, - clock: &Clk, - transmission_queue: &application::transmission::Queue, - segment_alloc: &buffer::Allocator, - ) -> Result<(), processing::Error> - where - D: decrypt::Key, - Clk: Clock, - { - match self.on_control_packet_impl( - decrypt_key, - ecn, - packet, - random, - clock, - transmission_queue, - segment_alloc, - ) { - Ok(None) => {} - Ok(Some(error)) => return Err(error), - Err(error) => { - self.on_error(error); - } - } - - self.invariants(); - - Ok(()) - } - - #[inline(always)] - fn on_control_packet_impl( - &mut self, - decrypt_key: &D, - _ecn: ExplicitCongestionNotification, - packet: &mut packet::control::decoder::Packet, - random: &mut dyn random::Generator, - clock: &Clk, - transmission_queue: &application::transmission::Queue, - segment_alloc: &buffer::Allocator, - ) -> Result, Error> - where - D: decrypt::Key, - Clk: Clock, - { - probes::on_control_packet( - decrypt_key.credentials().id, - self.stream_id, - packet.packet_number(), - packet.control_data().len(), - ); + pub fn poll(&mut self, cx: &mut Context) -> Poll<()> { + let initial = self.snapshot(); - // only process the packet after we know it's authentic - let res = decrypt_key.decrypt( - packet.crypto_nonce(), - packet.header(), - &[], - packet.auth_tag(), - UninitSlice::new(&mut []), - ); + let is_initial = self.sender.state.is_ready(); - probes::on_control_packet_decrypted( - decrypt_key.credentials().id, - self.stream_id, - packet.packet_number(), - packet.control_data().len(), - res.is_ok(), + tracing::trace!( + view = "before", + sender_state = ?self.sender.state, + worker_state = ?self.state, + snapshot = ?initial, ); - // drop the packet if it failed to authenticate - if let Err(err) = res { - return Ok(Some(err.into())); - } + self.shared.sender.worker_waker.on_worker_wake(); - // check if we've already seen the packet - ensure!( - self.control_filter.on_packet(packet).is_ok(), - return { - probes::on_control_packet_duplicate( - decrypt_key.credentials().id, - self.stream_id, - packet.packet_number(), - packet.control_data().len(), - ); - // drop the packet if we've already seen it - Ok(Some(processing::Error::Duplicate)) - } - ); - - let packet_number = packet.packet_number(); + self.poll_once(cx); - // raise our next expected control packet + // check if the application sent us any more messages + if !self + .shared + .sender + .worker_waker + .on_worker_sleep() + .is_working() { - let pn = packet_number.saturating_add(VarInt::from_u8(1)); - let pn = self.next_expected_control_packet.max(pn); - self.next_expected_control_packet = pn; + // yield to the runtime + cx.waker().wake_by_ref(); } - let recv_time = clock.get_time(); - let mut newly_acked = false; - let mut max_acked_stream = None; - let mut max_acked_recovery = None; - let mut loaded_transmit_queue = false; + let current = self.snapshot(); - { - let mut decoder = DecoderBufferMut::new(packet.control_data_mut()); - while !decoder.is_empty() { - let (frame, remaining) = decoder - .decode::() - .map_err(|decoder| Error::FrameError { decoder })?; - decoder = remaining; - - trace!(?frame); - - match frame { - FrameMut::Padding(_) => { - continue; - } - FrameMut::Ping(_) => { - // no need to do anything special here - } - FrameMut::Ack(ack) => { - if !core::mem::replace(&mut loaded_transmit_queue, true) { - // make sure we have a current view of the application transmissions - self.load_transmission_queue(transmission_queue); - } - - if ack.ecn_counts.is_some() { - self.on_frame_ack::<_, _, _, true>( - decrypt_key, - &ack, - random, - &recv_time, - &mut newly_acked, - &mut max_acked_stream, - &mut max_acked_recovery, - segment_alloc, - )?; - } else { - self.on_frame_ack::<_, _, _, false>( - decrypt_key, - &ack, - random, - &recv_time, - &mut newly_acked, - &mut max_acked_stream, - &mut max_acked_recovery, - segment_alloc, - )?; - } - } - FrameMut::MaxData(frame) => { - if self.max_data < frame.maximum_data { - self.max_data = frame.maximum_data; - } - } - FrameMut::ConnectionClose(close) => { - debug!(connection_close = ?close, state = ?self.state); - - probes::on_close( - decrypt_key.credentials().id, - self.stream_id, - packet_number, - close.error_code, - ); + tracing::trace!( + view = "after", + sender_state = ?self.sender.state, + worker_state = ?self.state, + snapshot = ?current, + ); - // if there was no error and we transmitted everything then just shut the - // stream down - if close.error_code == VarInt::ZERO - && close.frame_type.is_some() - && self.state.on_recv_all_acks().is_ok() - { - self.clean_up(); - // transmit one more PTO packet so we can ACK the peer's - // CONNECTION_CLOSE frame and they can shutdown quickly. Otherwise, - // they'll need to hang around to respond to potential loss. - self.pto.force_transmit(); - return Ok(None); - } - - // no need to transmit a reset back to the peer - just close it - let _ = self.state.on_send_reset(); - let _ = self.state.on_recv_reset_ack(); - let error = if close.frame_type.is_some() { - Error::TransportError { - code: close.error_code, - } - } else { - Error::ApplicationError { - error: close.error_code.into(), - } - }; - return Err(error); - } - _ => continue, + if is_initial || initial.timeout != current.timeout { + if let Some(target) = current.timeout { + self.timer.update(target); + if self.timer.poll_ready(cx).is_ready() { + cx.waker().wake_by_ref(); } + } else { + self.timer.cancel(); } } - for (space, pn) in [ - (stream::PacketSpace::Stream, max_acked_stream), - (stream::PacketSpace::Recovery, max_acked_recovery), - ] { - if let Some(pn) = pn { - self.detect_lost_packets(decrypt_key, random, &recv_time, space, pn)?; - } - } - - self.on_peer_activity(newly_acked); + current.apply(&initial, &self.shared); - // try to transition to the final state if we've sent all of the data - if self.unacked_ranges.is_empty() - && self.error.is_none() - && self.state.on_recv_all_acks().is_ok() - { - self.clean_up(); - // transmit one more PTO packet so we can ACK the peer's - // CONNECTION_CLOSE frame and they can shutdown quickly. Otherwise, - // they'll need to hang around to respond to potential loss. - self.pto.force_transmit(); + if let waiting::State::Finished = &self.state { + Poll::Ready(()) + } else { + Poll::Pending } - - Ok(None) } #[inline] - fn on_frame_ack( - &mut self, - decrypt_key: &D, - ack: &frame::Ack, - random: &mut dyn random::Generator, - clock: &Clk, - newly_acked: &mut bool, - max_acked_stream: &mut Option, - max_acked_recovery: &mut Option, - segment_alloc: &buffer::Allocator, - ) -> Result<(), Error> - where - D: decrypt::Key, - Ack: frame::ack::AckRanges, - Clk: Clock, - { - let mut cca_args = None; - let mut bytes_acked = 0; - - macro_rules! impl_ack_processing { - ($space:ident, $sent_packets:ident, $on_packet_number:expr) => { - for range in ack.ack_ranges() { - let pmin = PacketNumberSpace::Initial.new_packet_number(*range.start()); - let pmax = PacketNumberSpace::Initial.new_packet_number(*range.end()); - let range = s2n_quic_core::packet::number::PacketNumberRange::new(pmin, pmax); - for (num, packet) in self.$sent_packets.remove_range(range) { - let num_varint = unsafe { VarInt::new_unchecked(num.as_u64()) }; - - #[allow(clippy::redundant_closure_call)] - ($on_packet_number)(num_varint, &packet); - - let _ = self.unacked_ranges.remove(packet.info.tracking_range()); - - self.ecn - .on_packet_ack(packet.info.time_sent, packet.info.ecn); - bytes_acked += packet.info.cca_len() as usize; - - // record the most recent packet - if cca_args - .as_ref() - .map_or(true, |prev: &(Timestamp, _)| prev.0 < packet.info.time_sent) - { - cca_args = Some((packet.info.time_sent, packet.cc_info)); - } - - // free the retransmission segment - if let Some(segment) = packet.info.retransmission { - if let Some(segment) = self.stream_packet_buffers.remove(segment) { - // push the segment so the application can reuse it - if segment.capacity() >= self.max_sent_segment_size as usize { - segment_alloc.free(segment); - } - } - } - - probes::on_packet_ack( - decrypt_key.credentials().id, - self.stream_id, - stream::PacketSpace::$space, - num.as_u64(), - packet.info.packet_len, - packet.info.stream_offset, - packet.info.payload_len, - clock - .get_time() - .saturating_duration_since(packet.info.time_sent), - ); - - *newly_acked = true; - } - } - }; - } - - if IS_STREAM { - impl_ack_processing!( - Stream, - sent_stream_packets, - |packet_number: VarInt, _packet: &SentStreamPacket| { - *max_acked_stream = (*max_acked_stream).max(Some(packet_number)); - } - ); - } else { - impl_ack_processing!( - Recovery, - sent_recovery_packets, - |packet_number: VarInt, sent_packet: &SentRecoveryPacket| { - *max_acked_recovery = (*max_acked_recovery).max(Some(packet_number)); - *max_acked_stream = - (*max_acked_stream).max(Some(sent_packet.max_stream_packet_number)); - - // increase the max stream packet if this was a probe - if sent_packet.info.retransmission.is_none() { - self.max_stream_packet_number = self - .max_stream_packet_number - .max(sent_packet.max_stream_packet_number + 1); - } - } - ); - }; - - if let Some((time_sent, cc_info)) = cca_args { - let rtt_sample = clock.get_time().saturating_duration_since(time_sent); - - self.rtt_estimator.update_rtt( - ack.ack_delay(), - rtt_sample, - clock.get_time(), - true, - PacketNumberSpace::ApplicationData, - ); - - self.cca.on_packet_ack( - cc_info.first_sent_time, - bytes_acked, - cc_info, - &self.rtt_estimator, - random, - clock.get_time(), - ); - } + fn poll_once(&mut self, cx: &mut Context) { + let _ = self.poll_messages(cx); + let _ = self.poll_socket(cx); - Ok(()) + let _ = self.poll_timers(cx); + let _ = self.poll_transmit(cx); + self.after_transmit(); } #[inline] - fn detect_lost_packets( - &mut self, - decrypt_key: &D, - random: &mut dyn random::Generator, - clock: &Clk, - packet_space: stream::PacketSpace, - max: VarInt, - ) -> Result<(), Error> - where - D: decrypt::Key, - Clk: Clock, - { - let Some(loss_threshold) = max.checked_sub(VarInt::from_u8(2)) else { - return Ok(()); - }; - - let mut is_unrecoverable = false; - - macro_rules! impl_loss_detection { - ($sent_packets:ident, $on_packet:expr) => {{ - let lost_min = PacketNumberSpace::Initial.new_packet_number(VarInt::ZERO); - let lost_max = PacketNumberSpace::Initial.new_packet_number(loss_threshold); - let range = s2n_quic_core::packet::number::PacketNumberRange::new(lost_min, lost_max); - for (num, packet) in self.$sent_packets.remove_range(range) { - // TODO create a path and publisher - // self.ecn.on_packet_loss(packet.time_sent, packet.ecn, now, path, publisher); - - self.cca.on_packet_lost( - packet.info.cca_len() as _, - packet.cc_info, - random, - clock.get_time(), - ); + fn poll_messages(&mut self, cx: &mut Context) -> Poll<()> { + let _ = cx; + + while let Some(message) = self.shared.sender.pop_worker_message() { + match message.event { + Event::Shutdown { + queue, + is_panicking, + } => { + // if the application is panicking then we notify the peer + if is_panicking { + let error = Error::ApplicationError { error: 1u8.into() }; + self.sender.on_error(error); + continue; + } - probes::on_packet_lost( - decrypt_key.credentials().id, - self.stream_id, - packet_space, - num.as_u64(), - packet.info.packet_len, - packet.info.stream_offset, - packet.info.payload_len, - clock - .get_time() - .saturating_duration_since(packet.info.time_sent), - packet.info.retransmission.is_some(), - ); + // transition to a detached state + if self.state.on_application_detach().is_ok() { + debug_assert!( + self.application_queue.is_empty(), + "dropped queue twice for same stream" + ); - #[allow(clippy::redundant_closure_call)] - ($on_packet)(&packet); - - if let Some(segment) = packet.info.retransmission { - // update our local packet number to be at least 1 more than the largest lost - // packet number - let min_recovery_packet_number = num.as_u64() + 1; - self.recovery_packet_number = - self.recovery_packet_number.max(min_recovery_packet_number); - - let retransmission = retransmission::Segment { - segment, - stream_offset: packet.info.stream_offset, - payload_len: packet.info.payload_len, - ty: TransmissionType::Stream, - included_fin: packet.info.included_fin, - }; - self.retransmissions.push(retransmission); - } else { - // we can only recover reliable streams - is_unrecoverable |= packet.info.payload_len > 0 && !self.stream_id.is_reliable; + self.application_queue = queue; + continue; } - }} - } - } - - match packet_space { - stream::PacketSpace::Stream => impl_loss_detection!(sent_stream_packets, |_| {}), - stream::PacketSpace::Recovery => { - impl_loss_detection!(sent_recovery_packets, |sent_packet: &SentRecoveryPacket| { - self.max_stream_packet_number = self - .max_stream_packet_number - .max(sent_packet.max_stream_packet_number + 1); - }) + } } } - ensure!(!is_unrecoverable, Err(Error::RetransmissionFailure)); - - self.invariants(); - - Ok(()) - } - - #[inline] - fn on_peer_activity(&mut self, newly_acked_packets: bool) { - if let Some(prev) = self.peer_activity.as_mut() { - prev.newly_acked_packets |= newly_acked_packets; - } else { - self.peer_activity = Some(PeerActivity { - newly_acked_packets, - }); - } - } - - #[inline] - pub fn before_sleep(&mut self, clock: &Clk) { - self.process_peer_activity(); - - // make sure our timers are armed - self.update_idle_timer(clock); - self.update_inflight_timer(clock); - self.update_pto_timer(clock); - - trace!( - unacked_ranges = ?self.unacked_ranges, - retransmissions = self.retransmissions.len(), - stream_packets_in_flight = self.sent_stream_packets.iter().count(), - recovery_packets_in_flight = self.sent_recovery_packets.iter().count(), - pto_timer = ?self.pto.next_expiration(), - inflight_timer = ?self.inflight_timer.next_expiration(), - idle_timer = ?self.idle_timer.next_expiration(), - ); - } - - #[inline] - fn process_peer_activity(&mut self) { - let Some(PeerActivity { - newly_acked_packets, - }) = self.peer_activity.take() - else { - return; - }; - - if newly_acked_packets { - self.reset_pto_timer(); - } - - // force probing when we've sent all of the data but haven't got an ACK for the final - // offset - if self.state.is_data_sent() && self.stream_packet_buffers.is_empty() { - self.pto.force_transmit(); - } - - // re-arm the idle timer as long as we're not in terminal state - if !self.state.is_terminal() { - self.idle_timer.cancel(); - self.inflight_timer.cancel(); - } - } - - #[inline] - pub fn on_time_update(&mut self, clock: &Clk, load_last_activity: Ld) - where - Clk: Clock, - Ld: FnOnce() -> Timestamp, - { - if self.poll_idle_timer(clock, load_last_activity).is_ready() { - self.on_error(Error::IdleTimeout); - // we don't actually want to send any packets on idle timeout - let _ = self.state.on_send_reset(); - let _ = self.state.on_recv_reset_ack(); - return; - } - - if self - .inflight_timer - .poll_expiration(clock.get_time()) - .is_ready() - { - self.on_error(Error::IdleTimeout); - return; - } - - if self - .pto - .on_timeout(self.has_inflight_packets(), clock.get_time()) - .is_ready() - { - // TODO where does this come from - let max_pto_backoff = 1024; - self.pto_backoff = self.pto_backoff.saturating_mul(2).min(max_pto_backoff); - } - } - - #[inline] - fn poll_idle_timer(&mut self, clock: &Clk, load_last_activity: Ld) -> Poll<()> - where - Clk: Clock, - Ld: FnOnce() -> Timestamp, - { - let now = clock.get_time(); - - // check the idle timer first - ready!(self.idle_timer.poll_expiration(now)); - - // if that expired then load the last activity from the peer and update the idle timer with - // the value - let last_peer_activity = load_last_activity(); - self.update_idle_timer(&last_peer_activity); - - // check the idle timer once more before returning - ready!(self.idle_timer.poll_expiration(now)); - Poll::Ready(()) } #[inline] - fn has_inflight_packets(&self) -> bool { - !self.sent_stream_packets.is_empty() - || !self.sent_recovery_packets.is_empty() - || !self.retransmissions.is_empty() - || !self.transmit_queue.is_empty() - } - - #[inline] - fn update_idle_timer(&mut self, clock: &impl Clock) { - ensure!(!self.idle_timer.is_armed()); - - let now = clock.get_time(); - self.idle_timer.set(now + self.idle_timeout); - } - - #[inline] - fn update_inflight_timer(&mut self, clock: &impl Clock) { - // TODO make this configurable - let inflight_timeout = crate::stream::DEFAULT_INFLIGHT_TIMEOUT; - - if self.has_inflight_packets() { - if !self.inflight_timer.is_armed() { - self.inflight_timer.set(clock.get_time() + inflight_timeout); - } - } else { - self.inflight_timer.cancel(); + fn poll_socket(&mut self, cx: &mut Context) -> Poll<()> { + loop { + // try to receive until we get blocked + let _ = ready!(self.socket.poll_recv_buffer(cx, &mut self.recv_buffer)); + self.process_recv_buffer(); } } #[inline] - fn update_pto_timer(&mut self, clock: &impl Clock) { - ensure!(!self.pto.is_armed()); - - let mut should_arm = self.has_inflight_packets(); - - // if we have stream packet buffers in flight then arm the PTO - should_arm |= !self.stream_packet_buffers.is_empty(); + fn process_recv_buffer(&mut self) { + ensure!(!self.recv_buffer.is_empty()); - // if we've sent all of the data/reset and are waiting to clean things up - should_arm |= self.state.is_data_sent() || self.state.is_reset_sent(); + let remote_addr = self.recv_buffer.remote_address(); + let tag_len = self.shared.crypto.tag_len(); + let ecn = self.recv_buffer.ecn(); + let random = &mut self.random; + let mut any_valid_packets = false; + let clock = &clock::Cached::new(&self.shared.clock); - ensure!(should_arm); + for segment in self.recv_buffer.segments() { + let segment_len = segment.len(); + let mut decoder = DecoderBufferMut::new(segment); - self.force_arm_pto_timer(clock); - } - - #[inline] - fn force_arm_pto_timer(&mut self, clock: &impl Clock) { - let pto_period = self - .rtt_estimator - .pto_period(self.pto_backoff, PacketNumberSpace::Initial); - self.pto.update(clock.get_time(), pto_period); - } + while !decoder.is_empty() { + let remaining_len = decoder.len(); - #[inline] - fn reset_pto_timer(&mut self) { - self.pto_backoff = INITIAL_PTO_BACKOFF; - self.pto.cancel(); - } + let packet = match decoder.decode_parameterized(tag_len) { + Ok((packet, remaining)) => { + decoder = remaining; + packet + } + Err(err) => { + // we couldn't parse the rest of the packet so bail + tracing::error!(decoder_error = %err, segment_len, remaining_len); + break; + } + }; - /// Called by the worker thread when it becomes aware of the application having transmitted a - /// segment - #[inline] - pub fn load_transmission_queue( - &mut self, - queue: &application::transmission::Queue, - ) -> bool { - let mut did_transmit_stream = false; - - for Transmission { - packet_number, - info, - has_more_app_data, - } in queue.drain() - { - self.max_sent_segment_size = self.max_sent_segment_size.max(info.packet_len); - let info = info.map(|buffer| self.stream_packet_buffers.insert(buffer)); - self.on_transmit_segment( - stream::PacketSpace::Stream, - packet_number, - info, - has_more_app_data, - ); - did_transmit_stream = true; - } + match packet { + Packet::Control(mut packet) => { + // make sure we're processing the expected stream + ensure!( + packet.stream_id() == Some(&self.shared.application().stream_id), + continue + ); - if did_transmit_stream { - // if we just sent some packets then we can use those as probes - self.reset_pto_timer(); - } + let _ = self.shared.crypto.open_with(|opener| { + self.sender.on_control_packet( + opener, + ecn, + &mut packet, + random, + clock, + &self.shared.sender.application_transmission_queue, + &self.shared.sender.segment_alloc, + )?; - self.invariants(); + any_valid_packets = true; - did_transmit_stream - } - - #[inline] - fn on_transmit_segment( - &mut self, - packet_space: stream::PacketSpace, - packet_number: VarInt, - info: transmission::Info, - has_more_app_data: bool, - ) { - // the BBR implementation requires monotonic time so track that - let mut cca_time_sent = info.time_sent; - - match packet_space { - stream::PacketSpace::Stream => { - if let Some(min) = self.last_sent_recovery_packet { - cca_time_sent = info.time_sent.max(min); + >::Ok(()) + }); + } + other => self.shared.crypto.map().handle_unexpected_packet(&other), } } - stream::PacketSpace::Recovery => { - self.last_sent_recovery_packet = Some(info.time_sent); - } } - let cc_info = self.cca.on_packet_sent( - cca_time_sent, - info.cca_len(), - has_more_app_data, - &self.rtt_estimator, - ); - - // update the max offset that we've transmitted - self.max_sent_offset = self.max_sent_offset.max(info.end_offset()); - - // try to transition to start sending - let _ = self.state.on_send_stream(); - if info.included_fin { - // if the transmission included the final offset, then transition to that state - let _ = self.state.on_send_fin(); - } - - if let stream::PacketSpace::Recovery = packet_space { - let packet_number = PacketNumberSpace::Initial.new_packet_number(packet_number); - let max_stream_packet_number = self.max_stream_packet_number; - self.sent_recovery_packets.insert( - packet_number, - SentRecoveryPacket { - info, - cc_info, - max_stream_packet_number, - }, - ); - } else { - self.max_stream_packet_number = self.max_stream_packet_number.max(packet_number); - let packet_number = PacketNumberSpace::Initial.new_packet_number(packet_number); - self.sent_stream_packets - .insert(packet_number, SentStreamPacket { info, cc_info }); + if any_valid_packets { + // if the writer saw any ACKs then we're done handshaking + let did_complete_handshake = true; + self.shared + .on_valid_packet(&remote_addr, Half::Write, did_complete_handshake); } } #[inline] - pub fn fill_transmit_queue( - &mut self, - encrypt_key: &E, - source_control_port: u16, - clock: &Clk, - ) -> Result<(), Error> - where - E: encrypt::Key, - Clk: Clock, - { - if let Err(error) = self.fill_transmit_queue_impl(encrypt_key, source_control_port, clock) { - self.on_error(error); - return Err(error); - } - - Ok(()) + fn poll_timers(&mut self, cx: &mut Context) -> Poll<()> { + let _ = cx; + let shared = &self.shared; + let clock = clock::Cached::new(&shared.clock); + self.sender + .on_time_update(&clock, || shared.last_peer_activity()); + Poll::Ready(()) } #[inline] - fn fill_transmit_queue_impl( - &mut self, - encrypt_key: &E, - source_control_port: u16, - clock: &Clk, - ) -> Result<(), Error> - where - E: encrypt::Key, - Clk: Clock, - { - // skip a packet number if we're probing - if self.pto.transmissions() > 0 { - self.recovery_packet_number += 1; - } + fn poll_transmit(&mut self, cx: &mut Context) -> Poll<()> { + loop { + ready!(self.poll_transmit_flush(cx)); + + match self.state { + waiting::State::Acking => { + let _ = self.shared.crypto.seal_with(|sealer| { + self.sender.fill_transmit_queue( + sealer, + self.socket.local_addr().unwrap().port(), + &self.shared.clock, + ) + }); + } + waiting::State::Detached => { + // flush the remaining application queue + let _ = ready!(self.application_queue.poll_flush( + cx, + usize::MAX, + &self.socket, + &addr::Addr::new(self.shared.write_remote_addr()), + &self.shared.sender.segment_alloc, + &self.shared.gso, + )); + + // make sure we have the current view from the application + self.sender.load_transmission_queue( + &self.shared.sender.application_transmission_queue, + ); - self.try_transmit_retransmissions(encrypt_key, clock)?; - self.try_transmit_probe(encrypt_key, source_control_port, clock)?; + // try to transition to having sent all of the data + if self.sender.state.on_send_fin().is_ok() { + // arm the PTO now to force it to transmit a final packet + self.sender.pto.force_transmit(); + } - Ok(()) - } + // transition to shutting down + let _ = self.state.on_shutdown(); - #[inline] - fn try_transmit_retransmissions( - &mut self, - encrypt_key: &E, - clock: &Clk, - ) -> Result<(), Error> - where - E: encrypt::Key, - Clk: Clock, - { - // We'll only have retransmissions if we're reliable - ensure!(self.stream_id.is_reliable, Ok(())); - - while let Some(retransmission) = self.retransmissions.peek() { - // make sure we fit in the current congestion window - let remaining_cca_window = self - .cca - .congestion_window() - .saturating_sub(self.cca.bytes_in_flight()); - ensure!( - retransmission.payload_len as u32 <= remaining_cca_window, - break - ); - - let buffer = self.stream_packet_buffers[retransmission.segment].make_mut(); - - debug_assert!(!buffer.is_empty(), "empty retransmission buffer submitted"); - - let packet_number = - VarInt::new(self.recovery_packet_number).expect("2^62 is a lot of packets"); - self.recovery_packet_number += 1; - - { - let buffer = DecoderBufferMut::new(buffer); - match decoder::Packet::retransmit( - buffer, - stream::PacketSpace::Recovery, - packet_number, - encrypt_key, - ) { - Ok(info) => info, - Err(err) => { - // this shouldn't ever happen - debug_assert!(false, "{err:?}"); - return Err(Error::RetransmissionFailure); - } + continue; } - }; - - let time_sent = clock.get_time(); - let packet_len = buffer.len() as u16; - - { - let info = self - .retransmissions - .pop() - .expect("retransmission should be available"); - - // enqueue the transmission - self.transmit_queue - .push_back(TransmitIndex::Stream(info.segment)); - - let stream_offset = info.stream_offset; - let payload_len = info.payload_len; - let included_fin = info.included_fin; - let retransmission = Some(info.segment); - - // TODO store this as part of the packet queue - let ecn = ExplicitCongestionNotification::Ect0; - - let info = transmission::Info { - packet_len, - stream_offset, - payload_len, - included_fin, - retransmission, - time_sent, - ecn, - }; - - probes::on_transmit_stream( - encrypt_key.credentials().id, - self.stream_id, - stream::PacketSpace::Recovery, - PacketNumberSpace::Initial.new_packet_number(packet_number), - stream_offset, - payload_len, - included_fin, - true, - ); - - self.on_transmit_segment(stream::PacketSpace::Recovery, packet_number, info, false); - - // consider this transmission a probe if needed - if self.pto.transmissions() > 0 { - self.pto.on_transmit_once(); + waiting::State::ShuttingDown => { + let _ = self.shared.crypto.seal_with(|sealer| { + self.sender.fill_transmit_queue( + sealer, + self.socket.local_addr().unwrap().port(), + &self.shared.clock, + ) + }); + + if self.sender.state.is_terminal() { + let _ = self.state.on_finished(); + } } + waiting::State::Finished => break, } + + ensure!(!self.sender.transmit_queue.is_empty(), break); } - Ok(()) + Poll::Ready(()) } #[inline] - pub fn try_transmit_probe( - &mut self, - encrypt_key: &E, - source_control_port: u16, - clock: &Clk, - ) -> Result<(), Error> - where - E: encrypt::Key, - Clk: Clock, - { - while self.pto.transmissions() > 0 { - // probes are not congestion-controlled - - let packet_number = - VarInt::new(self.recovery_packet_number).expect("2^62 is a lot of packets"); - self.recovery_packet_number += 1; - - let mut buffer = self.free_packet_buffers.pop().unwrap_or_default(); - - // resize the buffer to what we need - { - let min_len = stream::encoder::MAX_RETRANSMISSION_HEADER_LEN + 128; - - if buffer.capacity() < min_len { - buffer.reserve(min_len - buffer.len()); - } + fn poll_transmit_flush(&mut self, cx: &mut Context) -> Poll<()> { + ensure!(!self.sender.transmit_queue.is_empty(), Poll::Ready(())); - unsafe { - debug_assert!(buffer.capacity() >= min_len); - buffer.set_len(min_len); - } - } + let mut max_segments = self.shared.gso.max_segments(); + let addr = self.shared.write_remote_addr(); + let addr = addr::Addr::new(addr); + let clock = &self.shared.clock; - let offset = self.max_sent_offset; - let final_offset = if self.state.is_data_sent() { - Some(offset) - } else { - None - }; - - let mut payload = probe::Probe { - offset, - final_offset, - }; - - let encoder = EncoderBuffer::new(&mut buffer); - let packet_len = encoder::encode( - encoder, - source_control_port, - None, - self.stream_id, - stream::PacketSpace::Recovery, - packet_number, - self.next_expected_control_packet, - VarInt::ZERO, - &mut &[][..], - VarInt::ZERO, - &(), - &mut payload, - encrypt_key, - ); - - let payload_len = 0; - let included_fin = final_offset.is_some(); - buffer.truncate(packet_len); - - debug_assert!( - packet_len < u16::MAX as usize, - "cannot write larger packets than 2^16" - ); - let packet_len = packet_len as u16; - - let time_sent = clock.get_time(); - - // TODO store this as part of the packet queue - let ecn = ExplicitCongestionNotification::Ect0; - - // enqueue the transmission - let buffer_index = self.recovery_packet_buffers.insert(buffer); - self.transmit_queue - .push_back(TransmitIndex::Recovery(buffer_index)); - - let info = transmission::Info { - packet_len, - stream_offset: offset, - payload_len, - included_fin, - retransmission: None, // PTO packets are not retransmitted - time_sent, - ecn, - }; - - self.on_transmit_segment(stream::PacketSpace::Recovery, packet_number, info, false); - - self.pto.on_transmit_once(); - } + while !self.sender.transmit_queue.is_empty() { + // pace out retransmissions + ready!(self.pacer.poll_pacing(cx, &self.shared.clock)); - Ok(()) - } + // construct all of the segments we're going to send in this batch + let segments = + msg::segment::Batch::new(self.sender.transmit_queue_iter(clock).take(max_segments)); - #[inline] - pub fn transmit_queue_iter( - &mut self, - clock: &Clk, - ) -> impl Iterator + '_ { - let ecn = self - .ecn - .ecn(s2n_quic_core::transmission::Mode::Normal, clock.get_time()); - let stream_packet_buffers = &self.stream_packet_buffers; - let recovery_packet_buffers = &self.recovery_packet_buffers; - - self.transmit_queue.iter().filter_map(move |index| { - let buf = match *index { - TransmitIndex::Stream(index) => stream_packet_buffers.get(index)?.as_slice(), - TransmitIndex::Recovery(index) => recovery_packet_buffers.get(index)?, - }; - - Some((ecn, buf)) - }) - } + let ecn = segments.ecn(); + let res = ready!(self.socket.poll_send(cx, &addr, ecn, &segments)); - #[inline] - pub fn on_transmit_queue(&mut self, count: usize) { - for transmission in self.transmit_queue.drain(..count) { - match transmission { - TransmitIndex::Stream(index) => { - // make sure the packet wasn't freed between when we wanted to transmit and - // when we actually did - ensure!(self.stream_packet_buffers.get(index).is_some(), continue); - } - TransmitIndex::Recovery(index) => { - // make sure the packet wasn't freed between when we wanted to transmit and - // when we actually did - let Some(mut buffer) = self.recovery_packet_buffers.remove(index) else { - continue; - }; - buffer.clear(); - self.free_packet_buffers.push(buffer); + if let Err(error) = res { + if self.shared.gso.handle_socket_error(&error).is_some() { + // update the max_segments value if it was changed due to the error + max_segments = 1; } - }; - } - } - - #[inline] - pub fn on_error(&mut self, error: Error) { - ensure!(self.error.is_none()); - self.error = Some(error); - let _ = self.state.on_queue_reset(); - - self.clean_up(); - } + } - #[inline] - fn clean_up(&mut self) { - self.retransmissions.clear(); - let min = PacketNumberSpace::Initial.new_packet_number(VarInt::ZERO); - let max = PacketNumberSpace::Initial.new_packet_number(VarInt::MAX); - let range = s2n_quic_core::packet::number::PacketNumberRange::new(min, max); - let _ = self.sent_stream_packets.remove_range(range); - let _ = self.sent_recovery_packets.remove_range(range); - - self.idle_timer.cancel(); - self.inflight_timer.cancel(); - self.pto.cancel(); - self.unacked_ranges.clear(); - - self.transmit_queue.clear(); - for buffer in self.stream_packet_buffers.drain() { - // TODO push buffer into free segment queue - let _ = buffer; - } - for (_idx, mut buffer) in self.recovery_packet_buffers.drain() { - buffer.clear(); - self.free_packet_buffers.push(buffer); + // consume the segments that we transmitted + let segment_count = segments.len(); + drop(segments); + self.sender.on_transmit_queue(segment_count); } - self.invariants(); + Poll::Ready(()) } - #[cfg(debug_assertions)] #[inline] - fn invariants(&self) { - // TODO - } + fn after_transmit(&mut self) { + self.sender + .load_transmission_queue(&self.shared.sender.application_transmission_queue); - #[cfg(not(debug_assertions))] - #[inline(always)] - fn invariants(&self) {} -} + self.sender + .before_sleep(&clock::Cached::new(&self.shared.clock)); + } -impl timer::Provider for Worker { #[inline] - fn timers(&self, query: &mut Q) -> timer::Result { - // if we're in a terminal state then no timers are needed - ensure!(!self.state.is_terminal(), Ok(())); - self.pto.timers(query)?; - self.idle_timer.timers(query)?; - Ok(()) + fn snapshot(&self) -> Snapshot { + Snapshot { + flow_offset: self.sender.flow_offset(), + has_pending_retransmissions: self.sender.transmit_queue.is_empty(), + send_quantum: self.sender.cca.send_quantum(), + // TODO get this from the ECN controller + ecn: ExplicitCongestionNotification::Ect0, + max_datagram_size: self.sender.max_datagram_size, + next_expected_control_packet: self.sender.next_expected_control_packet, + timeout: self.sender.next_expiration(), + bandwidth: self.sender.cca.bandwidth(), + error: self.sender.error, + } } } diff --git a/dc/s2n-quic-dc/src/stream/shared.rs b/dc/s2n-quic-dc/src/stream/shared.rs new file mode 100644 index 0000000000..2ba44e2a2b --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/shared.rs @@ -0,0 +1,173 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + clock::Clock, + stream::{ + recv::shared as recv, + send::{application, shared as send}, + }, +}; +use core::{ + cell::UnsafeCell, + ops, + sync::atomic::{AtomicU16, AtomicU64, AtomicU8, Ordering}, + time::Duration, +}; +use s2n_quic_core::{ + ensure, + inet::{IpAddress, SocketAddress}, + time::Timestamp, +}; +use s2n_quic_platform::features; +use std::sync::Arc; + +pub use crate::stream::crypto::Crypto; + +#[derive(Clone, Copy, Debug)] +pub enum Half { + Read, + Write, +} + +pub type ArcShared = Arc>; + +#[derive(Debug)] +pub struct Shared { + pub receiver: recv::State, + pub sender: send::State, + pub crypto: Crypto, + pub common: Common, +} + +impl Shared { + #[inline] + pub fn on_valid_packet( + &self, + remote_addr: &SocketAddress, + half: Half, + did_complete_handshake: bool, + ) { + if did_complete_handshake { + /* + // TODO only update this if this if we are done "handshaking" + let remote_port = match half { + Half::Read => &self.read_remote_port, + Half::Write => &self.write_remote_port, + }; + remote_port.store(remote_addr.port(), Ordering::Relaxed); + */ + let _ = half; + let remote_port = remote_addr.port(); + if remote_port != 0 { + self.read_remote_port.store(remote_port, Ordering::Relaxed); + self.write_remote_port.store(remote_port, Ordering::Relaxed); + } + } + + // update the last time we've seen peer activity + self.on_peer_activity(); + } + + #[inline] + pub fn on_peer_activity(&self) { + self.last_peer_activity.fetch_max( + unsafe { self.clock.get_time().as_duration().as_micros() as _ }, + Ordering::Relaxed, + ); + } +} + +impl Shared { + #[inline] + pub fn last_peer_activity(&self) -> Timestamp { + let timestamp = self.last_peer_activity.load(Ordering::Relaxed); + let timestamp = Duration::from_micros(timestamp); + unsafe { Timestamp::from_duration(timestamp) } + } + + #[inline] + pub fn write_remote_addr(&self) -> SocketAddress { + self.remote_ip() + .with_port(self.common.write_remote_port.load(Ordering::Relaxed)) + } + + #[inline] + pub fn read_remote_addr(&self) -> SocketAddress { + self.remote_ip() + .with_port(self.common.read_remote_port.load(Ordering::Relaxed)) + } + + #[inline] + pub fn remote_ip(&self) -> IpAddress { + unsafe { + // SAFETY: the fixed information doesn't change for the lifetime of the stream + *self.common.fixed.remote_ip.get() + } + } + + #[inline] + pub fn application(&self) -> application::state::State { + unsafe { + // SAFETY: the fixed information doesn't change for the lifetime of the stream + *self.common.fixed.application.get() + } + } + + #[inline] + pub fn source_control_port(&self) -> u16 { + unsafe { + // SAFETY: the fixed information doesn't change for the lifetime of the stream + *self.common.fixed.source_control_port.get() + } + } +} + +impl ops::Deref for Shared { + type Target = Common; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.common + } +} + +#[derive(Debug)] +pub struct Common { + pub gso: features::Gso, + pub read_remote_port: AtomicU16, + pub write_remote_port: AtomicU16, + pub fixed: FixedValues, + /// The last time we received a packet from the peer + pub last_peer_activity: AtomicU64, + pub closed_halves: AtomicU8, + pub clock: Clock, +} + +impl Common { + #[inline] + pub fn ensure_open(&self) -> std::io::Result<()> { + ensure!( + self.closed_halves.load(Ordering::Relaxed) < 2, + // macos returns a different error kind + Err(if cfg!(target_os = "macos") { + std::io::ErrorKind::InvalidInput + } else { + std::io::ErrorKind::NotConnected + } + .into()) + ); + Ok(()) + } +} + +/// Values that don't change while the state is shared between threads +#[derive(Debug)] +pub struct FixedValues { + pub remote_ip: UnsafeCell, + pub source_control_port: UnsafeCell, + pub application: UnsafeCell, +} + +unsafe impl Send for FixedValues {} +unsafe impl Sync for FixedValues {} diff --git a/dc/s2n-quic-dc/src/stream/socket.rs b/dc/s2n-quic-dc/src/stream/socket.rs index d2ab7406c5..a170f16863 100644 --- a/dc/s2n-quic-dc/src/stream/socket.rs +++ b/dc/s2n-quic-dc/src/stream/socket.rs @@ -4,8 +4,9 @@ use super::TransportFeatures; pub mod application; -mod fd; +pub mod fd; mod handle; +#[cfg(feature = "tokio")] mod tokio; mod tracing; diff --git a/dc/s2n-quic-dc/src/stream/socket/application/builder.rs b/dc/s2n-quic-dc/src/stream/socket/application/builder.rs index e2586d6144..5bf56ea8ed 100644 --- a/dc/s2n-quic-dc/src/stream/socket/application/builder.rs +++ b/dc/s2n-quic-dc/src/stream/socket/application/builder.rs @@ -1,72 +1,82 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use super::super::{ArcApplication, Tracing}; -use std::{io, sync::Arc}; -use tokio::io::unix::AsyncFd; +use crate::stream::socket::ArcApplication; +use std::io; pub trait Builder: 'static + Send + Sync { fn build(self: Box) -> io::Result; } -impl Builder for std::net::UdpSocket { - #[inline] - fn build(self: Box) -> io::Result { - let v = AsyncFd::new(*self)?; - let v = Tracing(v); - let v = super::Single(v); - let v = Arc::new(v); - Ok(v) +#[cfg(feature = "tokio")] +mod tokio_impl { + use super::*; + use crate::stream::socket::{application, Tracing}; + use std::{io, sync::Arc}; + use tokio::io::unix::AsyncFd; + + impl Builder for std::net::UdpSocket { + #[inline] + fn build(self: Box) -> io::Result { + let v = AsyncFd::new(*self)?; + let v = Tracing(v); + let v = application::Single(v); + let v = Arc::new(v); + Ok(v) + } } -} -impl Builder for std::net::TcpStream { - #[inline] - fn build(self: Box) -> io::Result { - let v = tokio::net::TcpStream::from_std(*self)?; - let v = Tracing(v); - let v = super::Single(v); - let v = Arc::new(v); - Ok(v) + impl Builder for Arc { + #[inline] + fn build(self: Box) -> io::Result { + // TODO avoid the Box> indirection here? + let v = AsyncFd::new(*self)?; + let v = Tracing(v); + let v = application::Single(v); + let v = Arc::new(v); + Ok(v) + } } -} -impl Builder for tokio::net::TcpStream { - #[inline] - fn build(self: Box) -> io::Result { - let v = Tracing(*self); - let v = super::Single(v); - let v = Arc::new(v); - Ok(v) + pub struct UdpPair { + pub reader: Arc, + pub writer: Arc, } -} -impl Builder for Arc { - #[inline] - fn build(self: Box) -> io::Result { - // TODO avoid the Box> indirection here? - let v = AsyncFd::new(*self)?; - let v = Tracing(v); - let v = super::Single(v); - let v = Arc::new(v); - Ok(v) + impl Builder for UdpPair { + #[inline] + fn build(self: Box) -> io::Result { + let read = AsyncFd::new(self.reader)?; + let read = Tracing(read); + let write = AsyncFd::new(self.writer)?; + let write = Tracing(write); + let v = application::Pair { read, write }; + let v = Arc::new(v); + Ok(v) + } } -} -pub struct UdpPair { - pub reader: Arc, - pub writer: Arc, -} + impl Builder for std::net::TcpStream { + #[inline] + fn build(self: Box) -> io::Result { + let v = tokio::net::TcpStream::from_std(*self)?; + let v = Tracing(v); + let v = application::Single(v); + let v = Arc::new(v); + Ok(v) + } + } -impl Builder for UdpPair { - #[inline] - fn build(self: Box) -> io::Result { - let read = AsyncFd::new(self.reader)?; - let read = Tracing(read); - let write = AsyncFd::new(self.writer)?; - let write = Tracing(write); - let v = super::Pair { read, write }; - let v = Arc::new(v); - Ok(v) + impl Builder for tokio::net::TcpStream { + #[inline] + fn build(self: Box) -> io::Result { + let v = Tracing(*self); + let v = application::Single(v); + let v = Arc::new(v); + Ok(v) + } } } + +#[cfg(feature = "tokio")] +pub use tokio_impl::*; diff --git a/dc/s2n-quic-dc/src/stream/socket/tokio/tcp.rs b/dc/s2n-quic-dc/src/stream/socket/tokio/tcp.rs index 51f676666b..983edc9160 100644 --- a/dc/s2n-quic-dc/src/stream/socket/tokio/tcp.rs +++ b/dc/s2n-quic-dc/src/stream/socket/tokio/tcp.rs @@ -22,7 +22,7 @@ impl Socket for TcpStream { #[inline] fn protocol(&self) -> Protocol { - Protocol::Udp + Protocol::Tcp } #[inline] diff --git a/dc/s2n-quic-dc/src/testing.rs b/dc/s2n-quic-dc/src/testing.rs new file mode 100644 index 0000000000..93775f5ffc --- /dev/null +++ b/dc/s2n-quic-dc/src/testing.rs @@ -0,0 +1,9 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub fn assert_debug(_v: &T) {} +pub fn assert_send(_v: &T) {} +pub fn assert_sync(_v: &T) {} +pub fn assert_static(_v: &T) {} +pub fn assert_async_read(_v: &T) {} +pub fn assert_async_write(_v: &T) {} diff --git a/dc/wireshark/src/bin/generate-pcap.rs b/dc/wireshark/src/bin/generate-pcap.rs index 0a4922413c..7b5b1c5785 100644 --- a/dc/wireshark/src/bin/generate-pcap.rs +++ b/dc/wireshark/src/bin/generate-pcap.rs @@ -237,6 +237,8 @@ impl Packet { fn add_dcquic(&mut self, packet_idx: u64) { // dcQUIC datagram. self.buffer.write_all(&[0x46]).unwrap(); + // wire version = 0 + self.buffer.write_all(&[0]).unwrap(); // Path secret ID self.buffer.write_all(&[0x43; 16]).unwrap(); // Key ID diff --git a/dc/wireshark/src/dissect.rs b/dc/wireshark/src/dissect.rs index 3ba1814272..0200eb9dbd 100644 --- a/dc/wireshark/src/dissect.rs +++ b/dc/wireshark/src/dissect.rs @@ -9,7 +9,7 @@ use crate::{ }; use s2n_codec::DecoderBufferMut; use s2n_quic_core::{frame::FrameMut, varint::VarInt}; -use s2n_quic_dc::packet::{self, stream}; +use s2n_quic_dc::packet::{self, stream, WireVersion}; #[derive(Clone, Copy, Debug)] #[allow(dead_code)] @@ -72,12 +72,16 @@ pub fn stream( fields.has_control_data, fields.has_final_offset, fields.has_application_header, + fields.key_phase, ] { tag_tree.add_boolean(buffer, field, tag); } let tag = tag.value; + let wire_version = buffer.consume::()?; + wire_version.record(buffer, tree, fields.wire_version); + let path_secret_id = buffer.consume_bytes(16)?; path_secret_id.record(buffer, tree, fields.path_secret_id); @@ -166,12 +170,19 @@ pub fn control( let tag_item = tag.record(buffer, tree, fields.tag); let mut tag_tree = tree.add_subtree(tag_item, fields.tag_subtree); - for field in [fields.is_stream, fields.has_application_header] { + for field in [ + fields.is_stream, + fields.has_application_header, + fields.key_phase, + ] { tag_tree.add_boolean(buffer, field, tag); } let tag = tag.value; + let wire_version = buffer.consume::()?; + wire_version.record(buffer, tree, fields.wire_version); + let path_secret_id = buffer.consume_bytes(16)?; path_secret_id.record(buffer, tree, fields.path_secret_id); @@ -401,12 +412,16 @@ pub fn datagram( fields.is_ack_eliciting, fields.is_connected, fields.has_application_header, + fields.key_phase, ] { tag_tree.add_boolean(buffer, field, tag); } let tag = tag.value; + let wire_version = buffer.consume::()?; + wire_version.record(buffer, tree, fields.wire_version); + let path_secret_id = buffer.consume_bytes(16)?; path_secret_id.record(buffer, tree, fields.path_secret_id); @@ -490,6 +505,9 @@ pub fn secret_control( packet::Tag::UnknownPathSecret(_) => { item.append_text(c" (UnknownPathSecret)"); + let wire_version = buffer.consume::()?; + wire_version.record(buffer, tree, fields.wire_version); + let path_secret_id = buffer.consume_bytes(16)?; path_secret_id.record(buffer, tree, fields.path_secret_id); @@ -504,6 +522,9 @@ pub fn secret_control( packet::Tag::StaleKey(_) => { item.append_text(c" (StaleKey)"); + let wire_version = buffer.consume::()?; + wire_version.record(buffer, tree, fields.wire_version); + let path_secret_id = buffer.consume_bytes(16)?; path_secret_id.record(buffer, tree, fields.path_secret_id); @@ -521,6 +542,9 @@ pub fn secret_control( packet::Tag::ReplayDetected(_) => { item.append_text(c" (ReplayDetected)"); + let wire_version = buffer.consume::()?; + wire_version.record(buffer, tree, fields.wire_version); + let path_secret_id = buffer.consume_bytes(16)?; path_secret_id.record(buffer, tree, fields.path_secret_id); diff --git a/dc/wireshark/src/field.rs b/dc/wireshark/src/field.rs index a4616ac28f..947ade8e1b 100644 --- a/dc/wireshark/src/field.rs +++ b/dc/wireshark/src/field.rs @@ -31,6 +31,8 @@ pub struct Registration { pub is_recovery_packet: i32, pub has_control_data: i32, pub has_final_offset: i32, + pub key_phase: i32, + pub wire_version: i32, pub path_secret_id: i32, pub key_id: i32, pub source_control_port: i32, @@ -268,6 +270,19 @@ fn init() -> Registration { ) .with_mask(masks::HAS_FINAL_OFFSET) .register(), + key_phase: protocol + .field(c"Key Phase", c"dcquic.tag.key_phase", BOOLEAN, SEP_DOT, c"") + .with_mask(masks::KEY_PHASE) + .register(), + wire_version: protocol + .field( + c"Wire Version", + c"dcquic.wire_version", + UINT32, + BASE_DEC, + c"dcQUIC wire version", + ) + .register(), path_secret_id: protocol .field( c"Path Secret ID", @@ -546,13 +561,16 @@ mod masks { pub const HAS_CONTROL_DATA: u64 = stream::Tag::HAS_CONTROL_DATA_MASK as _; pub const HAS_FINAL_OFFSET: u64 = stream::Tag::HAS_FINAL_OFFSET_MASK as _; - pub const HAS_APPLICATION_HEADER: u64 = { - // Statically assert that the masks line up between all three packets. - const _: [(); stream::Tag::HAS_APPLICATION_HEADER_MASK as usize] = - [(); datagram::Tag::HAS_APPLICATION_HEADER_MASK as usize]; - const _: [(); stream::Tag::HAS_APPLICATION_HEADER_MASK as usize] = - [(); control::Tag::HAS_APPLICATION_HEADER_MASK as usize]; + macro_rules! common_tag { + ($name:ident) => {{ + // Statically assert that the masks line up between all three packets. + const _: [(); stream::Tag::$name as usize] = [(); datagram::Tag::$name as usize]; + const _: [(); stream::Tag::$name as usize] = [(); control::Tag::$name as usize]; - datagram::Tag::HAS_APPLICATION_HEADER_MASK as _ - }; + datagram::Tag::$name as _ + }}; + } + + pub const HAS_APPLICATION_HEADER: u64 = common_tag!(HAS_APPLICATION_HEADER_MASK); + pub const KEY_PHASE: u64 = common_tag!(KEY_PHASE_MASK); } diff --git a/dc/wireshark/src/test.rs b/dc/wireshark/src/test.rs index 68f620797c..94ff8136c3 100644 --- a/dc/wireshark/src/test.rs +++ b/dc/wireshark/src/test.rs @@ -10,7 +10,7 @@ use s2n_quic_core::{ }; use s2n_quic_dc::{ credentials::{self, Credentials}, - packet::{self, stream}, + packet::{self, stream, WireVersion}, }; use std::{collections::HashMap, num::NonZeroU16, ptr, time::Duration}; @@ -72,6 +72,7 @@ fn check_stream_parse() { let tag: Parsed = tag.map(|v| v.into()); assert_eq!(tracker.remove(fields.tag), Field::Integer(tag.value as u64)); + assert_eq!(tracker.remove(fields.wire_version), Field::Integer(0)); assert_eq!( tracker.remove(fields.path_secret_id), Field::Slice(packet.credentials.id.to_vec()) @@ -126,6 +127,7 @@ fn check_stream_parse() { fields.has_control_data, fields.has_final_offset, fields.has_application_header, + fields.key_phase, ] { assert_eq!(tracker.remove(field), Field::Integer(tag.value as u64)); } @@ -239,6 +241,7 @@ fn check_datagram_parse() { let tag: Parsed = tag.map(|v| v.into()); assert_eq!(tracker.remove(fields.tag), Field::Integer(tag.value as u64)); + assert_eq!(tracker.remove(fields.wire_version), Field::Integer(0)); assert_eq!( tracker.remove(fields.path_secret_id), Field::Slice(packet.credentials.id.to_vec()) @@ -267,6 +270,7 @@ fn check_datagram_parse() { fields.is_ack_eliciting, fields.has_application_header, fields.is_connected, + fields.key_phase, ] { assert_eq!(tracker.remove(field), Field::Integer(tag.value as u64)); } @@ -362,6 +366,7 @@ fn check_control_parse() { let tag: Parsed = tag.map(|v| v.into()); assert_eq!(tracker.remove(fields.tag), Field::Integer(tag.value as u64)); + assert_eq!(tracker.remove(fields.wire_version), Field::Integer(0)); assert_eq!( tracker.remove(fields.path_secret_id), Field::Slice(packet.credentials.id.to_vec()) @@ -437,11 +442,15 @@ fn check_secret_control_parse() { let mut buffer = vec![0; s2n_quic_dc::packet::secret_control::MAX_PACKET_SIZE]; let length = match packet { SecretControlPacket::UnknownPathSecret { id, auth_tag } => { - s2n_quic_dc::packet::secret_control::UnknownPathSecret { credential_id: *id } - .encode(EncoderBuffer::new(&mut buffer), auth_tag) + s2n_quic_dc::packet::secret_control::UnknownPathSecret { + wire_version: WireVersion::ZERO, + credential_id: *id, + } + .encode(EncoderBuffer::new(&mut buffer), auth_tag) } SecretControlPacket::StaleKey { id, key_id } => { s2n_quic_dc::packet::secret_control::StaleKey { + wire_version: WireVersion::ZERO, credential_id: *id, min_key_id: *key_id, } @@ -449,6 +458,7 @@ fn check_secret_control_parse() { } SecretControlPacket::ReplayDetected { id, key_id } => { s2n_quic_dc::packet::secret_control::ReplayDetected { + wire_version: WireVersion::ZERO, credential_id: *id, rejected_key_id: *key_id, } @@ -468,6 +478,7 @@ fn check_secret_control_parse() { match packet { SecretControlPacket::UnknownPathSecret { id, auth_tag } => { assert_eq!(tracker.remove(fields.tag), Field::Integer(0b0110_0000)); + assert_eq!(tracker.remove(fields.wire_version), Field::Integer(0)); assert_eq!( tracker.remove(fields.path_secret_id), Field::Slice(id.to_vec()) @@ -479,6 +490,7 @@ fn check_secret_control_parse() { } SecretControlPacket::StaleKey { id, key_id } => { assert_eq!(tracker.remove(fields.tag), Field::Integer(0b0110_0001)); + assert_eq!(tracker.remove(fields.wire_version), Field::Integer(0)); assert_eq!( tracker.remove(fields.path_secret_id), Field::Slice(id.to_vec()) @@ -492,6 +504,7 @@ fn check_secret_control_parse() { } SecretControlPacket::ReplayDetected { id, key_id } => { assert_eq!(tracker.remove(fields.tag), Field::Integer(0b0110_0010)); + assert_eq!(tracker.remove(fields.wire_version), Field::Integer(0)); assert_eq!( tracker.remove(fields.path_secret_id), Field::Slice(id.to_vec()) diff --git a/dc/wireshark/src/value.rs b/dc/wireshark/src/value.rs index b0e8dd0fe2..279eea633b 100644 --- a/dc/wireshark/src/value.rs +++ b/dc/wireshark/src/value.rs @@ -3,7 +3,7 @@ use crate::{buffer::Buffer, wireshark::Node}; use s2n_quic_core::varint::VarInt; -use s2n_quic_dc::packet; +use s2n_quic_dc::packet::{self, WireVersion}; #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct Parsed { @@ -64,6 +64,12 @@ impl Parsed { } } +impl Parsed { + pub fn record(&self, buffer: &Buffer, tree: &mut T, field: i32) -> T::AddedItem { + tree.add_u32(buffer, field, self.map(|v| v.0)) + } +} + impl Parsed { pub fn record(&self, buffer: &Buffer, tree: &mut T, field: i32) -> T::AddedItem { tree.add_u32(buffer, field, *self)