Skip to content

Commit

Permalink
Add checking for header sanity
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Abramov <[email protected]>
  • Loading branch information
agalakhov and daniel-abramov committed Sep 23, 2023
1 parent f916b33 commit 2e50292
Showing 1 changed file with 69 additions and 16 deletions.
85 changes: 69 additions & 16 deletions src/handshake/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub struct HandshakeMachine<Stream> {
impl<Stream> HandshakeMachine<Stream> {
/// Start reading data from the peer.
pub fn start_read(stream: Stream) -> Self {
HandshakeMachine { stream, state: HandshakeState::Reading(ReadBuffer::new()) }
Self { stream, state: HandshakeState::Reading(ReadBuffer::new(), AttackCheck::new()) }
}
/// Start writing data to the peer.
pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
Expand All @@ -41,25 +41,31 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
pub fn single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>> {
trace!("Doing handshake round.");
match self.state {
HandshakeState::Reading(mut buf) => {
HandshakeState::Reading(mut buf, mut attack_check) => {
let read = buf.read_from(&mut self.stream).no_block()?;
match read {
Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)),
Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading {
result: obj,
stream: self.stream,
tail: buf.into_vec(),
Some(count) => {
attack_check.check_incoming_packet_size(count)?;
// TODO: this is slow for big headers with too many small packets.
// The parser has to be reworked in order to work on streams instead
// of buffers.
Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading {
result: obj,
stream: self.stream,
tail: buf.into_vec(),
})
} else {
RoundResult::Incomplete(HandshakeMachine {
state: HandshakeState::Reading(buf, attack_check),
..self
})
})
} else {
RoundResult::Incomplete(HandshakeMachine {
state: HandshakeState::Reading(buf),
..self
})
}),
}
None => Ok(RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Reading(buf),
state: HandshakeState::Reading(buf, attack_check),
..self
})),
}
Expand Down Expand Up @@ -119,7 +125,54 @@ pub trait TryParse: Sized {
#[derive(Debug)]
enum HandshakeState {
/// Reading data from the peer.
Reading(ReadBuffer),
Reading(ReadBuffer, AttackCheck),
/// Sending data to the peer.
Writing(Cursor<Vec<u8>>),
}

/// Attack mitigation. Contains counters needed to prevent DoS attacks
/// and reject valid but useless headers.
#[derive(Debug)]
pub(crate) struct AttackCheck {
/// Number of HTTP header successful reads (TCP packets).
number_of_packets: usize,
/// Total number of bytes in HTTP header.
number_of_bytes: usize,
}

impl AttackCheck {
/// Initialize attack checking for incoming buffer.
fn new() -> Self {
Self { number_of_packets: 0, number_of_bytes: 0 }
}

/// Check the size of an incoming packet. To be called immediately after `read()`
/// passing its returned bytes count as `size`.
fn check_incoming_packet_size(&mut self, size: usize) -> Result<()> {
self.number_of_packets += 1;
self.number_of_bytes += size;

// TODO: these values are hardcoded. Instead of making them configurable,
// rework the way HTTP header is parsed to remove this check at all.
const MAX_BYTES: usize = 65536;
const MAX_PACKETS: usize = 512;
const MIN_PACKET_SIZE: usize = 128;
const MIN_PACKET_CHECK_THRESHOLD: usize = 64;

if self.number_of_bytes > MAX_BYTES {
return Err(Error::AttackAttempt);
}

if self.number_of_packets > MAX_PACKETS {
return Err(Error::AttackAttempt);
}

if self.number_of_packets > MIN_PACKET_CHECK_THRESHOLD {
if self.number_of_packets * MIN_PACKET_SIZE > self.number_of_bytes {
return Err(Error::AttackAttempt);
}
}

Ok(())
}
}

0 comments on commit 2e50292

Please sign in to comment.