Skip to content

Commit

Permalink
feat(comms/rpc): restrict rpc session per peer tari-project#4497
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbondi committed Aug 26, 2022
1 parent 90f0034 commit 923304c
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 11 deletions.
1 change: 1 addition & 0 deletions applications/tari_base_node/src/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ where B: BlockchainBackend + 'static
let base_node_service = handles.expect_handle::<LocalNodeCommsInterface>();
let rpc_server = RpcServer::builder()
.with_maximum_simultaneous_sessions(config.rpc_max_simultaneous_sessions)
.with_maximum_sessions_per_client(config.rpc_max_sessions_per_peer)
.finish();

// Add your RPC services here ‍🏴‍☠️️☮️🌊
Expand Down
2 changes: 1 addition & 1 deletion applications/tari_miner/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ async fn display_report(report: &MiningReport, num_mining_threads: usize) {
let hashrate = report.hashes as f64 / report.elapsed.as_micros() as f64;
info!(
target: LOG_TARGET,
"⛏ Miner {} reported {:.2}MH/s with total {:.2}MH/s over {} threads. Height: {}. Target: {})",
"⛏ Miner {:0>2} reported {:.2}MH/s with total {:.2}MH/s over {} threads. Height: {}. Target: {})",
report.miner,
hashrate,
hashrate * num_mining_threads as f64,
Expand Down
4 changes: 4 additions & 0 deletions base_layer/p2p/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ pub struct P2pConfig {
/// The global maximum allowed RPC sessions.
/// Default: 100
pub rpc_max_simultaneous_sessions: usize,
/// The maximum allowed RPC sessions per peer.
/// Default: 11 (one more than the pool size for wallets)
pub rpc_max_sessions_per_peer: usize,
}

impl Default for P2pConfig {
Expand All @@ -141,6 +144,7 @@ impl Default for P2pConfig {
user_agent: "".to_string(),
auxiliary_tcp_listener_address: None,
rpc_max_simultaneous_sessions: 100,
rpc_max_sessions_per_peer: 2,
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion comms/core/src/protocol/rpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub use body::{Body, ClientStreaming, IntoBody, Streaming};
mod context;

mod server;
pub use server::{mock, NamedProtocolService, RpcServer, RpcServerError, RpcServerHandle};
pub use server::{mock, NamedProtocolService, RpcServer, RpcServerBuilder, RpcServerError, RpcServerHandle};

mod client;
pub use client::{
Expand Down
4 changes: 3 additions & 1 deletion comms/core/src/protocol/rpc/server/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use std::io;
use prost::DecodeError;
use tokio::sync::oneshot;

use crate::{proto, protocol::rpc::handshake::RpcHandshakeError};
use crate::{peer_manager::NodeId, proto, protocol::rpc::handshake::RpcHandshakeError};

#[derive(Debug, thiserror::Error)]
pub enum RpcServerError {
Expand All @@ -35,6 +35,8 @@ pub enum RpcServerError {
Io(#[from] io::Error),
#[error("Maximum number of RPC sessions reached")]
MaximumSessionsReached,
#[error("Maximum number of client RPC sessions reached for node {node_id}")]
MaxSessionsPerClientReached { node_id: NodeId },
#[error("Internal service request canceled")]
RequestCanceled,
#[error("Stream was closed by remote")]
Expand Down
55 changes: 53 additions & 2 deletions comms/core/src/protocol/rpc/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,17 @@ mod metrics;
pub mod mock;

mod router;
mod session;

use std::{
borrow::Cow,
cmp,
collections::HashMap,
convert::TryFrom,
future::Future,
io,
pin::Pin,
sync::Arc,
sync::{atomic, atomic::AtomicUsize, Arc},
task::Poll,
time::{Duration, Instant},
};
Expand Down Expand Up @@ -76,6 +79,7 @@ use crate::{
rpc::{
body::BodyBytes,
message::{RpcMethod, RpcResponse},
server::session::SessionToken,
},
ProtocolEvent,
ProtocolId,
Expand Down Expand Up @@ -169,6 +173,7 @@ impl Default for RpcServer {
#[derive(Clone)]
pub struct RpcServerBuilder {
maximum_simultaneous_sessions: Option<usize>,
maximum_sessions_per_client: Option<usize>,
minimum_client_deadline: Duration,
handshake_timeout: Duration,
}
Expand All @@ -188,6 +193,16 @@ impl RpcServerBuilder {
self
}

pub fn with_maximum_sessions_per_client(mut self, limit: usize) -> Self {
self.maximum_sessions_per_client = Some(cmp::min(limit, BoundedExecutor::max_theoretical_tasks()));
self
}

pub fn with_unlimited_sessions_per_client(mut self) -> Self {
self.maximum_sessions_per_client = None;
self
}

pub fn with_minimum_client_deadline(mut self, deadline: Duration) -> Self {
self.minimum_client_deadline = deadline;
self
Expand All @@ -206,7 +221,8 @@ impl RpcServerBuilder {
impl Default for RpcServerBuilder {
fn default() -> Self {
Self {
maximum_simultaneous_sessions: Some(1000),
maximum_simultaneous_sessions: None,
maximum_sessions_per_client: None,
minimum_client_deadline: Duration::from_secs(1),
handshake_timeout: Duration::from_secs(15),
}
Expand All @@ -220,6 +236,7 @@ pub(super) struct PeerRpcServer<TSvc, TCommsProvider> {
protocol_notifications: Option<ProtocolNotificationRx<Substream>>,
comms_provider: TCommsProvider,
request_rx: mpsc::Receiver<RpcServerRequest>,
sessions: HashMap<NodeId, Arc<AtomicUsize>>,
}

impl<TSvc, TCommsProvider> PeerRpcServer<TSvc, TCommsProvider>
Expand Down Expand Up @@ -255,6 +272,7 @@ where
protocol_notifications: Some(protocol_notifications),
comms_provider,
request_rx,
sessions: HashMap::new(),
}
}

Expand Down Expand Up @@ -336,6 +354,28 @@ where
Ok(())
}

fn new_session_for(&mut self, node_id: NodeId) -> Result<SessionToken, RpcServerError> {
match self.config.maximum_sessions_per_client {
Some(max) => {
let counter = self
.sessions
.entry(node_id.clone())
.or_insert_with(|| Arc::new(AtomicUsize::new(0)));

let count = counter.load(atomic::Ordering::Acquire);

debug_assert!(count <= max);
if count >= max {
// metrics::max_sessions_per_client_reached_counter(&node_id).inc();
return Err(RpcServerError::MaxSessionsPerClientReached { node_id });
}
counter.fetch_add(1, atomic::Ordering::Release);
Ok(SessionToken::new(counter.clone()))
},
None => Ok(SessionToken::nop()),
}
}

#[tracing::instrument(name = "rpc::server::try_initiate_service", skip(self, framed), err)]
async fn try_initiate_service(
&mut self,
Expand Down Expand Up @@ -374,6 +414,16 @@ where
},
};

let session_token = match self.new_session_for(node_id.clone()) {
Ok(token) => token,
Err(err) => {
handshake
.reject_with_reason(HandshakeRejectReason::NoSessionsAvailable)
.await?;
return Err(err);
},
};

let version = handshake.perform_server_handshake().await?;
debug!(
target: LOG_TARGET,
Expand All @@ -396,6 +446,7 @@ where
num_sessions.inc();
service.start().await;
num_sessions.dec();
drop(session_token);
})
.map_err(|_| RpcServerError::MaximumSessionsReached)?;

Expand Down
145 changes: 139 additions & 6 deletions comms/core/src/protocol/rpc/test/smoke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ use crate::{
},
RpcError,
RpcServer,
RpcServerBuilder,
RpcStatusCode,
},
ProtocolEvent,
Expand All @@ -67,25 +68,23 @@ use crate::{
Substream,
};

pub(super) async fn setup_service<T: GreetingRpc>(
pub(super) async fn setup_service_with_builder<T: GreetingRpc>(
service_impl: T,
num_concurrent_sessions: usize,
builder: RpcServerBuilder,
) -> (
mpsc::Sender<ProtocolNotification<Substream>>,
task::JoinHandle<()>,
RpcCommsBackend,
Shutdown,
) {
let (notif_tx, notif_rx) = mpsc::channel(1);
let (notif_tx, notif_rx) = mpsc::channel(10);
let shutdown = Shutdown::new();
let (context, _) = create_mocked_rpc_context();
let server_hnd = task::spawn({
let context = context.clone();
let shutdown_signal = shutdown.to_signal();
async move {
let fut = RpcServer::builder()
.with_maximum_simultaneous_sessions(num_concurrent_sessions)
.with_minimum_client_deadline(Duration::from_secs(0))
let fut = builder
.finish()
.add_service(GreetingServer::new(service_impl))
.serve(notif_rx, context);
Expand All @@ -97,9 +96,25 @@ pub(super) async fn setup_service<T: GreetingRpc>(
}
}
});

(notif_tx, server_hnd, context, shutdown)
}

pub(super) async fn setup_service<T: GreetingRpc>(
service_impl: T,
num_concurrent_sessions: usize,
) -> (
mpsc::Sender<ProtocolNotification<Substream>>,
task::JoinHandle<()>,
RpcCommsBackend,
Shutdown,
) {
let builder = RpcServer::builder()
.with_maximum_simultaneous_sessions(num_concurrent_sessions)
.with_minimum_client_deadline(Duration::from_secs(0));
setup_service_with_builder(service_impl, builder).await
}

pub(super) async fn setup<T: GreetingRpc>(
service_impl: T,
num_concurrent_sessions: usize,
Expand Down Expand Up @@ -453,3 +468,121 @@ async fn stream_interruption_handling() {
.unwrap()
.unwrap();
}

#[runtime::test]
async fn max_global_sessions() {
let builder = RpcServer::builder().with_maximum_simultaneous_sessions(1);
let (muxer, _outbound, context, _shutdown) = setup_service_with_builder(GreetingService::default(), builder).await;
let (_, mut inbound, outbound) = build_multiplexed_connections().await;

let node_identity = build_node_identity(Default::default());
// Notify that a peer wants to speak the greeting RPC protocol
context.peer_manager().add_peer(node_identity.to_peer()).await.unwrap();

for _ in 0..2 {
let substream = outbound.get_yamux_control().open_stream().await.unwrap();
muxer
.send(ProtocolNotification::new(
ProtocolId::from_static(b"/test/greeting/1.0"),
ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), substream),
))
.await
.unwrap();
}

let socket = inbound.incoming_mut().next().await.unwrap();
let framed = framing::canonical(socket, 1024);
let mut client = GreetingClient::builder()
.with_deadline(Duration::from_secs(5))
.connect(framed)
.await
.unwrap();

let socket = inbound.incoming_mut().next().await.unwrap();
let framed = framing::canonical(socket, 1024);
let err = GreetingClient::builder()
.with_deadline(Duration::from_secs(5))
.connect(framed)
.await
.unwrap_err();

unpack_enum!(RpcError::HandshakeError(err) = err);
unpack_enum!(RpcHandshakeError::Rejected(HandshakeRejectReason::NoSessionsAvailable) = err);

client.close().await;
let substream = outbound.get_yamux_control().open_stream().await.unwrap();
muxer
.send(ProtocolNotification::new(
ProtocolId::from_static(b"/test/greeting/1.0"),
ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), substream),
))
.await
.unwrap();
let socket = inbound.incoming_mut().next().await.unwrap();
let framed = framing::canonical(socket, 1024);
let _client = GreetingClient::builder()
.with_deadline(Duration::from_secs(5))
.connect(framed)
.await
.unwrap();
}

#[runtime::test]
async fn max_per_client_sessions() {
let builder = RpcServer::builder()
.with_maximum_simultaneous_sessions(3)
.with_maximum_sessions_per_client(1);
let (muxer, _outbound, context, _shutdown) = setup_service_with_builder(GreetingService::default(), builder).await;
let (_, mut inbound, outbound) = build_multiplexed_connections().await;

let node_identity = build_node_identity(Default::default());
// Notify that a peer wants to speak the greeting RPC protocol
context.peer_manager().add_peer(node_identity.to_peer()).await.unwrap();
for _ in 0..2 {
let substream = outbound.get_yamux_control().open_stream().await.unwrap();
muxer
.send(ProtocolNotification::new(
ProtocolId::from_static(b"/test/greeting/1.0"),
ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), substream),
))
.await
.unwrap();
}

let socket = inbound.incoming_mut().next().await.unwrap();
let framed = framing::canonical(socket, 1024);
let mut client = GreetingClient::builder()
.with_deadline(Duration::from_secs(5))
.connect(framed)
.await
.unwrap();

let socket = inbound.incoming_mut().next().await.unwrap();
let framed = framing::canonical(socket, 1024);
let err = GreetingClient::builder()
.with_deadline(Duration::from_secs(5))
.connect(framed)
.await
.unwrap_err();

unpack_enum!(RpcError::HandshakeError(err) = err);
unpack_enum!(RpcHandshakeError::Rejected(HandshakeRejectReason::NoSessionsAvailable) = err);

client.close().await;
drop(client);
let substream = outbound.get_yamux_control().open_stream().await.unwrap();
muxer
.send(ProtocolNotification::new(
ProtocolId::from_static(b"/test/greeting/1.0"),
ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), substream),
))
.await
.unwrap();
let socket = inbound.incoming_mut().next().await.unwrap();
let framed = framing::canonical(socket, 1024);
let _client = GreetingClient::builder()
.with_deadline(Duration::from_secs(5))
.connect(framed)
.await
.unwrap();
}

0 comments on commit 923304c

Please sign in to comment.