diff --git a/Sources/NIOWebSocket/WebSocketFrameDecoder.swift b/Sources/NIOWebSocket/WebSocketFrameDecoder.swift index d3392d330d..a055d38d20 100644 --- a/Sources/NIOWebSocket/WebSocketFrameDecoder.swift +++ b/Sources/NIOWebSocket/WebSocketFrameDecoder.swift @@ -73,21 +73,21 @@ enum DecoderState { /// The initial frame byte has been received, but the length byte /// has not. - case firstByteReceived + case firstByteReceived(firstByte: UInt8) /// The length byte indicates that we need to wait for the length word, and we're /// currently waiting for it. - case waitingForLengthWord + case waitingForLengthWord(firstByte: UInt8, masked: Bool) /// The length byte indicates that we need to wait for the length qword, and /// we're currently waiting for it. - case waitingForLengthQWord + case waitingForLengthQWord(firstByte: UInt8, masked: Bool) /// The mask bit indicates we are expecting a mask key. - case waitingForMask + case waitingForMask(firstByte: UInt8, length: Int) /// All the header data is complete, we are waiting for the application data. - case waitingForData + case waitingForData(firstByte: UInt8, length: Int, maskingKey: WebSocketMaskingKey?) } enum ParseResult { @@ -101,105 +101,111 @@ enum ParseResult { /// This parser attempts to parse a websocket frame incrementally, keeping as much parsing state around as possible to ensure that /// we don't repeatedly partially parse the data. struct WSParser { - internal private(set) var firstByte: UInt8? = nil - internal private(set) var length: Int? = nil - internal private(set) var masked: Bool = false - internal private(set) var maskingKey: WebSocketMaskingKey? = nil - /// The current state of the decoder during incremental parse. var state: DecoderState = .idle - private mutating func reset() { - self.state = .idle - self.firstByte = nil - self.length = nil - self.masked = false - self.maskingKey = nil - } - mutating func parseStep(_ buffer: inout ByteBuffer) -> ParseResult { switch self.state { case .idle: // This is a new buffer. We want to find the first octet and save it off. - assert(self.firstByte == nil) guard let firstByte = buffer.readInteger(as: UInt8.self) else { return .insufficientData } - self.firstByte = firstByte - self.state = .firstByteReceived + self.state = .firstByteReceived(firstByte: firstByte) return .continueParsing - case .firstByteReceived: + case .firstByteReceived(let firstByte): // Now we're looking for the length. We begin by finding the length byte to see if we // need any more data. - assert(self.length == nil) - assert(self.firstByte != nil) guard let lengthByte = buffer.readInteger(as: UInt8.self) else { return .insufficientData } - self.masked = (lengthByte & 0x80) != 0 + let masked = (lengthByte & 0x80) != 0 - switch lengthByte & 0x7F { - case 126: - self.state = .waitingForLengthWord - case 127: - self.state = .waitingForLengthQWord - case let len: + switch (lengthByte & 0x7F, masked) { + case (126, _): + self.state = .waitingForLengthWord(firstByte: firstByte, masked: masked) + case (127, _): + self.state = .waitingForLengthQWord(firstByte: firstByte, masked: masked) + case (let len, true): assert(len <= 125) - self.length = Int(len) - self.state = self.masked ? .waitingForMask : .waitingForData + self.state = .waitingForMask(firstByte: firstByte, length: Int(len)) + case (let len, false): + assert(len <= 125) + self.state = .waitingForData(firstByte: firstByte, length: Int(len), maskingKey: nil) } return .continueParsing - case .waitingForLengthWord: + case .waitingForLengthWord(let firstByte, let masked): // We've got a one-word length here. - assert(self.length == nil) - assert(self.firstByte != nil) guard let lengthWord = buffer.readInteger(as: UInt16.self) else { return .insufficientData } - self.length = Int(lengthWord) - self.state = self.masked ? .waitingForMask : .waitingForData + if masked { + self.state = .waitingForMask(firstByte: firstByte, length: Int(lengthWord)) + } else { + self.state = .waitingForData(firstByte: firstByte, length: Int(lengthWord), maskingKey: nil) + } return .continueParsing - case .waitingForLengthQWord: + case .waitingForLengthQWord(let firstByte, let masked): // We've got a qword of length here. - assert(self.length == nil) - assert(self.firstByte != nil) guard let lengthQWord = buffer.readInteger(as: UInt64.self) else { return .insufficientData } - self.length = Int(lengthQWord) - self.state = self.masked ? .waitingForMask : .waitingForData + if masked { + self.state = .waitingForMask(firstByte: firstByte, length: Int(lengthQWord)) + } else { + self.state = .waitingForData(firstByte: firstByte, length: Int(lengthQWord), maskingKey: nil) + } return .continueParsing - case .waitingForMask: + case .waitingForMask(let firstByte, let length): // We're waiting for the masking key. - assert(maskingKey == nil) - assert(self.firstByte != nil) - assert(self.length != nil) guard let maskingKey = buffer.readInteger(as: UInt32.self) else { return .insufficientData } - self.maskingKey = WebSocketMaskingKey(networkRepresentation: maskingKey) - self.state = .waitingForData + + self.state = .waitingForData(firstByte: firstByte, length: length, maskingKey: WebSocketMaskingKey(networkRepresentation: maskingKey)) return .continueParsing - case .waitingForData: - assert(self.firstByte != nil) - assert(self.length != nil) - guard let data = buffer.readSlice(length: self.length!) else { + case .waitingForData(let firstByte, let length, let maskingKey): + guard let data = buffer.readSlice(length: length) else { return .insufficientData } - let frame = WebSocketFrame(firstByte: self.firstByte!, maskKey: self.maskingKey, applicationData: data) - self.reset() + let frame = WebSocketFrame(firstByte: firstByte, maskKey: maskingKey, applicationData: data) + self.state = .idle return .result(frame) } } + + /// Apply a number of validations to the incremental state, ensuring that the frame we're + /// receiving is valid. + func validateState(maxFrameSize: Int) throws { + switch self.state { + case .waitingForMask(let firstByte, let length), .waitingForData(let firstByte, let length, _): + if length > maxFrameSize { + throw NIOWebSocketError.invalidFrameLength + } + + let isControlFrame = (firstByte & 0x08) != 0 + let isFragment = (firstByte & 0x80) == 0 + + if isControlFrame && isFragment { + throw NIOWebSocketError.fragmentedControlFrame + } + if isControlFrame && length > 125 { + throw NIOWebSocketError.multiByteControlFrameLength + } + case .idle, .firstByteReceived, .waitingForLengthWord, .waitingForLengthQWord: + // No validation necessary in this state as we have no length to validate. + break + } + } } /// An inbound `ChannelHandler` that deserializes websocket frames into a structured @@ -263,7 +269,7 @@ public final class WebSocketFrameDecoder: ByteToMessageDecoder { ctx.fireChannelRead(self.wrapInboundOut(frame)) case .continueParsing: do { - try self.validateState() + try self.parser.validateState(maxFrameSize: self.maxFrameSize) } catch { self.handleError(error, ctx: ctx) } @@ -281,25 +287,7 @@ public final class WebSocketFrameDecoder: ByteToMessageDecoder { return .needMoreData } - /// Apply a number of validations to the incremental state, ensuring that the frame we're - /// receiving is valid. - private func validateState() throws { - if let length = parser.length, length > self.maxFrameSize { - throw NIOWebSocketError.invalidFrameLength - } - if let length = parser.length, let firstByte = parser.firstByte { - let isControlFrame = (firstByte & 0x08) != 0 - let isFragment = (firstByte & 0x80) == 0 - - if isControlFrame && isFragment { - throw NIOWebSocketError.fragmentedControlFrame - } - if isControlFrame && length > 125 { - throw NIOWebSocketError.multiByteControlFrameLength - } - } - } /// We hit a decoding error, we're going to tear things down now. To do this we're /// basically going to send an error frame and then close the connection. Once we're