diff --git a/mongo/errors.go b/mongo/errors.go index 329cadb0a3..5267621274 100644 --- a/mongo/errors.go +++ b/mongo/errors.go @@ -127,16 +127,21 @@ func unwrap(err error) error { return u.Unwrap() } -// IsNetworkError returns true if err is a network error -func IsNetworkError(err error) bool { +// errorHasLabel returns true if err contains the specified label +func errorHasLabel(err error, label string) bool { for ; err != nil; err = unwrap(err) { if e, ok := err.(ServerError); ok { - return e.HasErrorLabel("NetworkError") + return e.HasErrorLabel(label) } } return false } +// IsNetworkError returns true if err is a network error +func IsNetworkError(err error) bool { + return errorHasLabel(err, "NetworkError") +} + // MongocryptError represents an libmongocrypt error during client-side encryption. type MongocryptError struct { Code int32 diff --git a/mongo/session.go b/mongo/session.go index 2c1c38b829..f51aee2125 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -193,10 +193,8 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(sessCtx Sessi default: } - if cerr, ok := err.(CommandError); ok { - if cerr.HasErrorLabel(driver.TransientTransactionError) { - continue - } + if errorHasLabel(err, driver.TransientTransactionError) { + continue } return res, err } diff --git a/mongo/with_transactions_test.go b/mongo/with_transactions_test.go index 6d64af7292..cd656b9a7e 100644 --- a/mongo/with_transactions_test.go +++ b/mongo/with_transactions_test.go @@ -35,6 +35,18 @@ var ( withTxnFailedEvents []*event.CommandFailedEvent ) +type wrappedError struct { + err error +} + +func (we wrappedError) Error() string { + return we.err.Error() +} + +func (we wrappedError) Unwrap() error { + return we.err +} + func TestConvenientTransactions(t *testing.T) { client := setupConvenientTransactions(t) db := client.Database("TestConvenientTransactions") @@ -381,6 +393,30 @@ func TestConvenientTransactions(t *testing.T) { // Assert that transaction is canceled within 500ms and not 2 seconds. assert.Soon(t, callback, 500*time.Millisecond) }) + t.Run("wrapped transient transaction error retried", func(t *testing.T) { + sess, err := client.StartSession() + assert.Nil(t, err, "StartSession error: %v", err) + defer sess.EndSession(context.Background()) + + // returnError tracks whether or not the callback is being retried + returnError := true + res, err := sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) { + if returnError { + returnError = false + return nil, wrappedError{ + CommandError{ + Name: "test Error", + Labels: []string{driver.TransientTransactionError}, + }, + } + } + return false, nil + }) + assert.Nil(t, err, "WithTransaction error: %v", err) + resBool, ok := res.(bool) + assert.True(t, ok, "expected result type %T, got %T", false, res) + assert.False(t, resBool, "expected result false, got %v", resBool) + }) } func setupConvenientTransactions(t *testing.T, extraClientOpts ...*options.ClientOptions) *Client {