Skip to content

Commit

Permalink
feat: static middleware for server
Browse files Browse the repository at this point in the history
- Related #500

[ci skip]
  • Loading branch information
tegefaulkes committed Feb 7, 2023
1 parent f78a7f9 commit d53f1ff
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 97 deletions.
9 changes: 8 additions & 1 deletion src/RPC/RPCClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -299,13 +299,20 @@ class RPCClient<M extends Manifest> {
}

protected middleware: Array<
MiddlewareFactory<JsonRpcRequest<JSONValue>, JsonRpcResponse<JSONValue>>
MiddlewareFactory<
JsonRpcRequest<JSONValue>,
JsonRpcRequest<JSONValue>,
JsonRpcResponse<JSONValue>,
JsonRpcResponse<JSONValue>
>
> = [];

@ready(new rpcErrors.ErrorRpcDestroyed())
public registerMiddleware(
middlewareFactory: MiddlewareFactory<
JsonRpcRequest<JSONValue>,
JsonRpcRequest<JSONValue>,
JsonRpcResponse<JSONValue>,
JsonRpcResponse<JSONValue>
>,
) {
Expand Down
64 changes: 27 additions & 37 deletions src/RPC/RPCServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,24 @@ class RPCServer {
static async createRPCServer({
manifest,
container,
middleware = rpcUtils.defaultMiddlewareWrapper(),
logger = new Logger(this.name),
}: {
manifest: Manifest;
container: POJO;
middleware?: MiddlewareFactory<
JsonRpcRequest<JSONValue>,
Uint8Array,
Uint8Array,
JsonRpcResponseResult<JSONValue>
>;
logger?: Logger;
}): Promise<RPCServer> {
logger.info(`Creating ${this.name}`);
const rpcServer = new this({
manifest,
container,
middleware,
logger,
});
logger.info(`Created ${this.name}`);
Expand All @@ -53,14 +61,27 @@ class RPCServer {
protected handlerMap: Map<string, RawDuplexStreamHandler> = new Map();
protected activeStreams: Set<PromiseCancellable<void>> = new Set();
protected events: EventTarget = new EventTarget();
protected middleware: MiddlewareFactory<
JsonRpcRequest<JSONValue>,
Uint8Array,
Uint8Array,
JsonRpcResponseResult<JSONValue>
>;

public constructor({
manifest,
container,
middleware,
logger,
}: {
manifest: Manifest;
container: POJO;
middleware: MiddlewareFactory<
JsonRpcRequest<JSONValue>,
Uint8Array,
Uint8Array,
JsonRpcResponseResult<JSONValue>
>;
logger: Logger;
}) {
for (const [key, manifestItem] of Object.entries(manifest)) {
Expand All @@ -85,6 +106,7 @@ class RPCServer {
}
}
this.container = container;
this.middleware = middleware;
this.logger = logger;
}

Expand Down Expand Up @@ -120,23 +142,10 @@ class RPCServer {
connectionInfo,
ctx,
) => {
// Middleware
const outputTransformStream = new rpcUtils.JsonMessageToJsonStream();
const outputReadableSteam = outputTransformStream.readable;
let forwardStream = input.pipeThrough(
new rpcUtils.JsonToJsonMessageStream(
rpcUtils.parseJsonRpcRequest,
undefined,
header,
),
);
let reverseStream = outputTransformStream.writable;
for (const middlewareFactory of this.middleware) {
const middleware = middlewareFactory();
forwardStream = forwardStream.pipeThrough(middleware.forward);
void middleware.reverse.readable.pipeTo(reverseStream).catch(() => {});
reverseStream = middleware.reverse.writable;
}
// Setting up middleware
const middleware = this.middleware(header);
const forwardStream = input.pipeThrough(middleware.forward);
const reverseStream = middleware.reverse.writable;
const events = this.events;
const outputGen = async function* (): AsyncGenerator<
JsonRpcResponse<JSONValue>
Expand Down Expand Up @@ -206,7 +215,7 @@ class RPCServer {
});
void reverseMiddlewareStream.pipeTo(reverseStream).catch(() => {});

return outputReadableSteam;
return middleware.reverse.readable;
};

this.registerRawStreamHandler(method, rawSteamHandler);
Expand Down Expand Up @@ -344,25 +353,6 @@ class RPCServer {
) {
this.events.removeEventListener(type, callback, options);
}

protected middleware: Array<
MiddlewareFactory<JsonRpcRequest<JSONValue>, JsonRpcResponse<JSONValue>>
> = [];

@ready(new rpcErrors.ErrorRpcDestroyed())
public registerMiddleware(
middlewareFactory: MiddlewareFactory<
JsonRpcRequest<JSONValue>,
JsonRpcResponse<JSONValue>
>,
) {
this.middleware.push(middlewareFactory);
}

@ready(new rpcErrors.ErrorRpcDestroyed())
public clearMiddleware() {
this.middleware = [];
}
}

export default RPCServer;
8 changes: 5 additions & 3 deletions src/RPC/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,11 @@ type StreamPairCreateCallback = () => Promise<
ReadableWritablePair<Uint8Array, Uint8Array>
>;

type MiddlewareFactory<F, R> = () => {
forward: ReadableWritablePair<F, F>;
reverse: ReadableWritablePair<R, R>;
type MiddlewareFactory<FR, FW, RR, RW> = (
header?: JsonRpcRequest<JSONValue>,
) => {
forward: ReadableWritablePair<FR, FW>;
reverse: ReadableWritablePair<RR, RW>;
};

type DuplexStreamCaller<
Expand Down
56 changes: 56 additions & 0 deletions src/RPC/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import type {
JsonRpcResponseResult,
JsonRpcRequest,
JsonRpcResponse,
MiddlewareFactory,
} from 'RPC/types';
import type { JSONValue } from '../types';
import type { JsonValue } from 'fast-check';
Expand Down Expand Up @@ -637,6 +638,59 @@ function getHandlerTypes(manifest: Manifest): Record<string, HandlerType> {
return out;
}

const defaultMiddleware: MiddlewareFactory<
JsonRpcRequest<JSONValue>,
JsonRpcRequest<JSONValue>,
JsonRpcResponse<JSONValue>,
JsonRpcResponse<JSONValue>
> = () => {
return {
forward: new TransformStream(),
reverse: new TransformStream(),
};
};

const defaultMiddlewareWrapper = (
middleware: MiddlewareFactory<
JsonRpcRequest<JSONValue>,
JsonRpcRequest<JSONValue>,
JsonRpcResponse<JSONValue>,
JsonRpcResponse<JSONValue>
> = defaultMiddleware,
) => {
return (header: JsonRpcRequest<JSONValue>) => {
const inputTransformStream = new JsonToJsonMessageStream(
parseJsonRpcRequest,
undefined,
header,
);
const outputTransformStream = new TransformStream<
JsonRpcResponseResult<JSONValue>,
JsonRpcResponseResult<JSONValue>
>();

const middleMiddleware = middleware(header);

const forwardReadable = inputTransformStream.readable.pipeThrough(
middleMiddleware.forward,
); // Usual middleware here
const reverseReadable = outputTransformStream.readable
.pipeThrough(middleMiddleware.reverse) // Usual middleware here
.pipeThrough(new JsonMessageToJsonStream());

return {
forward: {
readable: forwardReadable,
writable: inputTransformStream.writable,
},
reverse: {
readable: reverseReadable,
writable: outputTransformStream.writable,
},
};
};
};

export {
JsonToJsonMessageStream,
JsonMessageToJsonStream,
Expand All @@ -657,4 +711,6 @@ export {
QueueMergingTransformStream,
extractFirstMessageTransform,
getHandlerTypes,
defaultMiddleware,
defaultMiddlewareWrapper,
};
4 changes: 4 additions & 0 deletions src/clientRPC/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ function authenticationMiddlewareServer(
keyRing: KeyRing,
): MiddlewareFactory<
JsonRpcRequest<ClientDataAndMetadata<JSONValue>>,
JsonRpcRequest<ClientDataAndMetadata<JSONValue>>,
JsonRpcResponse<ClientDataAndMetadata<JSONValue>>,
JsonRpcResponse<ClientDataAndMetadata<JSONValue>>
> {
return () => {
Expand Down Expand Up @@ -127,6 +129,8 @@ function authenticationMiddlewareClient(
session: Session,
): MiddlewareFactory<
JsonRpcRequest<ClientDataAndMetadata<JSONValue>>,
JsonRpcRequest<ClientDataAndMetadata<JSONValue>>,
JsonRpcResponse<ClientDataAndMetadata<JSONValue>>,
JsonRpcResponse<ClientDataAndMetadata<JSONValue>>
> {
return () => {
Expand Down
Loading

0 comments on commit d53f1ff

Please sign in to comment.