From f65f45b72743a3246f5bd43cfe2ba27e02045f6b Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 7 Feb 2023 17:18:56 +0100 Subject: [PATCH] Fix Request streaming memory leak (#665) --- Sources/AsyncHTTPClient/HTTPHandler.swift | 12 +++-- .../RequestBag+StateMachine.swift | 47 +++++++++--------- Sources/AsyncHTTPClient/RequestBag.swift | 3 +- .../RequestBagTests+XCTest.swift | 1 + .../RequestBagTests.swift | 49 +++++++++++++++++++ 5 files changed, 82 insertions(+), 30 deletions(-) diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index 73f467d09..7b2c5c6ff 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -379,7 +379,8 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate { } var state = State.idle - let request: HTTPClient.Request + let requestMethod: HTTPMethod + let requestHost: String static let maxByteBufferSize = Int(UInt32.max) @@ -408,14 +409,15 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate { maxBodySize <= Self.maxByteBufferSize, "maxBodyLength is not allowed to exceed 2^32 because ByteBuffer can not store more bytes" ) - self.request = request + self.requestMethod = request.method + self.requestHost = request.host self.maxBodySize = maxBodySize } public func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { switch self.state { case .idle: - if self.request.method != .HEAD, + if self.requestMethod != .HEAD, let contentLength = head.headers.first(name: "Content-Length"), let announcedBodySize = Int(contentLength), announcedBodySize > self.maxBodySize { @@ -481,9 +483,9 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate { case .idle: preconditionFailure("no head received before end") case .head(let head): - return Response(host: self.request.host, status: head.status, version: head.version, headers: head.headers, body: nil) + return Response(host: self.requestHost, status: head.status, version: head.version, headers: head.headers, body: nil) case .body(let head, let body): - return Response(host: self.request.host, status: head.status, version: head.version, headers: head.headers, body: body) + return Response(host: self.requestHost, status: head.status, version: head.version, headers: head.headers, body: body) case .end: preconditionFailure("request already processed") case .error(let error): diff --git a/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift b/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift index 9509fa2e6..e7fad6850 100644 --- a/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift +++ b/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift @@ -29,15 +29,15 @@ extension HTTPClient { extension RequestBag { struct StateMachine { fileprivate enum State { - case initialized - case queued(HTTPRequestScheduler) + case initialized(RedirectHandler?) + case queued(HTTPRequestScheduler, RedirectHandler?) /// if the deadline was exceeded while in the `.queued(_:)` state, /// we wait until the request pool fails the request with a potential more descriptive error message, /// if a connection failure has occured while the request was queued. case deadlineExceededWhileQueued case executing(HTTPRequestExecutor, RequestStreamState, ResponseStreamState) case finished(error: Error?) - case redirected(HTTPRequestExecutor, Int, HTTPResponseHead, URL) + case redirected(HTTPRequestExecutor, RedirectHandler, Int, HTTPResponseHead, URL) case modifying } @@ -55,23 +55,22 @@ extension RequestBag { case eof } - case initialized + case initialized(RedirectHandler?) case buffering(CircularBuffer, next: Next) case waitingForRemote } - private var state: State = .initialized - private let redirectHandler: RedirectHandler? + private var state: State init(redirectHandler: RedirectHandler?) { - self.redirectHandler = redirectHandler + self.state = .initialized(redirectHandler) } } } extension RequestBag.StateMachine { mutating func requestWasQueued(_ scheduler: HTTPRequestScheduler) { - guard case .initialized = self.state else { + guard case .initialized(let redirectHandler) = self.state else { // There might be a race between `requestWasQueued` and `willExecuteRequest`: // // If the request is created and passed to the HTTPClient on thread A, it will move into @@ -91,7 +90,7 @@ extension RequestBag.StateMachine { return } - self.state = .queued(scheduler) + self.state = .queued(scheduler, redirectHandler) } enum WillExecuteRequestAction { @@ -102,8 +101,8 @@ extension RequestBag.StateMachine { mutating func willExecuteRequest(_ executor: HTTPRequestExecutor) -> WillExecuteRequestAction { switch self.state { - case .initialized, .queued: - self.state = .executing(executor, .initialized, .initialized) + case .initialized(let redirectHandler), .queued(_, let redirectHandler): + self.state = .executing(executor, .initialized, .initialized(redirectHandler)) return .none case .deadlineExceededWhileQueued: let error: Error = HTTPClientError.deadlineExceeded @@ -127,8 +126,8 @@ extension RequestBag.StateMachine { case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("A request stream can only be resumed, if the request was started") - case .executing(let executor, .initialized, .initialized): - self.state = .executing(executor, .producing, .initialized) + case .executing(let executor, .initialized, .initialized(let redirectHandler)): + self.state = .executing(executor, .producing, .initialized(redirectHandler)) return .startWriter case .executing(_, .producing, _): @@ -299,11 +298,11 @@ extension RequestBag.StateMachine { case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("How can we receive a response, if the request hasn't started yet.") case .executing(let executor, let requestState, let responseState): - guard case .initialized = responseState else { + guard case .initialized(let redirectHandler) = responseState else { preconditionFailure("If we receive a response, we must not have received something else before") } - if let redirectURL = self.redirectHandler?.redirectTarget( + if let redirectHandler = redirectHandler, let redirectURL = redirectHandler.redirectTarget( status: head.status, responseHeaders: head.headers ) { @@ -312,11 +311,11 @@ extension RequestBag.StateMachine { // smaller than 3kb. switch head.contentLength { case .some(0...(HTTPClient.maxBodySizeRedirectResponse)), .none: - self.state = .redirected(executor, 0, head, redirectURL) + self.state = .redirected(executor, redirectHandler, 0, head, redirectURL) return .signalBodyDemand(executor) case .some: self.state = .finished(error: HTTPClientError.cancelled) - return .redirect(executor, self.redirectHandler!, head, redirectURL) + return .redirect(executor, redirectHandler, head, redirectURL) } } else { self.state = .executing(executor, requestState, .buffering(.init(), next: .askExecutorForMore)) @@ -369,15 +368,15 @@ extension RequestBag.StateMachine { } else { return .none } - case .redirected(let executor, var receivedBytes, let head, let redirectURL): + case .redirected(let executor, let redirectHandler, var receivedBytes, let head, let redirectURL): let partsLength = buffer.reduce(into: 0) { $0 += $1.readableBytes } receivedBytes += partsLength if receivedBytes > HTTPClient.maxBodySizeRedirectResponse { self.state = .finished(error: HTTPClientError.cancelled) - return .redirect(executor, self.redirectHandler!, head, redirectURL) + return .redirect(executor, redirectHandler, head, redirectURL) } else { - self.state = .redirected(executor, receivedBytes, head, redirectURL) + self.state = .redirected(executor, redirectHandler, receivedBytes, head, redirectURL) return .signalBodyDemand(executor) } @@ -428,9 +427,9 @@ extension RequestBag.StateMachine { self.state = .executing(executor, requestState, .buffering(newChunks, next: .eof)) return .consume(first) - case .redirected(_, _, let head, let redirectURL): + case .redirected(_, let redirectHandler, _, let head, let redirectURL): self.state = .finished(error: nil) - return .redirect(self.redirectHandler!, head, redirectURL) + return .redirect(redirectHandler, head, redirectURL) case .finished(error: .some): return .none @@ -553,7 +552,7 @@ extension RequestBag.StateMachine { mutating func deadlineExceeded() -> DeadlineExceededAction { switch self.state { - case .queued(let queuer): + case .queued(let queuer, _): /// We do not fail the request immediately because we want to give the scheduler a chance of throwing a better error message /// We therefore depend on the scheduler failing the request after we cancel the request. self.state = .deadlineExceededWhileQueued @@ -582,7 +581,7 @@ extension RequestBag.StateMachine { case .initialized: self.state = .finished(error: error) return .failTask(error, nil, nil) - case .queued(let queuer): + case .queued(let queuer, _): self.state = .finished(error: error) return .failTask(error, queuer, nil) case .executing(let executor, let requestState, .buffering(_, next: .eof)): diff --git a/Sources/AsyncHTTPClient/RequestBag.swift b/Sources/AsyncHTTPClient/RequestBag.swift index 50c0057ba..1119236fb 100644 --- a/Sources/AsyncHTTPClient/RequestBag.swift +++ b/Sources/AsyncHTTPClient/RequestBag.swift @@ -33,7 +33,7 @@ final class RequestBag { } private let delegate: Delegate - private let request: HTTPClient.Request + private var request: HTTPClient.Request // the request state is synchronized on the task eventLoop private var state: StateMachine @@ -126,6 +126,7 @@ final class RequestBag { guard let body = self.request.body else { preconditionFailure("Expected to have a body, if the `HTTPRequestStateMachine` resume a request stream") } + self.request.body = nil let writer = HTTPClient.Body.StreamWriter { self.writeNextRequestPart($0) diff --git a/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift b/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift index b6a05733c..53c152c06 100644 --- a/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift @@ -40,6 +40,7 @@ extension RequestBagTests { ("testRedirectWith3KBBody", testRedirectWith3KBBody), ("testRedirectWith4KBBodyAnnouncedInResponseHead", testRedirectWith4KBBodyAnnouncedInResponseHead), ("testRedirectWith4KBBodyNotAnnouncedInResponseHead", testRedirectWith4KBBodyNotAnnouncedInResponseHead), + ("testWeDontLeakTheRequestIfTheRequestWriterWasCapturedByAPromise", testWeDontLeakTheRequestIfTheRequestWriterWasCapturedByAPromise), ] } } diff --git a/Tests/AsyncHTTPClientTests/RequestBagTests.swift b/Tests/AsyncHTTPClientTests/RequestBagTests.swift index 43062405c..43134d453 100644 --- a/Tests/AsyncHTTPClientTests/RequestBagTests.swift +++ b/Tests/AsyncHTTPClientTests/RequestBagTests.swift @@ -17,6 +17,7 @@ import Logging import NIOCore import NIOEmbedded import NIOHTTP1 +import NIOPosix import XCTest final class RequestBagTests: XCTestCase { @@ -836,6 +837,54 @@ final class RequestBagTests: XCTestCase { XCTAssertTrue(redirectTriggered) } + + func testWeDontLeakTheRequestIfTheRequestWriterWasCapturedByAPromise() { + final class LeakDetector {} + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } + + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(group)) + defer { XCTAssertNoThrow(try httpClient.shutdown().wait()) } + + let httpBin = HTTPBin() + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + var leakDetector = LeakDetector() + + do { + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/", method: .POST)) + guard var request = maybeRequest else { return XCTFail("Expected to have a request here") } + + let writerPromise = group.any().makePromise(of: HTTPClient.Body.StreamWriter.self) + let donePromise = group.any().makePromise(of: Void.self) + request.body = .stream { [leakDetector] writer in + _ = leakDetector + writerPromise.succeed(writer) + return donePromise.futureResult + } + + let resultFuture = httpClient.execute(request: request) + request.body = nil + writerPromise.futureResult.whenSuccess { writer in + writer.write(.byteBuffer(ByteBuffer(string: "hello"))).map { + print("written") + }.cascade(to: donePromise) + } + XCTAssertNoThrow(try donePromise.futureResult.wait()) + print("HTTP sent") + + var result: HTTPClient.Response? + XCTAssertNoThrow(result = try resultFuture.wait()) + + XCTAssertEqual(.ok, result?.status) + let body = result?.body.map { String(buffer: $0) } + XCTAssertNotNil(body) + print("HTTP done") + } + XCTAssertTrue(isKnownUniquelyReferenced(&leakDetector)) + } } extension HTTPClient.Task {