diff --git a/app/scripts/lib/createDupeReqFilterMiddleware.test.ts b/app/scripts/lib/createDupeReqFilterMiddleware.test.ts deleted file mode 100644 index 18354eebaad3..000000000000 --- a/app/scripts/lib/createDupeReqFilterMiddleware.test.ts +++ /dev/null @@ -1,135 +0,0 @@ -import { jsonrpc2 } from '@metamask/utils'; -import createDupeReqFilterMiddleware, { - THREE_MINUTES, -} from './createDupeReqFilterMiddleware'; - -describe('createDupeReqFilterMiddleware', () => { - const getMockRequest = (id: number | string) => ({ - jsonrpc: jsonrpc2, - id, - method: 'foo', - }); - const getMockResponse = () => ({ jsonrpc: jsonrpc2, id: 'foo' }); - - beforeEach(() => { - jest.useFakeTimers({ now: 10 }); - }); - - it('forwards requests with ids seen for the first time', () => { - const filterFn = createDupeReqFilterMiddleware(); - const nextMock = jest.fn(); - const endMock = jest.fn(); - - filterFn(getMockRequest(1), getMockResponse(), nextMock, endMock); - - expect(nextMock).toHaveBeenCalledTimes(1); - expect(endMock).not.toHaveBeenCalled(); - }); - - it('ends the request if the id has been seen before', () => { - const filterFn = createDupeReqFilterMiddleware(); - const nextMock = jest.fn(); - const endMock = jest.fn(); - - filterFn(getMockRequest(1), getMockResponse(), nextMock, endMock); - - expect(nextMock).toHaveBeenCalledTimes(1); - expect(endMock).not.toHaveBeenCalled(); - - const response = getMockResponse(); - filterFn(getMockRequest(1), response, nextMock, endMock); - expect('result' in response).toBe(false); - expect(nextMock).toHaveBeenCalledTimes(1); - expect(endMock).toHaveBeenCalledTimes(1); - }); - - it('forwards JSON-RPC notifications (requests without ids)', () => { - const filterFn = createDupeReqFilterMiddleware(); - const nextMock = jest.fn(); - const endMock = jest.fn(); - - const notification = getMockRequest(1); - // @ts-expect-error Intentional destructive testing - delete notification.id; - filterFn(notification, getMockResponse(), nextMock, endMock); - filterFn(notification, getMockResponse(), nextMock, endMock); - - expect(nextMock).toHaveBeenCalledTimes(2); - expect(endMock).not.toHaveBeenCalled(); - }); - - it('expires single id after three minutes', () => { - const filterFn = createDupeReqFilterMiddleware(); - const nextMock = jest.fn(); - const endMock = jest.fn(); - - filterFn(getMockRequest(0), getMockResponse(), nextMock, endMock); - - jest.advanceTimersByTime(THREE_MINUTES); - - filterFn(getMockRequest(0), getMockResponse(), nextMock, endMock); - - expect(nextMock).toHaveBeenCalledTimes(2); - expect(endMock).not.toHaveBeenCalled(); - }); - - it('expires multiple ids after three minutes', () => { - const filterFn = createDupeReqFilterMiddleware(); - const nextMock = jest.fn(); - const endMock = jest.fn(); - - filterFn(getMockRequest(0), getMockResponse(), nextMock, endMock); - filterFn(getMockRequest(1), getMockResponse(), nextMock, endMock); - - jest.advanceTimersByTime(1); - - filterFn(getMockRequest(2), getMockResponse(), nextMock, endMock); - - jest.advanceTimersByTime(THREE_MINUTES); - - filterFn(getMockRequest(0), getMockResponse(), nextMock, endMock); - filterFn(getMockRequest(1), getMockResponse(), nextMock, endMock); - // This should be ignored since id 2 has yet to expire. - filterFn(getMockRequest(2), getMockResponse(), nextMock, endMock); - - expect(nextMock).toHaveBeenCalledTimes(5); - expect(endMock).toHaveBeenCalledTimes(1); - }); - - it('expires single id in three minute intervals', () => { - const filterFn = createDupeReqFilterMiddleware(); - const nextMock = jest.fn(); - const endMock = jest.fn(); - - filterFn(getMockRequest(0), getMockResponse(), nextMock, endMock); - - jest.advanceTimersByTime(THREE_MINUTES); - - filterFn(getMockRequest(0), getMockResponse(), nextMock, endMock); - // This should be ignored - filterFn(getMockRequest(0), getMockResponse(), nextMock, endMock); - - expect(nextMock).toHaveBeenCalledTimes(2); - expect(endMock).toHaveBeenCalledTimes(1); - - jest.advanceTimersByTime(THREE_MINUTES); - - filterFn(getMockRequest(0), getMockResponse(), nextMock, endMock); - - expect(nextMock).toHaveBeenCalledTimes(3); - expect(endMock).toHaveBeenCalledTimes(1); - }); - - it('handles running expiry job without seeing any ids', () => { - const filterFn = createDupeReqFilterMiddleware(); - const nextMock = jest.fn(); - const endMock = jest.fn(); - - jest.advanceTimersByTime(THREE_MINUTES + 1); - - filterFn(getMockRequest(0), getMockResponse(), nextMock, endMock); - - expect(nextMock).toHaveBeenCalledTimes(1); - expect(endMock).not.toHaveBeenCalled(); - }); -}); diff --git a/app/scripts/lib/createDupeReqFilterStream.test.ts b/app/scripts/lib/createDupeReqFilterStream.test.ts new file mode 100644 index 000000000000..1486ac7d1325 --- /dev/null +++ b/app/scripts/lib/createDupeReqFilterStream.test.ts @@ -0,0 +1,388 @@ +import NodeStream from 'node:stream'; +import OurReadableStream from 'readable-stream'; + +import type { JsonRpcRequest } from '@metamask/utils'; +import createDupeReqFilterStream, { + THREE_MINUTES, +} from './createDupeReqFilterStream'; + +const { Transform } = OurReadableStream; + +function createTestStream(output: JsonRpcRequest[] = [], S = Transform) { + const transformStream = createDupeReqFilterStream(); + const testOutStream = new S({ + transform: (chunk: JsonRpcRequest, _, cb) => { + output.push(chunk); + cb(); + }, + objectMode: true, + }); + + transformStream.pipe(testOutStream); + + return transformStream; +} + +function runStreamTest( + requests: JsonRpcRequest[] = [], + advanceTimersTime = 10, + S = Transform, +) { + return new Promise((resolve, reject) => { + const output: JsonRpcRequest[] = []; + const testStream = createTestStream(output, S); + + testStream + .on('finish', () => resolve(output)) + .on('error', (err) => reject(err)); + + requests.forEach((request) => testStream.write(request)); + testStream.end(); + + jest.advanceTimersByTime(advanceTimersTime); + }); +} + +describe('createDupeReqFilterStream', () => { + beforeEach(() => { + jest.useFakeTimers({ now: 10 }); + }); + + it('lets through requests with ids being seen for the first time', async () => { + const requests = [ + { id: 1, method: 'foo' }, + { id: 2, method: 'bar' }, + ]; + + const expectedOutput = [ + { id: 1, method: 'foo' }, + { id: 2, method: 'bar' }, + ]; + + const output = await runStreamTest(requests); + expect(output).toEqual(expectedOutput); + }); + + it('does not let through the request if the id has been seen before', async () => { + const requests = [ + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, // duplicate + ]; + + const expectedOutput = [{ id: 1, method: 'foo' }]; + + const output = await runStreamTest(requests); + expect(output).toEqual(expectedOutput); + }); + + it("lets through requests if they don't have an id", async () => { + const requests = [{ method: 'notify1' }, { method: 'notify2' }]; + + const expectedOutput = [{ method: 'notify1' }, { method: 'notify2' }]; + + const output = await runStreamTest(requests); + expect(output).toEqual(expectedOutput); + }); + + it('handles a mix of request types', async () => { + const requests = [ + { id: 1, method: 'foo' }, + { method: 'notify1' }, + { id: 1, method: 'foo' }, + { id: 2, method: 'bar' }, + { method: 'notify2' }, + { id: 2, method: 'bar' }, + { id: 3, method: 'baz' }, + ]; + + const expectedOutput = [ + { id: 1, method: 'foo' }, + { method: 'notify1' }, + { id: 2, method: 'bar' }, + { method: 'notify2' }, + { id: 3, method: 'baz' }, + ]; + + const output = await runStreamTest(requests); + expect(output).toEqual(expectedOutput); + }); + + it('expires single id after three minutes', () => { + const output: JsonRpcRequest[] = []; + const testStream = createTestStream(output); + + const requests1 = [ + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, + ]; + const expectedOutputBeforeExpiryTime = [{ id: 1, method: 'foo' }]; + + requests1.forEach((request) => testStream.write(request)); + expect(output).toEqual(expectedOutputBeforeExpiryTime); + + const requests2 = [ + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, + ]; + const expectedOutputAfterExpiryTime = [ + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, + ]; + + jest.advanceTimersByTime(THREE_MINUTES); + + requests2.forEach((request) => testStream.write(request)); + expect(output).toEqual(expectedOutputAfterExpiryTime); + }); + + it('does not expire single id after less than three', () => { + const output: JsonRpcRequest[] = []; + const testStream = createTestStream(output); + + const requests1 = [ + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, + ]; + const expectedOutputBeforeTimeElapses = [{ id: 1, method: 'foo' }]; + + requests1.forEach((request) => testStream.write(request)); + expect(output).toEqual(expectedOutputBeforeTimeElapses); + + const requests2 = [ + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, + ]; + const expectedOutputAfterTimeElapses = expectedOutputBeforeTimeElapses; + + jest.advanceTimersByTime(THREE_MINUTES - 1); + + requests2.forEach((request) => testStream.write(request)); + expect(output).toEqual(expectedOutputAfterTimeElapses); + }); + + it('expires multiple ids after three minutes', () => { + const output: JsonRpcRequest[] = []; + const testStream = createTestStream(output); + + const requests1 = [ + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, + { id: 2, method: 'bar' }, + { id: 2, method: 'bar' }, + { id: 3, method: 'baz' }, + { id: 3, method: 'baz' }, + ]; + const expectedOutputBeforeExpiryTime = [ + { id: 1, method: 'foo' }, + { id: 2, method: 'bar' }, + { id: 3, method: 'baz' }, + ]; + + requests1.forEach((request) => testStream.write(request)); + expect(output).toEqual(expectedOutputBeforeExpiryTime); + + const requests2 = [ + { id: 3, method: 'baz' }, + { id: 3, method: 'baz' }, + { id: 2, method: 'bar' }, + { id: 2, method: 'bar' }, + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, + ]; + const expectedOutputAfterExpiryTime = [ + { id: 1, method: 'foo' }, + { id: 2, method: 'bar' }, + { id: 3, method: 'baz' }, + { id: 3, method: 'baz' }, + { id: 2, method: 'bar' }, + { id: 1, method: 'foo' }, + ]; + + jest.advanceTimersByTime(THREE_MINUTES); + + requests2.forEach((request) => testStream.write(request)); + expect(output).toEqual(expectedOutputAfterExpiryTime); + }); + + it('expires single id in three minute intervals', () => { + const output: JsonRpcRequest[] = []; + const testStream = createTestStream(output); + + const requests1 = [ + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, + ]; + const expectedOutputBeforeExpiryTime = [{ id: 1, method: 'foo' }]; + + requests1.forEach((request) => testStream.write(request)); + expect(output).toEqual(expectedOutputBeforeExpiryTime); + + const requests2 = [ + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, + ]; + const expectedOutputAfterFirstExpiryTime = [ + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, + ]; + + jest.advanceTimersByTime(THREE_MINUTES); + + requests2.forEach((request) => testStream.write(request)); + expect(output).toEqual(expectedOutputAfterFirstExpiryTime); + + const requests3 = [ + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, + ]; + const expectedOutputAfterSecondExpiryTime = [ + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, + { id: 1, method: 'foo' }, + ]; + + jest.advanceTimersByTime(THREE_MINUTES); + + requests3.forEach((request) => testStream.write(request)); + expect(output).toEqual(expectedOutputAfterSecondExpiryTime); + }); + + it('expires somes ids at intervals while not expiring others', () => { + const output: JsonRpcRequest[] = []; + const testStream = createTestStream(output); + + const requests1 = [ + { id: 1, method: 'foo' }, + { id: 2, method: 'bar' }, + ]; + const expectedOutputBeforeExpiryTime = [ + { id: 1, method: 'foo' }, + { id: 2, method: 'bar' }, + ]; + + requests1.forEach((request) => testStream.write(request)); + expect(output).toEqual(expectedOutputBeforeExpiryTime); + + const requests2 = [{ id: 3, method: 'baz' }]; + const expectedOutputAfterFirstExpiryTime = [ + { id: 1, method: 'foo' }, + { id: 2, method: 'bar' }, + { id: 3, method: 'baz' }, + ]; + + jest.advanceTimersByTime(THREE_MINUTES - 1); + + requests2.forEach((request) => testStream.write(request)); + expect(output).toEqual(expectedOutputAfterFirstExpiryTime); + + const requests3 = [ + { id: 1, method: 'foo' }, + { id: 2, method: 'bar' }, + { id: 3, method: 'baz' }, + { id: 4, method: 'buzz' }, + ]; + const expectedOutputAfterSecondExpiryTime = [ + { id: 1, method: 'foo' }, + { id: 2, method: 'bar' }, + { id: 3, method: 'baz' }, + { id: 1, method: 'foo' }, + { id: 2, method: 'bar' }, + { id: 4, method: 'buzz' }, + ]; + + jest.advanceTimersByTime(THREE_MINUTES - 1); + + requests3.forEach((request) => testStream.write(request)); + expect(output).toEqual(expectedOutputAfterSecondExpiryTime); + }); + + it('handles running expiry job without seeing any ids', () => { + const output: JsonRpcRequest[] = []; + const testStream = createTestStream(output); + + const requests1 = [{ id: 1, method: 'foo' }]; + const expectedOutputBeforeExpiryTime = [{ id: 1, method: 'foo' }]; + + requests1.forEach((request) => testStream.write(request)); + expect(output).toEqual(expectedOutputBeforeExpiryTime); + + jest.advanceTimersByTime(THREE_MINUTES + 1); + + expect(output).toEqual(expectedOutputBeforeExpiryTime); + }); + + [ + ['node:stream', NodeStream] as [string, typeof NodeStream], + // Redundantly include used version twice for regression-detection purposes + ['readable-stream', OurReadableStream] as [ + string, + typeof OurReadableStream, + ], + ].forEach(([name, streamsImpl]) => { + describe(`Using Streams implementation: ${name}`, () => { + [ + ['Duplex', streamsImpl.Duplex] as [string, typeof streamsImpl.Duplex], + ['Transform', streamsImpl.Transform] as [ + string, + typeof streamsImpl.Transform, + ], + ['Writable', streamsImpl.Writable] as [ + string, + typeof streamsImpl.Writable, + ], + ].forEach(([className, S]) => { + it(`handles a mix of request types coming through a ${className} stream`, async () => { + const requests = [ + { id: 1, method: 'foo' }, + { method: 'notify1' }, + { id: 1, method: 'foo' }, + { id: 2, method: 'bar' }, + { method: 'notify2' }, + { id: 2, method: 'bar' }, + { id: 3, method: 'baz' }, + ]; + + const expectedOutput = [ + { id: 1, method: 'foo' }, + { method: 'notify1' }, + { id: 2, method: 'bar' }, + { method: 'notify2' }, + { id: 3, method: 'baz' }, + ]; + + const output: JsonRpcRequest[] = []; + const testStream = createDupeReqFilterStream(); + const testOutStream = new S({ + transform: (chunk: JsonRpcRequest, _, cb) => { + output.push(chunk); + cb(); + }, + objectMode: true, + }); + + testOutStream._write = ( + chunk: JsonRpcRequest, + _: BufferEncoding, + callback: (error?: Error | null) => void, + ) => { + output.push(chunk); + callback(); + }; + + testStream.pipe(testOutStream); + + requests.forEach((request) => testStream.write(request)); + + expect(output).toEqual(expectedOutput); + }); + }); + }); + }); +}); diff --git a/app/scripts/lib/createDupeReqFilterMiddleware.ts b/app/scripts/lib/createDupeReqFilterStream.ts similarity index 63% rename from app/scripts/lib/createDupeReqFilterMiddleware.ts rename to app/scripts/lib/createDupeReqFilterStream.ts index fbaa1aa658b1..63d801e7f1e4 100644 --- a/app/scripts/lib/createDupeReqFilterMiddleware.ts +++ b/app/scripts/lib/createDupeReqFilterStream.ts @@ -1,5 +1,6 @@ -import { JsonRpcMiddleware } from 'json-rpc-engine'; +import { Transform } from 'readable-stream'; import log from 'loglevel'; +import type { JsonRpcRequest } from '@metamask/utils'; import { MINUTE } from '../../../shared/constants/time'; export const THREE_MINUTES = MINUTE * 3; @@ -42,24 +43,26 @@ const makeExpirySet = () => { }; /** - * Returns a middleware that filters out requests whose ids we've already seen. + * Returns a transform stream that filters out requests whose ids we've already seen. * Ignores JSON-RPC notifications, i.e. requests with an `undefined` id. * - * @returns The middleware function. + * @returns The stream object. */ -export default function createDupeReqFilterMiddleware(): JsonRpcMiddleware< - unknown, - void -> { +export default function createDupeReqFilterStream() { const seenRequestIds = makeExpirySet(); - return function filterDuplicateRequestMiddleware(req, _res, next, end) { - if (req.id === undefined) { + return new Transform({ + transform(chunk: JsonRpcRequest, _, cb) { // JSON-RPC notifications have no ids; our only recourse is to let them through. - return next(); - } else if (!seenRequestIds.add(req.id)) { - log.info(`RPC request with id "${req.id}" already seen.`); - return end(); - } - return next(); - }; + const hasNoId = chunk.id === undefined; + const requestNotYetSeen = seenRequestIds.add(chunk.id); + + if (hasNoId || requestNotYetSeen) { + cb(null, chunk); + } else { + log.debug(`RPC request with id "${chunk.id}" already seen.`); + cb(); + } + }, + objectMode: true, + }); } diff --git a/app/scripts/metamask-controller.js b/app/scripts/metamask-controller.js index 3d27df5651f9..d6580d6eece9 100644 --- a/app/scripts/metamask-controller.js +++ b/app/scripts/metamask-controller.js @@ -266,7 +266,7 @@ import { mmiKeyringBuilderFactory } from './mmi-keyring-builder-factory'; ///: END:ONLY_INCLUDE_IF import ComposableObservableStore from './lib/ComposableObservableStore'; import AccountTracker from './lib/account-tracker'; -import createDupeReqFilterMiddleware from './lib/createDupeReqFilterMiddleware'; +import createDupeReqFilterStream from './lib/createDupeReqFilterStream'; import createLoggerMiddleware from './lib/createLoggerMiddleware'; import { createMethodMiddleware } from './lib/rpc-method-middleware'; import createOriginMiddleware from './lib/createOriginMiddleware'; @@ -4829,12 +4829,14 @@ export default class MetamaskController extends EventEmitter { tabId, }); + const dupeReqFilterStream = createDupeReqFilterStream(); + // setup connection const providerStream = createEngineStream({ engine }); const connectionId = this.addConnection(origin, { engine }); - pump(outStream, providerStream, outStream, (err) => { + pump(outStream, dupeReqFilterStream, providerStream, outStream, (err) => { // handle any middleware cleanup engine._middleware.forEach((mid) => { if (mid.destroy && typeof mid.destroy === 'function') { @@ -4912,10 +4914,6 @@ export default class MetamaskController extends EventEmitter { engine.emit('notification', message), ); - if (isManifestV3) { - engine.push(createDupeReqFilterMiddleware()); - } - // append tabId to each request if it exists if (tabId) { engine.push(createTabIdMiddleware({ tabId }));