Skip to content

Commit

Permalink
feat(comms/rpc): restrict rpc session per peer #4497 (#4549)
Browse files Browse the repository at this point in the history
Description
---
- restrict the max number of active rpc sessions per peer

Motivation and Context
---
fixes #4497 

How Has This Been Tested?
---
Additional unit tests
  • Loading branch information
sdbondi authored Aug 29, 2022
1 parent 7e7d053 commit 080bccf
Show file tree
Hide file tree
Showing 11 changed files with 219 additions and 13 deletions.
1 change: 1 addition & 0 deletions applications/tari_base_node/src/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ where B: BlockchainBackend + 'static
let base_node_service = handles.expect_handle::<LocalNodeCommsInterface>();
let rpc_server = RpcServer::builder()
.with_maximum_simultaneous_sessions(config.rpc_max_simultaneous_sessions)
.with_maximum_sessions_per_client(config.rpc_max_sessions_per_peer)
.finish();

// Add your RPC services here ‍🏴‍☠️️☮️🌊
Expand Down
2 changes: 1 addition & 1 deletion applications/tari_miner/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ async fn display_report(report: &MiningReport, num_mining_threads: usize) {
let hashrate = report.hashes as f64 / report.elapsed.as_micros() as f64;
info!(
target: LOG_TARGET,
"⛏ Miner {} reported {:.2}MH/s with total {:.2}MH/s over {} threads. Height: {}. Target: {})",
"⛏ Miner {:0>2} reported {:.2}MH/s with total {:.2}MH/s over {} threads. Height: {}. Target: {})",
report.miner,
hashrate,
hashrate * num_mining_threads as f64,
Expand Down
4 changes: 4 additions & 0 deletions base_layer/p2p/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ pub struct P2pConfig {
/// The global maximum allowed RPC sessions.
/// Default: 100
pub rpc_max_simultaneous_sessions: usize,
/// The maximum allowed RPC sessions per peer.
/// Default: 10
pub rpc_max_sessions_per_peer: usize,
}

impl Default for P2pConfig {
Expand All @@ -141,6 +144,7 @@ impl Default for P2pConfig {
user_agent: "".to_string(),
auxiliary_tcp_listener_address: None,
rpc_max_simultaneous_sessions: 100,
rpc_max_sessions_per_peer: 10,
}
}
}
Expand Down
1 change: 1 addition & 0 deletions base_layer/wallet/tests/contacts_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ pub fn setup_contacts_service<T: ContactsBackend + 'static>(
listener_liveness_max_sessions: 0,
user_agent: "tari/test-wallet".to_string(),
rpc_max_simultaneous_sessions: 0,
rpc_max_sessions_per_peer: 0,
};
let peer_message_subscription_factory = Arc::new(subscription_factory);
let shutdown = Shutdown::new();
Expand Down
3 changes: 3 additions & 0 deletions base_layer/wallet/tests/wallet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ async fn create_wallet(
user_agent: "tari/test-wallet".to_string(),
auxiliary_tcp_listener_address: None,
rpc_max_simultaneous_sessions: 0,
rpc_max_sessions_per_peer: 0,
};

let sql_database_path = comms_config
Expand Down Expand Up @@ -642,6 +643,7 @@ async fn test_store_and_forward_send_tx() {
assert!(tx_recv, "Must have received a tx from alice");
}

