Skip to content

Commit

Permalink
feat(NODE-6389): add support for timeoutMS in StateMachine.execute() (#…
Browse files Browse the repository at this point in the history
…4243)

Co-authored-by: Warren James <[email protected]>
Co-authored-by: Neal Beeken <[email protected]>
Co-authored-by: Bailey Pearson <[email protected]>
  • Loading branch information
4 people authored and dariakp committed Nov 6, 2024
1 parent 2398fc6 commit c55f965
Show file tree
Hide file tree
Showing 5 changed files with 371 additions and 55 deletions.
88 changes: 62 additions & 26 deletions src/client-side-encryption/state_machine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ import {
} from '../bson';
import { type ProxyOptions } from '../cmap/connection';
import { getSocks, type SocksLib } from '../deps';
import { MongoOperationTimeoutError } from '../error';
import { type MongoClient, type MongoClientOptions } from '../mongo_client';
import { Timeout, type TimeoutContext, TimeoutError } from '../timeout';
import { BufferPool, MongoDBCollectionNamespace, promiseWithResolvers } from '../utils';
import { autoSelectSocketOptions, type DataKey } from './client_encryption';
import { MongoCryptError } from './errors';
Expand Down Expand Up @@ -173,6 +175,7 @@ export type StateMachineOptions = {
* An internal class that executes across a MongoCryptContext until either
* a finishing state or an error is reached. Do not instantiate directly.
*/
// TODO(DRIVERS-2671): clarify CSOT behavior for FLE APIs
export class StateMachine {
constructor(
private options: StateMachineOptions,
Expand All @@ -182,7 +185,11 @@ export class StateMachine {
/**
* Executes the state machine according to the specification
*/
async execute(executor: StateMachineExecutable, context: MongoCryptContext): Promise<Uint8Array> {
async execute(
executor: StateMachineExecutable,
context: MongoCryptContext,
timeoutContext?: TimeoutContext
): Promise<Uint8Array> {
const keyVaultNamespace = executor._keyVaultNamespace;
const keyVaultClient = executor._keyVaultClient;
const metaDataClient = executor._metaDataClient;
Expand All @@ -201,8 +208,13 @@ export class StateMachine {
'unreachable state machine state: entered MONGOCRYPT_CTX_NEED_MONGO_COLLINFO but metadata client is undefined'
);
}
const collInfo = await this.fetchCollectionInfo(metaDataClient, context.ns, filter);

const collInfo = await this.fetchCollectionInfo(
metaDataClient,
context.ns,
filter,
timeoutContext
);
if (collInfo) {
context.addMongoOperationResponse(collInfo);
}
Expand All @@ -222,9 +234,9 @@ export class StateMachine {
// When we are using the shared library, we don't have a mongocryptd manager.
const markedCommand: Uint8Array = mongocryptdManager
? await mongocryptdManager.withRespawn(
this.markCommand.bind(this, mongocryptdClient, context.ns, command)
this.markCommand.bind(this, mongocryptdClient, context.ns, command, timeoutContext)
)
: await this.markCommand(mongocryptdClient, context.ns, command);
: await this.markCommand(mongocryptdClient, context.ns, command, timeoutContext);

context.addMongoOperationResponse(markedCommand);
context.finishMongoOperation();
Expand All @@ -233,7 +245,12 @@ export class StateMachine {

case MONGOCRYPT_CTX_NEED_MONGO_KEYS: {
const filter = context.nextMongoOperation();
const keys = await this.fetchKeys(keyVaultClient, keyVaultNamespace, filter);
const keys = await this.fetchKeys(
keyVaultClient,
keyVaultNamespace,
filter,
timeoutContext
);

if (keys.length === 0) {
// See docs on EMPTY_V
Expand All @@ -255,9 +272,7 @@ export class StateMachine {
}

case MONGOCRYPT_CTX_NEED_KMS: {
const requests = Array.from(this.requests(context));
await Promise.all(requests);

await Promise.all(this.requests(context, timeoutContext));
context.finishKMSRequests();
break;
}
Expand Down Expand Up @@ -299,7 +314,7 @@ export class StateMachine {
* @param kmsContext - A C++ KMS context returned from the bindings
* @returns A promise that resolves when the KMS reply has be fully parsed
*/
async kmsRequest(request: MongoCryptKMSRequest): Promise<void> {
async kmsRequest(request: MongoCryptKMSRequest, timeoutContext?: TimeoutContext): Promise<void> {
const parsedUrl = request.endpoint.split(':');
const port = parsedUrl[1] != null ? Number.parseInt(parsedUrl[1], 10) : HTTPS_PORT;
const socketOptions = autoSelectSocketOptions(this.options.socketOptions || {});
Expand Down Expand Up @@ -329,10 +344,6 @@ export class StateMachine {
}
}

function ontimeout() {
return new MongoCryptError('KMS request timed out');
}

function onerror(cause: Error) {
return new MongoCryptError('KMS request failed', { cause });
}
Expand Down Expand Up @@ -364,7 +375,6 @@ export class StateMachine {
resolve: resolveOnNetSocketConnect
} = promiseWithResolvers<void>();
netSocket
.once('timeout', () => rejectOnNetSocketError(ontimeout()))
.once('error', err => rejectOnNetSocketError(onerror(err)))
.once('close', () => rejectOnNetSocketError(onclose()))
.once('connect', () => resolveOnNetSocketConnect());
Expand Down Expand Up @@ -410,8 +420,8 @@ export class StateMachine {
reject: rejectOnTlsSocketError,
resolve
} = promiseWithResolvers<void>();

socket
.once('timeout', () => rejectOnTlsSocketError(ontimeout()))
.once('error', err => rejectOnTlsSocketError(onerror(err)))
.once('close', () => rejectOnTlsSocketError(onclose()))
.on('data', data => {
Expand All @@ -425,20 +435,26 @@ export class StateMachine {
resolve();
}
});
await willResolveKmsRequest;
await (timeoutContext?.csotEnabled()
? Promise.all([willResolveKmsRequest, Timeout.expires(timeoutContext?.remainingTimeMS)])
: willResolveKmsRequest);
} catch (error) {
if (error instanceof TimeoutError)
throw new MongoOperationTimeoutError('KMS request timed out');
throw error;
} finally {
// There's no need for any more activity on this socket at this point.
destroySockets();
}
}

*requests(context: MongoCryptContext) {
*requests(context: MongoCryptContext, timeoutContext?: TimeoutContext) {
for (
let request = context.nextKMSRequest();
request != null;
request = context.nextKMSRequest()
) {
yield this.kmsRequest(request);
yield this.kmsRequest(request, timeoutContext);
}
}

Expand Down Expand Up @@ -498,15 +514,19 @@ export class StateMachine {
async fetchCollectionInfo(
client: MongoClient,
ns: string,
filter: Document
filter: Document,
timeoutContext?: TimeoutContext
): Promise<Uint8Array | null> {
const { db } = MongoDBCollectionNamespace.fromString(ns);

const collections = await client
.db(db)
.listCollections(filter, {
promoteLongs: false,
promoteValues: false
promoteValues: false,
...(timeoutContext?.csotEnabled()
? { timeoutMS: timeoutContext?.remainingTimeMS, timeoutMode: 'cursorLifetime' }
: {})
})
.toArray();

Expand All @@ -522,12 +542,22 @@ export class StateMachine {
* @param command - The command to execute.
* @param callback - Invoked with the serialized and marked bson command, or with an error
*/
async markCommand(client: MongoClient, ns: string, command: Uint8Array): Promise<Uint8Array> {
const options = { promoteLongs: false, promoteValues: false };
async markCommand(
client: MongoClient,
ns: string,
command: Uint8Array,
timeoutContext?: TimeoutContext
): Promise<Uint8Array> {
const { db } = MongoDBCollectionNamespace.fromString(ns);
const rawCommand = deserialize(command, options);
const bsonOptions = { promoteLongs: false, promoteValues: false };
const rawCommand = deserialize(command, bsonOptions);

const response = await client.db(db).command(rawCommand, options);
const response = await client.db(db).command(rawCommand, {
...bsonOptions,
...(timeoutContext?.csotEnabled()
? { timeoutMS: timeoutContext?.remainingTimeMS }
: undefined)
});

return serialize(response, this.bsonOptions);
}
Expand All @@ -543,15 +573,21 @@ export class StateMachine {
fetchKeys(
client: MongoClient,
keyVaultNamespace: string,
filter: Uint8Array
filter: Uint8Array,
timeoutContext?: TimeoutContext
): Promise<Array<DataKey>> {
const { db: dbName, collection: collectionName } =
MongoDBCollectionNamespace.fromString(keyVaultNamespace);

return client
.db(dbName)
.collection<DataKey>(collectionName, { readConcern: { level: 'majority' } })
.find(deserialize(filter))
.find(
deserialize(filter),
timeoutContext?.csotEnabled()
? { timeoutMS: timeoutContext?.remainingTimeMS, timeoutMode: 'cursorLifetime' }
: {}
)
.toArray();
}
}
4 changes: 4 additions & 0 deletions src/sdam/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,10 @@ export class Server extends TypedEventEmitter<ServerEvents> {
delete finalOptions.readPreference;
}

if (this.description.iscryptd) {
finalOptions.omitMaxTimeMS = true;
}

const session = finalOptions.session;
let conn = session?.pinnedConnection;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
/* Specification prose tests */

import { type ChildProcess, spawn } from 'node:child_process';

import { expect } from 'chai';
import * as semver from 'semver';
import * as sinon from 'sinon';
Expand All @@ -16,7 +18,8 @@ import {
MongoServerSelectionError,
now,
ObjectId,
promiseWithResolvers
promiseWithResolvers,
squashError
} from '../../mongodb';
import { type FailPoint } from '../../tools/utils';

Expand Down Expand Up @@ -103,17 +106,55 @@ describe('CSOT spec prose tests', function () {
});
});

context.skip('2. maxTimeMS is not set for commands sent to mongocryptd', () => {
/**
* This test MUST only be run against enterprise server versions 4.2 and higher.
*
* 1. Launch a mongocryptd process on 23000.
* 1. Create a MongoClient (referred to as `client`) using the URI `mongodb://localhost:23000/?timeoutMS=1000`.
* 1. Using `client`, execute the `{ ping: 1 }` command against the `admin` database.
* 1. Verify via command monitoring that the `ping` command sent did not contain a `maxTimeMS` field.
*/
});
context(
'2. maxTimeMS is not set for commands sent to mongocryptd',
{ requires: { mongodb: '>=4.2' } },
() => {
/**
* This test MUST only be run against enterprise server versions 4.2 and higher.
*
* 1. Launch a mongocryptd process on 23000.
* 1. Create a MongoClient (referred to as `client`) using the URI `mongodb://localhost:23000/?timeoutMS=1000`.
* 1. Using `client`, execute the `{ ping: 1 }` command against the `admin` database.
* 1. Verify via command monitoring that the `ping` command sent did not contain a `maxTimeMS` field.
*/

let client: MongoClient;
const mongocryptdTestPort = '23000';
let childProcess: ChildProcess;

beforeEach(async function () {
childProcess = spawn('mongocryptd', ['--port', mongocryptdTestPort, '--ipv6'], {
stdio: 'ignore',
detached: true
});

childProcess.on('error', error => console.warn(this.currentTest?.fullTitle(), error));
client = new MongoClient(`mongodb://localhost:${mongocryptdTestPort}/?timeoutMS=1000`, {
monitorCommands: true
});
});

afterEach(async function () {
await client.close();
childProcess.kill('SIGKILL');
sinon.restore();
});

it('maxTimeMS is not set', async function () {
const commandStarted = [];
client.on('commandStarted', ev => commandStarted.push(ev));
await client
.db('admin')
.command({ ping: 1 })
.catch(e => squashError(e));
expect(commandStarted).to.have.lengthOf(1);
expect(commandStarted[0].command).to.not.have.property('maxTimeMS');
});
}
);

// TODO(NODE-6391): Add timeoutMS support to Explicit Encryption
context.skip('3. ClientEncryption', () => {
/**
* Each test under this category MUST only be run against server versions 4.4 and higher. In these tests,
Expand Down Expand Up @@ -720,6 +761,30 @@ describe('CSOT spec prose tests', function () {
'TODO(NODE-6223): Auto connect performs extra server selection. Explicit connect throws on invalid host name';
});

it.skip("timeoutMS honored for server selection if it's lower than serverSelectionTimeoutMS", async function () {
/**
* 1. Create a MongoClient (referred to as `client`) with URI `mongodb://invalid/?timeoutMS=10&serverSelectionTimeoutMS=20`.
* 1. Using `client`, run the command `{ ping: 1 }` against the `admin` database.
* - Expect this to fail with a server selection timeout error after no more than 15ms.
*/
client = new MongoClient('mongodb://invalid/?timeoutMS=10&serverSelectionTimeoutMS=20');
const start = now();

const maybeError = await client
.db('test')
.admin()
.ping()
.then(
() => null,
e => e
);
const end = now();

expect(maybeError).to.be.instanceof(MongoOperationTimeoutError);
expect(end - start).to.be.lte(15);
}).skipReason =
'TODO(NODE-6223): Auto connect performs extra server selection. Explicit connect throws on invalid host name';

it.skip("timeoutMS honored for server selection if it's lower than serverSelectionTimeoutMS", async function () {
/**
* 1. Create a MongoClient (referred to as `client`) with URI `mongodb://invalid/?timeoutMS=10&serverSelectionTimeoutMS=20`.
Expand Down
Loading

0 comments on commit c55f965

Please sign in to comment.