From 2436a2bea8b0f7fdae33e5c6e7e418c17d23e994 Mon Sep 17 00:00:00 2001 From: Jon Cinque Date: Tue, 28 Jun 2022 11:01:49 -0400 Subject: [PATCH] client: Use async connection in async TPU client (#25969) * client: Add nonblocking QuicTpuConnection implementation * Remove integer arithmetic * client: Support sync and async connections in cache * client: Use async connection in async TPU client * Address feedback * Rename Connection -> BaseTpuConnection --- Cargo.lock | 1 + client/Cargo.toml | 1 + client/src/connection_cache.rs | 163 +++++++++++++++-------- client/src/nonblocking/quic_client.rs | 83 +++++++++++- client/src/nonblocking/tpu_client.rs | 42 ++++-- client/src/nonblocking/tpu_connection.rs | 3 +- client/src/nonblocking/udp_client.rs | 4 +- client/src/quic_client.rs | 89 +++++-------- client/src/tpu_connection.rs | 8 +- client/src/udp_client.rs | 3 +- client/tests/quic_client.rs | 128 +++++++++++++----- 11 files changed, 353 insertions(+), 172 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 84afbe36a489fd..578b9abf6055c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4859,6 +4859,7 @@ dependencies = [ "solana-measure", "solana-metrics", "solana-net-utils", + "solana-perf", "solana-sdk 1.11.2", "solana-streamer", "solana-transaction-status", diff --git a/client/Cargo.toml b/client/Cargo.toml index ef1e20fd459a87..b56824a7499418 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -62,6 +62,7 @@ anyhow = "1.0.57" assert_matches = "1.5.0" jsonrpc-http-server = "18.0.0" solana-logger = { path = "../logger", version = "=1.11.2" } +solana-perf = { path = "../perf", version = "=1.11.2" } [package.metadata.docs.rs] targets = ["x86_64-unknown-linux-gnu"] diff --git a/client/src/connection_cache.rs b/client/src/connection_cache.rs index 547e8e2a8de3f0..98da0f85bbbbae 100644 --- a/client/src/connection_cache.rs +++ b/client/src/connection_cache.rs @@ -1,9 +1,10 @@ use { crate::{ - nonblocking::quic_client::QuicLazyInitializedEndpoint, - quic_client::QuicTpuConnection, - tpu_connection::{ClientStats, Connection}, - udp_client::UdpTpuConnection, + nonblocking::{ + quic_client::{QuicClient, QuicLazyInitializedEndpoint}, + tpu_connection::NonblockingConnection, + }, + tpu_connection::{BlockingConnection, ClientStats}, }, indexmap::map::{Entry, IndexMap}, rand::{thread_rng, Rng}, @@ -246,7 +247,7 @@ pub struct ConnectionCache { /// Models the pool of connections struct ConnectionPool { /// The connections in the pool - connections: Vec>, + connections: Vec>, /// Connections in this pool share the same endpoint endpoint: Option>, @@ -255,7 +256,7 @@ struct ConnectionPool { impl ConnectionPool { /// Get a connection from the pool. It must have at least one connection in the pool. /// This randomly picks a connection in the pool. - fn borrow_connection(&self) -> Arc { + fn borrow_connection(&self) -> Arc { let mut rng = thread_rng(); let n = rng.gen_range(0, self.connections.len()); self.connections[n].clone() @@ -318,55 +319,49 @@ impl ConnectionCache { ) }); - let (cache_hit, connection_cache_stats, num_evictions, eviction_timing_ms) = - if to_create_connection { - let connection: Connection = match &self.use_quic { - UseQUIC::Yes => QuicTpuConnection::new( - endpoint.as_ref().unwrap().clone(), - *addr, - self.stats.clone(), - ) - .into(), - UseQUIC::No(socket) => { - UdpTpuConnection::new(socket.clone(), *addr, self.stats.clone()).into() - } - }; - - let connection = Arc::new(connection); - - // evict a connection if the cache is reaching upper bounds - let mut num_evictions = 0; - let mut get_connection_cache_eviction_measure = - Measure::start("get_connection_cache_eviction_measure"); - while map.len() >= MAX_CONNECTIONS { - let mut rng = thread_rng(); - let n = rng.gen_range(0, MAX_CONNECTIONS); - map.swap_remove_index(n); - num_evictions += 1; - } - get_connection_cache_eviction_measure.stop(); + let (cache_hit, num_evictions, eviction_timing_ms) = if to_create_connection { + let connection = match &self.use_quic { + UseQUIC::Yes => BaseTpuConnection::Quic(Arc::new(QuicClient::new( + endpoint.as_ref().unwrap().clone(), + *addr, + ))), + UseQUIC::No(socket) => BaseTpuConnection::Udp(socket.clone()), + }; - match map.entry(*addr) { - Entry::Occupied(mut entry) => { - let pool = entry.get_mut(); - pool.connections.push(connection); - } - Entry::Vacant(entry) => { - entry.insert(ConnectionPool { - connections: vec![connection], - endpoint, - }); - } + let connection = Arc::new(connection); + + // evict a connection if the cache is reaching upper bounds + let mut num_evictions = 0; + let mut get_connection_cache_eviction_measure = + Measure::start("get_connection_cache_eviction_measure"); + while map.len() >= MAX_CONNECTIONS { + let mut rng = thread_rng(); + let n = rng.gen_range(0, MAX_CONNECTIONS); + map.swap_remove_index(n); + num_evictions += 1; + } + get_connection_cache_eviction_measure.stop(); + + match map.entry(*addr) { + Entry::Occupied(mut entry) => { + let pool = entry.get_mut(); + pool.connections.push(connection); } - ( - false, - self.stats.clone(), - num_evictions, - get_connection_cache_eviction_measure.as_ms(), - ) - } else { - (true, self.stats.clone(), 0, 0) - }; + Entry::Vacant(entry) => { + entry.insert(ConnectionPool { + connections: vec![connection], + endpoint, + }); + } + } + ( + false, + num_evictions, + get_connection_cache_eviction_measure.as_ms(), + ) + } else { + (true, 0, 0) + }; let pool = map.get(addr).unwrap(); let connection = pool.borrow_connection(); @@ -374,7 +369,7 @@ impl ConnectionCache { CreateConnectionResult { connection, cache_hit, - connection_cache_stats, + connection_cache_stats: self.stats.clone(), num_evictions, eviction_timing_ms, } @@ -443,7 +438,10 @@ impl ConnectionCache { } } - pub fn get_connection(&self, addr: &SocketAddr) -> Arc { + fn get_connection_and_log_stats( + &self, + addr: &SocketAddr, + ) -> (Arc, Arc) { let mut get_connection_measure = Measure::start("get_connection_measure"); let GetConnectionResult { connection, @@ -490,7 +488,17 @@ impl ConnectionCache { .get_connection_ms .fetch_add(get_connection_measure.as_ms(), Ordering::Relaxed); - connection + (connection, connection_cache_stats) + } + + pub fn get_connection(&self, addr: &SocketAddr) -> BlockingConnection { + let (connection, connection_cache_stats) = self.get_connection_and_log_stats(addr); + connection.new_blocking_connection(*addr, connection_cache_stats) + } + + pub fn get_nonblocking_connection(&self, addr: &SocketAddr) -> NonblockingConnection { + let (connection, connection_cache_stats) = self.get_connection_and_log_stats(addr); + connection.new_nonblocking_connection(*addr, connection_cache_stats) } } @@ -507,8 +515,46 @@ impl Default for ConnectionCache { } } +enum BaseTpuConnection { + Udp(Arc), + Quic(Arc), +} +impl BaseTpuConnection { + fn new_blocking_connection( + &self, + addr: SocketAddr, + stats: Arc, + ) -> BlockingConnection { + use crate::{quic_client::QuicTpuConnection, udp_client::UdpTpuConnection}; + match self { + BaseTpuConnection::Udp(udp_socket) => { + UdpTpuConnection::new_from_addr(udp_socket.clone(), addr).into() + } + BaseTpuConnection::Quic(quic_client) => { + QuicTpuConnection::new_with_client(quic_client.clone(), stats).into() + } + } + } + + fn new_nonblocking_connection( + &self, + addr: SocketAddr, + stats: Arc, + ) -> NonblockingConnection { + use crate::nonblocking::{quic_client::QuicTpuConnection, udp_client::UdpTpuConnection}; + match self { + BaseTpuConnection::Udp(udp_socket) => { + UdpTpuConnection::new_from_addr(udp_socket.try_clone().unwrap(), addr).into() + } + BaseTpuConnection::Quic(quic_client) => { + QuicTpuConnection::new_with_client(quic_client.clone(), stats).into() + } + } + } +} + struct GetConnectionResult { - connection: Arc, + connection: Arc, cache_hit: bool, report_stats: bool, map_timing_ms: u64, @@ -519,7 +565,7 @@ struct GetConnectionResult { } struct CreateConnectionResult { - connection: Arc, + connection: Arc, cache_hit: bool, connection_cache_stats: Arc, num_evictions: u64, @@ -578,6 +624,7 @@ mod tests { assert!(map.len() == MAX_CONNECTIONS); addrs.iter().for_each(|a| { let conn = &map.get(a).expect("Address not found").connections[0]; + let conn = conn.new_blocking_connection(*a, connection_cache.stats.clone()); assert!(a.ip() == conn.tpu_addr().ip()); }); } diff --git a/client/src/nonblocking/quic_client.rs b/client/src/nonblocking/quic_client.rs index 53e20d83a1e49a..c68904babe869b 100644 --- a/client/src/nonblocking/quic_client.rs +++ b/client/src/nonblocking/quic_client.rs @@ -4,9 +4,10 @@ use { crate::{ client_error::ClientErrorKind, connection_cache::ConnectionCacheStats, - tpu_connection::ClientStats, + nonblocking::tpu_connection::TpuConnection, tpu_connection::ClientStats, }, async_mutex::Mutex, + async_trait::async_trait, futures::future::join_all, itertools::Itertools, log::*, @@ -15,8 +16,9 @@ use { }, solana_measure::measure::Measure, solana_net_utils::VALIDATOR_PORT_RANGE, - solana_sdk::quic::{ - QUIC_KEEP_ALIVE_MS, QUIC_MAX_TIMEOUT_MS, QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS, + solana_sdk::{ + quic::{QUIC_KEEP_ALIVE_MS, QUIC_MAX_TIMEOUT_MS, QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS}, + transport::Result as TransportResult, }, std::{ net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket}, @@ -424,3 +426,78 @@ impl QuicClient { self.stats.clone() } } + +pub struct QuicTpuConnection { + client: Arc, + connection_stats: Arc, +} + +impl QuicTpuConnection { + pub fn base_stats(&self) -> Arc { + self.client.stats() + } + + pub fn new( + endpoint: Arc, + addr: SocketAddr, + connection_stats: Arc, + ) -> Self { + let client = Arc::new(QuicClient::new(endpoint, addr)); + Self::new_with_client(client, connection_stats) + } + + pub fn new_with_client( + client: Arc, + connection_stats: Arc, + ) -> Self { + Self { + client, + connection_stats, + } + } +} + +#[async_trait] +impl TpuConnection for QuicTpuConnection { + fn tpu_addr(&self) -> &SocketAddr { + self.client.tpu_addr() + } + + async fn send_wire_transaction_batch(&self, buffers: &[T]) -> TransportResult<()> + where + T: AsRef<[u8]> + Send + Sync, + { + let stats = ClientStats::default(); + let len = buffers.len(); + let res = self + .client + .send_batch(buffers, &stats, self.connection_stats.clone()) + .await; + self.connection_stats + .add_client_stats(&stats, len, res.is_ok()); + res?; + Ok(()) + } + + async fn send_wire_transaction(&self, wire_transaction: T) -> TransportResult<()> + where + T: AsRef<[u8]> + Send + Sync, + { + let stats = Arc::new(ClientStats::default()); + let send_buffer = + self.client + .send_buffer(wire_transaction, &stats, self.connection_stats.clone()); + if let Err(e) = send_buffer.await { + warn!( + "Failed to send transaction async to {}, error: {:?} ", + self.tpu_addr(), + e + ); + datapoint_warn!("send-wire-async", ("failure", 1, i64),); + self.connection_stats.add_client_stats(&stats, 1, false); + } else { + self.connection_stats.add_client_stats(&stats, 1, true); + } + Ok(()) + } +} diff --git a/client/src/nonblocking/tpu_client.rs b/client/src/nonblocking/tpu_client.rs index fdff5322b24b2b..4dd39c7685407b 100644 --- a/client/src/nonblocking/tpu_client.rs +++ b/client/src/nonblocking/tpu_client.rs @@ -5,6 +5,7 @@ use { nonblocking::{ pubsub_client::{PubsubClient, PubsubClientError}, rpc_client::RpcClient, + tpu_connection::TpuConnection, }, rpc_request::MAX_GET_SIGNATURE_STATUSES_QUERY_ITEMS, rpc_response::SlotUpdate, @@ -13,10 +14,9 @@ use { LeaderTpuCache, LeaderTpuCacheUpdateInfo, RecentLeaderSlots, TpuClientConfig, MAX_FANOUT_SLOTS, SEND_TRANSACTION_INTERVAL, TRANSACTION_RESEND_INTERVAL, }, - tpu_connection::TpuConnection, }, bincode::serialize, - futures_util::stream::StreamExt, + futures_util::{future::join_all, stream::StreamExt}, log::*, solana_sdk::{ clock::Slot, @@ -68,6 +68,15 @@ pub struct TpuClient { connection_cache: Arc, } +async fn send_wire_transaction_to_addr( + connection_cache: &ConnectionCache, + addr: &SocketAddr, + wire_transaction: Vec, +) -> TransportResult<()> { + let conn = connection_cache.get_nonblocking_connection(addr); + conn.send_wire_transaction(wire_transaction.clone()).await +} + impl TpuClient { /// Serialize and send transaction to the current and upcoming leader TPUs according to fanout /// size @@ -94,17 +103,28 @@ impl TpuClient { /// Send a wire transaction to the current and upcoming leader TPUs according to fanout size /// Returns the last error if all sends fail async fn try_send_wire_transaction(&self, wire_transaction: Vec) -> TransportResult<()> { + let leaders = self + .leader_tpu_service + .leader_tpu_sockets(self.fanout_slots); + let futures = leaders + .iter() + .map(|addr| { + send_wire_transaction_to_addr( + &self.connection_cache, + addr, + wire_transaction.clone(), + ) + }) + .collect::>(); + let results: Vec> = join_all(futures).await; + let mut last_error: Option = None; let mut some_success = false; - for tpu_address in self - .leader_tpu_service - .leader_tpu_sockets(self.fanout_slots) - { - let conn = self.connection_cache.get_connection(&tpu_address); - // Fake async - let result = conn.send_wire_transaction_async(wire_transaction.clone()); - if let Err(err) = result { - last_error = Some(err); + for result in results { + if let Err(e) = result { + if last_error.is_none() { + last_error = Some(e); + } } else { some_success = true; } diff --git a/client/src/nonblocking/tpu_connection.rs b/client/src/nonblocking/tpu_connection.rs index 25190c64a8c1f9..9e819070bc0c47 100644 --- a/client/src/nonblocking/tpu_connection.rs +++ b/client/src/nonblocking/tpu_connection.rs @@ -1,7 +1,7 @@ //! Trait defining async send functions, to be used for UDP or QUIC sending use { - crate::nonblocking::udp_client::UdpTpuConnection, + crate::nonblocking::{quic_client::QuicTpuConnection, udp_client::UdpTpuConnection}, async_trait::async_trait, enum_dispatch::enum_dispatch, solana_sdk::{transaction::VersionedTransaction, transport::Result as TransportResult}, @@ -13,6 +13,7 @@ use { // trying to convert later. #[enum_dispatch] pub enum NonblockingConnection { + QuicTpuConnection, UdpTpuConnection, } diff --git a/client/src/nonblocking/udp_client.rs b/client/src/nonblocking/udp_client.rs index 1f607430e8959e..1a418765042dad 100644 --- a/client/src/nonblocking/udp_client.rs +++ b/client/src/nonblocking/udp_client.rs @@ -14,7 +14,7 @@ pub struct UdpTpuConnection { } impl UdpTpuConnection { - pub fn new(tpu_addr: SocketAddr, socket: std::net::UdpSocket) -> Self { + pub fn new_from_addr(socket: std::net::UdpSocket, tpu_addr: SocketAddr) -> Self { socket.set_nonblocking(true).unwrap(); let socket = UdpSocket::from_std(socket).unwrap(); Self { @@ -85,7 +85,7 @@ mod tests { let addr = addr_str.parse().unwrap(); let socket = solana_net_utils::bind_with_any_port(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))).unwrap(); - let connection = UdpTpuConnection::new(addr, socket); + let connection = UdpTpuConnection::new_from_addr(socket, addr); let reader = UdpSocket::bind(addr_str).await.expect("bind"); check_send_one(&connection, &reader).await; check_send_batch(&connection, &reader).await; diff --git a/client/src/quic_client.rs b/client/src/quic_client.rs index c91a20d99332cc..7db1ecd3dbcf5b 100644 --- a/client/src/quic_client.rs +++ b/client/src/quic_client.rs @@ -4,11 +4,16 @@ use { crate::{ connection_cache::ConnectionCacheStats, - nonblocking::quic_client::{QuicClient, QuicLazyInitializedEndpoint}, - tpu_connection::{ClientStats, TpuConnection}, + nonblocking::{ + quic_client::{ + QuicClient, QuicLazyInitializedEndpoint, + QuicTpuConnection as NonblockingQuicTpuConnection, + }, + tpu_connection::TpuConnection as NonblockingTpuConnection, + }, + tpu_connection::TpuConnection, }, lazy_static::lazy_static, - log::*, solana_sdk::transport::Result as TransportResult, std::{net::SocketAddr, sync::Arc}, tokio::runtime::Runtime, @@ -22,92 +27,58 @@ lazy_static! { } pub struct QuicTpuConnection { - client: Arc, - connection_stats: Arc, + inner: Arc, } - impl QuicTpuConnection { - pub fn base_stats(&self) -> Arc { - self.client.stats() - } - pub fn new( endpoint: Arc, tpu_addr: SocketAddr, connection_stats: Arc, ) -> Self { - let client = Arc::new(QuicClient::new(endpoint, tpu_addr)); + let inner = Arc::new(NonblockingQuicTpuConnection::new( + endpoint, + tpu_addr, + connection_stats, + )); + Self { inner } + } - Self { + pub fn new_with_client( + client: Arc, + connection_stats: Arc, + ) -> Self { + let inner = Arc::new(NonblockingQuicTpuConnection::new_with_client( client, connection_stats, - } + )); + Self { inner } } } impl TpuConnection for QuicTpuConnection { fn tpu_addr(&self) -> &SocketAddr { - self.client.tpu_addr() + self.inner.tpu_addr() } fn send_wire_transaction_batch(&self, buffers: &[T]) -> TransportResult<()> where - T: AsRef<[u8]>, + T: AsRef<[u8]> + Send + Sync, { - let stats = ClientStats::default(); - let len = buffers.len(); - let _guard = RUNTIME.enter(); - let send_batch = self - .client - .send_batch(buffers, &stats, self.connection_stats.clone()); - let res = RUNTIME.block_on(send_batch); - self.connection_stats - .add_client_stats(&stats, len, res.is_ok()); - res?; + let _res = RUNTIME.block_on(self.inner.send_wire_transaction_batch(buffers))?; Ok(()) } fn send_wire_transaction_async(&self, wire_transaction: Vec) -> TransportResult<()> { - let stats = Arc::new(ClientStats::default()); - let _guard = RUNTIME.enter(); - let client = self.client.clone(); - let connection_stats = self.connection_stats.clone(); + let inner = self.inner.clone(); //drop and detach the task - let _ = RUNTIME.spawn(async move { - let send_buffer = - client.send_buffer(wire_transaction, &stats, connection_stats.clone()); - if let Err(e) = send_buffer.await { - warn!( - "Failed to send transaction async to {}, error: {:?} ", - client.tpu_addr(), - e - ); - datapoint_warn!("send-wire-async", ("failure", 1, i64),); - connection_stats.add_client_stats(&stats, 1, false); - } else { - connection_stats.add_client_stats(&stats, 1, true); - } - }); + let _ = RUNTIME.spawn(async move { inner.send_wire_transaction(wire_transaction).await }); Ok(()) } fn send_wire_transaction_batch_async(&self, buffers: Vec>) -> TransportResult<()> { - let stats = Arc::new(ClientStats::default()); - let _guard = RUNTIME.enter(); - let client = self.client.clone(); - let connection_stats = self.connection_stats.clone(); - let len = buffers.len(); + let inner = self.inner.clone(); //drop and detach the task - let _ = RUNTIME.spawn(async move { - let send_batch = client.send_batch(&buffers, &stats, connection_stats.clone()); - if let Err(e) = send_batch.await { - warn!("Failed to send transaction batch async to {:?}", e); - datapoint_warn!("send-wire-batch-async", ("failure", 1, i64),); - connection_stats.add_client_stats(&stats, len, false); - } else { - connection_stats.add_client_stats(&stats, len, true); - } - }); + let _ = RUNTIME.spawn(async move { inner.send_wire_transaction_batch(&buffers).await }); Ok(()) } } diff --git a/client/src/tpu_connection.rs b/client/src/tpu_connection.rs index 1f694a1ecc8770..9f02319379c942 100644 --- a/client/src/tpu_connection.rs +++ b/client/src/tpu_connection.rs @@ -24,12 +24,12 @@ pub struct ClientStats { } #[enum_dispatch] -pub enum Connection { +pub enum BlockingConnection { UdpTpuConnection, QuicTpuConnection, } -#[enum_dispatch(Connection)] +#[enum_dispatch(BlockingConnection)] pub trait TpuConnection { fn tpu_addr(&self) -> &SocketAddr; @@ -44,7 +44,7 @@ pub trait TpuConnection { fn send_wire_transaction(&self, wire_transaction: T) -> TransportResult<()> where - T: AsRef<[u8]>, + T: AsRef<[u8]> + Send + Sync, { self.send_wire_transaction_batch(&[wire_transaction]) } @@ -65,7 +65,7 @@ pub trait TpuConnection { fn send_wire_transaction_batch(&self, buffers: &[T]) -> TransportResult<()> where - T: AsRef<[u8]>; + T: AsRef<[u8]> + Send + Sync; fn send_wire_transaction_batch_async(&self, buffers: Vec>) -> TransportResult<()>; } diff --git a/client/src/udp_client.rs b/client/src/udp_client.rs index b0fa879f0222de..2346eb213ba993 100644 --- a/client/src/udp_client.rs +++ b/client/src/udp_client.rs @@ -46,12 +46,13 @@ impl TpuConnection for UdpTpuConnection { fn send_wire_transaction_batch(&self, buffers: &[T]) -> TransportResult<()> where - T: AsRef<[u8]>, + T: AsRef<[u8]> + Send + Sync, { let pkts: Vec<_> = buffers.iter().zip(repeat(self.tpu_addr())).collect(); batch_send(&self.socket, &pkts)?; Ok(()) } + fn send_wire_transaction_batch_async(&self, buffers: Vec>) -> TransportResult<()> { let pkts: Vec<_> = buffers.into_iter().zip(repeat(self.tpu_addr())).collect(); batch_send(&self.socket, &pkts)?; diff --git a/client/tests/quic_client.rs b/client/tests/quic_client.rs index 40d02cbbbe1dcf..2ccce3606080c1 100644 --- a/client/tests/quic_client.rs +++ b/client/tests/quic_client.rs @@ -1,19 +1,16 @@ #[cfg(test)] mod tests { use { - crossbeam_channel::unbounded, + crossbeam_channel::{unbounded, Receiver}, solana_client::{ connection_cache::ConnectionCacheStats, - nonblocking::quic_client::QuicLazyInitializedEndpoint, quic_client::QuicTpuConnection, - tpu_connection::TpuConnection, + nonblocking::quic_client::QuicLazyInitializedEndpoint, }, + solana_perf::packet::PacketBatch, solana_sdk::{packet::PACKET_DATA_SIZE, signature::Keypair}, - solana_streamer::{ - quic::{spawn_server, StreamStats}, - streamer::StakedNodes, - }, + solana_streamer::{quic::StreamStats, streamer::StakedNodes}, std::{ - net::{SocketAddr, UdpSocket}, + net::{IpAddr, SocketAddr, UdpSocket}, sync::{ atomic::{AtomicBool, Ordering}, Arc, RwLock, @@ -22,17 +19,55 @@ mod tests { }, }; + fn check_packets( + receiver: Receiver, + num_bytes: usize, + num_expected_packets: usize, + ) { + let mut all_packets = vec![]; + let now = Instant::now(); + let mut total_packets: usize = 0; + while now.elapsed().as_secs() < 5 { + if let Ok(packets) = receiver.recv_timeout(Duration::from_secs(1)) { + total_packets = total_packets.saturating_add(packets.len()); + all_packets.push(packets) + } + if total_packets >= num_expected_packets { + break; + } + } + for batch in all_packets { + for p in &batch { + assert_eq!(p.meta.size, num_bytes); + } + } + assert_eq!(total_packets, num_expected_packets); + } + + fn server_args() -> ( + UdpSocket, + Arc, + Keypair, + IpAddr, + Arc, + ) { + ( + UdpSocket::bind("127.0.0.1:0").unwrap(), + Arc::new(AtomicBool::new(false)), + Keypair::new(), + "127.0.0.1".parse().unwrap(), + Arc::new(StreamStats::default()), + ) + } + #[test] fn test_quic_client_multiple_writes() { + use solana_client::{quic_client::QuicTpuConnection, tpu_connection::TpuConnection}; solana_logger::setup(); - let s = UdpSocket::bind("127.0.0.1:0").unwrap(); - let exit = Arc::new(AtomicBool::new(false)); let (sender, receiver) = unbounded(); - let keypair = Keypair::new(); - let ip = "127.0.0.1".parse().unwrap(); let staked_nodes = Arc::new(RwLock::new(StakedNodes::default())); - let stats = Arc::new(StreamStats::default()); - let t = spawn_server( + let (s, exit, keypair, ip, stats) = server_args(); + let t = solana_streamer::quic::spawn_server( s.try_clone().unwrap(), &keypair, ip, @@ -63,26 +98,53 @@ mod tests { assert!(client.send_wire_transaction_batch_async(packets).is_ok()); - let mut all_packets = vec![]; - let now = Instant::now(); - let mut total_packets = 0; - while now.elapsed().as_secs() < 5 { - if let Ok(packets) = receiver.recv_timeout(Duration::from_secs(1)) { - total_packets += packets.len(); - all_packets.push(packets) - } - if total_packets >= num_expected_packets { - break; - } - } - for batch in all_packets { - for p in &batch { - assert_eq!(p.meta.size, num_bytes); - } - } - assert_eq!(total_packets, num_expected_packets); - + check_packets(receiver, num_bytes, num_expected_packets); exit.store(true, Ordering::Relaxed); t.join().unwrap(); } + + #[tokio::test] + async fn test_nonblocking_quic_client_multiple_writes() { + use solana_client::nonblocking::{ + quic_client::QuicTpuConnection, tpu_connection::TpuConnection, + }; + solana_logger::setup(); + let (sender, receiver) = unbounded(); + let staked_nodes = Arc::new(RwLock::new(StakedNodes::default())); + let (s, exit, keypair, ip, stats) = server_args(); + let t = solana_streamer::nonblocking::quic::spawn_server( + s.try_clone().unwrap(), + &keypair, + ip, + sender, + exit.clone(), + 1, + staked_nodes, + 10, + 10, + stats, + ) + .unwrap(); + + let addr = s.local_addr().unwrap().ip(); + let port = s.local_addr().unwrap().port(); + let tpu_addr = SocketAddr::new(addr, port); + let connection_cache_stats = Arc::new(ConnectionCacheStats::default()); + let client = QuicTpuConnection::new( + Arc::new(QuicLazyInitializedEndpoint::default()), + tpu_addr, + connection_cache_stats, + ); + + // Send a full size packet with single byte writes. + let num_bytes = PACKET_DATA_SIZE; + let num_expected_packets: usize = 4000; + let packets = vec![vec![0u8; PACKET_DATA_SIZE]; num_expected_packets]; + + assert!(client.send_wire_transaction_batch(&packets).await.is_ok()); + + check_packets(receiver, num_bytes, num_expected_packets); + exit.store(true, Ordering::Relaxed); + t.await.unwrap(); + } }