diff --git a/pkg/server/drain.go b/pkg/server/drain.go index ddc7fcf71bc5..4bc6cbdef9a0 100644 --- a/pkg/server/drain.go +++ b/pkg/server/drain.go @@ -12,6 +12,7 @@ package server import ( "context" + "fmt" "io" "strings" "time" @@ -32,21 +33,38 @@ import ( var ( queryWait = settings.RegisterDurationSetting( - settings.TenantWritable, + settings.TenantReadOnly, "server.shutdown.query_wait", "the timeout for waiting for active queries to finish during a drain "+ "(note that the --drain-wait parameter for cockroach node drain may need adjustment "+ "after changing this setting)", 10*time.Second, + settings.NonNegativeDurationWithMaximum(10*time.Hour), ).WithPublic() drainWait = settings.RegisterDurationSetting( - settings.TenantWritable, + settings.TenantReadOnly, "server.shutdown.drain_wait", "the amount of time a server waits in an unready state before proceeding with a drain "+ + "(note that the --drain-wait parameter for cockroach node drain may need adjustment "+ + "after changing this setting. --drain-wait is to specify the duration of the "+ + "whole draining process, while server.shutdown.drain_wait is to set the"+ + "wait time for health probes to notice that the node is not ready.)", + 0*time.Second, + settings.NonNegativeDurationWithMaximum(10*time.Hour), + ).WithPublic() + + connectionWait = settings.RegisterDurationSetting( + settings.TenantReadOnly, + "server.shutdown.connection_wait", + "the maximum amount of time a server waits for all SQL connections to "+ + "be closed before proceeding with a drain. "+ + "When all SQL connections are closed before times out, the server early "+ + "exits and proceeds to draining range leases. "+ "(note that the --drain-wait parameter for cockroach node drain may need adjustment "+ "after changing this setting)", 0*time.Second, + settings.NonNegativeDurationWithMaximum(10*time.Hour), ).WithPublic() ) @@ -309,19 +327,33 @@ func (s *drainServer) drainClients( s.grpc.setMode(modeDraining) s.sqlServer.isReady.Set(false) + // Log the number of connections periodically. + if err := s.logOpenConns(ctx); err != nil { + log.Ops.Warningf(ctx, "error showing alive SQL connections: %v", err) + } + // Wait the duration of drainWait. // This will fail load balancer checks and delay draining so that client // traffic can move off this node. // Note delay only happens on first call to drain. if shouldDelayDraining { + log.Ops.Info(ctx, "waiting for health probes to notice that the node "+ + "is not ready for new sql connections...") s.drainSleepFn(drainWait.Get(&s.sqlServer.execCfg.Settings.SV)) } - // Drain all SQL connections. - // The queryWait duration is a timeout for waiting on clients - // to self-disconnect. If the timeout is reached, any remaining connections + // Wait for users to close the existing SQL connections. + // During this phase, the server is rejecting new SQL connections. + // The server exit this phase either once all SQL connections are closed, + // or it reaches the connectionMaxWait timeout, whichever earlier. + s.sqlServer.pgServer.WaitForSQLConnsToClose(ctx, connectionWait.Get(&s.sqlServer.execCfg.Settings.SV)) + + // Drain any remaining SQL connections. + // The queryWait duration is a timeout for waiting for SQL queries to finish. + // If the timeout is reached, any remaining connections // will be closed. queryMaxWait := queryWait.Get(&s.sqlServer.execCfg.Settings.SV) + if err := s.sqlServer.pgServer.Drain(ctx, queryMaxWait, reporter); err != nil { return err } @@ -357,3 +389,21 @@ func (s *drainServer) drainNode( // Mark the stores of the node as "draining" and drain all range leases. return s.kvServer.node.SetDraining(true /* drain */, reporter, verbose) } + +// logOpenConns logs the number of open SQL connections every 3 seconds. +func (s *drainServer) logOpenConns(ctx context.Context) error { + return s.stopper.RunAsyncTask(ctx, "log-open-conns", func(ctx context.Context) { + ticker := time.NewTicker(3 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + log.Ops.Info(ctx, fmt.Sprintf("number of open connections: %d\n", s.sqlServer.pgServer.GetConnCancelMapLen())) + case <-s.stopper.ShouldQuiesce(): + return + case <-ctx.Done(): + return + } + } + }) +} diff --git a/pkg/sql/pgwire/server.go b/pkg/sql/pgwire/server.go index bbd4c178f244..03da970e1501 100644 --- a/pkg/sql/pgwire/server.go +++ b/pkg/sql/pgwire/server.go @@ -220,7 +220,16 @@ type Server struct { // cancel the associated connection. The corresponding key is a channel // that is closed when the connection is done. connCancelMap cancelChanMap - draining bool + // draining is set to true when the server starts draining the SQL layer. + // Remaining SQL connections will be closed as + // soon as their queries finish. After the timeout set by + // server.shutdown.query_wait, all connections will be closed regardless any + // queries in flight. + draining bool + // rejectNewConnections is set true when the server does not accept new + // SQL connections, e.g. when the draining process enters the phase whose + // duration is specified by the server.shutdown.connection_wait. + rejectNewConnections bool } auth struct { @@ -443,14 +452,15 @@ func (s *Server) Drain( // Undrain switches the server back to the normal mode of operation in which // connections are accepted. func (s *Server) Undrain() { - s.mu.Lock() - s.setDrainingLocked(false) - s.mu.Unlock() + s.setRejectNewConnections(false) + s.setDraining(false) } -// setDrainingLocked sets the server's draining state and returns whether the -// state changed (i.e. drain != s.mu.draining). s.mu must be locked. -func (s *Server) setDrainingLocked(drain bool) bool { +// setDraining sets the server's draining state and returns whether the +// state changed (i.e. drain != s.mu.draining). +func (s *Server) setDraining(drain bool) bool { + s.mu.Lock() + defer s.mu.Unlock() if s.mu.draining == drain { return false } @@ -458,6 +468,97 @@ func (s *Server) setDrainingLocked(drain bool) bool { return true } +// setRejectNewConnections sets the server's rejectNewConnections state. +func (s *Server) setRejectNewConnections(rej bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.mu.rejectNewConnections = rej +} + +// GetConnCancelMapLen returns the length of connCancelMap of the server. +// This is a helper function when the server waits the SQL connections to be +// closed. During this period, the server listens to the status of all +// connections, and early exits this draining phase if there remains no active +// SQL connections. +func (s *Server) GetConnCancelMapLen() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.mu.connCancelMap) +} + +// WaitForSQLConnsToClose waits for the client to close all SQL connections for the +// duration of connectionWait. +// With this phase, the node starts rejecting SQL connections, and as +// soon as all existing SQL connections are closed, the server early exits this +// draining phase. +func (s *Server) WaitForSQLConnsToClose(ctx context.Context, connectionWait time.Duration) error { + // If we're already draining the SQL connections, we don't need to wait again. + if s.IsDraining() { + return nil + } + + s.setRejectNewConnections(true) + + if connectionWait == 0 { + return nil + } + + log.Ops.Info(ctx, "waiting for clients to close existing SQL connections...") + + timer := time.NewTimer(connectionWait) + defer timer.Stop() + + _, allConnsDone, quitWaitingForConns := s.waitConnsDone() + defer close(quitWaitingForConns) + + select { + // Connection wait times out. + case <-time.After(connectionWait): + log.Ops.Infof(ctx, + "%d connections remain after waiting %s; proceeding to drain SQL connections", + s.GetConnCancelMapLen(), + connectionWait, + ) + case <-allConnsDone: + } + + return nil +} + +// waitConnsDone returns a copy of s.mu.connCancelMap, and a channel that +// will be closed once all sql connections are closed, or the server quits +// waiting for connections, whichever earlier. +func (s *Server) waitConnsDone() (cancelChanMap, chan struct{}, chan struct{}) { + + connCancelMap := func() cancelChanMap { + s.mu.Lock() + defer s.mu.Unlock() + connCancelMap := make(cancelChanMap) + for done, cancel := range s.mu.connCancelMap { + connCancelMap[done] = cancel + } + return connCancelMap + }() + + allConnsDone := make(chan struct{}, 1) + + quitWaitingForConns := make(chan struct{}, 1) + + go func() { + defer close(allConnsDone) + + for done := range connCancelMap { + select { + case <-done: + case <-quitWaitingForConns: + return + } + } + }() + + return connCancelMap, allConnsDone, quitWaitingForConns +} + // drainImpl drains the SQL clients. // // The queryWait duration is used to wait on clients to @@ -476,8 +577,25 @@ func (s *Server) drainImpl( cancelWait time.Duration, reporter func(int, redact.SafeString), ) error { - // This anonymous function returns a copy of s.mu.connCancelMap if there are - // any active connections to cancel. We will only attempt to cancel + + if !s.setDraining(true) { + // We are already draining. + return nil + } + + // If there is no open SQL connections to drain, just return. + if s.GetConnCancelMapLen() == 0 { + return nil + } + + log.Ops.Info(ctx, "starting draining SQL connections...") + + // Spin off a goroutine that waits for all connections to signal that they + // are done and reports it on allConnsDone. The main goroutine signals this + // goroutine to stop work through quitWaitingForConns. + + // This s.waitConnsDone function returns a copy of s.mu.connCancelMap if there + // are any active connections to cancel. We will only attempt to cancel // connections that were active at the moment the draining switch happened. // It is enough to do this because: // 1) If no new connections are added to the original map all connections @@ -486,44 +604,14 @@ func (s *Server) drainImpl( // were added when s.mu.draining = false, thus not requiring cancellation. // These connections are not our responsibility and will be handled when the // server starts draining again. - connCancelMap := func() cancelChanMap { - s.mu.Lock() - defer s.mu.Unlock() - if !s.setDrainingLocked(true) { - // We are already draining. - return nil - } - connCancelMap := make(cancelChanMap) - for done, cancel := range s.mu.connCancelMap { - connCancelMap[done] = cancel - } - return connCancelMap - }() - if len(connCancelMap) == 0 { - return nil - } + connCancelMap, allConnsDone, quitWaitingForConns := s.waitConnsDone() + defer close(quitWaitingForConns) + if reporter != nil { // Report progress to the Drain RPC. reporter(len(connCancelMap), "SQL clients") } - // Spin off a goroutine that waits for all connections to signal that they - // are done and reports it on allConnsDone. The main goroutine signals this - // goroutine to stop work through quitWaitingForConns. - allConnsDone := make(chan struct{}) - quitWaitingForConns := make(chan struct{}) - defer close(quitWaitingForConns) - go func() { - defer close(allConnsDone) - for done := range connCancelMap { - select { - case <-done: - case <-quitWaitingForConns: - return - } - } - }() - // Wait for connections to finish up their queries for the duration of queryWait. select { case <-time.After(queryWait): @@ -606,7 +694,7 @@ func (s *Server) TestingEnableAuthLogging() { // // An error is returned if the initial handshake of the connection fails. func (s *Server) ServeConn(ctx context.Context, conn net.Conn, socketType SocketType) error { - ctx, draining, onCloseFn := s.registerConn(ctx) + ctx, rejectNewConnections, onCloseFn := s.registerConn(ctx) defer onCloseFn() connDetails := eventpb.CommonConnectionDetails{ @@ -672,7 +760,7 @@ func (s *Server) ServeConn(ctx context.Context, conn net.Conn, socketType Socket } // If the server is shutting down, terminate the connection early. - if draining { + if rejectNewConnections { log.Ops.Info(ctx, "rejecting new connection while server is draining") return s.sendErr(ctx, conn, newAdminShutdownErr(ErrDrainingNewConn)) } @@ -1183,19 +1271,20 @@ func (s *Server) maybeUpgradeToSecureConn( } // registerConn registers the incoming connection to the map of active connections, -// which can be canceled by a concurrent server drain. It also returns -// the current draining status of the server. +// which can be canceled by a concurrent server drain. It also returns a boolean +// variable rejectConn, which shows if the server is rejecting new SQL +// connections. // // The onCloseFn() callback must be called at the end of the // connection by the caller. func (s *Server) registerConn( ctx context.Context, -) (newCtx context.Context, draining bool, onCloseFn func()) { +) (newCtx context.Context, rejectNewConnections bool, onCloseFn func()) { onCloseFn = func() {} newCtx = ctx s.mu.Lock() - draining = s.mu.draining - if !draining { + rejectNewConnections = s.mu.rejectNewConnections + if !rejectNewConnections { var cancel context.CancelFunc newCtx, cancel = contextutil.WithCancel(ctx) done := make(chan struct{}) @@ -1210,11 +1299,11 @@ func (s *Server) registerConn( } s.mu.Unlock() - // If the Server is draining, we will use the connection only to send an - // error, so we don't count it in the stats. This makes sense since - // DrainClient() waits for that number to drop to zero, + // If the server is rejecting new SQL connections, we will use the connection + // only to send an error, so we don't count it in the stats. This makes sense + // since DrainClient() waits for that number to drop to zero, // so we don't want it to oscillate unnecessarily. - if !draining { + if !rejectNewConnections { s.metrics.NewConns.Inc(1) s.metrics.Conns.Inc(1) prevOnCloseFn := onCloseFn