Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix crash for large HTTP request headers #661

Merged
merged 21 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,19 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler {

private func run(_ action: HTTP1ConnectionStateMachine.Action, context: ChannelHandlerContext) {
switch action {
case .sendRequestHead(let head, startBody: let startBody):
self.sendRequestHead(head, startBody: startBody, context: context)
case .sendRequestHead(let head, let sendEnd):
self.sendRequestHead(head, sendEnd: sendEnd, context: context)
case .notifyRequestHeadSendSuccessfully(let resumeRequestBodyStream, let startIdleTimer):

self.request!.requestHeadSent()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this force-unwrap safe?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment

if resumeRequestBodyStream {
self.request!.resumeRequestBodyStream()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relatedly, why is this safe? Can it end up nil after the preceding outcall?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch! This is not safe. I added a test which crashed previously bit is fixed by chaining the implantation to unwrap safely: 18f6505

}
if startIdleTimer {
if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() {
self.runTimeoutAction(timeoutAction, context: context)
}
}
case .sendBodyPart(let part, let writePromise):
context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: writePromise)

Expand Down Expand Up @@ -320,32 +330,15 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler {
}
}

private func sendRequestHead(_ head: HTTPRequestHead, startBody: Bool, context: ChannelHandlerContext) {
if startBody {
context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil)

// The above write might trigger an error, which may lead to a call to `errorCaught`,
// which in turn, may fail the request and pop it from the handler. For this reason
// we must check if the request is still present here.
guard let request = self.request else { return }
request.requestHeadSent()

request.resumeRequestBodyStream()
} else {
private func sendRequestHead(_ head: HTTPRequestHead, sendEnd: Bool, context: ChannelHandlerContext) {
if sendEnd {
context.write(self.wrapOutboundOut(.head(head)), promise: nil)
context.write(self.wrapOutboundOut(.end(nil)), promise: nil)
context.flush()

// The above write might trigger an error, which may lead to a call to `errorCaught`,
// which in turn, may fail the request and pop it from the handler. For this reason
// we must check if the request is still present here.
guard let request = self.request else { return }
request.requestHeadSent()

if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() {
self.runTimeoutAction(timeoutAction, context: context)
}
} else {
context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil)
}
self.run(self.state.headSent(), context: context)
}

private func runTimeoutAction(_ action: IdleReadStateMachine.Action, context: ChannelHandlerContext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ struct HTTP1ConnectionStateMachine {
case none
}

case sendRequestHead(HTTPRequestHead, startBody: Bool)
case sendRequestHead(HTTPRequestHead, sendEnd: Bool)
case notifyRequestHeadSendSuccessfully(
resumeRequestBodyStream: Bool,
startIdleTimer: Bool
)
case sendBodyPart(IOData, EventLoopPromise<Void>?)
case sendRequestEnd(EventLoopPromise<Void>?)
case failSendBodyPart(Error, EventLoopPromise<Void>?)
Expand Down Expand Up @@ -350,6 +354,17 @@ struct HTTP1ConnectionStateMachine {
return state.modify(with: action)
}
}

mutating func headSent() -> Action {
guard case .inRequest(var requestStateMachine, let close) = self.state else {
return .wait
}
return self.avoidingStateMachineCoW { state in
let action = requestStateMachine.headSent()
state = .inRequest(requestStateMachine, close: close)
return state.modify(with: action)
}
}
}

