Skip to content

Commit

Permalink
fix: unwraps in rpc client (#5770)
Browse files Browse the repository at this point in the history
Description
---
Removes unwraps in the rpc client code

Motivation and Context
---
There should not be unwraps, and they should propagate errors on invalid
data
  • Loading branch information
SWvheerden authored Sep 14, 2023
1 parent 558e6f2 commit 6f0d20a
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions comms/core/src/protocol/rpc/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId

async fn do_ping_pong(&mut self, reply: oneshot::Sender<Result<Duration, RpcStatus>>) -> Result<(), RpcError> {
let ack = proto::rpc::RpcRequest {
flags: u32::try_from(RpcMessageFlags::ACK.bits()).unwrap(),
flags: u32::from(RpcMessageFlags::ACK.bits()),
deadline: self.config.deadline.map(|t| t.as_secs()).unwrap_or(0),
..Default::default()
};
Expand Down Expand Up @@ -573,7 +573,10 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
return Err(status.into());
}

let resp_flags = RpcMessageFlags::from_bits_truncate(u8::try_from(resp.flags).unwrap());
let resp_flags =
RpcMessageFlags::from_bits_truncate(u8::try_from(resp.flags).map_err(|_| {
RpcStatus::protocol_error(&format!("invalid message flag: must be less than {}", u8::MAX))
})?);
if !resp_flags.contains(RpcMessageFlags::ACK) {
warn!(
target: LOG_TARGET,
Expand Down Expand Up @@ -603,7 +606,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
let request_id = self.next_request_id();
let method = request.method.into();
let req = proto::rpc::RpcRequest {
request_id: u32::try_from(request_id).unwrap(),
request_id: u32::from(request_id),
method,
deadline: self.config.deadline.map(|t| t.as_secs()).unwrap_or(0),
flags: 0,
Expand Down Expand Up @@ -777,7 +780,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
self.protocol_name()
);
let req = proto::rpc::RpcRequest {
request_id: u32::try_from(request_id).unwrap(),
request_id: u32::from(request_id),
method,
flags: RpcMessageFlags::FIN.bits().into(),
deadline: self.config.deadline.map(|d| d.as_secs()).unwrap_or(0),
Expand Down Expand Up @@ -921,7 +924,10 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin
self.time_to_first_msg = Some(timer.elapsed());
self.check_response(&resp)?;
let mut chunk_count = 1;
let mut last_chunk_flags = RpcMessageFlags::from_bits_truncate(u8::try_from(resp.flags).unwrap());
let mut last_chunk_flags =
RpcMessageFlags::from_bits_truncate(u8::try_from(resp.flags).map_err(|_| {
RpcStatus::protocol_error(&format!("invalid message flag: must be less than {}", u8::MAX))
})?);
let mut last_chunk_size = resp.payload.len();
self.bytes_read += last_chunk_size;
loop {
Expand All @@ -944,7 +950,9 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin
}

let msg = self.next().await?;
last_chunk_flags = RpcMessageFlags::from_bits_truncate(u8::try_from(msg.flags).unwrap());
last_chunk_flags = RpcMessageFlags::from_bits_truncate(u8::try_from(msg.flags).map_err(|_| {
RpcStatus::protocol_error(&format!("invalid message flag: must be less than {}", u8::MAX))
})?);
last_chunk_size = msg.payload.len();
self.bytes_read += last_chunk_size;
self.check_response(&resp)?;
Expand All @@ -962,15 +970,20 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin
let resp_id = u16::try_from(resp.request_id)
.map_err(|_| RpcStatus::protocol_error(&format!("invalid request_id: must be less than {}", u16::MAX)))?;

let flags = RpcMessageFlags::from_bits_truncate(u8::try_from(resp.flags).unwrap());
let flags =
RpcMessageFlags::from_bits_truncate(u8::try_from(resp.flags).map_err(|_| {
RpcStatus::protocol_error(&format!("invalid message flag: must be less than {}", u8::MAX))
})?);
if flags.contains(RpcMessageFlags::ACK) {
return Err(RpcError::UnexpectedAckResponse);
}

if resp_id != self.request_id {
return Err(RpcError::ResponseIdDidNotMatchRequest {
expected: self.request_id,
actual: u16::try_from(resp.request_id).unwrap(),
actual: u16::try_from(resp.request_id).map_err(|_| {
RpcStatus::protocol_error(&format!("invalid request_id: must be less than {}", u16::MAX))
})?,
});
}

Expand Down

0 comments on commit 6f0d20a

Please sign in to comment.