diff --git a/src/transport/quic/mod.rs b/src/transport/quic/mod.rs index ad03674d..d69e1603 100644 --- a/src/transport/quic/mod.rs +++ b/src/transport/quic/mod.rs @@ -34,7 +34,11 @@ use crate::{ PeerId, }; -use futures::{future::BoxFuture, stream::FuturesUnordered, Stream, StreamExt}; +use futures::{ + future::BoxFuture, + stream::{AbortHandle, FuturesUnordered}, + Stream, StreamExt, TryFutureExt, +}; use multiaddr::{Multiaddr, Protocol}; use quinn::{ClientConfig, Connecting, Connection, Endpoint, IdleTimeout}; @@ -66,6 +70,25 @@ struct NegotiatedConnection { connection: Connection, } +#[derive(Debug)] +enum RawConnectionResult { + /// The first successful connection. + Connected { + connection_id: ConnectionId, + address: Multiaddr, + stream: NegotiatedConnection, + }, + + /// All connection attempts failed. + Failed { + connection_id: ConnectionId, + errors: Vec<(Multiaddr, DialError)>, + }, + + /// Future was canceled. + Canceled { connection_id: ConnectionId }, +} + /// QUIC transport object. pub(crate) struct QuicTransport { /// Transport handle. @@ -92,21 +115,15 @@ pub(crate) struct QuicTransport { pending_open: HashMap, /// Pending raw, unnegotiated connections. - pending_raw_connections: FuturesUnordered< - BoxFuture< - 'static, - Result< - (ConnectionId, Multiaddr, NegotiatedConnection), - (ConnectionId, Vec<(Multiaddr, DialError)>), - >, - >, - >, + pending_raw_connections: FuturesUnordered>, /// Opened raw connection, waiting for approval/rejection from `TransportManager`. opened_raw: HashMap, /// Canceled raw connections. canceled: HashSet, + + cancel_futures: HashMap, } impl QuicTransport { @@ -225,6 +242,7 @@ impl TransportBuilder for QuicTransport { pending_inbound_connections: HashMap::new(), pending_raw_connections: FuturesUnordered::new(), pending_connections: FuturesUnordered::new(), + cancel_futures: HashMap::new(), }, listen_addresses, )) @@ -407,12 +425,18 @@ impl Transport for QuicTransport { }) .collect(); - self.pending_raw_connections.push(Box::pin(async move { + // Future that will resolve to the first successful connection. + let future = async move { let mut errors = Vec::with_capacity(num_addresses); while let Some(result) = futures.next().await { match result { - Ok((address, connection)) => return Ok((connection_id, address, connection)), + Ok((address, stream)) => + return RawConnectionResult::Connected { + connection_id, + address, + stream, + }, Err(error) => { tracing::debug!( target: LOG_TARGET, @@ -425,8 +449,16 @@ impl Transport for QuicTransport { } } - Err((connection_id, errors)) - })); + RawConnectionResult::Failed { + connection_id, + errors, + } + }; + + let (fut, handle) = futures::future::abortable(future); + let fut = fut.unwrap_or_else(move |_| RawConnectionResult::Canceled { connection_id }); + self.pending_raw_connections.push(Box::pin(fut)); + self.cancel_futures.insert(connection_id, handle); Ok(()) } @@ -446,6 +478,7 @@ impl Transport for QuicTransport { /// Cancel opening connections. fn cancel(&mut self, connection_id: ConnectionId) { self.canceled.insert(connection_id); + self.cancel_futures.remove(&connection_id).map(|handle| handle.abort()); } } @@ -470,16 +503,14 @@ impl Stream for QuicTransport { } while let Poll::Ready(Some(result)) = self.pending_raw_connections.poll_next_unpin(cx) { - match result { - Ok((connection_id, address, stream)) => { - tracing::trace!( - target: LOG_TARGET, - ?connection_id, - ?address, - canceled = self.canceled.contains(&connection_id), - "connection opened", - ); + tracing::trace!(target: LOG_TARGET, ?result, "raw connection result"); + match result { + RawConnectionResult::Connected { + connection_id, + address, + stream, + } => if !self.canceled.remove(&connection_id) { self.opened_raw.insert(connection_id, (stream, address.clone())); @@ -487,15 +518,20 @@ impl Stream for QuicTransport { connection_id, address, })); - } - } - Err((connection_id, errors)) => + }, + RawConnectionResult::Failed { + connection_id, + errors, + } => if !self.canceled.remove(&connection_id) { return Poll::Ready(Some(TransportEvent::OpenFailure { connection_id, errors, })); }, + RawConnectionResult::Canceled { connection_id } => { + self.canceled.remove(&connection_id); + } } }