From f28d69f15d68afcbda6e30c4eec037bbffc78dc7 Mon Sep 17 00:00:00 2001 From: Divjot Arora Date: Fri, 30 Apr 2021 12:15:29 -0400 Subject: [PATCH] GODRIVER-1750 Ensure contexts are always cancelled during server monitoring (#654) --- x/mongo/driver/topology/server.go | 4 +++ x/mongo/driver/topology/server_test.go | 36 ++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index 857730f9fe..649faa534d 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -596,6 +596,10 @@ func (s *Server) setupHeartbeatConnection() error { // Take the lock when assigning the context and connection because they're accessed by cancelCheck. s.heartbeatLock.Lock() + if s.heartbeatCtxCancel != nil { + // Ensure the previous context is cancelled to avoid a leak. + s.heartbeatCtxCancel() + } s.heartbeatCtx, s.heartbeatCtxCancel = context.WithCancel(s.globalCtx) s.conn = conn s.heartbeatLock.Unlock() diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index 4b23a27049..61ef63b6ff 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -507,6 +507,42 @@ func TestServer(t *testing.T) { assert.Equal(t, s.cfg.heartbeatTimeout, conn.readTimeout, "expected readTimeout to be: %v, got: %v", s.cfg.heartbeatTimeout, conn.readTimeout) assert.Equal(t, s.cfg.heartbeatTimeout, conn.writeTimeout, "expected writeTimeout to be: %v, got: %v", s.cfg.heartbeatTimeout, conn.writeTimeout) }) + t.Run("heartbeat contexts are not leaked", func(t *testing.T) { + // The context created for heartbeats should be cancelled when it is no longer needed to avoid leaks. + + server, err := ConnectServer( + address.Address("invalid"), + nil, + primitive.NewObjectID(), + withMonitoringDisabled(func(bool) bool { + return true + }), + ) + assert.Nil(t, err, "ConnectServer error: %v", err) + + // Expect check to return an error in the server description because the server address doesn't exist. This is + // OK because we just want to ensure the heartbeat context is created. + desc, err := server.check() + assert.Nil(t, err, "check error: %v", err) + assert.NotNil(t, desc.LastError, "expected server description to contain an error, got nil") + assert.NotNil(t, server.heartbeatCtx, "expected heartbeatCtx to be non-nil, got nil") + assert.Nil(t, server.heartbeatCtx.Err(), "expected heartbeatCtx error to be nil, got %v", server.heartbeatCtx.Err()) + + // Override heartbeatCtxCancel with a wrapper that records whether or not it was called. + oldCancelFn := server.heartbeatCtxCancel + var previousCtxCancelled bool + server.heartbeatCtxCancel = func() { + previousCtxCancelled = true + oldCancelFn() + } + + // The second check call should attempt to create a new heartbeat connection and should cancel the previous + // heartbeatCtx during the process. + desc, err = server.check() + assert.Nil(t, err, "check error: %v", err) + assert.NotNil(t, desc.LastError, "expected server description to contain an error, got nil") + assert.True(t, previousCtxCancelled, "expected check to cancel previous context but did not") + }) } func includesMetadata(t *testing.T, wm []byte) bool {