diff --git a/client/src/connection_cache.rs b/client/src/connection_cache.rs index 2c5ae30f954430..f0628d3e32b9de 100644 --- a/client/src/connection_cache.rs +++ b/client/src/connection_cache.rs @@ -9,8 +9,14 @@ use { indexmap::map::{Entry, IndexMap}, rand::{thread_rng, Rng}, solana_measure::measure::Measure, - solana_sdk::{quic::QUIC_PORT_OFFSET, signature::Keypair, timing::AtomicInterval}, - solana_streamer::tls_certificates::new_self_signed_tls_certificate_chain, + solana_sdk::{ + pubkey::Pubkey, quic::QUIC_PORT_OFFSET, signature::Keypair, timing::AtomicInterval, + }, + solana_streamer::{ + nonblocking::quic::{compute_max_allowed_uni_streams, ConnectionPeerType}, + streamer::StakedNodes, + tls_certificates::new_self_signed_tls_certificate_chain, + }, std::{ error::Error, net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket}, @@ -228,6 +234,8 @@ pub struct ConnectionCache { tpu_udp_socket: Arc, client_certificate: Arc, use_quic: bool, + maybe_staked_nodes: Option>>, + maybe_client_pubkey: Option, } /// Models the pool of connections @@ -279,6 +287,15 @@ impl ConnectionCache { Ok(()) } + pub fn set_staked_nodes( + &mut self, + staked_nodes: &Arc>, + client_pubkey: &Pubkey, + ) { + self.maybe_staked_nodes = Some(staked_nodes.clone()); + self.maybe_client_pubkey = Some(*client_pubkey); + } + pub fn with_udp(connection_pool_size: usize) -> Self { // The minimum pool size is 1. let connection_pool_size = 1.max(connection_pool_size); @@ -303,6 +320,24 @@ impl ConnectionCache { } } + fn compute_max_parallel_chunks(&self) -> usize { + let (client_type, stake, total_stake) = + self.maybe_client_pubkey + .map_or((ConnectionPeerType::Unstaked, 0, 0), |pubkey| { + self.maybe_staked_nodes.as_ref().map_or( + (ConnectionPeerType::Unstaked, 0, 0), + |stakes| { + let rstakes = stakes.read().unwrap(); + rstakes.pubkey_stake_map.get(&pubkey).map_or( + (ConnectionPeerType::Unstaked, 0, rstakes.total_stake), + |stake| (ConnectionPeerType::Staked, *stake, rstakes.total_stake), + ) + }, + ) + }); + compute_max_allowed_uni_streams(client_type, stake, total_stake) + } + /// Create a lazy connection object under the exclusive lock of the cache map if there is not /// enough used connections in the connection pool for the specified address. /// Returns CreateConnectionResult. @@ -335,6 +370,7 @@ impl ConnectionCache { BaseTpuConnection::Quic(Arc::new(QuicClient::new( endpoint.as_ref().unwrap().clone(), *addr, + self.compute_max_parallel_chunks(), ))) }; @@ -534,6 +570,8 @@ impl Default for ConnectionCache { key: priv_key, }), use_quic: DEFAULT_TPU_USE_QUIC, + maybe_staked_nodes: None, + maybe_client_pubkey: None, } } } @@ -604,8 +642,18 @@ mod tests { }, rand::{Rng, SeedableRng}, rand_chacha::ChaChaRng, - solana_sdk::quic::QUIC_PORT_OFFSET, - std::net::{IpAddr, Ipv4Addr, SocketAddr}, + solana_sdk::{ + pubkey::Pubkey, + quic::{ + QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS, QUIC_MIN_STAKED_CONCURRENT_STREAMS, + QUIC_PORT_OFFSET, + }, + }, + solana_streamer::streamer::StakedNodes, + std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::{Arc, RwLock}, + }, }; fn get_addr(rng: &mut ChaChaRng) -> SocketAddr { @@ -661,6 +709,55 @@ mod tests { let _conn = map.get(&addr).expect("Address not found"); } + #[test] + fn test_connection_cache_max_parallel_chunks() { + solana_logger::setup(); + let mut connection_cache = ConnectionCache::default(); + assert_eq!( + connection_cache.compute_max_parallel_chunks(), + QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS + ); + + let staked_nodes = Arc::new(RwLock::new(StakedNodes::default())); + let pubkey = Pubkey::new_unique(); + connection_cache.set_staked_nodes(&staked_nodes, &pubkey); + assert_eq!( + connection_cache.compute_max_parallel_chunks(), + QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS + ); + + staked_nodes.write().unwrap().total_stake = 10000; + assert_eq!( + connection_cache.compute_max_parallel_chunks(), + QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS + ); + + staked_nodes + .write() + .unwrap() + .pubkey_stake_map + .insert(pubkey, 1); + assert_eq!( + connection_cache.compute_max_parallel_chunks(), + QUIC_MIN_STAKED_CONCURRENT_STREAMS + ); + + staked_nodes + .write() + .unwrap() + .pubkey_stake_map + .remove(&pubkey); + staked_nodes + .write() + .unwrap() + .pubkey_stake_map + .insert(pubkey, 1000); + assert_ne!( + connection_cache.compute_max_parallel_chunks(), + QUIC_MIN_STAKED_CONCURRENT_STREAMS + ); + } + // Test that we can get_connection with a connection cache configured for quic // on an address with a port that, if QUIC_PORT_OFFSET were added to it, it would overflow to // an invalid port. diff --git a/client/src/nonblocking/quic_client.rs b/client/src/nonblocking/quic_client.rs index e1fa31d0c14413..160d85c9b551a7 100644 --- a/client/src/nonblocking/quic_client.rs +++ b/client/src/nonblocking/quic_client.rs @@ -267,15 +267,21 @@ pub struct QuicClient { connection: Arc>>, addr: SocketAddr, stats: Arc, + num_chunks: usize, } impl QuicClient { - pub fn new(endpoint: Arc, addr: SocketAddr) -> Self { + pub fn new( + endpoint: Arc, + addr: SocketAddr, + num_chunks: usize, + ) -> Self { Self { endpoint, connection: Arc::new(Mutex::new(None)), addr, stats: Arc::new(ClientStats::default()), + num_chunks, } } @@ -441,7 +447,7 @@ impl QuicClient { fn compute_chunk_length(num_buffers_to_chunk: usize, num_chunks: usize) -> usize { // The function is equivalent to checked div_ceil() - // Also, if num_chunks == 0 || num_buffers_per_chunk == 0, return 1 + // Also, if num_chunks == 0 || num_buffers_to_chunk == 0, return 1 num_buffers_to_chunk .checked_div(num_chunks) .map_or(1, |value| { @@ -485,8 +491,7 @@ impl QuicClient { // by just getting a reference to the NewConnection once let connection_ref: &NewConnection = &connection; - let chunk_len = - Self::compute_chunk_length(buffers.len() - 1, QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS); + let chunk_len = Self::compute_chunk_length(buffers.len() - 1, self.num_chunks); let chunks = buffers[1..buffers.len()].iter().chunks(chunk_len); let futures: Vec<_> = chunks @@ -530,7 +535,11 @@ impl QuicTpuConnection { addr: SocketAddr, connection_stats: Arc, ) -> Self { - let client = Arc::new(QuicClient::new(endpoint, addr)); + let client = Arc::new(QuicClient::new( + endpoint, + addr, + QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS, + )); Self::new_with_client(client, connection_stats) } diff --git a/core/src/tpu.rs b/core/src/tpu.rs index fc3a18e77cb582..3ce676c9716cd1 100644 --- a/core/src/tpu.rs +++ b/core/src/tpu.rs @@ -100,6 +100,7 @@ impl Tpu { connection_cache: &Arc, keypair: &Keypair, enable_quic_servers: bool, + staked_nodes: &Arc>, ) -> Self { let TpuSockets { transactions: transactions_sockets, @@ -127,7 +128,6 @@ impl Tpu { Some(bank_forks.read().unwrap().get_vote_only_mode_signal()), ); - let staked_nodes = Arc::new(RwLock::new(StakedNodes::default())); let staked_nodes_updater_service = StakedNodesUpdaterService::new( exit.clone(), cluster_info.clone(), @@ -181,7 +181,7 @@ impl Tpu { forwarded_packet_sender, exit.clone(), MAX_QUIC_CONNECTIONS_PER_PEER, - staked_nodes, + staked_nodes.clone(), MAX_STAKED_CONNECTIONS.saturating_add(MAX_UNSTAKED_CONNECTIONS), 0, // Prevent unstaked nodes from forwarding transactions stats, diff --git a/core/src/validator.rs b/core/src/validator.rs index 439f338a40ccc6..608b07770de983 100644 --- a/core/src/validator.rs +++ b/core/src/validator.rs @@ -95,7 +95,7 @@ use { timing::timestamp, }, solana_send_transaction_service::send_transaction_service, - solana_streamer::socket::SocketAddrSpace, + solana_streamer::{socket::SocketAddrSpace, streamer::StakedNodes}, solana_vote_program::vote_state::VoteState, std::{ collections::{HashMap, HashSet}, @@ -672,12 +672,15 @@ impl Validator { ); let poh_recorder = Arc::new(Mutex::new(poh_recorder)); + let staked_nodes = Arc::new(RwLock::new(StakedNodes::default())); + let connection_cache = match use_quic { true => { let mut connection_cache = ConnectionCache::new(tpu_connection_pool_size); connection_cache .update_client_certificate(&identity_keypair, node.info.gossip.ip()) .expect("Failed to update QUIC client certificates"); + connection_cache.set_staked_nodes(&staked_nodes, &identity_keypair.pubkey()); Arc::new(connection_cache) } false => Arc::new(ConnectionCache::with_udp(tpu_connection_pool_size)), @@ -994,6 +997,7 @@ impl Validator { &connection_cache, &identity_keypair, config.enable_quic_servers, + &staked_nodes, ); datapoint_info!( diff --git a/streamer/src/nonblocking/quic.rs b/streamer/src/nonblocking/quic.rs index d103b6fba8b538..5f866d30e029ba 100644 --- a/streamer/src/nonblocking/quic.rs +++ b/streamer/src/nonblocking/quic.rs @@ -159,10 +159,10 @@ fn get_connection_stake( }) } -fn compute_max_allowed_uni_streams( +pub fn compute_max_allowed_uni_streams( peer_type: ConnectionPeerType, peer_stake: u64, - staked_nodes: Arc>, + total_stake: u64, ) -> usize { if peer_stake == 0 { // Treat stake = 0 as unstaked @@ -170,13 +170,11 @@ fn compute_max_allowed_uni_streams( } else { match peer_type { ConnectionPeerType::Staked => { - let staked_nodes = staked_nodes.read().unwrap(); - // No checked math for f64 type. So let's explicitly check for 0 here - if staked_nodes.total_stake == 0 { + if total_stake == 0 { QUIC_MIN_STAKED_CONCURRENT_STREAMS } else { - (((peer_stake as f64 / staked_nodes.total_stake as f64) + (((peer_stake as f64 / total_stake as f64) * QUIC_TOTAL_STAKED_CONCURRENT_STREAMS as f64) as usize) .max(QUIC_MIN_STAKED_CONCURRENT_STREAMS) @@ -264,17 +262,19 @@ async fn setup_connection( if let Some((mut connection_table_l, stake)) = table_and_stake { let table_type = connection_table_l.peer_type; - let max_uni_streams = VarInt::from_u64(compute_max_allowed_uni_streams( - table_type, - stake, - staked_nodes.clone(), - ) as u64); + let total_stake = staked_nodes.read().map_or(0, |stakes| stakes.total_stake); + drop(staked_nodes); + + let max_uni_streams = + VarInt::from_u64( + compute_max_allowed_uni_streams(table_type, stake, total_stake) as u64, + ); debug!( "Peer type: {:?}, stake {}, total stake {}, max streams {}", table_type, stake, - staked_nodes.read().unwrap().total_stake, + total_stake, max_uni_streams.unwrap().into_inner() ); @@ -558,7 +558,7 @@ impl Drop for ConnectionEntry { } #[derive(Copy, Clone, Debug)] -enum ConnectionPeerType { +pub enum ConnectionPeerType { Unstaked, Staked, } @@ -1406,58 +1406,52 @@ pub mod test { #[test] fn test_max_allowed_uni_streams() { - let staked_nodes = Arc::new(RwLock::new(StakedNodes::default())); assert_eq!( - compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 0, staked_nodes.clone()), + compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 0, 0), QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS ); assert_eq!( - compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 10, staked_nodes.clone()), + compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 10, 0), QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS ); assert_eq!( - compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 0, staked_nodes.clone()), + compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 0, 0), QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS ); assert_eq!( - compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 10, staked_nodes.clone()), + compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 10, 0), QUIC_MIN_STAKED_CONCURRENT_STREAMS ); - staked_nodes.write().unwrap().total_stake = 10000; assert_eq!( - compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 1000, staked_nodes.clone()), + compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 1000, 10000), (QUIC_TOTAL_STAKED_CONCURRENT_STREAMS / (10_f64)) as usize ); assert_eq!( - compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 100, staked_nodes.clone()), + compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 100, 10000), (QUIC_TOTAL_STAKED_CONCURRENT_STREAMS / (100_f64)) as usize ); assert_eq!( - compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 10, staked_nodes.clone()), + compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 10, 10000), QUIC_MIN_STAKED_CONCURRENT_STREAMS ); assert_eq!( - compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 1, staked_nodes.clone()), + compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 1, 10000), QUIC_MIN_STAKED_CONCURRENT_STREAMS ); assert_eq!( - compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 0, staked_nodes.clone()), + compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 0, 10000), QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS ); assert_eq!( - compute_max_allowed_uni_streams( - ConnectionPeerType::Unstaked, - 1000, - staked_nodes.clone() - ), + compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 1000, 10000), QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS ); assert_eq!( - compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 1, staked_nodes.clone()), + compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 1, 10000), QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS ); assert_eq!( - compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 0, staked_nodes), + compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 0, 10000), QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS ); }