diff --git a/applications/tari_base_node/src/command_handler.rs b/applications/tari_base_node/src/command_handler.rs index 76d79551ed..ef8fd3a57c 100644 --- a/applications/tari_base_node/src/command_handler.rs +++ b/applications/tari_base_node/src/command_handler.rs @@ -123,10 +123,8 @@ impl CommandHandler { self.executor.spawn(async move { let mut status_line = StatusLine::new(); - let version = format!("v{}", consts::APP_VERSION_NUMBER); - status_line.add_field("", version); - let network = format!("{}", config.network); - status_line.add_field("", network); + status_line.add_field("", format!("v{}", consts::APP_VERSION_NUMBER)); + status_line.add_field("", config.network); status_line.add_field("State", state_info.borrow().state_info.short_desc()); let metadata = node.get_metadata().await.unwrap(); diff --git a/comms/rpc_macros/src/generator.rs b/comms/rpc_macros/src/generator.rs index 5f44066f19..a6f4ac1917 100644 --- a/comms/rpc_macros/src/generator.rs +++ b/comms/rpc_macros/src/generator.rs @@ -194,15 +194,15 @@ impl RpcCodeGenerator { .collect::<TokenStream>(); let client_struct_body = quote! { - pub async fn connect<TSubstream>(framed: #dep_mod::CanonicalFraming<TSubstream>) -> Result<Self, #dep_mod::RpcError> - where TSubstream: #dep_mod::AsyncRead + #dep_mod::AsyncWrite + Unpin + Send + 'static { + pub async fn connect(framed: #dep_mod::CanonicalFraming<#dep_mod::Substream>) -> Result<Self, #dep_mod::RpcError> { use #dep_mod::NamedProtocolService; let inner = #dep_mod::RpcClient::connect(Default::default(), framed, Self::PROTOCOL_NAME.into()).await?; Ok(Self { inner }) } pub fn builder() -> #dep_mod::RpcClientBuilder<Self> { - #dep_mod::RpcClientBuilder::new() + use #dep_mod::NamedProtocolService; + #dep_mod::RpcClientBuilder::new().with_protocol_id(Self::PROTOCOL_NAME.into()) } #client_methods diff --git a/comms/src/memsocket/mod.rs b/comms/src/memsocket/mod.rs index ed77fc6146..caaa683593 100644 --- a/comms/src/memsocket/mod.rs +++ b/comms/src/memsocket/mod.rs @@ -30,6 +30,7 @@ use futures::{ stream::{FusedStream, Stream}, task::{Context, Poll}, }; +use log::*; use std::{ cmp, collections::{hash_map::Entry, HashMap}, @@ -433,6 +434,7 @@ impl AsyncRead for MemorySocket { buf.advance(bytes_to_read); current_buffer.advance(bytes_to_read); + trace!("reading {} bytes", bytes_to_read); bytes_read += bytes_to_read; } @@ -462,11 +464,12 @@ impl AsyncRead for MemorySocket { impl AsyncWrite for MemorySocket { /// Attempt to write bytes from `buf` into the outgoing channel. - fn poll_write(mut self: Pin<&mut Self>, context: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> { + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> { let len = buf.len(); - match self.outgoing.poll_ready(context) { + match self.outgoing.poll_ready(cx) { Poll::Ready(Ok(())) => { + trace!("writing {} bytes", len); if let Err(e) = self.outgoing.start_send(Bytes::copy_from_slice(buf)) { if e.is_disconnected() { return Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, e))); @@ -475,6 +478,7 @@ impl AsyncWrite for MemorySocket { // Unbounded channels should only ever have "Disconnected" errors unreachable!(); } + Poll::Ready(Ok(len)) }, Poll::Ready(Err(e)) => { if e.is_disconnected() { @@ -484,19 +488,18 @@ impl AsyncWrite for MemorySocket { // Unbounded channels should only ever have "Disconnected" errors unreachable!(); }, - Poll::Pending => return Poll::Pending, + Poll::Pending => Poll::Pending, } - - Poll::Ready(Ok(len)) } /// Attempt to flush the channel. Cannot Fail. - fn poll_flush(self: Pin<&mut Self>, _context: &mut Context) -> Poll<io::Result<()>> { + fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<io::Result<()>> { + trace!("flush"); Poll::Ready(Ok(())) } /// Attempt to close the channel. Cannot Fail. - fn poll_shutdown(self: Pin<&mut Self>, _context: &mut Context) -> Poll<io::Result<()>> { + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context) -> Poll<io::Result<()>> { self.outgoing.close_channel(); Poll::Ready(Ok(())) @@ -506,7 +509,8 @@ impl AsyncWrite for MemorySocket { #[cfg(test)] mod test { use super::*; - use crate::runtime; + use crate::{framing, runtime}; + use futures::SinkExt; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio_stream::StreamExt; @@ -705,4 +709,21 @@ mod test { Ok(()) } + + #[runtime::test] + async fn read_and_write_canonical_framing() -> io::Result<()> { + let (a, b) = MemorySocket::new_pair(); + let mut a = framing::canonical(a, 1024); + let mut b = framing::canonical(b, 1024); + + a.send(Bytes::from_static(b"frame-1")).await?; + b.send(Bytes::from_static(b"frame-2")).await?; + let msg = b.next().await.unwrap()?; + assert_eq!(&msg[..], b"frame-1"); + + let msg = a.next().await.unwrap()?; + assert_eq!(&msg[..], b"frame-2"); + + Ok(()) + } } diff --git a/comms/src/multiplexing/yamux.rs b/comms/src/multiplexing/yamux.rs index 1723033739..28a14dfff7 100644 --- a/comms/src/multiplexing/yamux.rs +++ b/comms/src/multiplexing/yamux.rs @@ -166,7 +166,7 @@ pub struct IncomingSubstreams { } impl IncomingSubstreams { - pub fn new(inner: IncomingRx, substream_counter: SubstreamCounter, shutdown: Shutdown) -> Self { + pub(self) fn new(inner: IncomingRx, substream_counter: SubstreamCounter, shutdown: Shutdown) -> Self { Self { inner, substream_counter, @@ -205,6 +205,12 @@ pub struct Substream { counter_guard: CounterGuard, } +impl Substream { + pub fn id(&self) -> yamux::StreamId { + self.stream.get_ref().id() + } +} + impl tokio::io::AsyncRead for Substream { fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> { Pin::new(&mut self.stream).poll_read(cx, buf) @@ -242,13 +248,17 @@ where TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static } } - #[tracing::instrument(name = "yamux::incoming_worker::run", skip(self))] + #[tracing::instrument(name = "yamux::incoming_worker::run", skip(self), fields(connection = %self.connection))] pub async fn run(mut self) { loop { tokio::select! { biased; - _ = &mut self.shutdown_signal => { + _ = self.shutdown_signal.wait() => { + debug!( + target: LOG_TARGET, + "{} Yamux connection shutdown", self.connection + ); let mut control = self.connection.control(); if let Err(err) = control.close().await { error!(target: LOG_TARGET, "Failed to close yamux connection: {}", err); @@ -259,11 +269,13 @@ where TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static result = self.connection.next_stream() => { match result { Ok(Some(stream)) => { - event!(Level::TRACE, "yamux::stream received {}", stream);if self.sender.send(stream).await.is_err() { + event!(Level::TRACE, "yamux::incoming_worker::new_stream {}", stream); + if self.sender.send(stream).await.is_err() { debug!( target: LOG_TARGET, - "Incoming peer substream task is shutting down because the internal stream sender channel \ - was closed" + "{} Incoming peer substream task is shutting down because the internal stream sender channel \ + was closed", + self.connection ); break; } @@ -271,19 +283,23 @@ where TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static Ok(None) =>{ debug!( target: LOG_TARGET, - "Incoming peer substream completed. IncomingWorker exiting" + "{} Incoming peer substream completed. IncomingWorker exiting", + self.connection ); break; } Err(err) => { event!( - Level::ERROR, - "Incoming peer substream task received an error because '{}'", - err - ); - error!( + Level::ERROR, + "{} Incoming peer substream task received an error because '{}'", + self.connection, + err + ); + error!( target: LOG_TARGET, - "Incoming peer substream task received an error because '{}'", err + "{} Incoming peer substream task received an error because '{}'", + self.connection, + err ); break; }, diff --git a/comms/src/protocol/rpc/client.rs b/comms/src/protocol/rpc/client.rs index 5467befeba..a6d6554f3e 100644 --- a/comms/src/protocol/rpc/client.rs +++ b/comms/src/protocol/rpc/client.rs @@ -38,6 +38,7 @@ use crate::{ ProtocolId, }, runtime::task, + Substream, }; use bytes::Bytes; use futures::{ @@ -60,7 +61,6 @@ use std::{ }; use tari_shutdown::{Shutdown, ShutdownSignal}; use tokio::{ - io::{AsyncRead, AsyncWrite}, sync::{mpsc, oneshot, Mutex}, time, }; @@ -76,14 +76,11 @@ pub struct RpcClient { impl RpcClient { /// Create a new RpcClient using the given framed substream and perform the RPC handshake. - pub async fn connect<TSubstream>( + pub async fn connect( config: RpcClientConfig, - framed: CanonicalFraming<TSubstream>, + framed: CanonicalFraming<Substream>, protocol_name: ProtocolId, - ) -> Result<Self, RpcError> - where - TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static, - { + ) -> Result<Self, RpcError> { let (request_tx, request_rx) = mpsc::channel(1); let shutdown = Shutdown::new(); let shutdown_signal = shutdown.to_signal(); @@ -224,14 +221,14 @@ where TClient: From<RpcClient> + NamedProtocolService self } - pub(crate) fn with_protocol_id(mut self, protocol_id: ProtocolId) -> Self { + /// Set the protocol ID associated with this client. This is used for logging purposes only. + pub fn with_protocol_id(mut self, protocol_id: ProtocolId) -> Self { self.protocol_id = Some(protocol_id); self } /// Negotiates and establishes a session to the peer's RPC service - pub async fn connect<TSubstream>(self, framed: CanonicalFraming<TSubstream>) -> Result<TClient, RpcError> - where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static { + pub async fn connect(self, framed: CanonicalFraming<Substream>) -> Result<TClient, RpcError> { RpcClient::connect( self.config, framed, @@ -346,10 +343,10 @@ impl Service<BaseRequest<Bytes>> for ClientConnector { } } -pub struct RpcClientWorker<TSubstream> { +struct RpcClientWorker { config: RpcClientConfig, request_rx: mpsc::Receiver<ClientRequest>, - framed: CanonicalFraming<TSubstream>, + framed: CanonicalFraming<Substream>, // Request ids are limited to u16::MAX because varint encoding is used over the wire and the magnitude of the value // sent determines the byte size. A u16 will be more than enough for the purpose next_request_id: u16, @@ -359,13 +356,11 @@ pub struct RpcClientWorker<TSubstream> { shutdown_signal: ShutdownSignal, } -impl<TSubstream> RpcClientWorker<TSubstream> -where TSubstream: AsyncRead + AsyncWrite + Unpin + Send -{ - pub fn new( +impl RpcClientWorker { + pub(self) fn new( config: RpcClientConfig, request_rx: mpsc::Receiver<ClientRequest>, - framed: CanonicalFraming<TSubstream>, + framed: CanonicalFraming<Substream>, ready_tx: oneshot::Sender<Result<(), RpcError>>, protocol_id: ProtocolId, shutdown_signal: ShutdownSignal, @@ -386,11 +381,16 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send String::from_utf8_lossy(&self.protocol_id) } - #[tracing::instrument(name = "rpc_client_worker run", skip(self), fields(next_request_id= self.next_request_id))] + fn stream_id(&self) -> yamux::StreamId { + self.framed.get_ref().id() + } + + #[tracing::instrument(name = "rpc_client_worker run", skip(self), fields(next_request_id = self.next_request_id))] async fn run(mut self) { debug!( target: LOG_TARGET, - "Performing client handshake for '{}'", + "(stream={}) Performing client handshake for '{}'", + self.stream_id(), self.protocol_name() ); let start = Instant::now(); @@ -400,7 +400,8 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send let latency = start.elapsed(); debug!( target: LOG_TARGET, - "RPC Session ({}) negotiation completed. Latency: {:.0?}", + "(stream={}) RPC Session ({}) negotiation completed. Latency: {:.0?}", + self.stream_id(), self.protocol_name(), latency ); @@ -428,7 +429,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send match req { Some(req) => { if let Err(err) = self.handle_request(req).await { - error!(target: LOG_TARGET, "Unexpected error: {}. Worker is terminating.", err); + error!(target: LOG_TARGET, "(stream={}) Unexpected error: {}. Worker is terminating.", self.stream_id(), err); break; } } @@ -439,12 +440,18 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send } if let Err(err) = self.framed.close().await { - debug!(target: LOG_TARGET, "IO Error when closing substream: {}", err); + debug!( + target: LOG_TARGET, + "(stream={}) IO Error when closing substream: {}", + self.stream_id(), + err + ); } debug!( target: LOG_TARGET, - "RpcClientWorker ({}) terminated.", + "(stream={}) RpcClientWorker ({}) terminated.", + self.stream_id(), self.protocol_name() ); } @@ -477,14 +484,20 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send debug!( target: LOG_TARGET, - "Ping (protocol {}) sent in {:.2?}", + "(stream={}) Ping (protocol {}) sent in {:.2?}", + self.stream_id(), self.protocol_name(), start.elapsed() ); let resp = match self.read_reply().await { Ok(resp) => resp, Err(RpcError::ReplyTimeout) => { - debug!(target: LOG_TARGET, "Ping timed out after {:.0?}", start.elapsed()); + debug!( + target: LOG_TARGET, + "(stream={}) Ping timed out after {:.0?}", + self.stream_id(), + start.elapsed() + ); let _ = reply.send(Err(RpcStatus::timed_out("Response timed out"))); return Ok(()); }, @@ -499,7 +512,12 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send let resp_flags = RpcMessageFlags::from_bits_truncate(resp.flags as u8); if !resp_flags.contains(RpcMessageFlags::ACK) { - warn!(target: LOG_TARGET, "Invalid ping response {:?}", resp); + warn!( + target: LOG_TARGET, + "(stream={}) Invalid ping response {:?}", + self.stream_id(), + resp + ); let _ = reply.send(Err(RpcStatus::protocol_error(format!( "Received invalid ping response on protocol '{}'", self.protocol_name() @@ -613,8 +631,9 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send if response_tx.is_closed() { warn!( target: LOG_TARGET, - "Response receiver was dropped before the response/stream could complete for protocol {}, \ - the stream will continue until completed", + "(stream={}) Response receiver was dropped before the response/stream could complete for \ + protocol {}, the stream will continue until completed", + self.framed.get_ref().id(), self.protocol_name() ); } else { diff --git a/comms/src/protocol/rpc/handshake.rs b/comms/src/protocol/rpc/handshake.rs index b39c15e6d7..4e07a0294d 100644 --- a/comms/src/protocol/rpc/handshake.rs +++ b/comms/src/protocol/rpc/handshake.rs @@ -138,15 +138,19 @@ where T: AsyncRead + AsyncWrite + Unpin let msg = proto::rpc::RpcSession { supported_versions: SUPPORTED_RPC_VERSIONS.to_vec(), }; + let payload = msg.to_encoded_bytes(); + debug!(target: LOG_TARGET, "Sending client handshake ({} bytes)", payload.len()); // It is possible that the server rejects the session and closes the substream before we've had a chance to send // anything. Rather than returning an IO error, let's ignore the send error and see if we can receive anything, // or return an IO error similarly to what send would have done. - if let Err(err) = self.framed.send(msg.to_encoded_bytes().into()).await { + if let Err(err) = self.framed.send(payload.into()).await { warn!( target: LOG_TARGET, "IO error when sending new session handshake to peer: {}", err ); + panic!(); } + self.framed.flush().await?; match self.recv_next_frame().await { Ok(Some(Ok(msg))) => { let msg = proto::rpc::RpcSessionReply::decode(&mut msg.freeze())?; diff --git a/comms/src/protocol/rpc/mod.rs b/comms/src/protocol/rpc/mod.rs index 2244979adf..33208df391 100644 --- a/comms/src/protocol/rpc/mod.rs +++ b/comms/src/protocol/rpc/mod.rs @@ -63,6 +63,7 @@ pub const RPC_MAX_FRAME_SIZE: usize = 4 * 1024 * 1024; // 4 MiB pub mod __macro_reexports { pub use crate::{ framing::CanonicalFraming, + multiplexing::Substream, protocol::{ rpc::{ client_pool::RpcPoolClient, diff --git a/comms/src/protocol/rpc/server/mock.rs b/comms/src/protocol/rpc/server/mock.rs index 69659ba03b..dae0f9ce93 100644 --- a/comms/src/protocol/rpc/server/mock.rs +++ b/comms/src/protocol/rpc/server/mock.rs @@ -194,14 +194,14 @@ impl RpcCommsProvider for MockCommsProvider { } } -pub struct MockRpcServer<TSvc, TSubstream> { - inner: Option<PeerRpcServer<TSvc, TSubstream, MockCommsProvider>>, - protocol_tx: ProtocolNotificationTx<TSubstream>, +pub struct MockRpcServer<TSvc> { + inner: Option<PeerRpcServer<TSvc, MockCommsProvider>>, + protocol_tx: ProtocolNotificationTx<Substream>, our_node: Arc<NodeIdentity>, request_tx: mpsc::Sender<RpcServerRequest>, } -impl<TSvc> MockRpcServer<TSvc, Substream> +impl<TSvc> MockRpcServer<TSvc> where TSvc: MakeService< ProtocolId, @@ -259,7 +259,7 @@ where } } -impl MockRpcServer<MockRpcImpl, Substream> { +impl MockRpcServer<MockRpcImpl> { pub async fn create_mockimpl_connection(&self, peer: Peer) -> PeerConnection { // MockRpcImpl accepts any protocol self.create_connection(peer, ProtocolId::new()).await diff --git a/comms/src/protocol/rpc/server/mod.rs b/comms/src/protocol/rpc/server/mod.rs index 88fdb7ee61..2633d435c0 100644 --- a/comms/src/protocol/rpc/server/mod.rs +++ b/comms/src/protocol/rpc/server/mod.rs @@ -52,19 +52,15 @@ use crate::{ proto, protocol::{ProtocolEvent, ProtocolId, ProtocolNotification, ProtocolNotificationRx}, Bytes, + Substream, }; use futures::SinkExt; use prost::Message; use std::{ - borrow::Cow, future::Future, time::{Duration, Instant}, }; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::mpsc, - time, -}; +use tokio::{sync::mpsc, time}; use tokio_stream::StreamExt; use tower::Service; use tower_make::MakeService; @@ -116,14 +112,13 @@ impl RpcServer { RpcServerHandle::new(self.request_tx.clone()) } - pub(super) async fn serve<S, TSubstream, TCommsProvider>( + pub(super) async fn serve<S, TCommsProvider>( self, service: S, - notifications: ProtocolNotificationRx<TSubstream>, + notifications: ProtocolNotificationRx<Substream>, comms_provider: TCommsProvider, ) -> Result<(), RpcServerError> where - TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: MakeService< ProtocolId, Request<Bytes>, @@ -197,18 +192,17 @@ impl Default for RpcServerBuilder { } } -pub(super) struct PeerRpcServer<TSvc, TSubstream, TCommsProvider> { +pub(super) struct PeerRpcServer<TSvc, TCommsProvider> { executor: BoundedExecutor, config: RpcServerBuilder, service: TSvc, - protocol_notifications: Option<ProtocolNotificationRx<TSubstream>>, + protocol_notifications: Option<ProtocolNotificationRx<Substream>>, comms_provider: TCommsProvider, request_rx: mpsc::Receiver<RpcServerRequest>, } -impl<TSvc, TSubstream, TCommsProvider> PeerRpcServer<TSvc, TSubstream, TCommsProvider> +impl<TSvc, TCommsProvider> PeerRpcServer<TSvc, TCommsProvider> where - TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static, TSvc: MakeService< ProtocolId, Request<Bytes>, @@ -225,7 +219,7 @@ where fn new( config: RpcServerBuilder, service: TSvc, - protocol_notifications: ProtocolNotificationRx<TSubstream>, + protocol_notifications: ProtocolNotificationRx<Substream>, comms_provider: TCommsProvider, request_rx: mpsc::Receiver<RpcServerRequest>, ) -> Self { @@ -289,7 +283,7 @@ where #[tracing::instrument(name = "rpc::server::new_client_connection", skip(self, notification), err)] async fn handle_protocol_notification( &mut self, - notification: ProtocolNotification<TSubstream>, + notification: ProtocolNotification<Substream>, ) -> Result<(), RpcServerError> { match notification.event { ProtocolEvent::NewInboundSubstream(node_id, substream) => { @@ -318,7 +312,7 @@ where &mut self, protocol: ProtocolId, node_id: NodeId, - mut framed: CanonicalFraming<TSubstream>, + mut framed: CanonicalFraming<Substream>, ) -> Result<(), RpcServerError> { let mut handshake = Handshake::new(&mut framed).with_timeout(self.config.handshake_timeout); @@ -357,14 +351,14 @@ where "Server negotiated RPC v{} with client node `{}`", version, node_id ); - let service = ActivePeerRpcService { - config: self.config.clone(), + let service = ActivePeerRpcService::new( + self.config.clone(), protocol, - node_id: node_id.clone(), - framed, + node_id.clone(), service, - comms_provider: self.comms_provider.clone(), - }; + framed, + self.comms_provider.clone(), + ); self.executor .try_spawn(service.start()) @@ -374,64 +368,91 @@ where } } -struct ActivePeerRpcService<TSvc, TSubstream, TCommsProvider> { +struct ActivePeerRpcService<TSvc, TCommsProvider> { config: RpcServerBuilder, protocol: ProtocolId, node_id: NodeId, service: TSvc, - framed: CanonicalFraming<TSubstream>, + framed: CanonicalFraming<Substream>, comms_provider: TCommsProvider, + logging_context_string: String, } -impl<TSvc, TSubstream, TCommsProvider> ActivePeerRpcService<TSvc, TSubstream, TCommsProvider> +impl<TSvc, TCommsProvider> ActivePeerRpcService<TSvc, TCommsProvider> where - TSubstream: AsyncRead + AsyncWrite + Unpin, TSvc: Service<Request<Bytes>, Response = Response<Body>, Error = RpcStatus>, TCommsProvider: RpcCommsProvider + Send + Clone + 'static, { + pub(self) fn new( + config: RpcServerBuilder, + protocol: ProtocolId, + node_id: NodeId, + service: TSvc, + framed: CanonicalFraming<Substream>, + comms_provider: TCommsProvider, + ) -> Self { + Self { + logging_context_string: format!( + "stream_id: {}, peer: {}, protocol: {}", + framed.get_ref().id(), + node_id, + String::from_utf8_lossy(&protocol) + ), + + config, + protocol, + node_id, + service, + framed, + comms_provider, + } + } + async fn start(mut self) { debug!( target: LOG_TARGET, - "(Peer = `{}`) Rpc server ({}) started.", - self.node_id, - self.protocol_name() + "({}) Rpc server started.", self.logging_context_string, ); if let Err(err) = self.run().await { error!( target: LOG_TARGET, - "(Peer = `{}`) Rpc server ({}) exited with an error: {}", - self.node_id, - self.protocol_name(), - err + "({}) Rpc server exited with an error: {}", self.logging_context_string, err ); } debug!( target: LOG_TARGET, - "(Peer = {}) Rpc service ({}) shutdown", - self.node_id, - self.protocol_name() + "({}) Rpc service shutdown", self.logging_context_string ); } - fn protocol_name(&self) -> Cow<'_, str> { - String::from_utf8_lossy(&self.protocol) - } - async fn run(&mut self) -> Result<(), RpcServerError> { while let Some(result) = self.framed.next().await { - let start = Instant::now(); - if let Err(err) = self.handle(result?.freeze()).await { - self.framed.close().await?; - return Err(err); + match result { + Ok(frame) => { + let start = Instant::now(); + if let Err(err) = self.handle(frame.freeze()).await { + self.framed.close().await?; + return Err(err); + } + let elapsed = start.elapsed(); + debug!( + target: LOG_TARGET, + "({}) RPC request completed in {:.0?}{}", + self.logging_context_string, + elapsed, + if elapsed.as_secs() > 5 { " (LONG REQUEST)" } else { "" } + ); + }, + Err(err) => { + if let Err(err) = self.framed.close().await { + error!( + target: LOG_TARGET, + "({}) Failed to close substream after socket error: {}", self.logging_context_string, err + ); + } + return Err(err.into()); + }, } - let elapsed = start.elapsed(); - debug!( - target: LOG_TARGET, - "RPC ({}) request completed in {:.0?}{}", - self.protocol_name(), - elapsed, - if elapsed.as_secs() > 5 { " (LONG REQUEST)" } else { "" } - ); } self.framed.close().await?; @@ -450,7 +471,7 @@ where if deadline < self.config.minimum_client_deadline { debug!( target: LOG_TARGET, - "[Peer=`{}`] Client has an invalid deadline. {}", self.node_id, decoded_msg + "({}) Client has an invalid deadline. {}", self.logging_context_string, decoded_msg ); // Let the client know that they have disobeyed the spec let status = RpcStatus::bad_request(format!( @@ -471,9 +492,7 @@ where if msg_flags.contains(RpcMessageFlags::ACK) { debug!( target: LOG_TARGET, - "[Peer=`{}` {}] sending ACK response.", - self.node_id, - self.protocol_name() + "({}) sending ACK response.", self.logging_context_string ); let ack = proto::rpc::RpcResponse { request_id, @@ -487,7 +506,7 @@ where debug!( target: LOG_TARGET, - "[Peer=`{}`] Got request {}", self.node_id, decoded_msg + "({}) Got request {}", self.logging_context_string, decoded_msg ); let req = Request::with_context( @@ -496,7 +515,12 @@ where decoded_msg.message.into(), ); - let service_call = log_timing(request_id, "service call", self.service.call(req)); + let service_call = log_timing( + self.logging_context_string.clone(), + request_id, + "service call", + self.service.call(req), + ); let service_result = time::timeout(deadline, service_call).await; let service_result = match service_result { Ok(v) => v, @@ -545,7 +569,12 @@ where let mut message = body.into_message(); loop { - let msg_read = log_timing(request_id, "message read", message.next()); + let msg_read = log_timing( + self.logging_context_string.clone(), + request_id, + "message read", + message.next(), + ); match time::timeout(deadline, msg_read).await { Ok(Some(msg)) => { let resp = match msg { @@ -573,8 +602,13 @@ where }, }; - let is_valid = - log_timing(request_id, "transmit", self.send_response(request_id, resp)).await?; + let is_valid = log_timing( + self.logging_context_string.clone(), + request_id, + "transmit", + self.send_response(request_id, resp), + ) + .await?; if !is_valid { break; @@ -647,14 +681,15 @@ where } } -async fn log_timing<R, F: Future<Output = R>>(request_id: u32, tag: &str, fut: F) -> R { +async fn log_timing<R, F: Future<Output = R>>(context_str: String, request_id: u32, tag: &str, fut: F) -> R { let t = Instant::now(); let span = span!(Level::TRACE, "rpc::internal::timing::{}::{}", request_id, tag); let ret = fut.instrument(span).await; let elapsed = t.elapsed(); trace!( target: LOG_TARGET, - "RPC TIMING(REQ_ID={}): '{}' took {:.2}s{}", + "({}) RPC TIMING(REQ_ID={}): '{}' took {:.2}s{}", + context_str, request_id, tag, elapsed.as_secs_f32(), diff --git a/comms/src/protocol/rpc/server/router.rs b/comms/src/protocol/rpc/server/router.rs index 1d40988075..342454e122 100644 --- a/comms/src/protocol/rpc/server/router.rs +++ b/comms/src/protocol/rpc/server/router.rs @@ -42,6 +42,7 @@ use crate::{ }, runtime::task, Bytes, + Substream, }; use futures::{ future::BoxFuture, @@ -49,10 +50,7 @@ use futures::{ FutureExt, }; use std::sync::Arc; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::mpsc, -}; +use tokio::sync::mpsc; use tower::Service; use tower_make::MakeService; @@ -133,13 +131,12 @@ where <B::Service as Service<Request<Bytes>>>::Future: Send + 'static, { /// Start all services - pub(crate) async fn serve<TSubstream, TCommsProvider>( + pub(crate) async fn serve<TCommsProvider>( self, - protocol_notifications: ProtocolNotificationRx<TSubstream>, + protocol_notifications: ProtocolNotificationRx<Substream>, comms_provider: TCommsProvider, ) -> Result<(), RpcError> where - TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static, TCommsProvider: RpcCommsProvider + Clone + Send + 'static, { self.server diff --git a/comms/src/protocol/rpc/test/greeting_service.rs b/comms/src/protocol/rpc/test/greeting_service.rs index b303ce5fc7..f66221b5ae 100644 --- a/comms/src/protocol/rpc/test/greeting_service.rs +++ b/comms/src/protocol/rpc/test/greeting_service.rs @@ -27,6 +27,7 @@ use crate::{ ProtocolId, }, utils, + Substream, }; use core::iter; use std::{ @@ -393,14 +394,13 @@ impl __rpc_deps::NamedProtocolService for GreetingClient { } impl GreetingClient { - pub async fn connect<TSubstream>(framed: __rpc_deps::CanonicalFraming<TSubstream>) -> Result<Self, RpcError> - where TSubstream: __rpc_deps::AsyncRead + __rpc_deps::AsyncWrite + Unpin + Send + 'static { + pub async fn connect(framed: __rpc_deps::CanonicalFraming<Substream>) -> Result<Self, RpcError> { let inner = __rpc_deps::RpcClient::connect(Default::default(), framed, Self::PROTOCOL_NAME.into()).await?; Ok(Self { inner }) } pub fn builder() -> __rpc_deps::RpcClientBuilder<Self> { - __rpc_deps::RpcClientBuilder::new() + __rpc_deps::RpcClientBuilder::new().with_protocol_id(Self::PROTOCOL_NAME.into()) } pub async fn say_hello(&mut self, request: SayHelloRequest) -> Result<SayHelloResponse, RpcError> { diff --git a/comms/src/protocol/rpc/test/smoke.rs b/comms/src/protocol/rpc/test/smoke.rs index bc0f1bb25d..553c0001cd 100644 --- a/comms/src/protocol/rpc/test/smoke.rs +++ b/comms/src/protocol/rpc/test/smoke.rs @@ -22,7 +22,7 @@ use crate::{ framing, - memsocket::MemorySocket, + multiplexing::Yamux, protocol::{ rpc::{ context::RpcCommsBackend, @@ -50,10 +50,11 @@ use crate::{ ProtocolNotification, }, runtime, - test_utils::node_identity::build_node_identity, + test_utils::{node_identity::build_node_identity, transport::build_multiplexed_connections}, NodeIdentity, + Substream, }; -use futures::{future, future::Either, StreamExt}; +use futures::StreamExt; use std::{sync::Arc, time::Duration}; use tari_crypto::tari_utilities::hex::Hex; use tari_shutdown::Shutdown; @@ -67,7 +68,7 @@ pub(super) async fn setup_service<T: GreetingRpc>( service_impl: T, num_concurrent_sessions: usize, ) -> ( - mpsc::Sender<ProtocolNotification<MemorySocket>>, + mpsc::Sender<ProtocolNotification<Substream>>, task::JoinHandle<()>, RpcCommsBackend, Shutdown, @@ -86,11 +87,10 @@ pub(super) async fn setup_service<T: GreetingRpc>( .add_service(GreetingServer::new(service_impl)) .serve(notif_rx, context); - futures::pin_mut!(fut); - - match future::select(shutdown_signal, fut).await { - Either::Left(_) => {}, - Either::Right((r, _)) => r.unwrap(), + tokio::select! { + biased; + _ = shutdown_signal => {}, + r = fut => r.unwrap(), } } }); @@ -100,31 +100,35 @@ pub(super) async fn setup_service<T: GreetingRpc>( pub(super) async fn setup<T: GreetingRpc>( service_impl: T, num_concurrent_sessions: usize, -) -> (MemorySocket, task::JoinHandle<()>, Arc<NodeIdentity>, Shutdown) { +) -> (Yamux, Yamux, task::JoinHandle<()>, Arc<NodeIdentity>, Shutdown) { let (notif_tx, server_hnd, context, shutdown) = setup_service(service_impl, num_concurrent_sessions).await; - let (inbound, outbound) = MemorySocket::new_pair(); - let node_identity = build_node_identity(Default::default()); + let (_, inbound, outbound) = build_multiplexed_connections().await; + let substream = outbound.get_yamux_control().open_stream().await.unwrap(); + let node_identity = build_node_identity(Default::default()); // Notify that a peer wants to speak the greeting RPC protocol context.peer_manager().add_peer(node_identity.to_peer()).await.unwrap(); notif_tx .send(ProtocolNotification::new( ProtocolId::from_static(b"/test/greeting/1.0"), - ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), inbound), + ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), substream), )) .await .unwrap(); - (outbound, server_hnd, node_identity, shutdown) + (inbound, outbound, server_hnd, node_identity, shutdown) } #[runtime::test] async fn request_response_errors_and_streaming() { - let (socket, server_hnd, node_identity, mut shutdown) = setup(GreetingService::default(), 1).await; + let (mut muxer, _outbound, server_hnd, node_identity, mut shutdown) = setup(GreetingService::default(), 1).await; + let socket = muxer.incoming_mut().next().await.unwrap(); let framed = framing::canonical(socket, 1024); let mut client = GreetingClient::builder() .with_deadline(Duration::from_secs(5)) + .with_deadline_grace_period(Duration::from_secs(5)) + .with_handshake_timeout(Duration::from_secs(5)) .connect(framed) .await .unwrap(); @@ -200,7 +204,8 @@ async fn request_response_errors_and_streaming() { #[runtime::test] async fn concurrent_requests() { - let (socket, _, _, _shutdown) = setup(GreetingService::default(), 1).await; + let (mut muxer, _outbound, _, _, _shutdown) = setup(GreetingService::default(), 1).await; + let socket = muxer.incoming_mut().next().await.unwrap(); let framed = framing::canonical(socket, 1024); let mut client = GreetingClient::builder() @@ -240,7 +245,8 @@ async fn concurrent_requests() { #[runtime::test] async fn response_too_big() { - let (socket, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; + let (mut muxer, _outbound, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; + let socket = muxer.incoming_mut().next().await.unwrap(); let framed = framing::canonical(socket, RPC_MAX_FRAME_SIZE); let mut client = GreetingClient::builder().connect(framed).await.unwrap(); @@ -261,7 +267,8 @@ async fn response_too_big() { #[runtime::test] async fn ping_latency() { - let (socket, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; + let (mut muxer, _outbound, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; + let socket = muxer.incoming_mut().next().await.unwrap(); let framed = framing::canonical(socket, RPC_MAX_FRAME_SIZE); let mut client = GreetingClient::builder().connect(framed).await.unwrap(); @@ -274,7 +281,8 @@ async fn ping_latency() { #[runtime::test] async fn server_shutdown_before_connect() { - let (socket, _, _, mut shutdown) = setup(GreetingService::new(&[]), 1).await; + let (mut muxer, _outbound, _, _, mut shutdown) = setup(GreetingService::new(&[]), 1).await; + let socket = muxer.incoming_mut().next().await.unwrap(); let framed = framing::canonical(socket, 1024); shutdown.trigger(); @@ -288,7 +296,8 @@ async fn server_shutdown_before_connect() { #[runtime::test] async fn timeout() { let delay = Arc::new(RwLock::new(Duration::from_secs(10))); - let (socket, _, _, _shutdown) = setup(SlowGreetingService::new(delay.clone()), 1).await; + let (mut muxer, _outbound, _, _, _shutdown) = setup(SlowGreetingService::new(delay.clone()), 1).await; + let socket = muxer.incoming_mut().next().await.unwrap(); let framed = framing::canonical(socket, 1024); let mut client = GreetingClient::builder() .with_deadline(Duration::from_secs(1)) @@ -313,7 +322,9 @@ async fn timeout() { async fn unknown_protocol() { let (notif_tx, _, _, _shutdown) = setup_service(GreetingService::new(&[]), 1).await; - let (inbound, socket) = MemorySocket::new_pair(); + let (_, inbound, mut outbound) = build_multiplexed_connections().await; + let in_substream = inbound.get_yamux_control().open_stream().await.unwrap(); + let node_identity = build_node_identity(Default::default()); // This case should never happen because protocols are preregistered with the connection manager and so a @@ -322,12 +333,13 @@ async fn unknown_protocol() { notif_tx .send(ProtocolNotification::new( ProtocolId::from_static(b"this-is-junk"), - ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), inbound), + ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), in_substream), )) .await .unwrap(); - let framed = framing::canonical(socket, 1024); + let out_socket = outbound.incoming_mut().next().await.unwrap(); + let framed = framing::canonical(out_socket, 1024); let err = GreetingClient::connect(framed).await.unwrap_err(); assert!(matches!( err, @@ -337,7 +349,8 @@ async fn unknown_protocol() { #[runtime::test] async fn rejected_no_sessions_available() { - let (socket, _, _, _shutdown) = setup(GreetingService::new(&[]), 0).await; + let (mut muxer, _outbound, _, _, _shutdown) = setup(GreetingService::new(&[]), 0).await; + let socket = muxer.incoming_mut().next().await.unwrap(); let framed = framing::canonical(socket, 1024); let err = GreetingClient::builder().connect(framed).await.unwrap_err(); assert!(matches!( @@ -349,7 +362,8 @@ async fn rejected_no_sessions_available() { #[runtime::test] async fn stream_still_works_after_cancel() { let service_impl = GreetingService::default(); - let (socket, _, _, _shutdown) = setup(service_impl.clone(), 1).await; + let (mut muxer, _outbound, _, _, _shutdown) = setup(service_impl.clone(), 1).await; + let socket = muxer.incoming_mut().next().await.unwrap(); let framed = framing::canonical(socket, 1024); let mut client = GreetingClient::builder()