From e7fd2c23963f421a571dd912f2d05923a4961da8 Mon Sep 17 00:00:00 2001 From: behzad nouri Date: Fri, 6 Oct 2023 10:31:04 -0400 Subject: [PATCH] separates out routing shreds from establishing connections Currently each outgoing shred will attempt to establish a connection if one does not already exist. This is very wasteful and consumes many tokio tasks if the remote node is down or unresponsive. The commit decouples routing packets from establishing connections by adding a buffering channel for each remote address. Outgoing packets are always sent down this channel to be processed once the connection is established. If connecting attempt fails, all packets already pushed to the channel are dropped at once, reducing the number of attempts to make a connection if the remote node is down or unresponsive. --- turbine/src/quic_endpoint.rs | 274 +++++++++++++++++++---------------- 1 file changed, 151 insertions(+), 123 deletions(-) diff --git a/turbine/src/quic_endpoint.rs b/turbine/src/quic_endpoint.rs index 0f93391e042b47..df8d56437084c4 100644 --- a/turbine/src/quic_endpoint.rs +++ b/turbine/src/quic_endpoint.rs @@ -18,20 +18,20 @@ use { collections::{hash_map::Entry, HashMap}, io::Error as IoError, net::{IpAddr, SocketAddr, UdpSocket}, - ops::Deref, sync::Arc, }, thiserror::Error, tokio::{ sync::{ - mpsc::{Receiver as AsyncReceiver, Sender as AsyncSender}, - RwLock, + mpsc::{error::TrySendError, Receiver as AsyncReceiver, Sender as AsyncSender}, + Mutex, RwLock as AsyncRwLock, }, task::JoinHandle, }, }; -const CLIENT_CHANNEL_CAPACITY: usize = 1 << 20; +const CLIENT_CHANNEL_BUFFER: usize = 1 << 14; +const ROUTER_CHANNEL_BUFFER: usize = 64; const INITIAL_MAXIMUM_TRANSMISSION_UNIT: u16 = 1280; const ALPN_TURBINE_PROTOCOL_ID: &[u8] = b"solana-turbine"; const CONNECT_SERVER_NAME: &str = "solana-turbine"; @@ -47,7 +47,6 @@ const CONNECTION_CLOSE_REASON_INVALID_IDENTITY: &[u8] = b"INVALID_IDENTITY"; const CONNECTION_CLOSE_REASON_REPLACED: &[u8] = b"REPLACED"; pub type AsyncTryJoinHandle = TryJoin, JoinHandle<()>>; -type ConnectionCache = HashMap<(SocketAddr, Option), Arc>>>; #[derive(Error, Debug)] pub enum Error { @@ -99,10 +98,22 @@ pub fn new_quic_endpoint( )? }; endpoint.set_default_client_config(client_config); - let cache = Arc::>::default(); - let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_CAPACITY); - let server_task = runtime.spawn(run_server(endpoint.clone(), sender.clone(), cache.clone())); - let client_task = runtime.spawn(run_client(endpoint.clone(), client_receiver, sender, cache)); + let cache = Arc::>>::default(); + let router = Arc::>>>::default(); + let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_BUFFER); + let server_task = runtime.spawn(run_server( + endpoint.clone(), + sender.clone(), + router.clone(), + cache.clone(), + )); + let client_task = runtime.spawn(run_client( + endpoint.clone(), + client_receiver, + sender, + router, + cache, + )); let task = futures::future::try_join(server_task, client_task); Ok((endpoint, client_sender, task)) } @@ -152,13 +163,15 @@ fn new_transport_config() -> TransportConfig { async fn run_server( endpoint: Endpoint, sender: Sender<(Pubkey, SocketAddr, Bytes)>, - cache: Arc>, + router: Arc>>>, + cache: Arc>>, ) { while let Some(connecting) = endpoint.accept().await { tokio::task::spawn(handle_connecting_error( endpoint.clone(), connecting, sender.clone(), + router.clone(), cache.clone(), )); } @@ -168,27 +181,61 @@ async fn run_client( endpoint: Endpoint, mut receiver: AsyncReceiver<(SocketAddr, Bytes)>, sender: Sender<(Pubkey, SocketAddr, Bytes)>, - cache: Arc>, + router: Arc>>>, + cache: Arc>>, ) { while let Some((remote_address, bytes)) = receiver.recv().await { - tokio::task::spawn(send_datagram_task( + let bytes = match router.read().await.get(&remote_address) { + None => bytes, + Some(sender) => match sender.try_send(bytes) { + Ok(()) => continue, + Err(TrySendError::Full(_)) => { + error!("TrySendError::Full {remote_address}"); + continue; + } + Err(TrySendError::Closed(bytes)) => bytes, + }, + }; + let receiver = { + let mut router = router.write().await; + let bytes = match router.get(&remote_address) { + None => bytes, + Some(sender) => match sender.try_send(bytes) { + Ok(()) => continue, + Err(TrySendError::Full(_)) => { + error!("TrySendError::Full {remote_address}"); + continue; + } + Err(TrySendError::Closed(bytes)) => bytes, + }, + }; + let (sender, receiver) = tokio::sync::mpsc::channel(ROUTER_CHANNEL_BUFFER); + sender.try_send(bytes).unwrap(); + router.insert(remote_address, sender); + receiver + }; + tokio::task::spawn(make_connection_task( endpoint.clone(), remote_address, - bytes, sender.clone(), + receiver, + router.clone(), cache.clone(), )); } close_quic_endpoint(&endpoint); + // Drop sender channels to unblock threads waiting on the receiving end. + router.write().await.clear(); } async fn handle_connecting_error( endpoint: Endpoint, connecting: Connecting, sender: Sender<(Pubkey, SocketAddr, Bytes)>, - cache: Arc>, + router: Arc>>>, + cache: Arc>>, ) { - if let Err(err) = handle_connecting(endpoint, connecting, sender, cache).await { + if let Err(err) = handle_connecting(endpoint, connecting, sender, router, cache).await { error!("handle_connecting: {err:?}"); } } @@ -197,52 +244,75 @@ async fn handle_connecting( endpoint: Endpoint, connecting: Connecting, sender: Sender<(Pubkey, SocketAddr, Bytes)>, - cache: Arc>, + router: Arc>>>, + cache: Arc>>, ) -> Result<(), Error> { let connection = connecting.await?; let remote_address = connection.remote_address(); let remote_pubkey = get_remote_pubkey(&connection)?; - handle_connection_error( + let receiver = { + let (sender, receiver) = tokio::sync::mpsc::channel(ROUTER_CHANNEL_BUFFER); + router.write().await.insert(remote_address, sender); + receiver + }; + handle_connection( endpoint, remote_address, remote_pubkey, connection, sender, + receiver, + router, cache, ) .await; Ok(()) } -async fn handle_connection_error( +async fn handle_connection( endpoint: Endpoint, remote_address: SocketAddr, remote_pubkey: Pubkey, connection: Connection, sender: Sender<(Pubkey, SocketAddr, Bytes)>, - cache: Arc>, + receiver: AsyncReceiver, + router: Arc>>>, + cache: Arc>>, ) { - cache_connection(remote_address, remote_pubkey, connection.clone(), &cache).await; - if let Err(err) = handle_connection( - &endpoint, + cache_connection(remote_pubkey, connection.clone(), &cache).await; + let send_datagram_task = tokio::task::spawn(send_datagram_task(connection.clone(), receiver)); + let read_datagram_task = tokio::task::spawn(read_datagram_task( + endpoint, remote_address, remote_pubkey, - &connection, - &sender, - ) - .await - { - drop_connection(remote_address, remote_pubkey, &connection, &cache).await; - error!("handle_connection: {remote_pubkey}, {remote_address}, {err:?}"); + connection.clone(), + sender, + )); + match futures::future::try_join(send_datagram_task, read_datagram_task).await { + Err(err) => error!("handle_connection: {remote_pubkey}, {remote_address}, {err:?}"), + Ok(out) => { + if let (Err(ref err), _) = out { + error!("send_datagram_task: {remote_pubkey}, {remote_address}, {err:?}"); + } + if let (_, Err(ref err)) = out { + error!("read_datagram_task: {remote_pubkey}, {remote_address}, {err:?}"); + } + } + } + drop_connection(remote_pubkey, &connection, &cache).await; + if let Entry::Occupied(entry) = router.write().await.entry(remote_address) { + if entry.get().is_closed() { + entry.remove(); + } } } -async fn handle_connection( - endpoint: &Endpoint, +async fn read_datagram_task( + endpoint: Endpoint, remote_address: SocketAddr, remote_pubkey: Pubkey, - connection: &Connection, - sender: &Sender<(Pubkey, SocketAddr, Bytes)>, + connection: Connection, + sender: Sender<(Pubkey, SocketAddr, Bytes)>, ) -> Result<(), Error> { // Assert that send won't block. debug_assert_eq!(sender.capacity(), None); @@ -250,7 +320,7 @@ async fn handle_connection( match connection.read_datagram().await { Ok(bytes) => { if let Err(err) = sender.send((remote_pubkey, remote_address, bytes)) { - close_quic_endpoint(endpoint); + close_quic_endpoint(&endpoint); return Err(Error::from(err)); } } @@ -265,67 +335,53 @@ async fn handle_connection( } async fn send_datagram_task( + connection: Connection, + mut receiver: AsyncReceiver, +) -> Result<(), Error> { + while let Some(bytes) = receiver.recv().await { + connection.send_datagram(bytes)?; + } + Ok(()) +} + +async fn make_connection_task( endpoint: Endpoint, remote_address: SocketAddr, - bytes: Bytes, sender: Sender<(Pubkey, SocketAddr, Bytes)>, - cache: Arc>, + receiver: AsyncReceiver, + router: Arc>>>, + cache: Arc>>, ) { - if let Err(err) = send_datagram(&endpoint, remote_address, bytes, sender, cache).await { - error!("send_datagram: {remote_address}, {err:?}"); + if let Err(err) = + make_connection(endpoint, remote_address, sender, receiver, router, cache).await + { + error!("make_connection: {remote_address}, {err:?}"); } } -async fn send_datagram( - endpoint: &Endpoint, +async fn make_connection( + endpoint: Endpoint, remote_address: SocketAddr, - bytes: Bytes, sender: Sender<(Pubkey, SocketAddr, Bytes)>, - cache: Arc>, + receiver: AsyncReceiver, + router: Arc>>>, + cache: Arc>>, ) -> Result<(), Error> { - let connection = get_connection(endpoint, remote_address, sender, cache).await?; - connection.send_datagram(bytes)?; - Ok(()) -} - -async fn get_connection( - endpoint: &Endpoint, - remote_address: SocketAddr, - sender: Sender<(Pubkey, SocketAddr, Bytes)>, - cache: Arc>, -) -> Result { - let entry = get_cache_entry(remote_address, &cache).await; - { - let connection: Option = entry.read().await.clone(); - if let Some(connection) = connection { - if connection.close_reason().is_none() { - return Ok(connection); - } - } - } - let connection = { - // Need to write lock here so that only one task initiates - // a new connection to the same remote_address. - let mut entry = entry.write().await; - if let Some(connection) = entry.deref() { - if connection.close_reason().is_none() { - return Ok(connection.clone()); - } - } - let connection = endpoint - .connect(remote_address, CONNECT_SERVER_NAME)? - .await?; - entry.insert(connection).clone() - }; - tokio::task::spawn(handle_connection_error( - endpoint.clone(), + let connection = endpoint + .connect(remote_address, CONNECT_SERVER_NAME)? + .await?; + handle_connection( + endpoint, connection.remote_address(), get_remote_pubkey(&connection)?, - connection.clone(), + connection, sender, + receiver, + router, cache, - )); - Ok(connection) + ) + .await; + Ok(()) } fn get_remote_pubkey(connection: &Connection) -> Result { @@ -341,62 +397,34 @@ fn get_remote_pubkey(connection: &Connection) -> Result { } } -async fn get_cache_entry( - remote_address: SocketAddr, - cache: &RwLock, -) -> Arc>> { - let key = (remote_address, /*remote_pubkey:*/ None); - if let Some(entry) = cache.read().await.get(&key) { - return entry.clone(); - } - cache.write().await.entry(key).or_default().clone() -} - async fn cache_connection( - remote_address: SocketAddr, remote_pubkey: Pubkey, connection: Connection, - cache: &RwLock, + cache: &Mutex>, ) { - let entries: [Arc>>; 2] = { - let mut cache = cache.write().await; - [Some(remote_pubkey), None].map(|remote_pubkey| { - let key = (remote_address, remote_pubkey); - cache.entry(key).or_default().clone() - }) + let Some(old) = cache.lock().await.insert(remote_pubkey, connection) else { + return; }; - let mut entry = entries[0].write().await; - *entries[1].write().await = Some(connection.clone()); - if let Some(old) = entry.replace(connection) { - drop(entry); - old.close( - CONNECTION_CLOSE_ERROR_CODE_REPLACED, - CONNECTION_CLOSE_REASON_REPLACED, - ); - } + old.close( + CONNECTION_CLOSE_ERROR_CODE_REPLACED, + CONNECTION_CLOSE_REASON_REPLACED, + ); } async fn drop_connection( - remote_address: SocketAddr, remote_pubkey: Pubkey, connection: &Connection, - cache: &RwLock, + cache: &Mutex>, ) { - if connection.close_reason().is_none() { - connection.close( - CONNECTION_CLOSE_ERROR_CODE_DROPPED, - CONNECTION_CLOSE_REASON_DROPPED, - ); - } - let key = (remote_address, Some(remote_pubkey)); - if let Entry::Occupied(entry) = cache.write().await.entry(key) { - if matches!(entry.get().read().await.deref(), - Some(entry) if entry.stable_id() == connection.stable_id()) - { + connection.close( + CONNECTION_CLOSE_ERROR_CODE_DROPPED, + CONNECTION_CLOSE_REASON_DROPPED, + ); + if let Entry::Occupied(entry) = cache.lock().await.entry(remote_pubkey) { + if entry.get().stable_id() == connection.stable_id() { entry.remove(); } } - // Cache entry for (remote_address, None) will be lazily evicted. } impl From> for Error {