From 5199a0b7a47957a1679dd6b861463ebf5bd40cde Mon Sep 17 00:00:00 2001 From: Matt Dale <9760375+matthewdale@users.noreply.github.com> Date: Tue, 29 Jun 2021 17:44:36 -0700 Subject: [PATCH] GODRIVER-2037 Don't clear the connection pool on client-side connect timeout errors. (#688) --- mongo/integration/primary_stepdown_test.go | 59 ++++ mongo/integration/sdam_error_handling_test.go | 160 ++++++---- x/mongo/driver/topology/connection.go | 10 +- x/mongo/driver/topology/connection_options.go | 4 +- x/mongo/driver/topology/connection_test.go | 2 +- x/mongo/driver/topology/sdam_spec_test.go | 2 +- x/mongo/driver/topology/server.go | 63 +++- x/mongo/driver/topology/server_test.go | 300 ++++++++++++++++++ 8 files changed, 535 insertions(+), 65 deletions(-) diff --git a/mongo/integration/primary_stepdown_test.go b/mongo/integration/primary_stepdown_test.go index bc92d3bef8..4aa08a55d9 100644 --- a/mongo/integration/primary_stepdown_test.go +++ b/mongo/integration/primary_stepdown_test.go @@ -7,6 +7,7 @@ package integration import ( + "sync" "testing" "go.mongodb.org/mongo-driver/bson" @@ -23,7 +24,65 @@ const ( errorInterruptedAtShutdown int32 = 11600 ) +// testPoolMonitor exposes an *event.PoolMonitor and collects all events logged to that +// *event.PoolMonitor. It is safe to use from multiple concurrent goroutines. +type testPoolMonitor struct { + *event.PoolMonitor + + events []*event.PoolEvent + mu sync.RWMutex +} + +func newTestPoolMonitor() *testPoolMonitor { + tpm := &testPoolMonitor{ + events: make([]*event.PoolEvent, 0), + } + tpm.PoolMonitor = &event.PoolMonitor{ + Event: func(evt *event.PoolEvent) { + tpm.mu.Lock() + defer tpm.mu.Unlock() + tpm.events = append(tpm.events, evt) + }, + } + return tpm +} + +// Events returns a copy of the events collected by the testPoolMonitor. Filters can optionally be +// applied to the returned events set and are applied using AND logic (i.e. all filters must return +// true to include the event in the result). +func (tpm *testPoolMonitor) Events(filters ...func(*event.PoolEvent) bool) []*event.PoolEvent { + filtered := make([]*event.PoolEvent, 0, len(tpm.events)) + tpm.mu.RLock() + defer tpm.mu.RUnlock() + + for _, evt := range tpm.events { + keep := true + for _, filter := range filters { + if !filter(evt) { + keep = false + break + } + } + if keep { + filtered = append(filtered, evt) + } + } + + return filtered +} + +// IsPoolCleared returns true if there are any events of type "event.PoolCleared" in the events +// recorded by the testPoolMonitor. +func (tpm *testPoolMonitor) IsPoolCleared() bool { + poolClearedEvents := tpm.Events(func(evt *event.PoolEvent) bool { + return evt.Type == event.PoolCleared + }) + return len(poolClearedEvents) > 0 +} + var poolChan = make(chan *event.PoolEvent, 100) + +// TODO(GODRIVER-2068): Replace all uses of poolMonitor with individual instances of testPoolMonitor. var poolMonitor = &event.PoolMonitor{ Event: func(event *event.PoolEvent) { poolChan <- event diff --git a/mongo/integration/sdam_error_handling_test.go b/mongo/integration/sdam_error_handling_test.go index 3ae4ebf6f2..1529156e6c 100644 --- a/mongo/integration/sdam_error_handling_test.go +++ b/mongo/integration/sdam_error_handling_test.go @@ -28,13 +28,12 @@ func TestSDAMErrorHandling(t *testing.T) { return options.Client(). ApplyURI(mtest.ClusterURI()). SetRetryWrites(false). - SetPoolMonitor(poolMonitor). SetWriteConcern(mtest.MajorityWc) } baseMtOpts := func() *mtest.Options { mtOpts := mtest.NewOptions(). - Topologies(mtest.ReplicaSet). // Don't run on sharded clusters to avoid complexity of sharded failpoints. - MinServerVersion("4.0"). // 4.0+ is required to use failpoints on replica sets. + Topologies(mtest.ReplicaSet, mtest.Single). // Don't run on sharded clusters to avoid complexity of sharded failpoints. + MinServerVersion("4.0"). // 4.0+ is required to use failpoints on replica sets. ClientOptions(baseClientOpts()) if mtest.ClusterTopologyKind() == mtest.Sharded { @@ -48,13 +47,14 @@ func TestSDAMErrorHandling(t *testing.T) { // blockConnection and appName. mt.RunOpts("before handshake completes", baseMtOpts().Auth(true).MinServerVersion("4.4"), func(mt *mtest.T) { mt.RunOpts("network errors", noClientOpts, func(mt *mtest.T) { - mt.Run("pool cleared on network timeout", func(mt *mtest.T) { - // Assert that the pool is cleared when a connection created by an application operation thread - // encounters a network timeout during handshaking. Unlike the non-timeout test below, we only test - // connections created in the foreground for timeouts because connections created by the pool - // maintenance routine can't be timed out using a context. - - appName := "authNetworkTimeoutTest" + mt.Run("pool not cleared on operation-scoped network timeout", func(mt *mtest.T) { + // Assert that the pool is not cleared when a connection created by an application + // operation thread encounters an operation timeout during handshaking. Unlike the + // non-timeout test below, we only test connections created in the foreground for + // timeouts because connections created by the pool maintenance routine can't be + // timed out using a context. + + appName := "authOperationTimeoutTest" // Set failpoint on saslContinue instead of saslStart because saslStart isn't done when using // speculative auth. mt.SetFailPoint(mtest.FailPoint{ @@ -70,24 +70,61 @@ func TestSDAMErrorHandling(t *testing.T) { }, }) - // Reset the client with the appName specified in the failpoint. - clientOpts := options.Client(). - SetAppName(appName). - SetRetryWrites(false). - SetPoolMonitor(poolMonitor) - mt.ResetClient(clientOpts) - clearPoolChan() + // Reset the client with the appName specified in the failpoint and the pool monitor. + tpm := newTestPoolMonitor() + mt.ResetClient(baseClientOpts().SetAppName(appName).SetPoolMonitor(tpm.PoolMonitor)) - // The saslContinue blocks for 150ms so run the InsertOne with a 100ms context to cause a network - // timeout during auth and assert that the pool was cleared. + // Use a context with a 100ms timeout so that the saslContinue delay of 150ms causes + // an operation-scoped context timeout (i.e. a timeout not caused by a client timeout + // like connectTimeoutMS or socketTimeoutMS). timeoutCtx, cancel := context.WithTimeout(mtest.Background, 100*time.Millisecond) defer cancel() _, err := mt.Coll.InsertOne(timeoutCtx, bson.D{{"test", 1}}) assert.NotNil(mt, err, "expected InsertOne error, got nil") assert.True(mt, mongo.IsTimeout(err), "expected timeout error, got %v", err) assert.True(mt, mongo.IsNetworkError(err), "expected network error, got %v", err) - assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not") + assert.False(mt, tpm.IsPoolCleared(), "expected pool not to be cleared but was cleared") }) + + mt.Run("pool cleared on non-operation-scoped network timeout", func(mt *mtest.T) { + // Assert that the pool is cleared when a connection created by an application + // operation thread encounters a timeout caused by connectTimeoutMS during + // handshaking. + + appName := "authConnectTimeoutTest" + // Set failpoint on saslContinue instead of saslStart because saslStart isn't done when using + // speculative auth. + mt.SetFailPoint(mtest.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: mtest.FailPointMode{ + Times: 1, + }, + Data: mtest.FailPointData{ + FailCommands: []string{"saslContinue"}, + BlockConnection: true, + BlockTimeMS: 150, + AppName: appName, + }, + }) + + // Reset the client with the appName specified in the failpoint and the pool monitor. + tpm := newTestPoolMonitor() + mt.ResetClient(baseClientOpts(). + SetAppName(appName). + SetPoolMonitor(tpm.PoolMonitor). + // Set a 100ms socket timeout so that the saslContinue delay of 150ms causes a + // timeout during socket read (i.e. a timeout not caused by the InsertOne context). + SetSocketTimeout(100 * time.Millisecond)) + + // Use context.Background() so that the new connection will not time out due to an + // operation-scoped timeout. + _, err := mt.Coll.InsertOne(context.Background(), bson.D{{"test", 1}}) + assert.NotNil(mt, err, "expected InsertOne error, got nil") + assert.True(mt, mongo.IsTimeout(err), "expected timeout error, got %v", err) + assert.True(mt, mongo.IsNetworkError(err), "expected network error, got %v", err) + assert.True(mt, tpm.IsPoolCleared(), "expected pool to be cleared but was not") + }) + mt.RunOpts("pool cleared on non-timeout network error", noClientOpts, func(mt *mtest.T) { mt.Run("background", func(mt *mtest.T) { // Assert that the pool is cleared when a connection created by the background pool maintenance @@ -106,16 +143,19 @@ func TestSDAMErrorHandling(t *testing.T) { }, }) - clientOpts := options.Client(). + // Reset the client with the appName specified in the failpoint. + tpm := newTestPoolMonitor() + mt.ResetClient(baseClientOpts(). SetAppName(appName). - SetMinPoolSize(5). - SetPoolMonitor(poolMonitor) - mt.ResetClient(clientOpts) - clearPoolChan() + SetPoolMonitor(tpm.PoolMonitor). + // Set minPoolSize to enable the background pool maintenance goroutine. + SetMinPoolSize(5)) time.Sleep(200 * time.Millisecond) - assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not") + + assert.True(mt, tpm.IsPoolCleared(), "expected pool to be cleared but was not") }) + mt.Run("foreground", func(mt *mtest.T) { // Assert that the pool is cleared when a connection created by an application thread connection // checkout encounters a non-timeout network error during handshaking. @@ -133,16 +173,14 @@ func TestSDAMErrorHandling(t *testing.T) { }, }) - clientOpts := options.Client(). - SetAppName(appName). - SetPoolMonitor(poolMonitor) - mt.ResetClient(clientOpts) - clearPoolChan() + // Reset the client with the appName specified in the failpoint. + tpm := newTestPoolMonitor() + mt.ResetClient(baseClientOpts().SetAppName(appName).SetPoolMonitor(tpm.PoolMonitor)) _, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"x", 1}}) assert.NotNil(mt, err, "expected InsertOne error, got nil") assert.False(mt, mongo.IsTimeout(err), "expected non-timeout error, got %v", err) - assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not") + assert.True(mt, tpm.IsPoolCleared(), "expected pool to be cleared but was not") }) }) }) @@ -150,7 +188,8 @@ func TestSDAMErrorHandling(t *testing.T) { mt.RunOpts("after handshake completes", baseMtOpts(), func(mt *mtest.T) { mt.RunOpts("network errors", noClientOpts, func(mt *mtest.T) { mt.Run("pool cleared on non-timeout network error", func(mt *mtest.T) { - clearPoolChan() + appName := "afterHandshakeNetworkError" + mt.SetFailPoint(mtest.FailPoint{ ConfigureFailPoint: "failCommand", Mode: mtest.FailPointMode{ @@ -159,16 +198,22 @@ func TestSDAMErrorHandling(t *testing.T) { Data: mtest.FailPointData{ FailCommands: []string{"insert"}, CloseConnection: true, + AppName: appName, }, }) + // Reset the client with the appName specified in the failpoint. + tpm := newTestPoolMonitor() + mt.ResetClient(baseClientOpts().SetAppName(appName).SetPoolMonitor(tpm.PoolMonitor)) + _, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"test", 1}}) assert.NotNil(mt, err, "expected InsertOne error, got nil") assert.False(mt, mongo.IsTimeout(err), "expected non-timeout error, got %v", err) - assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not") + assert.True(mt, tpm.IsPoolCleared(), "expected pool to be cleared but was not") }) mt.Run("pool not cleared on timeout network error", func(mt *mtest.T) { - clearPoolChan() + tpm := newTestPoolMonitor() + mt.ResetClient(baseClientOpts().SetPoolMonitor(tpm.PoolMonitor)) _, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"x", 1}}) assert.Nil(mt, err, "InsertOne error: %v", err) @@ -181,11 +226,11 @@ func TestSDAMErrorHandling(t *testing.T) { _, err = mt.Coll.Find(timeoutCtx, filter) assert.NotNil(mt, err, "expected Find error, got %v", err) assert.True(mt, mongo.IsTimeout(err), "expected timeout error, got %v", err) - - assert.False(mt, isPoolCleared(), "expected pool to not be cleared but was") + assert.False(mt, tpm.IsPoolCleared(), "expected pool to not be cleared but was") }) mt.Run("pool not cleared on context cancellation", func(mt *mtest.T) { - clearPoolChan() + tpm := newTestPoolMonitor() + mt.ResetClient(baseClientOpts().SetPoolMonitor(tpm.PoolMonitor)) _, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"x", 1}}) assert.Nil(mt, err, "InsertOne error: %v", err) @@ -204,8 +249,7 @@ func TestSDAMErrorHandling(t *testing.T) { assert.False(mt, mongo.IsTimeout(err), "expected non-timeout error, got %v", err) assert.True(mt, mongo.IsNetworkError(err), "expected network error, got %v", err) assert.True(mt, errors.Is(err, context.Canceled), "expected error %v to be context.Canceled", err) - - assert.False(mt, isPoolCleared(), "expected pool to not be cleared but was") + assert.False(mt, tpm.IsPoolCleared(), "expected pool to not be cleared but was") }) }) mt.RunOpts("server errors", noClientOpts, func(mt *mtest.T) { @@ -242,10 +286,10 @@ func TestSDAMErrorHandling(t *testing.T) { } for _, tc := range testCases { mt.RunOpts(fmt.Sprintf("command error - %s", tc.name), serverErrorsMtOpts, func(mt *mtest.T) { - clearPoolChan() + appName := fmt.Sprintf("command_error_%s", tc.name) // Cause the next insert to fail with an ok:0 response. - fp := mtest.FailPoint{ + mt.SetFailPoint(mtest.FailPoint{ ConfigureFailPoint: "failCommand", Mode: mtest.FailPointMode{ Times: 1, @@ -253,17 +297,21 @@ func TestSDAMErrorHandling(t *testing.T) { Data: mtest.FailPointData{ FailCommands: []string{"insert"}, ErrorCode: tc.errorCode, + AppName: appName, }, - } - mt.SetFailPoint(fp) + }) + + // Reset the client with the appName specified in the failpoint. + tpm := newTestPoolMonitor() + mt.ResetClient(baseClientOpts().SetAppName(appName).SetPoolMonitor(tpm.PoolMonitor)) - runServerErrorsTest(mt, tc.isShutdownError) + runServerErrorsTest(mt, tc.isShutdownError, tpm) }) mt.RunOpts(fmt.Sprintf("write concern error - %s", tc.name), serverErrorsMtOpts, func(mt *mtest.T) { - clearPoolChan() + appName := fmt.Sprintf("write_concern_error_%s", tc.name) // Cause the next insert to fail with a write concern error. - fp := mtest.FailPoint{ + mt.SetFailPoint(mtest.FailPoint{ ConfigureFailPoint: "failCommand", Mode: mtest.FailPointMode{ Times: 1, @@ -273,18 +321,22 @@ func TestSDAMErrorHandling(t *testing.T) { WriteConcernError: &mtest.WriteConcernErrorData{ Code: tc.errorCode, }, + AppName: appName, }, - } - mt.SetFailPoint(fp) + }) + + // Reset the client with the appName specified in the failpoint. + tpm := newTestPoolMonitor() + mt.ResetClient(baseClientOpts().SetAppName(appName).SetPoolMonitor(tpm.PoolMonitor)) - runServerErrorsTest(mt, tc.isShutdownError) + runServerErrorsTest(mt, tc.isShutdownError, tpm) }) } }) }) } -func runServerErrorsTest(mt *mtest.T, isShutdownError bool) { +func runServerErrorsTest(mt *mtest.T, isShutdownError bool, tpm *testPoolMonitor) { mt.Helper() _, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"x", 1}}) @@ -292,13 +344,13 @@ func runServerErrorsTest(mt *mtest.T, isShutdownError bool) { // The pool should always be cleared for shutdown errors, regardless of server version. if isShutdownError { - assert.True(mt, isPoolCleared(), "expected pool to be cleared, but was not") + assert.True(mt, tpm.IsPoolCleared(), "expected pool to be cleared, but was not") return } // For non-shutdown errors, the pool is only cleared if the error is from a pre-4.2 server. wantCleared := mtest.CompareServerVersions(mtest.ServerVersion(), "4.2") < 0 - gotCleared := isPoolCleared() - assert.Equal(mt, wantCleared, gotCleared, "expected pool to be cleared: %v; pool was cleared: %v", + gotCleared := tpm.IsPoolCleared() + assert.Equal(mt, wantCleared, gotCleared, "expected pool to be cleared: %t; pool was cleared: %t", wantCleared, gotCleared) } diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 46dd9e6991..e23aeeaf1b 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -101,7 +101,7 @@ func newConnection(addr address.Address, opts ...ConnectionOption) (*connection, return c, nil } -func (c *connection) processInitializationError(err error) { +func (c *connection) processInitializationError(opCtx context.Context, err error) { atomic.StoreInt32(&c.connected, disconnected) if c.nc != nil { _ = c.nc.Close() @@ -109,7 +109,7 @@ func (c *connection) processInitializationError(err error) { c.connectErr = ConnectionError{Wrapped: err, init: true} if c.config.errorHandlingCallback != nil { - c.config.errorHandlingCallback(c.connectErr, c.generation, c.desc.ServiceID) + c.config.errorHandlingCallback(opCtx, c.connectErr, c.generation, c.desc.ServiceID) } } @@ -184,7 +184,7 @@ func (c *connection) connect(ctx context.Context) { var tempNc net.Conn tempNc, err = c.config.dialer.DialContext(dialCtx, c.addr.Network(), c.addr.String()) if err != nil { - c.processInitializationError(err) + c.processInitializationError(ctx, err) return } c.nc = tempNc @@ -200,7 +200,7 @@ func (c *connection) connect(ctx context.Context) { } tlsNc, err := configureTLS(dialCtx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts) if err != nil { - c.processInitializationError(err) + c.processInitializationError(ctx, err) return } c.nc = tlsNc @@ -252,7 +252,7 @@ func (c *connection) connect(ctx context.Context) { // We have a failed handshake here if err != nil { - c.processInitializationError(err) + c.processInitializationError(ctx, err) return } diff --git a/x/mongo/driver/topology/connection_options.go b/x/mongo/driver/topology/connection_options.go index fc02075575..b9dac8c406 100644 --- a/x/mongo/driver/topology/connection_options.go +++ b/x/mongo/driver/topology/connection_options.go @@ -55,7 +55,7 @@ type connectionConfig struct { zstdLevel *int ocspCache ocsp.Cache disableOCSPEndpointCheck bool - errorHandlingCallback func(error, uint64, *primitive.ObjectID) + errorHandlingCallback func(opCtx context.Context, err error, startGenNum uint64, svcID *primitive.ObjectID) tlsConnectionSource tlsConnectionSource loadBalanced bool getGenerationFn generationNumberFn @@ -92,7 +92,7 @@ func withTLSConnectionSource(fn func(tlsConnectionSource) tlsConnectionSource) C } } -func withErrorHandlingCallback(fn func(error, uint64, *primitive.ObjectID)) ConnectionOption { +func withErrorHandlingCallback(fn func(opCtx context.Context, err error, startGenNum uint64, svcID *primitive.ObjectID)) ConnectionOption { return func(c *connectionConfig) error { c.errorHandlingCallback = fn return nil diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index 6249322d04..04b00d52cb 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -129,7 +129,7 @@ func TestConnection(t *testing.T) { return &net.TCPConn{}, nil }) }), - withErrorHandlingCallback(func(err error, _ uint64, _ *primitive.ObjectID) { + withErrorHandlingCallback(func(_ context.Context, err error, _ uint64, _ *primitive.ObjectID) { got = err }), ) diff --git a/x/mongo/driver/topology/sdam_spec_test.go b/x/mongo/driver/topology/sdam_spec_test.go index b60cffae7e..5908e8a32e 100644 --- a/x/mongo/driver/topology/sdam_spec_test.go +++ b/x/mongo/driver/topology/sdam_spec_test.go @@ -343,7 +343,7 @@ func applyErrors(t *testing.T, topo *Topology, errors []applicationError) { switch appErr.When { case "beforeHandshakeCompletes": - server.ProcessHandshakeError(currError, generation, nil) + server.ProcessHandshakeError(nil, currError, generation, nil) case "afterHandshakeCompletes": _ = server.ProcessError(currError, &conn) default: diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index 8806cf6951..0120fd111c 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -272,8 +272,11 @@ func (s *Server) Connection(ctx context.Context) (driver.Connection, error) { return &Connection{connection: connImpl}, nil } -// ProcessHandshakeError implements SDAM error handling for errors that occur before a connection finishes handshaking. -func (s *Server) ProcessHandshakeError(err error, startingGenerationNumber uint64, serviceID *primitive.ObjectID) { +// ProcessHandshakeError implements SDAM error handling for errors that occur before a connection +// finishes handshaking. opCtx is the context passed to Server.Connection() and is used to determine +// whether or not an operation-scoped context deadline or cancellation was the cause of the +// handshake error; it is not used for timeout or cancellation of ProcessHandshakeError. +func (s *Server) ProcessHandshakeError(opCtx context.Context, err error, startingGenerationNumber uint64, serviceID *primitive.ObjectID) { // Ignore the error if the server is behind a load balancer but the service ID is unknown. This indicates that the // error happened when dialing the connection or during the MongoDB handshake, so we don't know the service ID to // use for clearing the pool. @@ -290,6 +293,62 @@ func (s *Server) ProcessHandshakeError(err error, startingGenerationNumber uint6 return } + isCtxTimeoutOrCanceled := func(ctx context.Context) bool { + if ctx == nil { + return false + } + + if ctx.Err() != nil { + return true + } + + // In some networking functions, the deadline from the context is used to determine timeouts + // instead of the ctx.Done() chan closure. In that case, there can be a race condition + // between the networking functing returning an error and ctx.Err() returning an an error + // (i.e. the networking function returns an error caused by the context deadline, but + // ctx.Err() returns nil). If the operation-scoped context deadline was exceeded, assume + // operation-scoped context timeout caused the error. + if deadline, ok := ctx.Deadline(); ok && time.Now().After(deadline) { + return true + } + + return false + } + + isErrTimeoutOrCanceled := func(err error) bool { + for err != nil { + // Check for errors that implement the "net.Error" interface and self-report as timeout + // errors. Includes some "*net.OpError" errors and "context.DeadlineExceeded". + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return true + } + // Check for context cancellation. Also handle the case where the cancellation error has + // been replaced by "net.errCanceled" (which isn't exported and can't be compared + // directly) by checking the error message. + if err == context.Canceled || err.Error() == "operation was canceled" { + return true + } + + wrapper, ok := err.(interface{ Unwrap() error }) + if !ok { + break + } + err = wrapper.Unwrap() + } + + return false + } + + // Ignore errors that indicate a client-side timeout occurred when the context passed into an + // operation timed out or was canceled (i.e. errors caused by an operation-scoped timeout). + // Timeouts caused by reaching connectTimeoutMS or other non-operation-scoped timeouts should + // still clear the pool. + // TODO(GODRIVER-2038): Remove this condition when connections are no longer created with an + // operation-scoped timeout. + if isCtxTimeoutOrCanceled(opCtx) && isErrTimeoutOrCanceled(wrappedConnErr) { + return + } + // Since the only kind of ConnectionError we receive from pool.Get will be an initialization error, we should set // the description.Server appropriately. The description should not have a TopologyVersion because the staleness // checking logic above has already determined that this description is not stale. diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index 37afac2ff2..bb1746d1f7 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -54,6 +54,306 @@ func (cncd *channelNetConnDialer) DialContext(_ context.Context, _, _ string) (n return cnc, nil } +// TestServerConnectionTimeout tests how different timeout errors are handled during connection +// creation and server handshake. +func TestServerConnectionTimeout(t *testing.T) { + testCases := []struct { + desc string + dialer func(Dialer) Dialer + handshaker func(Handshaker) Handshaker + operationTimeout time.Duration + connectTimeout time.Duration + expectErr bool + expectPoolCleared bool + }{ + { + desc: "successful connection should not clear the pool", + expectErr: false, + expectPoolCleared: false, + }, + { + desc: "operation timeout error during dialing should not clear the pool", + dialer: func(Dialer) Dialer { + var d net.Dialer + return DialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { + // Wait for the passed in context to time out. Expect the error returned by + // DialContext() to be treated as a timeout caused by an operation-scoped deadline. + // E.g. FindOne(context.WithTimeout(...)) + <-ctx.Done() + return d.DialContext(ctx, network, addr) + }) + }, + operationTimeout: 100 * time.Millisecond, + connectTimeout: 1 * time.Minute, + expectErr: true, + expectPoolCleared: false, + }, + { + desc: "connectTimeMS timeout error during dialing should clear the pool", + dialer: func(Dialer) Dialer { + var d net.Dialer + return DialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { + // Wait for the passed in context to time out. Expect the error returned by + // DialContext() to be treated as a timeout caused by reaching connectTimeoutMS. + <-ctx.Done() + return d.DialContext(ctx, network, addr) + }) + }, + operationTimeout: 1 * time.Minute, + connectTimeout: 100 * time.Millisecond, + expectErr: true, + expectPoolCleared: true, + }, + { + desc: "connectTimeMS timeout error during dialing with no operation timeout should clear the pool", + dialer: func(Dialer) Dialer { + var d net.Dialer + return DialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { + // Wait for the passed in context to time out. Expect the error returned by + // DialContext() to be treated as a timeout caused by reaching connectTimeoutMS. + <-ctx.Done() + return d.DialContext(ctx, network, addr) + }) + }, + operationTimeout: 0, // Uses a context.Background() with no timeout. + connectTimeout: 100 * time.Millisecond, + expectErr: true, + expectPoolCleared: true, + }, + { + desc: "operation timeout error during handshake should not clear the pool", + handshaker: func(Handshaker) Handshaker { + h := auth.Handshaker(nil, &auth.HandshakeOptions{}) + return &testHandshaker{ + getHandshakeInformation: func(ctx context.Context, addr address.Address, c driver.Connection) (driver.HandshakeInformation, error) { + // Wait for the passed in context to time out. Expect the error returned by + // GetHandshakeInformation() to be treated as a timeout caused by an + // operation-scoped deadline. + // E.g. FindOne(context.WithTimeout(...)) + <-ctx.Done() + return h.GetHandshakeInformation(ctx, addr, c) + }, + } + }, + operationTimeout: 100 * time.Millisecond, + connectTimeout: 1 * time.Minute, + expectErr: true, + expectPoolCleared: false, + }, + { + desc: "dial errors unrelated to context timeouts should clear the pool", + dialer: func(Dialer) Dialer { + var d net.Dialer + return DialerFunc(func(ctx context.Context, _, _ string) (net.Conn, error) { + // Try to dial an invalid TCP address and expect an error. + return d.DialContext(ctx, "tcp", "300.0.0.0:nope") + }) + }, + expectErr: true, + expectPoolCleared: true, + }, + { + desc: "context error with dial errors unrelated to context timeouts should clear the pool", + dialer: func(Dialer) Dialer { + var d net.Dialer + return DialerFunc(func(ctx context.Context, _, _ string) (net.Conn, error) { + // Try to dial an invalid TCP address and expect an error. + c, err := d.DialContext(ctx, "tcp", "300.0.0.0:nope") + // Wait for the passed in context to time out. Expect that the context error is + // ignored because the dial error is not a timeout. + <-ctx.Done() + return c, err + }) + }, + operationTimeout: 100 * time.Millisecond, + expectErr: true, + expectPoolCleared: true, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + // Start a goroutine that pulls events from the events channel and inserts them into the + // events slice. Use a sync.WaitGroup to allow the test code to block until the events + // channel loop exits, guaranteeing that all events were copied from the channel. + // TODO(GODRIVER-2068): Consider using the "testPoolMonitor" from the "mongo/integration" + // package. Requires moving "testPoolMonitor" into a test utilities package. + events := make([]*event.PoolEvent, 0) + eventsCh := make(chan *event.PoolEvent) + var eventsWg sync.WaitGroup + eventsWg.Add(1) + go func() { + defer eventsWg.Done() + for evt := range eventsCh { + events = append(events, evt) + } + }() + + // Create a TCP listener on a random port. The listener will accept connections but not + // read or write to them. + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { + _ = l.Close() + }() + + server, err := NewServer( + address.Address(l.Addr().String()), + primitive.NewObjectID(), + // Create a connection pool event monitor that sends all events to an events channel + // so we can assert on the connection pool events later. + WithConnectionPoolMonitor(func(_ *event.PoolMonitor) *event.PoolMonitor { + return &event.PoolMonitor{ + Event: func(event *event.PoolEvent) { + eventsCh <- event + }, + } + }), + // Replace the default dialer and handshaker with the test dialer and handshaker, if + // present. + WithConnectionOptions(func(opts ...ConnectionOption) []ConnectionOption { + if tc.connectTimeout > 0 { + opts = append(opts, WithConnectTimeout(func(time.Duration) time.Duration { return tc.connectTimeout })) + } + if tc.dialer != nil { + opts = append(opts, WithDialer(tc.dialer)) + } + if tc.handshaker != nil { + opts = append(opts, WithHandshaker(tc.handshaker)) + } + return opts + }), + // Disable monitoring to prevent unrelated failures from the RTT monitor and + // heartbeats from unexpectedly clearing the connection pool. + withMonitoringDisabled(func(bool) bool { return true }), + ) + require.NoError(t, err) + require.NoError(t, server.Connect(nil)) + + // Create a context with the operation timeout if one is specified in the test case. + ctx := context.Background() + if tc.operationTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, tc.operationTimeout) + defer cancel() + } + _, err = server.Connection(ctx) + if tc.expectErr { + assert.NotNil(t, err, "expected an error but got nil") + } else { + assert.Nil(t, err, "expected no error but got %s", err) + } + + // Disconnect the server then close the events channel and expect that no more events + // are sent on the channel. Then wait for the events channel loop to return before + // inspecting the events slice. + _ = server.Disconnect(context.Background()) + close(eventsCh) + eventsWg.Wait() + require.NotEmpty(t, events, "expected more than 0 connection pool monitor events") + + // Look for any "ConnectionPoolCleared" events in the events slice so we can assert that + // the Server connection pool was or wasn't cleared, depending on the test expectations. + poolCleared := false + for _, evt := range events { + if evt.Type == event.PoolCleared { + poolCleared = true + } + } + assert.Equal( + t, + tc.expectPoolCleared, + poolCleared, + "want pool cleared: %t, actual pool cleared: %t", + tc.expectPoolCleared, + poolCleared) + }) + } +} + +// TestServerConnectionCancellation tests how context cancellation errors are handled during +// connection creation. +func TestServerConnectionCancellation(t *testing.T) { + // Start a goroutine that pulls events from the events channel and inserts them into the + // events slice. Use a sync.WaitGroup to allow the test code to block until the events + // channel loop exits, guaranteeing that all events were copied from the channel. + // TODO(GODRIVER-2068): Consider using the "testPoolMonitor" from the "mongo/integration" + // package. Requires moving "testPoolMonitor" into a test utilities package. + events := make([]*event.PoolEvent, 0) + eventsCh := make(chan *event.PoolEvent) + var eventsWg sync.WaitGroup + eventsWg.Add(1) + go func() { + defer eventsWg.Done() + for evt := range eventsCh { + events = append(events, evt) + } + }() + + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { + _ = l.Close() + }() + + // Create a context with cancellation that can be cancelled before the dial completes. + ctx, cancel := context.WithCancel(context.Background()) + + server, err := NewServer( + address.Address(l.Addr().String()), + primitive.NewObjectID(), + // Create a connection pool event monitor that sends all events to an events channel + // so we can assert on the connection pool events later. + WithConnectionPoolMonitor(func(_ *event.PoolMonitor) *event.PoolMonitor { + return &event.PoolMonitor{ + Event: func(event *event.PoolEvent) { + eventsCh <- event + }, + } + }), + // Disable monitoring to prevent unrelated failures from the RTT monitor and + // heartbeats from unexpectedly clearing the connection pool. + withMonitoringDisabled(func(bool) bool { return true }), + WithConnectionOptions(func(opts ...ConnectionOption) []ConnectionOption { + return append(opts, WithDialer(func(Dialer) Dialer { + var d net.Dialer + return DialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { + // Cancel the operation context before dialing and expect a context cancellation + // error. + cancel() + return d.DialContext(ctx, network, addr) + }) + })) + }), + ) + require.NoError(t, err) + require.NoError(t, server.Connect(nil)) + + _, err = server.Connection(ctx) + assert.NotNil(t, err, "expected an error but got nil") + + // Disconnect the server then close the events channel and expect that no more events + // are sent on the channel. Then wait for the events channel loop to return before + // inspecting the events slice. + _ = server.Disconnect(context.Background()) + close(eventsCh) + eventsWg.Wait() + require.NotEmpty(t, events, "expected more than 0 connection pool monitor events") + + // Look for any "ConnectionPoolCleared" events in the events slice so we can assert that + // the Server connection pool wasn't cleared. + poolCleared := false + for _, evt := range events { + if evt.Type == event.PoolCleared { + poolCleared = true + } + } + assert.False(t, poolCleared, "expected pool to not be cleared, but was cleared") +} + func TestServer(t *testing.T) { var serverTestTable = []struct { name string