diff --git a/docs/generated/settings/settings-for-tenants.txt b/docs/generated/settings/settings-for-tenants.txt
index 4c597aabb1ec..2568e2cff3e9 100644
--- a/docs/generated/settings/settings-for-tenants.txt
+++ b/docs/generated/settings/settings-for-tenants.txt
@@ -65,7 +65,8 @@ server.oidc_authentication.provider_url string sets OIDC provider URL ({provide
server.oidc_authentication.redirect_url string https://localhost:8080/oidc/v1/callback sets OIDC redirect URL via a URL string or a JSON string containing a required `redirect_urls` key with an object that maps from region keys to URL strings (URLs should point to your load balancer and must route to the path /oidc/v1/callback)
server.oidc_authentication.scopes string openid sets OIDC scopes to include with authentication request (space delimited list of strings, required to start with `openid`)
server.rangelog.ttl duration 720h0m0s if nonzero, range log entries older than this duration are deleted every 10m0s. Should not be lowered below 24 hours.
-server.shutdown.drain_wait duration 0s the amount of time a server waits in an unready state before proceeding with a drain (note that the --drain-wait parameter for cockroach node drain may need adjustment after changing this setting)
+server.shutdown.connection_wait duration 0s the maximum amount of time a server waits for all SQL connections to be closed before proceeding with a drain. When all SQL connections are closed before times out, the server early exits and proceeds to draining range leases. (note that the --drain-wait parameter for cockroach node drain may need adjustment after changing this setting)
+server.shutdown.drain_wait duration 0s the amount of time a server waits in an unready state before proceeding with a drain (note that the --drain-wait parameter for cockroach node drain may need adjustment after changing this setting. --drain-wait is to specify the duration of the whole draining process, while server.shutdown.drain_wait is to set the wait time for health probes to notice that the node is not ready.)
server.shutdown.lease_transfer_wait duration 5s the timeout for a single iteration of the range lease transfer phase of draining (note that the --drain-wait parameter for cockroach node drain may need adjustment after changing this setting)
server.shutdown.query_wait duration 10s the timeout for waiting for active queries to finish during a drain (note that the --drain-wait parameter for cockroach node drain may need adjustment after changing this setting)
server.time_until_store_dead duration 5m0s the time after which if there is no new gossiped information about a store, it is considered dead
diff --git a/docs/generated/settings/settings.html b/docs/generated/settings/settings.html
index 14fd3b3e2c69..b460d26848fc 100644
--- a/docs/generated/settings/settings.html
+++ b/docs/generated/settings/settings.html
@@ -77,7 +77,8 @@
server.oidc_authentication.redirect_url | string | https://localhost:8080/oidc/v1/callback | sets OIDC redirect URL via a URL string or a JSON string containing a required `redirect_urls` key with an object that maps from region keys to URL strings (URLs should point to your load balancer and must route to the path /oidc/v1/callback) |
server.oidc_authentication.scopes | string | openid | sets OIDC scopes to include with authentication request (space delimited list of strings, required to start with `openid`) |
server.rangelog.ttl | duration | 720h0m0s | if nonzero, range log entries older than this duration are deleted every 10m0s. Should not be lowered below 24 hours. |
-server.shutdown.drain_wait | duration | 0s | the amount of time a server waits in an unready state before proceeding with a drain (note that the --drain-wait parameter for cockroach node drain may need adjustment after changing this setting) |
+server.shutdown.connection_wait | duration | 0s | the maximum amount of time a server waits for all SQL connections to be closed before proceeding with a drain. When all SQL connections are closed before times out, the server early exits and proceeds to draining range leases. (note that the --drain-wait parameter for cockroach node drain may need adjustment after changing this setting) |
+server.shutdown.drain_wait | duration | 0s | the amount of time a server waits in an unready state before proceeding with a drain (note that the --drain-wait parameter for cockroach node drain may need adjustment after changing this setting. --drain-wait is to specify the duration of the whole draining process, while server.shutdown.drain_wait is to set the wait time for health probes to notice that the node is not ready.) |
server.shutdown.lease_transfer_wait | duration | 5s | the timeout for a single iteration of the range lease transfer phase of draining (note that the --drain-wait parameter for cockroach node drain may need adjustment after changing this setting) |
server.shutdown.query_wait | duration | 10s | the timeout for waiting for active queries to finish during a drain (note that the --drain-wait parameter for cockroach node drain may need adjustment after changing this setting) |
server.time_until_store_dead | duration | 5m0s | the time after which if there is no new gossiped information about a store, it is considered dead |
diff --git a/pkg/cmd/roachtest/tests/BUILD.bazel b/pkg/cmd/roachtest/tests/BUILD.bazel
index d74089c6e4d9..2133b306bd80 100644
--- a/pkg/cmd/roachtest/tests/BUILD.bazel
+++ b/pkg/cmd/roachtest/tests/BUILD.bazel
@@ -32,6 +32,7 @@ go_library(
"disk_stall.go",
"django.go",
"django_blocklist.go",
+ "drain.go",
"drop.go",
"drt.go",
"encryption.go",
diff --git a/pkg/cmd/roachtest/tests/drain.go b/pkg/cmd/roachtest/tests/drain.go
new file mode 100644
index 000000000000..1d0762ddd79e
--- /dev/null
+++ b/pkg/cmd/roachtest/tests/drain.go
@@ -0,0 +1,236 @@
+// Copyright 2022 The Cockroach Authors.
+//
+// Use of this software is governed by the Business Source License
+// included in the file licenses/BSL.txt.
+//
+// As of the Change Date specified in that file, in accordance with
+// the Business Source License, use of this software will be governed
+// by the Apache License, Version 2.0, included in the file
+// licenses/APL.txt.
+
+package tests
+
+import (
+ "context"
+ gosql "database/sql"
+ "fmt"
+ "math/rand"
+ "path/filepath"
+ "time"
+
+ "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/cluster"
+ "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/option"
+ "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/registry"
+ "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/test"
+ "github.com/cockroachdb/cockroach/pkg/roachprod/install"
+ "github.com/cockroachdb/cockroach/pkg/util/timeutil"
+ "github.com/cockroachdb/errors"
+ "github.com/stretchr/testify/require"
+)
+
+func registerDrain(r registry.Registry) {
+ {
+ r.Add(registry.TestSpec{
+ Name: "drain/conn-wait",
+ Owner: registry.OwnerSQLExperience,
+ Cluster: r.MakeClusterSpec(1),
+ Run: func(ctx context.Context, t test.Test, c cluster.Cluster) {
+ runConnectionWait(ctx, t, c)
+ },
+ },
+ )
+ }
+}
+
+func runConnectionWait(ctx context.Context, t test.Test, c cluster.Cluster) {
+ var err error
+
+ err = c.PutE(ctx, t.L(), t.Cockroach(), "./cockroach", c.All())
+ require.NoError(t, err, "cannot mount cockroach binary")
+
+ // Verify that draining proceeds immediately after connections are closed client-side.
+ {
+ const (
+ // Set the duration of each phase of the draining period.
+ drainWaitDuration = 10 * time.Second
+ connectionWaitDuration = 100 * time.Second
+ queryWaitDuration = 10 * time.Second
+ // pokeDuringConnWaitTimestamp is the timestamp after the server
+ // starts waiting for SQL connections to close (with the start of the whole
+ // draining process marked as timestamp 0). It should be set larger than
+ // drainWaitDuration, but smaller than (drainWaitDuration +
+ // connectionWaitDuration).
+ pokeDuringConnWaitTimestamp = 45 * time.Second
+ connMaxLifetime = 30 * time.Second
+ connMaxCount = 5
+ nodeToDrain = 1
+ )
+ totalWaitDuration := drainWaitDuration + connectionWaitDuration + queryWaitDuration
+
+ prepareCluster(ctx, t, c, drainWaitDuration, connectionWaitDuration, queryWaitDuration)
+
+ db := c.Conn(ctx, t.L(), nodeToDrain)
+ defer db.Close()
+
+ db.SetConnMaxLifetime(connMaxLifetime)
+ db.SetMaxOpenConns(connMaxCount)
+
+ var conns []*gosql.Conn
+
+ // Get two connections from the connection pools.
+ for j := 0; j < 2; j++ {
+ conn, err := db.Conn(ctx)
+
+ require.NoError(t, err, "failed to a SQL connection from the connection pool")
+
+ conns = append(conns, conn)
+ }
+
+ // Start draining the node.
+ m := c.NewMonitor(ctx, c.Node(nodeToDrain))
+
+ m.Go(func(ctx context.Context) error {
+ t.Status(fmt.Sprintf("start draining node %d", nodeToDrain))
+ return c.RunE(ctx,
+ c.Node(nodeToDrain),
+ fmt.Sprintf("./cockroach node drain --insecure --drain-wait=%fs",
+ totalWaitDuration.Seconds()))
+ })
+
+ drainStartTimestamp := timeutil.Now()
+
+ // Sleep till the server is in the status of waiting for users to close SQL
+ // connections. Verify that the server is rejecting new SQL connections now.
+ time.Sleep(pokeDuringConnWaitTimestamp)
+ _, err = db.Conn(ctx)
+ if err != nil {
+ t.Status(fmt.Sprintf("%s after draining starts, server is rejecting "+
+ "new SQL connections: %v", pokeDuringConnWaitTimestamp, err))
+ } else {
+ t.Fatal(errors.New("new SQL connections should not be allowed when the server " +
+ "starts waiting for the user to close SQL connections"))
+ }
+
+ require.Equalf(t, db.Stats().OpenConnections, 2, "number of open connections should be 2")
+
+ t.Status("number of open connections: ", db.Stats().OpenConnections)
+
+ randConn := conns[rand.Intn(len(conns))]
+
+ // When server is waiting clients to close connections, verify that SQL
+ // queries do not fail.
+ _, err = randConn.ExecContext(ctx, "SELECT 1;")
+
+ require.NoError(t, err, "expected query not to fail before the "+
+ "server starts draining SQL connections")
+
+ for _, conn := range conns {
+ err := conn.Close()
+ require.NoError(t, err,
+ "expected connection to be able to be successfully closed client-side")
+ }
+
+ t.Status("all SQL connections are put back to the connection pool")
+
+ err = m.WaitE()
+ require.NoError(t, err, "error waiting for the draining to finish")
+
+ drainEndTimestamp := timeutil.Now()
+ actualDrainDuration := drainEndTimestamp.Sub(drainStartTimestamp).Seconds()
+
+ t.L().Printf("the draining lasted %f seconds", actualDrainDuration)
+
+ if actualDrainDuration >= float64(totalWaitDuration)-10 {
+ t.Fatal(errors.New("the draining process didn't early exit " +
+ "when waiting for server to close all SQL connections"))
+ }
+
+ // Fully quit the draining node so that we can restart it for the next test.
+ quitNode(ctx, t, c, nodeToDrain)
+ }
+
+ // Verify a warning exists in the case that connectionWait expires.
+ {
+ const (
+ // Set the duration of the draining period.
+ drainWaitDuration = 0 * time.Second
+ connectionWaitDuration = 10 * time.Second
+ queryWaitDuration = 20 * time.Second
+ nodeToDrain = 1
+ )
+
+ totalWaitDuration := drainWaitDuration + connectionWaitDuration + queryWaitDuration
+
+ prepareCluster(ctx, t, c, drainWaitDuration, connectionWaitDuration, queryWaitDuration)
+
+ db := c.Conn(ctx, t.L(), nodeToDrain)
+ defer db.Close()
+
+ // Get a connection from the connection pool.
+ _, err = db.Conn(ctx)
+
+ require.NoError(t, err, "cannot get a SQL connection from the connection pool")
+
+ m := c.NewMonitor(ctx, c.Node(nodeToDrain))
+ m.Go(func(ctx context.Context) error {
+ t.Status(fmt.Sprintf("draining node %d", nodeToDrain))
+ return c.RunE(ctx,
+ c.Node(nodeToDrain),
+ fmt.Sprintf("./cockroach node drain --insecure --drain-wait=%fs",
+ totalWaitDuration.Seconds()))
+ })
+
+ err = m.WaitE()
+ require.NoError(t, err, "error waiting for the draining to finish")
+
+ logFile := filepath.Join("logs", "*.log")
+ err = c.RunE(ctx, c.Node(nodeToDrain),
+ "grep", "-q", "'proceeding to drain SQL connections'", logFile)
+ require.NoError(t, err, "warning is not logged in the log file")
+ }
+
+}
+
+// prepareCluster is to start the server on nodes in the given cluster, and set
+// the cluster setting for duration of each phase of the draining process.
+func prepareCluster(
+ ctx context.Context,
+ t test.Test,
+ c cluster.Cluster,
+ drainWait time.Duration,
+ connectionWait time.Duration,
+ queryWait time.Duration,
+) {
+
+ c.Start(ctx, t.L(), option.DefaultStartOpts(), install.MakeClusterSettings(), c.All())
+
+ db := c.Conn(ctx, t.L(), 1)
+ defer db.Close()
+
+ _, err := db.ExecContext(ctx, `
+ SET CLUSTER SETTING server.shutdown.drain_wait = $1;
+ SET CLUSTER SETTING server.shutdown.connection_wait = $2;
+ SET CLUSTER SETTING server.shutdown.query_wait = $3;`,
+ drainWait.Seconds(),
+ connectionWait.Seconds(),
+ queryWait.Seconds(),
+ )
+ require.NoError(t, err)
+
+}
+
+func quitNode(ctx context.Context, t test.Test, c cluster.Cluster, node int) {
+ args := append([]string{
+ "./cockroach", "quit", "--insecure", "--logtostderr=INFO",
+ fmt.Sprintf("--port={pgport:%d}", node)})
+ result, err := c.RunWithDetailsSingleNode(ctx, t.L(), c.Node(node), args...)
+ output := result.Stdout + result.Stderr
+ t.L().Printf("cockroach quit:\n%s\n", output)
+ require.NoError(t, err, "cannot quit cockroach")
+
+ stopOpts := option.DefaultStopOpts()
+ stopOpts.RoachprodOpts.Sig = 0
+ stopOpts.RoachprodOpts.Wait = true
+ c.Stop(ctx, t.L(), stopOpts, c.All())
+ t.L().Printf("stopped cluster")
+}
diff --git a/pkg/cmd/roachtest/tests/registry.go b/pkg/cmd/roachtest/tests/registry.go
index 84aa2711bbc3..fd12f1d0a539 100644
--- a/pkg/cmd/roachtest/tests/registry.go
+++ b/pkg/cmd/roachtest/tests/registry.go
@@ -34,6 +34,7 @@ func RegisterTests(r registry.Registry) {
registerDiskFull(r)
RegisterDiskStalledDetection(r)
registerDjango(r)
+ registerDrain(r)
registerDrop(r)
registerEncryption(r)
registerEngineSwitch(r)
diff --git a/pkg/server/drain.go b/pkg/server/drain.go
index ddc7fcf71bc5..02f73e5650e1 100644
--- a/pkg/server/drain.go
+++ b/pkg/server/drain.go
@@ -12,6 +12,7 @@ package server
import (
"context"
+ "fmt"
"io"
"strings"
"time"
@@ -32,21 +33,38 @@ import (
var (
queryWait = settings.RegisterDurationSetting(
- settings.TenantWritable,
+ settings.TenantReadOnly,
"server.shutdown.query_wait",
"the timeout for waiting for active queries to finish during a drain "+
"(note that the --drain-wait parameter for cockroach node drain may need adjustment "+
"after changing this setting)",
10*time.Second,
+ settings.NonNegativeDurationWithMaximum(10*time.Hour),
).WithPublic()
drainWait = settings.RegisterDurationSetting(
- settings.TenantWritable,
+ settings.TenantReadOnly,
"server.shutdown.drain_wait",
"the amount of time a server waits in an unready state before proceeding with a drain "+
+ "(note that the --drain-wait parameter for cockroach node drain may need adjustment "+
+ "after changing this setting. --drain-wait is to specify the duration of the "+
+ "whole draining process, while server.shutdown.drain_wait is to set the "+
+ "wait time for health probes to notice that the node is not ready.)",
+ 0*time.Second,
+ settings.NonNegativeDurationWithMaximum(10*time.Hour),
+ ).WithPublic()
+
+ connectionWait = settings.RegisterDurationSetting(
+ settings.TenantReadOnly,
+ "server.shutdown.connection_wait",
+ "the maximum amount of time a server waits for all SQL connections to "+
+ "be closed before proceeding with a drain. "+
+ "When all SQL connections are closed before times out, the server early "+
+ "exits and proceeds to draining range leases. "+
"(note that the --drain-wait parameter for cockroach node drain may need adjustment "+
"after changing this setting)",
0*time.Second,
+ settings.NonNegativeDurationWithMaximum(10*time.Hour),
).WithPublic()
)
@@ -309,19 +327,33 @@ func (s *drainServer) drainClients(
s.grpc.setMode(modeDraining)
s.sqlServer.isReady.Set(false)
+ // Log the number of connections periodically.
+ if err := s.logOpenConns(ctx); err != nil {
+ log.Ops.Warningf(ctx, "error showing alive SQL connections: %v", err)
+ }
+
// Wait the duration of drainWait.
// This will fail load balancer checks and delay draining so that client
// traffic can move off this node.
// Note delay only happens on first call to drain.
if shouldDelayDraining {
+ log.Ops.Info(ctx, "waiting for health probes to notice that the node "+
+ "is not ready for new sql connections...")
s.drainSleepFn(drainWait.Get(&s.sqlServer.execCfg.Settings.SV))
}
- // Drain all SQL connections.
- // The queryWait duration is a timeout for waiting on clients
- // to self-disconnect. If the timeout is reached, any remaining connections
+ // Wait for users to close the existing SQL connections.
+ // During this phase, the server is rejecting new SQL connections.
+ // The server exit this phase either once all SQL connections are closed,
+ // or it reaches the connectionMaxWait timeout, whichever earlier.
+ s.sqlServer.pgServer.WaitForSQLConnsToClose(ctx, connectionWait.Get(&s.sqlServer.execCfg.Settings.SV))
+
+ // Drain any remaining SQL connections.
+ // The queryWait duration is a timeout for waiting for SQL queries to finish.
+ // If the timeout is reached, any remaining connections
// will be closed.
queryMaxWait := queryWait.Get(&s.sqlServer.execCfg.Settings.SV)
+
if err := s.sqlServer.pgServer.Drain(ctx, queryMaxWait, reporter); err != nil {
return err
}
@@ -357,3 +389,21 @@ func (s *drainServer) drainNode(
// Mark the stores of the node as "draining" and drain all range leases.
return s.kvServer.node.SetDraining(true /* drain */, reporter, verbose)
}
+
+// logOpenConns logs the number of open SQL connections every 3 seconds.
+func (s *drainServer) logOpenConns(ctx context.Context) error {
+ return s.stopper.RunAsyncTask(ctx, "log-open-conns", func(ctx context.Context) {
+ ticker := time.NewTicker(3 * time.Second)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-ticker.C:
+ log.Ops.Info(ctx, fmt.Sprintf("number of open connections: %d\n", s.sqlServer.pgServer.GetConnCancelMapLen()))
+ case <-s.stopper.ShouldQuiesce():
+ return
+ case <-ctx.Done():
+ return
+ }
+ }
+ })
+}
diff --git a/pkg/sql/pgwire/server.go b/pkg/sql/pgwire/server.go
index bbd4c178f244..03da970e1501 100644
--- a/pkg/sql/pgwire/server.go
+++ b/pkg/sql/pgwire/server.go
@@ -220,7 +220,16 @@ type Server struct {
// cancel the associated connection. The corresponding key is a channel
// that is closed when the connection is done.
connCancelMap cancelChanMap
- draining bool
+ // draining is set to true when the server starts draining the SQL layer.
+ // Remaining SQL connections will be closed as
+ // soon as their queries finish. After the timeout set by
+ // server.shutdown.query_wait, all connections will be closed regardless any
+ // queries in flight.
+ draining bool
+ // rejectNewConnections is set true when the server does not accept new
+ // SQL connections, e.g. when the draining process enters the phase whose
+ // duration is specified by the server.shutdown.connection_wait.
+ rejectNewConnections bool
}
auth struct {
@@ -443,14 +452,15 @@ func (s *Server) Drain(
// Undrain switches the server back to the normal mode of operation in which
// connections are accepted.
func (s *Server) Undrain() {
- s.mu.Lock()
- s.setDrainingLocked(false)
- s.mu.Unlock()
+ s.setRejectNewConnections(false)
+ s.setDraining(false)
}
-// setDrainingLocked sets the server's draining state and returns whether the
-// state changed (i.e. drain != s.mu.draining). s.mu must be locked.
-func (s *Server) setDrainingLocked(drain bool) bool {
+// setDraining sets the server's draining state and returns whether the
+// state changed (i.e. drain != s.mu.draining).
+func (s *Server) setDraining(drain bool) bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
if s.mu.draining == drain {
return false
}
@@ -458,6 +468,97 @@ func (s *Server) setDrainingLocked(drain bool) bool {
return true
}
+// setRejectNewConnections sets the server's rejectNewConnections state.
+func (s *Server) setRejectNewConnections(rej bool) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.mu.rejectNewConnections = rej
+}
+
+// GetConnCancelMapLen returns the length of connCancelMap of the server.
+// This is a helper function when the server waits the SQL connections to be
+// closed. During this period, the server listens to the status of all
+// connections, and early exits this draining phase if there remains no active
+// SQL connections.
+func (s *Server) GetConnCancelMapLen() int {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return len(s.mu.connCancelMap)
+}
+
+// WaitForSQLConnsToClose waits for the client to close all SQL connections for the
+// duration of connectionWait.
+// With this phase, the node starts rejecting SQL connections, and as
+// soon as all existing SQL connections are closed, the server early exits this
+// draining phase.
+func (s *Server) WaitForSQLConnsToClose(ctx context.Context, connectionWait time.Duration) error {
+ // If we're already draining the SQL connections, we don't need to wait again.
+ if s.IsDraining() {
+ return nil
+ }
+
+ s.setRejectNewConnections(true)
+
+ if connectionWait == 0 {
+ return nil
+ }
+
+ log.Ops.Info(ctx, "waiting for clients to close existing SQL connections...")
+
+ timer := time.NewTimer(connectionWait)
+ defer timer.Stop()
+
+ _, allConnsDone, quitWaitingForConns := s.waitConnsDone()
+ defer close(quitWaitingForConns)
+
+ select {
+ // Connection wait times out.
+ case <-time.After(connectionWait):
+ log.Ops.Infof(ctx,
+ "%d connections remain after waiting %s; proceeding to drain SQL connections",
+ s.GetConnCancelMapLen(),
+ connectionWait,
+ )
+ case <-allConnsDone:
+ }
+
+ return nil
+}
+
+// waitConnsDone returns a copy of s.mu.connCancelMap, and a channel that
+// will be closed once all sql connections are closed, or the server quits
+// waiting for connections, whichever earlier.
+func (s *Server) waitConnsDone() (cancelChanMap, chan struct{}, chan struct{}) {
+
+ connCancelMap := func() cancelChanMap {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ connCancelMap := make(cancelChanMap)
+ for done, cancel := range s.mu.connCancelMap {
+ connCancelMap[done] = cancel
+ }
+ return connCancelMap
+ }()
+
+ allConnsDone := make(chan struct{}, 1)
+
+ quitWaitingForConns := make(chan struct{}, 1)
+
+ go func() {
+ defer close(allConnsDone)
+
+ for done := range connCancelMap {
+ select {
+ case <-done:
+ case <-quitWaitingForConns:
+ return
+ }
+ }
+ }()
+
+ return connCancelMap, allConnsDone, quitWaitingForConns
+}
+
// drainImpl drains the SQL clients.
//
// The queryWait duration is used to wait on clients to
@@ -476,8 +577,25 @@ func (s *Server) drainImpl(
cancelWait time.Duration,
reporter func(int, redact.SafeString),
) error {
- // This anonymous function returns a copy of s.mu.connCancelMap if there are
- // any active connections to cancel. We will only attempt to cancel
+
+ if !s.setDraining(true) {
+ // We are already draining.
+ return nil
+ }
+
+ // If there is no open SQL connections to drain, just return.
+ if s.GetConnCancelMapLen() == 0 {
+ return nil
+ }
+
+ log.Ops.Info(ctx, "starting draining SQL connections...")
+
+ // Spin off a goroutine that waits for all connections to signal that they
+ // are done and reports it on allConnsDone. The main goroutine signals this
+ // goroutine to stop work through quitWaitingForConns.
+
+ // This s.waitConnsDone function returns a copy of s.mu.connCancelMap if there
+ // are any active connections to cancel. We will only attempt to cancel
// connections that were active at the moment the draining switch happened.
// It is enough to do this because:
// 1) If no new connections are added to the original map all connections
@@ -486,44 +604,14 @@ func (s *Server) drainImpl(
// were added when s.mu.draining = false, thus not requiring cancellation.
// These connections are not our responsibility and will be handled when the
// server starts draining again.
- connCancelMap := func() cancelChanMap {
- s.mu.Lock()
- defer s.mu.Unlock()
- if !s.setDrainingLocked(true) {
- // We are already draining.
- return nil
- }
- connCancelMap := make(cancelChanMap)
- for done, cancel := range s.mu.connCancelMap {
- connCancelMap[done] = cancel
- }
- return connCancelMap
- }()
- if len(connCancelMap) == 0 {
- return nil
- }
+ connCancelMap, allConnsDone, quitWaitingForConns := s.waitConnsDone()
+ defer close(quitWaitingForConns)
+
if reporter != nil {
// Report progress to the Drain RPC.
reporter(len(connCancelMap), "SQL clients")
}
- // Spin off a goroutine that waits for all connections to signal that they
- // are done and reports it on allConnsDone. The main goroutine signals this
- // goroutine to stop work through quitWaitingForConns.
- allConnsDone := make(chan struct{})
- quitWaitingForConns := make(chan struct{})
- defer close(quitWaitingForConns)
- go func() {
- defer close(allConnsDone)
- for done := range connCancelMap {
- select {
- case <-done:
- case <-quitWaitingForConns:
- return
- }
- }
- }()
-
// Wait for connections to finish up their queries for the duration of queryWait.
select {
case <-time.After(queryWait):
@@ -606,7 +694,7 @@ func (s *Server) TestingEnableAuthLogging() {
//
// An error is returned if the initial handshake of the connection fails.
func (s *Server) ServeConn(ctx context.Context, conn net.Conn, socketType SocketType) error {
- ctx, draining, onCloseFn := s.registerConn(ctx)
+ ctx, rejectNewConnections, onCloseFn := s.registerConn(ctx)
defer onCloseFn()
connDetails := eventpb.CommonConnectionDetails{
@@ -672,7 +760,7 @@ func (s *Server) ServeConn(ctx context.Context, conn net.Conn, socketType Socket
}
// If the server is shutting down, terminate the connection early.
- if draining {
+ if rejectNewConnections {
log.Ops.Info(ctx, "rejecting new connection while server is draining")
return s.sendErr(ctx, conn, newAdminShutdownErr(ErrDrainingNewConn))
}
@@ -1183,19 +1271,20 @@ func (s *Server) maybeUpgradeToSecureConn(
}
// registerConn registers the incoming connection to the map of active connections,
-// which can be canceled by a concurrent server drain. It also returns
-// the current draining status of the server.
+// which can be canceled by a concurrent server drain. It also returns a boolean
+// variable rejectConn, which shows if the server is rejecting new SQL
+// connections.
//
// The onCloseFn() callback must be called at the end of the
// connection by the caller.
func (s *Server) registerConn(
ctx context.Context,
-) (newCtx context.Context, draining bool, onCloseFn func()) {
+) (newCtx context.Context, rejectNewConnections bool, onCloseFn func()) {
onCloseFn = func() {}
newCtx = ctx
s.mu.Lock()
- draining = s.mu.draining
- if !draining {
+ rejectNewConnections = s.mu.rejectNewConnections
+ if !rejectNewConnections {
var cancel context.CancelFunc
newCtx, cancel = contextutil.WithCancel(ctx)
done := make(chan struct{})
@@ -1210,11 +1299,11 @@ func (s *Server) registerConn(
}
s.mu.Unlock()
- // If the Server is draining, we will use the connection only to send an
- // error, so we don't count it in the stats. This makes sense since
- // DrainClient() waits for that number to drop to zero,
+ // If the server is rejecting new SQL connections, we will use the connection
+ // only to send an error, so we don't count it in the stats. This makes sense
+ // since DrainClient() waits for that number to drop to zero,
// so we don't want it to oscillate unnecessarily.
- if !draining {
+ if !rejectNewConnections {
s.metrics.NewConns.Inc(1)
s.metrics.Conns.Inc(1)
prevOnCloseFn := onCloseFn