diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index 5c1a64858..3793c00c1 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -1,7 +1,7 @@ import type { ClientCallerInterface, DuplexCallerInterface, - JsonRpcRequest, + JsonRpcRequestMessage, ServerCallerInterface, StreamPairCreateCallback, } from './types'; @@ -64,7 +64,7 @@ class RPCClient { ): Promise> { const streamPair = await this.streamPairCreateCallback(); const inputStream = streamPair.readable.pipeThrough( - new rpcUtils.JsonToJsonMessageStream(), + new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcResponse), ); const outputTransform = new rpcUtils.JsonMessageToJsonStream(); void outputTransform.readable.pipeTo(streamPair.writable); @@ -75,9 +75,8 @@ class RPCClient { try { while (true) { value = yield; - const message: JsonRpcRequest = { + const message: JsonRpcRequestMessage = { method, - type: 'JsonRpcRequest', jsonrpc: '2.0', id: null, params: value, @@ -96,12 +95,12 @@ class RPCClient { while (true) { const { value, done } = await reader.read(); if (done) break; - if ( - value?.type === 'JsonRpcRequest' || - value?.type === 'JsonRpcNotification' - ) { - yield value.params as O; + if ('error' in value) { + throw Error('TMP message was an error message', { + cause: value.error, + }); } + yield value.result as O; } }; const output = outputGen(); diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index 8335d623b..1e683dd61 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -17,7 +17,6 @@ import Logger from '@matrixai/logger'; import { PromiseCancellable } from '@matrixai/async-cancellable'; import * as rpcErrors from './errors'; import * as rpcUtils from './utils'; -import * as grpcUtils from '../grpc/utils'; interface RPCServer extends CreateDestroy {} @CreateDestroy() @@ -151,18 +150,10 @@ class RPCServer { // a generator. const inputGen = async function* () { const pojoStream = streamPair.readable.pipeThrough( - new rpcUtils.JsonToJsonMessageStream(), + new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcRequest), ); for await (const dataMessage of pojoStream) { - // FIXME: don't bother filtering, we should assume all input messages are request or notification. - // These should be checked by parsing, no need for a type field. - // Filtering for request and notification messages - if ( - dataMessage.type === 'JsonRpcRequest' || - dataMessage.type === 'JsonRpcNotification' - ) { - yield dataMessage; - } + yield dataMessage; } }; const container = this.container; @@ -203,7 +194,6 @@ class RPCServer { ctx, )) { const responseMessage: JsonRpcResponseResult = { - type: 'JsonRpcResponseResult', jsonrpc: '2.0', result: response, id: null, @@ -216,10 +206,9 @@ class RPCServer { const rpcError: JsonRpcError = { code: e.exitCode, message: e.description, - data: grpcUtils.fromError(e), + data: rpcUtils.fromError(e), }; const rpcErrorMessage: JsonRpcResponseError = { - type: 'JsonRpcResponseError', jsonrpc: '2.0', error: rpcError, id: null, diff --git a/src/RPC/types.ts b/src/RPC/types.ts index 712b96d8a..3e1cfaa66 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -6,8 +6,7 @@ import type { ReadableWritablePair } from 'stream/web'; /** * This is the JSON RPC request object. this is the generic message type used for the RPC. */ -type JsonRpcRequest = { - type: 'JsonRpcRequest'; +type JsonRpcRequestMessage = { // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0" jsonrpc: '2.0'; // A String containing the name of the method to be invoked. Method names that begin with the word rpc followed by a @@ -23,8 +22,7 @@ type JsonRpcRequest = { id: string | number | null; }; -type JsonRpcNotification = { - type: 'JsonRpcNotification'; +type JsonRpcRequestNotification = { // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0" jsonrpc: '2.0'; // A String containing the name of the method to be invoked. Method names that begin with the word rpc followed by a @@ -37,7 +35,6 @@ type JsonRpcNotification = { }; type JsonRpcResponseResult = { - type: 'JsonRpcResponseResult'; // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0". jsonrpc: '2.0'; // This member is REQUIRED on success. @@ -51,14 +48,13 @@ type JsonRpcResponseResult = { id: string | number | null; }; -type JsonRpcResponseError = { - type: 'JsonRpcResponseError'; +type JsonRpcResponseError = { // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0". jsonrpc: '2.0'; // This member is REQUIRED on error. // This member MUST NOT exist if there was no error triggered during invocation. // The value for this member MUST be an Object as defined in section 5.1. - error: JsonRpcError; + error: JsonRpcError; // This member is REQUIRED. // It MUST be the same as the value of the id member in the Request Object. // If there was an error in detecting the id in the Request object (e.g. Parse error/Invalid Request), @@ -78,7 +74,7 @@ type JsonRpcResponseError = { // -32603 Internal error Internal JSON-RPC error. // -32000 to -32099 -type JsonRpcError = { +type JsonRpcError = { // A Number that indicates the error type that occurred. // This MUST be an integer. code: number; @@ -88,19 +84,20 @@ type JsonRpcError = { // A Primitive or Structured value that contains additional information about the error. // This may be omitted. // The value of this member is defined by the Server (e.g. detailed error information, nested errors etc.). - data?: T; + data?: JSONValue; }; -type JsonRpcResponse< - T extends JSONValue | unknown = unknown, - K extends JSONValue | unknown = unknown, -> = JsonRpcResponseResult | JsonRpcResponseError; +type JsonRpcRequest = + | JsonRpcRequestMessage + | JsonRpcRequestNotification; + +type JsonRpcResponse = + | JsonRpcResponseResult + | JsonRpcResponseError; type JsonRpcMessage = | JsonRpcRequest - | JsonRpcNotification - | JsonRpcResponseResult - | JsonRpcResponseError; + | JsonRpcResponse; // Handler types type Handler = ( @@ -165,11 +162,12 @@ type StreamPairCreateCallback = () => Promise< >; export type { - JsonRpcRequest, - JsonRpcNotification, + JsonRpcRequestMessage, + JsonRpcRequestNotification, JsonRpcResponseResult, JsonRpcResponseError, JsonRpcError, + JsonRpcRequest, JsonRpcResponse, JsonRpcMessage, DuplexStreamHandler, diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index 9edd3e30e..8881b40af 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -6,41 +6,44 @@ import type { import type { JsonRpcError, JsonRpcMessage, - JsonRpcNotification, - JsonRpcRequest, + JsonRpcRequestNotification, + JsonRpcRequestMessage, JsonRpcResponseError, JsonRpcResponseResult, + JsonRpcRequest, + JsonRpcResponse, } from 'RPC/types'; import type { JSONValue } from '../types'; import { TransformStream } from 'stream/web'; import * as rpcErrors from './errors'; -import * as rpcUtils from './utils'; import * as utils from '../utils'; import * as validationErrors from '../validation/errors'; const jsonStreamParsers = require('@streamparser/json'); -class JsonToJsonMessage implements Transformer { +class JsonToJsonMessage + implements Transformer +{ protected bytesWritten: number = 0; - constructor(protected byteLimit: number) {} + constructor( + protected messageParser: (message: unknown) => T, + protected byteLimit: number, + ) {} protected parser = new jsonStreamParsers.JSONParser({ separator: '', paths: ['$'], }); - start: TransformerStartCallback = async (controller) => { + start: TransformerStartCallback = async (controller) => { this.parser.onValue = (value) => { - const jsonMessage = rpcUtils.parseJsonRpcMessage(value.value); + const jsonMessage = this.messageParser(value.value); controller.enqueue(jsonMessage); this.bytesWritten = 0; }; }; - transform: TransformerTransformCallback = async ( - chunk, - _controller, - ) => { + transform: TransformerTransformCallback = async (chunk) => { try { this.bytesWritten += chunk.byteLength; this.parser.write(chunk); @@ -54,9 +57,15 @@ class JsonToJsonMessage implements Transformer { } // TODO: rename to something more descriptive? -class JsonToJsonMessageStream extends TransformStream { - constructor(byteLimit: number = 1024 * 1024) { - super(new JsonToJsonMessage(byteLimit)); +class JsonToJsonMessageStream extends TransformStream< + Buffer, + T +> { + constructor( + messageParser: (message: unknown) => T, + byteLimit: number = 1024 * 1024, + ) { + super(new JsonToJsonMessage(messageParser, byteLimit)); } } @@ -82,17 +91,6 @@ function parseJsonRpcRequest( if (!utils.isObject(message)) { throw new validationErrors.ErrorParse('must be a JSON POJO'); } - if (!('type' in message)) { - throw new validationErrors.ErrorParse('`type` property must be defined'); - } - if (typeof message.type !== 'string') { - throw new validationErrors.ErrorParse('`type` property must be a string'); - } - if (message.type !== 'JsonRpcRequest') { - throw new validationErrors.ErrorParse( - '`type` property must be "JsonRpcRequest"', - ); - } if (!('method' in message)) { throw new validationErrors.ErrorParse('`method` property must be defined'); } @@ -102,51 +100,36 @@ function parseJsonRpcRequest( // If ('params' in message && !utils.isObject(message.params)) { // throw new validationErrors.ErrorParse('`params` property must be a POJO'); // } - if (!('id' in message)) { + return message as JsonRpcRequest; +} + +function parseJsonRpcRequestMessage( + message: unknown, +): JsonRpcRequestMessage { + const jsonRequest = parseJsonRpcRequest(message); + if (!('id' in jsonRequest)) { throw new validationErrors.ErrorParse('`id` property must be defined'); } if ( - typeof message.id !== 'string' && - typeof message.id !== 'number' && - message.id !== null + typeof jsonRequest.id !== 'string' && + typeof jsonRequest.id !== 'number' && + jsonRequest.id !== null ) { throw new validationErrors.ErrorParse( '`id` property must be a string, number or null', ); } - return message as JsonRpcRequest; + return jsonRequest as JsonRpcRequestMessage; } -function parseJsonRpcNotification( +function parseJsonRpcRequestNotification( message: unknown, -): JsonRpcNotification { - if (!utils.isObject(message)) { - throw new validationErrors.ErrorParse('must be a JSON POJO'); - } - if (!('type' in message)) { - throw new validationErrors.ErrorParse('`type` property must be defined'); - } - if (typeof message.type !== 'string') { - throw new validationErrors.ErrorParse('`type` property must be a string'); - } - if (message.type !== 'JsonRpcNotification') { - throw new validationErrors.ErrorParse( - '`type` property must be "JsonRpcRequest"', - ); - } - if (!('method' in message)) { - throw new validationErrors.ErrorParse('`method` property must be defined'); - } - if (typeof message.method !== 'string') { - throw new validationErrors.ErrorParse('`method` property must be a string'); - } - // If ('params' in message && !utils.isObject(message.params)) { - // throw new validationErrors.ErrorParse('`params` property must be a POJO'); - // } - if ('id' in message) { +): JsonRpcRequestNotification { + const jsonRequest = parseJsonRpcRequest(message); + if ('id' in jsonRequest) { throw new validationErrors.ErrorParse('`id` property must not be defined'); } - return message as JsonRpcNotification; + return jsonRequest as JsonRpcRequestNotification; } function parseJsonRpcResponseResult( @@ -155,17 +138,6 @@ function parseJsonRpcResponseResult( if (!utils.isObject(message)) { throw new validationErrors.ErrorParse('must be a JSON POJO'); } - if (!('type' in message)) { - throw new validationErrors.ErrorParse('`type` property must be defined'); - } - if (typeof message.type !== 'string') { - throw new validationErrors.ErrorParse('`type` property must be a string'); - } - if (message.type !== 'JsonRpcResponseResult') { - throw new validationErrors.ErrorParse( - '`type` property must be "JsonRpcRequest"', - ); - } if (!('result' in message)) { throw new validationErrors.ErrorParse('`result` property must be defined'); } @@ -192,23 +164,10 @@ function parseJsonRpcResponseResult( return message as JsonRpcResponseResult; } -function parseJsonRpcResponseError( - message: unknown, -): JsonRpcResponseError { +function parseJsonRpcResponseError(message: unknown): JsonRpcResponseError { if (!utils.isObject(message)) { throw new validationErrors.ErrorParse('must be a JSON POJO'); } - if (!('type' in message)) { - throw new validationErrors.ErrorParse('`type` property must be defined'); - } - if (typeof message.type !== 'string') { - throw new validationErrors.ErrorParse('`type` property must be a string'); - } - if (message.type !== 'JsonRpcResponseError') { - throw new validationErrors.ErrorParse( - '`type` property must be "JsonRpcResponseError"', - ); - } if ('result' in message) { throw new validationErrors.ErrorParse( '`result` property must not be defined', @@ -217,7 +176,7 @@ function parseJsonRpcResponseError( if (!('error' in message)) { throw new validationErrors.ErrorParse('`error` property must be defined'); } - parseJsonRpcError(message.error); + parseJsonRpcError(message.error); if (!('id' in message)) { throw new validationErrors.ErrorParse('`id` property must be defined'); } @@ -230,12 +189,10 @@ function parseJsonRpcResponseError( '`id` property must be a string, number or null', ); } - return message as JsonRpcResponseError; + return message as JsonRpcResponseError; } -function parseJsonRpcError( - message: unknown, -): JsonRpcError { +function parseJsonRpcError(message: unknown): JsonRpcError { if (!utils.isObject(message)) { throw new validationErrors.ErrorParse('must be a JSON POJO'); } @@ -256,20 +213,35 @@ function parseJsonRpcError( // If ('data' in message && !utils.isObject(message.data)) { // throw new validationErrors.ErrorParse('`data` property must be a POJO'); // } - return message as JsonRpcError; + return message as JsonRpcError; } -function parseJsonRpcMessage( +function parseJsonRpcResponse( message: unknown, -): JsonRpcMessage { +): JsonRpcResponse { if (!utils.isObject(message)) { throw new validationErrors.ErrorParse('must be a JSON POJO'); } - if (!('type' in message)) { - throw new validationErrors.ErrorParse('`type` property must be defined'); + try { + return parseJsonRpcResponseResult(message); + } catch (e) { + // Do nothing } - if (typeof message.type !== 'string') { - throw new validationErrors.ErrorParse('`type` property must be a string'); + try { + return parseJsonRpcResponseError(message); + } catch (e) { + // Do nothing + } + throw new validationErrors.ErrorParse( + 'structure did not match a `JsonRpcResponse`', + ); +} + +function parseJsonRpcMessage( + message: unknown, +): JsonRpcMessage { + if (!utils.isObject(message)) { + throw new validationErrors.ErrorParse('must be a JSON POJO'); } if (!('jsonrpc' in message)) { throw new validationErrors.ErrorParse('`jsonrpc` property must be defined'); @@ -279,19 +251,79 @@ function parseJsonRpcMessage( '`jsonrpc` property must be a string of "2.0"', ); } - switch (message.type) { - case 'JsonRpcRequest': - return parseJsonRpcRequest(message); - case 'JsonRpcNotification': - return parseJsonRpcNotification(message); - case 'JsonRpcResponseResult': - return parseJsonRpcResponseResult(message); - case 'JsonRpcResponseError': - return parseJsonRpcResponseError(message); - default: - throw new validationErrors.ErrorParse( - '`type` property must be a valid type', - ); + try { + return parseJsonRpcRequest(message); + } catch { + // Do nothing + } + try { + return parseJsonRpcResponse(message); + } catch { + // Do nothing + } + throw new validationErrors.ErrorParse( + 'Message structure did not match a `JsonRpcMessage`', + ); +} + +/** + * Replacer function for serialising errors over GRPC (used by `JSON.stringify` + * in `fromError`) + * Polykey errors are handled by their inbuilt `toJSON` method , so this only + * serialises other errors + */ +function replacer(key: string, value: any): any { + if (value instanceof AggregateError) { + // AggregateError has an `errors` property + return { + type: value.constructor.name, + data: { + errors: value.errors, + message: value.message, + stack: value.stack, + }, + }; + } else if (value instanceof Error) { + // If it's some other type of error then only serialise the message and + // stack (and the type of the error) + return { + type: value.name, + data: { + message: value.message, + stack: value.stack, + }, + }; + } else { + // If it's not an error then just leave as is + return value; + } +} + +/** + * The same as `replacer`, however this will additionally filter out any + * sensitive data that should not be sent over the network when sending to an + * agent (as opposed to a client) + */ +function sensitiveReplacer(key: string, value: any) { + if (key === 'stack') { + return; + } else { + return replacer(key, value); + } +} + +/** + * Serializes Error instances into GRPC errors + * Use this on the sending side to send exceptions + * Do not send exceptions to clients you do not trust + * If sending to an agent (rather than a client), set sensitive to true to + * prevent sensitive information from being sent over the network + */ +function fromError(error: Error, sensitive: boolean = false) { + if (sensitive) { + return { error: JSON.stringify(error, sensitiveReplacer) }; + } else { + return { error: JSON.stringify(error, replacer) }; } } @@ -299,8 +331,11 @@ export { JsonToJsonMessageStream, JsonMessageToJsonStream, parseJsonRpcRequest, - parseJsonRpcNotification, + parseJsonRpcRequestMessage, + parseJsonRpcRequestNotification, parseJsonRpcResponseResult, parseJsonRpcResponseError, + parseJsonRpcResponse, parseJsonRpcMessage, + fromError, }; diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index dbb51cba5..8cfc69732 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -1,6 +1,6 @@ import type { ReadableWritablePair } from 'stream/web'; import type { JSONValue } from '@/types'; -import type { JsonRpcRequest } from '@/RPC/types'; +import type { JsonRpcRequestMessage } from '@/RPC/types'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import { testProp, fc } from '@fast-check/jest'; import RPCClient from '@/RPC/RPCClient'; @@ -14,7 +14,7 @@ describe(`${RPCClient.name}`, () => { const methodName = 'testMethod'; const specificMessageArb = fc - .array(rpcTestUtils.jsonRpcRequestArb(), { + .array(rpcTestUtils.jsonRpcResponseResultArb(), { minLength: 5, }) .noShrink(); @@ -43,13 +43,12 @@ describe(`${RPCClient.name}`, () => { } await callerInterface.write(value); } - const expectedMessages: Array = messages.map((v) => { - const request: JsonRpcRequest = { - type: 'JsonRpcRequest', + const expectedMessages: Array = messages.map((v) => { + const request: JsonRpcRequestMessage = { jsonrpc: '2.0', method: methodName, id: null, - ...(v.params === undefined ? {} : { params: v.params }), + ...(v.result === undefined ? {} : { params: v.result }), }; return request; }); @@ -80,12 +79,11 @@ describe(`${RPCClient.name}`, () => { for await (const value of callerInterface.outputGenerator) { values.push(value); } - const expectedValues = messages.map((v) => v.params); + const expectedValues = messages.map((v) => v.result); expect(values).toStrictEqual(expectedValues); expect((await outputResult)[0]?.toString()).toStrictEqual( JSON.stringify({ method: methodName, - type: 'JsonRpcRequest', jsonrpc: '2.0', id: null, params, @@ -95,7 +93,7 @@ describe(`${RPCClient.name}`, () => { ); testProp( 'generic client stream caller', - [rpcTestUtils.jsonRpcRequestArb(), fc.array(fc.jsonValue())], + [rpcTestUtils.jsonRpcResponseResultArb(), fc.array(fc.jsonValue())], async (message, params) => { const inputStream = rpcTestUtils.jsonRpcStream([message]); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); @@ -115,11 +113,10 @@ describe(`${RPCClient.name}`, () => { await callerInterface.write(param as JSONValue); } await callerInterface.end(); - expect(await callerInterface.result).toStrictEqual(message.params); + expect(await callerInterface.result).toStrictEqual(message.result); const expectedOutput = params.map((v) => JSON.stringify({ method: methodName, - type: 'JsonRpcRequest', jsonrpc: '2.0', id: null, params: v, @@ -132,7 +129,7 @@ describe(`${RPCClient.name}`, () => { ); testProp( 'generic unary caller', - [rpcTestUtils.jsonRpcRequestArb(), fc.jsonValue()], + [rpcTestUtils.jsonRpcResponseResultArb(), fc.jsonValue()], async (message, params) => { const inputStream = rpcTestUtils.jsonRpcStream([message]); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); @@ -149,11 +146,10 @@ describe(`${RPCClient.name}`, () => { params as JSONValue, {}, ); - expect(result).toStrictEqual(message.params); + expect(result).toStrictEqual(message.result); expect((await outputResult)[0]?.toString()).toStrictEqual( JSON.stringify({ method: methodName, - type: 'JsonRpcRequest', jsonrpc: '2.0', id: null, params: params, diff --git a/tests/RPC/RPCServer.test.ts b/tests/RPC/RPCServer.test.ts index 528c0565b..0df1a8497 100644 --- a/tests/RPC/RPCServer.test.ts +++ b/tests/RPC/RPCServer.test.ts @@ -12,7 +12,6 @@ import type { ReadableWritablePair } from 'stream/web'; import { testProp, fc } from '@fast-check/jest'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import RPCServer from '@/RPC/RPCServer'; -import * as rpcErrors from '@/RPC/errors'; import * as rpcTestUtils from './utils'; describe(`${RPCServer.name}`, () => { @@ -22,7 +21,7 @@ describe(`${RPCServer.name}`, () => { const methodName = 'testMethod'; const specificMessageArb = fc - .array(rpcTestUtils.jsonRpcRequestArb(fc.constant(methodName)), { + .array(rpcTestUtils.jsonRpcRequestMessageArb(fc.constant(methodName)), { minLength: 5, }) .noShrink(); @@ -83,7 +82,7 @@ describe(`${RPCServer.name}`, () => { ); const singleNumberMessageArb = fc.array( - rpcTestUtils.jsonRpcRequestArb( + rpcTestUtils.jsonRpcRequestMessageArb( fc.constant(methodName), fc.integer({ min: 1, max: 20 }), ), @@ -224,18 +223,19 @@ describe(`${RPCServer.name}`, () => { const [outputResult, outputStream] = rpcTestUtils.streamToArray(); let thing; let lastMessage: JsonRpcMessage | undefined; - const tapStream = new rpcTestUtils.TapStream( - async (_, iteration) => { - if (iteration === 2) { - // @ts-ignore: kidnap private property - const activeStreams = rpc.activeStreams.values(); - for (const activeStream of activeStreams) { - thing = activeStream; - activeStream.cancel(new rpcErrors.ErrorRpcStopping()); - } - } - }, - ); + const tapStream: any = {}; + // Const tapStream = new rpcTestUtils.TapStream( + // async (_, iteration) => { + // if (iteration === 2) { + // // @ts-ignore: kidnap private property + // const activeStreams = rpc.activeStreams.values(); + // for (const activeStream of activeStreams) { + // thing = activeStream; + // activeStream.cancel(new rpcErrors.ErrorRpcStopping()); + // } + // } + // }, + // ); await tapStream.readable.pipeTo(outputStream); const readWriteStream: ReadableWritablePair = { readable: stream, @@ -255,7 +255,6 @@ describe(`${RPCServer.name}`, () => { await expect(thing).toResolve(); // Last message should be an error message expect(lastMessage).toBeDefined(); - expect(lastMessage?.type).toBe('JsonRpcResponseError'); }, ); diff --git a/tests/RPC/utils.test.ts b/tests/RPC/utils.test.ts index e028c6e97..bd737b505 100644 --- a/tests/RPC/utils.test.ts +++ b/tests/RPC/utils.test.ts @@ -12,7 +12,9 @@ describe('utils tests', () => { async (messages) => { const parsedStream = rpcTestUtils .jsonRpcStream(messages) - .pipeThrough(new rpcUtils.JsonToJsonMessageStream()); // Converting back. + .pipeThrough( + new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcMessage), + ); // Converting back. const asd = await AsyncIterable.as(parsedStream).toArray(); expect(asd).toEqual(messages); @@ -27,7 +29,9 @@ describe('utils tests', () => { const parsedStream = rpcTestUtils .jsonRpcStream(messages) .pipeThrough(new rpcTestUtils.BufferStreamToSnippedStream(snippattern)) // Imaginary internet here - .pipeThrough(new rpcUtils.JsonToJsonMessageStream()); // Converting back. + .pipeThrough( + new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcMessage), + ); // Converting back. const asd = await AsyncIterable.as(parsedStream).toArray(); expect(asd).toStrictEqual(messages); @@ -50,7 +54,9 @@ describe('utils tests', () => { .jsonRpcStream(messages) .pipeThrough(new rpcTestUtils.BufferStreamToSnippedStream(snippattern)) // Imaginary internet here .pipeThrough(new rpcTestUtils.BufferStreamToNoisyStream(noise)) // Adding bad data to the stream - .pipeThrough(new rpcUtils.JsonToJsonMessageStream()); // Converting back. + .pipeThrough( + new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcMessage), + ); // Converting back. await expect(AsyncIterable.as(parsedStream).toArray()).rejects.toThrow( rpcErrors.ErrorRpcParse, @@ -61,7 +67,7 @@ describe('utils tests', () => { testProp( 'can parse messages', - [rpcTestUtils.jsonRpcMessageArb], + [rpcTestUtils.jsonRpcMessageArb()], async (message) => { rpcUtils.parseJsonRpcMessage(message); }, @@ -71,15 +77,23 @@ describe('utils tests', () => { testProp( 'Message size limit is enforced', [ - fc.array(rpcTestUtils.jsonRpcRequestArb(fc.string({ minLength: 100 })), { - minLength: 1, - }), + fc.array( + rpcTestUtils.jsonRpcRequestMessageArb(fc.string({ minLength: 100 })), + { + minLength: 1, + }, + ), ], async (messages) => { const parsedStream = rpcTestUtils .jsonRpcStream(messages) .pipeThrough(new rpcTestUtils.BufferStreamToSnippedStream([10])) - .pipeThrough(new rpcUtils.JsonToJsonMessageStream(50)); + .pipeThrough( + new rpcUtils.JsonToJsonMessageStream( + rpcUtils.parseJsonRpcMessage, + 50, + ), + ); const doThing = async () => { for await (const _ of parsedStream) { diff --git a/tests/RPC/utils.ts b/tests/RPC/utils.ts index 81ce3913e..376f5705a 100644 --- a/tests/RPC/utils.ts +++ b/tests/RPC/utils.ts @@ -7,10 +7,12 @@ import type { POJO } from '@/types'; import type { JsonRpcError, JsonRpcMessage, - JsonRpcNotification, - JsonRpcRequest, + JsonRpcRequestNotification, + JsonRpcRequestMessage, JsonRpcResponseError, JsonRpcResponseResult, + JsonRpcResponse, + JsonRpcRequest, } from '@/RPC/types'; import type { JsonValue } from 'fast-check'; import { ReadableStream, WritableStream, TransformStream } from 'stream/web'; @@ -100,11 +102,13 @@ const jsonRpcStream = (messages: Array) => { }); }; -const jsonRpcRequestArb = ( +const safeJsonValueArb = fc + .jsonValue() + .map((value) => JSON.parse(JSON.stringify(value))); + +const jsonRpcRequestMessageArb = ( method: fc.Arbitrary = fc.string(), - params: fc.Arbitrary = fc - .jsonValue() - .map((value) => JSON.parse(JSON.stringify(value))), + params: fc.Arbitrary = safeJsonValueArb, ) => fc .record( @@ -119,37 +123,55 @@ const jsonRpcRequestArb = ( requiredKeys: ['type', 'jsonrpc', 'method', 'id'], }, ) + .noShrink() as fc.Arbitrary; + +const jsonRpcRequestNotificationArb = ( + method: fc.Arbitrary = fc.string(), + params: fc.Arbitrary = safeJsonValueArb, +) => + fc + .record( + { + type: fc.constant('JsonRpcNotification'), + jsonrpc: fc.constant('2.0'), + method: method, + params: params, + }, + { + requiredKeys: ['type', 'jsonrpc', 'method'], + }, + ) + .noShrink() as fc.Arbitrary; + +const jsonRpcRequestArb = ( + method: fc.Arbitrary = fc.string(), + params: fc.Arbitrary = safeJsonValueArb, +) => + fc + .oneof( + jsonRpcRequestMessageArb(method, params), + jsonRpcRequestNotificationArb(method, params), + ) .noShrink() as fc.Arbitrary; -const jsonRpcNotificationArb = fc - .record( - { - type: fc.constant('JsonRpcNotification'), +const jsonRpcResponseResultArb = ( + result: fc.Arbitrary = safeJsonValueArb, +) => + fc + .record({ + type: fc.constant('JsonRpcResponseResult'), jsonrpc: fc.constant('2.0'), - method: fc.string(), - params: fc.jsonValue(), - }, - { - requiredKeys: ['type', 'jsonrpc', 'method'], - }, - ) - .noShrink() as fc.Arbitrary; - -const jsonRpcResponseResultArb = fc - .record({ - type: fc.constant('JsonRpcResponseResult'), - jsonrpc: fc.constant('2.0'), - result: fc.jsonValue(), - id: fc.oneof(fc.string(), fc.integer(), fc.constant(null)), - }) - .noShrink() as fc.Arbitrary; + result: result, + id: fc.oneof(fc.string(), fc.integer(), fc.constant(null)), + }) + .noShrink() as fc.Arbitrary; const jsonRpcErrorArb = fc .record( { code: fc.integer(), message: fc.string(), - data: fc.jsonValue(), + data: safeJsonValueArb, }, { requiredKeys: ['code', 'message'], @@ -166,21 +188,28 @@ const jsonRpcResponseErrorArb = fc }) .noShrink() as fc.Arbitrary; -const jsonRpcMessageArb = fc - .oneof( - jsonRpcRequestArb(), - jsonRpcNotificationArb, - jsonRpcResponseResultArb, - jsonRpcResponseErrorArb, - ) - .noShrink() as fc.Arbitrary; +const jsonRpcResponseArb = ( + result: fc.Arbitrary = safeJsonValueArb, +) => + fc + .oneof(jsonRpcResponseResultArb(result), jsonRpcResponseErrorArb) + .noShrink() as fc.Arbitrary; + +const jsonRpcMessageArb = ( + method: fc.Arbitrary = fc.string(), + params: fc.Arbitrary = safeJsonValueArb, + result: fc.Arbitrary = safeJsonValueArb, +) => + fc + .oneof(jsonRpcRequestArb(method, params), jsonRpcResponseArb(result)) + .noShrink() as fc.Arbitrary; const snippingPatternArb = fc .array(fc.integer({ min: 1, max: 32 }), { minLength: 100, size: 'medium' }) .noShrink(); const jsonMessagesArb = fc - .array(jsonRpcRequestArb(), { minLength: 2 }) + .array(jsonRpcRequestMessageArb(), { minLength: 2 }) .noShrink(); function streamToArray(): [Promise>, WritableStream] { @@ -200,44 +229,20 @@ function streamToArray(): [Promise>, WritableStream] { return [result.p, outputStream]; } -class Tap implements Transformer { - protected iteration = 0; - protected tapIterator; - - constructor(tapIterator: (chunk: T, iteration: number) => Promise) { - this.tapIterator = tapIterator; - } - - transform: TransformerTransformCallback = async (chunk, controller) => { - await this.tapIterator(chunk, this.iteration); - controller.enqueue(chunk); - this.iteration += 1; - }; -} - -/** - * This is used to convert regular chunks into randomly sized chunks based on - * a provided pattern. This is to replicate randomness introduced by packets - * splitting up the data. - */ -class TapStream extends TransformStream { - constructor(tapIterator: (chunk: T, iteration: number) => Promise) { - super(new Tap(tapIterator)); - } -} - export { BufferStreamToSnippedStream, BufferStreamToNoisyStream, jsonRpcStream, + safeJsonValueArb, + jsonRpcRequestMessageArb, + jsonRpcRequestNotificationArb, jsonRpcRequestArb, - jsonRpcNotificationArb, jsonRpcResponseResultArb, jsonRpcErrorArb, jsonRpcResponseErrorArb, + jsonRpcResponseArb, jsonRpcMessageArb, snippingPatternArb, jsonMessagesArb, streamToArray, - TapStream, };