Skip to content

Commit

Permalink
fix: refactoring middleware
Browse files Browse the repository at this point in the history
Related #500
Related #502

[ci skip]
  • Loading branch information
tegefaulkes committed Feb 14, 2023
1 parent 586c990 commit fb198ff
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 190 deletions.
74 changes: 31 additions & 43 deletions src/RPC/RPCClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import type {
JsonRpcRequest,
JsonRpcResponse,
MiddlewareFactory,
Middleware,
} from './types';
import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy';
import Logger from '@matrixai/logger';
Expand Down Expand Up @@ -55,34 +54,37 @@ class RPCClient {
method: string,
_metadata: POJO,
): Promise<ReadableWritablePair<O, I>> {
const streamPair = await this.streamPairCreateCallback();
let reverseMiddlewareStream = streamPair.readable.pipeThrough(
new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcResponse),
);
for (const middleWare of this.reverseMiddleware) {
const middle = middleWare();
reverseMiddlewareStream = middle(reverseMiddlewareStream);
}
const outputStream = reverseMiddlewareStream.pipeThrough(
new rpcUtils.ClientOutputTransformerStream<O>(),
);
const inputMessageTransformer =
// Creating caller side transforms
const outputMessageTransforStream =
new rpcUtils.ClientOutputTransformerStream<O>();
const inputMessageTransformStream =
new rpcUtils.ClientInputTransformerStream<I>(method);
let forwardMiddlewareStream = inputMessageTransformer.readable;
for (const middleware of this.forwardMiddleWare) {
const middle = middleware();
forwardMiddlewareStream = middle(forwardMiddlewareStream);
let reverseStream = outputMessageTransforStream.writable;
let forwardStream = inputMessageTransformStream.readable;
// Setting up middleware chains
for (const middlewareFactory of this.middleware) {
const middleware = middlewareFactory();
forwardStream = forwardStream.pipeThrough(middleware.forward);
void middleware.reverse.readable.pipeTo(reverseStream).catch(() => {});
reverseStream = middleware.reverse.writable;
}
void forwardMiddlewareStream
// Hooking up agnostic stream side
const streamPair = await this.streamPairCreateCallback();
void streamPair.readable
.pipeThrough(
new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcResponse),
)
.pipeTo(reverseStream)
.catch(() => {});
void forwardStream
.pipeThrough(new rpcUtils.JsonMessageToJsonStream())
.pipeTo(streamPair.writable)
.catch(() => {});
const inputStream = inputMessageTransformer.writable;

// Returning interface
return {
readable: outputStream,
writable: inputStream,
readable: outputMessageTransforStream.readable,
writable: inputMessageTransformStream.writable,
};
}

Expand Down Expand Up @@ -205,37 +207,23 @@ class RPCClient {
return callerInterface.output;
}

protected forwardMiddleWare: Array<
MiddlewareFactory<Middleware<JsonRpcRequest<JSONValue>>>
> = [];
protected reverseMiddleware: Array<
MiddlewareFactory<Middleware<JsonRpcResponse<JSONValue>>>
protected middleware: Array<
MiddlewareFactory<JsonRpcRequest<JSONValue>, JsonRpcResponse<JSONValue>>
> = [];

@ready(new rpcErrors.ErrorRpcDestroyed())
public registerForwardMiddleware(
middlewareFactory: MiddlewareFactory<Middleware<JsonRpcRequest<JSONValue>>>,
) {
this.forwardMiddleWare.push(middlewareFactory);
}

@ready(new rpcErrors.ErrorRpcDestroyed())
public clearForwardMiddleware() {
this.reverseMiddleware = [];
}

