diff --git a/streamer/src/nonblocking/quic.rs b/streamer/src/nonblocking/quic.rs index bd0c352397eb52..f668a08edd38c3 100644 --- a/streamer/src/nonblocking/quic.rs +++ b/streamer/src/nonblocking/quic.rs @@ -33,13 +33,26 @@ use { std::{ iter::repeat_with, net::{IpAddr, SocketAddr, UdpSocket}, + // CAUTION: be careful not to introduce any awaits while holding an RwLock. sync::{ atomic::{AtomicBool, AtomicU64, Ordering}, - Arc, Mutex, MutexGuard, RwLock, + Arc, RwLock, }, time::{Duration, Instant}, }, - tokio::{task::JoinHandle, time::timeout}, + tokio::{ + // CAUTION: It's kind of sketch that we're mixing async and sync locks (see the RwLock above). + // This is done so that sync code can also access the stake table. + // Make sure we don't hold a sync lock across an await - including the await to + // lock an async Mutex. This does not happen now and should not happen as long as we + // don't hold an async Mutex and sync RwLock at the same time (currently true) + // but if we do, the scope of the RwLock must always be a subset of the async Mutex + // (i.e. lock order is always async Mutex -> RwLock). Also, be careful not to + // introduce any other awaits while holding the RwLock. + sync::{Mutex, MutexGuard}, + task::JoinHandle, + time::timeout, + }, }; const WAIT_FOR_STREAM_TIMEOUT: Duration = Duration::from_millis(100); @@ -384,7 +397,7 @@ fn handle_and_cache_new_connection( } } -fn prune_unstaked_connections_and_add_new_connection( +async fn prune_unstaked_connections_and_add_new_connection( connection: Connection, connection_table: Arc>, max_connections: usize, @@ -395,7 +408,7 @@ fn prune_unstaked_connections_and_add_new_connection( let stats = params.stats.clone(); if max_connections > 0 { let connection_table_clone = connection_table.clone(); - let mut connection_table = connection_table.lock().unwrap(); + let mut connection_table = connection_table.lock().await; prune_unstaked_connection_table(&mut connection_table, max_connections, stats); handle_and_cache_new_connection( connection, @@ -505,7 +518,8 @@ async fn setup_connection( match params.peer_type { ConnectionPeerType::Staked(stake) => { - let mut connection_table_l = staked_connection_table.lock().unwrap(); + let mut connection_table_l = staked_connection_table.lock().await; + if connection_table_l.total_size >= max_staked_connections { let num_pruned = connection_table_l.prune_random(PRUNE_RANDOM_SAMPLE_SIZE, stake); @@ -536,7 +550,9 @@ async fn setup_connection( ¶ms, wait_for_chunk_timeout, stream_load_ema.clone(), - ) { + ) + .await + { stats .connection_added_from_staked_peer .fetch_add(1, Ordering::Relaxed); @@ -558,7 +574,9 @@ async fn setup_connection( ¶ms, wait_for_chunk_timeout, stream_load_ema.clone(), - ) { + ) + .await + { stats .connection_added_from_unstaked_peer .fetch_add(1, Ordering::Relaxed); @@ -801,7 +819,7 @@ async fn handle_connection( } } - let removed_connection_count = connection_table.lock().unwrap().remove_connection( + let removed_connection_count = connection_table.lock().await.remove_connection( ConnectionTableKey::new(remote_addr.ip(), params.remote_pubkey), remote_addr.port(), stable_id,