Skip to content

Commit

Permalink
feat: agentUnlock example
Browse files Browse the repository at this point in the history
Related #500
Related #501
Related #502

[ci skip]
  • Loading branch information
tegefaulkes committed Feb 14, 2023
1 parent fb198ff commit 29e61a9
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 107 deletions.
35 changes: 6 additions & 29 deletions src/RPC/RPCClient.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { StreamPairCreateCallback } from './types';
import type { JSONValue, POJO } from 'types';
import type { JSONValue } from 'types';
import type { ReadableWritablePair } from 'stream/web';
import type {
JsonRpcRequest,
Expand Down Expand Up @@ -52,7 +52,6 @@ class RPCClient {
@ready(new rpcErrors.ErrorRpcDestroyed())
public async duplexStreamCaller<I extends JSONValue, O extends JSONValue>(
method: string,
_metadata: POJO,
): Promise<ReadableWritablePair<O, I>> {
// Creating caller side transforms
const outputMessageTransforStream =
Expand Down Expand Up @@ -92,12 +91,8 @@ class RPCClient {
public async serverStreamCaller<I extends JSONValue, O extends JSONValue>(
method: string,
parameters: I,
metadata: POJO,
) {
const callerInterface = await this.duplexStreamCaller<I, O>(
method,
metadata,
);
const callerInterface = await this.duplexStreamCaller<I, O>(method);
const writer = callerInterface.writable.getWriter();
await writer.write(parameters);
await writer.close();
Expand All @@ -108,12 +103,8 @@ class RPCClient {
@ready(new rpcErrors.ErrorRpcDestroyed())
public async clientStreamCaller<I extends JSONValue, O extends JSONValue>(
method: string,
metadata: POJO,
) {
const callerInterface = await this.duplexStreamCaller<I, O>(
method,
metadata,
);
const callerInterface = await this.duplexStreamCaller<I, O>(method);
const reader = callerInterface.readable.getReader();
const output = reader.read().then(({ value, done }) => {
if (done) {
Expand All @@ -131,12 +122,8 @@ class RPCClient {
public async unaryCaller<I extends JSONValue, O extends JSONValue>(
method: string,
parameters: I,
metadata: POJO,
): Promise<O> {
const callerInterface = await this.duplexStreamCaller<I, O>(
method,
metadata,
);
const callerInterface = await this.duplexStreamCaller<I, O>(method);
const reader = callerInterface.readable.getReader();
const writer = callerInterface.writable.getWriter();
await writer.write(parameters);
Expand All @@ -153,12 +140,8 @@ class RPCClient {
public async withDuplexCaller<I extends JSONValue, O extends JSONValue>(
method: string,
f: (output: AsyncGenerator<O>) => AsyncGenerator<I>,
metadata: POJO,
): Promise<void> {
const callerInterface = await this.duplexStreamCaller<I, O>(
method,
metadata,
);
const callerInterface = await this.duplexStreamCaller<I, O>(method);
const outputGenerator = async function* () {
for await (const value of callerInterface.readable) {
yield value;
Expand All @@ -176,12 +159,10 @@ class RPCClient {
method: string,
parameters: I,
f: (output: AsyncGenerator<O>) => Promise<void>,
metadata: POJO,
) {
const callerInterface = await this.serverStreamCaller<I, O>(
method,
parameters,
metadata,
);
const outputGenerator = async function* () {
yield* callerInterface;
Expand All @@ -193,12 +174,8 @@ class RPCClient {
public async withClientCaller<I extends JSONValue, O extends JSONValue>(
method: string,
f: () => AsyncGenerator<I>,
metadata: POJO,
): Promise<O> {
const callerInterface = await this.clientStreamCaller<I, O>(
method,
metadata,
);
const callerInterface = await this.clientStreamCaller<I, O>(method);
const writer = callerInterface.writable.getWriter();
for await (const value of f()) {
await writer.write(value);
Expand Down
5 changes: 3 additions & 2 deletions src/RPC/RPCServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import Logger from '@matrixai/logger';
import { PromiseCancellable } from '@matrixai/async-cancellable';
import * as rpcUtils from './utils';
import * as rpcErrors from './errors';
import { sysexits } from '../errors';

interface RPCServer extends CreateDestroy {}
@CreateDestroy()
Expand Down Expand Up @@ -245,8 +246,8 @@ class RPCServer {
if (rpcUtils.isReturnableError(e)) {
// We want to convert this error to an error message and pass it along
const rpcError: JsonRpcError = {
code: e.exitCode,
message: e.description,
code: e.exitCode ?? sysexits.UNKNOWN,
message: e.description ?? '',
data: rpcUtils.fromError(e),
};
const rpcErrorMessage: JsonRpcResponseError = {
Expand Down
29 changes: 23 additions & 6 deletions src/clientRPC/handlers/agentUnlock.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import type { UnaryHandler } from '../../RPC/types';
import type Logger from '@matrixai/logger';
import type RPCClient from '../../RPC/RPCClient';
import type { POJO } from '../../types';
import type { JSONValue } from '../../types';
import type { ClientDataAndMetadata } from '../types';

const agentUnlockName = 'agentStatus';
const agentUnlockHandler: UnaryHandler<null, null> = async (
const agentUnlockHandler: UnaryHandler<
ClientDataAndMetadata<null>,
ClientDataAndMetadata<null>
> = async (
_input,
_container: {
logger: Logger;
Expand All @@ -13,12 +17,25 @@ const agentUnlockHandler: UnaryHandler<null, null> = async (
_ctx,
) => {
// This is a NOP handler,
// authentication and unlocking is handled via middleware
return null;
// authentication and unlocking is handled via middleware.
// Failure to authenticate will be an error from the middleware layer.
return {
metadata: {},
data: null,
};
};

const agentUnlockCaller = async (metadata: POJO, rpcClient: RPCClient) => {
await rpcClient.unaryCaller<null, null>(agentUnlockName, null, metadata);
const agentUnlockCaller = async (
metadata: Record<string, JSONValue>,
rpcClient: RPCClient,
) => {
return rpcClient.unaryCaller<
ClientDataAndMetadata<null>,
ClientDataAndMetadata<null>
>(agentUnlockName, {
metadata: metadata,
data: null,
});
};

export { agentUnlockName, agentUnlockHandler, agentUnlockCaller };
10 changes: 10 additions & 0 deletions src/clientRPC/types.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import type { JSONValue } from '../types';

type ClientDataAndMetadata<T extends JSONValue> = {
metadata: JSONValue & {
Authorization?: string;
};
data: T;
};

export type { ClientDataAndMetadata };
199 changes: 153 additions & 46 deletions src/clientRPC/utils.ts
Original file line number Diff line number Diff line change
@@ -1,58 +1,165 @@
import type { SessionToken } from '../sessions/types';
import type KeyRing from '../keys/KeyRing';
import type SessionManager from '../sessions/SessionManager';
import type { Authenticate } from '../client/types';
import * as grpc from '@grpc/grpc-js';
import type { Session } from '../sessions';
import type { ClientDataAndMetadata } from './types';
import type { JSONValue } from '../types';
import type {
JsonRpcRequest,
JsonRpcResponse,
MiddlewareFactory,
} from '../RPC/types';
import { TransformStream } from 'stream/web';
import * as clientErrors from '../client/errors';
import * as utils from '../utils';

/**
* Encodes an Authorization header from session token
* Assumes token is already encoded
* Will mutate metadata if it is passed in
*/
function encodeAuthFromSession(
token: SessionToken,
metadata: grpc.Metadata = new grpc.Metadata(),
): grpc.Metadata {
metadata.set('Authorization', `Bearer ${token}`);
return metadata;
}

function authenticator(
async function authenticate(
sessionManager: SessionManager,
keyRing: KeyRing,
): Authenticate {
return async (
forwardMetadata: grpc.Metadata,
reverseMetadata: grpc.Metadata = new grpc.Metadata(),
) => {
const auth = forwardMetadata.get('Authorization')[0] as string | undefined;
if (auth == null) {
throw new clientErrors.ErrorClientAuthMissing();
message: JsonRpcRequest<ClientDataAndMetadata<JSONValue>>,
) {
if (message.params == null) throw new clientErrors.ErrorClientAuthMissing();
if (message.params.metadata == null) {
throw new clientErrors.ErrorClientAuthMissing();
}
const auth = message.params.metadata.Authorization;
if (auth == null) {
throw new clientErrors.ErrorClientAuthMissing();
}
if (auth.startsWith('Bearer ')) {
const token = auth.substring(7) as SessionToken;
if (!(await sessionManager.verifyToken(token))) {
throw new clientErrors.ErrorClientAuthDenied();
}
if (auth.startsWith('Bearer ')) {
const token = auth.substring(7) as SessionToken;
if (!(await sessionManager.verifyToken(token))) {
throw new clientErrors.ErrorClientAuthDenied();
}
} else if (auth.startsWith('Basic ')) {
const encoded = auth.substring(6);
const decoded = Buffer.from(encoded, 'base64').toString('utf-8');
const match = decoded.match(/:(.*)/);
if (match == null) {
throw new clientErrors.ErrorClientAuthFormat();
}
const password = match[1];
if (!(await keyRing.checkPassword(password))) {
throw new clientErrors.ErrorClientAuthDenied();
}
} else {
throw new clientErrors.ErrorClientAuthMissing();
} else if (auth.startsWith('Basic ')) {
const encoded = auth.substring(6);
const decoded = Buffer.from(encoded, 'base64').toString('utf-8');
const match = decoded.match(/:(.*)/);
if (match == null) {
throw new clientErrors.ErrorClientAuthFormat();
}
const token = await sessionManager.createToken();
encodeAuthFromSession(token, reverseMetadata);
return reverseMetadata;
const password = match[1];
if (!(await keyRing.checkPassword(password))) {
throw new clientErrors.ErrorClientAuthDenied();
}
} else {
throw new clientErrors.ErrorClientAuthMissing();
}
const token = await sessionManager.createToken();
return `Bearer ${token}`;
}

function decodeAuth(messageParams: ClientDataAndMetadata<JSONValue>) {
const auth = messageParams.metadata.Authorization;
if (auth == null || !auth.startsWith('Bearer ')) {
return;
}
return auth.substring(7) as SessionToken;
}

function encodeAuthFromPassword(password: string): string {
const encoded = Buffer.from(`:${password}`).toString('base64');
return `Basic ${encoded}`;
}

function authenticationMiddlewareServer(
sessionManager: SessionManager,
keyRing: KeyRing,
): MiddlewareFactory<
JsonRpcRequest<ClientDataAndMetadata<JSONValue>>,
JsonRpcResponse<ClientDataAndMetadata<JSONValue>>
> {
return () => {
let forwardFirst = true;
let reverseController;
let outgoingToken: string | null = null;
return {
forward: new TransformStream<
JsonRpcRequest<ClientDataAndMetadata<JSONValue>>,
JsonRpcRequest<ClientDataAndMetadata<JSONValue>>
>({
transform: async (chunk, controller) => {
if (forwardFirst) {
try {
outgoingToken = await authenticate(
sessionManager,
keyRing,
chunk,
);
} catch (e) {
controller.terminate();
reverseController.terminate();
return;
}
}
forwardFirst = false;
controller.enqueue(chunk);
},
}),
reverse: new TransformStream({
start: (controller) => {
reverseController = controller;
},
transform: (chunk, controller) => {
// Add the outgoing metadata to the next message.
if (outgoingToken != null && 'result' in chunk) {
chunk.result.metadata.Authorization = outgoingToken;
outgoingToken = null;
}
controller.enqueue(chunk);
},
}),
};
};
}

function authenticationMiddlewareClient(
session: Session,
): MiddlewareFactory<
JsonRpcRequest<ClientDataAndMetadata<JSONValue>>,
JsonRpcResponse<ClientDataAndMetadata<JSONValue>>
> {
return () => {
let forwardFirst = true;
return {
forward: new TransformStream<
JsonRpcRequest<ClientDataAndMetadata<JSONValue>>,
JsonRpcRequest<ClientDataAndMetadata<JSONValue>>
>({
transform: async (chunk, controller) => {
if (forwardFirst) {
if (chunk.params == null) utils.never();
if (chunk.params.metadata.Authorization == null) {
const token = await session.readToken();
if (token != null) {
chunk.params.metadata.Authorization = `Bearer ${token}`;
}
}
}
forwardFirst = false;
controller.enqueue(chunk);
},
}),
reverse: new TransformStream<
JsonRpcResponse<ClientDataAndMetadata<JSONValue>>,
JsonRpcResponse<ClientDataAndMetadata<JSONValue>>
>({
transform: async (chunk, controller) => {
controller.enqueue(chunk);
if (!('result' in chunk)) return;
const token = decodeAuth(chunk.result);
if (token == null) return;
await session.writeToken(token);
},
}),
};
};
}

export { authenticator };
export {
authenticate,
decodeAuth,
encodeAuthFromPassword,
authenticationMiddlewareServer,
authenticationMiddlewareClient,
};
4 changes: 2 additions & 2 deletions tests/clientRPC/handlers/agentStatus.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ describe('agentStatus', () => {
// Setup
const rpcServer = await RPCServer.createRPCServer({
container: {
// KeyRing,
// certManager,
keyRing,
certManager,
logger,
},
logger,
Expand Down
Loading

0 comments on commit 29e61a9

Please sign in to comment.