diff --git a/x/mongo/driver/session/server_session.go b/x/mongo/driver/session/server_session.go index 4dbb5ee960..e7881fa170 100644 --- a/x/mongo/driver/session/server_session.go +++ b/x/mongo/driver/session/server_session.go @@ -11,6 +11,7 @@ import ( "crypto/rand" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver/uuid" ) @@ -27,12 +28,18 @@ type Server struct { // returns whether or not a session has expired given a timeout in minutes // a session is considered expired if it has less than 1 minute left before becoming stale -func (ss *Server) expired(timeoutMinutes uint32) bool { - if timeoutMinutes <= 0 { +func (ss *Server) expired(topoDesc topologyDescription) bool { + // There is no server monitoring in LB mode, so we do not track session timeout minutes from server hello responses + // and never consider sessions to be expired. + if topoDesc.kind == description.LoadBalanced { + return false + } + + if topoDesc.timeoutMinutes <= 0 { return true } timeUnused := time.Since(ss.LastUsed).Minutes() - return timeUnused > float64(timeoutMinutes-1) + return timeUnused > float64(topoDesc.timeoutMinutes-1) } // update the last used time for this session. diff --git a/x/mongo/driver/session/server_session_test.go b/x/mongo/driver/session/server_session_test.go index 57b53a1041..f608ef94a8 100644 --- a/x/mongo/driver/session/server_session_test.go +++ b/x/mongo/driver/session/server_session_test.go @@ -10,21 +10,33 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/internal/testutil/assert" + "go.mongodb.org/mongo-driver/mongo/description" ) func TestServerSession(t *testing.T) { - t.Run("Expired", func(t *testing.T) { - sess, err := newServerSession() - require.Nil(t, err, "Unexpected error") - if !sess.expired(0) { - t.Errorf("session should be expired") - } - sess.LastUsed = time.Now().Add(-30 * time.Minute) - if !sess.expired(30) { - t.Errorf("session should be expired") - } + t.Run("non-lb mode", func(t *testing.T) { + sess, err := newServerSession() + assert.Nil(t, err, "newServerSession error: %v", err) + + // The session should be expired if timeoutMinutes is 0 or if its last used time is too old. + assert.True(t, sess.expired(topologyDescription{}), "expected session to be expired when timeoutMinutes=0") + sess.LastUsed = time.Now().Add(-30 * time.Minute) + topoDesc := topologyDescription{timeoutMinutes: 30} + assert.True(t, sess.expired(topoDesc), "expected session to be expired when timeoutMinutes=30") + }) + t.Run("lb mode", func(t *testing.T) { + sess, err := newServerSession() + assert.Nil(t, err, "newServerSession error: %v", err) + + // The session should never be considered expired. + topoDesc := topologyDescription{kind: description.LoadBalanced} + assert.False(t, sess.expired(topoDesc), "session reported that it was expired in LB mode with timeoutMinutes=0") + sess.LastUsed = time.Now().Add(-30 * time.Minute) + topoDesc.timeoutMinutes = 10 + assert.False(t, sess.expired(topoDesc), "session reported that it was expired in LB mode with timeoutMinutes=10") + }) }) } diff --git a/x/mongo/driver/session/session_pool.go b/x/mongo/driver/session/session_pool.go index 61616ac316..27db7c476f 100644 --- a/x/mongo/driver/session/session_pool.go +++ b/x/mongo/driver/session/session_pool.go @@ -20,13 +20,20 @@ type Node struct { prev *Node } +// topologyDescription is used to track a subset of the fields present in a description.Topology instance that are +// relevant for determining session expiration. +type topologyDescription struct { + kind description.TopologyKind + timeoutMinutes uint32 +} + // Pool is a pool of server sessions that can be reused. type Pool struct { - descChan <-chan description.Topology - head *Node - tail *Node - timeout uint32 - mutex sync.Mutex // mutex to protect list and sessionTimeout + descChan <-chan description.Topology + head *Node + tail *Node + latestTopology topologyDescription + mutex sync.Mutex // mutex to protect list and sessionTimeout checkedOut int // number of sessions checked out of pool } @@ -54,7 +61,10 @@ func NewPool(descChan <-chan description.Topology) *Pool { func (p *Pool) updateTimeout() { select { case newDesc := <-p.descChan: - p.timeout = newDesc.SessionTimeoutMinutes + p.latestTopology = topologyDescription{ + kind: newDesc.Kind, + timeoutMinutes: newDesc.SessionTimeoutMinutes, + } default: // no new description waiting } @@ -73,7 +83,7 @@ func (p *Pool) GetSession() (*Server, error) { p.updateTimeout() for p.head != nil { // pull session from head of queue and return if it is valid for at least 1 more minute - if p.head.expired(p.timeout) { + if p.head.expired(p.latestTopology) { p.head = p.head.next continue } @@ -112,7 +122,7 @@ func (p *Pool) ReturnSession(ss *Server) { p.updateTimeout() // check sessions at end of queue for expired // stop checking after hitting the first valid session - for p.tail != nil && p.tail.expired(p.timeout) { + for p.tail != nil && p.tail.expired(p.latestTopology) { if p.tail.prev != nil { p.tail.prev.next = nil } @@ -120,7 +130,7 @@ func (p *Pool) ReturnSession(ss *Server) { } // session expired - if ss.expired(p.timeout) { + if ss.expired(p.latestTopology) { return } diff --git a/x/mongo/driver/session/session_pool_test.go b/x/mongo/driver/session/session_pool_test.go index 0259af69af..62d4164a0f 100644 --- a/x/mongo/driver/session/session_pool_test.go +++ b/x/mongo/driver/session/session_pool_test.go @@ -18,7 +18,9 @@ func TestSessionPool(t *testing.T) { t.Run("TestLifo", func(t *testing.T) { descChan := make(chan description.Topology) p := NewPool(descChan) - p.timeout = 30 // Set to some arbitrarily high number greater than 1 minute. + p.latestTopology = topologyDescription{ + timeoutMinutes: 30, // Set to some arbitrarily high number greater than 1 minute. + } first, err := p.GetSession() assert.Nil(t, err, "GetSession error: %v", err) @@ -45,8 +47,8 @@ func TestSessionPool(t *testing.T) { t.Run("TestExpiredRemoved", func(t *testing.T) { descChan := make(chan description.Topology) p := NewPool(descChan) - // New sessions will always become stale when returned - p.timeout = 0 + // Set timeout minutes to 0 so new sessions will always become stale when returned + p.latestTopology = topologyDescription{} first, err := p.GetSession() assert.Nil(t, err, "GetSession error: %v", err)