Skip to content

Commit

Permalink
Refactor errors (fixes #530)
Browse files Browse the repository at this point in the history
h2::Error now knows whether protocol errors happened because the user
sent them, because it was received from the remote peer, or because
the library itself emitted an error because it detected a protocol
violation.

It also keeps track of whether it came from a RST_STREAM or GO_AWAY
frame, and in the case of the latter, it includes the additional
debug data if any.
  • Loading branch information
nox committed Sep 13, 2021
1 parent 61b4f8f commit 77969a6
Show file tree
Hide file tree
Showing 26 changed files with 465 additions and 433 deletions.
11 changes: 4 additions & 7 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@
//! [`Builder`]: struct.Builder.html
//! [`Error`]: ../struct.Error.html
use crate::codec::{Codec, RecvError, SendError, UserError};
use crate::codec::{Codec, SendError, UserError};
use crate::frame::{Headers, Pseudo, Reason, Settings, StreamId};
use crate::proto;
use crate::proto::{self, Error};
use crate::{FlowControl, PingPong, RecvStream, SendStream};

use bytes::{Buf, Bytes};
Expand Down Expand Up @@ -1493,7 +1493,7 @@ impl proto::Peer for Peer {
pseudo: Pseudo,
fields: HeaderMap,
stream_id: StreamId,
) -> Result<Self::Poll, RecvError> {
) -> Result<Self::Poll, Error> {
let mut b = Response::builder();

b = b.version(Version::HTTP_2);
Expand All @@ -1507,10 +1507,7 @@ impl proto::Peer for Peer {
Err(_) => {
// TODO: Should there be more specialized handling for different
// kinds of errors
return Err(RecvError::Stream {
id: stream_id,
reason: Reason::PROTOCOL_ERROR,
});
return Err(Error::library_reset(stream_id, Reason::PROTOCOL_ERROR));
}
};

Expand Down
49 changes: 5 additions & 44 deletions src/codec/error.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,12 @@
use crate::frame::{Reason, StreamId};
use crate::proto::Error;

use std::{error, fmt, io};

/// Errors that are received
#[derive(Debug)]
pub enum RecvError {
Connection(Reason),
Stream { id: StreamId, reason: Reason },
Io(io::Error),
}

/// Errors caused by sending a message
#[derive(Debug)]
pub enum SendError {
/// User error
Connection(Error),
User(UserError),

/// Connection error prevents sending.
Connection(Reason),

/// I/O error
Io(io::Error),
}

/// Errors caused by users of the library
Expand Down Expand Up @@ -65,47 +51,22 @@ pub enum UserError {
PeerDisabledServerPush,
}

// ===== impl RecvError =====

impl From<io::Error> for RecvError {
fn from(src: io::Error) -> Self {
RecvError::Io(src)
}
}

impl error::Error for RecvError {}

impl fmt::Display for RecvError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
use self::RecvError::*;

match *self {
Connection(ref reason) => reason.fmt(fmt),
Stream { ref reason, .. } => reason.fmt(fmt),
Io(ref e) => e.fmt(fmt),
}
}
}

// ===== impl SendError =====

impl error::Error for SendError {}

impl fmt::Display for SendError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
use self::SendError::*;

match *self {
User(ref e) => e.fmt(fmt),
Connection(ref reason) => reason.fmt(fmt),
Io(ref e) => e.fmt(fmt),
Self::Connection(ref e) => e.fmt(fmt),
Self::User(ref e) => e.fmt(fmt),
}
}
}

impl From<io::Error> for SendError {
fn from(src: io::Error) -> Self {
SendError::Io(src)
Self::Connection(src.into())
}
}

Expand Down
61 changes: 24 additions & 37 deletions src/codec/framed_read.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::codec::RecvError;
use crate::frame::{self, Frame, Kind, Reason};
use crate::frame::{
DEFAULT_MAX_FRAME_SIZE, DEFAULT_SETTINGS_HEADER_TABLE_SIZE, MAX_MAX_FRAME_SIZE,
};
use crate::proto::Error;

use crate::hpack;

