diff --git a/pkg/server/drain.go b/pkg/server/drain.go index c59c7127bda3..5203fb5fa86f 100644 --- a/pkg/server/drain.go +++ b/pkg/server/drain.go @@ -12,6 +12,7 @@ package server import ( "context" + "fmt" "io" "strings" "time" @@ -46,6 +47,15 @@ var ( "after changing this setting)", 0*time.Second, ).WithPublic() + + connectionWait = settings.RegisterDurationSetting( + settings.TenantReadOnly, + "server.shutdown.connection_wait", + "the amount of time a server waits users to close all SQL connections "+ + "before proceeding with the rest of the shutdown process. It must be set "+ + "with a non-negative value", + 0*time.Second, + ).WithPublic() ) // Drain puts the node into the specified drain mode(s) and optionally @@ -246,12 +256,46 @@ func (s *Server) drainClients(ctx context.Context, reporter func(int, redact.Saf // Wait for drainUnreadyWait. This will fail load balancer checks and // delay draining so that client traffic can move off this node. // Note delay only happens on first call to drain. + + // Log the number of connections in cockroach.log every 3 seconds. + go func() { + for { + s.sqlServer.pgServer.LogActiveConns(ctx) + time.Sleep(3 * time.Second) + } + }() + if shouldDelayDraining { + drainWaitLogMessage := fmt.Sprintf("drain wait starts, timeout after %s", + drainWait.Get(&s.st.SV), + ) + log.Ops.Info(ctx, drainWaitLogMessage) + s.drainSleepFn(drainWait.Get(&s.st.SV)) + + // Disallow new connections up to the connectionWait timeout. Or early exit + // if all existing connections are closed. + connectionMaxWait := connectionWait.Get(&s.st.SV) + if connectionMaxWait < 0 { + return errors.Wrapf(fmt.Errorf("server.shutdown.connection_wait must be set with non-negative duration"), "%s") + } + // TODO(janexing): maybe print a more informative log, such as + // "connection wait starts, no new SQL connections allowed, early exit to + // lease transfer phase once all SQL connections are closed." + connectionWaitLogMessage := fmt.Sprintf("connection wait starts, timeout after %s", + connectionMaxWait, + ) + log.Ops.Info(ctx, connectionWaitLogMessage) + + s.sqlServer.pgServer.DrainConnections(ctx, connectionMaxWait) } - // Disable incoming SQL clients up to the queryWait timeout. queryMaxWait := queryWait.Get(&s.st.SV) + if queryMaxWait < 0 { + return errors.Wrapf(fmt.Errorf("server.shutdown.queryMaxWait must be set with non-negative duration"), "%s") + } + + // Disable incoming SQL queries up to the queryWait timeout. if err := s.sqlServer.pgServer.Drain(ctx, queryMaxWait, reporter); err != nil { return err } diff --git a/pkg/sql/pgwire/conn.go b/pkg/sql/pgwire/conn.go index 2c13a70ec094..5a4c03ab03f5 100644 --- a/pkg/sql/pgwire/conn.go +++ b/pkg/sql/pgwire/conn.go @@ -166,7 +166,7 @@ func (s *Server) serveConn( } // Do the reading of commands from the network. - c.serveImpl(ctx, s.IsDraining, s.SQLServer, reserved, authOpt) + c.serveImpl(ctx, s.IsQueryDraining, s.SQLServer, reserved, authOpt) } // alwaysLogAuthActivity makes it possible to unconditionally enable diff --git a/pkg/sql/pgwire/server.go b/pkg/sql/pgwire/server.go index d9754d836e10..daf4290f8477 100644 --- a/pkg/sql/pgwire/server.go +++ b/pkg/sql/pgwire/server.go @@ -186,7 +186,15 @@ type Server struct { // cancel the associated connection. The corresponding key is a channel // that is closed when the connection is done. connCancelMap cancelChanMap - draining bool + // TODO (janexing) draining is only set true when draining wait is ended, + // and connection wait is started. Should be renamed more informatively. + draining bool + // queryDraining is set true when the draining process enters the query wait + // period, during which any connections without query in flight will be + // closed, and remaining connections will be closed as soon as their + // queries finish. After query wait, all connections will be closed + // regardless any queries in flight. + queryDraining bool } auth struct { @@ -363,6 +371,12 @@ func (s *Server) IsDraining() bool { return s.mu.draining } +func (s *Server) IsQueryDraining() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.mu.queryDraining +} + // Metrics returns the set of metrics structs. func (s *Server) Metrics() (res []interface{}) { return []interface{}{ @@ -404,6 +418,7 @@ func (s *Server) Drain( func (s *Server) Undrain() { s.mu.Lock() s.setDrainingLocked(false) + s.SetQueryDrainingLocked(false) s.mu.Unlock() } @@ -417,6 +432,62 @@ func (s *Server) setDrainingLocked(drain bool) bool { return true } +// LogActiveConns logs the number of active connections. +func (s *Server) LogActiveConns(ctx context.Context) { + log.Ops.Info(ctx, fmt.Sprintf("number of active connections: %d\n", len(s.mu.connCancelMap))) +} + +// SetQueryDrainingLocked set the server's queryDraining state and returns +// the state changed (i.e. drain != s.mu.queryDraining). s.mu must be locked. +func (s *Server) SetQueryDrainingLocked(drain bool) bool { + if s.mu.queryDraining == drain { + return false + } + s.mu.queryDraining = drain + return true +} + +// DisallowNewSQLConn disallow new connections to the server. +func (s *Server) DisallowNewSQLConn() { + s.mu.Lock() + defer s.mu.Unlock() + s.setDrainingLocked(true) +} + +// GetConnCancelMap returns the connCancelMap of the server. s.mu must be locked. +// This is a helper function to the connectionWait process, which listen to the +// status of all connections, and early exits if there remains no connections. +func (s *Server) GetConnCancelMap() cancelChanMap { + s.mu.Lock() + defer s.mu.Unlock() + return s.mu.connCancelMap +} + +// DrainConnections waits the users to close all SQL connections before the +// connectionWait timeout. +// During the connectionWait period, no new SQL connections are allowed, and as +// soon as all existing SQL connections are closed, we do an early exit. +func (s *Server) DrainConnections(ctx context.Context, connectionWait time.Duration) { + s.DisallowNewSQLConn() + + timer := time.NewTimer(connectionWait) + for { + select { + // connection wait times out. + case <-timer.C: + return + default: + if len(s.GetConnCancelMap()) == 0 { + if !timer.Stop() { + <-timer.C + } + return + } + } + } + return +} + // drainImpl drains the SQL clients. // // The queryWait duration is used to wait on clients to @@ -448,10 +519,6 @@ func (s *Server) drainImpl( connCancelMap := func() cancelChanMap { s.mu.Lock() defer s.mu.Unlock() - if !s.setDrainingLocked(true) { - // We are already draining. - return nil - } connCancelMap := make(cancelChanMap) for done, cancel := range s.mu.connCancelMap { connCancelMap[done] = cancel @@ -461,11 +528,20 @@ func (s *Server) drainImpl( if len(connCancelMap) == 0 { return nil } + + queryWaitLogMessage := fmt.Sprintf("query wait starts, timeout after %s", + queryWait, + ) + log.Ops.Info(ctx, queryWaitLogMessage) + if reporter != nil { // Report progress to the Drain RPC. reporter(len(connCancelMap), "SQL clients") } + // Mark the server enters the query wait process. + s.SetQueryDrainingLocked(true) + // Spin off a goroutine that waits for all connections to signal that they // are done and reports it on allConnsDone. The main goroutine signals this // goroutine to stop work through quitWaitingForConns. @@ -486,6 +562,7 @@ func (s *Server) drainImpl( // Wait for all connections to finish up to drainWait. select { case <-time.After(queryWait): + fmt.Println("query wait ended at ", time.Now().Format(pgTimeFormat)) log.Ops.Warningf(ctx, "canceling all sessions after waiting %s", queryWait) case <-allConnsDone: }