From 03b9852a2fe57920e7c882fcd10823b668f92389 Mon Sep 17 00:00:00 2001 From: Divjot Arora Date: Wed, 28 Apr 2021 15:06:31 -0400 Subject: [PATCH 1/2] GODRIVER-1750 Ensure contexts are always cancelled during server monitoring --- x/mongo/driver/topology/server.go | 4 +++ x/mongo/driver/topology/server_test.go | 36 ++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index 4a8c9392d5..8806cf6951 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -617,6 +617,10 @@ func (s *Server) setupHeartbeatConnection() error { // Take the lock when assigning the context and connection because they're accessed by cancelCheck. s.heartbeatLock.Lock() + if s.heartbeatCtxCancel != nil { + // Ensure the previous context is cancelled to avoid a leak. + s.heartbeatCtxCancel() + } s.heartbeatCtx, s.heartbeatCtxCancel = context.WithCancel(s.globalCtx) s.conn = conn s.heartbeatLock.Unlock() diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index 7174d3c87a..06afc3c7b4 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -627,6 +627,42 @@ func TestServer(t *testing.T) { assert.Equal(t, s.cfg.heartbeatTimeout, conn.readTimeout, "expected readTimeout to be: %v, got: %v", s.cfg.heartbeatTimeout, conn.readTimeout) assert.Equal(t, s.cfg.heartbeatTimeout, conn.writeTimeout, "expected writeTimeout to be: %v, got: %v", s.cfg.heartbeatTimeout, conn.writeTimeout) }) + t.Run("heartbeat contexts are not leaked", func(t *testing.T) { + // The context created for heartbeats should be cancelled when it is no longer needed to avoid leaks. + + server, err := ConnectServer( + address.Address("invalid"), + nil, + primitive.NewObjectID(), + withMonitoringDisabled(func(bool) bool { + return true + }), + ) + assert.Nil(t, err, "ConnectServer error: %v", err) + + // Expect check to return an error in the server description because the server address doesn't exist. This is + // OK because we just want to ensure the heartbeat context is created. + desc, err := server.check() + assert.Nil(t, err, "check error: %v", err) + assert.NotNil(t, desc.LastError, "expected server description to contain an error, got nil") + assert.NotNil(t, server.heartbeatCtx, "expected heartbeatCtx to be non-nil, got nil") + assert.Nil(t, server.heartbeatCtx.Err(), "expected heartbeatCtx error to be nil, got %v", server.heartbeatCtx.Err()) + + // Override heartbeatCtxCancel with a wrapper that records whether or not it was called. + oldCancelFn := server.heartbeatCtxCancel + var previousCtxCancelled bool + server.heartbeatCtxCancel = func() { + previousCtxCancelled = true + oldCancelFn() + } + + // The second check call should attempt to create a new heartbeat connection and should cancel the previous + // heartbeatCtx during the process. + _, err = server.check() + assert.Nil(t, err, "check error: %v", err) + assert.NotNil(t, desc.LastError, "expected server description to contain an error, got nil") + assert.True(t, previousCtxCancelled, "expected check to cancel previous context but did not") + }) } func includesMetadata(t *testing.T, wm []byte) bool { From 722880bea5450f5572744c3806d0f53bbea6aafe Mon Sep 17 00:00:00 2001 From: Divjot Arora Date: Wed, 28 Apr 2021 16:14:18 -0400 Subject: [PATCH 2/2] catch desc --- x/mongo/driver/topology/server_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index 06afc3c7b4..37afac2ff2 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -658,7 +658,7 @@ func TestServer(t *testing.T) { // The second check call should attempt to create a new heartbeat connection and should cancel the previous // heartbeatCtx during the process. - _, err = server.check() + desc, err = server.check() assert.Nil(t, err, "check error: %v", err) assert.NotNil(t, desc.LastError, "expected server description to contain an error, got nil") assert.True(t, previousCtxCancelled, "expected check to cancel previous context but did not")