diff --git a/networking/rpc_framework/src/client/mod.rs b/networking/rpc_framework/src/client/mod.rs index a570dc3d7..8ce870ac5 100644 --- a/networking/rpc_framework/src/client/mod.rs +++ b/networking/rpc_framework/src/client/mod.rs @@ -73,7 +73,6 @@ use crate::{ RpcHandshakeError, RpcServerError, RpcStatus, - RPC_CHUNKING_MAX_CHUNKS, }; const LOG_TARGET: &str = "comms::rpc::client"; @@ -934,53 +933,17 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin pub async fn read_response(&mut self) -> Result<proto::RpcResponse, RpcError> { let timer = Instant::now(); - let mut resp = self.next().await?; + let resp = self.next().await?; self.time_to_first_msg = Some(timer.elapsed()); self.check_response(&resp)?; - let mut chunk_count = 1; - let mut last_chunk_flags = - RpcMessageFlags::from_bits(u8::try_from(resp.flags).map_err(|_| { - RpcStatus::protocol_error(&format!("invalid message flag: must be less than {}", u8::MAX)) - })?) - .ok_or(RpcStatus::protocol_error(&format!( - "invalid message flag, does not match any flags ({})", - resp.flags - )))?; - let mut last_chunk_size = resp.payload.len(); - self.bytes_read += last_chunk_size; - loop { - trace!( - target: LOG_TARGET, - "Chunk {} received (flags={:?}, {} bytes, {} total)", - chunk_count, - last_chunk_flags, - last_chunk_size, - resp.payload.len() - ); - if !last_chunk_flags.is_more() { - return Ok(resp); - } - - if chunk_count >= RPC_CHUNKING_MAX_CHUNKS { - return Err(RpcError::RemotePeerExceededMaxChunkCount { - expected: RPC_CHUNKING_MAX_CHUNKS, - }); - } - - let msg = self.next().await?; - last_chunk_flags = RpcMessageFlags::from_bits(u8::try_from(msg.flags).map_err(|_| { - RpcStatus::protocol_error(&format!("invalid message flag: must be less than {}", u8::MAX)) - })?) - .ok_or(RpcStatus::protocol_error(&format!( - "invalid message flag, does not match any flags ({})", - resp.flags - )))?; - last_chunk_size = msg.payload.len(); - self.bytes_read += last_chunk_size; - self.check_response(&resp)?; - resp.payload.extend(msg.payload); - chunk_count += 1; - } + self.bytes_read = resp.payload.len(); + trace!( + target: LOG_TARGET, + "Received {} bytes in {:.2?}", + resp.payload.len(), + self.time_to_first_msg.unwrap_or_default() + ); + Ok(resp) } pub async fn read_ack(&mut self) -> Result<proto::RpcResponse, RpcError> { diff --git a/networking/rpc_framework/src/lib.rs b/networking/rpc_framework/src/lib.rs index d6fbf84e6..1678b6bfa 100644 --- a/networking/rpc_framework/src/lib.rs +++ b/networking/rpc_framework/src/lib.rs @@ -25,17 +25,12 @@ //! Provides a request/response protocol that supports streaming. //! Available with the `rpc` crate feature. -// TODO: fix all tests // #[cfg(test)] // mod test; /// Maximum frame size of each RPC message. This is enforced in tokio's length delimited codec. /// This can be thought of as the hard limit on message size. pub const RPC_MAX_FRAME_SIZE: usize = 3 * 1024 * 1024; // 3 MiB -/// Maximum number of chunks into which a message can be broken up. -const RPC_CHUNKING_MAX_CHUNKS: usize = 16; // 16 x 256 Kib = 4 MiB max combined message size -const RPC_CHUNKING_THRESHOLD: usize = 256 * 1024; -const RPC_CHUNKING_SIZE_LIMIT: usize = 384 * 1024; /// The maximum request payload size const fn max_request_size() -> usize { @@ -44,7 +39,7 @@ const fn max_request_size() -> usize { /// The maximum size for a single RPC response message const fn max_response_size() -> usize { - RPC_CHUNKING_MAX_CHUNKS * RPC_CHUNKING_THRESHOLD + RPC_MAX_FRAME_SIZE } /// The maximum size for a single RPC response excluding overhead diff --git a/networking/rpc_framework/src/message.rs b/networking/rpc_framework/src/message.rs index 6f930f7a0..ca195ec1a 100644 --- a/networking/rpc_framework/src/message.rs +++ b/networking/rpc_framework/src/message.rs @@ -9,6 +9,7 @@ use bytes::Bytes; use crate::{ body::{Body, IntoBody}, error::HandshakeRejectReason, + max_response_payload_size, proto, proto::rpc_session_reply::SessionResult, RpcError, @@ -145,8 +146,6 @@ bitflags! { const FIN = 0x01; /// Typically sent with empty contents and used to confirm a substream is alive. const ACK = 0x02; - /// Another chunk to be received - const MORE = 0x04; } } impl RpcMessageFlags { @@ -157,10 +156,6 @@ impl RpcMessageFlags { pub fn is_ack(self) -> bool { self.contains(Self::ACK) } - - pub fn is_more(self) -> bool { - self.contains(Self::MORE) - } } impl Default for RpcMessageFlags { @@ -218,6 +213,20 @@ impl RpcResponse { payload: self.payload.to_vec(), } } + + pub fn exceeded_message_size(self) -> RpcResponse { + let msg = format!( + "The response size exceeded the maximum allowed payload size. Max = {} bytes, Got = {} bytes", + max_response_payload_size() as f32, + self.payload.len() as f32, + ); + RpcResponse { + request_id: self.request_id, + status: RpcStatusCode::MalformedResponse, + flags: RpcMessageFlags::FIN, + payload: msg.into_bytes().into(), + } + } } impl Default for RpcResponse { diff --git a/networking/rpc_framework/src/server/chunking.rs b/networking/rpc_framework/src/server/chunking.rs deleted file mode 100644 index 6720181cc..000000000 --- a/networking/rpc_framework/src/server/chunking.rs +++ /dev/null @@ -1,268 +0,0 @@ -// Copyright 2021, The Tari Project -// -// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the -// following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following -// disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the -// following disclaimer in the documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote -// products derived from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, -// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -use std::cmp; - -use bytes::Bytes; -use log::*; - -const LOG_TARGET: &str = "tari::rpc::server::chunking"; - -use crate::{ - message::{RpcMessageFlags, RpcResponse}, - proto, - RpcStatusCode, - RPC_CHUNKING_SIZE_LIMIT, - RPC_CHUNKING_THRESHOLD, -}; - -pub(super) struct ChunkedResponseIter { - message: RpcResponse, - initial_payload_size: usize, - has_emitted_once: bool, - num_chunks: usize, - total_chunks: usize, -} - -fn calculate_total_chunk_count(payload_len: usize) -> usize { - let mut total_chunks = payload_len / RPC_CHUNKING_THRESHOLD; - let excess = (payload_len % RPC_CHUNKING_THRESHOLD) + RPC_CHUNKING_THRESHOLD; - if total_chunks == 0 || excess > RPC_CHUNKING_SIZE_LIMIT { - // If the chunk (threshold size) + excess cannot fit in the RPC_CHUNKING_SIZE_LIMIT, then we'll emit another - // frame smaller than threshold size - total_chunks += 1; - } - - total_chunks -} - -impl ChunkedResponseIter { - pub fn new(message: RpcResponse) -> Self { - let len = message.payload.len(); - Self { - initial_payload_size: message.payload.len(), - message, - has_emitted_once: false, - num_chunks: 0, - total_chunks: calculate_total_chunk_count(len), - } - } - - fn remaining(&self) -> usize { - self.message.payload.len() - } - - fn payload_mut(&mut self) -> &mut Bytes { - &mut self.message.payload - } - - fn payload(&self) -> &Bytes { - &self.message.payload - } - - fn get_next_chunk(&mut self) -> Option<Bytes> { - let len = self.payload().len(); - if len == 0 { - if self.num_chunks > 1 { - debug!( - target: LOG_TARGET, - "Emitted {} chunks (Avg.Size: {} bytes, Total: {} bytes)", - self.num_chunks, - self.initial_payload_size / self.num_chunks, - self.initial_payload_size - ); - } - return None; - } - - // If the payload is within the maximum chunk size, simply return the rest of it - if len <= RPC_CHUNKING_SIZE_LIMIT { - let chunk = self.payload_mut().split_to(len); - self.num_chunks += 1; - trace!( - target: LOG_TARGET, - "Emitting chunk {}/{} ({} bytes)", - self.num_chunks, - self.total_chunks, - chunk.len() - ); - return Some(chunk); - } - - let chunk_size = cmp::min(len, RPC_CHUNKING_THRESHOLD); - let chunk = self.payload_mut().split_to(chunk_size); - - self.num_chunks += 1; - trace!( - target: LOG_TARGET, - "Emitting chunk {}/{} ({} bytes)", - self.num_chunks, - self.total_chunks, - chunk.len() - ); - Some(chunk) - } - - fn is_last_chunk(&self) -> bool { - self.num_chunks == self.total_chunks - } - - fn exceeded_message_size(&self) -> proto::RpcResponse { - const BYTES_PER_MB: f32 = 1024.0 * 1024.0; - // Precision loss is acceptable because this is for display purposes only - let msg = format!( - "The response size exceeded the maximum allowed payload size. Max = {:.4} MiB, Got = {:.4} MiB", - crate::max_response_payload_size() as f32 / BYTES_PER_MB, - self.message.payload.len() as f32 / BYTES_PER_MB, - ); - warn!(target: LOG_TARGET, "{}", msg); - proto::RpcResponse { - request_id: self.message.request_id, - status: RpcStatusCode::MalformedResponse as u32, - flags: RpcMessageFlags::FIN.bits().into(), - payload: msg.into_bytes(), - } - } -} - -impl Iterator for ChunkedResponseIter { - type Item = proto::RpcResponse; - - fn next(&mut self) -> Option<Self::Item> { - // Edge case: the initial message has an empty payload. - if self.initial_payload_size == 0 { - if self.has_emitted_once { - return None; - } - self.has_emitted_once = true; - return Some(self.message.to_proto()); - } - - // Edge case: the total message size cannot fit into the maximum allowed chunks - if self.remaining() > crate::max_response_payload_size() { - if self.has_emitted_once { - return None; - } - self.has_emitted_once = true; - return Some(self.exceeded_message_size()); - } - - let request_id = self.message.request_id; - let chunk = self.get_next_chunk()?; - - // status MUST be set for the first chunked message, all subsequent chunk messages MUST have a status of 0 - let mut status = 0; - if !self.has_emitted_once { - status = self.message.status as u32; - } - self.has_emitted_once = true; - - let mut flags = self.message.flags; - if !self.is_last_chunk() { - // For all chunks except the last the MORE flag MUST be set - flags |= RpcMessageFlags::MORE; - } - let msg = proto::RpcResponse { - request_id, - status, - flags: flags.bits().into(), - payload: chunk.to_vec(), - }; - - Some(msg) - } -} - -#[cfg(test)] -mod test { - use std::{convert::TryFrom, iter}; - - use super::*; - - fn create(size: usize) -> ChunkedResponseIter { - let msg = RpcResponse { - payload: iter::repeat(0).take(size).collect(), - ..Default::default() - }; - ChunkedResponseIter::new(msg) - } - - #[test] - fn it_emits_a_zero_size_message() { - let iter = create(0); - assert_eq!(iter.total_chunks, 1); - let msgs = iter.collect::<Vec<_>>(); - assert_eq!(msgs.len(), 1); - assert!(!RpcMessageFlags::from_bits(u8::try_from(msgs[0].flags).unwrap()) - .unwrap() - .is_more()); - } - - #[test] - fn it_emits_one_message_below_threshold() { - let iter = create(RPC_CHUNKING_THRESHOLD - 1); - assert_eq!(iter.total_chunks, 1); - let msgs = iter.collect::<Vec<_>>(); - assert_eq!(msgs.len(), 1); - assert!(!RpcMessageFlags::from_bits(u8::try_from(msgs[0].flags).unwrap()) - .unwrap() - .is_more()); - } - - #[test] - fn it_emits_a_single_message() { - let iter = create(RPC_CHUNKING_SIZE_LIMIT - 1); - assert_eq!(iter.count(), 1); - - let iter = create(RPC_CHUNKING_SIZE_LIMIT); - assert_eq!(iter.count(), 1); - } - - #[test] - fn it_emits_an_expected_number_of_chunks() { - let iter = create(RPC_CHUNKING_THRESHOLD * 2); - assert_eq!(iter.count(), 2); - - let diff = RPC_CHUNKING_SIZE_LIMIT - RPC_CHUNKING_THRESHOLD; - let iter = create(RPC_CHUNKING_THRESHOLD * 2 + diff); - assert_eq!(iter.count(), 2); - - let iter = create(RPC_CHUNKING_THRESHOLD * 2 + diff + 1); - assert_eq!(iter.count(), 3); - } - - #[test] - fn it_sets_the_more_flag_except_last() { - use std::convert::TryFrom; - let iter = create(RPC_CHUNKING_THRESHOLD * 3); - let msgs = iter.collect::<Vec<_>>(); - assert!(RpcMessageFlags::from_bits(u8::try_from(msgs[0].flags).unwrap()) - .unwrap() - .is_more()); - assert!(RpcMessageFlags::from_bits(u8::try_from(msgs[1].flags).unwrap()) - .unwrap() - .is_more()); - assert!(!RpcMessageFlags::from_bits(u8::try_from(msgs[2].flags).unwrap()) - .unwrap() - .is_more()); - } -} diff --git a/networking/rpc_framework/src/server/mod.rs b/networking/rpc_framework/src/server/mod.rs index bb9809cb7..8ea7e5d5c 100644 --- a/networking/rpc_framework/src/server/mod.rs +++ b/networking/rpc_framework/src/server/mod.rs @@ -20,9 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -mod chunking; -use chunking::ChunkedResponseIter; - mod error; pub use error::RpcServerError; @@ -53,7 +50,7 @@ use std::{ }; use bytes::Bytes; -use futures::{future, stream, stream::FuturesUnordered, SinkExt, Stream, StreamExt}; +use futures::{future, stream::FuturesUnordered, SinkExt, Stream, StreamExt}; use libp2p::{PeerId, StreamProtocol}; use libp2p_substream::{ProtocolEvent, ProtocolNotification}; use log::*; @@ -66,6 +63,7 @@ use tracing::{debug, error, instrument, span, trace, warn, Instrument, Level}; use super::{ body::Body, error::HandshakeRejectReason, + max_response_payload_size, message::{Request, Response, RpcMessageFlags}, not_found::ProtocolServiceNotFound, status::RpcStatus, @@ -710,12 +708,15 @@ where TSvc: Service<Request<Bytes>, Response = Response<Body>, Error = RpcStatus let mut stream = body .into_message() .map(|result| into_response(request_id, result)) - .flat_map(move |message| { + .map(move |mut message| { + if message.payload.len() > max_response_payload_size() { + message = message.exceeded_message_size(); + } #[cfg(feature = "metrics")] if !message.status.is_ok() { metrics::status_error_counter(&peer_id, &protocol, message.status).inc(); } - stream::iter(ChunkedResponseIter::new(message)) + message.to_proto() }) .map(|resp| Bytes::from(resp.encode_to_vec())); diff --git a/networking/rpc_framework/src/test/comms_integration.rs b/networking/rpc_framework/src/test/comms_integration.rs deleted file mode 100644 index f2d6f6dae..000000000 --- a/networking/rpc_framework/src/test/comms_integration.rs +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2020, The Tari Project -// -// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the -// following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following -// disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the -// following disclaimer in the documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote -// products derived from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, -// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -use tari_shutdown::Shutdown; -use tari_test_utils::unpack_enum; - -use crate::{ - protocol::rpc::{ - test::mock::{MockRpcClient, MockRpcService}, - RpcError, - RpcServer, - RpcStatus, - RpcStatusCode, - }, - test_utils::node_identity::build_node_identity, - transports::MemoryTransport, - types::CommsDatabase, - CommsBuilder, -}; - -#[tokio::test] -async fn run_service() { - let node_identity1 = build_node_identity(Default::default()); - let rpc_service = MockRpcService::new(); - let mock_state = rpc_service.shared_state(); - let shutdown = Shutdown::new(); - let comms1 = CommsBuilder::new() - .with_listener_address(node_identity1.first_public_address().unwrap()) - .with_node_identity(node_identity1) - .with_shutdown_signal(shutdown.to_signal()) - .with_peer_storage(CommsDatabase::new(), None) - .build() - .unwrap() - .add_rpc_server(RpcServer::new().add_service(rpc_service)) - .spawn_with_transport(MemoryTransport) - .await - .unwrap(); - - let node_identity2 = build_node_identity(Default::default()); - let comms2 = CommsBuilder::new() - .with_listener_address(node_identity2.first_public_address().unwrap()) - .with_shutdown_signal(shutdown.to_signal()) - .with_node_identity(node_identity2.clone()) - .with_peer_storage(CommsDatabase::new(), None) - .build() - .unwrap(); - - comms2 - .peer_manager() - .add_peer(comms1.node_identity().to_peer()) - .await - .unwrap(); - - let comms2 = comms2.spawn_with_transport(MemoryTransport).await.unwrap(); - - let mut conn = comms2 - .connectivity() - .dial_peer(comms1.node_identity().node_id().clone()) - .await - .unwrap(); - - let mut client = conn.connect_rpc::<MockRpcClient>().await.unwrap(); - - mock_state.set_response_ok(&()); - client.request_response::<_, ()>((), 0.into()).await.unwrap(); - assert_eq!(mock_state.call_count(), 1); - - mock_state.set_response_err(RpcStatus::bad_request("Insert 💾")); - let err = client.request_response::<_, ()>((), 0.into()).await.unwrap_err(); - unpack_enum!(RpcError::RequestFailed(status) = err); - unpack_enum!(RpcStatusCode::BadRequest = status.as_status_code()); - assert_eq!(mock_state.call_count(), 2); -} diff --git a/networking/rpc_framework/src/test/greeting_service.rs b/networking/rpc_framework/src/test/greeting_service.rs index 885e2d13a..b33d83a1f 100644 --- a/networking/rpc_framework/src/test/greeting_service.rs +++ b/networking/rpc_framework/src/test/greeting_service.rs @@ -31,6 +31,7 @@ use std::{ time::Duration, }; +use async_trait::async_trait; use tari_utilities::hex::Hex; use tokio::{ sync::{mpsc, RwLock}, @@ -38,15 +39,7 @@ use tokio::{ time, }; -use crate::{ - async_trait, - protocol::{ - rpc::{NamedProtocolService, Request, Response, RpcError, RpcServerError, RpcStatus, Streaming}, - ProtocolId, - }, - utils, - Substream, -}; +use crate::{Request, Response, RpcStatus, Streaming}; #[async_trait] // #[tari_rpc(protocol_name = "/tari/greeting/1.0", server_struct = GreetingServer, client_struct = GreetingClient)] @@ -126,7 +119,9 @@ impl GreetingRpc for GreetingService { let num = *request.message(); let greetings = self.greetings[..cmp::min(num as usize, self.greetings.len())].to_vec(); task::spawn(async move { - let _result = utils::mpsc::send_all(&tx, greetings.into_iter().map(Ok)).await; + for greeting in greetings { + tx.send(Ok(greeting)).await.unwrap(); + } }); Ok(Streaming::new(rx)) diff --git a/networking/rpc_framework/src/test/mock.rs b/networking/rpc_framework/src/test/mock.rs index 14725712b..66ab1a092 100644 --- a/networking/rpc_framework/src/test/mock.rs +++ b/networking/rpc_framework/src/test/mock.rs @@ -172,12 +172,3 @@ impl From<RpcClient> for MockRpcClient { Self { inner } } } - -pub(crate) fn create_mocked_rpc_context() -> (RpcCommsBackend, ConnectivityManagerMockState) { - let (connectivity, mock) = create_connectivity_mock(); - let mock_state = mock.get_shared_state(); - mock.spawn(); - let peer_manager = build_peer_manager(); - - (RpcCommsBackend::new(peer_manager, connectivity), mock_state) -} diff --git a/networking/rpc_framework/src/test/mod.rs b/networking/rpc_framework/src/test/mod.rs index 206e84c7c..9273b7142 100644 --- a/networking/rpc_framework/src/test/mod.rs +++ b/networking/rpc_framework/src/test/mod.rs @@ -20,7 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -mod comms_integration; pub(super) mod greeting_service; mod handshake; pub(super) mod mock; diff --git a/networking/rpc_framework/src/test/smoke.rs b/networking/rpc_framework/src/test/smoke.rs index 595771535..bef0a8def 100644 --- a/networking/rpc_framework/src/test/smoke.rs +++ b/networking/rpc_framework/src/test/smoke.rs @@ -23,6 +23,8 @@ use std::{sync::Arc, time::Duration}; use futures::StreamExt; +use libp2p::{PeerId, StreamProtocol}; +use libp2p_substream::{ProtocolEvent, ProtocolNotification}; use tari_shutdown::Shutdown; use tari_test_utils::unpack_enum; use tari_utilities::hex::Hex; @@ -33,37 +35,22 @@ use tokio::{ }; use crate::{ + error::HandshakeRejectReason, framing, - multiplexing::Yamux, - protocol::{ - rpc, - rpc::{ - context::RpcCommsBackend, - error::HandshakeRejectReason, - handshake::RpcHandshakeError, - test::{ - greeting_service::{ - GreetingClient, - GreetingRpc, - GreetingServer, - GreetingService, - SayHelloRequest, - SlowGreetingService, - SlowStreamRequest, - }, - mock::create_mocked_rpc_context, - }, - RpcError, - RpcServer, - RpcServerBuilder, - RpcStatusCode, - }, - ProtocolEvent, - ProtocolId, - ProtocolNotification, + handshake::RpcHandshakeError, + test::greeting_service::{ + GreetingClient, + GreetingRpc, + GreetingServer, + GreetingService, + SayHelloRequest, + SlowGreetingService, + SlowStreamRequest, }, - test_utils::{node_identity::build_node_identity, transport::build_multiplexed_connections}, - NodeIdentity, + RpcError, + RpcServer, + RpcServerBuilder, + RpcStatusCode, Substream, }; @@ -71,22 +58,19 @@ pub(super) async fn setup_service_with_builder<T: GreetingRpc>( service_impl: T, builder: RpcServerBuilder, ) -> ( - mpsc::Sender<ProtocolNotification<Substream>>, + mpsc::UnboundedSender<ProtocolNotification<Substream>>, task::JoinHandle<()>, - RpcCommsBackend, Shutdown, ) { - let (notif_tx, notif_rx) = mpsc::channel(10); + let (notif_tx, notif_rx) = mpsc::unbounded_channel(); let shutdown = Shutdown::new(); - let (context, _) = create_mocked_rpc_context(); let server_hnd = task::spawn({ - let context = context.clone(); let shutdown_signal = shutdown.to_signal(); async move { let fut = builder .finish() .add_service(GreetingServer::new(service_impl)) - .serve(notif_rx, context); + .serve(notif_rx); tokio::select! { biased; @@ -96,16 +80,15 @@ pub(super) async fn setup_service_with_builder<T: GreetingRpc>( } }); - (notif_tx, server_hnd, context, shutdown) + (notif_tx, server_hnd, shutdown) } pub(super) async fn setup_service<T: GreetingRpc>( service_impl: T, num_concurrent_sessions: usize, ) -> ( - mpsc::Sender<ProtocolNotification<Substream>>, + mpsc::UnboundedSender<ProtocolNotification<Substream>>, task::JoinHandle<()>, - RpcCommsBackend, Shutdown, ) { let builder = RpcServer::builder() @@ -121,14 +104,13 @@ pub(super) async fn setup<T: GreetingRpc>( let (notif_tx, server_hnd, context, shutdown) = setup_service(service_impl, num_concurrent_sessions).await; let (_, inbound, outbound) = build_multiplexed_connections().await; let substream = outbound.get_yamux_control().open_stream().await.unwrap(); + let peer_id = PeerId::random(); - let node_identity = build_node_identity(Default::default()); // Notify that a peer wants to speak the greeting RPC protocol - context.peer_manager().add_peer(node_identity.to_peer()).await.unwrap(); notif_tx .send(ProtocolNotification::new( - ProtocolId::from_static(b"/test/greeting/1.0"), - ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), substream), + StreamProtocol::new(b"/test/greeting/1.0"), + ProtocolEvent::NewInboundSubstream { peer_id, substream }, )) .await .unwrap();