diff --git a/mongo/session.go b/mongo/session.go index 36d0523de0..2c1c38b829 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -209,8 +209,10 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(sessCtx Sessi CommitLoop: for { err = s.CommitTransaction(ctx) - if err == nil { - return res, nil + // End when error is nil (transaction has been committed), or when context has timed out or been + // canceled, as retrying has no chance of success. + if err == nil || ctx.Err() != nil { + return res, err } select { @@ -307,6 +309,11 @@ func (s *sessionImpl) CommitTransaction(ctx context.Context) error { } err = op.Execute(ctx) + // Return error without updating transaction state if it is a timeout, as the transaction has not + // actually been committed. + if IsTimeout(err) { + return replaceErrors(err) + } s.clientSession.Committing = false commitErr := s.clientSession.CommitTransaction() diff --git a/mongo/with_transactions_test.go b/mongo/with_transactions_test.go index e719a6822c..0d41e7c300 100644 --- a/mongo/with_transactions_test.go +++ b/mongo/with_transactions_test.go @@ -259,6 +259,128 @@ func TestConvenientTransactions(t *testing.T) { assert.True(t, ok, "expected context for abortTransaction to contain ctxKey") assert.Equal(t, "foobar", ctxValue, "expected value for ctxKey to be 'world', got %s", ctxValue) }) + t.Run("commitTransaction timeout allows abortTransaction", func(t *testing.T) { + // Create a special CommandMonitor that only records information about abortTransaction events. + var abortStarted []*event.CommandStartedEvent + var abortSucceeded []*event.CommandSucceededEvent + var abortFailed []*event.CommandFailedEvent + monitor := &event.CommandMonitor{ + Started: func(ctx context.Context, evt *event.CommandStartedEvent) { + if evt.CommandName == "abortTransaction" { + abortStarted = append(abortStarted, evt) + } + }, + Succeeded: func(_ context.Context, evt *event.CommandSucceededEvent) { + if evt.CommandName == "abortTransaction" { + abortSucceeded = append(abortSucceeded, evt) + } + }, + Failed: func(_ context.Context, evt *event.CommandFailedEvent) { + if evt.CommandName == "abortTransaction" { + abortFailed = append(abortFailed, evt) + } + }, + } + + // Set up a new Client using the command monitor defined above get a handle to a collection. The collection + // needs to be explicitly created on the server because implicit collection creation is not allowed in + // transactions for server versions <= 4.2. + client := setupConvenientTransactions(t, options.Client().SetMonitor(monitor)) + db := client.Database("foo") + coll := db.Collection("test") + defer func() { + _ = coll.Drop(bgCtx) + }() + + err := db.RunCommand(bgCtx, bson.D{{"create", coll.Name()}}).Err() + assert.Nil(t, err, "error creating collection on server: %v", err) + + // Start session. + session, err := client.StartSession() + defer session.EndSession(bgCtx) + assert.Nil(t, err, "StartSession error: %v", err) + + _ = WithSession(bgCtx, session, func(sessionContext SessionContext) error { + // Start transaction. + err = session.StartTransaction() + assert.Nil(t, err, "StartTransaction error: %v", err) + + // Insert a document. + _, err := coll.InsertOne(sessionContext, bson.D{{"val", 17}}) + assert.Nil(t, err, "InsertOne error: %v", err) + + // Set a timeout of 0 for commitTransaction. + commitTimeoutCtx, commitCancel := context.WithTimeout(sessionContext, 0) + defer commitCancel() + + // CommitTransaction results in context.DeadlineExceeded. + commitErr := session.CommitTransaction(commitTimeoutCtx) + assert.True(t, IsTimeout(commitErr), + "expected timeout error error; got %v", commitErr) + + // Assert session state is not Committed. + clientSession := session.(XSession).ClientSession() + assert.False(t, clientSession.TransactionCommitted(), "expected session state to not be Committed") + + // AbortTransaction without error. + abortErr := session.AbortTransaction(context.Background()) + assert.Nil(t, abortErr, "AbortTransaction error: %v", abortErr) + + // Assert that AbortTransaction was started once and succeeded. + assert.Equal(t, 1, len(abortStarted), "expected 1 abortTransaction started event, got %d", len(abortStarted)) + assert.Equal(t, 1, len(abortSucceeded), "expected 1 abortTransaction succeeded event, got %d", + len(abortSucceeded)) + assert.Equal(t, 0, len(abortFailed), "expected 0 abortTransaction failed events, got %d", len(abortFailed)) + + return nil + }) + }) + t.Run("commitTransaction timeout does not retry", func(t *testing.T) { + withTransactionTimeout = 2 * time.Second + + coll := db.Collection("test") + // Explicitly create the collection on server because implicit collection creation is not allowed in + // transactions for server versions <= 4.2. + err := db.RunCommand(bgCtx, bson.D{{"create", coll.Name()}}).Err() + assert.Nil(t, err, "error creating collection on server: %v", err) + defer func() { + _ = coll.Drop(bgCtx) + }() + + // Start session. + sess, err := client.StartSession() + assert.Nil(t, err, "StartSession error: %v", err) + defer sess.EndSession(context.Background()) + + // Defer running killAllSessions to manually close open transaction. + defer func() { + err := dbAdmin.RunCommand(bgCtx, bson.D{ + {"killAllSessions", bson.A{}}, + }).Err() + if err != nil { + if ce, ok := err.(CommandError); !ok || ce.Code != errorInterrupted { + t.Fatalf("killAllSessions error: %v", err) + } + } + }() + + // Create context to manually cancel in callback. + cancelCtx, cancel := context.WithCancel(bgCtx) + defer cancel() + + // Insert a document within a session and manually cancel context. + callback := func() { + _, _ = sess.WithTransaction(cancelCtx, func(sessCtx SessionContext) (interface{}, error) { + _, err := coll.InsertOne(sessCtx, bson.M{"x": 1}) + assert.Nil(t, err, "InsertOne error: %v", err) + cancel() + return nil, nil + }) + } + + // Assert that transaction is canceled within 500ms and not 2 seconds. + assert.Soon(t, callback, 500*time.Millisecond) + }) } func setupConvenientTransactions(t *testing.T, extraClientOpts ...*options.ClientOptions) *Client {