Skip to content

Commit

Permalink
Merge #106066
Browse files Browse the repository at this point in the history
106066: server: remove unused params r=rafiss a=ecwall

Informs #105448

1) Remove unnecessary `conn` method params.
2) Remove unused `conn.sendErr` execCfg param.

Co-authored-by: Evan Wall <[email protected]>
  • Loading branch information
craig[bot] and ecwall committed Jul 16, 2023
2 parents c609b82 + 822232b commit 400e1ad
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 42 deletions.
14 changes: 7 additions & 7 deletions pkg/sql/pgwire/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (c *conn) handleAuthentication(
tlsState, hbaEntry, authMethod, err := c.findAuthenticationMethod(authOpt)
if err != nil {
ac.LogAuthFailed(ctx, eventpb.AuthFailReason_METHOD_NOT_FOUND, err)
return nil, c.sendError(ctx, execCfg, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification))
return nil, c.sendError(ctx, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification))
}

ac.SetAuthMethod(hbaEntry.Method.String())
Expand All @@ -113,7 +113,7 @@ func (c *conn) handleAuthentication(
connClose = behaviors.ConnClose
if err != nil {
ac.LogAuthFailed(ctx, eventpb.AuthFailReason_UNKNOWN, err)
return connClose, c.sendError(ctx, execCfg, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification))
return connClose, c.sendError(ctx, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification))
}

// Choose the system identity that we'll use below for mapping
Expand All @@ -138,7 +138,7 @@ func (c *conn) handleAuthentication(
if err := c.checkClientUsernameMatchesMapping(ctx, ac, behaviors.MapRole, systemIdentity); err != nil {
log.Warningf(ctx, "unable to map incoming identity %q to any database user: %+v", systemIdentity, err)
ac.LogAuthFailed(ctx, eventpb.AuthFailReason_USER_NOT_FOUND, err)
return connClose, c.sendError(ctx, execCfg, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification))
return connClose, c.sendError(ctx, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification))
}

// Once chooseDbRole() returns, we know that the actual DB username
Expand All @@ -157,7 +157,7 @@ func (c *conn) handleAuthentication(
if err != nil {
log.Warningf(ctx, "user retrieval failed for user=%q: %+v", dbUser, err)
ac.LogAuthFailed(ctx, eventpb.AuthFailReason_USER_RETRIEVAL_ERROR, err)
return connClose, c.sendError(ctx, execCfg, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification))
return connClose, c.sendError(ctx, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification))
}
c.sessionArgs.IsSuperuser = isSuperuser

Expand All @@ -166,12 +166,12 @@ func (c *conn) handleAuthentication(
// If the user does not exist, we show the same error used for invalid
// passwords, to make it harder for an attacker to determine if a user
// exists.
return connClose, c.sendError(ctx, execCfg, pgerror.WithCandidateCode(security.NewErrPasswordUserAuthFailed(dbUser), pgcode.InvalidPassword))
return connClose, c.sendError(ctx, pgerror.WithCandidateCode(security.NewErrPasswordUserAuthFailed(dbUser), pgcode.InvalidPassword))
}

if !canLoginSQL {
ac.LogAuthFailed(ctx, eventpb.AuthFailReason_LOGIN_DISABLED, nil)
return connClose, c.sendError(ctx, execCfg, pgerror.Newf(pgcode.InvalidAuthorizationSpecification, "%s does not have login privilege", dbUser))
return connClose, c.sendError(ctx, pgerror.Newf(pgcode.InvalidAuthorizationSpecification, "%s does not have login privilege", dbUser))
}

// At this point, we know that the requested user exists and is
Expand All @@ -184,7 +184,7 @@ func (c *conn) handleAuthentication(
} else {
err = pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification)
}
return connClose, c.sendError(ctx, execCfg, err)
return connClose, c.sendError(ctx, err)
}

// Add all the defaults to this session's defaults. If there is an
Expand Down
58 changes: 28 additions & 30 deletions pkg/sql/pgwire/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func (c *conn) GetErr() error {
return nil
}

