Skip to content

Commit

Permalink
sql: remove redundant session iteration
Browse files Browse the repository at this point in the history
Fixes #95743

Improves session/query cancelation with the following
1) Replaces session scanning by session ID with map lookup.
2) Replaces active query scanning by query ID with map lookup
   (session containing query to cancel is still scanned for).
3) Does not serialize entire session to get session username or id.

Informs #77676

77676 was closed but some test cases incorrectly mentioned that addressing
77676 fixed them. This PR correctly fixes these test cases.

Release note: None
  • Loading branch information
ecwall committed Jan 26, 2023
1 parent 9caf758 commit 129af11
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 209 deletions.
65 changes: 20 additions & 45 deletions pkg/ccl/serverccl/statusccl/tenant_status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -943,30 +943,22 @@ func testTenantStatusCancelSessionErrorMessages(t *testing.T, helper serverccl.T
testCases := []struct {
sessionID string
expectedError string

// This is a temporary assertion. We should always show the following "not found" error messages,
// regardless of admin status, but our current behavior is slightly broken and will be fixed in #77676.
nonAdminSeesError bool
}{
{
sessionID: "",
expectedError: "session ID 00000000000000000000000000000000 not found",
nonAdminSeesError: true,
sessionID: "",
expectedError: "session ID 00000000000000000000000000000000 not found",
},
{
sessionID: "01", // This query ID claims to have SQL instance ID 1, different from the one we're talking to.
expectedError: "session ID 00000000000000000000000000000001 not found",
nonAdminSeesError: false,
sessionID: "01", // This query ID claims to have SQL instance ID 1, different from the one we're talking to.
expectedError: "session ID 00000000000000000000000000000001 not found",
},
{
sessionID: "02", // This query ID claims to have SQL instance ID 2, the instance we're talking to.
expectedError: "session ID 00000000000000000000000000000002 not found",
nonAdminSeesError: false,
sessionID: "02", // This query ID claims to have SQL instance ID 2, the instance we're talking to.
expectedError: "session ID 00000000000000000000000000000002 not found",
},
{
sessionID: "42", // This query ID claims to have SQL instance ID 42, which does not exist.
expectedError: "session ID 00000000000000000000000000000042 not found",
nonAdminSeesError: true,
sessionID: "42", // This query ID claims to have SQL instance ID 42, which does not exist.
expectedError: "session ID 00000000000000000000000000000042 not found",
},
}

Expand All @@ -982,12 +974,8 @@ func testTenantStatusCancelSessionErrorMessages(t *testing.T, helper serverccl.T
err = client.PostJSONChecked("/_status/cancel_session/0", &serverpb.CancelSessionRequest{
SessionID: sessionID.GetBytes(),
}, &resp)
if isAdmin || testCase.nonAdminSeesError {
require.NoError(t, err)
require.Equal(t, testCase.expectedError, resp.Error)
} else {
require.Error(t, err)
}
require.NoError(t, err)
require.Equal(t, testCase.expectedError, resp.Error)
})
}
})
Expand Down Expand Up @@ -1072,36 +1060,27 @@ func testTenantStatusCancelQueryErrorMessages(t *testing.T, helper serverccl.Ten
testCases := []struct {
queryID string
expectedError string

// This is a temporary assertion. We should always show the following "not found" error messages,
// regardless of admin status, but our current behavior is slightly broken and will be fixed in #77676.
nonAdminSeesError bool
}{
{
queryID: "BOGUS_QUERY_ID",
expectedError: "query ID 00000000000000000000000000000000 malformed: " +
"could not decode BOGUS_QUERY_ID as hex: encoding/hex: invalid byte: U+004F 'O'",
nonAdminSeesError: true,
},
{
queryID: "",
expectedError: "query ID 00000000000000000000000000000000 not found",
nonAdminSeesError: true,
queryID: "",
expectedError: "query ID 00000000000000000000000000000000 not found",
},
{
queryID: "01", // This query ID claims to have SQL instance ID 1, different from the one we're talking to.
expectedError: "query ID 00000000000000000000000000000001 not found",
nonAdminSeesError: false,
queryID: "01", // This query ID claims to have SQL instance ID 1, different from the one we're talking to.
expectedError: "query ID 00000000000000000000000000000001 not found",
},
{
queryID: "02", // This query ID claims to have SQL instance ID 2, the instance we're talking to.
expectedError: "query ID 00000000000000000000000000000002 not found",
nonAdminSeesError: false,
queryID: "02", // This query ID claims to have SQL instance ID 2, the instance we're talking to.
expectedError: "query ID 00000000000000000000000000000002 not found",
},
{
queryID: "42", // This query ID claims to have SQL instance ID 42, which does not exist.
expectedError: "query ID 00000000000000000000000000000042 not found",
nonAdminSeesError: true,
queryID: "42", // This query ID claims to have SQL instance ID 42, which does not exist.
expectedError: "query ID 00000000000000000000000000000042 not found",
},
}

Expand All @@ -1115,12 +1094,8 @@ func testTenantStatusCancelQueryErrorMessages(t *testing.T, helper serverccl.Ten
err := client.PostJSONChecked("/_status/cancel_query/0", &serverpb.CancelQueryRequest{
QueryID: testCase.queryID,
}, &resp)
if isAdmin || testCase.nonAdminSeesError {
require.NoError(t, err)
require.Equal(t, testCase.expectedError, resp.Error)
} else {
require.Error(t, err)
}
require.NoError(t, err)
require.Equal(t, testCase.expectedError, resp.Error)
})
}
})
Expand Down
150 changes: 60 additions & 90 deletions pkg/server/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,90 +270,42 @@ func (b *baseStatusServer) getLocalSessions(
return userSessions, nil
}