@ready(new rpcErrors.ErrorRpcDestroyed())
public registerReverseMiddleware(
public registerMiddleware(
middlewareFactory: MiddlewareFactory<
Middleware<JsonRpcResponse<JSONValue>>
JsonRpcRequest<JSONValue>,
JsonRpcResponse<JSONValue>
>,
) {
this.reverseMiddleware.push(middlewareFactory);
this.middleware.push(middlewareFactory);
}

@ready(new rpcErrors.ErrorRpcDestroyed())
public clearReverseMiddleware() {
this.reverseMiddleware = [];
public clearMiddleware() {
this.middleware = [];
}
}

Expand Down
89 changes: 36 additions & 53 deletions src/RPC/RPCServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import type { ReadableWritablePair } from 'stream/web';
import type { JSONValue, POJO } from '../types';
import type { ConnectionInfo } from '../network/types';
import type { RPCErrorEvent } from './utils';
import type { MiddlewareFactory, MiddlewareShort, Middleware } from './types';
import type { MiddlewareFactory } from './types';
import { ReadableStream } from 'stream/web';
import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy';
import Logger from '@matrixai/logger';
Expand Down Expand Up @@ -154,22 +154,25 @@ class RPCServer {
void handlerProm
.finally(() => this.activeStreams.delete(handlerProm))
.catch(() => {});
// Setting up forward middleware
let middlewareStream = streamPair.readable.pipeThrough(
// Setting up middleware
let forwardStream = streamPair.readable.pipeThrough(
new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcRequest),
);
const shortMessageQueue: Array<JsonRpcResponse> = [];
for (const forwardMiddleWareFactory of this.forwardMiddleWare) {
const middleware = forwardMiddleWareFactory();
middlewareStream = middleware(
middlewareStream,
(value: JsonRpcResponse) => shortMessageQueue.push(value),
);
const outputTransformStream = new rpcUtils.JsonMessageToJsonStream();
void outputTransformStream.readable
.pipeTo(streamPair.writable)
.catch(() => {});
let reverseStream = outputTransformStream.writable;
for (const middlewareFactory of this.middleware) {
const middleware = middlewareFactory();
forwardStream = forwardStream.pipeThrough(middleware.forward);
void middleware.reverse.readable.pipeTo(reverseStream).catch(() => {});
reverseStream = middleware.reverse.writable;
}
// While ReadableStream can be converted to AsyncIterable, we want it as
// a generator.
const inputGen = async function* () {
for await (const dataMessage of middlewareStream) {
for await (const dataMessage of forwardStream) {
yield dataMessage;
}
};
Expand All @@ -186,9 +189,8 @@ class RPCServer {
const input = inputGen();
if (ctx.signal.aborted) throw ctx.signal.reason;
const leadingMetadataMessage = await input.next();
if (leadingMetadataMessage.done === true) {
throw new rpcErrors.ErrorRpcProtocal('Stream ended before response');
}
// If the stream ends early then we just stop processing
if (leadingMetadataMessage.done === true) return;
const method = leadingMetadataMessage.value.method;
const initialParams = leadingMetadataMessage.value.params;
const dataGen = async function* () {
Expand Down Expand Up @@ -223,14 +225,18 @@ class RPCServer {

const outputGenerator = outputGen();

let reverseMiddlewareStream = new ReadableStream<
const reverseMiddlewareStream = new ReadableStream<
JsonRpcResponse<JSONValue>
>({
pull: async (controller) => {
try {
const { value, done } = await outputGenerator.next();
if (done) {
controller.close();
try {
controller.close();
} catch {
// Ignore already closed error
}
resolve();
return;
}
Expand Down Expand Up @@ -259,24 +265,19 @@ class RPCServer {
}),
);
}
controller.close();
try {
controller.close();
} catch {
// Ignore already closed error
}
resolve();
}
},
cancel: async (reason) => {
await outputGenerator.throw(reason);
},
});
// Setting up reverse middleware
for (const reverseMiddleWareFactory of this.reverseMiddleware) {
const middleware = reverseMiddleWareFactory();
reverseMiddlewareStream = middleware(reverseMiddlewareStream);
}
reverseMiddlewareStream
.pipeThrough(new rpcUtils.QueueMergingTransformStream(shortMessageQueue))
.pipeThrough(new rpcUtils.JsonMessageToJsonStream())
.pipeTo(streamPair.writable)
.catch(() => {});
void reverseMiddlewareStream.pipeTo(reverseStream).catch(() => {});
}

@ready(new rpcErrors.ErrorRpcDestroyed())
Expand All @@ -297,41 +298,23 @@ class RPCServer {
this.events.removeEventListener(type, callback, options);
}

protected forwardMiddleWare: Array<
MiddlewareFactory<
MiddlewareShort<JsonRpcRequest<JSONValue>, JsonRpcResponse<JSONValue>>
>
> = [];
protected reverseMiddleware: Array<
MiddlewareFactory<Middleware<JsonRpcResponse<JSONValue>>>
protected middleware: Array<
MiddlewareFactory<JsonRpcRequest<JSONValue>, JsonRpcResponse<JSONValue>>
> = [];

@ready(new rpcErrors.ErrorRpcDestroyed())
public registerForwardMiddleware(
public registerMiddleware(
middlewareFactory: MiddlewareFactory<
MiddlewareShort<JsonRpcRequest<JSONValue>, JsonRpcResponse<JSONValue>>
>,
) {
this.forwardMiddleWare.push(middlewareFactory);
}

@ready(new rpcErrors.ErrorRpcDestroyed())
public clearForwardMiddleware() {
this.reverseMiddleware = [];
}

@ready(new rpcErrors.ErrorRpcDestroyed())
public registerReverseMiddleware(
middlewareFactory: MiddlewareFactory<
Middleware<JsonRpcResponse<JSONValue>>
JsonRpcRequest<JSONValue>,
JsonRpcResponse<JSONValue>
>,
) {
this.reverseMiddleware.push(middlewareFactory);
this.middleware.push(middlewareFactory);
}

@ready(new rpcErrors.ErrorRpcDestroyed())
public clearReverseMiddleware() {
this.reverseMiddleware = [];
public clearMiddleware() {
this.middleware = [];
}
}

Expand Down
13 changes: 4 additions & 9 deletions src/RPC/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import type { JSONValue, POJO } from '../types';
import type { ConnectionInfo } from '../network/types';
import type { ContextCancellable } from '../contexts/types';
import type { ReadableWritablePair } from 'stream/web';
import type { ReadableStream } from 'stream/web';

/**
* This is the JSON RPC request object. this is the generic message type used for the RPC.
Expand Down Expand Up @@ -128,12 +127,10 @@ type StreamPairCreateCallback = () => Promise<
ReadableWritablePair<Uint8Array, Uint8Array>
>;

type MiddlewareShort<T, K> = (
input: ReadableStream<T>,
short: (value: K) => void,
) => ReadableStream<T>;
type Middleware<T> = (input: ReadableStream<T>) => ReadableStream<T>;
type MiddlewareFactory<T> = () => T;
type MiddlewareFactory<F, R> = () => {
forward: ReadableWritablePair<F, F>;
reverse: ReadableWritablePair<R, R>;
};

export type {
JsonRpcRequestMessage,
Expand All @@ -149,7 +146,5 @@ export type {
ClientStreamHandler,
UnaryHandler,
StreamPairCreateCallback,
MiddlewareShort,
Middleware,
MiddlewareFactory,
};
62 changes: 31 additions & 31 deletions tests/RPC/RPCClient.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -341,21 +341,21 @@ describe(`${RPCClient.name}`, () => {
logger,
});

rpcClient.registerForwardMiddleware(() => {
return (input) =>
input.pipeThrough(
new TransformStream<
JsonRpcRequest<JSONValue>,
JsonRpcRequest<JSONValue>
>({
transform: (chunk, controller) => {
controller.enqueue({
...chunk,
params: 'one',
});
},
}),
);
rpcClient.registerMiddleware(() => {
return {
forward: new TransformStream<
JsonRpcRequest<JSONValue>,
JsonRpcRequest<JSONValue>
>({
transform: (chunk, controller) => {
controller.enqueue({
...chunk,
params: 'one',
});
},
}),
reverse: new TransformStream(),
};
});
const callerInterface = await rpcClient.duplexStreamCaller<
JSONValue,
Expand Down Expand Up @@ -391,7 +391,7 @@ describe(`${RPCClient.name}`, () => {
await rpcClient.destroy();
},
);
testProp.only(
testProp(
'generic duplex caller with reverse Middleware',
[specificMessageArb],
async (messages) => {
Expand All @@ -407,21 +407,21 @@ describe(`${RPCClient.name}`, () => {
logger,
});

rpcClient.registerReverseMiddleware(() => {
return (input) =>
input.pipeThrough(
new TransformStream<
JsonRpcResponse<JSONValue>,
JsonRpcResponse<JSONValue>
>({
transform: (chunk, controller) => {
controller.enqueue({
...chunk,
result: 'one',
});
},
}),
);
rpcClient.registerMiddleware(() => {
return {
forward: new TransformStream(),
reverse: new TransformStream<
JsonRpcResponse<JSONValue>,
JsonRpcResponse<JSONValue>
>({
transform: (chunk, controller) => {
controller.enqueue({
...chunk,
result: 'one',
});
},
}),
};
});
const callerInterface = await rpcClient.duplexStreamCaller<
JSONValue,
Expand Down
Loading

0 comments on commit fb198ff

Please sign in to comment.