diff --git a/pkg/server/drain.go b/pkg/server/drain.go index ddc7fcf71bc5..a66c4b18cbc2 100644 --- a/pkg/server/drain.go +++ b/pkg/server/drain.go @@ -12,6 +12,7 @@ package server import ( "context" + "fmt" "io" "strings" "time" @@ -32,21 +33,38 @@ import ( var ( queryWait = settings.RegisterDurationSetting( - settings.TenantWritable, + settings.TenantReadOnly, "server.shutdown.query_wait", "the timeout for waiting for active queries to finish during a drain "+ "(note that the --drain-wait parameter for cockroach node drain may need adjustment "+ "after changing this setting)", 10*time.Second, + settings.NonNegativeDurationWithMaximum(10*time.Hour), ).WithPublic() drainWait = settings.RegisterDurationSetting( - settings.TenantWritable, + settings.TenantReadOnly, "server.shutdown.drain_wait", "the amount of time a server waits in an unready state before proceeding with a drain "+ + "(note that the --drain-wait parameter for cockroach node drain may need adjustment "+ + "after changing this setting. --drain-wait is to specify the duration of the "+ + "whole draining process, while server.shutdown.drain_wait is to set the"+ + "wait time for health probes to notice that the node is not ready.)", + 0*time.Second, + settings.NonNegativeDurationWithMaximum(10*time.Hour), + ).WithPublic() + + connectionWait = settings.RegisterDurationSetting( + settings.TenantReadOnly, + "server.shutdown.connection_wait", + "the maximum amount of time a server waits for all SQL connections to "+ + "be closed before proceeding with a drain. "+ + "When all SQL connections are closed before times out, the server early "+ + "exits and proceeds to draining range leases. "+ "(note that the --drain-wait parameter for cockroach node drain may need adjustment "+ "after changing this setting)", 0*time.Second, + settings.NonNegativeDurationWithMaximum(10*time.Hour), ).WithPublic() ) @@ -309,19 +327,34 @@ func (s *drainServer) drainClients( s.grpc.setMode(modeDraining) s.sqlServer.isReady.Set(false) - // Wait the duration of drainWait. - // This will fail load balancer checks and delay draining so that client - // traffic can move off this node. + // Log the number of connections periodically. + if err := s.logOpenConns(ctx); err != nil { + log.Ops.Warningf(ctx, "error showing alive SQL connections: %v", err) + } + + // Wait for drainUnreadyWait. This will fail load balancer checks and + // delay draining so that client traffic can move off this node. // Note delay only happens on first call to drain. if shouldDelayDraining { - s.drainSleepFn(drainWait.Get(&s.sqlServer.execCfg.Settings.SV)) + drainWaitDuration := drainWait.Get(&s.sqlServer.execCfg.Settings.SV) + log.Ops.Info(ctx, "waiting for health probes to notice that the node "+ + "is not ready for new sql connections...") + s.drainSleepFn(drainWaitDuration) } + // Wait for users to close the existing SQL connections. + // During this phase, the server is rejecting new SQL connections. + // The server exit this phase either once all SQL connections are closed, + // or it reaches the connectionMaxWait timeout, whichever earlier. + connectionMaxWait := connectionWait.Get(&s.sqlServer.execCfg.Settings.SV) + s.sqlServer.pgServer.WaitForSQLConnsToClose(ctx, connectionMaxWait) + // Drain all SQL connections. - // The queryWait duration is a timeout for waiting on clients - // to self-disconnect. If the timeout is reached, any remaining connections + // The queryWait duration is a timeout for waiting for SQL queries to finish. + // If the timeout is reached, any remaining connections // will be closed. queryMaxWait := queryWait.Get(&s.sqlServer.execCfg.Settings.SV) + if err := s.sqlServer.pgServer.Drain(ctx, queryMaxWait, reporter); err != nil { return err } @@ -357,3 +390,21 @@ func (s *drainServer) drainNode( // Mark the stores of the node as "draining" and drain all range leases. return s.kvServer.node.SetDraining(true /* drain */, reporter, verbose) } + +// logOpenConns logs the number of open SQL connections every 3 seconds. +func (s *drainServer) logOpenConns(ctx context.Context) error { + return s.stopper.RunAsyncTask(ctx, "log-open-conns", func(ctx context.Context) { + ticker := time.NewTicker(3 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + log.Ops.Info(ctx, fmt.Sprintf("number of open connections: %d\n", s.sqlServer.pgServer.GetConnCancelMapLen())) + case <-s.stopper.ShouldQuiesce(): + return + case <-ctx.Done(): + return + } + } + }) +} diff --git a/pkg/sql/pgwire/server.go b/pkg/sql/pgwire/server.go index 3e79a3d6d048..a59a31b47ab2 100644 --- a/pkg/sql/pgwire/server.go +++ b/pkg/sql/pgwire/server.go @@ -220,7 +220,16 @@ type Server struct { // cancel the associated connection. The corresponding key is a channel // that is closed when the connection is done. connCancelMap cancelChanMap - draining bool + // draining is set true when server starts draining any SQL connections that + // are without query in flight. Remaining SQL connections will be closed as + // soon as their queries finish. After the timeout set by + // server.shutdown.query_wait, all connections will be closed regardless any + // queries in flight. + draining bool + // rejectNewConnections is set true when the server does not accept new + // SQL connections, e.g. when the draining process enters the phase whose + // duration is specified by the server.shutdown.connection_wait. + rejectNewConnections bool } auth struct { @@ -443,14 +452,14 @@ func (s *Server) Drain( // Undrain switches the server back to the normal mode of operation in which // connections are accepted. func (s *Server) Undrain() { - s.mu.Lock() - s.setDrainingLocked(false) - s.mu.Unlock() + s.setDraining(false) } -// setDrainingLocked sets the server's draining state and returns whether the -// state changed (i.e. drain != s.mu.draining). s.mu must be locked. -func (s *Server) setDrainingLocked(drain bool) bool { +// setDraining sets the server's draining state and returns whether the +// state changed (i.e. drain != s.mu.draining). +func (s *Server) setDraining(drain bool) bool { + s.mu.Lock() + defer s.mu.Unlock() if s.mu.draining == drain { return false } @@ -458,6 +467,99 @@ func (s *Server) setDrainingLocked(drain bool) bool { return true } +// DisallowNewSQLConn disallows new SQL connections to the server. +func (s *Server) DisallowNewSQLConn() { + s.mu.Lock() + defer s.mu.Unlock() + s.mu.rejectNewConnections = true +} + +// GetConnCancelMapLen returns the length of connCancelMap of the server. +// This is a helper function when the server waits the SQL connections to be +// closed. During this period, the server listens to the status of all +// connections, and early exits this draining phase if there remains no active +// SQL connections. +func (s *Server) GetConnCancelMapLen() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.mu.connCancelMap) +} + +// WaitForSQLConnsToClose waits for the client to close all SQL connections for the +// duration of connectionWait. +// With this phase, the node starts rejecting SQL connections, and as +// soon as all existing SQL connections are closed, the server early exits this +// draining phase. +func (s *Server) WaitForSQLConnsToClose(ctx context.Context, connectionWait time.Duration) { + // If we're already draining the SQL connections, we don't need to wait again. + if s.IsDraining() { + return + } + + log.Ops.Info(ctx, "waiting for clients to close existing SQL connections...") + + s.DisallowNewSQLConn() + + timer := time.NewTimer(connectionWait) + defer timer.Stop() + + quitWaitingForConns := make(chan struct{}, 1) + defer close(quitWaitingForConns) + + _, allConnsDone := s.waitConnsDone(quitWaitingForConns) + + for { + select { + // Connection wait times out. + case <-timer.C: + log.Ops.Infof(ctx, + "server.shutdown.connection_wait times out, "+ + "there are still %d SQL connections not closed by the client", + s.GetConnCancelMapLen(), + ) + return + case <-allConnsDone: + // We wait till all SQL connections are closed, or connection_wait timeouts. + if !timer.Stop() { + <-timer.C + } + return + } + } +} + +// waitConnsDone returns a copy of s.mu.connCancelMap, and a channel that +// will be closed once all sql connections are closed, or the server quits +// waiting for connections, whichever earlier. +func (s *Server) waitConnsDone(quitWaitingForConns chan struct{}) (cancelChanMap, chan struct{}) { + + connCancelMap := func() cancelChanMap { + s.mu.Lock() + defer s.mu.Unlock() + connCancelMap := make(cancelChanMap) + for done, cancel := range s.mu.connCancelMap { + connCancelMap[done] = cancel + } + return connCancelMap + }() + + allConnsDone := make(chan struct{}, 1) + + go func() { + defer close(allConnsDone) + + for done := range connCancelMap { + select { + case <-done: + case <-quitWaitingForConns: + return + } + } + }() + + return connCancelMap, allConnsDone +} + // drainImpl drains the SQL clients. // // The queryWait duration is used to wait on clients to @@ -476,8 +578,21 @@ func (s *Server) drainImpl( cancelWait time.Duration, reporter func(int, redact.SafeString), ) error { - // This anonymous function returns a copy of s.mu.connCancelMap if there are - // any active connections to cancel. We will only attempt to cancel + + if !s.setDraining(true) { + // We are already draining. + return nil + } + + // Spin off a goroutine that waits for all connections to signal that they + // are done and reports it on allConnsDone. The main goroutine signals this + // goroutine to stop work through quitWaitingForConns. + + quitWaitingForConns := make(chan struct{}, 1) + defer close(quitWaitingForConns) + + // This s.waitConnsDone function returns a copy of s.mu.connCancelMap if there + // are any active connections to cancel. We will only attempt to cancel // connections that were active at the moment the draining switch happened. // It is enough to do this because: // 1) If no new connections are added to the original map all connections @@ -486,43 +601,18 @@ func (s *Server) drainImpl( // were added when s.mu.draining = false, thus not requiring cancellation. // These connections are not our responsibility and will be handled when the // server starts draining again. - connCancelMap := func() cancelChanMap { - s.mu.Lock() - defer s.mu.Unlock() - if !s.setDrainingLocked(true) { - // We are already draining. - return nil - } - connCancelMap := make(cancelChanMap) - for done, cancel := range s.mu.connCancelMap { - connCancelMap[done] = cancel - } - return connCancelMap - }() + connCancelMap, allConnsDone := s.waitConnsDone(quitWaitingForConns) + if len(connCancelMap) == 0 { return nil } + if reporter != nil { // Report progress to the Drain RPC. reporter(len(connCancelMap), "SQL clients") } - // Spin off a goroutine that waits for all connections to signal that they - // are done and reports it on allConnsDone. The main goroutine signals this - // goroutine to stop work through quitWaitingForConns. - allConnsDone := make(chan struct{}) - quitWaitingForConns := make(chan struct{}) - defer close(quitWaitingForConns) - go func() { - defer close(allConnsDone) - for done := range connCancelMap { - select { - case <-done: - case <-quitWaitingForConns: - return - } - } - }() + log.Ops.Info(ctx, "starting draining the SQL layer...") // Wait for connections to finish up their queries for the duration of queryWait. select { @@ -1183,19 +1273,20 @@ func (s *Server) maybeUpgradeToSecureConn( } // registerConn registers the incoming connection to the map of active connections, -// which can be canceled by a concurrent server drain. It also returns -// the current draining status of the server. +// which can be canceled by a concurrent server drain. It also returns a boolean +// variable rejectConn, which shows if the server is rejecting new SQL +// connections. // // The onCloseFn() callback must be called at the end of the // connection by the caller. func (s *Server) registerConn( ctx context.Context, -) (newCtx context.Context, draining bool, onCloseFn func()) { +) (newCtx context.Context, rejectNewConnections bool, onCloseFn func()) { onCloseFn = func() {} newCtx = ctx s.mu.Lock() - draining = s.mu.draining - if !draining { + rejectNewConnections = s.mu.rejectNewConnections + if !rejectNewConnections { var cancel context.CancelFunc newCtx, cancel = contextutil.WithCancel(ctx) done := make(chan struct{}) @@ -1210,11 +1301,11 @@ func (s *Server) registerConn( } s.mu.Unlock() - // If the Server is draining, we will use the connection only to send an - // error, so we don't count it in the stats. This makes sense since - // DrainClient() waits for that number to drop to zero, + // If the server is rejecting new SQL connections, we will use the connection + // only to send an error, so we don't count it in the stats. This makes sense + // since DrainClient() waits for that number to drop to zero, // so we don't want it to oscillate unnecessarily. - if !draining { + if !rejectNewConnections { s.metrics.NewConns.Inc(1) s.metrics.Conns.Inc(1) prevOnCloseFn := onCloseFn