diff --git a/src/execution/PromiseCanceller.ts b/src/execution/PromiseCanceller.ts index 60c3e3b6a3..16b438e17f 100644 --- a/src/execution/PromiseCanceller.ts +++ b/src/execution/PromiseCanceller.ts @@ -1,8 +1,8 @@ import { promiseWithResolvers } from '../jsutils/promiseWithResolvers.js'; /** - * A PromiseCanceller object can be used to cancel multiple promises - * using a single AbortSignal. + * A PromiseCanceller object can be used to trigger multiple responses + * in response to a single AbortSignal. * * @internal */ @@ -28,7 +28,7 @@ export class PromiseCanceller { this.abortSignal.removeEventListener('abort', this.abort); } - withCancellation(originalPromise: Promise): Promise { + cancellablePromise(originalPromise: Promise): Promise { if (this.abortSignal.aborted) { // eslint-disable-next-line @typescript-eslint/prefer-promise-reject-errors return Promise.reject(this.abortSignal.reason); @@ -50,4 +50,27 @@ export class PromiseCanceller { return promise; } + + cancellableIterable(iterable: AsyncIterable): AsyncIterable { + const iterator = iterable[Symbol.asyncIterator](); + + const _next = iterator.next.bind(iterator); + + if (iterator.return) { + const _return = iterator.return.bind(iterator); + + return { + [Symbol.asyncIterator]: () => ({ + next: () => this.cancellablePromise(_next()), + return: () => this.cancellablePromise(_return()), + }), + }; + } + + return { + [Symbol.asyncIterator]: () => ({ + next: () => this.cancellablePromise(_next()), + }), + }; + } } diff --git a/src/execution/__tests__/PromiseCanceller-test.ts b/src/execution/__tests__/PromiseCanceller-test.ts index 91fe6c40e5..5800c4ceac 100644 --- a/src/execution/__tests__/PromiseCanceller-test.ts +++ b/src/execution/__tests__/PromiseCanceller-test.ts @@ -5,52 +5,117 @@ import { expectPromise } from '../../__testUtils__/expectPromise.js'; import { PromiseCanceller } from '../PromiseCanceller.js'; describe('PromiseCanceller', () => { - it('works to cancel an already resolved promise', async () => { - const abortController = new AbortController(); - const abortSignal = abortController.signal; + describe('cancellablePromise', () => { + it('works to cancel an already resolved promise', async () => { + const abortController = new AbortController(); + const abortSignal = abortController.signal; - const promiseCanceller = new PromiseCanceller(abortSignal); + const promiseCanceller = new PromiseCanceller(abortSignal); - const promise = Promise.resolve(1); + const promise = Promise.resolve(1); - const withCancellation = promiseCanceller.withCancellation(promise); + const withCancellation = promiseCanceller.cancellablePromise(promise); - abortController.abort(new Error('Cancelled!')); + abortController.abort(new Error('Cancelled!')); - await expectPromise(withCancellation).toRejectWith('Cancelled!'); - }); + await expectPromise(withCancellation).toRejectWith('Cancelled!'); + }); + + it('works to cancel an already resolved promise after abort signal triggered', async () => { + const abortController = new AbortController(); + const abortSignal = abortController.signal; + + abortController.abort(new Error('Cancelled!')); - it('works to cancel a hanging promise', async () => { - const abortController = new AbortController(); - const abortSignal = abortController.signal; + const promiseCanceller = new PromiseCanceller(abortSignal); - const promiseCanceller = new PromiseCanceller(abortSignal); + const promise = Promise.resolve(1); - const promise = new Promise(() => { - /* never resolves */ + const withCancellation = promiseCanceller.cancellablePromise(promise); + + await expectPromise(withCancellation).toRejectWith('Cancelled!'); }); - const withCancellation = promiseCanceller.withCancellation(promise); + it('works to cancel a hanging promise', async () => { + const abortController = new AbortController(); + const abortSignal = abortController.signal; + + const promiseCanceller = new PromiseCanceller(abortSignal); + + const promise = new Promise(() => { + /* never resolves */ + }); + + const withCancellation = promiseCanceller.cancellablePromise(promise); + + abortController.abort(new Error('Cancelled!')); + + await expectPromise(withCancellation).toRejectWith('Cancelled!'); + }); + + it('works to cancel a hanging promise created after abort signal triggered', async () => { + const abortController = new AbortController(); + const abortSignal = abortController.signal; + + abortController.abort(new Error('Cancelled!')); - abortController.abort(new Error('Cancelled!')); + const promiseCanceller = new PromiseCanceller(abortSignal); - await expectPromise(withCancellation).toRejectWith('Cancelled!'); + const promise = new Promise(() => { + /* never resolves */ + }); + + const withCancellation = promiseCanceller.cancellablePromise(promise); + + await expectPromise(withCancellation).toRejectWith('Cancelled!'); + }); }); - it('works to cancel a hanging promise created after abort signal triggered', async () => { - const abortController = new AbortController(); - const abortSignal = abortController.signal; + describe('cancellableAsyncIterable', () => { + it('works to abort a next call', async () => { + const abortController = new AbortController(); + const abortSignal = abortController.signal; + + const promiseCanceller = new PromiseCanceller(abortSignal); + + const asyncIterable = { + [Symbol.asyncIterator]: () => ({ + next: () => Promise.resolve({ value: 1, done: false }), + }), + }; + + const cancellableAsyncIterable = + promiseCanceller.cancellableIterable(asyncIterable); - abortController.abort(new Error('Cancelled!')); + const nextPromise = + cancellableAsyncIterable[Symbol.asyncIterator]().next(); - const promiseCanceller = new PromiseCanceller(abortSignal); + abortController.abort(new Error('Cancelled!')); - const promise = new Promise(() => { - /* never resolves */ + await expectPromise(nextPromise).toRejectWith('Cancelled!'); }); - const withCancellation = promiseCanceller.withCancellation(promise); + it('works to abort a next call when already aborted', async () => { + const abortController = new AbortController(); + const abortSignal = abortController.signal; - await expectPromise(withCancellation).toRejectWith('Cancelled!'); + abortController.abort(new Error('Cancelled!')); + + const promiseCanceller = new PromiseCanceller(abortSignal); + + const asyncIterable = { + [Symbol.asyncIterator]: () => ({ + next: () => Promise.resolve({ value: 1, done: false }), + }), + }; + + const cancellableAsyncIterable = + promiseCanceller.cancellableIterable(asyncIterable); + + const nextPromise = + cancellableAsyncIterable[Symbol.asyncIterator]().next(); + + await expectPromise(nextPromise).toRejectWith('Cancelled!'); + }); }); }); diff --git a/src/execution/__tests__/abort-signal-test.ts b/src/execution/__tests__/abort-signal-test.ts index d12253b517..3c2f41553f 100644 --- a/src/execution/__tests__/abort-signal-test.ts +++ b/src/execution/__tests__/abort-signal-test.ts @@ -1,9 +1,12 @@ -import { expect } from 'chai'; +import { assert, expect } from 'chai'; import { describe, it } from 'mocha'; import { expectJSON } from '../../__testUtils__/expectJSON.js'; +import { expectPromise } from '../../__testUtils__/expectPromise.js'; import { resolveOnNextTick } from '../../__testUtils__/resolveOnNextTick.js'; +import { isAsyncIterable } from '../../jsutils/isAsyncIterable.js'; + import type { DocumentNode } from '../../language/ast.js'; import { parse } from '../../language/parser.js'; @@ -400,6 +403,56 @@ describe('Execute: Cancellation', () => { }); }); + it('should stop the execution when aborted despite a hanging async item', async () => { + const abortController = new AbortController(); + const document = parse(` + query { + todo { + id + items + } + } + `); + + const resultPromise = execute({ + document, + schema, + abortSignal: abortController.signal, + rootValue: { + todo: () => ({ + id: '1', + async *items() { + yield await new Promise(() => { + /* will never resolve */ + }); /* c8 ignore start */ + } /* c8 ignore stop */, + }), + }, + }); + + abortController.abort(); + + const result = await resultPromise; + + expect(result.errors?.[0].originalError?.name).to.equal('AbortError'); + + expectJSON(result).toDeepEqual({ + data: { + todo: { + id: '1', + items: null, + }, + }, + errors: [ + { + message: 'This operation was aborted', + path: ['todo', 'items'], + locations: [{ line: 5, column: 11 }], + }, + ], + }); + }); + it('should stop the execution when aborted with proper null bubbling', async () => { const abortController = new AbortController(); const document = parse(` @@ -610,6 +663,63 @@ describe('Execute: Cancellation', () => { ]); }); + it('should stop streamed execution when aborted', async () => { + const abortController = new AbortController(); + const document = parse(` + query { + todo { + id + items @stream + } + } + `); + + const resultPromise = complete( + document, + { + todo: { + id: '1', + items: [Promise.resolve('item')], + }, + }, + abortController.signal, + ); + + abortController.abort(); + + const result = await resultPromise; + + expectJSON(result).toDeepEqual([ + { + data: { + todo: { + id: '1', + items: [], + }, + }, + pending: [{ id: '0', path: ['todo', 'items'] }], + hasNext: true, + }, + { + incremental: [ + { + items: [null], + errors: [ + { + message: 'This operation was aborted', + path: ['todo', 'items', 0], + locations: [{ line: 5, column: 11 }], + }, + ], + id: '0', + }, + ], + completed: [{ id: '0' }], + hasNext: false, + }, + ]); + }); + it('should stop the execution when aborted mid-mutation', async () => { const abortController = new AbortController(); const document = parse(` @@ -685,7 +795,7 @@ describe('Execute: Cancellation', () => { }); }); - it('should stop the execution when aborted during subscription', async () => { + it('should stop the execution when aborted prior to return of a subscription resolver', async () => { const abortController = new AbortController(); const document = parse(` subscription { @@ -693,7 +803,7 @@ describe('Execute: Cancellation', () => { } `); - const resultPromise = subscribe({ + const subscriptionPromise = subscribe({ document, schema, abortSignal: abortController.signal, @@ -707,7 +817,7 @@ describe('Execute: Cancellation', () => { abortController.abort(); - const result = await resultPromise; + const result = await subscriptionPromise; expectJSON(result).toDeepEqual({ errors: [ @@ -719,4 +829,120 @@ describe('Execute: Cancellation', () => { ], }); }); + + it('should successfully wrap the subscription', async () => { + const abortController = new AbortController(); + const document = parse(` + subscription { + foo + } + `); + + async function* foo() { + yield await Promise.resolve({ foo: 'foo' }); + } + + const subscription = await subscribe({ + document, + schema, + abortSignal: abortController.signal, + rootValue: { + foo: Promise.resolve(foo()), + }, + }); + + assert(isAsyncIterable(subscription)); + + expectJSON(await subscription.next()).toDeepEqual({ + value: { + data: { + foo: 'foo', + }, + }, + done: false, + }); + + expectJSON(await subscription.next()).toDeepEqual({ + value: undefined, + done: true, + }); + }); + + it('should stop the execution when aborted during subscription', async () => { + const abortController = new AbortController(); + const document = parse(` + subscription { + foo + } + `); + + async function* foo() { + yield await Promise.resolve({ foo: 'foo' }); + } + + const subscription = subscribe({ + document, + schema, + abortSignal: abortController.signal, + rootValue: { + foo: foo(), + }, + }); + + assert(isAsyncIterable(subscription)); + + expectJSON(await subscription.next()).toDeepEqual({ + value: { + data: { + foo: 'foo', + }, + }, + done: false, + }); + + abortController.abort(); + + await expectPromise(subscription.next()).toRejectWith( + 'This operation was aborted', + ); + }); + + it('should stop the execution when aborted during subscription returned asynchronously', async () => { + const abortController = new AbortController(); + const document = parse(` + subscription { + foo + } + `); + + async function* foo() { + yield await Promise.resolve({ foo: 'foo' }); + } + + const subscription = await subscribe({ + document, + schema, + abortSignal: abortController.signal, + rootValue: { + foo: Promise.resolve(foo()), + }, + }); + + assert(isAsyncIterable(subscription)); + + expectJSON(await subscription.next()).toDeepEqual({ + value: { + data: { + foo: 'foo', + }, + }, + done: false, + }); + + abortController.abort(); + + await expectPromise(subscription.next()).toRejectWith( + 'This operation was aborted', + ); + }); }); diff --git a/src/execution/__tests__/mapAsyncIterable-test.ts b/src/execution/__tests__/mapAsyncIterable-test.ts index dee53aa486..599e15f05e 100644 --- a/src/execution/__tests__/mapAsyncIterable-test.ts +++ b/src/execution/__tests__/mapAsyncIterable-test.ts @@ -89,6 +89,54 @@ describe('mapAsyncIterable', () => { }); }); + it('calls done when completes', async () => { + async function* source() { + yield 1; + yield 2; + yield 3; + } + + let done = false; + const doubles = mapAsyncIterable( + source(), + (x) => Promise.resolve(x + x), + () => { + done = true; + }, + ); + + expect(await doubles.next()).to.deep.equal({ value: 2, done: false }); + expect(await doubles.next()).to.deep.equal({ value: 4, done: false }); + expect(await doubles.next()).to.deep.equal({ value: 6, done: false }); + expect(done).to.equal(false); + expect(await doubles.next()).to.deep.equal({ + value: undefined, + done: true, + }); + expect(done).to.equal(true); + }); + + it('calls done when completes with error', async () => { + async function* source() { + yield 1; + throw new Error('Oops'); + } + + let done = false; + const doubles = mapAsyncIterable( + source(), + (x) => Promise.resolve(x + x), + () => { + done = true; + }, + ); + + expect(await doubles.next()).to.deep.equal({ value: 2, done: false }); + expect(done).to.equal(false); + await expectPromise(doubles.next()).toRejectWith('Oops'); + expect(done).to.equal(true); + }); + it('allows returning early from mapped async generator', async () => { async function* source() { try { diff --git a/src/execution/execute.ts b/src/execution/execute.ts index 1258abf279..d847835d98 100644 --- a/src/execution/execute.ts +++ b/src/execution/execute.ts @@ -878,7 +878,7 @@ function executeField( fieldDetailsList, info, path, - promiseCanceller?.withCancellation(result) ?? result, + promiseCanceller?.cancellablePromise(result) ?? result, incrementalContext, deferMap, ); @@ -1386,7 +1386,9 @@ function completeListValue( const itemType = returnType.ofType; if (isAsyncIterable(result)) { - const asyncIterator = result[Symbol.asyncIterator](); + const maybeCancellableIterable = + exeContext.promiseCanceller?.cancellableIterable(result) ?? result; + const asyncIterator = maybeCancellableIterable[Symbol.asyncIterator](); return completeAsyncIteratorValue( exeContext, @@ -1597,7 +1599,7 @@ async function completePromisedListItemValue( deferMap: ReadonlyMap | undefined, ): Promise { try { - const resolved = await (exeContext.promiseCanceller?.withCancellation( + const resolved = await (exeContext.promiseCanceller?.cancellablePromise( item, ) ?? item); let completed = completeValue( @@ -2105,17 +2107,26 @@ function mapSourceToResponse( return resultOrStream; } + const abortSignal = validatedExecutionArgs.abortSignal; + const promiseCanceller = abortSignal + ? new PromiseCanceller(abortSignal) + : undefined; + // For each payload yielded from a subscription, map it over the normal // GraphQL `execute` function, with `payload` as the rootValue. // This implements the "MapSourceToResponseEvent" algorithm described in // the GraphQL specification.. - return mapAsyncIterable(resultOrStream, (payload: unknown) => { - const perEventExecutionArgs: ValidatedExecutionArgs = { - ...validatedExecutionArgs, - rootValue: payload, - }; - return validatedExecutionArgs.perEventExecutor(perEventExecutionArgs); - }); + return mapAsyncIterable( + promiseCanceller?.cancellableIterable(resultOrStream) ?? resultOrStream, + (payload: unknown) => { + const perEventExecutionArgs: ValidatedExecutionArgs = { + ...validatedExecutionArgs, + rootValue: payload, + }; + return validatedExecutionArgs.perEventExecutor(perEventExecutionArgs); + }, + () => promiseCanceller?.disconnect(), + ); } export function executeSubscriptionEvent( @@ -2267,11 +2278,10 @@ function executeSubscription( const promiseCanceller = abortSignal ? new PromiseCanceller(abortSignal) : undefined; - const promise = promiseCanceller?.withCancellation(result) ?? result; + + const promise = promiseCanceller?.cancellablePromise(result) ?? result; return promise.then(assertEventStream).then( (resolved) => { - // TODO: add test case - /* c8 ignore next */ promiseCanceller?.disconnect(); return resolved; }, @@ -2648,7 +2658,7 @@ function completeStreamItem( fieldDetailsList, info, itemPath, - exeContext.promiseCanceller?.withCancellation(item) ?? item, + exeContext.promiseCanceller?.cancellablePromise(item) ?? item, incrementalContext, new Map(), ).then( diff --git a/src/execution/mapAsyncIterable.ts b/src/execution/mapAsyncIterable.ts index 0f6fd78c2d..e0f942fd53 100644 --- a/src/execution/mapAsyncIterable.ts +++ b/src/execution/mapAsyncIterable.ts @@ -7,18 +7,28 @@ import type { PromiseOrValue } from '../jsutils/PromiseOrValue.js'; export function mapAsyncIterable( iterable: AsyncGenerator | AsyncIterable, callback: (value: T) => PromiseOrValue, + onDone?: (() => void) | undefined, ): AsyncGenerator { const iterator = iterable[Symbol.asyncIterator](); async function mapResult( - result: IteratorResult, + promise: Promise>, ): Promise> { - if (result.done) { - return result; + let value: T; + try { + const result = await promise; + if (result.done) { + onDone?.(); + return result; + } + value = result.value; + } catch (error) { + onDone?.(); + throw error; } try { - return { value: await callback(result.value), done: false }; + return { value: await callback(value), done: false }; } catch (error) { /* c8 ignore start */ // FIXME: add test case @@ -36,17 +46,17 @@ export function mapAsyncIterable( return { async next() { - return mapResult(await iterator.next()); + return mapResult(iterator.next()); }, async return(): Promise> { // If iterator.return() does not exist, then type R must be undefined. return typeof iterator.return === 'function' - ? mapResult(await iterator.return()) + ? mapResult(iterator.return()) : { value: undefined as any, done: true }; }, async throw(error?: unknown) { if (typeof iterator.throw === 'function') { - return mapResult(await iterator.throw(error)); + return mapResult(iterator.throw(error)); } throw error; },