diff --git a/lib/core/util.js b/lib/core/util.js index fcdd5da0483..dfefac6d15c 100644 --- a/lib/core/util.js +++ b/lib/core/util.js @@ -511,6 +511,11 @@ function assertRequestHandler (handler, method, upgrade) { throw new InvalidArgumentError('handler must be an object') } + if (typeof handler.onRequestStart === 'function') { + // TODO (fix): More checks... + return + } + if (typeof handler.onConnect !== 'function') { throw new InvalidArgumentError('invalid onConnect method') } diff --git a/lib/dispatcher/dispatcher-base.js b/lib/dispatcher/dispatcher-base.js index afe4e9086db..cb3f0e02a3c 100644 --- a/lib/dispatcher/dispatcher-base.js +++ b/lib/dispatcher/dispatcher-base.js @@ -1,6 +1,7 @@ 'use strict' const Dispatcher = require('./dispatcher') +const UnwrapHandler = require('../handler/unwrap-handler') const { ClientDestroyedError, ClientClosedError, @@ -142,7 +143,7 @@ class DispatcherBase extends Dispatcher { throw new ClientClosedError() } - return this[kDispatch](opts, handler) + return this[kDispatch](opts, UnwrapHandler.unwrap(handler)) } catch (err) { if (typeof handler.onError !== 'function') { throw new InvalidArgumentError('invalid onError method') diff --git a/lib/dispatcher/dispatcher.js b/lib/dispatcher/dispatcher.js index ecff2a9b168..3019ad0e1da 100644 --- a/lib/dispatcher/dispatcher.js +++ b/lib/dispatcher/dispatcher.js @@ -1,5 +1,8 @@ 'use strict' const EventEmitter = require('node:events') +const WrapHandler = require('../handler/wrap-handler') + +const wrapInterceptor = (dispatch) => (opts, handler) => dispatch(opts, WrapHandler.wrap(handler)) class Dispatcher extends EventEmitter { dispatch () { @@ -28,6 +31,7 @@ class Dispatcher extends EventEmitter { throw new TypeError(`invalid interceptor, expected function received ${typeof interceptor}`) } + dispatch = wrapInterceptor(dispatch) dispatch = interceptor(dispatch) if (dispatch == null || typeof dispatch !== 'function' || dispatch.length !== 2) { diff --git a/lib/handler/redirect-handler.js b/lib/handler/redirect-handler.js index 02c302d9aa8..df2baee19d5 100644 --- a/lib/handler/redirect-handler.js +++ b/lib/handler/redirect-handler.js @@ -40,8 +40,6 @@ class RedirectHandler { throw new InvalidArgumentError('maxRedirections must be a positive number') } - util.assertRequestHandler(handler, opts.method, opts.upgrade) - this.dispatch = dispatch this.location = null this.abort = null diff --git a/lib/handler/unwrap-handler.js b/lib/handler/unwrap-handler.js new file mode 100644 index 00000000000..c7cef701b63 --- /dev/null +++ b/lib/handler/unwrap-handler.js @@ -0,0 +1,98 @@ +'use strict' + +const { parseHeaders } = require('../core/util') +const { InvalidArgumentError } = require('../core/errors') + +const kResume = Symbol('resume') + +class UnwrapController { + #paused = false + #reason = null + #aborted = false + #abort + + [kResume] = null + + constructor (abort) { + this.#abort = abort + } + + pause () { + this.#paused = true + } + + resume () { + if (this.#paused) { + this.#paused = false + this[kResume]?.() + } + } + + abort (reason) { + if (!this.#aborted) { + this.#aborted = true + this.#reason = reason + this.#abort(reason) + } + } + + get aborted () { + return this.#aborted + } + + get reason () { + return this.#reason + } + + get paused () { + return this.#paused + } +} + +module.exports = class UnwrapHandler { + #handler + #controller + + constructor (handler) { + this.#handler = handler + } + + static unwrap (handler) { + // TODO (fix): More checks... + return handler.onConnect ? handler : new UnwrapHandler(handler) + } + + onConnect (abort, context) { + this.#controller = new UnwrapController(abort) + this.#handler.onRequestStart?.(this.#controller, context) + } + + onUpgrade (statusCode, rawHeaders, socket) { + this.#handler.onRequestUpgrade?.(statusCode, parseHeaders(rawHeaders), socket) + } + + onHeaders (statusCode, rawHeaders, resume, statusMessage) { + this.#controller[kResume] = resume + this.#handler.onResponseStart?.(this.#controller, statusCode, statusMessage) + this.#handler.onResponseHeaders?.(this.#controller, parseHeaders(rawHeaders)) + return !this.#controller.paused + } + + onData (data) { + this.#handler.onResponseData?.(this.#controller, data) + return !this.#controller.paused + } + + onComplete (rawTrailers) { + this.#handler.onResponseTrailer?.(this.#controller, parseHeaders(rawTrailers)) + this.#handler.onResponseEnd?.(this.#controller) + } + + onError (err) { + if (!this.#handler.onError) { + throw new InvalidArgumentError('invalid onError method') + } + + this.#handler.onResponseError?.(this.#controller, err) + } +} diff --git a/lib/handler/wrap-handler.js b/lib/handler/wrap-handler.js new file mode 100644 index 00000000000..3b9a6abcb16 --- /dev/null +++ b/lib/handler/wrap-handler.js @@ -0,0 +1,126 @@ +'use strict' + +const { InvalidArgumentError } = require('../core/errors') + +module.exports = class WrapHandler { + #handler + #statusCode = 0 + #statusMessage = '' + #trailers = {} + + constructor (handler) { + this.#handler = handler + } + + static wrap (handler) { + // TODO (fix): More checks... + return handler.onRequestStart ? handler : new WrapHandler(handler) + } + + // Unwrap Interface + + onConnect (abort, context) { + return this.#handler.onConnect?.(abort, context) + } + + onHeaders (statusCode, rawHeaders, resume, statusMessage) { + return this.#handler.onHeaders?.(statusCode, rawHeaders, resume, statusMessage) + } + + onUpgrade (statusCode, rawHeaders, socket) { + return this.#handler.onUpgrade?.(statusCode, rawHeaders, socket) + } + + onData (data) { + return this.#handler.onData?.(data) + } + + onComplete (trailers) { + return this.#handler.onComplete?.(trailers) + } + + onError (err) { + if (!this.#handler.onError) { + throw new InvalidArgumentError('invalid onError method') + } + + return this.#handler.onError?.(err) + } + + // Wrap Interface + + onRequestStart (controller, context) { + this.#handler.onConnect?.((reason) => controller.abort(reason), context) + this.#statusCode = 0 + this.#statusMessage = '' + this.#trailers = {} + } + + onRequestError (controller, error) { + if (!this.#handler.onError) { + throw new InvalidArgumentError('invalid onError method') + } + + this.#handler.onError?.(error) + } + + onRequestUpgrade (statusCode, headers, socket) { + const rawHeaders = [] + for (const [key, val] of Object.entries(headers)) { + // TODO (fix): What if val is Array + rawHeaders.push(Buffer.from(key), Buffer.from(val)) + } + + this.#handler.onUpgrade?.(statusCode, rawHeaders, socket) + } + + onResponseStart (controller, statusCode, statusMessage) { + this.#statusCode = statusCode + this.#statusMessage = statusMessage + } + + onResponseHeaders (controller, headers) { + const rawHeaders = [] + for (const [key, val] of Object.entries(headers)) { + // TODO (fix): What if val is Array + rawHeaders.push(Buffer.from(key), Buffer.from(val)) + } + + if (this.#handler.onHeaders?.( + this.#statusCode, + rawHeaders, + () => controller.resume(), + this.#statusMessage + ) === false) { + controller.pause() + } + } + + onResponseData (controller, data) { + if (this.#handler.onData?.(data) === false) { + controller.pause() + } + } + + onResponseTrailer (controller, trailers) { + this.#trailers = trailers + } + + onResponseEnd (controller) { + const rawTrailers = [] + for (const [key, val] of Object.entries(this.#trailers)) { + // TODO (fix): What if val is Array + rawTrailers.push(Buffer.from(key), Buffer.from(val)) + } + + this.#handler.onComplete?.(rawTrailers) + } + + onResponseError (controller, error) { + if (!this.#handler.onError) { + throw new InvalidArgumentError('invalid onError method') + } + + this.#handler.onError?.(error) + } +} diff --git a/test/mock-agent.js b/test/mock-agent.js index e8afa8b00d6..d6fa744049d 100644 --- a/test/mock-agent.js +++ b/test/mock-agent.js @@ -142,89 +142,6 @@ describe('MockAgent - dispatch', () => { onError: () => {} })) }) - - test('should throw if handler is not valid on redirect', (t) => { - t = tspl(t, { plan: 7 }) - - const baseUrl = 'http://localhost:9999' - - const mockAgent = new MockAgent() - after(() => mockAgent.close()) - - t.throws(() => mockAgent.dispatch({ - origin: baseUrl, - path: '/foo', - method: 'GET' - }, { - onError: 'INVALID' - }), new InvalidArgumentError('invalid onError method')) - - t.throws(() => mockAgent.dispatch({ - origin: baseUrl, - path: '/foo', - method: 'GET' - }, { - onError: (err) => { throw err }, - onConnect: 'INVALID' - }), new InvalidArgumentError('invalid onConnect method')) - - t.throws(() => mockAgent.dispatch({ - origin: baseUrl, - path: '/foo', - method: 'GET' - }, { - onError: (err) => { throw err }, - onConnect: () => {}, - onBodySent: 'INVALID' - }), new InvalidArgumentError('invalid onBodySent method')) - - t.throws(() => mockAgent.dispatch({ - origin: baseUrl, - path: '/foo', - method: 'CONNECT' - }, { - onError: (err) => { throw err }, - onConnect: () => {}, - onBodySent: () => {}, - onUpgrade: 'INVALID' - }), new InvalidArgumentError('invalid onUpgrade method')) - - t.throws(() => mockAgent.dispatch({ - origin: baseUrl, - path: '/foo', - method: 'GET' - }, { - onError: (err) => { throw err }, - onConnect: () => {}, - onBodySent: () => {}, - onHeaders: 'INVALID' - }), new InvalidArgumentError('invalid onHeaders method')) - - t.throws(() => mockAgent.dispatch({ - origin: baseUrl, - path: '/foo', - method: 'GET' - }, { - onError: (err) => { throw err }, - onConnect: () => {}, - onBodySent: () => {}, - onHeaders: () => {}, - onData: 'INVALID' - }), new InvalidArgumentError('invalid onData method')) - - t.throws(() => mockAgent.dispatch({ - origin: baseUrl, - path: '/foo', - method: 'GET' - }, { - onError: (err) => { throw err }, - onConnect: () => {}, - onBodySent: () => {}, - onHeaders: () => {}, - onData: () => {}, - onComplete: 'INVALID' - }), new InvalidArgumentError('invalid onComplete method')) - }) }) test('MockAgent - .close should clean up registered pools', async (t) => { diff --git a/test/node-test/client-dispatch.js b/test/node-test/client-dispatch.js index 296e3b8d075..1a0680916a3 100644 --- a/test/node-test/client-dispatch.js +++ b/test/node-test/client-dispatch.js @@ -419,40 +419,6 @@ test('connect call onUpgrade once', async (t) => { await p.completed }) -test('dispatch onConnect missing', async (t) => { - const p = tspl(t, { plan: 1 }) - - const server = http.createServer((req, res) => { - res.end('ad') - }) - t.after(closeServerAsPromise(server)) - - server.listen(0, () => { - const client = new Client(`http://localhost:${server.address().port}`) - t.after(() => { return client.close() }) - - client.dispatch({ - path: '/', - method: 'GET' - }, { - onHeaders (statusCode, headers) { - t.ok(true, 'should not throw') - }, - onData (buf) { - t.ok(true, 'should not throw') - }, - onComplete (trailers) { - t.ok(true, 'should not throw') - }, - onError (err) { - p.strictEqual(err.code, 'UND_ERR_INVALID_ARG') - } - }) - }) - - await p.completed -}) - test('dispatch onHeaders missing', async (t) => { const p = tspl(t, { plan: 1 }) @@ -667,7 +633,7 @@ test('dispatch pool onError missing', async (t) => { client.dispatch({ path: '/', method: 'GET', - upgrade: 'Websocket' + upgrade: 1 }, { }) } catch (err) {