Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-1750 Ensure contexts are always cancelled during server monitoring #654

Merged
merged 2 commits into from
Apr 30, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions x/mongo/driver/topology/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,10 @@ func (s *Server) setupHeartbeatConnection() error {

// Take the lock when assigning the context and connection because they're accessed by cancelCheck.
s.heartbeatLock.Lock()
if s.heartbeatCtxCancel != nil {
// Ensure the previous context is cancelled to avoid a leak.
s.heartbeatCtxCancel()
}
divjotarora marked this conversation as resolved.
Show resolved Hide resolved
s.heartbeatCtx, s.heartbeatCtxCancel = context.WithCancel(s.globalCtx)
s.conn = conn
s.heartbeatLock.Unlock()
Expand Down
36 changes: 36 additions & 0 deletions x/mongo/driver/topology/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,42 @@ func TestServer(t *testing.T) {
assert.Equal(t, s.cfg.heartbeatTimeout, conn.readTimeout, "expected readTimeout to be: %v, got: %v", s.cfg.heartbeatTimeout, conn.readTimeout)
assert.Equal(t, s.cfg.heartbeatTimeout, conn.writeTimeout, "expected writeTimeout to be: %v, got: %v", s.cfg.heartbeatTimeout, conn.writeTimeout)
})
t.Run("heartbeat contexts are not leaked", func(t *testing.T) {
// The context created for heartbeats should be cancelled when it is no longer needed to avoid leaks.

server, err := ConnectServer(
address.Address("invalid"),
nil,
primitive.NewObjectID(),
withMonitoringDisabled(func(bool) bool {
return true
}),
)
assert.Nil(t, err, "ConnectServer error: %v", err)

// Expect check to return an error in the server description because the server address doesn't exist. This is
// OK because we just want to ensure the heartbeat context is created.
desc, err := server.check()
assert.Nil(t, err, "check error: %v", err)
assert.NotNil(t, desc.LastError, "expected server description to contain an error, got nil")
assert.NotNil(t, server.heartbeatCtx, "expected heartbeatCtx to be non-nil, got nil")
assert.Nil(t, server.heartbeatCtx.Err(), "expected heartbeatCtx error to be nil, got %v", server.heartbeatCtx.Err())

// Override heartbeatCtxCancel with a wrapper that records whether or not it was called.
divjotarora marked this conversation as resolved.
Show resolved Hide resolved
oldCancelFn := server.heartbeatCtxCancel
var previousCtxCancelled bool
server.heartbeatCtxCancel = func() {
previousCtxCancelled = true
oldCancelFn()
}

// The second check call should attempt to create a new heartbeat connection and should cancel the previous
// heartbeatCtx during the process.
_, err = server.check()
iwysiu marked this conversation as resolved.
Show resolved Hide resolved
assert.Nil(t, err, "check error: %v", err)
assert.NotNil(t, desc.LastError, "expected server description to contain an error, got nil")
assert.True(t, previousCtxCancelled, "expected check to cancel previous context but did not")
divjotarora marked this conversation as resolved.
Show resolved Hide resolved
})
}

func includesMetadata(t *testing.T, wm []byte) bool {
Expand Down