diff --git a/src/__tests__/http-test.ts b/src/__tests__/http-test.ts index eb97113b..190c608a 100644 --- a/src/__tests__/http-test.ts +++ b/src/__tests__/http-test.ts @@ -1,4 +1,5 @@ import zlib from 'zlib'; +import type http from 'http'; import type { Server as Restify } from 'restify'; import connect from 'connect'; @@ -81,6 +82,12 @@ function urlString(urlParams?: { [param: string]: string }): string { return string; } +function sleep() { + return new Promise((r) => { + setTimeout(r, 1); + }); +} + describe('GraphQL-HTTP tests for connect', () => { runTests(() => { const app = connect(); @@ -2389,9 +2396,7 @@ function runTests(server: Server) { graphqlHTTP(() => ({ schema: TestSchema, async *customExecuteFn() { - await new Promise((r) => { - setTimeout(r, 1); - }); + await sleep(); yield { data: { test2: 'Modification', @@ -2436,6 +2441,141 @@ function runTests(server: Server) { ].join('\r\n'), ); }); + + it('calls return on underlying async iterable when connection is closed', async () => { + const app = server(); + const fakeReturn = sinon.fake(); + + app.get( + urlString(), + graphqlHTTP(() => ({ + schema: TestSchema, + // custom iterable keeps yielding until return is called + customExecuteFn() { + let returned = false; + return { + [Symbol.asyncIterator]: () => ({ + next: async () => { + await sleep(); + if (returned) { + return { value: undefined, done: true }; + } + return { + value: { data: { test: 'Hello, World' }, hasNext: true }, + done: false, + }; + }, + return: () => { + returned = true; + fakeReturn(); + return Promise.resolve({ value: undefined, done: true }); + }, + }), + }; + }, + })), + ); + + let text = ''; + const request = app + .request() + .get(urlString({ query: '{test}' })) + .parse((res, cb) => { + res.on('data', (data) => { + text = `${text}${data.toString('utf8') as string}`; + ((res as unknown) as http.IncomingMessage).destroy(); + }); + res.on('end', (err) => { + cb(err, null); + }); + }); + + try { + await request; + } catch (e: unknown) { + // ignore aborted error + } + // sleep to allow return function to be called + await sleep(); + expect(text).to.equal( + [ + '', + '---', + 'Content-Type: application/json; charset=utf-8', + 'Content-Length: 47', + '', + '{"data":{"test":"Hello, World"},"hasNext":true}', + '', + ].join('\r\n'), + ); + expect(fakeReturn.callCount).to.equal(1); + }); + + it('handles return function on async iterable that throws', async () => { + const app = server(); + + app.get( + urlString(), + graphqlHTTP(() => ({ + schema: TestSchema, + // custom iterable keeps yielding until return is called + customExecuteFn() { + let returned = false; + return { + [Symbol.asyncIterator]: () => ({ + next: async () => { + await sleep(); + if (returned) { + return { value: undefined, done: true }; + } + return { + value: { data: { test: 'Hello, World' }, hasNext: true }, + done: false, + }; + }, + return: () => { + returned = true; + return Promise.reject(new Error('Throws!')); + }, + }), + }; + }, + })), + ); + + let text = ''; + const request = app + .request() + .get(urlString({ query: '{test}' })) + .parse((res, cb) => { + res.on('data', (data) => { + text = `${text}${data.toString('utf8') as string}`; + ((res as unknown) as http.IncomingMessage).destroy(); + }); + res.on('end', (err) => { + cb(err, null); + }); + }); + + try { + await request; + } catch (e: unknown) { + // ignore aborted error + } + // sleep to allow return function to be called + await sleep(); + expect(text).to.equal( + [ + '', + '---', + 'Content-Type: application/json; charset=utf-8', + 'Content-Length: 47', + '', + '{"data":{"test":"Hello, World"},"hasNext":true}', + '', + ].join('\r\n'), + ); + }); }); describe('Custom parse function', () => { diff --git a/src/index.ts b/src/index.ts index 7624e168..05161631 100644 --- a/src/index.ts +++ b/src/index.ts @@ -213,6 +213,7 @@ export function graphqlHTTP(options: Options): Middleware { let documentAST: DocumentNode; let executeResult; let result: ExecutionResult; + let finishedIterable = false; try { // Parse the Request to get GraphQL request parameters. @@ -371,6 +372,23 @@ export function graphqlHTTP(options: Options): Middleware { const asyncIterator = getAsyncIterator( executeResult, ); + + response.on('close', () => { + if ( + !finishedIterable && + typeof asyncIterator.return === 'function' + ) { + asyncIterator.return().then(null, (rawError: unknown) => { + const graphqlError = getGraphQlError(rawError); + sendPartialResponse(pretty, response, { + data: undefined, + errors: [formatErrorFn(graphqlError)], + hasNext: false, + }); + }); + } + }); + const { value } = await asyncIterator.next(); result = value; } else { @@ -398,6 +416,7 @@ export function graphqlHTTP(options: Options): Middleware { rawError instanceof Error ? rawError : String(rawError), ); + // eslint-disable-next-line require-atomic-updates response.statusCode = error.status; const { headers } = error; @@ -431,6 +450,7 @@ export function graphqlHTTP(options: Options): Middleware { // the resulting JSON payload. // https://graphql.github.io/graphql-spec/#sec-Data if (response.statusCode === 200 && result.data == null) { + // eslint-disable-next-line require-atomic-updates response.statusCode = 500; } @@ -462,17 +482,7 @@ export function graphqlHTTP(options: Options): Middleware { sendPartialResponse(pretty, response, formattedPayload); } } catch (rawError: unknown) { - /* istanbul ignore next: Thrown by underlying library. */ - const error = - rawError instanceof Error ? rawError : new Error(String(rawError)); - const graphqlError = new GraphQLError( - error.message, - undefined, - undefined, - undefined, - undefined, - error, - ); + const graphqlError = getGraphQlError(rawError); sendPartialResponse(pretty, response, { data: undefined, errors: [formatErrorFn(graphqlError)], @@ -481,6 +491,7 @@ export function graphqlHTTP(options: Options): Middleware { } response.write('\r\n-----\r\n'); response.end(); + finishedIterable = true; return; } @@ -657,3 +668,17 @@ function getAsyncIterator( const method = asyncIterable[Symbol.asyncIterator]; return method.call(asyncIterable); } + +function getGraphQlError(rawError: unknown) { + /* istanbul ignore next: Thrown by underlying library. */ + const error = + rawError instanceof Error ? rawError : new Error(String(rawError)); + return new GraphQLError( + error.message, + undefined, + undefined, + undefined, + undefined, + error, + ); +}