From be3412fdedc0bd2a47f978e2ce7a1467504fc5e8 Mon Sep 17 00:00:00 2001 From: Jane Xing Date: Tue, 1 Mar 2022 10:56:40 -0600 Subject: [PATCH] server,sql: implement connection_wait for graceful draining This commit is to add a phase to current draining process. At this phase, the server waits for SQL connections to be closed. New SQL connections are not allowed now. Once all SQL connections are closed, the server proceeds to draining the range leases. The maximum duration of this phase is determined by the cluster setting `server.shutdown.connection_wait` The duration can be set similarly to the other 3 existing draining phases: ``` SET CLUSTER SETTING server.shutdown.connection_wait = '40s' ``` Resolves #66319 Release note (ops change): add `server.shutdown.connection_wait` to the draining process configuration. This provides a workaround when customers encountered intermittent blips and failed requests when they were performing operations that are related to restarting nodes. Release justification: Low risk, high benefit changes to existing functionality (optimize the node draining process). --- .../settings/settings-for-tenants.txt | 3 +- docs/generated/settings/settings.html | 3 +- pkg/cmd/roachtest/tests/BUILD.bazel | 1 + pkg/cmd/roachtest/tests/drain.go | 236 ++++++++++++++++++ pkg/cmd/roachtest/tests/registry.go | 1 + pkg/server/drain.go | 60 ++++- pkg/sql/pgwire/server.go | 195 +++++++++++---- 7 files changed, 439 insertions(+), 60 deletions(-) create mode 100644 pkg/cmd/roachtest/tests/drain.go diff --git a/docs/generated/settings/settings-for-tenants.txt b/docs/generated/settings/settings-for-tenants.txt index 4c597aabb1ec..2568e2cff3e9 100644 --- a/docs/generated/settings/settings-for-tenants.txt +++ b/docs/generated/settings/settings-for-tenants.txt @@ -65,7 +65,8 @@ server.oidc_authentication.provider_url string sets OIDC provider URL ({provide server.oidc_authentication.redirect_url string https://localhost:8080/oidc/v1/callback sets OIDC redirect URL via a URL string or a JSON string containing a required `redirect_urls` key with an object that maps from region keys to URL strings (URLs should point to your load balancer and must route to the path /oidc/v1/callback) server.oidc_authentication.scopes string openid sets OIDC scopes to include with authentication request (space delimited list of strings, required to start with `openid`) server.rangelog.ttl duration 720h0m0s if nonzero, range log entries older than this duration are deleted every 10m0s. Should not be lowered below 24 hours. -server.shutdown.drain_wait duration 0s the amount of time a server waits in an unready state before proceeding with a drain (note that the --drain-wait parameter for cockroach node drain may need adjustment after changing this setting) +server.shutdown.connection_wait duration 0s the maximum amount of time a server waits for all SQL connections to be closed before proceeding with a drain. When all SQL connections are closed before times out, the server early exits and proceeds to draining range leases. (note that the --drain-wait parameter for cockroach node drain may need adjustment after changing this setting) +server.shutdown.drain_wait duration 0s the amount of time a server waits in an unready state before proceeding with a drain (note that the --drain-wait parameter for cockroach node drain may need adjustment after changing this setting. --drain-wait is to specify the duration of the whole draining process, while server.shutdown.drain_wait is to set the wait time for health probes to notice that the node is not ready.) server.shutdown.lease_transfer_wait duration 5s the timeout for a single iteration of the range lease transfer phase of draining (note that the --drain-wait parameter for cockroach node drain may need adjustment after changing this setting) server.shutdown.query_wait duration 10s the timeout for waiting for active queries to finish during a drain (note that the --drain-wait parameter for cockroach node drain may need adjustment after changing this setting) server.time_until_store_dead duration 5m0s the time after which if there is no new gossiped information about a store, it is considered dead diff --git a/docs/generated/settings/settings.html b/docs/generated/settings/settings.html index 14fd3b3e2c69..b460d26848fc 100644 --- a/docs/generated/settings/settings.html +++ b/docs/generated/settings/settings.html @@ -77,7 +77,8 @@ server.oidc_authentication.redirect_urlstringhttps://localhost:8080/oidc/v1/callbacksets OIDC redirect URL via a URL string or a JSON string containing a required `redirect_urls` key with an object that maps from region keys to URL strings (URLs should point to your load balancer and must route to the path /oidc/v1/callback) server.oidc_authentication.scopesstringopenidsets OIDC scopes to include with authentication request (space delimited list of strings, required to start with `openid`) server.rangelog.ttlduration720h0m0sif nonzero, range log entries older than this duration are deleted every 10m0s. Should not be lowered below 24 hours. -server.shutdown.drain_waitduration0sthe amount of time a server waits in an unready state before proceeding with a drain (note that the --drain-wait parameter for cockroach node drain may need adjustment after changing this setting) +server.shutdown.connection_waitduration0sthe maximum amount of time a server waits for all SQL connections to be closed before proceeding with a drain. When all SQL connections are closed before times out, the server early exits and proceeds to draining range leases. (note that the --drain-wait parameter for cockroach node drain may need adjustment after changing this setting) +server.shutdown.drain_waitduration0sthe amount of time a server waits in an unready state before proceeding with a drain (note that the --drain-wait parameter for cockroach node drain may need adjustment after changing this setting. --drain-wait is to specify the duration of the whole draining process, while server.shutdown.drain_wait is to set the wait time for health probes to notice that the node is not ready.) server.shutdown.lease_transfer_waitduration5sthe timeout for a single iteration of the range lease transfer phase of draining (note that the --drain-wait parameter for cockroach node drain may need adjustment after changing this setting) server.shutdown.query_waitduration10sthe timeout for waiting for active queries to finish during a drain (note that the --drain-wait parameter for cockroach node drain may need adjustment after changing this setting) server.time_until_store_deadduration5m0sthe time after which if there is no new gossiped information about a store, it is considered dead diff --git a/pkg/cmd/roachtest/tests/BUILD.bazel b/pkg/cmd/roachtest/tests/BUILD.bazel index d74089c6e4d9..2133b306bd80 100644 --- a/pkg/cmd/roachtest/tests/BUILD.bazel +++ b/pkg/cmd/roachtest/tests/BUILD.bazel @@ -32,6 +32,7 @@ go_library( "disk_stall.go", "django.go", "django_blocklist.go", + "drain.go", "drop.go", "drt.go", "encryption.go", diff --git a/pkg/cmd/roachtest/tests/drain.go b/pkg/cmd/roachtest/tests/drain.go new file mode 100644 index 000000000000..1d0762ddd79e --- /dev/null +++ b/pkg/cmd/roachtest/tests/drain.go @@ -0,0 +1,236 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package tests + +import ( + "context" + gosql "database/sql" + "fmt" + "math/rand" + "path/filepath" + "time" + + "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/cluster" + "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/option" + "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/registry" + "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/test" + "github.com/cockroachdb/cockroach/pkg/roachprod/install" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/require" +) + +func registerDrain(r registry.Registry) { + { + r.Add(registry.TestSpec{ + Name: "drain/conn-wait", + Owner: registry.OwnerSQLExperience, + Cluster: r.MakeClusterSpec(1), + Run: func(ctx context.Context, t test.Test, c cluster.Cluster) { + runConnectionWait(ctx, t, c) + }, + }, + ) + } +} + +func runConnectionWait(ctx context.Context, t test.Test, c cluster.Cluster) { + var err error + + err = c.PutE(ctx, t.L(), t.Cockroach(), "./cockroach", c.All()) + require.NoError(t, err, "cannot mount cockroach binary") + + // Verify that draining proceeds immediately after connections are closed client-side. + { + const ( + // Set the duration of each phase of the draining period. + drainWaitDuration = 10 * time.Second + connectionWaitDuration = 100 * time.Second + queryWaitDuration = 10 * time.Second + // pokeDuringConnWaitTimestamp is the timestamp after the server + // starts waiting for SQL connections to close (with the start of the whole + // draining process marked as timestamp 0). It should be set larger than + // drainWaitDuration, but smaller than (drainWaitDuration + + // connectionWaitDuration). + pokeDuringConnWaitTimestamp = 45 * time.Second + connMaxLifetime = 30 * time.Second + connMaxCount = 5 + nodeToDrain = 1 + ) + totalWaitDuration := drainWaitDuration + connectionWaitDuration + queryWaitDuration + + prepareCluster(ctx, t, c, drainWaitDuration, connectionWaitDuration, queryWaitDuration) + + db := c.Conn(ctx, t.L(), nodeToDrain) + defer db.Close() + + db.SetConnMaxLifetime(connMaxLifetime) + db.SetMaxOpenConns(connMaxCount) + + var conns []*gosql.Conn + + // Get two connections from the connection pools. + for j := 0; j < 2; j++ { + conn, err := db.Conn(ctx) + + require.NoError(t, err, "failed to a SQL connection from the connection pool") + + conns = append(conns, conn) + } + + // Start draining the node. + m := c.NewMonitor(ctx, c.Node(nodeToDrain)) + + m.Go(func(ctx context.Context) error { + t.Status(fmt.Sprintf("start draining node %d", nodeToDrain)) + return c.RunE(ctx, + c.Node(nodeToDrain), + fmt.Sprintf("./cockroach node drain --insecure --drain-wait=%fs", + totalWaitDuration.Seconds())) + }) + + drainStartTimestamp := timeutil.Now() + + // Sleep till the server is in the status of waiting for users to close SQL + // connections. Verify that the server is rejecting new SQL connections now. + time.Sleep(pokeDuringConnWaitTimestamp) + _, err = db.Conn(ctx) + if err != nil { + t.Status(fmt.Sprintf("%s after draining starts, server is rejecting "+ + "new SQL connections: %v", pokeDuringConnWaitTimestamp, err)) + } else { + t.Fatal(errors.New("new SQL connections should not be allowed when the server " + + "starts waiting for the user to close SQL connections")) + } + + require.Equalf(t, db.Stats().OpenConnections, 2, "number of open connections should be 2") + + t.Status("number of open connections: ", db.Stats().OpenConnections) + + randConn := conns[rand.Intn(len(conns))] + + // When server is waiting clients to close connections, verify that SQL + // queries do not fail. + _, err = randConn.ExecContext(ctx, "SELECT 1;") + + require.NoError(t, err, "expected query not to fail before the "+ + "server starts draining SQL connections") + + for _, conn := range conns { + err := conn.Close() + require.NoError(t, err, + "expected connection to be able to be successfully closed client-side") + } + + t.Status("all SQL connections are put back to the connection pool") + + err = m.WaitE() + require.NoError(t, err, "error waiting for the draining to finish") + + drainEndTimestamp := timeutil.Now() + actualDrainDuration := drainEndTimestamp.Sub(drainStartTimestamp).Seconds() + + t.L().Printf("the draining lasted %f seconds", actualDrainDuration) + + if actualDrainDuration >= float64(totalWaitDuration)-10 { + t.Fatal(errors.New("the draining process didn't early exit " + + "when waiting for server to close all SQL connections")) + } + + // Fully quit the draining node so that we can restart it for the next test. + quitNode(ctx, t, c, nodeToDrain) + } + + // Verify a warning exists in the case that connectionWait expires. + { + const ( + // Set the duration of the draining period. + drainWaitDuration = 0 * time.Second + connectionWaitDuration = 10 * time.Second + queryWaitDuration = 20 * time.Second + nodeToDrain = 1 + ) + + totalWaitDuration := drainWaitDuration + connectionWaitDuration + queryWaitDuration + + prepareCluster(ctx, t, c, drainWaitDuration, connectionWaitDuration, queryWaitDuration) + + db := c.Conn(ctx, t.L(), nodeToDrain) + defer db.Close() + + // Get a connection from the connection pool. + _, err = db.Conn(ctx) + + require.NoError(t, err, "cannot get a SQL connection from the connection pool") + + m := c.NewMonitor(ctx, c.Node(nodeToDrain)) + m.Go(func(ctx context.Context) error { + t.Status(fmt.Sprintf("draining node %d", nodeToDrain)) + return c.RunE(ctx, + c.Node(nodeToDrain), + fmt.Sprintf("./cockroach node drain --insecure --drain-wait=%fs", + totalWaitDuration.Seconds())) + }) + + err = m.WaitE() + require.NoError(t, err, "error waiting for the draining to finish") + + logFile := filepath.Join("logs", "*.log") + err = c.RunE(ctx, c.Node(nodeToDrain), + "grep", "-q", "'proceeding to drain SQL connections'", logFile) + require.NoError(t, err, "warning is not logged in the log file") + } + +} + +// prepareCluster is to start the server on nodes in the given cluster, and set +// the cluster setting for duration of each phase of the draining process. +func prepareCluster( + ctx context.Context, + t test.Test, + c cluster.Cluster, + drainWait time.Duration, + connectionWait time.Duration, + queryWait time.Duration, +) { + + c.Start(ctx, t.L(), option.DefaultStartOpts(), install.MakeClusterSettings(), c.All()) + + db := c.Conn(ctx, t.L(), 1) + defer db.Close() + + _, err := db.ExecContext(ctx, ` + SET CLUSTER SETTING server.shutdown.drain_wait = $1; + SET CLUSTER SETTING server.shutdown.connection_wait = $2; + SET CLUSTER SETTING server.shutdown.query_wait = $3;`, + drainWait.Seconds(), + connectionWait.Seconds(), + queryWait.Seconds(), + ) + require.NoError(t, err) + +} + +func quitNode(ctx context.Context, t test.Test, c cluster.Cluster, node int) { + args := append([]string{ + "./cockroach", "quit", "--insecure", "--logtostderr=INFO", + fmt.Sprintf("--port={pgport:%d}", node)}) + result, err := c.RunWithDetailsSingleNode(ctx, t.L(), c.Node(node), args...) + output := result.Stdout + result.Stderr + t.L().Printf("cockroach quit:\n%s\n", output) + require.NoError(t, err, "cannot quit cockroach") + + stopOpts := option.DefaultStopOpts() + stopOpts.RoachprodOpts.Sig = 0 + stopOpts.RoachprodOpts.Wait = true + c.Stop(ctx, t.L(), stopOpts, c.All()) + t.L().Printf("stopped cluster") +} diff --git a/pkg/cmd/roachtest/tests/registry.go b/pkg/cmd/roachtest/tests/registry.go index 84aa2711bbc3..fd12f1d0a539 100644 --- a/pkg/cmd/roachtest/tests/registry.go +++ b/pkg/cmd/roachtest/tests/registry.go @@ -34,6 +34,7 @@ func RegisterTests(r registry.Registry) { registerDiskFull(r) RegisterDiskStalledDetection(r) registerDjango(r) + registerDrain(r) registerDrop(r) registerEncryption(r) registerEngineSwitch(r) diff --git a/pkg/server/drain.go b/pkg/server/drain.go index ddc7fcf71bc5..02f73e5650e1 100644 --- a/pkg/server/drain.go +++ b/pkg/server/drain.go @@ -12,6 +12,7 @@ package server import ( "context" + "fmt" "io" "strings" "time" @@ -32,21 +33,38 @@ import ( var ( queryWait = settings.RegisterDurationSetting( - settings.TenantWritable, + settings.TenantReadOnly, "server.shutdown.query_wait", "the timeout for waiting for active queries to finish during a drain "+ "(note that the --drain-wait parameter for cockroach node drain may need adjustment "+ "after changing this setting)", 10*time.Second, + settings.NonNegativeDurationWithMaximum(10*time.Hour), ).WithPublic() drainWait = settings.RegisterDurationSetting( - settings.TenantWritable, + settings.TenantReadOnly, "server.shutdown.drain_wait", "the amount of time a server waits in an unready state before proceeding with a drain "+ + "(note that the --drain-wait parameter for cockroach node drain may need adjustment "+ + "after changing this setting. --drain-wait is to specify the duration of the "+ + "whole draining process, while server.shutdown.drain_wait is to set the "+ + "wait time for health probes to notice that the node is not ready.)", + 0*time.Second, + settings.NonNegativeDurationWithMaximum(10*time.Hour), + ).WithPublic() + + connectionWait = settings.RegisterDurationSetting( + settings.TenantReadOnly, + "server.shutdown.connection_wait", + "the maximum amount of time a server waits for all SQL connections to "+ + "be closed before proceeding with a drain. "+ + "When all SQL connections are closed before times out, the server early "+ + "exits and proceeds to draining range leases. "+ "(note that the --drain-wait parameter for cockroach node drain may need adjustment "+ "after changing this setting)", 0*time.Second, + settings.NonNegativeDurationWithMaximum(10*time.Hour), ).WithPublic() ) @@ -309,19 +327,33 @@ func (s *drainServer) drainClients( s.grpc.setMode(modeDraining) s.sqlServer.isReady.Set(false) + // Log the number of connections periodically. + if err := s.logOpenConns(ctx); err != nil { + log.Ops.Warningf(ctx, "error showing alive SQL connections: %v", err) + } + // Wait the duration of drainWait. // This will fail load balancer checks and delay draining so that client // traffic can move off this node. // Note delay only happens on first call to drain. if shouldDelayDraining { + log.Ops.Info(ctx, "waiting for health probes to notice that the node "+ + "is not ready for new sql connections...") s.drainSleepFn(drainWait.Get(&s.sqlServer.execCfg.Settings.SV)) } - // Drain all SQL connections. - // The queryWait duration is a timeout for waiting on clients - // to self-disconnect. If the timeout is reached, any remaining connections + // Wait for users to close the existing SQL connections. + // During this phase, the server is rejecting new SQL connections. + // The server exit this phase either once all SQL connections are closed, + // or it reaches the connectionMaxWait timeout, whichever earlier. + s.sqlServer.pgServer.WaitForSQLConnsToClose(ctx, connectionWait.Get(&s.sqlServer.execCfg.Settings.SV)) + + // Drain any remaining SQL connections. + // The queryWait duration is a timeout for waiting for SQL queries to finish. + // If the timeout is reached, any remaining connections // will be closed. queryMaxWait := queryWait.Get(&s.sqlServer.execCfg.Settings.SV) + if err := s.sqlServer.pgServer.Drain(ctx, queryMaxWait, reporter); err != nil { return err } @@ -357,3 +389,21 @@ func (s *drainServer) drainNode( // Mark the stores of the node as "draining" and drain all range leases. return s.kvServer.node.SetDraining(true /* drain */, reporter, verbose) } + +// logOpenConns logs the number of open SQL connections every 3 seconds. +func (s *drainServer) logOpenConns(ctx context.Context) error { + return s.stopper.RunAsyncTask(ctx, "log-open-conns", func(ctx context.Context) { + ticker := time.NewTicker(3 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + log.Ops.Info(ctx, fmt.Sprintf("number of open connections: %d\n", s.sqlServer.pgServer.GetConnCancelMapLen())) + case <-s.stopper.ShouldQuiesce(): + return + case <-ctx.Done(): + return + } + } + }) +} diff --git a/pkg/sql/pgwire/server.go b/pkg/sql/pgwire/server.go index bbd4c178f244..03da970e1501 100644 --- a/pkg/sql/pgwire/server.go +++ b/pkg/sql/pgwire/server.go @@ -220,7 +220,16 @@ type Server struct { // cancel the associated connection. The corresponding key is a channel // that is closed when the connection is done. connCancelMap cancelChanMap - draining bool + // draining is set to true when the server starts draining the SQL layer. + // Remaining SQL connections will be closed as + // soon as their queries finish. After the timeout set by + // server.shutdown.query_wait, all connections will be closed regardless any + // queries in flight. + draining bool + // rejectNewConnections is set true when the server does not accept new + // SQL connections, e.g. when the draining process enters the phase whose + // duration is specified by the server.shutdown.connection_wait. + rejectNewConnections bool } auth struct { @@ -443,14 +452,15 @@ func (s *Server) Drain( // Undrain switches the server back to the normal mode of operation in which // connections are accepted. func (s *Server) Undrain() { - s.mu.Lock() - s.setDrainingLocked(false) - s.mu.Unlock() + s.setRejectNewConnections(false) + s.setDraining(false) } -// setDrainingLocked sets the server's draining state and returns whether the -// state changed (i.e. drain != s.mu.draining). s.mu must be locked. -func (s *Server) setDrainingLocked(drain bool) bool { +// setDraining sets the server's draining state and returns whether the +// state changed (i.e. drain != s.mu.draining). +func (s *Server) setDraining(drain bool) bool { + s.mu.Lock() + defer s.mu.Unlock() if s.mu.draining == drain { return false } @@ -458,6 +468,97 @@ func (s *Server) setDrainingLocked(drain bool) bool { return true } +// setRejectNewConnections sets the server's rejectNewConnections state. +func (s *Server) setRejectNewConnections(rej bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.mu.rejectNewConnections = rej +} + +// GetConnCancelMapLen returns the length of connCancelMap of the server. +// This is a helper function when the server waits the SQL connections to be +// closed. During this period, the server listens to the status of all +// connections, and early exits this draining phase if there remains no active +// SQL connections. +func (s *Server) GetConnCancelMapLen() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.mu.connCancelMap) +} + +// WaitForSQLConnsToClose waits for the client to close all SQL connections for the +// duration of connectionWait. +// With this phase, the node starts rejecting SQL connections, and as +// soon as all existing SQL connections are closed, the server early exits this +// draining phase. +func (s *Server) WaitForSQLConnsToClose(ctx context.Context, connectionWait time.Duration) error { + // If we're already draining the SQL connections, we don't need to wait again. + if s.IsDraining() { + return nil + } + + s.setRejectNewConnections(true) + + if connectionWait == 0 { + return nil + } + + log.Ops.Info(ctx, "waiting for clients to close existing SQL connections...") + + timer := time.NewTimer(connectionWait) + defer timer.Stop() + + _, allConnsDone, quitWaitingForConns := s.waitConnsDone() + defer close(quitWaitingForConns) + + select { + // Connection wait times out. + case <-time.After(connectionWait): + log.Ops.Infof(ctx, + "%d connections remain after waiting %s; proceeding to drain SQL connections", + s.GetConnCancelMapLen(), + connectionWait, + ) + case <-allConnsDone: + } + + return nil +} + +// waitConnsDone returns a copy of s.mu.connCancelMap, and a channel that +// will be closed once all sql connections are closed, or the server quits +// waiting for connections, whichever earlier. +func (s *Server) waitConnsDone() (cancelChanMap, chan struct{}, chan struct{}) { + + connCancelMap := func() cancelChanMap { + s.mu.Lock() + defer s.mu.Unlock() + connCancelMap := make(cancelChanMap) + for done, cancel := range s.mu.connCancelMap { + connCancelMap[done] = cancel + } + return connCancelMap + }() + + allConnsDone := make(chan struct{}, 1) + + quitWaitingForConns := make(chan struct{}, 1) + + go func() { + defer close(allConnsDone) + + for done := range connCancelMap { + select { + case <-done: + case <-quitWaitingForConns: + return + } + } + }() + + return connCancelMap, allConnsDone, quitWaitingForConns +} + // drainImpl drains the SQL clients. // // The queryWait duration is used to wait on clients to @@ -476,8 +577,25 @@ func (s *Server) drainImpl( cancelWait time.Duration, reporter func(int, redact.SafeString), ) error { - // This anonymous function returns a copy of s.mu.connCancelMap if there are - // any active connections to cancel. We will only attempt to cancel + + if !s.setDraining(true) { + // We are already draining. + return nil + } + + // If there is no open SQL connections to drain, just return. + if s.GetConnCancelMapLen() == 0 { + return nil + } + + log.Ops.Info(ctx, "starting draining SQL connections...") + + // Spin off a goroutine that waits for all connections to signal that they + // are done and reports it on allConnsDone. The main goroutine signals this + // goroutine to stop work through quitWaitingForConns. + + // This s.waitConnsDone function returns a copy of s.mu.connCancelMap if there + // are any active connections to cancel. We will only attempt to cancel // connections that were active at the moment the draining switch happened. // It is enough to do this because: // 1) If no new connections are added to the original map all connections @@ -486,44 +604,14 @@ func (s *Server) drainImpl( // were added when s.mu.draining = false, thus not requiring cancellation. // These connections are not our responsibility and will be handled when the // server starts draining again. - connCancelMap := func() cancelChanMap { - s.mu.Lock() - defer s.mu.Unlock() - if !s.setDrainingLocked(true) { - // We are already draining. - return nil - } - connCancelMap := make(cancelChanMap) - for done, cancel := range s.mu.connCancelMap { - connCancelMap[done] = cancel - } - return connCancelMap - }() - if len(connCancelMap) == 0 { - return nil - } + connCancelMap, allConnsDone, quitWaitingForConns := s.waitConnsDone() + defer close(quitWaitingForConns) + if reporter != nil { // Report progress to the Drain RPC. reporter(len(connCancelMap), "SQL clients") } - // Spin off a goroutine that waits for all connections to signal that they - // are done and reports it on allConnsDone. The main goroutine signals this - // goroutine to stop work through quitWaitingForConns. - allConnsDone := make(chan struct{}) - quitWaitingForConns := make(chan struct{}) - defer close(quitWaitingForConns) - go func() { - defer close(allConnsDone) - for done := range connCancelMap { - select { - case <-done: - case <-quitWaitingForConns: - return - } - } - }() - // Wait for connections to finish up their queries for the duration of queryWait. select { case <-time.After(queryWait): @@ -606,7 +694,7 @@ func (s *Server) TestingEnableAuthLogging() { // // An error is returned if the initial handshake of the connection fails. func (s *Server) ServeConn(ctx context.Context, conn net.Conn, socketType SocketType) error { - ctx, draining, onCloseFn := s.registerConn(ctx) + ctx, rejectNewConnections, onCloseFn := s.registerConn(ctx) defer onCloseFn() connDetails := eventpb.CommonConnectionDetails{ @@ -672,7 +760,7 @@ func (s *Server) ServeConn(ctx context.Context, conn net.Conn, socketType Socket } // If the server is shutting down, terminate the connection early. - if draining { + if rejectNewConnections { log.Ops.Info(ctx, "rejecting new connection while server is draining") return s.sendErr(ctx, conn, newAdminShutdownErr(ErrDrainingNewConn)) } @@ -1183,19 +1271,20 @@ func (s *Server) maybeUpgradeToSecureConn( } // registerConn registers the incoming connection to the map of active connections, -// which can be canceled by a concurrent server drain. It also returns -// the current draining status of the server. +// which can be canceled by a concurrent server drain. It also returns a boolean +// variable rejectConn, which shows if the server is rejecting new SQL +// connections. // // The onCloseFn() callback must be called at the end of the // connection by the caller. func (s *Server) registerConn( ctx context.Context, -) (newCtx context.Context, draining bool, onCloseFn func()) { +) (newCtx context.Context, rejectNewConnections bool, onCloseFn func()) { onCloseFn = func() {} newCtx = ctx s.mu.Lock() - draining = s.mu.draining - if !draining { + rejectNewConnections = s.mu.rejectNewConnections + if !rejectNewConnections { var cancel context.CancelFunc newCtx, cancel = contextutil.WithCancel(ctx) done := make(chan struct{}) @@ -1210,11 +1299,11 @@ func (s *Server) registerConn( } s.mu.Unlock() - // If the Server is draining, we will use the connection only to send an - // error, so we don't count it in the stats. This makes sense since - // DrainClient() waits for that number to drop to zero, + // If the server is rejecting new SQL connections, we will use the connection + // only to send an error, so we don't count it in the stats. This makes sense + // since DrainClient() waits for that number to drop to zero, // so we don't want it to oscillate unnecessarily. - if !draining { + if !rejectNewConnections { s.metrics.NewConns.Inc(1) s.metrics.Conns.Inc(1) prevOnCloseFn := onCloseFn