diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift index 626a6fc23..8af70ac23 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift @@ -183,9 +183,23 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { private func run(_ action: HTTP1ConnectionStateMachine.Action, context: ChannelHandlerContext) { switch action { - case .sendRequestHead(let head, startBody: let startBody): - self.sendRequestHead(head, startBody: startBody, context: context) - + case .sendRequestHead(let head, let sendEnd): + self.sendRequestHead(head, sendEnd: sendEnd, context: context) + case .notifyRequestHeadSendSuccessfully(let resumeRequestBodyStream, let startIdleTimer): + // We can force unwrap the request here, as we have just validated in the state machine, + // that the request is neither failed nor finished yet + self.request!.requestHeadSent() + if resumeRequestBodyStream, let request = self.request { + // The above request head send notification might lead the request to mark itself as + // cancelled, which in turn might pop the request of the handler. For this reason we + // must check if the request is still present here. + request.resumeRequestBodyStream() + } + if startIdleTimer { + if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(timeoutAction, context: context) + } + } case .sendBodyPart(let part, let writePromise): context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: writePromise) @@ -320,32 +334,15 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { } } - private func sendRequestHead(_ head: HTTPRequestHead, startBody: Bool, context: ChannelHandlerContext) { - if startBody { - context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil) - - // The above write might trigger an error, which may lead to a call to `errorCaught`, - // which in turn, may fail the request and pop it from the handler. For this reason - // we must check if the request is still present here. - guard let request = self.request else { return } - request.requestHeadSent() - - request.resumeRequestBodyStream() - } else { + private func sendRequestHead(_ head: HTTPRequestHead, sendEnd: Bool, context: ChannelHandlerContext) { + if sendEnd { context.write(self.wrapOutboundOut(.head(head)), promise: nil) context.write(self.wrapOutboundOut(.end(nil)), promise: nil) context.flush() - - // The above write might trigger an error, which may lead to a call to `errorCaught`, - // which in turn, may fail the request and pop it from the handler. For this reason - // we must check if the request is still present here. - guard let request = self.request else { return } - request.requestHeadSent() - - if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { - self.runTimeoutAction(timeoutAction, context: context) - } + } else { + context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil) } + self.run(self.state.headSent(), context: context) } private func runTimeoutAction(_ action: IdleReadStateMachine.Action, context: ChannelHandlerContext) { diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift index e7258611c..a908ded9a 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift @@ -57,7 +57,11 @@ struct HTTP1ConnectionStateMachine { case none } - case sendRequestHead(HTTPRequestHead, startBody: Bool) + case sendRequestHead(HTTPRequestHead, sendEnd: Bool) + case notifyRequestHeadSendSuccessfully( + resumeRequestBodyStream: Bool, + startIdleTimer: Bool + ) case sendBodyPart(IOData, EventLoopPromise?) case sendRequestEnd(EventLoopPromise?) case failSendBodyPart(Error, EventLoopPromise?) @@ -350,6 +354,17 @@ struct HTTP1ConnectionStateMachine { return state.modify(with: action) } } + + mutating func headSent() -> Action { + guard case .inRequest(var requestStateMachine, let close) = self.state else { + return .wait + } + return self.avoidingStateMachineCoW { state in + let action = requestStateMachine.headSent() + state = .inRequest(requestStateMachine, close: close) + return state.modify(with: action) + } + } } extension HTTP1ConnectionStateMachine { @@ -390,8 +405,10 @@ extension HTTP1ConnectionStateMachine { extension HTTP1ConnectionStateMachine.State { fileprivate mutating func modify(with action: HTTPRequestStateMachine.Action) -> HTTP1ConnectionStateMachine.Action { switch action { - case .sendRequestHead(let head, let startBody): - return .sendRequestHead(head, startBody: startBody) + case .sendRequestHead(let head, let sendEnd): + return .sendRequestHead(head, sendEnd: sendEnd) + case .notifyRequestHeadSendSuccessfully(let resumeRequestBodyStream, let startIdleTimer): + return .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: resumeRequestBodyStream, startIdleTimer: startIdleTimer) case .pauseRequestBodyStream: return .pauseRequestBodyStream case .resumeRequestBodyStream: diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift index 578b83029..e7412f5c2 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift @@ -140,9 +140,23 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { private func run(_ action: HTTPRequestStateMachine.Action, context: ChannelHandlerContext) { switch action { - case .sendRequestHead(let head, let startBody): - self.sendRequestHead(head, startBody: startBody, context: context) - + case .sendRequestHead(let head, let sendEnd): + self.sendRequestHead(head, sendEnd: sendEnd, context: context) + case .notifyRequestHeadSendSuccessfully(let resumeRequestBodyStream, let startIdleTimer): + // We can force unwrap the request here, as we have just validated in the state machine, + // that the request is neither failed nor finished yet + self.request!.requestHeadSent() + if resumeRequestBodyStream, let request = self.request { + // The above request head send notification might lead the request to mark itself as + // cancelled, which in turn might pop the request of the handler. For this reason we + // must check if the request is still present here. + request.resumeRequestBodyStream() + } + if startIdleTimer { + if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(timeoutAction, context: context) + } + } case .pauseRequestBodyStream: // We can force unwrap the request here, as we have just validated in the state machine, // that the request is neither failed nor finished yet @@ -210,31 +224,15 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { } } - private func sendRequestHead(_ head: HTTPRequestHead, startBody: Bool, context: ChannelHandlerContext) { - if startBody { - context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil) - - // The above write might trigger an error, which may lead to a call to `errorCaught`, - // which in turn, may fail the request and pop it from the handler. For this reason - // we must check if the request is still present here. - guard let request = self.request else { return } - request.requestHeadSent() - request.resumeRequestBodyStream() - } else { + private func sendRequestHead(_ head: HTTPRequestHead, sendEnd: Bool, context: ChannelHandlerContext) { + if sendEnd { context.write(self.wrapOutboundOut(.head(head)), promise: nil) context.write(self.wrapOutboundOut(.end(nil)), promise: nil) context.flush() - - // The above write might trigger an error, which may lead to a call to `errorCaught`, - // which in turn, may fail the request and pop it from the handler. For this reason - // we must check if the request is still present here. - guard let request = self.request else { return } - request.requestHeadSent() - - if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { - self.runTimeoutAction(timeoutAction, context: context) - } + } else { + context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil) } + self.run(self.state.headSent(), context: context) } private func runSuccessfulFinalAction(_ action: HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction, context: ChannelHandlerContext) { diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift index aafa3d28b..4835feac3 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift @@ -20,21 +20,24 @@ struct HTTPRequestStateMachine { fileprivate enum State { /// The initial state machine state. The only valid mutation is `start()`. The state will /// transitions to: - /// - `.waitForChannelToBecomeWritable` - /// - `.running(.streaming, .initialized)` (if the Channel is writable and if a request body is expected) - /// - `.running(.endSent, .initialized)` (if the Channel is writable and no request body is expected) + /// - `.waitForChannelToBecomeWritable` (if the channel becomes non writable while sending the header) + /// - `.sendingHead` if the channel is writable case initialized + /// Waiting for the channel to be writable. Valid transitions are: - /// - `.running(.streaming, .initialized)` (once the Channel is writable again and if a request body is expected) - /// - `.running(.endSent, .initialized)` (once the Channel is writable again and no request body is expected) + /// - `.running(.streaming, .waitingForHead)` (once the Channel is writable again and if a request body is expected) + /// - `.running(.endSent, .waitingForHead)` (once the Channel is writable again and no request body is expected) /// - `.failed` (if a connection error occurred) case waitForChannelToBecomeWritable(HTTPRequestHead, RequestFramingMetadata) + /// A request is on the wire. Valid transitions are: /// - `.finished` /// - `.failed` case running(RequestState, ResponseState) + /// The request has completed successfully case finished + /// The request has failed case failed(Error) @@ -93,7 +96,11 @@ struct HTTPRequestStateMachine { case none } - case sendRequestHead(HTTPRequestHead, startBody: Bool) + case sendRequestHead(HTTPRequestHead, sendEnd: Bool) + case notifyRequestHeadSendSuccessfully( + resumeRequestBodyStream: Bool, + startIdleTimer: Bool + ) case sendBodyPart(IOData, EventLoopPromise?) case sendRequestEnd(EventLoopPromise?) case failSendBodyPart(Error, EventLoopPromise?) @@ -223,6 +230,7 @@ struct HTTPRequestStateMachine { // the request failed, before it was sent onto the wire. self.state = .failed(error) return .failRequest(error, .none) + case .running: self.state = .failed(error) return .failRequest(error, .close(nil)) @@ -520,7 +528,7 @@ struct HTTPRequestStateMachine { switch self.state { case .initialized, .waitForChannelToBecomeWritable: - preconditionFailure("How can we receive a response head before sending a request head ourselves") + preconditionFailure("How can we receive a response head before sending a request head ourselves \(self.state)") case .running(.streaming(let expectedBodyLength, let sentBodyBytes, producer: .paused), .waitingForHead): self.state = .running( @@ -561,7 +569,7 @@ struct HTTPRequestStateMachine { mutating func receivedHTTPResponseBodyPart(_ body: ByteBuffer) -> Action { switch self.state { case .initialized, .waitForChannelToBecomeWritable: - preconditionFailure("How can we receive a response head before sending a request head ourselves. Invalid state: \(self.state)") + preconditionFailure("How can we receive a response head before completely sending a request head ourselves. Invalid state: \(self.state)") case .running(_, .waitingForHead): preconditionFailure("How can we receive a response body, if we haven't received a head. Invalid state: \(self.state)") @@ -587,7 +595,7 @@ struct HTTPRequestStateMachine { private mutating func receivedHTTPResponseEnd() -> Action { switch self.state { case .initialized, .waitForChannelToBecomeWritable: - preconditionFailure("How can we receive a response head before sending a request head ourselves. Invalid state: \(self.state)") + preconditionFailure("How can we receive a response end before completely sending a request head ourselves. Invalid state: \(self.state)") case .running(_, .waitingForHead): preconditionFailure("How can we receive a response end, if we haven't a received a head. Invalid state: \(self.state)") @@ -654,7 +662,7 @@ struct HTTPRequestStateMachine { case .initialized, .running(_, .waitingForHead), .waitForChannelToBecomeWritable: - preconditionFailure("The response is expected to only ask for more data after the response head was forwarded") + preconditionFailure("The response is expected to only ask for more data after the response head was forwarded \(self.state)") case .running(let requestState, .receivingBody(let head, var responseStreamState)): return self.avoidingStateMachineCoW { state -> Action in @@ -697,18 +705,51 @@ struct HTTPRequestStateMachine { } private mutating func startSendingRequest(head: HTTPRequestHead, metadata: RequestFramingMetadata) -> Action { - switch metadata.body { - case .stream: - self.state = .running(.streaming(expectedBodyLength: nil, sentBodyBytes: 0, producer: .producing), .waitingForHead) - return .sendRequestHead(head, startBody: true) - case .fixedSize(0): + let length = metadata.body.expectedLength + if length == 0 { // no body self.state = .running(.endSent, .waitingForHead) - return .sendRequestHead(head, startBody: false) - case .fixedSize(let length): - // length is greater than zero and we therefore have a body to send - self.state = .running(.streaming(expectedBodyLength: length, sentBodyBytes: 0, producer: .producing), .waitingForHead) - return .sendRequestHead(head, startBody: true) + return .sendRequestHead(head, sendEnd: true) + } else { + self.state = .running(.streaming(expectedBodyLength: length, sentBodyBytes: 0, producer: .paused), .waitingForHead) + return .sendRequestHead(head, sendEnd: false) + } + } + + mutating func headSent() -> Action { + switch self.state { + case .initialized, .waitForChannelToBecomeWritable, .finished: + preconditionFailure("Not a valid transition after `.sendingHeader`: \(self.state)") + + case .running(.streaming(let expectedBodyLength, let sentBodyBytes, producer: .paused), let responseState): + let startProducing = self.isChannelWritable && expectedBodyLength != sentBodyBytes + self.state = .running(.streaming( + expectedBodyLength: expectedBodyLength, + sentBodyBytes: sentBodyBytes, + producer: startProducing ? .producing : .paused + ), responseState) + return .notifyRequestHeadSendSuccessfully( + resumeRequestBodyStream: startProducing, + startIdleTimer: false + ) + case .running(.endSent, _): + return .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + case .running(.streaming(_, _, producer: .producing), _): + preconditionFailure("request body producing can not start before we have successfully send the header \(self.state)") + case .failed: + return .wait + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } +} + +extension RequestFramingMetadata.Body { + var expectedLength: Int? { + switch self { + case .fixedSize(let length): return length + case .stream: return nil } } } diff --git a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift index 820e6cf10..bdf897b3d 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift @@ -528,7 +528,6 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { } func testChannelBecomesNonWritableDuringHeaderWrite() throws { - try XCTSkipIf(true, "this currently fails and will be fixed in follow up PR") final class ChangeWritabilityOnFlush: ChannelOutboundHandler { typealias OutboundIn = Any func flush(context: ChannelHandlerContext) { diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift index fd771aca0..ce8e6ed17 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift @@ -26,7 +26,8 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: ["content-length": "4"]) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) XCTAssertEqual(state.runNewRequest(head: requestHead, metadata: metadata), .wait) - XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, sendEnd: false)) + XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [1])) @@ -64,7 +65,8 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["content-length": "12"]) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) @@ -92,7 +94,8 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/", headers: ["connection": "close"]) let metadata = RequestFramingMetadata(connectionClose: true, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) @@ -108,7 +111,8 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true)) let responseHead = HTTPResponseHead(version: .http1_0, status: .ok, headers: ["content-length": "4"]) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) @@ -124,7 +128,8 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true)) let responseHead = HTTPResponseHead(version: .http1_0, status: .ok, headers: ["content-length": "4", "connection": "keep-alive"]) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) @@ -141,7 +146,8 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["connection": "close"]) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) @@ -170,9 +176,11 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none)) + + XCTAssertEqual(state.headSent(), .wait) } func testRequestWasCancelledWhileUploadingData() { @@ -182,7 +190,8 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: ["content-length": "4"]) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) XCTAssertEqual(state.runNewRequest(head: requestHead, metadata: metadata), .wait) - XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, sendEnd: false)) + XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [1])) @@ -235,7 +244,8 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: "Hello world!\n"))), .wait) @@ -250,7 +260,8 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .switchingProtocols) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, [])) @@ -262,7 +273,8 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .init(statusCode: 103, reasonPhrase: "Early Hints")) XCTAssertEqual(state.channelRead(.head(responseHead)), .wait) XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none)) @@ -295,6 +307,12 @@ extension HTTP1ConnectionStateMachine.Action: Equatable { case (.sendRequestHead(let lhsHead, let lhsStartBody), .sendRequestHead(let rhsHead, let rhsStartBody)): return lhsHead == rhsHead && lhsStartBody == rhsStartBody + case ( + .notifyRequestHeadSendSuccessfully(let lhsResumeRequestBodyStream, let lhsStartIdleTimer), + .notifyRequestHeadSendSuccessfully(let rhsResumeRequestBodyStream, let rhsStartIdleTimer) + ): + return lhsResumeRequestBodyStream == rhsResumeRequestBodyStream && lhsStartIdleTimer == rhsStartIdleTimer + case (.sendBodyPart(let lhsData, let lhsPromise), .sendBodyPart(let rhsData, let rhsPromise)): return lhsData == rhsData && lhsPromise?.futureResult == rhsPromise?.futureResult diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift index 5dfce3f9d..4873bc169 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift @@ -347,7 +347,6 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { } func testChannelBecomesNonWritableDuringHeaderWrite() throws { - try XCTSkipIf(true, "this currently fails and will be fixed in follow up PR") final class ChangeWritabilityOnFlush: ChannelOutboundHandler { typealias OutboundIn = Any func flush(context: ChannelHandlerContext) { diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index 6e84f9d29..d5a8160b6 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -145,6 +145,8 @@ extension HTTPClientTests { ("testShutdownWithFutures", testShutdownWithFutures), ("testMassiveHeaderHTTP1", testMassiveHeaderHTTP1), ("testMassiveHeaderHTTP2", testMassiveHeaderHTTP2), + ("testCancelingHTTP1RequestAfterHeaderSend", testCancelingHTTP1RequestAfterHeaderSend), + ("testCancelingHTTP2RequestAfterHeaderSend", testCancelingHTTP2RequestAfterHeaderSend), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 54d854bf0..49f94a7d4 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -3365,7 +3365,6 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testMassiveHeaderHTTP1() throws { - try XCTSkipIf(true, "this currently crashes and will be fixed in follow up PR") var request = try HTTPClient.Request(url: defaultHTTPBin.baseURL, method: .POST) // add ~64 KB header let headerValue = String(repeating: "0", count: 1024) @@ -3380,7 +3379,6 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testMassiveHeaderHTTP2() throws { - try XCTSkipIf(true, "this currently crashes and will be fixed in follow up PR") let bin = HTTPBin(.http2(settings: [ .init(parameter: .maxConcurrentStreams, value: 100), .init(parameter: .maxHeaderListSize, value: 1024 * 256), @@ -3407,4 +3405,36 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertNoThrow(try client.execute(request: request).wait()) } + + func testCancelingHTTP1RequestAfterHeaderSend() throws { + var request = try HTTPClient.Request(url: self.defaultHTTPBin.baseURL + "/wait", method: .POST) + // non-empty body is important + request.body = .byteBuffer(ByteBuffer([1])) + + class CancelAfterHeadSend: HTTPClientResponseDelegate { + init() {} + func didFinishRequest(task: AsyncHTTPClient.HTTPClient.Task) throws {} + func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) { + task.cancel() + } + } + XCTAssertThrowsError(try defaultClient.execute(request: request, delegate: CancelAfterHeadSend()).wait()) + } + + func testCancelingHTTP2RequestAfterHeaderSend() throws { + let bin = HTTPBin(.http2()) + defer { XCTAssertNoThrow(try bin.shutdown()) } + var request = try HTTPClient.Request(url: bin.baseURL + "/wait", method: .POST) + // non-empty body is important + request.body = .byteBuffer(ByteBuffer([1])) + + class CancelAfterHeadSend: HTTPClientResponseDelegate { + init() {} + func didFinishRequest(task: AsyncHTTPClient.HTTPClient.Task) throws {} + func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) { + task.cancel() + } + } + XCTAssertThrowsError(try defaultClient.execute(request: request, delegate: CancelAfterHeadSend()).wait()) + } } diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift index 61ea4702b..92bf42b1d 100644 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift @@ -24,7 +24,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) @@ -38,7 +38,8 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "4")])) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) + XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [1])) let part2 = IOData.byteBuffer(ByteBuffer(bytes: [2])) @@ -72,7 +73,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "4")])) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) @@ -87,7 +88,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "8")])) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(8)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) @@ -98,7 +99,8 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) + XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false)) let part = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) XCTAssertEqual(state.requestStreamPartReceived(part, promise: nil), .sendBodyPart(part, nil)) @@ -132,7 +134,8 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) + XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false)) let part = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) XCTAssertEqual(state.requestStreamPartReceived(part, promise: nil), .sendBodyPart(part, nil)) XCTAssertEqual(state.writabilityChanged(writable: false), .pauseRequestBodyStream) @@ -157,7 +160,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) @@ -179,7 +182,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) @@ -200,7 +203,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) @@ -219,7 +222,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) @@ -239,7 +242,7 @@ class HTTPRequestStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .wait) XCTAssertEqual(state.read(), .read) - XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, sendEnd: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) @@ -261,7 +264,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) @@ -288,7 +291,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) @@ -315,7 +318,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) @@ -362,7 +365,7 @@ class HTTPRequestStateMachineTests: XCTestCase { // --- sending request let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) // --- receiving response let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["content-length": "4"]) @@ -377,7 +380,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .close(nil)) } @@ -385,7 +388,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/", headers: .init([("content-length", "4")])) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .close(nil)) XCTAssertEqual(state.requestStreamPartReceived(.byteBuffer(.init(bytes: 1...3)), promise: nil), .failSendBodyPart(HTTPClientError.cancelled, nil)) } @@ -394,7 +397,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) @@ -411,7 +414,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) let continueHead = HTTPResponseHead(version: .http1_1, status: .continue) XCTAssertEqual(state.channelRead(.head(continueHead)), .wait) @@ -427,7 +430,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) @@ -439,7 +442,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) @@ -451,7 +454,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) state.errorHappened(HTTPParserError.invalidChunkSize).assertFailRequest(HTTPParserError.invalidChunkSize, .close(nil)) XCTAssertEqual(state.requestCancelled(), .wait, "A cancellation that happens to late is ignored") @@ -461,7 +464,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) let responseHead = HTTPResponseHead(version: .http1_0, status: .internalServerError) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) @@ -477,7 +480,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) let responseHead = HTTPResponseHead(version: .http1_0, status: .internalServerError) let body = ByteBuffer(string: "foo bar") @@ -495,7 +498,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .stream) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) let part1: ByteBuffer = .init(string: "foo") XCTAssertEqual(state.requestStreamPartReceived(.byteBuffer(part1), promise: nil), .sendBodyPart(.byteBuffer(part1), nil)) @@ -515,7 +518,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) let body = ByteBuffer(string: "foo bar") @@ -531,7 +534,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) XCTAssertEqual(state.errorHappened(NIOSSLError.uncleanShutdown), .wait) state.channelInactive().assertFailRequest(HTTPClientError.remoteConnectionClosed, .none) @@ -542,7 +545,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) state.errorHappened(ArbitraryError()).assertFailRequest(ArbitraryError(), .close(nil)) XCTAssertEqual(state.channelInactive(), .wait) @@ -552,7 +555,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["content-length": "30"]) let body = ByteBuffer(string: "foo bar") @@ -570,7 +573,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"]) let body = ByteBuffer(string: "foo bar") @@ -591,7 +594,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"]) let body = ByteBuffer(string: "foo bar") @@ -612,7 +615,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"]) let body = ByteBuffer(string: "foo bar") @@ -632,7 +635,7 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"]) let body = ByteBuffer(string: "foo bar") @@ -668,6 +671,12 @@ extension HTTPRequestStateMachine.Action: Equatable { case (.sendRequestHead(let lhsHead, let lhsStartBody), .sendRequestHead(let rhsHead, let rhsStartBody)): return lhsHead == rhsHead && lhsStartBody == rhsStartBody + case ( + .notifyRequestHeadSendSuccessfully(let lhsResumeRequestBodyStream, let lhsStartIdleTimer), + .notifyRequestHeadSendSuccessfully(let rhsResumeRequestBodyStream, let rhsStartIdleTimer) + ): + return lhsResumeRequestBodyStream == rhsResumeRequestBodyStream && lhsStartIdleTimer == rhsStartIdleTimer + case (.sendBodyPart(let lhsData, let lhsPromise), .sendBodyPart(let rhsData, let rhsPromise)): return lhsData == rhsData && lhsPromise?.futureResult == rhsPromise?.futureResult