diff --git a/pkg/server/admin.go b/pkg/server/admin.go index f7b1faa90bf7..44736c039a43 100644 --- a/pkg/server/admin.go +++ b/pkg/server/admin.go @@ -361,7 +361,7 @@ func (s *adminServer) Databases( ) (_ *serverpb.DatabasesResponse, retErr error) { ctx = s.AnnotateCtx(ctx) - sessionUser, err := userFromContext(ctx) + sessionUser, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -431,7 +431,7 @@ func (s *adminServer) DatabaseDetails( ctx context.Context, req *serverpb.DatabaseDetailsRequest, ) (_ *serverpb.DatabaseDetailsResponse, retErr error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -799,7 +799,7 @@ func (s *adminServer) TableDetails( ctx context.Context, req *serverpb.TableDetailsRequest, ) (_ *serverpb.TableDetailsResponse, retErr error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -1202,7 +1202,7 @@ func (s *adminServer) TableStats( ) (*serverpb.TableStatsResponse, error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -1443,7 +1443,7 @@ func (s *adminServer) Users( ctx context.Context, req *serverpb.UsersRequest, ) (_ *serverpb.UsersResponse, retErr error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -1647,7 +1647,7 @@ func (s *adminServer) RangeLog( ctx = s.AnnotateCtx(ctx) // Range keys, even when pretty-printed, contain PII. - user, err := userFromContext(ctx) + user, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, err } @@ -1881,7 +1881,7 @@ func (s *adminServer) SetUIData( ) (*serverpb.SetUIDataResponse, error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -1920,7 +1920,7 @@ func (s *adminServer) GetUIData( ) (*serverpb.GetUIDataResponse, error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -2188,7 +2188,7 @@ func getLivenessResponse( func (s *adminServer) Liveness( ctx context.Context, req *serverpb.LivenessRequest, ) (*serverpb.LivenessResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) return s.sqlServer.tenantConnect.Liveness(ctx, req) @@ -2210,7 +2210,7 @@ func (s *adminServer) Jobs( ) (_ *serverpb.JobsResponse, retErr error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -2405,7 +2405,7 @@ func (s *adminServer) Job( ) (_ *serverpb.JobResponse, retErr error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -2466,7 +2466,7 @@ func (s *adminServer) Locations( ctx = s.AnnotateCtx(ctx) // Require authentication. - _, err := userFromContext(ctx) + _, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -2536,7 +2536,7 @@ func (s *adminServer) QueryPlan( ) (*serverpb.QueryPlanResponse, error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -2579,7 +2579,7 @@ func (s *adminServer) QueryPlan( // getStatementBundle retrieves the statement bundle with the given id and // writes it out as an attachment. func (s *adminServer) getStatementBundle(ctx context.Context, id int64, w http.ResponseWriter) { - sessionUser, err := userFromContext(ctx) + sessionUser, err := userFromIncomingRPCContext(ctx) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -2904,7 +2904,7 @@ func (s *adminServer) DataDistribution( return nil, err } - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, serverError(ctx, err) } @@ -3109,7 +3109,7 @@ func (s *adminServer) dataDistributionHelper( func (s *systemAdminServer) EnqueueRange( ctx context.Context, req *serverpb.EnqueueRangeRequest, ) (*serverpb.EnqueueRangeResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.requireAdminUser(ctx); err != nil { @@ -3981,7 +3981,7 @@ func (c *adminPrivilegeChecker) requireViewDebugPermission(ctx context.Context) func (c *adminPrivilegeChecker) getUserAndRole( ctx context.Context, ) (userName username.SQLUsername, isAdmin bool, err error) { - userName, err = userFromContext(ctx) + userName, err = userFromIncomingRPCContext(ctx) if err != nil { return userName, false, err } diff --git a/pkg/server/api_v2.go b/pkg/server/api_v2.go index f794f49684b2..e627f378ab38 100644 --- a/pkg/server/api_v2.go +++ b/pkg/server/api_v2.go @@ -44,7 +44,6 @@ import ( "strconv" "github.com/cockroachdb/cockroach/pkg/kv" - "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/server/telemetry" "github.com/cockroachdb/cockroach/pkg/sql/roleoption" @@ -70,13 +69,6 @@ func writeJSONResponse(ctx context.Context, w http.ResponseWriter, code int, pay _, _ = w.Write(res) } -// Returns a SQL username from the request context of a route requiring login. -// Only use in routes that require login (requiresAuth = true in its route -// definition). -func getSQLUsername(ctx context.Context) username.SQLUsername { - return username.MakeSQLUsernameFromPreNormalizedString(ctx.Value(webSessionUserKey{}).(string)) -} - type ApiV2System interface { health(w http.ResponseWriter, r *http.Request) listNodes(w http.ResponseWriter, r *http.Request) @@ -325,7 +317,7 @@ func (a *apiV2Server) listSessions(w http.ResponseWriter, r *http.Request) { reqExcludeClosed := r.URL.Query().Get("exclude_closed_sessions") == "true" req := &serverpb.ListSessionsRequest{Username: reqUsername, ExcludeClosedSessions: reqExcludeClosed} response := &listSessionsResponse{} - outgoingCtx := apiToOutgoingGatewayCtx(ctx, r) + outgoingCtx := forwardHTTPAuthInfoToRPCCalls(ctx, r) responseProto, pagState, err := a.status.listSessionsHelper(outgoingCtx, req, limit, start) if err != nil { diff --git a/pkg/server/api_v2_auth.go b/pkg/server/api_v2_auth.go index 877a7e638789..54d11ff9981e 100644 --- a/pkg/server/api_v2_auth.go +++ b/pkg/server/api_v2_auth.go @@ -26,7 +26,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/protoutil" "github.com/cockroachdb/errors" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) @@ -353,13 +352,11 @@ func (a *authenticationV2Mux) ServeHTTP(w http.ResponseWriter, req *http.Request } // Valid session found, or insecure. Set the username in the request context, // so child http.Handlers can access it. - ctx := req.Context() - ctx = context.WithValue(ctx, webSessionUserKey{}, u) + var sessionID int64 if cookie != nil { - ctx = context.WithValue(ctx, webSessionIDKey{}, cookie.ID) + sessionID = cookie.ID } - req = req.WithContext(ctx) - + req = req.WithContext(contextWithHTTPAuthInfo(req.Context(), u, sessionID)) a.inner.ServeHTTP(w, req) } @@ -443,8 +440,7 @@ func (r *roleAuthorizationMux) hasRoleOption( func (r *roleAuthorizationMux) ServeHTTP(w http.ResponseWriter, req *http.Request) { // The username is set in authenticationV2Mux, and must correspond with a // logged-in user. - username := username.MakeSQLUsernameFromPreNormalizedString( - req.Context().Value(webSessionUserKey{}).(string)) + username := userFromHTTPAuthInfoContext(req.Context()) if role, err := r.getRoleForUser(req.Context(), username); err != nil || role < r.role { if err != nil { apiV2InternalError(req.Context(), err, w) @@ -465,9 +461,3 @@ func (r *roleAuthorizationMux) ServeHTTP(w http.ResponseWriter, req *http.Reques } r.inner.ServeHTTP(w, req) } - -// apiToOutgoingGatewayCtx converts an HTTP API (v1 or v2) context, to one that -// can issue outgoing RPC requests under the same logged-in user. -func apiToOutgoingGatewayCtx(ctx context.Context, r *http.Request) context.Context { - return metadata.NewOutgoingContext(ctx, forwardAuthenticationMetadata(ctx, r)) -} diff --git a/pkg/server/api_v2_ranges.go b/pkg/server/api_v2_ranges.go index 0a20c6699392..787df05f47a1 100644 --- a/pkg/server/api_v2_ranges.go +++ b/pkg/server/api_v2_ranges.go @@ -106,7 +106,7 @@ type nodesResponse struct { func (a *apiV2SystemServer) listNodes(w http.ResponseWriter, r *http.Request) { ctx := r.Context() limit, offset := getSimplePaginationValues(r) - ctx = apiToOutgoingGatewayCtx(ctx, r) + ctx = forwardHTTPAuthInfoToRPCCalls(ctx, r) nodes, next, err := a.systemStatus.nodesHelper(ctx, limit, offset) if err != nil { @@ -195,7 +195,7 @@ type rangeResponse struct { // "$ref": "#/definitions/rangeResponse" func (a *apiV2Server) listRange(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - ctx = apiToOutgoingGatewayCtx(ctx, r) + ctx = forwardHTTPAuthInfoToRPCCalls(ctx, r) vars := mux.Vars(r) rangeID, err := strconv.ParseInt(vars["range_id"], 10, 64) if err != nil { @@ -378,7 +378,7 @@ type nodeRangesResponse struct { // "$ref": "#/definitions/nodeRangesResponse" func (a *apiV2SystemServer) listNodeRanges(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - ctx = apiToOutgoingGatewayCtx(ctx, r) + ctx = forwardHTTPAuthInfoToRPCCalls(ctx, r) vars := mux.Vars(r) nodeIDStr := vars["node_id"] if nodeIDStr != "local" { @@ -497,7 +497,7 @@ type hotRangeInfo struct { // "$ref": "#/definitions/hotRangesResponse" func (a *apiV2Server) listHotRanges(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - ctx = apiToOutgoingGatewayCtx(ctx, r) + ctx = forwardHTTPAuthInfoToRPCCalls(ctx, r) nodeIDStr := r.URL.Query().Get("node_id") limit, start := getRPCPaginationValues(r) diff --git a/pkg/server/api_v2_sql.go b/pkg/server/api_v2_sql.go index 0de234773dc6..b73fbc03aa32 100644 --- a/pkg/server/api_v2_sql.go +++ b/pkg/server/api_v2_sql.go @@ -359,7 +359,7 @@ func (a *apiV2Server) execSQL(w http.ResponseWriter, r *http.Request) { } // The SQL username that owns this session. - username := getSQLUsername(ctx) + username := userFromHTTPAuthInfoContext(ctx) options := []isql.TxnOption{ isql.WithPriority(admissionpb.NormalPri), diff --git a/pkg/server/api_v2_sql_schema.go b/pkg/server/api_v2_sql_schema.go index c83ea7bde0b2..f26b905891f1 100644 --- a/pkg/server/api_v2_sql_schema.go +++ b/pkg/server/api_v2_sql_schema.go @@ -62,7 +62,7 @@ type usersResponse struct { func (a *apiV2Server) listUsers(w http.ResponseWriter, r *http.Request) { limit, offset := getSimplePaginationValues(r) ctx := r.Context() - username := getSQLUsername(ctx) + username := userFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) query := `SELECT username FROM system.users WHERE "isRole" = false ORDER BY username` @@ -149,7 +149,7 @@ type eventsResponse struct { func (a *apiV2Server) listEvents(w http.ResponseWriter, r *http.Request) { limit, offset := getSimplePaginationValues(r) ctx := r.Context() - username := getSQLUsername(ctx) + username := userFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) queryValues := r.URL.Query() @@ -213,7 +213,7 @@ type databasesResponse struct { func (a *apiV2Server) listDatabases(w http.ResponseWriter, r *http.Request) { limit, offset := getSimplePaginationValues(r) ctx := r.Context() - username := getSQLUsername(ctx) + username := userFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) var resp databasesResponse @@ -263,7 +263,7 @@ type databaseDetailsResponse struct { // description: Database not found func (a *apiV2Server) databaseDetails(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - username := getSQLUsername(ctx) + username := userFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) pathVars := mux.Vars(r) req := &serverpb.DatabaseDetailsRequest{ @@ -337,7 +337,7 @@ type databaseGrantsResponse struct { func (a *apiV2Server) databaseGrants(w http.ResponseWriter, r *http.Request) { ctx := r.Context() limit, offset := getSimplePaginationValues(r) - username := getSQLUsername(ctx) + username := userFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) pathVars := mux.Vars(r) req := &serverpb.DatabaseDetailsRequest{ @@ -412,7 +412,7 @@ type databaseTablesResponse struct { func (a *apiV2Server) databaseTables(w http.ResponseWriter, r *http.Request) { ctx := r.Context() limit, offset := getSimplePaginationValues(r) - username := getSQLUsername(ctx) + username := userFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) pathVars := mux.Vars(r) req := &serverpb.DatabaseDetailsRequest{ @@ -473,7 +473,7 @@ type tableDetailsResponse serverpb.TableDetailsResponse // description: Database or table not found func (a *apiV2Server) tableDetails(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - username := getSQLUsername(ctx) + username := userFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) pathVars := mux.Vars(r) req := &serverpb.TableDetailsRequest{ diff --git a/pkg/server/authentication.go b/pkg/server/authentication.go index ac15f7d70552..784d9d256ee8 100644 --- a/pkg/server/authentication.go +++ b/pkg/server/authentication.go @@ -572,10 +572,8 @@ const webSessionIDKeyStr = "websessionid" func (am *authenticationMux) ServeHTTP(w http.ResponseWriter, req *http.Request) { username, cookie, err := am.getSession(w, req) if err == nil { - ctx := req.Context() - ctx = context.WithValue(ctx, webSessionUserKey{}, username) - ctx = context.WithValue(ctx, webSessionIDKey{}, cookie.ID) - req = req.WithContext(ctx) + req = req.WithContext( + contextWithHTTPAuthInfo(req.Context(), username, cookie.ID)) } else if !am.allowAnonymous { if log.V(1) { log.Infof(req.Context(), "web session error: %v", err) @@ -669,7 +667,42 @@ func authenticationHeaderMatcher(key string) (string, bool) { return fmt.Sprintf("%s%s", gwruntime.MetadataHeaderPrefix, key), true } -func forwardAuthenticationMetadata(ctx context.Context, _ *http.Request) metadata.MD { +// contextWithHTTPAuthInfo embeds the HTTP authentication details into +// a go context. Meant for use with userFromHTTPAuthInfoContext(). +func contextWithHTTPAuthInfo( + ctx context.Context, username string, sessionID int64, +) context.Context { + ctx = context.WithValue(ctx, webSessionUserKey{}, username) + if sessionID != 0 { + ctx = context.WithValue(ctx, webSessionIDKey{}, sessionID) + } + return ctx +} + +// userFromHTTPAuthInfoContext returns a SQL username from the request +// context of a HTTP route requiring login. Only use in routes that require +// login (e.g. requiresAuth = true in the API v2 route definition). +// +// Do not use this function in _RPC_ API handlers. These access their +// SQL identity via the RPC incoming context. See +// userFromIncomingRPCContext(). +func userFromHTTPAuthInfoContext(ctx context.Context) username.SQLUsername { + return username.MakeSQLUsernameFromPreNormalizedString(ctx.Value(webSessionUserKey{}).(string)) +} + +// maybeUserFromHTTPAuthInfoContext is like userFromHTTPAuthInfoContext but +// it returns a boolean false if there is no user in the context. +func maybeUserFromHTTPAuthInfoContext(ctx context.Context) (username.SQLUsername, bool) { + if u := ctx.Value(webSessionUserKey{}); u != nil { + return username.MakeSQLUsernameFromPreNormalizedString(u.(string)), true + } + return username.SQLUsername{}, false +} + +// translateHTTPAuthInfoToGRPCMetadata translates the context.Value +// that results from HTTP authentication into gRPC metadata suitable +// for use by RPC API handlers. +func translateHTTPAuthInfoToGRPCMetadata(ctx context.Context, _ *http.Request) metadata.MD { md := metadata.MD{} if user := ctx.Value(webSessionUserKey{}); user != nil { md.Set(webSessionUserKeyStr, user.(string)) @@ -680,6 +713,59 @@ func forwardAuthenticationMetadata(ctx context.Context, _ *http.Request) metadat return md } +// forwardSQLIdentityThroughRPCCalls forwards the SQL identity of the +// original request (as populated by translateHTTPAuthInfoToGRPCMetadata in +// grpc-gateway) so it remains available to the remote node handling +// the request. +func forwardSQLIdentityThroughRPCCalls(ctx context.Context) context.Context { + if md, ok := grpcutil.FastFromIncomingContext(ctx); ok { + if u, ok := md[webSessionUserKeyStr]; ok { + return metadata.NewOutgoingContext(ctx, metadata.MD{webSessionUserKeyStr: u}) + } + } + return ctx +} + +// forwardHTTPAuthInfoToRPCCalls converts an HTTP API (v1 or v2) context, to one that +// can issue outgoing RPC requests under the same logged-in user. +func forwardHTTPAuthInfoToRPCCalls(ctx context.Context, r *http.Request) context.Context { + md := translateHTTPAuthInfoToGRPCMetadata(ctx, r) + return metadata.NewOutgoingContext(ctx, md) +} + +// userFromIncomingRPCContext is to be used in RPC API handlers. It +// assumes the SQL identity was populated in the context implicitly by +// gRPC via translateHTTPAuthInfoToGRPCMetadata(), or explicitly via +// forwardHTTPAuthInfoToRPCCalls() or +// forwardSQLIdentityThroughRPCCalls(). +// +// Do not use this function in _HTTP_ API handlers. Those access their +// SQL identity via a special context key. See +// userFromHTTPAuthInfoContext(). +func userFromIncomingRPCContext(ctx context.Context) (res username.SQLUsername, err error) { + md, ok := grpcutil.FastFromIncomingContext(ctx) + if !ok { + return username.RootUserName(), nil + } + usernames, ok := md[webSessionUserKeyStr] + if !ok { + // If the incoming context has metadata but no attached web session user, + // it's a gRPC / internal SQL connection which has root on the cluster. + // This assumption is a historical hiccup, and would be best described + // as a bug. See: https://github.com/cockroachdb/cockroach/issues/45018 + return username.RootUserName(), nil + } + if len(usernames) != 1 { + log.Warningf(ctx, "context's incoming metadata contains unexpected number of usernames: %+v ", md) + return res, fmt.Errorf( + "context's incoming metadata contains unexpected number of usernames: %+v ", md) + } + // At this point the user is already logged in, so we can assume + // the username has been normalized already. + username := username.MakeSQLUsernameFromPreNormalizedString(usernames[0]) + return username, nil +} + // sessionCookieValue defines the data needed to construct the // aggregate session cookie in the order provided. type sessionCookieValue struct { diff --git a/pkg/server/combined_statement_stats.go b/pkg/server/combined_statement_stats.go index 74359fa4dd77..3e7bac42d6fe 100644 --- a/pkg/server/combined_statement_stats.go +++ b/pkg/server/combined_statement_stats.go @@ -44,7 +44,7 @@ func getTimeFromSeconds(seconds int64) *time.Time { func (s *statusServer) CombinedStatementStats( ctx context.Context, req *serverpb.CombinedStatementsStatsRequest, ) (*serverpb.StatementsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { @@ -340,7 +340,7 @@ func collectCombinedTransactions( func (s *statusServer) StatementDetails( ctx context.Context, req *serverpb.StatementDetailsRequest, ) (*serverpb.StatementDetailsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { diff --git a/pkg/server/fanout_clients.go b/pkg/server/fanout_clients.go index c55b86d8cf46..7d1c5dfc1307 100644 --- a/pkg/server/fanout_clients.go +++ b/pkg/server/fanout_clients.go @@ -228,7 +228,7 @@ func (k kvFanoutClient) dialNode(ctx context.Context, serverID serverID) (*grpc. } func (k kvFanoutClient) listNodes(ctx context.Context) (*serverpb.NodesResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = k.ambientCtx.AnnotateCtx(ctx) statuses, _, err := getNodeStatuses(ctx, k.db, 0, 0) diff --git a/pkg/server/grpc_gateway.go b/pkg/server/grpc_gateway.go index 5bf6686b2c5d..6eed6aa808d8 100644 --- a/pkg/server/grpc_gateway.go +++ b/pkg/server/grpc_gateway.go @@ -72,7 +72,7 @@ func configureGRPCGateway( gwruntime.WithMarshalerOption(httputil.ProtoContentType, protopb), gwruntime.WithMarshalerOption(httputil.AltProtoContentType, protopb), gwruntime.WithOutgoingHeaderMatcher(authenticationHeaderMatcher), - gwruntime.WithMetadata(forwardAuthenticationMetadata), + gwruntime.WithMetadata(translateHTTPAuthInfoToGRPCMetadata), ) gwCtx, gwCancel := context.WithCancel(ambientCtx.AnnotateCtx(context.Background())) stopper.AddCloser(stop.CloserFn(gwCancel)) diff --git a/pkg/server/index_usage_stats.go b/pkg/server/index_usage_stats.go index f374ec455a90..c3c2f90f41e4 100644 --- a/pkg/server/index_usage_stats.go +++ b/pkg/server/index_usage_stats.go @@ -40,7 +40,7 @@ import ( func (s *statusServer) IndexUsageStatistics( ctx context.Context, req *serverpb.IndexUsageStatisticsRequest, ) (*serverpb.IndexUsageStatisticsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { @@ -130,7 +130,7 @@ func indexUsageStatsLocal( func (s *statusServer) ResetIndexUsageStats( ctx context.Context, req *serverpb.ResetIndexUsageStatsRequest, ) (*serverpb.ResetIndexUsageStatsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -207,7 +207,7 @@ func (s *statusServer) ResetIndexUsageStats( func (s *statusServer) TableIndexStats( ctx context.Context, req *serverpb.TableIndexStatsRequest, ) (*serverpb.TableIndexStatsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { @@ -228,7 +228,7 @@ func getTableIndexUsageStats( st *cluster.Settings, execConfig *sql.ExecutorConfig, ) (*serverpb.TableIndexStatsResponse, error) { - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return nil, err } @@ -384,7 +384,7 @@ func getDatabaseIndexRecommendations( return []*serverpb.IndexRecommendation{}, nil } - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) if err != nil { return []*serverpb.IndexRecommendation{}, err } diff --git a/pkg/server/index_usage_stats_test.go b/pkg/server/index_usage_stats_test.go index 939bd5a386ba..1ce2b2c4ab21 100644 --- a/pkg/server/index_usage_stats_test.go +++ b/pkg/server/index_usage_stats_test.go @@ -360,7 +360,7 @@ CREATE TABLE schema.test_table ( `) // Get Table IDs. - userName, err := userFromContext(ctx) + userName, err := userFromIncomingRPCContext(ctx) require.NoError(t, err) testCases := []struct { diff --git a/pkg/server/server_http.go b/pkg/server/server_http.go index 4d25064b26e0..1713c1b686f7 100644 --- a/pkg/server/server_http.go +++ b/pkg/server/server_http.go @@ -111,8 +111,9 @@ func (s *httpServer) setupRoutes( NodeID: s.cfg.IDContainer, OIDC: oidc, GetUser: func(ctx context.Context) *string { - if u, ok := ctx.Value(webSessionUserKey{}).(string); ok { - return &u + if user, ok := maybeUserFromHTTPAuthInfoContext(ctx); ok { + ustring := user.Normalized() + return &ustring } return nil }, @@ -183,7 +184,7 @@ func makeAdminAuthzCheckHandler( return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { // Retrieve the username embedded in the grpc metadata, if any. // This will be provided by the authenticationMux. - md := forwardAuthenticationMetadata(req.Context(), req) + md := translateHTTPAuthInfoToGRPCMetadata(req.Context(), req) authCtx := metadata.NewIncomingContext(req.Context(), md) // Check the privileges of the requester. err := adminAuthzCheck.requireViewDebugPermission(authCtx) diff --git a/pkg/server/sql_stats.go b/pkg/server/sql_stats.go index 6cb43885ebb3..411166d3bf9c 100644 --- a/pkg/server/sql_stats.go +++ b/pkg/server/sql_stats.go @@ -23,7 +23,7 @@ import ( func (s *statusServer) ResetSQLStats( ctx context.Context, req *serverpb.ResetSQLStatsRequest, ) (*serverpb.ResetSQLStatsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { diff --git a/pkg/server/statement_diagnostics_requests.go b/pkg/server/statement_diagnostics_requests.go index 5feff91cb49b..724f612705e6 100644 --- a/pkg/server/statement_diagnostics_requests.go +++ b/pkg/server/statement_diagnostics_requests.go @@ -71,7 +71,7 @@ func (diagnostics *stmtDiagnostics) toProto() serverpb.StatementDiagnostics { func (s *statusServer) CreateStatementDiagnosticsReport( ctx context.Context, req *serverpb.CreateStatementDiagnosticsReportRequest, ) (*serverpb.CreateStatementDiagnosticsReportResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewActivityAndNoViewActivityRedactedPermission(ctx); err != nil { @@ -103,7 +103,7 @@ func (s *statusServer) CreateStatementDiagnosticsReport( func (s *statusServer) CancelStatementDiagnosticsReport( ctx context.Context, req *serverpb.CancelStatementDiagnosticsReportRequest, ) (*serverpb.CancelStatementDiagnosticsReportResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewActivityAndNoViewActivityRedactedPermission(ctx); err != nil { @@ -127,7 +127,7 @@ func (s *statusServer) CancelStatementDiagnosticsReport( func (s *statusServer) StatementDiagnosticsRequests( ctx context.Context, _ *serverpb.StatementDiagnosticsReportsRequest, ) (*serverpb.StatementDiagnosticsReportsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewActivityAndNoViewActivityRedactedPermission(ctx); err != nil { @@ -221,7 +221,7 @@ func (s *statusServer) StatementDiagnosticsRequests( func (s *statusServer) StatementDiagnostics( ctx context.Context, req *serverpb.StatementDiagnosticsRequest, ) (*serverpb.StatementDiagnosticsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewActivityAndNoViewActivityRedactedPermission(ctx); err != nil { diff --git a/pkg/server/statements.go b/pkg/server/statements.go index cfa55475b0b2..41a8ff5fc7be 100644 --- a/pkg/server/statements.go +++ b/pkg/server/statements.go @@ -33,7 +33,7 @@ func (s *statusServer) Statements( return s.CombinedStatementStats(ctx, &combinedRequest) } - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { diff --git a/pkg/server/status.go b/pkg/server/status.go index 758e61a767c6..54aade2dee81 100644 --- a/pkg/server/status.go +++ b/pkg/server/status.go @@ -85,7 +85,6 @@ import ( raft "go.etcd.io/raft/v3" "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) @@ -133,13 +132,6 @@ type metricMarshaler interface { ScrapeIntoPrometheus(pm *metric.PrometheusExporter) } -func propagateGatewayMetadata(ctx context.Context) context.Context { - if md, ok := grpcutil.FastFromIncomingContext(ctx); ok { - return metadata.NewOutgoingContext(ctx, md) - } - return ctx -} - // baseStatusServer implements functionality shared by the tenantStatusServer // and the full statusServer. type baseStatusServer struct { @@ -171,7 +163,7 @@ func isInternalAppName(app string) bool { func (b *baseStatusServer) getLocalSessions( ctx context.Context, req *serverpb.ListSessionsRequest, ) ([]serverpb.Session, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = b.AnnotateCtx(ctx) sessionUser, isAdmin, err := b.privilegeChecker.getUserAndRole(ctx) @@ -275,7 +267,7 @@ func (b *baseStatusServer) getLocalSessions( func (b *baseStatusServer) checkCancelPrivilege( ctx context.Context, reqUsername username.SQLUsername, sessionUsername username.SQLUsername, ) error { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = b.AnnotateCtx(ctx) ctxUsername, isCtxAdmin, err := b.privilegeChecker.getUserAndRole(ctx) @@ -343,7 +335,7 @@ func (b *baseStatusServer) checkCancelPrivilege( func (b *baseStatusServer) ListLocalContentionEvents( ctx context.Context, _ *serverpb.ListContentionEventsRequest, ) (*serverpb.ListContentionEventsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = b.AnnotateCtx(ctx) if err := b.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { @@ -360,7 +352,7 @@ func (b *baseStatusServer) ListLocalContentionEvents( func (b *baseStatusServer) ListLocalDistSQLFlows( ctx context.Context, _ *serverpb.ListDistSQLFlowsRequest, ) (*serverpb.ListDistSQLFlowsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = b.AnnotateCtx(ctx) if err := b.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { @@ -690,7 +682,7 @@ func (s *statusServer) dialNode( func (s *systemStatusServer) Gossip( ctx context.Context, req *serverpb.GossipRequest, ) (*gossip.InfoStatus, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -718,7 +710,7 @@ func (s *systemStatusServer) Gossip( func (s *systemStatusServer) EngineStats( ctx context.Context, req *serverpb.EngineStatsRequest, ) (*serverpb.EngineStatsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -761,7 +753,7 @@ func (s *systemStatusServer) EngineStats( func (s *systemStatusServer) Allocator( ctx context.Context, req *serverpb.AllocatorRequest, ) (*serverpb.AllocatorResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx); err != nil { @@ -854,7 +846,7 @@ func recordedSpansToTraceEvents(spans []tracingpb.RecordedSpan) []*serverpb.Trac func (s *systemStatusServer) AllocatorRange( ctx context.Context, req *serverpb.AllocatorRangeRequest, ) (*serverpb.AllocatorRangeResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx) @@ -943,7 +935,7 @@ func (s *systemStatusServer) AllocatorRange( func (s *statusServer) Certificates( ctx context.Context, req *serverpb.CertificatesRequest, ) (*serverpb.CertificatesResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -1064,7 +1056,7 @@ func extractCertFields(contents []byte, details *serverpb.CertificateDetails) er func (s *statusServer) Details( ctx context.Context, req *serverpb.DetailsRequest, ) (*serverpb.DetailsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -1105,7 +1097,7 @@ func (s *statusServer) Details( func (s *statusServer) GetFiles( ctx context.Context, req *serverpb.GetFilesRequest, ) (*serverpb.GetFilesResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -1161,7 +1153,7 @@ func checkFilePattern(pattern string) error { func (s *statusServer) LogFilesList( ctx context.Context, req *serverpb.LogFilesListRequest, ) (*serverpb.LogFilesListResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -1196,7 +1188,7 @@ func (s *statusServer) LogFilesList( func (s *statusServer) LogFile( ctx context.Context, req *serverpb.LogFileRequest, ) (*serverpb.LogEntriesResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -1294,7 +1286,7 @@ func parseInt64WithDefault(s string, defaultValue int64) (int64, error) { func (s *statusServer) Logs( ctx context.Context, req *serverpb.LogsRequest, ) (*serverpb.LogEntriesResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -1382,7 +1374,7 @@ func (s *statusServer) Logs( func (s *statusServer) Stacks( ctx context.Context, req *serverpb.StacksRequest, ) (*serverpb.JSONResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -1415,7 +1407,7 @@ func (s *statusServer) Stacks( func (s *statusServer) Profile( ctx context.Context, req *serverpb.ProfileRequest, ) (*serverpb.JSONResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -1494,7 +1486,7 @@ func regionsResponseFromNodesResponse(nr *serverpb.NodesResponse) *serverpb.Regi func (s *statusServer) NodesList( ctx context.Context, _ *serverpb.NodesListRequest, ) (*serverpb.NodesListResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) // The node status contains details about the command line, network @@ -1520,7 +1512,7 @@ func (s *statusServer) NodesList( func (s *statusServer) Nodes( ctx context.Context, req *serverpb.NodesRequest, ) (*serverpb.NodesResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { @@ -1559,7 +1551,7 @@ func (s *statusServer) Nodes( func (s *systemStatusServer) Nodes( ctx context.Context, req *serverpb.NodesRequest, ) (*serverpb.NodesResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx) @@ -1577,7 +1569,7 @@ func (s *systemStatusServer) Nodes( func (s *statusServer) NodesUI( ctx context.Context, req *serverpb.NodesRequest, ) (*serverpb.NodesResponseExternal, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) return s.sqlServer.tenantConnect.NodesUI(ctx, req) @@ -1586,7 +1578,7 @@ func (s *statusServer) NodesUI( func (s *systemStatusServer) NodesUI( ctx context.Context, req *serverpb.NodesRequest, ) (*serverpb.NodesResponseExternal, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) hasViewClusterMetadata := false @@ -1757,7 +1749,7 @@ func getNodeStatuses( func (s *systemStatusServer) nodesHelper( ctx context.Context, limit, offset int, ) (*serverpb.NodesResponse, int, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) statuses, next, err := getNodeStatuses(ctx, s.db, limit, offset) @@ -1780,7 +1772,7 @@ func (s *systemStatusServer) nodesHelper( func (s *statusServer) Node( ctx context.Context, req *serverpb.NodeRequest, ) (*statuspb.NodeStatus, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) // The node status contains details about the command line, network @@ -1822,7 +1814,7 @@ func (s *statusServer) nodeStatus( func (s *statusServer) NodeUI( ctx context.Context, req *serverpb.NodeRequest, ) (*serverpb.NodeResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) // The node status contains details about the command line, network @@ -1846,7 +1838,7 @@ func (s *statusServer) NodeUI( func (s *statusServer) Metrics( ctx context.Context, req *serverpb.MetricsRequest, ) (*serverpb.JSONResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) nodeID, local, err := s.parseNodeID(req.NodeId) @@ -1872,7 +1864,7 @@ func (s *statusServer) Metrics( func (s *systemStatusServer) RaftDebug( ctx context.Context, req *serverpb.RaftDebugRequest, ) (*serverpb.RaftDebugResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx); err != nil { @@ -1998,7 +1990,7 @@ func (s *systemStatusServer) Ranges( func (s *systemStatusServer) rangesHelper( ctx context.Context, req *serverpb.RangesRequest, limit, offset int, ) (*serverpb.RangesResponse, int, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx) @@ -2201,7 +2193,7 @@ func (s *systemStatusServer) rangesHelper( func (t *statusServer) TenantRanges( ctx context.Context, req *serverpb.TenantRangesRequest, ) (*serverpb.TenantRangesResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = t.AnnotateCtx(ctx) // The tenant range report contains replica metadata which is admin-only. @@ -2215,7 +2207,7 @@ func (t *statusServer) TenantRanges( func (s *systemStatusServer) TenantRanges( ctx context.Context, req *serverpb.TenantRangesRequest, ) (*serverpb.TenantRangesResponse, error) { - propagateGatewayMetadata(ctx) + forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { return nil, err @@ -2366,7 +2358,7 @@ func (s *systemStatusServer) TenantRanges( func (s *systemStatusServer) HotRanges( ctx context.Context, req *serverpb.HotRangesRequest, ) (*serverpb.HotRangesResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx); err != nil { @@ -2445,7 +2437,7 @@ func (t *statusServer) HotRangesV2( func (s *systemStatusServer) HotRangesV2( ctx context.Context, req *serverpb.HotRangesRequest, ) (*serverpb.HotRangesResponseV2, error) { - ctx = s.AnnotateCtx(propagateGatewayMetadata(ctx)) + ctx = s.AnnotateCtx(forwardSQLIdentityThroughRPCCalls(ctx)) err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx) if err != nil { @@ -2718,7 +2710,7 @@ func (s *statusServer) KeyVisSamples( func (s *statusServer) Range( ctx context.Context, req *serverpb.RangeRequest, ) (*serverpb.RangeResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx); err != nil { @@ -3009,7 +3001,7 @@ func (s *statusServer) listSessionsHelper( func (s *statusServer) ListSessions( ctx context.Context, req *serverpb.ListSessionsRequest, ) (*serverpb.ListSessionsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, _, err := s.privilegeChecker.getUserAndRole(ctx); err != nil { @@ -3030,7 +3022,7 @@ func (s *statusServer) ListSessions( func (s *statusServer) CancelSession( ctx context.Context, req *serverpb.CancelSessionRequest, ) (*serverpb.CancelSessionResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) sessionIDBytes := req.SessionID @@ -3082,7 +3074,7 @@ func (s *statusServer) CancelSession( func (s *statusServer) CancelQuery( ctx context.Context, req *serverpb.CancelQueryRequest, ) (*serverpb.CancelQueryResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) queryID, err := clusterunique.IDFromString(req.QueryID) @@ -3181,7 +3173,7 @@ func (s *statusServer) CancelQueryByKey( } // This request needs to be forwarded to another node. - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) client, err := s.dialNode(ctx, roachpb.NodeID(req.SQLInstanceID)) if err != nil { @@ -3195,7 +3187,7 @@ func (s *statusServer) CancelQueryByKey( func (s *statusServer) ListContentionEvents( ctx context.Context, req *serverpb.ListContentionEventsRequest, ) (*serverpb.ListContentionEventsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) // Check permissions early to avoid fan-out to all nodes. @@ -3242,7 +3234,7 @@ func (s *statusServer) ListContentionEvents( func (s *statusServer) ListDistSQLFlows( ctx context.Context, request *serverpb.ListDistSQLFlowsRequest, ) (*serverpb.ListDistSQLFlowsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) // Check permissions early to avoid fan-out to all nodes. @@ -3334,7 +3326,7 @@ func mergeDistSQLRemoteFlows(a, b []serverpb.DistSQLRemoteFlows) []serverpb.Dist func (s *statusServer) ListExecutionInsights( ctx context.Context, req *serverpb.ListExecutionInsightsRequest, ) (*serverpb.ListExecutionInsightsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) // Check permissions early to avoid fan-out to all nodes. @@ -3396,7 +3388,7 @@ func (s *statusServer) ListExecutionInsights( func (s *systemStatusServer) SpanStats( ctx context.Context, req *serverpb.SpanStatsRequest, ) (*serverpb.SpanStatsResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -3440,7 +3432,7 @@ func (s *systemStatusServer) SpanStats( func (s *statusServer) Diagnostics( ctx context.Context, req *serverpb.DiagnosticsRequest, ) (*diagnosticspb.DiagnosticReport, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) nodeID, local, err := s.parseNodeID(req.NodeId) if err != nil { @@ -3462,7 +3454,7 @@ func (s *statusServer) Diagnostics( func (s *systemStatusServer) Stores( ctx context.Context, req *serverpb.StoresRequest, ) (*serverpb.StoresResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx); err != nil { @@ -3548,28 +3540,6 @@ func marshalJSONResponse(value interface{}) (*serverpb.JSONResponse, error) { return &serverpb.JSONResponse{Data: data}, nil } -func userFromContext(ctx context.Context) (res username.SQLUsername, err error) { - md, ok := grpcutil.FastFromIncomingContext(ctx) - if !ok { - return username.RootUserName(), nil - } - usernames, ok := md[webSessionUserKeyStr] - if !ok { - // If the incoming context has metadata but no attached web session user, - // it's a gRPC / internal SQL connection which has root on the cluster. - return username.RootUserName(), nil - } - if len(usernames) != 1 { - log.Warningf(ctx, "context's incoming metadata contains unexpected number of usernames: %+v ", md) - return res, fmt.Errorf( - "context's incoming metadata contains unexpected number of usernames: %+v ", md) - } - // At this point the user is already logged in, so we can assume - // the username has been normalized already. - username := username.MakeSQLUsernameFromPreNormalizedString(usernames[0]) - return username, nil -} - type systemInfoOnce struct { once sync.Once info serverpb.SystemInfo @@ -3612,7 +3582,7 @@ func (si *systemInfoOnce) systemInfo(ctx context.Context) serverpb.SystemInfo { func (s *statusServer) JobRegistryStatus( ctx context.Context, req *serverpb.JobRegistryStatusRequest, ) (*serverpb.JobRegistryStatusResponse, error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { @@ -3651,7 +3621,7 @@ func (s *statusServer) JobRegistryStatus( func (s *statusServer) JobStatus( ctx context.Context, req *serverpb.JobStatusRequest, ) (*serverpb.JobStatusResponse, error) { - ctx = s.AnnotateCtx(propagateGatewayMetadata(ctx)) + ctx = s.AnnotateCtx(forwardSQLIdentityThroughRPCCalls(ctx)) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { // NB: not using serverError() here since the priv checker @@ -3684,7 +3654,7 @@ func (s *statusServer) JobStatus( func (s *statusServer) TxnIDResolution( ctx context.Context, req *serverpb.TxnIDResolutionRequest, ) (*serverpb.TxnIDResolutionResponse, error) { - ctx = s.AnnotateCtx(propagateGatewayMetadata(ctx)) + ctx = s.AnnotateCtx(forwardSQLIdentityThroughRPCCalls(ctx)) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { return nil, err } @@ -3708,7 +3678,7 @@ func (s *statusServer) TxnIDResolution( func (s *statusServer) TransactionContentionEvents( ctx context.Context, req *serverpb.TransactionContentionEventsRequest, ) (*serverpb.TransactionContentionEventsResponse, error) { - ctx = s.AnnotateCtx(propagateGatewayMetadata(ctx)) + ctx = s.AnnotateCtx(forwardSQLIdentityThroughRPCCalls(ctx)) if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { return nil, err diff --git a/pkg/server/user.go b/pkg/server/user.go index 13943c243232..2cd2b1d8e454 100644 --- a/pkg/server/user.go +++ b/pkg/server/user.go @@ -22,7 +22,7 @@ import ( func (s *baseStatusServer) UserSQLRoles( ctx context.Context, req *serverpb.UserSQLRolesRequest, ) (_ *serverpb.UserSQLRolesResponse, retErr error) { - ctx = propagateGatewayMetadata(ctx) + ctx = forwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) username, isAdmin, err := s.privilegeChecker.getUserAndRole(ctx)