diff --git a/DEPS.bzl b/DEPS.bzl index 9e868da42f33..b0c3961bd882 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -1169,8 +1169,8 @@ def go_deps(): name = "com_github_gorilla_mux", build_file_proto_mode = "disable_global", importpath = "github.com/gorilla/mux", - sum = "h1:VuZ8uybHlWmqV03+zRzdwKL4tUnIp1MAQtp1mIFE1bc=", - version = "v1.7.4", + sum = "h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=", + version = "v1.8.0", ) go_repository( name = "com_github_gorilla_securecookie", diff --git a/go.mod b/go.mod index 3b69f4f094c6..1ea210525845 100644 --- a/go.mod +++ b/go.mod @@ -81,7 +81,7 @@ require ( github.com/google/pprof v0.0.0-20190109223431-e84dfd68c163 github.com/googleapis/gax-go v2.0.2+incompatible // indirect github.com/gorhill/cronexpr v0.0.0-20140423231348-a557574d6c02 - github.com/gorilla/mux v1.7.4 // indirect + github.com/gorilla/mux v1.8.0 github.com/goware/modvendor v0.3.0 github.com/grpc-ecosystem/grpc-gateway v1.13.0 github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 diff --git a/go.sum b/go.sum index 5f1780ab85f8..6116772b87e7 100644 --- a/go.sum +++ b/go.sum @@ -358,8 +358,8 @@ github.com/googleapis/gax-go v2.0.2+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorhill/cronexpr v0.0.0-20140423231348-a557574d6c02 h1:Spo+4PFAGDqULAsZ7J69MOxq4/fwgZ0zvmDTBqpq7yU= github.com/gorhill/cronexpr v0.0.0-20140423231348-a557574d6c02/go.mod h1:g2644b03hfBX9Ov0ZBDgXXens4rxSxmqFBbhvKv2yVA= -github.com/gorilla/mux v1.7.4 h1:VuZ8uybHlWmqV03+zRzdwKL4tUnIp1MAQtp1mIFE1bc= -github.com/gorilla/mux v1.7.4/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/sessions v1.2.0 h1:S7P+1Hm5V/AT9cjEcUD5uDaQSX0OE577aCXgoaKpYbQ= diff --git a/pkg/server/BUILD.bazel b/pkg/server/BUILD.bazel index 31820685f03c..6c5c9210e207 100644 --- a/pkg/server/BUILD.bazel +++ b/pkg/server/BUILD.bazel @@ -4,6 +4,7 @@ go_library( name = "server", srcs = [ "admin.go", + "api.go", "api_error.go", "authentication.go", "auto_upgrade.go", @@ -166,6 +167,7 @@ go_library( "@com_github_cockroachdb_sentry_go//:sentry-go", "@com_github_elastic_gosigar//:gosigar", "@com_github_gogo_protobuf//proto", + "@com_github_gorilla_mux///mux", "@com_github_grpc_ecosystem_grpc_gateway//runtime:go_default_library", "@com_github_grpc_ecosystem_grpc_gateway//utilities:go_default_library", "@com_github_marusama_semaphore//:semaphore", diff --git a/pkg/server/api.go b/pkg/server/api.go new file mode 100644 index 000000000000..1358236d5fcd --- /dev/null +++ b/pkg/server/api.go @@ -0,0 +1,208 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package server + +import ( + "context" + "encoding/json" + "net/http" + "strconv" + + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/sql/roleoption" + "github.com/gorilla/mux" +) + +const ( + apiV2Path = "/api/v2/" + apiV2AuthHeader = "X-Cockroach-API-Session" +) + +func writeJsonResponse(w http.ResponseWriter, code int, payload interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + + res, err := json.Marshal(payload) + if err != nil { + panic(err) + } + if _, err := w.Write(res); err != nil { + panic(err) + } +} + +// apiV2Server implements endpoints under apiV2Path. +type apiV2Server struct { + admin *adminServer + authServer *authenticationV2Server + status *statusServer + mux *mux.Router +} + +func newApiServer(ctx context.Context, s *Server) *apiV2Server { + authServer := newAuthenticationV2Server(ctx, s, apiV2Path) + innerMux := mux.NewRouter() + + authMux := newAuthenticationV2Mux(authServer, innerMux) + outerMux := mux.NewRouter() + a := &apiV2Server{ + admin: s.admin, + authServer: authServer, + status: s.status, + mux: outerMux, + } + a.registerRoutes(innerMux, authMux) + return a +} + +func (a *apiV2Server) registerRoutes(innerMux *mux.Router, authMux http.Handler) { + var noOption roleoption.Option + routeDefinitions := []struct{ + endpoint string + handler http.HandlerFunc + requiresAuth bool + role apiRole + option roleoption.Option + }{ + // Pass through auth-related endpoints to the auth server. + {"login/", a.authServer.ServeHTTP, false /* requiresAuth */, regularRole, noOption}, + {"logout/", a.authServer.ServeHTTP, false /* requiresAuth */, regularRole, noOption}, + + // Directly register other endpoints in the api server. + {"sessions/", a.listSessions, true /* requiresAuth */ , adminRole, noOption}, + {"hotranges/", a.hotRanges, true /* requiresAuth */ , adminRole, noOption}, + {"ranges/{range_id}/", a.rangeHandler, true /* requiresAuth */ , adminRole, noOption}, + {"nodes/", a.nodes, true /* requiresAuth */ , regularRole, noOption}, + } + + // For all routes requiring authentication, have the outer mux (a.mux) + // send requests through to the authMux, and also register the relevant route + // in innerMux. Routes not requiring login can directly be handled in a.mux. + for _, route := range routeDefinitions { + if route.requiresAuth { + a.mux.Handle(apiV2Path + route.endpoint, authMux) + handler := http.Handler(route.handler) + if route.role != regularRole { + handler = &roleAuthorizationMux{ + ie: a.admin.ie, + role: route.role, + option: route.option, + inner: route.handler, + } + } + innerMux.Handle(apiV2Path + route.endpoint, handler) + } else { + a.mux.HandleFunc(apiV2Path + route.endpoint, route.handler) + } + } +} + +type listSessionsResponse struct { + serverpb.ListSessionsResponse + + Next string `json:"next"` +} + +func (a *apiV2Server) listSessions(w http.ResponseWriter, r *http.Request) { + limit, start := getRPCPaginationValues(r) + req := &serverpb.ListSessionsRequest{Username: r.Context().Value(webSessionUserKey{}).(string)} + response := &listSessionsResponse{} + + responseProto, pagState, err := a.status.listSessionsHelper(r.Context(), req, limit, start) + var nextBytes []byte + if nextBytes, err = pagState.MarshalText(); err != nil { + err := serverpb.ListSessionsError{Message: err.Error()} + response.Errors = append(response.Errors, err) + } else { + response.Next = string(nextBytes) + } + response.ListSessionsResponse = *responseProto + writeJsonResponse(w, http.StatusOK, response) +} + +type rangeResponse struct { + serverpb.RangeResponse + + Next string `json:"next"` +} + +func (a *apiV2Server) rangeHandler(w http.ResponseWriter, r *http.Request) { + limit, start := getRPCPaginationValues(r) + var err error + var rangeID int64 + vars := mux.Vars(r) + if rangeID, err = strconv.ParseInt(vars["range_id"], 10, 64); err != nil { + http.Error(w, "invalid range id", http.StatusBadRequest) + return + } + + req := &serverpb.RangeRequest{RangeId: rangeID} + response := &rangeResponse{} + responseProto, next, err := a.status.rangeHelper(r.Context(), req, limit, start) + if err != nil { + apiV2InternalError(r.Context(), err, w) + return + } + response.RangeResponse = *responseProto + if nextBytes, err := next.MarshalText(); err == nil { + response.Next = string(nextBytes) + } + writeJsonResponse(w, http.StatusOK, response) +} + +type hotRangesResponse struct { + serverpb.HotRangesResponse + + Next string `json:"next"` +} + +func (a *apiV2Server) hotRanges(w http.ResponseWriter, r *http.Request) { + limit, start := getRPCPaginationValues(r) + req := &serverpb.HotRangesRequest{NodeID: r.URL.Query().Get("node_id")} + response := &hotRangesResponse{} + + responseProto, next, err := a.status.hotRangesHelper(r.Context(), req, limit, start) + if err != nil { + apiV2InternalError(r.Context(), err, w) + return + } + response.HotRangesResponse = *responseProto + if nextBytes, err := next.MarshalText(); err == nil { + response.Next = string(nextBytes) + } + writeJsonResponse(w, http.StatusOK, response) +} + +type nodesResponse struct { + serverpb.NodesResponse + + Next int `json:"next"` +} + +func (a *apiV2Server) nodes(w http.ResponseWriter, r *http.Request) { + limit, offset := getSimplePaginationValues(r) + req := &serverpb.NodesRequest{} + response := &nodesResponse{} + + responseProto, next, err := a.status.nodesHelper(r.Context(), req, limit, offset) + if err != nil { + apiV2InternalError(r.Context(), err, w) + return + } + response.NodesResponse = *responseProto + response.Next = next + writeJsonResponse(w, http.StatusOK, response) +} + +func (a *apiV2Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + a.mux.ServeHTTP(w, r) +} + diff --git a/pkg/server/api_auth.go b/pkg/server/api_auth.go new file mode 100644 index 000000000000..33de2e334cfd --- /dev/null +++ b/pkg/server/api_auth.go @@ -0,0 +1,355 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package server + +import ( + "context" + "encoding/base64" + "net/http" + + "github.com/cockroachdb/cockroach/pkg/security" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/sql/roleoption" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/protoutil" + "github.com/cockroachdb/errors" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type authenticationV2Server struct { + ctx context.Context + sqlServer *SQLServer + authServer *authenticationServer + mux *http.ServeMux + basePath string +} + +func newAuthenticationV2Server(ctx context.Context, s *Server, basePath string) *authenticationV2Server { + simpleMux := http.NewServeMux() + + authServer := &authenticationV2Server{ + sqlServer: s.sqlServer, + authServer: newAuthenticationServer(s), + mux: simpleMux, + ctx: ctx, + basePath: basePath, + } + + authServer.registerRoutes() + return authServer +} + +func (a *authenticationV2Server) registerRoutes() { + a.bindEndpoint("login/", a.login) + a.bindEndpoint("logout/", a.logout) +} + +func (a *authenticationV2Server) bindEndpoint(endpoint string, handler http.HandlerFunc) { + a.mux.HandleFunc(a.basePath+endpoint, handler) +} + +// createSessionFor creates a login session for the given user. +// +// The caller is responsible to ensure the username has been normalized already. +func (a *authenticationV2Server) createSessionFor( + ctx context.Context, username security.SQLUsername, +) (string, error) { + // Create a new database session, generating an ID and secret key. + id, secret, err := a.authServer.newAuthSession(ctx, username) + if err != nil { + return "", apiInternalError(ctx, err) + } + + // Generate and set a session for the response. Because HTTP cookies + // must be strings, the cookie value (a marshaled protobuf) is encoded in + // base64. We just piggyback on the v1 API SessionCookie here, however + // this won't be set as an HTTP cookie on the client side. + cookieValue := &serverpb.SessionCookie{ + ID: id, + Secret: secret, + } + cookieValueBytes, err := protoutil.Marshal(cookieValue) + if err != nil { + return "", errors.Wrap(err, "session cookie could not be encoded") + } + value := base64.StdEncoding.EncodeToString(cookieValueBytes) + return value, nil +} + +type loginResponse struct { + Session string `json:"session"` +} + +func (a *authenticationV2Server) login(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + http.Error(w, "not found", http.StatusNotFound) + } + if err := r.ParseForm(); err != nil { + apiV2InternalError(r.Context(), err, w) + return + } + if r.Form.Get("username") == "" { + http.Error(w, "username not specified", http.StatusBadRequest) + return + } + + // In CockroachDB SQL, unlike in PostgreSQL, usernames are + // case-insensitive. Therefore we need to normalize the username + // here, so that the normalized username is retained in the session + // table: the APIs extract the username from the session table + // without further normalization. + username, _ := security.MakeSQLUsernameFromUserInput(r.Form.Get("username"), security.UsernameValidation) + + // Verify the provided username/password pair. + verified, expired, err := a.authServer.verifyPassword(a.ctx, username, r.Form.Get("password")) + if err != nil { + apiV2InternalError(r.Context(), err, w) + return + } + if expired { + http.Error(w, "the password has expired", http.StatusUnauthorized) + return + } + if !verified { + http.Error(w, "the provided credentials did not match any account on the server", http.StatusUnauthorized) + return + } + + session, err := a.createSessionFor(a.ctx, username) + if err != nil { + apiV2InternalError(r.Context(), err, w) + return + } + + writeJsonResponse(w, http.StatusOK, &loginResponse{Session: session}) +} + +type logoutResponse struct { + LoggedOut bool `json:"logged_out"` +} + +func (a *authenticationV2Server) logout(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + http.Error(w, "not found", http.StatusNotFound) + } + session := r.Header.Get(apiV2AuthHeader) + if session == "" { + http.Error(w, "invalid or unspecified session", http.StatusBadRequest) + return + } + var sessionCookie serverpb.SessionCookie + decoded, err := base64.StdEncoding.DecodeString(session) + if err != nil { + apiV2InternalError(r.Context(), err, w) + return + } + if err := protoutil.Unmarshal(decoded, &sessionCookie); err != nil { + apiV2InternalError(r.Context(), err, w) + return + } + + // Revoke the session. + if n, err := a.sqlServer.internalExecutor.ExecEx( + a.ctx, + "revoke-auth-session", + nil, /* txn */ + sessiondata.InternalExecutorOverride{User: security.RootUserName()}, + `UPDATE system.web_sessions SET "revokedAt" = now() WHERE id = $1`, + sessionCookie.ID, + ); err != nil { + apiV2InternalError(r.Context(), err, w) + return + } else if n == 0 { + err := status.Errorf( + codes.InvalidArgument, + "session with id %d nonexistent", sessionCookie.ID) + log.Infof(a.ctx, "%v", err) + http.Error(w, "invalid session", http.StatusBadRequest) + return + } + + writeJsonResponse(w, http.StatusOK, &logoutResponse{LoggedOut: true}) +} + +func (a *authenticationV2Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + a.mux.ServeHTTP(w, r) +} + +// authenticationV2Mux provides authentication checks for an arbitrary inner +// http.Handler. If the session cookie is not set, an HTTP 401 error is returned +// and the request isn't routed through to the inner handler. On success, the +// username is set on the request context for use in the inner handler. +type authenticationV2Mux struct { + s *authenticationV2Server + inner http.Handler +} + +func newAuthenticationV2Mux(s *authenticationV2Server, inner http.Handler) *authenticationV2Mux { + return &authenticationV2Mux{ + s: s, + inner: inner, + } +} + +// getSession decodes the cookie from the request, looks up the corresponding session, and +// returns the logged in user name. If there's an error, it returns an error value and +// also sends the error over http using w. +func (a *authenticationV2Mux) getSession( + w http.ResponseWriter, req *http.Request, +) (string, *serverpb.SessionCookie, error) { + // Validate the returned cookie. + rawSession := req.Header.Get(apiV2AuthHeader) + if len(rawSession) == 0 { + err := errors.New("invalid session header") + http.Error(w, err.Error(), http.StatusUnauthorized) + return "", nil, err + } + sessionCookie := &serverpb.SessionCookie{} + decoded, err := base64.StdEncoding.DecodeString(rawSession) + if err != nil { + err := errors.New("invalid session header") + http.Error(w, err.Error(), http.StatusBadRequest) + return "", nil, err + } + if err := protoutil.Unmarshal(decoded, sessionCookie); err != nil { + err := errors.New("invalid session header") + http.Error(w, err.Error(), http.StatusBadRequest) + return "", nil, err + } + + valid, username, err := a.s.authServer.verifySession(req.Context(), sessionCookie) + if err != nil { + apiV2InternalError(req.Context(), err, w) + return "", nil, err + } + if !valid { + err := errors.New("the provided authentication session could not be validated") + http.Error(w, err.Error(), http.StatusUnauthorized) + return "", nil, err + } + + return username, sessionCookie, nil +} + +func (am *authenticationV2Mux) ServeHTTP(w http.ResponseWriter, req *http.Request) { + username, cookie, err := am.getSession(w, req) + if err == nil { + ctx := req.Context() + ctx = context.WithValue(ctx, webSessionUserKey{}, username) + ctx = context.WithValue(ctx, webSessionIDKey{}, cookie.ID) + req = req.WithContext(ctx) + } else { + // getSession writes an error to w if err != nil. + return + } + am.inner.ServeHTTP(w, req) +} + +type apiRole int + +const ( + regularRole apiRole = iota + adminRole + superUserRole +) + +// roleAuthorizationMux enforces a role (eg. role, for an arbitrary inner mux. +type roleAuthorizationMux struct { + ie *sql.InternalExecutor + role apiRole + option roleoption.Option + inner http.Handler +} + +func (r *roleAuthorizationMux) getRoleForUser(ctx context.Context, user security.SQLUsername) (apiRole, error) { + if user.IsRootUser() { + // Shortcut. + return superUserRole, nil + } + rows, _, err := r.ie.QueryWithCols( + ctx, "check-is-admin", nil, /* txn */ + sessiondata.InternalExecutorOverride{User: user}, + "SELECT crdb_internal.is_admin()") + if err != nil { + return regularRole, err + } + if len(rows) != 1 { + return regularRole, errors.AssertionFailedf("hasAdminRole: expected 1 row, got %d", len(rows)) + } + if len(rows[0]) != 1 { + return regularRole, errors.AssertionFailedf("hasAdminRole: expected 1 column, got %d", len(rows[0])) + } + dbDatum, ok := tree.AsDBool(rows[0][0]) + if !ok { + return regularRole, errors.AssertionFailedf("hasAdminRole: expected bool, got %T", rows[0][0]) + } + if dbDatum { + return adminRole, nil + } + return regularRole, nil +} + +func (r *roleAuthorizationMux) hasRoleOption( + ctx context.Context, user security.SQLUsername, roleOption roleoption.Option, +) (bool, error) { + if user.IsRootUser() { + // Shortcut. + return true, nil + } + rows, _, err := r.ie.QueryWithCols( + ctx, "check-role-option", nil, /* txn */ + sessiondata.InternalExecutorOverride{User: user}, + "SELECT crdb_internal.has_role_option($1)", roleOption.String()) + if err != nil { + return false, err + } + if len(rows) != 1 { + return false, errors.AssertionFailedf("hasRoleOption: expected 1 row, got %d", len(rows)) + } + if len(rows[0]) != 1 { + return false, errors.AssertionFailedf("hasRoleOption: expected 1 column, got %d", len(rows[0])) + } + dbDatum, ok := tree.AsDBool(rows[0][0]) + if !ok { + return false, errors.AssertionFailedf("hasRoleOption: expected bool, got %T", rows[0][0]) + } + return bool(dbDatum), nil +} + +func (r *roleAuthorizationMux) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // The username is set in authenticationV2Mux, and must correspond with a + // logged-in user. + username := security.MakeSQLUsernameFromPreNormalizedString( + req.Context().Value(webSessionUserKey{}).(string)) + if role, err := r.getRoleForUser(req.Context(), username); err != nil || role < r.role { + if err != nil { + apiV2InternalError(req.Context(), err, w) + } else { + http.Error(w, "user not allowed to access this endpoint", http.StatusForbidden) + } + return + } + if r.option > 0 { + ok, err := r.hasRoleOption(req.Context(), username, r.option) + if err != nil { + apiV2InternalError(req.Context(), err, w) + return + } else if !ok { + http.Error(w, "user not allowed to access this endpoint", http.StatusForbidden) + return + } + } + r.inner.ServeHTTP(w, req) +} diff --git a/pkg/server/api_error.go b/pkg/server/api_error.go index d1a7ca7d4a1f..102ad0587d13 100644 --- a/pkg/server/api_error.go +++ b/pkg/server/api_error.go @@ -12,15 +12,18 @@ package server import ( "context" + "net/http" "github.com/cockroachdb/cockroach/pkg/util/log" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) +var errAPIInternalErrorString = "An internal server error has occurred. Please check your CockroachDB logs for more details." + var errAPIInternalError = status.Errorf( codes.Internal, - "An internal server error has occurred. Please check your CockroachDB logs for more details.", + errAPIInternalErrorString, ) // apiInternalError should be used to wrap server-side errors during API @@ -31,3 +34,12 @@ func apiInternalError(ctx context.Context, err error) error { log.ErrorfDepth(ctx, 1, "%s", err) return errAPIInternalError } + +// apiV2InternalError should be used to wrap server-side errors during API +// requests for V2 (non-GRPC) endpoints. This method records the contents +// of the error to the server log, and sends the standard internal error string +// over the http.ResponseWriter. +func apiV2InternalError(ctx context.Context, err error, w http.ResponseWriter) { + log.ErrorfDepth(ctx, 1, "%s", err) + http.Error(w, errAPIInternalErrorString, http.StatusInternalServerError) +} diff --git a/pkg/server/authentication_test.go b/pkg/server/authentication_test.go index fe831f5a1d1c..2965331c980f 100644 --- a/pkg/server/authentication_test.go +++ b/pkg/server/authentication_test.go @@ -562,7 +562,7 @@ func TestLogout(t *testing.T) { ts := s.(*TestServer) // Log in. - authHTTPClient, cookie, err := ts.getAuthenticatedHTTPClientAndCookie(authenticatedUserName(), true) + authHTTPClient, cookie, err := ts.getAuthenticatedHTTPClientAndCookie(authenticatedUserName(), true, false) if err != nil { t.Fatal("error opening HTTP client", err) } diff --git a/pkg/server/pagination.go b/pkg/server/pagination.go new file mode 100644 index 000000000000..47c54ba4a8b2 --- /dev/null +++ b/pkg/server/pagination.go @@ -0,0 +1,429 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package server + +import ( + "bytes" + "context" + "encoding/base64" + "fmt" + "io/ioutil" + "net/http" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/util/contextutil" + "github.com/cockroachdb/errors" +) + +// simplePaginate takes in an input slice, and returns a sub-slice of the next +// `limit` elements starting at `offset`. The second returned value is the +// next offset that can be used to return the next "limit" results, or +// len(result) if there are no more results. +func simplePaginate(input interface{}, limit, offset int) (result interface{}, next int) { + val := reflect.ValueOf(input) + if limit <= 0 || val.Kind() != reflect.Slice { + return input, 0 + } else if offset < 0 { + offset = 0 + } + startIdx := offset + endIdx := offset + limit + if startIdx > val.Len() { + startIdx = val.Len() + } + if endIdx > val.Len() { + endIdx = val.Len() + } + return val.Slice(startIdx, endIdx).Interface(), endIdx +} + +// paginationState represents the current state of pagination through the result +// set of an RPC-based endpoint. Meant for use with rpcNodePaginator, which +// implements most of the pagination logic. +type paginationState struct { + nodesQueried []roachpb.NodeID + inProgress roachpb.NodeID + inProgressIndex int + nodesToQuery []roachpb.NodeID +} + +// mergeNodeIDs merges allNodeIDs with all node IDs in the paginationState; +// adding any nodes to the end of p.nodesToQuery that don't already exist in p. +// allNodeIDs must be a sorted slice of all currently-live nodes. +func (p *paginationState) mergeNodeIDs(allNodeIDs []roachpb.NodeID) { + sortedNodeIDs := make([]roachpb.NodeID, 0, len(p.nodesQueried) + 1 + len(p.nodesToQuery)) + sortedNodeIDs = append(sortedNodeIDs, p.nodesQueried...) + if p.inProgress != 0 { + sortedNodeIDs = append(sortedNodeIDs, p.inProgress) + } + sortedNodeIDs = append(sortedNodeIDs, p.nodesToQuery...) + sort.Slice(sortedNodeIDs, func(i, j int) bool { + return sortedNodeIDs[i] < sortedNodeIDs[j] + }) + // As both sortedNodeIDs and allNodeIDs are sorted by node ID, and we just + // need to add (to p.nodesToQuery) values in allNodeIDs that are *not* in + // sortedNodeIDs, we can do this merge by iterating through both slices at the + // same time. j is the index for sortedNodeIDs. + j := 0 + for i := range allNodeIDs { + // Ratchet j forward to the same ID as allNodeIDs[i]. + for j < len(sortedNodeIDs) && sortedNodeIDs[j] < allNodeIDs[i] { + j++ + } + // If allNodeIDs[i] is not in sortedNodeIDs, add it to p.nodesToQuery. + if j >= len(sortedNodeIDs) || sortedNodeIDs[j] != allNodeIDs[i] { + p.nodesToQuery = append(p.nodesToQuery, allNodeIDs[i]) + } + } + if p.inProgress == 0 && len(p.nodesToQuery) > 0 { + p.inProgress = p.nodesToQuery[0] + p.inProgressIndex = 0 + p.nodesToQuery = p.nodesToQuery[1:] + } +} + +// paginate processes the response from a given node, and returns start/end +// indices that the response should be sliced at (if it is a slice; otherwise +// inclusion/exclusion is denoted by end > start). Note that this method +// expects that it is called serially, with nodeIDs in the same order as +// p.nodesToQuery (nodes skipped in that slice are considered to have returned +// an error; out-of-order nodeIDs generate a panic). +func (p *paginationState) paginate(limit int, nodeID roachpb.NodeID, length int) (start, end, newLimit int) { + if limit <= 0 || nodeID == 0 || int(p.inProgress) == 0 { + // Already reached limit. + return 0, 0, 0 + } + if p.inProgress != nodeID { + p.nodesQueried = append(p.nodesQueried, p.inProgress) + p.inProgress = 0 + p.inProgressIndex = 0 + for i := range p.nodesToQuery { + if p.nodesToQuery[i] == nodeID { + // Deducing from the caller contract, all the nodes in + // p.nodesToQuery[0:i] must have returned errors. + p.inProgress = nodeID + p.nodesQueried = append(p.nodesQueried, p.nodesToQuery[0:i]...) + p.nodesToQuery = p.nodesToQuery[i+1:] + break + } + } + if p.inProgress == 0 { + // This node isn't in list. This should never happen. + panic(fmt.Sprintf("could not find node %d in pagination state %v", nodeID, p)) + } + } + doneWithNode := false + if length > 0 { + start = p.inProgressIndex + if start > length { + start = length + } + // end = min(length, start + limit) + if start + limit >= length { + end = length + doneWithNode = true + } else { + end = start + limit + } + limit -= end - start + p.inProgressIndex = end + } + if doneWithNode { + p.nodesQueried = append(p.nodesQueried, nodeID) + p.inProgressIndex = 0 + if len(p.nodesToQuery) > 0 { + p.inProgress = p.nodesToQuery[0] + p.nodesToQuery = p.nodesToQuery[1:] + } else { + p.nodesToQuery = p.nodesToQuery[:0] + p.inProgress = 0 + } + } + return start, end, limit +} + +// UnmarshalText takes a URL-friendly base64-encoded version of a continuation/ +// next token (likely coming from a user HTTP request), and unmarshals it to a +// paginationState. The format is: +// +// ||| +// +// Where: +// - nodesQueried is a comma-separated list of node IDs that have already been +// queried (matching p.nodesQueried). +// - inProgressNode is the ID of the node where the cursor is currently at. +// - inProgressNodeIndex is the index of the response from inProgressNode's +// node-local function where the cursor is currently at. +// - nodesToQuery is a comma-separated list of node IDs of nodes that are yet +// to be queried. +// +// All node IDs and indices are represented as unsigned 32-bit ints, and +// comma-separated lists are allowed to have trailing commas. The character +// separating all of the above components is the pipe (|) character. +func (p *paginationState) UnmarshalText(text []byte) error { + decoder := base64.NewDecoder(base64.URLEncoding, bytes.NewReader(text)) + var decodedText []byte + var err error + if decodedText, err = ioutil.ReadAll(decoder); err != nil { + return err + } + parts := strings.Split(string(decodedText), "|") + if len(parts) != 4 { + return errors.New("invalid pagination state") + } + parseNodeIDSlice := func(str string) ([]roachpb.NodeID, error) { + parts := strings.Split(str, ",") + res := make([]roachpb.NodeID, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if len(part) == 0 { + continue + } + val, err := strconv.ParseUint(part, 10, 32) + if err != nil { + return nil, errors.Wrap(err, "invalid pagination state") + } + if val <= 0 { + return nil, errors.New("expected positive nodeID in pagination token") + } + res = append(res, roachpb.NodeID(val)) + } + return res, nil + } + p.nodesQueried, err = parseNodeIDSlice(parts[0]) + if err != nil { + return err + } + var inProgressInt int + inProgressInt, err = strconv.Atoi(parts[1]) + if err != nil { + return errors.Wrap(err, "invalid pagination state") + } + p.inProgress = roachpb.NodeID(inProgressInt) + p.inProgressIndex, err = strconv.Atoi(parts[2]) + if err != nil { + return errors.Wrap(err, "invalid pagination state") + } + if p.inProgressIndex < 0 || (p.inProgressIndex > 0 && p.inProgress <= 0) { + return errors.Newf("invalid pagination resumption token: (%d, %d)", p.inProgress, p.inProgressIndex) + } + p.nodesToQuery, err = parseNodeIDSlice(parts[3]) + if err != nil { + return err + } + return nil +} + +// MarshalText converts the current paginationState to an ascii text +// representation that can be sent back to the user as a next/continuation +// token. For format, see the comment on UnmarshalText. +func (p *paginationState) MarshalText() (text []byte, err error) { + var builder, builder2 bytes.Buffer + for _, nid := range p.nodesQueried { + fmt.Fprintf(&builder, "%d,", nid) + } + fmt.Fprintf(&builder, "|%d|%d|", p.inProgress, p.inProgressIndex) + for _, nid := range p.nodesToQuery { + fmt.Fprintf(&builder, "%d,", nid) + } + encoder := base64.NewEncoder(base64.URLEncoding, &builder2) + if _, err = encoder.Write(builder.Bytes()); err != nil { + return nil, err + } + if err = encoder.Close(); err != nil { + return nil, err + } + return builder2.Bytes(), nil +} + +// paginatedNodeResponse stores the response from one node in a paginated fan-out +// request. For use with rpcNodePaginator. +type paginatedNodeResponse struct { + nodeID roachpb.NodeID + response interface{} + value reflect.Value + len int + err error +} + +// rpcNodePaginator allows for concurrent fan-out RPC requests to be made to +// multiple nodes, and their responses ordered back in the same ordering as +// that in pagState, and with responses limit-ed to the specified limit. Uses +// reflection to limit the response in the responseFn if it's a slice, and +// treats it as an item of length 1 if it's not a slice. +type rpcNodePaginator struct { + limit int + numNodes int + errorCtx string + pagState paginationState + responseChan chan paginatedNodeResponse + nodeStatuses map[roachpb.NodeID]nodeStatusWithLiveness + + dialFn func(ctx context.Context, id roachpb.NodeID) (client interface{}, err error) + nodeFn func(ctx context.Context, client interface{}, nodeID roachpb.NodeID) (res interface{}, err error) + responseFn func(nodeID roachpb.NodeID, resp interface{}) + errorFn func(nodeID roachpb.NodeID, nodeFnError error) + + mu struct { + sync.Mutex + + turnCond sync.Cond + + currentIdx, currentLen int + } + + // Stores a 1 if the limit has been reached. Must be accessed and updated + // atomically. + done int32 +} + +func (r *rpcNodePaginator) init() { + r.mu.turnCond.L = &r.mu + r.responseChan = make(chan paginatedNodeResponse, r.numNodes) +} + +// queryNode queries the given node, and sends the responses back through responseChan +// in order of idx (i.e. when all nodes with a lower idx have already sent theirs). +// Safe for concurrent use. +func (r *rpcNodePaginator) queryNode(ctx context.Context, nodeID roachpb.NodeID, idx int) { + if atomic.LoadInt32(&r.done) != 0 { + // There are more values than we need. currentLen >= limit. + return + } + var client interface{} + addNodeResp := func(resp paginatedNodeResponse) { + r.mu.Lock() + defer r.mu.Unlock() + + for r.mu.currentIdx < idx && atomic.LoadInt32(&r.done) == 0 { + r.mu.turnCond.Wait() + select { + case <-ctx.Done(): + r.mu.turnCond.Broadcast() + return + default: + } + } + if atomic.LoadInt32(&r.done) != 0 { + // There are more values than we need. currentLen >= limit. + r.mu.turnCond.Broadcast() + return + } + r.responseChan <- resp + r.mu.currentLen += resp.len + if nodeID == r.pagState.inProgress { + // We're resuming partway through a node's response. Subtract away the + // count of values already sent in previous calls (i.e. inProgressIndex). + if resp.len > r.pagState.inProgressIndex { + r.mu.currentLen -= r.pagState.inProgressIndex + } else { + r.mu.currentLen -= resp.len + } + } + if r.mu.currentLen >= r.limit { + atomic.StoreInt32(&r.done, 1) + close(r.responseChan) + } + r.mu.currentIdx++ + r.mu.turnCond.Broadcast() + } + if err := contextutil.RunWithTimeout(ctx, "dial node", base.NetworkTimeout, func(ctx context.Context) error { + var err error + client, err = r.dialFn(ctx, nodeID) + return err + }); err != nil { + err = errors.Wrapf(err, "failed to dial into node %d (%s)", + nodeID, r.nodeStatuses[nodeID].livenessStatus) + addNodeResp(paginatedNodeResponse{nodeID: nodeID, err: err}) + return + } + + res, err := r.nodeFn(ctx, client, nodeID) + if err != nil { + err = errors.Wrapf(err, "error requesting %s from node %d (%s)", + r.errorCtx, nodeID, r.nodeStatuses[nodeID].livenessStatus) + } + length := 0 + value := reflect.ValueOf(res) + if res != nil && !value.IsNil() { + length = 1 + if value.Kind() == reflect.Slice { + length = value.Len() + } + } + addNodeResp(paginatedNodeResponse{nodeID: nodeID, response: res, len: length, value: value, err: err}) +} + +// processResponses processes the responses returned into responseChan. Must only +// be called once. +func (r *rpcNodePaginator) processResponses(ctx context.Context) (next paginationState, err error) { + // Copy r.pagState, as concurrent invocations of queryNode expect it to not + // change. + next = r.pagState + limit := r.limit + numNodes := r.numNodes + for numNodes > 0 { + select { + case res, ok := <-r.responseChan: + if res.err != nil { + r.errorFn(res.nodeID, res.err) + } else { + start, end, newLimit := next.paginate(limit, res.nodeID, res.len) + var response interface{} + if res.value.Kind() == reflect.Slice { + response = res.value.Slice(start, end).Interface() + } else if end > start { + // res.len must be 1 if res.value.Kind is not Slice. + response = res.value.Interface() + } + r.responseFn(res.nodeID, response) + limit = newLimit + } + if !ok { + return next, err + } + case <-ctx.Done(): + err = errors.Errorf("request of %s canceled before completion", r.errorCtx) + break + } + numNodes-- + } + return next, err +} + +func getRPCPaginationValues(r *http.Request) (limit int, start paginationState) { + var err error + if limit, err = strconv.Atoi(r.URL.Query().Get("limit")); err != nil || limit <= 0 { + return 0, paginationState{} + } + if err = start.UnmarshalText([]byte(r.URL.Query().Get("start"))); err != nil { + return limit, paginationState{} + } + return limit, start +} + +func getSimplePaginationValues(r *http.Request) (limit, offset int) { + var err error + if limit, err = strconv.Atoi(r.URL.Query().Get("limit")); err != nil || limit <= 0 { + return 0, 0 + } + if offset, err = strconv.Atoi(r.URL.Query().Get("offset")); err != nil || offset <= 0 { + return limit, 0 + } + return limit, offset +} diff --git a/pkg/server/pagination_test.go b/pkg/server/pagination_test.go new file mode 100644 index 000000000000..aaad4c4c60c4 --- /dev/null +++ b/pkg/server/pagination_test.go @@ -0,0 +1,284 @@ +// Copyright 2020 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package server + +import ( + "context" + "fmt" + "sort" + "strconv" + "strings" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/datadriven" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/require" +) + +func TestSimplePaginate(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + datadriven.RunTest(t, "testdata/simple_paginate", func(t *testing.T, d *datadriven.TestData) string { + switch d.Cmd { + case "paginate": + var input interface{} + if len(d.CmdArgs) != 2 { + return "expected 2 args: paginate " + } + limit, err := strconv.Atoi(d.CmdArgs[0].Key) + if err != nil { + return err.Error() + } + offset, err := strconv.Atoi(d.CmdArgs[1].Key) + if err != nil { + return err.Error() + } + inputString := strings.TrimSpace(d.Input) + if len(inputString) > 0 { + var inputSlice []int + for _, part := range strings.Split(inputString, ",") { + val, err := strconv.Atoi(strings.TrimSpace(part)) + if err != nil { + return err.Error() + } + inputSlice = append(inputSlice, val) + } + input = inputSlice + } + result, next := simplePaginate(input, limit, offset) + return fmt.Sprintf("result=%v\nnext=%d", result, next) + default: + return fmt.Sprintf("unexpected command: %s", d.Cmd) + } + }) +} + +func TestPaginationState(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + parseNodesString := func(t *testing.T, nodesString string) []roachpb.NodeID { + var res []roachpb.NodeID + for _, node := range strings.Split(nodesString, ",") { + i, err := strconv.Atoi(strings.TrimSpace(node)) + require.NoError(t, err) + res = append(res, roachpb.NodeID(i)) + } + return res + } + printState := func(state paginationState) string { + var builder strings.Builder + fmt.Fprintf(&builder, "nodesQueried:", ) + for i, node := range state.nodesQueried { + if i > 0 { + fmt.Fprintf(&builder, ",") + } else { + fmt.Fprintf(&builder, " ") + } + fmt.Fprintf(&builder, "%d", node) + } + fmt.Fprintf(&builder, "\ninProgress: %d", state.inProgress) + fmt.Fprintf(&builder, "\ninProgressIndex: %d", state.inProgressIndex) + fmt.Fprintf(&builder, "\nnodesToQuery:") + for i, node := range state.nodesToQuery { + if i > 0 { + fmt.Fprintf(&builder, ",") + } else { + fmt.Fprintf(&builder, " ") + } + fmt.Fprintf(&builder, "%d", node) + } + return builder.String() + } + + var state paginationState + datadriven.RunTest(t, "testdata/pagination_state", func(t *testing.T, d *datadriven.TestData) string { + switch d.Cmd { + case "define": + state = paginationState{} + for _, line := range strings.Split(d.Input, "\n") { + parts := strings.Split(line, ":") + switch parts[0] { + case "queried": + state.nodesQueried = parseNodesString(t, parts[1]) + case "to-query": + state.nodesToQuery = parseNodesString(t, parts[1]) + case "in-progress": + inProgress, err := strconv.Atoi(strings.TrimSpace(parts[1])) + require.NoError(t, err) + state.inProgress = roachpb.NodeID(inProgress) + case "in-progress-index": + inProgressIdx, err := strconv.Atoi(strings.TrimSpace(parts[1])) + require.NoError(t, err) + state.inProgressIndex = inProgressIdx + default: + return fmt.Sprintf("unexpected keyword: %s", parts[0]) + } + } + return "ok" + + case "merge-node-ids": + state.mergeNodeIDs(parseNodesString(t, d.Input)) + return printState(state) + + case "paginate": + var limit, nodeID, length int + var err error + for _, line := range strings.Split(d.Input, "\n") { + fields := strings.Fields(line) + if len(fields) != 2 { + return "expected lines in the format " + } + switch fields[0] { + case "limit": + limit, err = strconv.Atoi(fields[1]) + case "nodeID": + nodeID, err = strconv.Atoi(fields[1]) + case "length": + length, err = strconv.Atoi(fields[1]) + default: + return fmt.Sprintf("unexpected field: %s", fields[0]) + } + require.NoError(t, err) + } + start, end, newLimit := state.paginate(limit, roachpb.NodeID(nodeID), length) + return fmt.Sprintf("start: %d\nend: %d\nnewLimit: %d\nstate:\n%s", start, end, newLimit, printState(state)) + + case "marshal": + textState, err := state.MarshalText() + require.NoError(t, err) + return string(textState) + + case "unmarshal": + require.NoError(t, state.UnmarshalText([]byte(d.Input))) + return printState(state) + + default: + return fmt.Sprintf("unexpected command: %s", d.Cmd) + } + }) +} + +type testNodeResponse struct { + nodeID roachpb.NodeID + val int +} + +func TestRPCPaginator(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + testCases := []struct{ + limits []int + numResponses map[roachpb.NodeID]int + errors int + }{ + {[]int{3,1,5,7,9}, map[roachpb.NodeID]int{1: 5, 2: 10, 3: 7, 5: 10}, 0}, + {[]int{1,5,10}, map[roachpb.NodeID]int{1: 5, 2: 0, 3: -1, 5: 2}, 1}, + } + + ctx, done := context.WithTimeout(context.Background(), 10*time.Second) + defer done() + + for i, tc := range testCases { + t.Run(fmt.Sprintf("testCase=%d", i), func(t *testing.T) { + // Build a reference response first, to compare each potential limit with. + var referenceResp []testNodeResponse + for nodeID, numResponses := range tc.numResponses { + for i := 0; i < numResponses; i++ { + referenceResp = append(referenceResp, testNodeResponse{nodeID, i}) + } + } + sort.Slice(referenceResp, func(i, j int) bool { + if referenceResp[i].nodeID == referenceResp[j].nodeID { + return referenceResp[i].val < referenceResp[j].val + } + return referenceResp[i].nodeID < referenceResp[j].nodeID + }) + dialFn := func(ctx context.Context, id roachpb.NodeID) (client interface{}, err error) { + return id, nil + } + nodeFn := func(ctx context.Context, client interface{}, nodeID roachpb.NodeID) (res interface{}, err error) { + numResponses := tc.numResponses[nodeID] + if numResponses < 0 { + return nil, errors.New("injected") + } + var response []testNodeResponse + for i := 0; i < numResponses; i++ { + response = append(response, testNodeResponse{nodeID, i}) + } + return response, nil + } + + for _, limit := range tc.limits { + t.Run(fmt.Sprintf("limit=%d", limit), func(t *testing.T) { + var response []testNodeResponse + errorsDetected := 0 + responseFn := func(nodeID roachpb.NodeID, resp interface{}) { + if val, ok := resp.([]testNodeResponse); ok { + response = append(response, val...) + } + } + errorFn := func(nodeID roachpb.NodeID, nodeFnError error) { + errorsDetected++ + } + var pagState paginationState + sortedNodeIDs := make([]roachpb.NodeID, 0, len(tc.numResponses)) + for nodeID := range tc.numResponses { + sortedNodeIDs = append(sortedNodeIDs, nodeID) + } + sort.Slice(sortedNodeIDs, func(i, j int) bool { + return sortedNodeIDs[i] < sortedNodeIDs[j] + }) + pagState.mergeNodeIDs(sortedNodeIDs) + for { + nodesToQuery := []roachpb.NodeID{pagState.inProgress} + nodesToQuery = append(nodesToQuery, pagState.nodesToQuery...) + paginator := rpcNodePaginator{ + limit: limit, + numNodes: len(nodesToQuery), + errorCtx: "test", + pagState: pagState, + nodeStatuses: make(map[roachpb.NodeID]nodeStatusWithLiveness), + dialFn: dialFn, + nodeFn: nodeFn, + responseFn: responseFn, + errorFn: errorFn, + } + paginator.init() + + // Issue requests in parallel. + for idx, nodeID := range nodesToQuery { + go func(nodeID roachpb.NodeID, idx int) { + paginator.queryNode(ctx, nodeID, idx) + }(nodeID, idx) + } + + var err error + pagState, err = paginator.processResponses(ctx) + require.NoError(t, err) + if pagState.inProgress == 0 { + break + } + } + require.Equal(t, referenceResp, response) + require.Equal(t, tc.errors, errorsDetected) + }) + } + }) + } + +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 8b5615dcf586..a08dcba39d45 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1720,6 +1720,9 @@ func (s *Server) PreStart(ctx context.Context) error { } s.mux.Handle(debug.Endpoint, debugHandler) + apiServer := newApiServer(ctx, s) + s.mux.Handle(apiV2Path, apiServer) + log.Event(ctx, "added http endpoints") // Record node start in telemetry. Get the right counter for this storage diff --git a/pkg/server/status.go b/pkg/server/status.go index 366dc6ec9f1a..387c896e32e0 100644 --- a/pkg/server/status.go +++ b/pkg/server/status.go @@ -28,6 +28,7 @@ import ( "regexp" "runtime" "runtime/pprof" + "sort" "strconv" "strings" "sync" @@ -92,6 +93,12 @@ const ( // that will be made at any point of time. maxConcurrentRequests = 100 + // maxConcurrentPaginatedRequests is the maximum number of RPC fan-out + // requests that will be made at any point of time for a row-limited / + // paginated request. This should be much lower than maxConcurrentRequests + // as too much concurrency here can result in wasted results. + maxConcurrentPaginatedRequests = 4 + // omittedKeyStr is the string returned in place of a key when keys aren't // permitted in responses. omittedKeyStr = "omitted (due to the 'server.remote_debugging.mode' setting)" @@ -1227,18 +1234,10 @@ func (s *statusServer) Profile( } } -// Nodes returns all node statuses. -// -// The LivenessByNodeID in the response returns the known liveness -// information according to gossip. Nodes for which there is no gossip -// information will not have an entry. Clients can exploit the fact -// that status "UNKNOWN" has value 0 (the default) when accessing the -// map. -func (s *statusServer) Nodes( +func (s *statusServer) nodesHelper( ctx context.Context, req *serverpb.NodesRequest, -) (*serverpb.NodesResponse, error) { - ctx = propagateGatewayMetadata(ctx) - ctx = s.AnnotateCtx(ctx) + limit, offset int, +) (*serverpb.NodesResponse, int, error) { startKey := keys.StatusNodePrefix endKey := startKey.PrefixEnd() @@ -1246,7 +1245,7 @@ func (s *statusServer) Nodes( b.Scan(startKey, endKey) if err := s.db.Run(ctx, b); err != nil { log.Errorf(ctx, "%v", err) - return nil, status.Errorf(codes.Internal, err.Error()) + return nil, 0, status.Errorf(codes.Internal, err.Error()) } rows := b.Results[0].Rows @@ -1256,14 +1255,36 @@ func (s *statusServer) Nodes( for i, row := range rows { if err := row.ValueProto(&resp.Nodes[i]); err != nil { log.Errorf(ctx, "%v", err) - return nil, status.Errorf(codes.Internal, err.Error()) + return nil, 0, status.Errorf(codes.Internal, err.Error()) } } + sort.Slice(resp.Nodes, func(i, j int) bool { + return resp.Nodes[i].Desc.NodeID < resp.Nodes[j].Desc.NodeID + }) + nodes, next := simplePaginate(resp.Nodes, limit, offset) + resp.Nodes = nodes.([]statuspb.NodeStatus) clock := s.admin.server.clock + // TODO(bilal): Truncate this map to just have resp.Nodes nodes. resp.LivenessByNodeID = getLivenessStatusMap(s.nodeLiveness, clock.Now().GoTime(), s.st) + return &resp, next, nil +} - return &resp, nil +// Nodes returns all node statuses. +// +// The LivenessByNodeID in the response returns the known liveness +// information according to gossip. Nodes for which there is no gossip +// information will not have an entry. Clients can exploit the fact +// that status "UNKNOWN" has value 0 (the default) when accessing the +// map. +func (s *statusServer) Nodes( + ctx context.Context, req *serverpb.NodesRequest, +) (*serverpb.NodesResponse, error) { + ctx = propagateGatewayMetadata(ctx) + ctx = s.AnnotateCtx(ctx) + + resp, _, err := s.nodesHelper(ctx, req, 0, 0) + return resp, err } // nodesStatusWithLiveness is like Nodes but for internal @@ -1650,17 +1671,10 @@ func (s *statusServer) Ranges( return &output, nil } -// HotRanges returns the hottest ranges on each store on the requested node(s). -func (s *statusServer) HotRanges( +func (s *statusServer) hotRangesHelper( ctx context.Context, req *serverpb.HotRangesRequest, -) (*serverpb.HotRangesResponse, error) { - ctx = propagateGatewayMetadata(ctx) - ctx = s.AnnotateCtx(ctx) - - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { - return nil, err - } - + limit int, start paginationState, +) (*serverpb.HotRangesResponse, paginationState, error) { response := &serverpb.HotRangesResponse{ NodeID: s.gossip.NodeID.Get(), HotRangesByNodeID: make(map[roachpb.NodeID]serverpb.HotRangesResponse_NodeResponse), @@ -1669,21 +1683,22 @@ func (s *statusServer) HotRanges( if len(req.NodeID) > 0 { requestedNodeID, local, err := s.parseNodeID(req.NodeID) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, err.Error()) + return nil, start, status.Errorf(codes.InvalidArgument, err.Error()) } // Only hot ranges from the local node. if local { response.HotRangesByNodeID[requestedNodeID] = s.localHotRanges(ctx) - return response, nil + return response, start, nil } // Only hot ranges from one non-local node. status, err := s.dialNode(ctx, requestedNodeID) if err != nil { - return nil, err + return nil, start, err } - return status.HotRanges(ctx, req) + resp, err := status.HotRanges(ctx, req) + return resp, start, err } // Hot ranges from all nodes. @@ -1692,9 +1707,9 @@ func (s *statusServer) HotRanges( return client, err } remoteRequest := serverpb.HotRangesRequest{NodeID: "local"} - nodeFn := func(ctx context.Context, client interface{}, _ roachpb.NodeID) (interface{}, error) { - status := client.(serverpb.StatusClient) - return status.HotRanges(ctx, &remoteRequest) + nodeFn := func(ctx context.Context, client interface{}, nodeID roachpb.NodeID) (interface{}, error) { + statusClient := client.(serverpb.StatusClient) + return statusClient.HotRanges(ctx, &remoteRequest) } responseFn := func(nodeID roachpb.NodeID, resp interface{}) { hotRangesResp := resp.(*serverpb.HotRangesResponse) @@ -1706,11 +1721,28 @@ func (s *statusServer) HotRanges( } } - if err := s.iterateNodes(ctx, "hot ranges", dialFn, nodeFn, responseFn, errorFn); err != nil { + var next paginationState + var err error + if next, err = s.paginatedIterateNodes( + ctx, "hot ranges", limit, start, dialFn, nodeFn, responseFn, errorFn); err != nil { + return nil, start, err + } + return response, next, nil +} + +// HotRanges returns the hottest ranges on each store on the requested node(s). +func (s *statusServer) HotRanges( + ctx context.Context, req *serverpb.HotRangesRequest, +) (*serverpb.HotRangesResponse, error) { + ctx = propagateGatewayMetadata(ctx) + ctx = s.AnnotateCtx(ctx) + + if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { return nil, err } - return response, nil + response, _, err := s.hotRangesHelper(ctx, req, 0, paginationState{}) + return response, err } func (s *statusServer) localHotRanges(ctx context.Context) serverpb.HotRangesResponse_NodeResponse { @@ -1739,18 +1771,10 @@ func (s *statusServer) localHotRanges(ctx context.Context) serverpb.HotRangesRes return resp } -// Range returns rangeInfos for all nodes in the cluster about a specific -// range. It also returns the range history for that range as well. -func (s *statusServer) Range( +func (s *statusServer) rangeHelper( ctx context.Context, req *serverpb.RangeRequest, -) (*serverpb.RangeResponse, error) { - ctx = propagateGatewayMetadata(ctx) - ctx = s.AnnotateCtx(ctx) - - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { - return nil, err - } - + limit int, start paginationState, +) (*serverpb.RangeResponse, paginationState, error) { response := &serverpb.RangeResponse{ RangeID: roachpb.RangeID(req.RangeId), NodeID: s.gossip.NodeID.Get(), @@ -1767,19 +1791,26 @@ func (s *statusServer) Range( } nodeFn := func(ctx context.Context, client interface{}, _ roachpb.NodeID) (interface{}, error) { status := client.(serverpb.StatusClient) - return status.Ranges(ctx, rangesRequest) + resp, err := status.Ranges(ctx, rangesRequest) + if err != nil { + return nil, err + } + sort.Slice(resp.Ranges, func(i, j int) bool { + return resp.Ranges[i].SourceNodeID < resp.Ranges[j].SourceNodeID + }) + return resp.Ranges, nil } nowNanos := timeutil.Now().UnixNano() responseFn := func(nodeID roachpb.NodeID, resp interface{}) { - rangesResp := resp.(*serverpb.RangesResponse) + rangesResp := resp.([]serverpb.RangeInfo) // Age the MVCCStats to a consistent current timestamp. An age that is // not up to date is less useful. - for i := range rangesResp.Ranges { - rangesResp.Ranges[i].State.Stats.AgeTo(nowNanos) + for i := range rangesResp { + rangesResp[i].State.Stats.AgeTo(nowNanos) } response.ResponsesByNodeID[nodeID] = serverpb.RangeResponse_NodeResponse{ Response: true, - Infos: rangesResp.Ranges, + Infos: rangesResp, } } errorFn := func(nodeID roachpb.NodeID, err error) { @@ -1788,12 +1819,31 @@ func (s *statusServer) Range( } } - if err := s.iterateNodes( - ctx, fmt.Sprintf("details about range %d", req.RangeId), dialFn, nodeFn, responseFn, errorFn, + var next paginationState + var err error + if next, err = s.paginatedIterateNodes( + ctx, fmt.Sprintf("details about range %d", req.RangeId), limit, start, + dialFn, nodeFn, responseFn, errorFn, ); err != nil { + return nil, start, err + } + return response, next, nil +} + +// Range returns rangeInfos for all nodes in the cluster about a specific +// range. It also returns the range history for that range as well. +func (s *statusServer) Range( + ctx context.Context, req *serverpb.RangeRequest, +) (*serverpb.RangeResponse, error) { + ctx = propagateGatewayMetadata(ctx) + ctx = s.AnnotateCtx(ctx) + + if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { return nil, err } - return response, nil + + resp, _, err := s.rangeHelper(ctx, req, 0, paginationState{}) + return resp, err } // ListLocalSessions returns a list of SQL sessions on this node. @@ -1890,17 +1940,89 @@ func (s *statusServer) iterateNodes( return resultErr } -// ListSessions returns a list of SQL sessions on all nodes in the cluster. -func (s *statusServer) ListSessions( - ctx context.Context, req *serverpb.ListSessionsRequest, -) (*serverpb.ListSessionsResponse, error) { - ctx = propagateGatewayMetadata(ctx) - ctx = s.AnnotateCtx(ctx) +// paginatedIterateNodes iterates nodeFn over all non-removed nodes +// sequentially. It then calls nodeResponse for every valid result of nodeFn, +// and nodeError on every error result. It returns the next `limit` results +// after `offset`. +func (s *statusServer) paginatedIterateNodes( + ctx context.Context, + errorCtx string, + limit int, pagState paginationState, + dialFn func(ctx context.Context, nodeID roachpb.NodeID) (interface{}, error), + nodeFn func(ctx context.Context, client interface{}, nodeID roachpb.NodeID) (interface{}, error), + responseFn func(nodeID roachpb.NodeID, resp interface{}), + errorFn func(nodeID roachpb.NodeID, nodeFnError error), +) (next paginationState, err error) { + if limit == 0 { + return paginationState{}, s.iterateNodes(ctx, errorCtx, dialFn, nodeFn, responseFn, errorFn) + } + nodeStatuses, err := s.nodesStatusWithLiveness(ctx) + if err != nil { + return paginationState{}, err + } - if _, _, err := s.privilegeChecker.getUserAndRole(ctx); err != nil { - return nil, err + // channels for responses and errors. + type nodeResponse struct { + nodeID roachpb.NodeID + response interface{} + value reflect.Value + len int + err error + } + + numNodes := len(nodeStatuses) + nodeIDs := make([]roachpb.NodeID, 0, numNodes) + for nodeID := range nodeStatuses { + nodeIDs = append(nodeIDs, nodeID) + } + // Sort all nodes by IDs, as this is what mergeNodeIDs expects. + sort.Slice(nodeIDs, func(i, j int) bool { + return nodeIDs[i] < nodeIDs[j] + }) + pagState.mergeNodeIDs(nodeIDs) + // Remove any node that have already been queried. + nodeIDs = nodeIDs[:0] + if pagState.inProgress != 0 { + nodeIDs = append(nodeIDs, pagState.inProgress) + } + nodeIDs = append(nodeIDs, pagState.nodesToQuery...) + + paginator := &rpcNodePaginator{ + limit: limit, + numNodes: len(nodeIDs), + errorCtx: errorCtx, + pagState: pagState, + nodeStatuses: nodeStatuses, + dialFn: dialFn, + nodeFn: nodeFn, + responseFn: responseFn, + errorFn: errorFn, + } + + paginator.init() + // Issue the requests concurrently. + sem := quotapool.NewIntPool("node status", maxConcurrentPaginatedRequests) + ctx, cancel := s.stopper.WithCancelOnStop(ctx) + defer cancel() + for idx, nodeID := range nodeIDs { + nodeID := nodeID // needed to ensure the closure below captures a copy. + idx := idx + if err := s.stopper.RunLimitedAsyncTask( + ctx, fmt.Sprintf("server.statusServer: requesting %s", errorCtx), + sem, true, /* wait */ + func(ctx context.Context) { paginator.queryNode(ctx, nodeID, idx) }, + ); err != nil { + return pagState, err + } } + return paginator.processResponses(ctx) +} + +func (s *statusServer) listSessionsHelper( + ctx context.Context, req *serverpb.ListSessionsRequest, + limit int, start paginationState, +) (*serverpb.ListSessionsResponse, paginationState, error) { response := &serverpb.ListSessionsResponse{ Sessions: make([]serverpb.Session, 0), Errors: make([]serverpb.ListSessionsError, 0), @@ -1911,23 +2033,54 @@ func (s *statusServer) ListSessions( return client, err } nodeFn := func(ctx context.Context, client interface{}, _ roachpb.NodeID) (interface{}, error) { - status := client.(serverpb.StatusClient) - return status.ListLocalSessions(ctx, req) + statusClient := client.(serverpb.StatusClient) + resp, err := statusClient.ListLocalSessions(ctx, req) + if resp != nil && err == nil { + if len(resp.Errors) > 0 { + return nil, errors.New(resp.Errors[0].Message) + } + sort.Slice(resp.Sessions, func(i, j int) bool { + return resp.Sessions[i].Start.Before(resp.Sessions[j].Start) + }) + return resp.Sessions, nil + } + return nil, err } responseFn := func(_ roachpb.NodeID, nodeResp interface{}) { - sessions := nodeResp.(*serverpb.ListSessionsResponse) - response.Sessions = append(response.Sessions, sessions.Sessions...) + if nodeResp == nil { + return + } + sessions := nodeResp.([]serverpb.Session) + response.Sessions = append(response.Sessions, sessions...) } errorFn := func(nodeID roachpb.NodeID, err error) { errResponse := serverpb.ListSessionsError{NodeID: nodeID, Message: err.Error()} response.Errors = append(response.Errors, errResponse) } - if err := s.iterateNodes(ctx, "session list", dialFn, nodeFn, responseFn, errorFn); err != nil { + var err error + var pagState paginationState + if pagState, err = s.paginatedIterateNodes( + ctx, "session list", limit, start, dialFn, nodeFn, responseFn, errorFn); err != nil { err := serverpb.ListSessionsError{Message: err.Error()} response.Errors = append(response.Errors, err) } - return response, nil + return response, pagState, nil +} + +// ListSessions returns a list of SQL sessions on all nodes in the cluster. +func (s *statusServer) ListSessions( + ctx context.Context, req *serverpb.ListSessionsRequest, +) (*serverpb.ListSessionsResponse, error) { + ctx = propagateGatewayMetadata(ctx) + ctx = s.AnnotateCtx(ctx) + + if _, _, err := s.privilegeChecker.getUserAndRole(ctx); err != nil { + return nil, err + } + + resp, _, err := s.listSessionsHelper(ctx, req, 0 /* limit */, paginationState{}) + return resp, err } // CancelSession responds to a session cancellation request by canceling the diff --git a/pkg/server/status_test.go b/pkg/server/status_test.go index df473efc5c3f..64f1d54b4433 100644 --- a/pkg/server/status_test.go +++ b/pkg/server/status_test.go @@ -14,9 +14,11 @@ import ( "bytes" "context" gosql "database/sql" + "encoding/json" "fmt" "io/ioutil" "math" + "net/http" "net/url" "os" "path/filepath" @@ -1883,6 +1885,107 @@ func TestListSessionsSecurity(t *testing.T) { } } +func TestListSessionsV2(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{}) + ctx := context.Background() + defer testCluster.Stopper().Stop(ctx) + + ts1 := testCluster.Server(0) + + var sqlConns []*gosql.Conn + for i := 0; i < 15; i++ { + serverConn := testCluster.ServerConn(i % 3) + conn, err := serverConn.Conn(ctx) + require.NoError(t, err) + sqlConns = append(sqlConns, conn) + } + + defer func() { + for _, conn := range sqlConns { + _ = conn.Close() + } + }() + + doSessionsRequest := func (client http.Client, limit int, start string) listSessionsResponse { + req, err := http.NewRequest("GET", ts1.AdminURL() + apiV2Path + "sessions/", nil) + require.NoError(t, err) + query := req.URL.Query() + if limit > 0 { + query.Add("limit", strconv.Itoa(limit)) + } + if len(start) > 0 { + query.Add("start", start) + } + req.URL.RawQuery = query.Encode() + resp, err := client.Do(req) + require.NoError(t, err) + require.NotNil(t, resp) + bytesResponse, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + + var sessionsResponse listSessionsResponse + if resp.StatusCode != 200 { + t.Fatal(string(bytesResponse)) + } + require.NoError(t, json.Unmarshal(bytesResponse, &sessionsResponse)) + return sessionsResponse + } + + rootClient, err := ts1.GetRootAuthenticatedHTTPClient() + require.NoError(t, err) + sessionsResponse := doSessionsRequest(rootClient, 0, "") + require.LessOrEqual(t, 15, len(sessionsResponse.Sessions)) + allSessions := sessionsResponse.Sessions + sort.Slice(allSessions, func(i, j int) bool { + return allSessions[i].Start.Before(allSessions[j].Start) + }) + + // Test the paginated version is identical to the non-paginated one. + for limit := 1; limit <= 15; limit++ { + var next string + var paginatedSessions []serverpb.Session + for { + sessionsResponse := doSessionsRequest(rootClient, limit, next) + paginatedSessions = append(paginatedSessions, sessionsResponse.Sessions...) + next = sessionsResponse.Next + require.LessOrEqual(t, len(sessionsResponse.Sessions), limit) + if len(sessionsResponse.Sessions) < limit { + break + } + } + sort.Slice(paginatedSessions, func(i, j int) bool { + return paginatedSessions[i].Start.Before(paginatedSessions[j].Start) + }) + // Sometimes there can be a transient session that pops up in one of the two + // calls. Exclude it by only comparing the first 15 sessions. + require.Equal(t, paginatedSessions[:15], allSessions[:15]) + } + + // An non-superuser admin user cannot see sessions across all users. + adminClient, err := ts1.GetAdminAuthenticatedHTTPClient() + require.NoError(t, err) + sessionsResponse2 := doSessionsRequest(adminClient, 0, "") + require.Equal(t, 0, len(sessionsResponse2.Sessions)) + + // An non-admin user cannot see sessions at all. + nonAdminClient, err := ts1.GetAuthenticatedHTTPClient(false) + require.NoError(t, err) + req, err := http.NewRequest("GET", ts1.AdminURL() + apiV2Path + "sessions/", nil) + require.NoError(t, err) + resp, err := nonAdminClient.Do(req) + require.NoError(t, err) + require.NotNil(t, resp) + bytesResponse, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + require.Equal(t, http.StatusForbidden, resp.StatusCode) + require.Contains(t, string(bytesResponse),"not allowed") +} + func TestCreateStatementDiagnosticsReport(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) diff --git a/pkg/server/testdata/pagination_state b/pkg/server/testdata/pagination_state new file mode 100644 index 000000000000..09dd13f24a2d --- /dev/null +++ b/pkg/server/testdata/pagination_state @@ -0,0 +1,212 @@ + +define +queried: 1,2,3 +in-progress: 4 +in-progress-index: 2 +to-query: 5,6,7 +---- +ok + +marshal +---- +MSwyLDMsfDR8Mnw1LDYsNyw= + +merge-node-ids +9,11 +---- +nodesQueried: 1,2,3 +inProgress: 4 +inProgressIndex: 2 +nodesToQuery: 5,6,7,9,11 + +merge-node-ids +8,9,10,11 +---- +nodesQueried: 1,2,3 +inProgress: 4 +inProgressIndex: 2 +nodesToQuery: 5,6,7,9,11,8,10 + +# Do nothing if all nodes being merged are already in the struct. + +merge-node-ids +1,3,4,5,6,7,8 +---- +nodesQueried: 1,2,3 +inProgress: 4 +inProgressIndex: 2 +nodesToQuery: 5,6,7,9,11,8,10 + +marshal +---- +MSwyLDMsfDR8Mnw1LDYsNyw5LDExLDgsMTAs + +# The struct being unmarshalled below should match the one defined at the top. + +unmarshal +MSwyLDMsfDR8Mnw1LDYsNyw= +---- +nodesQueried: 1,2,3 +inProgress: 4 +inProgressIndex: 2 +nodesToQuery: 5,6,7 + +unmarshal +MSx8MnwyfDMs +---- +nodesQueried: 1 +inProgress: 2 +inProgressIndex: 2 +nodesToQuery: 3 + +# Tests for paginate(). + +define +queried: 1,2,3 +in-progress: 4 +in-progress-index: 2 +to-query: 5,6,7 +---- +ok + +# Simple case - get the next 5 elements a couple times. Note that each +# subsequent `limit` matches the previously returned `newLimit`, or 5 if the +# previous newLimit was 0 (i.e. denoting the "next request"). + +paginate +limit 5 +length 10 +nodeID 4 +---- +start: 2 +end: 7 +newLimit: 0 +state: +nodesQueried: 1,2,3 +inProgress: 4 +inProgressIndex: 7 +nodesToQuery: 5,6,7 + + +paginate +limit 5 +length 10 +nodeID 4 +---- +start: 7 +end: 10 +newLimit: 2 +state: +nodesQueried: 1,2,3,4 +inProgress: 5 +inProgressIndex: 0 +nodesToQuery: 6,7 + + +paginate +limit 2 +length 7 +nodeID 5 +---- +start: 0 +end: 2 +newLimit: 0 +state: +nodesQueried: 1,2,3,4 +inProgress: 5 +inProgressIndex: 2 +nodesToQuery: 6,7 + +paginate +limit 5 +length 7 +nodeID 5 +---- +start: 2 +end: 7 +newLimit: 0 +state: +nodesQueried: 1,2,3,4,5 +inProgress: 6 +inProgressIndex: 0 +nodesToQuery: 7 + +paginate +limit 5 +length 4 +nodeID 6 +---- +start: 0 +end: 4 +newLimit: 1 +state: +nodesQueried: 1,2,3,4,5,6 +inProgress: 7 +inProgressIndex: 0 +nodesToQuery: + +paginate +limit 1 +length 6 +nodeID 7 +---- +start: 0 +end: 1 +newLimit: 0 +state: +nodesQueried: 1,2,3,4,5,6 +inProgress: 7 +inProgressIndex: 1 +nodesToQuery: + +paginate +limit 5 +length 6 +nodeID 7 +---- +start: 1 +end: 6 +newLimit: 0 +state: +nodesQueried: 1,2,3,4,5,6,7 +inProgress: 0 +inProgressIndex: 0 +nodesToQuery: + +# Test a case where node 5 returns an error and gets skipped. + +define +queried: 1,2,3 +in-progress: 4 +in-progress-index: 2 +to-query: 5,6,7 +---- +ok + +paginate +limit 5 +length 5 +nodeID 4 +---- +start: 2 +end: 5 +newLimit: 2 +state: +nodesQueried: 1,2,3,4 +inProgress: 5 +inProgressIndex: 0 +nodesToQuery: 6,7 + +paginate +limit 2 +length 5 +nodeID 6 +---- +start: 0 +end: 2 +newLimit: 0 +state: +nodesQueried: 1,2,3,4,5 +inProgress: 6 +inProgressIndex: 2 +nodesToQuery: 7 diff --git a/pkg/server/testdata/simple_paginate b/pkg/server/testdata/simple_paginate new file mode 100644 index 000000000000..b80ead49710e --- /dev/null +++ b/pkg/server/testdata/simple_paginate @@ -0,0 +1,57 @@ + +# usage: paginate +# + +# Simple two cases + +paginate 5 0 +1,2,3,4,5,6,7,8,9,10 +---- +result=[1 2 3 4 5] +next=5 + +paginate 5 5 +1,2,3,4,5,6,7,8,9,10 +---- +result=[6 7 8 9 10] +next=10 + +# Case where end index is greater than len. + +paginate 5 5 +1,2,3,4,5,6,7,8 +---- +result=[6 7 8] +next=8 + +# Offset beyond the end returns an empty slice. + +paginate 15 15 +1,2,3,4,5,6,7,8 +---- +result=[] +next=8 + +# Limits of 0 translate to returning the entire object +# (i.e. pagination disabled) + +paginate 0 0 +1,2,3,4,5,6,7,8,9,10 +---- +result=[1 2 3 4 5 6 7 8 9 10] +next=0 + +# Negative offsets silently translate to 0. + +paginate 5 -1 +1,2,3,4,5,6,7,8,9,10 +---- +result=[1 2 3 4 5] +next=5 + +# Non-slice input always returns a nil output + +paginate 5 5 +---- +result= +next=0 diff --git a/pkg/server/testserver.go b/pkg/server/testserver.go index d3c6ee200c2d..84720cea8a79 100644 --- a/pkg/server/testserver.go +++ b/pkg/server/testserver.go @@ -12,6 +12,7 @@ package server import ( "context" + "encoding/base64" "fmt" "net/http" "net/http/cookiejar" @@ -60,6 +61,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/metric" "github.com/cockroachdb/cockroach/pkg/util/netutil" + "github.com/cockroachdb/cockroach/pkg/util/protoutil" "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" @@ -287,7 +289,7 @@ type TestServer struct { *Server // authClient is an http.Client that has been authenticated to access the // Admin UI. - authClient [2]struct { + authClient [3]struct { httpClient http.Client cookie *serverpb.SessionCookie once sync.Once @@ -916,7 +918,14 @@ func authenticatedUserNameNoAdmin() security.SQLUsername { // GetAdminAuthenticatedHTTPClient implements the TestServerInterface. func (ts *TestServer) GetAdminAuthenticatedHTTPClient() (http.Client, error) { httpClient, _, err := ts.getAuthenticatedHTTPClientAndCookie( - authenticatedUserName(), true) + authenticatedUserName(), true, false) + return httpClient, err +} + +// GetRootAuthenticatedHTTPClient implements the TestServerInterface. +func (ts *TestServer) GetRootAuthenticatedHTTPClient() (http.Client, error) { + httpClient, _, err := ts.getAuthenticatedHTTPClientAndCookie( + security.RootUserName(), true, true) return httpClient, err } @@ -926,24 +935,40 @@ func (ts *TestServer) GetAuthenticatedHTTPClient(isAdmin bool) (http.Client, err if !isAdmin { authUser = authenticatedUserNameNoAdmin() } - httpClient, _, err := ts.getAuthenticatedHTTPClientAndCookie(authUser, isAdmin) + httpClient, _, err := ts.getAuthenticatedHTTPClientAndCookie(authUser, isAdmin, false /* exists */) return httpClient, err } +type v2AuthDecorator struct { + http.RoundTripper + + session string +} + +func (v *v2AuthDecorator) RoundTrip(r *http.Request) (*http.Response, error) { + r.Header.Add(apiV2AuthHeader, v.session) + return v.RoundTripper.RoundTrip(r) +} + func (ts *TestServer) getAuthenticatedHTTPClientAndCookie( - authUser security.SQLUsername, isAdmin bool, + authUser security.SQLUsername, isAdmin bool, isRoot bool, ) (http.Client, *serverpb.SessionCookie, error) { authIdx := 0 if isAdmin { authIdx = 1 } + if isRoot { + authIdx = 2 + } authClient := &ts.authClient[authIdx] authClient.once.Do(func() { // Create an authentication session for an arbitrary admin user. authClient.err = func() error { // The user needs to exist as the admin endpoints will check its role. - if err := ts.createAuthUser(authUser, isAdmin); err != nil { - return err + if !isRoot { + if err := ts.createAuthUser(authUser, isAdmin); err != nil { + return err + } } id, secret, err := ts.authentication.newAuthSession(context.TODO(), authUser) @@ -973,6 +998,14 @@ func (ts *TestServer) getAuthenticatedHTTPClientAndCookie( if err != nil { return err } + rawCookieBytes, err := protoutil.Marshal(rawCookie) + if err != nil { + return err + } + authClient.httpClient.Transport = &v2AuthDecorator{ + RoundTripper: authClient.httpClient.Transport, + session: base64.StdEncoding.EncodeToString(rawCookieBytes), + } authClient.httpClient.Jar = cookieJar authClient.cookie = rawCookie return nil diff --git a/pkg/testutils/serverutils/test_server_shim.go b/pkg/testutils/serverutils/test_server_shim.go index 5aab3a6a9452..20ce13ff3dd0 100644 --- a/pkg/testutils/serverutils/test_server_shim.go +++ b/pkg/testutils/serverutils/test_server_shim.go @@ -150,6 +150,10 @@ type TestServerInterface interface { // authenticated to access Admin API methods (via a cookie). // The user has admin privileges. GetAdminAuthenticatedHTTPClient() (http.Client, error) + // GetRootAuthenticatedHTTPClient returns an http client which has been + // authenticated to access Admin API methods (via a cookie). + // The user has admin and superuser privileges. + GetRootAuthenticatedHTTPClient() (http.Client, error) // GetAuthenticatedHTTPClient returns an http client which has been // authenticated to access Admin API methods (via a cookie). GetAuthenticatedHTTPClient(isAdmin bool) (http.Client, error)