extension HTTP1ConnectionStateMachine {
Expand Down Expand Up @@ -390,8 +405,10 @@ extension HTTP1ConnectionStateMachine {
extension HTTP1ConnectionStateMachine.State {
fileprivate mutating func modify(with action: HTTPRequestStateMachine.Action) -> HTTP1ConnectionStateMachine.Action {
switch action {
case .sendRequestHead(let head, let startBody):
return .sendRequestHead(head, startBody: startBody)
case .sendRequestHead(let head, let sendEnd):
return .sendRequestHead(head, sendEnd: sendEnd)
case .notifyRequestHeadSendSuccessfully(let resumeRequestBodyStream, let startIdleTimer):
return .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: resumeRequestBodyStream, startIdleTimer: startIdleTimer)
case .pauseRequestBodyStream:
return .pauseRequestBodyStream
case .resumeRequestBodyStream:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,18 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler {

private func run(_ action: HTTPRequestStateMachine.Action, context: ChannelHandlerContext) {
switch action {
case .sendRequestHead(let head, let startBody):
self.sendRequestHead(head, startBody: startBody, context: context)

case .sendRequestHead(let head, let sendEnd):
self.sendRequestHead(head, sendEnd: sendEnd, context: context)
case .notifyRequestHeadSendSuccessfully(let resumeRequestBodyStream, let startIdleTimer):
self.request!.requestHeadSent()
Lukasa marked this conversation as resolved.
Show resolved Hide resolved
if resumeRequestBodyStream {
self.request!.resumeRequestBodyStream()
}
if startIdleTimer {
if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() {
self.runTimeoutAction(timeoutAction, context: context)
}
}
case .pauseRequestBodyStream:
// We can force unwrap the request here, as we have just validated in the state machine,
// that the request is neither failed nor finished yet
Expand Down Expand Up @@ -210,31 +219,15 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler {
}
}

private func sendRequestHead(_ head: HTTPRequestHead, startBody: Bool, context: ChannelHandlerContext) {
if startBody {
context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil)

// The above write might trigger an error, which may lead to a call to `errorCaught`,
// which in turn, may fail the request and pop it from the handler. For this reason
// we must check if the request is still present here.
guard let request = self.request else { return }
request.requestHeadSent()
request.resumeRequestBodyStream()
} else {
private func sendRequestHead(_ head: HTTPRequestHead, sendEnd: Bool, context: ChannelHandlerContext) {
if sendEnd {
context.write(self.wrapOutboundOut(.head(head)), promise: nil)
context.write(self.wrapOutboundOut(.end(nil)), promise: nil)
context.flush()

// The above write might trigger an error, which may lead to a call to `errorCaught`,
// which in turn, may fail the request and pop it from the handler. For this reason
// we must check if the request is still present here.
guard let request = self.request else { return }
request.requestHeadSent()

if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() {
self.runTimeoutAction(timeoutAction, context: context)
}
} else {
context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil)
}
self.run(self.state.headSent(), context: context)
}

private func runSuccessfulFinalAction(_ action: HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction, context: ChannelHandlerContext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,24 @@ struct HTTPRequestStateMachine {
fileprivate enum State {
/// The initial state machine state. The only valid mutation is `start()`. The state will
/// transitions to:
/// - `.waitForChannelToBecomeWritable`
/// - `.running(.streaming, .initialized)` (if the Channel is writable and if a request body is expected)
/// - `.running(.endSent, .initialized)` (if the Channel is writable and no request body is expected)
/// - `.waitForChannelToBecomeWritable` (if the channel becomes non writable while sending the header)
/// - `.sendingHead` if the channel is writable
case initialized

/// Waiting for the channel to be writable. Valid transitions are:
/// - `.running(.streaming, .initialized)` (once the Channel is writable again and if a request body is expected)
/// - `.running(.endSent, .initialized)` (once the Channel is writable again and no request body is expected)
/// - `.running(.streaming, .waitingForHead)` (once the Channel is writable again and if a request body is expected)
/// - `.running(.endSent, .waitingForHead)` (once the Channel is writable again and no request body is expected)
/// - `.failed` (if a connection error occurred)
case waitForChannelToBecomeWritable(HTTPRequestHead, RequestFramingMetadata)

/// A request is on the wire. Valid transitions are:
/// - `.finished`
/// - `.failed`
case running(RequestState, ResponseState)

/// The request has completed successfully
case finished

/// The request has failed
case failed(Error)

Expand Down Expand Up @@ -93,7 +96,11 @@ struct HTTPRequestStateMachine {
case none
}

case sendRequestHead(HTTPRequestHead, startBody: Bool)
case sendRequestHead(HTTPRequestHead, sendEnd: Bool)
case notifyRequestHeadSendSuccessfully(
resumeRequestBodyStream: Bool,
startIdleTimer: Bool
)
case sendBodyPart(IOData, EventLoopPromise<Void>?)
case sendRequestEnd(EventLoopPromise<Void>?)
case failSendBodyPart(Error, EventLoopPromise<Void>?)
Expand Down Expand Up @@ -223,6 +230,7 @@ struct HTTPRequestStateMachine {
// the request failed, before it was sent onto the wire.
self.state = .failed(error)
return .failRequest(error, .none)

case .running:
self.state = .failed(error)
return .failRequest(error, .close(nil))
Expand Down Expand Up @@ -520,7 +528,7 @@ struct HTTPRequestStateMachine {

switch self.state {
case .initialized, .waitForChannelToBecomeWritable:
preconditionFailure("How can we receive a response head before sending a request head ourselves")
preconditionFailure("How can we receive a response head before sending a request head ourselves \(self.state)")

case .running(.streaming(let expectedBodyLength, let sentBodyBytes, producer: .paused), .waitingForHead):
self.state = .running(
Expand Down Expand Up @@ -561,7 +569,7 @@ struct HTTPRequestStateMachine {
mutating func receivedHTTPResponseBodyPart(_ body: ByteBuffer) -> Action {
switch self.state {
case .initialized, .waitForChannelToBecomeWritable:
preconditionFailure("How can we receive a response head before sending a request head ourselves. Invalid state: \(self.state)")
preconditionFailure("How can we receive a response head before completely sending a request head ourselves. Invalid state: \(self.state)")

case .running(_, .waitingForHead):
preconditionFailure("How can we receive a response body, if we haven't received a head. Invalid state: \(self.state)")
Expand All @@ -587,7 +595,7 @@ struct HTTPRequestStateMachine {
private mutating func receivedHTTPResponseEnd() -> Action {
switch self.state {
case .initialized, .waitForChannelToBecomeWritable:
preconditionFailure("How can we receive a response head before sending a request head ourselves. Invalid state: \(self.state)")
preconditionFailure("How can we receive a response end before completely sending a request head ourselves. Invalid state: \(self.state)")

case .running(_, .waitingForHead):
preconditionFailure("How can we receive a response end, if we haven't a received a head. Invalid state: \(self.state)")
Expand Down Expand Up @@ -654,7 +662,7 @@ struct HTTPRequestStateMachine {
case .initialized,
.running(_, .waitingForHead),
.waitForChannelToBecomeWritable:
preconditionFailure("The response is expected to only ask for more data after the response head was forwarded")
preconditionFailure("The response is expected to only ask for more data after the response head was forwarded \(self.state)")

case .running(let requestState, .receivingBody(let head, var responseStreamState)):
return self.avoidingStateMachineCoW { state -> Action in
Expand Down Expand Up @@ -697,18 +705,51 @@ struct HTTPRequestStateMachine {
}

private mutating func startSendingRequest(head: HTTPRequestHead, metadata: RequestFramingMetadata) -> Action {
switch metadata.body {
case .stream:
self.state = .running(.streaming(expectedBodyLength: nil, sentBodyBytes: 0, producer: .producing), .waitingForHead)
return .sendRequestHead(head, startBody: true)
case .fixedSize(0):
let length = metadata.body.expectedLength
if length == 0 {
// no body
self.state = .running(.endSent, .waitingForHead)
return .sendRequestHead(head, startBody: false)
case .fixedSize(let length):
// length is greater than zero and we therefore have a body to send
self.state = .running(.streaming(expectedBodyLength: length, sentBodyBytes: 0, producer: .producing), .waitingForHead)
return .sendRequestHead(head, startBody: true)
return .sendRequestHead(head, sendEnd: true)
} else {
self.state = .running(.streaming(expectedBodyLength: length, sentBodyBytes: 0, producer: .paused), .waitingForHead)
return .sendRequestHead(head, sendEnd: false)
}
}

mutating func headSent() -> Action {
switch self.state {
case .initialized, .waitForChannelToBecomeWritable, .finished:
preconditionFailure("Not a valid transition after `.sendingHeader`: \(self.state)")

case .running(.streaming(let expectedBodyLength, let sentBodyBytes, producer: .paused), let responseState):
let startProducing = self.isChannelWritable && expectedBodyLength != sentBodyBytes
self.state = .running(.streaming(
expectedBodyLength: expectedBodyLength,
sentBodyBytes: sentBodyBytes,
producer: startProducing ? .producing : .paused
), responseState)
return .notifyRequestHeadSendSuccessfully(
resumeRequestBodyStream: startProducing,
startIdleTimer: false
)
case .running(.endSent, _):
return .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true)
case .running(.streaming(_, _, producer: .producing), _):
preconditionFailure("request body producing can not start before we have successfully send the header \(self.state)")
case .failed:
return .wait

case .modifying:
preconditionFailure("Invalid state: \(self.state)")
}
}
}

extension RequestFramingMetadata.Body {
var expectedLength: Int? {
switch self {
case .fixedSize(let length): return length
case .stream: return nil
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,6 @@ class HTTP1ClientChannelHandlerTests: XCTestCase {
}

func testChannelBecomesNonWritableDuringHeaderWrite() throws {
try XCTSkipIf(true, "this currently fails and will be fixed in follow up PR")
final class ChangeWritabilityOnFlush: ChannelOutboundHandler {
typealias OutboundIn = Any
func flush(context: ChannelHandlerContext) {
Expand Down
Loading