From b7b03cbb056b0a918d4d26a114c852f0977cc2a9 Mon Sep 17 00:00:00 2001 From: ryleung-solana <91908731+ryleung-solana@users.noreply.github.com> Date: Fri, 26 Aug 2022 05:10:28 +0800 Subject: [PATCH] Fix quic staked chunking (#27402) --- tpu-client/src/connection_cache.rs | 14 +++---- tpu-client/src/nonblocking/quic_client.rs | 50 ++--------------------- 2 files changed, 11 insertions(+), 53 deletions(-) diff --git a/tpu-client/src/connection_cache.rs b/tpu-client/src/connection_cache.rs index 9e5efff3f3e061..28d8f10e5e8433 100644 --- a/tpu-client/src/connection_cache.rs +++ b/tpu-client/src/connection_cache.rs @@ -320,7 +320,7 @@ impl ConnectionCache { } } - fn compute_max_parallel_chunks(&self) -> usize { + fn compute_max_parallel_streams(&self) -> usize { let (client_type, stake, total_stake) = self.maybe_client_pubkey .map_or((ConnectionPeerType::Unstaked, 0, 0), |pubkey| { @@ -370,7 +370,7 @@ impl ConnectionCache { BaseTpuConnection::Quic(Arc::new(QuicClient::new( endpoint.as_ref().unwrap().clone(), *addr, - self.compute_max_parallel_chunks(), + self.compute_max_parallel_streams(), ))) }; @@ -730,7 +730,7 @@ mod tests { solana_logger::setup(); let mut connection_cache = ConnectionCache::default(); assert_eq!( - connection_cache.compute_max_parallel_chunks(), + connection_cache.compute_max_parallel_streams(), QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS ); @@ -738,13 +738,13 @@ mod tests { let pubkey = Pubkey::new_unique(); connection_cache.set_staked_nodes(&staked_nodes, &pubkey); assert_eq!( - connection_cache.compute_max_parallel_chunks(), + connection_cache.compute_max_parallel_streams(), QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS ); staked_nodes.write().unwrap().total_stake = 10000; assert_eq!( - connection_cache.compute_max_parallel_chunks(), + connection_cache.compute_max_parallel_streams(), QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS ); @@ -754,7 +754,7 @@ mod tests { .pubkey_stake_map .insert(pubkey, 1); assert_eq!( - connection_cache.compute_max_parallel_chunks(), + connection_cache.compute_max_parallel_streams(), QUIC_MIN_STAKED_CONCURRENT_STREAMS ); @@ -769,7 +769,7 @@ mod tests { .pubkey_stake_map .insert(pubkey, 1000); assert_ne!( - connection_cache.compute_max_parallel_chunks(), + connection_cache.compute_max_parallel_streams(), QUIC_MIN_STAKED_CONCURRENT_STREAMS ); } diff --git a/tpu-client/src/nonblocking/quic_client.rs b/tpu-client/src/nonblocking/quic_client.rs index 295460aabc8a61..a793ecb766e7f8 100644 --- a/tpu-client/src/nonblocking/quic_client.rs +++ b/tpu-client/src/nonblocking/quic_client.rs @@ -270,21 +270,21 @@ pub struct QuicClient { connection: Arc>>, addr: SocketAddr, stats: Arc, - num_chunks: usize, + chunk_size: usize, } impl QuicClient { pub fn new( endpoint: Arc, addr: SocketAddr, - num_chunks: usize, + chunk_size: usize, ) -> Self { Self { endpoint, connection: Arc::new(Mutex::new(None)), addr, stats: Arc::new(ClientStats::default()), - num_chunks, + chunk_size, } } @@ -451,21 +451,6 @@ impl QuicClient { Ok(()) } - fn compute_chunk_length(num_buffers_to_chunk: usize, num_chunks: usize) -> usize { - // The function is equivalent to checked div_ceil() - // Also, if num_chunks == 0 || num_buffers_to_chunk == 0, return 1 - num_buffers_to_chunk - .checked_div(num_chunks) - .map_or(1, |value| { - if num_buffers_to_chunk.checked_rem(num_chunks).unwrap_or(0) != 0 { - value.saturating_add(1) - } else { - value - } - }) - .max(1) - } - pub async fn send_batch( &self, buffers: &[T], @@ -498,8 +483,7 @@ impl QuicClient { // by just getting a reference to the NewConnection once let connection_ref: &NewConnection = &connection; - let chunk_len = Self::compute_chunk_length(buffers.len() - 1, self.num_chunks); - let chunks = buffers[1..buffers.len()].iter().chunks(chunk_len); + let chunks = buffers[1..buffers.len()].iter().chunks(self.chunk_size); let futures: Vec<_> = chunks .into_iter() @@ -608,29 +592,3 @@ impl TpuConnection for QuicTpuConnection { Ok(()) } } - -#[cfg(test)] -mod tests { - use crate::nonblocking::quic_client::QuicClient; - - #[test] - fn test_transaction_batch_chunking() { - assert_eq!(QuicClient::compute_chunk_length(0, 0), 1); - assert_eq!(QuicClient::compute_chunk_length(10, 0), 1); - assert_eq!(QuicClient::compute_chunk_length(0, 10), 1); - assert_eq!(QuicClient::compute_chunk_length(usize::MAX, usize::MAX), 1); - assert_eq!(QuicClient::compute_chunk_length(10, usize::MAX), 1); - assert!(QuicClient::compute_chunk_length(usize::MAX, 10) == (usize::MAX / 10) + 1); - assert_eq!(QuicClient::compute_chunk_length(10, 1), 10); - assert_eq!(QuicClient::compute_chunk_length(10, 2), 5); - assert_eq!(QuicClient::compute_chunk_length(10, 3), 4); - assert_eq!(QuicClient::compute_chunk_length(10, 4), 3); - assert_eq!(QuicClient::compute_chunk_length(10, 5), 2); - assert_eq!(QuicClient::compute_chunk_length(10, 6), 2); - assert_eq!(QuicClient::compute_chunk_length(10, 7), 2); - assert_eq!(QuicClient::compute_chunk_length(10, 8), 2); - assert_eq!(QuicClient::compute_chunk_length(10, 9), 2); - assert_eq!(QuicClient::compute_chunk_length(10, 10), 1); - assert_eq!(QuicClient::compute_chunk_length(10, 11), 1); - } -}