Expand Down Expand Up @@ -98,8 +98,7 @@ fn decode_frame(
max_header_list_size: usize,
partial_inout: &mut Option<Partial>,
mut bytes: BytesMut,
) -> Result<Option<Frame>, RecvError> {
use self::RecvError::*;
) -> Result<Option<Frame>, Error> {
let span = tracing::trace_span!("FramedRead::decode_frame", offset = bytes.len());
let _e = span.enter();

Expand All @@ -110,7 +109,7 @@ fn decode_frame(

if partial_inout.is_some() && head.kind() != Kind::Continuation {
proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind());
return Err(Connection(Reason::PROTOCOL_ERROR));
return Err(Error::library_go_away(Reason::PROTOCOL_ERROR).into());
}

let kind = head.kind();
Expand All @@ -131,14 +130,11 @@ fn decode_frame(
// A stream cannot depend on itself. An endpoint MUST
// treat this as a stream error (Section 5.4.2) of type
// `PROTOCOL_ERROR`.
return Err(Stream {
id: $head.stream_id(),
reason: Reason::PROTOCOL_ERROR,
});
return Err(Error::library_reset($head.stream_id(), Reason::PROTOCOL_ERROR));
},
Err(e) => {
proto_err!(conn: "failed to load frame; err={:?}", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
}
};

Expand All @@ -151,14 +147,11 @@ fn decode_frame(
Err(frame::Error::MalformedMessage) => {
let id = $head.stream_id();
proto_err!(stream: "malformed header block; stream={:?}", id);
return Err(Stream {
id,
reason: Reason::PROTOCOL_ERROR,
});
return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
},
Err(e) => {
proto_err!(conn: "failed HPACK decoding; err={:?}", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
}
}

Expand All @@ -183,7 +176,7 @@ fn decode_frame(

res.map_err(|e| {
proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
Error::library_go_away(Reason::PROTOCOL_ERROR)
})?
.into()
}
Expand All @@ -192,7 +185,7 @@ fn decode_frame(

res.map_err(|e| {
proto_err!(conn: "failed to load PING frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
Error::library_go_away(Reason::PROTOCOL_ERROR)
})?
.into()
}
Expand All @@ -201,7 +194,7 @@ fn decode_frame(

res.map_err(|e| {
proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
Error::library_go_away(Reason::PROTOCOL_ERROR)
})?
.into()
}
Expand All @@ -212,7 +205,7 @@ fn decode_frame(
// TODO: Should this always be connection level? Probably not...
res.map_err(|e| {
proto_err!(conn: "failed to load DATA frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
Error::library_go_away(Reason::PROTOCOL_ERROR)
})?
.into()
}
Expand All @@ -221,15 +214,15 @@ fn decode_frame(
let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]);
res.map_err(|e| {
proto_err!(conn: "failed to load RESET frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
Error::library_go_away(Reason::PROTOCOL_ERROR)
})?
.into()
}
Kind::GoAway => {
let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]);
res.map_err(|e| {
proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
Error::library_go_away(Reason::PROTOCOL_ERROR)
})?
.into()
}
Expand All @@ -238,7 +231,7 @@ fn decode_frame(
if head.stream_id() == 0 {
// Invalid stream identifier
proto_err!(conn: "invalid stream ID 0");
return Err(Connection(Reason::PROTOCOL_ERROR));
return Err(Error::library_go_away(Reason::PROTOCOL_ERROR).into());
}

match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) {
Expand All @@ -249,14 +242,11 @@ fn decode_frame(
// `PROTOCOL_ERROR`.
let id = head.stream_id();
proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id);
return Err(Stream {
id,
reason: Reason::PROTOCOL_ERROR,
});
return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
}
Err(e) => {
proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
}
}
}
Expand All @@ -267,14 +257,14 @@ fn decode_frame(
Some(partial) => partial,
None => {
proto_err!(conn: "received unexpected CONTINUATION frame");
return Err(Connection(Reason::PROTOCOL_ERROR));
return Err(Error::library_go_away(Reason::PROTOCOL_ERROR).into());
}
};

// The stream identifiers must match
if partial.frame.stream_id() != head.stream_id() {
proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID");
return Err(Connection(Reason::PROTOCOL_ERROR));
return Err(Error::library_go_away(Reason::PROTOCOL_ERROR).into());
}

// Extend the buf
Expand All @@ -297,7 +287,7 @@ fn decode_frame(
// the attacker to go away.
if partial.buf.len() + bytes.len() > max_header_list_size {
proto_err!(conn: "CONTINUATION frame header block size over ignorable limit");
return Err(Connection(Reason::COMPRESSION_ERROR));
return Err(Error::library_go_away(Reason::COMPRESSION_ERROR).into());
}
}
partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]);
Expand All @@ -312,14 +302,11 @@ fn decode_frame(
Err(frame::Error::MalformedMessage) => {
let id = head.stream_id();
proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id);
return Err(Stream {
id,
reason: Reason::PROTOCOL_ERROR,
});
return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
}
Err(e) => {
proto_err!(conn: "failed HPACK decoding; err={:?}", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
}
}

Expand All @@ -343,7 +330,7 @@ impl<T> Stream for FramedRead<T>
where
T: AsyncRead + Unpin,
{
type Item = Result<Frame, RecvError>;
type Item = Result<Frame, Error>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let span = tracing::trace_span!("FramedRead::poll_next");
Expand Down Expand Up @@ -371,11 +358,11 @@ where
}
}

fn map_err(err: io::Error) -> RecvError {
fn map_err(err: io::Error) -> Error {
if let io::ErrorKind::InvalidData = err.kind() {
if let Some(custom) = err.get_ref() {
if custom.is::<LengthDelimitedCodecError>() {
return RecvError::Connection(Reason::FRAME_SIZE_ERROR);
return Error::library_go_away(Reason::FRAME_SIZE_ERROR);
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/codec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ mod error;
mod framed_read;
mod framed_write;

pub use self::error::{RecvError, SendError, UserError};
pub use self::error::{SendError, UserError};

use self::framed_read::FramedRead;
use self::framed_write::FramedWrite;

use crate::frame::{self, Data, Frame};
use crate::proto::Error;

use bytes::Buf;
use futures_core::Stream;
Expand Down Expand Up @@ -155,7 +156,7 @@ impl<T, B> Stream for Codec<T, B>
where
T: AsyncRead + Unpin,
{
type Item = Result<Frame, RecvError>;
type Item = Result<Frame, Error>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_next(cx)
Expand Down
Loading

0 comments on commit 77969a6

Please sign in to comment.