Skip to content

Commit

Permalink
prunes turbine QUIC connections
Browse files Browse the repository at this point in the history
The commit implements lazy eviction for turbine QUIC connections.
The cache is allowed to grow to 2 x capacity at which point at least
half of the entries with lowest stake are evicted, resulting in an
amortized O(1) performance.
  • Loading branch information
behzadnouri committed Oct 11, 2023
1 parent 42fa18d commit a8e91ec
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 12 deletions.
1 change: 1 addition & 0 deletions core/src/validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,7 @@ impl Validator {
.expect("Operator must spin up node with valid QUIC TVU address")
.ip(),
turbine_quic_endpoint_sender,
bank_forks.clone(),
)
.unwrap();

Expand Down
155 changes: 143 additions & 12 deletions turbine/src/quic_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@ use {
rcgen::RcgenError,
rustls::{Certificate, PrivateKey},
solana_quic_client::nonblocking::quic_client::SkipServerVerification,
solana_runtime::bank_forks::BankForks,
solana_sdk::{pubkey::Pubkey, signature::Keypair},
solana_streamer::{
quic::SkipClientVerification, tls_certificates::new_self_signed_tls_certificate,
},
std::{
cmp::Reverse,
collections::{hash_map::Entry, HashMap},
io::Error as IoError,
net::{IpAddr, SocketAddr, UdpSocket},
sync::Arc,
sync::{
atomic::{AtomicBool, Ordering},
Arc, RwLock,
},
},
thiserror::Error,
tokio::{
Expand All @@ -32,6 +37,7 @@ use {

const CLIENT_CHANNEL_BUFFER: usize = 1 << 14;
const ROUTER_CHANNEL_BUFFER: usize = 64;
const CONNECTION_CACHE_CAPACITY: usize = 3072;
const INITIAL_MAXIMUM_TRANSMISSION_UNIT: u16 = 1280;
const ALPN_TURBINE_PROTOCOL_ID: &[u8] = b"solana-turbine";
const CONNECT_SERVER_NAME: &str = "solana-turbine";
Expand Down Expand Up @@ -75,6 +81,7 @@ pub fn new_quic_endpoint(
socket: UdpSocket,
address: IpAddr,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
bank_forks: Arc<RwLock<BankForks>>,
) -> Result<
(
Endpoint,
Expand All @@ -98,19 +105,24 @@ pub fn new_quic_endpoint(
)?
};
endpoint.set_default_client_config(client_config);
let prune_cache_pending = Arc::<AtomicBool>::default();
let cache = Arc::<Mutex<HashMap<Pubkey, Connection>>>::default();
let router = Arc::<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>::default();
let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_BUFFER);
let server_task = runtime.spawn(run_server(
endpoint.clone(),
sender.clone(),
bank_forks.clone(),
prune_cache_pending.clone(),
router.clone(),
cache.clone(),
));
let client_task = runtime.spawn(run_client(
endpoint.clone(),
client_receiver,
sender,
bank_forks,
prune_cache_pending,
router,
cache,
));
Expand Down Expand Up @@ -163,6 +175,8 @@ fn new_transport_config() -> TransportConfig {
async fn run_server(
endpoint: Endpoint,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
Expand All @@ -171,6 +185,8 @@ async fn run_server(
endpoint.clone(),
connecting,
sender.clone(),
bank_forks.clone(),
prune_cache_pending.clone(),
router.clone(),
cache.clone(),
));
Expand All @@ -181,6 +197,8 @@ async fn run_client(
endpoint: Endpoint,
mut receiver: AsyncReceiver<(SocketAddr, Bytes)>,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
Expand Down Expand Up @@ -219,6 +237,8 @@ async fn run_client(
remote_address,
sender.clone(),
receiver,
bank_forks.clone(),
prune_cache_pending.clone(),
router.clone(),
cache.clone(),
));
Expand All @@ -232,10 +252,22 @@ async fn handle_connecting_error(
endpoint: Endpoint,
connecting: Connecting,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
if let Err(err) = handle_connecting(endpoint, connecting, sender, router, cache).await {
if let Err(err) = handle_connecting(
endpoint,
connecting,
sender,
bank_forks,
prune_cache_pending,
router,
cache,
)
.await
{
error!("handle_connecting: {err:?}");
}
}
Expand All @@ -244,6 +276,8 @@ async fn handle_connecting(
endpoint: Endpoint,
connecting: Connecting,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) -> Result<(), Error> {
Expand All @@ -262,24 +296,37 @@ async fn handle_connecting(
connection,
sender,
receiver,
bank_forks,
prune_cache_pending,
router,
cache,
)
.await;
Ok(())
}

#[allow(clippy::too_many_arguments)]
async fn handle_connection(
endpoint: Endpoint,
remote_address: SocketAddr,
remote_pubkey: Pubkey,
connection: Connection,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
receiver: AsyncReceiver<Bytes>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
cache_connection(remote_pubkey, connection.clone(), &cache).await;
cache_connection(
remote_pubkey,
connection.clone(),
bank_forks,
prune_cache_pending,
router.clone(),
cache.clone(),
)
.await;
let send_datagram_task = tokio::task::spawn(send_datagram_task(connection.clone(), receiver));
let read_datagram_task = tokio::task::spawn(read_datagram_task(
endpoint,
Expand Down Expand Up @@ -349,11 +396,22 @@ async fn make_connection_task(
remote_address: SocketAddr,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
receiver: AsyncReceiver<Bytes>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
if let Err(err) =
make_connection(endpoint, remote_address, sender, receiver, router, cache).await
if let Err(err) = make_connection(
endpoint,
remote_address,
sender,
receiver,
bank_forks,
prune_cache_pending,
router,
cache,
)
.await
{
error!("make_connection: {remote_address}, {err:?}");
}
Expand All @@ -364,6 +422,8 @@ async fn make_connection(
remote_address: SocketAddr,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
receiver: AsyncReceiver<Bytes>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) -> Result<(), Error> {
Expand All @@ -377,6 +437,8 @@ async fn make_connection(
connection,
sender,
receiver,
bank_forks,
prune_cache_pending,
router,
cache,
)
Expand All @@ -400,15 +462,32 @@ fn get_remote_pubkey(connection: &Connection) -> Result<Pubkey, Error> {
async fn cache_connection(
remote_pubkey: Pubkey,
connection: Connection,
cache: &Mutex<HashMap<Pubkey, Connection>>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
let Some(old) = cache.lock().await.insert(remote_pubkey, connection) else {
return;
let (old, should_prune_cache) = {
let mut cache = cache.lock().await;
(
cache.insert(remote_pubkey, connection),
cache.len() >= CONNECTION_CACHE_CAPACITY.saturating_mul(2),
)
};
old.close(
CONNECTION_CLOSE_ERROR_CODE_REPLACED,
CONNECTION_CLOSE_REASON_REPLACED,
);
if let Some(old) = old {
old.close(
CONNECTION_CLOSE_ERROR_CODE_REPLACED,
CONNECTION_CLOSE_REASON_REPLACED,
);
}
if should_prune_cache && !prune_cache_pending.swap(true, Ordering::Relaxed) {
tokio::task::spawn(prune_connection_cache(
bank_forks,
prune_cache_pending,
router,
cache,
));
}
}

async fn drop_connection(
Expand All @@ -427,6 +506,49 @@ async fn drop_connection(
}
}

async fn prune_connection_cache(
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
debug_assert!(prune_cache_pending.load(Ordering::Relaxed));
let staked_nodes = {
let root_bank = bank_forks.read().unwrap().root_bank();
root_bank.staked_nodes()
};
{
let mut cache = cache.lock().await;
if cache.len() < CONNECTION_CACHE_CAPACITY.saturating_mul(2) {
return;
}
let mut connections: Vec<_> = cache
.drain()
.filter(|(_, connection)| connection.close_reason().is_none())
.map(|entry @ (pubkey, _)| {
let stake = staked_nodes.get(&pubkey).copied().unwrap_or_default();
(stake, entry)
})
.collect();
connections
.select_nth_unstable_by_key(CONNECTION_CACHE_CAPACITY, |&(stake, _)| Reverse(stake));
for (_, (_, connection)) in &connections[CONNECTION_CACHE_CAPACITY..] {
connection.close(
CONNECTION_CLOSE_ERROR_CODE_DROPPED,
CONNECTION_CLOSE_REASON_DROPPED,
);
}
cache.extend(
connections
.into_iter()
.take(CONNECTION_CACHE_CAPACITY)
.map(|(_, entry)| entry),
);
prune_cache_pending.store(false, Ordering::Relaxed);
}
router.write().await.retain(|_, sender| !sender.is_closed());
}

impl<T> From<crossbeam_channel::SendError<T>> for Error {
fn from(_: crossbeam_channel::SendError<T>) -> Self {
Error::ChannelSendError
Expand All @@ -438,6 +560,8 @@ mod tests {
use {
super::*,
itertools::{izip, multiunzip},
solana_ledger::genesis_utils::{create_genesis_config, GenesisConfigInfo},
solana_runtime::bank::Bank,
solana_sdk::signature::Signer,
std::{iter::repeat_with, net::Ipv4Addr, time::Duration},
};
Expand Down Expand Up @@ -465,6 +589,12 @@ mod tests {
repeat_with(crossbeam_channel::unbounded::<(Pubkey, SocketAddr, Bytes)>)
.take(NUM_ENDPOINTS)
.unzip();
let bank_forks = {
let GenesisConfigInfo { genesis_config, .. } =
create_genesis_config(/*mint_lamports:*/ 100_000);
let bank = Bank::new_for_tests(&genesis_config);
Arc::new(RwLock::new(BankForks::new(bank)))
};
let (endpoints, senders, tasks): (Vec<_>, Vec<_>, Vec<_>) =
multiunzip(keypairs.iter().zip(sockets).zip(senders).map(
|((keypair, socket), sender)| {
Expand All @@ -474,6 +604,7 @@ mod tests {
socket,
IpAddr::V4(Ipv4Addr::LOCALHOST),
sender,
bank_forks.clone(),
)
.unwrap()
},
Expand Down

0 comments on commit a8e91ec

Please sign in to comment.