diff --git a/.env.example b/.env.example index fb42c6230..7b59c198f 100644 --- a/.env.example +++ b/.env.example @@ -7,33 +7,5 @@ NODE_ENV=development # Debug node native modules - https://nodejs.org/api/cli.html#node_debug_nativemodule # NODE_DEBUG_NATIVE= -# Path to PK executable to override tests/bin target -# PK_TEST_COMMAND= - -# If set, indicates that `PK_TEST_COMMAND` is targetting docker -# PK_TEST_COMMAND_DOCKER= -# Accessing AWS for testnet.polykey.io and mainnet.polykey.io deployment -AWS_DEFAULT_REGION='ap-southeast-2' -AWS_ACCESS_KEY_ID= -AWS_SECRET_ACCESS_KEY= - -# Path to container registry authentication file used by `skopeo` -# The file has the same contents as `DOCKER_AUTH_CONFIG` -# Use this command to acquire the auth file at `./tmp/auth.json`: -# ``` -# printf 'PASSWORD' | skopeo login \ -# --username 'USERNAME' \ -# --password-stdin \ -# $CI_REGISTRY_IMAGE \ -# --authfile=./tmp/auth.json -# ``` -# REGISTRY_AUTH_FILE= - # Authenticate to GitHub with `gh` # GITHUB_TOKEN= - -# To allow testing different executables in the bin tests -# Both PK_TEST_COMMAND and PK_TEST_PLATFORM must be set at the same time -# PK_TEST_COMMAND= #Specify the shell command we want to test against -# PK_TEST_PLATFORM=docker #Overrides the auto set `testPlatform` variable used for enabling platform specific tests -# PK_TEST_TMPDIR= #Sets the `global.tmpDir` variable to allow overriding the temp directory used for tests diff --git a/jest.config.js b/jest.config.js index 1a7a09533..d964a3d7d 100644 --- a/jest.config.js +++ b/jest.config.js @@ -11,7 +11,6 @@ const moduleNameMapper = pathsToModuleNameMapper(compilerOptions.paths, { // Global variables that are shared across the jest worker pool // These variables must be static and serializable -if ((process.env.PK_TEST_PLATFORM != null) !== (process.env.PK_TEST_COMMAND != null)) throw Error('Both PK_TEST_PLATFORM and PK_TEST_COMMAND must be set together.') const globals = { // Absolute directory to the project root projectDir: __dirname, @@ -23,13 +22,9 @@ const globals = { ), // Default asynchronous test timeout defaultTimeout: 20000, - polykeyStartupTimeout: 30000, failedConnectionTimeout: 50000, // Timeouts rely on setTimeout which takes 32 bit numbers maxTimeout: Math.pow(2, 31) - 1, - testCmd: process.env.PK_TEST_COMMAND, - testPlatform: process.env.PK_TEST_PLATFORM, - tmpDir: path.resolve(process.env.PK_TEST_TMPDIR ?? os.tmpdir()), }; // The `globalSetup` and `globalTeardown` cannot access the `globals` diff --git a/package-lock.json b/package-lock.json index bf81767be..139508a82 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "polykey", - "version": "1.1.3-alpha.0", + "version": "1.1.5-feature-agent-migration-stage2.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "polykey", - "version": "1.1.3-alpha.0", + "version": "1.1.5-feature-agent-migration-stage2.0", "license": "GPL-3.0", "dependencies": { "@matrixai/async-cancellable": "^1.1.1", @@ -17,7 +17,7 @@ "@matrixai/errors": "^1.1.7", "@matrixai/id": "^3.3.6", "@matrixai/logger": "^3.1.0", - "@matrixai/quic": "^0.0.13", + "@matrixai/quic": "^0.0.16", "@matrixai/resources": "^1.1.5", "@matrixai/timer": "^1.1.1", "@matrixai/workers": "^1.3.7", @@ -32,7 +32,6 @@ "ajv": "^7.0.4", "canonicalize": "^1.0.5", "cheerio": "^1.0.0-rc.5", - "commander": "^8.3.0", "cross-fetch": "^3.0.6", "cross-spawn": "^7.0.3", "encryptedfs": "^3.5.6", @@ -49,9 +48,6 @@ "resource-counter": "^1.2.4", "sodium-native": "^3.4.1", "threads": "^1.6.5", - "tslib": "^2.4.0", - "tsyringe": "^4.7.0", - "uWebSockets.js": "github:uNetworking/uWebSockets.js#v20.19.0", "ws": "^8.12.0" }, "devDependencies": { @@ -60,8 +56,7 @@ "@swc/jest": "^0.2.26", "@types/cross-spawn": "^6.0.2", "@types/jest": "^28.1.3", - "@types/nexpect": "^0.4.31", - "@types/node": "^18.11.11", + "@types/node": "^18.15.0", "@types/pako": "^1.0.2", "@types/prompts": "^2.0.13", "@types/readable-stream": "^2.3.11", @@ -1784,14 +1779,14 @@ "integrity": "sha512-C4JWpgbNik3V99bfGfDell5cH3JULD67eEq9CeXl4rYgsvanF8hhuY84ZYvndPhimt9qjA9/Z8uExKGoiv1zVw==" }, "node_modules/@matrixai/quic": { - "version": "0.0.13", - "resolved": "https://registry.npmjs.org/@matrixai/quic/-/quic-0.0.13.tgz", - "integrity": "sha512-tvlA0m2fUIchyEZxzkBbvYNXYf21u0gR4Lv2BaYZYmGa1Fr2VH07MCZu9Ka8DpAEOXKEU94yqhNSEKCCJ83LJA==", + "version": "0.0.16", + "resolved": "https://registry.npmjs.org/@matrixai/quic/-/quic-0.0.16.tgz", + "integrity": "sha512-TSAv9KShBISnKngnD+gpbwLV62Og7X+Na7lmO3BoSaw8YWiUs+yo2hnzJXx8YVr3iYeeJoy45LdcethuUzMwfg==", "dependencies": { - "@matrixai/async-cancellable": "^1.1.0", + "@matrixai/async-cancellable": "^1.1.1", "@matrixai/async-init": "^1.8.4", "@matrixai/async-locks": "^4.0.0", - "@matrixai/contexts": "^1.0.0", + "@matrixai/contexts": "^1.1.0", "@matrixai/errors": "^1.1.7", "@matrixai/logger": "^3.1.0", "@matrixai/resources": "^1.1.5", @@ -1799,16 +1794,16 @@ "ip-num": "^1.5.0" }, "optionalDependencies": { - "@matrixai/quic-darwin-arm64": "0.0.13", - "@matrixai/quic-darwin-x64": "0.0.13", - "@matrixai/quic-linux-x64": "0.0.13", - "@matrixai/quic-win32-x64": "0.0.13" + "@matrixai/quic-darwin-arm64": "0.0.16", + "@matrixai/quic-darwin-x64": "0.0.16", + "@matrixai/quic-linux-x64": "0.0.16", + "@matrixai/quic-win32-x64": "0.0.16" } }, "node_modules/@matrixai/quic-darwin-arm64": { - "version": "0.0.13", - "resolved": "https://registry.npmjs.org/@matrixai/quic-darwin-arm64/-/quic-darwin-arm64-0.0.13.tgz", - "integrity": "sha512-EKBfqYr6mMj0k9cE97KiommyFb7eD3u4OWloMFySERcBzg+9HWwonDX5/kyChllxEDorPXneW/CfF8gtZTQ1ug==", + "version": "0.0.16", + "resolved": "https://registry.npmjs.org/@matrixai/quic-darwin-arm64/-/quic-darwin-arm64-0.0.16.tgz", + "integrity": "sha512-5Tyi2qkJf/VNd2s36Ddt1RzZD4dxyZHX1eH5qCmcyJK6oSd+BFLUlI4CrXlyo1PbmKg+lHFS7FRezVZ73nvmug==", "cpu": [ "arm64" ], @@ -1818,9 +1813,9 @@ ] }, "node_modules/@matrixai/quic-darwin-x64": { - "version": "0.0.13", - "resolved": "https://registry.npmjs.org/@matrixai/quic-darwin-x64/-/quic-darwin-x64-0.0.13.tgz", - "integrity": "sha512-WTf9gKdAqHkWVk48eWZ4JofctjZBrvUxEfg8HcBDzye1kz1O+0IyJlF4web3ZFYu/lvoGQn6DTaJvozdQS5hTw==", + "version": "0.0.16", + "resolved": "https://registry.npmjs.org/@matrixai/quic-darwin-x64/-/quic-darwin-x64-0.0.16.tgz", + "integrity": "sha512-C2zaQrOLu+c8NS4TjnH9ynhXBWAd1H8/wZp3AnwzqVGNJalOfQM7X2G9f0+E4awwtT1jxfNPDPfiDx0hQgoOiA==", "cpu": [ "x64" ], @@ -1830,9 +1825,9 @@ ] }, "node_modules/@matrixai/quic-linux-x64": { - "version": "0.0.13", - "resolved": "https://registry.npmjs.org/@matrixai/quic-linux-x64/-/quic-linux-x64-0.0.13.tgz", - "integrity": "sha512-ExOhO9YjiCNV6OrRMF2+CVQdPANa2zSqlMzCUaLC5whAsll50M08LpoV4J/HnmpTWPcfohr+G28bFWVsnb8/wA==", + "version": "0.0.16", + "resolved": "https://registry.npmjs.org/@matrixai/quic-linux-x64/-/quic-linux-x64-0.0.16.tgz", + "integrity": "sha512-4TmtS030ZNKH8RxDtzZHsjeP2vdmVAZi1+SKmEXpea6ZeC0cUassYLXz4afgX0KanSAKtOYfXX3cVF7+p8Ci5Q==", "cpu": [ "x64" ], @@ -2579,15 +2574,6 @@ "integrity": "sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ==", "dev": true }, - "node_modules/@types/nexpect": { - "version": "0.4.31", - "resolved": "https://registry.npmjs.org/@types/nexpect/-/nexpect-0.4.31.tgz", - "integrity": "sha512-Plh9Dlj2AKdsblgF1Pv7s2BjlojqW93d1zIUtK5xVVrUjkZQezyWIOAq0Xfwp0e0SDQ70YmaDqzhoJru2kqVPA==", - "dev": true, - "dependencies": { - "@types/node": "*" - } - }, "node_modules/@types/node": { "version": "18.17.3", "resolved": "https://registry.npmjs.org/@types/node/-/node-18.17.3.tgz", @@ -3602,14 +3588,6 @@ "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", "dev": true }, - "node_modules/commander": { - "version": "8.3.0", - "resolved": "https://registry.npmjs.org/commander/-/commander-8.3.0.tgz", - "integrity": "sha512-OkTL9umf+He2DZkUq8f8J9of7yL6RJKI24dVITBmNfZBmri9zYZQrKkuXiKhyfPSu8tUhnVBB1iKXevvnlR4Ww==", - "engines": { - "node": ">= 12" - } - }, "node_modules/common-tags": { "version": "1.8.2", "resolved": "https://registry.npmjs.org/common-tags/-/common-tags-1.8.2.tgz", @@ -9580,10 +9558,6 @@ "uuid": "dist/bin/uuid" } }, - "node_modules/uWebSockets.js": { - "version": "20.19.0", - "resolved": "git+ssh://git@github.com/uNetworking/uWebSockets.js.git#42c9c0d5d31f46ca4115dc75672b0037ec970f28" - }, "node_modules/v8-compile-cache-lib": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz", diff --git a/package.json b/package.json index d89bbab3e..4b3b51684 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "polykey", - "version": "1.1.3-alpha.0", + "version": "1.1.5-feature-agent-migration-stage2.0", "homepage": "https://polykey.io", "author": "Matrix AI", "contributors": [ @@ -30,6 +30,9 @@ }, { "name": "Emma Casolin" + }, + { + "name": "Amy Yan" } ], "description": "Polykey Core Library", @@ -67,10 +70,10 @@ "@matrixai/errors": "^1.1.7", "@matrixai/id": "^3.3.6", "@matrixai/logger": "^3.1.0", - "@matrixai/quic": "^0.0.13", "@matrixai/resources": "^1.1.5", "@matrixai/timer": "^1.1.1", "@matrixai/workers": "^1.3.7", + "@matrixai/quic": "^0.0.16", "@peculiar/asn1-pkcs8": "^2.3.0", "@peculiar/asn1-schema": "^2.3.0", "@peculiar/asn1-x509": "^2.3.0", @@ -82,7 +85,6 @@ "ajv": "^7.0.4", "canonicalize": "^1.0.5", "cheerio": "^1.0.0-rc.5", - "commander": "^8.3.0", "cross-fetch": "^3.0.6", "cross-spawn": "^7.0.3", "encryptedfs": "^3.5.6", @@ -99,9 +101,6 @@ "resource-counter": "^1.2.4", "sodium-native": "^3.4.1", "threads": "^1.6.5", - "tslib": "^2.4.0", - "tsyringe": "^4.7.0", - "uWebSockets.js": "github:uNetworking/uWebSockets.js#v20.19.0", "ws": "^8.12.0" }, "devDependencies": { @@ -110,8 +109,7 @@ "@swc/jest": "^0.2.26", "@types/cross-spawn": "^6.0.2", "@types/jest": "^28.1.3", - "@types/nexpect": "^0.4.31", - "@types/node": "^18.11.11", + "@types/node": "^18.15.0", "@types/pako": "^1.0.2", "@types/prompts": "^2.0.13", "@types/readable-stream": "^2.3.11", diff --git a/src/PolykeyAgent.ts b/src/PolykeyAgent.ts index 9384f878f..6ead0d885 100644 --- a/src/PolykeyAgent.ts +++ b/src/PolykeyAgent.ts @@ -1,12 +1,11 @@ import type { FileSystem, PromiseDeconstructed } from './types'; import type { PolykeyWorkerManagerInterface } from './workers/types'; -import type { ConnectionData, Host, Port, TLSConfig } from './network/types'; +import type { ConnectionData, TLSConfig } from './network/types'; import type { SeedNodes } from './nodes/types'; -import type { CertificatePEM, CertManagerChangeData, Key } from './keys/types'; +import type { CertManagerChangeData, Key } from './keys/types'; import type { RecoveryCode, PrivateKey } from './keys/types'; import type { PasswordMemLimit, PasswordOpsLimit } from './keys/types'; -import type * as quicEvents from '@matrixai/quic/dist/events'; -import type { ClientCrypto, QUICConfig, ServerCrypto } from '@matrixai/quic'; +import type { ClientCrypto, ServerCrypto } from '@matrixai/quic'; import path from 'path'; import process from 'process'; import { webcrypto } from 'crypto'; @@ -52,11 +51,12 @@ type NetworkConfig = { agentHost?: string; agentPort?: number; ipv6Only?: boolean; + agentKeepAliveIntervalTime?: number; + agentMaxIdleTimeout?: number; // RPCServer for client service clientHost?: string; clientPort?: number; // Websocket server config - maxReadableStreamBytes?: number; maxIdleTimeout?: number; pingIntervalTime?: number; pingTimeoutTimeTime?: number; @@ -66,11 +66,6 @@ type NetworkConfig = { handlerTimeoutGraceTime?: number; }; -type PolykeyQUICConfig = Omit< - Partial, - 'ca' | 'key' | 'cert' | 'verifyPeer' | 'verifyAllowFail' ->; - interface PolykeyAgent extends CreateDestroyStartStop {} @CreateDestroyStartStop( new errors.ErrorPolykeyAgentRunning(), @@ -97,8 +92,6 @@ class PolykeyAgent { keyRingConfig = {}, certManagerConfig = {}, networkConfig = {}, - quicServerConfig = {}, - quicClientConfig = {}, nodeConnectionManagerConfig = {}, seedNodes = {}, workers, @@ -124,7 +117,6 @@ class PolykeyAgent { webSocketServerClient, rpcServerAgent, quicSocket, - quicServerAgent, fs = require('fs'), logger = new Logger(this.name), fresh = false, @@ -151,8 +143,6 @@ class PolykeyAgent { connectionHolePunchIntervalTime?: number; }; networkConfig?: NetworkConfig; - quicServerConfig?: PolykeyQUICConfig; - quicClientConfig?: PolykeyQUICConfig; seedNodes?: SeedNodes; workers?: number; status?: Status; @@ -176,7 +166,6 @@ class PolykeyAgent { webSocketServerClient?: WebSocketServer; rpcServerAgent?: RPCServer; quicSocket?: QUICSocket; - quicServerAgent?: QUICServer; fs?: FileSystem; logger?: Logger; fresh?: boolean; @@ -201,14 +190,7 @@ class PolykeyAgent { ...config.defaults.networkConfig, ...utils.filterEmptyObject(networkConfig), }; - const quicServerConfig_ = { - ...config.defaults.quicServerConfig, - ...utils.filterEmptyObject(quicServerConfig), - }; - const quicClientConfig_ = { - ...config.defaults.quicClientConfig, - ...utils.filterEmptyObject(quicClientConfig), - }; + await utils.mkdirExists(fs, nodePath); const statusPath = path.join(nodePath, config.defaults.statusBase); const statusLockPath = path.join(nodePath, config.defaults.statusLockBase); @@ -358,29 +340,57 @@ class PolykeyAgent { logger: logger.getChild(QUICSocket.name), resolveHostname, }); - const clientCrypto: ClientCrypto = { + const crypto: ServerCrypto & ClientCrypto = { randomBytes: async (data: ArrayBuffer) => { const randomBytes = keysUtils.getRandomBytes(data.byteLength); const dataBuf = Buffer.from(data); dataBuf.write(randomBytes.toString('binary'), 'binary'); }, + async sign(key: ArrayBuffer, data: ArrayBuffer) { + const cryptoKey = await webcrypto.subtle.importKey( + 'raw', + key, + { + name: 'HMAC', + hash: 'SHA-256', + }, + true, + ['sign', 'verify'], + ); + return webcrypto.subtle.sign('HMAC', cryptoKey, data); + }, + async verify(key: ArrayBuffer, data: ArrayBuffer, sig: ArrayBuffer) { + const cryptoKey = await webcrypto.subtle.importKey( + 'raw', + key, + { + name: 'HMAC', + hash: 'SHA-256', + }, + true, + ['sign', 'verify'], + ); + return webcrypto.subtle.verify('HMAC', cryptoKey, sig, data); + }, + }; + const tlsConfig: TLSConfig = { + keyPrivatePem: keysUtils.privateKeyToPEM(keyRing.keyPair.privateKey), + certChainPem: await certManager.getCertPEMsChainPEM(), }; nodeConnectionManager = nodeConnectionManager ?? new NodeConnectionManager({ + handleStream: () => {}, keyRing, nodeGraph, seedNodes, quicSocket, - quicClientConfig: { - ...quicClientConfig_, - key: keysUtils.privateKeyToPEM(keyRing.keyPair.privateKey), - cert: await certManager.getCertPEMsChainPEM(), - }, ...nodeConnectionManagerConfig_, - crypto: { - ops: clientCrypto, - }, + connectionKeepAliveIntervalTime: + networkConfig_.agentKeepAliveIntervalTime, + connectionMaxIdleTimeout: networkConfig_.agentMaxIdleTimeout, + tlsConfig, + crypto, logger: logger.getChild(NodeConnectionManager.name), }); nodeManager = @@ -474,20 +484,14 @@ class PolykeyAgent { logger: logger.getChild(RPCServer.name + 'Client'), }); } - const tlsConfig: TLSConfig = { - keyPrivatePem: keysUtils.privateKeyToPEM(keyRing.keyPair.privateKey), - certChainPem: await certManager.getCertPEMsChainPEM(), - }; webSocketServerClient = webSocketServerClient ?? (await WebSocketServer.createWebSocketServer({ connectionCallback: (rpcStream) => rpcServerClient!.handleStream(rpcStream), - fs, host: networkConfig_.clientHost, port: networkConfig_.clientPort, tlsConfig, - maxReadableStreamBytes: networkConfig_.maxReadableStreamBytes, maxIdleTimeout: networkConfig_.maxIdleTimeout, pingIntervalTime: networkConfig_.pingIntervalTime, pingTimeoutTimeTime: networkConfig_.pingTimeoutTimeTime, @@ -517,58 +521,8 @@ class PolykeyAgent { logger: logger.getChild(RPCServer.name + 'Agent'), }); } - const serverCrypto: ServerCrypto = { - async sign(key: ArrayBuffer, data: ArrayBuffer) { - const cryptoKey = await webcrypto.subtle.importKey( - 'raw', - key, - { - name: 'HMAC', - hash: 'SHA-256', - }, - true, - ['sign', 'verify'], - ); - return webcrypto.subtle.sign('HMAC', cryptoKey, data); - }, - async verify(key: ArrayBuffer, data: ArrayBuffer, sig: ArrayBuffer) { - const cryptoKey = await webcrypto.subtle.importKey( - 'raw', - key, - { - name: 'HMAC', - hash: 'SHA-256', - }, - true, - ['sign', 'verify'], - ); - return webcrypto.subtle.verify('HMAC', cryptoKey, sig, data); - }, - }; - quicServerAgent = - quicServerAgent ?? - new QUICServer({ - config: { - ...quicServerConfig_, - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - verifyPeer: true, - verifyAllowFail: true, - }, - crypto: { - key: keysUtils.generateKey(), - ops: serverCrypto, - }, - verifyCallback: networkUtils.verifyClientCertificateChain, - logger: logger.getChild(QUICServer.name + 'Agent'), - socket: quicSocket, - resolveHostname, - reasonToCode: utils.reasonToCode, - codeToReason: utils.codeToReason, - }); } catch (e) { logger.warn(`Failed Creating ${this.name}`); - await quicServerAgent?.stop({ force: true }); await quicSocket?.stop({ force: true }); await rpcServerAgent?.destroy(true); await rpcServerClient?.destroy(); @@ -612,7 +566,6 @@ class PolykeyAgent { webSocketServerClient, rpcServerAgent, quicSocket, - quicServerAgent, events, fs, logger, @@ -653,7 +606,6 @@ class PolykeyAgent { public readonly webSocketServerClient: WebSocketServer; public readonly rpcServerAgent: RPCServer; public readonly quicSocket: QUICSocket; - public readonly quicServerAgent: QUICServer; protected workerManager: PolykeyWorkerManagerInterface | undefined; constructor({ @@ -679,7 +631,6 @@ class PolykeyAgent { webSocketServerClient, rpcServerAgent, quicSocket, - quicServerAgent, events, fs, logger, @@ -706,7 +657,6 @@ class PolykeyAgent { webSocketServerClient: WebSocketServer; rpcServerAgent: RPCServer; quicSocket: QUICSocket; - quicServerAgent: QUICServer; events: EventBus; fs: FileSystem; logger: Logger; @@ -734,7 +684,6 @@ class PolykeyAgent { this.webSocketServerClient = webSocketServerClient; this.rpcServerAgent = rpcServerAgent; this.quicSocket = quicSocket; - this.quicServerAgent = quicServerAgent; this.events = events; this.fs = fs; } @@ -784,14 +733,8 @@ class PolykeyAgent { keyPrivatePem: keysUtils.privateKeyToPEM(data.keyPair.privateKey), certChainPem: await this.certManager.getCertPEMsChainPEM(), }; - // FIXME: Can we even support updating TLS config anymore? - // We would need to shut down the Websocket server and re-create it with the new config. - // Right now graceful shutdown is not supported. - // this.grpcServerClient.setTLSConfig(tlsConfig); - this.quicServerAgent.updateConfig({ - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - }); + this.webSocketServerClient.setTlsConfig(tlsConfig); + this.nodeConnectionManager.updateTlsConfig(tlsConfig); this.logger.info(`${KeyRing.name} change propagated`); }, ); @@ -876,79 +819,11 @@ class PolykeyAgent { port: _networkConfig.agentPort, ipv6Only: _networkConfig.ipv6Only, }); - // Setting up stream handling - const handleStream = async ( - event: quicEvents.QUICConnectionStreamEvent, - ) => { - // Streams are handled via the RPCServer. - const stream = event.detail; - this.rpcServerAgent.handleStream(stream); - }; - - const handleConnection = async ( - event: quicEvents.QUICServerConnectionEvent, - ) => { - // Needs to setup stream handler - const connection = event.detail; - try { - // Dispatch connection event - const remoteCertificates = connection.getRemoteCertsChain(); - if (remoteCertificates.length === 0) { - throw Error('remote certificates were not provided'); - } - const remoteCertPem = remoteCertificates[0]; - const remoteCert = keysUtils.certFromPEM( - remoteCertPem as CertificatePEM, - ); - if (remoteCert == null) throw Error('failed to parse certificate'); - const nodeId = keysUtils.certNodeId(remoteCert); - if (nodeId == null) throw Error('failed to extract NodeId from cert'); - const data: ConnectionData = { - remoteNodeId: nodeId, - remoteHost: connection.remoteHost as Host, - remotePort: connection.remotePort as Port, - }; - await this.events.emitAsync( - PolykeyAgent.eventSymbols.QUICServer, - data, - ); - } catch (e) { - this.logger.error(e.message); - await connection.stop({ - applicationError: true, - errorMessage: e.message, - force: true, - }); - } - - connection.addEventListener('connectionStream', handleStream); - connection.addEventListener( - 'connectionStop', - () => { - connection.removeEventListener('connectionStream', handleStream); - }, - { once: true }, - ); - }; - this.quicServerAgent.addEventListener( - 'serverConnection', - handleConnection, - ); - this.quicServerAgent.addEventListener( - 'serverStop', - () => { - this.quicServerAgent.removeEventListener( - 'serverConnection', - handleConnection, - ); - }, - { once: true }, - ); - // Finished setting up handling. - // No host or port is provided here, it's configured in the shared QUICSocket. - await this.quicServerAgent.start(); await this.nodeManager.start(); - await this.nodeConnectionManager.start({ nodeManager: this.nodeManager }); + await this.nodeConnectionManager.start({ + nodeManager: this.nodeManager, + handleStream: (stream) => this.rpcServerAgent.handleStream(stream), + }); await this.nodeGraph.start({ fresh }); await this.nodeManager.syncNodeGraph(false); await this.discovery.start({ fresh }); @@ -989,7 +864,6 @@ class PolykeyAgent { await this.nodeGraph?.stop(); await this.nodeConnectionManager?.stop(); await this.nodeManager?.stop(); - await this.quicServerAgent.stop(); await this.quicSocket.stop(); await this.webSocketServerClient.stop(true); await this.identitiesManager?.stop(); @@ -1025,7 +899,6 @@ class PolykeyAgent { await this.nodeConnectionManager.stop(); await this.nodeGraph.stop(); await this.nodeManager.stop(); - await this.quicServerAgent.stop(); await this.quicSocket.stop(); await this.webSocketServerClient.stop(true); await this.identitiesManager.stop(); diff --git a/src/agent/handlers/clientManifest.ts b/src/agent/handlers/clientManifest.ts index 35b58edb2..f94bb5b73 100644 --- a/src/agent/handlers/clientManifest.ts +++ b/src/agent/handlers/clientManifest.ts @@ -2,17 +2,18 @@ import type { AgentRPCRequestParams, AgentRPCResponseResult } from '../types'; import type { AgentClaimMessage, ClaimIdMessage, - GitPackMessage, HolePunchRelayMessage, NodeAddressMessage, NodeIdMessage, SignedNotificationEncoded, - VaultInfo, - VaultsGitInfoGetMessage, - VaultsGitPackGetMessage, VaultsScanMessage, } from './types'; -import { DuplexCaller, ServerCaller, UnaryCaller } from '../../rpc/callers'; +import { + DuplexCaller, + RawCaller, + ServerCaller, + UnaryCaller, +} from '../../rpc/callers'; const nodesClaimsGet = new ServerCaller< AgentRPCRequestParams, @@ -39,15 +40,9 @@ const notificationsSend = new UnaryCaller< AgentRPCResponseResult >(); -const vaultsGitInfoGet = new ServerCaller< - AgentRPCRequestParams, - AgentRPCResponseResult ->(); +const vaultsGitInfoGet = new RawCaller(); -const vaultsGitPackGet = new ServerCaller< - AgentRPCRequestParams, - AgentRPCResponseResult ->(); +const vaultsGitPackGet = new RawCaller(); const vaultsScan = new ServerCaller< AgentRPCRequestParams, diff --git a/src/agent/handlers/serverManifest.ts b/src/agent/handlers/serverManifest.ts index 8f0f5d76c..952f66e24 100644 --- a/src/agent/handlers/serverManifest.ts +++ b/src/agent/handlers/serverManifest.ts @@ -41,8 +41,8 @@ const serverManifest = (container: { nodesCrossSignClaim: new NodesCrossSignClaimHandler(container), nodesHolePunchMessageSend: new NodesHolePunchMessageSendHandler(container), notificationsSend: new NotificationsSendHandler(container), - VaultsGitInfoGet: new VaultsGitInfoGetHandler(container), - VaultsGitPackGet: new VaultsGitPackGetHandler(container), + vaultsGitInfoGet: new VaultsGitInfoGetHandler(container), + vaultsGitPackGet: new VaultsGitPackGetHandler(container), vaultsScan: new VaultsScanHandler(container), }; }; diff --git a/src/agent/handlers/types.ts b/src/agent/handlers/types.ts index f3a0df180..338997830 100644 --- a/src/agent/handlers/types.ts +++ b/src/agent/handlers/types.ts @@ -40,19 +40,3 @@ export type VaultInfo = { export type VaultsScanMessage = VaultInfo & { vaultPermissions: Array; }; - -export type VaultsGitInfoGetMessage = { - vaultNameOrId: VaultIdEncoded | VaultName; - action: VaultAction; -}; - -export type GitPackMessage = { - /** - * Chunk of data in binary form; - */ - chunk: string; -}; - -export type VaultsGitPackGetMessage = { - body: string; -}; diff --git a/src/agent/handlers/vaultsGitInfoGet.ts b/src/agent/handlers/vaultsGitInfoGet.ts index 969bbb8d8..e0f6b4623 100644 --- a/src/agent/handlers/vaultsGitInfoGet.ts +++ b/src/agent/handlers/vaultsGitInfoGet.ts @@ -1,64 +1,57 @@ -import type { GitPackMessage, VaultInfo } from './types'; -import type { AgentRPCRequestParams, AgentRPCResponseResult } from '../types'; import type { DB } from '@matrixai/db'; import type { VaultManager } from '../../vaults'; import type { ACL } from '../../acl'; import type Logger from '@matrixai/logger'; -import type { VaultsGitInfoGetMessage } from './types'; -import type { VaultAction } from '../../vaults/types'; +import type { JSONRPCRequest } from '../../rpc/types'; +import type { ContextTimed } from '@matrixai/contexts'; +import type { JSONValue } from '../../types'; +import { ReadableStream } from 'stream/web'; import * as agentErrors from '../errors'; import * as vaultsUtils from '../../vaults/utils'; import * as vaultsErrors from '../../vaults/errors'; -import { ServerHandler } from '../../rpc/handlers'; -import { validateSync } from '../../validation'; -import { matchSync } from '../../utils'; +import { RawHandler } from '../../rpc/handlers'; +import { never } from '../../utils'; import * as validationUtils from '../../validation/utils'; import * as nodesUtils from '../../nodes/utils'; import * as agentUtils from '../utils'; +import * as utils from '../../utils'; -class VaultsGitInfoGetHandler extends ServerHandler< - { - db: DB; - vaultManager: VaultManager; - acl: ACL; - logger: Logger; - }, - AgentRPCRequestParams, - AgentRPCResponseResult -> { - public async *handle( - input: AgentRPCRequestParams, +class VaultsGitInfoGetHandler extends RawHandler<{ + db: DB; + vaultManager: VaultManager; + acl: ACL; + logger: Logger; +}> { + public async handle( + input: [JSONRPCRequest, ReadableStream], _cancel, - meta, - ): AsyncGenerator { + meta: Record | undefined, + _ctx: ContextTimed, // TODO: use + ): Promise<[JSONValue, ReadableStream]> { const { db, vaultManager, acl } = this.container; - yield* db.withTransactionG(async function* ( - tran, - ): AsyncGenerator { + const [headerMessage, inputStream] = input; + await inputStream.cancel(); + const params = headerMessage.params; + if (params == null || !utils.isObject(params)) never(); + if ( + !('vaultNameOrId' in params) || + typeof params.vaultNameOrId != 'string' + ) { + never(); + } + if (!('action' in params) || typeof params.action != 'string') never(); + const vaultNameOrId = params.vaultNameOrId; + const actionType = validationUtils.parseVaultAction(params.action); + const data = await db.withTransactionF(async (tran) => { const vaultIdFromName = await vaultManager.getVaultId( - input.vaultNameOrId, + vaultNameOrId, tran, ); const vaultId = - vaultIdFromName ?? vaultsUtils.decodeVaultId(input.vaultNameOrId); + vaultIdFromName ?? vaultsUtils.decodeVaultId(vaultNameOrId); if (vaultId == null) { throw new vaultsErrors.ErrorVaultsVaultUndefined(); } - const { - actionType, - }: { - actionType: VaultAction; - } = validateSync( - (keyPath, value) => { - return matchSync(keyPath)( - [['actionType'], () => validationUtils.parseVaultAction(value)], - () => value, - ); - }, - { - actionType: input.action, - }, - ); const vaultName = (await vaultManager.getVaultMeta(vaultId, tran)) ?.vaultName; if (vaultName == null) { @@ -84,21 +77,37 @@ class VaultsGitInfoGetHandler extends ServerHandler< )}`, ); } - - yield { - vaultName: vaultName, - vaultIdEncoded: vaultsUtils.encodeVaultId(vaultId), + return { + vaultId, + vaultName, }; - for await (const byte of vaultManager.handleInfoRequest(vaultId, tran)) { - if (byte !== null) { - yield { - chunk: byte.toString('binary'), - }; - } else { + }); + + let handleInfoRequestGen: AsyncGenerator; + const stream = new ReadableStream({ + start: async () => { + handleInfoRequestGen = vaultManager.handleInfoRequest(data.vaultId); + }, + pull: async (controller) => { + const result = await handleInfoRequestGen.next(); + if (result.done) { + controller.close(); return; + } else { + controller.enqueue(result.value); } - } + }, + cancel: async (reason) => { + await handleInfoRequestGen.throw(reason).catch(() => {}); + }, }); + return [ + { + vaultName: data.vaultName, + vaultIdEncoded: vaultsUtils.encodeVaultId(data.vaultId), + }, + stream, + ]; } } diff --git a/src/agent/handlers/vaultsGitPackGet.ts b/src/agent/handlers/vaultsGitPackGet.ts index 1f03862f1..6a01f0572 100644 --- a/src/agent/handlers/vaultsGitPackGet.ts +++ b/src/agent/handlers/vaultsGitPackGet.ts @@ -1,105 +1,108 @@ import type { DB } from '@matrixai/db'; -import type { GitPackMessage, VaultsGitPackGetMessage } from './types'; -import type { AgentRPCRequestParams, AgentRPCResponseResult } from '../types'; -import type { VaultAction, VaultName } from '../../vaults/types'; +import type { VaultName } from '../../vaults/types'; import type VaultManager from '../../vaults/VaultManager'; import type ACL from '../../acl/ACL'; +import type { JSONValue } from '../../types'; +import type { PassThrough } from 'readable-stream'; +import type { JSONRPCRequest } from '../../rpc/types'; +import { ReadableStream } from 'stream/web'; +import * as utils from '../../utils'; import * as agentErrors from '../errors'; import * as agentUtils from '../utils'; import * as nodesUtils from '../../nodes/utils'; import * as vaultsUtils from '../../vaults/utils'; import * as vaultsErrors from '../../vaults/errors'; -import { validateSync } from '../../validation'; -import { matchSync } from '../../utils'; +import { never } from '../../utils'; import * as validationUtils from '../../validation/utils'; -import { ServerHandler } from '../../rpc/handlers'; +import { RawHandler } from '../../rpc/handlers'; -class VaultsGitPackGetHandler extends ServerHandler< - { - vaultManager: VaultManager; - acl: ACL; - db: DB; - }, - AgentRPCRequestParams, - AgentRPCResponseResult -> { - public async *handle( - input: AgentRPCRequestParams, +class VaultsGitPackGetHandler extends RawHandler<{ + vaultManager: VaultManager; + acl: ACL; + db: DB; +}> { + public async handle( + input: [JSONRPCRequest, ReadableStream], _cancel, meta, - ): AsyncGenerator> { + ): Promise<[JSONValue, ReadableStream]> { const { vaultManager, acl, db } = this.container; + const [headerMessage, inputStream] = input; const requestingNodeId = agentUtils.nodeIdFromMeta(meta); if (requestingNodeId == null) { throw new agentErrors.ErrorAgentNodeIdMissing(); } const nodeIdEncoded = nodesUtils.encodeNodeId(requestingNodeId); - const nameOrId = meta.get('vaultNameOrId').pop()!.toString(); - yield* db.withTransactionG(async function* ( - tran, - ): AsyncGenerator> { - const vaultIdFromName = await vaultManager.getVaultId( - nameOrId as VaultName, - tran, - ); - const vaultId = vaultIdFromName ?? vaultsUtils.decodeVaultId(nameOrId); - if (vaultId == null) { - throw new vaultsErrors.ErrorVaultsVaultUndefined(); - } - const { - actionType, - }: { - actionType: VaultAction; - } = validateSync( - (keyPath, value) => { - return matchSync(keyPath)( - [['actionType'], () => validationUtils.parseVaultAction(value)], - () => value, - ); - }, - { - actionType: meta.get('vaultAction').pop()!.toString(), - }, - ); - // Checking permissions - const permissions = await acl.getNodePerm(requestingNodeId, tran); - const vaultPerms = permissions?.vaults[vaultId]; - if (vaultPerms?.[actionType] !== null) { - throw new vaultsErrors.ErrorVaultsPermissionDenied( - `${nodeIdEncoded} does not have permission to ${actionType} from vault ${vaultsUtils.encodeVaultId( - vaultId, - )}`, + const params = headerMessage.params; + if (params == null || !utils.isObject(params)) never(); + if (!('nameOrId' in params) || typeof params.nameOrId != 'string') { + never(); + } + if (!('vaultAction' in params) || typeof params.vaultAction != 'string') { + never(); + } + const nameOrId = params.nameOrId; + const actionType = validationUtils.parseVaultAction(params.vaultAction); + const [vaultIdFromName, permissions] = await db.withTransactionF( + async (tran) => { + const vaultIdFromName = await vaultManager.getVaultId( + nameOrId as VaultName, + tran, ); - } - const [sideBand, progressStream] = await vaultManager.handlePackRequest( - vaultId, - Buffer.from(input.body, 'utf-8'), - tran, + const permissions = await acl.getNodePerm(requestingNodeId, tran); + + return [vaultIdFromName, permissions]; + }, + ); + const vaultId = vaultIdFromName ?? vaultsUtils.decodeVaultId(nameOrId); + if (vaultId == null) { + throw new vaultsErrors.ErrorVaultsVaultUndefined(); + } + // Checking permissions + const vaultPerms = permissions?.vaults[vaultId]; + if (vaultPerms?.[actionType] !== null) { + throw new vaultsErrors.ErrorVaultsPermissionDenied( + `${nodeIdEncoded} does not have permission to ${actionType} from vault ${vaultsUtils.encodeVaultId( + vaultId, + )}`, ); - yield { - chunk: Buffer.from('0008NAK\n').toString('binary'), - }; - const responseBuffers: Uint8Array[] = []; - // FIXME: this WHOLE thing needs to change, why are we streaming when we send monolithic messages? - const result = await new Promise((resolve, reject) => { + } + + // Getting data + let sideBand: PassThrough; + let progressStream: PassThrough; + const outputStream = new ReadableStream({ + start: async (controller) => { + const body = new Array(); + for await (const message of inputStream) { + body.push(message); + } + [sideBand, progressStream] = await vaultManager.handlePackRequest( + vaultId, + Buffer.concat(body), + ); + controller.enqueue(Buffer.from('0008NAK\n')); sideBand.on('data', async (data: Uint8Array) => { - responseBuffers.push(data); + controller.enqueue(data); + sideBand.pause(); }); sideBand.on('end', async () => { - const result = Buffer.concat(responseBuffers).toString('binary'); - resolve(result); + controller.close(); }); - sideBand.on('error', (err) => { - reject(err); + sideBand.on('error', (e) => { + controller.error(e); }); progressStream.write(Buffer.from('0014progress is at 50%\n')); progressStream.end(); - }); - yield { - chunk: result, - }; + }, + pull: () => { + sideBand.resume(); + }, + cancel: (e) => { + sideBand.destroy(e); + }, }); - return; + return [null, outputStream]; } } diff --git a/src/bootstrap/utils.ts b/src/bootstrap/utils.ts index aaee2bd58..b6659c493 100644 --- a/src/bootstrap/utils.ts +++ b/src/bootstrap/utils.ts @@ -14,7 +14,7 @@ import { Sigchain } from '../sigchain'; import { ACL } from '../acl'; import { GestaltGraph } from '../gestalts'; import { KeyRing } from '../keys'; -import { NodeConnectionManager, NodeGraph, NodeManager } from '../nodes'; +import { NodeGraph, NodeManager } from '../nodes'; import { VaultManager } from '../vaults'; import { NotificationsManager } from '../notifications'; import { mkdirExists } from '../utils'; @@ -153,19 +153,11 @@ async function bootstrapState({ logger, lazy: true, }); - const nodeConnectionManager = new NodeConnectionManager({ - keyRing, - nodeGraph, - quicClientConfig: {} as any, // No connections are attempted - crypto: {} as any, // No connections are attempted - quicSocket: {} as any, // No connections are attempted - logger: logger.getChild(NodeConnectionManager.name), - }); const nodeManager = new NodeManager({ db, keyRing, nodeGraph, - nodeConnectionManager, + nodeConnectionManager: {} as any, // No connections are attempted sigchain, taskManager, gestaltGraph, @@ -175,7 +167,7 @@ async function bootstrapState({ await NotificationsManager.createNotificationsManager({ acl, db, - nodeConnectionManager, + nodeConnectionManager: {} as any, // No connections are attempted nodeManager, keyRing, logger: logger.getChild(NotificationsManager.name), @@ -186,7 +178,7 @@ async function bootstrapState({ db, gestaltGraph, keyRing, - nodeConnectionManager, + nodeConnectionManager: {} as any, // No connections are attempted vaultsPath, notificationsManager, logger: logger.getChild(VaultManager.name), diff --git a/src/config.ts b/src/config.ts index 7cd5c4809..4de7c276d 100644 --- a/src/config.ts +++ b/src/config.ts @@ -90,30 +90,54 @@ const config = { certDuration: 31536000, }, networkConfig: { - // Config for the QUICSocket - agentHost: '127.0.0.1', + /** + * Agent host defaults to `::` dual stack. + * This is because the agent service is supposed to be public. + */ + agentHost: '::', agentPort: 0, - ipv6Only: false, - // Config for the websocket server - clientHost: '127.0.0.1', + /** + * Client host defaults to `localhost`. + * This will depend on the OS configuration. + * Usually it will be IPv4 `127.0.0.1` or IPv6 `::1`. + * This is because the client service is private most of the time. + */ + clientHost: 'localhost', clientPort: 0, - // Websocket server config - maxReadableStreamBytes: 1_000_000_000, // About 1 GB - maxIdleTimeout: 120, // 2 minutes - pingIntervalTime: 1_000, // 1 second - pingTimeoutTimeTime: 10_000, // 10 seconds - // RPC config + /** + * If using dual stack `::`, then this forces only IPv6 bindings. + */ + ipv6Only: false, + + /** + * Agent service transport keep alive interval time. + * This the maxmum time between keep alive messages. + * This only has effect if `agentMaxIdleTimeout` is greater than 0. + * See the transport layer for further details. + */ + agentKeepAliveIntervalTime: 10_000, // 10 seconds + + /** + * Agent service transport max idle timeout. + * This is the maximum time that a connection can be idle. + * This also controls how long the transport layer will dial + * for a client connection. + * See the transport layer for further details. + */ + agentMaxIdleTimeout: 60_000, // 1 minute + + clientMaxIdleTimeout: 120, // 2 minutes + clientPingIntervalTime: 1_000, // 1 second + clientPingTimeoutTimeTime: 10_000, // 10 seconds + + /** + * Controls the stream parser buffer limit. + * This is the maximum number of bytes that the stream parser + * will buffer before rejecting the RPC call. + */ clientParserBufferByteLimit: 1_000_000, // About 1MB - handlerTimeoutTime: 60_000, // 1 minute - handlerTimeoutGraceTime: 2_000, // 2 seconds - }, - quicServerConfig: { - keepAliveIntervalTime: 10_000, // 10 seconds - maxIdleTimeout: 60_000, // 1 minute - }, - quicClientConfig: { - keepAliveIntervalTime: 10_000, // 10 seconds - maxIdleTimeout: 60_000, // 1 minute + clientHandlerTimeoutTime: 60_000, // 1 minute + clientHandlerTimeoutGraceTime: 2_000, // 2 seconds }, nodeConnectionManagerConfig: { connectionConnectTime: 2000, diff --git a/src/index.ts b/src/index.ts index e9b64d89e..25c85c720 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,5 +1,38 @@ -import PolykeyAgent from './PolykeyAgent'; -import PolykeyClient from './PolykeyClient'; -import * as errors from './errors'; +export { default as PolykeyAgent } from './PolykeyAgent'; +export { default as PolykeyClient } from './PolykeyClient'; +export { default as ErrorPolykey } from './ErrorPolykey'; +export { default as config } from './config'; +export * as utils from './utils'; +export * as errors from './errors'; +export * from './types'; -export { PolykeyAgent, PolykeyClient, errors }; +// Subdomains for Polykey +// Users should prefer importing them directly to avoid importing the entire +// kitchen sink here + +export * as acl from './acl'; +export * as agent from './agent'; +export * as bootstrap from './bootstrap'; +export * as claims from './claims'; +export * as client from './client'; +export * as discovery from './discovery'; +export * as events from './events'; +export * as gestalts from './gestalts'; +export * as git from './git'; +export * as http from './http'; +export * as identities from './identities'; +export * as ids from './ids'; +export * as keys from './keys'; +export * as network from './network'; +export * as nodes from './nodes'; +export * as notifications from './notifications'; +export * as rpc from './rpc'; +export * as schema from './schema'; +export * as sessions from './sessions'; +export * as sigchain from './sigchain'; +export * as status from './status'; +export * as tasks from './tasks'; +export * as tokens from './tokens'; +export * as validation from './validation'; +export * as vaults from './vaults'; +export * as workers from './workers'; diff --git a/src/nodes/NodeConnection.ts b/src/nodes/NodeConnection.ts index 65667232d..b0eae9da2 100644 --- a/src/nodes/NodeConnection.ts +++ b/src/nodes/NodeConnection.ts @@ -1,10 +1,15 @@ import type { ContextTimed } from '@matrixai/contexts'; import type { PromiseCancellable } from '@matrixai/async-cancellable'; -import type { NodeId, QUICClientConfig } from './types'; -import type { Host, Hostname, Port } from '../network/types'; -import type { CertificatePEM } from '../keys/types'; -import type { ClientManifest } from '../rpc/types'; -import type { QUICSocket, ClientCrypto } from '@matrixai/quic'; +import type { NodeId } from './types'; +import type { Host, Hostname, Port, TLSConfig } from '../network/types'; +import type { Certificate, CertificatePEM } from '../keys/types'; +import type { ClientManifest, RPCStream } from '../rpc/types'; +import type { + QUICSocket, + ClientCrypto, + QUICConnection, + events as quicEvents, +} from '@matrixai/quic'; import type { ContextTimedInput } from '@matrixai/contexts/dist/types'; import type { X509Certificate } from '@peculiar/x509'; import Logger from '@matrixai/logger'; @@ -17,7 +22,9 @@ import RPCClient from '../rpc/RPCClient'; import * as networkUtils from '../network/utils'; import * as rpcUtils from '../rpc/utils'; import * as keysUtils from '../keys/utils'; +import * as nodesUtils from '../nodes/utils'; import { never } from '../utils'; +import * as utils from '../utils'; /** * Encapsulates the unidirectional client-side connection of one node to another. @@ -46,28 +53,33 @@ class NodeConnection extends EventTarget { public readonly certChain: Readonly[]; protected logger: Logger; - public readonly quicClient: QUICClient; + public readonly quicClient: QUICClient | undefined; + public readonly quicConnection: QUICConnection; public readonly rpcClient: RPCClient; static createNodeConnection( { + handleStream, targetNodeIds, targetHost, targetPort, targetHostname, - quicClientConfig, + tlsConfig, + connectionKeepAliveIntervalTime, + connectionMaxIdleTimeout = 60_000, quicSocket, manifest, logger, }: { + handleStream: (stream: RPCStream) => void; targetNodeIds: Array; targetHost: Host; targetPort: Port; targetHostname?: Hostname; - quicClientConfig: QUICClientConfig; - crypto: { - ops: ClientCrypto; - }; + crypto: ClientCrypto; + tlsConfig: TLSConfig; + connectionKeepAliveIntervalTime?: number; + connectionMaxIdleTimeout?: number; quicSocket?: QUICSocket; manifest: M; logger?: Logger; @@ -77,26 +89,30 @@ class NodeConnection extends EventTarget { @timedCancellable(true, 20000) static async createNodeConnection( { + handleStream, targetNodeIds, targetHost, targetPort, targetHostname, - quicClientConfig, crypto, - quicSocket, + tlsConfig, manifest, + connectionKeepAliveIntervalTime, + connectionMaxIdleTimeout = 60_000, + quicSocket, logger = new Logger(this.name), }: { + handleStream: (stream: RPCStream) => void; targetNodeIds: Array; targetHost: Host; targetPort: Port; targetHostname?: Hostname; - quicClientConfig: QUICClientConfig; - crypto: { - ops: ClientCrypto; - }; - quicSocket?: QUICSocket; + crypto: ClientCrypto; + tlsConfig: TLSConfig; manifest: M; + connectionKeepAliveIntervalTime?: number; + connectionMaxIdleTimeout?: number; + quicSocket?: QUICSocket; logger?: Logger; }, @context ctx: ContextTimed, @@ -106,7 +122,6 @@ class NodeConnection extends EventTarget { if (networkUtils.isHostWildcard(targetHost)) { throw new nodesErrors.ErrorNodeConnectionHostWildcard(); } - const clientLogger = logger.getChild(RPCClient.name); let validatedNodeId: NodeId | undefined; const quicClient = await QUICClient.createQUICClient( { @@ -114,10 +129,13 @@ class NodeConnection extends EventTarget { port: targetPort, socket: quicSocket, config: { + keepAliveIntervalTime: connectionKeepAliveIntervalTime, + maxIdleTimeout: connectionMaxIdleTimeout, verifyPeer: true, verifyAllowFail: true, ca: undefined, - ...quicClientConfig, + key: tlsConfig.keyPrivatePem, + cert: tlsConfig.certChainPem, }, verifyCallback: async (certPEMs) => { validatedNodeId = await networkUtils.verifyServerCertificateChain( @@ -125,18 +143,41 @@ class NodeConnection extends EventTarget { certPEMs, ); }, - crypto: crypto, + crypto: { + ops: crypto, + }, + reasonToCode: utils.reasonToCode, + codeToReason: utils.codeToReason, logger: logger.getChild(QUICClient.name), }, ctx, ); + const quicConnection = quicClient.connection; + // Setting up stream handling + const handleConnectionStream = ( + streamEvent: quicEvents.QUICConnectionStreamEvent, + ) => { + const stream = streamEvent.detail; + handleStream(stream); + }; + quicConnection.addEventListener('connectionStream', handleConnectionStream); + quicConnection.addEventListener( + 'connectionStop', + () => { + quicConnection.removeEventListener( + 'connectionStream', + handleConnectionStream, + ); + }, + { once: true }, + ); const rpcClient = await RPCClient.createRPCClient({ manifest, middlewareFactory: rpcUtils.defaultClientMiddlewareWrapper(), streamFactory: () => { - return quicClient.connection.streamNew(); + return quicConnection.streamNew(); }, - logger: clientLogger, + logger: logger.getChild(RPCClient.name), }); if (validatedNodeId == null) never(); // Obtaining remote node ID from certificate chain. It should always exist in the chain if validated. @@ -151,6 +192,7 @@ class NodeConnection extends EventTarget { if (certChain == null) never(); const nodeId = keysUtils.certNodeId(certChain[0]); if (nodeId == null) never(); + const newLogger = logger.getParent() ?? new Logger(this.name); const nodeConnection = new this({ validatedNodeId, nodeId, @@ -161,8 +203,13 @@ class NodeConnection extends EventTarget { certChain, hostname: targetHostname, quicClient, + quicConnection, rpcClient, - logger, + logger: newLogger.getChild( + `${this.name} [${nodesUtils.encodeNodeId(nodeId)}@${ + quicConnection.remoteHost + }:${quicConnection.remotePort}]`, + ), }); quicClient.addEventListener( 'clientDestroy', @@ -176,6 +223,77 @@ class NodeConnection extends EventTarget { return nodeConnection; } + static async createNodeConnectionReverse({ + handleStream, + certChain, + nodeId, + quicConnection, + manifest, + logger = new Logger(this.name), + }: { + handleStream: (stream: RPCStream) => void; + certChain: Array; + nodeId: NodeId; + quicConnection: QUICConnection; + manifest: M; + logger?: Logger; + }): Promise> { + logger.info(`Creating ${this.name}`); + // Creating RPCClient + const rpcClient = await RPCClient.createRPCClient({ + manifest, + middlewareFactory: rpcUtils.defaultClientMiddlewareWrapper(), + streamFactory: () => { + return quicConnection.streamNew(); + }, + logger: logger.getChild(RPCClient.name), + }); + // Setting up stream handling + const handleConnectionStream = ( + streamEvent: quicEvents.QUICConnectionStreamEvent, + ) => { + const stream = streamEvent.detail; + handleStream(stream); + }; + quicConnection.addEventListener('connectionStream', handleConnectionStream); + quicConnection.addEventListener( + 'connectionStop', + () => { + quicConnection.removeEventListener( + 'connectionStream', + handleConnectionStream, + ); + }, + { once: true }, + ); + // Creating NodeConnection + const nodeConnection = new this({ + validatedNodeId: nodeId, + nodeId: nodeId, + localHost: quicConnection.localHost as Host, + localPort: quicConnection.localPort as Port, + host: quicConnection.remoteHost as Host, + port: quicConnection.remotePort as Port, + certChain, + // Hostname and client are not available on reverse connections + hostname: undefined, + quicClient: undefined, + quicConnection, + rpcClient, + logger, + }); + quicConnection.addEventListener( + 'connectionStop', + async () => { + // Trigger the nodeConnection destroying + await nodeConnection.destroy({ force: false }); + }, + { once: true }, + ); + logger.info(`Created ${this.name}`); + return nodeConnection; + } + constructor({ validatedNodeId, nodeId, @@ -186,6 +304,7 @@ class NodeConnection extends EventTarget { certChain, hostname, quicClient, + quicConnection, rpcClient, logger, }: { @@ -197,7 +316,8 @@ class NodeConnection extends EventTarget { localPort: Port; certChain: Readonly[]; hostname?: Hostname; - quicClient: QUICClient; + quicClient?: QUICClient; + quicConnection: QUICConnection; rpcClient: RPCClient; logger: Logger; }) { @@ -211,6 +331,7 @@ class NodeConnection extends EventTarget { this.certChain = certChain; this.hostname = hostname; this.quicClient = quicClient; + this.quicConnection = quicConnection; this.rpcClient = rpcClient; this.logger = logger; } @@ -221,7 +342,18 @@ class NodeConnection extends EventTarget { force?: boolean; } = {}) { this.logger.info(`Destroying ${this.constructor.name}`); - await this.quicClient.destroy({ force }); + await this.quicClient?.destroy({ force }); + // This is only needed for reverse connections, otherwise it is handled by the quicClient. + await this.quicConnection.stop( + force + ? { + applicationError: true, + errorCode: 0, + errorMessage: 'NodeConnection is forcing destruction', + force: true, + } + : {}, + ); await this.rpcClient.destroy(); this.logger.debug(`${this.constructor.name} triggered destroyed event`); this.dispatchEvent(new nodesEvents.NodeConnectionDestroyEvent()); diff --git a/src/nodes/NodeConnectionManager.ts b/src/nodes/NodeConnectionManager.ts index 40e1dc6f7..13c0f629b 100644 --- a/src/nodes/NodeConnectionManager.ts +++ b/src/nodes/NodeConnectionManager.ts @@ -1,6 +1,7 @@ -import type { QUICSocket } from '@matrixai/quic'; +import type { QUICConnection, QUICSocket } from '@matrixai/quic'; import type { ResourceAcquire } from '@matrixai/resources'; import type { ContextTimed } from '@matrixai/contexts'; +import type { CertificatePEM } from '../keys/types'; import type KeyRing from '../keys/KeyRing'; import type { Host, Hostname, Port } from '../network/types'; import type NodeGraph from './NodeGraph'; @@ -9,15 +10,17 @@ import type { NodeData, NodeId, NodeIdString, - QUICClientConfig, SeedNodes, } from './types'; import type NodeManager from './NodeManager'; -import type { PromiseCancellable } from '@matrixai/async-cancellable'; import type { LockRequest } from '@matrixai/async-locks/dist/types'; import type { HolePunchRelayMessage } from '../agent/handlers/types'; import type { ClientCrypto } from '@matrixai/quic'; import type { ContextTimedInput } from '@matrixai/contexts/dist/types'; +import type { RPCStream } from '../rpc/types'; +import type { TLSConfig } from '../network/types'; +import type { ServerCrypto, events as QuicEvents } from '@matrixai/quic'; +import type { PromiseCancellable } from '@matrixai/async-cancellable'; import { withF } from '@matrixai/resources'; import Logger from '@matrixai/logger'; import { ready, StartStop } from '@matrixai/async-init/dist/StartStop'; @@ -25,6 +28,7 @@ import { IdInternal } from '@matrixai/id'; import { Lock, LockBox } from '@matrixai/async-locks'; import { Timer } from '@matrixai/timer'; import { timedCancellable, context } from '@matrixai/contexts/dist/decorators'; +import { QUICServer } from '@matrixai/quic'; import NodeConnection from './NodeConnection'; import * as nodesUtils from './utils'; import * as nodesErrors from './errors'; @@ -33,7 +37,7 @@ import * as networkUtils from '../network/utils'; import { never } from '../utils'; import * as utils from '../utils'; import { clientManifest as agentClientManifest } from '../agent/handlers/clientManifest'; -import { getRandomBytes } from '../keys/utils/random'; +import * as keysUtils from '../keys/utils'; // TODO: check all locking and add cancellation for it. @@ -79,10 +83,13 @@ class NodeConnectionManager { */ public readonly connectionHolePunchIntervalTime: number; + protected handleStream: (stream: RPCStream) => void = + () => never() as (stream: RPCStream) => void; protected logger: Logger; protected nodeGraph: NodeGraph; protected keyRing: KeyRing; protected quicSocket: QUICSocket; + protected quicServer: QUICServer; // NodeManager has to be passed in during start to allow co-dependency protected nodeManager: NodeManager | undefined; protected seedNodes: SeedNodes; @@ -105,33 +112,39 @@ class NodeConnectionManager { > = new Map(); protected backoffDefault: number = 1000 * 60 * 5; // 5 min protected backoffMultiplier: number = 2; // Doubles every failure - protected quicClientConfig: QUICClientConfig; - protected crypto: { - ops: ClientCrypto; + protected tlsConfig: TLSConfig; + protected connectionKeepAliveIntervalTime: number; + protected connectionMaxIdleTimeout: number; + protected crypto: ServerCrypto & ClientCrypto; + protected serverConnectionHandler = async ( + connectionEvent: QuicEvents.QUICServerConnectionEvent, + ) => { + const quicConnection = connectionEvent.detail; + await this.handleConnectionReverse(quicConnection); }; public constructor({ keyRing, nodeGraph, quicSocket, - quicClientConfig, crypto, + tlsConfig, seedNodes = {}, initialClosestNodes = 3, - connectionConnectTime = 2000, - connectionTimeoutTime = 60000, - pingTimeoutTime = 2000, - connectionHolePunchTimeoutTime = 4000, + connectionConnectTime = 2_000, + connectionTimeoutTime = 60_000, + pingTimeoutTime = 2_000, + connectionHolePunchTimeoutTime = 4_000, connectionHolePunchIntervalTime = 250, + connectionKeepAliveIntervalTime = 10_000, + connectionMaxIdleTimeout = 60_000, logger, }: { keyRing: KeyRing; nodeGraph: NodeGraph; quicSocket: QUICSocket; - quicClientConfig: QUICClientConfig; - crypto: { - ops: ClientCrypto; - }; + crypto: ServerCrypto & ClientCrypto; + tlsConfig: TLSConfig; seedNodes?: SeedNodes; initialClosestNodes?: number; connectionConnectTime?: number; @@ -139,13 +152,15 @@ class NodeConnectionManager { pingTimeoutTime?: number; connectionHolePunchTimeoutTime?: number; connectionHolePunchIntervalTime?: number; - logger?: Logger; + connectionKeepAliveIntervalTime?: number; + connectionMaxIdleTimeout?: number; + logger: Logger; }) { this.logger = logger ?? new Logger(NodeConnectionManager.name); this.keyRing = keyRing; this.nodeGraph = nodeGraph; this.quicSocket = quicSocket; - this.quicClientConfig = quicClientConfig; + this.tlsConfig = tlsConfig; this.crypto = crypto; const localNodeIdEncoded = nodesUtils.encodeNodeId(keyRing.getNodeId()); delete seedNodes[localNodeIdEncoded]; @@ -156,9 +171,41 @@ class NodeConnectionManager { this.connectionHolePunchTimeoutTime = connectionHolePunchTimeoutTime; this.connectionHolePunchIntervalTime = connectionHolePunchIntervalTime; this.pingTimeoutTime = pingTimeoutTime; + this.connectionKeepAliveIntervalTime = connectionKeepAliveIntervalTime; + this.connectionMaxIdleTimeout = connectionMaxIdleTimeout; + // Setting up QUICServer + const resolveHostname = (host) => { + return networkUtils.resolveHostname(host)[0] ?? ''; + }; + this.quicServer = new QUICServer({ + config: { + keepAliveIntervalTime: connectionKeepAliveIntervalTime, + maxIdleTimeout: connectionMaxIdleTimeout, + key: tlsConfig.keyPrivatePem, + cert: tlsConfig.certChainPem, + verifyPeer: true, + verifyAllowFail: true, + }, + crypto: { + key: keysUtils.generateKey(), + ops: crypto, + }, + verifyCallback: networkUtils.verifyClientCertificateChain, + logger: logger.getChild(QUICServer.name + 'Agent'), + socket: quicSocket, + resolveHostname, + reasonToCode: utils.reasonToCode, + codeToReason: utils.codeToReason, + }); } - public async start({ nodeManager }: { nodeManager: NodeManager }) { + public async start({ + nodeManager, + handleStream, + }: { + nodeManager: NodeManager; + handleStream: (stream: RPCStream) => void; + }) { this.logger.info(`Starting ${this.constructor.name}`); this.nodeManager = nodeManager; // Adding seed nodes @@ -171,11 +218,23 @@ class NodeConnectionManager { true, ); } + this.handleStream = handleStream; + // Starting QUICServer + // No host or port is provided here, it's configured in the shared QUICSocket. + await this.quicServer.start(); + this.quicServer.addEventListener( + 'serverConnection', + this.serverConnectionHandler, + ); this.logger.info(`Started ${this.constructor.name}`); } public async stop() { this.logger.info(`Stopping ${this.constructor.name}`); + this.quicServer.removeEventListener( + 'serverConnection', + this.serverConnectionHandler, + ); this.nodeManager = undefined; const destroyProms: Array> = []; for (const [nodeId, connAndTimer] of this.connections) { @@ -187,6 +246,8 @@ class NodeConnectionManager { destroyProms.push(destroyProm); } await Promise.all(destroyProms); + await this.quicServer.stop({ force: true }); + this.handleStream = () => never(); this.logger.info(`Stopped ${this.constructor.name}`); } @@ -511,10 +572,10 @@ class NodeConnectionManager { await Promise.allSettled(connProms); } if (connectionsResults.size === 0) { + // TODO: This needs to throw if none were established. + // The usual use case is a single node, this shouldn't be a aggregate error type. throw Error('No connections established!'); } - // TODO: This needs to throw if none were established. - // The usual use case is a single node, this shouldn't be a aggregate error type. return connectionsResults; } @@ -534,36 +595,43 @@ class NodeConnectionManager { }, connectionsResults: Map, ctx: ContextTimed, - ) { + ): Promise { // TODO: do we bother with a concurrency limit for now? It's simple to use a semaphore. // TODO: if all connections fail then this needs to throw. Or does it? Do we just report the allSettled result? - // TODO: add ICE. Create hole punch relay proms. // 1. attempt connection to an address this.logger.debug( `establishing single connection for address ${address.host}:${address.port}`, ); + const iceProm = this.initiateHolePunch(nodeIds, ctx); const connection = await NodeConnection.createNodeConnection( { + handleStream: this.handleStream, targetNodeIds: nodeIds, manifest: agentClientManifest, - quicClientConfig: this.quicClientConfig, crypto: this.crypto, targetHost: address.host, targetPort: address.port, + tlsConfig: this.tlsConfig, + connectionKeepAliveIntervalTime: this.connectionKeepAliveIntervalTime, + connectionMaxIdleTimeout: this.connectionMaxIdleTimeout, quicSocket: this.quicSocket, logger: this.logger.getChild( `${NodeConnection.name} [${address.host}:${address.port}]`, ), }, ctx, - ).catch((e) => { - this.logger.debug( - `establish single connection failed for ${address.host}:${address.port} with ${e.message}`, - ); - throw e; - }); - // TODO: finally cancel ICE. Use signal and await all settled + ) + .catch((e) => { + this.logger.debug( + `establish single connection failed for ${address.host}:${address.port} with ${e.message}`, + ); + throw e; + }) + .finally(async () => { + iceProm.cancel('Connection was established'); + await iceProm; + }); // 2. if established then add to result map const nodeId = connection.nodeId; const nodeIdString = nodeId.toString() as NodeIdString; @@ -577,30 +645,105 @@ class NodeConnectionManager { throw Error( 'TMP IMP, This should be exceedingly rare, lets see if it happens', ); - return; + // Return; } // Final setup + const newConnAndTimer = this.addConnection(nodeId, connection); + // We can assume connection was established and destination was valid, we can add the target to the nodeGraph + await this.nodeManager?.setNode(nodeId, { + host: address.host, + port: address.port, + }); + connectionsResults.set(nodeIdString, newConnAndTimer); + this.logger.debug( + `Created NodeConnection for ${nodesUtils.encodeNodeId( + nodeId, + )} on ${address}`, + ); + } + + /** + * This will take a `QUICConnection` emitted by the `QUICServer` and handle adding it to the connection map + */ + @ready(new nodesErrors.ErrorNodeConnectionManagerNotRunning()) + protected async handleConnectionReverse(quicConnection: QUICConnection) { + // Checking NodeId + // No specific error here, validation is handled by the QUICServer + const certChain = quicConnection.getRemoteCertsChain().map((pem) => { + const cert = keysUtils.certFromPEM(pem as CertificatePEM); + if (cert == null) never(); + return cert; + }); + if (certChain == null) never(); + const nodeId = keysUtils.certNodeId(certChain[0]); + if (nodeId == null) never(); + const nodeIdString = nodeId.toString() as NodeIdString; + // TODO: A connection can fail while awaiting lock. We should abort early in this case. + return await this.connectionLocks.withF( + [nodeIdString, Lock], + async (): Promise => { + // Check if the connection already exists under that nodeId and reject the connection if so + if (this.connections.has(nodeIdString)) { + // Reject and return early. + await quicConnection.stop({ + applicationError: true, + errorCode: 42, + errorMessage: 'Connection already exists, forcing close', + force: true, + }); + return; + } + const nodeConnection = + await NodeConnection.createNodeConnectionReverse( + { + handleStream: this.handleStream, + nodeId, + certChain, + manifest: agentClientManifest, + quicConnection: quicConnection, + logger: this.logger.getChild( + `${NodeConnection.name} [${nodesUtils.encodeNodeId(nodeId)}@${ + quicConnection.remoteHost + }:${quicConnection.remotePort}]`, + ), + }, + ); + // Final setup + this.addConnection(nodeId, nodeConnection); + // We can add the target to the nodeGraph + await this.nodeManager?.setNode(nodeId, { + host: nodeConnection.host, + port: nodeConnection.port, + }); + }, + ); + } + + /** + * Adds connection to the connections map. Preforms some checks and lifecycle hooks. + * This code is shared between the reverse and forward connection creation. + */ + protected addConnection( + nodeId: NodeId, + nodeConnection: NodeConnection, + ): ConnectionAndTimer { + const nodeIdString = nodeId.toString() as NodeIdString; + // Check if exists in map, this should never happen but better safe than sorry. + if (this.connections.has(nodeIdString)) never(); const handleDestroy = async () => { this.logger.debug('stream destroyed event'); // To avoid deadlock only in the case where this is called // we want to check for destroying connection and read lock - const connAndTimer = this.connections.get(nodeIdString); - // If the connection is calling destroyCallback then it SHOULD - // exist in the connection map - if (connAndTimer == null) return; + // If the connection is calling destroyCallback then it SHOULD exist in the connection map. + if (!this.connections.has(nodeIdString)) return; // Already locked so already destroying if (this.connectionLocks.isLocked(nodeIdString)) return; await this.destroyConnection(nodeId); }; - connection.addEventListener('destroy', handleDestroy, { + nodeConnection.addEventListener('destroy', handleDestroy, { once: true, }); - // We can assume connection was established and destination was valid, - // we can add the target to the nodeGraph - await this.nodeManager?.setNode(nodeId, { - host: address.host, - port: address.port, - }); + // Creating TTL timeout. // We don't create a TTL for seed nodes. const timeToLiveTimer = !this.isSeedNode(nodeId) @@ -611,17 +754,12 @@ class NodeConnectionManager { : null; // Add to map const newConnAndTimer: ConnectionAndTimer = { - connection, + connection: nodeConnection, timer: timeToLiveTimer, usageCount: 0, }; this.connections.set(nodeIdString, newConnAndTimer); - connectionsResults.set(nodeIdString, newConnAndTimer); - this.logger.debug( - `Created NodeConnection for ${nodesUtils.encodeNodeId( - nodeId, - )} on ${address}`, - ); + return newConnAndTimer; } /** @@ -702,7 +840,7 @@ class NodeConnectionManager { // Setting up established event checking try { while (true) { - const message = getRandomBytes(32); + const message = keysUtils.getRandomBytes(32); await this.quicSocket.send(Buffer.from(message), port, host); await Promise.race([utils.sleep(delay), endedProm.p]); if (ended) break; @@ -936,7 +1074,7 @@ class NodeConnectionManager { } /** - * Performs a RPC request to retrieve the closest nodes relative to the given + * Performs an RPC request to retrieve the closest nodes relative to the given * target node ID. * @param nodeId the node ID to search on * @param targetNodeId the node ID to find other nodes closest to it @@ -1002,7 +1140,7 @@ class NodeConnectionManager { } /** - * Performs a RPC request to send a hole-punch message to the target. Used to + * Performs an RPC request to send a hole-punch message to the target. Used to * initially establish the NodeConnection from source to target. * * @param relayNodeId node ID of the relay node (i.e. the seed node) @@ -1132,7 +1270,7 @@ class NodeConnectionManager { /** * Checks if a connection can be made to the target. Returns true if the * connection can be authenticated, it's certificate matches the nodeId and - * the addresses match if provided. Otherwise returns false. + * the addresses match if provided. Otherwise, returns false. * @param nodeId - NodeId of the target * @param host - Host of the target node * @param port - Port of the target node @@ -1248,6 +1386,33 @@ class NodeConnectionManager { return results; } + public updateConnectionConfig({ + connectionKeepAliveIntervalTime, + connectionMaxIdleTimeout, + }: { + connectionKeepAliveIntervalTime?: number; + connectionMaxIdleTimeout?: number; + }) { + if (connectionKeepAliveIntervalTime != null) { + this.connectionKeepAliveIntervalTime = connectionKeepAliveIntervalTime; + } + if (connectionMaxIdleTimeout != null) { + this.connectionMaxIdleTimeout = connectionMaxIdleTimeout; + } + this.quicServer.updateConfig({ + keepAliveIntervalTime: connectionKeepAliveIntervalTime, + maxIdleTimeout: connectionMaxIdleTimeout, + }); + } + + public updateTlsConfig(tlsConfig: TLSConfig) { + this.tlsConfig = tlsConfig; + this.quicServer.updateConfig({ + key: tlsConfig.keyPrivatePem, + cert: tlsConfig.certChainPem, + }); + } + protected hasBackoff(nodeId: NodeId): boolean { const backoff = this.nodesBackoffMap.get(nodeId.toString()); if (backoff == null) return false; @@ -1275,6 +1440,52 @@ class NodeConnectionManager { protected removeBackoff(nodeId: NodeId): void { this.nodesBackoffMap.delete(nodeId.toString()); } + + /** + * This attempts the NAT hole punch procedure. It will return a + * `PromiseCancellable` that will resolve once the procedure times out, is + * cancelled or the other end responds. + * + * This is pretty simple, it will contact all known seed nodes and get them to + * relay a punch signal message. + * + * Note: Avoid using a large set of target nodes, It could trigger a large + * amount of pings to a single target. + */ + protected initiateHolePunch( + targetNodeIds: Array, + ctx?: Partial, + ): PromiseCancellable; + @timedCancellable(true) + protected async initiateHolePunch( + targetNodeIds: Array, + @context ctx: ContextTimed, + ): Promise { + const seedNodes = this.getSeedNodes(); + const allProms: Array>> = []; + for (const targetNodeId of targetNodeIds) { + if (!this.isSeedNode(targetNodeId)) { + const holePunchProms = seedNodes.map((seedNodeId) => { + return ( + this.sendSignalingMessage( + seedNodeId, + this.keyRing.getNodeId(), + targetNodeId, + undefined, + ctx, + ) + // Ignore results + .then( + () => {}, + () => {}, + ) + ); + }); + allProms.push(Promise.all(holePunchProms)); + } + } + await Promise.all(allProms).catch(); + } } export default NodeConnectionManager; diff --git a/src/nodes/types.ts b/src/nodes/types.ts index d07de37e3..41eb082b3 100644 --- a/src/nodes/types.ts +++ b/src/nodes/types.ts @@ -1,6 +1,5 @@ import type { NodeId, NodeIdString, NodeIdEncoded } from '../ids/types'; import type { Host, Hostname, Port } from '../network/types'; -import type { QUICConfig } from '@matrixai/quic'; /** * Key indicating which space the NodeGraph is in @@ -27,14 +26,6 @@ type NodeData = { type SeedNodes = Record; -/** - * These are the config options we pass to the quic system. - * It is re-defined here to only expose the options we want to propagate. - * Other parameters are provided via the internal logic. - */ -type QUICClientConfig = Pick & - Omit, 'ca' | 'verifyPeer' | 'verifyAllowFail'>; - export type { NodeId, NodeIdString, @@ -46,5 +37,4 @@ export type { NodeBucket, NodeData, NodeGraphSpace, - QUICClientConfig, }; diff --git a/src/rpc/RPCClient.ts b/src/rpc/RPCClient.ts index 334ab64ff..a3139828f 100644 --- a/src/rpc/RPCClient.ts +++ b/src/rpc/RPCClient.ts @@ -1,11 +1,12 @@ import type { WritableStream, ReadableStream } from 'stream/web'; -import type { ContextTimed, ContextTimedInput } from '@matrixai/contexts'; +import type { ContextTimedInput } from '@matrixai/contexts'; import type { HandlerType, JSONRPCRequestMessage, StreamFactory, ClientManifest, RPCStream, + JSONRPCResponseResult, } from './types'; import type { JSONValue } from '../types'; import type { @@ -20,7 +21,7 @@ import { Timer } from '@matrixai/timer'; import * as rpcUtilsMiddleware from './utils/middleware'; import * as rpcErrors from './errors'; import * as rpcUtils from './utils/utils'; -import { promise } from '../utils'; +import { never, promise } from '../utils'; const timerCleanupReasonSymbol = Symbol('timerCleanUpReasonSymbol'); @@ -254,18 +255,23 @@ class RPCClient { method: string, ctx: Partial = {}, ): Promise> { + // Setting up abort signal and timer const abortController = new AbortController(); const signal = abortController.signal; // A promise that will reject if there is an abort signal or timeout const abortRaceProm = promise(); // Prevent unhandled rejection when we're done with the promise abortRaceProm.p.catch(() => {}); + const abortRacePromHandler = () => { + abortRaceProm.rejectP(signal.reason); + }; + signal.addEventListener('abort', abortRacePromHandler); + let abortHandler: () => void; if (ctx.signal != null) { // Propagate signal events abortHandler = () => { abortController.abort(ctx.signal?.reason); - abortRaceProm.rejectP(ctx.signal?.reason); }; if (ctx.signal.aborted) abortHandler(); ctx.signal.addEventListener('abort', abortHandler); @@ -281,29 +287,32 @@ class RPCClient { const cleanUp = () => { // Clean up the timer and signal if (ctx.timer == null) timer.cancel(timerCleanupReasonSymbol); - signal.removeEventListener('abort', abortHandler); + if (ctx.signal != null) { + ctx.signal.removeEventListener('abort', abortHandler); + } + signal.addEventListener('abort', abortRacePromHandler); }; // Setting up abort events for timeout const timeoutError = new rpcErrors.ErrorRPCTimedOut(); void timer.then( () => { abortController.abort(timeoutError); - abortRaceProm.rejectP(timeoutError); }, () => {}, // Ignore cancellation error ); + // Hooking up agnostic stream side let rpcStream: RPCStream; + const streamFactoryProm = this.streamFactory({ signal, timer }); try { - rpcStream = await Promise.race([ - this.streamFactory({ signal, timer }), - abortRaceProm.p, - ]); + rpcStream = await Promise.race([streamFactoryProm, abortRaceProm.p]); } catch (e) { cleanUp(); + void streamFactoryProm.then((stream) => + stream.cancel(Error('TMP stream timed out early')), + ); throw e; } - // Setting up event for stream timeout void timer.then( () => { rpcStream.cancel(new rpcErrors.ErrorRPCTimedOut()); @@ -374,45 +383,78 @@ class RPCClient { public async rawStreamCaller( method: string, headerParams: JSONValue, - ctx: Partial = {}, - ): Promise> { + ctx: Partial = {}, + ): Promise< + RPCStream< + Uint8Array, + Uint8Array, + Record & { result: JSONValue; command: string } + > + > { + // Setting up abort signal and timer const abortController = new AbortController(); const signal = abortController.signal; // A promise that will reject if there is an abort signal or timeout const abortRaceProm = promise(); // Prevent unhandled rejection when we're done with the promise abortRaceProm.p.catch(() => {}); + const abortRacePromHandler = () => { + abortRaceProm.rejectP(signal.reason); + }; + signal.addEventListener('abort', abortRacePromHandler); + let abortHandler: () => void; if (ctx.signal != null) { // Propagate signal events abortHandler = () => { abortController.abort(ctx.signal?.reason); - abortRaceProm.rejectP(ctx.signal?.reason); }; if (ctx.signal.aborted) abortHandler(); ctx.signal.addEventListener('abort', abortHandler); } - const timer = - ctx.timer ?? - new Timer({ - delay: this.streamKeepAliveTimeoutTime, + let timer: Timer; + if (!(ctx.timer instanceof Timer)) { + timer = new Timer({ + delay: ctx.timer ?? this.streamKeepAliveTimeoutTime, }); + } else { + timer = ctx.timer; + } const cleanUp = () => { // Clean up the timer and signal if (ctx.timer == null) timer.cancel(timerCleanupReasonSymbol); - signal.removeEventListener('abort', abortHandler); + if (ctx.signal != null) { + ctx.signal.removeEventListener('abort', abortHandler); + } + signal.addEventListener('abort', abortRacePromHandler); }; + // Setting up abort events for timeout const timeoutError = new rpcErrors.ErrorRPCTimedOut(); void timer.then( () => { abortController.abort(timeoutError); - abortRaceProm.rejectP(timeoutError); }, - () => {}, + () => {}, // Ignore cancellation error ); - let rpcStream: RPCStream; - const setupStream = async () => { - const rpcStream = await this.streamFactory({ signal, timer }); + + const setupStream = async (): Promise< + [JSONValue, RPCStream] + > => { + if (signal.aborted) throw signal.reason; + const abortProm = promise(); + // Ignore error if orphaned + void abortProm.p.catch(() => {}); + signal.addEventListener( + 'abort', + () => { + abortProm.rejectP(signal.reason); + }, + { once: true }, + ); + const rpcStream = await Promise.race([ + this.streamFactory({ signal, timer }), + abortProm.p, + ]); const tempWriter = rpcStream.writable.getWriter(); const header: JSONRPCRequestMessage = { jsonrpc: '2.0', @@ -422,15 +464,51 @@ class RPCClient { }; await tempWriter.write(Buffer.from(JSON.stringify(header))); tempWriter.releaseLock(); - return rpcStream; + const headTransformStream = rpcUtils.parseHeadStream( + rpcUtils.parseJSONRPCResponse, + ); + void rpcStream.readable + // Allow us to re-use the readable after reading the first message + .pipeTo(headTransformStream.writable) + // Ignore any errors here, we only care that it ended + .catch(() => {}); + const tempReader = headTransformStream.readable.getReader(); + let leadingMessage: JSONRPCResponseResult; + try { + const message = await Promise.race([tempReader.read(), abortProm.p]); + const messageValue = message.value as JSONRPCResponse; + if (message.done) never(); + if ('error' in messageValue) { + const metadata = { + ...(rpcStream.meta ?? {}), + command: method, + }; + throw rpcUtils.toError(messageValue.error.data, metadata); + } + leadingMessage = messageValue; + } catch (e) { + rpcStream.cancel(Error('TMP received error in leading response')); + throw e; + } + tempReader.releaseLock(); + const newRpcStream: RPCStream = { + writable: rpcStream.writable, + readable: headTransformStream.readable as ReadableStream, + cancel: rpcStream.cancel, + meta: rpcStream.meta, + }; + return [leadingMessage.result, newRpcStream]; }; + let streamCreation: [JSONValue, RPCStream]; try { - rpcStream = await Promise.race([setupStream(), abortRaceProm.p]); + streamCreation = await setupStream(); } finally { cleanUp(); } + const [result, rpcStream] = streamCreation; const metadata = { ...(rpcStream.meta ?? {}), + result, command: method, }; return { diff --git a/src/rpc/RPCServer.ts b/src/rpc/RPCServer.ts index 5fd239911..bf300ed4e 100644 --- a/src/rpc/RPCServer.ts +++ b/src/rpc/RPCServer.ts @@ -1,3 +1,4 @@ +import type { ReadableStreamDefaultReadResult } from 'stream/web'; import type { ClientHandlerImplementation, DuplexHandlerImplementation, @@ -239,7 +240,7 @@ class RPCServer extends EventTarget { handler: DuplexHandlerImplementation, timeout: number | undefined, ): void { - const rawSteamHandler: RawHandlerImplementation = ( + const rawSteamHandler: RawHandlerImplementation = async ( [header, input], cancel, meta, @@ -343,7 +344,7 @@ class RPCServer extends EventTarget { }); // Ignore any errors here, it should propagate to the ends of the stream void reverseMiddlewareStream.pipeTo(reverseStream).catch(() => {}); - return middleware.reverse.readable; + return [undefined, middleware.reverse.readable]; }; this.registerRawStreamHandler(method, rawSteamHandler, timeout); } @@ -442,7 +443,6 @@ class RPCServer extends EventTarget { }); }; abortController.signal.addEventListener('abort', handleAbort); - const prom = (async () => { const headTransformStream = rpcUtilsMiddleware.binaryToJsonMessageStream( rpcUtils.parseJSONRPCRequest, @@ -469,30 +469,80 @@ class RPCServer extends EventTarget { await inputStream.cancel(reason); await rpcStream.writable.abort(reason); await inputStreamEndProm; + timer.cancel(cleanupReason); + graceTimer?.cancel(cleanupReason); + await timer.catch(() => {}); + await graceTimer?.catch(() => {}); }; // Read a single empty value to consume the first message const reader = headTransformStream.readable.getReader(); // Allows timing out when waiting for the first message - const headerMessage = await Promise.race([ - reader.read(), - timer.then( - () => undefined, - () => {}, - ), - ]); + let headerMessage: + | ReadableStreamDefaultReadResult + | undefined + | void; + try { + headerMessage = await Promise.race([ + reader.read(), + timer.then( + () => undefined, + () => {}, + ), + ]); + } catch (e) { + const newErr = new rpcErrors.ErrorRPCHandlerFailed( + 'Stream failed waiting for header', + { cause: e }, + ); + await inputStreamEndProm; + timer.cancel(cleanupReason); + graceTimer?.cancel(cleanupReason); + await timer.catch(() => {}); + await graceTimer?.catch(() => {}); + this.dispatchEvent( + new rpcEvents.RPCErrorEvent({ + detail: new rpcErrors.ErrorRPCOutputStreamError( + 'Stream failed waiting for header', + { + cause: newErr, + }, + ), + }), + ); + return; + } // Downgrade back to the raw stream await reader.cancel(); // There are 2 conditions where we just end here // 1. The timeout timer resolves before the first message // 2. the stream ends before the first message if (headerMessage == null) { - await cleanUp( - new rpcErrors.ErrorRPCHandlerFailed('Timed out waiting for header'), + const newErr = new rpcErrors.ErrorRPCHandlerFailed( + 'Timed out waiting for header', + ); + await cleanUp(newErr); + this.dispatchEvent( + new rpcEvents.RPCErrorEvent({ + detail: new rpcErrors.ErrorRPCOutputStreamError( + 'Timed out waiting for header', + { + cause: newErr, + }, + ), + }), ); return; } if (headerMessage.done) { - await cleanUp(new rpcErrors.ErrorRPCHandlerFailed('Missing header')); + const newErr = new rpcErrors.ErrorRPCHandlerFailed('Missing header'); + await cleanUp(newErr); + this.dispatchEvent( + new rpcEvents.RPCErrorEvent({ + detail: new rpcErrors.ErrorRPCOutputStreamError('Missing header', { + cause: newErr, + }), + }), + ); return; } const method = headerMessage.value.method; @@ -514,16 +564,54 @@ class RPCServer extends EventTarget { // Otherwise refresh timer.refresh(); } - const outputStream = handler( - [headerMessage.value, inputStream], - rpcStream.cancel, - rpcStream.meta, - { signal: abortController.signal, timer }, - ); + this.logger.info(`Handling stream with method (${method})`); + let handlerResult: [JSONValue | undefined, ReadableStream]; + const headerWriter = rpcStream.writable.getWriter(); + try { + handlerResult = await handler( + [headerMessage.value, inputStream], + rpcStream.cancel, + rpcStream.meta, + { signal: abortController.signal, timer }, + ); + } catch (e) { + const rpcError: JSONRPCError = { + code: e.exitCode ?? sysexits.UNKNOWN, + message: e.description ?? '', + data: rpcUtils.fromError(e, this.sensitive), + }; + const rpcErrorMessage: JSONRPCResponseError = { + jsonrpc: '2.0', + error: rpcError, + id: null, + }; + await headerWriter.write(Buffer.from(JSON.stringify(rpcErrorMessage))); + await headerWriter.close(); + // Clean up and return + timer.cancel(cleanupReason); + abortController.signal.removeEventListener('abort', handleAbort); + graceTimer?.cancel(cleanupReason); + abortController.abort(new rpcErrors.ErrorRPCStreamEnded()); + rpcStream.cancel(Error('TMP header message was an error')); + return; + } + const [leadingResult, outputStream] = handlerResult; + + if (leadingResult !== undefined) { + // Writing leading metadata + const leadingMessage: JSONRPCResponseResult = { + jsonrpc: '2.0', + result: leadingResult, + id: null, + }; + await headerWriter.write(Buffer.from(JSON.stringify(leadingMessage))); + } + headerWriter.releaseLock(); const outputStreamEndProm = outputStream .pipeTo(rpcStream.writable) .catch(() => {}); // Ignore any errors, we only care that it finished await Promise.allSettled([inputStreamEndProm, outputStreamEndProm]); + this.logger.info(`Handled stream with method (${method})`); // Cleaning up abort and timer timer.cancel(cleanupReason); abortController.signal.removeEventListener('abort', handleAbort); diff --git a/src/rpc/handlers.ts b/src/rpc/handlers.ts index 5ce685e59..a388c75a1 100644 --- a/src/rpc/handlers.ts +++ b/src/rpc/handlers.ts @@ -29,7 +29,7 @@ abstract class RawHandler< cancel: (reason?: any) => void, meta: Record | undefined, ctx: ContextTimed, - ): ReadableStream; + ): Promise<[JSONValue, ReadableStream]>; } abstract class DuplexHandler< diff --git a/src/rpc/types.ts b/src/rpc/types.ts index c0acec854..574d3045d 100644 --- a/src/rpc/types.ts +++ b/src/rpc/types.ts @@ -164,7 +164,7 @@ type HandlerImplementation = ( type RawHandlerImplementation = HandlerImplementation< [JSONRPCRequest, ReadableStream], - ReadableStream + Promise<[JSONValue | undefined, ReadableStream]> >; type DuplexHandlerImplementation< @@ -264,7 +264,13 @@ type DuplexCallerImplementation< type RawCallerImplementation = ( headerParams: JSONValue, ctx?: Partial, -) => Promise>; +) => Promise< + RPCStream< + Uint8Array, + Uint8Array, + Record & { result: JSONValue; command: string } + > +>; type ConvertDuplexCaller = T extends DuplexCaller ? DuplexCallerImplementation diff --git a/src/rpc/utils/utils.ts b/src/rpc/utils/utils.ts index af8b71fdc..d4432bbf7 100644 --- a/src/rpc/utils/utils.ts +++ b/src/rpc/utils/utils.ts @@ -13,6 +13,7 @@ import type { import type { JSONValue } from '../../types'; import type { Timer } from '@matrixai/timer'; import { TransformStream } from 'stream/web'; +import { JSONParser } from '@streamparser/json'; import { AbstractError } from '@matrixai/errors'; import * as rpcErrors from '../errors'; import * as utils from '../../utils'; @@ -429,6 +430,72 @@ function getHandlerTypes( return out; } +/** + * This function is a factory to create a TransformStream that will + * transform a `Uint8Array` stream to a JSONRPC message stream. + * The parsed messages will be validated with the provided messageParser, this + * also infers the type of the stream output. + * @param messageParser - Validates the JSONRPC messages, so you can select for a + * specific type of message + * @param bufferByteLimit - sets the number of bytes buffered before throwing an + * error. This is used to avoid infinitely buffering the input. + */ +function parseHeadStream( + messageParser: (message: unknown) => T, + bufferByteLimit: number = 1024 * 1024, +): TransformStream { + const parser = new JSONParser({ + separator: '', + paths: ['$'], + }); + let bytesWritten: number = 0; + let parsing = true; + let ended = false; + + const endP = utils.promise(); + parser.onEnd = () => endP.resolveP(); + + return new TransformStream( + { + flush: async () => { + if (!parser.isEnded) parser.end(); + await endP.p; + }, + start: (controller) => { + parser.onValue = async (value) => { + const jsonMessage = messageParser(value.value); + controller.enqueue(jsonMessage); + bytesWritten = 0; + parsing = false; + }; + }, + transform: async (chunk, controller) => { + if (parsing) { + try { + bytesWritten += chunk.byteLength; + parser.write(chunk); + } catch (e) { + throw new rpcErrors.ErrorRPCParse(undefined, { cause: e }); + } + if (bytesWritten > bufferByteLimit) { + throw new rpcErrors.ErrorRPCMessageLength(); + } + } else { + // Wait for parser to end + if (!ended) { + parser.end(); + await endP.p; + ended = true; + } + // Pass through normal chunks + controller.enqueue(chunk); + } + }, + }, + { highWaterMark: 1 }, + ); +} + export { parseJSONRPCRequest, parseJSONRPCRequestMessage, @@ -442,4 +509,5 @@ export { clientInputTransformStream, clientOutputTransformStream, getHandlerTypes, + parseHeadStream, }; diff --git a/src/utils/utils.ts b/src/utils/utils.ts index 53b293b5c..eda95679a 100644 --- a/src/utils/utils.ts +++ b/src/utils/utils.ts @@ -429,12 +429,20 @@ function lexiUnpackBuffer(b: Buffer): number { return lexi.unpack([...b]); } +// TODO: remove this, quick hack to allow errors to jump the network +const codeMap = new Map(); +let code = 1; + const reasonToCode = (_type: 'recv' | 'send', _reason?: any): number => { - return 0; + codeMap.set(code, _reason); + const returnCode = code; + code++; + return returnCode; }; const codeToReason = (type: 'recv' | 'send', code: number): any => { - return Error(`${type} ${code}`); + const asd = codeMap.get(code); + return asd; }; export { diff --git a/src/vaults/VaultInternal.ts b/src/vaults/VaultInternal.ts index 5e8cbb933..34caefcc0 100644 --- a/src/vaults/VaultInternal.ts +++ b/src/vaults/VaultInternal.ts @@ -17,6 +17,7 @@ import type { NodeId, NodeIdEncoded } from '../ids/types'; import type NodeConnectionManager from '../nodes/NodeConnectionManager'; import type RPCClient from '../rpc/RPCClient'; import type { clientManifest as agentClientManifest } from '../agent/handlers/clientManifest'; +import type { POJO } from '../types'; import path from 'path'; import git from 'isomorphic-git'; import Logger from '@matrixai/logger'; @@ -29,6 +30,8 @@ import { RWLockWriter } from '@matrixai/async-locks'; import * as vaultsErrors from './errors'; import * as vaultsUtils from './utils'; import { tagLast } from './types'; +import * as validationUtils from '../validation/utils'; +import * as utils from '../utils'; import * as nodesUtils from '../nodes/utils'; import { never } from '../utils/utils'; @@ -137,7 +140,7 @@ class VaultInternal { const vaultIdEncoded = vaultsUtils.encodeVaultId(vaultId); logger.info(`Cloning ${this.name} - ${vaultIdEncoded}`); - const vault = new VaultInternal({ + const vault = new this({ vaultId, db, vaultsDbPath, @@ -756,91 +759,86 @@ class VaultInternal { } protected async request( - _client: RPCClient, - _vaultNameOrId: VaultId | VaultName, - _vaultAction: VaultAction, + client: RPCClient, + vaultNameOrId: VaultId | VaultName, + vaultAction: VaultAction, ): Promise { - throw Error('TMP IMP'); - // Const vaultNameOrId_ = typeof vaultNameOrId === 'string' ? - // vaultNameOrId : - // vaultsUtils.encodeVaultId(vaultNameOrId); - // const response = client.methods.vaultsGitInfoGet({ - // vaultNameOrId: vaultNameOrId_, - // action: vaultAction, - // }); - // let vaultName, remoteVaultId; - // response.stream.on('metadata', async (meta) => { - // // Receive the Id of the remote vault - // vaultName = meta.get('vaultName').pop(); - // if (vaultName) vaultName = vaultName.toString(); - // const vId = meta.get('vaultId').pop(); - // if (vId) remoteVaultId = validationUtils.parseVaultId(vId.toString()); - // }); - // // Collect the response buffers from the GET request - // const infoResponse: Uint8Array[] = []; - // for await (const resp of response) { - // infoResponse.push(resp.getChunk_asU8()); - // } - // const metadata = new grpc.Metadata(); - // metadata.set('vaultAction', vaultAction); - // if (typeof vaultNameOrId === 'string') { - // metadata.set('vaultNameOrId', vaultNameOrId); - // } else { - // // Metadata only accepts the user readable form of the vault Id - // // as the string form has illegal characters - // metadata.set('vaultNameOrId', vaultsUtils.encodeVaultId(vaultNameOrId)); - // } - // return [ - // async function ({ - // url, - // method = 'GET', - // headers = {}, - // body = [Buffer.from('')], - // }: { - // url: string; - // method: string; - // headers: POJO; - // body: Buffer[]; - // }) { - // if (method === 'GET') { - // // Send back the GET request info response - // return { - // url: url, - // method: method, - // body: infoResponse, - // headers: headers, - // statusCode: 200, - // statusMessage: 'OK', - // }; - // } else if (method === 'POST') { - // const responseBuffers: Array = []; - // const stream = client.vaultsGitPackGet(metadata); - // const chunk = new vaultsPB.PackChunk(); - // // Body is usually an async generator but in the cases we are using, - // // only the first value is used - // chunk.setChunk(body[0]); - // // Tell the server what commit we need - // await stream.write(chunk); - // let packResponse = (await stream.read()).value; - // while (packResponse != null) { - // responseBuffers.push(packResponse.getChunk_asU8()); - // packResponse = (await stream.read()).value; - // } - // return { - // url: url, - // method: method, - // body: responseBuffers, - // headers: headers, - // statusCode: 200, - // statusMessage: 'OK', - // }; - // } else { - // never(); - // } - // }, - // vaultName, - // remoteVaultId, - // ]; + const vaultNameOrId_ = + typeof vaultNameOrId === 'string' + ? vaultNameOrId + : vaultsUtils.encodeVaultId(vaultNameOrId); + const vaultsGitInfoGetStream = await client.methods.vaultsGitInfoGet({ + vaultNameOrId: vaultNameOrId_, + action: vaultAction, + }); + const result = vaultsGitInfoGetStream.meta?.result; + if (result == null || !utils.isObject(result)) never(); + if (!('vaultName' in result) || typeof result.vaultName != 'string') { + never(); + } + if ( + !('vaultIdEncoded' in result) || + typeof result.vaultIdEncoded != 'string' + ) { + never(); + } + const vaultName = result.vaultName; + const remoteVaultId = validationUtils.parseVaultId(result.vaultIdEncoded); + + // Collect the response buffers from the GET request + const infoResponse: Uint8Array[] = []; + for await (const chunk of vaultsGitInfoGetStream.readable) { + infoResponse.push(chunk); + } + return [ + async function ({ + url, + method = 'GET', + headers = {}, + body = [Buffer.from('')], + }: { + url: string; + method: string; + headers: POJO; + body: Buffer[]; + }) { + if (method === 'GET') { + // Send back the GET request info response + return { + url: url, + method: method, + body: infoResponse, + headers: headers, + statusCode: 200, + statusMessage: 'OK', + }; + } else if (method === 'POST') { + const responseBuffers: Array = []; + const vaultsGitPackGetStream = await client.methods.vaultsGitPackGet({ + nameOrId: result.vaultIdEncoded as string, + vaultAction, + }); + const writer = vaultsGitPackGetStream.writable.getWriter(); + await writer.write(body[0]); + await writer.close(); + for await (const value of vaultsGitPackGetStream.readable) { + responseBuffers.push(value); + } + return { + url: url, + method: method, + body: responseBuffers, + headers: headers, + statusCode: 200, + statusMessage: 'OK', + }; + } else { + never(); + } + }, + vaultName, + remoteVaultId, + ]; } /** diff --git a/src/vaults/VaultManager.ts b/src/vaults/VaultManager.ts index 345318e0c..e98dc89c9 100644 --- a/src/vaults/VaultManager.ts +++ b/src/vaults/VaultManager.ts @@ -796,11 +796,11 @@ class VaultManager { tran?: DBTransaction, ): AsyncGenerator { if (tran == null) { - return this.db.withTransactionF(async (tran) => - this.handleInfoRequest(vaultId, tran), - ); + const handleInfoRequest = (tran) => this.handleInfoRequest(vaultId, tran); + return yield* this.db.withTransactionG(async function* (tran) { + return yield* handleInfoRequest(tran); + }); } - const efs = this.efs; const vault = await this.getVault(vaultId, tran); return yield* withG( @@ -986,7 +986,6 @@ class VaultManager { if (tran == null) { return this.db.withTransactionF((tran) => this.getVault(vaultId, tran)); } - const vaultIdString = vaultId.toString() as VaultIdString; // 1. get the vault, if it exists then return that const vault = this.vaultMap.get(vaultIdString); @@ -1035,7 +1034,6 @@ class VaultManager { return [vaultId.toString(), RWLockWriter, 'read']; }, ); - // Running the function with locking return await this.vaultLocks.withF(...vaultLocks, async () => { // Getting the vaults while locked @@ -1044,7 +1042,7 @@ class VaultManager { return await this.getVault(vaultId, tran); }), ); - return f(...vaults); + return await f(...vaults); }); } diff --git a/src/websockets/WebSocketClient.ts b/src/websockets/WebSocketClient.ts index ffa5d3510..481e9a8f1 100644 --- a/src/websockets/WebSocketClient.ts +++ b/src/websockets/WebSocketClient.ts @@ -1,12 +1,6 @@ import type { TLSSocket } from 'tls'; -import type { - ReadableStreamController, - WritableStreamDefaultController, -} from 'stream/web'; import type { ContextTimed } from '@matrixai/contexts'; import type { NodeId, NodeIdEncoded } from '../ids'; -import type { JSONValue } from '../types'; -import { WritableStream, ReadableStream } from 'stream/web'; import { createDestroy } from '@matrixai/async-init'; import Logger from '@matrixai/logger'; import WebSocket from 'ws'; @@ -33,7 +27,6 @@ class WebSocketClient { * Default is 1,000 milliseconds. * @param obj.pingTimeoutTimeTime - Time before connection is cleaned up after no ping responses. * Default is 10,000 milliseconds. - * @param obj.maxReadableStreamBytes - The number of bytes the readable stream will buffer until pausing. * @param obj.logger */ static async createWebSocketClient({ @@ -43,7 +36,6 @@ class WebSocketClient { connectionTimeoutTime = Infinity, pingIntervalTime = 1_000, pingTimeoutTimeTime = 10_000, - maxReadableStreamBytes = 1_000, // About 1kB logger = new Logger(this.name), }: { host: string; @@ -52,7 +44,6 @@ class WebSocketClient { connectionTimeoutTime?: number; pingIntervalTime?: number; pingTimeoutTimeTime?: number; - maxReadableStreamBytes?: number; logger?: Logger; }): Promise { logger.info(`Creating ${this.name}`); @@ -60,7 +51,6 @@ class WebSocketClient { logger, host, port, - maxReadableStreamBytes, expectedNodeIds, connectionTimeoutTime, pingIntervalTime, @@ -77,7 +67,6 @@ class WebSocketClient { protected logger: Logger, host: string, protected port: number, - protected maxReadableStreamBytes: number, protected expectedNodeIds: Array, protected connectionTimeoutTime: number, protected pingIntervalTime: number, @@ -126,7 +115,7 @@ class WebSocketClient { @createDestroy.ready(new webSocketErrors.ErrorClientDestroyed()) public async startConnection( ctx: Partial = {}, - ): Promise { + ): Promise { // Setting up abort/cancellation logic const abortRaceProm = promise(); // Ignore unhandled rejection @@ -161,7 +150,13 @@ class WebSocketClient { const address = `wss://${this.host}:${this.port}`; this.logger.info(`Connecting to ${address}`); const connectProm = promise(); - const authenticateProm = promise(); + const authenticateProm = promise<{ + nodeId: NodeIdEncoded; + localHost: string; + localPort: number; + remoteHost: string; + remotePort: number; + }>(); const ws = new WebSocket(address, { rejectUnauthorized: false, }); @@ -178,12 +173,21 @@ class WebSocketClient { ws.once('upgrade', async (request) => { const tlsSocket = request.socket as TLSSocket; const peerCert = tlsSocket.getPeerCertificate(true); - webSocketUtils - .verifyServerCertificateChain( + try { + const nodeId = await webSocketUtils.verifyServerCertificateChain( this.expectedNodeIds, webSocketUtils.detailedToCertChain(peerCert), - ) - .then(authenticateProm.resolveP, authenticateProm.rejectP); + ); + authenticateProm.resolveP({ + nodeId: nodesUtils.encodeNodeId(nodeId), + localHost: request.connection.localAddress ?? '', + localPort: request.connection.localPort ?? 0, + remoteHost: request.connection.remoteAddress ?? '', + remotePort: request.connection.remotePort ?? 0, + }); + } catch (e) { + authenticateProm.rejectP(e); + } }); ws.once('open', () => { this.logger.info('starting connection'); @@ -222,17 +226,14 @@ class WebSocketClient { // Constructing the `ReadableWritablePair`, the lifecycle is handed off to // the webSocketStream at this point. - const webSocketStreamClient = new WebSocketStreamClientInternal( + const webSocketStreamClient = new WebSocketStream( ws, - this.maxReadableStreamBytes, this.pingIntervalTime, this.pingTimeoutTimeTime, { - host: this.host, - nodeId: nodesUtils.encodeNodeId(await authenticateProm.p), - port: this.port, + ...(await authenticateProm.p), }, - this.logger, + this.logger.getChild(WebSocketStream.name), ); const abortStream = () => { webSocketStreamClient.cancel( @@ -258,219 +259,4 @@ class WebSocketClient { } // This is the internal implementation of the client's stream pair. -class WebSocketStreamClientInternal extends WebSocketStream { - protected readableController: - | ReadableStreamController - | undefined; - protected writableController: WritableStreamDefaultController | undefined; - - constructor( - protected ws: WebSocket, - maxReadableStreamBytes: number, - pingInterval: number, - pingTimeoutTime: number, - protected clientMetadata: { - nodeId: NodeIdEncoded; - host: string; - port: number; - }, - logger: Logger, - ) { - super(); - const readableLogger = logger.getChild('readable'); - const writableLogger = logger.getChild('writable'); - - this.readable = new ReadableStream( - { - start: (controller) => { - this.readableController = controller; - readableLogger.info('Starting'); - const messageHandler = (data) => { - readableLogger.debug(`Received ${data.toString()}`); - if (controller.desiredSize == null) { - controller.error(Error('NEVER')); - return; - } - if (controller.desiredSize < 0) { - readableLogger.debug('Applying readable backpressure'); - ws.pause(); - } - const message = data as Buffer; - if (message.length === 0) { - readableLogger.debug('Null message received'); - ws.removeListener('message', messageHandler); - if (!this._readableEnded) { - this.signalReadableEnd(); - readableLogger.debug('Closing'); - controller.close(); - } - if (this._writableEnded) { - logger.debug('Closing socket'); - ws.close(); - } - return; - } - controller.enqueue(message); - }; - readableLogger.debug('Registering socket message handler'); - ws.on('message', messageHandler); - ws.once('close', (code, reason) => { - logger.info('Socket closed'); - ws.removeListener('message', messageHandler); - if (!this._readableEnded) { - readableLogger.debug( - `Closed early, ${code}, ${reason.toString()}`, - ); - const e = new webSocketErrors.ErrorClientConnectionEndedEarly(); - this.signalReadableEnd(e); - controller.error(e); - } - }); - ws.once('error', (e) => { - if (!this._readableEnded) { - readableLogger.error(e); - this.signalReadableEnd(e); - controller.error(e); - } - }); - }, - cancel: (reason) => { - readableLogger.debug('Cancelled'); - this.signalReadableEnd(reason); - if (!this._writableEnded) { - readableLogger.debug('Closing socket'); - this.signalWritableEnd(reason); - ws.close(); - } - }, - pull: () => { - readableLogger.debug('Releasing backpressure'); - ws.resume(); - }, - }, - { - highWaterMark: maxReadableStreamBytes, - size: (chunk) => chunk?.byteLength ?? 0, - }, - ); - this.writable = new WritableStream({ - start: (controller) => { - this.writableController = controller; - writableLogger.info('Starting'); - ws.once('error', (e) => { - if (!this._writableEnded) { - writableLogger.error(e); - this.signalWritableEnd(e); - controller.error(e); - } - }); - ws.once('close', (code, reason) => { - if (!this._writableEnded) { - writableLogger.debug(`Closed early, ${code}, ${reason.toString()}`); - const e = new webSocketErrors.ErrorClientConnectionEndedEarly(); - this.signalWritableEnd(e); - controller.error(e); - } - }); - }, - close: () => { - writableLogger.debug('Closing, sending null message'); - ws.send(Buffer.from([])); - this.signalWritableEnd(); - if (this._readableEnded) { - writableLogger.debug('Closing socket'); - ws.close(); - } - }, - abort: (reason) => { - writableLogger.debug('Aborted'); - this.signalWritableEnd(reason); - if (this._readableEnded) { - writableLogger.debug('Closing socket'); - ws.close(); - } - }, - write: async (chunk, controller) => { - if (this._writableEnded) return; - writableLogger.debug(`Sending ${chunk?.toString()}`); - const wait = promise(); - ws.send(chunk, (e) => { - if (e != null && !this._writableEnded) { - // Opting to debug message here and not log an error, sending - // failure is common if we send before the close event. - writableLogger.debug('failed to send'); - const err = new webSocketErrors.ErrorClientConnectionEndedEarly( - undefined, - { - cause: e, - }, - ); - this.signalWritableEnd(err); - controller.error(err); - } - wait.resolveP(); - }); - await wait.p; - }, - }); - - // Setting up heartbeat - const pingTimer = setInterval(() => { - ws.ping(); - }, pingInterval); - const pingTimeoutTimeTimer = setTimeout(() => { - logger.debug('Ping timed out'); - ws.close(4002, 'Timed out'); - }, pingTimeoutTime); - ws.on('ping', () => { - logger.debug('Received ping'); - ws.pong(); - }); - ws.on('pong', () => { - logger.debug('Received pong'); - pingTimeoutTimeTimer.refresh(); - }); - ws.once('close', (code, reason) => { - logger.debug('WebSocket closed'); - const err = - code !== 1000 - ? new webSocketErrors.ErrorClientConnectionEndedEarly( - `ended with code ${code}, ${reason.toString()}`, - ) - : undefined; - this.signalWebSocketEnd(err); - logger.debug('Cleaning up timers'); - // Clean up timers - clearTimeout(pingTimer); - clearTimeout(pingTimeoutTimeTimer); - }); - } - - get meta(): Record { - // Spreading to avoid modifying the data - return { - ...this.clientMetadata, - }; - } - - cancel(reason?: any): void { - // Default error - const err = reason ?? new webSocketErrors.ErrorClientConnectionEndedEarly(); - // Close the streams with the given error, - if (!this._readableEnded) { - this.readableController?.error(err); - this.signalReadableEnd(err); - } - if (!this._writableEnded) { - this.writableController?.error(err); - this.signalWritableEnd(err); - } - // Then close the websocket - if (!this._webSocketEnded) { - this.ws.close(4000, 'Ending connection'); - this.signalWebSocketEnd(err); - } - } -} - export default WebSocketClient; diff --git a/src/websockets/WebSocketServer.ts b/src/websockets/WebSocketServer.ts index a9df35d92..5d99f79a5 100644 --- a/src/websockets/WebSocketServer.ts +++ b/src/websockets/WebSocketServer.ts @@ -1,40 +1,17 @@ -import type { - ReadableStreamController, - WritableStreamDefaultController, -} from 'stream/web'; -import type { - HttpRequest, - HttpResponse, - us_socket_context_t, - WebSocket, -} from 'uWebSockets.js'; -import type { FileSystem, JSONValue, PromiseDeconstructed } from '../types'; import type { TLSConfig } from '../network/types'; -import { WritableStream, ReadableStream } from 'stream/web'; -import path from 'path'; -import os from 'os'; -import { startStop } from '@matrixai/async-init'; +import type { IncomingMessage, ServerResponse } from 'http'; +import type tls from 'tls'; +import https from 'https'; +import { startStop, status } from '@matrixai/async-init'; import Logger from '@matrixai/logger'; -import uWebsocket from 'uWebSockets.js'; +import * as ws from 'ws'; import WebSocketStream from './WebSocketStream'; import * as webSocketErrors from './errors'; import * as webSocketEvents from './events'; -import { promise } from '../utils'; +import { never, promise } from '../utils'; type ConnectionCallback = (streamPair: WebSocketStream) => void; -type Context = { - message: ( - ws: WebSocket, - message: ArrayBuffer, - isBinary: boolean, - ) => void; - drain: (ws: WebSocket) => void; - close: (ws: WebSocket, code: number, message: ArrayBuffer) => void; - pong: (ws: WebSocket, message: ArrayBuffer) => void; - logger: Logger; -}; - /** * Events: * - start @@ -48,7 +25,6 @@ class WebSocketServer extends EventTarget { * @param obj * @param obj.connectionCallback - * @param obj.tlsConfig - TLSConfig containing the private key and cert chain used for TLS. - * @param obj.basePath - Directory path used for storing temp cert files for starting the `uWebsocket` server. * @param obj.host - Listen address to bind to. * @param obj.port - Listen port to bind to. * @param obj.maxIdleTimeout - Timeout time for when the connection is cleaned up after no activity. @@ -57,40 +33,30 @@ class WebSocketServer extends EventTarget { * Default is 1,000 milliseconds. * @param obj.pingTimeoutTimeTime - Time before connection is cleaned up after no ping responses. * Default is 10,000 milliseconds. - * @param obj.fs - FileSystem interface used for creating files. - * @param obj.maxReadableStreamBytes - The number of bytes the readable stream will buffer until pausing. * @param obj.logger */ static async createWebSocketServer({ connectionCallback, tlsConfig, - basePath, host, port, maxIdleTimeout = 120, pingIntervalTime = 1_000, pingTimeoutTimeTime = 10_000, - fs = require('fs'), - maxReadableStreamBytes = 1_000_000_000, // About 1 GB logger = new Logger(this.name), }: { connectionCallback: ConnectionCallback; tlsConfig: TLSConfig; - basePath?: string; host?: string; port?: number; maxIdleTimeout?: number; pingIntervalTime?: number; pingTimeoutTimeTime?: number; - fs?: FileSystem; - maxReadableStreamBytes?: number; logger?: Logger; }) { logger.info(`Creating ${this.name}`); const wsServer = new this( logger, - fs, - maxReadableStreamBytes, maxIdleTimeout, pingIntervalTime, pingTimeoutTimeTime, @@ -98,7 +64,6 @@ class WebSocketServer extends EventTarget { await wsServer.start({ connectionCallback, tlsConfig, - basePath, host, port, }); @@ -106,29 +71,24 @@ class WebSocketServer extends EventTarget { return wsServer; } - protected server: uWebsocket.TemplatedApp; - protected listenSocket: uWebsocket.us_listen_socket; + protected server: https.Server; + protected webSocketServer: ws.WebSocketServer; protected _port: number; protected _host: string; protected connectionEventHandler: ( event: webSocketEvents.ConnectionEvent, ) => void; protected activeSockets: Set = new Set(); - protected connectionIndex: number = 0; /** * * @param logger - * @param fs - * @param maxReadableStreamBytes Max number of bytes stored in read buffer before error * @param maxIdleTimeout * @param pingIntervalTime * @param pingTimeoutTimeTime */ constructor( protected logger: Logger, - protected fs: FileSystem, - protected maxReadableStreamBytes, protected maxIdleTimeout: number | undefined, protected pingIntervalTime: number, protected pingTimeoutTimeTime: number, @@ -138,13 +98,11 @@ class WebSocketServer extends EventTarget { public async start({ tlsConfig, - basePath = os.tmpdir(), host, port = 0, connectionCallback, }: { tlsConfig: TLSConfig; - basePath?: string; host?: string; port?: number; connectionCallback?: ConnectionCallback; @@ -158,47 +116,29 @@ class WebSocketServer extends EventTarget { }; this.addEventListener('connection', this.connectionEventHandler); } - await this.setupServer(basePath, tlsConfig); - this.server.ws('/*', { - sendPingsAutomatically: true, - idleTimeout: this.maxIdleTimeout, - upgrade: this.upgrade, - open: this.open, - message: this.message, - close: this.close, - drain: this.drain, - pong: this.pong, - // Ping uses default behaviour. - // We don't use subscriptions. + this.server = https.createServer({ + key: tlsConfig.keyPrivatePem, + cert: tlsConfig.certChainPem, }); - this.server.any('/*', (res, _) => { - // Reject normal requests with an upgrade code - res - .writeStatus('426') - .writeHeader('connection', 'Upgrade') - .writeHeader('upgrade', 'websocket') - .end('426 Upgrade Required', true); + this.webSocketServer = new ws.WebSocketServer({ + server: this.server, }); + + this.webSocketServer.on('connection', this.connectionHandler); + this.webSocketServer.on('close', this.closeHandler); + this.server.on('close', this.closeHandler); + this.webSocketServer.on('error', this.errorHandler); + this.server.on('error', this.errorHandler); + this.server.on('request', this.requestHandler); + const listenProm = promise(); - const listenCallback = (listenSocket) => { - if (listenSocket) { - this.listenSocket = listenSocket; - listenProm.resolveP(); - } else { - listenProm.rejectP(new webSocketErrors.ErrorServerPortUnavailable()); - } - }; - if (host != null) { - // With custom host - this.server.listen(host, port ?? 0, listenCallback); - } else { - // With default host - this.server.listen(port, listenCallback); - } + this.server.listen(port ?? 0, host, listenProm.resolveP); await listenProm.p; - this._port = uWebsocket.us_socket_local_port(this.listenSocket); + const address = this.server.address(); + if (address == null || typeof address === 'string') never(); + this._port = address.port; this.logger.debug(`Listening on port ${this._port}`); - this._host = host ?? '127.0.0.1'; + this._host = address.address ?? '127.0.0.1'; this.dispatchEvent( new webSocketEvents.StartEvent({ detail: { @@ -212,8 +152,6 @@ class WebSocketServer extends EventTarget { public async stop(force: boolean = false): Promise { this.logger.info(`Stopping ${this.constructor.name}`); - // Close the server by closing the underlying socket - uWebsocket.us_listen_socket_close(this.listenSocket); // Shutting down active websockets if (force) { for (const webSocketStream of this.activeSockets) { @@ -225,9 +163,37 @@ class WebSocketServer extends EventTarget { // Ignore errors, we only care that it finished webSocketStream.endedProm.catch(() => {}); } + // Close the server by closing the underlying socket + const wssCloseProm = promise(); + this.webSocketServer.close((e) => { + if (e == null || e.message === 'The server is not running') { + wssCloseProm.resolveP(); + } else { + wssCloseProm.rejectP(e); + } + }); + await wssCloseProm.p; + const serverCloseProm = promise(); + this.server.close((e) => { + if (e == null || e.message === 'Server is not running.') { + serverCloseProm.resolveP(); + } else { + serverCloseProm.rejectP(e); + } + }); + await serverCloseProm.p; + // Removing handlers if (this.connectionEventHandler != null) { this.removeEventListener('connection', this.connectionEventHandler); } + + this.webSocketServer.off('connection', this.connectionHandler); + this.webSocketServer.off('close', this.closeHandler); + this.server.off('close', this.closeHandler); + this.webSocketServer.off('error', this.errorHandler); + this.server.off('error', this.errorHandler); + this.server.on('request', this.requestHandler); + this.dispatchEvent(new webSocketEvents.StopEvent()); this.logger.info(`Stopped ${this.constructor.name}`); } @@ -242,68 +208,39 @@ class WebSocketServer extends EventTarget { return this._host; } - /** - * This creates the pem files and starts the server with them. It ensures that - * files are cleaned up to the best of its ability. - */ - protected async setupServer(basePath: string, tlsConfig: TLSConfig) { - const tmpDir = await this.fs.promises.mkdtemp( - path.join(basePath, 'polykey-'), - ); - // TODO: The key file needs to be in the encrypted format - const keyFile = path.join(tmpDir, 'keyFile.pem'); - const certFile = path.join(tmpDir, 'certFile.pem'); - await this.fs.promises.writeFile(keyFile, tlsConfig.keyPrivatePem); - await this.fs.promises.writeFile(certFile, tlsConfig.certChainPem); - try { - this.server = uWebsocket.SSLApp({ - key_file_name: keyFile, - cert_file_name: certFile, - }); - } finally { - await this.fs.promises.rm(keyFile); - await this.fs.promises.rm(certFile); - await this.fs.promises.rm(tmpDir, { recursive: true, force: true }); - } + @startStop.ready(new webSocketErrors.ErrorWebSocketServerNotRunning()) + public setTlsConfig(tlsConfig: TLSConfig): void { + const tlsServer = this.server as tls.Server; + tlsServer.setSecureContext({ + key: tlsConfig.keyPrivatePem, + cert: tlsConfig.certChainPem, + }); } - /** - * Applies default upgrade behaviour and creates a UserData object we can - * mutate for the Context - */ - protected upgrade = ( - res: HttpResponse, - req: HttpRequest, - context: us_socket_context_t, - ) => { - const logger = this.logger.getChild(`Connection ${this.connectionIndex}`); - res.upgrade>( - { - logger, - }, - req.getHeader('sec-websocket-key'), - req.getHeader('sec-websocket-protocol'), - req.getHeader('sec-websocket-extensions'), - context, - ); - this.connectionIndex += 1; - }; - /** * Handles the creation of the `ReadableWritablePair` and provides it to the * StreamPair handler. */ - protected open = (ws: WebSocket) => { - const webSocketStream = new WebSocketStreamServerInternal( - ws, - this.maxReadableStreamBytes, + protected connectionHandler = ( + webSocket: ws.WebSocket, + request: IncomingMessage, + ) => { + const connection = request.connection; + const webSocketStream = new WebSocketStream( + webSocket, this.pingIntervalTime, this.pingTimeoutTimeTime, - {}, // TODO: fill in connection metadata + { + localHost: connection.localAddress ?? '', + localPort: connection.localPort ?? 0, + remoteHost: connection.remoteAddress ?? '', + remotePort: connection.remotePort ?? 0, + }, + this.logger.getChild(WebSocketStream.name), ); // Adding socket to the active sockets map this.activeSockets.add(webSocketStream); - webSocketStream.endedProm + void webSocketStream.endedProm // Ignore errors, we only care that it finished .catch(() => {}) .finally(() => { @@ -322,215 +259,35 @@ class WebSocketServer extends EventTarget { }; /** - * Routes incoming messages to each stream using the `Context` message - * callback. + * Used to trigger stopping if the underlying server fails */ - protected message = ( - ws: WebSocket, - message: ArrayBuffer, - isBinary: boolean, - ) => { - ws.getUserData().message(ws, message, isBinary); - }; - - protected drain = (ws: WebSocket) => { - ws.getUserData().drain(ws); + protected closeHandler = async () => { + if (this[status] == null || this[status] === 'stopping') { + this.logger.debug('close event but already stopping'); + return; + } + this.logger.debug('close event, forcing stop'); + await this.stop(true); }; - protected close = ( - ws: WebSocket, - code: number, - message: ArrayBuffer, - ) => { - ws.getUserData().close(ws, code, message); + /** + * Used to propagate error conditions + */ + protected errorHandler = (e: Error) => { + this.logger.error(e); }; - protected pong = (ws: WebSocket, message: ArrayBuffer) => { - ws.getUserData().pong(ws, message); + /** + * Will tell any normal HTTP request to upgrade + */ + protected requestHandler = (_req, res: ServerResponse) => { + res + .writeHead(426, '426 Upgrade Required', { + connection: 'Upgrade', + upgrade: 'websocket', + }) + .end('426 Upgrade Required'); }; } -class WebSocketStreamServerInternal extends WebSocketStream { - protected backPressure: PromiseDeconstructed | null = null; - protected writeBackpressure: boolean = false; - protected writableController: WritableStreamDefaultController | undefined; - protected readableController: - | ReadableStreamController - | undefined; - - constructor( - protected ws: WebSocket, - maxReadBufferBytes: number, - pingInterval: number, - pingTimeoutTime: number, - protected metadata: Record, - ) { - super(); - const context = ws.getUserData(); - const logger = context.logger; - logger.info('WS opened'); - const writableLogger = logger.getChild('Writable'); - const readableLogger = logger.getChild('Readable'); - // Setting up the writable stream - this.writable = new WritableStream({ - start: (controller) => { - this.writableController = controller; - }, - write: async (chunk, controller) => { - await this.backPressure?.p; - const writeResult = ws.send(chunk, true); - switch (writeResult) { - default: - case 2: - // Write failure, emit error - writableLogger.error('Send error'); - controller.error(new webSocketErrors.ErrorServerSendFailed()); - break; - case 0: - writableLogger.info('Write backpressure'); - // Signal backpressure - this.backPressure = promise(); - this.writeBackpressure = true; - this.backPressure.p.finally(() => { - this.writeBackpressure = false; - }); - break; - case 1: - // Success - writableLogger.debug(`Sending ${Buffer.from(chunk).toString()}`); - break; - } - }, - close: () => { - writableLogger.info('Closed, sending null message'); - if (!this._webSocketEnded) ws.send(Buffer.from([]), true); - this.signalWritableEnd(); - if (this._readableEnded && !this._webSocketEnded) { - writableLogger.debug('Ending socket'); - this.signalWebSocketEnd(); - ws.end(); - } - }, - abort: (reason) => { - writableLogger.info('Aborted'); - if (this._readableEnded && !this._webSocketEnded) { - writableLogger.debug('Ending socket'); - this.signalWebSocketEnd(reason); - ws.end(4000, 'Aborting connection'); - } - }, - }); - // Setting up the readable stream - this.readable = new ReadableStream( - { - start: (controller) => { - this.readableController = controller; - context.message = (ws, message, _) => { - const messageBuffer = Buffer.from(message); - readableLogger.debug(`Received ${messageBuffer.toString()}`); - if (message.byteLength === 0) { - readableLogger.debug('Null message received'); - if (!this._readableEnded) { - readableLogger.debug('Closing'); - this.signalReadableEnd(); - controller.close(); - if (this._writableEnded && !this._webSocketEnded) { - readableLogger.debug('Ending socket'); - this.signalWebSocketEnd(); - ws.end(); - } - } - return; - } - controller.enqueue(messageBuffer); - if (controller.desiredSize != null && controller.desiredSize < 0) { - readableLogger.error('Read stream buffer full'); - const err = new webSocketErrors.ErrorServerReadableBufferLimit(); - if (!this._webSocketEnded) { - this.signalWebSocketEnd(err); - ws.end(4000, 'Read stream buffer full'); - } - controller.error(err); - } - }; - }, - cancel: (reason) => { - this.signalReadableEnd(reason); - if (this._writableEnded && !this._webSocketEnded) { - readableLogger.debug('Ending socket'); - this.signalWebSocketEnd(); - ws.end(); - } - }, - }, - { - highWaterMark: maxReadBufferBytes, - size: (chunk) => chunk?.byteLength ?? 0, - }, - ); - - const pingTimer = setInterval(() => { - ws.ping(); - }, pingInterval); - const pingTimeoutTimeTimer = setTimeout(() => { - logger.debug('Ping timed out'); - ws.end(); - }, pingTimeoutTime); - context.pong = () => { - logger.debug('Received pong'); - pingTimeoutTimeTimer.refresh(); - }; - context.close = () => { - logger.debug('Closing'); - this.signalWebSocketEnd(); - // Cleaning up timers - logger.debug('Cleaning up timers'); - clearTimeout(pingTimer); - clearTimeout(pingTimeoutTimeTimer); - // Closing streams - logger.debug('Cleaning streams'); - const err = new webSocketErrors.ErrorServerConnectionEndedEarly(); - if (!this._readableEnded) { - readableLogger.debug('Closing'); - this.signalReadableEnd(err); - this.readableController?.error(err); - } - if (!this._writableEnded) { - writableLogger.debug('Closing'); - this.signalWritableEnd(err); - this.writableController?.error(err); - } - }; - context.drain = () => { - logger.debug('Drained'); - this.backPressure?.resolveP(); - }; - } - - get meta(): Record { - return { - ...this.metadata, - }; - } - - cancel(reason?: any): void { - // Default error - const err = reason ?? new webSocketErrors.ErrorClientConnectionEndedEarly(); - // Close the streams with the given error, - if (!this._readableEnded) { - this.readableController?.error(err); - this.signalReadableEnd(err); - } - if (!this._writableEnded) { - this.writableController?.error(err); - this.signalWritableEnd(err); - } - // Then close the websocket - if (!this._webSocketEnded) { - this.ws.end(4000, 'Ending connection'); - this.signalWebSocketEnd(err); - } - } -} - export default WebSocketServer; diff --git a/src/websockets/WebSocketStream.ts b/src/websockets/WebSocketStream.ts index f71d50e26..ea5934cfc 100644 --- a/src/websockets/WebSocketStream.ts +++ b/src/websockets/WebSocketStream.ts @@ -1,13 +1,17 @@ +import type { ReadableWritablePair } from 'stream/web'; import type { - ReadableStream, - ReadableWritablePair, - WritableStream, + ReadableStreamController, + WritableStreamDefaultController, } from 'stream/web'; +import type * as ws from 'ws'; +import type Logger from '@matrixai/logger'; +import type { NodeIdEncoded } from '../ids/types'; +import { WritableStream, ReadableStream } from 'stream/web'; +import * as webSocketErrors from './errors'; +import * as utilsErrors from '../utils/errors'; import { promise } from '../utils'; -abstract class WebSocketStream - implements ReadableWritablePair -{ +class WebSocketStream implements ReadableWritablePair { public readable: ReadableStream; public writable: WritableStream; @@ -19,7 +23,24 @@ abstract class WebSocketStream protected _webSocketEndedProm = promise(); protected _endedProm: Promise; - protected constructor() { + protected readableController: + | ReadableStreamController + | undefined; + protected writableController: WritableStreamDefaultController | undefined; + + constructor( + protected ws: ws.WebSocket, + pingInterval: number, + pingTimeoutTime: number, + protected metadata: { + nodeId?: NodeIdEncoded; + localHost: string; + localPort: number; + remoteHost: string; + remotePort: number; + }, + logger: Logger, + ) { // Sanitise promises so they don't result in unhandled rejections this._readableEndedProm.p.catch(() => {}); this._writableEndedProm.p.catch(() => {}); @@ -42,6 +63,193 @@ abstract class WebSocketStream }); // Ignore errors if it's never used this._endedProm.catch(() => {}); + + logger.info('WS opened'); + const readableLogger = logger.getChild('readable'); + const writableLogger = logger.getChild('writable'); + // Setting up the readable stream + this.readable = new ReadableStream( + { + start: (controller) => { + readableLogger.debug('Starting'); + this.readableController = controller; + const messageHandler = (data: ws.RawData, isBinary: boolean) => { + if (!isBinary || data instanceof Array) { + controller.error(new utilsErrors.ErrorUtilsUndefinedBehaviour()); + return; + } + const message = data as Buffer; + readableLogger.debug(`Received ${message.toString()}`); + if (message.length === 0) { + readableLogger.debug('Null message received'); + ws.removeListener('message', messageHandler); + if (!this._readableEnded) { + readableLogger.debug('Closing'); + this.signalReadableEnd(); + controller.close(); + } + if (this._writableEnded) { + logger.debug('Closing socket'); + ws.close(); + } + return; + } + if (this._readableEnded) { + return; + } + controller.enqueue(message); + if (controller.desiredSize == null) { + controller.error(new utilsErrors.ErrorUtilsUndefinedBehaviour()); + return; + } + if (controller.desiredSize < 0) { + readableLogger.debug('Applying readable backpressure'); + ws.pause(); + } + }; + readableLogger.debug('Registering socket message handler'); + ws.on('message', messageHandler); + ws.once('close', (code, reason) => { + logger.info('Socket closed'); + ws.removeListener('message', messageHandler); + if (!this._readableEnded) { + readableLogger.debug( + `Closed early, ${code}, ${reason.toString()}`, + ); + const e = new webSocketErrors.ErrorClientConnectionEndedEarly(); + this.signalReadableEnd(e); + controller.error(e); + } + }); + ws.once('error', (e) => { + if (!this._readableEnded) { + readableLogger.error(e); + this.signalReadableEnd(e); + controller.error(e); + } + }); + }, + cancel: (reason) => { + readableLogger.debug('Cancelled'); + this.signalReadableEnd(reason); + if (this._writableEnded) { + readableLogger.debug('Closing socket'); + this.signalWritableEnd(reason); + ws.close(); + } + }, + pull: () => { + readableLogger.debug('Releasing backpressure'); + ws.resume(); + }, + }, + { highWaterMark: 1 }, + ); + this.writable = new WritableStream( + { + start: (controller) => { + this.writableController = controller; + writableLogger.info('Starting'); + ws.once('error', (e) => { + if (!this._writableEnded) { + writableLogger.error(e); + this.signalWritableEnd(e); + controller.error(e); + } + }); + ws.once('close', (code, reason) => { + if (!this._writableEnded) { + writableLogger.debug( + `Closed early, ${code}, ${reason.toString()}`, + ); + const e = new webSocketErrors.ErrorClientConnectionEndedEarly(); + this.signalWritableEnd(e); + controller.error(e); + } + }); + }, + close: async () => { + writableLogger.debug('Closing, sending null message'); + const sendProm = promise(); + ws.send(Buffer.from([]), (err) => { + if (err == null) sendProm.resolveP(); + else sendProm.rejectP(err); + }); + await sendProm.p; + this.signalWritableEnd(); + if (this._readableEnded) { + writableLogger.debug('Closing socket'); + ws.close(); + } + }, + abort: (reason) => { + writableLogger.debug('Aborted'); + this.signalWritableEnd(reason); + if (this._readableEnded) { + writableLogger.debug('Closing socket'); + ws.close(4000, `Aborting connection with ${reason.message}`); + } + }, + write: async (chunk, controller) => { + if (this._writableEnded) return; + writableLogger.debug(`Sending ${chunk?.toString()}`); + const wait = promise(); + ws.send(chunk, (e) => { + if (e != null && !this._writableEnded) { + // Opting to debug message here and not log an error, sending + // failure is common if we send before the close event. + writableLogger.debug('failed to send'); + const err = new webSocketErrors.ErrorClientConnectionEndedEarly( + undefined, + { + cause: e, + }, + ); + this.signalWritableEnd(err); + controller.error(err); + } + wait.resolveP(); + }); + await wait.p; + }, + }, + { highWaterMark: 1 }, + ); + + // Setting up heartbeat + const pingTimer = setInterval(() => { + ws.ping(); + }, pingInterval); + const pingTimeoutTimeTimer = setTimeout(() => { + logger.debug('Ping timed out'); + ws.close(4002, 'Timed out'); + }, pingTimeoutTime); + const pingHandler = () => { + logger.debug('Received ping'); + ws.pong(); + }; + const pongHandler = () => { + logger.debug('Received pong'); + pingTimeoutTimeTimer.refresh(); + }; + ws.on('ping', pingHandler); + ws.on('pong', pongHandler); + ws.once('close', (code, reason) => { + ws.off('ping', pingHandler); + ws.off('pong', pongHandler); + logger.debug('WebSocket closed'); + const err = + code !== 1000 + ? new webSocketErrors.ErrorClientConnectionEndedEarly( + `ended with code ${code}, ${reason.toString()}`, + ) + : undefined; + this.signalWebSocketEnd(err); + logger.debug('Cleaning up timers'); + // Clean up timers + clearTimeout(pingTimer); + clearTimeout(pingTimeoutTimeTimer); + }); } get readableEnded() { @@ -88,10 +296,34 @@ abstract class WebSocketStream return this._endedProm; } + get meta() { + // Spreading to avoid modifying the data + return { + ...this.metadata, + }; + } + /** * Forces the active stream to end early */ - abstract cancel(reason?: any): void; + public cancel(reason?: any): void { + // Default error + const err = reason ?? new webSocketErrors.ErrorClientConnectionEndedEarly(); + // Close the streams with the given error, + if (!this._readableEnded) { + this.readableController?.error(err); + this.signalReadableEnd(err); + } + if (!this._writableEnded) { + this.writableController?.error(err); + this.signalWritableEnd(err); + } + // Then close the websocket + if (!this._webSocketEnded) { + this.ws.close(4000, 'Ending connection'); + this.signalWebSocketEnd(err); + } + } /** * Signals the end of the ReadableStream. to be used with the extended class diff --git a/tests/agent/handlers/nodesClaimsGet.test.ts b/tests/agent/handlers/nodesClaimsGet.test.ts index 246898f1a..9ca427a0b 100644 --- a/tests/agent/handlers/nodesClaimsGet.test.ts +++ b/tests/agent/handlers/nodesClaimsGet.test.ts @@ -87,7 +87,10 @@ describe('nodesClaimsGet', () => { cert: tlsConfig.certChainPem, verifyPeer: false, }, - crypto, + crypto: { + key: keysUtils.generateKey(), + ops: crypto, + }, logger, }); const handleStream = async ( @@ -134,7 +137,9 @@ describe('nodesClaimsGet', () => { logger, }); quicClient = await QUICClient.createQUICClient({ - crypto, + crypto: { + ops: crypto, + }, config: { verifyPeer: false, }, diff --git a/tests/agent/handlers/nodesClosestLocalNode.test.ts b/tests/agent/handlers/nodesClosestLocalNode.test.ts index 58c574cb2..03bdc7305 100644 --- a/tests/agent/handlers/nodesClosestLocalNode.test.ts +++ b/tests/agent/handlers/nodesClosestLocalNode.test.ts @@ -86,7 +86,10 @@ describe('nodesClosestLocalNode', () => { cert: tlsConfig.certChainPem, verifyPeer: false, }, - crypto, + crypto: { + key: keysUtils.generateKey(), + ops: crypto, + }, logger, }); const handleStream = async ( @@ -133,7 +136,9 @@ describe('nodesClosestLocalNode', () => { logger, }); quicClient = await QUICClient.createQUICClient({ - crypto, + crypto: { + ops: crypto, + }, config: { verifyPeer: false, }, diff --git a/tests/agent/handlers/nodesCrossSignClaim.test.ts b/tests/agent/handlers/nodesCrossSignClaim.test.ts index 03afc9ed8..a522864c9 100644 --- a/tests/agent/handlers/nodesCrossSignClaim.test.ts +++ b/tests/agent/handlers/nodesCrossSignClaim.test.ts @@ -136,7 +136,10 @@ describe('nodesCrossSignClaim', () => { verifyPeer: true, verifyAllowFail: true, }, - crypto, + crypto: { + key: keysUtils.generateKey(), + ops: crypto, + }, logger, }); const handleStream = async ( @@ -186,7 +189,9 @@ describe('nodesCrossSignClaim', () => { localNodeId = keysUtils.publicKeyToNodeId(clientKeyPair.publicKey); const tlsConfigClient = await tlsTestsUtils.createTLSConfig(clientKeyPair); quicClient = await QUICClient.createQUICClient({ - crypto, + crypto: { + ops: crypto, + }, config: { key: tlsConfigClient.keyPrivatePem, cert: tlsConfigClient.certChainPem, diff --git a/tests/agent/handlers/nodesHolePunchMessage.test.ts b/tests/agent/handlers/nodesHolePunchMessage.test.ts index 7171b5389..22d44aa6f 100644 --- a/tests/agent/handlers/nodesHolePunchMessage.test.ts +++ b/tests/agent/handlers/nodesHolePunchMessage.test.ts @@ -105,10 +105,7 @@ describe('nodesHolePunchMessage', () => { keyRing.keyPair, ); nodeConnectionManager = new NodeConnectionManager({ - quicClientConfig: { - key: tlsConfigClient.keyPrivatePem, - cert: tlsConfigClient.certChainPem, - }, + tlsConfig: tlsConfigClient, crypto, quicSocket, keyRing, @@ -128,7 +125,7 @@ describe('nodesHolePunchMessage', () => { logger, }); await nodeManager.start(); - await nodeConnectionManager.start({ nodeManager }); + await nodeConnectionManager.start({ nodeManager, handleStream: () => {} }); await taskManager.startProcessing(); // Setting up server @@ -153,7 +150,10 @@ describe('nodesHolePunchMessage', () => { verifyPeer: true, verifyAllowFail: true, }, - crypto, + crypto: { + key: keysUtils.generateKey(), + ops: crypto, + }, logger, }); const handleStream = async ( @@ -200,7 +200,9 @@ describe('nodesHolePunchMessage', () => { logger, }); quicClient = await QUICClient.createQUICClient({ - crypto, + crypto: { + ops: crypto, + }, config: { key: tlsConfigClient.keyPrivatePem, cert: tlsConfigClient.certChainPem, diff --git a/tests/agent/handlers/notificationsSend.test.ts b/tests/agent/handlers/notificationsSend.test.ts index a97552e0c..3409628d5 100644 --- a/tests/agent/handlers/notificationsSend.test.ts +++ b/tests/agent/handlers/notificationsSend.test.ts @@ -128,10 +128,7 @@ describe('notificationsSend', () => { keyRing.keyPair, ); nodeConnectionManager = new NodeConnectionManager({ - quicClientConfig: { - key: tlsConfigClient.keyPrivatePem, - cert: tlsConfigClient.certChainPem, - }, + tlsConfig: tlsConfigClient, crypto, quicSocket, keyRing, @@ -151,7 +148,7 @@ describe('notificationsSend', () => { logger, }); await nodeManager.start(); - await nodeConnectionManager.start({ nodeManager }); + await nodeConnectionManager.start({ nodeManager, handleStream: () => {} }); await taskManager.startProcessing(); notificationsManager = await NotificationsManager.createNotificationsManager({ @@ -183,7 +180,10 @@ describe('notificationsSend', () => { verifyPeer: true, verifyAllowFail: true, }, - crypto, + crypto: { + key: keysUtils.generateKey(), + ops: crypto, + }, logger, }); const handleStream = async ( @@ -230,7 +230,9 @@ describe('notificationsSend', () => { logger, }); quicClient = await QUICClient.createQUICClient({ - crypto, + crypto: { + ops: crypto, + }, config: { key: tlsConfigClient.keyPrivatePem, cert: tlsConfigClient.certChainPem, diff --git a/tests/client/handlers/agent.test.ts b/tests/client/handlers/agent.test.ts index 2decd3c78..a7192e624 100644 --- a/tests/client/handlers/agent.test.ts +++ b/tests/client/handlers/agent.test.ts @@ -200,8 +200,8 @@ describe('agentStatus', () => { nodeIdEncoded: nodesUtils.encodeNodeId(pkAgent.keyRing.getNodeId()), clientHost: pkAgent.webSocketServerClient.getHost(), clientPort: pkAgent.webSocketServerClient.getPort(), - agentHost: pkAgent.quicServerAgent.host, - agentPort: pkAgent.quicServerAgent.port, + agentHost: pkAgent.quicSocket.host, + agentPort: pkAgent.quicSocket.port, publicKeyJwk: keysUtils.publicKeyToJWK(pkAgent.keyRing.keyPair.publicKey), certChainPEM: await pkAgent.certManager.getCertPEMsChainPEM(), }); diff --git a/tests/client/handlers/gestalts.test.ts b/tests/client/handlers/gestalts.test.ts index 1fdfb9ef9..fa3b21e2f 100644 --- a/tests/client/handlers/gestalts.test.ts +++ b/tests/client/handlers/gestalts.test.ts @@ -448,12 +448,8 @@ describe('gestaltsDiscoverByIdentity', () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, nodeGraph, - quicClientConfig: { - // @ts-ignore: TLS not needed for this test - key: undefined, - // @ts-ignore: TLS not needed for this test - cert: undefined, - }, + // @ts-ignore: TLS not needed for this test + tlsConfig: {}, crypto, quicSocket, connectionConnectTime: 2000, @@ -471,7 +467,7 @@ describe('gestaltsDiscoverByIdentity', () => { logger, }); await nodeManager.start(); - await nodeConnectionManager.start({ nodeManager }); + await nodeConnectionManager.start({ nodeManager, handleStream: () => {} }); discovery = await Discovery.createDiscovery({ db, gestaltGraph, @@ -637,12 +633,8 @@ describe('gestaltsDiscoverByNode', () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, nodeGraph, - quicClientConfig: { - // @ts-ignore: TLS not needed for this test - key: undefined, - // @ts-ignore: TLS not needed for this test - cert: undefined, - }, + // @ts-ignore: TLS not needed for this test + tlsConfig: {}, crypto, quicSocket, connectionConnectTime: 2000, @@ -660,7 +652,7 @@ describe('gestaltsDiscoverByNode', () => { logger, }); await nodeManager.start(); - await nodeConnectionManager.start({ nodeManager }); + await nodeConnectionManager.start({ nodeManager, handleStream: () => {} }); discovery = await Discovery.createDiscovery({ db, gestaltGraph, @@ -1264,10 +1256,8 @@ describe('gestaltsGestaltTrustByIdentity', () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - }, + // @ts-ignore: TLS not needed for this test + tlsConfig: {}, crypto, quicSocket, connectionConnectTime: 2000, @@ -1285,7 +1275,7 @@ describe('gestaltsGestaltTrustByIdentity', () => { logger, }); await nodeManager.start(); - await nodeConnectionManager.start({ nodeManager }); + await nodeConnectionManager.start({ nodeManager, handleStream: () => {} }); discovery = await Discovery.createDiscovery({ db, gestaltGraph, @@ -1785,12 +1775,8 @@ describe('gestaltsGestaltTrustByNode', () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, nodeGraph, - quicClientConfig: { - // @ts-ignore: TLS not needed for this test - key: undefined, - // @ts-ignore: TLS not needed for this test - cert: undefined, - }, + // @ts-ignore: TLS not needed for this test + tlsConfig: {}, crypto, quicSocket, connectionConnectTime: 2000, @@ -1808,10 +1794,10 @@ describe('gestaltsGestaltTrustByNode', () => { logger, }); await nodeManager.start(); - await nodeConnectionManager.start({ nodeManager }); + await nodeConnectionManager.start({ nodeManager, handleStream: () => {} }); await nodeManager.setNode(nodeIdRemote, { - host: node.quicServerAgent.host as Host, - port: node.quicServerAgent.port as Port, + host: node.quicSocket.host as Host, + port: node.quicSocket.port as Port, }); discovery = await Discovery.createDiscovery({ db, diff --git a/tests/client/handlers/keys.test.ts b/tests/client/handlers/keys.test.ts index feccb6f8b..b76c84647 100644 --- a/tests/client/handlers/keys.test.ts +++ b/tests/client/handlers/keys.test.ts @@ -546,7 +546,7 @@ describe('keysKeyPairRenew', () => { const rootKeyPair1 = pkAgent.keyRing.keyPair; const nodeId1 = pkAgent.keyRing.getNodeId(); // @ts-ignore - get protected property - const config1 = pkAgent.quicServerAgent.config; + const config1 = pkAgent.nodeConnectionManager.quicServer.config; const fwdTLSConfig1 = { keyPrivatePem: config1.key, certChainPem: config1.cert, @@ -566,7 +566,7 @@ describe('keysKeyPairRenew', () => { const rootKeyPair2 = pkAgent.keyRing.keyPair; const nodeId2 = pkAgent.keyRing.getNodeId(); // @ts-ignore - get protected property - const config2 = pkAgent.quicServerAgent.config; + const config2 = pkAgent.nodeConnectionManager.quicServer.config; const fwdTLSConfig2 = { keyPrivatePem: config2.key, certChainPem: config2.cert, @@ -662,7 +662,7 @@ describe('keysKeyPairReset', () => { const rootKeyPair1 = pkAgent.keyRing.keyPair; const nodeId1 = pkAgent.keyRing.getNodeId(); // @ts-ignore - get protected property - const config1 = pkAgent.quicServerAgent.config; + const config1 = pkAgent.nodeConnectionManager.quicServer.config; const fwdTLSConfig1 = { keyPrivatePem: config1.key, certChainPem: config1.cert, @@ -682,7 +682,7 @@ describe('keysKeyPairReset', () => { const rootKeyPair2 = pkAgent.keyRing.keyPair; const nodeId2 = pkAgent.keyRing.getNodeId(); // @ts-ignore - get protected property - const config2 = pkAgent.quicServerAgent.config; + const config2 = pkAgent.nodeConnectionManager.quicServer.config; const fwdTLSConfig2 = { keyPrivatePem: config2.key, certChainPem: config2.cert, diff --git a/tests/client/handlers/nodes.test.ts b/tests/client/handlers/nodes.test.ts index 3eda8fe46..1d7477a07 100644 --- a/tests/client/handlers/nodes.test.ts +++ b/tests/client/handlers/nodes.test.ts @@ -103,12 +103,8 @@ describe('nodesAdd', () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, nodeGraph, - quicClientConfig: { - // @ts-ignore: TLS not needed for this test - key: undefined, - // @ts-ignore: TLS not needed for this test - cert: undefined, - }, + // @ts-ignore: TLS not needed for this test + tlsConfig: {}, crypto, quicSocket, connectionConnectTime: 2000, @@ -126,7 +122,7 @@ describe('nodesAdd', () => { logger, }); await nodeManager.start(); - await nodeConnectionManager.start({ nodeManager }); + await nodeConnectionManager.start({ nodeManager, handleStream: () => {} }); await taskManager.startProcessing(); }); afterEach(async () => { @@ -360,12 +356,8 @@ describe('nodesClaim', () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, nodeGraph, - quicClientConfig: { - // @ts-ignore: TLS not needed for this test - key: undefined, - // @ts-ignore: TLS not needed for this test - cert: undefined, - }, + // @ts-ignore: TLS not needed for this test + tlsConfig: {}, crypto, quicSocket, connectionConnectTime: 2000, @@ -383,7 +375,7 @@ describe('nodesClaim', () => { logger, }); await nodeManager.start(); - await nodeConnectionManager.start({ nodeManager }); + await nodeConnectionManager.start({ nodeManager, handleStream: () => {} }); await taskManager.startProcessing(); notificationsManager = await NotificationsManager.createNotificationsManager({ @@ -568,19 +560,18 @@ describe('nodesFind', () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, nodeGraph, - quicClientConfig: { - // @ts-ignore: TLS not needed for this test - key: undefined, - // @ts-ignore: TLS not needed for this test - cert: undefined, - }, + // @ts-ignore: TLS not needed for this test + tlsConfig: {}, crypto, quicSocket, connectionConnectTime: 2000, connectionTimeoutTime: 2000, logger: logger.getChild('NodeConnectionManager'), }); - await nodeConnectionManager.start({ nodeManager: {} as NodeManager }); + await nodeConnectionManager.start({ + nodeManager: {} as NodeManager, + handleStream: () => {}, + }); await taskManager.startProcessing(); }); afterEach(async () => { @@ -745,12 +736,8 @@ describe('nodesPing', () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, nodeGraph, - quicClientConfig: { - // @ts-ignore: TLS not needed for this test - key: undefined, - // @ts-ignore: TLS not needed for this test - cert: undefined, - }, + // @ts-ignore: TLS not needed for this test + tlsConfig: {}, crypto, quicSocket, connectionConnectTime: 2000, @@ -767,7 +754,7 @@ describe('nodesPing', () => { gestaltGraph: {} as GestaltGraph, logger, }); - await nodeConnectionManager.start({ nodeManager }); + await nodeConnectionManager.start({ nodeManager, handleStream: () => {} }); await taskManager.startProcessing(); }); afterEach(async () => { diff --git a/tests/client/handlers/notifications.test.ts b/tests/client/handlers/notifications.test.ts index e3ab5a8f4..da07a20ac 100644 --- a/tests/client/handlers/notifications.test.ts +++ b/tests/client/handlers/notifications.test.ts @@ -111,12 +111,8 @@ describe('notificationsClear', () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, nodeGraph, - quicClientConfig: { - // @ts-ignore: TLS not needed for this test - key: undefined, - // @ts-ignore: TLS not needed for this test - cert: undefined, - }, + // @ts-ignore: TLS not needed for this test + tlsConfig: {}, crypto, quicSocket, connectionConnectTime: 2000, @@ -134,7 +130,7 @@ describe('notificationsClear', () => { logger, }); await nodeManager.start(); - await nodeConnectionManager.start({ nodeManager }); + await nodeConnectionManager.start({ nodeManager, handleStream: () => {} }); await taskManager.startProcessing(); notificationsManager = await NotificationsManager.createNotificationsManager({ @@ -281,12 +277,8 @@ describe('notificationsRead', () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, nodeGraph, - quicClientConfig: { - // @ts-ignore: TLS not needed for this test - key: undefined, - // @ts-ignore: TLS not needed for this test - cert: undefined, - }, + // @ts-ignore: TLS not needed for this test + tlsConfig: {}, crypto, quicSocket, connectionConnectTime: 2000, @@ -304,7 +296,7 @@ describe('notificationsRead', () => { logger, }); await nodeManager.start(); - await nodeConnectionManager.start({ nodeManager }); + await nodeConnectionManager.start({ nodeManager, handleStream: () => {} }); await taskManager.start(); notificationsManager = await NotificationsManager.createNotificationsManager({ @@ -833,12 +825,8 @@ describe('notificationsSend', () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, nodeGraph, - quicClientConfig: { - // @ts-ignore: TLS not needed for this test - key: undefined, - // @ts-ignore: TLS not needed for this test - cert: undefined, - }, + // @ts-ignore: TLS not needed for this test + tlsConfig: {}, crypto, quicSocket, connectionConnectTime: 2000, @@ -856,7 +844,7 @@ describe('notificationsSend', () => { logger, }); await nodeManager.start(); - await nodeConnectionManager.start({ nodeManager }); + await nodeConnectionManager.start({ nodeManager, handleStream: () => {} }); await taskManager.start(); notificationsManager = await NotificationsManager.createNotificationsManager({ diff --git a/tests/discovery/Discovery.test.ts b/tests/discovery/Discovery.test.ts index 0d1de554d..1d81ed52f 100644 --- a/tests/discovery/Discovery.test.ts +++ b/tests/discovery/Discovery.test.ts @@ -164,10 +164,7 @@ describe('Discovery', () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - }, + tlsConfig, crypto, quicSocket, connectionConnectTime: 2000, @@ -185,7 +182,7 @@ describe('Discovery', () => { logger, }); await nodeManager.start(); - await nodeConnectionManager.start({ nodeManager }); + await nodeConnectionManager.start({ nodeManager, handleStream: () => {} }); // Set up other gestalt nodeA = await PolykeyAgent.createPolykeyAgent({ password: password, @@ -219,8 +216,8 @@ describe('Discovery', () => { nodeIdB = nodeB.keyRing.getNodeId(); await testNodesUtils.nodesConnect(nodeA, nodeB); await nodeGraph.setNode(nodeA.keyRing.getNodeId(), { - host: nodeA.quicServerAgent.host as Host, - port: nodeA.quicServerAgent.port as Port, + host: nodeA.quicSocket.host as Host, + port: nodeA.quicSocket.port as Port, }); await nodeB.acl.setNodeAction(nodeA.keyRing.getNodeId(), 'claim'); await nodeA.nodeManager.claimNode(nodeB.keyRing.getNodeId()); diff --git a/tests/global.d.ts b/tests/global.d.ts index ecd25dd85..1db39a105 100644 --- a/tests/global.d.ts +++ b/tests/global.d.ts @@ -10,9 +10,5 @@ declare var projectDir: string; declare var testDir: string; declare var dataDir: string; declare var defaultTimeout: number; -declare var polykeyStartupTimeout: number; declare var failedConnectionTimeout: number; declare var maxTimeout: number; -declare var testCmd: string | undefined; -declare var testPlatform: string; -declare var tmpDir: string; diff --git a/tests/nodes/NodeConnection.test.ts b/tests/nodes/NodeConnection.test.ts index 388f9619a..bc3c2ecf2 100644 --- a/tests/nodes/NodeConnection.test.ts +++ b/tests/nodes/NodeConnection.test.ts @@ -1,6 +1,8 @@ import type { Host, Port, TLSConfig } from '@/network/types'; import type * as quicEvents from '@matrixai/quic/dist/events'; import type { NodeId, NodeIdEncoded } from '@/ids'; +import type { RPCStream } from '@/rpc/types'; +import type { CertificatePEM } from '@/keys/types'; import { QUICServer, QUICSocket } from '@matrixai/quic'; import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; import { errors as quicErrors } from '@matrixai/quic'; @@ -9,7 +11,7 @@ import * as nodesUtils from '@/nodes/utils'; import * as keysUtils from '@/keys/utils'; import RPCServer from '@/rpc/RPCServer'; import NodeConnection from '@/nodes/NodeConnection'; -import { promise } from '@/utils'; +import { never, promise } from '@/utils'; import * as networkUtils from '@/network/utils'; import * as tlsTestUtils from '../utils/tls'; @@ -32,7 +34,15 @@ describe(`${NodeConnection.name}`, () => { let rpcServer: RPCServer; let clientSocket: QUICSocket; - let nodeConnection_: NodeConnection; + const nodeConnections: Array> = []; + /** + * Adds created nodeConnections to the `nodeConnections` array for automated cleanup. + * @param nc + */ + const extractNodeConnection = (nc: NodeConnection) => { + nodeConnections.push(nc); + return nc; + }; const handleStream = async (event: quicEvents.QUICConnectionStreamEvent) => { // Streams are handled via the RPCServer. @@ -79,7 +89,10 @@ describe(`${NodeConnection.name}`, () => { verifyAllowFail: true, }, verifyCallback: networkUtils.verifyClientCertificateChain, - crypto, + crypto: { + key: keysUtils.generateKey(), + ops: crypto, + }, socket: serverSocket, logger: logger.getChild(`${QUICServer.name}`), }); @@ -111,7 +124,7 @@ describe(`${NodeConnection.name}`, () => { }); afterEach(async () => { - await nodeConnection_?.destroy({ force: true }); + await Promise.all(nodeConnections.map((nc) => nc.destroy({ force: true }))); await clientSocket.stop({ force: true }); await rpcServer.destroy(true); await quicServer.stop({ force: true }); // Ignore errors due to socket already stopped @@ -120,64 +133,49 @@ describe(`${NodeConnection.name}`, () => { test('session readiness', async () => { const nodeConnection = await NodeConnection.createNodeConnection({ + handleStream: () => {}, targetNodeIds: [serverNodeId], targetHost: localHost as Host, targetPort: quicServer.port as Port, manifest: {}, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - }, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, logger: logger.getChild(`${NodeConnection.name}`), - }).then((n) => { - nodeConnection_ = n; - return n; - }); + }).then(extractNodeConnection); await nodeConnection.destroy(); // Should be a noop await nodeConnection.destroy(); }); test('connects to the target', async () => { await NodeConnection.createNodeConnection({ + handleStream: () => {}, targetNodeIds: [serverNodeId], targetHost: localHost as Host, targetPort: quicServer.port as Port, manifest: {}, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - }, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, logger: logger.getChild(`${NodeConnection.name}`), - }).then((n) => { - nodeConnection_ = n; - return n; - }); + }).then(extractNodeConnection); }); test('connection fails to target (times out)', async () => { const nodeConnectionProm = NodeConnection.createNodeConnection( { + handleStream: () => {}, targetNodeIds: [serverNodeId], targetHost: localHost as Host, targetPort: 12345 as Port, manifest: {}, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - maxIdleTimeout: 1000, - }, + connectionMaxIdleTimeout: 1000, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, logger: logger.getChild(`${NodeConnection.name}`), }, { timer: 100 }, - ).then((n) => { - nodeConnection_ = n; - return n; - }); + ).then(extractNodeConnection); await expect(nodeConnectionProm).rejects.toThrow( // QuicErrors.ErrorQUICClientCreateTimeOut, // FIXME: this is not throwing the right error ErrorContextsTimedTimeOut, @@ -186,24 +184,19 @@ describe(`${NodeConnection.name}`, () => { test('connection drops out (socket stops responding)', async () => { const nodeConnection = await NodeConnection.createNodeConnection( { + handleStream: () => {}, targetNodeIds: [serverNodeId], targetHost: localHost as Host, targetPort: quicServer.port as Port, manifest: {}, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - maxIdleTimeout: 100, - }, + connectionMaxIdleTimeout: 100, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, logger: logger.getChild(`${NodeConnection.name}`), }, { timer: 100 }, - ).then((n) => { - nodeConnection_ = n; - return n; - }); + ).then(extractNodeConnection); const destroyProm = promise(); nodeConnection.addEventListener( 'destroy', @@ -219,109 +212,83 @@ describe(`${NodeConnection.name}`, () => { }); test('get the root chain cert', async () => { const nodeConnection = await NodeConnection.createNodeConnection({ + handleStream: () => {}, targetNodeIds: [serverNodeId], targetHost: localHost as Host, targetPort: quicServer.port as Port, manifest: {}, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - }, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, logger: logger.getChild(`${NodeConnection.name}`), - }).then((n) => { - nodeConnection_ = n; - return n; - }); + }).then(extractNodeConnection); const remoteCertPem = keysUtils.certToPEM(nodeConnection.certChain[0]); expect(remoteCertPem).toEqual(serverTlsConfig.certChainPem); }); test('get the NodeId', async () => { const nodeConnection = await NodeConnection.createNodeConnection({ + handleStream: () => {}, targetNodeIds: [serverNodeId], targetHost: localHost as Host, targetPort: quicServer.port as Port, manifest: {}, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - }, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, logger: logger.getChild(`${NodeConnection.name}`), - }).then((n) => { - nodeConnection_ = n; - return n; - }); + }).then(extractNodeConnection); expect(serverNodeIdEncoded).toEqual( nodesUtils.encodeNodeId(nodeConnection.nodeId), ); }); test('Should fail due to server rejecting client certificate (no certs)', async () => { const nodeConnectionProm = NodeConnection.createNodeConnection({ + handleStream: () => {}, targetNodeIds: [serverNodeId], targetHost: localHost as Host, targetPort: quicServer.port as Port, manifest: {}, - quicClientConfig: { - // @ts-ignore: TLS not used for this test - key: undefined, - // @ts-ignore: TLS not used for this test - cert: undefined, - }, + // @ts-ignore: TLS not used for this test + tlsConfig: {}, crypto, quicSocket: clientSocket, logger: logger.getChild(`${NodeConnection.name}`), - }).then((n) => { - nodeConnection_ = n; - return n; - }); + }).then(extractNodeConnection); await expect(nodeConnectionProm).rejects.toThrow( quicErrors.ErrorQUICConnectionInternal, ); }); test('Should fail due to client rejecting server certificate (missing NodeId)', async () => { const nodeConnectionProm = NodeConnection.createNodeConnection({ + handleStream: () => {}, targetNodeIds: [clientNodeId], targetHost: localHost as Host, targetPort: quicServer.port as Port, manifest: {}, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - }, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, logger: logger.getChild(`${NodeConnection.name}`), - }).then((n) => { - nodeConnection_ = n; - return n; - }); + }).then(extractNodeConnection); await expect(nodeConnectionProm).rejects.toThrow(); }); test('Should fail and destroy due to connection failure', async () => { const nodeConnection = await NodeConnection.createNodeConnection( { + handleStream: () => {}, targetNodeIds: [serverNodeId], targetHost: localHost as Host, targetPort: quicServer.port as Port, manifest: {}, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - keepAliveIntervalTime: 100, - maxIdleTimeout: 200, - }, + connectionKeepAliveIntervalTime: 100, + connectionMaxIdleTimeout: 200, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, logger: logger.getChild(`${NodeConnection.name}`), }, { timer: 150 }, - ).then((n) => { - nodeConnection_ = n; - return n; - }); + ).then(extractNodeConnection); const destroyProm = promise(); nodeConnection.addEventListener('destroy', () => { destroyProm.resolveP(); @@ -332,30 +299,25 @@ describe(`${NodeConnection.name}`, () => { test('Should fail and destroy due to connection ending local', async () => { const nodeConnection = await NodeConnection.createNodeConnection( { + handleStream: () => {}, targetNodeIds: [serverNodeId], targetHost: localHost as Host, targetPort: quicServer.port as Port, manifest: {}, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - maxIdleTimeout: 200, - keepAliveIntervalTime: 100, - }, + connectionMaxIdleTimeout: 200, + connectionKeepAliveIntervalTime: 100, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, logger: logger.getChild(`${NodeConnection.name}`), }, { timer: 150 }, - ).then((n) => { - nodeConnection_ = n; - return n; - }); + ).then(extractNodeConnection); const destroyProm = promise(); nodeConnection.addEventListener('destroy', () => { destroyProm.resolveP(); }); - await nodeConnection.quicClient.connection.stop({ + await nodeConnection.quicConnection.stop({ applicationError: true, errorCode: 0, force: false, @@ -365,25 +327,20 @@ describe(`${NodeConnection.name}`, () => { test('Should fail and destroy due to connection ending remote', async () => { const nodeConnection = await NodeConnection.createNodeConnection( { + handleStream: () => {}, targetNodeIds: [serverNodeId], targetHost: localHost as Host, targetPort: quicServer.port as Port, manifest: {}, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - maxIdleTimeout: 200, - keepAliveIntervalTime: 100, - }, + connectionMaxIdleTimeout: 200, + connectionKeepAliveIntervalTime: 100, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, logger: logger.getChild(`${NodeConnection.name}`), }, { timer: 150 }, - ).then((n) => { - nodeConnection_ = n; - return n; - }); + ).then(extractNodeConnection); const destroyProm = promise(); nodeConnection.addEventListener('destroy', () => { destroyProm.resolveP(); @@ -397,4 +354,121 @@ describe(`${NodeConnection.name}`, () => { }); await destroyProm.p; }); + test('should wrap reverse connection', async () => { + const nodeConnectionReverseProm = promise>(); + quicServer.removeEventListener('serverConnection', handleConnection); + quicServer.addEventListener( + 'serverConnection', + async (event: quicEvents.QUICServerConnectionEvent) => { + const quicConnection = event.detail; + const certChain = quicConnection.getRemoteCertsChain().map((pem) => { + const cert = keysUtils.certFromPEM(pem as CertificatePEM); + if (cert == null) never(); + return cert; + }); + if (certChain == null) never(); + const nodeId = keysUtils.certNodeId(certChain[0]); + if (nodeId == null) never(); + const nodeConnection = await NodeConnection.createNodeConnectionReverse( + { + handleStream: () => {}, + nodeId, + certChain, + manifest: {}, + quicConnection, + logger, + }, + ).then(extractNodeConnection); + nodeConnectionReverseProm.resolveP(nodeConnection); + }, + { once: true }, + ); + const nodeConnection = await NodeConnection.createNodeConnection({ + handleStream: () => {}, + targetNodeIds: [serverNodeId], + targetHost: localHost as Host, + targetPort: quicServer.port as Port, + manifest: {}, + tlsConfig: clientTlsConfig, + crypto, + quicSocket: clientSocket, + logger: logger.getChild(`${NodeConnection.name}`), + }).then(extractNodeConnection); + const nodeConnectionReverse = await nodeConnectionReverseProm.p; + const nodeConnectionDestroyProm = promise(); + nodeConnection.addEventListener( + 'destroy', + () => nodeConnectionDestroyProm.resolveP(), + { once: true }, + ); + await nodeConnectionReverse.destroy({ force: true }); + await nodeConnectionDestroyProm.p; + }); + test('should handle reverse streams', async () => { + const nodeConnectionReverseProm = promise>(); + const reverseStreamProm = promise>(); + quicServer.removeEventListener('serverConnection', handleConnection); + quicServer.addEventListener( + 'serverConnection', + async (event: quicEvents.QUICServerConnectionEvent) => { + const quicConnection = event.detail; + const certChain = quicConnection.getRemoteCertsChain().map((pem) => { + const cert = keysUtils.certFromPEM(pem as CertificatePEM); + if (cert == null) never(); + return cert; + }); + if (certChain == null) never(); + const nodeId = keysUtils.certNodeId(certChain[0]); + if (nodeId == null) never(); + const nodeConnection = await NodeConnection.createNodeConnectionReverse( + { + handleStream: (stream) => { + reverseStreamProm.resolveP(stream); + }, + nodeId, + certChain, + manifest: {}, + quicConnection, + logger, + }, + ).then(extractNodeConnection); + nodeConnectionReverseProm.resolveP(nodeConnection); + }, + { once: true }, + ); + const forwardStreamProm = promise>(); + const nodeConnection = await NodeConnection.createNodeConnection({ + handleStream: (stream) => forwardStreamProm.resolveP(stream), + targetNodeIds: [serverNodeId], + targetHost: localHost as Host, + targetPort: quicServer.port as Port, + manifest: {}, + tlsConfig: clientTlsConfig, + crypto, + quicSocket: clientSocket, + logger: logger.getChild(`${NodeConnection.name}`), + }).then(extractNodeConnection); + const nodeConnectionReverse = await nodeConnectionReverseProm.p; + + // Checking stream creation + const forwardStream = await nodeConnection.quicConnection.streamNew(); + const writer1 = forwardStream.writable.getWriter(); + await writer1.write(Buffer.from('Hello!')); + await reverseStreamProm.p; + + const reverseStream = + await nodeConnectionReverse.quicConnection.streamNew(); + const writer2 = reverseStream.writable.getWriter(); + await writer2.write(Buffer.from('Hello!')); + await forwardStreamProm.p; + + const nodeConnectionDestroyProm = promise(); + nodeConnection.addEventListener( + 'destroy', + () => nodeConnectionDestroyProm.resolveP(), + { once: true }, + ); + await nodeConnectionReverse.destroy({ force: true }); + await nodeConnectionDestroyProm.p; + }); }); diff --git a/tests/nodes/NodeConnectionManager.general.test.ts b/tests/nodes/NodeConnectionManager.general.test.ts index c4ecaac72..43da9bc1d 100644 --- a/tests/nodes/NodeConnectionManager.general.test.ts +++ b/tests/nodes/NodeConnectionManager.general.test.ts @@ -92,6 +92,8 @@ describe(`${NodeConnectionManager.name} general test`, () => { let nodeManager: NodeManager; let nodeConnectionManager: NodeConnectionManager; + // Default stream handler, just drop the stream + const handleStream = () => {}; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -195,10 +197,7 @@ describe(`${NodeConnectionManager.name} general test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - }, + tlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -216,6 +215,7 @@ describe(`${NodeConnectionManager.name} general test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); @@ -238,12 +238,9 @@ describe(`${NodeConnectionManager.name} general test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - maxIdleTimeout: 10000, - keepAliveIntervalTime: 1000, - }, + connectionMaxIdleTimeout: 10000, + connectionKeepAliveIntervalTime: 1000, + tlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -261,6 +258,7 @@ describe(`${NodeConnectionManager.name} general test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); // Mocking pinging to always return true @@ -294,12 +292,9 @@ describe(`${NodeConnectionManager.name} general test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - maxIdleTimeout: 10000, - keepAliveIntervalTime: 1000, - }, + connectionMaxIdleTimeout: 10000, + connectionKeepAliveIntervalTime: 1000, + tlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -317,6 +312,7 @@ describe(`${NodeConnectionManager.name} general test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); // Mocking pinging to always return true @@ -343,12 +339,9 @@ describe(`${NodeConnectionManager.name} general test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - maxIdleTimeout: 10000, - keepAliveIntervalTime: 1000, - }, + connectionMaxIdleTimeout: 10000, + connectionKeepAliveIntervalTime: 1000, + tlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -366,6 +359,7 @@ describe(`${NodeConnectionManager.name} general test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); // Mocking pinging to always return true @@ -427,12 +421,9 @@ describe(`${NodeConnectionManager.name} general test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - maxIdleTimeout: 10000, - keepAliveIntervalTime: 1000, - }, + connectionMaxIdleTimeout: 10000, + connectionKeepAliveIntervalTime: 1000, + tlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -450,6 +441,7 @@ describe(`${NodeConnectionManager.name} general test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); // Mocking pinging to always return true @@ -511,12 +503,9 @@ describe(`${NodeConnectionManager.name} general test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - maxIdleTimeout: 10000, - keepAliveIntervalTime: 1000, - }, + connectionMaxIdleTimeout: 10000, + connectionKeepAliveIntervalTime: 1000, + tlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -534,6 +523,7 @@ describe(`${NodeConnectionManager.name} general test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); @@ -577,12 +567,9 @@ describe(`${NodeConnectionManager.name} general test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - maxIdleTimeout: 10000, - keepAliveIntervalTime: 1000, - }, + connectionMaxIdleTimeout: 10000, + connectionKeepAliveIntervalTime: 1000, + tlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -600,6 +587,7 @@ describe(`${NodeConnectionManager.name} general test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); @@ -641,4 +629,5 @@ describe(`${NodeConnectionManager.name} general test`, () => { await nodeConnectionManager.stop(); }); + test.todo('Handles reverse streams'); }); diff --git a/tests/nodes/NodeConnectionManager.lifecycle.test.ts b/tests/nodes/NodeConnectionManager.lifecycle.test.ts index b33ab7338..8cc9ddd7e 100644 --- a/tests/nodes/NodeConnectionManager.lifecycle.test.ts +++ b/tests/nodes/NodeConnectionManager.lifecycle.test.ts @@ -57,6 +57,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { let nodeManager: NodeManager; let nodeConnectionManager: NodeConnectionManager; + const handleStream = () => {}; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -80,7 +81,10 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { key: serverTlsConfig.keyPrivatePem, cert: serverTlsConfig.certChainPem, }, - crypto, + crypto: { + key: keysUtils.generateKey(), + ops: crypto, + }, socket: serverSocket, logger: logger.getChild(`${QUICServer.name}`), }); @@ -172,11 +176,8 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - }, crypto, + tlsConfig: clientTlsConfig, quicSocket: clientSocket, seedNodes: undefined, }); @@ -193,6 +194,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await nodeConnectionManager.stop(); @@ -204,10 +206,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - }, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -225,6 +224,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); @@ -240,10 +240,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - }, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -261,6 +258,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); @@ -276,10 +274,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - }, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -297,6 +292,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); @@ -320,10 +316,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - }, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -341,6 +334,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); // @ts-ignore: kidnap protected property @@ -364,10 +358,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - }, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -385,6 +376,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); // @ts-ignore: kidnap protected property @@ -416,10 +408,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - }, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -437,6 +426,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); await nodeConnectionManager.withConnF(serverNodeId, async () => { @@ -458,10 +448,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - }, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -479,6 +466,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); const waitProm = promise(); @@ -508,10 +496,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - }, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -529,6 +514,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); await nodeConnectionManager.withConnF(serverNodeId, async () => { @@ -557,10 +543,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - }, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -578,6 +561,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); await nodeConnectionManager.withConnF(serverNodeId, async () => { @@ -598,10 +582,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - }, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -619,6 +600,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); const result = await nodeConnectionManager.pingNode( @@ -636,10 +618,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - }, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -657,6 +636,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); const result = await nodeConnectionManager.pingNode( @@ -675,10 +655,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: clientTlsConfig.keyPrivatePem, - cert: clientTlsConfig.certChainPem, - }, + tlsConfig: clientTlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -696,6 +673,7 @@ describe(`${NodeConnectionManager.name} lifecycle test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); const result = await nodeConnectionManager.pingNode( diff --git a/tests/nodes/NodeConnectionManager.seednodes.test.ts b/tests/nodes/NodeConnectionManager.seednodes.test.ts index ab25e7cb7..e8e5759d0 100644 --- a/tests/nodes/NodeConnectionManager.seednodes.test.ts +++ b/tests/nodes/NodeConnectionManager.seednodes.test.ts @@ -21,7 +21,6 @@ import Sigchain from '@/sigchain/Sigchain'; import TaskManager from '@/tasks/TaskManager'; import NodeManager from '@/nodes/NodeManager'; import PolykeyAgent from '@/PolykeyAgent'; -import { sleep } from '@/utils'; import * as testNodesUtils from './utils'; import * as tlsTestUtils from '../utils/tls'; @@ -63,6 +62,7 @@ describe(`${NodeConnectionManager.name} seednodes test`, () => { let taskManager: TaskManager; let nodeManager: NodeManager; let tlsConfig: TLSConfig; + const handleStream = () => {}; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -195,10 +195,7 @@ describe(`${NodeConnectionManager.name} seednodes test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - }, + tlsConfig, crypto, quicSocket: clientSocket, seedNodes: dummySeedNodes, @@ -216,6 +213,7 @@ describe(`${NodeConnectionManager.name} seednodes test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); @@ -236,11 +234,8 @@ describe(`${NodeConnectionManager.name} seednodes test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - keepAliveIntervalTime: 1000, - }, + connectionKeepAliveIntervalTime: 1000, + tlsConfig, crypto, quicSocket: clientSocket, seedNodes: { @@ -264,6 +259,7 @@ describe(`${NodeConnectionManager.name} seednodes test`, () => { await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); @@ -284,12 +280,9 @@ describe(`${NodeConnectionManager.name} seednodes test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - maxIdleTimeout: 1000, - keepAliveIntervalTime: 500, - }, + connectionMaxIdleTimeout: 1000, + connectionKeepAliveIntervalTime: 500, + tlsConfig, crypto, quicSocket: clientSocket, seedNodes: { @@ -309,13 +302,13 @@ describe(`${NodeConnectionManager.name} seednodes test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); await remotePolykeyAgent1.nodeGraph.setNode(remoteNodeId2, remoteAddress2); await nodeManager.syncNodeGraph(true, 100); - await sleep(1000); expect(mockedRefreshBucket).toHaveBeenCalled(); await nodeConnectionManager.stop(); @@ -336,11 +329,8 @@ describe(`${NodeConnectionManager.name} seednodes test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - keepAliveIntervalTime: 1000, - }, + connectionKeepAliveIntervalTime: 1000, + tlsConfig, crypto, quicSocket: clientSocket, seedNodes: { @@ -360,6 +350,7 @@ describe(`${NodeConnectionManager.name} seednodes test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); @@ -384,11 +375,8 @@ describe(`${NodeConnectionManager.name} seednodes test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - keepAliveIntervalTime: 1000, - }, + connectionKeepAliveIntervalTime: 1000, + tlsConfig, crypto, quicSocket: clientSocket, seedNodes: { @@ -408,6 +396,7 @@ describe(`${NodeConnectionManager.name} seednodes test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); @@ -427,11 +416,8 @@ describe(`${NodeConnectionManager.name} seednodes test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - keepAliveIntervalTime: 1000, - }, + connectionKeepAliveIntervalTime: 1000, + tlsConfig, crypto, quicSocket: clientSocket, seedNodes: { @@ -451,6 +437,7 @@ describe(`${NodeConnectionManager.name} seednodes test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); diff --git a/tests/nodes/NodeConnectionManager.timeout.test.ts b/tests/nodes/NodeConnectionManager.timeout.test.ts index 2802ecf84..f2b3dc3df 100644 --- a/tests/nodes/NodeConnectionManager.timeout.test.ts +++ b/tests/nodes/NodeConnectionManager.timeout.test.ts @@ -53,6 +53,7 @@ describe(`${NodeConnectionManager.name} timeout test`, () => { let taskManager: TaskManager; let nodeManager: NodeManager; let tlsConfig: TLSConfig; + const handleStream = () => {}; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -153,10 +154,7 @@ describe(`${NodeConnectionManager.name} timeout test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - }, + tlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -175,6 +173,7 @@ describe(`${NodeConnectionManager.name} timeout test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); @@ -210,10 +209,7 @@ describe(`${NodeConnectionManager.name} timeout test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - }, + tlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -232,6 +228,7 @@ describe(`${NodeConnectionManager.name} timeout test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); @@ -282,10 +279,7 @@ describe(`${NodeConnectionManager.name} timeout test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - }, + tlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -304,6 +298,7 @@ describe(`${NodeConnectionManager.name} timeout test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); @@ -338,10 +333,7 @@ describe(`${NodeConnectionManager.name} timeout test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - }, + tlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -361,6 +353,7 @@ describe(`${NodeConnectionManager.name} timeout test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); @@ -380,10 +373,7 @@ describe(`${NodeConnectionManager.name} timeout test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - }, + tlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -403,6 +393,7 @@ describe(`${NodeConnectionManager.name} timeout test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); @@ -428,10 +419,7 @@ describe(`${NodeConnectionManager.name} timeout test`, () => { keyRing, logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - }, + tlsConfig, crypto, quicSocket: clientSocket, seedNodes: undefined, @@ -451,6 +439,7 @@ describe(`${NodeConnectionManager.name} timeout test`, () => { await nodeManager.start(); await nodeConnectionManager.start({ nodeManager, + handleStream, }); await taskManager.startProcessing(); diff --git a/tests/nodes/NodeGraph.test.ts b/tests/nodes/NodeGraph.test.ts index a0765f7a8..01717a5f4 100644 --- a/tests/nodes/NodeGraph.test.ts +++ b/tests/nodes/NodeGraph.test.ts @@ -358,17 +358,17 @@ describe(`${NodeGraph.name} test`, () => { ).toBe(true); // Sort by lastUpdated asc bucket = await nodeGraph.getBucket(bucketIndex, 'lastUpdated', 'asc'); - let bucketLastUpdateds = bucket.map(([, nodeData]) => nodeData.lastUpdated); + let bucketLastUpdated = bucket.map(([, nodeData]) => nodeData.lastUpdated); expect( - bucketLastUpdateds.slice(1).every((lastUpdated, i) => { - return bucketLastUpdateds[i] <= lastUpdated; + bucketLastUpdated.slice(1).every((lastUpdated, i) => { + return bucketLastUpdated[i] <= lastUpdated; }), ).toBe(true); bucket = await nodeGraph.getBucket(bucketIndex, 'lastUpdated', 'desc'); - bucketLastUpdateds = bucket.map(([, nodeData]) => nodeData.lastUpdated); + bucketLastUpdated = bucket.map(([, nodeData]) => nodeData.lastUpdated); expect( - bucketLastUpdateds.slice(1).every((lastUpdated, i) => { - return bucketLastUpdateds[i] >= lastUpdated; + bucketLastUpdated.slice(1).every((lastUpdated, i) => { + return bucketLastUpdated[i] >= lastUpdated; }), ).toBe(true); await nodeGraph.stop(); @@ -526,12 +526,12 @@ describe(`${NodeGraph.name} test`, () => { expect(nodeData.address.port < 2 ** 16).toBe(true); expect(nodeData.lastUpdated >= now).toBe(true); } - const bucketLastUpdateds = bucket.map( + const bucketLastUpdated = bucket.map( ([, nodeData]) => nodeData.lastUpdated, ); expect( - bucketLastUpdateds.slice(1).every((lastUpdated, i) => { - return bucketLastUpdateds[i] <= lastUpdated; + bucketLastUpdated.slice(1).every((lastUpdated, i) => { + return bucketLastUpdated[i] <= lastUpdated; }), ).toBe(true); } @@ -556,12 +556,12 @@ describe(`${NodeGraph.name} test`, () => { expect(nodeData.address.port < 2 ** 16).toBe(true); expect(nodeData.lastUpdated >= now).toBe(true); } - const bucketLastUpdateds = bucket.map( + const bucketLastUpdated = bucket.map( ([, nodeData]) => nodeData.lastUpdated, ); expect( - bucketLastUpdateds.slice(1).every((lastUpdated, i) => { - return bucketLastUpdateds[i] >= lastUpdated; + bucketLastUpdated.slice(1).every((lastUpdated, i) => { + return bucketLastUpdated[i] >= lastUpdated; }), ).toBe(true); } @@ -661,8 +661,8 @@ describe(`${NodeGraph.name} test`, () => { ); testProp( 'reset buckets should re-order the buckets', - [testNodesUtils.uniqueNodeIdArb(2), testNodesUtils.nodeIdArrayArb(50)], - async (nodeIds, initialNodes) => { + [testNodesUtils.uniqueNodeIdArb(2)], + async (nodeIds) => { const getNodeIdMock = jest.fn(); const dummyKeyRing = { getNodeId: getNodeIdMock, @@ -674,7 +674,8 @@ describe(`${NodeGraph.name} test`, () => { fresh: true, logger, }); - for (const nodeId of initialNodes) { + for (let i = 1; i < 255 / 25; i += 50) { + const nodeId = nodesUtils.generateRandomNodeIdForBucket(nodeIds[0], i); await nodeGraph.setNode(nodeId, { host: '127.0.0.1', port: utils.getRandomInt(0, 2 ** 16), @@ -866,7 +867,7 @@ describe(`${NodeGraph.name} test`, () => { nodesUtils.bucketSortByDistance(nodeIds, targetNodeId); const a = nodeIds.map((a) => nodesUtils.encodeNodeId(a[0])); const b = result.map((a) => nodesUtils.encodeNodeId(a[0])); - // Are the closest nodes out of all of the nodes + // Are the closest nodes out of all the nodes expect(a.slice(0, b.length)).toEqual(b); // Check that the list is strictly ascending @@ -910,7 +911,7 @@ describe(`${NodeGraph.name} test`, () => { nodesUtils.bucketSortByDistance(nodeIds, targetNodeId); const a = nodeIds.map((a) => nodesUtils.encodeNodeId(a[0])); const b = result.map((a) => nodesUtils.encodeNodeId(a[0])); - // Are the closest nodes out of all of the nodes + // Are the closest nodes out of all the nodes expect(a.slice(0, b.length)).toEqual(b); // Check that the list is strictly ascending @@ -954,7 +955,7 @@ describe(`${NodeGraph.name} test`, () => { nodesUtils.bucketSortByDistance(nodeIds, targetNodeId); const a = nodeIds.map((a) => nodesUtils.encodeNodeId(a[0])); const b = result.map((a) => nodesUtils.encodeNodeId(a[0])); - // Are the closest nodes out of all of the nodes + // Are the closest nodes out of all the nodes expect(a.slice(0, b.length)).toEqual(b); // Check that the list is strictly ascending @@ -998,7 +999,7 @@ describe(`${NodeGraph.name} test`, () => { nodesUtils.bucketSortByDistance(nodeIds, targetNodeId); const a = nodeIds.map((a) => nodesUtils.encodeNodeId(a[0])); const b = result.map((a) => nodesUtils.encodeNodeId(a[0])); - // Are the closest nodes out of all of the nodes + // Are the closest nodes out of all the nodes expect(a.slice(0, b.length)).toEqual(b); // Check that the list is strictly ascending @@ -1042,7 +1043,7 @@ describe(`${NodeGraph.name} test`, () => { nodesUtils.bucketSortByDistance(nodeIds, targetNodeId); const a = nodeIds.map((a) => nodesUtils.encodeNodeId(a[0])); const b = result.map((a) => nodesUtils.encodeNodeId(a[0])); - // Are the closest nodes out of all of the nodes + // Are the closest nodes out of all the nodes expect(a.slice(0, b.length)).toEqual(b); // Check that the list is strictly ascending @@ -1086,7 +1087,7 @@ describe(`${NodeGraph.name} test`, () => { nodesUtils.bucketSortByDistance(nodeIds, targetNodeId); const a = nodeIds.map((a) => nodesUtils.encodeNodeId(a[0])); const b = result.map((a) => nodesUtils.encodeNodeId(a[0])); - // Are the closest nodes out of all of the nodes + // Are the closest nodes out of all the nodes expect(a.slice(0, b.length)).toEqual(b); // Check that the list is strictly ascending @@ -1130,7 +1131,7 @@ describe(`${NodeGraph.name} test`, () => { nodesUtils.bucketSortByDistance(nodeIds, targetNodeId); const a = nodeIds.map((a) => nodesUtils.encodeNodeId(a[0])); const b = result.map((a) => nodesUtils.encodeNodeId(a[0])); - // Are the closest nodes out of all of the nodes + // Are the closest nodes out of all the nodes expect(a.slice(0, b.length)).toEqual(b); // Check that the list is strictly ascending @@ -1161,7 +1162,7 @@ describe(`${NodeGraph.name} test`, () => { nodesUtils.bucketSortByDistance(nodeIds, targetNodeId); const a = nodeIds.map((a) => nodesUtils.encodeNodeId(a[0])); const b = result.map((a) => nodesUtils.encodeNodeId(a[0])); - // Are the closest nodes out of all of the nodes + // Are the closest nodes out of all the nodes expect(a.slice(0, b.length)).toEqual(b); // Check that the list is strictly ascending diff --git a/tests/nodes/NodeManager.test.ts b/tests/nodes/NodeManager.test.ts index 4f7c3ba40..3538f9087 100644 --- a/tests/nodes/NodeManager.test.ts +++ b/tests/nodes/NodeManager.test.ts @@ -333,10 +333,7 @@ describe(`${NodeManager.name} test`, () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - }, + tlsConfig, crypto, quicSocket: clientSocket, logger, @@ -369,8 +366,8 @@ describe(`${NodeManager.name} test`, () => { }); const serverNodeId = server.keyRing.getNodeId(); const serverNodeAddress: NodeAddress = { - host: server.quicServerAgent.host as Host, - port: server.quicServerAgent.port as Port, + host: server.quicSocket.host as Host, + port: server.quicSocket.port as Port, }; await nodeGraph.setNode(serverNodeId, serverNodeAddress); @@ -385,7 +382,8 @@ describe(`${NodeManager.name} test`, () => { await nodeConnectionManager.withConnF(serverNodeId, async () => { // Do nothing }); - + // Wait for background logic to settle + await sleep(100); const nodeData2 = await server.nodeGraph.getNode(expectedNodeId); expect(nodeData2).toBeDefined(); expect(nodeData2?.address.host).toEqual(expectedHost); @@ -557,10 +555,7 @@ describe(`${NodeManager.name} test`, () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - }, + tlsConfig, crypto, quicSocket: clientSocket, logger, @@ -637,6 +632,7 @@ describe(`${NodeManager.name} test`, () => { logger, }); await nodeManager.start(); + await taskManager.stopProcessing(); // Creating dummy tasks const task1 = await taskManager.scheduleTask({ @@ -651,7 +647,6 @@ describe(`${NodeManager.name} test`, () => { }); // Stopping nodeManager should cancel any nodeManager tasks - await taskManager.stopProcessing(); await nodeManager.stop(); const tasks: Array = []; for await (const task of taskManager.getTasks('asc', true, [ diff --git a/tests/nodes/utils.ts b/tests/nodes/utils.ts index e6336b96b..63ace36b7 100644 --- a/tests/nodes/utils.ts +++ b/tests/nodes/utils.ts @@ -71,13 +71,13 @@ function generateNodeIdForBucket( async function nodesConnect(localNode: PolykeyAgent, remoteNode: PolykeyAgent) { // Add remote node's details to local node await localNode.nodeManager.setNode(remoteNode.keyRing.getNodeId(), { - host: remoteNode.quicServerAgent.host, - port: remoteNode.quicServerAgent.port, + host: remoteNode.quicSocket.host, + port: remoteNode.quicSocket.port, } as NodeAddress); // Add local node's details to remote node await remoteNode.nodeManager.setNode(localNode.keyRing.getNodeId(), { - host: localNode.quicServerAgent.host, - port: localNode.quicServerAgent.port, + host: localNode.quicSocket.host, + port: localNode.quicSocket.port, } as NodeAddress); } diff --git a/tests/notifications/NotificationsManager.test.ts b/tests/notifications/NotificationsManager.test.ts index 81dedf69b..98d82b9e1 100644 --- a/tests/notifications/NotificationsManager.test.ts +++ b/tests/notifications/NotificationsManager.test.ts @@ -123,10 +123,7 @@ describe('NotificationsManager', () => { nodeConnectionManager = new NodeConnectionManager({ nodeGraph, keyRing, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - }, + tlsConfig, crypto, quicSocket, logger, @@ -142,7 +139,7 @@ describe('NotificationsManager', () => { logger, }); await nodeManager.start(); - await nodeConnectionManager.start({ nodeManager }); + await nodeConnectionManager.start({ nodeManager, handleStream: () => {} }); await taskManager.start(); // Set up node for receiving notifications receiver = await PolykeyAgent.createPolykeyAgent({ @@ -160,8 +157,8 @@ describe('NotificationsManager', () => { }, }); await nodeGraph.setNode(receiver.keyRing.getNodeId(), { - host: receiver.quicServerAgent.host as Host, - port: receiver.quicServerAgent.port as Port, + host: receiver.quicSocket.host as Host, + port: receiver.quicSocket.port as Port, }); }, globalThis.defaultTimeout); afterEach(async () => { diff --git a/tests/rpc/RPC.test.ts b/tests/rpc/RPC.test.ts index c9a03129c..bdc377c93 100644 --- a/tests/rpc/RPC.test.ts +++ b/tests/rpc/RPC.test.ts @@ -4,7 +4,7 @@ import type { JSONValue } from '@/types'; import { TransformStream } from 'stream/web'; import { fc, testProp } from '@fast-check/jest'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; -import { sleep } from 'ix/asynciterable/_sleep'; +import * as utils from '@/utils'; import RPCServer from '@/rpc/RPCServer'; import RPCClient from '@/rpc/RPCClient'; import { @@ -43,10 +43,10 @@ describe('RPC', () => { class TestMethod extends RawHandler { public handle( input: [JSONRPCRequest, ReadableStream], - ): ReadableStream { + ): [JSONValue, ReadableStream] { const [header_, stream] = input; header = header_; - return stream; + return ['some leading data', stream]; } } const rpcServer = await RPCServer.createRPCServer({ @@ -89,12 +89,91 @@ describe('RPC', () => { id: null, }; expect(header).toStrictEqual(expectedHeader); + expect(callerInterface.meta?.result).toBe('some leading data'); expect(await outputResult).toStrictEqual(inputData); await pipeProm; await rpcServer.destroy(); await rpcClient.destroy(); }, ); + test('RPC communication with raw stream times out waiting for leading message', async () => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + void (async () => { + for await (const _ of serverPair.readable) { + // Just consume + } + })(); + + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + testMethod: new RawCaller(), + }, + streamFactory: async () => { + return { + ...clientPair, + cancel: () => {}, + }; + }, + logger, + }); + + await expect( + rpcClient.methods.testMethod( + { + hello: 'world', + }, + { timer: 100 }, + ), + ).rejects.toThrow(rpcErrors.ErrorRPCTimedOut); + await rpcClient.destroy(); + }); + test('RPC communication with raw stream, raw handler throws', async () => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + + class TestMethod extends RawHandler { + public handle(): [JSONValue, ReadableStream] { + throw Error('some error'); + } + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + }); + rpcServer.handleStream({ + ...serverPair, + cancel: () => {}, + }); + + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + testMethod: new RawCaller(), + }, + streamFactory: async () => { + return { + ...clientPair, + cancel: () => {}, + }; + }, + logger, + }); + + await expect( + rpcClient.methods.testMethod({ + hello: 'world', + }), + ).rejects.toThrow(rpcErrors.ErrorPolykeyRemote); + + await rpcServer.destroy(); + await rpcClient.destroy(); + }); testProp( 'RPC communication with duplex stream', [fc.array(rpcTestUtils.safeJsonValueArb, { minLength: 1 })], @@ -466,7 +545,7 @@ describe('RPC', () => { const writer = callerInterface.writable.getWriter(); await writer.write({}); // Allow time to process buffer - await sleep(0); + await utils.sleep(0); await expect(writer.write({})).toReject(); const reader = callerInterface.readable.getReader(); await expect(reader.read()).toReject(); diff --git a/tests/rpc/RPCClient.test.ts b/tests/rpc/RPCClient.test.ts index 083d73f32..9b1c5fea6 100644 --- a/tests/rpc/RPCClient.test.ts +++ b/tests/rpc/RPCClient.test.ts @@ -4,12 +4,12 @@ import type { JSONRPCRequest, JSONRPCRequestMessage, JSONRPCResponse, + JSONRPCResponseResult, RPCStream, } from '@/rpc/types'; import { TransformStream, ReadableStream } from 'stream/web'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import { testProp, fc } from '@fast-check/jest'; -import { Timer } from '@matrixai/timer'; import RPCClient from '@/rpc/RPCClient'; import RPCServer from '@/rpc/RPCServer'; import * as rpcErrors from '@/rpc/errors'; @@ -53,6 +53,12 @@ describe(`${RPCClient.name}`, () => { meta: undefined, readable: new ReadableStream({ start: (controller) => { + const leadingResponse: JSONRPCResponseResult = { + jsonrpc: '2.0', + result: null, + id: null, + }; + controller.enqueue(Buffer.from(JSON.stringify(leadingResponse))); for (const datum of outputData) { controller.enqueue(datum); } @@ -618,6 +624,12 @@ describe(`${RPCClient.name}`, () => { meta: undefined, readable: new ReadableStream({ start: (controller) => { + const leadingResponse: JSONRPCResponseResult = { + jsonrpc: '2.0', + result: null, + id: null, + }; + controller.enqueue(Buffer.from(JSON.stringify(leadingResponse))); for (const datum of outputData) { controller.enqueue(datum); } @@ -653,6 +665,7 @@ describe(`${RPCClient.name}`, () => { ]); expect(await outputResult).toStrictEqual(outputData); }, + { seed: -783452149, path: '0:0:0:0:0:0:0', endOnFailure: true }, ); testProp( 'manifest duplex caller', @@ -748,7 +761,7 @@ describe(`${RPCClient.name}`, () => { const callerInterfaceProm = rpcClient.rawStreamCaller( 'testMethod', {}, - { timer: new Timer({ delay: 100 }) }, + { timer: 100 }, ); await expect(callerInterfaceProm).toReject(); await expect(callerInterfaceProm).rejects.toThrow( @@ -759,11 +772,11 @@ describe(`${RPCClient.name}`, () => { }); test('raw caller handles abort when creating stream', async () => { const holdProm = promise(); - let ctx: ContextTimed | undefined; + const ctxProm = promise(); const rpcClient = await RPCClient.createRPCClient({ manifest: {}, streamFactory: async (ctx_) => { - ctx = ctx_; + ctxProm.resolveP(ctx_); await holdProm.p; // Should never reach this when testing return {} as RPCStream; @@ -772,7 +785,6 @@ describe(`${RPCClient.name}`, () => { }); const abortController = new AbortController(); const rejectReason = Symbol('rejectReason'); - abortController.abort(rejectReason); // Timing out on stream creation const callerInterfaceProm = rpcClient.rawStreamCaller( @@ -780,6 +792,8 @@ describe(`${RPCClient.name}`, () => { {}, { signal: abortController.signal }, ); + abortController.abort(rejectReason); + const ctx = await ctxProm.p; await expect(callerInterfaceProm).toReject(); await expect(callerInterfaceProm).rejects.toBe(rejectReason); expect(ctx?.signal.aborted).toBeTrue(); @@ -800,24 +814,23 @@ describe(`${RPCClient.name}`, () => { writable: forwardPassThroughStream.writable, readable: reversePassThroughStream.readable, }; - let ctx: ContextTimed | undefined; + const ctxProm = promise(); const rpcClient = await RPCClient.createRPCClient({ manifest: {}, streamFactory: async (ctx_) => { - ctx = ctx_; + ctxProm.resolveP(ctx_); return streamPair; }, logger, }); // Timing out on stream - await Promise.all([ - rpcClient.rawStreamCaller( - 'testMethod', - {}, - { timer: new Timer({ delay: 100 }) }, - ), - forwardPassThroughStream.readable.getReader().read(), - ]); + await expect( + Promise.all([ + rpcClient.rawStreamCaller('testMethod', {}, { timer: 100 }), + forwardPassThroughStream.readable.getReader().read(), + ]), + ).rejects.toThrow(rpcErrors.ErrorRPCTimedOut); + const ctx = await ctxProm.p; await ctx?.timer; expect(ctx?.signal.aborted).toBeTrue(); expect(ctx?.signal.reason).toBeInstanceOf(rpcErrors.ErrorRPCTimedOut); @@ -850,21 +863,22 @@ describe(`${RPCClient.name}`, () => { const rejectReason = Symbol('rejectReason'); // Timing out on stream const reader = forwardPassThroughStream.readable.getReader(); - await Promise.all([ - rpcClient.rawStreamCaller( - 'testMethod', - {}, - { signal: abortController.signal }, - ), - reader.read(), - ]); - const ctx = await ctxProm.p; const abortProm = promise(); - if (ctx.signal.aborted) abortProm.resolveP(); - ctx.signal.addEventListener('abort', () => { - abortProm.resolveP(); + const ctxWaitProm = ctxProm.p.then((ctx) => { + if (ctx.signal.aborted) abortProm.resolveP(); + ctx.signal.addEventListener('abort', () => { + abortProm.resolveP(); + }); + abortController.abort(rejectReason); }); - abortController.abort(rejectReason); + const rawStreamProm = rpcClient.rawStreamCaller( + 'testMethod', + {}, + { signal: abortController.signal }, + ); + await Promise.allSettled([rawStreamProm, reader.read(), ctxWaitProm]); + await expect(rawStreamProm).rejects.toBe(rejectReason); + const ctx = await ctxProm.p; await abortProm.p; expect(ctx?.signal.aborted).toBeTrue(); expect(ctx?.signal.reason).toBe(rejectReason); @@ -909,7 +923,7 @@ describe(`${RPCClient.name}`, () => { }); // Timing out on stream creation const callerInterfaceProm = rpcClient.duplexStreamCaller('testMethod', { - timer: new Timer({ delay: 100 }), + timer: 100, }); await expect(callerInterfaceProm).toReject(); await expect(callerInterfaceProm).rejects.toThrow( @@ -1003,7 +1017,7 @@ describe(`${RPCClient.name}`, () => { // Timing out on stream await rpcClient.duplexStreamCaller('testMethod', { - timer: new Timer({ delay: 100 }), + timer: 100, }); await ctx?.timer; expect(ctx?.signal.aborted).toBeTrue(); @@ -1019,7 +1033,10 @@ describe(`${RPCClient.name}`, () => { Uint8Array >(); const streamPair: RPCStream = { - cancel: () => {}, + cancel: async (reason) => { + await forwardPassThroughStream.readable.cancel(reason); + await reversePassThroughStream.writable.abort(reason); + }, meta: undefined, writable: forwardPassThroughStream.writable, readable: reversePassThroughStream.readable, @@ -1037,7 +1054,7 @@ describe(`${RPCClient.name}`, () => { const rejectReason = Symbol('rejectReason'); abortController.abort(rejectReason); // Timing out on stream - await rpcClient.duplexStreamCaller('testMethod', { + const stream = await rpcClient.duplexStreamCaller('testMethod', { signal: abortController.signal, }); const ctx = await ctxProm.p; @@ -1048,6 +1065,7 @@ describe(`${RPCClient.name}`, () => { }); expect(ctx?.signal.aborted).toBeTrue(); expect(ctx?.signal.reason).toBe(rejectReason); + stream.cancel(Error('asd')); }); testProp( 'duplex caller timeout is refreshed when sending message', @@ -1123,7 +1141,7 @@ describe(`${RPCClient.name}`, () => { }, middlewareFactory: rpcUtilsMiddleware.defaultClientMiddlewareWrapper( (ctx) => { - ctx.timer.reset(1000); + ctx.timer.reset(123); return { forward: new TransformStream(), reverse: new TransformStream(), @@ -1141,7 +1159,7 @@ describe(`${RPCClient.name}`, () => { // Writing should refresh timer engage the middleware const writer = callerInterface.writable.getWriter(); await writer.write({}); - expect(ctx.timer.delay).toBe(1000); + expect(ctx.timer.delay).toBe(123); await writer.close(); await outputResult; diff --git a/tests/rpc/RPCServer.test.ts b/tests/rpc/RPCServer.test.ts index c4c2889c1..4986f5071 100644 --- a/tests/rpc/RPCServer.test.ts +++ b/tests/rpc/RPCServer.test.ts @@ -67,18 +67,22 @@ describe(`${RPCServer.name}`, () => { rpcTestUtils.binaryStreamToSnippedStream([4, 7, 13, 2, 6]), ); class TestHandler extends RawHandler { - public handle([_header, input]): ReadableStream { + public handle([_header, input]): [ + JSONValue, + ReadableStream, + ] { void (async () => { for await (const _ of input) { // No touch, only consume } })().catch(() => {}); - return new ReadableStream({ + const readableStream = new ReadableStream({ start: (controller) => { controller.enqueue(Buffer.from('hello world!')); controller.close(); }, }); + return [null, readableStream]; } } const rpcServer = await RPCServer.createRPCServer({ @@ -769,7 +773,12 @@ describe(`${RPCServer.name}`, () => { test('timeout with default time after handler selected', async () => { const ctxProm = promise(); class TestHandler extends RawHandler { - public handle(_input, _cancel, _meta, ctx_): ReadableStream { + public handle( + _input, + _cancel, + _meta, + ctx_, + ): [JSONValue, ReadableStream] { ctxProm.resolveP(ctx_); // Do nothing, expecting timeout let controller: ReadableStreamController; @@ -781,7 +790,7 @@ describe(`${RPCServer.name}`, () => { ctx_.signal.addEventListener('abort', () => { controller!.error(Error('ending')); }); - return stream; + return [null, stream]; } } const rpcServer = await RPCServer.createRPCServer({ @@ -987,7 +996,12 @@ describe(`${RPCServer.name}`, () => { test('stream ending cleans up timer and abortSignal', async () => { const ctxProm = promise(); class TestHandler extends RawHandler { - public handle(input, _cancel, _meta, ctx_): ReadableStream { + public handle( + input, + _cancel, + _meta, + ctx_, + ): [JSONValue, ReadableStream] { ctxProm.resolveP(ctx_); // Do nothing, expecting timeout void (async () => { @@ -995,11 +1009,12 @@ describe(`${RPCServer.name}`, () => { // Do nothing, only consume } })(); - return new ReadableStream({ + const readableStream = new ReadableStream({ start: (controller) => { controller.close(); }, }); + return [null, readableStream]; } } const rpcServer = await RPCServer.createRPCServer({ @@ -1034,10 +1049,15 @@ describe(`${RPCServer.name}`, () => { test('Timeout has a grace period before forcing the streams closed', async () => { const ctxProm = promise(); class TestHandler extends RawHandler { - public handle(_input, _cancel, _meta, ctx_): ReadableStream { + public handle( + _input, + _cancel, + _meta, + ctx_, + ): [JSONValue, ReadableStream] { ctxProm.resolveP(ctx_); // Do nothing, expecting timeout - return new ReadableStream(); + return [null, new ReadableStream()]; } } const rpcServer = await RPCServer.createRPCServer({ @@ -1125,10 +1145,15 @@ describe(`${RPCServer.name}`, () => { test('destroying the `RPCServer` sends an abort signal and closes connection', async () => { const ctxProm = promise(); class TestHandler extends RawHandler { - public handle(input, _cancel, _meta, ctx_): ReadableStream { + public handle( + input, + _cancel, + _meta, + ctx_, + ): [JSONValue, ReadableStream] { ctxProm.resolveP(ctx_); // Echo messages - return input[1]; + return [null, input[1]]; } } const rpcServer = await RPCServer.createRPCServer({ diff --git a/tests/rpc/utils.ts b/tests/rpc/utils.ts index 3744306ca..4da98b217 100644 --- a/tests/rpc/utils.ts +++ b/tests/rpc/utils.ts @@ -82,8 +82,8 @@ const messagesToReadableStream = (messages: Array) => { * a json stringify and parse cycle. */ const safeJsonValueArb = fc - .jsonValue() - .map((value) => JSON.parse(JSON.stringify(value)) as JSONValue); + .json() + .map((value) => JSON.parse(value.replace('__proto__', 'proto')) as JSONValue); const idArb = fc.oneof(fc.string(), fc.integer(), fc.constant(null)); diff --git a/tests/scratch.test.ts b/tests/scratch.test.ts index c7c21d965..4959b5280 100644 --- a/tests/scratch.test.ts +++ b/tests/scratch.test.ts @@ -1,14 +1,13 @@ import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; -import NodeManager from '@/nodes/NodeManager'; // This is a 'scratch paper' test file for quickly running tests in the CI describe('scratch', () => { - const _logger = new Logger(`${NodeManager.name} test`, LogLevel.WARN, [ + const _logger = new Logger(`scratch test`, LogLevel.WARN, [ new StreamHandler(), ]); -}); -// We can't have empty test files so here is a sanity test -test('Should avoid empty test suite', async () => { - expect(1 + 1).toBe(2); + // We can't have empty test files so here is a sanity test + test('Should avoid empty test suite', async () => { + expect(1 + 1).toBe(2); + }); }); diff --git a/tests/utils/tls.ts b/tests/utils/tls.ts index 5306df1fe..97f32af99 100644 --- a/tests/utils/tls.ts +++ b/tests/utils/tls.ts @@ -2,7 +2,6 @@ import type { CertId, Certificate, CertificatePEMChain, - Key, KeyPair, } from '@/keys/types'; import type { TLSConfig } from '@/network/types'; @@ -64,8 +63,8 @@ async function createTLSConfigWithChain( }; } -function createCrypto(key: Key = keysUtils.generateKey()) { - const ops: ClientCrypto & ServerCrypto = { +function createCrypto(): ServerCrypto & ClientCrypto { + return { randomBytes: async (data: ArrayBuffer) => { const randomBytes = keysUtils.getRandomBytes(data.byteLength); const dataBuf = Buffer.from(data); @@ -74,10 +73,6 @@ function createCrypto(key: Key = keysUtils.generateKey()) { sign: testNodesUtils.sign, verify: testNodesUtils.verify, }; - return { - key: key, - ops, - }; } export { createTLSConfig, createTLSConfigWithChain, createCrypto }; diff --git a/tests/vaults/VaultManager.test.ts b/tests/vaults/VaultManager.test.ts index 7d1fab63b..5b597db43 100644 --- a/tests/vaults/VaultManager.test.ts +++ b/tests/vaults/VaultManager.test.ts @@ -469,8 +469,7 @@ describe('VaultManager', () => { await vaultManager?.destroy(); } }); - // TODO: disabled until feature is addressed in agent migration stage 2 - describe.skip('with remote agents', () => { + describe('with remote agents', () => { let allDataDir: string; let keyRing: KeyRing; let nodeGraph: NodeGraph; @@ -517,12 +516,12 @@ describe('VaultManager', () => { // Adding details to each agent await remoteKeynode1.nodeGraph.setNode(remoteKeynode2Id, { - host: remoteKeynode2.quicServerAgent.host as Host, - port: remoteKeynode2.quicServerAgent.port as Port, + host: remoteKeynode2.quicSocket.host as Host, + port: remoteKeynode2.quicSocket.port as Port, }); await remoteKeynode2.nodeGraph.setNode(remoteKeynode1Id, { - host: remoteKeynode1.quicServerAgent.host as Host, - port: remoteKeynode1.quicServerAgent.port as Port, + host: remoteKeynode1.quicSocket.host as Host, + port: remoteKeynode1.quicSocket.port as Port, }); await remoteKeynode1.gestaltGraph.setNode({ @@ -580,25 +579,23 @@ describe('VaultManager', () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, nodeGraph, - quicClientConfig: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - }, + tlsConfig, crypto, quicSocket, logger, }); await nodeConnectionManager.start({ nodeManager: { setNode: jest.fn() } as unknown as NodeManager, + handleStream: () => {}, }); await taskManager.startProcessing(); await nodeGraph.setNode(remoteKeynode1Id, { - host: remoteKeynode1.quicServerAgent.host as Host, - port: remoteKeynode1.quicServerAgent.port as Port, + host: remoteKeynode1.quicSocket.host as Host, + port: remoteKeynode1.quicSocket.port as Port, }); await nodeGraph.setNode(remoteKeynode2Id, { - host: remoteKeynode2.quicServerAgent.host as Host, - port: remoteKeynode2.quicServerAgent.port as Port, + host: remoteKeynode2.quicSocket.host as Host, + port: remoteKeynode2.quicSocket.port as Port, }); }); afterEach(async () => { @@ -655,7 +652,6 @@ describe('VaultManager', () => { localNodeId, 'pull', ); - await vaultManager.cloneVault(remoteKeynode1Id, vaultName); const vaultId = await vaultManager.getVaultId(vaultName); if (vaultId === undefined) fail('VaultId is not found.'); @@ -1400,8 +1396,8 @@ describe('VaultManager', () => { // Letting nodeGraph know where the remote agent is await nodeGraph.setNode(targetNodeId, { - host: remoteKeynode1.quicServerAgent.host as Host, - port: remoteKeynode1.quicServerAgent.port as Port, + host: remoteKeynode1.quicSocket.host as Host, + port: remoteKeynode1.quicSocket.port as Port, }); await remoteKeynode1.gestaltGraph.setNode({ diff --git a/tests/websockets/WebSocket.test.ts b/tests/websockets/WebSocket.test.ts index 3d56191d6..64815dcf9 100644 --- a/tests/websockets/WebSocket.test.ts +++ b/tests/websockets/WebSocket.test.ts @@ -1,8 +1,8 @@ import type { ReadableWritablePair } from 'stream/web'; import type { TLSConfig } from '@/network/types'; import type { KeyPair } from '@/keys/types'; +import type { NodeId } from '@/ids/types'; import type http from 'http'; -import type WebSocketStream from '@/websockets/WebSocketStream'; import fs from 'fs'; import path from 'path'; import os from 'os'; @@ -10,7 +10,7 @@ import https from 'https'; import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; import { testProp, fc } from '@fast-check/jest'; import { Timer } from '@matrixai/timer'; -import { KeyRing } from '@/keys/index'; +import { status } from '@matrixai/async-init'; import WebSocketServer from '@/websockets/WebSocketServer'; import WebSocketClient from '@/websockets/WebSocketClient'; import { promise } from '@/utils'; @@ -28,8 +28,8 @@ describe('WebSocket', () => { formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, ), ]); - let dataDir: string; - let keyRing: KeyRing; + let keyPair: KeyPair; + let nodeId: NodeId; let tlsConfig: TLSConfig; const host = '127.0.0.2'; let webSocketServer: WebSocketServer; @@ -63,19 +63,14 @@ describe('WebSocket', () => { dataDir = await fs.promises.mkdtemp( path.join(os.tmpdir(), 'polykey-test-'), ); - const keysPath = path.join(dataDir, 'keys'); - keyRing = await KeyRing.createKeyRing({ - keysPath: keysPath, - password: 'password', - logger: logger.getChild('keyRing'), - }); - tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); + keyPair = keysUtils.generateKeyPair(); + nodeId = keysUtils.publicKeyToNodeId(keyPair.publicKey); + tlsConfig = await testsUtils.createTLSConfig(keyPair); }); afterEach(async () => { logger.info('AFTEREACH'); await webSocketServer?.stop(true); await webSocketClient?.destroy(true); - await keyRing.stop(); await fs.promises.rm(dataDir, { force: true, recursive: true }); }); @@ -89,7 +84,6 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -98,7 +92,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -115,6 +109,42 @@ describe('WebSocket', () => { expect((await reader.read()).done).toBeTrue(); logger.info('ending'); }); + test('can change TLS config', async () => { + const keyPairNew = keysUtils.generateKeyPair(); + const nodeIdNew = keysUtils.publicKeyToNodeId(keyPairNew.publicKey); + const tlsConfigNew = await testsUtils.createTLSConfig(keyPairNew); + + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => logger.info('STREAM HANDLING ENDED')); + }, + tlsConfig, + host, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.getPort()}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ + host, + port: webSocketServer.getPort(), + expectedNodeIds: [nodeId, nodeIdNew], + logger: logger.getChild('clientClient'), + }); + const websocket = await webSocketClient.startConnection(); + expect(websocket.meta.nodeId).toBe(nodesUtils.encodeNodeId(nodeId)); + websocket.cancel(); + + // Changing certs + webSocketServer.setTlsConfig(tlsConfigNew); + const websocket2 = await webSocketClient.startConnection(); + expect(websocket2.meta.nodeId).toBe(nodesUtils.encodeNodeId(nodeIdNew)); + websocket2.cancel(); + + logger.info('ending'); + }); test('makes a connection over IPv6', async () => { webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { @@ -124,7 +154,6 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, - basePath: dataDir, tlsConfig, host: '::1', logger: logger.getChild('server'), @@ -133,7 +162,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host: '::1', port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -159,7 +188,6 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -168,7 +196,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -190,7 +218,6 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -199,7 +226,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); @@ -227,129 +254,56 @@ describe('WebSocket', () => { } }, ); - test('reverse backpressure', async () => { - const backpressure = promise(); - const resumeWriting = promise(); - let webSocketStream: WebSocketStream | null = null; + test('handles https server failure', async () => { webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); - void Promise.allSettled([ - (async () => { - for await (const _ of streamPair.readable) { - // No touch, only consume - } - })(), - (async () => { - // Kidnap the context - // @ts-ignore: kidnap protected property - for (const websocket of webSocketServer.activeSockets.values()) { - webSocketStream = websocket; - } - if (webSocketStream == null) { - await streamPair.writable.close(); - return; - } - // Write until backPressured - const message = Buffer.alloc(128, 0xf0); - const writer = streamPair.writable.getWriter(); - // @ts-ignore: kidnap protected property - while (!webSocketStream.writeBackpressure) { - await writer.write(message); - } - logger.info('BACK PRESSURED'); - backpressure.resolveP(); - await resumeWriting.p; - for (let i = 0; i < 100; i++) { - await writer.write(message); - } - await writer.close(); - logger.info('WRITING ENDED'); - })(), - ]).catch((e) => logger.error(e.toString())); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => logger.info('STREAM HANDLING ENDED')); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), }); logger.info(`Server started on port ${webSocketServer.getPort()}`); - webSocketClient = await WebSocketClient.createWebSocketClient({ - host, - port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], - logger: logger.getChild('clientClient'), + + const closeP = promise(); + // @ts-ignore: protected property + webSocketServer.server.close(() => { + closeP.resolveP(); }); - const websocket = await webSocketClient.startConnection(); - await websocket.writable.close(); + await closeP.p; + // The webSocketServer should stop itself + expect(webSocketServer[status]).toBe(null); - await backpressure.p; - // @ts-ignore: kidnap protected property - expect(webSocketStream.writeBackpressure).toBeTrue(); - resumeWriting.resolveP(); - // Consume all the back-pressured data - for await (const _ of websocket.readable) { - // No touch, only consume - } - // @ts-ignore: kidnap protected property - expect(webSocketStream.writeBackpressure).toBeFalse(); logger.info('ending'); }); - // Readable backpressure is not actually supported. We're dealing with it by - // using a buffer with a provided limit that can be very large. - test('exceeding readable buffer limit causes error', async () => { - const startReading = promise(); - const handlingProm = promise(); + test('handles webSocketServer server failure', async () => { webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); - Promise.all([ - (async () => { - await startReading.p; - logger.info('Starting consumption'); - for await (const _ of streamPair.readable) { - // No touch, only consume - } - logger.info('Reads ended'); - })(), - (async () => { - await streamPair.writable.close(); - })(), - ]) + void streamPair.readable + .pipeTo(streamPair.writable) .catch(() => {}) - .finally(() => handlingProm.resolveP()); + .finally(() => logger.info('STREAM HANDLING ENDED')); }, - basePath: dataDir, tlsConfig, host, - // Setting a really low buffer limit - maxReadableStreamBytes: 1500, logger: logger.getChild('server'), }); logger.info(`Server started on port ${webSocketServer.getPort()}`); - webSocketClient = await WebSocketClient.createWebSocketClient({ - host, - port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], - logger: logger.getChild('clientClient'), + + const closeP = promise(); + // @ts-ignore: protected property + webSocketServer.webSocketServer.close(() => { + closeP.resolveP(); }); - const websocket = await webSocketClient.startConnection(); - const message = Buffer.alloc(1_000, 0xf0); - const writer = websocket.writable.getWriter(); - logger.info('Starting writes'); - await expect(async () => { - for (let i = 0; i < 100; i++) { - await writer.write(message); - } - }).rejects.toThrow(); - startReading.resolveP(); - logger.info('writes ended'); - await expect(async () => { - for await (const _ of websocket.readable) { - // No touch, only consume - } - }).rejects.toThrow(); - await handlingProm.p; + await closeP.p; + // The webSocketServer should stop itself + expect(webSocketServer[status]).toBe(null); + logger.info('ending'); }); test('client ends connection abruptly', async () => { @@ -360,7 +314,6 @@ describe('WebSocket', () => { logger.info('inside callback'); streamPairProm.resolveP(streamPair); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -378,7 +331,7 @@ describe('WebSocket', () => { env: { PK_TEST_HOST: host, PK_TEST_PORT: `${webSocketServer.getPort()}`, - PK_TEST_NODE_ID: nodesUtils.encodeNodeId(keyRing.getNodeId()), + PK_TEST_NODE_ID: nodesUtils.encodeNodeId(nodeId), }, }, logger, @@ -440,7 +393,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: await startedProm.p, - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -468,6 +421,7 @@ describe('WebSocket', () => { [messagesArb, messagesArb], async (messages1, messages2) => { try { + const serverStreamProm = promise(); webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); @@ -480,9 +434,8 @@ describe('WebSocket', () => { for await (const _ of streamPair.readable) { // No touch, only consume } - })().catch((e) => logger.error(e)); + })().then(serverStreamProm.resolveP, serverStreamProm.rejectP); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -491,11 +444,12 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); await asyncReadWrite(messages1, websocket); + await serverStreamProm.p; logger.info('ending'); } finally { await webSocketServer.stop(true); @@ -507,6 +461,7 @@ describe('WebSocket', () => { [messagesArb, messagesArb], async (messages1, messages2) => { try { + const serverStreamProm = promise(); webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); @@ -519,9 +474,8 @@ describe('WebSocket', () => { await writer.write(val); } await writer.close(); - })().catch((e) => logger.error(e)); + })().then(serverStreamProm.resolveP, serverStreamProm.rejectP); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -530,11 +484,12 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); await asyncReadWrite(messages1, websocket); + await serverStreamProm.p; logger.info('ending'); } finally { await webSocketServer.stop(true); @@ -546,6 +501,7 @@ describe('WebSocket', () => { [messagesArb, messagesArb], async (messages1, messages2) => { try { + const serverStreamProm = promise(); webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); @@ -556,9 +512,8 @@ describe('WebSocket', () => { await writer.write(val); } await writer.close(); - })().catch((e) => logger.error(e)); + })().then(serverStreamProm.resolveP, serverStreamProm.rejectP); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -567,11 +522,12 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); await asyncReadWrite(messages1, websocket); + await serverStreamProm.p; logger.info('ending'); } finally { await webSocketServer.stop(true); @@ -586,7 +542,6 @@ describe('WebSocket', () => { logger.info('inside callback'); streamPairProm.resolveP(streamPair); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -595,7 +550,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -627,7 +582,6 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -657,7 +611,6 @@ describe('WebSocket', () => { logger.info('inside callback'); // Hang connection }, - basePath: dataDir, tlsConfig, host, pingTimeoutTimeTime: 100, @@ -667,7 +620,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); await webSocketClient.startConnection(); @@ -684,7 +637,6 @@ describe('WebSocket', () => { logger.info('inside callback'); streamPairProm.resolveP(streamPair); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -693,7 +645,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -728,7 +680,6 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -751,7 +702,7 @@ describe('WebSocket', () => { }); test('authenticates with multiple certs in chain', async () => { const keyPairs: Array = [ - keyRing.keyPair, + keyPair, keysUtils.generateKeyPair(), keysUtils.generateKeyPair(), keysUtils.generateKeyPair(), @@ -767,7 +718,6 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -797,7 +747,6 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -806,7 +755,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId(), alternativeNodeId], + expectedNodeIds: [nodeId, alternativeNodeId], logger: logger.getChild('clientClient'), }); await expect(webSocketClient.startConnection()).toResolve(); @@ -819,7 +768,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: 12345, - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], connectionTimeoutTime: 0, logger: logger.getChild('clientClient'), }); @@ -837,7 +786,6 @@ describe('WebSocket', () => { logger.info('inside callback'); // Hang connection }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -846,7 +794,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], pingTimeoutTimeTime: 100, logger: logger.getChild('clientClient'), }); @@ -874,7 +822,6 @@ describe('WebSocket', () => { })().catch(() => {}), ]); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -883,7 +830,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); const abortController = new AbortController();