diff --git a/x/mongo/driver/topology/CMAP_spec_test.go b/x/mongo/driver/topology/CMAP_spec_test.go index f95d74d8cb..2d036a04b2 100644 --- a/x/mongo/driver/topology/CMAP_spec_test.go +++ b/x/mongo/driver/topology/CMAP_spec_test.go @@ -160,7 +160,7 @@ func runCMAPTest(t *testing.T, testFileName string) { for len(testInfo.backgroundThreadErrors) > 0 { bgErr := <-testInfo.backgroundThreadErrors errs = append(errs, bgErr) - if bgErr != nil && strings.ToLower(test.Error.Message) == bgErr.Error() { + if bgErr != nil && strings.Contains(bgErr.Error(), strings.ToLower(test.Error.Message)) { erroredCorrectly = true break } diff --git a/x/mongo/driver/topology/errors.go b/x/mongo/driver/topology/errors.go index dfa3ef501b..1e6936a8c3 100644 --- a/x/mongo/driver/topology/errors.go +++ b/x/mongo/driver/topology/errors.go @@ -1,6 +1,8 @@ package topology -import "fmt" +import ( + "fmt" +) // ConnectionError represents a connection error. type ConnectionError struct { @@ -25,3 +27,20 @@ func (e ConnectionError) Error() string { func (e ConnectionError) Unwrap() error { return e.Wrapped } + +// WaitQueueTimeoutError represents a timeout when requesting a connection from the pool +type WaitQueueTimeoutError struct { + Wrapped error +} + +func (w WaitQueueTimeoutError) Error() string { + errorMsg := "timed out while checking out a connection from connection pool" + if w.Wrapped != nil { + return fmt.Sprintf("%s: %s", errorMsg, w.Wrapped.Error()) + } + return errorMsg +} + +func (w WaitQueueTimeoutError) Unwrap() error { + return w.Wrapped +} diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index f94a4c7347..0a4fe48708 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -31,9 +31,6 @@ var ErrConnectionClosed = ConnectionError{ConnectionID: "<closed>", message: "co // ErrWrongPool is return when a connection is returned to a pool it doesn't belong to. var ErrWrongPool = PoolError("connection does not belong to this pool") -// ErrWaitQueueTimeout is returned when the request to get a connection from the pool timesout when on the wait queue -var ErrWaitQueueTimeout = PoolError("timed out while checking out a connection from connection pool") - // PoolError is an error returned from a Pool method. type PoolError string @@ -340,7 +337,10 @@ func (p *pool) get(ctx context.Context) (*connection, error) { Reason: event.ReasonTimedOut, }) } - return nil, ErrWaitQueueTimeout + errWaitQueueTimeout := WaitQueueTimeoutError{ + Wrapped: ctx.Err(), + } + return nil, errWaitQueueTimeout } // This loop is so that we don't end up with more than maxPoolSize connections if p.conns.Maintain runs between diff --git a/x/mongo/driver/topology/pool_test.go b/x/mongo/driver/topology/pool_test.go index abcd63d015..0775d44c14 100644 --- a/x/mongo/driver/topology/pool_test.go +++ b/x/mongo/driver/topology/pool_test.go @@ -689,6 +689,40 @@ func TestPool(t *testing.T) { noerr(t, err) }) }) + t.Run("wait queue timeout error", func(t *testing.T) { + cleanup := make(chan struct{}) + addr := bootstrapConnections(t, 1, func(nc net.Conn) { + <-cleanup + _ = nc.Close() + }) + d := newdialer(&net.Dialer{}) + pc := poolConfig{ + Address: address.Address(addr.String()), + MaxPoolSize: 1, + } + p, err := newPool(pc, WithDialer(func(Dialer) Dialer { return d })) + noerr(t, err) + err = p.connect() + noerr(t, err) + + // get first connection. + _, err = p.get(context.Background()) + noerr(t, err) + + // Set a short timeout and get again. + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + _, err = p.get(ctx) + assert.NotNil(t, err, "expected a WaitQueueTimeout; got nil") + + // Assert that error received is WaitQueueTimeoutError with context deadline exceeded. + wqtErr, ok := err.(WaitQueueTimeoutError) + assert.True(t, ok, "expected a WaitQueueTimeoutError; got %v", err) + assert.True(t, wqtErr.Unwrap() == context.DeadlineExceeded, + "expected a timeout error; got %v", wqtErr) + + close(cleanup) + }) } type sleepDialer struct {