From d773494953d91cb659e3cc05041cdf93e40e7da7 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Mon, 30 Jan 2023 19:09:43 +1100 Subject: [PATCH] wip: agentUnlock example Related #500 Related #501 Related #502 [ci skip] --- src/RPC/RPCClient.ts | 35 +--- src/RPC/RPCServer.ts | 5 +- src/clientRPC/handlers/agentUnlock.ts | 29 ++- src/clientRPC/types.ts | 10 + src/clientRPC/utils.ts | 199 ++++++++++++++----- tests/clientRPC/handlers/agentStatus.test.ts | 4 +- tests/clientRPC/handlers/agentUnlock.test.ts | 60 ++++-- 7 files changed, 235 insertions(+), 107 deletions(-) create mode 100644 src/clientRPC/types.ts diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index 8c348ac1a..0931a8020 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -1,5 +1,5 @@ import type { StreamPairCreateCallback } from './types'; -import type { JSONValue, POJO } from 'types'; +import type { JSONValue } from 'types'; import type { ReadableWritablePair } from 'stream/web'; import type { JsonRpcRequest, @@ -52,7 +52,6 @@ class RPCClient { @ready(new rpcErrors.ErrorRpcDestroyed()) public async duplexStreamCaller( method: string, - _metadata: POJO, ): Promise> { // Creating caller side transforms const outputMessageTransforStream = @@ -92,12 +91,8 @@ class RPCClient { public async serverStreamCaller( method: string, parameters: I, - metadata: POJO, ) { - const callerInterface = await this.duplexStreamCaller( - method, - metadata, - ); + const callerInterface = await this.duplexStreamCaller(method); const writer = callerInterface.writable.getWriter(); await writer.write(parameters); await writer.close(); @@ -108,12 +103,8 @@ class RPCClient { @ready(new rpcErrors.ErrorRpcDestroyed()) public async clientStreamCaller( method: string, - metadata: POJO, ) { - const callerInterface = await this.duplexStreamCaller( - method, - metadata, - ); + const callerInterface = await this.duplexStreamCaller(method); const reader = callerInterface.readable.getReader(); const output = reader.read().then(({ value, done }) => { if (done) { @@ -131,12 +122,8 @@ class RPCClient { public async unaryCaller( method: string, parameters: I, - metadata: POJO, ): Promise { - const callerInterface = await this.duplexStreamCaller( - method, - metadata, - ); + const callerInterface = await this.duplexStreamCaller(method); const reader = callerInterface.readable.getReader(); const writer = callerInterface.writable.getWriter(); await writer.write(parameters); @@ -153,12 +140,8 @@ class RPCClient { public async withDuplexCaller( method: string, f: (output: AsyncGenerator) => AsyncGenerator, - metadata: POJO, ): Promise { - const callerInterface = await this.duplexStreamCaller( - method, - metadata, - ); + const callerInterface = await this.duplexStreamCaller(method); const outputGenerator = async function* () { for await (const value of callerInterface.readable) { yield value; @@ -176,12 +159,10 @@ class RPCClient { method: string, parameters: I, f: (output: AsyncGenerator) => Promise, - metadata: POJO, ) { const callerInterface = await this.serverStreamCaller( method, parameters, - metadata, ); const outputGenerator = async function* () { yield* callerInterface; @@ -193,12 +174,8 @@ class RPCClient { public async withClientCaller( method: string, f: () => AsyncGenerator, - metadata: POJO, ): Promise { - const callerInterface = await this.clientStreamCaller( - method, - metadata, - ); + const callerInterface = await this.clientStreamCaller(method); const writer = callerInterface.writable.getWriter(); for await (const value of f()) { await writer.write(value); diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index 7d360d280..e6166fba8 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -20,6 +20,7 @@ import Logger from '@matrixai/logger'; import { PromiseCancellable } from '@matrixai/async-cancellable'; import * as rpcUtils from './utils'; import * as rpcErrors from './errors'; +import { sysexits } from '../errors'; interface RPCServer extends CreateDestroy {} @CreateDestroy() @@ -245,8 +246,8 @@ class RPCServer { if (rpcUtils.isReturnableError(e)) { // We want to convert this error to an error message and pass it along const rpcError: JsonRpcError = { - code: e.exitCode, - message: e.description, + code: e.exitCode ?? sysexits.UNKNOWN, + message: e.description ?? '', data: rpcUtils.fromError(e), }; const rpcErrorMessage: JsonRpcResponseError = { diff --git a/src/clientRPC/handlers/agentUnlock.ts b/src/clientRPC/handlers/agentUnlock.ts index 3eadc2a8e..e1e77ad30 100644 --- a/src/clientRPC/handlers/agentUnlock.ts +++ b/src/clientRPC/handlers/agentUnlock.ts @@ -1,10 +1,14 @@ import type { UnaryHandler } from '../../RPC/types'; import type Logger from '@matrixai/logger'; import type RPCClient from '../../RPC/RPCClient'; -import type { POJO } from '../../types'; +import type { JSONValue } from '../../types'; +import type { ClientDataAndMetadata } from '../types'; const agentUnlockName = 'agentStatus'; -const agentUnlockHandler: UnaryHandler = async ( +const agentUnlockHandler: UnaryHandler< + ClientDataAndMetadata, + ClientDataAndMetadata +> = async ( _input, _container: { logger: Logger; @@ -13,12 +17,25 @@ const agentUnlockHandler: UnaryHandler = async ( _ctx, ) => { // This is a NOP handler, - // authentication and unlocking is handled via middleware - return null; + // authentication and unlocking is handled via middleware. + // Failure to authenticate will be an error from the middleware layer. + return { + metadata: {}, + data: null, + }; }; -const agentUnlockCaller = async (metadata: POJO, rpcClient: RPCClient) => { - await rpcClient.unaryCaller(agentUnlockName, null, metadata); +const agentUnlockCaller = async ( + metadata: Record, + rpcClient: RPCClient, +) => { + return rpcClient.unaryCaller< + ClientDataAndMetadata, + ClientDataAndMetadata + >(agentUnlockName, { + metadata: metadata, + data: null, + }); }; export { agentUnlockName, agentUnlockHandler, agentUnlockCaller }; diff --git a/src/clientRPC/types.ts b/src/clientRPC/types.ts new file mode 100644 index 000000000..b570749f1 --- /dev/null +++ b/src/clientRPC/types.ts @@ -0,0 +1,10 @@ +import type { JSONValue } from '../types'; + +type ClientDataAndMetadata = { + metadata: JSONValue & { + Authorization?: string; + }; + data: T; +}; + +export type { ClientDataAndMetadata }; diff --git a/src/clientRPC/utils.ts b/src/clientRPC/utils.ts index ca96647e7..0aa45cd16 100644 --- a/src/clientRPC/utils.ts +++ b/src/clientRPC/utils.ts @@ -1,58 +1,165 @@ import type { SessionToken } from '../sessions/types'; import type KeyRing from '../keys/KeyRing'; import type SessionManager from '../sessions/SessionManager'; -import type { Authenticate } from '../client/types'; -import * as grpc from '@grpc/grpc-js'; +import type { Session } from '../sessions'; +import type { ClientDataAndMetadata } from './types'; +import type { JSONValue } from '../types'; +import type { + JsonRpcRequest, + JsonRpcResponse, + MiddlewareFactory, +} from '../RPC/types'; +import { TransformStream } from 'stream/web'; import * as clientErrors from '../client/errors'; +import * as utils from '../utils'; -/** - * Encodes an Authorization header from session token - * Assumes token is already encoded - * Will mutate metadata if it is passed in - */ -function encodeAuthFromSession( - token: SessionToken, - metadata: grpc.Metadata = new grpc.Metadata(), -): grpc.Metadata { - metadata.set('Authorization', `Bearer ${token}`); - return metadata; -} - -function authenticator( +async function authenticate( sessionManager: SessionManager, keyRing: KeyRing, -): Authenticate { - return async ( - forwardMetadata: grpc.Metadata, - reverseMetadata: grpc.Metadata = new grpc.Metadata(), - ) => { - const auth = forwardMetadata.get('Authorization')[0] as string | undefined; - if (auth == null) { - throw new clientErrors.ErrorClientAuthMissing(); + message: JsonRpcRequest>, +) { + if (message.params == null) throw new clientErrors.ErrorClientAuthMissing(); + if (message.params.metadata == null) { + throw new clientErrors.ErrorClientAuthMissing(); + } + const auth = message.params.metadata.Authorization; + if (auth == null) { + throw new clientErrors.ErrorClientAuthMissing(); + } + if (auth.startsWith('Bearer ')) { + const token = auth.substring(7) as SessionToken; + if (!(await sessionManager.verifyToken(token))) { + throw new clientErrors.ErrorClientAuthDenied(); } - if (auth.startsWith('Bearer ')) { - const token = auth.substring(7) as SessionToken; - if (!(await sessionManager.verifyToken(token))) { - throw new clientErrors.ErrorClientAuthDenied(); - } - } else if (auth.startsWith('Basic ')) { - const encoded = auth.substring(6); - const decoded = Buffer.from(encoded, 'base64').toString('utf-8'); - const match = decoded.match(/:(.*)/); - if (match == null) { - throw new clientErrors.ErrorClientAuthFormat(); - } - const password = match[1]; - if (!(await keyRing.checkPassword(password))) { - throw new clientErrors.ErrorClientAuthDenied(); - } - } else { - throw new clientErrors.ErrorClientAuthMissing(); + } else if (auth.startsWith('Basic ')) { + const encoded = auth.substring(6); + const decoded = Buffer.from(encoded, 'base64').toString('utf-8'); + const match = decoded.match(/:(.*)/); + if (match == null) { + throw new clientErrors.ErrorClientAuthFormat(); } - const token = await sessionManager.createToken(); - encodeAuthFromSession(token, reverseMetadata); - return reverseMetadata; + const password = match[1]; + if (!(await keyRing.checkPassword(password))) { + throw new clientErrors.ErrorClientAuthDenied(); + } + } else { + throw new clientErrors.ErrorClientAuthMissing(); + } + const token = await sessionManager.createToken(); + return `Bearer ${token}`; +} + +function decodeAuth(messageParams: ClientDataAndMetadata) { + const auth = messageParams.metadata.Authorization; + if (auth == null || !auth.startsWith('Bearer ')) { + return; + } + return auth.substring(7) as SessionToken; +} + +function encodeAuthFromPassword(password: string): string { + const encoded = Buffer.from(`:${password}`).toString('base64'); + return `Basic ${encoded}`; +} + +function authenticationMiddlewareServer( + sessionManager: SessionManager, + keyRing: KeyRing, +): MiddlewareFactory< + JsonRpcRequest>, + JsonRpcResponse> +> { + return () => { + let forwardFirst = true; + let reverseController; + let outgoingToken: string | null = null; + return { + forward: new TransformStream< + JsonRpcRequest>, + JsonRpcRequest> + >({ + transform: async (chunk, controller) => { + if (forwardFirst) { + try { + outgoingToken = await authenticate( + sessionManager, + keyRing, + chunk, + ); + } catch (e) { + controller.terminate(); + reverseController.terminate(); + return; + } + } + forwardFirst = false; + controller.enqueue(chunk); + }, + }), + reverse: new TransformStream({ + start: (controller) => { + reverseController = controller; + }, + transform: (chunk, controller) => { + // Add the outgoing metadata to the next message. + if (outgoingToken != null && 'result' in chunk) { + chunk.result.metadata.Authorization = outgoingToken; + outgoingToken = null; + } + controller.enqueue(chunk); + }, + }), + }; + }; +} + +function authenticationMiddlewareClient( + session: Session, +): MiddlewareFactory< + JsonRpcRequest>, + JsonRpcResponse> +> { + return () => { + let forwardFirst = true; + return { + forward: new TransformStream< + JsonRpcRequest>, + JsonRpcRequest> + >({ + transform: async (chunk, controller) => { + if (forwardFirst) { + if (chunk.params == null) utils.never(); + if (chunk.params.metadata.Authorization == null) { + const token = await session.readToken(); + if (token != null) { + chunk.params.metadata.Authorization = `Bearer ${token}`; + } + } + } + forwardFirst = false; + controller.enqueue(chunk); + }, + }), + reverse: new TransformStream< + JsonRpcResponse>, + JsonRpcResponse> + >({ + transform: async (chunk, controller) => { + controller.enqueue(chunk); + if (!('result' in chunk)) return; + const token = decodeAuth(chunk.result); + if (token == null) return; + await session.writeToken(token); + }, + }), + }; }; } -export { authenticator }; +export { + authenticate, + decodeAuth, + encodeAuthFromPassword, + authenticationMiddlewareServer, + authenticationMiddlewareClient, +}; diff --git a/tests/clientRPC/handlers/agentStatus.test.ts b/tests/clientRPC/handlers/agentStatus.test.ts index b64d33bf1..6f0de9e39 100644 --- a/tests/clientRPC/handlers/agentStatus.test.ts +++ b/tests/clientRPC/handlers/agentStatus.test.ts @@ -68,8 +68,8 @@ describe('agentStatus', () => { // Setup const rpcServer = await RPCServer.createRPCServer({ container: { - // KeyRing, - // certManager, + keyRing, + certManager, logger, }, logger, diff --git a/tests/clientRPC/handlers/agentUnlock.test.ts b/tests/clientRPC/handlers/agentUnlock.test.ts index b6194d864..9dacadc70 100644 --- a/tests/clientRPC/handlers/agentUnlock.test.ts +++ b/tests/clientRPC/handlers/agentUnlock.test.ts @@ -2,7 +2,6 @@ import type { ConnectionInfo } from '@/network/types'; import fs from 'fs'; import path from 'path'; import os from 'os'; -import { TransformStream } from 'stream/web'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import { DB } from '@matrixai/db'; import KeyRing from '@/keys/KeyRing'; @@ -16,6 +15,8 @@ import { agentUnlockCaller, } from '@/clientRPC/handlers/agentUnlock'; import RPCClient from '@/RPC/RPCClient'; +import { Session, SessionManager } from '@/sessions'; +import * as abcUtils from '@/clientRPC/utils'; import * as rpcTestUtils from '../../RPC/utils'; describe('agentStatus', () => { @@ -28,6 +29,8 @@ describe('agentStatus', () => { let keyRing: KeyRing; let taskManager: TaskManager; let certManager: CertManager; + let session: Session; + let sessionManager: SessionManager; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -35,6 +38,7 @@ describe('agentStatus', () => { ); const keysPath = path.join(dataDir, 'keys'); const dbPath = path.join(dataDir, 'db'); + const sessionPath = path.join(dataDir, 'session'); db = await DB.createDB({ dbPath, logger, @@ -54,6 +58,15 @@ describe('agentStatus', () => { taskManager, logger, }); + session = await Session.createSession({ + sessionTokenPath: sessionPath, + logger, + }); + sessionManager = await SessionManager.createSessionManager({ + db, + keyRing, + logger, + }); }); afterEach(async () => { await certManager.stop(); @@ -69,25 +82,14 @@ describe('agentStatus', () => { // Setup const rpcServer = await RPCServer.createRPCServer({ container: { - // KeyRing, - // certManager, logger, }, logger, }); rpcServer.registerUnaryHandler(agentUnlockName, agentUnlockHandler); - rpcServer.registerForwardMiddleware(() => { - return (input) => { - // This middleware needs to check the first message for the token - return input.pipeThrough( - new TransformStream({ - transform: (chunk, controller) => { - controller.enqueue(chunk); - }, - }), - ); - }; - }); + rpcServer.registerMiddleware( + abcUtils.authenticationMiddlewareServer(sessionManager, keyRing), + ); const rpcClient = await RPCClient.createRPCClient({ streamPairCreateCallback: async () => { const { clientPair, serverPair } = rpcTestUtils.createTapPairs(); @@ -96,15 +98,29 @@ describe('agentStatus', () => { }, logger, }); + rpcClient.registerMiddleware( + abcUtils.authenticationMiddlewareClient(session), + ); // Doing the test - const result = await agentUnlockCaller({}, rpcClient); - expect(result).toStrictEqual({ - pid: process.pid, - nodeId: keyRing.getNodeId(), - publicJwk: JSON.stringify( - keysUtils.publicKeyToJWK(keyRing.keyPair.publicKey), - ), + const result = await agentUnlockCaller( + { + Authorization: abcUtils.encodeAuthFromPassword(password), + }, + rpcClient, + ); + expect(result).toMatchObject({ + metadata: { + Authorization: expect.any(String), + }, + data: null, + }); + const result2 = await agentUnlockCaller({}, rpcClient); + expect(result2).toMatchObject({ + metadata: { + Authorization: expect.any(String), + }, + data: null, }); }); });