diff --git a/core/src/validator.rs b/core/src/validator.rs index b206cf87b30d8c..f48eff357fe29a 100644 --- a/core/src/validator.rs +++ b/core/src/validator.rs @@ -1182,6 +1182,7 @@ impl Validator { .expect("Operator must spin up node with valid QUIC TVU address") .ip(), turbine_quic_endpoint_sender, + bank_forks.clone(), ) .unwrap(); diff --git a/turbine/src/quic_endpoint.rs b/turbine/src/quic_endpoint.rs index df8d56437084c4..776a7c7d7f5c04 100644 --- a/turbine/src/quic_endpoint.rs +++ b/turbine/src/quic_endpoint.rs @@ -10,15 +10,20 @@ use { rcgen::RcgenError, rustls::{Certificate, PrivateKey}, solana_quic_client::nonblocking::quic_client::SkipServerVerification, + solana_runtime::bank_forks::BankForks, solana_sdk::{pubkey::Pubkey, signature::Keypair}, solana_streamer::{ quic::SkipClientVerification, tls_certificates::new_self_signed_tls_certificate, }, std::{ + cmp::Reverse, collections::{hash_map::Entry, HashMap}, io::Error as IoError, net::{IpAddr, SocketAddr, UdpSocket}, - sync::Arc, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, RwLock, + }, }, thiserror::Error, tokio::{ @@ -32,6 +37,7 @@ use { const CLIENT_CHANNEL_BUFFER: usize = 1 << 14; const ROUTER_CHANNEL_BUFFER: usize = 64; +const CONNECTION_CACHE_CAPACITY: usize = 3072; const INITIAL_MAXIMUM_TRANSMISSION_UNIT: u16 = 1280; const ALPN_TURBINE_PROTOCOL_ID: &[u8] = b"solana-turbine"; const CONNECT_SERVER_NAME: &str = "solana-turbine"; @@ -75,6 +81,7 @@ pub fn new_quic_endpoint( socket: UdpSocket, address: IpAddr, sender: Sender<(Pubkey, SocketAddr, Bytes)>, + bank_forks: Arc>, ) -> Result< ( Endpoint, @@ -98,12 +105,15 @@ pub fn new_quic_endpoint( )? }; endpoint.set_default_client_config(client_config); + let prune_cache_pending = Arc::::default(); let cache = Arc::>>::default(); let router = Arc::>>>::default(); let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_BUFFER); let server_task = runtime.spawn(run_server( endpoint.clone(), sender.clone(), + bank_forks.clone(), + prune_cache_pending.clone(), router.clone(), cache.clone(), )); @@ -111,6 +121,8 @@ pub fn new_quic_endpoint( endpoint.clone(), client_receiver, sender, + bank_forks, + prune_cache_pending, router, cache, )); @@ -163,6 +175,8 @@ fn new_transport_config() -> TransportConfig { async fn run_server( endpoint: Endpoint, sender: Sender<(Pubkey, SocketAddr, Bytes)>, + bank_forks: Arc>, + prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) { @@ -171,6 +185,8 @@ async fn run_server( endpoint.clone(), connecting, sender.clone(), + bank_forks.clone(), + prune_cache_pending.clone(), router.clone(), cache.clone(), )); @@ -181,6 +197,8 @@ async fn run_client( endpoint: Endpoint, mut receiver: AsyncReceiver<(SocketAddr, Bytes)>, sender: Sender<(Pubkey, SocketAddr, Bytes)>, + bank_forks: Arc>, + prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) { @@ -219,6 +237,8 @@ async fn run_client( remote_address, sender.clone(), receiver, + bank_forks.clone(), + prune_cache_pending.clone(), router.clone(), cache.clone(), )); @@ -232,10 +252,12 @@ async fn handle_connecting_error( endpoint: Endpoint, connecting: Connecting, sender: Sender<(Pubkey, SocketAddr, Bytes)>, + bank_forks: Arc>, + prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) { - if let Err(err) = handle_connecting(endpoint, connecting, sender, router, cache).await { + if let Err(err) = handle_connecting(endpoint, connecting, sender, bank_forks, prune_cache_pending, router, cache).await { error!("handle_connecting: {err:?}"); } } @@ -244,6 +266,8 @@ async fn handle_connecting( endpoint: Endpoint, connecting: Connecting, sender: Sender<(Pubkey, SocketAddr, Bytes)>, + bank_forks: Arc>, + prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) -> Result<(), Error> { @@ -262,6 +286,8 @@ async fn handle_connecting( connection, sender, receiver, + bank_forks, + prune_cache_pending, router, cache, ) @@ -276,10 +302,12 @@ async fn handle_connection( connection: Connection, sender: Sender<(Pubkey, SocketAddr, Bytes)>, receiver: AsyncReceiver, + bank_forks: Arc>, + prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) { - cache_connection(remote_pubkey, connection.clone(), &cache).await; + cache_connection(remote_pubkey, connection.clone(), bank_forks, prune_cache_pending, router.clone(), cache.clone()).await; let send_datagram_task = tokio::task::spawn(send_datagram_task(connection.clone(), receiver)); let read_datagram_task = tokio::task::spawn(read_datagram_task( endpoint, @@ -349,11 +377,13 @@ async fn make_connection_task( remote_address: SocketAddr, sender: Sender<(Pubkey, SocketAddr, Bytes)>, receiver: AsyncReceiver, + bank_forks: Arc>, + prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) { if let Err(err) = - make_connection(endpoint, remote_address, sender, receiver, router, cache).await + make_connection(endpoint, remote_address, sender, receiver, bank_forks, prune_cache_pending, router, cache).await { error!("make_connection: {remote_address}, {err:?}"); } @@ -364,6 +394,8 @@ async fn make_connection( remote_address: SocketAddr, sender: Sender<(Pubkey, SocketAddr, Bytes)>, receiver: AsyncReceiver, + bank_forks: Arc>, + prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) -> Result<(), Error> { @@ -377,6 +409,8 @@ async fn make_connection( connection, sender, receiver, + bank_forks, + prune_cache_pending, router, cache, ) @@ -400,15 +434,34 @@ fn get_remote_pubkey(connection: &Connection) -> Result { async fn cache_connection( remote_pubkey: Pubkey, connection: Connection, - cache: &Mutex>, + bank_forks: Arc>, + prune_cache_pending: Arc, + router: Arc>>>, + cache: Arc>>, ) { - let Some(old) = cache.lock().await.insert(remote_pubkey, connection) else { - return; + let (old, should_prune_cache) = { + let mut cache = cache.lock().await; + ( + cache.insert(remote_pubkey, connection), + cache.len() >= CONNECTION_CACHE_CAPACITY.saturating_mul(2), + ) }; - old.close( - CONNECTION_CLOSE_ERROR_CODE_REPLACED, - CONNECTION_CLOSE_REASON_REPLACED, - ); + if let Some(old) = old { + old.close( + CONNECTION_CLOSE_ERROR_CODE_REPLACED, + CONNECTION_CLOSE_REASON_REPLACED, + ); + } + if should_prune_cache { + if !prune_cache_pending.swap(true, Ordering::Relaxed) { + tokio::task::spawn(prune_connection_cache( + bank_forks, + prune_cache_pending, + router, + cache, + )); + } + } } async fn drop_connection( @@ -427,6 +480,49 @@ async fn drop_connection( } } +async fn prune_connection_cache( + bank_forks: Arc>, + prune_cache_pending: Arc, + router: Arc>>>, + cache: Arc>>, +) { + debug_assert!(prune_cache_pending.load(Ordering::Relaxed)); + let staked_nodes = { + let root_bank = bank_forks.read().unwrap().root_bank(); + root_bank.staked_nodes() + }; + { + let mut cache = cache.lock().await; + if cache.len() < CONNECTION_CACHE_CAPACITY.saturating_mul(2) { + return; + } + let mut connections: Vec<_> = cache + .drain() + .filter(|(_, connection)| connection.close_reason().is_none()) + .map(|entry @ (pubkey, _)| { + let stake = staked_nodes.get(&pubkey).copied().unwrap_or_default(); + (stake, entry) + }) + .collect(); + connections + .select_nth_unstable_by_key(CONNECTION_CACHE_CAPACITY, |&(stake, _)| Reverse(stake)); + for (_, (_, connection)) in &connections[CONNECTION_CACHE_CAPACITY..] { + connection.close( + CONNECTION_CLOSE_ERROR_CODE_DROPPED, + CONNECTION_CLOSE_REASON_DROPPED, + ); + } + cache.extend( + connections + .into_iter() + .take(CONNECTION_CACHE_CAPACITY) + .map(|(_, entry)| entry), + ); + prune_cache_pending.store(false, Ordering::Relaxed); + } + router.write().await.retain(|_, sender| !sender.is_closed()); +} + impl From> for Error { fn from(_: crossbeam_channel::SendError) -> Self { Error::ChannelSendError @@ -438,6 +534,8 @@ mod tests { use { super::*, itertools::{izip, multiunzip}, + solana_ledger::genesis_utils::{create_genesis_config, GenesisConfigInfo}, + solana_runtime::bank::Bank, solana_sdk::signature::Signer, std::{iter::repeat_with, net::Ipv4Addr, time::Duration}, }; @@ -465,6 +563,12 @@ mod tests { repeat_with(crossbeam_channel::unbounded::<(Pubkey, SocketAddr, Bytes)>) .take(NUM_ENDPOINTS) .unzip(); + let bank_forks = { + let GenesisConfigInfo { genesis_config, .. } = + create_genesis_config(/*mint_lamports:*/ 100_000); + let bank = Bank::new_for_tests(&genesis_config); + Arc::new(RwLock::new(BankForks::new(bank))) + }; let (endpoints, senders, tasks): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(keypairs.iter().zip(sockets).zip(senders).map( |((keypair, socket), sender)| { @@ -474,6 +578,7 @@ mod tests { socket, IpAddr::V4(Ipv4Addr::LOCALHOST), sender, + bank_forks.clone(), ) .unwrap() },