diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index bc7ea8bfd..d45f6ae38 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -130,6 +130,64 @@ class RPCClient { await writer.close(); return output.value; } + + @ready(new rpcErrors.ErrorRpcDestroyed()) + public async withDuplexCaller( + method: string, + f: (output: AsyncGenerator) => AsyncGenerator, + metadata: POJO, + ): Promise { + const callerInterface = await this.duplexStreamCaller( + method, + metadata, + ); + const outputGenerator = async function* () { + for await (const value of callerInterface.readable) { + yield value; + } + }; + const writer = callerInterface.writable.getWriter(); + for await (const value of f(outputGenerator())) { + await writer.write(value); + } + await writer.close(); + } + + @ready(new rpcErrors.ErrorRpcDestroyed()) + public async withServerCaller( + method: string, + parameters: I, + f: (output: AsyncGenerator) => Promise, + metadata: POJO, + ) { + const callerInterface = await this.serverStreamCaller( + method, + parameters, + metadata, + ); + const outputGenerator = async function* () { + yield* callerInterface; + }; + await f(outputGenerator()); + } + + @ready(new rpcErrors.ErrorRpcDestroyed()) + public async withClientCaller( + method: string, + f: () => AsyncGenerator, + metadata: POJO, + ): Promise { + const callerInterface = await this.clientStreamCaller( + method, + metadata, + ); + const writer = callerInterface.writable.getWriter(); + for await (const value of f()) { + await writer.write(value); + } + await writer.close(); + return callerInterface.output; + } } export default RPCClient; diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index 5e29dc5b7..91afceb97 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -202,4 +202,116 @@ describe(`${RPCClient.name}`, () => { await rpcClient.destroy(); }, ); + testProp( + 'withDuplexCaller', + [fc.array(rpcTestUtils.jsonRpcResponseResultArb(), { minLength: 1 })], + async (messages) => { + const inputStream = rpcTestUtils.jsonRpcStream(messages); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => streamPair, + logger, + }); + let count = 0; + await rpcClient.withDuplexCaller( + methodName, + async function* (output) { + for await (const value of output) { + count += 1; + yield value; + } + }, + {}, + ); + const result = await outputResult; + // We're just checking that it consuming the messages as expected + expect(result.length).toEqual(messages.length); + expect(count).toEqual(messages.length); + await rpcClient.destroy(); + }, + ); + testProp( + 'withServerCaller', + [ + fc.array(rpcTestUtils.jsonRpcResponseResultArb(), { minLength: 1 }), + rpcTestUtils.safeJsonValueArb, + ], + async (messages, params) => { + const inputStream = rpcTestUtils.jsonRpcStream(messages); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => streamPair, + logger, + }); + let count = 0; + await rpcClient.withServerCaller( + methodName, + params, + async (output) => { + for await (const _ of output) count += 1; + }, + {}, + ); + const result = await outputResult; + expect(count).toEqual(messages.length); + expect(result.toString()).toStrictEqual( + JSON.stringify({ + method: methodName, + jsonrpc: '2.0', + id: null, + params: params, + }), + ); + await rpcClient.destroy(); + }, + ); + testProp( + 'withClientCaller', + [ + rpcTestUtils.jsonRpcResponseResultArb(), + fc.array(rpcTestUtils.safeJsonValueArb, { minLength: 2 }).noShrink(), + ], + async (message, inputMessages) => { + const inputStream = rpcTestUtils.jsonRpcStream([message]); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => streamPair, + logger, + }); + const result = await rpcClient.withClientCaller( + methodName, + async function* () { + for (const inputMessage of inputMessages) { + yield inputMessage; + } + }, + {}, + ); + const expectedResult = inputMessages.map((v) => { + return JSON.stringify({ + method: methodName, + jsonrpc: '2.0', + id: null, + params: v, + }); + }); + expect((await outputResult).map((v) => v.toString())).toStrictEqual( + expectedResult, + ); + expect(result).toStrictEqual(message.result); + await rpcClient.destroy(); + }, + ); });