diff --git a/dc/s2n-quic-dc/src/stream/client/tokio.rs b/dc/s2n-quic-dc/src/stream/client/tokio.rs index 7cfdc6e56..d1c85c528 100644 --- a/dc/s2n-quic-dc/src/stream/client/tokio.rs +++ b/dc/s2n-quic-dc/src/stream/client/tokio.rs @@ -65,7 +65,18 @@ where // Make sure TCP_NODELAY is set let _ = socket.set_nodelay(true); - let stream = endpoint::open_stream(env, peer, env::TcpRegistered(socket), subscriber, None)?; + let local_port = socket.local_addr()?.port(); + let stream = endpoint::open_stream( + env, + peer, + env::TcpRegistered { + socket, + peer_addr: acceptor_addr.into(), + local_port, + }, + subscriber, + None, + )?; // build the stream inside the application context let mut stream = stream.connect()?; @@ -85,14 +96,26 @@ where #[inline] pub async fn connect_tcp_with( peer: secret::map::Peer, - stream: TcpStream, + socket: TcpStream, env: &Environment, subscriber: Sub, ) -> io::Result> where Sub: event::Subscriber, { - let stream = endpoint::open_stream(env, peer, env::TcpRegistered(stream), subscriber, None)?; + let local_port = socket.local_addr()?.port(); + let peer_addr = socket.peer_addr()?.into(); + let stream = endpoint::open_stream( + env, + peer, + env::TcpRegistered { + socket, + peer_addr, + local_port, + }, + subscriber, + None, + )?; // build the stream inside the application context let mut stream = stream.connect()?; diff --git a/dc/s2n-quic-dc/src/stream/environment/tokio.rs b/dc/s2n-quic-dc/src/stream/environment/tokio.rs index 3c5091b82..df82da4cc 100644 --- a/dc/s2n-quic-dc/src/stream/environment/tokio.rs +++ b/dc/s2n-quic-dc/src/stream/environment/tokio.rs @@ -255,7 +255,11 @@ where } /// A socket that is already registered with the application runtime -pub struct TcpRegistered(pub TcpStream); +pub struct TcpRegistered { + pub socket: TcpStream, + pub peer_addr: SocketAddress, + pub local_port: u16, +} impl super::Peer> for TcpRegistered where @@ -274,9 +278,9 @@ where #[inline] fn setup(self, _env: &Environment) -> super::Result> { - let remote_addr = self.0.peer_addr()?.into(); - let source_control_port = self.0.local_addr()?.port(); - let application = Box::new(self.0); + let remote_addr = self.peer_addr; + let source_control_port = self.local_port; + let application = Box::new(self.socket); Ok(super::SocketSet { application, read_worker: None, @@ -289,7 +293,11 @@ where } /// A socket that should be reregistered with the application runtime -pub struct TcpReregistered(pub TcpStream, pub SocketAddress); +pub struct TcpReregistered { + pub socket: TcpStream, + pub peer_addr: SocketAddress, + pub local_port: u16, +} impl super::Peer> for TcpReregistered where @@ -308,9 +316,9 @@ where #[inline] fn setup(self, _env: &Environment) -> super::Result> { - let remote_addr = self.1; - let source_control_port = self.0.local_addr()?.port(); - let application = Box::new(self.0.into_std()?); + let source_control_port = self.local_port; + let remote_addr = self.peer_addr; + let application = Box::new(self.socket.into_std()?); Ok(super::SocketSet { application, read_worker: None, diff --git a/dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs b/dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs index a52957ec0..b4dd603e0 100644 --- a/dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs +++ b/dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs @@ -225,6 +225,8 @@ impl FreshQueue { remaining -= 1; if remaining == 0 { + // if we're yielding then we need to wake ourselves up again + cx.waker().wake_by_ref(); break; } } @@ -444,6 +446,7 @@ where secrets: secret::Map, accept_flavor: accept::Flavor, subscriber: Sub, + local_port: u16, } impl WorkerContext @@ -458,6 +461,7 @@ where secrets: acceptor.secrets.clone(), accept_flavor: acceptor.accept_flavor, subscriber: acceptor.subscriber.clone(), + local_port: acceptor.socket.local_addr().unwrap().port(), } } } @@ -691,7 +695,11 @@ impl WorkerState { let stream_builder = match endpoint::accept_stream( now, &context.env, - env::TcpReregistered(socket, remote_address), + env::TcpReregistered { + socket, + peer_addr: remote_address, + local_port: context.local_port, + }, &initial_packet, None, Some(recv_buffer), @@ -702,7 +710,7 @@ impl WorkerState { ) { Ok(stream) => stream, Err(error) => { - if let Some(env::TcpReregistered(socket, remote_address)) = error.peer { + if let Some(env::TcpReregistered { socket, .. }) = error.peer { if !error.secret_control.is_empty() { // if we need to send an error then update the state and loop back // around