type sessionFinder func(sessions []serverpb.Session) (serverpb.Session, error)

func findSessionBySessionID(sessionID []byte) sessionFinder {
return func(sessions []serverpb.Session) (serverpb.Session, error) {
var session serverpb.Session
for _, s := range sessions {
if bytes.Equal(sessionID, s.ID) {
session = s
break
}
}
if len(session.ID) == 0 {
return session, fmt.Errorf("session ID %s not found", clusterunique.IDFromBytes(sessionID))
}
return session, nil
}
}

func findSessionByQueryID(queryID string) sessionFinder {
return func(sessions []serverpb.Session) (serverpb.Session, error) {
var session serverpb.Session
for _, s := range sessions {
for _, q := range s.ActiveQueries {
if queryID == q.ID {
session = s
break
}
}
}
if len(session.ID) == 0 {
return session, fmt.Errorf("query ID %s not found", queryID)
}
return session, nil
}
}

// checkCancelPrivilege returns nil if the user has the necessary cancel action
// privileges for a session. This function returns a proper gRPC error status.
func (b *baseStatusServer) checkCancelPrivilege(
ctx context.Context, userName username.SQLUsername, findSession sessionFinder,
ctx context.Context, reqUsername username.SQLUsername, sessionUsername username.SQLUsername,
) error {
ctx = propagateGatewayMetadata(ctx)
ctx = b.AnnotateCtx(ctx)
// reqUser is the user who made the cancellation request.
var reqUser username.SQLUsername
{
sessionUser, isAdmin, err := b.privilegeChecker.getUserAndRole(ctx)
if err != nil {
return serverError(ctx, err)
}
if userName.Undefined() || userName == sessionUser {
reqUser = sessionUser
} else {
// When CANCEL QUERY is run as a SQL statement, sessionUser is always root
// and the user who ran the statement is passed as req.Username.
if !isAdmin {
return errRequiresAdmin
}
reqUser = userName
}

ctxUsername, isAdmin, err := b.privilegeChecker.getUserAndRole(ctx)
if err != nil {
return serverError(ctx, err)
}
if reqUsername.Undefined() {
reqUsername = ctxUsername
} else if reqUsername != ctxUsername && !isAdmin {
// When CANCEL QUERY is run as a SQL statement, sessionUser is always root
// and the user who ran the statement is passed as req.Username.
return errRequiresAdmin
}

hasAdmin, err := b.privilegeChecker.hasAdminRole(ctx, reqUser)
hasAdmin, err := b.privilegeChecker.hasAdminRole(ctx, reqUsername)
if err != nil {
return serverError(ctx, err)
}

if !hasAdmin {
// Check if the user has permission to see the session.
session, err := findSession(b.sessionRegistry.SerializeAll())
if err != nil {
return serverError(ctx, err)
}

sessionUser := username.MakeSQLUsernameFromPreNormalizedString(session.Username)
if sessionUser != reqUser {
if sessionUsername != reqUsername {
// Must have CANCELQUERY privilege to cancel other users'
// sessions/queries.
hasCancelQuery, err := b.privilegeChecker.hasGlobalPrivilege(ctx, reqUser, privilege.CANCELQUERY)
hasCancelQuery, err := b.privilegeChecker.hasGlobalPrivilege(ctx, reqUsername, privilege.CANCELQUERY)
if err != nil {
return serverError(ctx, err)
}
if !hasCancelQuery {
ok, err := b.privilegeChecker.hasRoleOption(ctx, reqUser, roleoption.CANCELQUERY)
ok, err := b.privilegeChecker.hasRoleOption(ctx, reqUsername, roleoption.CANCELQUERY)
if err != nil {
return serverError(ctx, err)
}
Expand All @@ -362,7 +314,7 @@ func (b *baseStatusServer) checkCancelPrivilege(
}
}
// Non-admins cannot cancel admins' sessions/queries.
isAdminSession, err := b.privilegeChecker.hasAdminRole(ctx, sessionUser)
isAdminSession, err := b.privilegeChecker.hasAdminRole(ctx, sessionUsername)
if err != nil {
return serverError(ctx, err)
}
Expand Down Expand Up @@ -3063,16 +3015,21 @@ func (s *statusServer) CancelSession(
ctx = propagateGatewayMetadata(ctx)
ctx = s.AnnotateCtx(ctx)

sessionID := clusterunique.IDFromBytes(req.SessionID)
sessionIDBytes := req.SessionID
if len(sessionIDBytes) != 16 {
return &serverpb.CancelSessionResponse{
Error: fmt.Sprintf("session ID %v malformed", sessionIDBytes),
}, nil
}
sessionID := clusterunique.IDFromBytes(sessionIDBytes)
nodeID := sessionID.GetNodeID()
local := nodeID == int32(s.serverIterator.getID())
if !local {
status, err := s.dialNode(ctx, roachpb.NodeID(nodeID))
if err != nil {
if errors.Is(err, sqlinstance.NonExistentInstanceError) {
return &serverpb.CancelSessionResponse{
Canceled: false,
Error: fmt.Sprintf("session ID %s not found", sessionID),
Error: fmt.Sprintf("session ID %s not found", sessionID),
}, nil
}
return nil, serverError(ctx, err)
Expand All @@ -3085,17 +3042,21 @@ func (s *statusServer) CancelSession(
return nil, status.Errorf(codes.InvalidArgument, err.Error())
}

if err := s.checkCancelPrivilege(ctx, reqUsername, findSessionBySessionID(req.SessionID)); err != nil {
session, ok := s.sessionRegistry.GetSessionByID(sessionID)
if !ok {
return &serverpb.CancelSessionResponse{
Error: fmt.Sprintf("session ID %s not found", sessionID),
}, nil
}

if err := s.checkCancelPrivilege(ctx, reqUsername, session.BaseSessionUser()); err != nil {
// NB: not using serverError() here since the priv checker
// already returns a proper gRPC error status.
return nil, err
}

r, err := s.sessionRegistry.CancelSession(req.SessionID)
if err != nil {
return nil, serverError(ctx, err)
}
return r, nil
session.CancelSession()
return &serverpb.CancelSessionResponse{Canceled: true}, nil
}

// CancelQuery responds to a query cancellation request, and cancels
Expand All @@ -3109,8 +3070,7 @@ func (s *statusServer) CancelQuery(
queryID, err := clusterunique.IDFromString(req.QueryID)
if err != nil {
return &serverpb.CancelQueryResponse{
Canceled: false,
Error: errors.Wrapf(err, "query ID %s malformed", queryID).Error(),
Error: errors.Wrapf(err, "query ID %s malformed", queryID).Error(),
}, nil
}

Expand All @@ -3122,8 +3082,7 @@ func (s *statusServer) CancelQuery(
if err != nil {
if errors.Is(err, sqlinstance.NonExistentInstanceError) {
return &serverpb.CancelQueryResponse{
Canceled: false,
Error: fmt.Sprintf("query ID %s not found", queryID),
Error: fmt.Sprintf("query ID %s not found", queryID),
}, nil
}
return nil, serverError(ctx, err)
Expand All @@ -3136,18 +3095,23 @@ func (s *statusServer) CancelQuery(
return nil, status.Errorf(codes.InvalidArgument, err.Error())
}

if err := s.checkCancelPrivilege(ctx, reqUsername, findSessionByQueryID(req.QueryID)); err != nil {
session, ok := s.sessionRegistry.GetSessionByQueryID(queryID)
if !ok {
return &serverpb.CancelQueryResponse{
Error: fmt.Sprintf("query ID %s not found", queryID),
}, nil
}

if err := s.checkCancelPrivilege(ctx, reqUsername, session.BaseSessionUser()); err != nil {
// NB: not using serverError() here since the priv checker
// already returns a proper gRPC error status.
return nil, err
}

output := &serverpb.CancelQueryResponse{}
output.Canceled, err = s.sessionRegistry.CancelQuery(req.QueryID)
if err != nil {
output.Error = err.Error()
}
return output, nil
isCanceled := session.CancelQuery(queryID)
return &serverpb.CancelQueryResponse{
Canceled: isCanceled,
}, nil
}

// CancelQueryByKey responds to a pgwire query cancellation request, and cancels
Expand Down Expand Up @@ -3184,12 +3148,18 @@ func (s *statusServer) CancelQueryByKey(
}()

if local {
resp = &serverpb.CancelQueryByKeyResponse{}
resp.Canceled, err = s.sessionRegistry.CancelQueryByKey(req.CancelQueryKey)
if err != nil {
resp.Error = err.Error()
cancelQueryKey := req.CancelQueryKey
session, ok := s.sessionRegistry.GetSessionByCancelKey(cancelQueryKey)
if !ok {
return &serverpb.CancelQueryByKeyResponse{
Error: fmt.Sprintf("session for cancel key %d not found", cancelQueryKey),
}, nil
}
return resp, nil

isCanceled := session.CancelActiveQueries()
return &serverpb.CancelQueryByKeyResponse{
Canceled: isCanceled,
}, nil
}

// This request needs to be forwarded to another node.
Expand Down
Loading

0 comments on commit 129af11

Please sign in to comment.