From c3be20c86e1a6dfa3523b2d77e8c503d0f5b2ce3 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Fri, 23 Aug 2024 10:33:53 +0200 Subject: [PATCH] server: unify accept error handling (#1882) --- tonic/src/transport/server/incoming.rs | 41 ++++++++++++++++---------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/tonic/src/transport/server/incoming.rs b/tonic/src/transport/server/incoming.rs index b7d2f8d6c..8b9230630 100644 --- a/tonic/src/transport/server/incoming.rs +++ b/tonic/src/transport/server/incoming.rs @@ -1,14 +1,12 @@ -use super::service::ServerIo; -#[cfg(feature = "tls")] -use super::service::TlsAcceptor; -#[cfg(not(feature = "tls"))] -use std::io; use std::{ + io, net::{SocketAddr, TcpListener as StdTcpListener}, + ops::ControlFlow, pin::{pin, Pin}, task::{ready, Context, Poll}, time::Duration, }; + use tokio::{ io::{AsyncRead, AsyncWrite}, net::{TcpListener, TcpStream}, @@ -17,6 +15,10 @@ use tokio_stream::wrappers::TcpListenerStream; use tokio_stream::{Stream, StreamExt}; use tracing::warn; +use super::service::ServerIo; +#[cfg(feature = "tls")] +use super::service::TlsAcceptor; + #[cfg(not(feature = "tls"))] pub(crate) fn tcp_incoming( incoming: impl Stream>, @@ -31,15 +33,9 @@ where while let Some(item) = incoming.next().await { yield match item { Ok(_) => item.map(ServerIo::new_io)?, - Err(e) => { - let e = e.into(); - tracing::debug!(error = %e, "accept loop error"); - if let Some(e) = e.downcast_ref::() { - if e.kind() == io::ErrorKind::ConnectionAborted { - continue; - } - } - Err(e)? + Err(e) => match handle_accept_error(e) { + ControlFlow::Continue(()) => continue, + ControlFlow::Break(e) => Err(e)?, } } } @@ -78,8 +74,9 @@ where yield io; } - SelectOutput::Err(e) => { - tracing::debug!(error = %e, "accept loop error"); + SelectOutput::Err(e) => match handle_accept_error(e) { + ControlFlow::Continue(()) => continue, + ControlFlow::Break(e) => Err(e)?, } SelectOutput::Done => { @@ -90,6 +87,18 @@ where } } +fn handle_accept_error(e: impl Into) -> ControlFlow { + let e = e.into(); + tracing::debug!(error = %e, "accept loop error"); + if let Some(e) = e.downcast_ref::() { + if e.kind() == io::ErrorKind::ConnectionAborted { + return ControlFlow::Continue(()); + } + } + + ControlFlow::Break(e) +} + #[cfg(feature = "tls")] async fn select( incoming: &mut (impl Stream> + Unpin),