#[allow(clippy::too_many_lines)]
#[tokio::test]
async fn test_import_utxo() {
let factories = CryptoFactories::default();
Expand Down Expand Up @@ -678,6 +680,7 @@ async fn test_import_utxo() {
user_agent: "tari/test-wallet".to_string(),
auxiliary_tcp_listener_address: None,
rpc_max_simultaneous_sessions: 0,
rpc_max_sessions_per_peer: 0,
};
let config = WalletConfig {
p2p: comms_config,
Expand Down
1 change: 1 addition & 0 deletions base_layer/wallet_ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3917,6 +3917,7 @@ pub unsafe extern "C" fn comms_config_create(
listener_liveness_max_sessions: 0,
user_agent: format!("tari/mobile_wallet/{}", env!("CARGO_PKG_VERSION")),
rpc_max_simultaneous_sessions: 0,
rpc_max_sessions_per_peer: 0,
};

Box::into_raw(Box::new(config))
Expand Down
2 changes: 2 additions & 0 deletions common/config/presets/c_base_node.toml
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ track_reorgs = true
# The maximum simultaneous comms RPC sessions allowed (default value = 100). Setting this to -1 will allow unlimited
# sessions.
#rpc_max_simultaneous_sessions = 100
# The maximum comms RPC sessions allowed per peer (default value = 10).
#rpc_max_sessions_per_peer = 10

[base_node.p2p.transport]
# -------------- Transport configuration --------------
Expand Down
2 changes: 1 addition & 1 deletion comms/core/src/protocol/rpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub use body::{Body, ClientStreaming, IntoBody, Streaming};
mod context;

mod server;
pub use server::{mock, NamedProtocolService, RpcServer, RpcServerError, RpcServerHandle};
pub use server::{mock, NamedProtocolService, RpcServer, RpcServerBuilder, RpcServerError, RpcServerHandle};

mod client;
pub use client::{
Expand Down
4 changes: 3 additions & 1 deletion comms/core/src/protocol/rpc/server/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use std::io;
use prost::DecodeError;
use tokio::sync::oneshot;

use crate::{proto, protocol::rpc::handshake::RpcHandshakeError};
use crate::{peer_manager::NodeId, proto, protocol::rpc::handshake::RpcHandshakeError};

#[derive(Debug, thiserror::Error)]
pub enum RpcServerError {
Expand All @@ -35,6 +35,8 @@ pub enum RpcServerError {
Io(#[from] io::Error),
#[error("Maximum number of RPC sessions reached")]
MaximumSessionsReached,
#[error("Maximum number of client RPC sessions reached for node {node_id}")]
MaxSessionsPerClientReached { node_id: NodeId },
#[error("Internal service request canceled")]
RequestCanceled,
#[error("Stream was closed by remote")]
Expand Down
67 changes: 63 additions & 4 deletions comms/core/src/protocol/rpc/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ mod metrics;
pub mod mock;

mod router;

use std::{
borrow::Cow,
cmp,
collections::HashMap,
convert::TryFrom,
future::Future,
io,
Expand All @@ -47,10 +49,10 @@ use std::{
time::{Duration, Instant},
};

use futures::{future, stream, SinkExt, StreamExt};
use futures::{future, stream, stream::FuturesUnordered, SinkExt, StreamExt};
use prost::Message;
use router::Router;
use tokio::{sync::mpsc, time};
use tokio::{sync::mpsc, task::JoinHandle, time};
use tokio_stream::Stream;
use tower::{make::MakeService, Service};
use tracing::{debug, error, instrument, span, trace, warn, Instrument, Level};
Expand Down Expand Up @@ -169,6 +171,7 @@ impl Default for RpcServer {
#[derive(Clone)]
pub struct RpcServerBuilder {
maximum_simultaneous_sessions: Option<usize>,
maximum_sessions_per_client: Option<usize>,
minimum_client_deadline: Duration,
handshake_timeout: Duration,
}
Expand All @@ -188,6 +191,16 @@ impl RpcServerBuilder {
self
}

pub fn with_maximum_sessions_per_client(mut self, limit: usize) -> Self {
self.maximum_sessions_per_client = Some(cmp::min(limit, BoundedExecutor::max_theoretical_tasks()));
self
}

pub fn with_unlimited_sessions_per_client(mut self) -> Self {
self.maximum_sessions_per_client = None;
self
}

pub fn with_minimum_client_deadline(mut self, deadline: Duration) -> Self {
self.minimum_client_deadline = deadline;
self
Expand All @@ -206,7 +219,8 @@ impl RpcServerBuilder {
impl Default for RpcServerBuilder {
fn default() -> Self {
Self {
maximum_simultaneous_sessions: Some(1000),
maximum_simultaneous_sessions: None,
maximum_sessions_per_client: None,
minimum_client_deadline: Duration::from_secs(1),
handshake_timeout: Duration::from_secs(15),
}
Expand All @@ -220,6 +234,8 @@ pub(super) struct PeerRpcServer<TSvc, TCommsProvider> {
protocol_notifications: Option<ProtocolNotificationRx<Substream>>,
comms_provider: TCommsProvider,
request_rx: mpsc::Receiver<RpcServerRequest>,
sessions: HashMap<NodeId, usize>,
tasks: FuturesUnordered<JoinHandle<NodeId>>,
}

impl<TSvc, TCommsProvider> PeerRpcServer<TSvc, TCommsProvider>
Expand Down Expand Up @@ -255,6 +271,8 @@ where
protocol_notifications: Some(protocol_notifications),
comms_provider,
request_rx,
sessions: HashMap::new(),
tasks: FuturesUnordered::new(),
}
}

Expand All @@ -274,6 +292,10 @@ where
}
}

Some(Ok(node_id)) = self.tasks.next() => {
self.on_session_complete(&node_id);
},

Some(req) = self.request_rx.recv() => {
self.handle_request(req).await;
},
Expand Down Expand Up @@ -336,6 +358,32 @@ where
Ok(())
}

fn new_session_for(&mut self, node_id: NodeId) -> Result<usize, RpcServerError> {
match self.config.maximum_sessions_per_client {
Some(max) if max > 0 => {
let count = self.sessions.entry(node_id.clone()).or_insert(0);

debug_assert!(*count <= max);
if *count >= max {
return Err(RpcServerError::MaxSessionsPerClientReached { node_id });
}
*count += 1;
Ok(*count)
},
Some(_) => Ok(0),
None => Ok(0),
}
}

fn on_session_complete(&mut self, node_id: &NodeId) {
if let Some(v) = self.sessions.get_mut(node_id) {
*v -= 1;
if *v == 0 {
self.sessions.remove(node_id);
}
}
}

#[tracing::instrument(name = "rpc::server::try_initiate_service", skip(self, framed), err)]
async fn try_initiate_service(
&mut self,
Expand Down Expand Up @@ -374,6 +422,13 @@ where
},
};

if let Err(err) = self.new_session_for(node_id.clone()) {
handshake
.reject_with_reason(HandshakeRejectReason::NoSessionsAvailable)
.await?;
return Err(err);
}

let version = handshake.perform_server_handshake().await?;
debug!(
target: LOG_TARGET,
Expand All @@ -390,15 +445,19 @@ where
);

let node_id = node_id.clone();
self.executor
let handle = self
.executor
.try_spawn(async move {
let num_sessions = metrics::num_sessions(&node_id, &service.protocol);
num_sessions.inc();
service.start().await;
num_sessions.dec();
node_id
})
.map_err(|_| RpcServerError::MaximumSessionsReached)?;

self.tasks.push(handle);

Ok(())
}
}
Expand Down
Loading

0 comments on commit 080bccf

Please sign in to comment.