diff --git a/applications/tari_base_node/src/bootstrap.rs b/applications/tari_base_node/src/bootstrap.rs index 0636c5b3639..5021e527847 100644 --- a/applications/tari_base_node/src/bootstrap.rs +++ b/applications/tari_base_node/src/bootstrap.rs @@ -188,6 +188,7 @@ where B: BlockchainBackend + 'static let base_node_service = handles.expect_handle::(); let rpc_server = RpcServer::builder() .with_maximum_simultaneous_sessions(config.rpc_max_simultaneous_sessions) + .with_maximum_sessions_per_client(config.rpc_max_sessions_per_peer) .finish(); // Add your RPC services here ‍🏴‍☠️️☮️🌊 diff --git a/applications/tari_miner/src/main.rs b/applications/tari_miner/src/main.rs index 06a48ac577e..21ee1eff693 100644 --- a/applications/tari_miner/src/main.rs +++ b/applications/tari_miner/src/main.rs @@ -333,7 +333,7 @@ async fn display_report(report: &MiningReport, num_mining_threads: usize) { let hashrate = report.hashes as f64 / report.elapsed.as_micros() as f64; info!( target: LOG_TARGET, - "⛏ Miner {} reported {:.2}MH/s with total {:.2}MH/s over {} threads. Height: {}. Target: {})", + "⛏ Miner {:0>2} reported {:.2}MH/s with total {:.2}MH/s over {} threads. Height: {}. Target: {})", report.miner, hashrate, hashrate * num_mining_threads as f64, diff --git a/base_layer/p2p/src/config.rs b/base_layer/p2p/src/config.rs index 6d9a796d6e4..a6b45122919 100644 --- a/base_layer/p2p/src/config.rs +++ b/base_layer/p2p/src/config.rs @@ -118,6 +118,9 @@ pub struct P2pConfig { /// The global maximum allowed RPC sessions. /// Default: 100 pub rpc_max_simultaneous_sessions: usize, + /// The maximum allowed RPC sessions per peer. + /// Default: 11 (one more than the pool size for wallets) + pub rpc_max_sessions_per_peer: usize, } impl Default for P2pConfig { @@ -141,6 +144,7 @@ impl Default for P2pConfig { user_agent: "".to_string(), auxiliary_tcp_listener_address: None, rpc_max_simultaneous_sessions: 100, + rpc_max_sessions_per_peer: 2, } } } diff --git a/comms/core/src/protocol/rpc/mod.rs b/comms/core/src/protocol/rpc/mod.rs index 3d76b333db2..a48e56e891a 100644 --- a/comms/core/src/protocol/rpc/mod.rs +++ b/comms/core/src/protocol/rpc/mod.rs @@ -63,7 +63,7 @@ pub use body::{Body, ClientStreaming, IntoBody, Streaming}; mod context; mod server; -pub use server::{mock, NamedProtocolService, RpcServer, RpcServerError, RpcServerHandle}; +pub use server::{mock, NamedProtocolService, RpcServer, RpcServerBuilder, RpcServerError, RpcServerHandle}; mod client; pub use client::{ diff --git a/comms/core/src/protocol/rpc/server/error.rs b/comms/core/src/protocol/rpc/server/error.rs index 1049480f885..38f257b4238 100644 --- a/comms/core/src/protocol/rpc/server/error.rs +++ b/comms/core/src/protocol/rpc/server/error.rs @@ -25,7 +25,7 @@ use std::io; use prost::DecodeError; use tokio::sync::oneshot; -use crate::{proto, protocol::rpc::handshake::RpcHandshakeError}; +use crate::{peer_manager::NodeId, proto, protocol::rpc::handshake::RpcHandshakeError}; #[derive(Debug, thiserror::Error)] pub enum RpcServerError { @@ -35,6 +35,8 @@ pub enum RpcServerError { Io(#[from] io::Error), #[error("Maximum number of RPC sessions reached")] MaximumSessionsReached, + #[error("Maximum number of client RPC sessions reached for node {node_id}")] + MaxSessionsPerClientReached { node_id: NodeId }, #[error("Internal service request canceled")] RequestCanceled, #[error("Stream was closed by remote")] diff --git a/comms/core/src/protocol/rpc/server/mod.rs b/comms/core/src/protocol/rpc/server/mod.rs index 79530247700..a85652dfe2a 100644 --- a/comms/core/src/protocol/rpc/server/mod.rs +++ b/comms/core/src/protocol/rpc/server/mod.rs @@ -35,14 +35,17 @@ mod metrics; pub mod mock; mod router; +mod session; + use std::{ borrow::Cow, cmp, + collections::HashMap, convert::TryFrom, future::Future, io, pin::Pin, - sync::Arc, + sync::{atomic, atomic::AtomicUsize, Arc}, task::Poll, time::{Duration, Instant}, }; @@ -76,6 +79,7 @@ use crate::{ rpc::{ body::BodyBytes, message::{RpcMethod, RpcResponse}, + server::session::SessionToken, }, ProtocolEvent, ProtocolId, @@ -169,6 +173,7 @@ impl Default for RpcServer { #[derive(Clone)] pub struct RpcServerBuilder { maximum_simultaneous_sessions: Option, + maximum_sessions_per_client: Option, minimum_client_deadline: Duration, handshake_timeout: Duration, } @@ -188,6 +193,16 @@ impl RpcServerBuilder { self } + pub fn with_maximum_sessions_per_client(mut self, limit: usize) -> Self { + self.maximum_sessions_per_client = Some(cmp::min(limit, BoundedExecutor::max_theoretical_tasks())); + self + } + + pub fn with_unlimited_sessions_per_client(mut self) -> Self { + self.maximum_sessions_per_client = None; + self + } + pub fn with_minimum_client_deadline(mut self, deadline: Duration) -> Self { self.minimum_client_deadline = deadline; self @@ -206,7 +221,8 @@ impl RpcServerBuilder { impl Default for RpcServerBuilder { fn default() -> Self { Self { - maximum_simultaneous_sessions: Some(1000), + maximum_simultaneous_sessions: None, + maximum_sessions_per_client: None, minimum_client_deadline: Duration::from_secs(1), handshake_timeout: Duration::from_secs(15), } @@ -220,6 +236,7 @@ pub(super) struct PeerRpcServer { protocol_notifications: Option>, comms_provider: TCommsProvider, request_rx: mpsc::Receiver, + sessions: HashMap>, } impl PeerRpcServer @@ -255,6 +272,7 @@ where protocol_notifications: Some(protocol_notifications), comms_provider, request_rx, + sessions: HashMap::new(), } } @@ -336,6 +354,28 @@ where Ok(()) } + fn new_session_for(&mut self, node_id: NodeId) -> Result { + match self.config.maximum_sessions_per_client { + Some(max) => { + let counter = self + .sessions + .entry(node_id.clone()) + .or_insert_with(|| Arc::new(AtomicUsize::new(0))); + + let count = counter.load(atomic::Ordering::Acquire); + + debug_assert!(count <= max); + if count >= max { + // metrics::max_sessions_per_client_reached_counter(&node_id).inc(); + return Err(RpcServerError::MaxSessionsPerClientReached { node_id }); + } + counter.fetch_add(1, atomic::Ordering::Release); + Ok(SessionToken::new(counter.clone())) + }, + None => Ok(SessionToken::nop()), + } + } + #[tracing::instrument(name = "rpc::server::try_initiate_service", skip(self, framed), err)] async fn try_initiate_service( &mut self, @@ -374,6 +414,16 @@ where }, }; + let session_token = match self.new_session_for(node_id.clone()) { + Ok(token) => token, + Err(err) => { + handshake + .reject_with_reason(HandshakeRejectReason::NoSessionsAvailable) + .await?; + return Err(err); + }, + }; + let version = handshake.perform_server_handshake().await?; debug!( target: LOG_TARGET, @@ -396,6 +446,7 @@ where num_sessions.inc(); service.start().await; num_sessions.dec(); + drop(session_token); }) .map_err(|_| RpcServerError::MaximumSessionsReached)?; diff --git a/comms/core/src/protocol/rpc/test/smoke.rs b/comms/core/src/protocol/rpc/test/smoke.rs index 699b4360a32..515ba4f41cb 100644 --- a/comms/core/src/protocol/rpc/test/smoke.rs +++ b/comms/core/src/protocol/rpc/test/smoke.rs @@ -55,6 +55,7 @@ use crate::{ }, RpcError, RpcServer, + RpcServerBuilder, RpcStatusCode, }, ProtocolEvent, @@ -67,25 +68,23 @@ use crate::{ Substream, }; -pub(super) async fn setup_service( +pub(super) async fn setup_service_with_builder( service_impl: T, - num_concurrent_sessions: usize, + builder: RpcServerBuilder, ) -> ( mpsc::Sender>, task::JoinHandle<()>, RpcCommsBackend, Shutdown, ) { - let (notif_tx, notif_rx) = mpsc::channel(1); + let (notif_tx, notif_rx) = mpsc::channel(10); let shutdown = Shutdown::new(); let (context, _) = create_mocked_rpc_context(); let server_hnd = task::spawn({ let context = context.clone(); let shutdown_signal = shutdown.to_signal(); async move { - let fut = RpcServer::builder() - .with_maximum_simultaneous_sessions(num_concurrent_sessions) - .with_minimum_client_deadline(Duration::from_secs(0)) + let fut = builder .finish() .add_service(GreetingServer::new(service_impl)) .serve(notif_rx, context); @@ -97,9 +96,25 @@ pub(super) async fn setup_service( } } }); + (notif_tx, server_hnd, context, shutdown) } +pub(super) async fn setup_service( + service_impl: T, + num_concurrent_sessions: usize, +) -> ( + mpsc::Sender>, + task::JoinHandle<()>, + RpcCommsBackend, + Shutdown, +) { + let builder = RpcServer::builder() + .with_maximum_simultaneous_sessions(num_concurrent_sessions) + .with_minimum_client_deadline(Duration::from_secs(0)); + setup_service_with_builder(service_impl, builder).await +} + pub(super) async fn setup( service_impl: T, num_concurrent_sessions: usize, @@ -453,3 +468,121 @@ async fn stream_interruption_handling() { .unwrap() .unwrap(); } + +#[runtime::test] +async fn max_global_sessions() { + let builder = RpcServer::builder().with_maximum_simultaneous_sessions(1); + let (muxer, _outbound, context, _shutdown) = setup_service_with_builder(GreetingService::default(), builder).await; + let (_, mut inbound, outbound) = build_multiplexed_connections().await; + + let node_identity = build_node_identity(Default::default()); + // Notify that a peer wants to speak the greeting RPC protocol + context.peer_manager().add_peer(node_identity.to_peer()).await.unwrap(); + + for _ in 0..2 { + let substream = outbound.get_yamux_control().open_stream().await.unwrap(); + muxer + .send(ProtocolNotification::new( + ProtocolId::from_static(b"/test/greeting/1.0"), + ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), substream), + )) + .await + .unwrap(); + } + + let socket = inbound.incoming_mut().next().await.unwrap(); + let framed = framing::canonical(socket, 1024); + let mut client = GreetingClient::builder() + .with_deadline(Duration::from_secs(5)) + .connect(framed) + .await + .unwrap(); + + let socket = inbound.incoming_mut().next().await.unwrap(); + let framed = framing::canonical(socket, 1024); + let err = GreetingClient::builder() + .with_deadline(Duration::from_secs(5)) + .connect(framed) + .await + .unwrap_err(); + + unpack_enum!(RpcError::HandshakeError(err) = err); + unpack_enum!(RpcHandshakeError::Rejected(HandshakeRejectReason::NoSessionsAvailable) = err); + + client.close().await; + let substream = outbound.get_yamux_control().open_stream().await.unwrap(); + muxer + .send(ProtocolNotification::new( + ProtocolId::from_static(b"/test/greeting/1.0"), + ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), substream), + )) + .await + .unwrap(); + let socket = inbound.incoming_mut().next().await.unwrap(); + let framed = framing::canonical(socket, 1024); + let _client = GreetingClient::builder() + .with_deadline(Duration::from_secs(5)) + .connect(framed) + .await + .unwrap(); +} + +#[runtime::test] +async fn max_per_client_sessions() { + let builder = RpcServer::builder() + .with_maximum_simultaneous_sessions(3) + .with_maximum_sessions_per_client(1); + let (muxer, _outbound, context, _shutdown) = setup_service_with_builder(GreetingService::default(), builder).await; + let (_, mut inbound, outbound) = build_multiplexed_connections().await; + + let node_identity = build_node_identity(Default::default()); + // Notify that a peer wants to speak the greeting RPC protocol + context.peer_manager().add_peer(node_identity.to_peer()).await.unwrap(); + for _ in 0..2 { + let substream = outbound.get_yamux_control().open_stream().await.unwrap(); + muxer + .send(ProtocolNotification::new( + ProtocolId::from_static(b"/test/greeting/1.0"), + ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), substream), + )) + .await + .unwrap(); + } + + let socket = inbound.incoming_mut().next().await.unwrap(); + let framed = framing::canonical(socket, 1024); + let mut client = GreetingClient::builder() + .with_deadline(Duration::from_secs(5)) + .connect(framed) + .await + .unwrap(); + + let socket = inbound.incoming_mut().next().await.unwrap(); + let framed = framing::canonical(socket, 1024); + let err = GreetingClient::builder() + .with_deadline(Duration::from_secs(5)) + .connect(framed) + .await + .unwrap_err(); + + unpack_enum!(RpcError::HandshakeError(err) = err); + unpack_enum!(RpcHandshakeError::Rejected(HandshakeRejectReason::NoSessionsAvailable) = err); + + client.close().await; + drop(client); + let substream = outbound.get_yamux_control().open_stream().await.unwrap(); + muxer + .send(ProtocolNotification::new( + ProtocolId::from_static(b"/test/greeting/1.0"), + ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), substream), + )) + .await + .unwrap(); + let socket = inbound.incoming_mut().next().await.unwrap(); + let framed = framing::canonical(socket, 1024); + let _client = GreetingClient::builder() + .with_deadline(Duration::from_secs(5)) + .connect(framed) + .await + .unwrap(); +}