diff --git a/pkg/ccl/serverccl/statusccl/tenant_status_test.go b/pkg/ccl/serverccl/statusccl/tenant_status_test.go index 28bb645b3695..5859b8b4015a 100644 --- a/pkg/ccl/serverccl/statusccl/tenant_status_test.go +++ b/pkg/ccl/serverccl/statusccl/tenant_status_test.go @@ -943,30 +943,22 @@ func testTenantStatusCancelSessionErrorMessages(t *testing.T, helper serverccl.T testCases := []struct { sessionID string expectedError string - - // This is a temporary assertion. We should always show the following "not found" error messages, - // regardless of admin status, but our current behavior is slightly broken and will be fixed in #77676. - nonAdminSeesError bool }{ { - sessionID: "", - expectedError: "session ID 00000000000000000000000000000000 not found", - nonAdminSeesError: true, + sessionID: "", + expectedError: "session ID 00000000000000000000000000000000 not found", }, { - sessionID: "01", // This query ID claims to have SQL instance ID 1, different from the one we're talking to. - expectedError: "session ID 00000000000000000000000000000001 not found", - nonAdminSeesError: false, + sessionID: "01", // This query ID claims to have SQL instance ID 1, different from the one we're talking to. + expectedError: "session ID 00000000000000000000000000000001 not found", }, { - sessionID: "02", // This query ID claims to have SQL instance ID 2, the instance we're talking to. - expectedError: "session ID 00000000000000000000000000000002 not found", - nonAdminSeesError: false, + sessionID: "02", // This query ID claims to have SQL instance ID 2, the instance we're talking to. + expectedError: "session ID 00000000000000000000000000000002 not found", }, { - sessionID: "42", // This query ID claims to have SQL instance ID 42, which does not exist. - expectedError: "session ID 00000000000000000000000000000042 not found", - nonAdminSeesError: true, + sessionID: "42", // This query ID claims to have SQL instance ID 42, which does not exist. + expectedError: "session ID 00000000000000000000000000000042 not found", }, } @@ -982,12 +974,8 @@ func testTenantStatusCancelSessionErrorMessages(t *testing.T, helper serverccl.T err = client.PostJSONChecked("/_status/cancel_session/0", &serverpb.CancelSessionRequest{ SessionID: sessionID.GetBytes(), }, &resp) - if isAdmin || testCase.nonAdminSeesError { - require.NoError(t, err) - require.Equal(t, testCase.expectedError, resp.Error) - } else { - require.Error(t, err) - } + require.NoError(t, err) + require.Equal(t, testCase.expectedError, resp.Error) }) } }) @@ -1072,36 +1060,27 @@ func testTenantStatusCancelQueryErrorMessages(t *testing.T, helper serverccl.Ten testCases := []struct { queryID string expectedError string - - // This is a temporary assertion. We should always show the following "not found" error messages, - // regardless of admin status, but our current behavior is slightly broken and will be fixed in #77676. - nonAdminSeesError bool }{ { queryID: "BOGUS_QUERY_ID", expectedError: "query ID 00000000000000000000000000000000 malformed: " + "could not decode BOGUS_QUERY_ID as hex: encoding/hex: invalid byte: U+004F 'O'", - nonAdminSeesError: true, }, { - queryID: "", - expectedError: "query ID 00000000000000000000000000000000 not found", - nonAdminSeesError: true, + queryID: "", + expectedError: "query ID 00000000000000000000000000000000 not found", }, { - queryID: "01", // This query ID claims to have SQL instance ID 1, different from the one we're talking to. - expectedError: "query ID 00000000000000000000000000000001 not found", - nonAdminSeesError: false, + queryID: "01", // This query ID claims to have SQL instance ID 1, different from the one we're talking to. + expectedError: "query ID 00000000000000000000000000000001 not found", }, { - queryID: "02", // This query ID claims to have SQL instance ID 2, the instance we're talking to. - expectedError: "query ID 00000000000000000000000000000002 not found", - nonAdminSeesError: false, + queryID: "02", // This query ID claims to have SQL instance ID 2, the instance we're talking to. + expectedError: "query ID 00000000000000000000000000000002 not found", }, { - queryID: "42", // This query ID claims to have SQL instance ID 42, which does not exist. - expectedError: "query ID 00000000000000000000000000000042 not found", - nonAdminSeesError: true, + queryID: "42", // This query ID claims to have SQL instance ID 42, which does not exist. + expectedError: "query ID 00000000000000000000000000000042 not found", }, } @@ -1115,12 +1094,8 @@ func testTenantStatusCancelQueryErrorMessages(t *testing.T, helper serverccl.Ten err := client.PostJSONChecked("/_status/cancel_query/0", &serverpb.CancelQueryRequest{ QueryID: testCase.queryID, }, &resp) - if isAdmin || testCase.nonAdminSeesError { - require.NoError(t, err) - require.Equal(t, testCase.expectedError, resp.Error) - } else { - require.Error(t, err) - } + require.NoError(t, err) + require.Equal(t, testCase.expectedError, resp.Error) }) } }) diff --git a/pkg/server/status.go b/pkg/server/status.go index dfc669db4947..b1e214f6edd0 100644 --- a/pkg/server/status.go +++ b/pkg/server/status.go @@ -270,90 +270,42 @@ func (b *baseStatusServer) getLocalSessions( return userSessions, nil } -type sessionFinder func(sessions []serverpb.Session) (serverpb.Session, error) - -func findSessionBySessionID(sessionID []byte) sessionFinder { - return func(sessions []serverpb.Session) (serverpb.Session, error) { - var session serverpb.Session - for _, s := range sessions { - if bytes.Equal(sessionID, s.ID) { - session = s - break - } - } - if len(session.ID) == 0 { - return session, fmt.Errorf("session ID %s not found", clusterunique.IDFromBytes(sessionID)) - } - return session, nil - } -} - -func findSessionByQueryID(queryID string) sessionFinder { - return func(sessions []serverpb.Session) (serverpb.Session, error) { - var session serverpb.Session - for _, s := range sessions { - for _, q := range s.ActiveQueries { - if queryID == q.ID { - session = s - break - } - } - } - if len(session.ID) == 0 { - return session, fmt.Errorf("query ID %s not found", queryID) - } - return session, nil - } -} - // checkCancelPrivilege returns nil if the user has the necessary cancel action // privileges for a session. This function returns a proper gRPC error status. func (b *baseStatusServer) checkCancelPrivilege( - ctx context.Context, userName username.SQLUsername, findSession sessionFinder, + ctx context.Context, reqUsername username.SQLUsername, sessionUsername username.SQLUsername, ) error { ctx = propagateGatewayMetadata(ctx) ctx = b.AnnotateCtx(ctx) - // reqUser is the user who made the cancellation request. - var reqUser username.SQLUsername - { - sessionUser, isAdmin, err := b.privilegeChecker.getUserAndRole(ctx) - if err != nil { - return serverError(ctx, err) - } - if userName.Undefined() || userName == sessionUser { - reqUser = sessionUser - } else { - // When CANCEL QUERY is run as a SQL statement, sessionUser is always root - // and the user who ran the statement is passed as req.Username. - if !isAdmin { - return errRequiresAdmin - } - reqUser = userName - } + + ctxUsername, isAdmin, err := b.privilegeChecker.getUserAndRole(ctx) + if err != nil { + return serverError(ctx, err) + } + if reqUsername.Undefined() { + reqUsername = ctxUsername + } else if reqUsername != ctxUsername && !isAdmin { + // When CANCEL QUERY is run as a SQL statement, sessionUser is always root + // and the user who ran the statement is passed as req.Username. + return errRequiresAdmin } - hasAdmin, err := b.privilegeChecker.hasAdminRole(ctx, reqUser) + hasAdmin, err := b.privilegeChecker.hasAdminRole(ctx, reqUsername) if err != nil { return serverError(ctx, err) } if !hasAdmin { // Check if the user has permission to see the session. - session, err := findSession(b.sessionRegistry.SerializeAll()) - if err != nil { - return serverError(ctx, err) - } - - sessionUser := username.MakeSQLUsernameFromPreNormalizedString(session.Username) - if sessionUser != reqUser { + if sessionUsername != reqUsername { // Must have CANCELQUERY privilege to cancel other users' // sessions/queries. - hasCancelQuery, err := b.privilegeChecker.hasGlobalPrivilege(ctx, reqUser, privilege.CANCELQUERY) + hasCancelQuery, err := b.privilegeChecker.hasGlobalPrivilege(ctx, reqUsername, privilege.CANCELQUERY) if err != nil { return serverError(ctx, err) } if !hasCancelQuery { - ok, err := b.privilegeChecker.hasRoleOption(ctx, reqUser, roleoption.CANCELQUERY) + ok, err := b.privilegeChecker.hasRoleOption(ctx, reqUsername, roleoption.CANCELQUERY) if err != nil { return serverError(ctx, err) } @@ -362,7 +314,7 @@ func (b *baseStatusServer) checkCancelPrivilege( } } // Non-admins cannot cancel admins' sessions/queries. - isAdminSession, err := b.privilegeChecker.hasAdminRole(ctx, sessionUser) + isAdminSession, err := b.privilegeChecker.hasAdminRole(ctx, sessionUsername) if err != nil { return serverError(ctx, err) } @@ -3063,7 +3015,13 @@ func (s *statusServer) CancelSession( ctx = propagateGatewayMetadata(ctx) ctx = s.AnnotateCtx(ctx) - sessionID := clusterunique.IDFromBytes(req.SessionID) + sessionIDBytes := req.SessionID + if len(sessionIDBytes) != 16 { + return &serverpb.CancelSessionResponse{ + Error: fmt.Sprintf("session ID %v malformed", sessionIDBytes), + }, nil + } + sessionID := clusterunique.IDFromBytes(sessionIDBytes) nodeID := sessionID.GetNodeID() local := nodeID == int32(s.serverIterator.getID()) if !local { @@ -3071,8 +3029,7 @@ func (s *statusServer) CancelSession( if err != nil { if errors.Is(err, sqlinstance.NonExistentInstanceError) { return &serverpb.CancelSessionResponse{ - Canceled: false, - Error: fmt.Sprintf("session ID %s not found", sessionID), + Error: fmt.Sprintf("session ID %s not found", sessionID), }, nil } return nil, serverError(ctx, err) @@ -3085,17 +3042,21 @@ func (s *statusServer) CancelSession( return nil, status.Errorf(codes.InvalidArgument, err.Error()) } - if err := s.checkCancelPrivilege(ctx, reqUsername, findSessionBySessionID(req.SessionID)); err != nil { + session, ok := s.sessionRegistry.GetSessionByID(sessionID) + if !ok { + return &serverpb.CancelSessionResponse{ + Error: fmt.Sprintf("session ID %s not found", sessionID), + }, nil + } + + if err := s.checkCancelPrivilege(ctx, reqUsername, session.BaseSessionUser()); err != nil { // NB: not using serverError() here since the priv checker // already returns a proper gRPC error status. return nil, err } - r, err := s.sessionRegistry.CancelSession(req.SessionID) - if err != nil { - return nil, serverError(ctx, err) - } - return r, nil + session.CancelSession() + return &serverpb.CancelSessionResponse{Canceled: true}, nil } // CancelQuery responds to a query cancellation request, and cancels @@ -3109,8 +3070,7 @@ func (s *statusServer) CancelQuery( queryID, err := clusterunique.IDFromString(req.QueryID) if err != nil { return &serverpb.CancelQueryResponse{ - Canceled: false, - Error: errors.Wrapf(err, "query ID %s malformed", queryID).Error(), + Error: errors.Wrapf(err, "query ID %s malformed", queryID).Error(), }, nil } @@ -3122,8 +3082,7 @@ func (s *statusServer) CancelQuery( if err != nil { if errors.Is(err, sqlinstance.NonExistentInstanceError) { return &serverpb.CancelQueryResponse{ - Canceled: false, - Error: fmt.Sprintf("query ID %s not found", queryID), + Error: fmt.Sprintf("query ID %s not found", queryID), }, nil } return nil, serverError(ctx, err) @@ -3136,18 +3095,23 @@ func (s *statusServer) CancelQuery( return nil, status.Errorf(codes.InvalidArgument, err.Error()) } - if err := s.checkCancelPrivilege(ctx, reqUsername, findSessionByQueryID(req.QueryID)); err != nil { + session, ok := s.sessionRegistry.GetSessionByQueryID(queryID) + if !ok { + return &serverpb.CancelQueryResponse{ + Error: fmt.Sprintf("query ID %s not found", queryID), + }, nil + } + + if err := s.checkCancelPrivilege(ctx, reqUsername, session.BaseSessionUser()); err != nil { // NB: not using serverError() here since the priv checker // already returns a proper gRPC error status. return nil, err } - output := &serverpb.CancelQueryResponse{} - output.Canceled, err = s.sessionRegistry.CancelQuery(req.QueryID) - if err != nil { - output.Error = err.Error() - } - return output, nil + isCanceled := session.CancelQuery(queryID) + return &serverpb.CancelQueryResponse{ + Canceled: isCanceled, + }, nil } // CancelQueryByKey responds to a pgwire query cancellation request, and cancels @@ -3184,12 +3148,18 @@ func (s *statusServer) CancelQueryByKey( }() if local { - resp = &serverpb.CancelQueryByKeyResponse{} - resp.Canceled, err = s.sessionRegistry.CancelQueryByKey(req.CancelQueryKey) - if err != nil { - resp.Error = err.Error() + cancelQueryKey := req.CancelQueryKey + session, ok := s.sessionRegistry.GetSessionByCancelKey(cancelQueryKey) + if !ok { + return &serverpb.CancelQueryByKeyResponse{ + Error: fmt.Sprintf("session for cancel key %d not found", cancelQueryKey), + }, nil } - return resp, nil + + isCanceled := session.CancelActiveQueries() + return &serverpb.CancelQueryByKeyResponse{ + Canceled: isCanceled, + }, nil } // This request needs to be forwarded to another node. diff --git a/pkg/sql/conn_executor.go b/pkg/sql/conn_executor.go index 70c8cf522d83..cc3c035ccf3b 100644 --- a/pkg/sql/conn_executor.go +++ b/pkg/sql/conn_executor.go @@ -3137,8 +3137,16 @@ func (ex *connExecutor) initStatementResult( return nil } -// cancelQuery is part of the registrySession interface. -func (ex *connExecutor) cancelQuery(queryID clusterunique.ID) bool { +// hasQuery is part of the RegistrySession interface. +func (ex *connExecutor) hasQuery(queryID clusterunique.ID) bool { + ex.mu.RLock() + defer ex.mu.RUnlock() + _, exists := ex.mu.ActiveQueries[queryID] + return exists +} + +// CancelQuery is part of the RegistrySession interface. +func (ex *connExecutor) CancelQuery(queryID clusterunique.ID) bool { ex.mu.Lock() defer ex.mu.Unlock() if queryMeta, exists := ex.mu.ActiveQueries[queryID]; exists { @@ -3148,8 +3156,8 @@ func (ex *connExecutor) cancelQuery(queryID clusterunique.ID) bool { return false } -// cancelCurrentQueries is part of the registrySession interface. -func (ex *connExecutor) cancelCurrentQueries() bool { +// CancelActiveQueries is part of the RegistrySession interface. +func (ex *connExecutor) CancelActiveQueries() bool { ex.mu.Lock() defer ex.mu.Unlock() canceled := false @@ -3160,8 +3168,8 @@ func (ex *connExecutor) cancelCurrentQueries() bool { return canceled } -// cancelSession is part of the registrySession interface. -func (ex *connExecutor) cancelSession() { +// CancelSession is part of the RegistrySession interface. +func (ex *connExecutor) CancelSession() { if ex.onCancelSession == nil { return } @@ -3169,12 +3177,17 @@ func (ex *connExecutor) cancelSession() { ex.onCancelSession() } -// user is part of the registrySession interface. +// user is part of the RegistrySession interface. func (ex *connExecutor) user() username.SQLUsername { return ex.sessionData().User() } -// serialize is part of the registrySession interface. +// BaseSessionUser is part of the RegistrySession interface. +func (ex *connExecutor) BaseSessionUser() username.SQLUsername { + return ex.sessionDataStack.Base().SessionUser() +} + +// serialize is part of the RegistrySession interface. func (ex *connExecutor) serialize() serverpb.Session { ex.mu.RLock() defer ex.mu.RUnlock() diff --git a/pkg/sql/conn_executor_exec.go b/pkg/sql/conn_executor_exec.go index a410f3866f3b..8edb0095a068 100644 --- a/pkg/sql/conn_executor_exec.go +++ b/pkg/sql/conn_executor_exec.go @@ -149,7 +149,7 @@ func (ex *connExecutor) execStmt( // Cancel the session if the idle time exceeds the idle in session timeout. ex.mu.IdleInSessionTimeout = timeout{time.AfterFunc( ex.sessionData().IdleInSessionTimeout, - ex.cancelSession, + ex.CancelSession, )} } @@ -162,7 +162,7 @@ func (ex *connExecutor) execStmt( default: ex.mu.IdleInTransactionSessionTimeout = timeout{time.AfterFunc( ex.sessionData().IdleInTransactionSessionTimeout, - ex.cancelSession, + ex.CancelSession, )} } } diff --git a/pkg/sql/exec_util.go b/pkg/sql/exec_util.go index 1e223a6cec5b..cc75ad84f3bb 100644 --- a/pkg/sql/exec_util.go +++ b/pkg/sql/exec_util.go @@ -2068,8 +2068,8 @@ type SessionArgs struct { type SessionRegistry struct { mu struct { syncutil.RWMutex - sessionsByID map[clusterunique.ID]registrySession - sessionsByCancelKey map[pgwirecancel.BackendKeyData]registrySession + sessionsByID map[clusterunique.ID]RegistrySession + sessionsByCancelKey map[pgwirecancel.BackendKeyData]RegistrySession } } @@ -2077,31 +2077,40 @@ type SessionRegistry struct { // of sessions. func NewSessionRegistry() *SessionRegistry { r := SessionRegistry{} - r.mu.sessionsByID = make(map[clusterunique.ID]registrySession) - r.mu.sessionsByCancelKey = make(map[pgwirecancel.BackendKeyData]registrySession) + r.mu.sessionsByID = make(map[clusterunique.ID]RegistrySession) + r.mu.sessionsByCancelKey = make(map[pgwirecancel.BackendKeyData]RegistrySession) return &r } -func (r *SessionRegistry) getSessionByID(id clusterunique.ID) (registrySession, bool) { +func (r *SessionRegistry) GetSessionByID(sessionID clusterunique.ID) (RegistrySession, bool) { r.mu.RLock() defer r.mu.RUnlock() - session, ok := r.mu.sessionsByID[id] + session, ok := r.mu.sessionsByID[sessionID] return session, ok } -func (r *SessionRegistry) getSessionByCancelKey( +func (r *SessionRegistry) GetSessionByQueryID(queryID clusterunique.ID) (RegistrySession, bool) { + for _, session := range r.getSessions() { + if session.hasQuery(queryID) { + return session, true + } + } + return nil, false +} + +func (r *SessionRegistry) GetSessionByCancelKey( cancelKey pgwirecancel.BackendKeyData, -) (registrySession, bool) { +) (RegistrySession, bool) { r.mu.RLock() defer r.mu.RUnlock() session, ok := r.mu.sessionsByCancelKey[cancelKey] return session, ok } -func (r *SessionRegistry) getSessions() []registrySession { +func (r *SessionRegistry) getSessions() []RegistrySession { r.mu.RLock() defer r.mu.RUnlock() - sessions := make([]registrySession, 0, len(r.mu.sessionsByID)) + sessions := make([]RegistrySession, 0, len(r.mu.sessionsByID)) for _, session := range r.mu.sessionsByID { sessions = append(sessions, session) } @@ -2109,7 +2118,7 @@ func (r *SessionRegistry) getSessions() []registrySession { } func (r *SessionRegistry) register( - id clusterunique.ID, queryCancelKey pgwirecancel.BackendKeyData, s registrySession, + id clusterunique.ID, queryCancelKey pgwirecancel.BackendKeyData, s RegistrySession, ) { r.mu.Lock() defer r.mu.Unlock() @@ -2126,65 +2135,22 @@ func (r *SessionRegistry) deregister( delete(r.mu.sessionsByCancelKey, queryCancelKey) } -type registrySession interface { +type RegistrySession interface { user() username.SQLUsername - cancelQuery(queryID clusterunique.ID) bool - cancelCurrentQueries() bool - cancelSession() + // BaseSessionUser returns the base session's username. + BaseSessionUser() username.SQLUsername + hasQuery(queryID clusterunique.ID) bool + // CancelQuery cancels the query specified by queryID if it exists. + CancelQuery(queryID clusterunique.ID) bool + // CancelActiveQueries cancels all currently active queries. + CancelActiveQueries() bool + // CancelSession cancels the session. + CancelSession() // serialize serializes a Session into a serverpb.Session // that can be served over RPC. serialize() serverpb.Session } -// CancelQuery looks up the associated query in the session registry and cancels -// it. The caller is responsible for all permission checks. -func (r *SessionRegistry) CancelQuery(queryIDStr string) (bool, error) { - queryID, err := clusterunique.IDFromString(queryIDStr) - if err != nil { - return false, errors.Wrapf(err, "query ID %s malformed", queryID) - } - - for _, session := range r.getSessions() { - if session.cancelQuery(queryID) { - return true, nil - } - } - - return false, fmt.Errorf("query ID %s not found", queryID) -} - -// CancelQueryByKey looks up the associated query in the session registry and -// cancels it. -func (r *SessionRegistry) CancelQueryByKey( - queryCancelKey pgwirecancel.BackendKeyData, -) (canceled bool, err error) { - session, ok := r.getSessionByCancelKey(queryCancelKey) - if !ok { - return false, fmt.Errorf("session for cancel key %d not found", queryCancelKey) - } - return session.cancelCurrentQueries(), nil -} - -// CancelSession looks up the specified session in the session registry and -// cancels it. The caller is responsible for all permission checks. -func (r *SessionRegistry) CancelSession( - sessionIDBytes []byte, -) (*serverpb.CancelSessionResponse, error) { - if len(sessionIDBytes) != 16 { - return nil, errors.Errorf("invalid non-16-byte UUID %v", sessionIDBytes) - } - sessionID := clusterunique.IDFromBytes(sessionIDBytes) - - session, ok := r.getSessionByID(sessionID) - if !ok { - return &serverpb.CancelSessionResponse{ - Error: fmt.Sprintf("session ID %s not found", sessionID), - }, nil - } - session.cancelSession() - return &serverpb.CancelSessionResponse{Canceled: true}, nil -} - // SerializeAll returns a slice of all sessions in the registry converted to // serverpb.Sessions. func (r *SessionRegistry) SerializeAll() []serverpb.Session {