From b8518eb037505e73f0e770420e354546f3fdca1f Mon Sep 17 00:00:00 2001 From: Raphael 'kena' Poss Date: Thu, 2 Feb 2023 19:31:43 +0100 Subject: [PATCH] server: only forward the SQL identity in gRPC metadata Prior to this patch, we were forwarding any and all gRPC metadata during a RPC fanout. This was creating doubt and confusion, about how much data is really important/useful to forward. Analysis suggests we only care about the SQL user identity resulting from HTTP authentication. So this patch limits the forwarding to just that information. This specialization makes the forwarding logic easier to understand. This patch additionally renames functions as follows: | Old name | New name | |---------------------------------|---------------------------------------| | `userFromContext` | `userFromIncomingRPCContext` | | `getSQLUsername` | `userFromHTTPAuthInfoContext` | | `apiToOutgoingGatewayCtx` | `forwardHTTPAuthInfoToRPCCalls` | | `forwardAuthenticationMetadata` | `translateHTTPAuthInfoToGRPCMetadata` | | `propagateGatewayMetadata` | `forwardSQLIdentityThroughRPCCalls` | Release note: None --- pkg/server/admin.go | 34 +++--- pkg/server/api_v2.go | 10 +- pkg/server/api_v2_auth.go | 18 +-- pkg/server/api_v2_ranges.go | 8 +- pkg/server/api_v2_sql.go | 2 +- pkg/server/api_v2_sql_schema.go | 14 +-- pkg/server/authentication.go | 96 ++++++++++++++- pkg/server/combined_statement_stats.go | 4 +- pkg/server/fanout_clients.go | 2 +- pkg/server/grpc_gateway.go | 2 +- pkg/server/index_usage_stats.go | 10 +- pkg/server/index_usage_stats_test.go | 2 +- pkg/server/server_http.go | 7 +- pkg/server/sql_stats.go | 2 +- pkg/server/statement_diagnostics_requests.go | 8 +- pkg/server/statements.go | 2 +- pkg/server/status.go | 122 +++++++------------ pkg/server/user.go | 2 +- 18 files changed, 192 insertions(+), 153 deletions(-) diff --git a/pkg/server/admin.go b/pkg/server/admin.go index f7b1faa90bf7..44736c039a43 100644 --- a/pkg/server/admin.go +++ b/pkg/server/admin.go @@ -361,7 +361,7 @@ func (s *adminServer) Databases( ) (_ *serverpb.DatabasesResponse, retErr error) { ctx = s.AnnotateCtx(ctx) - sessionUser, err := userFromContext(ctx) + sessionUser, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -431,7 +431,7 @@ func (s *adminServer) DatabaseDetails( ctx context.Context, req *serverpb.DatabaseDetailsRequest, ) (_ *serverpb.DatabaseDetailsResponse, retErr error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -799,7 +799,7 @@ func (s *adminServer) TableDetails( ctx context.Context, req *serverpb.TableDetailsRequest, ) (_ *serverpb.TableDetailsResponse, retErr error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -1202,7 +1202,7 @@ func (s *adminServer) TableStats( ) (*serverpb.TableStatsResponse, error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -1443,7 +1443,7 @@ func (s *adminServer) Users( ctx context.Context, req *serverpb.UsersRequest, ) (_ *serverpb.UsersResponse, retErr error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -1647,7 +1647,7 @@ func (s *adminServer) RangeLog( ctx = s.AnnotateCtx(ctx) // Range keys, even when pretty-printed, contain PII. - user, err := userFromContext(ctx) + user, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, err } @@ -1881,7 +1881,7 @@ func (s *adminServer) SetUIData( ) (*serverpb.SetUIDataResponse, error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -1920,7 +1920,7 @@ func (s *adminServer) GetUIData( ) (*serverpb.GetUIDataResponse, error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -2188,7 +2188,7 @@ func getLivenessResponse( func (s *adminServer) Liveness( ctx context.Context, req *serverpb.LivenessRequest, ) (*serverpb.LivenessResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) return s.sqlServer.tenantConnect.Liveness(ctx, req) @@ -2210,7 +2210,7 @@ func (s *adminServer) Jobs( ) (_ *serverpb.JobsResponse, retErr error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -2405,7 +2405,7 @@ func (s *adminServer) Job( ) (_ *serverpb.JobResponse, retErr error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -2466,7 +2466,7 @@ func (s *adminServer) Locations( ctx = s.AnnotateCtx(ctx) // Require authentication. - _, err := userFromContext(ctx) + _, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -2536,7 +2536,7 @@ func (s *adminServer) QueryPlan( ) (*serverpb.QueryPlanResponse, error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -2579,7 +2579,7 @@ func (s *adminServer) QueryPlan( // getStatementBundle retrieves the statement bundle with the given id and // writes it out as an attachment. func (s *adminServer) getStatementBundle(ctx context.Context, id int64, w http.ResponseWriter) { - sessionUser, err := userFromContext(ctx) + sessionUser, err := userFromIncomingRPCContext(ctx) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -2904,7 +2904,7 @@ func (s *adminServer) DataDistribution( return nil, err } - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -3109,7 +3109,7 @@ func (s *adminServer) dataDistributionHelper( func (s *systemAdminServer) EnqueueRange( ctx context.Context, req *serverpb.EnqueueRangeRequest, ) (*serverpb.EnqueueRangeResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.requireAdminUser(ctx); err != nil { @@ -3981,7 +3981,7 @@ func (c *adminPrivilegeChecker) requireViewDebugPermission(ctx context.Context) func (c *adminPrivilegeChecker) getUserAndRole( ctx context.Context, ) (userName username.SQLUsername, isAdmin bool, err error) { - userName, err = userFromContext(ctx) + userName, err = userFromIncomingRPCContext(ctx) if err != nil { return userName, false, err } diff --git a/pkg/server/api_v2.go b/pkg/server/api_v2.go index f794f49684b2..e627f378ab38 100644 --- a/pkg/server/api_v2.go +++ b/pkg/server/api_v2.go @@ -44,7 +44,6 @@ import ( "strconv" "github.com/cockroachdb/cockroach/pkg/kv" - "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/server/telemetry" "github.com/cockroachdb/cockroach/pkg/sql/roleoption" @@ -70,13 +69,6 @@ func writeJSONResponse(ctx context.Context, w http.ResponseWriter, code int, pay _, _ = w.Write(res) } -// Returns a SQL username from the request context of a route requiring login. -// Only use in routes that require login (requiresAuth = true in its route -// definition). -func getSQLUsername(ctx context.Context) username.SQLUsername { - return username.MakeSQLUsernameFromPreNormalizedString(ctx.Value(webSessionUserKey{}).(string)) -} - type ApiV2System interface { health(w http.ResponseWriter, r *http.Request) listNodes(w http.ResponseWriter, r *http.Request) @@ -325,7 +317,7 @@ func (a *apiV2Server) listSessions(w http.ResponseWriter, r *http.Request) { reqExcludeClosed := r.URL.Query().Get("exclude_closed_sessions") == "true" req := &serverpb.ListSessionsRequest{Username: reqUsername, ExcludeClosedSessions: reqExcludeClosed} response := &listSessionsResponse{} - outgoingCtx := apiToOutgoingGatewayCtx(ctx, r) + outgoingCtx := forwardHTTPAuthInfoToRPCCalls(ctx, r) responseProto, pagState, err := a.status.listSessionsHelper(outgoingCtx, req, limit, start) if err != nil { diff --git a/pkg/server/api_v2_auth.go b/pkg/server/api_v2_auth.go index 877a7e638789..54d11ff9981e 100644 --- a/pkg/server/api_v2_auth.go +++ b/pkg/server/api_v2_auth.go @@ -26,7 +26,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/protoutil" "github.com/cockroachdb/errors" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) @@ -353,13 +352,11 @@ func (a *authenticationV2Mux) ServeHTTP(w http.ResponseWriter, req *http.Request } // Valid session found, or insecure. Set the username in the request context, // so child http.Handlers can access it. - ctx := req.Context() - ctx = context.WithValue(ctx, webSessionUserKey{}, u) + var sessionID int64 if cookie != nil { - ctx = context.WithValue(ctx, webSessionIDKey{}, cookie.ID) + sessionID = cookie.ID } - req = req.WithContext(ctx) - + req = req.WithContext(contextWithHTTPAuthInfo(req.Context(), u, sessionID)) a.inner.ServeHTTP(w, req) } @@ -443,8 +440,7 @@ func (r *roleAuthorizationMux) hasRoleOption( func (r *roleAuthorizationMux) ServeHTTP(w http.ResponseWriter, req *http.Request) { // The username is set in authenticationV2Mux, and must correspond with a // logged-in user. - username := username.MakeSQLUsernameFromPreNormalizedString( - req.Context().Value(webSessionUserKey{}).(string)) + username := userFromHTTPAuthInfoContext(req.Context()) if role, err := r.getRoleForUser(req.Context(), username); err != nil || role < r.role { if err != nil { apiV2InternalError(req.Context(), err, w) @@ -465,9 +461,3 @@ func (r *roleAuthorizationMux) ServeHTTP(w http.ResponseWriter, req *http.Reques } r.inner.ServeHTTP(w, req) } - -// apiToOutgoingGatewayCtx converts an HTTP API (v1 or v2) context, to one that -// can issue outgoing RPC requests under the same logged-in user. -func apiToOutgoingGatewayCtx(ctx context.Context, r *http.Request) context.Context { - return metadata.NewOutgoingContext(ctx, forwardAuthenticationMetadata(ctx, r)) -} diff --git a/pkg/server/api_v2_ranges.go b/pkg/server/api_v2_ranges.go index 0a20c6699392..787df05f47a1 100644 --- a/pkg/server/api_v2_ranges.go +++ b/pkg/server/api_v2_ranges.go @@ -106,7 +106,7 @@ type nodesResponse struct { func (a *apiV2SystemServer) listNodes(w http.ResponseWriter, r *http.Request) { ctx := r.Context() limit, offset := getSimplePaginationValues(r) - ctx = apiToOutgoingGatewayCtx(ctx, r) + ctx = forwardHTTPAuthInfoToRPCCalls(ctx, r) nodes, next, err := a.systemStatus.nodesHelper(ctx, limit, offset) if err != nil { @@ -195,7 +195,7 @@ type rangeResponse struct { // "$ref": "#/definitions/rangeResponse" func (a *apiV2Server) listRange(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - ctx = apiToOutgoingGatewayCtx(ctx, r) + ctx = forwardHTTPAuthInfoToRPCCalls(ctx, r) vars := mux.Vars(r) rangeID, err := strconv.ParseInt(vars["range_id"], 10, 64) if err != nil { @@ -378,7 +378,7 @@ type nodeRangesResponse struct { // "$ref": "#/definitions/nodeRangesResponse" func (a *apiV2SystemServer) listNodeRanges(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - ctx = apiToOutgoingGatewayCtx(ctx, r) + ctx = forwardHTTPAuthInfoToRPCCalls(ctx, r) vars := mux.Vars(r) nodeIDStr := vars["node_id"] if nodeIDStr != "local" { @@ -497,7 +497,7 @@ type hotRangeInfo struct { // "$ref": "#/definitions/hotRangesResponse" func (a *apiV2Server) listHotRanges(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - ctx = apiToOutgoingGatewayCtx(ctx, r) + ctx = forwardHTTPAuthInfoToRPCCalls(ctx, r) nodeIDStr := r.URL.Query().Get("node_id") limit, start := getRPCPaginationValues(r) diff --git a/pkg/server/api_v2_sql.go b/pkg/server/api_v2_sql.go index 0de234773dc6..b73fbc03aa32 100644 --- a/pkg/server/api_v2_sql.go +++ b/pkg/server/api_v2_sql.go @@ -359,7 +359,7 @@ func (a *apiV2Server) execSQL(w http.ResponseWriter, r *http.Request) { } // The SQL username that owns this session. - username := getSQLUsername(ctx) + username := userFromHTTPAuthInfoContext(ctx) options := []isql.TxnOption{ isql.WithPriority(admissionpb.NormalPri), diff --git a/pkg/server/api_v2_sql_schema.go b/pkg/server/api_v2_sql_schema.go index c83ea7bde0b2..f26b905891f1 100644 --- a/pkg/server/api_v2_sql_schema.go +++ b/pkg/server/api_v2_sql_schema.go @@ -62,7 +62,7 @@ type usersResponse struct { func (a *apiV2Server) listUsers(w http.ResponseWriter, r *http.Request) { limit, offset := getSimplePaginationValues(r) ctx := r.Context() - username := getSQLUsername(ctx) + username := userFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) query := `SELECT username FROM system.users WHERE "isRole" = false ORDER BY username` @@ -149,7 +149,7 @@ type eventsResponse struct { func (a *apiV2Server) listEvents(w http.ResponseWriter, r *http.Request) { limit, offset := getSimplePaginationValues(r) ctx := r.Context() - username := getSQLUsername(ctx) + username := userFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) queryValues := r.URL.Query() @@ -213,7 +213,7 @@ type databasesResponse struct { func (a *apiV2Server) listDatabases(w http.ResponseWriter, r *http.Request) { limit, offset := getSimplePaginationValues(r) ctx := r.Context() - username := getSQLUsername(ctx) + username := userFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) var resp databasesResponse @@ -263,7 +263,7 @@ type databaseDetailsResponse struct { // description: Database not found func (a *apiV2Server) databaseDetails(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - username := getSQLUsername(ctx) + username := userFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) pathVars := mux.Vars(r) req := &serverpb.DatabaseDetailsRequest{ @@ -337,7 +337,7 @@ type databaseGrantsResponse struct { func (a *apiV2Server) databaseGrants(w http.ResponseWriter, r *http.Request) { ctx := r.Context() limit, offset := getSimplePaginationValues(r) - username := getSQLUsername(ctx) + username := userFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) pathVars := mux.Vars(r) req := &serverpb.DatabaseDetailsRequest{ @@ -412,7 +412,7 @@ type databaseTablesResponse struct { func (a *apiV2Server) databaseTables(w http.ResponseWriter, r *http.Request) { ctx := r.Context() limit, offset := getSimplePaginationValues(r) - username := getSQLUsername(ctx) + username := userFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) pathVars := mux.Vars(r) req := &serverpb.DatabaseDetailsRequest{ @@ -473,7 +473,7 @@ type tableDetailsResponse serverpb.TableDetailsResponse // description: Database or table not found func (a *apiV2Server) tableDetails(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - username := getSQLUsername(ctx) + username := userFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) pathVars := mux.Vars(r) req := &serverpb.TableDetailsRequest{ diff --git a/pkg/server/authentication.go b/pkg/server/authentication.go index ac15f7d70552..784d9d256ee8 100644 --- a/pkg/server/authentication.go +++ b/pkg/server/authentication.go @@ -572,10 +572,8 @@ const webSessionIDKeyStr = "websessionid" func (am *authenticationMux) ServeHTTP(w http.ResponseWriter, req *http.Request) { username, cookie, err := am.getSession(w, req) if err == nil { - ctx := req.Context() - ctx = context.WithValue(ctx, webSessionUserKey{}, username) - ctx = context.WithValue(ctx, webSessionIDKey{}, cookie.ID) - req = req.WithContext(ctx) + req = req.WithContext( + contextWithHTTPAuthInfo(req.Context(), username, cookie.ID)) } else if !am.allowAnonymous { if log.V(1) { log.Infof(req.Context(), "web session error: %v", err) @@ -669,7 +667,42 @@ func authenticationHeaderMatcher(key string) (string, bool) { return fmt.Sprintf("%s%s", gwruntime.MetadataHeaderPrefix, key), true } -func forwardAuthenticationMetadata(ctx context.Context, _ *http.Request) metadata.MD { +// contextWithHTTPAuthInfo embeds the HTTP authentication details into +// a go context. Meant for use with userFromHTTPAuthInfoContext(). +func contextWithHTTPAuthInfo( + ctx context.Context, username string, sessionID int64, +) context.Context { + ctx = context.WithValue(ctx, webSessionUserKey{}, username) + if sessionID != 0 { + ctx = context.WithValue(ctx, webSessionIDKey{}, sessionID) + } + return ctx +} + +// userFromHTTPAuthInfoContext returns a SQL username from the request +// context of a HTTP route requiring login. Only use in routes that require +// login (e.g. requiresAuth = true in the API v2 route definition). +// +// Do not use this function in _RPC_ API handlers. These access their +// SQL identity via the RPC incoming context. See +// userFromIncomingRPCContext(). +func userFromHTTPAuthInfoContext(ctx context.Context) username.SQLUsername { + return username.MakeSQLUsernameFromPreNormalizedString(ctx.Value(webSessionUserKey{}).(string)) +} + +// maybeUserFromHTTPAuthInfoContext is like userFromHTTPAuthInfoContext but +// it returns a boolean false if there is no user in the context. +func maybeUserFromHTTPAuthInfoContext(ctx context.Context) (username.SQLUsername, bool) { + if u := ctx.Value(webSessionUserKey{}); u != nil { + return username.MakeSQLUsernameFromPreNormalizedString(u.(string)), true + } + return username.SQLUsername{}, false +} + +// translateHTTPAuthInfoToGRPCMetadata translates the context.Value +// that results from HTTP authentication into gRPC metadata suitable +// for use by RPC API handlers. +func translateHTTPAuthInfoToGRPCMetadata(ctx context.Context, _ *http.Request) metadata.MD { md := metadata.MD{} if user := ctx.Value(webSessionUserKey{}); user != nil { md.Set(webSessionUserKeyStr, user.(string)) @@ -680,6 +713,59 @@ func forwardAuthenticationMetadata(ctx context.Context, _ *http.Request) metadat return md } +// forwardSQLIdentityThroughRPCCalls forwards the SQL identity of the +// original request (as populated by translateHTTPAuthInfoToGRPCMetadata in +// grpc-gateway) so it remains available to the remote node handling +// the request. +func forwardSQLIdentityThroughRPCCalls(ctx context.Context) context.Context { + if md, ok := grpcutil.FastFromIncomingContext(ctx); ok { + if u, ok := md[webSessionUserKeyStr]; ok { + return metadata.NewOutgoingContext(ctx, metadata.MD{webSessionUserKeyStr: u}) + } + } + return ctx +} + +// forwardHTTPAuthInfoToRPCCalls converts an HTTP API (v1 or v2) context, to one that +// can issue outgoing RPC requests under the same logged-in user. +func forwardHTTPAuthInfoToRPCCalls(ctx context.Context, r *http.Request) context.Context { + md := translateHTTPAuthInfoToGRPCMetadata(ctx, r) + return metadata.NewOutgoingContext(ctx, md) +} + +// userFromIncomingRPCContext is to be used in RPC API handlers. It +// assumes the SQL identity was populated in the context implicitly by +// gRPC via translateHTTPAuthInfoToGRPCMetadata(), or explicitly via +// forwardHTTPAuthInfoToRPCCalls() or +// forwardSQLIdentityThroughRPCCalls(). +// +// Do not use this function in _HTTP_ API handlers. Those access their +// SQL identity via a special context key. See +// userFromHTTPAuthInfoContext(). +func userFromIncomingRPCContext(ctx context.Context) (res username.SQLUsername, err error) { + md, ok := grpcutil.FastFromIncomingContext(ctx) + if !ok { + return username.RootUserName(), nil + } + usernames, ok := md[webSessionUserKeyStr] + if !ok { + // If the incoming context has metadata but no attached web session user, + // it's a gRPC / internal SQL connection which has root on the cluster. + // This assumption is a historical hiccup, and would be best described + // as a bug. See: https://github.com/cockroachdb/cockroach/issues/45018 + return username.RootUserName(), nil + } + if len(usernames) != 1 { + log.Warningf(ctx, "context's incoming metadata contains unexpected number of usernames: %+v ", md) + return res, fmt.Errorf( + "context's incoming metadata contains unexpected number of usernames: %+v ", md) + } + // At this point the user is already logged in, so we can assume + // the username has been normalized already. + username := username.MakeSQLUsernameFromPreNormalizedString(usernames[0]) + return username, nil +} + // sessionCookieValue defines the data needed to construct the // aggregate session cookie in the order provided. type sessionCookieValue struct { diff --git a/pkg/server/combined_statement_stats.go b/pkg/server/combined_statement_stats.go index 74359fa4dd77..3e7bac42d6fe 100644 --- a/pkg/server/combined_statement_stats.go +++ b/pkg/server/combined_statement_stats.go @@ -44,7 +44,7 @@ func getTimeFromSeconds(seconds int64) *time.Time { func (s *statusServer) CombinedStatementStats( ctx context.Context, req *serverpb.CombinedStatementsStatsRequest, ) (*serverpb.StatementsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { @@ -340,7 +340,7 @@ func collectCombinedTransactions( func (s *statusServer) StatementDetails( ctx context.Context, req *serverpb.StatementDetailsRequest, ) (*serverpb.StatementDetailsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { diff --git a/pkg/server/fanout_clients.go b/pkg/server/fanout_clients.go index c55b86d8cf46..7d1c5dfc1307 100644 --- a/pkg/server/fanout_clients.go +++ b/pkg/server/fanout_clients.go @@ -228,7 +228,7 @@ func (k kvFanoutClient) dialNode(ctx context.Context, serverID serverID) (*grpc. } func (k kvFanoutClient) listNodes(ctx context.Context) (*serverpb.NodesResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = k.ambientCtx.AnnotateCtx(ctx) statuses, _, err := getNodeStatuses(ctx, k.db, 0, 0) diff --git a/pkg/server/grpc_gateway.go b/pkg/server/grpc_gateway.go index 5bf6686b2c5d..6eed6aa808d8 100644 --- a/pkg/server/grpc_gateway.go +++ b/pkg/server/grpc_gateway.go @@ -72,7 +72,7 @@ func configureGRPCGateway( gwruntime.WithMarshalerOption(httputil.ProtoContentType, protopb), gwruntime.WithMarshalerOption(httputil.AltProtoContentType, protopb), gwruntime.WithOutgoingHeaderMatcher(authenticationHeaderMatcher), - gwruntime.WithMetadata(forwardAuthenticationMetadata), + gwruntime.WithMetadata(translateHTTPAuthInfoToGRPCMetadata), ) gwCtx, gwCancel := context.WithCancel(ambientCtx.AnnotateCtx(context.Background())) stopper.AddCloser(stop.CloserFn(gwCancel)) diff --git a/pkg/server/index_usage_stats.go b/pkg/server/index_usage_stats.go index f374ec455a90..c3c2f90f41e4 100644 --- a/pkg/server/index_usage_stats.go +++ b/pkg/server/index_usage_stats.go @@ -40,7 +40,7 @@ import ( func (s *statusServer) IndexUsageStatistics( ctx context.Context, req *serverpb.IndexUsageStatisticsRequest, ) (*serverpb.IndexUsageStatisticsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { @@ -130,7 +130,7 @@ func indexUsageStatsLocal( func (s *statusServer) ResetIndexUsageStats( ctx context.Context, req *serverpb.ResetIndexUsageStatsRequest, ) (*serverpb.ResetIndexUsageStatsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -207,7 +207,7 @@ func (s *statusServer) ResetIndexUsageStats( func (s *statusServer) TableIndexStats( ctx context.Context, req *serverpb.TableIndexStatsRequest, ) (*serverpb.TableIndexStatsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { @@ -228,7 +228,7 @@ func getTableIndexUsageStats( st *cluster.Settings, execConfig *sql.ExecutorConfig, ) (*serverpb.TableIndexStatsResponse, error) { - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, err } @@ -384,7 +384,7 @@ func getDatabaseIndexRecommendations( return []*serverpb.IndexRecommendation{}, nil } - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return []*serverpb.IndexRecommendation{}, err } diff --git a/pkg/server/index_usage_stats_test.go b/pkg/server/index_usage_stats_test.go index 939bd5a386ba..1ce2b2c4ab21 100644 --- a/pkg/server/index_usage_stats_test.go +++ b/pkg/server/index_usage_stats_test.go @@ -360,7 +360,7 @@ CREATE TABLE schema.test_table ( `) // Get Table IDs. - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) require.NoError(t, err) testCases := []struct { diff --git a/pkg/server/server_http.go b/pkg/server/server_http.go index 4d25064b26e0..1713c1b686f7 100644 --- a/pkg/server/server_http.go +++ b/pkg/server/server_http.go @@ -111,8 +111,9 @@ func (s *httpServer) setupRoutes( NodeID: s.cfg.IDContainer, OIDC: oidc, GetUser: func(ctx context.Context) *string { - if u, ok := ctx.Value(webSessionUserKey{}).(string); ok { - return &u + if user, ok := maybeUserFromHTTPAuthInfoContext(ctx); ok { + ustring := user.Normalized() + return &ustring } return nil }, @@ -183,7 +184,7 @@ func makeAdminAuthzCheckHandler( return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { // Retrieve the username embedded in the grpc metadata, if any. // This will be provided by the authenticationMux. - md := forwardAuthenticationMetadata(req.Context(), req) + md := translateHTTPAuthInfoToGRPCMetadata(req.Context(), req) authCtx := metadata.NewIncomingContext(req.Context(), md) // Check the privileges of the requester. err := adminAuthzCheck.requireViewDebugPermission(authCtx) diff --git a/pkg/server/sql_stats.go b/pkg/server/sql_stats.go index 6cb43885ebb3..411166d3bf9c 100644 --- a/pkg/server/sql_stats.go +++ b/pkg/server/sql_stats.go @@ -23,7 +23,7 @@ import ( func (s *statusServer) ResetSQLStats( ctx context.Context, req *serverpb.ResetSQLStatsRequest, ) (*serverpb.ResetSQLStatsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { diff --git a/pkg/server/statement_diagnostics_requests.go b/pkg/server/statement_diagnostics_requests.go index 5feff91cb49b..724f612705e6 100644 --- a/pkg/server/statement_diagnostics_requests.go +++ b/pkg/server/statement_diagnostics_requests.go @@ -71,7 +71,7 @@ func (diagnostics *stmtDiagnostics) toProto() serverpb.StatementDiagnostics { func (s *statusServer) CreateStatementDiagnosticsReport( ctx context.Context, req *serverpb.CreateStatementDiagnosticsReportRequest, ) (*serverpb.CreateStatementDiagnosticsReportResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewActivityAndNoViewActivityRedactedPermission(ctx); err != nil { @@ -103,7 +103,7 @@ func (s *statusServer) CreateStatementDiagnosticsReport( func (s *statusServer) CancelStatementDiagnosticsReport( ctx context.Context, req *serverpb.CancelStatementDiagnosticsReportRequest, ) (*serverpb.CancelStatementDiagnosticsReportResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewActivityAndNoViewActivityRedactedPermission(ctx); err != nil { @@ -127,7 +127,7 @@ func (s *statusServer) CancelStatementDiagnosticsReport( func (s *statusServer) StatementDiagnosticsRequests( ctx context.Context, _ *serverpb.StatementDiagnosticsReportsRequest, ) (*serverpb.StatementDiagnosticsReportsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewActivityAndNoViewActivityRedactedPermission(ctx); err != nil { @@ -221,7 +221,7 @@ func (s *statusServer) StatementDiagnosticsRequests( func (s *statusServer) StatementDiagnostics( ctx context.Context, req *serverpb.StatementDiagnosticsRequest, ) (*serverpb.StatementDiagnosticsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewActivityAndNoViewActivityRedactedPermission(ctx); err != nil { diff --git a/pkg/server/statements.go b/pkg/server/statements.go index cfa55475b0b2..41a8ff5fc7be 100644 --- a/pkg/server/statements.go +++ b/pkg/server/statements.go @@ -33,7 +33,7 @@ func (s *statusServer) Statements( return s.CombinedStatementStats(ctx, &combinedRequest) } - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { diff --git a/pkg/server/status.go b/pkg/server/status.go index 758e61a767c6..54aade2dee81 100644 --- a/pkg/server/status.go +++ b/pkg/server/status.go @@ -85,7 +85,6 @@ import ( raft "go.etcd.io/raft/v3" "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) @@ -133,13 +132,6 @@ type metricMarshaler interface { ScrapeIntoPrometheus(pm *metric.PrometheusExporter) } -func propagateGatewayMetadata(ctx context.Context) context.Context { - if md, ok := grpcutil.FastFromIncomingContext(ctx); ok { - return metadata.NewOutgoingContext(ctx, md) - } - return ctx -} - // baseStatusServer implements functionality shared by the tenantStatusServer // and the full statusServer. type baseStatusServer struct { @@ -171,7 +163,7 @@ func isInternalAppName(app string) bool { func (b *baseStatusServer) getLocalSessions( ctx context.Context, req *serverpb.ListSessionsRequest, ) ([]serverpb.Session, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = b.AnnotateCtx(ctx) sessionUser, isAdmin, err := b.privilegeChecker.getUserAndRole(ctx) @@ -275,7 +267,7 @@ func (b *baseStatusServer) getLocalSessions( func (b *baseStatusServer) checkCancelPrivilege( ctx context.Context, reqUsername username.SQLUsername, sessionUsername username.SQLUsername, ) error { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = b.AnnotateCtx(ctx) ctxUsername, isCtxAdmin, err := b.privilegeChecker.getUserAndRole(ctx) @@ -343,7 +335,7 @@ func (b *baseStatusServer) checkCancelPrivilege( func (b *baseStatusServer) ListLocalContentionEvents( ctx context.Context, _ *serverpb.ListContentionEventsRequest, ) (*serverpb.ListContentionEventsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = b.AnnotateCtx(ctx) if err := b.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { @@ -360,7 +352,7 @@ func (b *baseStatusServer) ListLocalContentionEvents( func (b *baseStatusServer) ListLocalDistSQLFlows( ctx context.Context, _ *serverpb.ListDistSQLFlowsRequest, ) (*serverpb.ListDistSQLFlowsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = b.AnnotateCtx(ctx) if err := b.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { @@ -690,7 +682,7 @@ func (s *statusServer) dialNode( func (s *systemStatusServer) Gossip( ctx context.Context, req *serverpb.GossipRequest, ) (*gossip.InfoStatus, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -718,7 +710,7 @@ func (s *systemStatusServer) Gossip( func (s *systemStatusServer) EngineStats( ctx context.Context, req *serverpb.EngineStatsRequest, ) (*serverpb.EngineStatsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -761,7 +753,7 @@ func (s *systemStatusServer) EngineStats( func (s *systemStatusServer) Allocator( ctx context.Context, req *serverpb.AllocatorRequest, ) (*serverpb.AllocatorResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx); err != nil { @@ -854,7 +846,7 @@ func recordedSpansToTraceEvents(spans []tracingpb.RecordedSpan) []*serverpb.Trac func (s *systemStatusServer) AllocatorRange( ctx context.Context, req *serverpb.AllocatorRangeRequest, ) (*serverpb.AllocatorRangeResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx) @@ -943,7 +935,7 @@ func (s *systemStatusServer) AllocatorRange( func (s *statusServer) Certificates( ctx context.Context, req *serverpb.CertificatesRequest, ) (*serverpb.CertificatesResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -1064,7 +1056,7 @@ func extractCertFields(contents []byte, details *serverpb.CertificateDetails) er func (s *statusServer) Details( ctx context.Context, req *serverpb.DetailsRequest, ) (*serverpb.DetailsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -1105,7 +1097,7 @@ func (s *statusServer) Details( func (s *statusServer) GetFiles( ctx context.Context, req *serverpb.GetFilesRequest, ) (*serverpb.GetFilesResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -1161,7 +1153,7 @@ func checkFilePattern(pattern string) error { func (s *statusServer) LogFilesList( ctx context.Context, req *serverpb.LogFilesListRequest, ) (*serverpb.LogFilesListResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -1196,7 +1188,7 @@ func (s *statusServer) LogFilesList( func (s *statusServer) LogFile( ctx context.Context, req *serverpb.LogFileRequest, ) (*serverpb.LogEntriesResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -1294,7 +1286,7 @@ func parseInt64WithDefault(s string, defaultValue int64) (int64, error) { func (s *statusServer) Logs( ctx context.Context, req *serverpb.LogsRequest, ) (*serverpb.LogEntriesResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -1382,7 +1374,7 @@ func (s *statusServer) Logs( func (s *statusServer) Stacks( ctx context.Context, req *serverpb.StacksRequest, ) (*serverpb.JSONResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -1415,7 +1407,7 @@ func (s *statusServer) Stacks( func (s *statusServer) Profile( ctx context.Context, req *serverpb.ProfileRequest, ) (*serverpb.JSONResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -1494,7 +1486,7 @@ func regionsResponseFromNodesResponse(nr *serverpb.NodesResponse) *serverpb.Regi func (s *statusServer) NodesList( ctx context.Context, _ *serverpb.NodesListRequest, ) (*serverpb.NodesListResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) // The node status contains details about the command line, network @@ -1520,7 +1512,7 @@ func (s *statusServer) NodesList( func (s *statusServer) Nodes( ctx context.Context, req *serverpb.NodesRequest, ) (*serverpb.NodesResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { @@ -1559,7 +1551,7 @@ func (s *statusServer) Nodes( func (s *systemStatusServer) Nodes( ctx context.Context, req *serverpb.NodesRequest, ) (*serverpb.NodesResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx) @@ -1577,7 +1569,7 @@ func (s *systemStatusServer) Nodes( func (s *statusServer) NodesUI( ctx context.Context, req *serverpb.NodesRequest, ) (*serverpb.NodesResponseExternal, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) return s.sqlServer.tenantConnect.NodesUI(ctx, req) @@ -1586,7 +1578,7 @@ func (s *statusServer) NodesUI( func (s *systemStatusServer) NodesUI( ctx context.Context, req *serverpb.NodesRequest, ) (*serverpb.NodesResponseExternal, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) hasViewClusterMetadata := false @@ -1757,7 +1749,7 @@ func getNodeStatuses( func (s *systemStatusServer) nodesHelper( ctx context.Context, limit, offset int, ) (*serverpb.NodesResponse, int, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) statuses, next, err := getNodeStatuses(ctx, s.db, limit, offset) @@ -1780,7 +1772,7 @@ func (s *systemStatusServer) nodesHelper( func (s *statusServer) Node( ctx context.Context, req *serverpb.NodeRequest, ) (*statuspb.NodeStatus, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) // The node status contains details about the command line, network @@ -1822,7 +1814,7 @@ func (s *statusServer) nodeStatus( func (s *statusServer) NodeUI( ctx context.Context, req *serverpb.NodeRequest, ) (*serverpb.NodeResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) // The node status contains details about the command line, network @@ -1846,7 +1838,7 @@ func (s *statusServer) NodeUI( func (s *statusServer) Metrics( ctx context.Context, req *serverpb.MetricsRequest, ) (*serverpb.JSONResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) nodeID, local, err := s.parseNodeID(req.NodeId) @@ -1872,7 +1864,7 @@ func (s *statusServer) Metrics( func (s *systemStatusServer) RaftDebug( ctx context.Context, req *serverpb.RaftDebugRequest, ) (*serverpb.RaftDebugResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx); err != nil { @@ -1998,7 +1990,7 @@ func (s *systemStatusServer) Ranges( func (s *systemStatusServer) rangesHelper( ctx context.Context, req *serverpb.RangesRequest, limit, offset int, ) (*serverpb.RangesResponse, int, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx) @@ -2201,7 +2193,7 @@ func (s *systemStatusServer) rangesHelper( func (t *statusServer) TenantRanges( ctx context.Context, req *serverpb.TenantRangesRequest, ) (*serverpb.TenantRangesResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = t.AnnotateCtx(ctx) // The tenant range report contains replica metadata which is admin-only. @@ -2215,7 +2207,7 @@ func (t *statusServer) TenantRanges( func (s *systemStatusServer) TenantRanges( ctx context.Context, req *serverpb.TenantRangesRequest, ) (*serverpb.TenantRangesResponse, error) { - propagateGatewayMetadata(ctx) + forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { return nil, err @@ -2366,7 +2358,7 @@ func (s *systemStatusServer) TenantRanges( func (s *systemStatusServer) HotRanges( ctx context.Context, req *serverpb.HotRangesRequest, ) (*serverpb.HotRangesResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx); err != nil { @@ -2445,7 +2437,7 @@ func (t *statusServer) HotRangesV2( func (s *systemStatusServer) HotRangesV2( ctx context.Context, req *serverpb.HotRangesRequest, ) (*serverpb.HotRangesResponseV2, error) { - ctx = s.AnnotateCtx(propagateGatewayMetadata(ctx)) + ctx = s.AnnotateCtx(forwardSQLIdentityThroughRPCCalls(ctx)) err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx) if err != nil { @@ -2718,7 +2710,7 @@ func (s *statusServer) KeyVisSamples( func (s *statusServer) Range( ctx context.Context, req *serverpb.RangeRequest, ) (*serverpb.RangeResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx); err != nil { @@ -3009,7 +3001,7 @@ func (s *statusServer) listSessionsHelper( func (s *statusServer) ListSessions( ctx context.Context, req *serverpb.ListSessionsRequest, ) (*serverpb.ListSessionsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, _, err := s.privilegeChecker.getUserAndRole(ctx); err != nil { @@ -3030,7 +3022,7 @@ func (s *statusServer) ListSessions( func (s *statusServer) CancelSession( ctx context.Context, req *serverpb.CancelSessionRequest, ) (*serverpb.CancelSessionResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) sessionIDBytes := req.SessionID @@ -3082,7 +3074,7 @@ func (s *statusServer) CancelSession( func (s *statusServer) CancelQuery( ctx context.Context, req *serverpb.CancelQueryRequest, ) (*serverpb.CancelQueryResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) queryID, err := clusterunique.IDFromString(req.QueryID) @@ -3181,7 +3173,7 @@ func (s *statusServer) CancelQueryByKey( } // This request needs to be forwarded to another node. - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) client, err := s.dialNode(ctx, roachpb.NodeID(req.SQLInstanceID)) if err != nil { @@ -3195,7 +3187,7 @@ func (s *statusServer) CancelQueryByKey( func (s *statusServer) ListContentionEvents( ctx context.Context, req *serverpb.ListContentionEventsRequest, ) (*serverpb.ListContentionEventsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) // Check permissions early to avoid fan-out to all nodes. @@ -3242,7 +3234,7 @@ func (s *statusServer) ListContentionEvents( func (s *statusServer) ListDistSQLFlows( ctx context.Context, request *serverpb.ListDistSQLFlowsRequest, ) (*serverpb.ListDistSQLFlowsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) // Check permissions early to avoid fan-out to all nodes. @@ -3334,7 +3326,7 @@ func mergeDistSQLRemoteFlows(a, b []serverpb.DistSQLRemoteFlows) []serverpb.Dist func (s *statusServer) ListExecutionInsights( ctx context.Context, req *serverpb.ListExecutionInsightsRequest, ) (*serverpb.ListExecutionInsightsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) // Check permissions early to avoid fan-out to all nodes. @@ -3396,7 +3388,7 @@ func (s *statusServer) ListExecutionInsights( func (s *systemStatusServer) SpanStats( ctx context.Context, req *serverpb.SpanStatsRequest, ) (*serverpb.SpanStatsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -3440,7 +3432,7 @@ func (s *systemStatusServer) SpanStats( func (s *statusServer) Diagnostics( ctx context.Context, req *serverpb.DiagnosticsRequest, ) (*diagnosticspb.DiagnosticReport, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) nodeID, local, err := s.parseNodeID(req.NodeId) if err != nil { @@ -3462,7 +3454,7 @@ func (s *statusServer) Diagnostics( func (s *systemStatusServer) Stores( ctx context.Context, req *serverpb.StoresRequest, ) (*serverpb.StoresResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx); err != nil { @@ -3548,28 +3540,6 @@ func marshalJSONResponse(value interface{}) (*serverpb.JSONResponse, error) { return &serverpb.JSONResponse{Data: data}, nil } -func userFromContext(ctx context.Context) (res username.SQLUsername, err error) { - md, ok := grpcutil.FastFromIncomingContext(ctx) - if !ok { - return username.RootUserName(), nil - } - usernames, ok := md[webSessionUserKeyStr] - if !ok { - // If the incoming context has metadata but no attached web session user, - // it's a gRPC / internal SQL connection which has root on the cluster. - return username.RootUserName(), nil - } - if len(usernames) != 1 { - log.Warningf(ctx, "context's incoming metadata contains unexpected number of usernames: %+v ", md) - return res, fmt.Errorf( - "context's incoming metadata contains unexpected number of usernames: %+v ", md) - } - // At this point the user is already logged in, so we can assume - // the username has been normalized already. - username := username.MakeSQLUsernameFromPreNormalizedString(usernames[0]) - return username, nil -} - type systemInfoOnce struct { once sync.Once info serverpb.SystemInfo @@ -3612,7 +3582,7 @@ func (si *systemInfoOnce) systemInfo(ctx context.Context) serverpb.SystemInfo { func (s *statusServer) JobRegistryStatus( ctx context.Context, req *serverpb.JobRegistryStatusRequest, ) (*serverpb.JobRegistryStatusResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -3651,7 +3621,7 @@ func (s *statusServer) JobRegistryStatus( func (s *statusServer) JobStatus( ctx context.Context, req *serverpb.JobStatusRequest, ) (*serverpb.JobStatusResponse, error) { - ctx = s.AnnotateCtx(propagateGatewayMetadata(ctx)) + ctx = s.AnnotateCtx(forwardSQLIdentityThroughRPCCalls(ctx)) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { // NB: not using serverError() here since the priv checker @@ -3684,7 +3654,7 @@ func (s *statusServer) JobStatus( func (s *statusServer) TxnIDResolution( ctx context.Context, req *serverpb.TxnIDResolutionRequest, ) (*serverpb.TxnIDResolutionResponse, error) { - ctx = s.AnnotateCtx(propagateGatewayMetadata(ctx)) + ctx = s.AnnotateCtx(forwardSQLIdentityThroughRPCCalls(ctx)) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { return nil, err } @@ -3708,7 +3678,7 @@ func (s *statusServer) TxnIDResolution( func (s *statusServer) TransactionContentionEvents( ctx context.Context, req *serverpb.TransactionContentionEventsRequest, ) (*serverpb.TransactionContentionEventsResponse, error) { - ctx = s.AnnotateCtx(propagateGatewayMetadata(ctx)) + ctx = s.AnnotateCtx(forwardSQLIdentityThroughRPCCalls(ctx)) if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { return nil, err diff --git a/pkg/server/user.go b/pkg/server/user.go index 13943c243232..2cd2b1d8e454 100644 --- a/pkg/server/user.go +++ b/pkg/server/user.go @@ -22,7 +22,7 @@ import ( func (s *baseStatusServer) UserSQLRoles( ctx context.Context, req *serverpb.UserSQLRolesRequest, ) (_ *serverpb.UserSQLRolesResponse, retErr error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) username, isAdmin, err := s.privilegeChecker.getUserAndRole(ctx)