Skip to content

Commit

Permalink
GODRIVER-1977 Ensure sessions are not expired in LB mode (#643)
Browse files Browse the repository at this point in the history
  • Loading branch information
Divjot Arora authored Apr 22, 2021
1 parent a761f52 commit db0c9d9
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 26 deletions.
13 changes: 10 additions & 3 deletions x/mongo/driver/session/server_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"crypto/rand"

"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/uuid"
)
Expand All @@ -27,12 +28,18 @@ type Server struct {

// returns whether or not a session has expired given a timeout in minutes
// a session is considered expired if it has less than 1 minute left before becoming stale
func (ss *Server) expired(timeoutMinutes uint32) bool {
if timeoutMinutes <= 0 {
func (ss *Server) expired(topoDesc topologyDescription) bool {
// There is no server monitoring in LB mode, so we do not track session timeout minutes from server hello responses
// and never consider sessions to be expired.
if topoDesc.kind == description.LoadBalanced {
return false
}

if topoDesc.timeoutMinutes <= 0 {
return true
}
timeUnused := time.Since(ss.LastUsed).Minutes()
return timeUnused > float64(timeoutMinutes-1)
return timeUnused > float64(topoDesc.timeoutMinutes-1)
}

// update the last used time for this session.
Expand Down
34 changes: 23 additions & 11 deletions x/mongo/driver/session/server_session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,33 @@ import (
"testing"
"time"

"github.com/stretchr/testify/require"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
"go.mongodb.org/mongo-driver/mongo/description"
)

func TestServerSession(t *testing.T) {

t.Run("Expired", func(t *testing.T) {
sess, err := newServerSession()
require.Nil(t, err, "Unexpected error")
if !sess.expired(0) {
t.Errorf("session should be expired")
}
sess.LastUsed = time.Now().Add(-30 * time.Minute)
if !sess.expired(30) {
t.Errorf("session should be expired")
}
t.Run("non-lb mode", func(t *testing.T) {
sess, err := newServerSession()
assert.Nil(t, err, "newServerSession error: %v", err)

// The session should be expired if timeoutMinutes is 0 or if its last used time is too old.
assert.True(t, sess.expired(topologyDescription{}), "expected session to be expired when timeoutMinutes=0")
sess.LastUsed = time.Now().Add(-30 * time.Minute)
topoDesc := topologyDescription{timeoutMinutes: 30}
assert.True(t, sess.expired(topoDesc), "expected session to be expired when timeoutMinutes=30")
})
t.Run("lb mode", func(t *testing.T) {
sess, err := newServerSession()
assert.Nil(t, err, "newServerSession error: %v", err)

// The session should never be considered expired.
topoDesc := topologyDescription{kind: description.LoadBalanced}
assert.False(t, sess.expired(topoDesc), "session reported that it was expired in LB mode with timeoutMinutes=0")

sess.LastUsed = time.Now().Add(-30 * time.Minute)
topoDesc.timeoutMinutes = 10
assert.False(t, sess.expired(topoDesc), "session reported that it was expired in LB mode with timeoutMinutes=10")
})
})
}
28 changes: 19 additions & 9 deletions x/mongo/driver/session/session_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,20 @@ type Node struct {
prev *Node
}

// topologyDescription is used to track a subset of the fields present in a description.Topology instance that are
// relevant for determining session expiration.
type topologyDescription struct {
kind description.TopologyKind
timeoutMinutes uint32
}

// Pool is a pool of server sessions that can be reused.
type Pool struct {
descChan <-chan description.Topology
head *Node
tail *Node
timeout uint32
mutex sync.Mutex // mutex to protect list and sessionTimeout
descChan <-chan description.Topology
head *Node
tail *Node
latestTopology topologyDescription
mutex sync.Mutex // mutex to protect list and sessionTimeout

checkedOut int // number of sessions checked out of pool
}
Expand Down Expand Up @@ -54,7 +61,10 @@ func NewPool(descChan <-chan description.Topology) *Pool {
func (p *Pool) updateTimeout() {
select {
case newDesc := <-p.descChan:
p.timeout = newDesc.SessionTimeoutMinutes
p.latestTopology = topologyDescription{
kind: newDesc.Kind,
timeoutMinutes: newDesc.SessionTimeoutMinutes,
}
default:
// no new description waiting
}
Expand All @@ -73,7 +83,7 @@ func (p *Pool) GetSession() (*Server, error) {
p.updateTimeout()
for p.head != nil {
// pull session from head of queue and return if it is valid for at least 1 more minute
if p.head.expired(p.timeout) {
if p.head.expired(p.latestTopology) {
p.head = p.head.next
continue
}
Expand Down Expand Up @@ -112,15 +122,15 @@ func (p *Pool) ReturnSession(ss *Server) {
p.updateTimeout()
// check sessions at end of queue for expired
// stop checking after hitting the first valid session
for p.tail != nil && p.tail.expired(p.timeout) {
for p.tail != nil && p.tail.expired(p.latestTopology) {
if p.tail.prev != nil {
p.tail.prev.next = nil
}
p.tail = p.tail.prev
}

// session expired
if ss.expired(p.timeout) {
if ss.expired(p.latestTopology) {
return
}

Expand Down
8 changes: 5 additions & 3 deletions x/mongo/driver/session/session_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ func TestSessionPool(t *testing.T) {
t.Run("TestLifo", func(t *testing.T) {
descChan := make(chan description.Topology)
p := NewPool(descChan)
p.timeout = 30 // Set to some arbitrarily high number greater than 1 minute.
p.latestTopology = topologyDescription{
timeoutMinutes: 30, // Set to some arbitrarily high number greater than 1 minute.
}

first, err := p.GetSession()
assert.Nil(t, err, "GetSession error: %v", err)
Expand All @@ -45,8 +47,8 @@ func TestSessionPool(t *testing.T) {
t.Run("TestExpiredRemoved", func(t *testing.T) {
descChan := make(chan description.Topology)
p := NewPool(descChan)
// New sessions will always become stale when returned
p.timeout = 0
// Set timeout minutes to 0 so new sessions will always become stale when returned
p.latestTopology = topologyDescription{}

first, err := p.GetSession()
assert.Nil(t, err, "GetSession error: %v", err)
Expand Down

0 comments on commit db0c9d9

Please sign in to comment.