diff --git a/Sources/Apollo/Promise.swift b/Sources/Apollo/Promise.swift index 202dee7694..28ca10d284 100644 --- a/Sources/Apollo/Promise.swift +++ b/Sources/Apollo/Promise.swift @@ -3,6 +3,7 @@ import Dispatch func whenAll(_ promises: [Promise], notifyOn queue: DispatchQueue = .global()) -> Promise<[Value]> { return Promise { (fulfill, reject) in let group = DispatchGroup() + var rejected = false for promise in promises { group.enter() @@ -11,11 +12,15 @@ func whenAll(_ promises: [Promise], notifyOn queue: DispatchQueue group.leave() }.catch { error in reject(error) + rejected = true + group.leave() } } group.notify(queue: queue) { - fulfill(promises.map { $0.result!.value! }) + if !rejected { + fulfill(promises.map { $0.result!.value! }) + } } } } diff --git a/Sources/Apollo/ResultOrPromise.swift b/Sources/Apollo/ResultOrPromise.swift index 9d70e420b0..c2814ecaa2 100644 --- a/Sources/Apollo/ResultOrPromise.swift +++ b/Sources/Apollo/ResultOrPromise.swift @@ -19,6 +19,7 @@ func whenAll(_ resultsOrPromises: [ResultOrPromise], notifyOn queu return .promise(Promise { (fulfill, reject) in let group = DispatchGroup() + var rejected = false for resultOrPromise in resultsOrPromises { group.enter() @@ -27,11 +28,15 @@ func whenAll(_ resultsOrPromises: [ResultOrPromise], notifyOn queu group.leave() }.catch { error in reject(error) + rejected = true + group.leave() } } group.notify(queue: queue) { - fulfill(resultsOrPromises.map { $0.result!.value! }) + if !rejected { + fulfill(resultsOrPromises.map { $0.result!.value! }) + } } }) } diff --git a/Tests/ApolloTests/PromiseTests.swift b/Tests/ApolloTests/PromiseTests.swift index aeab6fb9fe..a601a2c803 100644 --- a/Tests/ApolloTests/PromiseTests.swift +++ b/Tests/ApolloTests/PromiseTests.swift @@ -249,4 +249,34 @@ class PromiseTests: XCTestCase { waitForExpectations(timeout: 1) } + + func testWhenAllRejectsWhenAnyOfThePromisesRejectsAsync_doesNotCreateMemoryLeak() throws { + var executeReject: ((Error) -> Void)? + + var promises: [Promise] = [Promise(fulfilled: "foo"), + Promise { _, reject in executeReject = reject }, + Promise(fulfilled: "bar")] + weak var weakPromise0 = promises[0] + weak var weakPromise1 = promises[1] + weak var weakPromise2 = promises[2] + + let expectation = self.expectation(description: "whenAll catch handler invoked") + + whenAll(promises).catch { error in + XCTAssert(error is TestError) + + expectation.fulfill() + } + + promises = [] + executeReject?(TestError()) + executeReject = nil + + waitForExpectations(timeout: 1) + + XCTAssertNil(weakPromise0) + XCTAssertNil(weakPromise1) + XCTAssertNil(weakPromise2) + } + }