From ee4e89a0157ce66afa598da5480d706c28c005aa Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Thu, 2 Aug 2018 11:23:20 +0100 Subject: [PATCH] Make bad websocket parser states unrepresentable. Motivation: Apparently when I wrote the WebSocket parser I forgot that enums are great, and so I added a bunch of optional properties. That was silly. This patch changes the WSParser structure to use an enum with associated data to ensure that we only store state when we are supposed to, and to guarantee that the state is good. Modifications: Move all state to enum case associated data. Result: Easier to validate the correctness of the WSParser code. --- .../NIOWebSocket/WebSocketFrameDecoder.swift | 136 ++++++++---------- 1 file changed, 62 insertions(+), 74 deletions(-) diff --git a/Sources/NIOWebSocket/WebSocketFrameDecoder.swift b/Sources/NIOWebSocket/WebSocketFrameDecoder.swift index d3392d330d..a055d38d20 100644 --- a/Sources/NIOWebSocket/WebSocketFrameDecoder.swift +++ b/Sources/NIOWebSocket/WebSocketFrameDecoder.swift @@ -73,21 +73,21 @@ enum DecoderState { /// The initial frame byte has been received, but the length byte /// has not. - case firstByteReceived + case firstByteReceived(firstByte: UInt8) /// The length byte indicates that we need to wait for the length word, and we're /// currently waiting for it. - case waitingForLengthWord + case waitingForLengthWord(firstByte: UInt8, masked: Bool) /// The length byte indicates that we need to wait for the length qword, and /// we're currently waiting for it. - case waitingForLengthQWord + case waitingForLengthQWord(firstByte: UInt8, masked: Bool) /// The mask bit indicates we are expecting a mask key. - case waitingForMask + case waitingForMask(firstByte: UInt8, length: Int) /// All the header data is complete, we are waiting for the application data. - case waitingForData + case waitingForData(firstByte: UInt8, length: Int, maskingKey: WebSocketMaskingKey?) } enum ParseResult { @@ -101,105 +101,111 @@ enum ParseResult { /// This parser attempts to parse a websocket frame incrementally, keeping as much parsing state around as possible to ensure that /// we don't repeatedly partially parse the data. struct WSParser { - internal private(set) var firstByte: UInt8? = nil - internal private(set) var length: Int? = nil - internal private(set) var masked: Bool = false - internal private(set) var maskingKey: WebSocketMaskingKey? = nil - /// The current state of the decoder during incremental parse. var state: DecoderState = .idle - private mutating func reset() { - self.state = .idle - self.firstByte = nil - self.length = nil - self.masked = false - self.maskingKey = nil - } - mutating func parseStep(_ buffer: inout ByteBuffer) -> ParseResult { switch self.state { case .idle: // This is a new buffer. We want to find the first octet and save it off. - assert(self.firstByte == nil) guard let firstByte = buffer.readInteger(as: UInt8.self) else { return .insufficientData } - self.firstByte = firstByte - self.state = .firstByteReceived + self.state = .firstByteReceived(firstByte: firstByte) return .continueParsing - case .firstByteReceived: + case .firstByteReceived(let firstByte): // Now we're looking for the length. We begin by finding the length byte to see if we // need any more data. - assert(self.length == nil) - assert(self.firstByte != nil) guard let lengthByte = buffer.readInteger(as: UInt8.self) else { return .insufficientData } - self.masked = (lengthByte & 0x80) != 0 + let masked = (lengthByte & 0x80) != 0 - switch lengthByte & 0x7F { - case 126: - self.state = .waitingForLengthWord - case 127: - self.state = .waitingForLengthQWord - case let len: + switch (lengthByte & 0x7F, masked) { + case (126, _): + self.state = .waitingForLengthWord(firstByte: firstByte, masked: masked) + case (127, _): + self.state = .waitingForLengthQWord(firstByte: firstByte, masked: masked) + case (let len, true): assert(len <= 125) - self.length = Int(len) - self.state = self.masked ? .waitingForMask : .waitingForData + self.state = .waitingForMask(firstByte: firstByte, length: Int(len)) + case (let len, false): + assert(len <= 125) + self.state = .waitingForData(firstByte: firstByte, length: Int(len), maskingKey: nil) } return .continueParsing - case .waitingForLengthWord: + case .waitingForLengthWord(let firstByte, let masked): // We've got a one-word length here. - assert(self.length == nil) - assert(self.firstByte != nil) guard let lengthWord = buffer.readInteger(as: UInt16.self) else { return .insufficientData } - self.length = Int(lengthWord) - self.state = self.masked ? .waitingForMask : .waitingForData + if masked { + self.state = .waitingForMask(firstByte: firstByte, length: Int(lengthWord)) + } else { + self.state = .waitingForData(firstByte: firstByte, length: Int(lengthWord), maskingKey: nil) + } return .continueParsing - case .waitingForLengthQWord: + case .waitingForLengthQWord(let firstByte, let masked): // We've got a qword of length here. - assert(self.length == nil) - assert(self.firstByte != nil) guard let lengthQWord = buffer.readInteger(as: UInt64.self) else { return .insufficientData } - self.length = Int(lengthQWord) - self.state = self.masked ? .waitingForMask : .waitingForData + if masked { + self.state = .waitingForMask(firstByte: firstByte, length: Int(lengthQWord)) + } else { + self.state = .waitingForData(firstByte: firstByte, length: Int(lengthQWord), maskingKey: nil) + } return .continueParsing - case .waitingForMask: + case .waitingForMask(let firstByte, let length): // We're waiting for the masking key. - assert(maskingKey == nil) - assert(self.firstByte != nil) - assert(self.length != nil) guard let maskingKey = buffer.readInteger(as: UInt32.self) else { return .insufficientData } - self.maskingKey = WebSocketMaskingKey(networkRepresentation: maskingKey) - self.state = .waitingForData + + self.state = .waitingForData(firstByte: firstByte, length: length, maskingKey: WebSocketMaskingKey(networkRepresentation: maskingKey)) return .continueParsing - case .waitingForData: - assert(self.firstByte != nil) - assert(self.length != nil) - guard let data = buffer.readSlice(length: self.length!) else { + case .waitingForData(let firstByte, let length, let maskingKey): + guard let data = buffer.readSlice(length: length) else { return .insufficientData } - let frame = WebSocketFrame(firstByte: self.firstByte!, maskKey: self.maskingKey, applicationData: data) - self.reset() + let frame = WebSocketFrame(firstByte: firstByte, maskKey: maskingKey, applicationData: data) + self.state = .idle return .result(frame) } } + + /// Apply a number of validations to the incremental state, ensuring that the frame we're + /// receiving is valid. + func validateState(maxFrameSize: Int) throws { + switch self.state { + case .waitingForMask(let firstByte, let length), .waitingForData(let firstByte, let length, _): + if length > maxFrameSize { + throw NIOWebSocketError.invalidFrameLength + } + + let isControlFrame = (firstByte & 0x08) != 0 + let isFragment = (firstByte & 0x80) == 0 + + if isControlFrame && isFragment { + throw NIOWebSocketError.fragmentedControlFrame + } + if isControlFrame && length > 125 { + throw NIOWebSocketError.multiByteControlFrameLength + } + case .idle, .firstByteReceived, .waitingForLengthWord, .waitingForLengthQWord: + // No validation necessary in this state as we have no length to validate. + break + } + } } /// An inbound `ChannelHandler` that deserializes websocket frames into a structured @@ -263,7 +269,7 @@ public final class WebSocketFrameDecoder: ByteToMessageDecoder { ctx.fireChannelRead(self.wrapInboundOut(frame)) case .continueParsing: do { - try self.validateState() + try self.parser.validateState(maxFrameSize: self.maxFrameSize) } catch { self.handleError(error, ctx: ctx) } @@ -281,25 +287,7 @@ public final class WebSocketFrameDecoder: ByteToMessageDecoder { return .needMoreData } - /// Apply a number of validations to the incremental state, ensuring that the frame we're - /// receiving is valid. - private func validateState() throws { - if let length = parser.length, length > self.maxFrameSize { - throw NIOWebSocketError.invalidFrameLength - } - if let length = parser.length, let firstByte = parser.firstByte { - let isControlFrame = (firstByte & 0x08) != 0 - let isFragment = (firstByte & 0x80) == 0 - - if isControlFrame && isFragment { - throw NIOWebSocketError.fragmentedControlFrame - } - if isControlFrame && length > 125 { - throw NIOWebSocketError.multiByteControlFrameLength - } - } - } /// We hit a decoding error, we're going to tear things down now. To do this we're /// basically going to send an error frame and then close the connection. Once we're