From d9955820fe054aebed038958c77727a262086a3f Mon Sep 17 00:00:00 2001 From: Evan Wall Date: Sat, 17 Jun 2023 09:28:11 -0400 Subject: [PATCH 1/2] server: remove unnecessary conn method params Informs #105448 Some methods were being passed conn.readBuf instead of accessing it via the receiver. Release note: None --- pkg/sql/pgwire/conn.go | 54 +++++++++++++++++++--------------------- pkg/sql/pgwire/server.go | 10 ++++---- 2 files changed, 31 insertions(+), 33 deletions(-) diff --git a/pkg/sql/pgwire/conn.go b/pkg/sql/pgwire/conn.go index 46323b449b7d..82fa5a03e191 100644 --- a/pkg/sql/pgwire/conn.go +++ b/pkg/sql/pgwire/conn.go @@ -473,26 +473,24 @@ func (c *conn) handleSimpleQuery( // An error is returned iff the statement buffer has been closed. In that case, // the connection should be considered toast. -func (c *conn) handleParse( - ctx context.Context, buf *pgwirebase.ReadBuffer, nakedIntSize *types.T, -) error { +func (c *conn) handleParse(ctx context.Context, nakedIntSize *types.T) error { telemetry.Inc(sqltelemetry.ParseRequestCounter) - name, err := buf.GetString() + name, err := c.readBuf.GetString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } - query, err := buf.GetString() + query, err := c.readBuf.GetString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } // The client may provide type information for (some of) the placeholders. - numQArgTypes, err := buf.GetUint16() + numQArgTypes, err := c.readBuf.GetUint16() if err != nil { return err } inTypeHints := make([]oid.Oid, numQArgTypes) for i := range inTypeHints { - typ, err := buf.GetUint32() + typ, err := c.readBuf.GetUint32() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -586,13 +584,13 @@ func (c *conn) handleParse( // An error is returned iff the statement buffer has been closed. In that case, // the connection should be considered toast. -func (c *conn) handleDescribe(ctx context.Context, buf *pgwirebase.ReadBuffer) error { +func (c *conn) handleDescribe(ctx context.Context) error { telemetry.Inc(sqltelemetry.DescribeRequestCounter) - typ, err := buf.GetPrepareType() + typ, err := c.readBuf.GetPrepareType() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } - name, err := buf.GetString() + name, err := c.readBuf.GetString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -606,13 +604,13 @@ func (c *conn) handleDescribe(ctx context.Context, buf *pgwirebase.ReadBuffer) e // An error is returned iff the statement buffer has been closed. In that case, // the connection should be considered toast. -func (c *conn) handleClose(ctx context.Context, buf *pgwirebase.ReadBuffer) error { +func (c *conn) handleClose(ctx context.Context) error { telemetry.Inc(sqltelemetry.CloseRequestCounter) - typ, err := buf.GetPrepareType() + typ, err := c.readBuf.GetPrepareType() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } - name, err := buf.GetString() + name, err := c.readBuf.GetString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -632,13 +630,13 @@ var formatCodesAllText = []pgwirebase.FormatCode{pgwirebase.FormatText} // statement. // An error is returned iff the statement buffer has been closed. In that case, // the connection should be considered toast. -func (c *conn) handleBind(ctx context.Context, buf *pgwirebase.ReadBuffer) error { +func (c *conn) handleBind(ctx context.Context) error { telemetry.Inc(sqltelemetry.BindRequestCounter) - portalName, err := buf.GetString() + portalName, err := c.readBuf.GetString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } - statementName, err := buf.GetString() + statementName, err := c.readBuf.GetString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -649,7 +647,7 @@ func (c *conn) handleBind(ctx context.Context, buf *pgwirebase.ReadBuffer) error // specified format code is applied to all arguments; or it can equal the // actual number of arguments. // http://www.postgresql.org/docs/current/static/protocol-message-formats.html - numQArgFormatCodes, err := buf.GetUint16() + numQArgFormatCodes, err := c.readBuf.GetUint16() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -660,7 +658,7 @@ func (c *conn) handleBind(ctx context.Context, buf *pgwirebase.ReadBuffer) error qArgFormatCodes = formatCodesAllText case 1: // `1` means read one code and apply it to every argument. - ch, err := buf.GetUint16() + ch, err := c.readBuf.GetUint16() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -674,7 +672,7 @@ func (c *conn) handleBind(ctx context.Context, buf *pgwirebase.ReadBuffer) error qArgFormatCodes = make([]pgwirebase.FormatCode, numQArgFormatCodes) // Read one format code for each argument and apply it to that argument. for i := range qArgFormatCodes { - ch, err := buf.GetUint16() + ch, err := c.readBuf.GetUint16() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -682,13 +680,13 @@ func (c *conn) handleBind(ctx context.Context, buf *pgwirebase.ReadBuffer) error } } - numValues, err := buf.GetUint16() + numValues, err := c.readBuf.GetUint16() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } qargs := make([][]byte, numValues) for i := 0; i < int(numValues); i++ { - plen, err := buf.GetUint32() + plen, err := c.readBuf.GetUint32() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -697,7 +695,7 @@ func (c *conn) handleBind(ctx context.Context, buf *pgwirebase.ReadBuffer) error qargs[i] = nil continue } - b, err := buf.GetBytes(int(plen)) + b, err := c.readBuf.GetBytes(int(plen)) if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -711,7 +709,7 @@ func (c *conn) handleBind(ctx context.Context, buf *pgwirebase.ReadBuffer) error // (if any); or it can equal the actual number of result columns of the // query. // http://www.postgresql.org/docs/current/static/protocol-message-formats.html - numColumnFormatCodes, err := buf.GetUint16() + numColumnFormatCodes, err := c.readBuf.GetUint16() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -722,7 +720,7 @@ func (c *conn) handleBind(ctx context.Context, buf *pgwirebase.ReadBuffer) error columnFormatCodes = formatCodesAllText case 1: // All columns will use the one specified format. - ch, err := buf.GetUint16() + ch, err := c.readBuf.GetUint16() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -736,7 +734,7 @@ func (c *conn) handleBind(ctx context.Context, buf *pgwirebase.ReadBuffer) error columnFormatCodes = make([]pgwirebase.FormatCode, numColumnFormatCodes) // Read one format code for each column and apply it to that column. for i := range columnFormatCodes { - ch, err := buf.GetUint16() + ch, err := c.readBuf.GetUint16() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -757,14 +755,14 @@ func (c *conn) handleBind(ctx context.Context, buf *pgwirebase.ReadBuffer) error // An error is returned iff the statement buffer has been closed. In that case, // the connection should be considered toast. func (c *conn) handleExecute( - ctx context.Context, buf *pgwirebase.ReadBuffer, timeReceived time.Time, followedBySync bool, + ctx context.Context, timeReceived time.Time, followedBySync bool, ) error { telemetry.Inc(sqltelemetry.ExecuteRequestCounter) - portalName, err := buf.GetString() + portalName, err := c.readBuf.GetString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } - limit, err := buf.GetUint32() + limit, err := c.readBuf.GetUint32() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } diff --git a/pkg/sql/pgwire/server.go b/pkg/sql/pgwire/server.go index a7b9942ac017..789b0065fe0c 100644 --- a/pkg/sql/pgwire/server.go +++ b/pkg/sql/pgwire/server.go @@ -1140,19 +1140,19 @@ func (s *Server) serveImpl( pgwirebase.ClientMessageType(nextMsgType[0]) == pgwirebase.ClientMsgSync { followedBySync = true } - return false, isSimpleQuery, c.handleExecute(ctx, &c.readBuf, timeReceived, followedBySync) + return false, isSimpleQuery, c.handleExecute(ctx, timeReceived, followedBySync) case pgwirebase.ClientMsgParse: - return false, isSimpleQuery, c.handleParse(ctx, &c.readBuf, parser.NakedIntTypeFromDefaultIntSize(atomic.LoadInt32(atomicUnqualifiedIntSize))) + return false, isSimpleQuery, c.handleParse(ctx, parser.NakedIntTypeFromDefaultIntSize(atomic.LoadInt32(atomicUnqualifiedIntSize))) case pgwirebase.ClientMsgDescribe: - return false, isSimpleQuery, c.handleDescribe(ctx, &c.readBuf) + return false, isSimpleQuery, c.handleDescribe(ctx) case pgwirebase.ClientMsgBind: - return false, isSimpleQuery, c.handleBind(ctx, &c.readBuf) + return false, isSimpleQuery, c.handleBind(ctx) case pgwirebase.ClientMsgClose: - return false, isSimpleQuery, c.handleClose(ctx, &c.readBuf) + return false, isSimpleQuery, c.handleClose(ctx) case pgwirebase.ClientMsgTerminate: terminateSeen = true From 822232b4f4994df376e3780f7e96fa4154414b54 Mon Sep 17 00:00:00 2001 From: Evan Wall Date: Wed, 21 Jun 2023 11:58:06 -0400 Subject: [PATCH 2/2] server: remove unused conn.sendErr execCfg param Informs #105448 Release note: None --- pkg/sql/pgwire/auth.go | 14 +++++++------- pkg/sql/pgwire/conn.go | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pkg/sql/pgwire/auth.go b/pkg/sql/pgwire/auth.go index c218fb80bb8f..084571fa7020 100644 --- a/pkg/sql/pgwire/auth.go +++ b/pkg/sql/pgwire/auth.go @@ -100,7 +100,7 @@ func (c *conn) handleAuthentication( tlsState, hbaEntry, authMethod, err := c.findAuthenticationMethod(authOpt) if err != nil { ac.LogAuthFailed(ctx, eventpb.AuthFailReason_METHOD_NOT_FOUND, err) - return nil, c.sendError(ctx, execCfg, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification)) + return nil, c.sendError(ctx, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification)) } ac.SetAuthMethod(hbaEntry.Method.String()) @@ -113,7 +113,7 @@ func (c *conn) handleAuthentication( connClose = behaviors.ConnClose if err != nil { ac.LogAuthFailed(ctx, eventpb.AuthFailReason_UNKNOWN, err) - return connClose, c.sendError(ctx, execCfg, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification)) + return connClose, c.sendError(ctx, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification)) } // Choose the system identity that we'll use below for mapping @@ -138,7 +138,7 @@ func (c *conn) handleAuthentication( if err := c.checkClientUsernameMatchesMapping(ctx, ac, behaviors.MapRole, systemIdentity); err != nil { log.Warningf(ctx, "unable to map incoming identity %q to any database user: %+v", systemIdentity, err) ac.LogAuthFailed(ctx, eventpb.AuthFailReason_USER_NOT_FOUND, err) - return connClose, c.sendError(ctx, execCfg, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification)) + return connClose, c.sendError(ctx, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification)) } // Once chooseDbRole() returns, we know that the actual DB username @@ -157,7 +157,7 @@ func (c *conn) handleAuthentication( if err != nil { log.Warningf(ctx, "user retrieval failed for user=%q: %+v", dbUser, err) ac.LogAuthFailed(ctx, eventpb.AuthFailReason_USER_RETRIEVAL_ERROR, err) - return connClose, c.sendError(ctx, execCfg, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification)) + return connClose, c.sendError(ctx, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification)) } c.sessionArgs.IsSuperuser = isSuperuser @@ -166,12 +166,12 @@ func (c *conn) handleAuthentication( // If the user does not exist, we show the same error used for invalid // passwords, to make it harder for an attacker to determine if a user // exists. - return connClose, c.sendError(ctx, execCfg, pgerror.WithCandidateCode(security.NewErrPasswordUserAuthFailed(dbUser), pgcode.InvalidPassword)) + return connClose, c.sendError(ctx, pgerror.WithCandidateCode(security.NewErrPasswordUserAuthFailed(dbUser), pgcode.InvalidPassword)) } if !canLoginSQL { ac.LogAuthFailed(ctx, eventpb.AuthFailReason_LOGIN_DISABLED, nil) - return connClose, c.sendError(ctx, execCfg, pgerror.Newf(pgcode.InvalidAuthorizationSpecification, "%s does not have login privilege", dbUser)) + return connClose, c.sendError(ctx, pgerror.Newf(pgcode.InvalidAuthorizationSpecification, "%s does not have login privilege", dbUser)) } // At this point, we know that the requested user exists and is @@ -184,7 +184,7 @@ func (c *conn) handleAuthentication( } else { err = pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification) } - return connClose, c.sendError(ctx, execCfg, err) + return connClose, c.sendError(ctx, err) } // Add all the defaults to this session's defaults. If there is an diff --git a/pkg/sql/pgwire/conn.go b/pkg/sql/pgwire/conn.go index 82fa5a03e191..1c7b8d19e6d0 100644 --- a/pkg/sql/pgwire/conn.go +++ b/pkg/sql/pgwire/conn.go @@ -122,7 +122,7 @@ func (c *conn) GetErr() error { return nil } -func (c *conn) sendError(ctx context.Context, execCfg *sql.ExecutorConfig, err error) error { +func (c *conn) sendError(ctx context.Context, err error) error { // We could, but do not, report server-side network errors while // trying to send the client error. This is because clients that // receive error payload are highly correlated with clients @@ -226,7 +226,7 @@ func (c *conn) processCommandsAsync( var decrementConnectionCount func() if decrementConnectionCount, retErr = sqlServer.IncrementConnectionCount(c.sessionArgs); retErr != nil { - _ = c.sendError(ctx, sqlServer.GetExecutorConfig(), retErr) + _ = c.sendError(ctx, retErr) return } defer decrementConnectionCount()