From ac55aee3be565da9e59f15492d499fa7d238e187 Mon Sep 17 00:00:00 2001 From: Evan Wall Date: Sat, 8 Jul 2023 20:20:16 -0400 Subject: [PATCH] server: move procCh init into Server.serveImpl Informs #105448 This changes `procCh` to a `sync.WaitGroup` because the channel is never read from and moves initialization into `Server.serveImpl`. Also `processCommandsAsync` is changed to `processCommands` and the goroutine is created inside `Server.serverImpl` to avoid needing a `procCh` parameter. Release note: None --- pkg/sql/pgwire/conn.go | 198 ++++++++++++++++++--------------------- pkg/sql/pgwire/server.go | 54 +++++------ 2 files changed, 115 insertions(+), 137 deletions(-) diff --git a/pkg/sql/pgwire/conn.go b/pkg/sql/pgwire/conn.go index 1c7b8d19e6d0..79508e8399a2 100644 --- a/pkg/sql/pgwire/conn.go +++ b/pkg/sql/pgwire/conn.go @@ -135,27 +135,15 @@ func (c *conn) authLogEnabled() bool { return c.alwaysLogAuthActivity || logSessionAuth.Get(c.sv) } -// processCommandsAsync spawns a goroutine that authenticates the connection and -// then processes commands from c.stmtBuf. -// -// It returns a channel that will be signaled when this goroutine is done. -// Whatever error is returned on that channel has already been written to the -// client connection, if applicable. -// -// If authentication fails, this goroutine finishes and, as always, cancelConn -// is called. +// processCommands authenticates the connection and then processes commands +// from c.stmtBuf. // // Args: // ac: An interface used by the authentication process to receive password data // and to ultimately declare the authentication successful. // reserved: Reserved memory. This method takes ownership and guarantees that it // will be closed when this function returns. -// cancelConn: A function to be called when this goroutine exits. Its goal is to -// cancel the connection's context, thus stopping the connection's goroutine. -// The returned channel is also closed before this goroutine dies, but the -// connection's goroutine is not expected to be reading from that channel -// (instead, it's expected to always be monitoring the network connection). -func (c *conn) processCommandsAsync( +func (c *conn) processCommands( ctx context.Context, authOpt authOptions, ac AuthConn, @@ -163,108 +151,102 @@ func (c *conn) processCommandsAsync( reserved *mon.BoundAccount, onDefaultIntSizeChange func(newSize int32), sessionID clusterunique.ID, -) <-chan error { +) { // reservedOwned is true while we own reserved, false when we pass ownership // away. reservedOwned := true - retCh := make(chan error, 1) - go func() { - var retErr error - var connHandler sql.ConnectionHandler - var authOK bool - var connCloseAuthHandler func() - defer func() { - // Release resources, if we still own them. - if reservedOwned { - reserved.Close(ctx) - } - // Notify the connection's goroutine that we're terminating. The - // connection might know already, as it might have triggered this - // goroutine's finish, but it also might be us that we're triggering the - // connection's death. This context cancelation serves to interrupt a - // network read on the connection's goroutine. - c.cancelConn() - - pgwireKnobs := sqlServer.GetExecutorConfig().PGWireTestingKnobs - if pgwireKnobs != nil && pgwireKnobs.CatchPanics { - if r := recover(); r != nil { - // Catch the panic and return it to the client as an error. - if err, ok := r.(error); ok { - // Mask the cause but keep the details. - retErr = errors.Handled(err) - } else { - retErr = errors.Newf("%+v", r) - } - retErr = pgerror.WithCandidateCode(retErr, pgcode.CrashShutdown) - // Add a prefix. This also adds a stack trace. - retErr = errors.Wrap(retErr, "caught fatal error") - _ = c.writeErr(ctx, retErr, &c.writerState.buf) - _ /* n */, _ /* err */ = c.writerState.buf.WriteTo(c.conn) - c.stmtBuf.Close() - // Send a ready for query to make sure the client can react. - // TODO(andrei, jordan): Why are we sending this exactly? - c.bufferReadyForQuery('I') + var retErr error + var connHandler sql.ConnectionHandler + var authOK bool + var connCloseAuthHandler func() + defer func() { + // Release resources, if we still own them. + if reservedOwned { + reserved.Close(ctx) + } + // Notify the connection's goroutine that we're terminating. The + // connection might know already, as it might have triggered this + // goroutine's finish, but it also might be us that we're triggering the + // connection's death. This context cancelation serves to interrupt a + // network read on the connection's goroutine. + c.cancelConn() + + pgwireKnobs := sqlServer.GetExecutorConfig().PGWireTestingKnobs + if pgwireKnobs != nil && pgwireKnobs.CatchPanics { + if r := recover(); r != nil { + // Catch the panic and return it to the client as an error. + if err, ok := r.(error); ok { + // Mask the cause but keep the details. + retErr = errors.Handled(err) + } else { + retErr = errors.Newf("%+v", r) } + retErr = pgerror.WithCandidateCode(retErr, pgcode.CrashShutdown) + // Add a prefix. This also adds a stack trace. + retErr = errors.Wrap(retErr, "caught fatal error") + _ = c.writeErr(ctx, retErr, &c.writerState.buf) + _ /* n */, _ /* err */ = c.writerState.buf.WriteTo(c.conn) + c.stmtBuf.Close() + // Send a ready for query to make sure the client can react. + // TODO(andrei, jordan): Why are we sending this exactly? + c.bufferReadyForQuery('I') } - if !authOK { - ac.AuthFail(retErr) - } - if connCloseAuthHandler != nil { - connCloseAuthHandler() - } - // Inform the connection goroutine of success or failure. - retCh <- retErr - }() - - // Authenticate the connection. - if connCloseAuthHandler, retErr = c.handleAuthentication( - ctx, ac, authOpt, sqlServer.GetExecutorConfig(), - ); retErr != nil { - // Auth failed or some other error. - return } - - var decrementConnectionCount func() - if decrementConnectionCount, retErr = sqlServer.IncrementConnectionCount(c.sessionArgs); retErr != nil { - _ = c.sendError(ctx, retErr) - return + if !authOK { + ac.AuthFail(retErr) } - defer decrementConnectionCount() - - if retErr = c.authOKMessage(); retErr != nil { - return - } - - // Inform the client of the default session settings. - connHandler, retErr = c.sendInitialConnData(ctx, sqlServer, onDefaultIntSizeChange, sessionID) - if retErr != nil { - return + if connCloseAuthHandler != nil { + connCloseAuthHandler() } - // Signal the connection was established to the authenticator. - ac.AuthOK(ctx) - ac.LogAuthOK(ctx) - - // We count the connection establish latency until we are ready to - // serve a SQL query. It includes the time it takes to authenticate and - // send the initial ReadyForQuery message. - duration := timeutil.Since(c.startTime).Nanoseconds() - c.metrics.ConnLatency.RecordValue(duration) - - // Mark the authentication as succeeded in case a panic - // is thrown below and we need to report to the client - // using the defer above. - authOK = true - - // Now actually process commands. - reservedOwned = false // We're about to pass ownership away. - retErr = sqlServer.ServeConn( - ctx, - connHandler, - reserved, - c.cancelConn, - ) }() - return retCh + + // Authenticate the connection. + if connCloseAuthHandler, retErr = c.handleAuthentication( + ctx, ac, authOpt, sqlServer.GetExecutorConfig(), + ); retErr != nil { + // Auth failed or some other error. + return + } + + var decrementConnectionCount func() + if decrementConnectionCount, retErr = sqlServer.IncrementConnectionCount(c.sessionArgs); retErr != nil { + _ = c.sendError(ctx, retErr) + return + } + defer decrementConnectionCount() + + if retErr = c.authOKMessage(); retErr != nil { + return + } + + // Inform the client of the default session settings. + connHandler, retErr = c.sendInitialConnData(ctx, sqlServer, onDefaultIntSizeChange, sessionID) + if retErr != nil { + return + } + // Signal the connection was established to the authenticator. + ac.AuthOK(ctx) + ac.LogAuthOK(ctx) + + // We count the connection establish latency until we are ready to + // serve a SQL query. It includes the time it takes to authenticate and + // send the initial ReadyForQuery message. + duration := timeutil.Since(c.startTime).Nanoseconds() + c.metrics.ConnLatency.RecordValue(duration) + + // Mark the authentication as succeeded in case a panic + // is thrown below and we need to report to the client + // using the defer above. + authOK = true + + // Now actually process commands. + reservedOwned = false // We're about to pass ownership away. + retErr = sqlServer.ServeConn( + ctx, + connHandler, + reserved, + c.cancelConn, + ) } func (c *conn) bufferParamStatus(param, value string) error { diff --git a/pkg/sql/pgwire/server.go b/pkg/sql/pgwire/server.go index 789b0065fe0c..119f011d4798 100644 --- a/pkg/sql/pgwire/server.go +++ b/pkg/sql/pgwire/server.go @@ -15,6 +15,7 @@ import ( "context" "io" "net" + "sync" "sync/atomic" "time" @@ -966,11 +967,9 @@ func (s *Server) serveImpl( systemIdentity = c.sessionArgs.User } authPipe := newAuthPipe(c, logAuthn, authOpt, systemIdentity) - var authenticator authenticatorIO = authPipe - // procCh is the channel on which we'll receive the termination signal from - // the command processor. - var procCh <-chan error + // procWg waits for the command processor to return. + var procWg sync.WaitGroup // We need a value for the unqualified int size here, but it is controlled // by a session variable, and this layer doesn't have access to the session @@ -983,18 +982,22 @@ func (s *Server) serveImpl( if !inTestWithoutSQL { // Spawn the command processing goroutine, which also handles connection - // authentication). It will notify us when it's done through procCh, and - // we'll also interact with the authentication process through ac. - var ac AuthConn = authPipe - procCh = c.processCommandsAsync( - ctx, - authOpt, - ac, - sqlServer, - reserved, - onDefaultIntSizeChange, - sessionID, - ) + // authentication). It will notify us when it's done through procWg, and + // we'll also interact with the authentication process through authPipe. + procWg.Add(1) + go func() { + // Inform the connection goroutine. + defer procWg.Done() + c.processCommands( + ctx, + authOpt, + authPipe, + sqlServer, + reserved, + onDefaultIntSizeChange, + sessionID, + ) + }() } else { // sqlServer == nil means we are in a local test. In this case // we only need the minimum to make pgx happy. @@ -1009,12 +1012,8 @@ func (s *Server) serveImpl( if err != nil { return } - var ac AuthConn = authPipe // Simulate auth succeeding. - ac.AuthOK(ctx) - dummyCh := make(chan error) - close(dummyCh) - procCh = dummyCh + authPipe.AuthOK(ctx) if err := c.bufferInitialReadyForQuery(0 /* queryCancelKey */); err != nil { return @@ -1098,13 +1097,13 @@ func (s *Server) serveImpl( // Pass the data to the authenticator. This hopefully causes it to finish // authentication in the background and give us an intSizer when we loop // around. - if err = authenticator.sendPwdData(pwd); err != nil { + if err = authPipe.sendPwdData(pwd); err != nil { return false, isSimpleQuery, err } return false, isSimpleQuery, nil } // Wait for the auth result. - if err = authenticator.authResult(); err != nil { + if err = authPipe.authResult(); err != nil { // The error has already been sent to the client. return true, isSimpleQuery, nil //nolint:returnerrcheck } @@ -1225,13 +1224,10 @@ func (s *Server) serveImpl( // In case the authenticator is blocked on waiting for data from the client, // tell it that there's no more data coming. This is a no-op if authentication // was completed already. - authenticator.noMorePwdData() + authPipe.noMorePwdData() - // Wait for the processor goroutine to finish, if it hasn't already. We're - // ignoring the error we get from it, as we have no use for it. It might be a - // connection error, or a context cancelation error case this goroutine is the - // one that triggered the execution to stop. - <-procCh + // Wait for the processor goroutine to finish, if it hasn't already. + procWg.Wait() if terminateSeen { return