diff --git a/src/execution/__tests__/stream-test.ts b/src/execution/__tests__/stream-test.ts index b3187cbf30..30dae4c37d 100644 --- a/src/execution/__tests__/stream-test.ts +++ b/src/execution/__tests__/stream-test.ts @@ -244,6 +244,22 @@ async function complete(document: DocumentNode, rootValue: unknown = {}) { return result; } +async function completeAsync(document: DocumentNode, numCalls: number) { + const schema = new GraphQLSchema({ query }); + + const result = await execute({ schema, document, rootValue: {} }); + + invariant(isAsyncIterable(result)); + + const iterator = result[Symbol.asyncIterator](); + + const promises = []; + for (let i = 0; i < numCalls; i++) { + promises.push(iterator.next()); + } + return Promise.all(promises); +} + describe('Execute: stream directive', () => { it('Can stream a list field', async () => { const document = parse('{ scalarList @stream(initialCount: 1) }'); @@ -684,6 +700,60 @@ describe('Execute: stream directive', () => { }, }); }); + it('Can handle concurrent calls to .next() without waiting', async () => { + const document = parse(` + query { + asyncIterableList @stream(initialCount: 2) { + name + id + } + } + `); + const result = await completeAsync(document, 4); + expectJSON(result).toDeepEqual([ + { + done: false, + value: { + data: { + asyncIterableList: [ + { + name: 'Luke', + id: '1', + }, + { + name: 'Han', + id: '2', + }, + ], + }, + hasNext: true, + }, + }, + { + done: false, + value: { + data: [ + { + name: 'Leia', + id: '3', + }, + ], + path: ['asyncIterableList', 2], + hasNext: true, + }, + }, + { + done: false, + value: { + hasNext: false, + }, + }, + { + done: true, + value: undefined, + }, + ]); + }); it('Handles error thrown in async iterable before initialCount is reached', async () => { const document = parse(` query { diff --git a/src/execution/execute.ts b/src/execution/execute.ts index 92bba2aa99..bc74ac75cb 100644 --- a/src/execution/execute.ts +++ b/src/execution/execute.ts @@ -1648,7 +1648,18 @@ function yieldSubsequentPayloads( const data = await asyncPayloadRecord.data; + if (exeContext.subsequentPayloads.length === 0) { + // a different call to next has exhausted all payloads + return { value: undefined, done: true }; + } + const index = exeContext.subsequentPayloads.indexOf(asyncPayloadRecord); + + if (index === -1) { + // a different call to next has consumed this payload + return race(); + } + exeContext.subsequentPayloads.splice(index, 1); if (asyncPayloadRecord.isCompletedIterator) {