diff --git a/protocols/ping/src/handler.rs b/protocols/ping/src/handler.rs deleted file mode 100644 index 67d4b66c0a0..00000000000 --- a/protocols/ping/src/handler.rs +++ /dev/null @@ -1,430 +0,0 @@ -// Copyright 2019 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -use crate::{protocol, PROTOCOL_NAME}; -use futures::future::BoxFuture; -use futures::prelude::*; -use futures_timer::Delay; -use libp2p_core::upgrade::ReadyUpgrade; -use libp2p_core::{upgrade::NegotiationError, UpgradeError}; -use libp2p_swarm::handler::{ - ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, -}; -use libp2p_swarm::{ - ConnectionHandler, ConnectionHandlerEvent, ConnectionHandlerUpgrErr, KeepAlive, - NegotiatedSubstream, SubstreamProtocol, -}; -use std::collections::VecDeque; -use std::{ - error::Error, - fmt, io, - num::NonZeroU32, - task::{Context, Poll}, - time::Duration, -}; -use void::Void; - -/// The configuration for outbound pings. -#[derive(Debug, Clone)] -pub struct Config { - /// The timeout of an outbound ping. - timeout: Duration, - /// The duration between the last successful outbound or inbound ping - /// and the next outbound ping. - interval: Duration, - /// The maximum number of failed outbound pings before the associated - /// connection is deemed unhealthy, indicating to the `Swarm` that it - /// should be closed. - max_failures: NonZeroU32, - /// Whether the connection should generally be kept alive unless - /// `max_failures` occur. - keep_alive: bool, -} - -impl Config { - /// Creates a new [`Config`] with the following default settings: - /// - /// * [`Config::with_interval`] 15s - /// * [`Config::with_timeout`] 20s - /// * [`Config::with_max_failures`] 1 - /// * [`Config::with_keep_alive`] false - /// - /// These settings have the following effect: - /// - /// * A ping is sent every 15 seconds on a healthy connection. - /// * Every ping sent must yield a response within 20 seconds in order to - /// be successful. - /// * A single ping failure is sufficient for the connection to be subject - /// to being closed. - /// * The connection may be closed at any time as far as the ping protocol - /// is concerned, i.e. the ping protocol itself does not keep the - /// connection alive. - pub fn new() -> Self { - Self { - timeout: Duration::from_secs(20), - interval: Duration::from_secs(15), - max_failures: NonZeroU32::new(1).expect("1 != 0"), - keep_alive: false, - } - } - - /// Sets the ping timeout. - pub fn with_timeout(mut self, d: Duration) -> Self { - self.timeout = d; - self - } - - /// Sets the ping interval. - pub fn with_interval(mut self, d: Duration) -> Self { - self.interval = d; - self - } - - /// Sets the maximum number of consecutive ping failures upon which the remote - /// peer is considered unreachable and the connection closed. - pub fn with_max_failures(mut self, n: NonZeroU32) -> Self { - self.max_failures = n; - self - } - - /// Sets whether the ping protocol itself should keep the connection alive, - /// apart from the maximum allowed failures. - /// - /// By default, the ping protocol itself allows the connection to be closed - /// at any time, i.e. in the absence of ping failures the connection lifetime - /// is determined by other protocol handlers. - /// - /// If the maximum number of allowed ping failures is reached, the - /// connection is always terminated as a result of [`ConnectionHandler::poll`] - /// returning an error, regardless of the keep-alive setting. - #[deprecated( - since = "0.40.0", - note = "Use `libp2p::swarm::behaviour::KeepAlive` if you need to keep connections alive unconditionally." - )] - pub fn with_keep_alive(mut self, b: bool) -> Self { - self.keep_alive = b; - self - } -} - -impl Default for Config { - fn default() -> Self { - Self::new() - } -} - -/// The successful result of processing an inbound or outbound ping. -#[derive(Debug)] -pub enum Success { - /// Received a ping and sent back a pong. - Pong, - /// Sent a ping and received back a pong. - /// - /// Includes the round-trip time. - Ping { rtt: Duration }, -} - -/// An outbound ping failure. -#[derive(Debug)] -pub enum Failure { - /// The ping timed out, i.e. no response was received within the - /// configured ping timeout. - Timeout, - /// The peer does not support the ping protocol. - Unsupported, - /// The ping failed for reasons other than a timeout. - Other { - error: Box, - }, -} - -impl fmt::Display for Failure { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Failure::Timeout => f.write_str("Ping timeout"), - Failure::Other { error } => write!(f, "Ping error: {}", error), - Failure::Unsupported => write!(f, "Ping protocol not supported"), - } - } -} - -impl Error for Failure { - fn source(&self) -> Option<&(dyn Error + 'static)> { - match self { - Failure::Timeout => None, - Failure::Other { error } => Some(&**error), - Failure::Unsupported => None, - } - } -} - -/// Protocol handler that handles pinging the remote at a regular period -/// and answering ping queries. -/// -/// If the remote doesn't respond, produces an error that closes the connection. -pub struct Handler { - /// Configuration options. - config: Config, - /// The timer used for the delay to the next ping as well as - /// the ping timeout. - timer: Delay, - /// Outbound ping failures that are pending to be processed by `poll()`. - pending_errors: VecDeque, - /// The number of consecutive ping failures that occurred. - /// - /// Each successful ping resets this counter to 0. - failures: u32, - /// The outbound ping state. - outbound: Option, - /// The inbound pong handler, i.e. if there is an inbound - /// substream, this is always a future that waits for the - /// next inbound ping to be answered. - inbound: Option, - /// Tracks the state of our handler. - state: State, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum State { - /// We are inactive because the other peer doesn't support ping. - Inactive { - /// Whether or not we've reported the missing support yet. - /// - /// This is used to avoid repeated events being emitted for a specific connection. - reported: bool, - }, - /// We are actively pinging the other peer. - Active, -} - -impl Handler { - /// Builds a new [`Handler`] with the given configuration. - pub fn new(config: Config) -> Self { - Handler { - config, - timer: Delay::new(Duration::new(0, 0)), - pending_errors: VecDeque::with_capacity(2), - failures: 0, - outbound: None, - inbound: None, - state: State::Active, - } - } - - fn on_dial_upgrade_error( - &mut self, - DialUpgradeError { error, .. }: DialUpgradeError< - ::OutboundOpenInfo, - ::OutboundProtocol, - >, - ) { - self.outbound = None; // Request a new substream on the next `poll`. - - let error = match error { - ConnectionHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => { - debug_assert_eq!(self.state, State::Active); - - self.state = State::Inactive { reported: false }; - return; - } - // Note: This timeout only covers protocol negotiation. - ConnectionHandlerUpgrErr::Timeout => Failure::Timeout, - e => Failure::Other { error: Box::new(e) }, - }; - - self.pending_errors.push_front(error); - } -} - -impl ConnectionHandler for Handler { - type InEvent = Void; - type OutEvent = crate::Result; - type Error = Failure; - type InboundProtocol = ReadyUpgrade<&'static [u8]>; - type OutboundProtocol = ReadyUpgrade<&'static [u8]>; - type OutboundOpenInfo = (); - type InboundOpenInfo = (); - - fn listen_protocol(&self) -> SubstreamProtocol, ()> { - SubstreamProtocol::new(ReadyUpgrade::new(PROTOCOL_NAME), ()) - } - - fn on_behaviour_event(&mut self, _: Void) {} - - fn connection_keep_alive(&self) -> KeepAlive { - if self.config.keep_alive { - KeepAlive::Yes - } else { - KeepAlive::No - } - } - - fn poll( - &mut self, - cx: &mut Context<'_>, - ) -> Poll, (), crate::Result, Self::Error>> - { - match self.state { - State::Inactive { reported: true } => { - return Poll::Pending; // nothing to do on this connection - } - State::Inactive { reported: false } => { - self.state = State::Inactive { reported: true }; - return Poll::Ready(ConnectionHandlerEvent::Custom(Err(Failure::Unsupported))); - } - State::Active => {} - } - - // Respond to inbound pings. - if let Some(fut) = self.inbound.as_mut() { - match fut.poll_unpin(cx) { - Poll::Pending => {} - Poll::Ready(Err(e)) => { - log::debug!("Inbound ping error: {:?}", e); - self.inbound = None; - } - Poll::Ready(Ok(stream)) => { - // A ping from a remote peer has been answered, wait for the next. - self.inbound = Some(protocol::recv_ping(stream).boxed()); - return Poll::Ready(ConnectionHandlerEvent::Custom(Ok(Success::Pong))); - } - } - } - - loop { - // Check for outbound ping failures. - if let Some(error) = self.pending_errors.pop_back() { - log::debug!("Ping failure: {:?}", error); - - self.failures += 1; - - // Note: For backward-compatibility, with configured - // `max_failures == 1`, the first failure is always "free" - // and silent. This allows peers who still use a new substream - // for each ping to have successful ping exchanges with peers - // that use a single substream, since every successful ping - // resets `failures` to `0`, while at the same time emitting - // events only for `max_failures - 1` failures, as before. - if self.failures > 1 || self.config.max_failures.get() > 1 { - if self.failures >= self.config.max_failures.get() { - log::debug!("Too many failures ({}). Closing connection.", self.failures); - return Poll::Ready(ConnectionHandlerEvent::Close(error)); - } - - return Poll::Ready(ConnectionHandlerEvent::Custom(Err(error))); - } - } - - // Continue outbound pings. - match self.outbound.take() { - Some(OutboundState::Ping(mut ping)) => match ping.poll_unpin(cx) { - Poll::Pending => { - if self.timer.poll_unpin(cx).is_ready() { - self.pending_errors.push_front(Failure::Timeout); - } else { - self.outbound = Some(OutboundState::Ping(ping)); - break; - } - } - Poll::Ready(Ok((stream, rtt))) => { - self.failures = 0; - self.timer.reset(self.config.interval); - self.outbound = Some(OutboundState::Idle(stream)); - return Poll::Ready(ConnectionHandlerEvent::Custom(Ok(Success::Ping { - rtt, - }))); - } - Poll::Ready(Err(e)) => { - self.pending_errors - .push_front(Failure::Other { error: Box::new(e) }); - } - }, - Some(OutboundState::Idle(stream)) => match self.timer.poll_unpin(cx) { - Poll::Pending => { - self.outbound = Some(OutboundState::Idle(stream)); - break; - } - Poll::Ready(()) => { - self.timer.reset(self.config.timeout); - self.outbound = - Some(OutboundState::Ping(protocol::send_ping(stream).boxed())); - } - }, - Some(OutboundState::OpenStream) => { - self.outbound = Some(OutboundState::OpenStream); - break; - } - None => { - self.outbound = Some(OutboundState::OpenStream); - let protocol = SubstreamProtocol::new(ReadyUpgrade::new(PROTOCOL_NAME), ()) - .with_timeout(self.config.timeout); - return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol, - }); - } - } - } - - Poll::Pending - } - - fn on_connection_event( - &mut self, - event: ConnectionEvent< - Self::InboundProtocol, - Self::OutboundProtocol, - Self::InboundOpenInfo, - Self::OutboundOpenInfo, - >, - ) { - match event { - ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound { - protocol: stream, - .. - }) => { - self.inbound = Some(protocol::recv_ping(stream).boxed()); - } - ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound { - protocol: stream, - .. - }) => { - self.timer.reset(self.config.timeout); - self.outbound = Some(OutboundState::Ping(protocol::send_ping(stream).boxed())); - } - ConnectionEvent::DialUpgradeError(dial_upgrade_error) => { - self.on_dial_upgrade_error(dial_upgrade_error) - } - ConnectionEvent::AddressChange(_) | ConnectionEvent::ListenUpgradeError(_) => {} - } - } -} - -type PingFuture = BoxFuture<'static, Result<(NegotiatedSubstream, Duration), io::Error>>; -type PongFuture = BoxFuture<'static, Result>; - -/// The current state w.r.t. outbound pings. -enum OutboundState { - /// A new substream is being negotiated for the ping protocol. - OpenStream, - /// The substream is idle, waiting to send the next ping. - Idle(NegotiatedSubstream), - /// A ping is being sent and the response awaited. - Ping(PingFuture), -} diff --git a/protocols/ping/src/lib.rs b/protocols/ping/src/lib.rs index 6e481500df9..51112a96b5e 100644 --- a/protocols/ping/src/lib.rs +++ b/protocols/ping/src/lib.rs @@ -42,52 +42,41 @@ #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] -mod handler; mod protocol; -use handler::Handler; -pub use handler::{Config, Failure, Success}; -use libp2p_core::{connection::ConnectionId, PeerId}; +use crate::protocol::{recv_ping, send_ping}; +use futures::future::Either; +use futures_timer::Delay; +use libp2p_core::connection::ConnectionId; +use libp2p_core::PeerId; +use libp2p_swarm::behaviour::{ConnectionEstablished, FromSwarm}; +use libp2p_swarm::handler::from_fn; use libp2p_swarm::{ - behaviour::FromSwarm, NetworkBehaviour, NetworkBehaviourAction, PollParameters, + CloseConnection, NetworkBehaviour, NetworkBehaviourAction, NotifyHandler, PollParameters, }; -use std::{ - collections::VecDeque, - task::{Context, Poll}, -}; - -#[deprecated(since = "0.39.1", note = "Use libp2p::ping::Config instead.")] -pub type PingConfig = Config; - -#[deprecated(since = "0.39.1", note = "Use libp2p::ping::Event instead.")] -pub type PingEvent = Event; - -#[deprecated(since = "0.39.1", note = "Use libp2p::ping::Success instead.")] -pub type PingSuccess = Success; - -#[deprecated(since = "0.39.1", note = "Use libp2p::ping::Failure instead.")] -pub type PingFailure = Failure; +use std::collections::{HashMap, VecDeque}; +use std::error::Error; +use std::num::NonZeroU32; +use std::task::{Context, Poll}; +use std::time::Duration; +use std::{fmt, io}; -#[deprecated(since = "0.39.1", note = "Use libp2p::ping::Result instead.")] -pub type PingResult = Result; +pub use crate::protocol::PROTOCOL_NAME; -#[deprecated(since = "0.39.1", note = "Use libp2p::ping::Behaviour instead.")] -pub type Ping = Behaviour; - -pub use self::protocol::PROTOCOL_NAME; - -/// The result of an inbound or outbound ping. pub type Result = std::result::Result; +type Handler = from_fn::FromFnProto; + /// A [`NetworkBehaviour`] that responds to inbound pings and /// periodically sends outbound pings on every established connection. /// /// See the crate root documentation for more information. +#[derive(Default)] pub struct Behaviour { /// Configuration for outbound pings. config: Config, - /// Queue of events to yield to the swarm. - events: VecDeque, + actions: VecDeque>, + failures: HashMap<(PeerId, ConnectionId), (u32, VecDeque)>, } /// Event generated by the `Ping` network behaviour. @@ -104,14 +93,138 @@ impl Behaviour { pub fn new(config: Config) -> Self { Self { config, - events: VecDeque::new(), + actions: Default::default(), + failures: Default::default(), + } + } + + fn reset_num_failures(&mut self, peer: PeerId, connection_id: ConnectionId) { + self.failures.entry((peer, connection_id)).or_default().0 = 0; + } + + fn record_failure(&mut self, peer: PeerId, connection_id: ConnectionId, e: Failure) { + self.failures + .entry((peer, connection_id)) + .or_default() + .1 + .push_back(e); + } +} + +/// The configuration for outbound pings. +#[derive(Debug, Clone)] +pub struct Config { + /// The timeout of an outbound ping. + pub(crate) timeout: Duration, + /// The duration between the last successful outbound or inbound ping + /// and the next outbound ping. + pub(crate) interval: Duration, + /// The maximum number of failed outbound pings before the associated + /// connection is deemed unhealthy, indicating to the `Swarm` that it + /// should be closed. + pub(crate) max_failures: NonZeroU32, +} + +impl Config { + /// Creates a new [`Config`] with the following default settings: + /// + /// * [`Config::with_interval`] 15s + /// * [`Config::with_timeout`] 20s + /// * [`Config::with_max_failures`] 1 + /// * [`Config::with_keep_alive`] false + /// + /// These settings have the following effect: + /// + /// * A ping is sent every 15 seconds on a healthy connection. + /// * Every ping sent must yield a response within 20 seconds in order to + /// be successful. + /// * A single ping failure is sufficient for the connection to be subject + /// to being closed. + /// * The connection may be closed at any time as far as the ping protocol + /// is concerned, i.e. the ping protocol itself does not keep the + /// connection alive. + pub fn new() -> Self { + Self { + timeout: Duration::from_secs(20), + interval: Duration::from_secs(15), + max_failures: NonZeroU32::new(1).expect("1 != 0"), } } + + /// Sets the ping timeout. + pub fn with_timeout(mut self, d: Duration) -> Self { + self.timeout = d; + self + } + + /// Sets the ping interval. + pub fn with_interval(mut self, d: Duration) -> Self { + self.interval = d; + self + } + + /// Sets the maximum number of consecutive ping failures upon which the remote + /// peer is considered unreachable and the connection closed. + pub fn with_max_failures(mut self, n: NonZeroU32) -> Self { + self.max_failures = n; + self + } } -impl Default for Behaviour { +impl Default for Config { fn default() -> Self { - Self::new(Config::new()) + Self::new() + } +} + +/// The successful result of processing an inbound or outbound ping. +#[derive(Debug)] +pub enum Success { + /// Received a ping and sent back a pong. + Pong, + /// Sent a ping and received back a pong. + /// + /// Includes the round-trip time. + Ping { rtt: Duration }, +} + +/// An outbound ping failure. +#[derive(Debug)] +pub enum Failure { + /// The ping timed out, i.e. no response was received within the + /// configured ping timeout. + Timeout, + /// The peer does not support the ping protocol. + Unsupported, + /// The ping failed for reasons other than a timeout. + Other { + error: Box, + }, +} + +impl From for Failure { + fn from(e: io::Error) -> Self { + Failure::Other { error: Box::new(e) } + } +} + +impl fmt::Display for Failure { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Failure::Timeout => f.write_str("Ping timeout"), + Failure::Other { error } => write!(f, "Ping error: {}", error), + Failure::Unsupported => write!(f, "Ping protocol not supported"), + } + } +} + +impl Error for Failure { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + Failure::Timeout => None, + Failure::Other { error } => Some(&**error), + Failure::Unsupported => None, + } } } @@ -120,30 +233,92 @@ impl NetworkBehaviour for Behaviour { type OutEvent = Event; fn new_handler(&mut self) -> Self::ConnectionHandler { - Handler::new(self.config.clone()) - } + from_fn::from_fn(std::str::from_utf8(PROTOCOL_NAME).unwrap()) + .without_state() + .with_streaming_inbound_handler(1, |stream, _, _, _| { + futures::stream::try_unfold(stream, |stream| async move { + let stream = recv_ping(stream).await?; + + Ok(Some((Success::Pong, stream))) + }) + }) + .with_streaming_outbound_handler(1, { + let interval = self.config.interval; + let timeout = self.config.timeout; + + move |stream, _, _, _, _| { + futures::stream::try_unfold(stream, move |stream| async move { + Delay::new(interval).await; - fn on_connection_handler_event(&mut self, peer: PeerId, _: ConnectionId, result: Result) { - self.events.push_front(Event { peer, result }) + let ping = send_ping(stream); + futures::pin_mut!(ping); + + match futures::future::select(Delay::new(timeout), ping).await { + Either::Left(((), _unfinished_ping)) => Err(Failure::Timeout), + Either::Right((Ok((stream, rtt)), _)) => { + Ok(Some((Success::Ping { rtt }, stream))) + } + Either::Right((Err(e), _)) => { + Err(Failure::Other { error: Box::new(e) }) + } + } + }) + } + }) } - fn poll( + fn on_connection_handler_event( &mut self, - _: &mut Context<'_>, - _: &mut impl PollParameters, - ) -> Poll> { - if let Some(e) = self.events.pop_back() { - let Event { result, peer } = &e; - - match result { - Ok(Success::Ping { .. }) => log::debug!("Ping sent to {:?}", peer), - Ok(Success::Pong) => log::debug!("Ping received from {:?}", peer), - _ => {} + peer: PeerId, + connection: ConnectionId, + event: from_fn::OutEvent, + ) { + match event { + from_fn::OutEvent::InboundEmitted(Ok(success)) => { + self.actions + .push_back(NetworkBehaviourAction::GenerateEvent(Event { + peer, + result: Ok(success), + })) + } + from_fn::OutEvent::OutboundEmitted(Ok(success)) => { + self.actions + .push_back(NetworkBehaviourAction::GenerateEvent(Event { + peer, + result: Ok(success), + })); + self.reset_num_failures(peer, connection); + } + from_fn::OutEvent::InboundEmitted(Err(e)) => { + log::debug!("Inbound ping error: {:?}", e); + } + from_fn::OutEvent::OutboundEmitted(Err(e)) => { + self.record_failure(peer, connection, e); + } + from_fn::OutEvent::FailedToOpen(from_fn::OpenError::Timeout(())) => { + self.record_failure(peer, connection, Failure::Timeout); + } + from_fn::OutEvent::FailedToOpen(from_fn::OpenError::Unsupported { + open_info: (), + .. + }) => { + self.record_failure(peer, connection, Failure::Unsupported); + } + from_fn::OutEvent::FailedToOpen(from_fn::OpenError::NegotiationFailed((), error)) => { + self.record_failure( + peer, + connection, + Failure::Other { + error: Box::new(error), + }, + ); + } + from_fn::OutEvent::FailedToOpen(from_fn::OpenError::LimitExceeded { + open_info: (), + .. + }) => { + unreachable!("We only ever open a new stream if the old one is dead.") } - - Poll::Ready(NetworkBehaviourAction::GenerateEvent(e)) - } else { - Poll::Pending } } @@ -152,8 +327,14 @@ impl NetworkBehaviour for Behaviour { event: libp2p_swarm::behaviour::FromSwarm, ) { match event { - FromSwarm::ConnectionEstablished(_) - | FromSwarm::ConnectionClosed(_) + FromSwarm::ConnectionEstablished(ConnectionEstablished { + peer_id, + connection_id, + .. + }) => self + .actions + .push_back(start_ping_action(peer_id, connection_id)), + FromSwarm::ConnectionClosed(_) | FromSwarm::AddressChange(_) | FromSwarm::DialFailure(_) | FromSwarm::ListenFailure(_) @@ -166,4 +347,61 @@ impl NetworkBehaviour for Behaviour { | FromSwarm::ExpiredExternalAddr(_) => {} } } + + fn poll( + &mut self, + _: &mut Context<'_>, + _: &mut impl PollParameters, + ) -> Poll> { + if let Some(action) = self.actions.pop_front() { + return Poll::Ready(action); + } + + for ((peer, connection), (failures, pending_errors)) in self.failures.iter_mut() { + // Check for outbound ping failures. + if let Some(error) = pending_errors.pop_back() { + log::debug!("Ping failure: {:?}", error); + + *failures += 1; + + // Note: For backward-compatibility, with configured + // `max_failures == 1`, the first failure is always "free" + // and silent. This allows peers who still use a new substream + // for each ping to have successful ping exchanges with peers + // that use a single substream, since every successful ping + // resets `failures` to `0`, while at the same time emitting + // events only for `max_failures - 1` failures, as before. + if *failures > 1 || self.config.max_failures.get() > 1 { + if *failures >= self.config.max_failures.get() { + log::debug!("Too many failures ({}). Closing connection.", failures); + return Poll::Ready(NetworkBehaviourAction::CloseConnection { + peer_id: *peer, + connection: CloseConnection::One(*connection), + }); + } + } + + self.actions + .push_back(start_ping_action(*peer, *connection)); + + return Poll::Ready(NetworkBehaviourAction::GenerateEvent(Event { + peer: *peer, + result: Err(error), + })); + } + } + + Poll::Pending + } +} + +fn start_ping_action( + peer_id: PeerId, + connection: ConnectionId, +) -> NetworkBehaviourAction { + NetworkBehaviourAction::NotifyHandler { + peer_id, + handler: NotifyHandler::One(connection), + event: from_fn::InEvent::NewOutbound(()), + } } diff --git a/protocols/rendezvous/src/client.rs b/protocols/rendezvous/src/client.rs index 5d44354992e..e6c7debf38b 100644 --- a/protocols/rendezvous/src/client.rs +++ b/protocols/rendezvous/src/client.rs @@ -18,33 +18,38 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::codec::{Cookie, ErrorCode, Namespace, NewRegistration, Registration, Ttl}; -use crate::handler; -use crate::handler::outbound; -use crate::handler::outbound::OpenInfo; -use crate::substream_handler::SubstreamConnectionHandler; +use crate::codec::{ + Cookie, Error, ErrorCode, Message, Namespace, NewRegistrationRequest, Registration, + RendezvousCodec, Ttl, +}; +use crate::PROTOCOL_IDENT; +use asynchronous_codec::Framed; use futures::future::BoxFuture; use futures::future::FutureExt; use futures::stream::FuturesUnordered; use futures::stream::StreamExt; -use instant::Duration; +use futures::SinkExt; use libp2p_core::connection::ConnectionId; use libp2p_core::identity::error::SigningError; use libp2p_core::identity::Keypair; -use libp2p_core::{Multiaddr, PeerId, PeerRecord}; +use libp2p_core::{ConnectedPoint, Multiaddr, PeerId, PeerRecord}; use libp2p_swarm::behaviour::FromSwarm; +use libp2p_swarm::handler::from_fn; use libp2p_swarm::{ - CloseConnection, NetworkBehaviour, NetworkBehaviourAction, NotifyHandler, PollParameters, + NegotiatedSubstream, NetworkBehaviour, NetworkBehaviourAction, NotifyHandler, PollParameters, }; use std::collections::{HashMap, VecDeque}; use std::iter::FromIterator; +use std::sync::Arc; use std::task::{Context, Poll}; +use std::time::Duration; +use void::Void; pub struct Behaviour { events: VecDeque< NetworkBehaviourAction< Event, - SubstreamConnectionHandler, + from_fn::FromFnProto, OpenInfo, ()>, >, >, keypair: Keypair, @@ -87,9 +92,7 @@ impl Behaviour { self.events .push_back(NetworkBehaviourAction::NotifyHandler { peer_id: rendezvous_node, - event: handler::OutboundInEvent::NewSubstream { - open_info: OpenInfo::UnregisterRequest(namespace), - }, + event: from_fn::InEvent::NewOutbound(OpenInfo::UnregisterRequest(namespace)), handler: NotifyHandler::Any, }); } @@ -111,13 +114,11 @@ impl Behaviour { self.events .push_back(NetworkBehaviourAction::NotifyHandler { peer_id: rendezvous_node, - event: handler::OutboundInEvent::NewSubstream { - open_info: OpenInfo::DiscoverRequest { - namespace: ns, - cookie, - limit, - }, - }, + event: from_fn::InEvent::NewOutbound(OpenInfo::DiscoverRequest { + namespace: ns, + cookie, + limit, + }), handler: NotifyHandler::Any, }); } @@ -164,15 +165,45 @@ pub enum Event { Expired { peer: PeerId }, } +#[derive(Debug, Clone)] +pub enum OutboundEvent { + Registered { + namespace: Namespace, + ttl: Ttl, + }, + RegisterFailed(Namespace, ErrorCode), + Discovered { + registrations: Vec, + cookie: Cookie, + }, + DiscoverFailed { + namespace: Option, + error: ErrorCode, + }, +} + +#[allow(clippy::large_enum_variant)] +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Clone)] +pub enum OpenInfo { + RegisterRequest(NewRegistrationRequest), + UnregisterRequest(Namespace), + DiscoverRequest { + namespace: Option, + cookie: Option, + limit: Option, + }, +} + impl NetworkBehaviour for Behaviour { - type ConnectionHandler = - SubstreamConnectionHandler; + type ConnectionHandler = from_fn::FromFnProto, OpenInfo, ()>; type OutEvent = Event; fn new_handler(&mut self) -> Self::ConnectionHandler { - let initial_keep_alive = Duration::from_secs(30); - - SubstreamConnectionHandler::new_outbound_only(initial_keep_alive) + from_fn::from_fn(PROTOCOL_IDENT) + .without_state() + .without_inbound_handler() + .with_outbound_handler(10, outbound_stream_handler) } fn addresses_of_peer(&mut self, peer: &PeerId) -> Vec { @@ -187,25 +218,89 @@ impl NetworkBehaviour for Behaviour { fn on_connection_handler_event( &mut self, peer_id: PeerId, - connection_id: ConnectionId, - event: handler::OutboundOutEvent, + _: ConnectionId, + event: from_fn::OutEvent, OpenInfo>, ) { let new_events = match event { - handler::OutboundOutEvent::InboundEvent { message, .. } => void::unreachable(message), - handler::OutboundOutEvent::OutboundEvent { message, .. } => handle_outbound_event( - message, - peer_id, - &mut self.discovered_peers, - &mut self.expiring_registrations, - ), - handler::OutboundOutEvent::InboundError { error, .. } => void::unreachable(error), - handler::OutboundOutEvent::OutboundError { error, .. } => { - log::warn!("Connection with peer {} failed: {}", peer_id, error); - - vec![NetworkBehaviourAction::CloseConnection { - peer_id, - connection: CloseConnection::One(connection_id), - }] + from_fn::OutEvent::InboundEmitted(never) => void::unreachable(never), + from_fn::OutEvent::OutboundEmitted(Ok(OutboundEvent::Discovered { + registrations, + cookie, + })) => { + self.discovered_peers + .extend(registrations.iter().map(|registration| { + let peer_id = registration.record.peer_id(); + let namespace = registration.namespace.clone(); + let addresses = registration.record.addresses().to_vec(); + + ((peer_id, namespace), addresses) + })); + self.expiring_registrations + .extend(registrations.iter().cloned().map(|registration| { + async move { + // if the timer errors we consider it expired + futures_timer::Delay::new(Duration::from_secs(registration.ttl as u64)) + .await; + + (registration.record.peer_id(), registration.namespace) + } + .boxed() + })); + + vec![NetworkBehaviourAction::GenerateEvent(Event::Discovered { + rendezvous_node: peer_id, + registrations, + cookie, + })] + } + from_fn::OutEvent::OutboundEmitted(Ok(OutboundEvent::Registered { + namespace, + ttl, + })) => { + vec![NetworkBehaviourAction::GenerateEvent(Event::Registered { + rendezvous_node: peer_id, + ttl, + namespace, + })] + } + from_fn::OutEvent::OutboundEmitted(Ok(OutboundEvent::DiscoverFailed { + namespace, + error, + })) => { + vec![NetworkBehaviourAction::GenerateEvent( + Event::DiscoverFailed { + rendezvous_node: peer_id, + namespace, + error, + }, + )] + } + from_fn::OutEvent::OutboundEmitted(Ok(OutboundEvent::RegisterFailed( + namespace, + error, + ))) => { + vec![NetworkBehaviourAction::GenerateEvent( + Event::RegisterFailed(RegisterError::Remote { + rendezvous_node: peer_id, + namespace, + error, + }), + )] + } + from_fn::OutEvent::OutboundEmitted(Err(_)) => { + todo!() + } + from_fn::OutEvent::FailedToOpen(from_fn::OpenError::Timeout(_)) => { + todo!() + } + from_fn::OutEvent::FailedToOpen(from_fn::OpenError::NegotiationFailed(..)) => { + todo!() + } + from_fn::OutEvent::FailedToOpen(from_fn::OpenError::LimitExceeded { .. }) => { + todo!() + } + from_fn::OutEvent::FailedToOpen(from_fn::OpenError::Unsupported { .. }) => { + todo!() } }; @@ -238,13 +333,13 @@ impl NetworkBehaviour for Behaviour { let action = match PeerRecord::new(&self.keypair, external_addresses) { Ok(peer_record) => NetworkBehaviourAction::NotifyHandler { peer_id: rendezvous_node, - event: handler::OutboundInEvent::NewSubstream { - open_info: OpenInfo::RegisterRequest(NewRegistration { + event: from_fn::InEvent::NewOutbound(OpenInfo::RegisterRequest( + NewRegistrationRequest { namespace, record: peer_record, ttl, - }), - }, + }, + )), handler: NotifyHandler::Any, }, Err(signing_error) => NetworkBehaviourAction::GenerateEvent(Event::RegisterFailed( @@ -285,70 +380,69 @@ impl NetworkBehaviour for Behaviour { } } -fn handle_outbound_event( - event: outbound::OutEvent, - peer_id: PeerId, - discovered_peers: &mut HashMap<(PeerId, Namespace), Vec>, - expiring_registrations: &mut FuturesUnordered>, -) -> Vec< - NetworkBehaviourAction< - Event, - SubstreamConnectionHandler, - >, -> { - match event { - outbound::OutEvent::Registered { namespace, ttl } => { - vec![NetworkBehaviourAction::GenerateEvent(Event::Registered { - rendezvous_node: peer_id, - ttl, +async fn outbound_stream_handler( + substream: NegotiatedSubstream, + _: PeerId, + _: ConnectedPoint, + _: Arc<()>, + request: OpenInfo, +) -> Result { + let mut substream = Framed::new(substream, RendezvousCodec::default()); + + substream + .send(match request.clone() { + OpenInfo::RegisterRequest(new_registration) => Message::Register(new_registration), + OpenInfo::UnregisterRequest(namespace) => Message::Unregister(namespace), + OpenInfo::DiscoverRequest { namespace, - })] + cookie, + limit, + } => Message::Discover { + namespace, + cookie, + limit, + }, + }) + .await?; + + let response = substream.next().await.transpose()?; + + let out_event = match (request, response) { + (OpenInfo::RegisterRequest(r), Some(Message::RegisterResponse(Ok(ttl)))) => { + OutboundEvent::Registered { + namespace: r.namespace, + ttl, + } } - outbound::OutEvent::RegisterFailed(namespace, error) => { - vec![NetworkBehaviourAction::GenerateEvent( - Event::RegisterFailed(RegisterError::Remote { - rendezvous_node: peer_id, - namespace, - error, - }), - )] + (OpenInfo::RegisterRequest(r), Some(Message::RegisterResponse(Err(e)))) => { + OutboundEvent::RegisterFailed(r.namespace, e) } - outbound::OutEvent::Discovered { + ( + OpenInfo::DiscoverRequest { .. }, + Some(Message::DiscoverResponse(Ok((registrations, cookie)))), + ) => OutboundEvent::Discovered { registrations, cookie, - } => { - discovered_peers.extend(registrations.iter().map(|registration| { - let peer_id = registration.record.peer_id(); - let namespace = registration.namespace.clone(); - - let addresses = registration.record.addresses().to_vec(); - - ((peer_id, namespace), addresses) - })); - expiring_registrations.extend(registrations.iter().cloned().map(|registration| { - async move { - // if the timer errors we consider it expired - futures_timer::Delay::new(Duration::from_secs(registration.ttl)).await; - - (registration.record.peer_id(), registration.namespace) - } - .boxed() - })); + }, + ( + OpenInfo::DiscoverRequest { namespace, .. }, + Some(Message::DiscoverResponse(Err(error))), + ) => OutboundEvent::DiscoverFailed { namespace, error }, + (OpenInfo::UnregisterRequest(_), None) => { + // All good. - vec![NetworkBehaviourAction::GenerateEvent(Event::Discovered { - rendezvous_node: peer_id, - registrations, - cookie, - })] + todo!() } - outbound::OutEvent::DiscoverFailed { namespace, error } => { - vec![NetworkBehaviourAction::GenerateEvent( - Event::DiscoverFailed { - rendezvous_node: peer_id, - namespace, - error, - }, - )] + (_, None) => { + // EOF? + todo!() } - } + _ => { + panic!("protocol violation") // TODO: Make two different codecs to avoid this? + } + }; + + substream.close().await?; + + Ok(out_event) } diff --git a/protocols/rendezvous/src/codec.rs b/protocols/rendezvous/src/codec.rs index 88af5a1fa98..16f62d29a29 100644 --- a/protocols/rendezvous/src/codec.rs +++ b/protocols/rendezvous/src/codec.rs @@ -18,7 +18,6 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::DEFAULT_TTL; use asynchronous_codec::{BytesMut, Decoder, Encoder}; use libp2p_core::{peer_record, signed_envelope, PeerRecord, SignedEnvelope}; use rand::RngCore; @@ -30,7 +29,7 @@ pub type Ttl = u64; #[allow(clippy::large_enum_variant)] #[derive(Debug, Clone)] pub enum Message { - Register(NewRegistration), + Register(NewRegistrationRequest), RegisterResponse(Result), Unregister(Namespace), Discover { @@ -161,26 +160,12 @@ impl Cookie { pub struct InvalidCookie; #[derive(Debug, Clone)] -pub struct NewRegistration { +pub struct NewRegistrationRequest { pub namespace: Namespace, pub record: PeerRecord, pub ttl: Option, } -impl NewRegistration { - pub fn new(namespace: Namespace, record: PeerRecord, ttl: Option) -> Self { - Self { - namespace, - record, - ttl, - } - } - - pub fn effective_ttl(&self) -> Ttl { - self.ttl.unwrap_or(DEFAULT_TTL) - } -} - #[derive(Debug, Clone, PartialEq, Eq)] pub struct Registration { pub namespace: Namespace, @@ -251,7 +236,7 @@ impl From for wire::Message { use wire::message::*; match message { - Message::Register(NewRegistration { + Message::Register(NewRegistrationRequest { namespace, record, ttl, @@ -375,16 +360,20 @@ impl TryFrom for Message { signed_peer_record: Some(signed_peer_record), }), .. - } => Message::Register(NewRegistration { - namespace: ns + } => { + let namespace = ns .map(Namespace::new) .transpose()? - .ok_or(ConversionError::MissingNamespace)?, - ttl, - record: PeerRecord::from_signed_envelope(SignedEnvelope::from_protobuf_encoding( - &signed_peer_record, - )?)?, - }), + .ok_or(ConversionError::MissingNamespace)?; + let record = PeerRecord::from_signed_envelope( + SignedEnvelope::from_protobuf_encoding(&signed_peer_record)?, + )?; + Message::Register(NewRegistrationRequest { + namespace, + record, + ttl, + }) + } wire::Message { r#type: Some(1), register_response: diff --git a/protocols/rendezvous/src/handler.rs b/protocols/rendezvous/src/handler.rs index d07bf4d248f..b7f06647aa3 100644 --- a/protocols/rendezvous/src/handler.rs +++ b/protocols/rendezvous/src/handler.rs @@ -22,7 +22,7 @@ use crate::codec; use crate::codec::Message; use void::Void; -const PROTOCOL_IDENT: &[u8] = b"/rendezvous/1.0.0"; +pub const PROTOCOL_IDENT: &str = "/rendezvous/1.0.0"; pub mod inbound; pub mod outbound; diff --git a/protocols/rendezvous/src/handler/inbound.rs b/protocols/rendezvous/src/handler/inbound.rs deleted file mode 100644 index 3f432bee6bd..00000000000 --- a/protocols/rendezvous/src/handler/inbound.rs +++ /dev/null @@ -1,192 +0,0 @@ -// Copyright 2021 COMIT Network. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -use crate::codec::{ - Cookie, ErrorCode, Message, Namespace, NewRegistration, Registration, RendezvousCodec, Ttl, -}; -use crate::handler::Error; -use crate::handler::PROTOCOL_IDENT; -use crate::substream_handler::{Next, PassthroughProtocol, SubstreamHandler}; -use asynchronous_codec::Framed; -use futures::{SinkExt, StreamExt}; -use libp2p_swarm::{NegotiatedSubstream, SubstreamProtocol}; -use std::fmt; -use std::task::{Context, Poll}; - -/// The state of an inbound substream (i.e. the remote node opened it). -#[allow(clippy::large_enum_variant)] -#[allow(clippy::enum_variant_names)] -pub enum Stream { - /// We are in the process of reading a message from the substream. - PendingRead(Framed), - /// We read a message, dispatched it to the behaviour and are waiting for the response. - PendingBehaviour(Framed), - /// We are in the process of sending a response. - PendingSend(Framed, Message), - /// We've sent the message and are now closing down the substream. - PendingClose(Framed), -} - -impl fmt::Debug for Stream { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Stream::PendingRead(_) => write!(f, "Inbound::PendingRead"), - Stream::PendingBehaviour(_) => write!(f, "Inbound::PendingBehaviour"), - Stream::PendingSend(_, _) => write!(f, "Inbound::PendingSend"), - Stream::PendingClose(_) => write!(f, "Inbound::PendingClose"), - } - } -} - -#[allow(clippy::large_enum_variant)] -#[allow(clippy::enum_variant_names)] -#[derive(Debug, Clone)] -pub enum OutEvent { - RegistrationRequested(NewRegistration), - UnregisterRequested(Namespace), - DiscoverRequested { - namespace: Option, - cookie: Option, - limit: Option, - }, -} - -#[derive(Debug)] -pub enum InEvent { - RegisterResponse { - ttl: Ttl, - }, - DeclineRegisterRequest(ErrorCode), - DiscoverResponse { - discovered: Vec, - cookie: Cookie, - }, - DeclineDiscoverRequest(ErrorCode), -} - -impl SubstreamHandler for Stream { - type InEvent = InEvent; - type OutEvent = OutEvent; - type Error = Error; - type OpenInfo = (); - - fn upgrade( - open_info: Self::OpenInfo, - ) -> SubstreamProtocol { - SubstreamProtocol::new(PassthroughProtocol::new(PROTOCOL_IDENT), open_info) - } - - fn new(substream: NegotiatedSubstream, _: Self::OpenInfo) -> Self { - Stream::PendingRead(Framed::new(substream, RendezvousCodec::default())) - } - - fn inject_event(self, event: Self::InEvent) -> Self { - match (event, self) { - (InEvent::RegisterResponse { ttl }, Stream::PendingBehaviour(substream)) => { - Stream::PendingSend(substream, Message::RegisterResponse(Ok(ttl))) - } - (InEvent::DeclineRegisterRequest(error), Stream::PendingBehaviour(substream)) => { - Stream::PendingSend(substream, Message::RegisterResponse(Err(error))) - } - ( - InEvent::DiscoverResponse { discovered, cookie }, - Stream::PendingBehaviour(substream), - ) => Stream::PendingSend( - substream, - Message::DiscoverResponse(Ok((discovered, cookie))), - ), - (InEvent::DeclineDiscoverRequest(error), Stream::PendingBehaviour(substream)) => { - Stream::PendingSend(substream, Message::DiscoverResponse(Err(error))) - } - (event, inbound) => { - debug_assert!(false, "{:?} cannot handle event {:?}", inbound, event); - - inbound - } - } - } - - fn advance(self, cx: &mut Context<'_>) -> Result, Self::Error> { - let next_state = match self { - Stream::PendingRead(mut substream) => { - match substream.poll_next_unpin(cx).map_err(Error::ReadMessage)? { - Poll::Ready(Some(msg)) => { - let event = match msg { - Message::Register(registration) => { - OutEvent::RegistrationRequested(registration) - } - Message::Discover { - cookie, - namespace, - limit, - } => OutEvent::DiscoverRequested { - cookie, - namespace, - limit, - }, - Message::Unregister(namespace) => { - OutEvent::UnregisterRequested(namespace) - } - other => return Err(Error::BadMessage(other)), - }; - - Next::EmitEvent { - event, - next_state: Stream::PendingBehaviour(substream), - } - } - Poll::Ready(None) => return Err(Error::UnexpectedEndOfStream), - Poll::Pending => Next::Pending { - next_state: Stream::PendingRead(substream), - }, - } - } - Stream::PendingBehaviour(substream) => Next::Pending { - next_state: Stream::PendingBehaviour(substream), - }, - Stream::PendingSend(mut substream, message) => match substream - .poll_ready_unpin(cx) - .map_err(Error::WriteMessage)? - { - Poll::Ready(()) => { - substream - .start_send_unpin(message) - .map_err(Error::WriteMessage)?; - - Next::Continue { - next_state: Stream::PendingClose(substream), - } - } - Poll::Pending => Next::Pending { - next_state: Stream::PendingSend(substream, message), - }, - }, - Stream::PendingClose(mut substream) => match substream.poll_close_unpin(cx) { - Poll::Ready(Ok(())) => Next::Done, - Poll::Ready(Err(_)) => Next::Done, // there is nothing we can do about an error during close - Poll::Pending => Next::Pending { - next_state: Stream::PendingClose(substream), - }, - }, - }; - - Ok(next_state) - } -} diff --git a/protocols/rendezvous/src/handler/outbound.rs b/protocols/rendezvous/src/handler/outbound.rs deleted file mode 100644 index d461e7c7294..00000000000 --- a/protocols/rendezvous/src/handler/outbound.rs +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2021 COMIT Network. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -use crate::codec::{Cookie, Message, NewRegistration, RendezvousCodec}; -use crate::handler::Error; -use crate::handler::PROTOCOL_IDENT; -use crate::substream_handler::{FutureSubstream, Next, PassthroughProtocol, SubstreamHandler}; -use crate::{ErrorCode, Namespace, Registration, Ttl}; -use asynchronous_codec::Framed; -use futures::{SinkExt, TryFutureExt, TryStreamExt}; -use libp2p_swarm::{NegotiatedSubstream, SubstreamProtocol}; -use std::task::Context; -use void::Void; - -pub struct Stream(FutureSubstream); - -impl SubstreamHandler for Stream { - type InEvent = Void; - type OutEvent = OutEvent; - type Error = Error; - type OpenInfo = OpenInfo; - - fn upgrade( - open_info: Self::OpenInfo, - ) -> SubstreamProtocol { - SubstreamProtocol::new(PassthroughProtocol::new(PROTOCOL_IDENT), open_info) - } - - fn new(substream: NegotiatedSubstream, info: Self::OpenInfo) -> Self { - let mut stream = Framed::new(substream, RendezvousCodec::default()); - let sent_message = match info { - OpenInfo::RegisterRequest(new_registration) => Message::Register(new_registration), - OpenInfo::UnregisterRequest(namespace) => Message::Unregister(namespace), - OpenInfo::DiscoverRequest { - namespace, - cookie, - limit, - } => Message::Discover { - namespace, - cookie, - limit, - }, - }; - - Self(FutureSubstream::new(async move { - use Message::*; - use OutEvent::*; - - stream - .send(sent_message.clone()) - .map_err(Error::WriteMessage) - .await?; - let received_message = stream.try_next().map_err(Error::ReadMessage).await?; - let received_message = received_message.ok_or(Error::UnexpectedEndOfStream)?; - - let event = match (sent_message, received_message) { - (Register(registration), RegisterResponse(Ok(ttl))) => Registered { - namespace: registration.namespace, - ttl, - }, - (Register(registration), RegisterResponse(Err(error))) => { - RegisterFailed(registration.namespace, error) - } - (Discover { .. }, DiscoverResponse(Ok((registrations, cookie)))) => Discovered { - registrations, - cookie, - }, - (Discover { namespace, .. }, DiscoverResponse(Err(error))) => { - DiscoverFailed { namespace, error } - } - (.., other) => return Err(Error::BadMessage(other)), - }; - - stream.close().map_err(Error::WriteMessage).await?; - - Ok(event) - })) - } - - fn inject_event(self, event: Self::InEvent) -> Self { - void::unreachable(event) - } - - fn advance(self, cx: &mut Context<'_>) -> Result, Self::Error> { - Ok(self.0.advance(cx)?.map_state(Stream)) - } -} - -#[derive(Debug, Clone)] -pub enum OutEvent { - Registered { - namespace: Namespace, - ttl: Ttl, - }, - RegisterFailed(Namespace, ErrorCode), - Discovered { - registrations: Vec, - cookie: Cookie, - }, - DiscoverFailed { - namespace: Option, - error: ErrorCode, - }, -} - -#[allow(clippy::large_enum_variant)] -#[allow(clippy::enum_variant_names)] -#[derive(Debug)] -pub enum OpenInfo { - RegisterRequest(NewRegistration), - UnregisterRequest(Namespace), - DiscoverRequest { - namespace: Option, - cookie: Option, - limit: Option, - }, -} diff --git a/protocols/rendezvous/src/lib.rs b/protocols/rendezvous/src/lib.rs index 337e554ea00..fceed1d8217 100644 --- a/protocols/rendezvous/src/lib.rs +++ b/protocols/rendezvous/src/lib.rs @@ -24,9 +24,9 @@ pub use self::codec::{Cookie, ErrorCode, Namespace, NamespaceTooLong, Registration, Ttl}; +const PROTOCOL_IDENT: &str = "/rendezvous/1.0.0"; + mod codec; -mod handler; -mod substream_handler; /// If unspecified, rendezvous nodes should assume a TTL of 2h. /// diff --git a/protocols/rendezvous/src/server.rs b/protocols/rendezvous/src/server.rs index 4126b6e3e28..3dd1f9ca282 100644 --- a/protocols/rendezvous/src/server.rs +++ b/protocols/rendezvous/src/server.rs @@ -18,32 +18,40 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::codec::{Cookie, ErrorCode, Namespace, NewRegistration, Registration, Ttl}; -use crate::handler::inbound; -use crate::substream_handler::{InboundSubstreamId, SubstreamConnectionHandler}; -use crate::{handler, MAX_TTL, MIN_TTL}; +use crate::codec::{ + Cookie, Error, ErrorCode, Message, Namespace, NewRegistrationRequest, Registration, + RendezvousCodec, Ttl, +}; +use crate::{DEFAULT_TTL, MAX_TTL, MIN_TTL, PROTOCOL_IDENT}; +use asynchronous_codec::Framed; use bimap::BiMap; use futures::future::BoxFuture; -use futures::ready; use futures::stream::FuturesUnordered; +use futures::{ready, SinkExt}; use futures::{FutureExt, StreamExt}; use libp2p_core::connection::ConnectionId; -use libp2p_core::PeerId; +use libp2p_core::{ConnectedPoint, PeerId, PeerRecord}; use libp2p_swarm::behaviour::FromSwarm; +use libp2p_swarm::handler::from_fn; use libp2p_swarm::{ - CloseConnection, NetworkBehaviour, NetworkBehaviourAction, NotifyHandler, PollParameters, + from_fn, NegotiatedSubstream, NetworkBehaviour, NetworkBehaviourAction, PollParameters, }; use std::collections::{HashMap, HashSet, VecDeque}; -use std::iter::FromIterator; +use std::io; +use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; use void::Void; pub struct Behaviour { + next_expiry: FuturesUnordered>, + registrations: from_fn::Shared, events: VecDeque< - NetworkBehaviourAction>, + NetworkBehaviourAction< + Event, + from_fn::FromFnProto, Error>, Void, Void, Registrations>, + >, >, - registrations: Registrations, } pub struct Config { @@ -76,8 +84,39 @@ impl Behaviour { /// Create a new instance of the rendezvous [`NetworkBehaviour`]. pub fn new(config: Config) -> Self { Self { + next_expiry: FuturesUnordered::new(), + registrations: from_fn::Shared::new(Registrations::with_config(config)), events: Default::default(), - registrations: Registrations::with_config(config), + } + } + + fn poll_expiry_timers(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + let expired_registration = match ready!(self.next_expiry.poll_next_unpin(cx)) { + Some(r) => r, + None => return Poll::Ready(None), // TODO: Register waker and return Pending? + }; + + // clean up our cookies + self.registrations.cookies.retain(|_, registrations| { + registrations.remove(&expired_registration); + + // retain all cookies where there are still registrations left + !registrations.is_empty() + }); + + self.registrations + .registrations_for_peer + .remove_by_right(&expired_registration); + + match self + .registrations + .registrations + .remove(&expired_registration) + { + Some(registration) => return Poll::Ready(Some(ExpiredRegistration(registration))), + None => continue, + } } } } @@ -110,38 +149,105 @@ pub enum Event { } impl NetworkBehaviour for Behaviour { - type ConnectionHandler = SubstreamConnectionHandler; + type ConnectionHandler = + from_fn::FromFnProto, Error>, Void, Void, Registrations>; type OutEvent = Event; fn new_handler(&mut self) -> Self::ConnectionHandler { - let initial_keep_alive = Duration::from_secs(30); + from_fn(PROTOCOL_IDENT) + .with_state(&self.registrations) + .with_inbound_handler(10, inbound_stream_handler) + .without_outbound_handler() + } - SubstreamConnectionHandler::new_inbound_only(initial_keep_alive) + fn on_swarm_event(&mut self, event: FromSwarm) { + self.registrations.on_swarm_event(&event); } fn on_connection_handler_event( &mut self, - peer_id: PeerId, - connection: ConnectionId, - event: handler::InboundOutEvent, + peer: PeerId, + _: ConnectionId, + event: from_fn::OutEvent, Error>, Void, Void>, ) { let new_events = match event { - handler::InboundOutEvent::InboundEvent { id, message } => { - handle_inbound_event(message, peer_id, connection, id, &mut self.registrations) + from_fn::OutEvent::InboundEmitted(Ok(Some(InboundOutEvent::NewRegistration( + new_registration, + )))) => { + let (registration, expiry) = self.registrations.add(new_registration); + self.next_expiry.push(expiry); + + vec![NetworkBehaviourAction::GenerateEvent( + Event::PeerRegistered { peer, registration }, + )] } - handler::InboundOutEvent::OutboundEvent { message, .. } => void::unreachable(message), - handler::InboundOutEvent::InboundError { error, .. } => { - log::warn!("Connection with peer {} failed: {}", peer_id, error); - - vec![NetworkBehaviourAction::CloseConnection { - peer_id, - connection: CloseConnection::One(connection), - }] + from_fn::OutEvent::InboundEmitted(Ok(Some(InboundOutEvent::RegistrationFailed { + error, + namespace, + }))) => { + vec![NetworkBehaviourAction::GenerateEvent( + Event::PeerNotRegistered { + peer, + error, + namespace, + }, + )] + } + from_fn::OutEvent::InboundEmitted(Ok(Some(InboundOutEvent::Discovered { + cookie, + registrations, + previous_registrations, + }))) => { + self.registrations + .cookies + .insert(cookie, previous_registrations); + + vec![NetworkBehaviourAction::GenerateEvent( + Event::DiscoverServed { + enquirer: peer, + registrations, + }, + )] + } + from_fn::OutEvent::InboundEmitted(Ok(Some(InboundOutEvent::DiscoverFailed { + error, + }))) => { + vec![NetworkBehaviourAction::GenerateEvent( + Event::DiscoverNotServed { + enquirer: peer, + error, + }, + )] + } + from_fn::OutEvent::InboundEmitted(Ok(Some(InboundOutEvent::Unregister(namespace)))) => { + self.registrations.remove(namespace.clone(), peer); + + vec![NetworkBehaviourAction::GenerateEvent( + Event::PeerUnregistered { peer, namespace }, + )] + } + from_fn::OutEvent::OutboundEmitted(never) => void::unreachable(never), + from_fn::OutEvent::FailedToOpen(never) => match never { + from_fn::OpenError::Timeout(never) => void::unreachable(never), + from_fn::OpenError::LimitExceeded { + open_info: never, .. + } => void::unreachable(never), + from_fn::OpenError::NegotiationFailed(never, _) => void::unreachable(never), + from_fn::OpenError::Unsupported { + open_info: never, .. + } => void::unreachable(never), + }, + from_fn::OutEvent::InboundEmitted(Err(error)) => { + log::debug!("Inbound stream from {peer} failed: {error}"); + + vec![] + } + from_fn::OutEvent::InboundEmitted(Ok(None)) => { + vec![] } - handler::InboundOutEvent::OutboundError { error, .. } => void::unreachable(error), }; - self.events.extend(new_events); + self.events.extend(new_events) } fn poll( @@ -149,170 +255,123 @@ impl NetworkBehaviour for Behaviour { cx: &mut Context<'_>, _: &mut impl PollParameters, ) -> Poll> { - if let Poll::Ready(ExpiredRegistration(registration)) = self.registrations.poll(cx) { + if let Some(event) = self.events.pop_front() { + return Poll::Ready(event); + } + + if let Poll::Ready(Some(ExpiredRegistration(registration))) = self.poll_expiry_timers(cx) { return Poll::Ready(NetworkBehaviourAction::GenerateEvent( Event::RegistrationExpired(registration), )); } - if let Some(event) = self.events.pop_front() { - return Poll::Ready(event); + if let Poll::Ready(action) = self.registrations.poll(cx) { + return Poll::Ready(action); } Poll::Pending } - - fn on_swarm_event(&mut self, event: FromSwarm) { - match event { - FromSwarm::ConnectionEstablished(_) - | FromSwarm::ConnectionClosed(_) - | FromSwarm::AddressChange(_) - | FromSwarm::DialFailure(_) - | FromSwarm::ListenFailure(_) - | FromSwarm::NewListener(_) - | FromSwarm::NewListenAddr(_) - | FromSwarm::ExpiredListenAddr(_) - | FromSwarm::ListenerError(_) - | FromSwarm::ListenerClosed(_) - | FromSwarm::NewExternalAddr(_) - | FromSwarm::ExpiredExternalAddr(_) => {} - } - } } -fn handle_inbound_event( - event: inbound::OutEvent, - peer_id: PeerId, - connection: ConnectionId, - id: InboundSubstreamId, - registrations: &mut Registrations, -) -> Vec>> { - match event { - // bad registration - inbound::OutEvent::RegistrationRequested(registration) - if registration.record.peer_id() != peer_id => - { - let error = ErrorCode::NotAuthorized; - - vec![ - NetworkBehaviourAction::NotifyHandler { - peer_id, - handler: NotifyHandler::One(connection), - event: handler::InboundInEvent::NotifyInboundSubstream { - id, - message: inbound::InEvent::DeclineRegisterRequest(error), - }, - }, - NetworkBehaviourAction::GenerateEvent(Event::PeerNotRegistered { - peer: peer_id, - namespace: registration.namespace, - error, - }), - ] - } - inbound::OutEvent::RegistrationRequested(registration) => { - let namespace = registration.namespace.clone(); - - match registrations.add(registration) { - Ok(registration) => { - vec![ - NetworkBehaviourAction::NotifyHandler { - peer_id, - handler: NotifyHandler::One(connection), - event: handler::InboundInEvent::NotifyInboundSubstream { - id, - message: inbound::InEvent::RegisterResponse { - ttl: registration.ttl, - }, - }, - }, - NetworkBehaviourAction::GenerateEvent(Event::PeerRegistered { - peer: peer_id, - registration, - }), - ] +async fn inbound_stream_handler( + substream: NegotiatedSubstream, + peer: PeerId, + _: ConnectedPoint, + registrations: Arc, +) -> Result, Error> { + let mut substream = Framed::new(substream, RendezvousCodec::default()); + + let message = substream + .next() + .await + .ok_or_else(|| Error::Io(io::ErrorKind::UnexpectedEof.into()))??; + + let out_event = match message { + Message::Register(new_registration) => { + let namespace = new_registration.namespace.clone(); + + match registrations.new_registration(peer, new_registration) { + Ok(new_registration) => { + substream + .send(Message::RegisterResponse(Ok( + new_registration.effective_ttl() + ))) + .await?; + + Some(InboundOutEvent::NewRegistration(new_registration)) } - Err(TtlOutOfRange::TooLong { .. }) | Err(TtlOutOfRange::TooShort { .. }) => { - let error = ErrorCode::InvalidTtl; - - vec![ - NetworkBehaviourAction::NotifyHandler { - peer_id, - handler: NotifyHandler::One(connection), - event: handler::InboundInEvent::NotifyInboundSubstream { - id, - message: inbound::InEvent::DeclineRegisterRequest(error), - }, - }, - NetworkBehaviourAction::GenerateEvent(Event::PeerNotRegistered { - peer: peer_id, - namespace, - error, - }), - ] + Err(error) => { + substream + .send(Message::RegisterResponse(Err(error))) + .await?; + + Some(InboundOutEvent::RegistrationFailed { namespace, error }) } } } - inbound::OutEvent::DiscoverRequested { + Message::Unregister(namespace) => Some(InboundOutEvent::Unregister(namespace)), + Message::Discover { namespace, cookie, limit, } => match registrations.get(namespace, cookie, limit) { - Ok((registrations, cookie)) => { - let discovered = registrations.cloned().collect::>(); - - vec![ - NetworkBehaviourAction::NotifyHandler { - peer_id, - handler: NotifyHandler::One(connection), - event: handler::InboundInEvent::NotifyInboundSubstream { - id, - message: inbound::InEvent::DiscoverResponse { - discovered: discovered.clone(), - cookie, - }, - }, - }, - NetworkBehaviourAction::GenerateEvent(Event::DiscoverServed { - enquirer: peer_id, - registrations: discovered, - }), - ] + Ok((registrations, cookie, sent_registrations)) => { + let registrations = registrations.cloned().collect::>(); + + substream + .send(Message::DiscoverResponse(Ok(( + registrations.clone(), + cookie.clone(), + )))) + .await?; + + Some(InboundOutEvent::Discovered { + cookie, + registrations, + previous_registrations: sent_registrations, + }) } Err(_) => { - let error = ErrorCode::InvalidCookie; - - vec![ - NetworkBehaviourAction::NotifyHandler { - peer_id, - handler: NotifyHandler::One(connection), - event: handler::InboundInEvent::NotifyInboundSubstream { - id, - message: inbound::InEvent::DeclineDiscoverRequest(error), - }, - }, - NetworkBehaviourAction::GenerateEvent(Event::DiscoverNotServed { - enquirer: peer_id, - error, - }), - ] + substream + .send(Message::DiscoverResponse(Err(ErrorCode::InvalidCookie))) + .await?; + + Some(InboundOutEvent::DiscoverFailed { + error: ErrorCode::InvalidCookie, + }) } }, - inbound::OutEvent::UnregisterRequested(namespace) => { - registrations.remove(namespace.clone(), peer_id); - - vec![NetworkBehaviourAction::GenerateEvent( - Event::PeerUnregistered { - peer: peer_id, - namespace, - }, - )] + Message::DiscoverResponse(_) | Message::RegisterResponse(_) => { + panic!("protocol violation") } - } + }; + + substream.close().await?; + + Ok(out_event) +} + +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] +pub enum InboundOutEvent { + NewRegistration(NewRegistration), + RegistrationFailed { + namespace: Namespace, + error: ErrorCode, + }, + Discovered { + cookie: Cookie, + registrations: Vec, + previous_registrations: HashSet, + }, + DiscoverFailed { + error: ErrorCode, + }, + Unregister(Namespace), } #[derive(Debug, Eq, PartialEq, Hash, Copy, Clone)] -struct RegistrationId(u64); +pub struct RegistrationId(u64); impl RegistrationId { fn new() -> Self { @@ -323,21 +382,26 @@ impl RegistrationId { #[derive(Debug, PartialEq)] struct ExpiredRegistration(Registration); +#[derive(Debug, Clone)] pub struct Registrations { registrations_for_peer: BiMap<(PeerId, Namespace), RegistrationId>, registrations: HashMap, cookies: HashMap>, min_ttl: Ttl, max_ttl: Ttl, - next_expiry: FuturesUnordered>, } -#[derive(Debug, thiserror::Error)] -pub enum TtlOutOfRange { - #[error("Requested TTL ({requested}s) is too long; max {bound}s")] - TooLong { bound: Ttl, requested: Ttl }, - #[error("Requested TTL ({requested}s) is too short; min {bound}s")] - TooShort { bound: Ttl, requested: Ttl }, +#[derive(Debug)] +pub struct NewRegistration { + namespace: Namespace, + record: PeerRecord, + ttl: Option, +} + +impl NewRegistration { + pub fn effective_ttl(&self) -> Ttl { + self.ttl.unwrap_or(DEFAULT_TTL) + } } impl Default for Registrations { @@ -349,33 +413,19 @@ impl Default for Registrations { impl Registrations { pub fn with_config(config: Config) -> Self { Self { - registrations_for_peer: Default::default(), - registrations: Default::default(), min_ttl: config.min_ttl, max_ttl: config.max_ttl, + registrations_for_peer: Default::default(), + registrations: Default::default(), cookies: Default::default(), - next_expiry: FuturesUnordered::from_iter(vec![futures::future::pending().boxed()]), } } pub fn add( &mut self, new_registration: NewRegistration, - ) -> Result { + ) -> (Registration, BoxFuture<'static, RegistrationId>) { let ttl = new_registration.effective_ttl(); - if ttl > self.max_ttl { - return Err(TtlOutOfRange::TooLong { - bound: self.max_ttl, - requested: ttl, - }); - } - if ttl < self.min_ttl { - return Err(TtlOutOfRange::TooShort { - bound: self.min_ttl, - requested: ttl, - }); - } - let namespace = new_registration.namespace; let registration_id = RegistrationId::new(); @@ -399,13 +449,36 @@ impl Registrations { self.registrations .insert(registration_id, registration.clone()); - let next_expiry = futures_timer::Delay::new(Duration::from_secs(ttl)) + let expiry = futures_timer::Delay::new(Duration::from_secs(ttl as u64)) .map(move |_| registration_id) .boxed(); - self.next_expiry.push(next_expiry); + (registration, expiry) + } + + pub fn new_registration( + &self, + from: PeerId, + new_registration: NewRegistrationRequest, + ) -> Result { + let ttl = new_registration.ttl.unwrap_or(DEFAULT_TTL); + + if ttl > self.max_ttl { + return Err(ErrorCode::InvalidTtl); + } + if ttl < self.min_ttl { + return Err(ErrorCode::InvalidTtl); + } + + if new_registration.record.peer_id() != from { + return Err(ErrorCode::NotAuthorized); + } - Ok(registration) + Ok(NewRegistration { + namespace: new_registration.namespace, + record: new_registration.record, + ttl: new_registration.ttl, + }) } pub fn remove(&mut self, namespace: Namespace, peer_id: PeerId) { @@ -419,11 +492,18 @@ impl Registrations { } pub fn get( - &mut self, + &self, discover_namespace: Option, cookie: Option, limit: Option, - ) -> Result<(impl Iterator + '_, Cookie), CookieNamespaceMismatch> { + ) -> Result< + ( + impl Iterator + '_, + Cookie, + HashSet, + ), + CookieNamespaceMismatch, + > { let cookie_namespace = cookie.as_ref().and_then(|cookie| cookie.namespace()); match (discover_namespace.as_ref(), cookie_namespace) { @@ -469,36 +549,13 @@ impl Registrations { let new_cookie = discover_namespace .map(Cookie::for_namespace) .unwrap_or_else(Cookie::for_all_namespaces); - self.cookies - .insert(new_cookie.clone(), reggos_of_last_discover); let reggos = &self.registrations; let registrations = ids .into_iter() .map(move |id| reggos.get(&id).expect("bad internal datastructure")); - Ok((registrations, new_cookie)) - } - - fn poll(&mut self, cx: &mut Context<'_>) -> Poll { - let expired_registration = ready!(self.next_expiry.poll_next_unpin(cx)).expect( - "This stream should never finish because it is initialised with a pending future", - ); - - // clean up our cookies - self.cookies.retain(|_, registrations| { - registrations.remove(&expired_registration); - - // retain all cookies where there are still registrations left - !registrations.is_empty() - }); - - self.registrations_for_peer - .remove_by_right(&expired_registration); - match self.registrations.remove(&expired_registration) { - None => self.poll(cx), - Some(registration) => Poll::Ready(ExpiredRegistration(registration)), - } + Ok((registrations, new_cookie, reggos_of_last_discover)) } } @@ -508,9 +565,6 @@ pub struct CookieNamespaceMismatch; #[cfg(test)] mod tests { - use instant::SystemTime; - use std::option::Option::None; - use libp2p_core::{identity, PeerRecord}; use super::*; @@ -518,23 +572,27 @@ mod tests { #[test] fn given_cookie_from_discover_when_discover_again_then_only_get_diff() { let mut registrations = Registrations::default(); - registrations.add(new_dummy_registration("foo")).unwrap(); - registrations.add(new_dummy_registration("foo")).unwrap(); + registrations.add(new_dummy_registration("foo")); + registrations.add(new_dummy_registration("foo")); - let (initial_discover, cookie) = registrations.get(None, None, None).unwrap(); + let (initial_discover, cookie, existing_registrations) = + registrations.get(None, None, None).unwrap(); assert_eq!(initial_discover.count(), 2); + registrations + .cookies + .insert(cookie.clone(), existing_registrations); - let (subsequent_discover, _) = registrations.get(None, Some(cookie), None).unwrap(); + let (subsequent_discover, _, _) = registrations.get(None, Some(cookie), None).unwrap(); assert_eq!(subsequent_discover.count(), 0); } #[test] fn given_registrations_when_discover_all_then_all_are_returned() { let mut registrations = Registrations::default(); - registrations.add(new_dummy_registration("foo")).unwrap(); - registrations.add(new_dummy_registration("foo")).unwrap(); + registrations.add(new_dummy_registration("foo")); + registrations.add(new_dummy_registration("foo")); - let (discover, _) = registrations.get(None, None, None).unwrap(); + let (discover, _, _) = registrations.get(None, None, None).unwrap(); assert_eq!(discover.count(), 2); } @@ -543,10 +601,10 @@ mod tests { fn given_registrations_when_discover_only_for_specific_namespace_then_only_those_are_returned() { let mut registrations = Registrations::default(); - registrations.add(new_dummy_registration("foo")).unwrap(); - registrations.add(new_dummy_registration("bar")).unwrap(); + registrations.add(new_dummy_registration("foo")); + registrations.add(new_dummy_registration("bar")); - let (discover, _) = registrations + let (discover, _, _) = registrations .get(Some(Namespace::from_static("foo")), None, None) .unwrap(); @@ -560,14 +618,10 @@ mod tests { fn given_reregistration_old_registration_is_discarded() { let alice = identity::Keypair::generate_ed25519(); let mut registrations = Registrations::default(); - registrations - .add(new_registration("foo", alice.clone(), None)) - .unwrap(); - registrations - .add(new_registration("foo", alice, None)) - .unwrap(); + registrations.add(new_registration("foo", alice.clone(), None)); + registrations.add(new_registration("foo", alice, None)); - let (discover, _) = registrations + let (discover, _, _) = registrations .get(Some(Namespace::from_static("foo")), None, None) .unwrap(); @@ -580,26 +634,34 @@ mod tests { #[test] fn given_cookie_from_2nd_discover_does_not_return_nodes_from_first_discover() { let mut registrations = Registrations::default(); - registrations.add(new_dummy_registration("foo")).unwrap(); - registrations.add(new_dummy_registration("foo")).unwrap(); + registrations.add(new_dummy_registration("foo")); + registrations.add(new_dummy_registration("foo")); - let (initial_discover, cookie1) = registrations.get(None, None, None).unwrap(); + let (initial_discover, cookie1, existing_registrations) = + registrations.get(None, None, None).unwrap(); assert_eq!(initial_discover.count(), 2); + registrations + .cookies + .insert(cookie1.clone(), existing_registrations); - let (subsequent_discover, cookie2) = registrations.get(None, Some(cookie1), None).unwrap(); + let (subsequent_discover, cookie2, existing_registrations) = + registrations.get(None, Some(cookie1), None).unwrap(); assert_eq!(subsequent_discover.count(), 0); + registrations + .cookies + .insert(cookie2.clone(), existing_registrations); - let (subsequent_discover, _) = registrations.get(None, Some(cookie2), None).unwrap(); + let (subsequent_discover, _, _) = registrations.get(None, Some(cookie2), None).unwrap(); assert_eq!(subsequent_discover.count(), 0); } #[test] fn cookie_from_different_discover_request_is_not_valid() { let mut registrations = Registrations::default(); - registrations.add(new_dummy_registration("foo")).unwrap(); - registrations.add(new_dummy_registration("bar")).unwrap(); + registrations.add(new_dummy_registration("foo")); + registrations.add(new_dummy_registration("bar")); - let (_, foo_discover_cookie) = registrations + let (_, foo_discover_cookie, _) = registrations .get(Some(Namespace::from_static("foo")), None, None) .unwrap(); let result = registrations.get( @@ -611,107 +673,13 @@ mod tests { assert!(matches!(result, Err(CookieNamespaceMismatch))) } - #[tokio::test] - async fn given_two_registration_ttls_one_expires_one_lives() { - let mut registrations = Registrations::with_config(Config { - min_ttl: 0, - max_ttl: 4, - }); - - let start_time = SystemTime::now(); - - registrations - .add(new_dummy_registration_with_ttl("foo", 1)) - .unwrap(); - registrations - .add(new_dummy_registration_with_ttl("bar", 4)) - .unwrap(); - - let event = registrations.next_event().await; - - let elapsed = start_time.elapsed().unwrap(); - assert!(elapsed.as_secs() >= 1); - assert!(elapsed.as_secs() < 2); - - assert_eq!(event.0.namespace, Namespace::from_static("foo")); - - { - let (mut discovered_foo, _) = registrations - .get(Some(Namespace::from_static("foo")), None, None) - .unwrap(); - assert!(discovered_foo.next().is_none()); - } - let (mut discovered_bar, _) = registrations - .get(Some(Namespace::from_static("bar")), None, None) - .unwrap(); - assert!(discovered_bar.next().is_some()); - } - - #[tokio::test] - async fn given_peer_unregisters_before_expiry_do_not_emit_registration_expired() { - let mut registrations = Registrations::with_config(Config { - min_ttl: 1, - max_ttl: 10, - }); - let dummy_registration = new_dummy_registration_with_ttl("foo", 2); - let namespace = dummy_registration.namespace.clone(); - let peer_id = dummy_registration.record.peer_id(); - - registrations.add(dummy_registration).unwrap(); - registrations.no_event_for(1).await; - registrations.remove(namespace, peer_id); - - registrations.no_event_for(3).await - } - - /// FuturesUnordered stop polling for ready futures when poll_next() is called until a None - /// value is returned. To prevent the next_expiry future from going to "sleep", next_expiry - /// is initialised with a future that always returns pending. This test ensures that - /// FuturesUnordered does not stop polling for ready futures. - #[tokio::test] - async fn given_all_registrations_expired_then_successfully_handle_new_registration_and_expiry() - { - let mut registrations = Registrations::with_config(Config { - min_ttl: 0, - max_ttl: 10, - }); - let dummy_registration = new_dummy_registration_with_ttl("foo", 1); - - registrations.add(dummy_registration.clone()).unwrap(); - let _ = registrations.next_event_in_at_most(2).await; - - registrations.no_event_for(1).await; - - registrations.add(dummy_registration).unwrap(); - let _ = registrations.next_event_in_at_most(2).await; - } - - #[tokio::test] - async fn cookies_are_cleaned_up_if_registrations_expire() { - let mut registrations = Registrations::with_config(Config { - min_ttl: 1, - max_ttl: 10, - }); - - registrations - .add(new_dummy_registration_with_ttl("foo", 2)) - .unwrap(); - let (_, _) = registrations.get(None, None, None).unwrap(); - - assert_eq!(registrations.cookies.len(), 1); - - let _ = registrations.next_event_in_at_most(3).await; - - assert_eq!(registrations.cookies.len(), 0); - } - #[test] fn given_limit_discover_only_returns_n_results() { let mut registrations = Registrations::default(); - registrations.add(new_dummy_registration("foo")).unwrap(); - registrations.add(new_dummy_registration("foo")).unwrap(); + registrations.add(new_dummy_registration("foo")); + registrations.add(new_dummy_registration("foo")); - let (registrations, _) = registrations.get(None, None, Some(1)).unwrap(); + let (registrations, _, _) = registrations.get(None, None, Some(1)).unwrap(); assert_eq!(registrations.count(), 1); } @@ -719,13 +687,17 @@ mod tests { #[test] fn given_limit_cookie_can_be_used_for_pagination() { let mut registrations = Registrations::default(); - registrations.add(new_dummy_registration("foo")).unwrap(); - registrations.add(new_dummy_registration("foo")).unwrap(); + registrations.add(new_dummy_registration("foo")); + registrations.add(new_dummy_registration("foo")); - let (discover1, cookie) = registrations.get(None, None, Some(1)).unwrap(); + let (discover1, cookie, existing_registrations) = + registrations.get(None, None, Some(1)).unwrap(); assert_eq!(discover1.count(), 1); + registrations + .cookies + .insert(cookie.clone(), existing_registrations); - let (discover2, _) = registrations.get(None, Some(cookie), None).unwrap(); + let (discover2, _, _) = registrations.get(None, Some(cookie), None).unwrap(); assert_eq!(discover2.count(), 1); } @@ -735,42 +707,16 @@ mod tests { new_registration(namespace, identity, None) } - fn new_dummy_registration_with_ttl(namespace: &'static str, ttl: Ttl) -> NewRegistration { - let identity = identity::Keypair::generate_ed25519(); - - new_registration(namespace, identity, Some(ttl)) - } - fn new_registration( namespace: &'static str, identity: identity::Keypair, ttl: Option, ) -> NewRegistration { - NewRegistration::new( - Namespace::from_static(namespace), - PeerRecord::new(&identity, vec!["/ip4/127.0.0.1/tcp/1234".parse().unwrap()]).unwrap(), + NewRegistration { + namespace: Namespace::from_static(namespace), + record: PeerRecord::new(&identity, vec!["/ip4/127.0.0.1/tcp/1234".parse().unwrap()]) + .unwrap(), ttl, - ) - } - - /// Defines utility functions that make the tests more readable. - impl Registrations { - async fn next_event(&mut self) -> ExpiredRegistration { - futures::future::poll_fn(|cx| self.poll(cx)).await - } - - /// Polls [`Registrations`] for `seconds` and panics if it returns a event during this time. - async fn no_event_for(&mut self, seconds: u64) { - tokio::time::timeout(Duration::from_secs(seconds), self.next_event()) - .await - .unwrap_err(); - } - - /// Polls [`Registrations`] for at most `seconds` and panics if doesn't return an event within that time. - async fn next_event_in_at_most(&mut self, seconds: u64) -> ExpiredRegistration { - tokio::time::timeout(Duration::from_secs(seconds), self.next_event()) - .await - .unwrap() } } } diff --git a/protocols/rendezvous/src/substream_handler.rs b/protocols/rendezvous/src/substream_handler.rs deleted file mode 100644 index f57dfded6c9..00000000000 --- a/protocols/rendezvous/src/substream_handler.rs +++ /dev/null @@ -1,553 +0,0 @@ -// Copyright 2021 COMIT Network. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -//! A generic [`ConnectionHandler`] that delegates the handling of substreams to [`SubstreamHandler`]s. -//! -//! This module is an attempt to simplify the implementation of protocols by freeing implementations from dealing with aspects such as concurrent substreams. -//! Particularly for outbound substreams, it greatly simplifies the definition of protocols through the [`FutureSubstream`] helper. -//! -//! At the moment, this module is an implementation detail of the rendezvous protocol but the intent is for it to be provided as a generic module that is accessible to other protocols as well. - -use futures::future::{self, BoxFuture, Fuse, FusedFuture}; -use futures::FutureExt; -use instant::Instant; -use libp2p_core::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}; -use libp2p_swarm::handler::{InboundUpgradeSend, OutboundUpgradeSend}; -use libp2p_swarm::{ - ConnectionHandler, ConnectionHandlerEvent, ConnectionHandlerUpgrErr, KeepAlive, - NegotiatedSubstream, SubstreamProtocol, -}; -use std::collections::{HashMap, VecDeque}; -use std::fmt; -use std::future::Future; -use std::hash::Hash; -use std::task::{Context, Poll}; -use std::time::Duration; -use void::Void; - -/// Handles a substream throughout its lifetime. -pub trait SubstreamHandler: Sized { - type InEvent; - type OutEvent; - type Error; - type OpenInfo; - - fn upgrade(open_info: Self::OpenInfo) - -> SubstreamProtocol; - fn new(substream: NegotiatedSubstream, info: Self::OpenInfo) -> Self; - fn inject_event(self, event: Self::InEvent) -> Self; - fn advance(self, cx: &mut Context<'_>) -> Result, Self::Error>; -} - -/// The result of advancing a [`SubstreamHandler`]. -pub enum Next { - /// Return the given event and set the handler into `next_state`. - EmitEvent { event: TEvent, next_state: TState }, - /// The handler currently cannot do any more work, set its state back into `next_state`. - Pending { next_state: TState }, - /// The handler performed some work and wants to continue in the given state. - /// - /// This variant is useful because it frees the handler from implementing a loop internally. - Continue { next_state: TState }, - /// The handler finished. - Done, -} - -impl Next { - pub fn map_state( - self, - map: impl FnOnce(TState) -> TNextState, - ) -> Next { - match self { - Next::EmitEvent { event, next_state } => Next::EmitEvent { - event, - next_state: map(next_state), - }, - Next::Pending { next_state } => Next::Pending { - next_state: map(next_state), - }, - Next::Continue { next_state } => Next::Pending { - next_state: map(next_state), - }, - Next::Done => Next::Done, - } - } -} - -#[derive(Debug, Hash, Eq, PartialEq, Clone, Copy)] -pub struct InboundSubstreamId(u64); - -impl InboundSubstreamId { - fn fetch_and_increment(&mut self) -> Self { - let next_id = *self; - self.0 += 1; - - next_id - } -} - -impl fmt::Display for InboundSubstreamId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -#[derive(Debug, Hash, Eq, PartialEq, Clone, Copy)] -pub struct OutboundSubstreamId(u64); - -impl OutboundSubstreamId { - fn fetch_and_increment(&mut self) -> Self { - let next_id = *self; - self.0 += 1; - - next_id - } -} - -impl fmt::Display for OutboundSubstreamId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -pub struct PassthroughProtocol { - ident: Option<&'static [u8]>, -} - -impl PassthroughProtocol { - pub fn new(ident: &'static [u8]) -> Self { - Self { ident: Some(ident) } - } -} - -impl UpgradeInfo for PassthroughProtocol { - type Info = &'static [u8]; - type InfoIter = std::option::IntoIter; - - fn protocol_info(&self) -> Self::InfoIter { - self.ident.into_iter() - } -} - -impl InboundUpgrade for PassthroughProtocol { - type Output = C; - type Error = Void; - type Future = BoxFuture<'static, Result>; - - fn upgrade_inbound(self, socket: C, _: Self::Info) -> Self::Future { - match self.ident { - Some(_) => future::ready(Ok(socket)).boxed(), - None => future::pending().boxed(), - } - } -} - -impl OutboundUpgrade for PassthroughProtocol { - type Output = C; - type Error = Void; - type Future = BoxFuture<'static, Result>; - - fn upgrade_outbound(self, socket: C, _: Self::Info) -> Self::Future { - match self.ident { - Some(_) => future::ready(Ok(socket)).boxed(), - None => future::pending().boxed(), - } - } -} - -/// An implementation of [`ConnectionHandler`] that delegates to individual [`SubstreamHandler`]s. -pub struct SubstreamConnectionHandler { - inbound_substreams: HashMap, - outbound_substreams: HashMap, - next_inbound_substream_id: InboundSubstreamId, - next_outbound_substream_id: OutboundSubstreamId, - - new_substreams: VecDeque, - - initial_keep_alive_deadline: Instant, -} - -impl - SubstreamConnectionHandler -{ - pub fn new(initial_keep_alive: Duration) -> Self { - Self { - inbound_substreams: Default::default(), - outbound_substreams: Default::default(), - next_inbound_substream_id: InboundSubstreamId(0), - next_outbound_substream_id: OutboundSubstreamId(0), - new_substreams: Default::default(), - initial_keep_alive_deadline: Instant::now() + initial_keep_alive, - } - } -} - -impl - SubstreamConnectionHandler -{ - pub fn new_outbound_only(initial_keep_alive: Duration) -> Self { - Self { - inbound_substreams: Default::default(), - outbound_substreams: Default::default(), - next_inbound_substream_id: InboundSubstreamId(0), - next_outbound_substream_id: OutboundSubstreamId(0), - new_substreams: Default::default(), - initial_keep_alive_deadline: Instant::now() + initial_keep_alive, - } - } -} - -impl - SubstreamConnectionHandler -{ - pub fn new_inbound_only(initial_keep_alive: Duration) -> Self { - Self { - inbound_substreams: Default::default(), - outbound_substreams: Default::default(), - next_inbound_substream_id: InboundSubstreamId(0), - next_outbound_substream_id: OutboundSubstreamId(0), - new_substreams: Default::default(), - initial_keep_alive_deadline: Instant::now() + initial_keep_alive, - } - } -} - -/// Poll all substreams within the given HashMap. -/// -/// This is defined as a separate function because we call it with two different fields stored within [`SubstreamConnectionHandler`]. -fn poll_substreams( - substreams: &mut HashMap, - cx: &mut Context<'_>, -) -> Poll> -where - TSubstream: SubstreamHandler, - TId: Copy + Eq + Hash + fmt::Display, -{ - let substream_ids = substreams.keys().copied().collect::>(); - - 'loop_substreams: for id in substream_ids { - let mut handler = substreams - .remove(&id) - .expect("we just got the key out of the map"); - - let (next_state, poll) = 'loop_handler: loop { - match handler.advance(cx) { - Ok(Next::EmitEvent { next_state, event }) => { - break (next_state, Poll::Ready(Ok((id, event)))) - } - Ok(Next::Pending { next_state }) => break (next_state, Poll::Pending), - Ok(Next::Continue { next_state }) => { - handler = next_state; - continue 'loop_handler; - } - Ok(Next::Done) => { - log::debug!("Substream handler {} finished", id); - continue 'loop_substreams; - } - Err(e) => return Poll::Ready(Err((id, e))), - } - }; - - substreams.insert(id, next_state); - - return poll; - } - - Poll::Pending -} - -/// Event sent from the [`libp2p_swarm::NetworkBehaviour`] to the [`SubstreamConnectionHandler`]. -#[allow(clippy::enum_variant_names)] -#[derive(Debug)] -pub enum InEvent { - /// Open a new substream using the provided `open_info`. - /// - /// For "client-server" protocols, this is typically the initial message to be sent to the other party. - NewSubstream { open_info: I }, - NotifyInboundSubstream { - id: InboundSubstreamId, - message: TInboundEvent, - }, - NotifyOutboundSubstream { - id: OutboundSubstreamId, - message: TOutboundEvent, - }, -} - -/// Event produced by the [`SubstreamConnectionHandler`] for the corresponding [`libp2p_swarm::NetworkBehaviour`]. -#[derive(Debug)] -pub enum OutEvent { - /// An inbound substream produced an event. - InboundEvent { - id: InboundSubstreamId, - message: TInbound, - }, - /// An outbound substream produced an event. - OutboundEvent { - id: OutboundSubstreamId, - message: TOutbound, - }, - /// An inbound substream errored irrecoverably. - InboundError { - id: InboundSubstreamId, - error: TInboundError, - }, - /// An outbound substream errored irrecoverably. - OutboundError { - id: OutboundSubstreamId, - error: TOutboundError, - }, -} - -impl< - TInboundInEvent, - TInboundOutEvent, - TOutboundInEvent, - TOutboundOutEvent, - TOutboundOpenInfo, - TInboundError, - TOutboundError, - TInboundSubstreamHandler, - TOutboundSubstreamHandler, - > ConnectionHandler - for SubstreamConnectionHandler< - TInboundSubstreamHandler, - TOutboundSubstreamHandler, - TOutboundOpenInfo, - > -where - TInboundSubstreamHandler: SubstreamHandler< - InEvent = TInboundInEvent, - OutEvent = TInboundOutEvent, - Error = TInboundError, - OpenInfo = (), - >, - TOutboundSubstreamHandler: SubstreamHandler< - InEvent = TOutboundInEvent, - OutEvent = TOutboundOutEvent, - Error = TOutboundError, - OpenInfo = TOutboundOpenInfo, - >, - TInboundInEvent: fmt::Debug + Send + 'static, - TInboundOutEvent: fmt::Debug + Send + 'static, - TOutboundInEvent: fmt::Debug + Send + 'static, - TOutboundOutEvent: fmt::Debug + Send + 'static, - TOutboundOpenInfo: fmt::Debug + Send + 'static, - TInboundError: fmt::Debug + Send + 'static, - TOutboundError: fmt::Debug + Send + 'static, - TInboundSubstreamHandler: Send + 'static, - TOutboundSubstreamHandler: Send + 'static, -{ - type InEvent = InEvent; - type OutEvent = OutEvent; - type Error = Void; - type InboundProtocol = PassthroughProtocol; - type OutboundProtocol = PassthroughProtocol; - type InboundOpenInfo = (); - type OutboundOpenInfo = TOutboundOpenInfo; - - fn listen_protocol(&self) -> SubstreamProtocol { - TInboundSubstreamHandler::upgrade(()) - } - - fn inject_fully_negotiated_inbound( - &mut self, - protocol: ::Output, - _: Self::InboundOpenInfo, - ) { - self.inbound_substreams.insert( - self.next_inbound_substream_id.fetch_and_increment(), - TInboundSubstreamHandler::new(protocol, ()), - ); - } - - fn inject_fully_negotiated_outbound( - &mut self, - protocol: ::Output, - info: Self::OutboundOpenInfo, - ) { - self.outbound_substreams.insert( - self.next_outbound_substream_id.fetch_and_increment(), - TOutboundSubstreamHandler::new(protocol, info), - ); - } - - fn inject_event(&mut self, event: Self::InEvent) { - match event { - InEvent::NewSubstream { open_info } => self.new_substreams.push_back(open_info), - InEvent::NotifyInboundSubstream { id, message } => { - match self.inbound_substreams.remove(&id) { - Some(handler) => { - let new_handler = handler.inject_event(message); - - self.inbound_substreams.insert(id, new_handler); - } - None => { - log::debug!("Substream with ID {} not found", id); - } - } - } - InEvent::NotifyOutboundSubstream { id, message } => { - match self.outbound_substreams.remove(&id) { - Some(handler) => { - let new_handler = handler.inject_event(message); - - self.outbound_substreams.insert(id, new_handler); - } - None => { - log::debug!("Substream with ID {} not found", id); - } - } - } - } - } - - fn inject_dial_upgrade_error( - &mut self, - _: Self::OutboundOpenInfo, - _: ConnectionHandlerUpgrErr, - ) { - // TODO: Handle upgrade errors properly - } - - fn connection_keep_alive(&self) -> KeepAlive { - // Rudimentary keep-alive handling, to be extended as needed as this abstraction is used more by other protocols. - - if Instant::now() < self.initial_keep_alive_deadline { - return KeepAlive::Yes; - } - - if self.inbound_substreams.is_empty() - && self.outbound_substreams.is_empty() - && self.new_substreams.is_empty() - { - return KeepAlive::No; - } - - KeepAlive::Yes - } - - fn poll( - &mut self, - cx: &mut Context<'_>, - ) -> Poll< - ConnectionHandlerEvent< - Self::OutboundProtocol, - Self::OutboundOpenInfo, - Self::OutEvent, - Self::Error, - >, - > { - if let Some(open_info) = self.new_substreams.pop_front() { - return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: TOutboundSubstreamHandler::upgrade(open_info), - }); - } - - match poll_substreams(&mut self.inbound_substreams, cx) { - Poll::Ready(Ok((id, message))) => { - return Poll::Ready(ConnectionHandlerEvent::Custom(OutEvent::InboundEvent { - id, - message, - })) - } - Poll::Ready(Err((id, error))) => { - return Poll::Ready(ConnectionHandlerEvent::Custom(OutEvent::InboundError { - id, - error, - })) - } - Poll::Pending => {} - } - - match poll_substreams(&mut self.outbound_substreams, cx) { - Poll::Ready(Ok((id, message))) => { - return Poll::Ready(ConnectionHandlerEvent::Custom(OutEvent::OutboundEvent { - id, - message, - })) - } - Poll::Ready(Err((id, error))) => { - return Poll::Ready(ConnectionHandlerEvent::Custom(OutEvent::OutboundError { - id, - error, - })) - } - Poll::Pending => {} - } - - Poll::Pending - } -} - -/// A helper struct for substream handlers that can be implemented as async functions. -/// -/// This only works for substreams without an `InEvent` because - once constructed - the state of an inner future is opaque. -pub struct FutureSubstream { - future: Fuse>>, -} - -impl FutureSubstream { - pub fn new(future: impl Future> + Send + 'static) -> Self { - Self { - future: future.boxed().fuse(), - } - } - - pub fn advance(mut self, cx: &mut Context<'_>) -> Result, TError> { - if self.future.is_terminated() { - return Ok(Next::Done); - } - - match self.future.poll_unpin(cx) { - Poll::Ready(Ok(event)) => Ok(Next::EmitEvent { - event, - next_state: self, - }), - Poll::Ready(Err(error)) => Err(error), - Poll::Pending => Ok(Next::Pending { next_state: self }), - } - } -} - -impl SubstreamHandler for void::Void { - type InEvent = void::Void; - type OutEvent = void::Void; - type Error = void::Void; - type OpenInfo = (); - - fn new(_: NegotiatedSubstream, _: Self::OpenInfo) -> Self { - unreachable!("we should never yield a substream") - } - - fn inject_event(self, event: Self::InEvent) -> Self { - void::unreachable(event) - } - - fn advance(self, _: &mut Context<'_>) -> Result, Self::Error> { - void::unreachable(self) - } - - fn upgrade( - open_info: Self::OpenInfo, - ) -> SubstreamProtocol { - SubstreamProtocol::new(PassthroughProtocol { ident: None }, open_info) - } -} diff --git a/protocols/rendezvous/tests/rendezvous.rs b/protocols/rendezvous/tests/rendezvous.rs index 85fdacd8ae4..ae452da8646 100644 --- a/protocols/rendezvous/tests/rendezvous.rs +++ b/protocols/rendezvous/tests/rendezvous.rs @@ -274,7 +274,7 @@ async fn registration_on_clients_expire() { let roberts_peer_id = *robert.local_peer_id(); robert.spawn_into_runtime(); - let registration_ttl = 3; + let registration_ttl = 2; alice .behaviour_mut() @@ -292,7 +292,7 @@ async fn registration_on_clients_expire() { } }; - tokio::time::sleep(Duration::from_secs(registration_ttl + 5)).await; + tokio::time::sleep(Duration::from_secs(registration_ttl + 3)).await; let event = bob.select_next_some().await; let error = bob.dial(*alice.local_peer_id()).unwrap_err(); diff --git a/swarm/src/handler.rs b/swarm/src/handler.rs index 8d34509c085..fc503be9139 100644 --- a/swarm/src/handler.rs +++ b/swarm/src/handler.rs @@ -39,6 +39,7 @@ //! > [`NetworkBehaviour`](crate::behaviour::NetworkBehaviour) trait. pub mod either; +pub mod from_fn; mod map_in; mod map_out; pub mod multi; diff --git a/swarm/src/handler/from_fn.rs b/swarm/src/handler/from_fn.rs new file mode 100644 index 00000000000..355a7e3057e --- /dev/null +++ b/swarm/src/handler/from_fn.rs @@ -0,0 +1,925 @@ +use crate::behaviour::{ConnectionClosed, ConnectionEstablished, FromSwarm}; +use crate::handler::{InboundUpgradeSend, OutboundUpgradeSend}; +use crate::{ + ConnectionHandler, ConnectionHandlerEvent, ConnectionHandlerUpgrErr, IntoConnectionHandler, + KeepAlive, NegotiatedSubstream, NetworkBehaviourAction, NotifyHandler, SubstreamProtocol, +}; +use futures::stream::{BoxStream, SelectAll}; +use futures::{Stream, StreamExt}; +use libp2p_core::connection::ConnectionId; +use libp2p_core::either::EitherOutput; +use libp2p_core::upgrade::{ + DeniedUpgrade, EitherUpgrade, NegotiationError, ProtocolError, ReadyUpgrade, +}; +use libp2p_core::{ConnectedPoint, PeerId, UpgradeError}; +use std::collections::{HashSet, VecDeque}; +use std::error::Error; +use std::fmt; +use std::future::Future; +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; +use std::task::{Context, Poll, Waker}; +use std::time::{Duration, Instant}; +use void::Void; + +/// A low-level building block for protocols that can be expressed as async functions. +/// +/// An async function in Rust is executed within an executor and thus only has access to its local state +/// and by extension, the state that was available when the [`Future`] was constructed. +/// +/// This [`ConnectionHandler`] aims to reduce the boilerplate for protocols which can be expressed as +/// a static sequence of reads and writes to a socket where the response can be generated with limited +/// "local" knowledge or in other words, the local state within the [`Future`]. +/// +/// For outbound substreams, arbitrary data can be supplied via [`InEvent::NewOutbound`] which will be +/// made available to the callback once the stream is fully negotiated. +/// +/// Inbound substreams may be opened at any time by the remote. To facilitate this one and more usecases, +/// the supplied callbacks for inbound and outbound substream are given access to the handler's `state` +/// field. State can be shared between the [`NetworkBehaviour`] and a [`FromFn`] [`ConnectionHandler`] +/// via the [`Shared`] abstraction. +/// +/// [`Shared`] tracks a piece of state and updates all registered [`ConnectionHandler`]s whenever the +/// state changes. +/// +/// [`NetworkBehaviour`]: crate::NetworkBehaviour +pub fn from_fn(protocol: &'static str) -> Builder { + Builder { + protocol, + phase: WantState {}, + } +} + +pub struct Builder { + protocol: &'static str, + phase: TPhase, +} + +pub struct WantState {} + +pub struct WantInboundHandler { + state: Arc, +} + +pub struct WantOutboundHandler { + state: Arc, + max_inbound: usize, + inbound_handler: Option< + Box< + dyn Fn( + NegotiatedSubstream, + PeerId, + ConnectedPoint, + Arc, + ) -> BoxStream<'static, TInbound> + + Send, + >, + >, +} + +impl Builder { + pub fn with_state( + self, + shared: &Shared, + ) -> Builder> { + Builder { + protocol: self.protocol, + phase: WantInboundHandler { + state: shared.shared.clone(), + }, + } + } + + pub fn without_state(self) -> Builder> { + Builder { + protocol: self.protocol, + phase: WantInboundHandler { + state: Arc::new(()), + }, + } + } +} + +impl Builder> { + pub fn with_inbound_handler( + self, + max_inbound: usize, + handler: impl Fn(NegotiatedSubstream, PeerId, ConnectedPoint, Arc) -> TInboundStream + + Send + + 'static, + ) -> Builder> + where + TInboundStream: Future + Send + 'static, + { + self.with_streaming_inbound_handler(max_inbound, move |stream, peer, endpoint, state| { + futures::stream::once(handler(stream, peer, endpoint, state)) + }) + } + + pub fn with_streaming_inbound_handler( + self, + max_inbound: usize, + handler: impl Fn(NegotiatedSubstream, PeerId, ConnectedPoint, Arc) -> TInboundStream + + Send + + 'static, + ) -> Builder> + where + TInboundStream: Stream + Send + 'static, + { + Builder { + protocol: self.protocol, + phase: WantOutboundHandler { + state: self.phase.state, + max_inbound, + inbound_handler: Some(Box::new(move |stream, peer, endpoint, state| { + handler(stream, peer, endpoint, state).boxed() + })), + }, + } + } + + pub fn without_inbound_handler(self) -> Builder> { + Builder { + protocol: self.protocol, + phase: WantOutboundHandler { + state: self.phase.state, + max_inbound: 0, + inbound_handler: None, + }, + } + } +} + +impl Builder> { + pub fn with_outbound_handler( + self, + max_pending_outbound: usize, + handler: impl Fn( + NegotiatedSubstream, + PeerId, + ConnectedPoint, + Arc, + TOutboundInfo, + ) -> TOutboundStream + + Send + + 'static, + ) -> FromFnProto + where + TOutboundStream: Future + Send + 'static, + { + self.with_streaming_outbound_handler( + max_pending_outbound, + move |stream, peer, endpoint, state, info| { + futures::stream::once(handler(stream, peer, endpoint, state, info)) + }, + ) + } + + pub fn with_streaming_outbound_handler( + self, + max_pending_outbound: usize, + handler: impl Fn( + NegotiatedSubstream, + PeerId, + ConnectedPoint, + Arc, + TOutboundInfo, + ) -> TOutboundStream + + Send + + 'static, + ) -> FromFnProto + where + TOutboundStream: Stream + Send + 'static, + { + FromFnProto { + protocol: self.protocol, + on_new_inbound: self.phase.inbound_handler, + on_new_outbound: Box::new(move |stream, peer, endpoint, state, info| { + handler(stream, peer, endpoint, state, info).boxed() + }), + inbound_streams_limit: self.phase.max_inbound, + pending_outbound_streams_limit: max_pending_outbound, + state: self.phase.state, + } + } + + pub fn without_outbound_handler(self) -> FromFnProto { + FromFnProto { + protocol: self.protocol, + on_new_inbound: self.phase.inbound_handler, + on_new_outbound: Box::new(|_, _, _, _, info| void::unreachable(info)), + inbound_streams_limit: self.phase.max_inbound, + pending_outbound_streams_limit: 0, + state: self.phase.state, + } + } +} + +#[derive(Debug)] +pub enum OutEvent { + InboundEmitted(I), + OutboundEmitted(O), + FailedToOpen(OpenError), +} + +#[derive(Debug)] +pub enum OpenError { + /// The time limit for the negotiation handshake was exceeded. + Timeout(OpenInfo), + /// We have hit the configured limit for the maximum number of pending substreams. + LimitExceeded { + open_info: OpenInfo, + pending_substreams: usize, + }, + /// The remote does not support this protocol. + Unsupported { + open_info: OpenInfo, + protocol: &'static str, + }, + NegotiationFailed(OpenInfo, ProtocolError), +} + +/// A wrapper for state that is shared across all connections. +/// +/// Any update to the state will "automatically" be relayed to all connections, assuming this struct +/// is correctly wired into your [`NetworkBehaviour`](crate::swarm::NetworkBehaviour). +/// +/// This struct implements an observer pattern. All registered connections will receive updates that +/// are made to the state. +pub struct Shared { + inner: T, + + shared: Arc, + + dirty: bool, + waker: Option, + connections: HashSet<(PeerId, ConnectionId)>, + pending_update_events: VecDeque<(PeerId, ConnectionId, Arc)>, +} + +impl Shared +where + T: Clone, +{ + pub fn new(state: T) -> Self { + Self { + inner: state.clone(), + shared: Arc::new(state), + dirty: false, + waker: None, + connections: HashSet::default(), + pending_update_events: VecDeque::default(), + } + } + + pub fn on_swarm_event(&mut self, event: &FromSwarm) + where + H: IntoConnectionHandler, + { + match event { + FromSwarm::ConnectionEstablished(ConnectionEstablished { + peer_id, + connection_id, + .. + }) => { + self.connections.insert((*peer_id, *connection_id)); + } + FromSwarm::ConnectionClosed(ConnectionClosed { + peer_id, + connection_id, + .. + }) => { + self.connections.remove(&(*peer_id, *connection_id)); + } + _ => {} + } + } + + pub fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> + where + THandler: IntoConnectionHandler, + { + if self.dirty { + self.shared = Arc::new(self.inner.clone()); + + self.pending_update_events = self + .connections + .iter() + .map(|(peer_id, conn_id)| (*peer_id, *conn_id, self.shared.clone())) + .collect(); + + self.dirty = false; + } + + if let Some((peer_id, conn_id, state)) = self.pending_update_events.pop_front() { + return Poll::Ready(NetworkBehaviourAction::NotifyHandler { + peer_id, + handler: NotifyHandler::One(conn_id), + event: InEvent::UpdateState(state), + }); + } + + self.waker = Some(cx.waker().clone()); + Poll::Pending + } +} + +impl Deref for Shared { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for Shared { + fn deref_mut(&mut self) -> &mut Self::Target { + self.dirty = true; + if let Some(waker) = self.waker.take() { + waker.wake(); + } + + &mut self.inner + } +} + +impl fmt::Display for OpenError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + OpenError::Timeout(_) => write!(f, "opening new substream timed out"), + OpenError::LimitExceeded { + pending_substreams, .. + } => write!( + f, + "limit for pending openings ({pending_substreams}) exceeded" + ), + OpenError::NegotiationFailed(_, _) => Ok(()), // Don't print anything to avoid double printing of error. + OpenError::Unsupported { protocol, .. } => { + write!(f, "remote peer does not support {protocol}") + } + } + } +} + +impl Error for OpenError +where + OpenInfo: fmt::Debug, +{ + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + OpenError::Timeout(_) => None, + OpenError::LimitExceeded { .. } => None, + OpenError::NegotiationFailed(_, source) => Some(source), + OpenError::Unsupported { .. } => None, + } + } +} + +#[derive(Debug)] +pub enum InEvent { + UpdateState(Arc), + NewOutbound(TOutboundOpenInfo), +} + +pub struct FromFnProto { + protocol: &'static str, + + on_new_inbound: Option< + Box< + dyn Fn( + NegotiatedSubstream, + PeerId, + ConnectedPoint, + Arc, + ) -> BoxStream<'static, TInbound> + + Send, + >, + >, + on_new_outbound: Box< + dyn Fn( + NegotiatedSubstream, + PeerId, + ConnectedPoint, + Arc, + TOutboundOpenInfo, + ) -> BoxStream<'static, TOutbound> + + Send, + >, + + inbound_streams_limit: usize, + pending_outbound_streams_limit: usize, + + state: Arc, +} + +impl IntoConnectionHandler + for FromFnProto +where + TInbound: fmt::Debug + Send + 'static, + TOutbound: fmt::Debug + Send + 'static, + TOutboundOpenInfo: fmt::Debug + Send + 'static, + TState: fmt::Debug + Send + Sync + 'static, +{ + type Handler = FromFn; + + fn into_handler( + self, + remote_peer_id: &PeerId, + connected_point: &ConnectedPoint, + ) -> Self::Handler { + FromFn { + protocol: self.protocol, + remote_peer_id: *remote_peer_id, + connected_point: connected_point.clone(), + inbound_streams: SelectAll::default(), + outbound_streams: SelectAll::default(), + on_new_inbound: self.on_new_inbound, + on_new_outbound: self.on_new_outbound, + inbound_streams_limit: self.inbound_streams_limit, + pending_outbound_streams: VecDeque::default(), + pending_outbound_streams_limit: self.pending_outbound_streams_limit, + failed_open: VecDeque::default(), + state: self.state, + keep_alive: KeepAlive::Yes, + } + } + + fn inbound_protocol(&self) -> ::InboundProtocol { + if self.on_new_inbound.is_some() { + EitherUpgrade::B(ReadyUpgrade::new(self.protocol)) + } else { + EitherUpgrade::A(DeniedUpgrade) + } + } +} + +pub struct FromFn { + protocol: &'static str, + remote_peer_id: PeerId, + connected_point: ConnectedPoint, + + inbound_streams: SelectAll>, + outbound_streams: SelectAll>, + + on_new_inbound: Option< + Box< + dyn Fn( + NegotiatedSubstream, + PeerId, + ConnectedPoint, + Arc, + ) -> BoxStream<'static, TInbound> + + Send, + >, + >, + on_new_outbound: Box< + dyn Fn( + NegotiatedSubstream, + PeerId, + ConnectedPoint, + Arc, + TOutboundInfo, + ) -> BoxStream<'static, TOutbound> + + Send, + >, + + inbound_streams_limit: usize, + + pending_outbound_streams: VecDeque, + pending_outbound_streams_limit: usize, + + failed_open: VecDeque>, + + state: Arc, + + keep_alive: KeepAlive, +} + +impl ConnectionHandler + for FromFn +where + TOutboundInfo: fmt::Debug + Send + 'static, + TInbound: fmt::Debug + Send + 'static, + TOutbound: fmt::Debug + Send + 'static, + TState: fmt::Debug + Send + Sync + 'static, +{ + type InEvent = InEvent; + type OutEvent = OutEvent; + type Error = Void; + type InboundProtocol = EitherUpgrade>; + type OutboundProtocol = ReadyUpgrade<&'static str>; + type InboundOpenInfo = (); + type OutboundOpenInfo = TOutboundInfo; + + fn listen_protocol(&self) -> SubstreamProtocol { + if self.on_new_inbound.is_some() { + SubstreamProtocol::new(EitherUpgrade::B(ReadyUpgrade::new(self.protocol)), ()) + } else { + SubstreamProtocol::new(EitherUpgrade::A(DeniedUpgrade), ()) + } + } + + fn inject_fully_negotiated_inbound( + &mut self, + protocol: ::Output, + _: Self::InboundOpenInfo, + ) { + match protocol { + EitherOutput::First(never) => void::unreachable(never), + EitherOutput::Second(protocol) => { + if self.inbound_streams.len() >= self.inbound_streams_limit { + log::debug!( + "Dropping inbound substream because limit ({}) would be exceeded", + self.inbound_streams_limit + ); + return; + } + + let on_new_inbound = self + .on_new_inbound + .as_ref() + .expect("to have callback when protocol was negotiated"); + + let inbound_future = (on_new_inbound)( + protocol, + self.remote_peer_id, + self.connected_point.clone(), + Arc::clone(&self.state), + ); + self.inbound_streams.push(inbound_future); + } + } + } + + fn inject_fully_negotiated_outbound( + &mut self, + protocol: ::Output, + info: Self::OutboundOpenInfo, + ) { + let outbound_future = (self.on_new_outbound)( + protocol, + self.remote_peer_id, + self.connected_point.clone(), + Arc::clone(&self.state), + info, + ); + self.outbound_streams.push(outbound_future); + } + + fn inject_event(&mut self, event: Self::InEvent) { + match event { + InEvent::UpdateState(new_state) => self.state = new_state, + InEvent::NewOutbound(open_info) => { + if self.pending_outbound_streams.len() >= self.pending_outbound_streams_limit { + self.failed_open.push_back(OpenError::LimitExceeded { + open_info, + pending_substreams: self.pending_outbound_streams.len(), + }); + } else { + self.pending_outbound_streams.push_back(open_info); + } + } + } + } + + fn inject_dial_upgrade_error( + &mut self, + info: Self::OutboundOpenInfo, + error: ConnectionHandlerUpgrErr<::Error>, + ) { + match error { + ConnectionHandlerUpgrErr::Timeout => { + self.failed_open.push_back(OpenError::Timeout(info)) + } + ConnectionHandlerUpgrErr::Timer => {} + ConnectionHandlerUpgrErr::Upgrade(UpgradeError::Select( + NegotiationError::ProtocolError(error), + )) => self + .failed_open + .push_back(OpenError::NegotiationFailed(info, error)), + ConnectionHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => { + self.failed_open.push_back(OpenError::Unsupported { + open_info: info, + protocol: self.protocol, + }) + } + ConnectionHandlerUpgrErr::Upgrade(UpgradeError::Apply(apply)) => { + void::unreachable(apply) + } + } + } + + fn connection_keep_alive(&self) -> KeepAlive { + self.keep_alive + } + + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll< + ConnectionHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::OutEvent, + Self::Error, + >, + > { + if let Some(error) = self.failed_open.pop_front() { + return Poll::Ready(ConnectionHandlerEvent::Custom(OutEvent::FailedToOpen( + error, + ))); + } + + match self.outbound_streams.poll_next_unpin(cx) { + Poll::Ready(Some(outbound_done)) => { + return Poll::Ready(ConnectionHandlerEvent::Custom(OutEvent::OutboundEmitted( + outbound_done, + ))); + } + Poll::Ready(None) => { + // Normally, we'd register a waker here but `Connection` polls us anyway again + // after calling `inject` on us which is where we'd use the waker. + } + Poll::Pending => {} + }; + + match self.inbound_streams.poll_next_unpin(cx) { + Poll::Ready(Some(inbound_done)) => { + return Poll::Ready(ConnectionHandlerEvent::Custom(OutEvent::InboundEmitted( + inbound_done, + ))); + } + Poll::Ready(None) => { + // Normally, we'd register a waker here but `Connection` polls us anyway again + // after calling `inject` on us which is where we'd use the waker. + } + Poll::Pending => {} + }; + + if let Some(outbound_open_info) = self.pending_outbound_streams.pop_front() { + return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { + protocol: SubstreamProtocol::new( + ReadyUpgrade::new(self.protocol), + outbound_open_info, + ), + }); + } + + if self.inbound_streams.is_empty() + && self.outbound_streams.is_empty() + && self.pending_outbound_streams.is_empty() + { + if self.keep_alive.is_yes() { + // TODO: Make timeout configurable + self.keep_alive = KeepAlive::Until(Instant::now() + Duration::from_secs(10)) + } + } else { + self.keep_alive = KeepAlive::Yes + } + + Poll::Pending + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + IntoConnectionHandler, NetworkBehaviour, NetworkBehaviourAction, PollParameters, Swarm, + SwarmEvent, + }; + use futures::{AsyncReadExt, AsyncWriteExt}; + use libp2p_core::connection::ConnectionId; + use libp2p_core::transport::MemoryTransport; + use libp2p_core::upgrade::Version; + use libp2p_core::{identity, PeerId, Transport}; + use libp2p_plaintext::PlainText2Config; + use libp2p_yamux as yamux; + use std::collections::HashMap; + use std::io; + use std::ops::AddAssign; + + #[async_std::test] + async fn greetings() { + let _ = env_logger::try_init(); + + let mut alice = make_swarm("Alice"); + let mut bob = make_swarm("Bob"); + + let bob_peer_id = *bob.local_peer_id(); + let listen_id = alice.listen_on("/memory/0".parse().unwrap()).unwrap(); + + let alice_listen_addr = loop { + if let SwarmEvent::NewListenAddr { + address, + listener_id, + } = alice.select_next_some().await + { + if listener_id == listen_id { + break address; + } + } + }; + bob.dial(alice_listen_addr).unwrap(); + + futures::future::join( + async { + while !matches!( + alice.select_next_some().await, + SwarmEvent::ConnectionEstablished { .. } + ) {} + }, + async { + while !matches!( + bob.select_next_some().await, + SwarmEvent::ConnectionEstablished { .. } + ) {} + }, + ) + .await; + + futures::future::join( + async { + alice.behaviour_mut().say_hello(bob_peer_id); + + loop { + if let SwarmEvent::Behaviour(greetings) = alice.select_next_some().await { + assert_eq!(*greetings.get(&Name("Bob".to_owned())).unwrap(), 1); + break; + } + } + }, + async { + loop { + if let SwarmEvent::Behaviour(greetings) = bob.select_next_some().await { + assert_eq!(*greetings.get(&Name("Alice".to_owned())).unwrap(), 1); + break; + } + } + }, + ) + .await; + + alice.behaviour_mut().state.name = Name("Carol".to_owned()); + bob.behaviour_mut().state.name = Name("Steve".to_owned()); + + futures::future::join( + async { + alice.behaviour_mut().say_hello(bob_peer_id); + + loop { + if let SwarmEvent::Behaviour(greetings) = alice.select_next_some().await { + assert_eq!(*greetings.get(&Name("Bob".to_owned())).unwrap(), 1); + assert_eq!(*greetings.get(&Name("Steve".to_owned())).unwrap(), 1); + break; + } + } + }, + async { + loop { + if let SwarmEvent::Behaviour(greetings) = bob.select_next_some().await { + assert_eq!(*greetings.get(&Name("Alice".to_owned())).unwrap(), 1); + assert_eq!(*greetings.get(&Name("Carol".to_owned())).unwrap(), 1); + break; + } + } + }, + ) + .await; + } + + struct HelloBehaviour { + state: Shared, + pending_messages: VecDeque, + pending_events: VecDeque>, + greeting_count: HashMap, + } + + #[derive(Debug, Clone)] + struct State { + name: Name, + } + + #[derive(Debug, Clone, PartialEq, Hash, Eq)] + struct Name(String); + + impl HelloBehaviour { + fn say_hello(&mut self, to: PeerId) { + self.pending_messages.push_back(to); + } + } + + impl NetworkBehaviour for HelloBehaviour { + type ConnectionHandler = FromFnProto, io::Result, (), State>; + type OutEvent = HashMap; + + fn new_handler(&mut self) -> Self::ConnectionHandler { + from_fn("/hello/1.0.0") + .with_state(&self.state) + .with_inbound_handler(5, |mut stream, _, _, state| async move { + let mut received_name = Vec::new(); + stream.read_to_end(&mut received_name).await?; + + stream.write_all(state.name.0.as_bytes()).await?; + stream.close().await?; + + Ok(Name(String::from_utf8(received_name).unwrap())) + }) + .with_outbound_handler(5, |mut stream, _, _, state, _| async move { + stream.write_all(state.name.0.as_bytes()).await?; + stream.flush().await?; + stream.close().await?; + + let mut received_name = Vec::new(); + stream.read_to_end(&mut received_name).await?; + + Ok(Name(String::from_utf8(received_name).unwrap())) + }) + } + + fn on_swarm_event(&mut self, event: FromSwarm) { + self.state.on_swarm_event(&event); + } + + fn on_connection_handler_event( + &mut self, + _peer_id: PeerId, + _connection_id: ConnectionId, + event: <::Handler as ConnectionHandler>::OutEvent, + ) { + match event { + OutEvent::InboundEmitted(Ok(name)) => { + self.greeting_count.entry(name).or_default().add_assign(1); + + self.pending_events.push_back(self.greeting_count.clone()) + } + OutEvent::OutboundEmitted(Ok(name)) => { + self.greeting_count.entry(name).or_default().add_assign(1); + + self.pending_events.push_back(self.greeting_count.clone()) + } + OutEvent::InboundEmitted(_) => {} + OutEvent::OutboundEmitted(_) => {} + OutEvent::FailedToOpen(OpenError::Timeout(_)) => {} + OutEvent::FailedToOpen(OpenError::NegotiationFailed(_, _)) => {} + OutEvent::FailedToOpen(OpenError::LimitExceeded { .. }) => {} + OutEvent::FailedToOpen(OpenError::Unsupported { .. }) => {} + } + } + + fn poll( + &mut self, + cx: &mut Context<'_>, + _: &mut impl PollParameters, + ) -> Poll> { + if let Some(greeting_count) = self.pending_events.pop_front() { + return Poll::Ready(NetworkBehaviourAction::GenerateEvent(greeting_count)); + } + + if let Poll::Ready(action) = self.state.poll(cx) { + return Poll::Ready(action); + } + + if let Some(to) = self.pending_messages.pop_front() { + return Poll::Ready(NetworkBehaviourAction::NotifyHandler { + peer_id: to, + handler: NotifyHandler::Any, + event: InEvent::NewOutbound(()), + }); + } + + Poll::Pending + } + } + + fn make_swarm(name: &'static str) -> Swarm { + let identity = identity::Keypair::generate_ed25519(); + + let transport = MemoryTransport::new() + .upgrade(Version::V1) + .authenticate(PlainText2Config { + local_public_key: identity.public(), + }) + .multiplex(yamux::YamuxConfig::default()) + .boxed(); + + Swarm::without_executor( + transport, + HelloBehaviour { + state: Shared::new(State { + name: Name(name.to_owned()), + }), + pending_messages: Default::default(), + pending_events: Default::default(), + greeting_count: Default::default(), + }, + identity.public().to_peer_id(), + ) + } + + // TODO: Add test for max pending dials + // TODO: Add test for max inbound streams +} diff --git a/swarm/src/lib.rs b/swarm/src/lib.rs index 7894f1e576a..f0477fae788 100644 --- a/swarm/src/lib.rs +++ b/swarm/src/lib.rs @@ -110,9 +110,9 @@ pub use connection::{ }; pub use executor::Executor; pub use handler::{ - ConnectionHandler, ConnectionHandlerEvent, ConnectionHandlerSelect, ConnectionHandlerUpgrErr, - IntoConnectionHandler, IntoConnectionHandlerSelect, KeepAlive, OneShotHandler, - OneShotHandlerConfig, SubstreamProtocol, + from_fn::from_fn, ConnectionHandler, ConnectionHandlerEvent, ConnectionHandlerSelect, + ConnectionHandlerUpgrErr, IntoConnectionHandler, IntoConnectionHandlerSelect, KeepAlive, + OneShotHandler, OneShotHandlerConfig, SubstreamProtocol, }; #[cfg(feature = "macros")] pub use libp2p_swarm_derive::NetworkBehaviour;