Skip to content

Commit

Permalink
Fix quic staked chunking (solana-labs#27402)
Browse files Browse the repository at this point in the history
  • Loading branch information
ryleung-solana authored Aug 25, 2022
1 parent d1522fc commit b7b03cb
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 53 deletions.
14 changes: 7 additions & 7 deletions tpu-client/src/connection_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ impl ConnectionCache {
}
}

fn compute_max_parallel_chunks(&self) -> usize {
fn compute_max_parallel_streams(&self) -> usize {
let (client_type, stake, total_stake) =
self.maybe_client_pubkey
.map_or((ConnectionPeerType::Unstaked, 0, 0), |pubkey| {
Expand Down Expand Up @@ -370,7 +370,7 @@ impl ConnectionCache {
BaseTpuConnection::Quic(Arc::new(QuicClient::new(
endpoint.as_ref().unwrap().clone(),
*addr,
self.compute_max_parallel_chunks(),
self.compute_max_parallel_streams(),
)))
};

Expand Down Expand Up @@ -730,21 +730,21 @@ mod tests {
solana_logger::setup();
let mut connection_cache = ConnectionCache::default();
assert_eq!(
connection_cache.compute_max_parallel_chunks(),
connection_cache.compute_max_parallel_streams(),
QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
);

let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
let pubkey = Pubkey::new_unique();
connection_cache.set_staked_nodes(&staked_nodes, &pubkey);
assert_eq!(
connection_cache.compute_max_parallel_chunks(),
connection_cache.compute_max_parallel_streams(),
QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
);

staked_nodes.write().unwrap().total_stake = 10000;
assert_eq!(
connection_cache.compute_max_parallel_chunks(),
connection_cache.compute_max_parallel_streams(),
QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
);

Expand All @@ -754,7 +754,7 @@ mod tests {
.pubkey_stake_map
.insert(pubkey, 1);
assert_eq!(
connection_cache.compute_max_parallel_chunks(),
connection_cache.compute_max_parallel_streams(),
QUIC_MIN_STAKED_CONCURRENT_STREAMS
);

Expand All @@ -769,7 +769,7 @@ mod tests {
.pubkey_stake_map
.insert(pubkey, 1000);
assert_ne!(
connection_cache.compute_max_parallel_chunks(),
connection_cache.compute_max_parallel_streams(),
QUIC_MIN_STAKED_CONCURRENT_STREAMS
);
}
Expand Down
50 changes: 4 additions & 46 deletions tpu-client/src/nonblocking/quic_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,21 +270,21 @@ pub struct QuicClient {
connection: Arc<Mutex<Option<QuicNewConnection>>>,
addr: SocketAddr,
stats: Arc<ClientStats>,
num_chunks: usize,
chunk_size: usize,
}

impl QuicClient {
pub fn new(
endpoint: Arc<QuicLazyInitializedEndpoint>,
addr: SocketAddr,
num_chunks: usize,
chunk_size: usize,
) -> Self {
Self {
endpoint,
connection: Arc::new(Mutex::new(None)),
addr,
stats: Arc::new(ClientStats::default()),
num_chunks,
chunk_size,
}
}

Expand Down Expand Up @@ -451,21 +451,6 @@ impl QuicClient {
Ok(())
}

fn compute_chunk_length(num_buffers_to_chunk: usize, num_chunks: usize) -> usize {
// The function is equivalent to checked div_ceil()
// Also, if num_chunks == 0 || num_buffers_to_chunk == 0, return 1
num_buffers_to_chunk
.checked_div(num_chunks)
.map_or(1, |value| {
if num_buffers_to_chunk.checked_rem(num_chunks).unwrap_or(0) != 0 {
value.saturating_add(1)
} else {
value
}
})
.max(1)
}

pub async fn send_batch<T>(
&self,
buffers: &[T],
Expand Down Expand Up @@ -498,8 +483,7 @@ impl QuicClient {
// by just getting a reference to the NewConnection once
let connection_ref: &NewConnection = &connection;

let chunk_len = Self::compute_chunk_length(buffers.len() - 1, self.num_chunks);
let chunks = buffers[1..buffers.len()].iter().chunks(chunk_len);
let chunks = buffers[1..buffers.len()].iter().chunks(self.chunk_size);

let futures: Vec<_> = chunks
.into_iter()
Expand Down Expand Up @@ -608,29 +592,3 @@ impl TpuConnection for QuicTpuConnection {
Ok(())
}
}

#[cfg(test)]
mod tests {
use crate::nonblocking::quic_client::QuicClient;

#[test]
fn test_transaction_batch_chunking() {
assert_eq!(QuicClient::compute_chunk_length(0, 0), 1);
assert_eq!(QuicClient::compute_chunk_length(10, 0), 1);
assert_eq!(QuicClient::compute_chunk_length(0, 10), 1);
assert_eq!(QuicClient::compute_chunk_length(usize::MAX, usize::MAX), 1);
assert_eq!(QuicClient::compute_chunk_length(10, usize::MAX), 1);
assert!(QuicClient::compute_chunk_length(usize::MAX, 10) == (usize::MAX / 10) + 1);
assert_eq!(QuicClient::compute_chunk_length(10, 1), 10);
assert_eq!(QuicClient::compute_chunk_length(10, 2), 5);
assert_eq!(QuicClient::compute_chunk_length(10, 3), 4);
assert_eq!(QuicClient::compute_chunk_length(10, 4), 3);
assert_eq!(QuicClient::compute_chunk_length(10, 5), 2);
assert_eq!(QuicClient::compute_chunk_length(10, 6), 2);
assert_eq!(QuicClient::compute_chunk_length(10, 7), 2);
assert_eq!(QuicClient::compute_chunk_length(10, 8), 2);
assert_eq!(QuicClient::compute_chunk_length(10, 9), 2);
assert_eq!(QuicClient::compute_chunk_length(10, 10), 1);
assert_eq!(QuicClient::compute_chunk_length(10, 11), 1);
}
}

0 comments on commit b7b03cb

Please sign in to comment.