func (c *conn) sendError(ctx context.Context, execCfg *sql.ExecutorConfig, err error) error {
func (c *conn) sendError(ctx context.Context, err error) error {
// We could, but do not, report server-side network errors while
// trying to send the client error. This is because clients that
// receive error payload are highly correlated with clients
Expand Down Expand Up @@ -226,7 +226,7 @@ func (c *conn) processCommandsAsync(

var decrementConnectionCount func()
if decrementConnectionCount, retErr = sqlServer.IncrementConnectionCount(c.sessionArgs); retErr != nil {
_ = c.sendError(ctx, sqlServer.GetExecutorConfig(), retErr)
_ = c.sendError(ctx, retErr)
return
}
defer decrementConnectionCount()
Expand Down Expand Up @@ -473,26 +473,24 @@ func (c *conn) handleSimpleQuery(

// An error is returned iff the statement buffer has been closed. In that case,
// the connection should be considered toast.
func (c *conn) handleParse(
ctx context.Context, buf *pgwirebase.ReadBuffer, nakedIntSize *types.T,
) error {
func (c *conn) handleParse(ctx context.Context, nakedIntSize *types.T) error {
telemetry.Inc(sqltelemetry.ParseRequestCounter)
name, err := buf.GetString()
name, err := c.readBuf.GetString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
query, err := buf.GetString()
query, err := c.readBuf.GetString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
// The client may provide type information for (some of) the placeholders.
numQArgTypes, err := buf.GetUint16()
numQArgTypes, err := c.readBuf.GetUint16()
if err != nil {
return err
}
inTypeHints := make([]oid.Oid, numQArgTypes)
for i := range inTypeHints {
typ, err := buf.GetUint32()
typ, err := c.readBuf.GetUint32()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand Down Expand Up @@ -586,13 +584,13 @@ func (c *conn) handleParse(

// An error is returned iff the statement buffer has been closed. In that case,
// the connection should be considered toast.
func (c *conn) handleDescribe(ctx context.Context, buf *pgwirebase.ReadBuffer) error {
func (c *conn) handleDescribe(ctx context.Context) error {
telemetry.Inc(sqltelemetry.DescribeRequestCounter)
typ, err := buf.GetPrepareType()
typ, err := c.readBuf.GetPrepareType()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
name, err := buf.GetString()
name, err := c.readBuf.GetString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand All @@ -606,13 +604,13 @@ func (c *conn) handleDescribe(ctx context.Context, buf *pgwirebase.ReadBuffer) e

// An error is returned iff the statement buffer has been closed. In that case,
// the connection should be considered toast.
func (c *conn) handleClose(ctx context.Context, buf *pgwirebase.ReadBuffer) error {
func (c *conn) handleClose(ctx context.Context) error {
telemetry.Inc(sqltelemetry.CloseRequestCounter)
typ, err := buf.GetPrepareType()
typ, err := c.readBuf.GetPrepareType()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
name, err := buf.GetString()
name, err := c.readBuf.GetString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand All @@ -632,13 +630,13 @@ var formatCodesAllText = []pgwirebase.FormatCode{pgwirebase.FormatText}
// statement.
// An error is returned iff the statement buffer has been closed. In that case,
// the connection should be considered toast.
func (c *conn) handleBind(ctx context.Context, buf *pgwirebase.ReadBuffer) error {
func (c *conn) handleBind(ctx context.Context) error {
telemetry.Inc(sqltelemetry.BindRequestCounter)
portalName, err := buf.GetString()
portalName, err := c.readBuf.GetString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
statementName, err := buf.GetString()
statementName, err := c.readBuf.GetString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand All @@ -649,7 +647,7 @@ func (c *conn) handleBind(ctx context.Context, buf *pgwirebase.ReadBuffer) error
// specified format code is applied to all arguments; or it can equal the
// actual number of arguments.
// http://www.postgresql.org/docs/current/static/protocol-message-formats.html
numQArgFormatCodes, err := buf.GetUint16()
numQArgFormatCodes, err := c.readBuf.GetUint16()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand All @@ -660,7 +658,7 @@ func (c *conn) handleBind(ctx context.Context, buf *pgwirebase.ReadBuffer) error
qArgFormatCodes = formatCodesAllText
case 1:
// `1` means read one code and apply it to every argument.
ch, err := buf.GetUint16()
ch, err := c.readBuf.GetUint16()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand All @@ -674,21 +672,21 @@ func (c *conn) handleBind(ctx context.Context, buf *pgwirebase.ReadBuffer) error
qArgFormatCodes = make([]pgwirebase.FormatCode, numQArgFormatCodes)
// Read one format code for each argument and apply it to that argument.
for i := range qArgFormatCodes {
ch, err := buf.GetUint16()
ch, err := c.readBuf.GetUint16()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
qArgFormatCodes[i] = pgwirebase.FormatCode(ch)
}
}

numValues, err := buf.GetUint16()
numValues, err := c.readBuf.GetUint16()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
qargs := make([][]byte, numValues)
for i := 0; i < int(numValues); i++ {
plen, err := buf.GetUint32()
plen, err := c.readBuf.GetUint32()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand All @@ -697,7 +695,7 @@ func (c *conn) handleBind(ctx context.Context, buf *pgwirebase.ReadBuffer) error
qargs[i] = nil
continue
}
b, err := buf.GetBytes(int(plen))
b, err := c.readBuf.GetBytes(int(plen))
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand All @@ -711,7 +709,7 @@ func (c *conn) handleBind(ctx context.Context, buf *pgwirebase.ReadBuffer) error
// (if any); or it can equal the actual number of result columns of the
// query.
// http://www.postgresql.org/docs/current/static/protocol-message-formats.html
numColumnFormatCodes, err := buf.GetUint16()
numColumnFormatCodes, err := c.readBuf.GetUint16()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand All @@ -722,7 +720,7 @@ func (c *conn) handleBind(ctx context.Context, buf *pgwirebase.ReadBuffer) error
columnFormatCodes = formatCodesAllText
case 1:
// All columns will use the one specified format.
ch, err := buf.GetUint16()
ch, err := c.readBuf.GetUint16()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand All @@ -736,7 +734,7 @@ func (c *conn) handleBind(ctx context.Context, buf *pgwirebase.ReadBuffer) error
columnFormatCodes = make([]pgwirebase.FormatCode, numColumnFormatCodes)
// Read one format code for each column and apply it to that column.
for i := range columnFormatCodes {
ch, err := buf.GetUint16()
ch, err := c.readBuf.GetUint16()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand All @@ -757,14 +755,14 @@ func (c *conn) handleBind(ctx context.Context, buf *pgwirebase.ReadBuffer) error
// An error is returned iff the statement buffer has been closed. In that case,
// the connection should be considered toast.
func (c *conn) handleExecute(
ctx context.Context, buf *pgwirebase.ReadBuffer, timeReceived time.Time, followedBySync bool,
ctx context.Context, timeReceived time.Time, followedBySync bool,
) error {
telemetry.Inc(sqltelemetry.ExecuteRequestCounter)
portalName, err := buf.GetString()
portalName, err := c.readBuf.GetString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
limit, err := buf.GetUint32()
limit, err := c.readBuf.GetUint32()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/sql/pgwire/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1140,19 +1140,19 @@ func (s *Server) serveImpl(
pgwirebase.ClientMessageType(nextMsgType[0]) == pgwirebase.ClientMsgSync {
followedBySync = true
}
return false, isSimpleQuery, c.handleExecute(ctx, &c.readBuf, timeReceived, followedBySync)
return false, isSimpleQuery, c.handleExecute(ctx, timeReceived, followedBySync)

case pgwirebase.ClientMsgParse:
return false, isSimpleQuery, c.handleParse(ctx, &c.readBuf, parser.NakedIntTypeFromDefaultIntSize(atomic.LoadInt32(atomicUnqualifiedIntSize)))
return false, isSimpleQuery, c.handleParse(ctx, parser.NakedIntTypeFromDefaultIntSize(atomic.LoadInt32(atomicUnqualifiedIntSize)))

case pgwirebase.ClientMsgDescribe:
return false, isSimpleQuery, c.handleDescribe(ctx, &c.readBuf)
return false, isSimpleQuery, c.handleDescribe(ctx)

case pgwirebase.ClientMsgBind:
return false, isSimpleQuery, c.handleBind(ctx, &c.readBuf)
return false, isSimpleQuery, c.handleBind(ctx)

case pgwirebase.ClientMsgClose:
return false, isSimpleQuery, c.handleClose(ctx, &c.readBuf)
return false, isSimpleQuery, c.handleClose(ctx)

case pgwirebase.ClientMsgTerminate:
terminateSeen = true
Expand Down

0 comments on commit 400e1ad

Please sign in to comment.