diff --git a/comms/core/src/protocol/messaging/inbound.rs b/comms/core/src/protocol/messaging/inbound.rs index 895ba1db32..f2c393af86 100644 --- a/comms/core/src/protocol/messaging/inbound.rs +++ b/comms/core/src/protocol/messaging/inbound.rs @@ -22,7 +22,7 @@ use std::io; -use futures::{future::Either, SinkExt, StreamExt}; +use futures::{future, future::Either, SinkExt, StreamExt}; use log::*; use tari_shutdown::ShutdownSignal; use tokio::{ @@ -33,7 +33,7 @@ use tokio::{ #[cfg(feature = "metrics")] use super::metrics; use super::{MessagingEvent, MessagingProtocol}; -use crate::{message::InboundMessage, peer_manager::NodeId, protocol::rpc::__macro_reexports::future}; +use crate::{message::InboundMessage, peer_manager::NodeId}; const LOG_TARGET: &str = "comms::protocol::messaging::inbound"; diff --git a/comms/core/src/protocol/rpc/client/mod.rs b/comms/core/src/protocol/rpc/client/mod.rs index 292c397930..1995715100 100644 --- a/comms/core/src/protocol/rpc/client/mod.rs +++ b/comms/core/src/protocol/rpc/client/mod.rs @@ -75,7 +75,6 @@ use crate::{ RpcError, RpcServerError, RpcStatus, - RPC_CHUNKING_MAX_CHUNKS, }, ProtocolId, }, @@ -932,53 +931,17 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin pub async fn read_response(&mut self) -> Result { let timer = Instant::now(); - let mut resp = self.next().await?; + let resp = self.next().await?; self.time_to_first_msg = Some(timer.elapsed()); self.check_response(&resp)?; - let mut chunk_count = 1; - let mut last_chunk_flags = - RpcMessageFlags::from_bits(u8::try_from(resp.flags).map_err(|_| { - RpcStatus::protocol_error(&format!("invalid message flag: must be less than {}", u8::MAX)) - })?) - .ok_or(RpcStatus::protocol_error(&format!( - "invalid message flag, does not match any flags ({})", - resp.flags - )))?; - let mut last_chunk_size = resp.payload.len(); - self.bytes_read += last_chunk_size; - loop { - trace!( - target: LOG_TARGET, - "Chunk {} received (flags={:?}, {} bytes, {} total)", - chunk_count, - last_chunk_flags, - last_chunk_size, - resp.payload.len() - ); - if !last_chunk_flags.is_more() { - return Ok(resp); - } - - if chunk_count >= RPC_CHUNKING_MAX_CHUNKS { - return Err(RpcError::RemotePeerExceededMaxChunkCount { - expected: RPC_CHUNKING_MAX_CHUNKS, - }); - } - - let msg = self.next().await?; - last_chunk_flags = RpcMessageFlags::from_bits(u8::try_from(msg.flags).map_err(|_| { - RpcStatus::protocol_error(&format!("invalid message flag: must be less than {}", u8::MAX)) - })?) - .ok_or(RpcStatus::protocol_error(&format!( - "invalid message flag, does not match any flags ({})", - resp.flags - )))?; - last_chunk_size = msg.payload.len(); - self.bytes_read += last_chunk_size; - self.check_response(&resp)?; - resp.payload.extend(msg.payload); - chunk_count += 1; - } + self.bytes_read = resp.payload.len(); + trace!( + target: LOG_TARGET, + "Received {} bytes in {:.2?}", + resp.payload.len(), + self.time_to_first_msg.unwrap_or_default() + ); + Ok(resp) } pub async fn read_ack(&mut self) -> Result { diff --git a/comms/core/src/protocol/rpc/message.rs b/comms/core/src/protocol/rpc/message.rs index ed377fe50f..105e524359 100644 --- a/comms/core/src/protocol/rpc/message.rs +++ b/comms/core/src/protocol/rpc/message.rs @@ -24,19 +24,24 @@ use std::{convert::TryFrom, fmt, time::Duration}; use bitflags::bitflags; use bytes::Bytes; +use log::warn; use super::RpcError; use crate::{ proto, proto::rpc::rpc_session_reply::SessionResult, - protocol::rpc::{ - body::{Body, IntoBody}, - context::RequestContext, - error::HandshakeRejectReason, - RpcStatusCode, + protocol::{ + rpc, + rpc::{ + body::{Body, IntoBody}, + context::RequestContext, + error::HandshakeRejectReason, + RpcStatusCode, + }, }, }; +const LOG_TARGET: &str = "comms::rpc::message"; #[derive(Debug)] pub struct Request { pub(super) context: Option, @@ -203,8 +208,6 @@ bitflags! { const FIN = 0x01; /// Typically sent with empty contents and used to confirm a substream is alive. const ACK = 0x02; - /// Another chunk to be received - const MORE = 0x04; } } impl RpcMessageFlags { @@ -215,10 +218,6 @@ impl RpcMessageFlags { pub fn is_ack(self) -> bool { self.contains(Self::ACK) } - - pub fn is_more(self) -> bool { - self.contains(Self::MORE) - } } impl Default for RpcMessageFlags { @@ -276,6 +275,21 @@ impl RpcResponse { payload: self.payload.to_vec(), } } + + pub fn exceeded_message_size(self) -> RpcResponse { + let msg = format!( + "The response size exceeded the maximum allowed payload size. Max = {} bytes, Got = {} bytes", + rpc::max_response_payload_size() as f32, + self.payload.len() as f32, + ); + warn!(target: LOG_TARGET, "{}", msg); + RpcResponse { + request_id: self.request_id, + status: RpcStatusCode::MalformedResponse, + flags: RpcMessageFlags::FIN, + payload: msg.into_bytes().into(), + } + } } impl Default for RpcResponse { diff --git a/comms/core/src/protocol/rpc/mod.rs b/comms/core/src/protocol/rpc/mod.rs index 35b5ebda6f..23a0e6e1d4 100644 --- a/comms/core/src/protocol/rpc/mod.rs +++ b/comms/core/src/protocol/rpc/mod.rs @@ -30,23 +30,14 @@ mod test; /// Maximum frame size of each RPC message. This is enforced in tokio's length delimited codec. /// This can be thought of as the hard limit on message size. -pub const RPC_MAX_FRAME_SIZE: usize = 3 * 1024 * 1024; // 3 MiB -/// Maximum number of chunks into which a message can be broken up. -const RPC_CHUNKING_MAX_CHUNKS: usize = 16; // 16 x 256 Kib = 4 MiB max combined message size -const RPC_CHUNKING_THRESHOLD: usize = 256 * 1024; -const RPC_CHUNKING_SIZE_LIMIT: usize = 384 * 1024; +pub const RPC_MAX_FRAME_SIZE: usize = 4 * 1024 * 1024; // 4 MiB /// The maximum request payload size const fn max_request_size() -> usize { RPC_MAX_FRAME_SIZE } -/// The maximum size for a single RPC response message -const fn max_response_size() -> usize { - RPC_CHUNKING_MAX_CHUNKS * RPC_CHUNKING_THRESHOLD -} - -/// The maximum size for a single RPC response excluding overhead +/// The maximum size for a single RPC response body excluding response header overhead const fn max_response_payload_size() -> usize { // RpcResponse overhead is: // - 4 varint protobuf fields, each field ID is 1 byte @@ -54,7 +45,7 @@ const fn max_response_payload_size() -> usize { // - 1 length varint for the payload, allow for 5 bytes to be safe (max_payload_size being technically too small is // fine, being too large isn't) const MAX_HEADER_SIZE: usize = 4 + 4 * 5; - max_response_size() - MAX_HEADER_SIZE + RPC_MAX_FRAME_SIZE - MAX_HEADER_SIZE } mod body; diff --git a/comms/core/src/protocol/rpc/server/chunking.rs b/comms/core/src/protocol/rpc/server/chunking.rs deleted file mode 100644 index d721af9c99..0000000000 --- a/comms/core/src/protocol/rpc/server/chunking.rs +++ /dev/null @@ -1,272 +0,0 @@ -// Copyright 2021, The Tari Project -// -// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the -// following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following -// disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the -// following disclaimer in the documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote -// products derived from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, -// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -use std::cmp; - -use bytes::Bytes; -use log::*; - -use super::LOG_TARGET; -use crate::{ - proto, - protocol::{ - rpc, - rpc::{ - message::{RpcMessageFlags, RpcResponse}, - RpcStatusCode, - RPC_CHUNKING_SIZE_LIMIT, - RPC_CHUNKING_THRESHOLD, - }, - }, -}; - -pub(super) struct ChunkedResponseIter { - message: RpcResponse, - initial_payload_size: usize, - has_emitted_once: bool, - num_chunks: usize, - total_chunks: usize, -} - -fn calculate_total_chunk_count(payload_len: usize) -> usize { - let mut total_chunks = payload_len / RPC_CHUNKING_THRESHOLD; - let excess = (payload_len % RPC_CHUNKING_THRESHOLD) + RPC_CHUNKING_THRESHOLD; - if total_chunks == 0 || excess > RPC_CHUNKING_SIZE_LIMIT { - // If the chunk (threshold size) + excess cannot fit in the RPC_CHUNKING_SIZE_LIMIT, then we'll emit another - // frame smaller than threshold size - total_chunks += 1; - } - - total_chunks -} - -impl ChunkedResponseIter { - pub fn new(message: RpcResponse) -> Self { - let len = message.payload.len(); - Self { - initial_payload_size: message.payload.len(), - message, - has_emitted_once: false, - num_chunks: 0, - total_chunks: calculate_total_chunk_count(len), - } - } - - fn remaining(&self) -> usize { - self.message.payload.len() - } - - fn payload_mut(&mut self) -> &mut Bytes { - &mut self.message.payload - } - - fn payload(&self) -> &Bytes { - &self.message.payload - } - - fn get_next_chunk(&mut self) -> Option { - let len = self.payload().len(); - if len == 0 { - if self.num_chunks > 1 { - debug!( - target: LOG_TARGET, - "Emitted {} chunks (Avg.Size: {} bytes, Total: {} bytes)", - self.num_chunks, - self.initial_payload_size / self.num_chunks, - self.initial_payload_size - ); - } - return None; - } - - // If the payload is within the maximum chunk size, simply return the rest of it - if len <= RPC_CHUNKING_SIZE_LIMIT { - let chunk = self.payload_mut().split_to(len); - self.num_chunks += 1; - trace!( - target: LOG_TARGET, - "Emitting chunk {}/{} ({} bytes)", - self.num_chunks, - self.total_chunks, - chunk.len() - ); - return Some(chunk); - } - - let chunk_size = cmp::min(len, RPC_CHUNKING_THRESHOLD); - let chunk = self.payload_mut().split_to(chunk_size); - - self.num_chunks += 1; - trace!( - target: LOG_TARGET, - "Emitting chunk {}/{} ({} bytes)", - self.num_chunks, - self.total_chunks, - chunk.len() - ); - Some(chunk) - } - - fn is_last_chunk(&self) -> bool { - self.num_chunks == self.total_chunks - } - - fn exceeded_message_size(&self) -> proto::rpc::RpcResponse { - const BYTES_PER_MB: f32 = 1024.0 * 1024.0; - // Precision loss is acceptable because this is for display purposes only - let msg = format!( - "The response size exceeded the maximum allowed payload size. Max = {:.4} MiB, Got = {:.4} MiB", - rpc::max_response_payload_size() as f32 / BYTES_PER_MB, - self.message.payload.len() as f32 / BYTES_PER_MB, - ); - warn!(target: LOG_TARGET, "{}", msg); - proto::rpc::RpcResponse { - request_id: self.message.request_id, - status: RpcStatusCode::MalformedResponse as u32, - flags: RpcMessageFlags::FIN.bits().into(), - payload: msg.into_bytes(), - } - } -} - -impl Iterator for ChunkedResponseIter { - type Item = proto::rpc::RpcResponse; - - fn next(&mut self) -> Option { - // Edge case: the initial message has an empty payload. - if self.initial_payload_size == 0 { - if self.has_emitted_once { - return None; - } - self.has_emitted_once = true; - return Some(self.message.to_proto()); - } - - // Edge case: the total message size cannot fit into the maximum allowed chunks - if self.remaining() > rpc::max_response_payload_size() { - if self.has_emitted_once { - return None; - } - self.has_emitted_once = true; - return Some(self.exceeded_message_size()); - } - - let request_id = self.message.request_id; - let chunk = self.get_next_chunk()?; - - // status MUST be set for the first chunked message, all subsequent chunk messages MUST have a status of 0 - let mut status = 0; - if !self.has_emitted_once { - status = self.message.status as u32; - } - self.has_emitted_once = true; - - let mut flags = self.message.flags; - if !self.is_last_chunk() { - // For all chunks except the last the MORE flag MUST be set - flags |= RpcMessageFlags::MORE; - } - let msg = proto::rpc::RpcResponse { - request_id, - status, - flags: flags.bits().into(), - payload: chunk.to_vec(), - }; - - Some(msg) - } -} - -#[cfg(test)] -mod test { - use std::{convert::TryFrom, iter}; - - use super::*; - - fn create(size: usize) -> ChunkedResponseIter { - let msg = RpcResponse { - payload: iter::repeat(0).take(size).collect(), - ..Default::default() - }; - ChunkedResponseIter::new(msg) - } - - #[test] - fn it_emits_a_zero_size_message() { - let iter = create(0); - assert_eq!(iter.total_chunks, 1); - let msgs = iter.collect::>(); - assert_eq!(msgs.len(), 1); - assert!(!RpcMessageFlags::from_bits(u8::try_from(msgs[0].flags).unwrap()) - .unwrap() - .is_more()); - } - - #[test] - fn it_emits_one_message_below_threshold() { - let iter = create(RPC_CHUNKING_THRESHOLD - 1); - assert_eq!(iter.total_chunks, 1); - let msgs = iter.collect::>(); - assert_eq!(msgs.len(), 1); - assert!(!RpcMessageFlags::from_bits(u8::try_from(msgs[0].flags).unwrap()) - .unwrap() - .is_more()); - } - - #[test] - fn it_emits_a_single_message() { - let iter = create(RPC_CHUNKING_SIZE_LIMIT - 1); - assert_eq!(iter.count(), 1); - - let iter = create(RPC_CHUNKING_SIZE_LIMIT); - assert_eq!(iter.count(), 1); - } - - #[test] - fn it_emits_an_expected_number_of_chunks() { - let iter = create(RPC_CHUNKING_THRESHOLD * 2); - assert_eq!(iter.count(), 2); - - let diff = RPC_CHUNKING_SIZE_LIMIT - RPC_CHUNKING_THRESHOLD; - let iter = create(RPC_CHUNKING_THRESHOLD * 2 + diff); - assert_eq!(iter.count(), 2); - - let iter = create(RPC_CHUNKING_THRESHOLD * 2 + diff + 1); - assert_eq!(iter.count(), 3); - } - - #[test] - fn it_sets_the_more_flag_except_last() { - use std::convert::TryFrom; - let iter = create(RPC_CHUNKING_THRESHOLD * 3); - let msgs = iter.collect::>(); - assert!(RpcMessageFlags::from_bits(u8::try_from(msgs[0].flags).unwrap()) - .unwrap() - .is_more()); - assert!(RpcMessageFlags::from_bits(u8::try_from(msgs[1].flags).unwrap()) - .unwrap() - .is_more()); - assert!(!RpcMessageFlags::from_bits(u8::try_from(msgs[2].flags).unwrap()) - .unwrap() - .is_more()); - } -} diff --git a/comms/core/src/protocol/rpc/server/mod.rs b/comms/core/src/protocol/rpc/server/mod.rs index 84996596f4..ae9cbb23ee 100644 --- a/comms/core/src/protocol/rpc/server/mod.rs +++ b/comms/core/src/protocol/rpc/server/mod.rs @@ -20,8 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -mod chunking; -use chunking::ChunkedResponseIter; +// mod chunking; mod error; pub use error::RpcServerError; @@ -52,7 +51,7 @@ use std::{ time::{Duration, Instant}, }; -use futures::{future, stream, stream::FuturesUnordered, SinkExt, StreamExt}; +use futures::{future, stream::FuturesUnordered, SinkExt, StreamExt}; use log::*; use prost::Message; use router::Router; @@ -79,6 +78,7 @@ use crate::{ peer_manager::NodeId, proto, protocol::{ + rpc, rpc::{ body::BodyBytes, message::{RpcMethod, RpcResponse}, @@ -749,12 +749,15 @@ where let mut stream = body .into_message() .map(|result| into_response(request_id, result)) - .flat_map(move |message| { + .map(move |mut message| { + if message.payload.len() > rpc::max_response_payload_size() { + message = message.exceeded_message_size(); + } #[cfg(feature = "metrics")] if !message.status.is_ok() { metrics::status_error_counter(&node_id, &protocol, message.status).inc(); } - stream::iter(ChunkedResponseIter::new(message)) + message.to_proto() }) .map(|resp| Bytes::from(resp.to_encoded_bytes())); diff --git a/comms/core/src/protocol/rpc/test/smoke.rs b/comms/core/src/protocol/rpc/test/smoke.rs index 344d29d2e1..e5e1f52dde 100644 --- a/comms/core/src/protocol/rpc/test/smoke.rs +++ b/comms/core/src/protocol/rpc/test/smoke.rs @@ -282,7 +282,7 @@ async fn response_too_big() { let (_inbound, outbound, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; let socket = outbound.get_yamux_control().open_stream().await.unwrap(); - let framed = framing::canonical(socket, rpc::max_response_size()); + let framed = framing::canonical(socket, rpc::max_request_size()); let mut client = GreetingClient::builder() .with_deadline(Duration::from_secs(5)) .connect(framed) @@ -291,7 +291,7 @@ async fn response_too_big() { // RPC_MAX_FRAME_SIZE bytes will always be too large because of the overhead of the RpcResponse proto message let err = client - .reply_with_msg_of_size(rpc::max_response_payload_size() as u64 + 1) + .reply_with_msg_of_size(rpc::max_response_payload_size() as u64 - 4) .await .unwrap_err(); unpack_enum!(RpcError::RequestFailed(status) = err); @@ -299,7 +299,7 @@ async fn response_too_big() { // Check that the exact frame size boundary works and that the session is still going let _string = client - .reply_with_msg_of_size(rpc::max_response_payload_size() as u64 - 9) + .reply_with_msg_of_size(rpc::max_response_payload_size() as u64 - 5) .await .unwrap(); }