From d53f1ff0eb750ef965b06be099f4c0dd462a8aff Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Tue, 7 Feb 2023 20:34:42 +1100 Subject: [PATCH] feat: static middleware for server - Related #500 [ci skip] --- src/RPC/RPCClient.ts | 9 +- src/RPC/RPCServer.ts | 64 +++++------- src/RPC/types.ts | 8 +- src/RPC/utils.ts | 56 ++++++++++ src/clientRPC/utils.ts | 4 + tests/RPC/RPCServer.test.ts | 102 ++++++++++--------- tests/clientRPC/handlers/agentUnlock.test.ts | 12 +-- tests/clientRPC/websocket.test.ts | 2 +- 8 files changed, 160 insertions(+), 97 deletions(-) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index 06513e0e33..b515baa9dc 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -299,13 +299,20 @@ class RPCClient { } protected middleware: Array< - MiddlewareFactory, JsonRpcResponse> + MiddlewareFactory< + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse + > > = []; @ready(new rpcErrors.ErrorRpcDestroyed()) public registerMiddleware( middlewareFactory: MiddlewareFactory< JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, JsonRpcResponse >, ) { diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index d547f66f48..81a8ea4b72 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -31,16 +31,24 @@ class RPCServer { static async createRPCServer({ manifest, container, + middleware = rpcUtils.defaultMiddlewareWrapper(), logger = new Logger(this.name), }: { manifest: Manifest; container: POJO; + middleware?: MiddlewareFactory< + JsonRpcRequest, + Uint8Array, + Uint8Array, + JsonRpcResponseResult + >; logger?: Logger; }): Promise { logger.info(`Creating ${this.name}`); const rpcServer = new this({ manifest, container, + middleware, logger, }); logger.info(`Created ${this.name}`); @@ -53,14 +61,27 @@ class RPCServer { protected handlerMap: Map = new Map(); protected activeStreams: Set> = new Set(); protected events: EventTarget = new EventTarget(); + protected middleware: MiddlewareFactory< + JsonRpcRequest, + Uint8Array, + Uint8Array, + JsonRpcResponseResult + >; public constructor({ manifest, container, + middleware, logger, }: { manifest: Manifest; container: POJO; + middleware: MiddlewareFactory< + JsonRpcRequest, + Uint8Array, + Uint8Array, + JsonRpcResponseResult + >; logger: Logger; }) { for (const [key, manifestItem] of Object.entries(manifest)) { @@ -85,6 +106,7 @@ class RPCServer { } } this.container = container; + this.middleware = middleware; this.logger = logger; } @@ -120,23 +142,10 @@ class RPCServer { connectionInfo, ctx, ) => { - // Middleware - const outputTransformStream = new rpcUtils.JsonMessageToJsonStream(); - const outputReadableSteam = outputTransformStream.readable; - let forwardStream = input.pipeThrough( - new rpcUtils.JsonToJsonMessageStream( - rpcUtils.parseJsonRpcRequest, - undefined, - header, - ), - ); - let reverseStream = outputTransformStream.writable; - for (const middlewareFactory of this.middleware) { - const middleware = middlewareFactory(); - forwardStream = forwardStream.pipeThrough(middleware.forward); - void middleware.reverse.readable.pipeTo(reverseStream).catch(() => {}); - reverseStream = middleware.reverse.writable; - } + // Setting up middleware + const middleware = this.middleware(header); + const forwardStream = input.pipeThrough(middleware.forward); + const reverseStream = middleware.reverse.writable; const events = this.events; const outputGen = async function* (): AsyncGenerator< JsonRpcResponse @@ -206,7 +215,7 @@ class RPCServer { }); void reverseMiddlewareStream.pipeTo(reverseStream).catch(() => {}); - return outputReadableSteam; + return middleware.reverse.readable; }; this.registerRawStreamHandler(method, rawSteamHandler); @@ -344,25 +353,6 @@ class RPCServer { ) { this.events.removeEventListener(type, callback, options); } - - protected middleware: Array< - MiddlewareFactory, JsonRpcResponse> - > = []; - - @ready(new rpcErrors.ErrorRpcDestroyed()) - public registerMiddleware( - middlewareFactory: MiddlewareFactory< - JsonRpcRequest, - JsonRpcResponse - >, - ) { - this.middleware.push(middlewareFactory); - } - - @ready(new rpcErrors.ErrorRpcDestroyed()) - public clearMiddleware() { - this.middleware = []; - } } export default RPCServer; diff --git a/src/RPC/types.ts b/src/RPC/types.ts index 1995030ddc..312e659731 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -135,9 +135,11 @@ type StreamPairCreateCallback = () => Promise< ReadableWritablePair >; -type MiddlewareFactory = () => { - forward: ReadableWritablePair; - reverse: ReadableWritablePair; +type MiddlewareFactory = ( + header?: JsonRpcRequest, +) => { + forward: ReadableWritablePair; + reverse: ReadableWritablePair; }; type DuplexStreamCaller< diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index 9a05e9b28b..9cb9479ce5 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -13,6 +13,7 @@ import type { JsonRpcResponseResult, JsonRpcRequest, JsonRpcResponse, + MiddlewareFactory, } from 'RPC/types'; import type { JSONValue } from '../types'; import type { JsonValue } from 'fast-check'; @@ -637,6 +638,59 @@ function getHandlerTypes(manifest: Manifest): Record { return out; } +const defaultMiddleware: MiddlewareFactory< + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse +> = () => { + return { + forward: new TransformStream(), + reverse: new TransformStream(), + }; +}; + +const defaultMiddlewareWrapper = ( + middleware: MiddlewareFactory< + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse + > = defaultMiddleware, +) => { + return (header: JsonRpcRequest) => { + const inputTransformStream = new JsonToJsonMessageStream( + parseJsonRpcRequest, + undefined, + header, + ); + const outputTransformStream = new TransformStream< + JsonRpcResponseResult, + JsonRpcResponseResult + >(); + + const middleMiddleware = middleware(header); + + const forwardReadable = inputTransformStream.readable.pipeThrough( + middleMiddleware.forward, + ); // Usual middleware here + const reverseReadable = outputTransformStream.readable + .pipeThrough(middleMiddleware.reverse) // Usual middleware here + .pipeThrough(new JsonMessageToJsonStream()); + + return { + forward: { + readable: forwardReadable, + writable: inputTransformStream.writable, + }, + reverse: { + readable: reverseReadable, + writable: outputTransformStream.writable, + }, + }; + }; +}; + export { JsonToJsonMessageStream, JsonMessageToJsonStream, @@ -657,4 +711,6 @@ export { QueueMergingTransformStream, extractFirstMessageTransform, getHandlerTypes, + defaultMiddleware, + defaultMiddlewareWrapper, }; diff --git a/src/clientRPC/utils.ts b/src/clientRPC/utils.ts index 631825de31..2a7d074d0a 100644 --- a/src/clientRPC/utils.ts +++ b/src/clientRPC/utils.ts @@ -77,6 +77,8 @@ function authenticationMiddlewareServer( keyRing: KeyRing, ): MiddlewareFactory< JsonRpcRequest>, + JsonRpcRequest>, + JsonRpcResponse>, JsonRpcResponse> > { return () => { @@ -127,6 +129,8 @@ function authenticationMiddlewareClient( session: Session, ): MiddlewareFactory< JsonRpcRequest>, + JsonRpcRequest>, + JsonRpcResponse>, JsonRpcResponse> > { return () => { diff --git a/tests/RPC/RPCServer.test.ts b/tests/RPC/RPCServer.test.ts index 5e2cdfaa22..c4f35b6dc0 100644 --- a/tests/RPC/RPCServer.test.ts +++ b/tests/RPC/RPCServer.test.ts @@ -14,6 +14,7 @@ import { fc, testProp } from '@fast-check/jest'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import RPCServer from '@/RPC/RPCServer'; import * as rpcErrors from '@/RPC/errors'; +import * as rpcUtils from '@/RPC/utils'; import * as rpcTestUtils from './utils'; describe(`${RPCServer.name}`, () => { @@ -472,10 +473,22 @@ describe(`${RPCServer.name}`, () => { }, }; const container = {}; + const middleware = rpcUtils.defaultMiddlewareWrapper(() => { + return { + forward: new TransformStream({ + transform: (chunk, controller) => { + chunk.params = 1; + controller.enqueue(chunk); + }, + }), + reverse: new TransformStream(), + }; + }); const rpcServer = await RPCServer.createRPCServer({ manifest: { testMethod, }, + middleware, container, logger, }); @@ -484,17 +497,6 @@ describe(`${RPCServer.name}`, () => { readable: stream, writable: outputStream, }; - rpcServer.registerMiddleware(() => { - return { - forward: new TransformStream({ - transform: (chunk, controller) => { - chunk.params = 1; - controller.enqueue(chunk); - }, - }), - reverse: new TransformStream(), - }; - }); rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); const out = await outputResult; expect(out.map((v) => v!.toString())).toStrictEqual( @@ -519,10 +521,22 @@ describe(`${RPCServer.name}`, () => { }, }; const container = {}; + const middleware = rpcUtils.defaultMiddlewareWrapper(() => { + return { + forward: new TransformStream(), + reverse: new TransformStream({ + transform: (chunk, controller) => { + if ('result' in chunk) chunk.result = 1; + controller.enqueue(chunk); + }, + }), + }; + }); const rpcServer = await RPCServer.createRPCServer({ manifest: { testMethod, }, + middleware, container, logger, }); @@ -531,17 +545,6 @@ describe(`${RPCServer.name}`, () => { readable: stream, writable: outputStream, }; - rpcServer.registerMiddleware(() => { - return { - forward: new TransformStream(), - reverse: new TransformStream({ - transform: (chunk, controller) => { - if ('result' in chunk) chunk.result = 1; - controller.enqueue(chunk); - }, - }), - }; - }); rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); const out = await outputResult; expect(out.map((v) => v!.toString())).toStrictEqual( @@ -569,33 +572,7 @@ describe(`${RPCServer.name}`, () => { }, }; const container = {}; - const rpcServer = await RPCServer.createRPCServer({ - manifest: { - testMethod, - }, - container, - logger, - }); - const [outputResult, outputStream] = rpcTestUtils.streamToArray(); - const readWriteStream: ReadableWritablePair = { - readable: stream, - writable: outputStream, - }; - type TestType = { - metadata: { - token: string; - }; - data: JSONValue; - }; - const failureMessage: JsonRpcResponseError = { - jsonrpc: '2.0', - id: null, - error: { - code: 1, - message: 'failure of somekind', - }, - }; - rpcServer.registerMiddleware(() => { + const middleware = rpcUtils.defaultMiddlewareWrapper(() => { let first = true; let reverseController: TransformStreamDefaultController< JsonRpcResponse @@ -627,6 +604,33 @@ describe(`${RPCServer.name}`, () => { }), }; }); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod, + }, + middleware, + container, + logger, + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: outputStream, + }; + type TestType = { + metadata: { + token: string; + }; + data: JSONValue; + }; + const failureMessage: JsonRpcResponseError = { + jsonrpc: '2.0', + id: null, + error: { + code: 1, + message: 'failure of somekind', + }, + }; rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); expect((await outputResult).toString()).toEqual( JSON.stringify(failureMessage), diff --git a/tests/clientRPC/handlers/agentUnlock.test.ts b/tests/clientRPC/handlers/agentUnlock.test.ts index aeca087587..4397df0012 100644 --- a/tests/clientRPC/handlers/agentUnlock.test.ts +++ b/tests/clientRPC/handlers/agentUnlock.test.ts @@ -14,8 +14,8 @@ import CertManager from '@/keys/CertManager'; import { agentUnlock } from '@/clientRPC/handlers/agentUnlock'; import RPCClient from '@/RPC/RPCClient'; import { Session, SessionManager } from '@/sessions'; -import * as abcUtils from '@/clientRPC/utils'; import * as clientRPCUtils from '@/clientRPC/utils'; +import * as rpcUtils from '@/RPC/utils'; import * as testsUtils from '../../utils'; describe('agentUnlock', () => { @@ -95,14 +95,14 @@ describe('agentUnlock', () => { }; const rpcServer = await RPCServer.createRPCServer({ manifest, + middleware: rpcUtils.defaultMiddlewareWrapper( + clientRPCUtils.authenticationMiddlewareServer(sessionManager, keyRing), + ), container: { logger, }, logger, }); - rpcServer.registerMiddleware( - abcUtils.authenticationMiddlewareServer(sessionManager, keyRing), - ); wss = clientRPCUtils.createClientServer( server, rpcServer, @@ -120,13 +120,13 @@ describe('agentUnlock', () => { logger, }); rpcClient.registerMiddleware( - abcUtils.authenticationMiddlewareClient(session), + clientRPCUtils.authenticationMiddlewareClient(session), ); // Doing the test const result = await rpcClient.methods.agentUnlock({ metadata: { - Authorization: abcUtils.encodeAuthFromPassword(password), + Authorization: clientRPCUtils.encodeAuthFromPassword(password), }, data: null, }); diff --git a/tests/clientRPC/websocket.test.ts b/tests/clientRPC/websocket.test.ts index e66526bfc3..9b1c643804 100644 --- a/tests/clientRPC/websocket.test.ts +++ b/tests/clientRPC/websocket.test.ts @@ -70,7 +70,7 @@ describe('websocket', () => { test1, test2, }; - const rpcServer = new RPCServer({ + const rpcServer = await RPCServer.createRPCServer({ manifest, container: {}, logger: logger.getChild('RPCServer'),