From 40f58aa6609372ad37507200af3e7fc6aab70272 Mon Sep 17 00:00:00 2001 From: Bilal Akhtar Date: Wed, 24 Feb 2021 16:42:36 -0500 Subject: [PATCH 1/3] server: Small tweaks to offset-based pagination code Makes small changes to simplePaginate such as returning a 0 for offset if the end of a slice has been reached. This works well with `omitzero` JSON fields as that'd let the Next value go ignored. Also adds a getSimplePaginateValues to parse request query string values for offset-based (aka "simple") pagination. Release note: None. Release justification: Small, low-risk change that only affects new endpoints that exist in parallel to existing ones. --- pkg/server/pagination.go | 24 ++++++++++++++++++++++-- pkg/server/testdata/simple_paginate | 6 +++--- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/pkg/server/pagination.go b/pkg/server/pagination.go index 8522c410c924..a010dcd36bf1 100644 --- a/pkg/server/pagination.go +++ b/pkg/server/pagination.go @@ -34,7 +34,10 @@ import ( // simplePaginate takes in an input slice, and returns a sub-slice of the next // `limit` elements starting at `offset`. The second returned value is the // next offset that can be used to return the next "limit" results, or -// len(result) if there are no more results. +// 0 if there are no more results. The choice of a 0 return value for next +// in cases where input has been exhausted, helps when it's being returned +// back to the client as a `json:omitempty` field, as the JSON mashal code will +// simply ignore the field if it's a zero value. func simplePaginate(input interface{}, limit, offset int) (result interface{}, next int) { val := reflect.ValueOf(input) if limit <= 0 || val.Kind() != reflect.Slice { @@ -50,7 +53,11 @@ func simplePaginate(input interface{}, limit, offset int) (result interface{}, n if endIdx > val.Len() { endIdx = val.Len() } - return val.Slice(startIdx, endIdx).Interface(), endIdx + next = endIdx + if endIdx == val.Len() { + next = 0 + } + return val.Slice(startIdx, endIdx).Interface(), next } // paginationState represents the current state of pagination through the result @@ -441,3 +448,16 @@ func getRPCPaginationValues(r *http.Request) (limit int, start paginationState) } return limit, start } + +// getSimplePaginationValues parses offset-based pagination related values out +// of the query string of a Request. Meant for use with simplePaginate. +func getSimplePaginationValues(r *http.Request) (limit, offset int) { + var err error + if limit, err = strconv.Atoi(r.URL.Query().Get("limit")); err != nil || limit <= 0 { + return 0, 0 + } + if offset, err = strconv.Atoi(r.URL.Query().Get("offset")); err != nil || offset < 0 { + return limit, 0 + } + return limit, offset +} diff --git a/pkg/server/testdata/simple_paginate b/pkg/server/testdata/simple_paginate index b80ead49710e..44e296e8a46c 100644 --- a/pkg/server/testdata/simple_paginate +++ b/pkg/server/testdata/simple_paginate @@ -14,7 +14,7 @@ paginate 5 5 1,2,3,4,5,6,7,8,9,10 ---- result=[6 7 8 9 10] -next=10 +next=0 # Case where end index is greater than len. @@ -22,7 +22,7 @@ paginate 5 5 1,2,3,4,5,6,7,8 ---- result=[6 7 8] -next=8 +next=0 # Offset beyond the end returns an empty slice. @@ -30,7 +30,7 @@ paginate 15 15 1,2,3,4,5,6,7,8 ---- result=[] -next=8 +next=0 # Limits of 0 translate to returning the entire object # (i.e. pagination disabled) From c02dd6fae94bbf7516c0d975c6a9218aecb1aae0 Mon Sep 17 00:00:00 2001 From: Eric Wang Date: Thu, 18 Feb 2021 14:07:29 -0500 Subject: [PATCH 2/3] sql: drop default value when its dependent sequence is dropped Fixes https://github.com/cockroachdb/cockroach/issues/51889. Previously, when we drop a database or a schema, we do not check if it contains any sequences that are used by anything in other databases/schemas. As a result, if a sequence is used by a table in a different database/schema, we end up dropping a sequence that the tables still relies on, resulting in a corrupted default expression. This patch addresses this issue by emulating Postgres behavior by removing the default value from any columns that use the dropped sequence. Release note (bug fix): drop default value when its dependent sequence is dropped. Release justification: low risk bug fix for existing functionality --- pkg/sql/alter_table.go | 2 +- pkg/sql/drop_cascade.go | 4 +- pkg/sql/drop_sequence.go | 81 ++++- pkg/sql/drop_table.go | 9 +- .../testdata/logic_test/drop_sequence | 317 ++++++++++++++++++ pkg/sql/sequence.go | 4 +- 6 files changed, 402 insertions(+), 15 deletions(-) diff --git a/pkg/sql/alter_table.go b/pkg/sql/alter_table.go index bf2c750e73c9..adaa756a9a92 100644 --- a/pkg/sql/alter_table.go +++ b/pkg/sql/alter_table.go @@ -439,7 +439,7 @@ func (n *alterTableNode) startExec(params runParams) error { return err } - if err := params.p.dropSequencesOwnedByCol(params.ctx, colToDrop.ColumnDesc(), true /* queueJob */); err != nil { + if err := params.p.dropSequencesOwnedByCol(params.ctx, colToDrop.ColumnDesc(), true /* queueJob */, t.DropBehavior); err != nil { return err } diff --git a/pkg/sql/drop_cascade.go b/pkg/sql/drop_cascade.go index c04ada462af5..7e87976745f5 100644 --- a/pkg/sql/drop_cascade.go +++ b/pkg/sql/drop_cascade.go @@ -192,13 +192,11 @@ func (d *dropCascadeState) dropAllCollectedObjects(ctx context.Context, p *plann var cascadedObjects []string var err error if desc.IsView() { - // TODO(knz): The names of dependent dropped views should be qualified here. cascadedObjects, err = p.dropViewImpl(ctx, desc, false /* queueJob */, "", tree.DropCascade) } else if desc.IsSequence() { err = p.dropSequenceImpl(ctx, desc, false /* queueJob */, "", tree.DropCascade) } else { - // TODO(knz): The names of dependent dropped tables should be qualified here. - cascadedObjects, err = p.dropTableImpl(ctx, desc, true /* droppingParent */, "") + cascadedObjects, err = p.dropTableImpl(ctx, desc, true /* droppingParent */, "", tree.DropCascade) } if err != nil { return err diff --git a/pkg/sql/drop_sequence.go b/pkg/sql/drop_sequence.go index 5c2e57475bc4..796892903c0d 100644 --- a/pkg/sql/drop_sequence.go +++ b/pkg/sql/drop_sequence.go @@ -12,6 +12,7 @@ package sql import ( "context" + "fmt" "github.com/cockroachdb/cockroach/pkg/server/telemetry" "github.com/cockroachdb/cockroach/pkg/sql/catalog" @@ -21,7 +22,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sqltelemetry" - "github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented" "github.com/cockroachdb/cockroach/pkg/util/iterutil" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/log/eventpb" @@ -54,7 +54,7 @@ func (p *planner) DropSequence(ctx context.Context, n *tree.DropSequence) (planN continue } - if depErr := p.sequenceDependencyError(ctx, droppedDesc); depErr != nil { + if depErr := p.sequenceDependencyError(ctx, droppedDesc, n.DropBehavior); depErr != nil { return nil, depErr } @@ -116,6 +116,11 @@ func (p *planner) dropSequenceImpl( if err := removeSequenceOwnerIfExists(ctx, p, seqDesc.ID, seqDesc.GetSequenceOpts()); err != nil { return err } + if behavior == tree.DropCascade { + if err := dropDependentOnSequence(ctx, p, seqDesc); err != nil { + return err + } + } return p.initiateDropTable(ctx, seqDesc, queueJob, jobDesc, true /* drainName */) } @@ -123,9 +128,9 @@ func (p *planner) dropSequenceImpl( // a table uses it in a DEFAULT expression on one of its columns, or nil if there is no // such dependency. func (p *planner) sequenceDependencyError( - ctx context.Context, droppedDesc *tabledesc.Mutable, + ctx context.Context, droppedDesc *tabledesc.Mutable, behavior tree.DropBehavior, ) error { - if len(droppedDesc.DependedOnBy) > 0 { + if behavior != tree.DropCascade && len(droppedDesc.DependedOnBy) > 0 { return pgerror.Newf( pgcode.DependentObjectsStillExist, "cannot drop sequence %s because other objects depend on it", @@ -205,10 +210,9 @@ func (p *planner) canRemoveOwnedSequencesImpl( } } - // Once Drop Sequence Cascade actually respects the drop behavior, this - // check should go away. + // If cascade is enabled, allow the sequences to be dropped. if behavior == tree.DropCascade { - return unimplemented.NewWithIssue(20965, "DROP SEQUENCE CASCADE is currently unimplemented") + continue } // If Cascade is not enabled, and more than 1 columns depend on it, and the return pgerror.Newf( @@ -219,3 +223,66 @@ func (p *planner) canRemoveOwnedSequencesImpl( } return nil } + +// dropDependentOnSequence drops the default values of any columns that depend on the +// given sequence descriptor being dropped, and if the dependent object +// is a view, it drops the views. +// This is called when the DropBehavior is DropCascade. +func dropDependentOnSequence(ctx context.Context, p *planner, seqDesc *tabledesc.Mutable) error { + for _, dependent := range seqDesc.DependedOnBy { + tblDesc, err := p.Descriptors().GetMutableTableByID(ctx, p.txn, dependent.ID, + tree.ObjectLookupFlags{ + CommonLookupFlags: tree.CommonLookupFlags{ + IncludeOffline: true, + IncludeDropped: true, + }, + }) + if err != nil { + return err + } + + // If the table that uses the sequence has been dropped already, + // no need to update, so skip. + if tblDesc.Dropped() { + continue + } + + // If the dependent object is a view, drop the view. + if tblDesc.IsView() { + _, err = p.dropViewImpl(ctx, tblDesc, false /* queueJob */, "", tree.DropCascade) + if err != nil { + return err + } + continue + } + + // Set of column IDs which will have their default values dropped. + colsToDropDefault := make(map[descpb.ColumnID]struct{}) + for _, colID := range dependent.ColumnIDs { + colsToDropDefault[colID] = struct{}{} + } + + // Iterate over all columns in the table, drop affected columns' default values + // and update back references. + for idx := range tblDesc.Columns { + column := &tblDesc.Columns[idx] + if _, ok := colsToDropDefault[column.ID]; ok { + column.DefaultExpr = nil + if err := p.removeSequenceDependencies(ctx, tblDesc, column); err != nil { + return err + } + } + } + + jobDesc := fmt.Sprintf( + "removing default expressions using sequence %q since it is being dropped", + seqDesc.Name, + ) + if err := p.writeSchemaChange( + ctx, tblDesc, descpb.InvalidMutationID, jobDesc, + ); err != nil { + return err + } + } + return nil +} diff --git a/pkg/sql/drop_table.go b/pkg/sql/drop_table.go index f27ad26bcd60..56bbe4f0f1fa 100644 --- a/pkg/sql/drop_table.go +++ b/pkg/sql/drop_table.go @@ -133,6 +133,7 @@ func (n *dropTableNode) startExec(params runParams) error { droppedDesc, false, /* droppingDatabase */ tree.AsStringWithFQNames(n.n, params.Ann()), + n.n.DropBehavior, ) if err != nil { return err @@ -270,7 +271,11 @@ func (p *planner) removeInterleave(ctx context.Context, ref descpb.ForeignKeyRef // dropped due to `cascade` behavior. droppingParent indicates whether this // table's parent (either database or schema) is being dropped func (p *planner) dropTableImpl( - ctx context.Context, tableDesc *tabledesc.Mutable, droppingParent bool, jobDesc string, + ctx context.Context, + tableDesc *tabledesc.Mutable, + droppingParent bool, + jobDesc string, + behavior tree.DropBehavior, ) ([]string, error) { var droppedViews []string @@ -322,7 +327,7 @@ func (p *planner) dropTableImpl( // Drop sequences that the columns of the table own. for _, col := range tableDesc.Columns { - if err := p.dropSequencesOwnedByCol(ctx, &col, !droppingParent); err != nil { + if err := p.dropSequencesOwnedByCol(ctx, &col, !droppingParent, behavior); err != nil { return droppedViews, err } } diff --git a/pkg/sql/logictest/testdata/logic_test/drop_sequence b/pkg/sql/logictest/testdata/logic_test/drop_sequence index cc4757345a9c..6718f2b5050e 100644 --- a/pkg/sql/logictest/testdata/logic_test/drop_sequence +++ b/pkg/sql/logictest/testdata/logic_test/drop_sequence @@ -1,5 +1,11 @@ # see also file `sequences` +statement ok +SET sql_safe_updates = true + +# Test dropping sequences with/without CASCADE +subtest drop_sequence + statement ok CREATE SEQUENCE drop_test @@ -14,3 +20,314 @@ CREATE SEQUENCE drop_if_exists_test statement ok DROP SEQUENCE IF EXISTS drop_if_exists_test + +statement ok +CREATE SEQUENCE drop_test + +statement ok +CREATE TABLE t1 (i INT NOT NULL DEFAULT nextval('drop_test')) + +query TT +SHOW CREATE TABLE t1 +---- +t1 CREATE TABLE public.t1 ( + i INT8 NOT NULL DEFAULT nextval('test.public.drop_test':::STRING::REGCLASS), + rowid INT8 NOT VISIBLE NOT NULL DEFAULT unique_rowid(), + CONSTRAINT "primary" PRIMARY KEY (rowid ASC), + FAMILY "primary" (i, rowid) +) + +query T +SELECT pg_get_serial_sequence('t1', 'i') +---- +public.drop_test + +statement error pq: cannot drop sequence drop_test because other objects depend on it +DROP SEQUENCE drop_test + +statement ok +DROP SEQUENCE drop_test CASCADE + +query TT +SHOW CREATE TABLE t1 +---- +t1 CREATE TABLE public.t1 ( + i INT8 NOT NULL, + rowid INT8 NOT VISIBLE NOT NULL DEFAULT unique_rowid(), + CONSTRAINT "primary" PRIMARY KEY (rowid ASC), + FAMILY "primary" (i, rowid) +) + +query T +SELECT pg_get_serial_sequence('t1', 'i') +---- +NULL + +statement ok +INSERT INTO t1 VALUES (1) + + +# Test that if a database is dropped with CASCADE and it +# contains a sequence, that sequence is dropped and any DEFAULT +# expressions using that sequence will also be dropped. +subtest drop_database_cascade + +statement ok +CREATE DATABASE other_db + +statement ok +CREATE SEQUENCE other_db.s + +statement ok +CREATE SEQUENCE s + +statement ok +CREATE TABLE foo ( + i INT NOT NULL DEFAULT nextval('other_db.s'), + j INT NOT NULL DEFAULT nextval('s'), + FAMILY (i, j) +) + +query TT +SHOW CREATE TABLE foo +---- +foo CREATE TABLE public.foo ( + i INT8 NOT NULL DEFAULT nextval('other_db.public.s':::STRING::REGCLASS), + j INT8 NOT NULL DEFAULT nextval('test.public.s':::STRING::REGCLASS), + rowid INT8 NOT VISIBLE NOT NULL DEFAULT unique_rowid(), + CONSTRAINT "primary" PRIMARY KEY (rowid ASC), + FAMILY fam_0_i_j_rowid (i, j, rowid) +) + +query TT +SELECT pg_get_serial_sequence('foo', 'i'), pg_get_serial_sequence('foo', 'j') +---- +public.s public.s + +statement error DROP DATABASE on non-empty database without explicit CASCADE +DROP DATABASE other_db + +statement ok +DROP DATABASE other_db CASCADE + +query TT +SHOW CREATE TABLE foo +---- +foo CREATE TABLE public.foo ( + i INT8 NOT NULL, + j INT8 NOT NULL DEFAULT nextval('test.public.s':::STRING::REGCLASS), + rowid INT8 NOT VISIBLE NOT NULL DEFAULT unique_rowid(), + CONSTRAINT "primary" PRIMARY KEY (rowid ASC), + FAMILY fam_0_i_j_rowid (i, j, rowid) +) + +query TT +SELECT pg_get_serial_sequence('foo', 'i'), pg_get_serial_sequence('foo', 'j') +---- +NULL public.s + +statement ok +INSERT INTO foo VALUES (1, default) + + +# Test that if a schema is dropped and it contains a sequence, +# any DEFAULT expressions using that sequence will also be dropped. +subtest drop_schema_cascade + +statement ok +CREATE SCHEMA other_sc + +statement ok +CREATE SEQUENCE other_sc.s + +statement ok +CREATE TABLE bar ( + i INT NOT NULL DEFAULT nextval('other_sc.s'), + j INT NOT NULL DEFAULT nextval('s'), + FAMILY (i, j) +) + +query TT +SHOW CREATE TABLE bar +---- +bar CREATE TABLE public.bar ( + i INT8 NOT NULL DEFAULT nextval('test.other_sc.s':::STRING::REGCLASS), + j INT8 NOT NULL DEFAULT nextval('test.public.s':::STRING::REGCLASS), + rowid INT8 NOT VISIBLE NOT NULL DEFAULT unique_rowid(), + CONSTRAINT "primary" PRIMARY KEY (rowid ASC), + FAMILY fam_0_i_j_rowid (i, j, rowid) +) + +query TT +SELECT pg_get_serial_sequence('bar', 'i'), pg_get_serial_sequence('bar', 'j') +---- +other_sc.s public.s + +statement error schema "other_sc" is not empty and CASCADE was not specified +DROP SCHEMA other_sc + +statement ok +DROP SCHEMA other_sc CASCADE + +query TT +SHOW CREATE TABLE bar +---- +bar CREATE TABLE public.bar ( + i INT8 NOT NULL, + j INT8 NOT NULL DEFAULT nextval('test.public.s':::STRING::REGCLASS), + rowid INT8 NOT VISIBLE NOT NULL DEFAULT unique_rowid(), + CONSTRAINT "primary" PRIMARY KEY (rowid ASC), + FAMILY fam_0_i_j_rowid (i, j, rowid) +) + +query TT +SELECT pg_get_serial_sequence('bar', 'i'), pg_get_serial_sequence('bar', 'j') +---- +NULL public.s + +statement ok +INSERT INTO bar VALUES (1, default) + + +# Test that sequences owned by tables are dropped properly, +# and if CASCADE is specified, DEFAULT expressions are dropped +subtest drop_table_cascade + +statement ok +CREATE TABLE t2 (i INT NOT NULL) + +statement ok +CREATE SEQUENCE s2 OWNED BY t2.i + +statement ok +CREATE TABLE t3 (i INT NOT NULL DEFAULT nextval('s2')) + +query T +SELECT pg_get_serial_sequence('t3', 'i') +---- +public.s2 + +statement error cannot drop table t2 because other objects depend on it +DROP TABLE t2 + +statement ok +DROP TABLE t2 CASCADE + +query TT +SHOW CREATE TABLE t3 +---- +t3 CREATE TABLE public.t3 ( + i INT8 NOT NULL, + rowid INT8 NOT VISIBLE NOT NULL DEFAULT unique_rowid(), + CONSTRAINT "primary" PRIMARY KEY (rowid ASC), + FAMILY "primary" (i, rowid) +) + +query T +SELECT pg_get_serial_sequence('t3', 'i') +---- +NULL + +statement ok +INSERT INTO t3 VALUES (1) + +statement ok +CREATE SEQUENCE s3 + +statement ok +CREATE TABLE t4 (i INT NOT NULL DEFAULT nextval('s3')) + +statement ok +ALTER SEQUENCE s3 OWNED BY t3.i + +query T +SELECT pg_get_serial_sequence('t4', 'i') +---- +public.s3 + +statement ok +DROP TABLE t3 CASCADE + +query TT +SHOW CREATE TABLE t4 +---- +t4 CREATE TABLE public.t4 ( + i INT8 NOT NULL, + rowid INT8 NOT VISIBLE NOT NULL DEFAULT unique_rowid(), + CONSTRAINT "primary" PRIMARY KEY (rowid ASC), + FAMILY "primary" (i, rowid) +) + +query T +SELECT pg_get_serial_sequence('t4', 'i') +---- +NULL + +statement ok +INSERT INTO t4 VALUES (1) + + +# Test that sequences owned by columns are dropped properly, +# and if CASCADE is specified, DEFAULT expressions are dropped +subtest drop_column_cascade + +statement ok +CREATE TABLE t5 (i INT NOT NULL) + +statement ok +CREATE SEQUENCE s5 OWNED BY t5.i + +statement ok +CREATE TABLE t6 (i INT NOT NULL DEFAULT nextval('s5')) + +query T +SELECT pg_get_serial_sequence('t6', 'i') +---- +public.s5 + +statement error ALTER TABLE DROP COLUMN will remove all data in that column +ALTER TABLE t5 DROP COLUMN i + +statement ok +SET sql_safe_updates = false + +statement ok +ALTER TABLE t5 DROP COLUMN i CASCADE + +query TT +SHOW CREATE TABLE t6 +---- +t6 CREATE TABLE public.t6 ( + i INT8 NOT NULL, + rowid INT8 NOT VISIBLE NOT NULL DEFAULT unique_rowid(), + CONSTRAINT "primary" PRIMARY KEY (rowid ASC), + FAMILY "primary" (i, rowid) +) + +query T +SELECT pg_get_serial_sequence('t6', 'i') +---- +NULL + +statement ok +INSERT INTO t6 VALUES (1) + + +# Test that sequences owned by columns are dropped properly, +# and if CASCADE is specified, DEFAULT expressions are dropped +subtest drop_view + +statement ok +CREATE SEQUENCE s6 + +statement ok +CREATE VIEW v AS SELECT nextval('s6') + +statement error cannot drop sequence s6 because other objects depend on it +DROP SEQUENCE s6 + +statement ok +DROP SEQUENCE s6 CASCADE + +statement error relation "v" does not exist +SELECT * from v diff --git a/pkg/sql/sequence.go b/pkg/sql/sequence.go index d4522a7e67c8..04a9b375e630 100644 --- a/pkg/sql/sequence.go +++ b/pkg/sql/sequence.go @@ -723,7 +723,7 @@ func maybeAddSequenceDependencies( // dropSequencesOwnedByCol drops all the sequences from col.OwnsSequenceIDs. // Called when the respective column (or the whole table) is being dropped. func (p *planner) dropSequencesOwnedByCol( - ctx context.Context, col *descpb.ColumnDescriptor, queueJob bool, + ctx context.Context, col *descpb.ColumnDescriptor, queueJob bool, behavior tree.DropBehavior, ) error { // Copy out the sequence IDs as the code to drop the sequence will reach // back around and update the descriptor from underneath us. @@ -749,7 +749,7 @@ func (p *planner) dropSequencesOwnedByCol( // Note that this call will end up resolving and modifying the table // descriptor. if err := p.dropSequenceImpl( - ctx, seqDesc, queueJob, jobDesc, tree.DropRestrict, + ctx, seqDesc, queueJob, jobDesc, behavior, ); err != nil { return err } From 7c55798b09b2afa94193cd8ea9e82e8eb44ce5b1 Mon Sep 17 00:00:00 2001 From: Bilal Akhtar Date: Thu, 11 Feb 2021 15:36:50 -0500 Subject: [PATCH 3/3] server: Migrate nodes, {hot-,}ranges, health endpoints to v2 API This change adds API v2 compatible versions of the node list, range info per node, ranges info across nodes, hot ranges, and health endpoints. These endpoints all support API v2 header-based authentication, pagination (if applicable), and only return relevant information in the response payloads. Release note (api change): Add these new HTTP API endpoints: - `/api/v2/nodes/`: Lists all nodes in the cluster - `/api/v2/nodes//ranges`: Lists all ranges on the specified node - `/api/v2/ranges/hot/`: Lists hot ranges in the cluster - `/api/v2/ranges//`: Describes range in more detail - `/api/v2/health/`: Returns an HTTP 200 response if node is healthy. Release justification: Adds more HTTP API endpoints in parallel that do not touch existing code. --- pkg/server/BUILD.bazel | 9 +- pkg/server/{api.go => api_v2.go} | 37 +- pkg/server/{api_auth.go => api_v2_auth.go} | 0 pkg/server/{api_error.go => api_v2_error.go} | 0 pkg/server/api_v2_ranges.go | 352 +++++++++++++++++++ pkg/server/api_v2_ranges_test.go | 160 +++++++++ pkg/server/api_v2_test.go | 154 ++++++++ pkg/server/status.go | 81 ++++- pkg/server/status_test.go | 99 ------ 9 files changed, 774 insertions(+), 118 deletions(-) rename pkg/server/{api.go => api_v2.go} (82%) rename pkg/server/{api_auth.go => api_v2_auth.go} (100%) rename pkg/server/{api_error.go => api_v2_error.go} (100%) create mode 100644 pkg/server/api_v2_ranges.go create mode 100644 pkg/server/api_v2_ranges_test.go create mode 100644 pkg/server/api_v2_test.go diff --git a/pkg/server/BUILD.bazel b/pkg/server/BUILD.bazel index 65cdb1528d3b..fbbfbca8552e 100644 --- a/pkg/server/BUILD.bazel +++ b/pkg/server/BUILD.bazel @@ -4,9 +4,10 @@ go_library( name = "server", srcs = [ "admin.go", - "api.go", - "api_auth.go", - "api_error.go", + "api_v2.go", + "api_v2_auth.go", + "api_v2_error.go", + "api_v2_ranges.go", "authentication.go", "auto_tls_init.go", "auto_upgrade.go", @@ -236,6 +237,8 @@ go_test( srcs = [ "admin_cluster_test.go", "admin_test.go", + "api_v2_ranges_test.go", + "api_v2_test.go", "authentication_test.go", "auto_tls_init_test.go", "config_test.go", diff --git a/pkg/server/api.go b/pkg/server/api_v2.go similarity index 82% rename from pkg/server/api.go rename to pkg/server/api_v2.go index 2dabff3e4ebb..19f0fa092e88 100644 --- a/pkg/server/api.go +++ b/pkg/server/api_v2.go @@ -15,6 +15,7 @@ import ( "encoding/json" "fmt" "net/http" + "strconv" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/server/telemetry" @@ -102,6 +103,13 @@ func (a *apiV2Server) registerRoutes(innerMux *mux.Router, authMux http.Handler) // Directly register other endpoints in the api server. {"sessions/", a.listSessions, true /* requiresAuth */, adminRole, noOption}, + {"nodes/", a.listNodes, true, adminRole, noOption}, + // Any endpoint returning range information requires an admin user. This is because range start/end keys + // are sensitive info. + {"nodes/{node_id}/ranges/", a.listNodeRanges, true, adminRole, noOption}, + {"ranges/hot/", a.listHotRanges, true, adminRole, noOption}, + {"ranges/{range_id:[0-9]+}/", a.listRange, true, adminRole, noOption}, + {"health/", a.health, false, regularRole, noOption}, } // For all routes requiring authentication, have the outer mux (a.mux) @@ -148,7 +156,7 @@ func (c *callCountDecorator) ServeHTTP(w http.ResponseWriter, req *http.Request) type listSessionsResponse struct { serverpb.ListSessionsResponse - Next string `json:"next"` + Next string `json:"next,omitempty"` } func (a *apiV2Server) listSessions(w http.ResponseWriter, r *http.Request) { @@ -174,3 +182,30 @@ func (a *apiV2Server) listSessions(w http.ResponseWriter, r *http.Request) { response.ListSessionsResponse = *responseProto writeJSONResponse(ctx, w, http.StatusOK, response) } + +func (a *apiV2Server) health(w http.ResponseWriter, r *http.Request) { + ready := false + readyStr := r.URL.Query().Get("ready") + if len(readyStr) > 0 { + var err error + ready, err = strconv.ParseBool(readyStr) + if err != nil { + http.Error(w, "invalid ready value", http.StatusBadRequest) + return + } + } + ctx := r.Context() + resp := &serverpb.HealthResponse{} + // If Ready is not set, the client doesn't want to know whether this node is + // ready to receive client traffic. + if !ready { + writeJSONResponse(ctx, w, 200, resp) + return + } + + if err := a.admin.checkReadinessForHealthCheck(ctx); err != nil { + apiV2InternalError(ctx, err, w) + return + } + writeJSONResponse(ctx, w, 200, resp) +} diff --git a/pkg/server/api_auth.go b/pkg/server/api_v2_auth.go similarity index 100% rename from pkg/server/api_auth.go rename to pkg/server/api_v2_auth.go diff --git a/pkg/server/api_error.go b/pkg/server/api_v2_error.go similarity index 100% rename from pkg/server/api_error.go rename to pkg/server/api_v2_error.go diff --git a/pkg/server/api_v2_ranges.go b/pkg/server/api_v2_ranges.go new file mode 100644 index 000000000000..381e89e292cd --- /dev/null +++ b/pkg/server/api_v2_ranges.go @@ -0,0 +1,352 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package server + +import ( + "context" + "fmt" + "net/http" + "sort" + "strconv" + "strings" + + "github.com/cockroachdb/cockroach/pkg/kv/kvserver/liveness/livenesspb" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/util" + "github.com/gorilla/mux" +) + +type nodeStatus struct { + // Fields that are a subset of NodeDescriptor. + NodeID roachpb.NodeID `json:"node_id"` + Address util.UnresolvedAddr `json:"address"` + Attrs roachpb.Attributes `json:"attrs"` + Locality roachpb.Locality `json:"locality"` + ServerVersion roachpb.Version `json:"ServerVersion"` + BuildTag string `json:"build_tag"` + StartedAt int64 `json:"started_at"` + ClusterName string `json:"cluster_name"` + SQLAddress util.UnresolvedAddr `json:"sql_address"` + + // Other fields that are a subset of roachpb.NodeStatus. + Metrics map[string]float64 `json:"metrics,omitempty"` + TotalSystemMemory int64 `json:"total_system_memory,omitempty"` + NumCpus int32 `json:"num_cpus,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` + + // Retrieved from the liveness status map. + LivenessStatus livenesspb.NodeLivenessStatus `json:"liveness_status"` +} + +// Response struct for listNodes. +type nodesResponse struct { + Nodes []nodeStatus `json:"nodes"` + Next int `json:"next,omitempty"` +} + +func (a *apiV2Server) listNodes(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + limit, offset := getSimplePaginationValues(r) + ctx = apiToOutgoingGatewayCtx(ctx, r) + + nodes, next, err := a.status.nodesHelper(ctx, limit, offset) + if err != nil { + apiV2InternalError(ctx, err, w) + return + } + var resp nodesResponse + resp.Next = next + for _, n := range nodes.Nodes { + resp.Nodes = append(resp.Nodes, nodeStatus{ + NodeID: n.Desc.NodeID, + Address: n.Desc.Address, + Attrs: n.Desc.Attrs, + Locality: n.Desc.Locality, + ServerVersion: n.Desc.ServerVersion, + BuildTag: n.Desc.BuildTag, + StartedAt: n.Desc.StartedAt, + ClusterName: n.Desc.ClusterName, + SQLAddress: n.Desc.SQLAddress, + Metrics: n.Metrics, + TotalSystemMemory: n.TotalSystemMemory, + NumCpus: n.NumCpus, + UpdatedAt: n.UpdatedAt, + LivenessStatus: nodes.LivenessByNodeID[n.Desc.NodeID], + }) + } + writeJSONResponse(ctx, w, 200, resp) +} + +func parseRangeIDs(input string, w http.ResponseWriter) (ranges []roachpb.RangeID, ok bool) { + if len(input) == 0 { + return nil, true + } + for _, reqRange := range strings.Split(input, ",") { + rangeID, err := strconv.ParseInt(reqRange, 10, 64) + if err != nil { + http.Error(w, "invalid range ID", http.StatusBadRequest) + return nil, false + } + + ranges = append(ranges, roachpb.RangeID(rangeID)) + } + return ranges, true +} + +type nodeRangeResponse struct { + RangeInfo rangeInfo `json:"range_info"` + Error string `json:"error,omitempty"` +} + +type rangeResponse struct { + Responses map[roachpb.NodeID]nodeRangeResponse `json:"responses_by_node_id"` +} + +func (a *apiV2Server) listRange(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ctx = apiToOutgoingGatewayCtx(ctx, r) + vars := mux.Vars(r) + rangeID, err := strconv.ParseInt(vars["range_id"], 10, 64) + if err != nil { + http.Error(w, "invalid range ID", http.StatusBadRequest) + return + } + + response := &rangeResponse{ + Responses: make(map[roachpb.NodeID]nodeRangeResponse), + } + + rangesRequest := &serverpb.RangesRequest{ + RangeIDs: []roachpb.RangeID{roachpb.RangeID(rangeID)}, + } + + dialFn := func(ctx context.Context, nodeID roachpb.NodeID) (interface{}, error) { + client, err := a.status.dialNode(ctx, nodeID) + return client, err + } + nodeFn := func(ctx context.Context, client interface{}, _ roachpb.NodeID) (interface{}, error) { + status := client.(serverpb.StatusClient) + return status.Ranges(ctx, rangesRequest) + } + responseFn := func(nodeID roachpb.NodeID, resp interface{}) { + rangesResp := resp.(*serverpb.RangesResponse) + // Age the MVCCStats to a consistent current timestamp. An age that is + // not up to date is less useful. + if len(rangesResp.Ranges) == 0 { + return + } + var ri rangeInfo + ri.init(rangesResp.Ranges[0]) + response.Responses[nodeID] = nodeRangeResponse{RangeInfo: ri} + } + errorFn := func(nodeID roachpb.NodeID, err error) { + response.Responses[nodeID] = nodeRangeResponse{ + Error: err.Error(), + } + } + + if err := a.status.iterateNodes( + ctx, fmt.Sprintf("details about range %d", rangeID), dialFn, nodeFn, responseFn, errorFn, + ); err != nil { + apiV2InternalError(ctx, err, w) + return + } + writeJSONResponse(ctx, w, 200, response) +} + +// rangeDescriptorInfo contains a subset of fields from roachpb.RangeDescriptor +// that are safe to be returned from APIs. +type rangeDescriptorInfo struct { + RangeID roachpb.RangeID `json:"range_id"` + StartKey roachpb.RKey `json:"start_key,omitempty"` + EndKey roachpb.RKey `json:"end_key,omitempty"` + + // Set for HotRanges. + StoreID roachpb.StoreID `json:"store_id"` + QueriesPerSecond float64 `json:"queries_per_second"` +} + +func (r *rangeDescriptorInfo) init(rd *roachpb.RangeDescriptor) { + if rd == nil { + *r = rangeDescriptorInfo{} + return + } + *r = rangeDescriptorInfo{ + RangeID: rd.RangeID, + StartKey: rd.StartKey, + EndKey: rd.EndKey, + } +} + +type rangeInfo struct { + Desc rangeDescriptorInfo `json:"desc"` + + // Subset of fields copied from serverpb.RangeInfo + Span serverpb.PrettySpan `json:"span"` + SourceNodeID roachpb.NodeID `json:"source_node_id,omitempty"` + SourceStoreID roachpb.StoreID `json:"source_store_id,omitempty"` + ErrorMessage string `json:"error_message,omitempty"` + LeaseHistory []roachpb.Lease `json:"lease_history"` + Problems serverpb.RangeProblems `json:"problems"` + Stats serverpb.RangeStatistics `json:"stats"` + Quiescent bool `json:"quiescent,omitempty"` + Ticking bool `json:"ticking,omitempty"` +} + +func (ri *rangeInfo) init(r serverpb.RangeInfo) { + *ri = rangeInfo{ + Span: r.Span, + SourceNodeID: r.SourceNodeID, + SourceStoreID: r.SourceStoreID, + ErrorMessage: r.ErrorMessage, + LeaseHistory: r.LeaseHistory, + Problems: r.Problems, + Stats: r.Stats, + Quiescent: r.Quiescent, + Ticking: r.Ticking, + } + ri.Desc.init(r.State.Desc) +} + +// Response struct for listNodeRanges. +type nodeRangesResponse struct { + Ranges []rangeInfo `json:"ranges"` + Next int `json:"next,omitempty"` +} + +func (a *apiV2Server) listNodeRanges(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ctx = apiToOutgoingGatewayCtx(ctx, r) + vars := mux.Vars(r) + nodeIDStr := vars["node_id"] + if nodeIDStr != "local" { + nodeID, err := strconv.ParseInt(nodeIDStr, 10, 32) + if err != nil || nodeID <= 0 { + http.Error(w, "invalid node ID", http.StatusBadRequest) + return + } + } + + ranges, ok := parseRangeIDs(r.URL.Query().Get("ranges"), w) + if !ok { + return + } + req := &serverpb.RangesRequest{ + NodeId: nodeIDStr, + RangeIDs: ranges, + } + limit, offset := getSimplePaginationValues(r) + statusResp, next, err := a.status.rangesHelper(ctx, req, limit, offset) + if err != nil { + apiV2InternalError(ctx, err, w) + return + } + resp := nodeRangesResponse{ + Ranges: make([]rangeInfo, 0, len(statusResp.Ranges)), + Next: next, + } + for _, r := range statusResp.Ranges { + var ri rangeInfo + ri.init(r) + resp.Ranges = append(resp.Ranges, ri) + } + writeJSONResponse(ctx, w, 200, resp) +} + +type responseError struct { + ErrorMessage string `json:"error_message"` + NodeID roachpb.NodeID `json:"node_id,omitempty"` +} + +// Response struct for listHotRanges. +type hotRangesResponse struct { + RangesByNodeID map[roachpb.NodeID][]rangeDescriptorInfo `json:"ranges_by_node_id"` + Errors []responseError `json:"response_error,omitempty"` + Next string `json:"next,omitempty"` +} + +func (a *apiV2Server) listHotRanges(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ctx = apiToOutgoingGatewayCtx(ctx, r) + nodeIDStr := r.URL.Query().Get("node_id") + limit, start := getRPCPaginationValues(r) + + response := &hotRangesResponse{ + RangesByNodeID: make(map[roachpb.NodeID][]rangeDescriptorInfo), + } + var requestedNodes []roachpb.NodeID + if len(nodeIDStr) > 0 { + requestedNodeID, _, err := a.status.parseNodeID(nodeIDStr) + if err != nil { + http.Error(w, "invalid node ID", http.StatusBadRequest) + return + } + requestedNodes = []roachpb.NodeID{requestedNodeID} + } + + dialFn := func(ctx context.Context, nodeID roachpb.NodeID) (interface{}, error) { + client, err := a.status.dialNode(ctx, nodeID) + return client, err + } + remoteRequest := serverpb.HotRangesRequest{NodeID: "local"} + nodeFn := func(ctx context.Context, client interface{}, nodeID roachpb.NodeID) (interface{}, error) { + status := client.(serverpb.StatusClient) + resp, err := status.HotRanges(ctx, &remoteRequest) + if err != nil || resp == nil { + return nil, err + } + rangeDescriptorInfos := make([]rangeDescriptorInfo, 0) + for _, store := range resp.HotRangesByNodeID[nodeID].Stores { + for _, hotRange := range store.HotRanges { + var r rangeDescriptorInfo + r.init(&hotRange.Desc) + r.StoreID = store.StoreID + r.QueriesPerSecond = hotRange.QueriesPerSecond + rangeDescriptorInfos = append(rangeDescriptorInfos, r) + } + } + sort.Slice(rangeDescriptorInfos, func(i, j int) bool { + if rangeDescriptorInfos[i].StoreID == rangeDescriptorInfos[j].StoreID { + return rangeDescriptorInfos[i].RangeID < rangeDescriptorInfos[j].RangeID + } + return rangeDescriptorInfos[i].StoreID < rangeDescriptorInfos[j].StoreID + }) + return rangeDescriptorInfos, nil + } + responseFn := func(nodeID roachpb.NodeID, resp interface{}) { + if hotRangesResp, ok := resp.([]rangeDescriptorInfo); ok { + response.RangesByNodeID[nodeID] = hotRangesResp + } + } + errorFn := func(nodeID roachpb.NodeID, err error) { + response.Errors = append(response.Errors, responseError{ + ErrorMessage: err.Error(), + NodeID: nodeID, + }) + } + + next, err := a.status.paginatedIterateNodes( + ctx, "hot ranges", limit, start, requestedNodes, dialFn, + nodeFn, responseFn, errorFn) + + if err != nil { + apiV2InternalError(ctx, err, w) + return + } + var nextBytes []byte + if nextBytes, err = next.MarshalText(); err != nil { + response.Errors = append(response.Errors, responseError{ErrorMessage: err.Error()}) + } else { + response.Next = string(nextBytes) + } + writeJSONResponse(ctx, w, 200, response) +} diff --git a/pkg/server/api_v2_ranges_test.go b/pkg/server/api_v2_ranges_test.go new file mode 100644 index 000000000000..7d518fac4f04 --- /dev/null +++ b/pkg/server/api_v2_ranges_test.go @@ -0,0 +1,160 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package server + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/keys" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/stretchr/testify/require" +) + +func TestHotRangesV2(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ts := startServer(t) + defer ts.Stopper().Stop(context.Background()) + + var hotRangesResp hotRangesResponse + client, err := ts.GetAdminAuthenticatedHTTPClient() + require.NoError(t, err) + + req, err := http.NewRequest("GET", ts.AdminURL()+apiV2Path+"ranges/hot/", nil) + require.NoError(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + require.NotNil(t, resp) + + require.Equal(t, 200, resp.StatusCode) + require.NoError(t, json.NewDecoder(resp.Body).Decode(&hotRangesResp)) + require.NoError(t, resp.Body.Close()) + + if len(hotRangesResp.RangesByNodeID) == 0 { + t.Fatalf("didn't get hot range responses from any nodes") + } + if len(hotRangesResp.Errors) > 0 { + t.Errorf("got an error in hot range response from n%d: %v", + hotRangesResp.Errors[0].NodeID, hotRangesResp.Errors[0].ErrorMessage) + } + + for nodeID, nodeResp := range hotRangesResp.RangesByNodeID { + if len(nodeResp) == 0 { + t.Fatalf("didn't get hot range response from node n%d", nodeID) + } + // We don't check for ranges being sorted by QPS, as this hot ranges + // report does not use that as its sort key (for stability across multiple + // pagination calls). + for _, r := range nodeResp { + if r.RangeID == 0 || (len(r.StartKey) == 0 && len(r.EndKey) == 0) { + t.Errorf("unexpected empty/unpopulated range descriptor: %+v", r) + } + } + } +} + +func TestNodeRangesV2(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ts := startServer(t) + defer ts.Stopper().Stop(context.Background()) + + // Perform a scan to ensure that all the raft groups are initialized. + if _, err := ts.db.Scan(context.Background(), keys.LocalMax, roachpb.KeyMax, 0); err != nil { + t.Fatal(err) + } + + var nodeRangesResp nodeRangesResponse + client, err := ts.GetAdminAuthenticatedHTTPClient() + require.NoError(t, err) + + req, err := http.NewRequest("GET", ts.AdminURL()+apiV2Path+"nodes/local/ranges/", nil) + require.NoError(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + require.NotNil(t, resp) + + require.Equal(t, 200, resp.StatusCode) + require.NoError(t, json.NewDecoder(resp.Body).Decode(&nodeRangesResp)) + require.NoError(t, resp.Body.Close()) + + if len(nodeRangesResp.Ranges) == 0 { + t.Errorf("didn't get any ranges") + } + for _, ri := range nodeRangesResp.Ranges { + require.Equal(t, roachpb.NodeID(1), ri.SourceNodeID) + require.Equal(t, roachpb.StoreID(1), ri.SourceStoreID) + require.GreaterOrEqual(t, len(ri.LeaseHistory), 1) + require.NotEmpty(t, ri.Span.StartKey) + require.NotEmpty(t, ri.Span.EndKey) + } + + // Take the first range ID, and call the ranges/ endpoint with it. + rangeID := nodeRangesResp.Ranges[0].Desc.RangeID + req, err = http.NewRequest("GET", fmt.Sprintf("%s%sranges/%d/", ts.AdminURL(), apiV2Path, rangeID), nil) + require.NoError(t, err) + resp, err = client.Do(req) + require.NoError(t, err) + require.NotNil(t, resp) + + var rangeResp rangeResponse + require.Equal(t, 200, resp.StatusCode) + require.NoError(t, json.NewDecoder(resp.Body).Decode(&rangeResp)) + require.NoError(t, resp.Body.Close()) + + require.Greater(t, len(rangeResp.Responses), 0) + nodeRangeResp := rangeResp.Responses[roachpb.NodeID(1)] + require.NotZero(t, nodeRangeResp) + // The below comparison is from the response returned in the previous API call + // ("nodeRangesResp") vs the current one ("nodeRangeResp"). + require.Equal(t, nodeRangesResp.Ranges[0].Desc, nodeRangeResp.RangeInfo.Desc) + require.Empty(t, nodeRangeResp.Error) +} + +func TestNodesV2(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{}) + ctx := context.Background() + defer testCluster.Stopper().Stop(ctx) + + ts1 := testCluster.Server(0) + + var nodesResp nodesResponse + client, err := ts1.GetAdminAuthenticatedHTTPClient() + require.NoError(t, err) + + req, err := http.NewRequest("GET", ts1.AdminURL()+apiV2Path+"nodes/", nil) + require.NoError(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + require.NotNil(t, resp) + + require.Equal(t, 200, resp.StatusCode) + require.NoError(t, json.NewDecoder(resp.Body).Decode(&nodesResp)) + require.NoError(t, resp.Body.Close()) + + require.Equal(t, 3, len(nodesResp.Nodes)) + for _, n := range nodesResp.Nodes { + require.Greater(t, int(n.NodeID), 0) + require.Less(t, int(n.NodeID), 4) + } +} diff --git a/pkg/server/api_v2_test.go b/pkg/server/api_v2_test.go new file mode 100644 index 000000000000..eab70447f925 --- /dev/null +++ b/pkg/server/api_v2_test.go @@ -0,0 +1,154 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package server + +import ( + "context" + gosql "database/sql" + "encoding/json" + "io/ioutil" + "net/http" + "sort" + "strconv" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/stretchr/testify/require" +) + +func TestListSessionsV2(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{}) + ctx := context.Background() + defer testCluster.Stopper().Stop(ctx) + + ts1 := testCluster.Server(0) + + var sqlConns []*gosql.Conn + for i := 0; i < 15; i++ { + serverConn := testCluster.ServerConn(i % 3) + conn, err := serverConn.Conn(ctx) + require.NoError(t, err) + sqlConns = append(sqlConns, conn) + } + + defer func() { + for _, conn := range sqlConns { + _ = conn.Close() + } + }() + + doSessionsRequest := func(client http.Client, limit int, start string) listSessionsResponse { + req, err := http.NewRequest("GET", ts1.AdminURL()+apiV2Path+"sessions/", nil) + require.NoError(t, err) + query := req.URL.Query() + if limit > 0 { + query.Add("limit", strconv.Itoa(limit)) + } + if len(start) > 0 { + query.Add("start", start) + } + req.URL.RawQuery = query.Encode() + resp, err := client.Do(req) + require.NoError(t, err) + require.NotNil(t, resp) + bytesResponse, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + + var sessionsResponse listSessionsResponse + if resp.StatusCode != 200 { + t.Fatal(string(bytesResponse)) + } + require.NoError(t, json.Unmarshal(bytesResponse, &sessionsResponse)) + return sessionsResponse + } + + time.Sleep(500 * time.Millisecond) + adminClient, err := ts1.GetAdminAuthenticatedHTTPClient() + require.NoError(t, err) + sessionsResponse := doSessionsRequest(adminClient, 0, "") + require.LessOrEqual(t, 15, len(sessionsResponse.Sessions)) + require.Equal(t, 0, len(sessionsResponse.Errors)) + allSessions := sessionsResponse.Sessions + sort.Slice(allSessions, func(i, j int) bool { + return allSessions[i].Start.Before(allSessions[j].Start) + }) + + // Test the paginated version is identical to the non-paginated one. + for limit := 1; limit <= 15; limit++ { + var next string + var paginatedSessions []serverpb.Session + for { + sessionsResponse := doSessionsRequest(adminClient, limit, next) + paginatedSessions = append(paginatedSessions, sessionsResponse.Sessions...) + next = sessionsResponse.Next + require.LessOrEqual(t, len(sessionsResponse.Sessions), limit) + if len(sessionsResponse.Sessions) < limit { + break + } + } + sort.Slice(paginatedSessions, func(i, j int) bool { + return paginatedSessions[i].Start.Before(paginatedSessions[j].Start) + }) + // Sometimes there can be a transient session that pops up in one of the two + // calls. Exclude it by only comparing the first 15 sessions. + require.Equal(t, paginatedSessions[:15], allSessions[:15]) + } + + // A non-admin user cannot see sessions at all. + nonAdminClient, err := ts1.GetAuthenticatedHTTPClient(false) + require.NoError(t, err) + req, err := http.NewRequest("GET", ts1.AdminURL()+apiV2Path+"sessions/", nil) + require.NoError(t, err) + resp, err := nonAdminClient.Do(req) + require.NoError(t, err) + require.NotNil(t, resp) + bytesResponse, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + require.Equal(t, http.StatusForbidden, resp.StatusCode) + require.Contains(t, string(bytesResponse), "not allowed") +} + +func TestHealthV2(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{}) + ctx := context.Background() + defer testCluster.Stopper().Stop(ctx) + + ts1 := testCluster.Server(0) + + client, err := ts1.GetAdminAuthenticatedHTTPClient() + require.NoError(t, err) + + req, err := http.NewRequest("GET", ts1.AdminURL()+apiV2Path+"health/", nil) + require.NoError(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + require.NotNil(t, resp) + + // Check if the response was a 200. + require.Equal(t, 200, resp.StatusCode) + // Check if an unmarshal into the (empty) HealthResponse struct works. + var hr serverpb.HealthResponse + require.NoError(t, json.NewDecoder(resp.Body).Decode(&hr)) + require.NoError(t, resp.Body.Close()) +} diff --git a/pkg/server/status.go b/pkg/server/status.go index 6b830df438f3..297894b9070c 100644 --- a/pkg/server/status.go +++ b/pkg/server/status.go @@ -1285,6 +1285,13 @@ func (s *statusServer) Profile( func (s *statusServer) Nodes( ctx context.Context, req *serverpb.NodesRequest, ) (*serverpb.NodesResponse, error) { + resp, _, err := s.nodesHelper(ctx, 0, 0) + return resp, err +} + +func (s *statusServer) nodesHelper( + ctx context.Context, limit, offset int, +) (*serverpb.NodesResponse, int, error) { ctx = propagateGatewayMetadata(ctx) ctx = s.AnnotateCtx(ctx) startKey := keys.StatusNodePrefix @@ -1294,9 +1301,16 @@ func (s *statusServer) Nodes( b.Scan(startKey, endKey) if err := s.db.Run(ctx, b); err != nil { log.Errorf(ctx, "%v", err) - return nil, status.Errorf(codes.Internal, err.Error()) + return nil, 0, status.Errorf(codes.Internal, err.Error()) + } + + var next int + var rows []kv.KeyValue + if len(b.Results[0].Rows) > 0 { + var rowsInterface interface{} + rowsInterface, next = simplePaginate(b.Results[0].Rows, limit, offset) + rows = rowsInterface.([]kv.KeyValue) } - rows := b.Results[0].Rows resp := serverpb.NodesResponse{ Nodes: make([]statuspb.NodeStatus, len(rows)), @@ -1304,14 +1318,13 @@ func (s *statusServer) Nodes( for i, row := range rows { if err := row.ValueProto(&resp.Nodes[i]); err != nil { log.Errorf(ctx, "%v", err) - return nil, status.Errorf(codes.Internal, err.Error()) + return nil, 0, status.Errorf(codes.Internal, err.Error()) } } clock := s.admin.server.clock resp.LivenessByNodeID = getLivenessStatusMap(s.nodeLiveness, clock.Now().GoTime(), s.st) - - return &resp, nil + return &resp, next, nil } // nodesStatusWithLiveness is like Nodes but for internal @@ -1539,24 +1552,38 @@ func (s *statusServer) handleVars(w http.ResponseWriter, r *http.Request) { func (s *statusServer) Ranges( ctx context.Context, req *serverpb.RangesRequest, ) (*serverpb.RangesResponse, error) { + resp, _, err := s.rangesHelper(ctx, req, 0, 0) + return resp, err +} + +// Ranges returns range info for the specified node. +func (s *statusServer) rangesHelper( + ctx context.Context, req *serverpb.RangesRequest, limit, offset int, +) (*serverpb.RangesResponse, int, error) { ctx = propagateGatewayMetadata(ctx) ctx = s.AnnotateCtx(ctx) if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { - return nil, err + return nil, 0, err } nodeID, local, err := s.parseNodeID(req.NodeId) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, err.Error()) + return nil, 0, status.Errorf(codes.InvalidArgument, err.Error()) } if !local { status, err := s.dialNode(ctx, nodeID) if err != nil { - return nil, err + return nil, 0, err + } + resp, err := status.Ranges(ctx, req) + if resp != nil && len(resp.Ranges) > 0 { + resultInterface, next := simplePaginate(resp.Ranges, limit, offset) + resp.Ranges = resultInterface.([]serverpb.RangeInfo) + return resp, next, err } - return status.Ranges(ctx, req) + return resp, 0, err } output := serverpb.RangesResponse{ @@ -1646,6 +1673,18 @@ func (s *statusServer) Ranges( isLiveMap := s.nodeLiveness.GetIsLiveMap() clusterNodes := s.storePool.ClusterNodeCount() + // There are two possibilities for ordering of ranges in the results: + // it could either be determined by the RangeIDs in the request (if specified), + // or be in RangeID order if not (as that's the ordering that + // IterateRangeDescriptors works on). The latter is already sorted in a + // stable fashion, as far as pagination is concerned. The former case requires + // sorting. + if len(req.RangeIDs) > 0 { + sort.Slice(req.RangeIDs, func(i, j int) bool { + return req.RangeIDs[i] < req.RangeIDs[j] + }) + } + err = s.stores.VisitStores(func(store *kvserver.Store) error { now := store.Clock().NowAsClockTimestamp() if len(req.RangeIDs) == 0 { @@ -1693,9 +1732,15 @@ func (s *statusServer) Ranges( return nil }) if err != nil { - return nil, status.Errorf(codes.Internal, err.Error()) + return nil, 0, status.Errorf(codes.Internal, err.Error()) } - return &output, nil + var next int + if len(req.RangeIDs) > 0 { + var outputInterface interface{} + outputInterface, next = simplePaginate(output.Ranges, limit, offset) + output.Ranges = outputInterface.([]serverpb.RangeInfo) + } + return &output, next, nil } // HotRanges returns the hottest ranges on each store on the requested node(s). @@ -1961,12 +2006,14 @@ func (s *statusServer) iterateNodes( // paginatedIterateNodes iterates nodeFn over all non-removed nodes // sequentially. It then calls nodeResponse for every valid result of nodeFn, // and nodeError on every error result. It returns the next `limit` results -// after `offset`. +// after `start`. If `requestedNodes` is specified and non-empty, iteration is +// only done on that subset of nodes in addition to any nodes already in pagState. func (s *statusServer) paginatedIterateNodes( ctx context.Context, errorCtx string, limit int, pagState paginationState, + requestedNodes []roachpb.NodeID, dialFn func(ctx context.Context, nodeID roachpb.NodeID) (interface{}, error), nodeFn func(ctx context.Context, client interface{}, nodeID roachpb.NodeID) (interface{}, error), responseFn func(nodeID roachpb.NodeID, resp interface{}), @@ -1982,8 +2029,12 @@ func (s *statusServer) paginatedIterateNodes( numNodes := len(nodeStatuses) nodeIDs := make([]roachpb.NodeID, 0, numNodes) - for nodeID := range nodeStatuses { - nodeIDs = append(nodeIDs, nodeID) + if len(requestedNodes) > 0 { + nodeIDs = append(nodeIDs, requestedNodes...) + } else { + for nodeID := range nodeStatuses { + nodeIDs = append(nodeIDs, nodeID) + } } // Sort all nodes by IDs, as this is what mergeNodeIDs expects. sort.Slice(nodeIDs, func(i, j int) bool { @@ -2070,7 +2121,7 @@ func (s *statusServer) listSessionsHelper( var err error var pagState paginationState if pagState, err = s.paginatedIterateNodes( - ctx, "session list", limit, start, dialFn, nodeFn, responseFn, errorFn); err != nil { + ctx, "session list", limit, start, nil, dialFn, nodeFn, responseFn, errorFn); err != nil { err := serverpb.ListSessionsError{Message: err.Error()} response.Errors = append(response.Errors, err) } diff --git a/pkg/server/status_test.go b/pkg/server/status_test.go index 62bf9944ff55..b5bfb9b0073f 100644 --- a/pkg/server/status_test.go +++ b/pkg/server/status_test.go @@ -14,11 +14,9 @@ import ( "bytes" "context" gosql "database/sql" - "encoding/json" "fmt" "io/ioutil" "math" - "net/http" "net/url" "os" "path/filepath" @@ -2008,103 +2006,6 @@ func TestListContentionEventsSecurity(t *testing.T) { } } -func TestListSessionsV2(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{}) - ctx := context.Background() - defer testCluster.Stopper().Stop(ctx) - - ts1 := testCluster.Server(0) - - var sqlConns []*gosql.Conn - for i := 0; i < 15; i++ { - serverConn := testCluster.ServerConn(i % 3) - conn, err := serverConn.Conn(ctx) - require.NoError(t, err) - sqlConns = append(sqlConns, conn) - } - - defer func() { - for _, conn := range sqlConns { - _ = conn.Close() - } - }() - - doSessionsRequest := func(client http.Client, limit int, start string) listSessionsResponse { - req, err := http.NewRequest("GET", ts1.AdminURL()+apiV2Path+"sessions/", nil) - require.NoError(t, err) - query := req.URL.Query() - if limit > 0 { - query.Add("limit", strconv.Itoa(limit)) - } - if len(start) > 0 { - query.Add("start", start) - } - req.URL.RawQuery = query.Encode() - resp, err := client.Do(req) - require.NoError(t, err) - require.NotNil(t, resp) - bytesResponse, err := ioutil.ReadAll(resp.Body) - require.NoError(t, err) - require.NoError(t, resp.Body.Close()) - - var sessionsResponse listSessionsResponse - if resp.StatusCode != 200 { - t.Fatal(string(bytesResponse)) - } - require.NoError(t, json.Unmarshal(bytesResponse, &sessionsResponse)) - return sessionsResponse - } - - time.Sleep(500 * time.Millisecond) - adminClient, err := ts1.GetAdminAuthenticatedHTTPClient() - require.NoError(t, err) - sessionsResponse := doSessionsRequest(adminClient, 0, "") - require.LessOrEqual(t, 15, len(sessionsResponse.Sessions)) - require.Equal(t, 0, len(sessionsResponse.Errors)) - allSessions := sessionsResponse.Sessions - sort.Slice(allSessions, func(i, j int) bool { - return allSessions[i].Start.Before(allSessions[j].Start) - }) - - // Test the paginated version is identical to the non-paginated one. - for limit := 1; limit <= 15; limit++ { - var next string - var paginatedSessions []serverpb.Session - for { - sessionsResponse := doSessionsRequest(adminClient, limit, next) - paginatedSessions = append(paginatedSessions, sessionsResponse.Sessions...) - next = sessionsResponse.Next - require.LessOrEqual(t, len(sessionsResponse.Sessions), limit) - if len(sessionsResponse.Sessions) < limit { - break - } - } - sort.Slice(paginatedSessions, func(i, j int) bool { - return paginatedSessions[i].Start.Before(paginatedSessions[j].Start) - }) - // Sometimes there can be a transient session that pops up in one of the two - // calls. Exclude it by only comparing the first 15 sessions. - require.Equal(t, paginatedSessions[:15], allSessions[:15]) - } - - // A non-admin user cannot see sessions at all. - nonAdminClient, err := ts1.GetAuthenticatedHTTPClient(false) - require.NoError(t, err) - req, err := http.NewRequest("GET", ts1.AdminURL()+apiV2Path+"sessions/", nil) - require.NoError(t, err) - resp, err := nonAdminClient.Do(req) - require.NoError(t, err) - require.NotNil(t, resp) - bytesResponse, err := ioutil.ReadAll(resp.Body) - require.NoError(t, err) - require.NoError(t, resp.Body.Close()) - require.Equal(t, http.StatusForbidden, resp.StatusCode) - require.Contains(t, string(bytesResponse), "not allowed") -} - func TestCreateStatementDiagnosticsReport(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t)