diff --git a/zebra-network/src/peer/handshake.rs b/zebra-network/src/peer/handshake.rs index c35bfbae47f..9f5e79fd81d 100644 --- a/zebra-network/src/peer/handshake.rs +++ b/zebra-network/src/peer/handshake.rs @@ -10,7 +10,7 @@ use std::{ use chrono::{TimeZone, Utc}; use futures::{ channel::{mpsc, oneshot}, - prelude::*, + future, FutureExt, SinkExt, StreamExt, }; use tokio::{net::TcpStream, sync::broadcast, time::timeout}; use tokio_util::codec::Framed; @@ -30,7 +30,7 @@ use crate::{ BoxError, Config, }; -use super::{Client, Connection, ErrorSlot, HandshakeError, PeerError}; +use super::{Client, ClientRequest, Connection, ErrorSlot, HandshakeError, PeerError}; /// A [`Service`] that handshakes with a remote peer and constructs a /// client/server pair. @@ -487,22 +487,22 @@ where // - every error/shutdown must update the address book state and return // // The address book state can be updated via `ClientRequest.tx`, or the - // timestamp_collector. + // heartbeat_ts_collector. // // Returning from the spawned closure terminates the connection's heartbeat task. let heartbeat_span = tracing::debug_span!(parent: connection_span, "heartbeat"); + let heartbeat_ts_collector = timestamp_collector.clone(); tokio::spawn( async move { - use super::ClientRequest; use futures::future::Either; let mut shutdown_rx = shutdown_rx; let mut server_tx = server_tx; - let mut timestamp_collector = timestamp_collector.clone(); + let mut timestamp_collector = heartbeat_ts_collector.clone(); let mut interval_stream = tokio::time::interval(constants::HEARTBEAT_INTERVAL); + loop { let shutdown_rx_ref = Pin::new(&mut shutdown_rx); - let mut send_addr_err = false; // CORRECTNESS // @@ -513,107 +513,37 @@ where // slow rate, and shutdown is a oneshot. If both futures // are ready, we want the shutdown to take priority over // sending a useless heartbeat. - match future::select(shutdown_rx_ref, interval_stream.next()).await { - Either::Right(_) => { - let (tx, rx) = oneshot::channel(); - let request = Request::Ping(Nonce::default()); - tracing::trace!(?request, "queueing heartbeat request"); - match server_tx.try_send(ClientRequest { - request, - tx, - span: tracing::Span::current(), - }) { - Ok(()) => { - // TODO: also wait on the shutdown_rx here - match timeout( - constants::HEARTBEAT_INTERVAL, - server_tx.flush(), - ) - .await - { - Ok(Ok(())) => { - } - Ok(Err(e)) => { - tracing::warn!( - ?e, - "flushing client request failed, shutting down" - ); - send_addr_err = true; - } - Err(e) => { - tracing::warn!( - ?e, - "flushing client request timed out, shutting down" - ); - send_addr_err = true; - } - } - } - Err(e) => { - tracing::trace!( - ?e, - "error sending heartbeat request, shutting down" - ); - if e.is_disconnected() { - let ClientRequest { tx, .. } = e.into_inner(); - let _ = - tx.send(Err(PeerError::ConnectionClosed.into())); - } else if e.is_full() { - // TODO: wait for the sink to be ready, or wait for a timeout, - // then close the connection with an overloaded error (#1551) - let ClientRequest { tx, .. } = e.into_inner(); - let _ = tx.send(Err(PeerError::Overloaded.into())); - } else { - // we need to map unexpected error types to PeerErrors - panic!("unexpected try_send error: {:?}", e); - } - return; - } - } - // Heartbeats are checked internally to the - // connection logic, but we need to wait on the - // response to avoid canceling the request. - // - // TODO: also wait on the shutdown_rx here - match timeout(constants::HEARTBEAT_INTERVAL, rx).await { - Ok(Ok(_)) => tracing::trace!("got heartbeat response"), - Ok(Err(e)) => { - tracing::warn!( - ?e, - "error awaiting heartbeat response, shutting down" - ); - send_addr_err = true; - } - Err(e) => { - tracing::warn!( - ?e, - "heartbeat response timed out, shutting down" - ); - send_addr_err = true; - } - } - } - Either::Left(_) => { - tracing::trace!("shutting down due to Client shut down"); - // awaiting a local task won't hang - let _ = timestamp_collector - .send(MetaAddr::new_shutdown(&addr, &remote_services)) - .await; - return; - } - } - if send_addr_err { - // We can't get the client request for this failure, - // so we can't send an error back on `tx`. So - // we just update the address book with a failure. + if matches!( + future::select(shutdown_rx_ref, interval_stream.next()).await, + Either::Left(_) + ) { + tracing::trace!("shutting down due to Client shut down"); + // awaiting a local task won't hang let _ = timestamp_collector - .send(MetaAddr::new_errored( - &addr, - &remote_services, - )) + .send(MetaAddr::new_shutdown(&addr, &remote_services)) .await; return; } + + // We've reached another heartbeat interval without + // shutting down. so do a heartbeat request. + // + // TODO: await heartbeat and shutdown. The select + // function has some strict lifetime requirements, + // try the select! macro with a custom enum mapping + // (#1783, #1678) + let heartbeat = send_one_heartbeat(&mut server_tx); + if heartbeat_timeout( + heartbeat, + &mut timestamp_collector, + &addr, + &remote_services, + ) + .await + .is_err() + { + return; + } } } .instrument(heartbeat_span) @@ -635,3 +565,106 @@ where .boxed() } } + +/// Send one heartbeat using `server_tx`. +async fn send_one_heartbeat(server_tx: &mut mpsc::Sender) -> Result<(), BoxError> { + // We just reached a heartbeat interval, so start sending + // a heartbeat. + let (tx, rx) = oneshot::channel(); + + // Try to send the heartbeat request + let request = Request::Ping(Nonce::default()); + tracing::trace!(?request, "queueing heartbeat request"); + match server_tx.try_send(ClientRequest { + request, + tx, + span: tracing::Span::current(), + }) { + Ok(()) => {} + Err(e) => { + if e.is_disconnected() { + Err(PeerError::ConnectionClosed)?; + } else if e.is_full() { + // Send the message when the Client becomes ready. + // If sending takes too long, the heartbeat timeout will elapse + // and close the connection, reducing our load to busy peers. + server_tx.send(e.into_inner()).await?; + } else { + // we need to map unexpected error types to PeerErrors + warn!(?e, "unexpected try_send error"); + Err(e)?; + }; + } + } + + // Flush the heartbeat request from the queue + server_tx.flush().await?; + tracing::trace!("sent heartbeat request"); + + // Heartbeats are checked internally to the + // connection logic, but we need to wait on the + // response to avoid canceling the request. + rx.await??; + tracing::trace!("got heartbeat response"); + + Ok(()) +} + +/// Wrap `fut` in a timeout, handing any inner or outer errors using +/// `handle_heartbeat_error`. +async fn heartbeat_timeout( + fut: F, + timestamp_collector: &mut mpsc::Sender, + addr: &SocketAddr, + remote_services: &PeerServices, +) -> Result +where + F: Future>, +{ + let t = match timeout(constants::HEARTBEAT_INTERVAL, fut).await { + Ok(inner_result) => { + handle_heartbeat_error( + inner_result, + timestamp_collector, + addr, + remote_services, + ) + .await? + } + Err(elapsed) => { + handle_heartbeat_error( + Err(elapsed), + timestamp_collector, + addr, + remote_services, + ) + .await? + } + }; + + Ok(t) +} + +/// If `result.is_err()`, mark `addr` as failed using `timestamp_collector`. +async fn handle_heartbeat_error( + result: Result, + timestamp_collector: &mut mpsc::Sender, + addr: &SocketAddr, + remote_services: &PeerServices, +) -> Result +where + E: std::fmt::Debug, +{ + match result { + Ok(t) => Ok(t), + Err(err) => { + tracing::debug!(?err, "heartbeat error, shutting down"); + + let _ = timestamp_collector + .send(MetaAddr::new_errored(&addr, &remote_services)) + .await; + + Err(err) + } + } +}