Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: move procCh init into Server.serveImpl #106943

Merged
merged 1 commit into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 90 additions & 108 deletions pkg/sql/pgwire/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,136 +135,118 @@ func (c *conn) authLogEnabled() bool {
return c.alwaysLogAuthActivity || logSessionAuth.Get(c.sv)
}

// processCommandsAsync spawns a goroutine that authenticates the connection and
// then processes commands from c.stmtBuf.
//
// It returns a channel that will be signaled when this goroutine is done.
// Whatever error is returned on that channel has already been written to the
// client connection, if applicable.
//
// If authentication fails, this goroutine finishes and, as always, cancelConn
// is called.
// processCommands authenticates the connection and then processes commands
// from c.stmtBuf.
//
// Args:
// ac: An interface used by the authentication process to receive password data
// and to ultimately declare the authentication successful.
// reserved: Reserved memory. This method takes ownership and guarantees that it
// will be closed when this function returns.
// cancelConn: A function to be called when this goroutine exits. Its goal is to
// cancel the connection's context, thus stopping the connection's goroutine.
// The returned channel is also closed before this goroutine dies, but the
// connection's goroutine is not expected to be reading from that channel
// (instead, it's expected to always be monitoring the network connection).
func (c *conn) processCommandsAsync(
func (c *conn) processCommands(
ctx context.Context,
authOpt authOptions,
ac AuthConn,
sqlServer *sql.Server,
reserved *mon.BoundAccount,
onDefaultIntSizeChange func(newSize int32),
sessionID clusterunique.ID,
) <-chan error {
) {
// reservedOwned is true while we own reserved, false when we pass ownership
// away.
reservedOwned := true
retCh := make(chan error, 1)
go func() {
var retErr error
var connHandler sql.ConnectionHandler
var authOK bool
var connCloseAuthHandler func()
defer func() {
// Release resources, if we still own them.
if reservedOwned {
reserved.Close(ctx)
}
// Notify the connection's goroutine that we're terminating. The
// connection might know already, as it might have triggered this
// goroutine's finish, but it also might be us that we're triggering the
// connection's death. This context cancelation serves to interrupt a
// network read on the connection's goroutine.
c.cancelConn()

pgwireKnobs := sqlServer.GetExecutorConfig().PGWireTestingKnobs
if pgwireKnobs != nil && pgwireKnobs.CatchPanics {
if r := recover(); r != nil {
// Catch the panic and return it to the client as an error.
if err, ok := r.(error); ok {
// Mask the cause but keep the details.
retErr = errors.Handled(err)
} else {
retErr = errors.Newf("%+v", r)
}
retErr = pgerror.WithCandidateCode(retErr, pgcode.CrashShutdown)
// Add a prefix. This also adds a stack trace.
retErr = errors.Wrap(retErr, "caught fatal error")
_ = c.writeErr(ctx, retErr, &c.writerState.buf)
_ /* n */, _ /* err */ = c.writerState.buf.WriteTo(c.conn)
c.stmtBuf.Close()
// Send a ready for query to make sure the client can react.
// TODO(andrei, jordan): Why are we sending this exactly?
c.bufferReadyForQuery('I')
var retErr error
var connHandler sql.ConnectionHandler
var authOK bool
var connCloseAuthHandler func()
defer func() {
// Release resources, if we still own them.
if reservedOwned {
reserved.Close(ctx)
}
// Notify the connection's goroutine that we're terminating. The
// connection might know already, as it might have triggered this
// goroutine's finish, but it also might be us that we're triggering the
// connection's death. This context cancelation serves to interrupt a
// network read on the connection's goroutine.
c.cancelConn()

pgwireKnobs := sqlServer.GetExecutorConfig().PGWireTestingKnobs
if pgwireKnobs != nil && pgwireKnobs.CatchPanics {
if r := recover(); r != nil {
// Catch the panic and return it to the client as an error.
if err, ok := r.(error); ok {
// Mask the cause but keep the details.
retErr = errors.Handled(err)
} else {
retErr = errors.Newf("%+v", r)
}
retErr = pgerror.WithCandidateCode(retErr, pgcode.CrashShutdown)
// Add a prefix. This also adds a stack trace.
retErr = errors.Wrap(retErr, "caught fatal error")
_ = c.writeErr(ctx, retErr, &c.writerState.buf)
_ /* n */, _ /* err */ = c.writerState.buf.WriteTo(c.conn)
c.stmtBuf.Close()
// Send a ready for query to make sure the client can react.
// TODO(andrei, jordan): Why are we sending this exactly?
c.bufferReadyForQuery('I')
}
if !authOK {
ac.AuthFail(retErr)
}
if connCloseAuthHandler != nil {
connCloseAuthHandler()
}
// Inform the connection goroutine of success or failure.
retCh <- retErr
}()

// Authenticate the connection.
if connCloseAuthHandler, retErr = c.handleAuthentication(
ctx, ac, authOpt, sqlServer.GetExecutorConfig(),
); retErr != nil {
// Auth failed or some other error.
return
}

var decrementConnectionCount func()
if decrementConnectionCount, retErr = sqlServer.IncrementConnectionCount(c.sessionArgs); retErr != nil {
_ = c.sendError(ctx, retErr)
return
if !authOK {
ac.AuthFail(retErr)
}
defer decrementConnectionCount()

if retErr = c.authOKMessage(); retErr != nil {
return
}

// Inform the client of the default session settings.
connHandler, retErr = c.sendInitialConnData(ctx, sqlServer, onDefaultIntSizeChange, sessionID)
if retErr != nil {
return
if connCloseAuthHandler != nil {
connCloseAuthHandler()
}
// Signal the connection was established to the authenticator.
ac.AuthOK(ctx)
ac.LogAuthOK(ctx)

// We count the connection establish latency until we are ready to
// serve a SQL query. It includes the time it takes to authenticate and
// send the initial ReadyForQuery message.
duration := timeutil.Since(c.startTime).Nanoseconds()
c.metrics.ConnLatency.RecordValue(duration)

// Mark the authentication as succeeded in case a panic
// is thrown below and we need to report to the client
// using the defer above.
authOK = true

// Now actually process commands.
reservedOwned = false // We're about to pass ownership away.
retErr = sqlServer.ServeConn(
ctx,
connHandler,
reserved,
c.cancelConn,
)
}()
return retCh

// Authenticate the connection.
if connCloseAuthHandler, retErr = c.handleAuthentication(
ctx, ac, authOpt, sqlServer.GetExecutorConfig(),
); retErr != nil {
// Auth failed or some other error.
return
}

var decrementConnectionCount func()
if decrementConnectionCount, retErr = sqlServer.IncrementConnectionCount(c.sessionArgs); retErr != nil {
_ = c.sendError(ctx, retErr)
return
}
defer decrementConnectionCount()

if retErr = c.authOKMessage(); retErr != nil {
return
}

// Inform the client of the default session settings.
connHandler, retErr = c.sendInitialConnData(ctx, sqlServer, onDefaultIntSizeChange, sessionID)
if retErr != nil {
return
}
// Signal the connection was established to the authenticator.
ac.AuthOK(ctx)
ac.LogAuthOK(ctx)

// We count the connection establish latency until we are ready to
// serve a SQL query. It includes the time it takes to authenticate and
// send the initial ReadyForQuery message.
duration := timeutil.Since(c.startTime).Nanoseconds()
c.metrics.ConnLatency.RecordValue(duration)

// Mark the authentication as succeeded in case a panic
// is thrown below and we need to report to the client
// using the defer above.
authOK = true

// Now actually process commands.
reservedOwned = false // We're about to pass ownership away.
retErr = sqlServer.ServeConn(
ctx,
connHandler,
reserved,
c.cancelConn,
)
}

func (c *conn) bufferParamStatus(param, value string) error {
Expand Down
54 changes: 25 additions & 29 deletions pkg/sql/pgwire/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"context"
"io"
"net"
"sync"
"sync/atomic"
"time"

Expand Down Expand Up @@ -966,11 +967,9 @@ func (s *Server) serveImpl(
systemIdentity = c.sessionArgs.User
}
authPipe := newAuthPipe(c, logAuthn, authOpt, systemIdentity)
var authenticator authenticatorIO = authPipe

// procCh is the channel on which we'll receive the termination signal from
// the command processor.
var procCh <-chan error
// procWg waits for the command processor to return.
var procWg sync.WaitGroup

// We need a value for the unqualified int size here, but it is controlled
// by a session variable, and this layer doesn't have access to the session
Expand All @@ -983,18 +982,22 @@ func (s *Server) serveImpl(

if !inTestWithoutSQL {
// Spawn the command processing goroutine, which also handles connection
// authentication). It will notify us when it's done through procCh, and
// we'll also interact with the authentication process through ac.
var ac AuthConn = authPipe
procCh = c.processCommandsAsync(
ctx,
authOpt,
ac,
sqlServer,
reserved,
onDefaultIntSizeChange,
sessionID,
)
// authentication). It will notify us when it's done through procWg, and
// we'll also interact with the authentication process through authPipe.
procWg.Add(1)
go func() {
// Inform the connection goroutine.
defer procWg.Done()
c.processCommands(
ctx,
authOpt,
authPipe,
sqlServer,
reserved,
onDefaultIntSizeChange,
sessionID,
)
}()
} else {
// sqlServer == nil means we are in a local test. In this case
// we only need the minimum to make pgx happy.
Expand All @@ -1009,12 +1012,8 @@ func (s *Server) serveImpl(
if err != nil {
return
}
var ac AuthConn = authPipe
// Simulate auth succeeding.
ac.AuthOK(ctx)
dummyCh := make(chan error)
close(dummyCh)
procCh = dummyCh
authPipe.AuthOK(ctx)

if err := c.bufferInitialReadyForQuery(0 /* queryCancelKey */); err != nil {
return
Expand Down Expand Up @@ -1098,13 +1097,13 @@ func (s *Server) serveImpl(
// Pass the data to the authenticator. This hopefully causes it to finish
// authentication in the background and give us an intSizer when we loop
// around.
if err = authenticator.sendPwdData(pwd); err != nil {
if err = authPipe.sendPwdData(pwd); err != nil {
return false, isSimpleQuery, err
}
return false, isSimpleQuery, nil
}
// Wait for the auth result.
if err = authenticator.authResult(); err != nil {
if err = authPipe.authResult(); err != nil {
// The error has already been sent to the client.
return true, isSimpleQuery, nil //nolint:returnerrcheck
}
Expand Down Expand Up @@ -1225,13 +1224,10 @@ func (s *Server) serveImpl(
// In case the authenticator is blocked on waiting for data from the client,
// tell it that there's no more data coming. This is a no-op if authentication
// was completed already.
authenticator.noMorePwdData()
authPipe.noMorePwdData()

// Wait for the processor goroutine to finish, if it hasn't already. We're
// ignoring the error we get from it, as we have no use for it. It might be a
// connection error, or a context cancelation error case this goroutine is the
// one that triggered the execution to stop.
<-procCh
// Wait for the processor goroutine to finish, if it hasn't already.
procWg.Wait()

if terminateSeen {
return
Expand Down