Skip to content

Commit

Permalink
Fix Request streaming memory leak (#665)
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianfett authored Feb 7, 2023
1 parent 59bfb96 commit f65f45b
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 30 deletions.
12 changes: 7 additions & 5 deletions Sources/AsyncHTTPClient/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,8 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate {
}

var state = State.idle
let request: HTTPClient.Request
let requestMethod: HTTPMethod
let requestHost: String

static let maxByteBufferSize = Int(UInt32.max)

Expand Down Expand Up @@ -408,14 +409,15 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate {
maxBodySize <= Self.maxByteBufferSize,
"maxBodyLength is not allowed to exceed 2^32 because ByteBuffer can not store more bytes"
)
self.request = request
self.requestMethod = request.method
self.requestHost = request.host
self.maxBodySize = maxBodySize
}

public func didReceiveHead(task: HTTPClient.Task<Response>, _ head: HTTPResponseHead) -> EventLoopFuture<Void> {
switch self.state {
case .idle:
if self.request.method != .HEAD,
if self.requestMethod != .HEAD,
let contentLength = head.headers.first(name: "Content-Length"),
let announcedBodySize = Int(contentLength),
announcedBodySize > self.maxBodySize {
Expand Down Expand Up @@ -481,9 +483,9 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate {
case .idle:
preconditionFailure("no head received before end")
case .head(let head):
return Response(host: self.request.host, status: head.status, version: head.version, headers: head.headers, body: nil)
return Response(host: self.requestHost, status: head.status, version: head.version, headers: head.headers, body: nil)
case .body(let head, let body):
return Response(host: self.request.host, status: head.status, version: head.version, headers: head.headers, body: body)
return Response(host: self.requestHost, status: head.status, version: head.version, headers: head.headers, body: body)
case .end:
preconditionFailure("request already processed")
case .error(let error):
Expand Down
47 changes: 23 additions & 24 deletions Sources/AsyncHTTPClient/RequestBag+StateMachine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ extension HTTPClient {
extension RequestBag {
struct StateMachine {
fileprivate enum State {
case initialized
case queued(HTTPRequestScheduler)
case initialized(RedirectHandler<Delegate.Response>?)
case queued(HTTPRequestScheduler, RedirectHandler<Delegate.Response>?)
/// if the deadline was exceeded while in the `.queued(_:)` state,
/// we wait until the request pool fails the request with a potential more descriptive error message,
/// if a connection failure has occured while the request was queued.
case deadlineExceededWhileQueued
case executing(HTTPRequestExecutor, RequestStreamState, ResponseStreamState)
case finished(error: Error?)
case redirected(HTTPRequestExecutor, Int, HTTPResponseHead, URL)
case redirected(HTTPRequestExecutor, RedirectHandler<Delegate.Response>, Int, HTTPResponseHead, URL)
case modifying
}

Expand All @@ -55,23 +55,22 @@ extension RequestBag {
case eof
}

case initialized
case initialized(RedirectHandler<Delegate.Response>?)
case buffering(CircularBuffer<ByteBuffer>, next: Next)
case waitingForRemote
}

private var state: State = .initialized
private let redirectHandler: RedirectHandler<Delegate.Response>?
private var state: State

init(redirectHandler: RedirectHandler<Delegate.Response>?) {
self.redirectHandler = redirectHandler
self.state = .initialized(redirectHandler)
}
}
}

extension RequestBag.StateMachine {
mutating func requestWasQueued(_ scheduler: HTTPRequestScheduler) {
guard case .initialized = self.state else {
guard case .initialized(let redirectHandler) = self.state else {
// There might be a race between `requestWasQueued` and `willExecuteRequest`:
//
// If the request is created and passed to the HTTPClient on thread A, it will move into
Expand All @@ -91,7 +90,7 @@ extension RequestBag.StateMachine {
return
}

self.state = .queued(scheduler)
self.state = .queued(scheduler, redirectHandler)
}

enum WillExecuteRequestAction {
Expand All @@ -102,8 +101,8 @@ extension RequestBag.StateMachine {

mutating func willExecuteRequest(_ executor: HTTPRequestExecutor) -> WillExecuteRequestAction {
switch self.state {
case .initialized, .queued:
self.state = .executing(executor, .initialized, .initialized)
case .initialized(let redirectHandler), .queued(_, let redirectHandler):
self.state = .executing(executor, .initialized, .initialized(redirectHandler))
return .none
case .deadlineExceededWhileQueued:
let error: Error = HTTPClientError.deadlineExceeded
Expand All @@ -127,8 +126,8 @@ extension RequestBag.StateMachine {
case .initialized, .queued, .deadlineExceededWhileQueued:
preconditionFailure("A request stream can only be resumed, if the request was started")

case .executing(let executor, .initialized, .initialized):
self.state = .executing(executor, .producing, .initialized)
case .executing(let executor, .initialized, .initialized(let redirectHandler)):
self.state = .executing(executor, .producing, .initialized(redirectHandler))
return .startWriter

case .executing(_, .producing, _):
Expand Down Expand Up @@ -299,11 +298,11 @@ extension RequestBag.StateMachine {
case .initialized, .queued, .deadlineExceededWhileQueued:
preconditionFailure("How can we receive a response, if the request hasn't started yet.")
case .executing(let executor, let requestState, let responseState):
guard case .initialized = responseState else {
guard case .initialized(let redirectHandler) = responseState else {
preconditionFailure("If we receive a response, we must not have received something else before")
}

if let redirectURL = self.redirectHandler?.redirectTarget(
if let redirectHandler = redirectHandler, let redirectURL = redirectHandler.redirectTarget(
status: head.status,
responseHeaders: head.headers
) {
Expand All @@ -312,11 +311,11 @@ extension RequestBag.StateMachine {
// smaller than 3kb.
switch head.contentLength {
case .some(0...(HTTPClient.maxBodySizeRedirectResponse)), .none:
self.state = .redirected(executor, 0, head, redirectURL)
self.state = .redirected(executor, redirectHandler, 0, head, redirectURL)
return .signalBodyDemand(executor)
case .some:
self.state = .finished(error: HTTPClientError.cancelled)
return .redirect(executor, self.redirectHandler!, head, redirectURL)
return .redirect(executor, redirectHandler, head, redirectURL)
}
} else {
self.state = .executing(executor, requestState, .buffering(.init(), next: .askExecutorForMore))
Expand Down Expand Up @@ -369,15 +368,15 @@ extension RequestBag.StateMachine {
} else {
return .none
}
case .redirected(let executor, var receivedBytes, let head, let redirectURL):
case .redirected(let executor, let redirectHandler, var receivedBytes, let head, let redirectURL):
let partsLength = buffer.reduce(into: 0) { $0 += $1.readableBytes }
receivedBytes += partsLength

if receivedBytes > HTTPClient.maxBodySizeRedirectResponse {
self.state = .finished(error: HTTPClientError.cancelled)
return .redirect(executor, self.redirectHandler!, head, redirectURL)
return .redirect(executor, redirectHandler, head, redirectURL)
} else {
self.state = .redirected(executor, receivedBytes, head, redirectURL)
self.state = .redirected(executor, redirectHandler, receivedBytes, head, redirectURL)
return .signalBodyDemand(executor)
}

Expand Down Expand Up @@ -428,9 +427,9 @@ extension RequestBag.StateMachine {
self.state = .executing(executor, requestState, .buffering(newChunks, next: .eof))
return .consume(first)

case .redirected(_, _, let head, let redirectURL):
case .redirected(_, let redirectHandler, _, let head, let redirectURL):
self.state = .finished(error: nil)
return .redirect(self.redirectHandler!, head, redirectURL)
return .redirect(redirectHandler, head, redirectURL)

case .finished(error: .some):
return .none
Expand Down Expand Up @@ -553,7 +552,7 @@ extension RequestBag.StateMachine {

mutating func deadlineExceeded() -> DeadlineExceededAction {
switch self.state {
case .queued(let queuer):
case .queued(let queuer, _):
/// We do not fail the request immediately because we want to give the scheduler a chance of throwing a better error message
/// We therefore depend on the scheduler failing the request after we cancel the request.
self.state = .deadlineExceededWhileQueued
Expand Down Expand Up @@ -582,7 +581,7 @@ extension RequestBag.StateMachine {
case .initialized:
self.state = .finished(error: error)
return .failTask(error, nil, nil)
case .queued(let queuer):
case .queued(let queuer, _):
self.state = .finished(error: error)
return .failTask(error, queuer, nil)
case .executing(let executor, let requestState, .buffering(_, next: .eof)):
Expand Down
3 changes: 2 additions & 1 deletion Sources/AsyncHTTPClient/RequestBag.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
}

private let delegate: Delegate
private let request: HTTPClient.Request
private var request: HTTPClient.Request

// the request state is synchronized on the task eventLoop
private var state: StateMachine
Expand Down Expand Up @@ -126,6 +126,7 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
guard let body = self.request.body else {
preconditionFailure("Expected to have a body, if the `HTTPRequestStateMachine` resume a request stream")
}
self.request.body = nil

let writer = HTTPClient.Body.StreamWriter {
self.writeNextRequestPart($0)
Expand Down
1 change: 1 addition & 0 deletions Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ extension RequestBagTests {
("testRedirectWith3KBBody", testRedirectWith3KBBody),
("testRedirectWith4KBBodyAnnouncedInResponseHead", testRedirectWith4KBBodyAnnouncedInResponseHead),
("testRedirectWith4KBBodyNotAnnouncedInResponseHead", testRedirectWith4KBBodyNotAnnouncedInResponseHead),
("testWeDontLeakTheRequestIfTheRequestWriterWasCapturedByAPromise", testWeDontLeakTheRequestIfTheRequestWriterWasCapturedByAPromise),
]
}
}
49 changes: 49 additions & 0 deletions Tests/AsyncHTTPClientTests/RequestBagTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import Logging
import NIOCore
import NIOEmbedded
import NIOHTTP1
import NIOPosix
import XCTest

final class RequestBagTests: XCTestCase {
Expand Down Expand Up @@ -836,6 +837,54 @@ final class RequestBagTests: XCTestCase {

XCTAssertTrue(redirectTriggered)
}

func testWeDontLeakTheRequestIfTheRequestWriterWasCapturedByAPromise() {
final class LeakDetector {}

let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) }

let httpClient = HTTPClient(eventLoopGroupProvider: .shared(group))
defer { XCTAssertNoThrow(try httpClient.shutdown().wait()) }

let httpBin = HTTPBin()
defer { XCTAssertNoThrow(try httpBin.shutdown()) }

var leakDetector = LeakDetector()

do {
var maybeRequest: HTTPClient.Request?
XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/", method: .POST))
guard var request = maybeRequest else { return XCTFail("Expected to have a request here") }

let writerPromise = group.any().makePromise(of: HTTPClient.Body.StreamWriter.self)
let donePromise = group.any().makePromise(of: Void.self)
request.body = .stream { [leakDetector] writer in
_ = leakDetector
writerPromise.succeed(writer)
return donePromise.futureResult
}

let resultFuture = httpClient.execute(request: request)
request.body = nil
writerPromise.futureResult.whenSuccess { writer in
writer.write(.byteBuffer(ByteBuffer(string: "hello"))).map {
print("written")
}.cascade(to: donePromise)
}
XCTAssertNoThrow(try donePromise.futureResult.wait())
print("HTTP sent")

var result: HTTPClient.Response?
XCTAssertNoThrow(result = try resultFuture.wait())

XCTAssertEqual(.ok, result?.status)
let body = result?.body.map { String(buffer: $0) }
XCTAssertNotNil(body)
print("HTTP done")
}
XCTAssertTrue(isKnownUniquelyReferenced(&leakDetector))
}
}

extension HTTPClient.Task {
Expand Down

0 comments on commit f65f45b

Please sign in to comment.