diff --git a/src/mongo_client.ts b/src/mongo_client.ts index 0de3e8e079..eb5b5ac06e 100644 --- a/src/mongo_client.ts +++ b/src/mongo_client.ts @@ -256,7 +256,7 @@ export interface MongoClientOptions extends BSONSerializeOptions, SupportedNodeC } /** @public */ -export type WithSessionCallback = (session: ClientSession) => Promise; +export type WithSessionCallback = (session: ClientSession) => Promise; /** @internal */ export interface MongoClientPrivate { @@ -605,29 +605,30 @@ export class MongoClient extends TypedEventEmitter { } /** - * Runs a given operation with an implicitly created session. The lifetime of the session - * will be handled without the need for user interaction. + * A convenience method for creating and handling the clean up of a ClientSession. + * The session will always be ended when the executor finishes. * - * NOTE: presently the operation MUST return a Promise (either explicit or implicitly as an async function) - * - * @param options - Optional settings for the command - * @param callback - An callback to execute with an implicitly created session + * @param executor - An executor function that all operations using the provided session must be invoked in + * @param options - optional settings for the session */ - async withSession(callback: WithSessionCallback): Promise; - async withSession(options: ClientSessionOptions, callback: WithSessionCallback): Promise; - async withSession( - optionsOrOperation: ClientSessionOptions | WithSessionCallback, - callback?: WithSessionCallback - ): Promise { + async withSession(executor: WithSessionCallback): Promise; + async withSession( + options: ClientSessionOptions, + executor: WithSessionCallback + ): Promise; + async withSession( + optionsOrExecutor: ClientSessionOptions | WithSessionCallback, + executor?: WithSessionCallback + ): Promise { const options = { // Always define an owner owner: Symbol(), // If it's an object inherit the options - ...(typeof optionsOrOperation === 'object' ? optionsOrOperation : {}) + ...(typeof optionsOrExecutor === 'object' ? optionsOrExecutor : {}) }; const withSessionCallback = - typeof optionsOrOperation === 'function' ? optionsOrOperation : callback; + typeof optionsOrExecutor === 'function' ? optionsOrExecutor : executor; if (withSessionCallback == null) { throw new MongoInvalidArgumentError('Missing required callback parameter'); @@ -636,7 +637,7 @@ export class MongoClient extends TypedEventEmitter { const session = this.startSession(options); try { - await withSessionCallback(session); + return await withSessionCallback(session); } finally { try { await session.endSession(); diff --git a/src/sessions.ts b/src/sessions.ts index 9b4d106352..4c30bf5aea 100644 --- a/src/sessions.ts +++ b/src/sessions.ts @@ -67,7 +67,7 @@ export interface ClientSessionOptions { } /** @public */ -export type WithTransactionCallback = (session: ClientSession) => Promise; +export type WithTransactionCallback = (session: ClientSession) => Promise; /** @public */ export type ClientSessionEvents = { @@ -432,18 +432,16 @@ export class ClientSession extends TypedEventEmitter { } /** - * Runs a provided callback within a transaction, retrying either the commitTransaction operation - * or entire transaction as needed (and when the error permits) to better ensure that - * the transaction can complete successfully. + * Starts a transaction and runs a provided function, ensuring the commitTransaction is always attempted when all operations run in the function have completed. * * **IMPORTANT:** This method requires the user to return a Promise, and `await` all operations. - * Any callbacks that do not return a Promise will result in undefined behavior. * * @remarks * This function: - * - Will return the command response from the final commitTransaction if every operation is successful (can be used as a truthy object) - * - Will return `undefined` if the transaction is explicitly aborted with `await session.abortTransaction()` - * - Will throw if one of the operations throws or `throw` statement is used inside the `withTransaction` callback + * - If all operations successfully complete and the `commitTransaction` operation is successful, then this function will return the result of the provided function. + * - If the transaction is unable to complete or an error is thrown from within the provided function, then this function will throw an error. + * - If the transaction is manually aborted within the provided function it will not throw. + * - May be called multiple times if the driver needs to attempt to retry the operations. * * Checkout a descriptive example here: * @see https://www.mongodb.com/developer/quickstart/node-transactions/ @@ -452,7 +450,7 @@ export class ClientSession extends TypedEventEmitter { * @param options - optional settings for the transaction * @returns A raw command response or undefined */ - async withTransaction( + async withTransaction( fn: WithTransactionCallback, options?: TransactionOptions ): Promise { @@ -543,25 +541,29 @@ function attemptTransactionCommit( session: ClientSession, startTime: number, fn: WithTransactionCallback, - options?: TransactionOptions + result: any, + options: TransactionOptions ): Promise { - return session.commitTransaction().catch((err: MongoError) => { - if ( - err instanceof MongoError && - hasNotTimedOut(startTime, MAX_WITH_TRANSACTION_TIMEOUT) && - !isMaxTimeMSExpiredError(err) - ) { - if (err.hasErrorLabel(MongoErrorLabel.UnknownTransactionCommitResult)) { - return attemptTransactionCommit(session, startTime, fn, options); - } + return session.commitTransaction().then( + () => result, + (err: MongoError) => { + if ( + err instanceof MongoError && + hasNotTimedOut(startTime, MAX_WITH_TRANSACTION_TIMEOUT) && + !isMaxTimeMSExpiredError(err) + ) { + if (err.hasErrorLabel(MongoErrorLabel.UnknownTransactionCommitResult)) { + return attemptTransactionCommit(session, startTime, fn, result, options); + } - if (err.hasErrorLabel(MongoErrorLabel.TransientTransactionError)) { - return attemptTransaction(session, startTime, fn, options); + if (err.hasErrorLabel(MongoErrorLabel.TransientTransactionError)) { + return attemptTransaction(session, startTime, fn, options); + } } - } - throw err; - }); + throw err; + } + ); } const USER_EXPLICIT_TXN_END_STATES = new Set([ @@ -574,11 +576,11 @@ function userExplicitlyEndedTransaction(session: ClientSession) { return USER_EXPLICIT_TXN_END_STATES.has(session.transaction.state); } -function attemptTransaction( +function attemptTransaction( session: ClientSession, startTime: number, - fn: WithTransactionCallback, - options?: TransactionOptions + fn: WithTransactionCallback, + options: TransactionOptions = {} ): Promise { session.startTransaction(options); @@ -591,18 +593,18 @@ function attemptTransaction( if (!isPromiseLike(promise)) { session.abortTransaction().catch(() => null); - throw new MongoInvalidArgumentError( - 'Function provided to `withTransaction` must return a Promise' + return Promise.reject( + new MongoInvalidArgumentError('Function provided to `withTransaction` must return a Promise') ); } return promise.then( - () => { + result => { if (userExplicitlyEndedTransaction(session)) { - return; + return result; } - return attemptTransactionCommit(session, startTime, fn, options); + return attemptTransactionCommit(session, startTime, fn, result, options); }, err => { function maybeRetryOrThrow(err: MongoError): Promise { diff --git a/test/integration/sessions/sessions.test.ts b/test/integration/sessions/sessions.test.ts index 59aaecbed0..a9629a47e5 100644 --- a/test/integration/sessions/sessions.test.ts +++ b/test/integration/sessions/sessions.test.ts @@ -81,6 +81,7 @@ describe('Sessions Spec', function () { describe('withSession', function () { let client: MongoClient; + beforeEach(async function () { client = await this.configuration.newClient().connect(); }); @@ -184,6 +185,13 @@ describe('Sessions Spec', function () { expect(client.s.sessionPool.sessions).to.have.length(1); expect(sessionWasEnded).to.be.true; }); + + it('resolves with the value the callback returns', async () => { + const result = await client.withSession(async session => { + return client.db('test').collection('foo').find({}, { session }).toArray(); + }); + expect(result).to.be.an('array'); + }); }); context('unacknowledged writes', () => { diff --git a/test/integration/transactions/transactions.test.ts b/test/integration/transactions/transactions.test.ts index e953fecc8c..5169fd6dce 100644 --- a/test/integration/transactions/transactions.test.ts +++ b/test/integration/transactions/transactions.test.ts @@ -8,6 +8,7 @@ import { MongoNetworkError, type ServerSessionPool } from '../../mongodb'; +import { type FailPoint } from '../../tools/utils'; describe('Transactions', function () { describe('withTransaction', function () { @@ -90,7 +91,7 @@ describe('Transactions', function () { await client.close(); }); - it('should return undefined when transaction is aborted explicitly', async () => { + it('returns result of executor when transaction is aborted explicitly', async () => { const session = client.startSession(); const withTransactionResult = await session @@ -98,25 +99,25 @@ describe('Transactions', function () { await collection.insertOne({ a: 1 }, { session }); await collection.findOne({ a: 1 }, { session }); await session.abortTransaction(); + return 'aborted!'; }) .finally(async () => await session.endSession()); - expect(withTransactionResult).to.be.undefined; + expect(withTransactionResult).to.equal('aborted!'); }); - it('should return raw command when transaction is successfully committed', async () => { + it('returns result of executor when transaction is successfully committed', async () => { const session = client.startSession(); const withTransactionResult = await session .withTransaction(async session => { await collection.insertOne({ a: 1 }, { session }); await collection.findOne({ a: 1 }, { session }); + return 'committed!'; }) .finally(async () => await session.endSession()); - expect(withTransactionResult).to.exist; - expect(withTransactionResult).to.be.an('object'); - expect(withTransactionResult).to.have.property('ok', 1); + expect(withTransactionResult).to.equal('committed!'); }); it('should throw when transaction is aborted due to an error', async () => { @@ -136,6 +137,48 @@ describe('Transactions', function () { }); } ); + + context('when retried', { requires: { mongodb: '>=4.2.0', topology: '!single' } }, () => { + let client: MongoClient; + let collection: Collection<{ a: number }>; + + beforeEach(async function () { + client = this.configuration.newClient(); + + await client.db('admin').command({ + configureFailPoint: 'failCommand', + mode: { times: 2 }, + data: { + failCommands: ['commitTransaction'], + errorCode: 24, + errorLabels: ['TransientTransactionError'], + closeConnection: false + } + } as FailPoint); + + collection = await client.db('withTransaction').createCollection('withTransactionRetry'); + }); + + afterEach(async () => { + await client?.close(); + }); + + it('returns the value of the final call to the executor', async () => { + const session = client.startSession(); + + let counter = 0; + const withTransactionResult = await session + .withTransaction(async session => { + await collection.insertOne({ a: 1 }, { session }); + counter += 1; + return counter; + }) + .finally(async () => await session.endSession()); + + expect(counter).to.equal(3); + expect(withTransactionResult).to.equal(3); + }); + }); }); describe('startTransaction', function () { diff --git a/test/tools/unified-spec-runner/operations.ts b/test/tools/unified-spec-runner/operations.ts index 6e69660779..304b8cae06 100644 --- a/test/tools/unified-spec-runner/operations.ts +++ b/test/tools/unified-spec-runner/operations.ts @@ -593,18 +593,11 @@ operations.set('withTransaction', async ({ entities, operation, client, testConf maxCommitTimeMS: operation.arguments!.maxCommitTimeMS }; - let errorFromOperations = null; - const result = await session.withTransaction(async () => { - errorFromOperations = await (async () => { - for (const callbackOperation of operation.arguments!.callback) { - await executeOperationAndCheck(callbackOperation, entities, client, testConfig); - } - })().catch(error => error); + await session.withTransaction(async () => { + for (const callbackOperation of operation.arguments!.callback) { + await executeOperationAndCheck(callbackOperation, entities, client, testConfig); + } }, options); - - if (result == null || errorFromOperations) { - throw errorFromOperations ?? Error('transaction not committed'); - } }); operations.set('countDocuments', async ({ entities, operation }) => { diff --git a/test/types/sessions.test-d.ts b/test/types/sessions.test-d.ts index 333872b1df..48c8fcb0dc 100644 --- a/test/types/sessions.test-d.ts +++ b/test/types/sessions.test-d.ts @@ -15,3 +15,13 @@ expectType( }) ); expectError(client.startSession({ defaultTransactionOptions: { readConcern: 1 } })); + +let something: any; +expectType(await client.withSession(async () => 2)); +expectType(await client.withSession(async () => something)); +const untypedFn: any = () => 2; +expectType(await client.withSession(untypedFn)); +const unknownFn: () => Promise = async () => 2; +expectType(await client.withSession(unknownFn)); +// Not a promise returning function +expectError(await client.withSession(() => null));