Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make bad websocket parser states unrepresentable. #547

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 62 additions & 74 deletions Sources/NIOWebSocket/WebSocketFrameDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,21 @@ enum DecoderState {

/// The initial frame byte has been received, but the length byte
/// has not.
case firstByteReceived
case firstByteReceived(firstByte: UInt8)

/// The length byte indicates that we need to wait for the length word, and we're
/// currently waiting for it.
case waitingForLengthWord
case waitingForLengthWord(firstByte: UInt8, masked: Bool)

/// The length byte indicates that we need to wait for the length qword, and
/// we're currently waiting for it.
case waitingForLengthQWord
case waitingForLengthQWord(firstByte: UInt8, masked: Bool)

/// The mask bit indicates we are expecting a mask key.
case waitingForMask
case waitingForMask(firstByte: UInt8, length: Int)

/// All the header data is complete, we are waiting for the application data.
case waitingForData
case waitingForData(firstByte: UInt8, length: Int, maskingKey: WebSocketMaskingKey?)
}

enum ParseResult {
Expand All @@ -101,105 +101,111 @@ enum ParseResult {
/// This parser attempts to parse a websocket frame incrementally, keeping as much parsing state around as possible to ensure that
/// we don't repeatedly partially parse the data.
struct WSParser {
internal private(set) var firstByte: UInt8? = nil
internal private(set) var length: Int? = nil
internal private(set) var masked: Bool = false
internal private(set) var maskingKey: WebSocketMaskingKey? = nil

/// The current state of the decoder during incremental parse.
var state: DecoderState = .idle

private mutating func reset() {
self.state = .idle
self.firstByte = nil
self.length = nil
self.masked = false
self.maskingKey = nil
}

mutating func parseStep(_ buffer: inout ByteBuffer) -> ParseResult {
switch self.state {
case .idle:
// This is a new buffer. We want to find the first octet and save it off.
assert(self.firstByte == nil)
guard let firstByte = buffer.readInteger(as: UInt8.self) else {
return .insufficientData
}
self.firstByte = firstByte
self.state = .firstByteReceived
self.state = .firstByteReceived(firstByte: firstByte)
return .continueParsing

case .firstByteReceived:
case .firstByteReceived(let firstByte):
// Now we're looking for the length. We begin by finding the length byte to see if we
// need any more data.
assert(self.length == nil)
assert(self.firstByte != nil)
guard let lengthByte = buffer.readInteger(as: UInt8.self) else {
return .insufficientData
}

self.masked = (lengthByte & 0x80) != 0
let masked = (lengthByte & 0x80) != 0

switch lengthByte & 0x7F {
case 126:
self.state = .waitingForLengthWord
case 127:
self.state = .waitingForLengthQWord
case let len:
switch (lengthByte & 0x7F, masked) {
case (126, _):
self.state = .waitingForLengthWord(firstByte: firstByte, masked: masked)
case (127, _):
self.state = .waitingForLengthQWord(firstByte: firstByte, masked: masked)
case (let len, true):
assert(len <= 125)
self.length = Int(len)
self.state = self.masked ? .waitingForMask : .waitingForData
self.state = .waitingForMask(firstByte: firstByte, length: Int(len))
case (let len, false):
assert(len <= 125)
self.state = .waitingForData(firstByte: firstByte, length: Int(len), maskingKey: nil)
}
return .continueParsing

case .waitingForLengthWord:
case .waitingForLengthWord(let firstByte, let masked):
// We've got a one-word length here.
assert(self.length == nil)
assert(self.firstByte != nil)
guard let lengthWord = buffer.readInteger(as: UInt16.self) else {
return .insufficientData
}

self.length = Int(lengthWord)
self.state = self.masked ? .waitingForMask : .waitingForData
if masked {
self.state = .waitingForMask(firstByte: firstByte, length: Int(lengthWord))
} else {
self.state = .waitingForData(firstByte: firstByte, length: Int(lengthWord), maskingKey: nil)
}
return .continueParsing

case .waitingForLengthQWord:
case .waitingForLengthQWord(let firstByte, let masked):
// We've got a qword of length here.
assert(self.length == nil)
assert(self.firstByte != nil)
guard let lengthQWord = buffer.readInteger(as: UInt64.self) else {
return .insufficientData
}

self.length = Int(lengthQWord)
self.state = self.masked ? .waitingForMask : .waitingForData
if masked {
self.state = .waitingForMask(firstByte: firstByte, length: Int(lengthQWord))
} else {
self.state = .waitingForData(firstByte: firstByte, length: Int(lengthQWord), maskingKey: nil)
}
return .continueParsing

case .waitingForMask:
case .waitingForMask(let firstByte, let length):
// We're waiting for the masking key.
assert(maskingKey == nil)
assert(self.firstByte != nil)
assert(self.length != nil)
guard let maskingKey = buffer.readInteger(as: UInt32.self) else {
return .insufficientData
}
self.maskingKey = WebSocketMaskingKey(networkRepresentation: maskingKey)
self.state = .waitingForData

self.state = .waitingForData(firstByte: firstByte, length: length, maskingKey: WebSocketMaskingKey(networkRepresentation: maskingKey))
return .continueParsing

case .waitingForData:
assert(self.firstByte != nil)
assert(self.length != nil)
guard let data = buffer.readSlice(length: self.length!) else {
case .waitingForData(let firstByte, let length, let maskingKey):
guard let data = buffer.readSlice(length: length) else {
return .insufficientData
}

let frame = WebSocketFrame(firstByte: self.firstByte!, maskKey: self.maskingKey, applicationData: data)
self.reset()
let frame = WebSocketFrame(firstByte: firstByte, maskKey: maskingKey, applicationData: data)
self.state = .idle
return .result(frame)
}
}

/// Apply a number of validations to the incremental state, ensuring that the frame we're
/// receiving is valid.
func validateState(maxFrameSize: Int) throws {
switch self.state {
case .waitingForMask(let firstByte, let length), .waitingForData(let firstByte, let length, _):
if length > maxFrameSize {
throw NIOWebSocketError.invalidFrameLength
}

let isControlFrame = (firstByte & 0x08) != 0
let isFragment = (firstByte & 0x80) == 0

if isControlFrame && isFragment {
throw NIOWebSocketError.fragmentedControlFrame
}
if isControlFrame && length > 125 {
throw NIOWebSocketError.multiByteControlFrameLength
}
case .idle, .firstByteReceived, .waitingForLengthWord, .waitingForLengthQWord:
// No validation necessary in this state as we have no length to validate.
break
}
}
}

/// An inbound `ChannelHandler` that deserializes websocket frames into a structured
Expand Down Expand Up @@ -263,7 +269,7 @@ public final class WebSocketFrameDecoder: ByteToMessageDecoder {
ctx.fireChannelRead(self.wrapInboundOut(frame))
case .continueParsing:
do {
try self.validateState()
try self.parser.validateState(maxFrameSize: self.maxFrameSize)
} catch {
self.handleError(error, ctx: ctx)
}
Expand All @@ -281,25 +287,7 @@ public final class WebSocketFrameDecoder: ByteToMessageDecoder {
return .needMoreData
}

/// Apply a number of validations to the incremental state, ensuring that the frame we're
/// receiving is valid.
private func validateState() throws {
if let length = parser.length, length > self.maxFrameSize {
throw NIOWebSocketError.invalidFrameLength
}

if let length = parser.length, let firstByte = parser.firstByte {
let isControlFrame = (firstByte & 0x08) != 0
let isFragment = (firstByte & 0x80) == 0

if isControlFrame && isFragment {
throw NIOWebSocketError.fragmentedControlFrame
}
if isControlFrame && length > 125 {
throw NIOWebSocketError.multiByteControlFrameLength
}
}
}

/// We hit a decoding error, we're going to tear things down now. To do this we're
/// basically going to send an error frame and then close the connection. Once we're
Expand Down