From 538cdc7de03244c1b029dfb5cd7faec7c8765377 Mon Sep 17 00:00:00 2001 From: Raphael 'kena' Poss Date: Mon, 17 Jul 2023 19:04:38 +0200 Subject: [PATCH 1/2] server: remove some occurrences of `TODODisableTestTenants` Release note: None --- pkg/server/drain_test.go | 2 +- pkg/server/server_controller_test.go | 6 +++--- pkg/server/structlogging/hot_ranges_log_test.go | 2 +- .../systemconfigwatchertest/test_system_config_watcher.go | 4 +--- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/pkg/server/drain_test.go b/pkg/server/drain_test.go index 5134b14f8473..35b6650c5244 100644 --- a/pkg/server/drain_test.go +++ b/pkg/server/drain_test.go @@ -315,7 +315,7 @@ func TestServerShutdownReleasesSession(t *testing.T) { ctx := context.Background() s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ - DefaultTestTenant: base.TODOTestTenantDisabled, + DefaultTestTenant: base.TestControlsTenantsExplicitly, }) defer s.Stopper().Stop(ctx) diff --git a/pkg/server/server_controller_test.go b/pkg/server/server_controller_test.go index 03e8d3fccecb..7c15f2d5c043 100644 --- a/pkg/server/server_controller_test.go +++ b/pkg/server/server_controller_test.go @@ -28,7 +28,7 @@ func TestServerController(t *testing.T) { ctx := context.Background() s, db, _ := serverutils.StartServer(t, base.TestServerArgs{ - DefaultTestTenant: base.TODOTestTenantDisabled, + DefaultTestTenant: base.TestControlsTenantsExplicitly, }) defer s.Stopper().Stop(ctx) @@ -60,7 +60,7 @@ func TestServerController(t *testing.T) { // controller itself: it's sufficient to see that the // tenant constructor was called. require.Error(t, err, "tenant connector requires a CCL binary") - // TODO(knz): test something about d + // TODO(knz): test something about d. } func TestSQLErrorUponInvalidTenant(t *testing.T) { @@ -70,7 +70,7 @@ func TestSQLErrorUponInvalidTenant(t *testing.T) { ctx := context.Background() s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ - DefaultTestTenant: base.TODOTestTenantDisabled, + DefaultTestTenant: base.TestControlsTenantsExplicitly, }) defer s.Stopper().Stop(ctx) diff --git a/pkg/server/structlogging/hot_ranges_log_test.go b/pkg/server/structlogging/hot_ranges_log_test.go index 0121563ce3ea..ad026415f5c5 100644 --- a/pkg/server/structlogging/hot_ranges_log_test.go +++ b/pkg/server/structlogging/hot_ranges_log_test.go @@ -47,7 +47,7 @@ func TestHotRangesStats(t *testing.T) { defer cleanup() s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ - DefaultTestTenant: base.TODOTestTenantDisabled, + DefaultTestTenant: base.TestControlsTenantsExplicitly, StoreSpecs: []base.StoreSpec{ base.DefaultTestStoreSpec, base.DefaultTestStoreSpec, diff --git a/pkg/server/systemconfigwatcher/systemconfigwatchertest/test_system_config_watcher.go b/pkg/server/systemconfigwatcher/systemconfigwatchertest/test_system_config_watcher.go index 532fa46927b4..f3e5ee4ae313 100644 --- a/pkg/server/systemconfigwatcher/systemconfigwatchertest/test_system_config_watcher.go +++ b/pkg/server/systemconfigwatcher/systemconfigwatchertest/test_system_config_watcher.go @@ -48,9 +48,7 @@ func TestSystemConfigWatcher(t *testing.T, skipSecondary bool) { ctx := context.Background() s, sqlDB, kvDB := serverutils.StartServer(t, base.TestServerArgs{ - // Test runs against tenant, so no need to create the default - // test tenant. - DefaultTestTenant: base.TODOTestTenantDisabled, + DefaultTestTenant: base.TestControlsTenantsExplicitly, }, ) defer s.Stopper().Stop(ctx) From c5e3cf49403c017c370d10d31071244835117c75 Mon Sep 17 00:00:00 2001 From: Raphael 'kena' Poss Date: Tue, 18 Jul 2023 17:39:45 +0200 Subject: [PATCH 2/2] server: split test code into sub-packages and files New sub-packages: - `storage_api`: unit tests for API endpoints specific to the KV/storage layer. - `application_api`: unit tests for API endpoints valid for application servers, including for secondary tenants. - `privchecker`: SQL authorization interface for API handlers. (previously: `adminPrivilegeChecker`) - `authserver`: HTTP/RPC authentication code. (previously: `authenticationServer`) - `srverrors`: error objects suitable for return from API handlers. - `srvtestutils`: common helpers for test code. Release note: None --- pkg/BUILD.bazel | 20 + pkg/ccl/oidcccl/BUILD.bazel | 2 +- pkg/ccl/oidcccl/authentication_oidc.go | 6 +- pkg/ccl/serverccl/BUILD.bazel | 1 + pkg/ccl/serverccl/server_controller_test.go | 7 +- pkg/cli/BUILD.bazel | 1 + pkg/cli/auth.go | 6 +- pkg/cli/democluster/BUILD.bazel | 1 + pkg/cli/democluster/demo_cluster.go | 3 +- pkg/cmd/roachtest/tests/BUILD.bazel | 2 +- pkg/cmd/roachtest/tests/cluster_init.go | 4 +- pkg/server/BUILD.bazel | 52 +- pkg/server/addjoin.go | 4 +- pkg/server/admin.go | 698 +-- pkg/server/admin_cluster_test.go | 254 -- pkg/server/admin_test.go | 3502 --------------- pkg/server/admin_test_utils.go | 60 - pkg/server/api_v2.go | 105 +- pkg/server/api_v2_error.go | 45 - pkg/server/api_v2_ranges.go | 31 +- pkg/server/api_v2_ranges_test.go | 16 +- pkg/server/api_v2_sql.go | 3 +- pkg/server/api_v2_sql_schema.go | 47 +- pkg/server/api_v2_sql_schema_test.go | 17 +- pkg/server/api_v2_test.go | 20 +- pkg/server/apiconstants/BUILD.bazel | 12 + pkg/server/apiconstants/constants.go | 47 + pkg/server/apiconstants/testutils.go | 35 + pkg/server/apiutil/BUILD.bazel | 9 + pkg/server/apiutil/apiutil.go | 32 + pkg/server/application_api/BUILD.bazel | 90 + pkg/server/application_api/activity_test.go | 144 + pkg/server/application_api/config_test.go | 248 ++ pkg/server/application_api/contention_test.go | 415 ++ pkg/server/application_api/dbconsole_test.go | 183 + pkg/server/application_api/doc.go | 15 + pkg/server/application_api/events_test.go | 155 + pkg/server/application_api/insights_test.go | 229 + pkg/server/application_api/jobs_test.go | 478 ++ pkg/server/application_api/main_test.go | 35 + pkg/server/application_api/metrics_test.go | 152 + pkg/server/application_api/query_plan_test.go | 66 + .../application_api/schema_inspection_test.go | 620 +++ pkg/server/application_api/security_test.go | 66 + pkg/server/application_api/sessions_test.go | 338 ++ pkg/server/application_api/sql_stats_test.go | 941 ++++ pkg/server/application_api/stmtdiag_test.go | 265 ++ .../storage_inspection_test.go | 494 +++ pkg/server/application_api/telemetry_test.go | 119 + pkg/server/application_api/util_test.go | 32 + pkg/server/application_api/zcfg_test.go | 138 + pkg/server/authserver/BUILD.bazel | 97 + pkg/server/authserver/api.go | 113 + pkg/server/authserver/api_v2.go | 74 + pkg/server/{ => authserver}/api_v2_auth.go | 125 +- pkg/server/{ => authserver}/authentication.go | 387 +- .../{ => authserver}/authentication_test.go | 201 +- pkg/server/authserver/context.go | 127 + pkg/server/authserver/cookie.go | 176 + pkg/server/authserver/main_test.go | 35 + pkg/server/combined_statement_stats.go | 114 +- pkg/server/debug/BUILD.bazel | 15 +- pkg/server/debug/debug_test.go | 223 + pkg/server/debug/main_test.go | 35 + pkg/server/decommission.go | 68 +- pkg/server/decommission_test.go | 319 -- pkg/server/distsql_flows.go | 63 + pkg/server/distsql_flows_test.go | 205 + pkg/server/drain.go | 3 +- pkg/server/fanout_clients.go | 6 +- pkg/server/grpc_gateway.go | 7 +- pkg/server/grpc_gateway_test.go | 57 + pkg/server/index_usage_stats.go | 20 +- pkg/server/index_usage_stats_test.go | 8 +- pkg/server/init_handshake.go | 7 +- pkg/server/main_test.go | 2 + pkg/server/multi_store_test.go | 4 +- pkg/server/nodes_response.go | 155 + pkg/server/nodes_response_test.go | 170 + pkg/server/privchecker/BUILD.bazel | 55 + pkg/server/privchecker/api.go | 98 + pkg/server/privchecker/main_test.go | 35 + pkg/server/privchecker/privchecker.go | 326 ++ pkg/server/privchecker/privchecker_test.go | 163 + pkg/server/problem_ranges.go | 2 +- pkg/server/purge_auth_session_test.go | 5 +- pkg/server/rangetestutils/BUILD.bazel | 12 + pkg/server/rangetestutils/rangetestutils.go | 49 + pkg/server/server.go | 25 +- pkg/server/server_controller_http.go | 45 +- pkg/server/server_http.go | 48 +- pkg/server/server_sql.go | 11 + pkg/server/server_test.go | 8 +- pkg/server/sql_stats.go | 5 +- pkg/server/srverrors/BUILD.bazel | 32 + pkg/server/srverrors/errors.go | 80 + pkg/server/srverrors/errors_test.go | 40 + pkg/server/srverrors/main_test.go | 22 + pkg/server/srvtestutils/BUILD.bazel | 18 + pkg/server/srvtestutils/testutils.go | 157 + pkg/server/statement_diagnostics_requests.go | 17 +- pkg/server/statements.go | 5 +- pkg/server/status.go | 620 +-- pkg/server/status_local_file_retrieval.go | 3 +- pkg/server/status_test.go | 3889 ----------------- pkg/server/storage_api/BUILD.bazel | 70 + pkg/server/storage_api/certs_test.go | 55 + pkg/server/storage_api/decommission_test.go | 1036 +++++ pkg/server/storage_api/doc.go | 15 + pkg/server/storage_api/engine_test.go | 95 + pkg/server/storage_api/enqueue_test.go | 124 + pkg/server/storage_api/files_test.go | 155 + pkg/server/storage_api/gossip_test.go | 47 + pkg/server/storage_api/health_test.go | 143 + pkg/server/storage_api/logfiles_test.go | 423 ++ pkg/server/storage_api/main_test.go | 33 + pkg/server/storage_api/network_test.go | 72 + pkg/server/storage_api/nodes_test.go | 214 + pkg/server/storage_api/raft_test.go | 75 + pkg/server/storage_api/rangelog_test.go | 178 + pkg/server/storage_api/ranges_test.go | 190 + pkg/server/tenant.go | 24 +- pkg/server/testserver.go | 129 +- pkg/server/testserver_http.go | 35 +- pkg/server/user.go | 16 +- pkg/server/user_test.go | 28 +- .../lint/passes/fmtsafe/functions.go | 4 +- pkg/testutils/serverutils/test_server_shim.go | 4 - pkg/testutils/serverutils/test_tenant_shim.go | 25 +- pkg/util/safesql/BUILD.bazel | 9 + pkg/util/safesql/safesql.go | 88 + 131 files changed, 12207 insertions(+), 9931 deletions(-) delete mode 100644 pkg/server/admin_cluster_test.go delete mode 100644 pkg/server/admin_test.go delete mode 100644 pkg/server/admin_test_utils.go delete mode 100644 pkg/server/api_v2_error.go create mode 100644 pkg/server/apiconstants/BUILD.bazel create mode 100644 pkg/server/apiconstants/constants.go create mode 100644 pkg/server/apiconstants/testutils.go create mode 100644 pkg/server/apiutil/BUILD.bazel create mode 100644 pkg/server/apiutil/apiutil.go create mode 100644 pkg/server/application_api/BUILD.bazel create mode 100644 pkg/server/application_api/activity_test.go create mode 100644 pkg/server/application_api/config_test.go create mode 100644 pkg/server/application_api/contention_test.go create mode 100644 pkg/server/application_api/dbconsole_test.go create mode 100644 pkg/server/application_api/doc.go create mode 100644 pkg/server/application_api/events_test.go create mode 100644 pkg/server/application_api/insights_test.go create mode 100644 pkg/server/application_api/jobs_test.go create mode 100644 pkg/server/application_api/main_test.go create mode 100644 pkg/server/application_api/metrics_test.go create mode 100644 pkg/server/application_api/query_plan_test.go create mode 100644 pkg/server/application_api/schema_inspection_test.go create mode 100644 pkg/server/application_api/security_test.go create mode 100644 pkg/server/application_api/sessions_test.go create mode 100644 pkg/server/application_api/sql_stats_test.go create mode 100644 pkg/server/application_api/stmtdiag_test.go create mode 100644 pkg/server/application_api/storage_inspection_test.go create mode 100644 pkg/server/application_api/telemetry_test.go create mode 100644 pkg/server/application_api/util_test.go create mode 100644 pkg/server/application_api/zcfg_test.go create mode 100644 pkg/server/authserver/BUILD.bazel create mode 100644 pkg/server/authserver/api.go create mode 100644 pkg/server/authserver/api_v2.go rename pkg/server/{ => authserver}/api_v2_auth.go (81%) rename pkg/server/{ => authserver}/authentication.go (55%) rename pkg/server/{ => authserver}/authentication_test.go (84%) create mode 100644 pkg/server/authserver/context.go create mode 100644 pkg/server/authserver/cookie.go create mode 100644 pkg/server/authserver/main_test.go create mode 100644 pkg/server/debug/debug_test.go create mode 100644 pkg/server/debug/main_test.go delete mode 100644 pkg/server/decommission_test.go create mode 100644 pkg/server/distsql_flows.go create mode 100644 pkg/server/distsql_flows_test.go create mode 100644 pkg/server/grpc_gateway_test.go create mode 100644 pkg/server/nodes_response.go create mode 100644 pkg/server/nodes_response_test.go create mode 100644 pkg/server/privchecker/BUILD.bazel create mode 100644 pkg/server/privchecker/api.go create mode 100644 pkg/server/privchecker/main_test.go create mode 100644 pkg/server/privchecker/privchecker.go create mode 100644 pkg/server/privchecker/privchecker_test.go create mode 100644 pkg/server/rangetestutils/BUILD.bazel create mode 100644 pkg/server/rangetestutils/rangetestutils.go create mode 100644 pkg/server/srverrors/BUILD.bazel create mode 100644 pkg/server/srverrors/errors.go create mode 100644 pkg/server/srverrors/errors_test.go create mode 100644 pkg/server/srverrors/main_test.go create mode 100644 pkg/server/srvtestutils/BUILD.bazel create mode 100644 pkg/server/srvtestutils/testutils.go delete mode 100644 pkg/server/status_test.go create mode 100644 pkg/server/storage_api/BUILD.bazel create mode 100644 pkg/server/storage_api/certs_test.go create mode 100644 pkg/server/storage_api/decommission_test.go create mode 100644 pkg/server/storage_api/doc.go create mode 100644 pkg/server/storage_api/engine_test.go create mode 100644 pkg/server/storage_api/enqueue_test.go create mode 100644 pkg/server/storage_api/files_test.go create mode 100644 pkg/server/storage_api/gossip_test.go create mode 100644 pkg/server/storage_api/health_test.go create mode 100644 pkg/server/storage_api/logfiles_test.go create mode 100644 pkg/server/storage_api/main_test.go create mode 100644 pkg/server/storage_api/network_test.go create mode 100644 pkg/server/storage_api/nodes_test.go create mode 100644 pkg/server/storage_api/raft_test.go create mode 100644 pkg/server/storage_api/rangelog_test.go create mode 100644 pkg/server/storage_api/ranges_test.go create mode 100644 pkg/util/safesql/BUILD.bazel create mode 100644 pkg/util/safesql/safesql.go diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index 100f0e42a109..85b8ffa655b3 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -285,6 +285,8 @@ ALL_TESTS = [ "//pkg/security/username:username_disallowed_imports_test", "//pkg/security/username:username_test", "//pkg/security:security_test", + "//pkg/server/application_api:application_api_test", + "//pkg/server/authserver:authserver_test", "//pkg/server/autoconfig:autoconfig_test", "//pkg/server/debug/goroutineui:goroutineui_test", "//pkg/server/debug/pprofui:pprofui_test", @@ -294,11 +296,14 @@ ALL_TESTS = [ "//pkg/server/dumpstore:dumpstore_test", "//pkg/server/goroutinedumper:goroutinedumper_test", "//pkg/server/pgurl:pgurl_test", + "//pkg/server/privchecker:privchecker_test", "//pkg/server/profiler:profiler_test", "//pkg/server/serverpb:serverpb_test", "//pkg/server/serverrules:serverrules_test", "//pkg/server/settingswatcher:settingswatcher_test", + "//pkg/server/srverrors:srverrors_test", "//pkg/server/status:status_test", + "//pkg/server/storage_api:storage_api_test", "//pkg/server/structlogging:structlogging_test", "//pkg/server/systemconfigwatcher:systemconfigwatcher_test", "//pkg/server/telemetry:telemetry_test", @@ -1503,6 +1508,12 @@ GO_TARGETS = [ "//pkg/security/username:username_test", "//pkg/security:security", "//pkg/security:security_test", + "//pkg/server/apiconstants:apiconstants", + "//pkg/server/apiutil:apiutil", + "//pkg/server/application_api:application_api", + "//pkg/server/application_api:application_api_test", + "//pkg/server/authserver:authserver", + "//pkg/server/authserver:authserver_test", "//pkg/server/autoconfig/acprovider:acprovider", "//pkg/server/autoconfig/autoconfigpb:autoconfigpb", "//pkg/server/autoconfig:autoconfig", @@ -1523,17 +1534,25 @@ GO_TARGETS = [ "//pkg/server/goroutinedumper:goroutinedumper_test", "//pkg/server/pgurl:pgurl", "//pkg/server/pgurl:pgurl_test", + "//pkg/server/privchecker:privchecker", + "//pkg/server/privchecker:privchecker_test", "//pkg/server/profiler:profiler", "//pkg/server/profiler:profiler_test", + "//pkg/server/rangetestutils:rangetestutils", "//pkg/server/serverpb:serverpb", "//pkg/server/serverpb:serverpb_test", "//pkg/server/serverrules:serverrules", "//pkg/server/serverrules:serverrules_test", "//pkg/server/settingswatcher:settingswatcher", "//pkg/server/settingswatcher:settingswatcher_test", + "//pkg/server/srverrors:srverrors", + "//pkg/server/srverrors:srverrors_test", + "//pkg/server/srvtestutils:srvtestutils", "//pkg/server/status/statuspb:statuspb", "//pkg/server/status:status", "//pkg/server/status:status_test", + "//pkg/server/storage_api:storage_api", + "//pkg/server/storage_api:storage_api_test", "//pkg/server/structlogging:structlogging", "//pkg/server/structlogging:structlogging_test", "//pkg/server/systemconfigwatcher/systemconfigwatchertest:systemconfigwatchertest", @@ -2343,6 +2362,7 @@ GO_TARGETS = [ "//pkg/util/retry:retry_test", "//pkg/util/ring:ring", "//pkg/util/ring:ring_test", + "//pkg/util/safesql:safesql", "//pkg/util/schedulerlatency:schedulerlatency", "//pkg/util/schedulerlatency:schedulerlatency_test", "//pkg/util/sdnotify:sdnotify", diff --git a/pkg/ccl/oidcccl/BUILD.bazel b/pkg/ccl/oidcccl/BUILD.bazel index b5305cffed92..ee25740fb320 100644 --- a/pkg/ccl/oidcccl/BUILD.bazel +++ b/pkg/ccl/oidcccl/BUILD.bazel @@ -15,7 +15,7 @@ go_library( "//pkg/ccl/utilccl", "//pkg/roachpb", "//pkg/security/username", - "//pkg/server", + "//pkg/server/authserver", "//pkg/server/serverpb", "//pkg/server/telemetry", "//pkg/settings", diff --git a/pkg/ccl/oidcccl/authentication_oidc.go b/pkg/ccl/oidcccl/authentication_oidc.go index 97984cbc9b93..005450991958 100644 --- a/pkg/ccl/oidcccl/authentication_oidc.go +++ b/pkg/ccl/oidcccl/authentication_oidc.go @@ -23,7 +23,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/ccl/utilccl" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/security/username" - "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/server/telemetry" "github.com/cockroachdb/cockroach/pkg/settings/cluster" @@ -293,7 +293,7 @@ var ConfigureOIDC = func( userLoginFromSSO func(ctx context.Context, username string) (*http.Cookie, error), ambientCtx log.AmbientContext, cluster uuid.UUID, -) (server.OIDC, error) { +) (authserver.OIDC, error) { oidcAuthentication := &oidcAuthenticationServer{} // Don't want to use GRPC here since these endpoints require HTTP-Redirect behaviors and the @@ -719,5 +719,5 @@ var ConfigureOIDC = func( } func init() { - server.ConfigureOIDC = ConfigureOIDC + authserver.ConfigureOIDC = ConfigureOIDC } diff --git a/pkg/ccl/serverccl/BUILD.bazel b/pkg/ccl/serverccl/BUILD.bazel index 650356f386ef..9d38ea243b13 100644 --- a/pkg/ccl/serverccl/BUILD.bazel +++ b/pkg/ccl/serverccl/BUILD.bazel @@ -66,6 +66,7 @@ go_test( "//pkg/security/securitytest", "//pkg/security/username", "//pkg/server", + "//pkg/server/authserver", "//pkg/server/serverpb", "//pkg/server/systemconfigwatcher/systemconfigwatchertest", "//pkg/settings/cluster", diff --git a/pkg/ccl/serverccl/server_controller_test.go b/pkg/ccl/serverccl/server_controller_test.go index 9e1b1f77d547..3680156b38ad 100644 --- a/pkg/ccl/serverccl/server_controller_test.go +++ b/pkg/ccl/serverccl/server_controller_test.go @@ -24,6 +24,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/multitenant/tenantcapabilities" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/sql/lexbase" @@ -285,7 +286,7 @@ VALUES($1, $2, $3, $4, $5, (SELECT user_id FROM system.users WHERE username = $3 t.Logf("retrieving session list from system tenant via cookie") c := &http.Cookie{ - Name: server.TenantSelectCookieName, + Name: authserver.TenantSelectCookieName, Value: catconstants.SystemTenantName, Path: "/", HttpOnly: true, @@ -362,7 +363,7 @@ func TestServerControllerDefaultHTTPTenant(t *testing.T) { tenantCookie := "" for _, c := range resp.Cookies() { - if c.Name == server.TenantSelectCookieName { + if c.Name == authserver.TenantSelectCookieName { tenantCookie = c.Value } } @@ -387,7 +388,7 @@ func TestServerControllerBadHTTPCookies(t *testing.T) { require.NoError(t, err) c := &http.Cookie{ - Name: server.TenantSelectCookieName, + Name: authserver.TenantSelectCookieName, Value: "some-nonexistent-tenant", Path: "/", HttpOnly: true, diff --git a/pkg/cli/BUILD.bazel b/pkg/cli/BUILD.bazel index 4ba294554af5..0ba18823c628 100644 --- a/pkg/cli/BUILD.bazel +++ b/pkg/cli/BUILD.bazel @@ -144,6 +144,7 @@ go_library( "//pkg/security/securitytest", "//pkg/security/username", "//pkg/server", + "//pkg/server/authserver", "//pkg/server/autoconfig/acprovider", "//pkg/server/pgurl", "//pkg/server/profiler", diff --git a/pkg/cli/auth.go b/pkg/cli/auth.go index c00c16ca1ede..f70c005328bc 100644 --- a/pkg/cli/auth.go +++ b/pkg/cli/auth.go @@ -21,7 +21,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/cli/clisqlclient" "github.com/cockroachdb/cockroach/pkg/cli/clisqlexec" "github.com/cockroachdb/cockroach/pkg/clusterversion" - "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/util/timeutil" @@ -117,7 +117,7 @@ func createAuthSessionToken( } // Make a secret. - secret, hashedSecret, err := server.CreateAuthSecret() + secret, hashedSecret, err := authserver.CreateAuthSecret() if err != nil { return -1, nil, err } @@ -185,7 +185,7 @@ RETURNING id // Spell out the cookie. sCookie := &serverpb.SessionCookie{ID: id, Secret: secret} - httpCookie, err = server.EncodeSessionCookie(sCookie, false /* forHTTPSOnly */) + httpCookie, err = authserver.EncodeSessionCookie(sCookie, false /* forHTTPSOnly */) return id, httpCookie, err } diff --git a/pkg/cli/democluster/BUILD.bazel b/pkg/cli/democluster/BUILD.bazel index c44780dd2ca3..f6fc10cf3525 100644 --- a/pkg/cli/democluster/BUILD.bazel +++ b/pkg/cli/democluster/BUILD.bazel @@ -29,6 +29,7 @@ go_library( "//pkg/security/certnames", "//pkg/security/username", "//pkg/server", + "//pkg/server/authserver", "//pkg/server/autoconfig/acprovider", "//pkg/server/pgurl", "//pkg/server/serverpb", diff --git a/pkg/cli/democluster/demo_cluster.go b/pkg/cli/democluster/demo_cluster.go index 9ebe17b64e31..b9d53745a1cd 100644 --- a/pkg/cli/democluster/demo_cluster.go +++ b/pkg/cli/democluster/demo_cluster.go @@ -35,6 +35,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/security/certnames" "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/pgurl" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/server/status" @@ -2002,7 +2003,7 @@ func (c *transientCluster) addDemoLoginToURL(uiURL *url.URL, includeTenantName b // in that case. q.Add("username", c.adminUser.Normalized()) q.Add("password", c.adminPassword) - uiURL.Path = server.DemoLoginPath + uiURL.Path = authserver.DemoLoginPath } if !includeTenantName { diff --git a/pkg/cmd/roachtest/tests/BUILD.bazel b/pkg/cmd/roachtest/tests/BUILD.bazel index 22f2563bc5e9..b09eb6c0e559 100644 --- a/pkg/cmd/roachtest/tests/BUILD.bazel +++ b/pkg/cmd/roachtest/tests/BUILD.bazel @@ -215,7 +215,7 @@ go_library( "//pkg/roachprod/logger", "//pkg/roachprod/prometheus", "//pkg/roachprod/vm", - "//pkg/server", + "//pkg/server/authserver", "//pkg/server/serverpb", "//pkg/sql", "//pkg/sql/pgwire/pgcode", diff --git a/pkg/cmd/roachtest/tests/cluster_init.go b/pkg/cmd/roachtest/tests/cluster_init.go index 45e5b0b6ebd1..cf1576003e86 100644 --- a/pkg/cmd/roachtest/tests/cluster_init.go +++ b/pkg/cmd/roachtest/tests/cluster_init.go @@ -23,7 +23,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/option" "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/test" "github.com/cockroachdb/cockroach/pkg/roachprod/install" - "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/util/httputil" "github.com/cockroachdb/cockroach/pkg/util/retry" @@ -138,7 +138,7 @@ func runClusterInit(ctx context.Context, t test.Test, c cluster.Cluster) { // Prevent regression of #25771 by also sending authenticated // requests, like would be sent if an admin UI were open against // this node while it booted. - cookie, err := server.EncodeSessionCookie(&serverpb.SessionCookie{ + cookie, err := authserver.EncodeSessionCookie(&serverpb.SessionCookie{ // The actual contents of the cookie don't matter; the presence of // a valid encoded cookie is enough to trigger the authentication // code paths. diff --git a/pkg/server/BUILD.bazel b/pkg/server/BUILD.bazel index f6ac69d75461..abe86994394c 100644 --- a/pkg/server/BUILD.bazel +++ b/pkg/server/BUILD.bazel @@ -5,15 +5,11 @@ go_library( srcs = [ "addjoin.go", "admin.go", - "admin_test_utils.go", "admission.go", "api_v2.go", - "api_v2_auth.go", - "api_v2_error.go", "api_v2_ranges.go", "api_v2_sql.go", "api_v2_sql_schema.go", - "authentication.go", "auto_tls_init.go", "auto_upgrade.go", "clock_monotonicity.go", @@ -23,6 +19,7 @@ go_library( "config_unix.go", "config_windows.go", "decommission.go", + "distsql_flows.go", "doc.go", "drain.go", "env_sampler.go", @@ -44,6 +41,7 @@ go_library( "node_http_router.go", "node_tenant.go", "node_tombstone_storage.go", + "nodes_response.go", "pagination.go", "problem_ranges.go", "rlimit_bsd.go", @@ -124,6 +122,7 @@ go_library( "//pkg/kv/kvprober", "//pkg/kv/kvserver", "//pkg/kv/kvserver/allocator/allocatorimpl", + "//pkg/kv/kvserver/allocator/plan", "//pkg/kv/kvserver/allocator/storepool", "//pkg/kv/kvserver/closedts/ctpb", "//pkg/kv/kvserver/closedts/sidetransport", @@ -169,9 +168,11 @@ go_library( "//pkg/security", "//pkg/security/certnames", "//pkg/security/clientsecopts", - "//pkg/security/password", "//pkg/security/securityassets", "//pkg/security/username", + "//pkg/server/apiconstants", + "//pkg/server/apiutil", + "//pkg/server/authserver", "//pkg/server/autoconfig", "//pkg/server/autoconfig/acprovider", "//pkg/server/debug", @@ -180,10 +181,12 @@ go_library( "//pkg/server/diagnostics/diagnosticspb", "//pkg/server/goroutinedumper", "//pkg/server/pgurl", + "//pkg/server/privchecker", "//pkg/server/profiler", "//pkg/server/serverpb", "//pkg/server/serverrules", "//pkg/server/settingswatcher", + "//pkg/server/srverrors", "//pkg/server/status", "//pkg/server/status/statuspb", "//pkg/server/structlogging", @@ -241,7 +244,6 @@ go_library( "//pkg/sql/parser", "//pkg/sql/parser/statements", "//pkg/sql/pgwire", - "//pkg/sql/pgwire/pgcode", "//pkg/sql/pgwire/pgerror", "//pkg/sql/pgwire/pgwirecancel", "//pkg/sql/physicalplan", @@ -271,11 +273,9 @@ go_library( "//pkg/sql/sqlstats/persistedsqlstats/sqlstatsutil", "//pkg/sql/stats", "//pkg/sql/stmtdiagnostics", - "//pkg/sql/syntheticprivilege", "//pkg/sql/syntheticprivilegecache", "//pkg/sql/ttl/ttljob", "//pkg/sql/ttl/ttlschedule", - "//pkg/sql/types", "//pkg/storage", "//pkg/storage/enginepb", "//pkg/storage/fs", @@ -316,6 +316,7 @@ go_library( "//pkg/util/quotapool", "//pkg/util/rangedesc", "//pkg/util/retry", + "//pkg/util/safesql", "//pkg/util/schedulerlatency", "//pkg/util/startup", "//pkg/util/stop", @@ -408,21 +409,19 @@ go_test( size = "enormous", srcs = [ "addjoin_test.go", - "admin_cluster_test.go", - "admin_test.go", "api_v2_ranges_test.go", "api_v2_sql_schema_test.go", "api_v2_sql_test.go", "api_v2_test.go", - "authentication_test.go", "auto_tls_init_test.go", "bench_test.go", "config_test.go", "connectivity_test.go", "critical_nodes_test.go", - "decommission_test.go", + "distsql_flows_test.go", "drain_test.go", "graphite_test.go", + "grpc_gateway_test.go", "helpers_test.go", "index_usage_stats_test.go", "init_handshake_test.go", @@ -435,6 +434,7 @@ go_test( "node_tenant_test.go", "node_test.go", "node_tombstone_storage_test.go", + "nodes_response_test.go", "pagination_test.go", "purge_auth_session_test.go", "server_controller_test.go", @@ -450,7 +450,6 @@ go_test( "statements_test.go", "stats_test.go", "status_ext_test.go", - "status_test.go", "sticky_engine_test.go", "tenant_range_lookup_test.go", "testserver_test.go", @@ -473,22 +472,15 @@ go_test( "//pkg/config", "//pkg/config/zonepb", "//pkg/gossip", - "//pkg/jobs", - "//pkg/jobs/jobspb", "//pkg/keys", "//pkg/kv", "//pkg/kv/kvclient/kvtenant", "//pkg/kv/kvpb", "//pkg/kv/kvserver", - "//pkg/kv/kvserver/allocator", - "//pkg/kv/kvserver/allocator/allocatorimpl", - "//pkg/kv/kvserver/allocator/plan", "//pkg/kv/kvserver/closedts", - "//pkg/kv/kvserver/closedts/ctpb", "//pkg/kv/kvserver/kvserverbase", "//pkg/kv/kvserver/kvserverpb", "//pkg/kv/kvserver/kvstorage", - "//pkg/kv/kvserver/liveness", "//pkg/kv/kvserver/liveness/livenesspb", "//pkg/roachpb", "//pkg/rpc", @@ -497,10 +489,12 @@ go_test( "//pkg/security/securityassets", "//pkg/security/securitytest", "//pkg/security/username", - "//pkg/server/debug", + "//pkg/server/apiconstants", + "//pkg/server/authserver", "//pkg/server/diagnostics", - "//pkg/server/diagnostics/diagnosticspb", + "//pkg/server/rangetestutils", "//pkg/server/serverpb", + "//pkg/server/srvtestutils", "//pkg/server/status", "//pkg/server/status/statuspb", "//pkg/server/telemetry", @@ -509,21 +503,14 @@ go_test( "//pkg/spanconfig", "//pkg/sql", "//pkg/sql/appstatspb", - "//pkg/sql/catalog/descpb", - "//pkg/sql/clusterunique", "//pkg/sql/execinfrapb", - "//pkg/sql/idxusage", - "//pkg/sql/pgwire/pgcode", - "//pkg/sql/pgwire/pgerror", "//pkg/sql/roleoption", - "//pkg/sql/sem/catconstants", "//pkg/sql/sem/tree", "//pkg/sql/sessiondata", "//pkg/sql/sqlstats", "//pkg/sql/sqlstats/persistedsqlstats", "//pkg/sql/tests", "//pkg/storage", - "//pkg/storage/enginepb", "//pkg/testutils", "//pkg/testutils/datapathutils", "//pkg/testutils/diagutils", @@ -543,20 +530,16 @@ go_test( "//pkg/util/encoding", "//pkg/util/envutil", "//pkg/util/grpcutil", - "//pkg/util/grunning", "//pkg/util/hlc", "//pkg/util/httputil", "//pkg/util/humanizeutil", "//pkg/util/leaktest", "//pkg/util/log", - "//pkg/util/log/logpb", "//pkg/util/metric", "//pkg/util/netutil", "//pkg/util/netutil/addr", "//pkg/util/protoutil", - "//pkg/util/randident", "//pkg/util/randutil", - "//pkg/util/stop", "//pkg/util/syncutil", "//pkg/util/timeutil", "//pkg/util/tracing", @@ -572,7 +555,6 @@ go_test( "@com_github_grpc_ecosystem_grpc_gateway//runtime:go_default_library", "@com_github_jackc_pgx_v4//:pgx", "@com_github_kr_pretty//:pretty", - "@com_github_lib_pq//:pq", "@com_github_prometheus_client_model//go", "@com_github_prometheus_common//expfmt", "@com_github_stretchr_testify//assert", @@ -581,9 +563,7 @@ go_test( "@io_opentelemetry_go_otel//attribute", "@org_golang_google_grpc//:go_default_library", "@org_golang_google_grpc//codes", - "@org_golang_google_grpc//credentials", "@org_golang_google_grpc//metadata", "@org_golang_google_grpc//status", - "@org_golang_x_crypto//bcrypt", ], ) diff --git a/pkg/server/addjoin.go b/pkg/server/addjoin.go index 2f96dda8c513..d09215e7dc82 100644 --- a/pkg/server/addjoin.go +++ b/pkg/server/addjoin.go @@ -57,7 +57,7 @@ func (s *adminServer) RequestCA( func (s *adminServer) consumeJoinToken(ctx context.Context, clientToken security.JoinToken) error { return s.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { - row, err := s.ie.QueryRow( + row, err := s.internalExecutor.QueryRow( ctx, "select-consume-join-token", txn, "SELECT id, secret FROM system.join_tokens WHERE id = $1 AND now() < expiration", clientToken.TokenID.String()) @@ -72,7 +72,7 @@ func (s *adminServer) consumeJoinToken(ctx context.Context, clientToken security return errors.New("invalid shared secret") } - i, err := s.ie.Exec(ctx, "delete-consume-join-token", txn, + i, err := s.internalExecutor.Exec(ctx, "delete-consume-join-token", txn, "DELETE FROM system.join_tokens WHERE id = $1", clientToken.TokenID.String()) if err != nil { diff --git a/pkg/server/admin.go b/pkg/server/admin.go index 4fb20ab10c69..ba0d7035df6b 100644 --- a/pkg/server/admin.go +++ b/pkg/server/admin.go @@ -39,7 +39,11 @@ import ( "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" + "github.com/cockroachdb/cockroach/pkg/server/authserver" + "github.com/cockroachdb/cockroach/pkg/server/privchecker" "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srverrors" "github.com/cockroachdb/cockroach/pkg/server/status" "github.com/cockroachdb/cockroach/pkg/server/telemetry" "github.com/cockroachdb/cockroach/pkg/settings" @@ -50,13 +54,8 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/deprecatedshowranges" "github.com/cockroachdb/cockroach/pkg/sql/isql" "github.com/cockroachdb/cockroach/pkg/sql/parser" - "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" - "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" - "github.com/cockroachdb/cockroach/pkg/sql/privilege" - "github.com/cockroachdb/cockroach/pkg/sql/roleoption" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" - "github.com/cockroachdb/cockroach/pkg/sql/syntheticprivilege" "github.com/cockroachdb/cockroach/pkg/ts/catalog" "github.com/cockroachdb/cockroach/pkg/util/envutil" "github.com/cockroachdb/cockroach/pkg/util/hlc" @@ -68,6 +67,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/mon" "github.com/cockroachdb/cockroach/pkg/util/protoutil" "github.com/cockroachdb/cockroach/pkg/util/quotapool" + "github.com/cockroachdb/cockroach/pkg/util/safesql" "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/util/tracing" @@ -82,18 +82,6 @@ import ( grpcstatus "google.golang.org/grpc/status" ) -const ( - // adminPrefix is the prefix for RESTful endpoints used to provide an - // administrative interface to the cockroach cluster. - adminPrefix = "/_admin/v1/" - - adminHealth = adminPrefix + "health" - - // defaultAPIEventLimit is the default maximum number of events returned by any - // endpoints returning events. - defaultAPIEventLimit = 1000 -) - // Number of empty ranges for table descriptors that aren't actually tables. These // cause special cases in range count computations because we split along them anyway, // but they're not SQL tables. @@ -115,7 +103,8 @@ type adminServer struct { serverpb.UnimplementedAdminServer log.AmbientContext - *adminPrivilegeChecker + privilegeChecker privchecker.CheckerForRPCHandlers + internalExecutor *sql.InternalExecutor sqlServer *SQLServer metricsRecorder *status.MetricsRecorder @@ -160,7 +149,7 @@ var tableStatsMaxFetcherConcurrency = settings.RegisterIntSetting( func newSystemAdminServer( sqlServer *SQLServer, cs *cluster.Settings, - adminAuthzCheck *adminPrivilegeChecker, + adminAuthzCheck privchecker.CheckerForRPCHandlers, ie *sql.InternalExecutor, ambient log.AmbientContext, metricsRecorder *status.MetricsRecorder, @@ -204,7 +193,7 @@ func newSystemAdminServer( func newAdminServer( sqlServer *SQLServer, cs *cluster.Settings, - adminAuthzCheck *adminPrivilegeChecker, + adminAuthzCheck privchecker.CheckerForRPCHandlers, ie *sql.InternalExecutor, ambient log.AmbientContext, metricsRecorder *status.MetricsRecorder, @@ -217,11 +206,11 @@ func newAdminServer( drainServer *drainServer, ) *adminServer { server := &adminServer{ - AmbientContext: ambient, - adminPrivilegeChecker: adminAuthzCheck, - internalExecutor: ie, - sqlServer: sqlServer, - metricsRecorder: metricsRecorder, + AmbientContext: ambient, + privilegeChecker: adminAuthzCheck, + internalExecutor: ie, + sqlServer: sqlServer, + metricsRecorder: metricsRecorder, statsLimiter: quotapool.NewIntPool( "table stats", uint64(tableStatsMaxFetcherConcurrency.Get(&cs.SV)), @@ -299,7 +288,7 @@ func (s *adminServer) RegisterGateway( // but from HTTP metadata (which does not). if s.sqlServer.cfg.Insecure { ctx := req.Context() - ctx = context.WithValue(ctx, webSessionUserKey{}, username.RootUser) + ctx = authserver.ContextWithHTTPAuthInfo(ctx, username.RootUser, 0) req = req.WithContext(ctx) } s.getStatementBundle(req.Context(), id, w) @@ -309,36 +298,6 @@ func (s *adminServer) RegisterGateway( return serverpb.RegisterAdminHandler(ctx, mux, conn) } -// serverError logs the provided error and returns an error that should be returned by -// the RPC endpoint method. -func serverError(ctx context.Context, err error) error { - log.ErrorfDepth(ctx, 1, "%+v", err) - - // Include the PGCode in the message for easier troubleshooting - errCode := pgerror.GetPGCode(err).String() - if errCode != pgcode.Uncategorized.String() { - errMessage := fmt.Sprintf("%s Error Code: %s", errAPIInternalErrorString, errCode) - return grpcstatus.Errorf(codes.Internal, errMessage) - } - - // The error is already grpcstatus formatted error. - // Likely calling serverError multiple times on same error. - grpcCode := grpcstatus.Code(err) - if grpcCode != codes.Unknown { - return err - } - - // Fallback to generic message - return errAPIInternalError -} - -// serverErrorf logs the provided error and returns an error that should be returned by -// the RPC endpoint method. -func serverErrorf(ctx context.Context, format string, args ...interface{}) error { - log.ErrorfDepth(ctx, 1, format, args...) - return errAPIInternalError -} - // isNotFoundError returns true if err is a table/database not found error. func isNotFoundError(err error) bool { // TODO(cdo): Replace this crude suffix-matching with something more structured once we have @@ -366,7 +325,7 @@ func (s *adminServer) ChartCatalog( chartCatalog, err := catalog.GenerateCatalog(metricsMetadata) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } resp := &serverpb.ChartCatalogResponse{ @@ -382,12 +341,12 @@ func (s *adminServer) Databases( ) (_ *serverpb.DatabasesResponse, retErr error) { ctx = s.AnnotateCtx(ctx) - sessionUser, err := userFromIncomingRPCContext(ctx) + sessionUser, err := authserver.UserFromIncomingRPCContext(ctx) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } - if err := s.requireViewActivityPermission(ctx); err != nil { + if err := s.privilegeChecker.RequireViewActivityPermission(ctx); err != nil { return nil, err } @@ -396,7 +355,7 @@ func (s *adminServer) Databases( } // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func (s *adminServer) databasesHelper( ctx context.Context, req *serverpb.DatabasesRequest, @@ -443,7 +402,7 @@ func maybeHandleNotFoundError(ctx context.Context, err error) error { if isNotFoundError(err) { return grpcstatus.Errorf(codes.NotFound, "%s", err) } - return serverError(ctx, err) + return srverrors.ServerError(ctx, err) } // DatabaseDetails is an endpoint that returns grants and a list of table names @@ -452,12 +411,12 @@ func (s *adminServer) DatabaseDetails( ctx context.Context, req *serverpb.DatabaseDetailsRequest, ) (_ *serverpb.DatabaseDetailsResponse, retErr error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromIncomingRPCContext(ctx) + userName, err := authserver.UserFromIncomingRPCContext(ctx) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } - if err := s.requireViewActivityPermission(ctx); err != nil { + if err := s.privilegeChecker.RequireViewActivityPermission(ctx); err != nil { return nil, err } @@ -466,7 +425,7 @@ func (s *adminServer) DatabaseDetails( } // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func (s *adminServer) getDatabaseGrants( ctx context.Context, req *serverpb.DatabaseDetailsRequest, @@ -480,7 +439,7 @@ func (s *adminServer) getDatabaseGrants( // TODO(cdo): Use placeholders when they're supported by SHOW. // Marshal grants. - query := makeSQLQuery() + query := safesql.NewQuery() // We use Sprintf instead of the more canonical query argument approach, as // that doesn't support arguments inside a SHOW subquery yet. query.Append(fmt.Sprintf("SELECT * FROM [SHOW GRANTS ON DATABASE %s]", escDBName)) @@ -538,14 +497,14 @@ func (s *adminServer) getDatabaseGrants( } // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func (s *adminServer) getDatabaseTables( ctx context.Context, req *serverpb.DatabaseDetailsRequest, userName username.SQLUsername, limit, offset int, ) (resp []string, retErr error) { - query := makeSQLQuery() + query := safesql.NewQuery() query.Append(`SELECT table_schema, table_name FROM information_schema.tables WHERE table_catalog = $ AND table_type != 'SYSTEM VIEW'`, req.Database) query.Append(" ORDER BY table_name") @@ -596,7 +555,7 @@ WHERE table_catalog = $ AND table_type != 'SYSTEM VIEW'`, req.Database) } // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func (s *adminServer) getMiscDatabaseDetails( ctx context.Context, req *serverpb.DatabaseDetailsRequest, @@ -633,7 +592,7 @@ func (s *adminServer) getMiscDatabaseDetails( } // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func (s *adminServer) databaseDetailsHelper( ctx context.Context, req *serverpb.DatabaseDetailsRequest, userName username.SQLUsername, ) (_ *serverpb.DatabaseDetailsResponse, retErr error) { @@ -664,7 +623,7 @@ func (s *adminServer) databaseDetailsHelper( return nil, err } dbIndexRecommendations, err := getDatabaseIndexRecommendations( - ctx, req.Database, s.ie, s.st, s.sqlServer.execCfg.UnusedIndexRecommendationsKnobs, + ctx, req.Database, s.internalExecutor, s.st, s.sqlServer.execCfg.UnusedIndexRecommendationsKnobs, ) if err != nil { return nil, err @@ -675,7 +634,7 @@ func (s *adminServer) databaseDetailsHelper( } // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func (s *adminServer) getDatabaseTableSpans( ctx context.Context, userName username.SQLUsername, dbName string, tableNames []string, ) (map[string]roachpb.Span, error) { @@ -696,7 +655,7 @@ func (s *adminServer) getDatabaseTableSpans( } // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func (s *adminServer) getDatabaseStats( ctx context.Context, tableSpans map[string]roachpb.Span, ) (*serverpb.DatabaseDetailsResponse_Stats, error) { @@ -789,7 +748,7 @@ func (s *adminServer) getDatabaseStats( // or database.schema.table if it was. // // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func getFullyQualifiedTableName(dbName string, tableName string) (string, error) { name, err := parser.ParseQualifiedTableName(tableName) if err != nil { @@ -820,12 +779,12 @@ func (s *adminServer) TableDetails( ctx context.Context, req *serverpb.TableDetailsRequest, ) (_ *serverpb.TableDetailsResponse, retErr error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromIncomingRPCContext(ctx) + userName, err := authserver.UserFromIncomingRPCContext(ctx) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } - if err := s.requireViewActivityPermission(ctx); err != nil { + if err := s.privilegeChecker.RequireViewActivityPermission(ctx); err != nil { return nil, err } @@ -834,7 +793,7 @@ func (s *adminServer) TableDetails( } // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func (s *adminServer) tableDetailsHelper( ctx context.Context, req *serverpb.TableDetailsRequest, userName username.SQLUsername, ) (_ *serverpb.TableDetailsResponse, retErr error) { @@ -1238,7 +1197,12 @@ func (s *adminServer) tableDetailsHelper( Database: req.Database, Table: req.Table, } - tableIndexStatsResponse, err := getTableIndexUsageStats(ctx, tableIndexStatsRequest, idxUsageStatsProvider, s.ie, s.st, s.sqlServer.execCfg) + tableIndexStatsResponse, err := getTableIndexUsageStats(ctx, + tableIndexStatsRequest, + idxUsageStatsProvider, + s.internalExecutor, + s.st, + s.sqlServer.execCfg) if err != nil { return nil, err } @@ -1263,14 +1227,14 @@ func (s *adminServer) TableStats( ) (*serverpb.TableStatsResponse, error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromIncomingRPCContext(ctx) + userName, err := authserver.UserFromIncomingRPCContext(ctx) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } - err = s.requireViewActivityPermission(ctx) + err = s.privilegeChecker.RequireViewActivityPermission(ctx) if err != nil { - // NB: not using serverError() here since the priv checker + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -1281,13 +1245,13 @@ func (s *adminServer) TableStats( tableID, err := s.queryTableID(ctx, userName, req.Database, escQualTable) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } tableSpan := generateTableSpan(tableID, s.sqlServer.execCfg.Codec) r, err := s.statsForSpan(ctx, tableSpan) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return r, nil } @@ -1298,8 +1262,8 @@ func (s *adminServer) NonTableStats( ctx context.Context, req *serverpb.NonTableStatsRequest, ) (*serverpb.NonTableStatsResponse, error) { ctx = s.AnnotateCtx(ctx) - if err := s.requireViewActivityPermission(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if err := s.privilegeChecker.RequireViewActivityPermission(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -1309,7 +1273,7 @@ func (s *adminServer) NonTableStats( EndKey: keys.TimeseriesPrefix.PrefixEnd(), }) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } response := serverpb.NonTableStatsResponse{ TimeSeriesStats: timeSeriesStats, @@ -1332,7 +1296,7 @@ func (s *adminServer) NonTableStats( for _, span := range spansForInternalUse { nonTableStats, err := s.statsForSpan(ctx, span) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } if response.InternalUseStats == nil { response.InternalUseStats = nonTableStats @@ -1354,7 +1318,10 @@ func (s *adminServer) NonTableStats( } // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. +// +// TODO(clust-obs): This method should not be implemented on top of +// `adminServer`. There should be a better place for it. func (s *adminServer) statsForSpan( ctx context.Context, span roachpb.Span, ) (*serverpb.TableStatsResponse, error) { @@ -1454,7 +1421,7 @@ func (s *adminServer) statsForSpan( // is missing. For successful calls, aggregate statistics. if resp.err != nil { if s, ok := grpcstatus.FromError(errors.UnwrapAll(resp.err)); ok && s.Code() == codes.PermissionDenied { - return nil, serverError(ctx, resp.err) + return nil, srverrors.ServerError(ctx, resp.err) } // If this node is unreachable, @@ -1519,19 +1486,19 @@ func (s *adminServer) Users( ctx context.Context, req *serverpb.UsersRequest, ) (_ *serverpb.UsersResponse, retErr error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromIncomingRPCContext(ctx) + userName, err := authserver.UserFromIncomingRPCContext(ctx) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } r, err := s.usersHelper(ctx, req, userName) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return r, nil } // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func (s *adminServer) usersHelper( ctx context.Context, req *serverpb.UsersRequest, userName username.SQLUsername, ) (_ *serverpb.UsersResponse, retErr error) { @@ -1580,9 +1547,9 @@ func (s *adminServer) Events( ) (_ *serverpb.EventsResponse, retErr error) { ctx = s.AnnotateCtx(ctx) - userName, err := s.requireAdminUser(ctx) + userName, err := s.privilegeChecker.RequireAdminUser(ctx) if err != nil { - // NB: not using serverError() here since the priv checker + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -1590,18 +1557,18 @@ func (s *adminServer) Events( limit := req.Limit if limit == 0 { - limit = defaultAPIEventLimit + limit = apiconstants.DefaultAPIEventLimit } r, err := s.eventsHelper(ctx, req, userName, int(limit), 0, redactEvents) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return r, nil } // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func (s *adminServer) eventsHelper( ctx context.Context, req *serverpb.EventsRequest, @@ -1610,7 +1577,7 @@ func (s *adminServer) eventsHelper( redactEvents bool, ) (_ *serverpb.EventsResponse, retErr error) { // Execute the query. - q := makeSQLQuery() + q := safesql.NewQuery() q.Append(`SELECT timestamp, "eventType", "reportingID", info, "uniqueID" `) q.Append("FROM system.eventlog ") q.Append("WHERE true ") // This simplifies the WHERE clause logic below. @@ -1723,19 +1690,19 @@ func (s *adminServer) RangeLog( ctx = s.AnnotateCtx(ctx) // Range keys, even when pretty-printed, contain PII. - user, err := userFromIncomingRPCContext(ctx) + user, err := authserver.UserFromIncomingRPCContext(ctx) if err != nil { return nil, err } - err = s.requireViewClusterMetadataPermission(ctx) + err = s.privilegeChecker.RequireViewClusterMetadataPermission(ctx) if err != nil { return nil, err } r, err := s.rangeLogHelper(ctx, req, user) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return r, nil } @@ -1745,11 +1712,11 @@ func (s *adminServer) rangeLogHelper( ) (_ *serverpb.RangeLogResponse, retErr error) { limit := req.Limit if limit == 0 { - limit = defaultAPIEventLimit + limit = apiconstants.DefaultAPIEventLimit } // Execute the query. - q := makeSQLQuery() + q := safesql.NewQuery() q.Append(`SELECT timestamp, "rangeID", "storeID", "eventType", "otherRangeID", info `) q.Append("FROM system.rangelog ") if req.RangeId > 0 { @@ -1866,7 +1833,7 @@ func (s *adminServer) rangeLogHelper( // that are not found will not be returned. // // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func (s *adminServer) getUIData( ctx context.Context, userName username.SQLUsername, keys []string, ) (_ *serverpb.GetUIDataResponse, retErr error) { @@ -1875,7 +1842,7 @@ func (s *adminServer) getUIData( } // Query database. - query := makeSQLQuery() + query := safesql.NewQuery() query.Append(`SELECT key, value, "lastUpdated" FROM system.ui WHERE key IN (`) for i, key := range keys { if i != 0 { @@ -1957,9 +1924,9 @@ func (s *adminServer) SetUIData( ) (*serverpb.SetUIDataResponse, error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromIncomingRPCContext(ctx) + userName, err := authserver.UserFromIncomingRPCContext(ctx) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } if len(req.KeyValues) == 0 { @@ -1976,10 +1943,10 @@ func (s *adminServer) SetUIData( sessiondata.RootUserSessionDataOverride, query, makeUIKey(userName, key), val) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } if rowsAffected != 1 { - return nil, serverErrorf(ctx, "rows affected %d != expected %d", rowsAffected, 1) + return nil, srverrors.ServerErrorf(ctx, "rows affected %d != expected %d", rowsAffected, 1) } } return &serverpb.SetUIDataResponse{}, nil @@ -1996,9 +1963,9 @@ func (s *adminServer) GetUIData( ) (*serverpb.GetUIDataResponse, error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromIncomingRPCContext(ctx) + userName, err := authserver.UserFromIncomingRPCContext(ctx) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } if len(req.Keys) == 0 { @@ -2007,7 +1974,7 @@ func (s *adminServer) GetUIData( resp, err := s.getUIData(ctx, userName, req.Keys) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return resp, nil @@ -2024,9 +1991,9 @@ func (s *adminServer) Settings( keys = settings.Keys(settings.ForSystemTenant) } - _, isAdmin, err := s.getUserAndRole(ctx) + _, isAdmin, err := s.privilegeChecker.GetUserAndRole(ctx) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } redactValues := true @@ -2039,7 +2006,7 @@ func (s *adminServer) Settings( } } else { // Non-root access cannot see the values in any case. - if err := s.adminPrivilegeChecker.requireViewClusterSettingOrModifyClusterSettingPermission(ctx); err != nil { + if err := s.privilegeChecker.RequireViewClusterSettingOrModifyClusterSettingPermission(ctx); err != nil { return nil, err } } @@ -2187,7 +2154,7 @@ func (s *systemAdminServer) checkReadinessForHealthCheck(ctx context.Context) er case modeOperational: break default: - return serverError(ctx, errors.Newf("unknown mode: %v", serveMode)) + return srverrors.ServerError(ctx, errors.Newf("unknown mode: %v", serveMode)) } status := s.nodeLiveness.GetNodeVitalityFromCache(roachpb.NodeID(s.serverIterator.getID())) @@ -2211,7 +2178,7 @@ func getLivenessResponse( nodeVitalityMap, err := nl.ScanNodeVitalityFromKV(ctx) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } livenesses := make([]livenesspb.Liveness, 0, len(nodeVitalityMap)) @@ -2238,7 +2205,7 @@ func getLivenessResponse( func (s *adminServer) Liveness( ctx context.Context, req *serverpb.LivenessRequest, ) (*serverpb.LivenessResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) return s.sqlServer.tenantConnect.Liveness(ctx, req) @@ -2258,9 +2225,9 @@ func (s *adminServer) Jobs( ) (_ *serverpb.JobsResponse, retErr error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromIncomingRPCContext(ctx) + userName, err := authserver.UserFromIncomingRPCContext(ctx) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } j, err := jobsHelper( @@ -2272,13 +2239,13 @@ func (s *adminServer) Jobs( &s.sqlServer.cfg.Settings.SV, ) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return j, nil } // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func jobsHelper( ctx context.Context, req *serverpb.JobsRequest, @@ -2288,7 +2255,7 @@ func jobsHelper( sv *settings.Values, ) (_ *serverpb.JobsResponse, retErr error) { - q := makeSQLQuery() + q := safesql.NewQuery() q.Append(` SELECT job_id, @@ -2462,19 +2429,19 @@ func (s *adminServer) Job( ) (_ *serverpb.JobResponse, retErr error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromIncomingRPCContext(ctx) + userName, err := authserver.UserFromIncomingRPCContext(ctx) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } r, err := jobHelper(ctx, request, userName, s.sqlServer) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return r, nil } // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func jobHelper( ctx context.Context, request *serverpb.JobRequest, @@ -2523,14 +2490,14 @@ func (s *adminServer) Locations( ctx = s.AnnotateCtx(ctx) // Require authentication. - _, err := userFromIncomingRPCContext(ctx) + _, err := authserver.UserFromIncomingRPCContext(ctx) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } r, err := s.locationsHelper(ctx, req) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return r, nil } @@ -2538,7 +2505,7 @@ func (s *adminServer) Locations( func (s *adminServer) locationsHelper( ctx context.Context, req *serverpb.LocationsRequest, ) (_ *serverpb.LocationsResponse, retErr error) { - q := makeSQLQuery() + q := safesql.NewQuery() q.Append(`SELECT "localityKey", "localityValue", latitude, longitude FROM system.locations`) it, err := s.internalExecutor.QueryIteratorEx( ctx, "admin-locations", nil, /* txn */ @@ -2593,19 +2560,19 @@ func (s *adminServer) QueryPlan( ) (*serverpb.QueryPlanResponse, error) { ctx = s.AnnotateCtx(ctx) - userName, err := userFromIncomingRPCContext(ctx) + userName, err := authserver.UserFromIncomingRPCContext(ctx) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } // As long as there's only one query provided it's safe to construct the // explain query. stmts, err := parser.Parse(req.Query) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } if len(stmts) > 1 { - return nil, serverErrorf(ctx, "more than one query provided") + return nil, srverrors.ServerErrorf(ctx, "more than one query provided") } explain := fmt.Sprintf( @@ -2617,15 +2584,15 @@ func (s *adminServer) QueryPlan( explain, ) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } if row == nil { - return nil, serverErrorf(ctx, "failed to query the physical plan") + return nil, srverrors.ServerErrorf(ctx, "failed to query the physical plan") } dbDatum, ok := tree.AsDString(row[0]) if !ok { - return nil, serverErrorf(ctx, "type assertion failed on json: %T", row) + return nil, srverrors.ServerErrorf(ctx, "type assertion failed on json: %T", row) } return &serverpb.QueryPlanResponse{ @@ -2636,7 +2603,7 @@ func (s *adminServer) QueryPlan( // getStatementBundle retrieves the statement bundle with the given id and // writes it out as an attachment. func (s *adminServer) getStatementBundle(ctx context.Context, id int64, w http.ResponseWriter) { - sqlUsername := userFromHTTPAuthInfoContext(ctx) + sqlUsername := authserver.UserFromHTTPAuthInfoContext(ctx) row, err := s.internalExecutor.QueryRowEx( ctx, "admin-stmt-bundle", nil, /* txn */ sessiondata.InternalExecutorOverride{User: sqlUsername}, @@ -2696,7 +2663,7 @@ func (s *systemAdminServer) DecommissionPreCheck( var nodesToCheck []roachpb.NodeID vitality, err := s.nodeLiveness.ScanNodeVitalityFromKV(ctx) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } resp := &serverpb.DecommissionPreCheckResponse{} @@ -2731,16 +2698,16 @@ func (s *systemAdminServer) DecommissionPreCheck( // exist. Ranges with replicas on multiple checked nodes will result in the // error being reported for each nodeID. rangeCheckErrsByNode := make(map[roachpb.NodeID][]serverpb.DecommissionPreCheckResponse_RangeCheckResult) - for _, rangeWithErr := range results.rangesNotReady { + for _, rangeWithErr := range results.RangesNotReady { rangeCheckResult := serverpb.DecommissionPreCheckResponse_RangeCheckResult{ - RangeID: rangeWithErr.desc.RangeID, - Action: rangeWithErr.action, - Events: recordedSpansToTraceEvents(rangeWithErr.tracingSpans), - Error: rangeWithErr.err.Error(), + RangeID: rangeWithErr.Desc.RangeID, + Action: rangeWithErr.Action, + Events: recordedSpansToTraceEvents(rangeWithErr.TracingSpans), + Error: rangeWithErr.Err.Error(), } for _, nID := range nodesToCheck { - if rangeWithErr.desc.Replicas().HasReplicaOnNode(nID) { + if rangeWithErr.Desc.Replicas().HasReplicaOnNode(nID) { rangeCheckErrsByNode[nID] = append(rangeCheckErrsByNode[nID], rangeCheckResult) } } @@ -2749,7 +2716,7 @@ func (s *systemAdminServer) DecommissionPreCheck( // Evaluate readiness by validating that there are no ranges with replicas on // the given node(s) that did not pass checks. for _, nID := range nodesToCheck { - numReplicas := len(results.replicasByNode[nID]) + numReplicas := len(results.ReplicasByNode[nID]) var readiness serverpb.DecommissionPreCheckResponse_NodeReadiness if len(rangeCheckErrsByNode[nID]) > 0 { readiness = serverpb.DecommissionPreCheckResponse_ALLOCATION_ERRORS @@ -2779,13 +2746,13 @@ func (s *systemAdminServer) DecommissionStatus( ) (*serverpb.DecommissionStatusResponse, error) { r, err := s.decommissionStatusHelper(ctx, req) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return r, nil } // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func (s *systemAdminServer) decommissionStatusHelper( ctx context.Context, req *serverpb.DecommissionStatusRequest, ) (*serverpb.DecommissionStatusResponse, error) { @@ -2924,7 +2891,7 @@ func (s *systemAdminServer) Decommission( // Mark the target nodes with their new membership status. They'll find out // as they heartbeat their liveness. if err := s.server.Decommission(ctx, req.TargetMembership, nodeIDs); err != nil { - // NB: not using serverError() here since Decommission + // NB: not using srverrors.ServerError() here since Decommission // already returns a proper gRPC error status. return nil, err } @@ -2950,26 +2917,26 @@ func (s *systemAdminServer) Decommission( func (s *systemAdminServer) DataDistribution( ctx context.Context, req *serverpb.DataDistributionRequest, ) (_ *serverpb.DataDistributionResponse, retErr error) { - if err := s.requireViewClusterMetadataPermission(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if err := s.privilegeChecker.RequireViewClusterMetadataPermission(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } - userName, err := userFromIncomingRPCContext(ctx) + userName, err := authserver.UserFromIncomingRPCContext(ctx) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } r, err := s.dataDistributionHelper(ctx, req, userName) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return r, nil } // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func (s *adminServer) dataDistributionHelper( ctx context.Context, req *serverpb.DataDistributionRequest, userName username.SQLUsername, ) (resp *serverpb.DataDistributionResponse, retErr error) { @@ -3172,11 +3139,11 @@ func (s *adminServer) dataDistributionHelper( func (s *systemAdminServer) EnqueueRange( ctx context.Context, req *serverpb.EnqueueRangeRequest, ) (*serverpb.EnqueueRangeResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if _, err := s.requireAdminUser(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -3199,7 +3166,7 @@ func (s *systemAdminServer) EnqueueRange( } else if req.NodeID != 0 { admin, err := s.dialNode(ctx, req.NodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return admin.EnqueueRange(ctx, req) } @@ -3235,7 +3202,7 @@ func (s *systemAdminServer) EnqueueRange( ) }); err != nil { if len(response.Details) == 0 { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } response.Details = append(response.Details, &serverpb.EnqueueRangeResponse_Details{ Error: err.Error(), @@ -3251,7 +3218,7 @@ func (s *systemAdminServer) EnqueueRange( // response. // // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func (s *systemAdminServer) enqueueRangeLocal( ctx context.Context, req *serverpb.EnqueueRangeRequest, ) (*serverpb.EnqueueRangeResponse, error) { @@ -3322,9 +3289,9 @@ func (s *systemAdminServer) SendKVBatch( ctx = s.AnnotateCtx(ctx) // Note: the root user will bypass SQL auth checks, which is useful in case of // a cluster outage. - user, err := s.requireAdminUser(ctx) + user, err := s.privilegeChecker.RequireAdminUser(ctx) if err != nil { - // NB: not using serverError() here since the priv checker + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -3336,7 +3303,7 @@ func (s *systemAdminServer) SendKVBatch( jsonpb := protoutil.JSONPb{} baJSON, err := jsonpb.Marshal(ba) if err != nil { - return nil, serverError(ctx, errors.Wrap(err, "failed to encode BatchRequest as JSON")) + return nil, srverrors.ServerError(ctx, errors.Wrap(err, "failed to encode BatchRequest as JSON")) } event := &eventpb.DebugSendKvBatch{ CommonEventDetails: logpb.CommonEventDetails{ @@ -3380,7 +3347,7 @@ func (s *systemAdminServer) RecoveryCollectReplicaInfo( ) error { ctx := stream.Context() ctx = s.server.AnnotateCtx(ctx) - _, err := s.requireAdminUser(ctx) + _, err := s.privilegeChecker.RequireAdminUser(ctx) if err != nil { return err } @@ -3395,7 +3362,7 @@ func (s *systemAdminServer) RecoveryCollectLocalReplicaInfo( ) error { ctx := stream.Context() ctx = s.server.AnnotateCtx(ctx) - _, err := s.requireAdminUser(ctx) + _, err := s.privilegeChecker.RequireAdminUser(ctx) if err != nil { return err } @@ -3408,7 +3375,7 @@ func (s *systemAdminServer) RecoveryStagePlan( ctx context.Context, request *serverpb.RecoveryStagePlanRequest, ) (*serverpb.RecoveryStagePlanResponse, error) { ctx = s.server.AnnotateCtx(ctx) - _, err := s.requireAdminUser(ctx) + _, err := s.privilegeChecker.RequireAdminUser(ctx) if err != nil { return nil, err } @@ -3421,7 +3388,7 @@ func (s *systemAdminServer) RecoveryNodeStatus( ctx context.Context, request *serverpb.RecoveryNodeStatusRequest, ) (*serverpb.RecoveryNodeStatusResponse, error) { ctx = s.server.AnnotateCtx(ctx) - _, err := s.requireAdminUser(ctx) + _, err := s.privilegeChecker.RequireAdminUser(ctx) if err != nil { return nil, err } @@ -3433,7 +3400,7 @@ func (s *systemAdminServer) RecoveryVerify( ctx context.Context, request *serverpb.RecoveryVerifyRequest, ) (*serverpb.RecoveryVerifyResponse, error) { ctx = s.server.AnnotateCtx(ctx) - _, err := s.requireAdminUser(ctx) + _, err := s.privilegeChecker.RequireAdminUser(ctx) if err != nil { return nil, err } @@ -3441,76 +3408,6 @@ func (s *systemAdminServer) RecoveryVerify( return s.server.recoveryServer.Verify(ctx, request, s.nodeLiveness, s.db) } -// sqlQuery allows you to incrementally build a SQL query that uses -// placeholders. Instead of specific placeholders like $1, you instead use the -// temporary placeholder $. -type sqlQuery struct { - buf bytes.Buffer - pidx int - qargs []interface{} - errs []error -} - -func makeSQLQuery() *sqlQuery { - res := &sqlQuery{} - return res -} - -// String returns the full query. -func (q *sqlQuery) String() string { - if len(q.errs) > 0 { - return "couldn't generate query: please check Errors()" - } - return q.buf.String() -} - -// Errors returns a slice containing all errors that have happened during the -// construction of this query. -func (q *sqlQuery) Errors() []error { - return q.errs -} - -// QueryArguments returns a filled map of placeholders containing all arguments -// provided to this query through Append. -func (q *sqlQuery) QueryArguments() []interface{} { - return q.qargs -} - -// Append appends the provided string and any number of query parameters. -// Instead of using normal placeholders (e.g. $1, $2), use meta-placeholder $. -// This method rewrites the query so that it uses proper placeholders. -// -// For example, suppose we have the following calls: -// -// query.Append("SELECT * FROM foo WHERE a > $ AND a < $ ", arg1, arg2) -// query.Append("LIMIT $", limit) -// -// The query is rewritten into: -// -// SELECT * FROM foo WHERE a > $1 AND a < $2 LIMIT $3 -// /* $1 = arg1, $2 = arg2, $3 = limit */ -// -// Note that this method does NOT return any errors. Instead, we queue up -// errors, which can later be accessed. Returning an error here would make -// query construction code exceedingly tedious. -func (q *sqlQuery) Append(s string, params ...interface{}) { - var placeholders int - for _, r := range s { - q.buf.WriteRune(r) - if r == '$' { - q.pidx++ - placeholders++ - q.buf.WriteString(strconv.Itoa(q.pidx)) // SQL placeholders are 1-based - } - } - - if placeholders != len(params) { - q.errs = append(q.errs, - errors.Errorf("# of placeholders %d != # of params %d", placeholders, len(params))) - } - q.qargs = append(q.qargs, params...) -} - // resultScanner scans columns from sql.ResultRow instances into variables, // performing the appropriate casting and error detection along the way. type resultScanner struct { @@ -3709,7 +3606,7 @@ func (rs resultScanner) Scan(row tree.Datums, colName string, dst interface{}) e // if it exists. // // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func (s *adminServer) queryZone( ctx context.Context, userName username.SQLUsername, id descpb.ID, ) (zonepb.ZoneConfig, bool, error) { @@ -3755,7 +3652,7 @@ func (s *adminServer) queryZone( // ZoneConfig specified for the object IDs in the path. // // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func (s *adminServer) queryZonePath( ctx context.Context, userName username.SQLUsername, path []descpb.ID, ) (descpb.ID, zonepb.ZoneConfig, bool, error) { @@ -3770,7 +3667,7 @@ func (s *adminServer) queryZonePath( // queryDatabaseID queries for the ID of the database with the given name. // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func (s *adminServer) queryDatabaseID( ctx context.Context, userName username.SQLUsername, name string, ) (descpb.ID, error) { @@ -3807,7 +3704,10 @@ func (s *adminServer) queryDatabaseID( // queryTableID queries for the ID of the table with the given name in the // database with the given name. The table name may contain a schema qualifier. // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. +// +// TODO(clust-obs): This method should not be implemented on top of +// `adminServer`. There should be a better place for it. func (s *adminServer) queryTableID( ctx context.Context, username username.SQLUsername, database string, tableName string, ) (descpb.ID, error) { @@ -3826,7 +3726,7 @@ func (s *adminServer) queryTableID( } // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func (s *adminServer) dialNode( ctx context.Context, nodeID roachpb.NodeID, ) (serverpb.AdminClient, error) { @@ -3837,303 +3737,11 @@ func (s *adminServer) dialNode( return serverpb.NewAdminClient(conn), nil } -// adminPrivilegeChecker is a helper struct to check whether given usernames -// have admin privileges. -type adminPrivilegeChecker struct { - ie isql.Executor - st *cluster.Settings - // makePlanner is a function that calls NewInternalPlanner - // to make a planner outside of the sql package. This is a hack - // to get around a Go package dependency cycle. See comment - // in pkg/scheduledjobs/env.go on planHookMaker. It should - // be cast to AuthorizationAccessor in order to use - // privilege checking functions. - makePlanner func(opName string) (interface{}, func()) -} - -// requireAdminUser's error return is a gRPC error. -func (c *adminPrivilegeChecker) requireAdminUser( - ctx context.Context, -) (userName username.SQLUsername, err error) { - userName, isAdmin, err := c.getUserAndRole(ctx) - if err != nil { - return userName, serverError(ctx, err) - } - if !isAdmin { - return userName, errRequiresAdmin - } - return userName, nil -} - -// requireViewActivityPermission's error return is a gRPC error. -func (c *adminPrivilegeChecker) requireViewActivityPermission(ctx context.Context) (err error) { - userName, isAdmin, err := c.getUserAndRole(ctx) - if err != nil { - return serverError(ctx, err) - } - if isAdmin { - return nil - } - if hasView, err := c.hasGlobalPrivilege(ctx, userName, privilege.VIEWACTIVITY); err != nil { - return serverError(ctx, err) - } else if hasView { - return nil - } - if hasView, err := c.hasRoleOption(ctx, userName, roleoption.VIEWACTIVITY); err != nil { - return serverError(ctx, err) - } else if hasView { - return nil - } - return grpcstatus.Errorf( - codes.PermissionDenied, "this operation requires the %s system privilege", - roleoption.VIEWACTIVITY) -} - -// requireViewActivityOrViewActivityRedactedPermission's error return is a gRPC error. -func (c *adminPrivilegeChecker) requireViewActivityOrViewActivityRedactedPermission( - ctx context.Context, -) (err error) { - userName, isAdmin, err := c.getUserAndRole(ctx) - if err != nil { - return serverError(ctx, err) - } - if isAdmin { - return nil - } - if hasView, err := c.hasGlobalPrivilege(ctx, userName, privilege.VIEWACTIVITY); err != nil { - return serverError(ctx, err) - } else if hasView { - return nil - } - if hasViewRedacted, err := c.hasGlobalPrivilege(ctx, userName, privilege.VIEWACTIVITYREDACTED); err != nil { - return serverError(ctx, err) - } else if hasViewRedacted { - return nil - } - if hasView, err := c.hasRoleOption(ctx, userName, roleoption.VIEWACTIVITY); err != nil { - return serverError(ctx, err) - } else if hasView { - return nil - } - if hasViewRedacted, err := c.hasRoleOption(ctx, userName, roleoption.VIEWACTIVITYREDACTED); err != nil { - return serverError(ctx, err) - } else if hasViewRedacted { - return nil - } - return grpcstatus.Errorf( - codes.PermissionDenied, "this operation requires the %s or %s system privileges", - roleoption.VIEWACTIVITY, roleoption.VIEWACTIVITYREDACTED) -} - -// requireViewClusterSettingOrModifyClusterSettingPermission's error return is a gRPC error. -func (c *adminPrivilegeChecker) requireViewClusterSettingOrModifyClusterSettingPermission( - ctx context.Context, -) (err error) { - userName, isAdmin, err := c.getUserAndRole(ctx) - if err != nil { - return serverError(ctx, err) - } - if isAdmin { - return nil - } - if hasView, err := c.hasGlobalPrivilege(ctx, userName, privilege.VIEWCLUSTERSETTING); err != nil { - return serverError(ctx, err) - } else if hasView { - return nil - } - if hasModify, err := c.hasGlobalPrivilege(ctx, userName, privilege.MODIFYCLUSTERSETTING); err != nil { - return serverError(ctx, err) - } else if hasModify { - return nil - } - if hasView, err := c.hasRoleOption(ctx, userName, roleoption.VIEWCLUSTERSETTING); err != nil { - return serverError(ctx, err) - } else if hasView { - return nil - } - if hasModify, err := c.hasRoleOption(ctx, userName, roleoption.MODIFYCLUSTERSETTING); err != nil { - return serverError(ctx, err) - } else if hasModify { - return nil - } - return grpcstatus.Errorf( - codes.PermissionDenied, "this operation requires the %s or %s system privileges", - privilege.VIEWCLUSTERSETTING, privilege.MODIFYCLUSTERSETTING) -} - -// This function requires that the user have the VIEWACTIVITY role, but does not -// have the VIEWACTIVITYREDACTED role. -// This function's error return is a gRPC error. -func (c *adminPrivilegeChecker) requireViewActivityAndNoViewActivityRedactedPermission( - ctx context.Context, -) (err error) { - userName, isAdmin, err := c.getUserAndRole(ctx) - if err != nil { - return serverError(ctx, err) - } - - if !isAdmin { - hasViewRedacted, err := c.hasGlobalPrivilege(ctx, userName, privilege.VIEWACTIVITYREDACTED) - if err != nil { - return serverError(ctx, err) - } - if !hasViewRedacted { - hasViewRedacted, err := c.hasRoleOption(ctx, userName, roleoption.VIEWACTIVITYREDACTED) - if err != nil { - return serverError(ctx, err) - } - if hasViewRedacted { - return grpcstatus.Errorf( - codes.PermissionDenied, "this operation requires %s role option and is not allowed for %s role option", - roleoption.VIEWACTIVITY, roleoption.VIEWACTIVITYREDACTED) - } - } else { - return grpcstatus.Errorf( - codes.PermissionDenied, "this operation requires %s system privilege and is not allowed for %s system privilege", - privilege.VIEWACTIVITY, privilege.VIEWACTIVITYREDACTED) - } - return c.requireViewActivityPermission(ctx) - } - return nil -} - -// requireViewClusterMetadataPermission requires the user have admin or the VIEWCLUSTERMETADATA -// system privilege and returns an error if the user does not have it. -func (c *adminPrivilegeChecker) requireViewClusterMetadataPermission( - ctx context.Context, -) (err error) { - userName, isAdmin, err := c.getUserAndRole(ctx) - if err != nil { - return serverError(ctx, err) - } - if isAdmin { - return nil - } - if hasViewClusterMetadata, err := c.hasGlobalPrivilege(ctx, userName, privilege.VIEWCLUSTERMETADATA); err != nil { - return serverError(ctx, err) - } else if hasViewClusterMetadata { - return nil - } - return grpcstatus.Errorf( - codes.PermissionDenied, "this operation requires the %s system privilege", - privilege.VIEWCLUSTERMETADATA) -} - -// requireViewDebugPermission requires the user have admin or the VIEWDEBUG system privilege -// and returns an error if the user does not have it. -func (c *adminPrivilegeChecker) requireViewDebugPermission(ctx context.Context) (err error) { - userName, isAdmin, err := c.getUserAndRole(ctx) - if err != nil { - return serverError(ctx, err) - } - if isAdmin { - return nil - } - if hasViewDebug, err := c.hasGlobalPrivilege(ctx, userName, privilege.VIEWDEBUG); err != nil { - return serverError(ctx, err) - } else if hasViewDebug { - return nil - } - return grpcstatus.Errorf( - codes.PermissionDenied, "this operation requires the %s system privilege", - privilege.VIEWDEBUG) -} - -// Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. -func (c *adminPrivilegeChecker) getUserAndRole( - ctx context.Context, -) (userName username.SQLUsername, isAdmin bool, err error) { - userName, err = userFromIncomingRPCContext(ctx) - if err != nil { - return userName, false, err - } - isAdmin, err = c.hasAdminRole(ctx, userName) - return userName, isAdmin, err -} - -// Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. -func (c *adminPrivilegeChecker) hasAdminRole( - ctx context.Context, user username.SQLUsername, -) (bool, error) { - if user.IsRootUser() { - // Shortcut. - return true, nil - } - row, err := c.ie.QueryRowEx( - ctx, "check-is-admin", nil, /* txn */ - sessiondata.InternalExecutorOverride{User: user}, - "SELECT crdb_internal.is_admin()") - if err != nil { - return false, err - } - if row == nil { - return false, errors.AssertionFailedf("hasAdminRole: expected 1 row, got 0") - } - if len(row) != 1 { - return false, errors.AssertionFailedf("hasAdminRole: expected 1 column, got %d", len(row)) - } - dbDatum, ok := tree.AsDBool(row[0]) - if !ok { - return false, errors.AssertionFailedf("hasAdminRole: expected bool, got %T", row[0]) - } - return bool(dbDatum), nil -} - -// Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. -func (c *adminPrivilegeChecker) hasRoleOption( - ctx context.Context, user username.SQLUsername, roleOption roleoption.Option, -) (bool, error) { - if user.IsRootUser() { - // Shortcut. - return true, nil - } - row, err := c.ie.QueryRowEx( - ctx, "check-role-option", nil, /* txn */ - sessiondata.InternalExecutorOverride{User: user}, - "SELECT crdb_internal.has_role_option($1)", roleOption.String()) - if err != nil { - return false, err - } - if row == nil { - return false, errors.AssertionFailedf("hasRoleOption: expected 1 row, got 0") - } - if len(row) != 1 { - return false, errors.AssertionFailedf("hasRoleOption: expected 1 column, got %d", len(row)) - } - dbDatum, ok := tree.AsDBool(row[0]) - if !ok { - return false, errors.AssertionFailedf("hasRoleOption: expected bool, got %T", row[0]) - } - return bool(dbDatum), nil -} - -// hasGlobalPrivilege is a helper function which calls -// CheckPrivilege and returns a true/false based on the returned -// result. -func (c *adminPrivilegeChecker) hasGlobalPrivilege( - ctx context.Context, user username.SQLUsername, privilege privilege.Kind, -) (bool, error) { - planner, cleanup := c.makePlanner("check-system-privilege") - defer cleanup() - aa := planner.(sql.AuthorizationAccessor) - return aa.HasPrivilege(ctx, syntheticprivilege.GlobalPrivilegeObject, privilege, user) -} - -var errRequiresAdmin = grpcstatus.Error(codes.PermissionDenied, "this operation requires admin privilege") - -func errRequiresRoleOption(option roleoption.Option) error { - return grpcstatus.Errorf( - codes.PermissionDenied, "this operation requires %s privilege", option) -} - func (s *adminServer) ListTracingSnapshots( ctx context.Context, req *serverpb.ListTracingSnapshotsRequest, ) (*serverpb.ListTracingSnapshotsResponse, error) { ctx = s.AnnotateCtx(ctx) - err := s.requireViewDebugPermission(ctx) + err := s.privilegeChecker.RequireViewDebugPermission(ctx) if err != nil { return nil, err } @@ -4171,7 +3779,7 @@ func (s *adminServer) TakeTracingSnapshot( ctx context.Context, req *serverpb.TakeTracingSnapshotRequest, ) (*serverpb.TakeTracingSnapshotResponse, error) { ctx = s.AnnotateCtx(ctx) - err := s.requireViewDebugPermission(ctx) + err := s.privilegeChecker.RequireViewDebugPermission(ctx) if err != nil { return nil, err } @@ -4215,7 +3823,7 @@ func (s *adminServer) GetTracingSnapshot( ctx context.Context, req *serverpb.GetTracingSnapshotRequest, ) (*serverpb.GetTracingSnapshotResponse, error) { ctx = s.AnnotateCtx(ctx) - err := s.requireViewDebugPermission(ctx) + err := s.privilegeChecker.RequireViewDebugPermission(ctx) if err != nil { return nil, err } @@ -4292,7 +3900,7 @@ func (s *adminServer) GetTrace( ctx context.Context, req *serverpb.GetTraceRequest, ) (*serverpb.GetTraceResponse, error) { ctx = s.AnnotateCtx(ctx) - err := s.requireViewDebugPermission(ctx) + err := s.privilegeChecker.RequireViewDebugPermission(ctx) if err != nil { return nil, err } diff --git a/pkg/server/admin_cluster_test.go b/pkg/server/admin_cluster_test.go deleted file mode 100644 index 450e369d3a7e..000000000000 --- a/pkg/server/admin_cluster_test.go +++ /dev/null @@ -1,254 +0,0 @@ -// Copyright 2016 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package server_test - -import ( - "context" - "testing" - "time" - - "github.com/cockroachdb/cockroach/pkg/base" - "github.com/cockroachdb/cockroach/pkg/kv/kvserver" - "github.com/cockroachdb/cockroach/pkg/kv/kvserver/liveness/livenesspb" - "github.com/cockroachdb/cockroach/pkg/roachpb" - "github.com/cockroachdb/cockroach/pkg/server/serverpb" - "github.com/cockroachdb/cockroach/pkg/testutils" - "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" - "github.com/cockroachdb/cockroach/pkg/testutils/testcluster" - "github.com/cockroachdb/cockroach/pkg/util/httputil" - "github.com/cockroachdb/cockroach/pkg/util/leaktest" - "github.com/cockroachdb/cockroach/pkg/util/log" - "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestAdminAPIDatabaseDetails(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - const numServers = 3 - tc := testcluster.StartTestCluster(t, numServers, base.TestClusterArgs{}) - defer tc.Stopper().Stop(context.Background()) - - db := tc.ServerConn(0) - - _, err := db.Exec("CREATE DATABASE test") - require.NoError(t, err) - - _, err = db.Exec("CREATE TABLE test.foo (id INT PRIMARY KEY, val STRING)") - require.NoError(t, err) - - for i := 0; i < 10; i++ { - _, err := db.Exec("INSERT INTO test.foo VALUES($1, $2)", i, "test") - require.NoError(t, err) - } - - // Flush all stores here so that we can read the ApproximateDiskBytes field without waiting for a flush. - for i := 0; i < numServers; i++ { - s := tc.Server(i) - err = s.GetStores().(*kvserver.Stores).VisitStores(func(store *kvserver.Store) error { - return store.TODOEngine().Flush() - }) - require.NoError(t, err) - } - - s := tc.Server(0) - - var resp serverpb.DatabaseDetailsResponse - require.NoError(t, serverutils.GetJSONProto(s, "/_admin/v1/databases/test", &resp)) - assert.Nil(t, resp.Stats, "No Stats unless we ask for them explicitly.") - - nodeIDs := tc.NodeIDs() - testutils.SucceedsSoon(t, func() error { - var resp serverpb.DatabaseDetailsResponse - require.NoError(t, serverutils.GetJSONProto(s, "/_admin/v1/databases/test?include_stats=true", &resp)) - - if resp.Stats.RangeCount != int64(1) { - return errors.Newf("expected range-count=1, got %d", resp.Stats.RangeCount) - } - if len(resp.Stats.NodeIDs) != len(nodeIDs) { - return errors.Newf("expected node-ids=%s, got %s", nodeIDs, resp.Stats.NodeIDs) - } - assert.Equal(t, nodeIDs, resp.Stats.NodeIDs, "NodeIDs") - - // We've flushed data so this estimation should be non-zero. - assert.Positive(t, resp.Stats.ApproximateDiskBytes, "ApproximateDiskBytes") - - return nil - }) -} - -func TestAdminAPITableStats(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - const nodeCount = 3 - tc := testcluster.StartTestCluster(t, nodeCount, base.TestClusterArgs{ - ReplicationMode: base.ReplicationAuto, - ServerArgs: base.TestServerArgs{ - ScanInterval: time.Millisecond, - ScanMinIdleTime: time.Millisecond, - ScanMaxIdleTime: time.Millisecond, - }, - }) - defer tc.Stopper().Stop(context.Background()) - server0 := tc.Server(0) - - // Create clients (SQL, HTTP) connected to server 0. - db := tc.ServerConn(0) - - client, err := server0.GetAdminHTTPClient() - if err != nil { - t.Fatal(err) - } - - client.Timeout = time.Hour // basically no timeout - - // Make a single table and insert some data. The database and test have - // names which require escaping, in order to verify that database and - // table names are being handled correctly. - if _, err := db.Exec(`CREATE DATABASE "test test"`); err != nil { - t.Fatal(err) - } - if _, err := db.Exec(` - CREATE TABLE "test test"."foo foo" ( - id INT PRIMARY KEY, - val STRING - )`, - ); err != nil { - t.Fatal(err) - } - for i := 0; i < 10; i++ { - if _, err := db.Exec(` - INSERT INTO "test test"."foo foo" VALUES( - $1, $2 - )`, i, "test", - ); err != nil { - t.Fatal(err) - } - } - - url := server0.AdminURL().String() + "/_admin/v1/databases/test test/tables/foo foo/stats" - var tsResponse serverpb.TableStatsResponse - - // The new SQL table may not yet have split into its own range. Wait for - // this to occur, and for full replication. - testutils.SucceedsSoon(t, func() error { - if err := httputil.GetJSON(client, url, &tsResponse); err != nil { - t.Fatal(err) - } - if len(tsResponse.MissingNodes) != 0 { - return errors.Errorf("missing nodes: %+v", tsResponse.MissingNodes) - } - if tsResponse.RangeCount != 1 { - return errors.Errorf("Table range not yet separated.") - } - if tsResponse.NodeCount != nodeCount { - return errors.Errorf("Table range not yet replicated to %d nodes.", 3) - } - if a, e := tsResponse.ReplicaCount, int64(nodeCount); a != e { - return errors.Errorf("expected %d replicas, found %d", e, a) - } - if a, e := tsResponse.Stats.KeyCount, int64(30); a < e { - return errors.Errorf("expected at least %d total keys, found %d", e, a) - } - return nil - }) - - if len(tsResponse.MissingNodes) > 0 { - t.Fatalf("expected no missing nodes, found %v", tsResponse.MissingNodes) - } - - // Kill a node, ensure it shows up in MissingNodes and that ReplicaCount is - // lower. - tc.StopServer(1) - - if err := httputil.GetJSON(client, url, &tsResponse); err != nil { - t.Fatal(err) - } - if a, e := tsResponse.NodeCount, int64(nodeCount); a != e { - t.Errorf("expected %d nodes, found %d", e, a) - } - if a, e := tsResponse.RangeCount, int64(1); a != e { - t.Errorf("expected %d ranges, found %d", e, a) - } - if a, e := tsResponse.ReplicaCount, int64((nodeCount/2)+1); a != e { - t.Errorf("expected %d replicas, found %d", e, a) - } - if a, e := tsResponse.Stats.KeyCount, int64(10); a < e { - t.Errorf("expected at least 10 total keys, found %d", a) - } - if len(tsResponse.MissingNodes) != 1 { - t.Errorf("expected one missing node, found %v", tsResponse.MissingNodes) - } - if len(tsResponse.NodeIDs) == 0 { - t.Error("expected at least one node in NodeIds list") - } - - // Call TableStats with a very low timeout. This tests that fan-out queries - // do not leak goroutines if the calling context is abandoned. - // Interestingly, the call can actually sometimes succeed, despite the small - // timeout; however, in aggregate (or in stress tests) this will suffice for - // detecting leaks. - client.Timeout = 1 * time.Nanosecond - _ = httputil.GetJSON(client, url, &tsResponse) -} - -func TestLivenessAPI(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - tc := testcluster.StartTestCluster(t, 3, base.TestClusterArgs{}) - defer tc.Stopper().Stop(context.Background()) - - startTime := tc.Server(0).Clock().PhysicalNow() - - // We need to retry because the gossiping of liveness status is an - // asynchronous process. - testutils.SucceedsSoon(t, func() error { - var resp serverpb.LivenessResponse - if err := serverutils.GetJSONProto(tc.Server(0), "/_admin/v1/liveness", &resp); err != nil { - return err - } - if a, e := len(resp.Livenesses), tc.NumServers(); a != e { - return errors.Errorf("found %d liveness records, wanted %d", a, e) - } - livenessMap := make(map[roachpb.NodeID]livenesspb.Liveness) - for _, l := range resp.Livenesses { - livenessMap[l.NodeID] = l - } - for i := 0; i < tc.NumServers(); i++ { - s := tc.Server(i) - sl, ok := livenessMap[s.NodeID()] - if !ok { - return errors.Errorf("found no liveness record for node %d", s.NodeID()) - } - if sl.Expiration.WallTime < startTime { - return errors.Errorf( - "expected node %d liveness to expire in future (after %d), expiration was %d", - s.NodeID(), - startTime, - sl.Expiration, - ) - } - status, ok := resp.Statuses[s.NodeID()] - if !ok { - return errors.Errorf("found no liveness status for node %d", s.NodeID()) - } - if a, e := status, livenesspb.NodeLivenessStatus_LIVE; a != e { - return errors.Errorf( - "liveness status for node %s was %s, wanted %s", s.NodeID(), a, e, - ) - } - } - return nil - }) -} diff --git a/pkg/server/admin_test.go b/pkg/server/admin_test.go deleted file mode 100644 index 7088f112f207..000000000000 --- a/pkg/server/admin_test.go +++ /dev/null @@ -1,3502 +0,0 @@ -// Copyright 2014 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package server - -import ( - "bytes" - "context" - gosql "database/sql" - "encoding/json" - "fmt" - "io" - "math" - "net/http" - "net/url" - "reflect" - "regexp" - "sort" - "strings" - "testing" - "time" - - "github.com/cockroachdb/cockroach/pkg/base" - "github.com/cockroachdb/cockroach/pkg/config/zonepb" - "github.com/cockroachdb/cockroach/pkg/jobs" - "github.com/cockroachdb/cockroach/pkg/jobs/jobspb" - "github.com/cockroachdb/cockroach/pkg/keys" - "github.com/cockroachdb/cockroach/pkg/kv/kvserver" - "github.com/cockroachdb/cockroach/pkg/kv/kvserver/kvserverpb" - "github.com/cockroachdb/cockroach/pkg/kv/kvserver/liveness" - "github.com/cockroachdb/cockroach/pkg/kv/kvserver/liveness/livenesspb" - "github.com/cockroachdb/cockroach/pkg/roachpb" - "github.com/cockroachdb/cockroach/pkg/rpc" - "github.com/cockroachdb/cockroach/pkg/security/username" - "github.com/cockroachdb/cockroach/pkg/server/debug" - "github.com/cockroachdb/cockroach/pkg/server/serverpb" - "github.com/cockroachdb/cockroach/pkg/server/telemetry" - "github.com/cockroachdb/cockroach/pkg/settings" - "github.com/cockroachdb/cockroach/pkg/settings/cluster" - "github.com/cockroachdb/cockroach/pkg/sql" - "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" - "github.com/cockroachdb/cockroach/pkg/sql/idxusage" - "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" - "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" - "github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants" - "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" - "github.com/cockroachdb/cockroach/pkg/testutils" - "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" - "github.com/cockroachdb/cockroach/pkg/testutils/skip" - "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" - "github.com/cockroachdb/cockroach/pkg/util/hlc" - "github.com/cockroachdb/cockroach/pkg/util/leaktest" - "github.com/cockroachdb/cockroach/pkg/util/log" - "github.com/cockroachdb/cockroach/pkg/util/protoutil" - "github.com/cockroachdb/cockroach/pkg/util/randident" - "github.com/cockroachdb/cockroach/pkg/util/randutil" - "github.com/cockroachdb/cockroach/pkg/util/timeutil" - "github.com/cockroachdb/cockroach/pkg/util/uuid" - "github.com/cockroachdb/errors" - "github.com/gogo/protobuf/proto" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" - "google.golang.org/grpc/status" -) - -func getAdminJSONProto( - ts serverutils.TestServerInterface, path string, response protoutil.Message, -) error { - return getAdminJSONProtoWithAdminOption(ts, path, response, true) -} - -func getAdminJSONProtoWithAdminOption( - ts serverutils.TestServerInterface, path string, response protoutil.Message, isAdmin bool, -) error { - return serverutils.GetJSONProtoWithAdminOption(ts, adminPrefix+path, response, isAdmin) -} - -func postAdminJSONProto( - ts serverutils.TestServerInterface, path string, request, response protoutil.Message, -) error { - return postAdminJSONProtoWithAdminOption(ts, path, request, response, true) -} - -func postAdminJSONProtoWithAdminOption( - ts serverutils.TestServerInterface, - path string, - request, response protoutil.Message, - isAdmin bool, -) error { - return serverutils.PostJSONProtoWithAdminOption(ts, adminPrefix+path, request, response, isAdmin) -} - -// getText fetches the HTTP response body as text in the form of a -// byte slice from the specified URL. -func getText(ts serverutils.TestServerInterface, url string) ([]byte, error) { - httpClient, err := ts.GetAdminHTTPClient() - if err != nil { - return nil, err - } - resp, err := httpClient.Get(url) - if err != nil { - return nil, err - } - defer resp.Body.Close() - return io.ReadAll(resp.Body) -} - -// getJSON fetches the JSON from the specified URL and returns -// it as unmarshaled JSON. Returns an error on any failure to fetch -// or unmarshal response body. -func getJSON(ts serverutils.TestServerInterface, url string) (interface{}, error) { - body, err := getText(ts, url) - if err != nil { - return nil, err - } - var jI interface{} - if err := json.Unmarshal(body, &jI); err != nil { - return nil, errors.Wrapf(err, "body is:\n%s", body) - } - return jI, nil -} - -// debugURL returns the root debug URL. -func debugURL(s serverutils.TestServerInterface) string { - return s.AdminURL().WithPath(debug.Endpoint).String() -} - -// TestAdminDebugExpVar verifies that cmdline and memstats variables are -// available via the /debug/vars link. -func TestAdminDebugExpVar(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails with - // it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - - jI, err := getJSON(s, debugURL(s)+"vars") - if err != nil { - t.Fatalf("failed to fetch JSON: %v", err) - } - j := jI.(map[string]interface{}) - if _, ok := j["cmdline"]; !ok { - t.Error("cmdline not found in JSON response") - } - if _, ok := j["memstats"]; !ok { - t.Error("memstats not found in JSON response") - } -} - -// TestAdminDebugMetrics verifies that cmdline and memstats variables are -// available via the /debug/metrics link. -func TestAdminDebugMetrics(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails with - // it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - - jI, err := getJSON(s, debugURL(s)+"metrics") - if err != nil { - t.Fatalf("failed to fetch JSON: %v", err) - } - j := jI.(map[string]interface{}) - if _, ok := j["cmdline"]; !ok { - t.Error("cmdline not found in JSON response") - } - if _, ok := j["memstats"]; !ok { - t.Error("memstats not found in JSON response") - } -} - -// TestAdminDebugPprof verifies that pprof tools are available. -// via the /debug/pprof/* links. -func TestAdminDebugPprof(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails with - // it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - - body, err := getText(s, debugURL(s)+"pprof/block?debug=1") - if err != nil { - t.Fatal(err) - } - if exp := "contention:\ncycles/second="; !bytes.Contains(body, []byte(exp)) { - t.Errorf("expected %s to contain %s", body, exp) - } -} - -// TestAdminDebugTrace verifies that the net/trace endpoints are available -// via /debug/{requests,events}. -func TestAdminDebugTrace(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails with - // it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - - tc := []struct { - segment, search string - }{ - {"requests", "/debug/requests"}, - {"events", "events"}, - } - - for _, c := range tc { - body, err := getText(s, debugURL(s)+c.segment) - if err != nil { - t.Fatal(err) - } - if !bytes.Contains(body, []byte(c.search)) { - t.Errorf("expected %s to be contained in %s", c.search, body) - } - } -} - -func TestAdminDebugAuth(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails with - // it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - ts := s.(*TestServer) - - url := debugURL(s) - - // Unauthenticated. - client, err := ts.GetUnauthenticatedHTTPClient() - if err != nil { - t.Fatal(err) - } - resp, err := client.Get(url) - if err != nil { - t.Fatal(err) - } - resp.Body.Close() - if resp.StatusCode != http.StatusUnauthorized { - t.Errorf("expected status code %d; got %d", http.StatusUnauthorized, resp.StatusCode) - } - - // Authenticated as non-admin. - client, err = ts.GetAuthenticatedHTTPClient(false, serverutils.SingleTenantSession) - if err != nil { - t.Fatal(err) - } - resp, err = client.Get(url) - if err != nil { - t.Fatal(err) - } - resp.Body.Close() - if resp.StatusCode != http.StatusUnauthorized { - t.Errorf("expected status code %d; got %d", http.StatusUnauthorized, resp.StatusCode) - } - - // Authenticated as admin. - client, err = ts.GetAuthenticatedHTTPClient(true, serverutils.SingleTenantSession) - if err != nil { - t.Fatal(err) - } - resp, err = client.Get(url) - if err != nil { - t.Fatal(err) - } - resp.Body.Close() - if resp.StatusCode != http.StatusOK { - t.Errorf("expected status code %d; got %d", http.StatusOK, resp.StatusCode) - } -} - -// TestAdminDebugRedirect verifies that the /debug/ endpoint is redirected to on -// incorrect /debug/ paths. -func TestAdminDebugRedirect(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails with - // it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - ts := s.(*TestServer) - - expURL := debugURL(s) - origURL := expURL + "incorrect" - - // Must be admin to access debug endpoints - client, err := ts.GetAdminHTTPClient() - if err != nil { - t.Fatal(err) - } - - // Don't follow redirects automatically. - redirectAttemptedError := errors.New("redirect") - client.CheckRedirect = func(req *http.Request, via []*http.Request) error { - return redirectAttemptedError - } - - resp, err := client.Get(origURL) - if urlError := (*url.Error)(nil); errors.As(err, &urlError) && - errors.Is(urlError.Err, redirectAttemptedError) { - // Ignore the redirectAttemptedError. - err = nil - } - if err != nil { - t.Fatal(err) - } else { - resp.Body.Close() - if resp.StatusCode != http.StatusMovedPermanently { - t.Errorf("expected status code %d; got %d", http.StatusMovedPermanently, resp.StatusCode) - } - if redirectURL, err := resp.Location(); err != nil { - t.Error(err) - } else if foundURL := redirectURL.String(); foundURL != expURL { - t.Errorf("expected location %s; got %s", expURL, foundURL) - } - } -} - -func generateRandomName() string { - rand, _ := randutil.NewTestRand() - cfg := randident.DefaultNameGeneratorConfig() - // REST api can not handle `/`. This is fixed in - // the UI by using sql-over-http endpoint instead. - cfg.Punctuate = -1 - cfg.Finalize() - - ng := randident.NewNameGenerator( - &cfg, - rand, - "a b%s-c.d", - ) - return ng.GenerateOne(42) -} - -func TestAdminAPIStatementDiagnosticsBundle(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) - defer s.Stopper().Stop(context.Background()) - ts := s.(*TestServer) - - query := "EXPLAIN ANALYZE (DEBUG) SELECT 'secret'" - _, err := db.Exec(query) - require.NoError(t, err) - - query = "SELECT id FROM system.statement_diagnostics LIMIT 1" - idRow, err := db.Query(query) - require.NoError(t, err) - var diagnosticRow string - if idRow.Next() { - err = idRow.Scan(&diagnosticRow) - require.NoError(t, err) - } else { - t.Fatal("no results") - } - - client, err := ts.GetAuthenticatedHTTPClient(false, serverutils.SingleTenantSession) - require.NoError(t, err) - resp, err := client.Get(ts.AdminURL().WithPath("/_admin/v1/stmtbundle/" + diagnosticRow).String()) - require.NoError(t, err) - defer resp.Body.Close() - require.Equal(t, 500, resp.StatusCode) - - adminClient, err := ts.GetAuthenticatedHTTPClient(true, serverutils.SingleTenantSession) - require.NoError(t, err) - adminResp, err := adminClient.Get(ts.AdminURL().WithPath("/_admin/v1/stmtbundle/" + diagnosticRow).String()) - require.NoError(t, err) - defer adminResp.Body.Close() - require.Equal(t, 200, adminResp.StatusCode) -} - -func TestAdminAPIDatabases(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, db, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails with - // it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - ts := s.(*TestServer) - - ac := ts.AmbientCtx() - ctx, span := ac.AnnotateCtxWithSpan(context.Background(), "test") - defer span.Finish() - - testDbName := generateRandomName() - testDbEscaped := tree.NameString(testDbName) - query := "CREATE DATABASE " + testDbEscaped - if _, err := db.Exec(query); err != nil { - t.Fatal(err) - } - // Test needs to revoke CONNECT on the public database to properly exercise - // fine-grained permissions logic. - if _, err := db.Exec(fmt.Sprintf("REVOKE CONNECT ON DATABASE %s FROM public", testDbEscaped)); err != nil { - t.Fatal(err) - } - if _, err := db.Exec("REVOKE CONNECT ON DATABASE defaultdb FROM public"); err != nil { - t.Fatal(err) - } - - // We have to create the non-admin user before calling - // "GRANT ... TO authenticatedUserNameNoAdmin". - // This is done in "GetAuthenticatedHTTPClient". - if _, err := ts.GetAuthenticatedHTTPClient(false, serverutils.SingleTenantSession); err != nil { - t.Fatal(err) - } - - // Grant permissions to view the tables for the given viewing user. - privileges := []string{"CONNECT"} - query = fmt.Sprintf( - "GRANT %s ON DATABASE %s TO %s", - strings.Join(privileges, ", "), - testDbEscaped, - authenticatedUserNameNoAdmin().SQLIdentifier(), - ) - if _, err := db.Exec(query); err != nil { - t.Fatal(err) - } - // Non admins now also require VIEWACTIVITY. - query = fmt.Sprintf( - "GRANT SYSTEM %s TO %s", - "VIEWACTIVITY", - authenticatedUserNameNoAdmin().SQLIdentifier(), - ) - if _, err := db.Exec(query); err != nil { - t.Fatal(err) - } - - for _, tc := range []struct { - expectedDBs []string - isAdmin bool - }{ - {[]string{"defaultdb", "postgres", "system", testDbName}, true}, - {[]string{"postgres", testDbName}, false}, - } { - t.Run(fmt.Sprintf("isAdmin:%t", tc.isAdmin), func(t *testing.T) { - // Test databases endpoint. - var resp serverpb.DatabasesResponse - if err := getAdminJSONProtoWithAdminOption( - s, - "databases", - &resp, - tc.isAdmin, - ); err != nil { - t.Fatal(err) - } - - if a, e := len(resp.Databases), len(tc.expectedDBs); a != e { - t.Fatalf("length of result %d != expected %d", a, e) - } - - sort.Strings(tc.expectedDBs) - sort.Strings(resp.Databases) - for i, e := range tc.expectedDBs { - if a := resp.Databases[i]; a != e { - t.Fatalf("database name %s != expected %s", a, e) - } - } - - // Test database details endpoint. - var details serverpb.DatabaseDetailsResponse - urlEscapeDbName := url.PathEscape(testDbName) - - if err := getAdminJSONProtoWithAdminOption( - s, - "databases/"+urlEscapeDbName, - &details, - tc.isAdmin, - ); err != nil { - t.Fatal(err) - } - - if a, e := len(details.Grants), 3; a != e { - t.Fatalf("# of grants %d != expected %d", a, e) - } - - userGrants := make(map[string][]string) - for _, grant := range details.Grants { - switch grant.User { - case username.AdminRole, username.RootUser, authenticatedUserNoAdmin: - userGrants[grant.User] = append(userGrants[grant.User], grant.Privileges...) - default: - t.Fatalf("unknown grant to user %s", grant.User) - } - } - for u, p := range userGrants { - switch u { - case username.AdminRole: - if !reflect.DeepEqual(p, []string{"ALL"}) { - t.Fatalf("privileges %v != expected %v", p, privileges) - } - case username.RootUser: - if !reflect.DeepEqual(p, []string{"ALL"}) { - t.Fatalf("privileges %v != expected %v", p, privileges) - } - case authenticatedUserNoAdmin: - sort.Strings(p) - if !reflect.DeepEqual(p, privileges) { - t.Fatalf("privileges %v != expected %v", p, privileges) - } - default: - t.Fatalf("unknown grant to user %s", u) - } - } - - // Verify Descriptor ID. - databaseID, err := ts.admin.queryDatabaseID(ctx, username.RootUserName(), testDbName) - if err != nil { - t.Fatal(err) - } - if a, e := details.DescriptorID, int64(databaseID); a != e { - t.Fatalf("db had descriptorID %d, expected %d", a, e) - } - }) - } -} - -func TestAdminAPIDatabaseDoesNotExist(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails with - // it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - - const errPattern = "database.+does not exist" - if err := getAdminJSONProto(s, "databases/i_do_not_exist", nil); !testutils.IsError(err, errPattern) { - t.Fatalf("unexpected error: %v\nexpected: %s", err, errPattern) - } -} - -func TestAdminAPIDatabaseSQLInjection(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails with - // it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - - const fakedb = "system;DROP DATABASE system;" - const path = "databases/" + fakedb - const errPattern = `target database or schema does not exist` - if err := getAdminJSONProto(s, path, nil); !testutils.IsError(err, errPattern) { - t.Fatalf("unexpected error: %v\nexpected: %s", err, errPattern) - } -} - -func TestAdminAPINonTableStats(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{}) - defer testCluster.Stopper().Stop(context.Background()) - s := testCluster.Server(0) - - // Skip TableStatsResponse.Stats comparison, since it includes data which - // aren't consistent (time, bytes). - expectedResponse := serverpb.NonTableStatsResponse{ - TimeSeriesStats: &serverpb.TableStatsResponse{ - RangeCount: 1, - ReplicaCount: 3, - NodeCount: 3, - }, - InternalUseStats: &serverpb.TableStatsResponse{ - RangeCount: 11, - ReplicaCount: 15, - NodeCount: 3, - }, - } - - var resp serverpb.NonTableStatsResponse - if err := getAdminJSONProto(s, "nontablestats", &resp); err != nil { - t.Fatal(err) - } - - assertExpectedStatsResponse := func(expected, actual *serverpb.TableStatsResponse) { - assert.Equal(t, expected.RangeCount, actual.RangeCount) - assert.Equal(t, expected.ReplicaCount, actual.ReplicaCount) - assert.Equal(t, expected.NodeCount, actual.NodeCount) - } - - assertExpectedStatsResponse(expectedResponse.TimeSeriesStats, resp.TimeSeriesStats) - assertExpectedStatsResponse(expectedResponse.InternalUseStats, resp.InternalUseStats) -} - -// Verify that for a cluster with no user data, all the ranges on the Databases -// page consist of: -// 1) the total ranges listed for the system database -// 2) the total ranges listed for the Non-Table data -func TestRangeCount(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{}) - require.NoError(t, testCluster.WaitForFullReplication()) - defer testCluster.Stopper().Stop(context.Background()) - s := testCluster.Server(0) - - // Sum up ranges for non-table parts of the system returned - // from the "nontablestats" enpoint. - getNonTableRangeCount := func() (ts, internal int64) { - var resp serverpb.NonTableStatsResponse - if err := getAdminJSONProto(s, "nontablestats", &resp); err != nil { - t.Fatal(err) - } - return resp.TimeSeriesStats.RangeCount, resp.InternalUseStats.RangeCount - } - - // Return map tablename=>count obtained from the - // "databases/system/tables/{table}" endpoints. - getSystemTableRangeCount := func() map[string]int64 { - m := map[string]int64{} - var dbResp serverpb.DatabaseDetailsResponse - if err := getAdminJSONProto(s, "databases/system", &dbResp); err != nil { - t.Fatal(err) - } - for _, tableName := range dbResp.TableNames { - var tblResp serverpb.TableStatsResponse - path := "databases/system/tables/" + tableName + "/stats" - if err := getAdminJSONProto(s, path, &tblResp); err != nil { - t.Fatal(err) - } - m[tableName] = tblResp.RangeCount - } - // Hardcode the single range used by each system sequence, the above - // request does not return sequences. - // TODO(richardjcai): Maybe update the request to return - // sequences as well? - m[fmt.Sprintf("public.%s", catconstants.DescIDSequenceTableName)] = 1 - m[fmt.Sprintf("public.%s", catconstants.RoleIDSequenceName)] = 1 - m[fmt.Sprintf("public.%s", catconstants.TenantIDSequenceTableName)] = 1 - return m - } - - getRangeCountFromFullSpan := func() int64 { - adminServer := s.(*TestServer).Server.admin - stats, err := adminServer.statsForSpan(context.Background(), roachpb.Span{ - Key: keys.LocalMax, - EndKey: keys.MaxKey, - }) - if err != nil { - t.Fatal(err) - } - return stats.RangeCount - } - - exp := getRangeCountFromFullSpan() - - var systemTableRangeCount int64 - sysDBMap := getSystemTableRangeCount() - for _, n := range sysDBMap { - systemTableRangeCount += n - } - - tsCount, internalCount := getNonTableRangeCount() - - act := tsCount + internalCount + systemTableRangeCount - - if !assert.Equal(t, - exp, - act, - ) { - t.Log("did nonTableDescriptorRangeCount() change?") - t.Logf( - "claimed numbers:\ntime series = %d\ninternal = %d\nsystemdb = %d (%v)", - tsCount, internalCount, systemTableRangeCount, sysDBMap, - ) - db := testCluster.ServerConn(0) - defer db.Close() - - runner := sqlutils.MakeSQLRunner(db) - s := sqlutils.MatrixToStr(runner.QueryStr(t, `SHOW CLUSTER RANGES`)) - t.Logf("actual ranges:\n%s", s) - } -} - -func TestAdminAPITableDoesNotExist(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails with - // it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - - const fakename = "i_do_not_exist" - const badDBPath = "databases/" + fakename + "/tables/foo" - const dbErrPattern = `relation \\"` + fakename + `.foo\\" does not exist` - if err := getAdminJSONProto(s, badDBPath, nil); !testutils.IsError(err, dbErrPattern) { - t.Fatalf("unexpected error: %v\nexpected: %s", err, dbErrPattern) - } - - const badTablePath = "databases/system/tables/" + fakename - const tableErrPattern = `relation \\"system.` + fakename + `\\" does not exist` - if err := getAdminJSONProto(s, badTablePath, nil); !testutils.IsError(err, tableErrPattern) { - t.Fatalf("unexpected error: %v\nexpected: %s", err, tableErrPattern) - } -} - -func TestAdminAPITableSQLInjection(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails with - // it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - - const fakeTable = "users;DROP DATABASE system;" - const path = "databases/system/tables/" + fakeTable - const errPattern = `relation \"system.` + fakeTable + `\" does not exist` - if err := getAdminJSONProto(s, path, nil); !testutils.IsError(err, regexp.QuoteMeta(errPattern)) { - t.Fatalf("unexpected error: %v\nexpected: %s", err, errPattern) - } -} - -func TestAdminAPITableDetails(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - for _, tc := range []struct { - name, dbName, tblName, pkName string - }{ - {name: "lower", dbName: "test", tblName: "tbl", pkName: "tbl_pkey"}, - {name: "lower other schema", dbName: "test", tblName: `testschema.tbl`, pkName: "tbl_pkey"}, - {name: "lower with space", dbName: "test test", tblName: `"tbl tbl"`, pkName: "tbl tbl_pkey"}, - {name: "upper", dbName: "TEST", tblName: `"TBL"`, pkName: "TBL_pkey"}, // Regression test for issue #14056 - } { - t.Run(tc.name, func(t *testing.T) { - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails - // with it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - ts := s.(*TestServer) - - escDBName := tree.NameStringP(&tc.dbName) - tblName := tc.tblName - schemaName := "testschema" - - ac := ts.AmbientCtx() - ctx, span := ac.AnnotateCtxWithSpan(context.Background(), "test") - defer span.Finish() - - tableSchema := `nulls_allowed INT8, - nulls_not_allowed INT8 NOT NULL DEFAULT 1000, - default2 INT8 DEFAULT 2, - string_default STRING DEFAULT 'default_string', - INDEX descidx (default2 DESC)` - - setupQueries := []string{ - fmt.Sprintf("CREATE DATABASE %s", escDBName), - fmt.Sprintf("CREATE SCHEMA %s", schemaName), - fmt.Sprintf(`CREATE TABLE %s.%s (%s)`, escDBName, tblName, tableSchema), - "CREATE USER readonly", - "CREATE USER app", - fmt.Sprintf("GRANT SELECT ON %s.%s TO readonly", escDBName, tblName), - fmt.Sprintf("GRANT SELECT,UPDATE,DELETE ON %s.%s TO app", escDBName, tblName), - fmt.Sprintf("CREATE STATISTICS test_stats FROM %s.%s", escDBName, tblName), - } - pgURL, cleanupGoDB := sqlutils.PGUrl( - t, s.ServingSQLAddr(), "StartServer" /* prefix */, url.User(username.RootUser)) - defer cleanupGoDB() - pgURL.Path = tc.dbName - db, err := gosql.Open("postgres", pgURL.String()) - if err != nil { - t.Fatal(err) - } - defer db.Close() - for _, q := range setupQueries { - if _, err := db.Exec(q); err != nil { - t.Fatal(err) - } - } - - // Perform API call. - var resp serverpb.TableDetailsResponse - url := fmt.Sprintf("databases/%s/tables/%s", tc.dbName, tblName) - if err := getAdminJSONProto(s, url, &resp); err != nil { - t.Fatal(err) - } - - // Verify columns. - expColumns := []serverpb.TableDetailsResponse_Column{ - {Name: "nulls_allowed", Type: "INT8", Nullable: true, DefaultValue: ""}, - {Name: "nulls_not_allowed", Type: "INT8", Nullable: false, DefaultValue: "1000"}, - {Name: "default2", Type: "INT8", Nullable: true, DefaultValue: "2"}, - {Name: "string_default", Type: "STRING", Nullable: true, DefaultValue: "'default_string'"}, - {Name: "rowid", Type: "INT8", Nullable: false, DefaultValue: "unique_rowid()", Hidden: true}, - } - testutils.SortStructs(expColumns, "Name") - testutils.SortStructs(resp.Columns, "Name") - if a, e := len(resp.Columns), len(expColumns); a != e { - t.Fatalf("# of result columns %d != expected %d (got: %#v)", a, e, resp.Columns) - } - for i, a := range resp.Columns { - e := expColumns[i] - if a.String() != e.String() { - t.Fatalf("mismatch at column %d: actual %#v != %#v", i, a, e) - } - } - - // Verify grants. - expGrants := []serverpb.TableDetailsResponse_Grant{ - {User: username.AdminRole, Privileges: []string{"ALL"}}, - {User: username.RootUser, Privileges: []string{"ALL"}}, - {User: "app", Privileges: []string{"DELETE"}}, - {User: "app", Privileges: []string{"SELECT"}}, - {User: "app", Privileges: []string{"UPDATE"}}, - {User: "readonly", Privileges: []string{"SELECT"}}, - } - testutils.SortStructs(expGrants, "User") - testutils.SortStructs(resp.Grants, "User") - if a, e := len(resp.Grants), len(expGrants); a != e { - t.Fatalf("# of grant columns %d != expected %d (got: %#v)", a, e, resp.Grants) - } - for i, a := range resp.Grants { - e := expGrants[i] - sort.Strings(a.Privileges) - sort.Strings(e.Privileges) - if a.String() != e.String() { - t.Fatalf("mismatch at index %d: actual %#v != %#v", i, a, e) - } - } - - // Verify indexes. - expIndexes := []serverpb.TableDetailsResponse_Index{ - {Name: tc.pkName, Column: "string_default", Direction: "N/A", Unique: true, Seq: 5, Storing: true}, - {Name: tc.pkName, Column: "default2", Direction: "N/A", Unique: true, Seq: 4, Storing: true}, - {Name: tc.pkName, Column: "nulls_not_allowed", Direction: "N/A", Unique: true, Seq: 3, Storing: true}, - {Name: tc.pkName, Column: "nulls_allowed", Direction: "N/A", Unique: true, Seq: 2, Storing: true}, - {Name: tc.pkName, Column: "rowid", Direction: "ASC", Unique: true, Seq: 1}, - {Name: "descidx", Column: "rowid", Direction: "ASC", Unique: false, Seq: 2, Implicit: true}, - {Name: "descidx", Column: "default2", Direction: "DESC", Unique: false, Seq: 1}, - } - testutils.SortStructs(expIndexes, "Name", "Seq") - testutils.SortStructs(resp.Indexes, "Name", "Seq") - for i, a := range resp.Indexes { - e := expIndexes[i] - if a.String() != e.String() { - t.Fatalf("mismatch at index %d: actual %#v != %#v", i, a, e) - } - } - - // Verify range count. - if a, e := resp.RangeCount, int64(1); a != e { - t.Fatalf("# of ranges %d != expected %d", a, e) - } - - // Verify Create Table Statement. - { - - showCreateTableQuery := fmt.Sprintf("SHOW CREATE TABLE %s.%s", escDBName, tblName) - - row := db.QueryRow(showCreateTableQuery) - var createStmt, tableName string - if err := row.Scan(&tableName, &createStmt); err != nil { - t.Fatal(err) - } - - if a, e := resp.CreateTableStatement, createStmt; a != e { - t.Fatalf("mismatched create table statement; expected %s, got %s", e, a) - } - } - - // Verify statistics last updated. - { - - showStatisticsForTableQuery := fmt.Sprintf("SELECT max(created) AS created FROM [SHOW STATISTICS FOR TABLE %s.%s]", escDBName, tblName) - - row := db.QueryRow(showStatisticsForTableQuery) - var createdTs time.Time - if err := row.Scan(&createdTs); err != nil { - t.Fatal(err) - } - - if a, e := resp.StatsLastCreatedAt, createdTs; reflect.DeepEqual(a, e) { - t.Fatalf("mismatched statistics creation timestamp; expected %s, got %s", e, a) - } - } - - // Verify Descriptor ID. - tableID, err := ts.admin.queryTableID(ctx, username.RootUserName(), tc.dbName, tc.tblName) - if err != nil { - t.Fatal(err) - } - if a, e := resp.DescriptorID, int64(tableID); a != e { - t.Fatalf("table had descriptorID %d, expected %d", a, e) - } - }) - } -} - -// TestAdminAPIZoneDetails verifies the zone configuration information returned -// for both DatabaseDetailsResponse AND TableDetailsResponse. -func TestAdminAPIZoneDetails(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, db, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails - // with it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - ts := s.(*TestServer) - - // Create database and table. - ac := ts.AmbientCtx() - ctx, span := ac.AnnotateCtxWithSpan(context.Background(), "test") - defer span.Finish() - setupQueries := []string{ - "CREATE DATABASE test", - "CREATE TABLE test.tbl (val STRING)", - } - for _, q := range setupQueries { - if _, err := db.Exec(q); err != nil { - t.Fatalf("error executing '%s': %s", q, err) - } - } - - // Function to verify the zone for table "test.tbl" as returned by the Admin - // API. - verifyTblZone := func( - expectedZone zonepb.ZoneConfig, expectedLevel serverpb.ZoneConfigurationLevel, - ) { - var resp serverpb.TableDetailsResponse - if err := getAdminJSONProto(s, "databases/test/tables/tbl", &resp); err != nil { - t.Fatal(err) - } - if a, e := &resp.ZoneConfig, &expectedZone; !a.Equal(e) { - t.Errorf("actual table zone config %v did not match expected value %v", a, e) - } - if a, e := resp.ZoneConfigLevel, expectedLevel; a != e { - t.Errorf("actual table ZoneConfigurationLevel %s did not match expected value %s", a, e) - } - if t.Failed() { - t.FailNow() - } - } - - // Function to verify the zone for database "test" as returned by the Admin - // API. - verifyDbZone := func( - expectedZone zonepb.ZoneConfig, expectedLevel serverpb.ZoneConfigurationLevel, - ) { - var resp serverpb.DatabaseDetailsResponse - if err := getAdminJSONProto(s, "databases/test", &resp); err != nil { - t.Fatal(err) - } - if a, e := &resp.ZoneConfig, &expectedZone; !a.Equal(e) { - t.Errorf("actual db zone config %v did not match expected value %v", a, e) - } - if a, e := resp.ZoneConfigLevel, expectedLevel; a != e { - t.Errorf("actual db ZoneConfigurationLevel %s did not match expected value %s", a, e) - } - if t.Failed() { - t.FailNow() - } - } - - // Function to store a zone config for a given object ID. - setZone := func(zoneCfg zonepb.ZoneConfig, id descpb.ID) { - zoneBytes, err := protoutil.Marshal(&zoneCfg) - if err != nil { - t.Fatal(err) - } - const query = `INSERT INTO system.zones VALUES($1, $2)` - if _, err := db.Exec(query, id, zoneBytes); err != nil { - t.Fatalf("error executing '%s': %s", query, err) - } - } - - // Verify zone matches cluster default. - verifyDbZone(s.(*TestServer).Cfg.DefaultZoneConfig, serverpb.ZoneConfigurationLevel_CLUSTER) - verifyTblZone(s.(*TestServer).Cfg.DefaultZoneConfig, serverpb.ZoneConfigurationLevel_CLUSTER) - - databaseID, err := ts.admin.queryDatabaseID(ctx, username.RootUserName(), "test") - if err != nil { - t.Fatal(err) - } - tableID, err := ts.admin.queryTableID(ctx, username.RootUserName(), "test", "tbl") - if err != nil { - t.Fatal(err) - } - - // Apply zone configuration to database and check again. - dbZone := zonepb.ZoneConfig{ - RangeMinBytes: proto.Int64(456), - } - setZone(dbZone, databaseID) - verifyDbZone(dbZone, serverpb.ZoneConfigurationLevel_DATABASE) - verifyTblZone(dbZone, serverpb.ZoneConfigurationLevel_DATABASE) - - // Apply zone configuration to table and check again. - tblZone := zonepb.ZoneConfig{ - RangeMinBytes: proto.Int64(789), - } - setZone(tblZone, tableID) - verifyDbZone(dbZone, serverpb.ZoneConfigurationLevel_DATABASE) - verifyTblZone(tblZone, serverpb.ZoneConfigurationLevel_TABLE) -} - -func TestAdminAPIUsers(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, db, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails - // with it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - - // Create sample users. - query := ` -INSERT INTO system.users (username, "hashedPassword", user_id) -VALUES ('adminUser', 'abc', 200), ('bob', 'xyz', 201)` - if _, err := db.Exec(query); err != nil { - t.Fatal(err) - } - - // Query the API for users. - var resp serverpb.UsersResponse - if err := getAdminJSONProto(s, "users", &resp); err != nil { - t.Fatal(err) - } - expResult := serverpb.UsersResponse{ - Users: []serverpb.UsersResponse_User{ - {Username: "adminUser"}, - {Username: "authentic_user"}, - {Username: "bob"}, - {Username: "root"}, - }, - } - - // Verify results. - const sortKey = "Username" - testutils.SortStructs(resp.Users, sortKey) - testutils.SortStructs(expResult.Users, sortKey) - if !reflect.DeepEqual(resp, expResult) { - t.Fatalf("result %v != expected %v", resp, expResult) - } -} - -func TestAdminAPIEvents(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, db, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails - // with it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - - setupQueries := []string{ - "CREATE DATABASE api_test", - "CREATE TABLE api_test.tbl1 (a INT)", - "CREATE TABLE api_test.tbl2 (a INT)", - "CREATE TABLE api_test.tbl3 (a INT)", - "DROP TABLE api_test.tbl1", - "DROP TABLE api_test.tbl2", - "SET CLUSTER SETTING cluster.organization = 'somestring';", - } - for _, q := range setupQueries { - if _, err := db.Exec(q); err != nil { - t.Fatalf("error executing '%s': %s", q, err) - } - } - - const allEvents = "" - type testcase struct { - eventType string - hasLimit bool - limit int - unredacted bool - expCount int - } - testcases := []testcase{ - {"node_join", false, 0, false, 1}, - {"node_restart", false, 0, false, 0}, - {"drop_database", false, 0, false, 0}, - {"create_database", false, 0, false, 3}, - {"drop_table", false, 0, false, 2}, - {"create_table", false, 0, false, 3}, - {"set_cluster_setting", false, 0, false, 2}, - // We use limit=true with no limit here because otherwise the - // expCount will mess up the expected total count below. - {"set_cluster_setting", true, 0, true, 2}, - {"create_table", true, 0, false, 3}, - {"create_table", true, -1, false, 3}, - {"create_table", true, 2, false, 2}, - } - minTotalEvents := 0 - for _, tc := range testcases { - if !tc.hasLimit { - minTotalEvents += tc.expCount - } - } - testcases = append(testcases, testcase{allEvents, false, 0, false, minTotalEvents}) - - for i, tc := range testcases { - url := "events" - if tc.eventType != allEvents { - url += "?type=" + tc.eventType - if tc.hasLimit { - url += fmt.Sprintf("&limit=%d", tc.limit) - } - if tc.unredacted { - url += "&unredacted_events=true" - } - } - - t.Run(url, func(t *testing.T) { - var resp serverpb.EventsResponse - if err := getAdminJSONProto(s, url, &resp); err != nil { - t.Fatal(err) - } - if tc.eventType == allEvents { - // When retrieving all events, we expect that there will be some system - // database migrations, unrelated to this test, that add to the log entry - // count. So, we do a looser check here. - if a, min := len(resp.Events), tc.expCount; a < tc.expCount { - t.Fatalf("%d: total # of events %d < min %d", i, a, min) - } - } else { - if a, e := len(resp.Events), tc.expCount; a != e { - t.Fatalf("%d: # of %s events %d != expected %d", i, tc.eventType, a, e) - } - } - - // Ensure we don't have blank / nonsensical fields. - for _, e := range resp.Events { - if e.Timestamp == (time.Time{}) { - t.Errorf("%d: missing/empty timestamp", i) - } - - if len(tc.eventType) > 0 { - if a, e := e.EventType, tc.eventType; a != e { - t.Errorf("%d: event type %s != expected %s", i, a, e) - } - } else { - if len(e.EventType) == 0 { - t.Errorf("%d: missing event type in event", i) - } - } - - isSettingChange := e.EventType == "set_cluster_setting" - - if e.ReportingID == 0 { - t.Errorf("%d: missing/empty ReportingID", i) - } - if len(e.Info) == 0 { - t.Errorf("%d: missing/empty Info", i) - } - if isSettingChange && strings.Contains(e.Info, "cluster.organization") { - if tc.unredacted { - if !strings.Contains(e.Info, "somestring") { - t.Errorf("%d: require 'somestring' in Info", i) - } - } else { - if strings.Contains(e.Info, "somestring") { - t.Errorf("%d: un-redacted 'somestring' in Info", i) - } - } - } - if len(e.UniqueID) == 0 { - t.Errorf("%d: missing/empty UniqueID", i) - } - } - }) - } -} - -func TestAdminAPISettings(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - s, conn, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails - // with it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - - // Any bool that defaults to true will work here. - const settingKey = "sql.metrics.statement_details.enabled" - st := s.ClusterSettings() - allKeys := settings.Keys(settings.ForSystemTenant) - - checkSetting := func(t *testing.T, k string, v serverpb.SettingsResponse_Value) { - ref, ok := settings.LookupForReporting(k, settings.ForSystemTenant) - if !ok { - t.Fatalf("%s: not found after initial lookup", k) - } - typ := ref.Typ() - - if !settings.TestingIsReportable(ref) { - if v.Value != "" && v.Value != "" { - t.Errorf("%s: expected redacted value for %v, got %s", k, ref, v.Value) - } - } else { - if ref.String(&st.SV) != v.Value { - t.Errorf("%s: expected value %v, got %s", k, ref, v.Value) - } - } - - if expectedPublic := ref.Visibility() == settings.Public; expectedPublic != v.Public { - t.Errorf("%s: expected public %v, got %v", k, expectedPublic, v.Public) - } - - if desc := ref.Description(); desc != v.Description { - t.Errorf("%s: expected description %s, got %s", k, desc, v.Description) - } - if typ != v.Type { - t.Errorf("%s: expected type %s, got %s", k, typ, v.Type) - } - if v.LastUpdated != nil { - db := sqlutils.MakeSQLRunner(conn) - q := makeSQLQuery() - q.Append(`SELECT name, "lastUpdated" FROM system.settings WHERE name=$`, k) - rows := db.Query( - t, - q.String(), - q.QueryArguments()..., - ) - defer rows.Close() - if rows.Next() == false { - t.Errorf("missing sql row for %s", k) - } - } - } - - t.Run("all", func(t *testing.T) { - var resp serverpb.SettingsResponse - - if err := getAdminJSONProto(s, "settings", &resp); err != nil { - t.Fatal(err) - } - - // Check that all expected keys were returned - if len(allKeys) != len(resp.KeyValues) { - t.Fatalf("expected %d keys, got %d", len(allKeys), len(resp.KeyValues)) - } - for _, k := range allKeys { - if _, ok := resp.KeyValues[k]; !ok { - t.Fatalf("expected key %s not found in response", k) - } - } - - // Check that the test key is listed and the values come indeed - // from the settings package unchanged. - seenRef := false - for k, v := range resp.KeyValues { - if k == settingKey { - seenRef = true - if v.Value != "true" { - t.Errorf("%s: expected true, got %s", k, v.Value) - } - } - - checkSetting(t, k, v) - } - - if !seenRef { - t.Fatalf("failed to observe test setting %s, got %+v", settingKey, resp.KeyValues) - } - }) - - t.Run("one-by-one", func(t *testing.T) { - var resp serverpb.SettingsResponse - - // All the settings keys must be retrievable, and their - // type and description must match. - for _, k := range allKeys { - q := make(url.Values) - q.Add("keys", k) - url := "settings?" + q.Encode() - if err := getAdminJSONProto(s, url, &resp); err != nil { - t.Fatalf("%s: %v", k, err) - } - if len(resp.KeyValues) != 1 { - t.Fatalf("%s: expected 1 response, got %d", k, len(resp.KeyValues)) - } - v, ok := resp.KeyValues[k] - if !ok { - t.Fatalf("%s: response does not contain key", k) - } - - checkSetting(t, k, v) - } - }) -} - -// TestAdminAPIUIData checks that UI customizations are properly -// persisted for both admin and non-admin users. -func TestAdminAPIUIData(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails - // with it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - - testutils.RunTrueAndFalse(t, "isAdmin", func(t *testing.T, isAdmin bool) { - start := timeutil.Now() - - mustSetUIData := func(keyValues map[string][]byte) { - if err := postAdminJSONProtoWithAdminOption(s, "uidata", &serverpb.SetUIDataRequest{ - KeyValues: keyValues, - }, &serverpb.SetUIDataResponse{}, isAdmin); err != nil { - t.Fatal(err) - } - } - - expectKeyValues := func(expKeyValues map[string][]byte) { - var resp serverpb.GetUIDataResponse - queryValues := make(url.Values) - for key := range expKeyValues { - queryValues.Add("keys", key) - } - url := "uidata?" + queryValues.Encode() - if err := getAdminJSONProtoWithAdminOption(s, url, &resp, isAdmin); err != nil { - t.Fatal(err) - } - // Do a two-way comparison. We can't use reflect.DeepEqual(), because - // resp.KeyValues has timestamps and expKeyValues doesn't. - for key, actualVal := range resp.KeyValues { - if a, e := actualVal.Value, expKeyValues[key]; !bytes.Equal(a, e) { - t.Fatalf("key %s: value = %v, expected = %v", key, a, e) - } - } - for key, expVal := range expKeyValues { - if a, e := resp.KeyValues[key].Value, expVal; !bytes.Equal(a, e) { - t.Fatalf("key %s: value = %v, expected = %v", key, a, e) - } - } - - // Sanity check LastUpdated. - for _, val := range resp.KeyValues { - now := timeutil.Now() - if val.LastUpdated.Before(start) { - t.Fatalf("val.LastUpdated %s < start %s", val.LastUpdated, start) - } - if val.LastUpdated.After(now) { - t.Fatalf("val.LastUpdated %s > now %s", val.LastUpdated, now) - } - } - } - - expectValueEquals := func(key string, expVal []byte) { - expectKeyValues(map[string][]byte{key: expVal}) - } - - expectKeyNotFound := func(key string) { - var resp serverpb.GetUIDataResponse - url := "uidata?keys=" + key - if err := getAdminJSONProtoWithAdminOption(s, url, &resp, isAdmin); err != nil { - t.Fatal(err) - } - if len(resp.KeyValues) != 0 { - t.Fatal("key unexpectedly found") - } - } - - // Basic tests. - var badResp serverpb.GetUIDataResponse - const errPattern = "400 Bad Request" - if err := getAdminJSONProtoWithAdminOption(s, "uidata", &badResp, isAdmin); !testutils.IsError(err, errPattern) { - t.Fatalf("unexpected error: %v\nexpected: %s", err, errPattern) - } - - mustSetUIData(map[string][]byte{"k1": []byte("v1")}) - expectValueEquals("k1", []byte("v1")) - - expectKeyNotFound("NON_EXISTENT_KEY") - - mustSetUIData(map[string][]byte{ - "k2": []byte("v2"), - "k3": []byte("v3"), - }) - expectValueEquals("k2", []byte("v2")) - expectValueEquals("k3", []byte("v3")) - expectKeyValues(map[string][]byte{ - "k2": []byte("v2"), - "k3": []byte("v3"), - }) - - mustSetUIData(map[string][]byte{"k2": []byte("v2-updated")}) - expectKeyValues(map[string][]byte{ - "k2": []byte("v2-updated"), - "k3": []byte("v3"), - }) - - // Write a binary blob with all possible byte values, then verify it. - var buf bytes.Buffer - for i := 0; i < 997; i++ { - buf.WriteByte(byte(i % 256)) - } - mustSetUIData(map[string][]byte{"bin": buf.Bytes()}) - expectValueEquals("bin", buf.Bytes()) - }) -} - -// TestAdminAPIUISeparateData check that separate users have separate customizations. -func TestAdminAPIUISeparateData(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails - // with it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - - // Make a setting for an admin user. - if err := postAdminJSONProtoWithAdminOption(s, "uidata", - &serverpb.SetUIDataRequest{KeyValues: map[string][]byte{"k": []byte("v1")}}, - &serverpb.SetUIDataResponse{}, - true /*isAdmin*/); err != nil { - t.Fatal(err) - } - - // Make a setting for a non-admin user. - if err := postAdminJSONProtoWithAdminOption(s, "uidata", - &serverpb.SetUIDataRequest{KeyValues: map[string][]byte{"k": []byte("v2")}}, - &serverpb.SetUIDataResponse{}, - false /*isAdmin*/); err != nil { - t.Fatal(err) - } - - var resp serverpb.GetUIDataResponse - url := "uidata?keys=k" - - if err := getAdminJSONProtoWithAdminOption(s, url, &resp, true /* isAdmin */); err != nil { - t.Fatal(err) - } - if len(resp.KeyValues) != 1 || !bytes.Equal(resp.KeyValues["k"].Value, []byte("v1")) { - t.Fatalf("unexpected admin values: %+v", resp.KeyValues) - } - if err := getAdminJSONProtoWithAdminOption(s, url, &resp, false /* isAdmin */); err != nil { - t.Fatal(err) - } - if len(resp.KeyValues) != 1 || !bytes.Equal(resp.KeyValues["k"].Value, []byte("v2")) { - t.Fatalf("unexpected non-admin values: %+v", resp.KeyValues) - } -} - -func TestClusterAPI(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, db, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails - // with it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - - testutils.RunTrueAndFalse(t, "reportingOn", func(t *testing.T, reportingOn bool) { - testutils.RunTrueAndFalse(t, "enterpriseOn", func(t *testing.T, enterpriseOn bool) { - // Override server license check. - if enterpriseOn { - old := base.CheckEnterpriseEnabled - base.CheckEnterpriseEnabled = func(_ *cluster.Settings, _ uuid.UUID, _ string) error { - return nil - } - defer func() { base.CheckEnterpriseEnabled = old }() - } - - if _, err := db.Exec(`SET CLUSTER SETTING diagnostics.reporting.enabled = $1`, reportingOn); err != nil { - t.Fatal(err) - } - - // We need to retry, because the cluster ID isn't set until after - // bootstrapping and because setting a cluster setting isn't necessarily - // instantaneous. - // - // Also note that there's a migration that affects `diagnostics.reporting.enabled`, - // so manipulating the cluster setting var directly is a bad idea. - testutils.SucceedsSoon(t, func() error { - var resp serverpb.ClusterResponse - if err := getAdminJSONProto(s, "cluster", &resp); err != nil { - return err - } - if a, e := resp.ClusterID, s.RPCContext().StorageClusterID.String(); a != e { - return errors.Errorf("cluster ID %s != expected %s", a, e) - } - if a, e := resp.ReportingEnabled, reportingOn; a != e { - return errors.Errorf("reportingEnabled = %t, wanted %t", a, e) - } - if a, e := resp.EnterpriseEnabled, enterpriseOn; a != e { - return errors.Errorf("enterpriseEnabled = %t, wanted %t", a, e) - } - return nil - }) - }) - }) -} - -func TestHealthAPI(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - ctx := context.Background() - - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails - // with it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(ctx) - ts := s.(*TestServer) - - // We need to retry because the node ID isn't set until after - // bootstrapping. - testutils.SucceedsSoon(t, func() error { - var resp serverpb.HealthResponse - return getAdminJSONProto(s, "health", &resp) - }) - - // Make the SQL listener appear unavailable. Verify that health fails after that. - ts.sqlServer.isReady.Set(false) - var resp serverpb.HealthResponse - err := getAdminJSONProto(s, "health?ready=1", &resp) - if err == nil { - t.Error("server appears ready even though SQL listener is not") - } - ts.sqlServer.isReady.Set(true) - err = getAdminJSONProto(s, "health?ready=1", &resp) - if err != nil { - t.Errorf("server not ready after SQL listener is ready again: %v", err) - } - - // Expire this node's liveness record by pausing heartbeats and advancing the - // server's clock. - defer ts.nodeLiveness.PauseAllHeartbeatsForTest()() - self, ok := ts.nodeLiveness.Self() - assert.True(t, ok) - s.Clock().Update(self.Expiration.ToTimestamp().Add(1, 0).UnsafeToClockTimestamp()) - - testutils.SucceedsSoon(t, func() error { - err := getAdminJSONProto(s, "health?ready=1", &resp) - if err == nil { - return errors.New("health OK, still waiting for unhealth") - } - - t.Logf("observed error: %v", err) - if !testutils.IsError(err, `(?s)503 Service Unavailable.*"error": "node is not healthy"`) { - return err - } - return nil - }) - - // After the node reports an error with `?ready=1`, the health - // endpoint must still succeed without error when `?ready=1` is not specified. - if err := getAdminJSONProto(s, "health", &resp); err != nil { - t.Fatal(err) - } -} - -// getSystemJobIDsForNonAutoJobs queries the jobs table for all job IDs that have -// the given status. Sorted by decreasing creation time. -func getSystemJobIDsForNonAutoJobs( - t testing.TB, db *sqlutils.SQLRunner, status jobs.Status, -) []int64 { - q := makeSQLQuery() - q.Append(`SELECT job_id FROM crdb_internal.jobs WHERE status=$`, status) - q.Append(` AND (`) - for i, jobType := range jobspb.AutomaticJobTypes { - q.Append(`job_type != $`, jobType.String()) - if i < len(jobspb.AutomaticJobTypes)-1 { - q.Append(" AND ") - } - } - q.Append(` OR job_type IS NULL)`) - q.Append(` ORDER BY created DESC`) - rows := db.Query( - t, - q.String(), - q.QueryArguments()..., - ) - defer rows.Close() - - res := []int64{} - for rows.Next() { - var id int64 - if err := rows.Scan(&id); err != nil { - t.Fatal(err) - } - res = append(res, id) - } - return res -} - -func TestAdminAPIJobs(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - now := timeutil.Now() - retentionTime := 336 * time.Hour - s, conn, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails - // with it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - Knobs: base.TestingKnobs{ - JobsTestingKnobs: &jobs.TestingKnobs{ - IntervalOverrides: jobs.TestingIntervalOverrides{ - RetentionTime: &retentionTime, - }, - }, - Server: &TestingKnobs{ - StubTimeNow: func() time.Time { return now }, - }, - }, - }) - - defer s.Stopper().Stop(context.Background()) - sqlDB := sqlutils.MakeSQLRunner(conn) - - testutils.RunTrueAndFalse(t, "isAdmin", func(t *testing.T, isAdmin bool) { - // Creating this client causes a user to be created, which causes jobs - // to be created, so we do it up-front rather than inside the test. - _, err := s.GetAuthenticatedHTTPClient(isAdmin, serverutils.SingleTenantSession) - if err != nil { - t.Fatal(err) - } - }) - - existingSucceededIDs := getSystemJobIDsForNonAutoJobs(t, sqlDB, jobs.StatusSucceeded) - existingRunningIDs := getSystemJobIDsForNonAutoJobs(t, sqlDB, jobs.StatusRunning) - existingIDs := append(existingSucceededIDs, existingRunningIDs...) - - runningOnlyIds := []int64{1, 2, 4, 11, 12} - revertingOnlyIds := []int64{7, 8, 9} - retryRunningIds := []int64{6} - retryRevertingIds := []int64{10} - ef := &jobspb.RetriableExecutionFailure{ - TruncatedError: "foo", - } - // Add a regression test for #84139 where a string with a quote in it - // caused a failure in the admin API. - efQuote := &jobspb.RetriableExecutionFailure{ - TruncatedError: "foo\"abc\"", - } - - testJobs := []struct { - id int64 - status jobs.Status - details jobspb.Details - progress jobspb.ProgressDetails - username username.SQLUsername - numRuns int64 - lastRun time.Time - executionFailures []*jobspb.RetriableExecutionFailure - }{ - {1, jobs.StatusRunning, jobspb.RestoreDetails{}, jobspb.RestoreProgress{}, username.RootUserName(), 1, time.Time{}, nil}, - {2, jobs.StatusRunning, jobspb.BackupDetails{}, jobspb.BackupProgress{}, username.RootUserName(), 1, timeutil.Now().Add(10 * time.Minute), nil}, - {3, jobs.StatusSucceeded, jobspb.BackupDetails{}, jobspb.BackupProgress{}, username.RootUserName(), 1, time.Time{}, nil}, - {4, jobs.StatusRunning, jobspb.ChangefeedDetails{}, jobspb.ChangefeedProgress{}, username.RootUserName(), 2, time.Time{}, nil}, - {5, jobs.StatusSucceeded, jobspb.BackupDetails{}, jobspb.BackupProgress{}, authenticatedUserNameNoAdmin(), 1, time.Time{}, nil}, - {6, jobs.StatusRunning, jobspb.ImportDetails{}, jobspb.ImportProgress{}, username.RootUserName(), 2, timeutil.Now().Add(10 * time.Minute), nil}, - {7, jobs.StatusReverting, jobspb.ImportDetails{}, jobspb.ImportProgress{}, username.RootUserName(), 1, time.Time{}, nil}, - {8, jobs.StatusReverting, jobspb.ImportDetails{}, jobspb.ImportProgress{}, username.RootUserName(), 1, timeutil.Now().Add(10 * time.Minute), nil}, - {9, jobs.StatusReverting, jobspb.ImportDetails{}, jobspb.ImportProgress{}, username.RootUserName(), 2, time.Time{}, nil}, - {10, jobs.StatusReverting, jobspb.ImportDetails{}, jobspb.ImportProgress{}, username.RootUserName(), 2, timeutil.Now().Add(10 * time.Minute), nil}, - {11, jobs.StatusRunning, jobspb.RestoreDetails{}, jobspb.RestoreProgress{}, username.RootUserName(), 1, time.Time{}, []*jobspb.RetriableExecutionFailure{ef}}, - {12, jobs.StatusRunning, jobspb.RestoreDetails{}, jobspb.RestoreProgress{}, username.RootUserName(), 1, time.Time{}, []*jobspb.RetriableExecutionFailure{efQuote}}, - } - for _, job := range testJobs { - payload := jobspb.Payload{ - UsernameProto: job.username.EncodeProto(), - Details: jobspb.WrapPayloadDetails(job.details), - RetriableExecutionFailureLog: job.executionFailures, - } - payloadBytes, err := protoutil.Marshal(&payload) - if err != nil { - t.Fatal(err) - } - - progress := jobspb.Progress{Details: jobspb.WrapProgressDetails(job.progress)} - // Populate progress.Progress field with a specific progress type based on - // the job type. - if _, ok := job.progress.(jobspb.ChangefeedProgress); ok { - progress.Progress = &jobspb.Progress_HighWater{ - HighWater: &hlc.Timestamp{}, - } - } else { - progress.Progress = &jobspb.Progress_FractionCompleted{ - FractionCompleted: 1.0, - } - } - - progressBytes, err := protoutil.Marshal(&progress) - if err != nil { - t.Fatal(err) - } - sqlDB.Exec(t, - `INSERT INTO system.jobs (id, status, num_runs, last_run, job_type) VALUES ($1, $2, $3, $4, $5)`, - job.id, job.status, job.numRuns, job.lastRun, payload.Type().String(), - ) - sqlDB.Exec(t, - `INSERT INTO system.job_info (job_id, info_key, value) VALUES ($1, $2, $3)`, - job.id, jobs.GetLegacyPayloadKey(), payloadBytes, - ) - sqlDB.Exec(t, - `INSERT INTO system.job_info (job_id, info_key, value) VALUES ($1, $2, $3)`, - job.id, jobs.GetLegacyProgressKey(), progressBytes, - ) - } - - const invalidJobType = math.MaxInt32 - - testCases := []struct { - uri string - expectedIDsViaAdmin []int64 - expectedIDsViaNonAdmin []int64 - }{ - { - "jobs", - append([]int64{12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, existingIDs...), - []int64{5}, - }, - { - "jobs?limit=1", - []int64{12}, - []int64{5}, - }, - { - "jobs?status=succeeded", - append([]int64{5, 3}, existingSucceededIDs...), - []int64{5}, - }, - { - "jobs?status=running", - append(append(append([]int64{}, runningOnlyIds...), retryRunningIds...), existingRunningIDs...), - []int64{}, - }, - { - "jobs?status=reverting", - append(append([]int64{}, revertingOnlyIds...), retryRevertingIds...), - []int64{}, - }, - { - "jobs?status=pending", - []int64{}, - []int64{}, - }, - { - "jobs?status=garbage", - []int64{}, - []int64{}, - }, - { - fmt.Sprintf("jobs?type=%d", jobspb.TypeBackup), - []int64{5, 3, 2}, - []int64{5}, - }, - { - fmt.Sprintf("jobs?type=%d", jobspb.TypeRestore), - []int64{1, 11, 12}, - []int64{}, - }, - { - fmt.Sprintf("jobs?type=%d", invalidJobType), - []int64{}, - []int64{}, - }, - { - fmt.Sprintf("jobs?status=running&type=%d", jobspb.TypeBackup), - []int64{2}, - []int64{}, - }, - } - - testutils.RunTrueAndFalse(t, "isAdmin", func(t *testing.T, isAdmin bool) { - for i, testCase := range testCases { - var res serverpb.JobsResponse - if err := getAdminJSONProtoWithAdminOption(s, testCase.uri, &res, isAdmin); err != nil { - t.Fatal(err) - } - resIDs := []int64{} - for _, job := range res.Jobs { - resIDs = append(resIDs, job.ID) - } - - expected := testCase.expectedIDsViaAdmin - if !isAdmin { - expected = testCase.expectedIDsViaNonAdmin - } - - sort.Slice(expected, func(i, j int) bool { - return expected[i] < expected[j] - }) - - sort.Slice(resIDs, func(i, j int) bool { - return resIDs[i] < resIDs[j] - }) - if e, a := expected, resIDs; !reflect.DeepEqual(e, a) { - t.Errorf("%d - %v: expected job IDs %v, but got %v", i, testCase.uri, e, a) - } - // We don't use require.Equal() because timestamps don't necessarily - // compare == due to only one of them having a monotonic clock reading. - require.True(t, now.Add(-retentionTime).Equal(res.EarliestRetainedTime)) - } - }) -} - -func TestAdminAPIJobsDetails(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - s, conn, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails - // with it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - sqlDB := sqlutils.MakeSQLRunner(conn) - - now := timeutil.Now() - - encodedError := func(err error) *errors.EncodedError { - ee := errors.EncodeError(context.Background(), err) - return &ee - } - testJobs := []struct { - id int64 - status jobs.Status - details jobspb.Details - progress jobspb.ProgressDetails - username username.SQLUsername - numRuns int64 - lastRun time.Time - executionLog []*jobspb.RetriableExecutionFailure - }{ - {1, jobs.StatusRunning, jobspb.RestoreDetails{}, jobspb.RestoreProgress{}, username.RootUserName(), 1, time.Time{}, nil}, - {2, jobs.StatusReverting, jobspb.BackupDetails{}, jobspb.BackupProgress{}, username.RootUserName(), 1, time.Time{}, nil}, - {3, jobs.StatusRunning, jobspb.BackupDetails{}, jobspb.BackupProgress{}, username.RootUserName(), 1, now.Add(10 * time.Minute), nil}, - {4, jobs.StatusReverting, jobspb.ChangefeedDetails{}, jobspb.ChangefeedProgress{}, username.RootUserName(), 1, now.Add(10 * time.Minute), nil}, - {5, jobs.StatusRunning, jobspb.BackupDetails{}, jobspb.BackupProgress{}, username.RootUserName(), 2, time.Time{}, nil}, - {6, jobs.StatusReverting, jobspb.ChangefeedDetails{}, jobspb.ChangefeedProgress{}, username.RootUserName(), 2, time.Time{}, nil}, - {7, jobs.StatusRunning, jobspb.BackupDetails{}, jobspb.BackupProgress{}, username.RootUserName(), 2, now.Add(10 * time.Minute), nil}, - {8, jobs.StatusReverting, jobspb.ChangefeedDetails{}, jobspb.ChangefeedProgress{}, username.RootUserName(), 2, now.Add(10 * time.Minute), []*jobspb.RetriableExecutionFailure{ - { - Status: string(jobs.StatusRunning), - ExecutionStartMicros: now.Add(-time.Minute).UnixMicro(), - ExecutionEndMicros: now.Add(-30 * time.Second).UnixMicro(), - InstanceID: 1, - Error: encodedError(errors.New("foo")), - }, - { - Status: string(jobs.StatusReverting), - ExecutionStartMicros: now.Add(-29 * time.Minute).UnixMicro(), - ExecutionEndMicros: now.Add(-time.Second).UnixMicro(), - InstanceID: 1, - TruncatedError: "bar", - }, - }}, - } - for _, job := range testJobs { - payload := jobspb.Payload{ - UsernameProto: job.username.EncodeProto(), - Details: jobspb.WrapPayloadDetails(job.details), - RetriableExecutionFailureLog: job.executionLog, - } - payloadBytes, err := protoutil.Marshal(&payload) - if err != nil { - t.Fatal(err) - } - - progress := jobspb.Progress{Details: jobspb.WrapProgressDetails(job.progress)} - // Populate progress.Progress field with a specific progress type based on - // the job type. - if _, ok := job.progress.(jobspb.ChangefeedProgress); ok { - progress.Progress = &jobspb.Progress_HighWater{ - HighWater: &hlc.Timestamp{}, - } - } else { - progress.Progress = &jobspb.Progress_FractionCompleted{ - FractionCompleted: 1.0, - } - } - - progressBytes, err := protoutil.Marshal(&progress) - if err != nil { - t.Fatal(err) - } - sqlDB.Exec(t, - `INSERT INTO system.jobs (id, status, num_runs, last_run) VALUES ($1, $2, $3, $4)`, - job.id, job.status, job.numRuns, job.lastRun, - ) - sqlDB.Exec(t, - `INSERT INTO system.job_info (job_id, info_key, value) VALUES ($1, $2, $3)`, - job.id, jobs.GetLegacyPayloadKey(), payloadBytes, - ) - sqlDB.Exec(t, - `INSERT INTO system.job_info (job_id, info_key, value) VALUES ($1, $2, $3)`, - job.id, jobs.GetLegacyProgressKey(), progressBytes, - ) - } - - var res serverpb.JobsResponse - if err := getAdminJSONProto(s, "jobs", &res); err != nil { - t.Fatal(err) - } - - // Trim down our result set to the jobs we injected. - resJobs := append([]serverpb.JobResponse(nil), res.Jobs...) - sort.Slice(resJobs, func(i, j int) bool { - return resJobs[i].ID < resJobs[j].ID - }) - resJobs = resJobs[:len(testJobs)] - - for i, job := range resJobs { - require.Equal(t, testJobs[i].id, job.ID) - require.Equal(t, len(testJobs[i].executionLog), len(job.ExecutionFailures)) - for j, f := range job.ExecutionFailures { - tf := testJobs[i].executionLog[j] - require.Equal(t, tf.Status, f.Status) - require.Equal(t, tf.ExecutionStartMicros, f.Start.UnixMicro()) - require.Equal(t, tf.ExecutionEndMicros, f.End.UnixMicro()) - var expErr string - if tf.Error != nil { - expErr = errors.DecodeError(context.Background(), *tf.Error).Error() - } else { - expErr = tf.TruncatedError - } - require.Equal(t, expErr, f.Error) - } - } -} - -func TestAdminAPILocations(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - s, conn, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails - // with it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - sqlDB := sqlutils.MakeSQLRunner(conn) - - testLocations := []struct { - localityKey string - localityValue string - latitude float64 - longitude float64 - }{ - {"city", "Des Moines", 41.60054, -93.60911}, - {"city", "New York City", 40.71427, -74.00597}, - {"city", "Seattle", 47.60621, -122.33207}, - } - for _, loc := range testLocations { - sqlDB.Exec(t, - `INSERT INTO system.locations ("localityKey", "localityValue", latitude, longitude) VALUES ($1, $2, $3, $4)`, - loc.localityKey, loc.localityValue, loc.latitude, loc.longitude, - ) - } - var res serverpb.LocationsResponse - if err := getAdminJSONProtoWithAdminOption(s, "locations", &res, false /* isAdmin */); err != nil { - t.Fatal(err) - } - for i, loc := range testLocations { - expLoc := serverpb.LocationsResponse_Location{ - LocalityKey: loc.localityKey, - LocalityValue: loc.localityValue, - Latitude: loc.latitude, - Longitude: loc.longitude, - } - if !reflect.DeepEqual(res.Locations[i], expLoc) { - t.Errorf("%d: expected location %v, but got %v", i, expLoc, res.Locations[i]) - } - } -} - -func TestAdminAPIQueryPlan(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - s, conn, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails - // with it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - sqlDB := sqlutils.MakeSQLRunner(conn) - - sqlDB.Exec(t, `CREATE DATABASE api_test`) - sqlDB.Exec(t, `CREATE TABLE api_test.t1 (id int primary key, name string)`) - sqlDB.Exec(t, `CREATE TABLE api_test.t2 (id int primary key, name string)`) - - testCases := []struct { - query string - exp []string - }{ - {"SELECT sum(id) FROM api_test.t1", []string{"nodeNames\":[\"1\"]", "Columns: id"}}, - {"SELECT sum(1) FROM api_test.t1 JOIN api_test.t2 on t1.id = t2.id", []string{"nodeNames\":[\"1\"]", "Columns: id"}}, - } - for i, testCase := range testCases { - var res serverpb.QueryPlanResponse - queryParam := url.QueryEscape(testCase.query) - if err := getAdminJSONProto(s, fmt.Sprintf("queryplan?query=%s", queryParam), &res); err != nil { - t.Errorf("%d: got error %s", i, err) - } - - for _, exp := range testCase.exp { - if !strings.Contains(res.DistSQLPhysicalQueryPlan, exp) { - t.Errorf("%d: expected response %v to contain %s", i, res, exp) - } - } - } - -} - -func TestAdminAPIRangeLogByRangeID(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, db, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails - // with it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - - rangeID := 654321 - testCases := []struct { - rangeID int - hasLimit bool - limit int - expected int - }{ - {rangeID, true, 0, 2}, - {rangeID, true, -1, 2}, - {rangeID, true, 1, 1}, - {rangeID, false, 0, 2}, - // We'll create one event that has rangeID+1 as the otherRangeID. - {rangeID + 1, false, 0, 1}, - } - - for _, otherRangeID := range []int{rangeID + 1, rangeID + 2} { - if _, err := db.Exec( - `INSERT INTO system.rangelog ( - timestamp, "rangeID", "otherRangeID", "storeID", "eventType" - ) VALUES ( - now(), $1, $2, $3, $4 - )`, - rangeID, otherRangeID, - 1, // storeID - kvserverpb.RangeLogEventType_add_voter.String(), - ); err != nil { - t.Fatal(err) - } - } - - for _, tc := range testCases { - url := fmt.Sprintf("rangelog/%d", tc.rangeID) - if tc.hasLimit { - url += fmt.Sprintf("?limit=%d", tc.limit) - } - t.Run(url, func(t *testing.T) { - var resp serverpb.RangeLogResponse - if err := getAdminJSONProto(s, url, &resp); err != nil { - t.Fatal(err) - } - - if e, a := tc.expected, len(resp.Events); e != a { - t.Fatalf("expected %d events, got %d", e, a) - } - - for _, event := range resp.Events { - expID := roachpb.RangeID(tc.rangeID) - if event.Event.RangeID != expID && event.Event.OtherRangeID != expID { - t.Errorf("expected rangeID or otherRangeID to be %d, got %d and r%d", - expID, event.Event.RangeID, event.Event.OtherRangeID) - } - } - }) - } -} - -// Test the range log API when queries are not filtered by a range ID (like in -// TestAdminAPIRangeLogByRangeID). -func TestAdminAPIFullRangeLog(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, db, _ := serverutils.StartServer(t, - base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails - // with it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - Knobs: base.TestingKnobs{ - Store: &kvserver.StoreTestingKnobs{ - DisableSplitQueue: true, - }, - }, - }) - defer s.Stopper().Stop(context.Background()) - - // Insert something in the rangelog table, otherwise it's empty for new - // clusters. - rows, err := db.Query(`SELECT count(1) FROM system.rangelog`) - if err != nil { - t.Fatal(err) - } - if !rows.Next() { - t.Fatal("missing row") - } - var cnt int - if err := rows.Scan(&cnt); err != nil { - t.Fatal(err) - } - if err := rows.Close(); err != nil { - t.Fatal(err) - } - if cnt != 0 { - t.Fatalf("expected 0 rows in system.rangelog, found: %d", cnt) - } - const rangeID = 100 - for i := 0; i < 10; i++ { - if _, err := db.Exec( - `INSERT INTO system.rangelog ( - timestamp, "rangeID", "storeID", "eventType" - ) VALUES (now(), $1, 1, $2)`, - rangeID, - kvserverpb.RangeLogEventType_add_voter.String(), - ); err != nil { - t.Fatal(err) - } - } - expectedEvents := 10 - - testCases := []struct { - hasLimit bool - limit int - expected int - }{ - {false, 0, expectedEvents}, - {true, 0, expectedEvents}, - {true, -1, expectedEvents}, - {true, 1, 1}, - } - - for _, tc := range testCases { - url := "rangelog" - if tc.hasLimit { - url += fmt.Sprintf("?limit=%d", tc.limit) - } - t.Run(url, func(t *testing.T) { - var resp serverpb.RangeLogResponse - if err := getAdminJSONProto(s, url, &resp); err != nil { - t.Fatal(err) - } - events := resp.Events - if e, a := tc.expected, len(events); e != a { - var sb strings.Builder - for _, ev := range events { - sb.WriteString(ev.String() + "\n") - } - t.Fatalf("expected %d events, got %d:\n%s", e, a, sb.String()) - } - }) - } -} - -func TestAdminAPIDataDistribution(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{}) - defer testCluster.Stopper().Stop(context.Background()) - - firstServer := testCluster.Server(0) - sqlDB := sqlutils.MakeSQLRunner(testCluster.ServerConn(0)) - - // TODO(irfansharif): The data-distribution page and underyling APIs don't - // know how to deal with coalesced ranges. See #97942. - sqlDB.Exec(t, `SET CLUSTER SETTING spanconfig.storage_coalesce_adjacent.enabled = false`) - - // Create some tables. - sqlDB.Exec(t, `CREATE DATABASE roachblog`) - sqlDB.Exec(t, `CREATE TABLE roachblog.posts (id INT PRIMARY KEY, title text, body text)`) - sqlDB.Exec(t, `CREATE TABLE roachblog.comments ( - id INT PRIMARY KEY, - post_id INT REFERENCES roachblog.posts, - body text - )`) - sqlDB.Exec(t, `CREATE SCHEMA roachblog."foo bar"`) - sqlDB.Exec(t, `CREATE TABLE roachblog."foo bar".other_stuff(id INT PRIMARY KEY, body TEXT)`) - // Test special characters in DB and table names. - sqlDB.Exec(t, `CREATE DATABASE "sp'ec\ch""ars"`) - sqlDB.Exec(t, `CREATE TABLE "sp'ec\ch""ars"."more\spec'chars" (id INT PRIMARY KEY)`) - - // Make sure secondary tenants don't cause the endpoint to error. - sqlDB.Exec(t, "CREATE TENANT 'app'") - - // Verify that we see their replicas in the DataDistribution response, evenly spread - // across the test cluster's three nodes. - - expectedDatabaseInfo := map[string]serverpb.DataDistributionResponse_DatabaseInfo{ - "roachblog": { - TableInfo: map[string]serverpb.DataDistributionResponse_TableInfo{ - "public.posts": { - ReplicaCountByNodeId: map[roachpb.NodeID]int64{ - 1: 1, - 2: 1, - 3: 1, - }, - }, - "public.comments": { - ReplicaCountByNodeId: map[roachpb.NodeID]int64{ - 1: 1, - 2: 1, - 3: 1, - }, - }, - `"foo bar".other_stuff`: { - ReplicaCountByNodeId: map[roachpb.NodeID]int64{ - 1: 1, - 2: 1, - 3: 1, - }, - }, - }, - }, - `sp'ec\ch"ars`: { - TableInfo: map[string]serverpb.DataDistributionResponse_TableInfo{ - `public."more\spec'chars"`: { - ReplicaCountByNodeId: map[roachpb.NodeID]int64{ - 1: 1, - 2: 1, - 3: 1, - }, - }, - }, - }, - } - - // Wait for the new tables' ranges to be created and replicated. - testutils.SucceedsSoon(t, func() error { - var resp serverpb.DataDistributionResponse - if err := getAdminJSONProto(firstServer, "data_distribution", &resp); err != nil { - t.Fatal(err) - } - - delete(resp.DatabaseInfo, "system") // delete results for system database. - if !reflect.DeepEqual(resp.DatabaseInfo, expectedDatabaseInfo) { - return fmt.Errorf("expected %v; got %v", expectedDatabaseInfo, resp.DatabaseInfo) - } - - // Don't test anything about the zone configs for now; just verify that something is there. - if len(resp.ZoneConfigs) == 0 { - return fmt.Errorf("no zone configs returned") - } - - return nil - }) - - // Verify that the request still works after a table has been dropped, - // and that dropped_at is set on the dropped table. - sqlDB.Exec(t, `DROP TABLE roachblog.comments`) - - var resp serverpb.DataDistributionResponse - if err := getAdminJSONProto(firstServer, "data_distribution", &resp); err != nil { - t.Fatal(err) - } - - if resp.DatabaseInfo["roachblog"].TableInfo["public.comments"].DroppedAt == nil { - t.Fatal("expected roachblog.comments to have dropped_at set but it's nil") - } - - // Verify that the request still works after a database has been dropped. - sqlDB.Exec(t, `DROP DATABASE roachblog CASCADE`) - - if err := getAdminJSONProto(firstServer, "data_distribution", &resp); err != nil { - t.Fatal(err) - } -} - -func BenchmarkAdminAPIDataDistribution(b *testing.B) { - skip.UnderShort(b, "TODO: fix benchmark") - testCluster := serverutils.StartNewTestCluster(b, 3, base.TestClusterArgs{}) - defer testCluster.Stopper().Stop(context.Background()) - - firstServer := testCluster.Server(0) - sqlDB := sqlutils.MakeSQLRunner(testCluster.ServerConn(0)) - - sqlDB.Exec(b, `CREATE DATABASE roachblog`) - - // Create a bunch of tables. - for i := 0; i < 200; i++ { - sqlDB.Exec( - b, - fmt.Sprintf(`CREATE TABLE roachblog.t%d (id INT PRIMARY KEY, title text, body text)`, i), - ) - // TODO(vilterp): split to increase the number of ranges for each table - } - - b.ResetTimer() - for n := 0; n < b.N; n++ { - var resp serverpb.DataDistributionResponse - if err := getAdminJSONProto(firstServer, "data_distribution", &resp); err != nil { - b.Fatal(err) - } - } - b.StopTimer() -} - -func TestEnqueueRange(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{ - ReplicationMode: base.ReplicationManual, - }) - defer testCluster.Stopper().Stop(context.Background()) - - // Up-replicate r1 to all 3 nodes. We use manual replication to avoid lease - // transfers causing temporary conditions in which no store is the - // leaseholder, which can break the tests below. - _, err := testCluster.AddVoters(roachpb.KeyMin, testCluster.Target(1), testCluster.Target(2)) - if err != nil { - t.Fatal(err) - } - - // RangeID being queued - const realRangeID = 1 - const fakeRangeID = 999 - - // Who we expect responses from. - const none = 0 - const leaseholder = 1 - const allReplicas = 3 - - testCases := []struct { - nodeID roachpb.NodeID - queue string - rangeID roachpb.RangeID - expectedDetails int - expectedNonErrors int - }{ - // Success cases - {0, "mvccGC", realRangeID, allReplicas, leaseholder}, - {0, "split", realRangeID, allReplicas, leaseholder}, - {0, "replicaGC", realRangeID, allReplicas, allReplicas}, - {0, "RaFtLoG", realRangeID, allReplicas, allReplicas}, - {0, "RAFTSNAPSHOT", realRangeID, allReplicas, allReplicas}, - {0, "consistencyChecker", realRangeID, allReplicas, leaseholder}, - {0, "TIMESERIESmaintenance", realRangeID, allReplicas, leaseholder}, - {1, "raftlog", realRangeID, leaseholder, leaseholder}, - {2, "raftlog", realRangeID, leaseholder, 1}, - {3, "raftlog", realRangeID, leaseholder, 1}, - // Compatibility cases. - // TODO(nvanbenschoten): remove this in v23.1. - {0, "gc", realRangeID, allReplicas, leaseholder}, - {0, "GC", realRangeID, allReplicas, leaseholder}, - // Error cases - {0, "gv", realRangeID, allReplicas, none}, - {0, "GC", fakeRangeID, allReplicas, none}, - } - - for _, tc := range testCases { - t.Run(tc.queue, func(t *testing.T) { - req := &serverpb.EnqueueRangeRequest{ - NodeID: tc.nodeID, - Queue: tc.queue, - RangeID: tc.rangeID, - } - var resp serverpb.EnqueueRangeResponse - if err := postAdminJSONProto(testCluster.Server(0), "enqueue_range", req, &resp); err != nil { - t.Fatal(err) - } - if e, a := tc.expectedDetails, len(resp.Details); e != a { - t.Errorf("expected %d details; got %d: %+v", e, a, resp) - } - var numNonErrors int - for _, details := range resp.Details { - if len(details.Events) > 0 && details.Error == "" { - numNonErrors++ - } - } - if tc.expectedNonErrors != numNonErrors { - t.Errorf("expected %d non-error details; got %d: %+v", tc.expectedNonErrors, numNonErrors, resp) - } - }) - } - - // Finally, test a few more basic error cases. - reqs := []*serverpb.EnqueueRangeRequest{ - {NodeID: -1, Queue: "mvccGC"}, - {Queue: ""}, - {RangeID: -1, Queue: "mvccGC"}, - } - for _, req := range reqs { - t.Run(fmt.Sprint(req), func(t *testing.T) { - var resp serverpb.EnqueueRangeResponse - err := postAdminJSONProto(testCluster.Server(0), "enqueue_range", req, &resp) - if err == nil { - t.Fatalf("unexpected success: %+v", resp) - } - if !testutils.IsError(err, "400 Bad Request") { - t.Fatalf("unexpected error type: %+v", err) - } - }) - } -} - -func TestStatsforSpanOnLocalMax(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{}) - defer testCluster.Stopper().Stop(context.Background()) - firstServer := testCluster.Server(0) - adminServer := firstServer.(*TestServer).Server.admin - - underTest := roachpb.Span{ - Key: keys.LocalMax, - EndKey: keys.SystemPrefix, - } - - _, err := adminServer.statsForSpan(context.Background(), underTest) - if err != nil { - t.Fatal(err) - } -} - -// TestEndpointTelemetryBasic tests that the telemetry collection on the usage of -// CRDB's endpoints works as expected by recording the call counts of `Admin` & -// `Status` requests. -func TestEndpointTelemetryBasic(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails - // with it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(context.Background()) - - // Check that calls over HTTP are recorded. - var details serverpb.LocationsResponse - if err := getAdminJSONProto(s, "locations", &details); err != nil { - t.Fatal(err) - } - require.GreaterOrEqual(t, telemetry.Read(getServerEndpointCounter( - "/cockroach.server.serverpb.Admin/Locations", - )), int32(1)) - - var resp serverpb.StatementsResponse - if err := getStatusJSONProto(s, "statements", &resp); err != nil { - t.Fatal(err) - } - require.Equal(t, int32(1), telemetry.Read(getServerEndpointCounter( - "/cockroach.server.serverpb.Status/Statements", - ))) -} - -// checkNodeCheckResultReady is a helper function for validating that the -// results of a decommission pre-check on a single node show it is ready. -func checkNodeCheckResultReady( - t *testing.T, - nID roachpb.NodeID, - replicaCount int64, - checkResult serverpb.DecommissionPreCheckResponse_NodeCheckResult, -) { - require.Equal(t, serverpb.DecommissionPreCheckResponse_NodeCheckResult{ - NodeID: nID, - DecommissionReadiness: serverpb.DecommissionPreCheckResponse_READY, - ReplicaCount: replicaCount, - CheckedRanges: nil, - }, checkResult) -} - -// checkRangeCheckResult is a helper function for validating a range error -// returned as part of a decommission pre-check. -func checkRangeCheckResult( - t *testing.T, - desc roachpb.RangeDescriptor, - checkResult serverpb.DecommissionPreCheckResponse_RangeCheckResult, - expectedAction string, - expectedErrSubstr string, - expectTraces bool, -) { - passed := false - defer func() { - if !passed { - t.Logf("failed checking %s", desc) - if expectTraces { - var traceBuilder strings.Builder - for _, event := range checkResult.Events { - fmt.Fprintf(&traceBuilder, "\n(%s) %s", event.Time, event.Message) - } - t.Logf("trace events: %s", traceBuilder.String()) - } - } - }() - require.Equalf(t, desc.RangeID, checkResult.RangeID, "expected r%d, got r%d with error: \"%s\"", - desc.RangeID, checkResult.RangeID, checkResult.Error) - require.Equalf(t, expectedAction, checkResult.Action, "r%d expected action %s, got action %s with error: \"%s\"", - desc.RangeID, expectedAction, checkResult.Action, checkResult.Error) - require.NotEmptyf(t, checkResult.Error, "r%d expected non-empty error", checkResult.RangeID) - if len(expectedErrSubstr) > 0 { - require.Containsf(t, checkResult.Error, expectedErrSubstr, "r%d expected error with \"%s\", got error: \"%s\"", - desc.RangeID, expectedErrSubstr, checkResult.Error) - } - if expectTraces { - require.NotEmptyf(t, checkResult.Events, "r%d expected traces, got none with error: \"%s\"", - checkResult.RangeID, checkResult.Error) - } else { - require.Emptyf(t, checkResult.Events, "r%d expected no traces with error: \"%s\"", - checkResult.RangeID, checkResult.Error) - } - passed = true -} - -// TestDecommissionPreCheckBasicReadiness tests the basic functionality of the -// DecommissionPreCheck endpoint. -func TestDecommissionPreCheckBasicReadiness(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - skip.UnderRace(t) // can't handle 7-node clusters - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 7, base.TestClusterArgs{ - ReplicationMode: base.ReplicationManual, // saves time - }) - defer tc.Stopper().Stop(ctx) - - adminSrv := tc.Server(4) - conn, err := adminSrv.RPCContext().GRPCDialNode( - adminSrv.RPCAddr(), adminSrv.NodeID(), rpc.DefaultClass).Connect(ctx) - require.NoError(t, err) - adminClient := serverpb.NewAdminClient(conn) - - resp, err := adminClient.DecommissionPreCheck(ctx, &serverpb.DecommissionPreCheckRequest{ - NodeIDs: []roachpb.NodeID{tc.Server(5).NodeID()}, - }) - require.NoError(t, err) - require.Len(t, resp.CheckedNodes, 1) - checkNodeCheckResultReady(t, tc.Server(5).NodeID(), 0, resp.CheckedNodes[0]) -} - -// TestDecommissionPreCheckUnready tests the functionality of the -// DecommissionPreCheck endpoint with some nodes not ready. -func TestDecommissionPreCheckUnready(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - skip.UnderRace(t) // can't handle 7-node clusters - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 7, base.TestClusterArgs{ - ReplicationMode: base.ReplicationManual, // saves time - }) - defer tc.Stopper().Stop(ctx) - - // Add replicas to a node we will check. - // Scratch range should have RF=3, liveness range should have RF=5. - adminSrvIdx := 3 - decommissioningSrvIdx := 5 - scratchKey := tc.ScratchRange(t) - scratchDesc := tc.AddVotersOrFatal(t, scratchKey, tc.Target(decommissioningSrvIdx)) - livenessDesc := tc.LookupRangeOrFatal(t, keys.NodeLivenessPrefix) - livenessDesc = tc.AddVotersOrFatal(t, livenessDesc.StartKey.AsRawKey(), tc.Target(decommissioningSrvIdx)) - - adminSrv := tc.Server(adminSrvIdx) - decommissioningSrv := tc.Server(decommissioningSrvIdx) - conn, err := adminSrv.RPCContext().GRPCDialNode( - adminSrv.RPCAddr(), adminSrv.NodeID(), rpc.DefaultClass).Connect(ctx) - require.NoError(t, err) - adminClient := serverpb.NewAdminClient(conn) - - checkNodeReady := func(nID roachpb.NodeID, replicaCount int64, strict bool) { - resp, err := adminClient.DecommissionPreCheck(ctx, &serverpb.DecommissionPreCheckRequest{ - NodeIDs: []roachpb.NodeID{nID}, - StrictReadiness: strict, - }) - require.NoError(t, err) - require.Len(t, resp.CheckedNodes, 1) - checkNodeCheckResultReady(t, nID, replicaCount, resp.CheckedNodes[0]) - } - - awaitDecommissioned := func(nID roachpb.NodeID) { - testutils.SucceedsSoon(t, func() error { - livenesses, err := adminSrv.NodeLiveness().(*liveness.NodeLiveness).ScanNodeVitalityFromKV(ctx) - if err != nil { - return err - } - for nodeID, nodeLiveness := range livenesses { - if nodeID == nID { - if nodeLiveness.IsDecommissioned() { - return nil - } else { - return errors.Errorf("n%d has membership: %s", nID, nodeLiveness.MembershipStatus()) - } - } - } - return errors.Errorf("n%d liveness not found", nID) - }) - } - - checkAndDecommission := func(srvIdx int, replicaCount int64, strict bool) { - nID := tc.Server(srvIdx).NodeID() - checkNodeReady(nID, replicaCount, strict) - require.NoError(t, adminSrv.Decommission( - ctx, livenesspb.MembershipStatus_DECOMMISSIONING, []roachpb.NodeID{nID})) - require.NoError(t, adminSrv.Decommission( - ctx, livenesspb.MembershipStatus_DECOMMISSIONED, []roachpb.NodeID{nID})) - awaitDecommissioned(nID) - } - - // In non-strict mode, this decommission appears "ready". This is because the - // ranges with replicas on decommissioningSrv have priority action "AddVoter", - // and they have valid targets. - checkNodeReady(decommissioningSrv.NodeID(), 2, false) - - // In strict mode, we would expect the readiness check to fail. - resp, err := adminClient.DecommissionPreCheck(ctx, &serverpb.DecommissionPreCheckRequest{ - NodeIDs: []roachpb.NodeID{decommissioningSrv.NodeID()}, - NumReplicaReport: 50, - StrictReadiness: true, - CollectTraces: true, - }) - require.NoError(t, err) - nodeCheckResult := resp.CheckedNodes[0] - require.Equalf(t, serverpb.DecommissionPreCheckResponse_ALLOCATION_ERRORS, nodeCheckResult.DecommissionReadiness, - "expected n%d to have allocation errors, got %s", nodeCheckResult.NodeID, nodeCheckResult.DecommissionReadiness) - require.Len(t, nodeCheckResult.CheckedRanges, 2) - checkRangeCheckResult(t, livenessDesc, nodeCheckResult.CheckedRanges[0], - "add voter", "needs repair beyond replacing/removing", true, - ) - checkRangeCheckResult(t, scratchDesc, nodeCheckResult.CheckedRanges[1], - "add voter", "needs repair beyond replacing/removing", true, - ) - - // Add replicas to ensure we have the correct number of replicas for each range. - scratchDesc = tc.AddVotersOrFatal(t, scratchKey, tc.Target(adminSrvIdx)) - livenessDesc = tc.AddVotersOrFatal(t, livenessDesc.StartKey.AsRawKey(), - tc.Target(adminSrvIdx), tc.Target(4), tc.Target(6), - ) - require.True(t, hasReplicaOnServers(tc, &scratchDesc, 0, adminSrvIdx, decommissioningSrvIdx)) - require.True(t, hasReplicaOnServers(tc, &livenessDesc, 0, adminSrvIdx, decommissioningSrvIdx, 4, 6)) - require.Len(t, scratchDesc.InternalReplicas, 3) - require.Len(t, livenessDesc.InternalReplicas, 5) - - // Decommissioning pre-check should pass on decommissioningSrv in both strict - // and non-strict modes, as each range can find valid upreplication targets. - checkNodeReady(decommissioningSrv.NodeID(), 2, true) - - // Check and decommission empty nodes, decreasing to a 5-node cluster. - checkAndDecommission(1, 0, true) - checkAndDecommission(2, 0, true) - - // Check that we can still decommission. - // Below 5 nodes, system ranges will have an effective RF=3. - checkNodeReady(decommissioningSrv.NodeID(), 2, true) - - // Check that we can decommission the nodes with liveness replicas only. - checkAndDecommission(4, 1, true) - checkAndDecommission(6, 1, true) - - // Check range descriptors are as expected. - scratchDesc = tc.LookupRangeOrFatal(t, scratchDesc.StartKey.AsRawKey()) - livenessDesc = tc.LookupRangeOrFatal(t, livenessDesc.StartKey.AsRawKey()) - require.True(t, hasReplicaOnServers(tc, &scratchDesc, 0, adminSrvIdx, decommissioningSrvIdx)) - require.True(t, hasReplicaOnServers(tc, &livenessDesc, 0, adminSrvIdx, decommissioningSrvIdx, 4, 6)) - require.Len(t, scratchDesc.InternalReplicas, 3) - require.Len(t, livenessDesc.InternalReplicas, 5) - - // Cleanup orphaned liveness replicas and check. - livenessDesc = tc.RemoveVotersOrFatal(t, livenessDesc.StartKey.AsRawKey(), tc.Target(4), tc.Target(6)) - require.True(t, hasReplicaOnServers(tc, &livenessDesc, 0, adminSrvIdx, decommissioningSrvIdx)) - require.Len(t, livenessDesc.InternalReplicas, 3) - - // Validate that the node is not ready to decommission. - resp, err = adminClient.DecommissionPreCheck(ctx, &serverpb.DecommissionPreCheckRequest{ - NodeIDs: []roachpb.NodeID{decommissioningSrv.NodeID()}, - NumReplicaReport: 1, // Test that we limit errors. - StrictReadiness: true, - }) - require.NoError(t, err) - nodeCheckResult = resp.CheckedNodes[0] - require.Equalf(t, serverpb.DecommissionPreCheckResponse_ALLOCATION_ERRORS, nodeCheckResult.DecommissionReadiness, - "expected n%d to have allocation errors, got %s", nodeCheckResult.NodeID, nodeCheckResult.DecommissionReadiness) - require.Equal(t, int64(2), nodeCheckResult.ReplicaCount) - require.Len(t, nodeCheckResult.CheckedRanges, 1) - checkRangeCheckResult(t, livenessDesc, nodeCheckResult.CheckedRanges[0], - "replace decommissioning voter", - "0 of 2 live stores are able to take a new replica for the range "+ - "(2 already have a voter, 0 already have a non-voter); "+ - "likely not enough nodes in cluster", - false, - ) -} - -// TestDecommissionPreCheckMultiple tests the functionality of the -// DecommissionPreCheck endpoint with multiple nodes. -func TestDecommissionPreCheckMultiple(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 5, base.TestClusterArgs{ - ReplicationMode: base.ReplicationManual, // saves time - }) - defer tc.Stopper().Stop(ctx) - - // TODO(sarkesian): Once #95909 is merged, test checks on a 3-node decommission. - // e.g. Test both server idxs 3,4 and 2,3,4 (which should not pass checks). - adminSrvIdx := 1 - decommissioningSrvIdxs := []int{3, 4} - decommissioningSrvNodeIDs := make([]roachpb.NodeID, len(decommissioningSrvIdxs)) - for i, srvIdx := range decommissioningSrvIdxs { - decommissioningSrvNodeIDs[i] = tc.Server(srvIdx).NodeID() - } - - // Add replicas to nodes we will check. - // Scratch range should have RF=3, liveness range should have RF=5. - rangeDescs := []roachpb.RangeDescriptor{ - tc.LookupRangeOrFatal(t, keys.NodeLivenessPrefix), - tc.LookupRangeOrFatal(t, tc.ScratchRange(t)), - } - rangeDescSrvIdxs := [][]int{ - {0, 1, 2, 3, 4}, - {0, 3, 4}, - } - rangeDescSrvTargets := make([][]roachpb.ReplicationTarget, len(rangeDescs)) - for i, srvIdxs := range rangeDescSrvIdxs { - for _, srvIdx := range srvIdxs { - if srvIdx != 0 { - rangeDescSrvTargets[i] = append(rangeDescSrvTargets[i], tc.Target(srvIdx)) - } - } - } - - for i, rangeDesc := range rangeDescs { - rangeDescs[i] = tc.AddVotersOrFatal(t, rangeDesc.StartKey.AsRawKey(), rangeDescSrvTargets[i]...) - } - - for i, rangeDesc := range rangeDescs { - require.True(t, hasReplicaOnServers(tc, &rangeDesc, rangeDescSrvIdxs[i]...)) - require.Len(t, rangeDesc.InternalReplicas, len(rangeDescSrvIdxs[i])) - } - - adminSrv := tc.Server(adminSrvIdx) - conn, err := adminSrv.RPCContext().GRPCDialNode( - adminSrv.RPCAddr(), adminSrv.NodeID(), rpc.DefaultClass).Connect(ctx) - require.NoError(t, err) - adminClient := serverpb.NewAdminClient(conn) - - // We expect to be able to decommission the targeted nodes simultaneously. - resp, err := adminClient.DecommissionPreCheck(ctx, &serverpb.DecommissionPreCheckRequest{ - NodeIDs: decommissioningSrvNodeIDs, - NumReplicaReport: 50, - StrictReadiness: true, - CollectTraces: true, - }) - require.NoError(t, err) - require.Len(t, resp.CheckedNodes, len(decommissioningSrvIdxs)) - for i, nID := range decommissioningSrvNodeIDs { - checkNodeCheckResultReady(t, nID, int64(len(rangeDescs)), resp.CheckedNodes[i]) - } -} - -// TestDecommissionPreCheckInvalidNode tests the functionality of the -// DecommissionPreCheck endpoint where some nodes are invalid. -func TestDecommissionPreCheckInvalidNode(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 5, base.TestClusterArgs{ - ReplicationMode: base.ReplicationManual, // saves time - }) - defer tc.Stopper().Stop(ctx) - - adminSrvIdx := 1 - validDecommissioningNodeID := roachpb.NodeID(5) - invalidDecommissioningNodeID := roachpb.NodeID(34) - decommissioningNodeIDs := []roachpb.NodeID{validDecommissioningNodeID, invalidDecommissioningNodeID} - - // Add replicas to nodes we will check. - // Scratch range should have RF=3, liveness range should have RF=5. - rangeDescs := []roachpb.RangeDescriptor{ - tc.LookupRangeOrFatal(t, keys.NodeLivenessPrefix), - tc.LookupRangeOrFatal(t, tc.ScratchRange(t)), - } - rangeDescSrvIdxs := [][]int{ - {0, 1, 2, 3, 4}, - {0, 3, 4}, - } - rangeDescSrvTargets := make([][]roachpb.ReplicationTarget, len(rangeDescs)) - for i, srvIdxs := range rangeDescSrvIdxs { - for _, srvIdx := range srvIdxs { - if srvIdx != 0 { - rangeDescSrvTargets[i] = append(rangeDescSrvTargets[i], tc.Target(srvIdx)) - } - } - } - - for i, rangeDesc := range rangeDescs { - rangeDescs[i] = tc.AddVotersOrFatal(t, rangeDesc.StartKey.AsRawKey(), rangeDescSrvTargets[i]...) - } - - for i, rangeDesc := range rangeDescs { - require.True(t, hasReplicaOnServers(tc, &rangeDesc, rangeDescSrvIdxs[i]...)) - require.Len(t, rangeDesc.InternalReplicas, len(rangeDescSrvIdxs[i])) - } - - adminSrv := tc.Server(adminSrvIdx) - conn, err := adminSrv.RPCContext().GRPCDialNode( - adminSrv.RPCAddr(), adminSrv.NodeID(), rpc.DefaultClass).Connect(ctx) - require.NoError(t, err) - adminClient := serverpb.NewAdminClient(conn) - - // We expect the pre-check to fail as some node IDs are invalid. - resp, err := adminClient.DecommissionPreCheck(ctx, &serverpb.DecommissionPreCheckRequest{ - NodeIDs: decommissioningNodeIDs, - NumReplicaReport: 50, - StrictReadiness: true, - CollectTraces: true, - }) - require.NoError(t, err) - require.Len(t, resp.CheckedNodes, len(decommissioningNodeIDs)) - checkNodeCheckResultReady(t, validDecommissioningNodeID, int64(len(rangeDescs)), resp.CheckedNodes[0]) - require.Equal(t, serverpb.DecommissionPreCheckResponse_NodeCheckResult{ - NodeID: invalidDecommissioningNodeID, - DecommissionReadiness: serverpb.DecommissionPreCheckResponse_UNKNOWN, - ReplicaCount: 0, - CheckedRanges: nil, - }, resp.CheckedNodes[1]) -} - -func TestDecommissionSelf(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - skip.UnderRace(t) // can't handle 7-node clusters - - // Set up test cluster. - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 7, base.TestClusterArgs{ - ReplicationMode: base.ReplicationManual, // saves time - }) - defer tc.Stopper().Stop(ctx) - - // Decommission several nodes, including the node we're submitting the - // decommission request to. We use the admin client in order to test the - // admin server's logic, which involves a subsequent DecommissionStatus - // call which could fail if used from a node that's just decommissioned. - adminSrv := tc.Server(4) - conn, err := adminSrv.RPCContext().GRPCDialNode( - adminSrv.RPCAddr(), adminSrv.NodeID(), rpc.DefaultClass).Connect(ctx) - require.NoError(t, err) - adminClient := serverpb.NewAdminClient(conn) - decomNodeIDs := []roachpb.NodeID{ - tc.Server(4).NodeID(), - tc.Server(5).NodeID(), - tc.Server(6).NodeID(), - } - - // The DECOMMISSIONING call should return a full status response. - resp, err := adminClient.Decommission(ctx, &serverpb.DecommissionRequest{ - NodeIDs: decomNodeIDs, - TargetMembership: livenesspb.MembershipStatus_DECOMMISSIONING, - }) - require.NoError(t, err) - require.Len(t, resp.Status, len(decomNodeIDs)) - for i, nodeID := range decomNodeIDs { - status := resp.Status[i] - require.Equal(t, nodeID, status.NodeID) - // Liveness entries may not have been updated yet. - require.Contains(t, []livenesspb.MembershipStatus{ - livenesspb.MembershipStatus_ACTIVE, - livenesspb.MembershipStatus_DECOMMISSIONING, - }, status.Membership, "unexpected membership status %v for node %v", status, nodeID) - } - - // The DECOMMISSIONED call should return an empty response, to avoid - // erroring due to loss of cluster RPC access when decommissioning self. - resp, err = adminClient.Decommission(ctx, &serverpb.DecommissionRequest{ - NodeIDs: decomNodeIDs, - TargetMembership: livenesspb.MembershipStatus_DECOMMISSIONED, - }) - require.NoError(t, err) - require.Empty(t, resp.Status) - - // The nodes should now have been (or soon become) decommissioned. - for i := 0; i < tc.NumServers(); i++ { - srv := tc.Server(i) - expect := livenesspb.MembershipStatus_ACTIVE - for _, nodeID := range decomNodeIDs { - if srv.NodeID() == nodeID { - expect = livenesspb.MembershipStatus_DECOMMISSIONED - break - } - } - require.Eventually(t, func() bool { - liveness, ok := srv.NodeLiveness().(*liveness.NodeLiveness).GetLiveness(srv.NodeID()) - return ok && liveness.Membership == expect - }, 5*time.Second, 100*time.Millisecond, "timed out waiting for node %v status %v", i, expect) - } -} - -// TestDecommissionEnqueueReplicas tests that a decommissioning node's replicas -// are proactively enqueued into their replicateQueues by the other nodes in the -// system. -func TestDecommissionEnqueueReplicas(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - skip.UnderRace(t) // can't handle 7-node clusters - - ctx := context.Background() - enqueuedRangeIDs := make(chan roachpb.RangeID) - tc := serverutils.StartNewTestCluster(t, 7, base.TestClusterArgs{ - ReplicationMode: base.ReplicationManual, - ServerArgs: base.TestServerArgs{ - Insecure: true, // allows admin client without setting up certs - Knobs: base.TestingKnobs{ - Store: &kvserver.StoreTestingKnobs{ - EnqueueReplicaInterceptor: func( - queueName string, repl *kvserver.Replica, - ) { - require.Equal(t, queueName, "replicate") - enqueuedRangeIDs <- repl.RangeID - }, - }, - }, - }, - }) - defer tc.Stopper().Stop(ctx) - - decommissionAndCheck := func(decommissioningSrvIdx int) { - t.Logf("decommissioning n%d", tc.Target(decommissioningSrvIdx).NodeID) - // Add a scratch range's replica to a node we will decommission. - scratchKey := tc.ScratchRange(t) - decommissioningSrv := tc.Server(decommissioningSrvIdx) - tc.AddVotersOrFatal(t, scratchKey, tc.Target(decommissioningSrvIdx)) - - conn, err := decommissioningSrv.RPCContext().GRPCDialNode( - decommissioningSrv.RPCAddr(), decommissioningSrv.NodeID(), rpc.DefaultClass, - ).Connect(ctx) - require.NoError(t, err) - adminClient := serverpb.NewAdminClient(conn) - decomNodeIDs := []roachpb.NodeID{tc.Server(decommissioningSrvIdx).NodeID()} - _, err = adminClient.Decommission( - ctx, - &serverpb.DecommissionRequest{ - NodeIDs: decomNodeIDs, - TargetMembership: livenesspb.MembershipStatus_DECOMMISSIONING, - }, - ) - require.NoError(t, err) - - // Ensure that the scratch range's replica was proactively enqueued. - require.Equal(t, <-enqueuedRangeIDs, tc.LookupRangeOrFatal(t, scratchKey).RangeID) - - // Check that the node was marked as decommissioning in each of the nodes' - // decommissioningNodeMap. This needs to be wrapped in a SucceedsSoon to - // deal with gossip propagation delays. - testutils.SucceedsSoon(t, func() error { - for i := 0; i < tc.NumServers(); i++ { - srv := tc.Server(i) - if _, exists := srv.DecommissioningNodeMap()[decommissioningSrv.NodeID()]; !exists { - return errors.Newf("node %d not detected to be decommissioning", decommissioningSrv.NodeID()) - } - } - return nil - }) - } - - decommissionAndCheck(2 /* decommissioningSrvIdx */) - decommissionAndCheck(3 /* decommissioningSrvIdx */) - decommissionAndCheck(5 /* decommissioningSrvIdx */) -} - -func TestAdminDecommissionedOperations(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - skip.UnderRace(t, "test uses timeouts, and race builds cause the timeouts to be exceeded") - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 2, base.TestClusterArgs{ - ReplicationMode: base.ReplicationManual, // saves time - ServerArgs: base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails - // with it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - Insecure: true, // allows admin client without setting up certs - }, - }) - defer tc.Stopper().Stop(ctx) - - serverutils.SetClusterSetting(t, tc, "server.shutdown.jobs_wait", 0) - - scratchKey := tc.ScratchRange(t) - scratchRange := tc.LookupRangeOrFatal(t, scratchKey) - require.Len(t, scratchRange.InternalReplicas, 1) - require.Equal(t, tc.Server(0).NodeID(), scratchRange.InternalReplicas[0].NodeID) - - // Decommission server 1 and wait for it to lose cluster access. - srv := tc.Server(0) - decomSrv := tc.Server(1) - for _, status := range []livenesspb.MembershipStatus{ - livenesspb.MembershipStatus_DECOMMISSIONING, livenesspb.MembershipStatus_DECOMMISSIONED, - } { - require.NoError(t, srv.Decommission(ctx, status, []roachpb.NodeID{decomSrv.NodeID()})) - } - - testutils.SucceedsWithin(t, func() error { - timeoutCtx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - _, err := decomSrv.DB().Scan(timeoutCtx, keys.LocalMax, keys.MaxKey, 0) - if err == nil { - return errors.New("expected error") - } - s, ok := status.FromError(errors.UnwrapAll(err)) - if ok && s.Code() == codes.PermissionDenied { - return nil - } - return err - }, 10*time.Second) - - // Set up an admin client. - //lint:ignore SA1019 grpc.WithInsecure is deprecated - conn, err := grpc.Dial(decomSrv.ServingRPCAddr(), grpc.WithInsecure()) - require.NoError(t, err) - defer func() { - _ = conn.Close() // nolint:grpcconnclose - }() - adminClient := serverpb.NewAdminClient(conn) - - // Run some operations on the decommissioned node. The ones that require - // access to the cluster should fail, other should succeed. We're mostly - // concerned with making sure they return rather than hang due to internal - // retries. - testcases := []struct { - name string - expectCode codes.Code - op func(context.Context, serverpb.AdminClient) error - }{ - {"Cluster", codes.OK, func(ctx context.Context, c serverpb.AdminClient) error { - _, err := c.Cluster(ctx, &serverpb.ClusterRequest{}) - return err - }}, - {"Databases", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { - _, err := c.Databases(ctx, &serverpb.DatabasesRequest{}) - return err - }}, - {"DatabaseDetails", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { - _, err := c.DatabaseDetails(ctx, &serverpb.DatabaseDetailsRequest{Database: "foo"}) - return err - }}, - {"DataDistribution", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { - _, err := c.DataDistribution(ctx, &serverpb.DataDistributionRequest{}) - return err - }}, - {"Decommission", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { - _, err := c.Decommission(ctx, &serverpb.DecommissionRequest{ - NodeIDs: []roachpb.NodeID{srv.NodeID(), decomSrv.NodeID()}, - TargetMembership: livenesspb.MembershipStatus_DECOMMISSIONED, - }) - return err - }}, - {"DecommissionStatus", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { - _, err := c.DecommissionStatus(ctx, &serverpb.DecommissionStatusRequest{ - NodeIDs: []roachpb.NodeID{srv.NodeID(), decomSrv.NodeID()}, - }) - return err - }}, - {"EnqueueRange", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { - _, err := c.EnqueueRange(ctx, &serverpb.EnqueueRangeRequest{ - RangeID: scratchRange.RangeID, - Queue: "replicaGC", - }) - return err - }}, - {"Events", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { - _, err := c.Events(ctx, &serverpb.EventsRequest{}) - return err - }}, - {"Health", codes.OK, func(ctx context.Context, c serverpb.AdminClient) error { - _, err := c.Health(ctx, &serverpb.HealthRequest{}) - return err - }}, - {"Jobs", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { - _, err := c.Jobs(ctx, &serverpb.JobsRequest{}) - return err - }}, - {"Liveness", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { - _, err := c.Liveness(ctx, &serverpb.LivenessRequest{}) - return err - }}, - {"Locations", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { - _, err := c.Locations(ctx, &serverpb.LocationsRequest{}) - return err - }}, - {"NonTableStats", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { - _, err := c.NonTableStats(ctx, &serverpb.NonTableStatsRequest{}) - return err - }}, - {"QueryPlan", codes.OK, func(ctx context.Context, c serverpb.AdminClient) error { - _, err := c.QueryPlan(ctx, &serverpb.QueryPlanRequest{Query: "SELECT 1"}) - return err - }}, - {"RangeLog", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { - _, err := c.RangeLog(ctx, &serverpb.RangeLogRequest{}) - return err - }}, - {"Settings", codes.OK, func(ctx context.Context, c serverpb.AdminClient) error { - _, err := c.Settings(ctx, &serverpb.SettingsRequest{}) - return err - }}, - {"TableStats", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { - _, err := c.TableStats(ctx, &serverpb.TableStatsRequest{Database: "foo", Table: "bar"}) - return err - }}, - {"TableDetails", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { - _, err := c.TableDetails(ctx, &serverpb.TableDetailsRequest{Database: "foo", Table: "bar"}) - return err - }}, - {"Users", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { - _, err := c.Users(ctx, &serverpb.UsersRequest{}) - return err - }}, - // We drain at the end, since it may evict us. - {"Drain", codes.Unknown, func(ctx context.Context, c serverpb.AdminClient) error { - stream, err := c.Drain(ctx, &serverpb.DrainRequest{DoDrain: true}) - if err != nil { - return err - } - _, err = stream.Recv() - return err - }}, - } - - for _, tc := range testcases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - testutils.SucceedsWithin(t, func() error { - timeoutCtx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - err := tc.op(timeoutCtx, adminClient) - if tc.expectCode == codes.OK { - require.NoError(t, err) - return nil - } - if err == nil { - // This will cause SuccessWithin to retry. - return errors.New("expected error, got no error") - } - s, ok := status.FromError(errors.UnwrapAll(err)) - if !ok { - // Not a gRPC error. - // This will cause SuccessWithin to retry. - return err - } - require.Equal(t, tc.expectCode, s.Code(), "%+v", err) - return nil - }, 10*time.Second) - }) - } -} - -func TestAdminPrivilegeChecker(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - ctx := context.Background() - s, db, kvDB := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails - // with it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - }) - defer s.Stopper().Stop(ctx) - - sqlDB := sqlutils.MakeSQLRunner(db) - sqlDB.Exec(t, "CREATE USER withadmin") - sqlDB.Exec(t, "GRANT admin TO withadmin") - sqlDB.Exec(t, "CREATE USER withva") - sqlDB.Exec(t, "ALTER ROLE withva WITH VIEWACTIVITY") - sqlDB.Exec(t, "CREATE USER withvaredacted") - sqlDB.Exec(t, "ALTER ROLE withvaredacted WITH VIEWACTIVITYREDACTED") - sqlDB.Exec(t, "CREATE USER withvaandredacted") - sqlDB.Exec(t, "ALTER ROLE withvaandredacted WITH VIEWACTIVITY") - sqlDB.Exec(t, "ALTER ROLE withvaandredacted WITH VIEWACTIVITYREDACTED") - sqlDB.Exec(t, "CREATE USER withoutprivs") - sqlDB.Exec(t, "CREATE USER withvaglobalprivilege") - sqlDB.Exec(t, "GRANT SYSTEM VIEWACTIVITY TO withvaglobalprivilege") - sqlDB.Exec(t, "CREATE USER withvaredactedglobalprivilege") - sqlDB.Exec(t, "GRANT SYSTEM VIEWACTIVITYREDACTED TO withvaredactedglobalprivilege") - sqlDB.Exec(t, "CREATE USER withvaandredactedglobalprivilege") - sqlDB.Exec(t, "GRANT SYSTEM VIEWACTIVITY TO withvaandredactedglobalprivilege") - sqlDB.Exec(t, "GRANT SYSTEM VIEWACTIVITYREDACTED TO withvaandredactedglobalprivilege") - sqlDB.Exec(t, "CREATE USER withviewclustermetadata") - sqlDB.Exec(t, "GRANT SYSTEM VIEWCLUSTERMETADATA TO withviewclustermetadata") - sqlDB.Exec(t, "CREATE USER withviewdebug") - sqlDB.Exec(t, "GRANT SYSTEM VIEWDEBUG TO withviewdebug") - - execCfg := s.ExecutorConfig().(sql.ExecutorConfig) - - plannerFn := func(opName string) (interface{}, func()) { - // This is a hack to get around a Go package dependency cycle. See comment - // in sql/jobs/registry.go on planHookMaker. - txn := kvDB.NewTxn(ctx, "test") - return sql.NewInternalPlanner( - opName, - txn, - username.RootUserName(), - &sql.MemoryMetrics{}, - &execCfg, - sql.NewInternalSessionData(ctx, execCfg.Settings, opName), - ) - } - - underTest := &adminPrivilegeChecker{ - ie: s.InternalExecutor().(*sql.InternalExecutor), - st: s.ClusterSettings(), - makePlanner: plannerFn, - } - - withAdmin, err := username.MakeSQLUsernameFromPreNormalizedStringChecked("withadmin") - require.NoError(t, err) - withVa, err := username.MakeSQLUsernameFromPreNormalizedStringChecked("withva") - require.NoError(t, err) - withVaRedacted, err := username.MakeSQLUsernameFromPreNormalizedStringChecked("withvaredacted") - require.NoError(t, err) - withVaAndRedacted, err := username.MakeSQLUsernameFromPreNormalizedStringChecked("withvaandredacted") - require.NoError(t, err) - withoutPrivs, err := username.MakeSQLUsernameFromPreNormalizedStringChecked("withoutprivs") - require.NoError(t, err) - withVaGlobalPrivilege := username.MakeSQLUsernameFromPreNormalizedString("withvaglobalprivilege") - withVaRedactedGlobalPrivilege := username.MakeSQLUsernameFromPreNormalizedString("withvaredactedglobalprivilege") - withVaAndRedactedGlobalPrivilege := username.MakeSQLUsernameFromPreNormalizedString("withvaandredactedglobalprivilege") - withviewclustermetadata := username.MakeSQLUsernameFromPreNormalizedString("withviewclustermetadata") - withViewDebug := username.MakeSQLUsernameFromPreNormalizedString("withviewdebug") - - tests := []struct { - name string - checkerFun func(context.Context) error - usernameWantErr map[username.SQLUsername]bool - }{ - { - "requireViewActivityPermission", - underTest.requireViewActivityPermission, - map[username.SQLUsername]bool{ - withAdmin: false, withVa: false, withVaRedacted: true, withVaAndRedacted: false, withoutPrivs: true, - withVaGlobalPrivilege: false, withVaRedactedGlobalPrivilege: true, withVaAndRedactedGlobalPrivilege: false, - }, - }, - { - "requireViewActivityOrViewActivityRedactedPermission", - underTest.requireViewActivityOrViewActivityRedactedPermission, - map[username.SQLUsername]bool{ - withAdmin: false, withVa: false, withVaRedacted: false, withVaAndRedacted: false, withoutPrivs: true, - withVaGlobalPrivilege: false, withVaRedactedGlobalPrivilege: false, withVaAndRedactedGlobalPrivilege: false, - }, - }, - { - "requireViewActivityAndNoViewActivityRedactedPermission", - underTest.requireViewActivityAndNoViewActivityRedactedPermission, - map[username.SQLUsername]bool{ - withAdmin: false, withVa: false, withVaRedacted: true, withVaAndRedacted: true, withoutPrivs: true, - withVaGlobalPrivilege: false, withVaRedactedGlobalPrivilege: true, withVaAndRedactedGlobalPrivilege: true, - }, - }, - { - "requireViewClusterMetadataPermission", - underTest.requireViewClusterMetadataPermission, - map[username.SQLUsername]bool{ - withAdmin: false, withoutPrivs: true, withviewclustermetadata: false, - }, - }, - { - "requireViewDebugPermission", - underTest.requireViewDebugPermission, - map[username.SQLUsername]bool{ - withAdmin: false, withoutPrivs: true, withViewDebug: false, - }, - }, - } - - for _, tt := range tests { - for userName, wantErr := range tt.usernameWantErr { - t.Run(fmt.Sprintf("%s-%s", tt.name, userName), func(t *testing.T) { - ctx := metadata.NewIncomingContext(ctx, metadata.New(map[string]string{"websessionuser": userName.SQLIdentifier()})) - err := tt.checkerFun(ctx) - if wantErr { - require.Error(t, err) - return - } - require.NoError(t, err) - }) - } - } -} - -func TestServerError(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - ctx := context.Background() - pgError := pgerror.New(pgcode.OutOfMemory, "TestServerError.OutOfMemory") - err := serverError(ctx, pgError) - require.Equal(t, "rpc error: code = Internal desc = An internal server error has occurred. Please check your CockroachDB logs for more details. Error Code: 53200", err.Error()) - - err = serverError(ctx, err) - require.Equal(t, "rpc error: code = Internal desc = An internal server error has occurred. Please check your CockroachDB logs for more details. Error Code: 53200", err.Error()) - - err = fmt.Errorf("random error that is not pgerror or grpcstatus") - err = serverError(ctx, err) - require.Equal(t, "rpc error: code = Internal desc = An internal server error has occurred. Please check your CockroachDB logs for more details.", err.Error()) -} - -func TestDatabaseAndTableIndexRecommendations(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - stubTime := stubUnusedIndexTime{} - stubDropUnusedDuration := time.Hour - - s, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{ - // Disable the default test tenant for now as this tests fails - // with it enabled. Tracked with #81590. - DefaultTestTenant: base.TODOTestTenantDisabled, - Knobs: base.TestingKnobs{ - UnusedIndexRecommendKnobs: &idxusage.UnusedIndexRecommendationTestingKnobs{ - GetCreatedAt: stubTime.getCreatedAt, - GetLastRead: stubTime.getLastRead, - GetCurrentTime: stubTime.getCurrent, - }, - }, - }) - idxusage.DropUnusedIndexDuration.Override(context.Background(), &s.ClusterSettings().SV, stubDropUnusedDuration) - defer s.Stopper().Stop(context.Background()) - - db := sqlutils.MakeSQLRunner(sqlDB) - db.Exec(t, "CREATE DATABASE test") - db.Exec(t, "USE test") - // Create a table and secondary index. - db.Exec(t, "CREATE TABLE test.test_table (num INT PRIMARY KEY, letter char)") - db.Exec(t, "CREATE INDEX test_idx ON test.test_table (letter)") - - // Test when last read does not exist and there is no creation time. Expect - // an index recommendation (index never used). - stubTime.setLastRead(time.Time{}) - stubTime.setCreatedAt(nil) - - // Test database details endpoint. - var dbDetails serverpb.DatabaseDetailsResponse - if err := getAdminJSONProto( - s, - "databases/test?include_stats=true", - &dbDetails, - ); err != nil { - t.Fatal(err) - } - // Expect 1 index recommendation (no index recommendation on primary index). - require.Equal(t, int32(1), dbDetails.Stats.NumIndexRecommendations) - - // Test table details endpoint. - var tableDetails serverpb.TableDetailsResponse - if err := getAdminJSONProto(s, "databases/test/tables/test_table", &tableDetails); err != nil { - t.Fatal(err) - } - require.Equal(t, true, tableDetails.HasIndexRecommendations) - - // Test when last read does not exist and there is a creation time, and the - // unused index duration has been exceeded. Expect an index recommendation. - currentTime := timeutil.Now() - createdTime := currentTime.Add(-stubDropUnusedDuration) - stubTime.setCurrent(currentTime) - stubTime.setLastRead(time.Time{}) - stubTime.setCreatedAt(&createdTime) - - // Test database details endpoint. - dbDetails = serverpb.DatabaseDetailsResponse{} - if err := getAdminJSONProto( - s, - "databases/test?include_stats=true", - &dbDetails, - ); err != nil { - t.Fatal(err) - } - require.Equal(t, int32(1), dbDetails.Stats.NumIndexRecommendations) - - // Test table details endpoint. - tableDetails = serverpb.TableDetailsResponse{} - if err := getAdminJSONProto(s, "databases/test/tables/test_table", &tableDetails); err != nil { - t.Fatal(err) - } - require.Equal(t, true, tableDetails.HasIndexRecommendations) - - // Test when last read does not exist and there is a creation time, and the - // unused index duration has not been exceeded. Expect no index - // recommendation. - currentTime = timeutil.Now() - stubTime.setCurrent(currentTime) - stubTime.setLastRead(time.Time{}) - stubTime.setCreatedAt(¤tTime) - - // Test database details endpoint. - dbDetails = serverpb.DatabaseDetailsResponse{} - if err := getAdminJSONProto( - s, - "databases/test?include_stats=true", - &dbDetails, - ); err != nil { - t.Fatal(err) - } - require.Equal(t, int32(0), dbDetails.Stats.NumIndexRecommendations) - - // Test table details endpoint. - tableDetails = serverpb.TableDetailsResponse{} - if err := getAdminJSONProto(s, "databases/test/tables/test_table", &tableDetails); err != nil { - t.Fatal(err) - } - require.Equal(t, false, tableDetails.HasIndexRecommendations) - - // Test when last read exists and the unused index duration has been - // exceeded. Expect an index recommendation. - currentTime = timeutil.Now() - lastRead := currentTime.Add(-stubDropUnusedDuration) - stubTime.setCurrent(currentTime) - stubTime.setLastRead(lastRead) - stubTime.setCreatedAt(nil) - - // Test database details endpoint. - dbDetails = serverpb.DatabaseDetailsResponse{} - if err := getAdminJSONProto( - s, - "databases/test?include_stats=true", - &dbDetails, - ); err != nil { - t.Fatal(err) - } - require.Equal(t, int32(1), dbDetails.Stats.NumIndexRecommendations) - - // Test table details endpoint. - tableDetails = serverpb.TableDetailsResponse{} - if err := getAdminJSONProto(s, "databases/test/tables/test_table", &tableDetails); err != nil { - t.Fatal(err) - } - require.Equal(t, true, tableDetails.HasIndexRecommendations) - - // Test when last read exists and the unused index duration has not been - // exceeded. Expect no index recommendation. - currentTime = timeutil.Now() - stubTime.setCurrent(currentTime) - stubTime.setLastRead(currentTime) - stubTime.setCreatedAt(nil) - - // Test database details endpoint. - dbDetails = serverpb.DatabaseDetailsResponse{} - if err := getAdminJSONProto( - s, - "databases/test?include_stats=true", - &dbDetails, - ); err != nil { - t.Fatal(err) - } - require.Equal(t, int32(0), dbDetails.Stats.NumIndexRecommendations) - - // Test table details endpoint. - tableDetails = serverpb.TableDetailsResponse{} - if err := getAdminJSONProto(s, "databases/test/tables/test_table", &tableDetails); err != nil { - t.Fatal(err) - } - require.Equal(t, false, tableDetails.HasIndexRecommendations) -} diff --git a/pkg/server/admin_test_utils.go b/pkg/server/admin_test_utils.go deleted file mode 100644 index 99a17af78f49..000000000000 --- a/pkg/server/admin_test_utils.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2022 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package server - -import ( - "time" - - "github.com/cockroachdb/cockroach/pkg/util/syncutil" -) - -type stubUnusedIndexTime struct { - syncutil.RWMutex - current time.Time - lastRead time.Time - createdAt *time.Time -} - -func (s *stubUnusedIndexTime) setCurrent(t time.Time) { - s.RWMutex.Lock() - defer s.RWMutex.Unlock() - s.current = t -} - -func (s *stubUnusedIndexTime) setLastRead(t time.Time) { - s.RWMutex.Lock() - defer s.RWMutex.Unlock() - s.lastRead = t -} - -func (s *stubUnusedIndexTime) setCreatedAt(t *time.Time) { - s.RWMutex.Lock() - defer s.RWMutex.Unlock() - s.createdAt = t -} - -func (s *stubUnusedIndexTime) getCurrent() time.Time { - s.RWMutex.RLock() - defer s.RWMutex.RUnlock() - return s.current -} - -func (s *stubUnusedIndexTime) getLastRead() time.Time { - s.RWMutex.RLock() - defer s.RWMutex.RUnlock() - return s.lastRead -} - -func (s *stubUnusedIndexTime) getCreatedAt() *time.Time { - s.RWMutex.RLock() - defer s.RWMutex.RUnlock() - return s.createdAt -} diff --git a/pkg/server/api_v2.go b/pkg/server/api_v2.go index 1f14f891394b..1d9b69fdd7f4 100644 --- a/pkg/server/api_v2.go +++ b/pkg/server/api_v2.go @@ -35,13 +35,16 @@ package server import ( "context" - "encoding/json" "fmt" "net/http" "strconv" "github.com/cockroachdb/cockroach/pkg/kv" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" + "github.com/cockroachdb/cockroach/pkg/server/apiutil" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srverrors" "github.com/cockroachdb/cockroach/pkg/server/telemetry" "github.com/cockroachdb/cockroach/pkg/sql/roleoption" "github.com/cockroachdb/cockroach/pkg/util/httputil" @@ -49,23 +52,6 @@ import ( "github.com/gorilla/mux" ) -const ( - apiV2Path = "/api/v2/" - apiV2AuthHeader = "X-Cockroach-API-Session" -) - -func writeJSONResponse(ctx context.Context, w http.ResponseWriter, code int, payload interface{}) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(code) - - res, err := json.Marshal(payload) - if err != nil { - apiV2InternalError(ctx, err, w) - return - } - _, _ = w.Write(res) -} - type ApiV2System interface { health(w http.ResponseWriter, r *http.Request) listNodes(w http.ResponseWriter, r *http.Request) @@ -80,7 +66,7 @@ type apiV2ServerOpts struct { db *kv.DB } -// apiV2Server implements version 2 API endpoints, under apiV2Path. The +// apiV2Server implements version 2 API endpoints, under apiconstants.APIV2Path. The // implementation of some endpoints is delegated to sub-servers (eg. auth // endpoints like `/login` and `/logout` are passed onto authServer), while // others are implemented directly by apiV2Server. @@ -89,7 +75,7 @@ type apiV2ServerOpts struct { // registerRoutes(). type apiV2Server struct { admin *adminServer - authServer *authenticationV2Server + authServer authserver.ServerV2 status *statusServer promRuleExporter *metric.PrometheusRuleExporter mux *mux.Router @@ -112,9 +98,10 @@ var _ http.Handler = &apiV2Server{} // newAPIV2Server returns a new apiV2Server. func newAPIV2Server(ctx context.Context, opts *apiV2ServerOpts) http.Handler { - authServer := newAuthenticationV2Server(ctx, opts.sqlServer, opts.sqlServer.cfg.Config, apiV2Path) + authServer := authserver.NewV2Server(ctx, opts.sqlServer, opts.sqlServer.cfg.Config, apiconstants.APIV2Path) innerMux := mux.NewRouter() - authMux := newAuthenticationV2Mux(authServer, innerMux) + allowAnonymous := opts.sqlServer.cfg.Insecure + authMux := authserver.NewV2Mux(authServer, innerMux, allowAnonymous) outerMux := mux.NewRouter() systemAdmin, saOk := opts.admin.(*systemAdminServer) @@ -177,33 +164,33 @@ func registerRoutes( url string handler http.HandlerFunc requiresAuth bool - role apiRole + role authserver.APIRole option roleoption.Option tenantEnabled bool }{ // Pass through auth-related endpoints to the auth server. - {"login/", a.authServer.ServeHTTP, false /* requiresAuth */, regularRole, noOption, false}, - {"logout/", a.authServer.ServeHTTP, false /* requiresAuth */, regularRole, noOption, false}, + {"login/", a.authServer.ServeHTTP, false /* requiresAuth */, authserver.RegularRole, noOption, false}, + {"logout/", a.authServer.ServeHTTP, false /* requiresAuth */, authserver.RegularRole, noOption, false}, // Directly register other endpoints in the api server. - {"sessions/", a.listSessions, true /* requiresAuth */, adminRole, noOption, false}, - {"nodes/", systemRoutes.listNodes, true, adminRole, noOption, false}, + {"sessions/", a.listSessions, true /* requiresAuth */, authserver.AdminRole, noOption, false}, + {"nodes/", systemRoutes.listNodes, true, authserver.AdminRole, noOption, false}, // Any endpoint returning range information requires an admin user. This is because range start/end keys // are sensitive info. - {"nodes/{node_id}/ranges/", systemRoutes.listNodeRanges, true, adminRole, noOption, false}, - {"ranges/hot/", a.listHotRanges, true, adminRole, noOption, false}, - {"ranges/{range_id:[0-9]+}/", a.listRange, true, adminRole, noOption, false}, - {"health/", systemRoutes.health, false, regularRole, noOption, false}, - {"users/", a.listUsers, true, regularRole, noOption, false}, - {"events/", a.listEvents, true, adminRole, noOption, false}, - {"databases/", a.listDatabases, true, regularRole, noOption, false}, - {"databases/{database_name:[\\w.]+}/", a.databaseDetails, true, regularRole, noOption, false}, - {"databases/{database_name:[\\w.]+}/grants/", a.databaseGrants, true, regularRole, noOption, false}, - {"databases/{database_name:[\\w.]+}/tables/", a.databaseTables, true, regularRole, noOption, false}, - {"databases/{database_name:[\\w.]+}/tables/{table_name:[\\w.]+}/", a.tableDetails, true, regularRole, noOption, false}, - {"rules/", a.listRules, false, regularRole, noOption, true}, + {"nodes/{node_id}/ranges/", systemRoutes.listNodeRanges, true, authserver.AdminRole, noOption, false}, + {"ranges/hot/", a.listHotRanges, true, authserver.AdminRole, noOption, false}, + {"ranges/{range_id:[0-9]+}/", a.listRange, true, authserver.AdminRole, noOption, false}, + {"health/", systemRoutes.health, false, authserver.RegularRole, noOption, false}, + {"users/", a.listUsers, true, authserver.RegularRole, noOption, false}, + {"events/", a.listEvents, true, authserver.AdminRole, noOption, false}, + {"databases/", a.listDatabases, true, authserver.RegularRole, noOption, false}, + {"databases/{database_name:[\\w.]+}/", a.databaseDetails, true, authserver.RegularRole, noOption, false}, + {"databases/{database_name:[\\w.]+}/grants/", a.databaseGrants, true, authserver.RegularRole, noOption, false}, + {"databases/{database_name:[\\w.]+}/tables/", a.databaseTables, true, authserver.RegularRole, noOption, false}, + {"databases/{database_name:[\\w.]+}/tables/{table_name:[\\w.]+}/", a.tableDetails, true, authserver.RegularRole, noOption, false}, + {"rules/", a.listRules, false, authserver.RegularRole, noOption, true}, - {"sql/", a.execSQL, true, regularRole, noOption, true}, + {"sql/", a.execSQL, true, authserver.RegularRole, noOption, true}, } // For all routes requiring authentication, have the outer mux (a.mux) @@ -216,23 +203,23 @@ func registerRoutes( inner: route.handler, } if !route.tenantEnabled && !a.sqlServer.execCfg.Codec.ForSystemTenant() { - a.mux.Handle(apiV2Path+route.url, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + a.mux.Handle(apiconstants.APIV2Path+route.url, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "Not Available on Tenants", http.StatusNotImplemented) })) } if route.requiresAuth { - a.mux.Handle(apiV2Path+route.url, authMux) - if route.role != regularRole { - handler = &roleAuthorizationMux{ - ie: a.sqlServer.internalExecutor, - role: route.role, - option: route.option, - inner: handler, - } + a.mux.Handle(apiconstants.APIV2Path+route.url, authMux) + if route.role != authserver.RegularRole { + handler = authserver.NewRoleAuthzMux( + a.sqlServer.internalExecutor, + route.role, + route.option, + handler, + ) } - innerMux.Handle(apiV2Path+route.url, handler) + innerMux.Handle(apiconstants.APIV2Path+route.url, handler) } else { - a.mux.Handle(apiV2Path+route.url, handler) + a.mux.Handle(apiconstants.APIV2Path+route.url, handler) } } } @@ -314,11 +301,11 @@ func (a *apiV2Server) listSessions(w http.ResponseWriter, r *http.Request) { reqExcludeClosed := r.URL.Query().Get("exclude_closed_sessions") == "true" req := &serverpb.ListSessionsRequest{Username: reqUsername, ExcludeClosedSessions: reqExcludeClosed} response := &listSessionsResponse{} - outgoingCtx := forwardHTTPAuthInfoToRPCCalls(ctx, r) + outgoingCtx := authserver.ForwardHTTPAuthInfoToRPCCalls(ctx, r) responseProto, pagState, err := a.status.listSessionsHelper(outgoingCtx, req, limit, start) if err != nil { - apiV2InternalError(ctx, err, w) + srverrors.APIV2InternalError(ctx, err, w) return } var nextBytes []byte @@ -329,7 +316,7 @@ func (a *apiV2Server) listSessions(w http.ResponseWriter, r *http.Request) { response.Next = string(nextBytes) } response.ListSessionsResponse = *responseProto - writeJSONResponse(ctx, w, http.StatusOK, response) + apiutil.WriteJSONResponse(ctx, w, http.StatusOK, response) } // swagger:operation GET /health/ health @@ -375,19 +362,19 @@ func (a *apiV2SystemServer) health(w http.ResponseWriter, r *http.Request) { // If Ready is not set, the client doesn't want to know whether this node is // ready to receive client traffic. if !ready { - writeJSONResponse(ctx, w, 200, resp) + apiutil.WriteJSONResponse(ctx, w, 200, resp) return } if err := a.systemAdmin.checkReadinessForHealthCheck(ctx); err != nil { - apiV2InternalError(ctx, err, w) + srverrors.APIV2InternalError(ctx, err, w) return } - writeJSONResponse(ctx, w, 200, resp) + apiutil.WriteJSONResponse(ctx, w, 200, resp) } func (a *apiV2Server) health(w http.ResponseWriter, r *http.Request) { - writeJSONResponse(r.Context(), w, http.StatusNotImplemented, nil) + apiutil.WriteJSONResponse(r.Context(), w, http.StatusNotImplemented, nil) } // swagger:operation GET /rules/ rules @@ -414,7 +401,7 @@ func (a *apiV2Server) listRules(w http.ResponseWriter, r *http.Request) { a.promRuleExporter.ScrapeRegistry(r.Context()) response, err := a.promRuleExporter.PrintAsYAML() if err != nil { - apiV2InternalError(r.Context(), err, w) + srverrors.APIV2InternalError(r.Context(), err, w) return } w.Header().Set(httputil.ContentTypeHeader, httputil.PlaintextContentType) diff --git a/pkg/server/api_v2_error.go b/pkg/server/api_v2_error.go deleted file mode 100644 index 102ad0587d13..000000000000 --- a/pkg/server/api_v2_error.go +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2017 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package server - -import ( - "context" - "net/http" - - "github.com/cockroachdb/cockroach/pkg/util/log" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -var errAPIInternalErrorString = "An internal server error has occurred. Please check your CockroachDB logs for more details." - -var errAPIInternalError = status.Errorf( - codes.Internal, - errAPIInternalErrorString, -) - -// apiInternalError should be used to wrap server-side errors during API -// requests. This method records the contents of the error to the server log, -// and returns a standard GRPC error which is appropriate to return to the -// client. -func apiInternalError(ctx context.Context, err error) error { - log.ErrorfDepth(ctx, 1, "%s", err) - return errAPIInternalError -} - -// apiV2InternalError should be used to wrap server-side errors during API -// requests for V2 (non-GRPC) endpoints. This method records the contents -// of the error to the server log, and sends the standard internal error string -// over the http.ResponseWriter. -func apiV2InternalError(ctx context.Context, err error, w http.ResponseWriter) { - log.ErrorfDepth(ctx, 1, "%s", err) - http.Error(w, errAPIInternalErrorString, http.StatusInternalServerError) -} diff --git a/pkg/server/api_v2_ranges.go b/pkg/server/api_v2_ranges.go index 13bfd39737bb..a6589bcbe511 100644 --- a/pkg/server/api_v2_ranges.go +++ b/pkg/server/api_v2_ranges.go @@ -18,7 +18,10 @@ import ( "strings" "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/server/apiutil" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srverrors" "github.com/cockroachdb/cockroach/pkg/util" "github.com/gorilla/mux" ) @@ -108,11 +111,11 @@ type nodesResponse struct { func (a *apiV2SystemServer) listNodes(w http.ResponseWriter, r *http.Request) { ctx := r.Context() limit, offset := getSimplePaginationValues(r) - ctx = forwardHTTPAuthInfoToRPCCalls(ctx, r) + ctx = authserver.ForwardHTTPAuthInfoToRPCCalls(ctx, r) nodes, next, err := a.systemStatus.nodesHelper(ctx, limit, offset) if err != nil { - apiV2InternalError(ctx, err, w) + srverrors.APIV2InternalError(ctx, err, w) return } var resp nodesResponse @@ -140,11 +143,11 @@ func (a *apiV2SystemServer) listNodes(w http.ResponseWriter, r *http.Request) { LivenessStatus: int32(nodes.LivenessByNodeID[n.Desc.NodeID]), }) } - writeJSONResponse(ctx, w, 200, resp) + apiutil.WriteJSONResponse(ctx, w, 200, resp) } func (a *apiV2Server) listNodes(w http.ResponseWriter, r *http.Request) { - writeJSONResponse(r.Context(), w, http.StatusNotImplemented, nil) + apiutil.WriteJSONResponse(r.Context(), w, http.StatusNotImplemented, nil) } func parseRangeIDs(input string, w http.ResponseWriter) (ranges []roachpb.RangeID, ok bool) { @@ -202,7 +205,7 @@ type rangeResponse struct { // "$ref": "#/definitions/rangeResponse" func (a *apiV2Server) listRange(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - ctx = forwardHTTPAuthInfoToRPCCalls(ctx, r) + ctx = authserver.ForwardHTTPAuthInfoToRPCCalls(ctx, r) vars := mux.Vars(r) rangeID, err := strconv.ParseInt(vars["range_id"], 10, 64) if err != nil { @@ -246,10 +249,10 @@ func (a *apiV2Server) listRange(w http.ResponseWriter, r *http.Request) { if err := a.status.iterateNodes( ctx, fmt.Sprintf("details about range %d", rangeID), dialFn, nodeFn, responseFn, errorFn, ); err != nil { - apiV2InternalError(ctx, err, w) + srverrors.APIV2InternalError(ctx, err, w) return } - writeJSONResponse(ctx, w, 200, response) + apiutil.WriteJSONResponse(ctx, w, 200, response) } // rangeDescriptorInfo contains a subset of fields from the Cockroach-internal @@ -385,7 +388,7 @@ type nodeRangesResponse struct { // "$ref": "#/definitions/nodeRangesResponse" func (a *apiV2SystemServer) listNodeRanges(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - ctx = forwardHTTPAuthInfoToRPCCalls(ctx, r) + ctx = authserver.ForwardHTTPAuthInfoToRPCCalls(ctx, r) vars := mux.Vars(r) nodeIDStr := vars["node_id"] if nodeIDStr != "local" { @@ -407,7 +410,7 @@ func (a *apiV2SystemServer) listNodeRanges(w http.ResponseWriter, r *http.Reques limit, offset := getSimplePaginationValues(r) statusResp, next, err := a.systemStatus.rangesHelper(ctx, req, limit, offset) if err != nil { - apiV2InternalError(ctx, err, w) + srverrors.APIV2InternalError(ctx, err, w) return } resp := nodeRangesResponse{ @@ -419,11 +422,11 @@ func (a *apiV2SystemServer) listNodeRanges(w http.ResponseWriter, r *http.Reques ri.init(r) resp.Ranges = append(resp.Ranges, ri) } - writeJSONResponse(ctx, w, 200, resp) + apiutil.WriteJSONResponse(ctx, w, 200, resp) } func (a *apiV2Server) listNodeRanges(w http.ResponseWriter, r *http.Request) { - writeJSONResponse(r.Context(), w, http.StatusNotImplemented, nil) + apiutil.WriteJSONResponse(r.Context(), w, http.StatusNotImplemented, nil) } type responseError struct { @@ -504,7 +507,7 @@ type hotRangeInfo struct { // "$ref": "#/definitions/hotRangesResponse" func (a *apiV2Server) listHotRanges(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - ctx = forwardHTTPAuthInfoToRPCCalls(ctx, r) + ctx = authserver.ForwardHTTPAuthInfoToRPCCalls(ctx, r) nodeIDStr := r.URL.Query().Get("node_id") limit, start := getRPCPaginationValues(r) @@ -568,7 +571,7 @@ func (a *apiV2Server) listHotRanges(w http.ResponseWriter, r *http.Request) { nodeFn, responseFn, errorFn) if err != nil { - apiV2InternalError(ctx, err, w) + srverrors.APIV2InternalError(ctx, err, w) return } var nextBytes []byte @@ -577,5 +580,5 @@ func (a *apiV2Server) listHotRanges(w http.ResponseWriter, r *http.Request) { } else { response.Next = string(nextBytes) } - writeJSONResponse(ctx, w, 200, response) + apiutil.WriteJSONResponse(ctx, w, 200, response) } diff --git a/pkg/server/api_v2_ranges_test.go b/pkg/server/api_v2_ranges_test.go index b8fde8bed994..c6c92c3313bf 100644 --- a/pkg/server/api_v2_ranges_test.go +++ b/pkg/server/api_v2_ranges_test.go @@ -20,6 +20,8 @@ import ( "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/keys" "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" + "github.com/cockroachdb/cockroach/pkg/server/rangetestutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" @@ -29,14 +31,14 @@ import ( func TestHotRangesV2(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) - ts := startServer(t) + ts := rangetestutils.StartServer(t) defer ts.Stopper().Stop(context.Background()) var hotRangesResp hotRangesResponse client, err := ts.GetAdminHTTPClient() require.NoError(t, err) - req, err := http.NewRequest("GET", ts.AdminURL().WithPath(apiV2Path+"ranges/hot/").String(), nil) + req, err := http.NewRequest("GET", ts.AdminURL().WithPath(apiconstants.APIV2Path+"ranges/hot/").String(), nil) require.NoError(t, err) resp, err := client.Do(req) require.NoError(t, err) @@ -64,11 +66,11 @@ func TestNodeRangesV2(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) - ts := startServer(t) + ts := rangetestutils.StartServer(t) defer ts.Stopper().Stop(context.Background()) // Perform a scan to ensure that all the raft groups are initialized. - if _, err := ts.db.Scan(context.Background(), keys.LocalMax, roachpb.KeyMax, 0); err != nil { + if _, err := ts.DB().Scan(context.Background(), keys.LocalMax, roachpb.KeyMax, 0); err != nil { t.Fatal(err) } @@ -76,7 +78,7 @@ func TestNodeRangesV2(t *testing.T) { client, err := ts.GetAdminHTTPClient() require.NoError(t, err) - req, err := http.NewRequest("GET", ts.AdminURL().WithPath(apiV2Path+"nodes/local/ranges/").String(), nil) + req, err := http.NewRequest("GET", ts.AdminURL().WithPath(apiconstants.APIV2Path+"nodes/local/ranges/").String(), nil) require.NoError(t, err) resp, err := client.Do(req) require.NoError(t, err) @@ -99,7 +101,7 @@ func TestNodeRangesV2(t *testing.T) { // Take the first range ID, and call the ranges/ endpoint with it. rangeID := nodeRangesResp.Ranges[0].Desc.RangeID - req, err = http.NewRequest("GET", fmt.Sprintf("%s%sranges/%d/", ts.AdminURL(), apiV2Path, rangeID), nil) + req, err = http.NewRequest("GET", fmt.Sprintf("%s%sranges/%d/", ts.AdminURL(), apiconstants.APIV2Path, rangeID), nil) require.NoError(t, err) resp, err = client.Do(req) require.NoError(t, err) @@ -141,7 +143,7 @@ func TestNodesV2(t *testing.T) { client, err := ts1.GetAdminHTTPClient() require.NoError(t, err) - req, err := http.NewRequest("GET", ts1.AdminURL().WithPath(apiV2Path+"nodes/").String(), nil) + req, err := http.NewRequest("GET", ts1.AdminURL().WithPath(apiconstants.APIV2Path+"nodes/").String(), nil) require.NoError(t, err) resp, err := client.Do(req) require.NoError(t, err) diff --git a/pkg/server/api_v2_sql.go b/pkg/server/api_v2_sql.go index a5fc50fdab11..3d49e46bb348 100644 --- a/pkg/server/api_v2_sql.go +++ b/pkg/server/api_v2_sql.go @@ -19,6 +19,7 @@ import ( "net/http" "time" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/sql/catalog/colinfo" "github.com/cockroachdb/cockroach/pkg/sql/isql" "github.com/cockroachdb/cockroach/pkg/sql/parser" @@ -358,7 +359,7 @@ func (a *apiV2Server) execSQL(w http.ResponseWriter, r *http.Request) { } // The SQL username that owns this session. - username := userFromHTTPAuthInfoContext(ctx) + username := authserver.UserFromHTTPAuthInfoContext(ctx) options := []isql.TxnOption{ isql.WithPriority(admissionpb.NormalPri), diff --git a/pkg/server/api_v2_sql_schema.go b/pkg/server/api_v2_sql_schema.go index f26b905891f1..a783c10edeb7 100644 --- a/pkg/server/api_v2_sql_schema.go +++ b/pkg/server/api_v2_sql_schema.go @@ -13,7 +13,10 @@ package server import ( "net/http" + "github.com/cockroachdb/cockroach/pkg/server/apiutil" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srverrors" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" "github.com/gorilla/mux" @@ -62,7 +65,7 @@ type usersResponse struct { func (a *apiV2Server) listUsers(w http.ResponseWriter, r *http.Request) { limit, offset := getSimplePaginationValues(r) ctx := r.Context() - username := userFromHTTPAuthInfoContext(ctx) + username := authserver.UserFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) query := `SELECT username FROM system.users WHERE "isRole" = false ORDER BY username` @@ -81,7 +84,7 @@ func (a *apiV2Server) listUsers(w http.ResponseWriter, r *http.Request) { query, qargs..., ) if err != nil { - apiV2InternalError(ctx, err, w) + srverrors.APIV2InternalError(ctx, err, w) return } @@ -92,13 +95,13 @@ func (a *apiV2Server) listUsers(w http.ResponseWriter, r *http.Request) { resp.Users = append(resp.Users, serverpb.UsersResponse_User{Username: string(tree.MustBeDString(row[0]))}) } if err != nil { - apiV2InternalError(ctx, err, w) + srverrors.APIV2InternalError(ctx, err, w) return } if limit > 0 && len(resp.Users) >= limit { resp.Next = offset + len(resp.Users) } - writeJSONResponse(ctx, w, 200, resp) + apiutil.WriteJSONResponse(ctx, w, 200, resp) } // Response for listEvents. @@ -149,7 +152,7 @@ type eventsResponse struct { func (a *apiV2Server) listEvents(w http.ResponseWriter, r *http.Request) { limit, offset := getSimplePaginationValues(r) ctx := r.Context() - username := userFromHTTPAuthInfoContext(ctx) + username := authserver.UserFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) queryValues := r.URL.Query() @@ -162,14 +165,14 @@ func (a *apiV2Server) listEvents(w http.ResponseWriter, r *http.Request) { eventsResp, err := a.admin.eventsHelper( ctx, req, username, limit, offset, true /* redactEvents */) if err != nil { - apiV2InternalError(ctx, err, w) + srverrors.APIV2InternalError(ctx, err, w) return } resp.EventsResponse = *eventsResp if limit > 0 && len(resp.Events) >= limit { resp.Next = offset + len(resp.Events) } - writeJSONResponse(ctx, w, 200, resp) + apiutil.WriteJSONResponse(ctx, w, 200, resp) } // Response for listDatabases. @@ -213,20 +216,20 @@ type databasesResponse struct { func (a *apiV2Server) listDatabases(w http.ResponseWriter, r *http.Request) { limit, offset := getSimplePaginationValues(r) ctx := r.Context() - username := userFromHTTPAuthInfoContext(ctx) + username := authserver.UserFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) var resp databasesResponse req := &serverpb.DatabasesRequest{} dbsResp, err := a.admin.databasesHelper(ctx, req, username, limit, offset) if err != nil { - apiV2InternalError(ctx, err, w) + srverrors.APIV2InternalError(ctx, err, w) return } var databases interface{} databases, resp.Next = simplePaginate(dbsResp.Databases, limit, offset) resp.Databases = databases.([]string) - writeJSONResponse(ctx, w, 200, resp) + apiutil.WriteJSONResponse(ctx, w, 200, resp) } // Response for databaseDetails. @@ -263,7 +266,7 @@ type databaseDetailsResponse struct { // description: Database not found func (a *apiV2Server) databaseDetails(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - username := userFromHTTPAuthInfoContext(ctx) + username := authserver.UserFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) pathVars := mux.Vars(r) req := &serverpb.DatabaseDetailsRequest{ @@ -277,14 +280,14 @@ func (a *apiV2Server) databaseDetails(w http.ResponseWriter, r *http.Request) { if status.Code(err) == codes.NotFound || isNotFoundError(err) { http.Error(w, "database not found", http.StatusNotFound) } else { - apiV2InternalError(ctx, err, w) + srverrors.APIV2InternalError(ctx, err, w) } return } resp := databaseDetailsResponse{ DescriptorID: dbDetailsResp.DescriptorID, } - writeJSONResponse(ctx, w, 200, resp) + apiutil.WriteJSONResponse(ctx, w, 200, resp) } // Response for databaseGrants. @@ -337,7 +340,7 @@ type databaseGrantsResponse struct { func (a *apiV2Server) databaseGrants(w http.ResponseWriter, r *http.Request) { ctx := r.Context() limit, offset := getSimplePaginationValues(r) - username := userFromHTTPAuthInfoContext(ctx) + username := authserver.UserFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) pathVars := mux.Vars(r) req := &serverpb.DatabaseDetailsRequest{ @@ -348,7 +351,7 @@ func (a *apiV2Server) databaseGrants(w http.ResponseWriter, r *http.Request) { if status.Code(err) == codes.NotFound || isNotFoundError(err) { http.Error(w, "database not found", http.StatusNotFound) } else { - apiV2InternalError(ctx, err, w) + srverrors.APIV2InternalError(ctx, err, w) } return } @@ -356,7 +359,7 @@ func (a *apiV2Server) databaseGrants(w http.ResponseWriter, r *http.Request) { if limit > 0 && len(grants) >= limit { resp.Next = offset + len(grants) } - writeJSONResponse(ctx, w, 200, resp) + apiutil.WriteJSONResponse(ctx, w, 200, resp) } // Response for databaseTables. @@ -412,7 +415,7 @@ type databaseTablesResponse struct { func (a *apiV2Server) databaseTables(w http.ResponseWriter, r *http.Request) { ctx := r.Context() limit, offset := getSimplePaginationValues(r) - username := userFromHTTPAuthInfoContext(ctx) + username := authserver.UserFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) pathVars := mux.Vars(r) req := &serverpb.DatabaseDetailsRequest{ @@ -423,7 +426,7 @@ func (a *apiV2Server) databaseTables(w http.ResponseWriter, r *http.Request) { if status.Code(err) == codes.NotFound || isNotFoundError(err) { http.Error(w, "database not found", http.StatusNotFound) } else { - apiV2InternalError(ctx, err, w) + srverrors.APIV2InternalError(ctx, err, w) } return } @@ -431,7 +434,7 @@ func (a *apiV2Server) databaseTables(w http.ResponseWriter, r *http.Request) { if limit > 0 && len(tables) >= limit { resp.Next = offset + len(tables) } - writeJSONResponse(ctx, w, 200, resp) + apiutil.WriteJSONResponse(ctx, w, 200, resp) } // Response for tableDetails. @@ -473,7 +476,7 @@ type tableDetailsResponse serverpb.TableDetailsResponse // description: Database or table not found func (a *apiV2Server) tableDetails(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - username := userFromHTTPAuthInfoContext(ctx) + username := authserver.UserFromHTTPAuthInfoContext(ctx) ctx = a.sqlServer.AnnotateCtx(ctx) pathVars := mux.Vars(r) req := &serverpb.TableDetailsRequest{ @@ -486,9 +489,9 @@ func (a *apiV2Server) tableDetails(w http.ResponseWriter, r *http.Request) { if status.Code(err) == codes.NotFound || isNotFoundError(err) { http.Error(w, "database or table not found", http.StatusNotFound) } else { - apiV2InternalError(ctx, err, w) + srverrors.APIV2InternalError(ctx, err, w) } return } - writeJSONResponse(ctx, w, 200, tableDetailsResponse(*resp)) + apiutil.WriteJSONResponse(ctx, w, 200, tableDetailsResponse(*resp)) } diff --git a/pkg/server/api_v2_sql_schema_test.go b/pkg/server/api_v2_sql_schema_test.go index 67bd9d13bdf1..a3063375d708 100644 --- a/pkg/server/api_v2_sql_schema_test.go +++ b/pkg/server/api_v2_sql_schema_test.go @@ -18,6 +18,7 @@ import ( "testing" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" @@ -43,7 +44,7 @@ func TestUsersV2(t *testing.T) { client, err := ts1.GetAdminHTTPClient() require.NoError(t, err) - req, err := http.NewRequest("GET", ts1.AdminURL().WithPath(apiV2Path+"users/").String(), nil) + req, err := http.NewRequest("GET", ts1.AdminURL().WithPath(apiconstants.APIV2Path+"users/").String(), nil) require.NoError(t, err) resp, err := client.Do(req) require.NoError(t, err) @@ -84,7 +85,7 @@ func TestDatabasesTablesV2(t *testing.T) { require.NoError(t, err) defer client.CloseIdleConnections() - req, err := http.NewRequest("GET", ts1.AdminURL().WithPath(apiV2Path+"databases/").String(), nil) + req, err := http.NewRequest("GET", ts1.AdminURL().WithPath(apiconstants.APIV2Path+"databases/").String(), nil) require.NoError(t, err) resp, err := client.Do(req) require.NoError(t, err) @@ -97,7 +98,7 @@ func TestDatabasesTablesV2(t *testing.T) { require.Contains(t, dr.Databases, "testdb") - req, err = http.NewRequest("GET", ts1.AdminURL().WithPath(apiV2Path+"databases/testdb/").String(), nil) + req, err = http.NewRequest("GET", ts1.AdminURL().WithPath(apiconstants.APIV2Path+"databases/testdb/").String(), nil) require.NoError(t, err) resp, err = client.Do(req) require.NoError(t, err) @@ -108,7 +109,7 @@ func TestDatabasesTablesV2(t *testing.T) { require.NoError(t, json.NewDecoder(resp.Body).Decode(&ddr)) require.NoError(t, resp.Body.Close()) - req, err = http.NewRequest("GET", ts1.AdminURL().WithPath(apiV2Path+"databases/testdb/grants/").String(), nil) + req, err = http.NewRequest("GET", ts1.AdminURL().WithPath(apiconstants.APIV2Path+"databases/testdb/grants/").String(), nil) require.NoError(t, err) resp, err = client.Do(req) require.NoError(t, err) @@ -120,7 +121,7 @@ func TestDatabasesTablesV2(t *testing.T) { require.NoError(t, resp.Body.Close()) require.NotEmpty(t, dgr.Grants) - req, err = http.NewRequest("GET", ts1.AdminURL().WithPath(apiV2Path+"databases/testdb/tables/").String(), nil) + req, err = http.NewRequest("GET", ts1.AdminURL().WithPath(apiconstants.APIV2Path+"databases/testdb/tables/").String(), nil) require.NoError(t, err) resp, err = client.Do(req) require.NoError(t, err) @@ -133,7 +134,7 @@ func TestDatabasesTablesV2(t *testing.T) { require.Contains(t, dtr.TableNames, "public.testtable") // Test that querying the wrong db name returns 404. - req, err = http.NewRequest("GET", ts1.AdminURL().WithPath(apiV2Path+"databases/testdb2/tables/").String(), nil) + req, err = http.NewRequest("GET", ts1.AdminURL().WithPath(apiconstants.APIV2Path+"databases/testdb2/tables/").String(), nil) require.NoError(t, err) resp, err = client.Do(req) require.NoError(t, err) @@ -141,7 +142,7 @@ func TestDatabasesTablesV2(t *testing.T) { require.Equal(t, 404, resp.StatusCode) require.NoError(t, resp.Body.Close()) - req, err = http.NewRequest("GET", ts1.AdminURL().WithPath(apiV2Path+"databases/testdb/tables/public.testtable/").String(), nil) + req, err = http.NewRequest("GET", ts1.AdminURL().WithPath(apiconstants.APIV2Path+"databases/testdb/tables/public.testtable/").String(), nil) require.NoError(t, err) resp, err = client.Do(req) require.NoError(t, err) @@ -179,7 +180,7 @@ func TestEventsV2(t *testing.T) { client, err := ts1.GetAdminHTTPClient() require.NoError(t, err) - req, err := http.NewRequest("GET", ts1.AdminURL().WithPath(apiV2Path+"events/").String(), nil) + req, err := http.NewRequest("GET", ts1.AdminURL().WithPath(apiconstants.APIV2Path+"events/").String(), nil) require.NoError(t, err) resp, err := client.Do(req) require.NoError(t, err) diff --git a/pkg/server/api_v2_test.go b/pkg/server/api_v2_test.go index 7dd3747bf37f..e79857dd4b2c 100644 --- a/pkg/server/api_v2_test.go +++ b/pkg/server/api_v2_test.go @@ -23,6 +23,8 @@ import ( "time" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" @@ -31,7 +33,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/metric" "github.com/cockroachdb/cockroach/pkg/util/protoutil" "github.com/stretchr/testify/require" - "gopkg.in/yaml.v2" + yaml "gopkg.in/yaml.v2" ) func TestListSessionsV2(t *testing.T) { @@ -59,7 +61,7 @@ func TestListSessionsV2(t *testing.T) { }() doSessionsRequest := func(client http.Client, limit int, start string) listSessionsResponse { - req, err := http.NewRequest("GET", ts1.AdminURL().WithPath(apiV2Path+"sessions/").String(), nil) + req, err := http.NewRequest("GET", ts1.AdminURL().WithPath(apiconstants.APIV2Path+"sessions/").String(), nil) require.NoError(t, err) query := req.URL.Query() query.Add("exclude_closed_sessions", "true") @@ -120,7 +122,7 @@ func TestListSessionsV2(t *testing.T) { // A non-admin user cannot see sessions at all. nonAdminClient, err := ts1.GetAuthenticatedHTTPClient(false, serverutils.SingleTenantSession) require.NoError(t, err) - req, err := http.NewRequest("GET", ts1.AdminURL().WithPath(apiV2Path+"sessions/").String(), nil) + req, err := http.NewRequest("GET", ts1.AdminURL().WithPath(apiconstants.APIV2Path+"sessions/").String(), nil) require.NoError(t, err) resp, err := nonAdminClient.Do(req) require.NoError(t, err) @@ -145,7 +147,7 @@ func TestHealthV2(t *testing.T) { client, err := ts1.GetAdminHTTPClient() require.NoError(t, err) - req, err := http.NewRequest("GET", ts1.AdminURL().WithPath(apiV2Path+"health/").String(), nil) + req, err := http.NewRequest("GET", ts1.AdminURL().WithPath(apiconstants.APIV2Path+"health/").String(), nil) require.NoError(t, err) resp, err := client.Do(req) require.NoError(t, err) @@ -173,7 +175,7 @@ func TestRulesV2(t *testing.T) { client, err := ts.GetUnauthenticatedHTTPClient() require.NoError(t, err) - req, err := http.NewRequest("GET", ts.AdminURL().WithPath(apiV2Path+"rules/").String(), nil) + req, err := http.NewRequest("GET", ts.AdminURL().WithPath(apiconstants.APIV2Path+"rules/").String(), nil) require.NoError(t, err) resp, err := client.Do(req) require.NoError(t, err) @@ -229,7 +231,7 @@ func TestAuthV2(t *testing.T) { { name: "cookie auth with correct magic header", cookie: sessionEncoded, - header: apiV2UseCookieBasedAuth, + header: authserver.APIV2UseCookieBasedAuth, expectedStatus: http.StatusOK, }, { @@ -246,14 +248,14 @@ func TestAuthV2(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { - req, err := http.NewRequest("GET", ts.AdminURL().WithPath(apiV2Path+"sessions/").String(), nil) + req, err := http.NewRequest("GET", ts.AdminURL().WithPath(apiconstants.APIV2Path+"sessions/").String(), nil) require.NoError(t, err) if tc.header != "" { - req.Header.Set(apiV2AuthHeader, tc.header) + req.Header.Set(authserver.APIV2AuthHeader, tc.header) } if tc.cookie != "" { req.AddCookie(&http.Cookie{ - Name: SessionCookieName, + Name: authserver.SessionCookieName, Value: tc.cookie, }) } diff --git a/pkg/server/apiconstants/BUILD.bazel b/pkg/server/apiconstants/BUILD.bazel new file mode 100644 index 000000000000..b6251be75256 --- /dev/null +++ b/pkg/server/apiconstants/BUILD.bazel @@ -0,0 +1,12 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "apiconstants", + srcs = [ + "constants.go", + "testutils.go", + ], + importpath = "github.com/cockroachdb/cockroach/pkg/server/apiconstants", + visibility = ["//visibility:public"], + deps = ["//pkg/security/username"], +) diff --git a/pkg/server/apiconstants/constants.go b/pkg/server/apiconstants/constants.go new file mode 100644 index 000000000000..9f2222d3f5a0 --- /dev/null +++ b/pkg/server/apiconstants/constants.go @@ -0,0 +1,47 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package apiconstants + +const ( + // APIV2Path is the prefix for the RESTful v2 API. + APIV2Path = "/api/v2/" + + // AdminPrefix is the prefix for RESTful endpoints used to provide an + // administrative interface to the cockroach cluster. + AdminPrefix = "/_admin/v1/" + + // AdminHealth is the canonical URL path to the health endpoint. + // (This is also aliased via /health.) + AdminHealth = AdminPrefix + "health" + + // StatusPrefix is the root of the cluster statistics and metrics API. + StatusPrefix = "/_status/" + + // StatusVars exposes Prometheus metrics for monitoring consumption. + StatusVars = StatusPrefix + "vars" + + // LoadStatusVars exposes prometheus metrics for instant monitoring of CPU load. + LoadStatusVars = StatusPrefix + "load" + + // DefaultAPIEventLimit is the default maximum number of events + // returned by any endpoints returning events. + DefaultAPIEventLimit = 1000 + + // MaxConcurrentRequests is the maximum number of RPC fan-out requests + // that will be made at any point of time. + MaxConcurrentRequests = 100 + + // MaxConcurrentPaginatedRequests is the maximum number of RPC fan-out + // requests that will be made at any point of time for a row-limited / + // paginated request. This should be much lower than maxConcurrentRequests + // as too much concurrency here can result in wasted results. + MaxConcurrentPaginatedRequests = 4 +) diff --git a/pkg/server/apiconstants/testutils.go b/pkg/server/apiconstants/testutils.go new file mode 100644 index 000000000000..9e7b73c876df --- /dev/null +++ b/pkg/server/apiconstants/testutils.go @@ -0,0 +1,35 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package apiconstants + +import "github.com/cockroachdb/cockroach/pkg/security/username" + +const ( + // TestingUser is a username available in test servers, + // that has been granted the admin role. + TestingUser = "authentic_user" + + // TestingUserNoAdmin is a username available in test servers, + // that has not been granted the admin role. + TestingUserNoAdmin = "authentic_user_noadmin" +) + +// TestingUserName returns the username of the authenticated +// user with an admin role. +func TestingUserName() username.SQLUsername { + return username.MakeSQLUsernameFromPreNormalizedString(TestingUser) +} + +// TestingUserNameNoAdmin returns the username of the +// authenticated user without an admin role. +func TestingUserNameNoAdmin() username.SQLUsername { + return username.MakeSQLUsernameFromPreNormalizedString(TestingUserNoAdmin) +} diff --git a/pkg/server/apiutil/BUILD.bazel b/pkg/server/apiutil/BUILD.bazel new file mode 100644 index 000000000000..0d6bba28b829 --- /dev/null +++ b/pkg/server/apiutil/BUILD.bazel @@ -0,0 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "apiutil", + srcs = ["apiutil.go"], + importpath = "github.com/cockroachdb/cockroach/pkg/server/apiutil", + visibility = ["//visibility:public"], + deps = ["//pkg/server/srverrors"], +) diff --git a/pkg/server/apiutil/apiutil.go b/pkg/server/apiutil/apiutil.go new file mode 100644 index 000000000000..af8372a007f9 --- /dev/null +++ b/pkg/server/apiutil/apiutil.go @@ -0,0 +1,32 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package apiutil + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/cockroachdb/cockroach/pkg/server/srverrors" +) + +// WriteJSONResponse returns a payload as JSON to the HTTP client. +func WriteJSONResponse(ctx context.Context, w http.ResponseWriter, code int, payload interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + + res, err := json.Marshal(payload) + if err != nil { + srverrors.APIV2InternalError(ctx, err, w) + return + } + _, _ = w.Write(res) +} diff --git a/pkg/server/application_api/BUILD.bazel b/pkg/server/application_api/BUILD.bazel new file mode 100644 index 000000000000..e9a5a56496f1 --- /dev/null +++ b/pkg/server/application_api/BUILD.bazel @@ -0,0 +1,90 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "application_api", + srcs = ["doc.go"], + importpath = "github.com/cockroachdb/cockroach/pkg/server/application_api", + visibility = ["//visibility:public"], +) + +go_test( + name = "application_api_test", + srcs = [ + "activity_test.go", + "config_test.go", + "contention_test.go", + "dbconsole_test.go", + "events_test.go", + "insights_test.go", + "jobs_test.go", + "main_test.go", + "metrics_test.go", + "query_plan_test.go", + "schema_inspection_test.go", + "security_test.go", + "sessions_test.go", + "sql_stats_test.go", + "stmtdiag_test.go", + "storage_inspection_test.go", + "telemetry_test.go", + "util_test.go", + "zcfg_test.go", + ], + args = ["-test.timeout=295s"], + deps = [ + "//pkg/base", + "//pkg/config/zonepb", + "//pkg/jobs", + "//pkg/jobs/jobspb", + "//pkg/keys", + "//pkg/kv/kvclient/kvtenant", + "//pkg/kv/kvserver", + "//pkg/roachpb", + "//pkg/rpc", + "//pkg/security/securityassets", + "//pkg/security/securitytest", + "//pkg/security/username", + "//pkg/server", + "//pkg/server/apiconstants", + "//pkg/server/diagnostics/diagnosticspb", + "//pkg/server/rangetestutils", + "//pkg/server/serverpb", + "//pkg/server/srvtestutils", + "//pkg/settings", + "//pkg/settings/cluster", + "//pkg/spanconfig", + "//pkg/sql", + "//pkg/sql/appstatspb", + "//pkg/sql/catalog/descpb", + "//pkg/sql/clusterunique", + "//pkg/sql/idxusage", + "//pkg/sql/sem/catconstants", + "//pkg/sql/sem/tree", + "//pkg/sql/sessiondata", + "//pkg/sql/sqlstats", + "//pkg/sql/tests", + "//pkg/testutils", + "//pkg/testutils/serverutils", + "//pkg/testutils/skip", + "//pkg/testutils/sqlutils", + "//pkg/testutils/testcluster", + "//pkg/util/grunning", + "//pkg/util/hlc", + "//pkg/util/httputil", + "//pkg/util/leaktest", + "//pkg/util/log", + "//pkg/util/protoutil", + "//pkg/util/randident", + "//pkg/util/randutil", + "//pkg/util/safesql", + "//pkg/util/stop", + "//pkg/util/syncutil", + "//pkg/util/timeutil", + "//pkg/util/uuid", + "@com_github_cockroachdb_errors//:errors", + "@com_github_gogo_protobuf//proto", + "@com_github_kr_pretty//:pretty", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + ], +) diff --git a/pkg/server/application_api/activity_test.go b/pkg/server/application_api/activity_test.go new file mode 100644 index 000000000000..6d97f84b1222 --- /dev/null +++ b/pkg/server/application_api/activity_test.go @@ -0,0 +1,144 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package application_api_test + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/protoutil" + "github.com/stretchr/testify/require" +) + +func TestListActivitySecurity(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) + ts := s.(*server.TestServer) + defer ts.Stopper().Stop(ctx) + + expectedErrNoPermission := "this operation requires the VIEWACTIVITY or VIEWACTIVITYREDACTED system privilege" + contentionMsg := &serverpb.ListContentionEventsResponse{} + flowsMsg := &serverpb.ListDistSQLFlowsResponse{} + getErrors := func(msg protoutil.Message) []serverpb.ListActivityError { + switch r := msg.(type) { + case *serverpb.ListContentionEventsResponse: + return r.Errors + case *serverpb.ListDistSQLFlowsResponse: + return r.Errors + default: + t.Fatal("unexpected message type") + return nil + } + } + + // HTTP requests respect the authenticated username from the HTTP session. + testCases := []struct { + endpoint string + expectedErr string + requestWithAdmin bool + requestWithViewActivityGranted bool + response protoutil.Message + }{ + {"local_contention_events", expectedErrNoPermission, false, false, contentionMsg}, + {"contention_events", expectedErrNoPermission, false, false, contentionMsg}, + {"local_contention_events", "", true, false, contentionMsg}, + {"contention_events", "", true, false, contentionMsg}, + {"local_contention_events", "", false, true, contentionMsg}, + {"contention_events", "", false, true, contentionMsg}, + {"local_distsql_flows", expectedErrNoPermission, false, false, flowsMsg}, + {"distsql_flows", expectedErrNoPermission, false, false, flowsMsg}, + {"local_distsql_flows", "", true, false, flowsMsg}, + {"distsql_flows", "", true, false, flowsMsg}, + {"local_distsql_flows", "", false, true, flowsMsg}, + {"distsql_flows", "", false, true, flowsMsg}, + } + myUser := apiconstants.TestingUserNameNoAdmin().Normalized() + for _, tc := range testCases { + if tc.requestWithViewActivityGranted { + // Note that for this query to work, it is crucial that + // srvtestutils.GetStatusJSONProtoWithAdminOption below is called at least once, + // on the previous test case, so that the user exists. + _, err := db.Exec(fmt.Sprintf("ALTER USER %s VIEWACTIVITY", myUser)) + require.NoError(t, err) + } + err := srvtestutils.GetStatusJSONProtoWithAdminOption(s, tc.endpoint, tc.response, tc.requestWithAdmin) + responseErrors := getErrors(tc.response) + if tc.expectedErr == "" { + if err != nil || len(responseErrors) > 0 { + t.Errorf("unexpected failure listing the activity; error: %v; response errors: %v", + err, responseErrors) + } + } else { + respErr := "" + if len(responseErrors) > 0 { + respErr = responseErrors[0].Message + } + if !testutils.IsError(err, tc.expectedErr) && + !strings.Contains(respErr, tc.expectedErr) { + t.Errorf("did not get expected error %q when listing the activity from %s: %v", + tc.expectedErr, tc.endpoint, err) + } + } + if tc.requestWithViewActivityGranted { + _, err := db.Exec(fmt.Sprintf("ALTER USER %s NOVIEWACTIVITY", myUser)) + require.NoError(t, err) + } + } + + // gRPC requests behave as root and thus are always allowed. + rootConfig := testutils.NewTestBaseContext(username.RootUserName()) + rpcContext := srvtestutils.NewRPCTestContext(ctx, ts, rootConfig) + url := ts.ServingRPCAddr() + nodeID := ts.NodeID() + conn, err := rpcContext.GRPCDialNode(url, nodeID, rpc.DefaultClass).Connect(ctx) + if err != nil { + t.Fatal(err) + } + client := serverpb.NewStatusClient(conn) + { + request := &serverpb.ListContentionEventsRequest{} + if resp, err := client.ListLocalContentionEvents(ctx, request); err != nil || len(resp.Errors) > 0 { + t.Errorf("unexpected failure listing local contention events; error: %v; response errors: %v", + err, resp.Errors) + } + if resp, err := client.ListContentionEvents(ctx, request); err != nil || len(resp.Errors) > 0 { + t.Errorf("unexpected failure listing contention events; error: %v; response errors: %v", + err, resp.Errors) + } + } + { + request := &serverpb.ListDistSQLFlowsRequest{} + if resp, err := client.ListLocalDistSQLFlows(ctx, request); err != nil || len(resp.Errors) > 0 { + t.Errorf("unexpected failure listing local distsql flows; error: %v; response errors: %v", + err, resp.Errors) + } + if resp, err := client.ListDistSQLFlows(ctx, request); err != nil || len(resp.Errors) > 0 { + t.Errorf("unexpected failure listing distsql flows; error: %v; response errors: %v", + err, resp.Errors) + } + } +} diff --git a/pkg/server/application_api/config_test.go b/pkg/server/application_api/config_test.go new file mode 100644 index 000000000000..e73131e965b2 --- /dev/null +++ b/pkg/server/application_api/config_test.go @@ -0,0 +1,248 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package application_api_test + +import ( + "context" + "net/url" + "reflect" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/settings" + "github.com/cockroachdb/cockroach/pkg/settings/cluster" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/safesql" + "github.com/cockroachdb/cockroach/pkg/util/uuid" + "github.com/cockroachdb/errors" +) + +func TestAdminAPISettings(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + s, conn, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails + // with it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + }) + defer s.Stopper().Stop(context.Background()) + + // Any bool that defaults to true will work here. + const settingKey = "sql.metrics.statement_details.enabled" + st := s.ClusterSettings() + allKeys := settings.Keys(settings.ForSystemTenant) + + checkSetting := func(t *testing.T, k string, v serverpb.SettingsResponse_Value) { + ref, ok := settings.LookupForReporting(k, settings.ForSystemTenant) + if !ok { + t.Fatalf("%s: not found after initial lookup", k) + } + typ := ref.Typ() + + if !settings.TestingIsReportable(ref) { + if v.Value != "" && v.Value != "" { + t.Errorf("%s: expected redacted value for %v, got %s", k, ref, v.Value) + } + } else { + if ref.String(&st.SV) != v.Value { + t.Errorf("%s: expected value %v, got %s", k, ref, v.Value) + } + } + + if expectedPublic := ref.Visibility() == settings.Public; expectedPublic != v.Public { + t.Errorf("%s: expected public %v, got %v", k, expectedPublic, v.Public) + } + + if desc := ref.Description(); desc != v.Description { + t.Errorf("%s: expected description %s, got %s", k, desc, v.Description) + } + if typ != v.Type { + t.Errorf("%s: expected type %s, got %s", k, typ, v.Type) + } + if v.LastUpdated != nil { + db := sqlutils.MakeSQLRunner(conn) + q := safesql.NewQuery() + q.Append(`SELECT name, "lastUpdated" FROM system.settings WHERE name=$`, k) + rows := db.Query( + t, + q.String(), + q.QueryArguments()..., + ) + defer rows.Close() + if rows.Next() == false { + t.Errorf("missing sql row for %s", k) + } + } + } + + t.Run("all", func(t *testing.T) { + var resp serverpb.SettingsResponse + + if err := srvtestutils.GetAdminJSONProto(s, "settings", &resp); err != nil { + t.Fatal(err) + } + + // Check that all expected keys were returned + if len(allKeys) != len(resp.KeyValues) { + t.Fatalf("expected %d keys, got %d", len(allKeys), len(resp.KeyValues)) + } + for _, k := range allKeys { + if _, ok := resp.KeyValues[k]; !ok { + t.Fatalf("expected key %s not found in response", k) + } + } + + // Check that the test key is listed and the values come indeed + // from the settings package unchanged. + seenRef := false + for k, v := range resp.KeyValues { + if k == settingKey { + seenRef = true + if v.Value != "true" { + t.Errorf("%s: expected true, got %s", k, v.Value) + } + } + + checkSetting(t, k, v) + } + + if !seenRef { + t.Fatalf("failed to observe test setting %s, got %+v", settingKey, resp.KeyValues) + } + }) + + t.Run("one-by-one", func(t *testing.T) { + var resp serverpb.SettingsResponse + + // All the settings keys must be retrievable, and their + // type and description must match. + for _, k := range allKeys { + q := make(url.Values) + q.Add("keys", k) + url := "settings?" + q.Encode() + if err := srvtestutils.GetAdminJSONProto(s, url, &resp); err != nil { + t.Fatalf("%s: %v", k, err) + } + if len(resp.KeyValues) != 1 { + t.Fatalf("%s: expected 1 response, got %d", k, len(resp.KeyValues)) + } + v, ok := resp.KeyValues[k] + if !ok { + t.Fatalf("%s: response does not contain key", k) + } + + checkSetting(t, k, v) + } + }) +} + +func TestClusterAPI(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, db, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails + // with it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + }) + defer s.Stopper().Stop(context.Background()) + + testutils.RunTrueAndFalse(t, "reportingOn", func(t *testing.T, reportingOn bool) { + testutils.RunTrueAndFalse(t, "enterpriseOn", func(t *testing.T, enterpriseOn bool) { + // Override server license check. + if enterpriseOn { + old := base.CheckEnterpriseEnabled + base.CheckEnterpriseEnabled = func(_ *cluster.Settings, _ uuid.UUID, _ string) error { + return nil + } + defer func() { base.CheckEnterpriseEnabled = old }() + } + + if _, err := db.Exec(`SET CLUSTER SETTING diagnostics.reporting.enabled = $1`, reportingOn); err != nil { + t.Fatal(err) + } + + // We need to retry, because the cluster ID isn't set until after + // bootstrapping and because setting a cluster setting isn't necessarily + // instantaneous. + // + // Also note that there's a migration that affects `diagnostics.reporting.enabled`, + // so manipulating the cluster setting var directly is a bad idea. + testutils.SucceedsSoon(t, func() error { + var resp serverpb.ClusterResponse + if err := srvtestutils.GetAdminJSONProto(s, "cluster", &resp); err != nil { + return err + } + if a, e := resp.ClusterID, s.RPCContext().StorageClusterID.String(); a != e { + return errors.Errorf("cluster ID %s != expected %s", a, e) + } + if a, e := resp.ReportingEnabled, reportingOn; a != e { + return errors.Errorf("reportingEnabled = %t, wanted %t", a, e) + } + if a, e := resp.EnterpriseEnabled, enterpriseOn; a != e { + return errors.Errorf("enterpriseEnabled = %t, wanted %t", a, e) + } + return nil + }) + }) + }) +} + +func TestAdminAPILocations(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + s, conn, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails + // with it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + }) + defer s.Stopper().Stop(context.Background()) + sqlDB := sqlutils.MakeSQLRunner(conn) + + testLocations := []struct { + localityKey string + localityValue string + latitude float64 + longitude float64 + }{ + {"city", "Des Moines", 41.60054, -93.60911}, + {"city", "New York City", 40.71427, -74.00597}, + {"city", "Seattle", 47.60621, -122.33207}, + } + for _, loc := range testLocations { + sqlDB.Exec(t, + `INSERT INTO system.locations ("localityKey", "localityValue", latitude, longitude) VALUES ($1, $2, $3, $4)`, + loc.localityKey, loc.localityValue, loc.latitude, loc.longitude, + ) + } + var res serverpb.LocationsResponse + if err := srvtestutils.GetAdminJSONProtoWithAdminOption(s, "locations", &res, false /* isAdmin */); err != nil { + t.Fatal(err) + } + for i, loc := range testLocations { + expLoc := serverpb.LocationsResponse_Location{ + LocalityKey: loc.localityKey, + LocalityValue: loc.localityValue, + Latitude: loc.latitude, + Longitude: loc.longitude, + } + if !reflect.DeepEqual(res.Locations[i], expLoc) { + t.Errorf("%d: expected location %v, but got %v", i, expLoc, res.Locations[i]) + } + } +} diff --git a/pkg/server/application_api/contention_test.go b/pkg/server/application_api/contention_test.go new file mode 100644 index 000000000000..39dce8b83b8c --- /dev/null +++ b/pkg/server/application_api/contention_test.go @@ -0,0 +1,415 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package application_api_test + +import ( + "context" + "fmt" + "strconv" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" + "github.com/cockroachdb/cockroach/pkg/sql/clusterunique" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" + "github.com/cockroachdb/cockroach/pkg/sql/tests" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/uuid" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/require" +) + +func TestStatusAPIContentionEvents(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + params, _ := tests.CreateTestServerParams() + ctx := context.Background() + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{ + ServerArgs: params, + }) + + defer testCluster.Stopper().Stop(ctx) + + server1Conn := sqlutils.MakeSQLRunner(testCluster.ServerConn(0)) + server2Conn := sqlutils.MakeSQLRunner(testCluster.ServerConn(1)) + + contentionCountBefore := testCluster.Server(1).SQLServer().(*sql.Server). + Metrics.EngineMetrics.SQLContendedTxns.Count() + + sqlutils.CreateTable( + t, + testCluster.ServerConn(0), + "test", + "x INT PRIMARY KEY", + 1, /* numRows */ + sqlutils.ToRowFn(sqlutils.RowIdxFn), + ) + + testTableID, err := + strconv.Atoi(server1Conn.QueryStr(t, "SELECT 'test.test'::regclass::oid")[0][0]) + require.NoError(t, err) + + server1Conn.Exec(t, "USE test") + server2Conn.Exec(t, "USE test") + server2Conn.Exec(t, "SET application_name = 'contentionTest'") + + server1Conn.Exec(t, ` +SET TRACING=on; +BEGIN; +UPDATE test SET x = 100 WHERE x = 1; +`) + server2Conn.Exec(t, ` +SET TRACING=on; +BEGIN PRIORITY HIGH; +UPDATE test SET x = 1000 WHERE x = 1; +COMMIT; +SET TRACING=off; +`) + server1Conn.ExpectErr( + t, + "^pq: restart transaction.+", + ` +COMMIT; +SET TRACING=off; +`, + ) + + var resp serverpb.ListContentionEventsResponse + require.NoError(t, + srvtestutils.GetStatusJSONProtoWithAdminOption( + testCluster.Server(2), + "contention_events", + &resp, + true /* isAdmin */), + ) + + require.GreaterOrEqualf(t, len(resp.Events.IndexContentionEvents), 1, + "expecting at least 1 contention event, but found none") + + found := false + for _, event := range resp.Events.IndexContentionEvents { + if event.TableID == descpb.ID(testTableID) && event.IndexID == descpb.IndexID(1) { + found = true + break + } + } + + require.True(t, found, + "expect to find contention event for table %d, but found %+v", testTableID, resp) + + server1Conn.CheckQueryResults(t, ` + SELECT count(*) + FROM crdb_internal.statement_statistics + WHERE + (statistics -> 'execution_statistics' -> 'contentionTime' ->> 'mean')::FLOAT > 0 + AND app_name = 'contentionTest' +`, [][]string{{"1"}}) + + server1Conn.CheckQueryResults(t, ` + SELECT count(*) + FROM crdb_internal.transaction_statistics + WHERE + (statistics -> 'execution_statistics' -> 'contentionTime' ->> 'mean')::FLOAT > 0 + AND app_name = 'contentionTest' +`, [][]string{{"1"}}) + + contentionCountNow := testCluster.Server(1).SQLServer().(*sql.Server). + Metrics.EngineMetrics.SQLContendedTxns.Count() + + require.Greaterf(t, contentionCountNow, contentionCountBefore, + "expected txn contention count to be more than %d, but it is %d", + contentionCountBefore, contentionCountNow) +} + +func TestTransactionContentionEvents(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + + s, conn1, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(ctx) + + sqlutils.CreateTable( + t, + conn1, + "test", + "x INT PRIMARY KEY", + 1, /* numRows */ + sqlutils.ToRowFn(sqlutils.RowIdxFn), + ) + + conn2 := + serverutils.OpenDBConn(t, s.ServingSQLAddr(), "", false /* insecure */, s.Stopper()) + defer func() { + require.NoError(t, conn2.Close()) + }() + + sqlConn1 := sqlutils.MakeSQLRunner(conn1) + sqlConn1.Exec(t, "SET CLUSTER SETTING sql.contention.txn_id_cache.max_size = '1GB'") + sqlConn1.Exec(t, "USE test") + sqlConn1.Exec(t, "SET application_name='conn1'") + + sqlConn2 := sqlutils.MakeSQLRunner(conn2) + sqlConn2.Exec(t, "USE test") + sqlConn2.Exec(t, "SET application_name='conn2'") + + // Start the first transaction. + sqlConn1.Exec(t, ` + SET TRACING=on; + BEGIN; + `) + + txnID1 := sqlConn1.QueryStr(t, ` + SELECT txn_id + FROM [SHOW TRANSACTIONS] + WHERE application_name = 'conn1'`)[0][0] + + sqlConn1.Exec(t, "UPDATE test SET x = 100 WHERE x = 1") + + // Start the second transaction with higher priority. This will cause the + // first transaction to be aborted. + sqlConn2.Exec(t, ` + SET TRACING=on; + BEGIN PRIORITY HIGH; + `) + + txnID2 := sqlConn1.QueryStr(t, ` + SELECT txn_id + FROM [SHOW TRANSACTIONS] + WHERE application_name = 'conn2'`)[0][0] + + sqlConn2.Exec(t, ` + UPDATE test SET x = 1000 WHERE x = 1; + COMMIT;`) + + // Ensure that the first transaction is aborted. + sqlConn1.ExpectErr( + t, + "^pq: restart transaction.+", + ` + COMMIT; + SET TRACING=off;`, + ) + + // Sanity check to see the first transaction has been aborted. + sqlConn1.CheckQueryResults(t, "SELECT * FROM test", + [][]string{{"1000"}}) + + txnIDCache := s.SQLServer().(*sql.Server).GetTxnIDCache() + + // Since contention event store's resolver only retries once in the case of + // missing txn fingerprint ID for a given txnID, we ensure that the txnIDCache + // write buffer is properly drained before we go on to test the contention + // registry. + testutils.SucceedsSoon(t, func() error { + txnIDCache.DrainWriteBuffer() + + txnID, err := uuid.FromString(txnID1) + require.NoError(t, err) + + if _, found := txnIDCache.Lookup(txnID); !found { + return errors.Newf("expected the txn fingerprint ID for txn %s to be "+ + "stored in txnID cache, but it is not", txnID1) + } + + txnID, err = uuid.FromString(txnID2) + require.NoError(t, err) + + if _, found := txnIDCache.Lookup(txnID); !found { + return errors.Newf("expected the txn fingerprint ID for txn %s to be "+ + "stored in txnID cache, but it is not", txnID2) + } + + return nil + }) + + testutils.SucceedsWithin(t, func() error { + err := s.ExecutorConfig().(sql.ExecutorConfig).ContentionRegistry.FlushEventsForTest(ctx) + require.NoError(t, err) + + notEmpty := sqlConn1.QueryStr(t, ` + SELECT count(*) > 0 + FROM crdb_internal.transaction_contention_events + WHERE + blocking_txn_id = $1::UUID AND + waiting_txn_id = $2::UUID AND + encode(blocking_txn_fingerprint_id, 'hex') != '0000000000000000' AND + encode(waiting_txn_fingerprint_id, 'hex') != '0000000000000000' AND + length(contending_key) > 0`, txnID1, txnID2)[0][0] + + if notEmpty != "true" { + return errors.Newf("expected at least one contention events, but " + + "none was found") + } + + return nil + }, 10*time.Second) + + nonAdminUser := apiconstants.TestingUserNameNoAdmin().Normalized() + adminUser := apiconstants.TestingUserName().Normalized() + + // N.B. We need both test users to be created before establishing SQL + // connections with their usernames. We use + // srvtestutils.GetStatusJSONProtoWithAdminOption() to implicitly create those + // usernames instead of regular CREATE USER statements, since the helper + // srvtestutils.GetStatusJSONProtoWithAdminOption() couldn't handle the case where + // those two usernames already exist. + // This is the reason why we don't check for returning errors. + _ = srvtestutils.GetStatusJSONProtoWithAdminOption( + s, + "transactioncontentionevents", + &serverpb.TransactionContentionEventsResponse{}, + true, /* isAdmin */ + ) + _ = srvtestutils.GetStatusJSONProtoWithAdminOption( + s, + "transactioncontentionevents", + &serverpb.TransactionContentionEventsResponse{}, + false, /* isAdmin */ + ) + + type testCase struct { + testName string + userName string + canViewContendingKey bool + grantPerm string + revokePerm string + isAdmin bool + } + + tcs := []testCase{ + { + testName: "nopermission", + userName: nonAdminUser, + canViewContendingKey: false, + }, + { + testName: "viewactivityredacted", + userName: nonAdminUser, + canViewContendingKey: false, + grantPerm: fmt.Sprintf("ALTER USER %s VIEWACTIVITYREDACTED", nonAdminUser), + revokePerm: fmt.Sprintf("ALTER USER %s NOVIEWACTIVITYREDACTED", nonAdminUser), + }, + { + testName: "viewactivity", + userName: nonAdminUser, + canViewContendingKey: true, + grantPerm: fmt.Sprintf("ALTER USER %s VIEWACTIVITY", nonAdminUser), + revokePerm: fmt.Sprintf("ALTER USER %s NOVIEWACTIVITY", nonAdminUser), + }, + { + testName: "viewactivity_and_viewactivtyredacted", + userName: nonAdminUser, + canViewContendingKey: false, + grantPerm: fmt.Sprintf(`ALTER USER %s VIEWACTIVITY; + ALTER USER %s VIEWACTIVITYREDACTED;`, + nonAdminUser, nonAdminUser), + revokePerm: fmt.Sprintf(`ALTER USER %s NOVIEWACTIVITY; + ALTER USER %s NOVIEWACTIVITYREDACTED;`, + nonAdminUser, nonAdminUser), + }, + { + testName: "adminuser", + userName: adminUser, + canViewContendingKey: true, + isAdmin: true, + }, + } + + expectationStringHelper := func(canViewContendingKey bool) string { + if canViewContendingKey { + return "able to view contending keys" + } + return "not able to view contending keys" + } + + for _, tc := range tcs { + t.Run(tc.testName, func(t *testing.T) { + if tc.grantPerm != "" { + sqlConn1.Exec(t, tc.grantPerm) + } + if tc.revokePerm != "" { + defer sqlConn1.Exec(t, tc.revokePerm) + } + + expectationStr := expectationStringHelper(tc.canViewContendingKey) + t.Run("sql_cli", func(t *testing.T) { + // Check we have proper permission control in SQL CLI. We use internal + // executor here since we can easily override the username without opening + // new SQL sessions. + row, err := s.InternalExecutor().(*sql.InternalExecutor).QueryRowEx( + ctx, + "test-contending-key-redaction", + nil, /* txn */ + sessiondata.InternalExecutorOverride{ + User: username.MakeSQLUsernameFromPreNormalizedString(tc.userName), + }, + ` + SELECT count(*) + FROM crdb_internal.transaction_contention_events + WHERE length(contending_key) > 0`, + ) + if tc.testName == "nopermission" { + require.Contains(t, err.Error(), "does not have VIEWACTIVITY") + } else { + require.NoError(t, err) + visibleContendingKeysCount := tree.MustBeDInt(row[0]) + + require.Equal(t, tc.canViewContendingKey, visibleContendingKeysCount > 0, + "expected to %s, but %d keys have been retrieved", + expectationStr, visibleContendingKeysCount) + } + }) + + t.Run("http", func(t *testing.T) { + // Check we have proper permission control in RPC/HTTP endpoint. + resp := serverpb.TransactionContentionEventsResponse{} + err := srvtestutils.GetStatusJSONProtoWithAdminOption( + s, + "transactioncontentionevents", + &resp, + tc.isAdmin, + ) + + if tc.testName == "nopermission" { + require.Contains(t, err.Error(), "status: 403") + } else { + require.NoError(t, err) + } + + for _, event := range resp.Events { + require.NotEqual(t, event.WaitingStmtFingerprintID, 0) + require.NotEqual(t, event.WaitingStmtID.String(), clusterunique.ID{}.String()) + + require.Equal(t, tc.canViewContendingKey, len(event.BlockingEvent.Key) > 0, + "expected to %s, but the contending key has length of %d", + expectationStr, + len(event.BlockingEvent.Key), + ) + } + }) + + }) + } +} diff --git a/pkg/server/application_api/dbconsole_test.go b/pkg/server/application_api/dbconsole_test.go new file mode 100644 index 000000000000..8b64a9b2a2e4 --- /dev/null +++ b/pkg/server/application_api/dbconsole_test.go @@ -0,0 +1,183 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package application_api_test + +import ( + "bytes" + "context" + "net/url" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" +) + +// TestAdminAPIUIData checks that UI customizations are properly +// persisted for both admin and non-admin users. +func TestAdminAPIUIData(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails + // with it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + }) + defer s.Stopper().Stop(context.Background()) + + testutils.RunTrueAndFalse(t, "isAdmin", func(t *testing.T, isAdmin bool) { + start := timeutil.Now() + + mustSetUIData := func(keyValues map[string][]byte) { + if err := srvtestutils.PostAdminJSONProtoWithAdminOption(s, "uidata", &serverpb.SetUIDataRequest{ + KeyValues: keyValues, + }, &serverpb.SetUIDataResponse{}, isAdmin); err != nil { + t.Fatal(err) + } + } + + expectKeyValues := func(expKeyValues map[string][]byte) { + var resp serverpb.GetUIDataResponse + queryValues := make(url.Values) + for key := range expKeyValues { + queryValues.Add("keys", key) + } + url := "uidata?" + queryValues.Encode() + if err := srvtestutils.GetAdminJSONProtoWithAdminOption(s, url, &resp, isAdmin); err != nil { + t.Fatal(err) + } + // Do a two-way comparison. We can't use reflect.DeepEqual(), because + // resp.KeyValues has timestamps and expKeyValues doesn't. + for key, actualVal := range resp.KeyValues { + if a, e := actualVal.Value, expKeyValues[key]; !bytes.Equal(a, e) { + t.Fatalf("key %s: value = %v, expected = %v", key, a, e) + } + } + for key, expVal := range expKeyValues { + if a, e := resp.KeyValues[key].Value, expVal; !bytes.Equal(a, e) { + t.Fatalf("key %s: value = %v, expected = %v", key, a, e) + } + } + + // Sanity check LastUpdated. + for _, val := range resp.KeyValues { + now := timeutil.Now() + if val.LastUpdated.Before(start) { + t.Fatalf("val.LastUpdated %s < start %s", val.LastUpdated, start) + } + if val.LastUpdated.After(now) { + t.Fatalf("val.LastUpdated %s > now %s", val.LastUpdated, now) + } + } + } + + expectValueEquals := func(key string, expVal []byte) { + expectKeyValues(map[string][]byte{key: expVal}) + } + + expectKeyNotFound := func(key string) { + var resp serverpb.GetUIDataResponse + url := "uidata?keys=" + key + if err := srvtestutils.GetAdminJSONProtoWithAdminOption(s, url, &resp, isAdmin); err != nil { + t.Fatal(err) + } + if len(resp.KeyValues) != 0 { + t.Fatal("key unexpectedly found") + } + } + + // Basic tests. + var badResp serverpb.GetUIDataResponse + const errPattern = "400 Bad Request" + if err := srvtestutils.GetAdminJSONProtoWithAdminOption(s, "uidata", &badResp, isAdmin); !testutils.IsError(err, errPattern) { + t.Fatalf("unexpected error: %v\nexpected: %s", err, errPattern) + } + + mustSetUIData(map[string][]byte{"k1": []byte("v1")}) + expectValueEquals("k1", []byte("v1")) + + expectKeyNotFound("NON_EXISTENT_KEY") + + mustSetUIData(map[string][]byte{ + "k2": []byte("v2"), + "k3": []byte("v3"), + }) + expectValueEquals("k2", []byte("v2")) + expectValueEquals("k3", []byte("v3")) + expectKeyValues(map[string][]byte{ + "k2": []byte("v2"), + "k3": []byte("v3"), + }) + + mustSetUIData(map[string][]byte{"k2": []byte("v2-updated")}) + expectKeyValues(map[string][]byte{ + "k2": []byte("v2-updated"), + "k3": []byte("v3"), + }) + + // Write a binary blob with all possible byte values, then verify it. + var buf bytes.Buffer + for i := 0; i < 997; i++ { + buf.WriteByte(byte(i % 256)) + } + mustSetUIData(map[string][]byte{"bin": buf.Bytes()}) + expectValueEquals("bin", buf.Bytes()) + }) +} + +// TestAdminAPIUISeparateData check that separate users have separate customizations. +func TestAdminAPIUISeparateData(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails + // with it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + }) + defer s.Stopper().Stop(context.Background()) + + // Make a setting for an admin user. + if err := srvtestutils.PostAdminJSONProtoWithAdminOption(s, "uidata", + &serverpb.SetUIDataRequest{KeyValues: map[string][]byte{"k": []byte("v1")}}, + &serverpb.SetUIDataResponse{}, + true /*isAdmin*/); err != nil { + t.Fatal(err) + } + + // Make a setting for a non-admin user. + if err := srvtestutils.PostAdminJSONProtoWithAdminOption(s, "uidata", + &serverpb.SetUIDataRequest{KeyValues: map[string][]byte{"k": []byte("v2")}}, + &serverpb.SetUIDataResponse{}, + false /*isAdmin*/); err != nil { + t.Fatal(err) + } + + var resp serverpb.GetUIDataResponse + url := "uidata?keys=k" + + if err := srvtestutils.GetAdminJSONProtoWithAdminOption(s, url, &resp, true /* isAdmin */); err != nil { + t.Fatal(err) + } + if len(resp.KeyValues) != 1 || !bytes.Equal(resp.KeyValues["k"].Value, []byte("v1")) { + t.Fatalf("unexpected admin values: %+v", resp.KeyValues) + } + if err := srvtestutils.GetAdminJSONProtoWithAdminOption(s, url, &resp, false /* isAdmin */); err != nil { + t.Fatal(err) + } + if len(resp.KeyValues) != 1 || !bytes.Equal(resp.KeyValues["k"].Value, []byte("v2")) { + t.Fatalf("unexpected non-admin values: %+v", resp.KeyValues) + } +} diff --git a/pkg/server/application_api/doc.go b/pkg/server/application_api/doc.go new file mode 100644 index 000000000000..fc07bd35c9e6 --- /dev/null +++ b/pkg/server/application_api/doc.go @@ -0,0 +1,15 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +// Package application_api pertains to the RPC and HTTP APIs exposed by +// the application layers, including SQL and tenant-scoped HTTP. +// Storage-level APIs (e.g. KV node inspection) are in the +// storage_api package. +package application_api diff --git a/pkg/server/application_api/events_test.go b/pkg/server/application_api/events_test.go new file mode 100644 index 000000000000..77911ef7e0aa --- /dev/null +++ b/pkg/server/application_api/events_test.go @@ -0,0 +1,155 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package application_api_test + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" +) + +func TestAdminAPIEvents(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, db, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails + // with it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + }) + defer s.Stopper().Stop(context.Background()) + + setupQueries := []string{ + "CREATE DATABASE api_test", + "CREATE TABLE api_test.tbl1 (a INT)", + "CREATE TABLE api_test.tbl2 (a INT)", + "CREATE TABLE api_test.tbl3 (a INT)", + "DROP TABLE api_test.tbl1", + "DROP TABLE api_test.tbl2", + "SET CLUSTER SETTING cluster.organization = 'somestring';", + } + for _, q := range setupQueries { + if _, err := db.Exec(q); err != nil { + t.Fatalf("error executing '%s': %s", q, err) + } + } + + const allEvents = "" + type testcase struct { + eventType string + hasLimit bool + limit int + unredacted bool + expCount int + } + testcases := []testcase{ + {"node_join", false, 0, false, 1}, + {"node_restart", false, 0, false, 0}, + {"drop_database", false, 0, false, 0}, + {"create_database", false, 0, false, 3}, + {"drop_table", false, 0, false, 2}, + {"create_table", false, 0, false, 3}, + {"set_cluster_setting", false, 0, false, 2}, + // We use limit=true with no limit here because otherwise the + // expCount will mess up the expected total count below. + {"set_cluster_setting", true, 0, true, 2}, + {"create_table", true, 0, false, 3}, + {"create_table", true, -1, false, 3}, + {"create_table", true, 2, false, 2}, + } + minTotalEvents := 0 + for _, tc := range testcases { + if !tc.hasLimit { + minTotalEvents += tc.expCount + } + } + testcases = append(testcases, testcase{allEvents, false, 0, false, minTotalEvents}) + + for i, tc := range testcases { + url := "events" + if tc.eventType != allEvents { + url += "?type=" + tc.eventType + if tc.hasLimit { + url += fmt.Sprintf("&limit=%d", tc.limit) + } + if tc.unredacted { + url += "&unredacted_events=true" + } + } + + t.Run(url, func(t *testing.T) { + var resp serverpb.EventsResponse + if err := srvtestutils.GetAdminJSONProto(s, url, &resp); err != nil { + t.Fatal(err) + } + if tc.eventType == allEvents { + // When retrieving all events, we expect that there will be some system + // database migrations, unrelated to this test, that add to the log entry + // count. So, we do a looser check here. + if a, min := len(resp.Events), tc.expCount; a < tc.expCount { + t.Fatalf("%d: total # of events %d < min %d", i, a, min) + } + } else { + if a, e := len(resp.Events), tc.expCount; a != e { + t.Fatalf("%d: # of %s events %d != expected %d", i, tc.eventType, a, e) + } + } + + // Ensure we don't have blank / nonsensical fields. + for _, e := range resp.Events { + if e.Timestamp == (time.Time{}) { + t.Errorf("%d: missing/empty timestamp", i) + } + + if len(tc.eventType) > 0 { + if a, e := e.EventType, tc.eventType; a != e { + t.Errorf("%d: event type %s != expected %s", i, a, e) + } + } else { + if len(e.EventType) == 0 { + t.Errorf("%d: missing event type in event", i) + } + } + + isSettingChange := e.EventType == "set_cluster_setting" + + if e.ReportingID == 0 { + t.Errorf("%d: missing/empty ReportingID", i) + } + if len(e.Info) == 0 { + t.Errorf("%d: missing/empty Info", i) + } + if isSettingChange && strings.Contains(e.Info, "cluster.organization") { + if tc.unredacted { + if !strings.Contains(e.Info, "somestring") { + t.Errorf("%d: require 'somestring' in Info", i) + } + } else { + if strings.Contains(e.Info, "somestring") { + t.Errorf("%d: un-redacted 'somestring' in Info", i) + } + } + } + if len(e.UniqueID) == 0 { + t.Errorf("%d: missing/empty UniqueID", i) + } + } + }) + } +} diff --git a/pkg/server/application_api/insights_test.go b/pkg/server/application_api/insights_test.go new file mode 100644 index 000000000000..7622f9762c3c --- /dev/null +++ b/pkg/server/application_api/insights_test.go @@ -0,0 +1,229 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package application_api_test + +import ( + "context" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/sql/idxusage" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/stretchr/testify/require" +) + +type stubUnusedIndexTime struct { + syncutil.RWMutex + current time.Time + lastRead time.Time + createdAt *time.Time +} + +func (s *stubUnusedIndexTime) setCurrent(t time.Time) { + s.RWMutex.Lock() + defer s.RWMutex.Unlock() + s.current = t +} + +func (s *stubUnusedIndexTime) setLastRead(t time.Time) { + s.RWMutex.Lock() + defer s.RWMutex.Unlock() + s.lastRead = t +} + +func (s *stubUnusedIndexTime) setCreatedAt(t *time.Time) { + s.RWMutex.Lock() + defer s.RWMutex.Unlock() + s.createdAt = t +} + +func (s *stubUnusedIndexTime) getCurrent() time.Time { + s.RWMutex.RLock() + defer s.RWMutex.RUnlock() + return s.current +} + +func (s *stubUnusedIndexTime) getLastRead() time.Time { + s.RWMutex.RLock() + defer s.RWMutex.RUnlock() + return s.lastRead +} + +func (s *stubUnusedIndexTime) getCreatedAt() *time.Time { + s.RWMutex.RLock() + defer s.RWMutex.RUnlock() + return s.createdAt +} + +func TestDatabaseAndTableIndexRecommendations(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + stubTime := stubUnusedIndexTime{} + stubDropUnusedDuration := time.Hour + + s, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails + // with it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + Knobs: base.TestingKnobs{ + UnusedIndexRecommendKnobs: &idxusage.UnusedIndexRecommendationTestingKnobs{ + GetCreatedAt: stubTime.getCreatedAt, + GetLastRead: stubTime.getLastRead, + GetCurrentTime: stubTime.getCurrent, + }, + }, + }) + idxusage.DropUnusedIndexDuration.Override(context.Background(), &s.ClusterSettings().SV, stubDropUnusedDuration) + defer s.Stopper().Stop(context.Background()) + + db := sqlutils.MakeSQLRunner(sqlDB) + db.Exec(t, "CREATE DATABASE test") + db.Exec(t, "USE test") + // Create a table and secondary index. + db.Exec(t, "CREATE TABLE test.test_table (num INT PRIMARY KEY, letter char)") + db.Exec(t, "CREATE INDEX test_idx ON test.test_table (letter)") + + // Test when last read does not exist and there is no creation time. Expect + // an index recommendation (index never used). + stubTime.setLastRead(time.Time{}) + stubTime.setCreatedAt(nil) + + // Test database details endpoint. + var dbDetails serverpb.DatabaseDetailsResponse + if err := srvtestutils.GetAdminJSONProto( + s, + "databases/test?include_stats=true", + &dbDetails, + ); err != nil { + t.Fatal(err) + } + // Expect 1 index recommendation (no index recommendation on primary index). + require.Equal(t, int32(1), dbDetails.Stats.NumIndexRecommendations) + + // Test table details endpoint. + var tableDetails serverpb.TableDetailsResponse + if err := srvtestutils.GetAdminJSONProto(s, "databases/test/tables/test_table", &tableDetails); err != nil { + t.Fatal(err) + } + require.Equal(t, true, tableDetails.HasIndexRecommendations) + + // Test when last read does not exist and there is a creation time, and the + // unused index duration has been exceeded. Expect an index recommendation. + currentTime := timeutil.Now() + createdTime := currentTime.Add(-stubDropUnusedDuration) + stubTime.setCurrent(currentTime) + stubTime.setLastRead(time.Time{}) + stubTime.setCreatedAt(&createdTime) + + // Test database details endpoint. + dbDetails = serverpb.DatabaseDetailsResponse{} + if err := srvtestutils.GetAdminJSONProto( + s, + "databases/test?include_stats=true", + &dbDetails, + ); err != nil { + t.Fatal(err) + } + require.Equal(t, int32(1), dbDetails.Stats.NumIndexRecommendations) + + // Test table details endpoint. + tableDetails = serverpb.TableDetailsResponse{} + if err := srvtestutils.GetAdminJSONProto(s, "databases/test/tables/test_table", &tableDetails); err != nil { + t.Fatal(err) + } + require.Equal(t, true, tableDetails.HasIndexRecommendations) + + // Test when last read does not exist and there is a creation time, and the + // unused index duration has not been exceeded. Expect no index + // recommendation. + currentTime = timeutil.Now() + stubTime.setCurrent(currentTime) + stubTime.setLastRead(time.Time{}) + stubTime.setCreatedAt(¤tTime) + + // Test database details endpoint. + dbDetails = serverpb.DatabaseDetailsResponse{} + if err := srvtestutils.GetAdminJSONProto( + s, + "databases/test?include_stats=true", + &dbDetails, + ); err != nil { + t.Fatal(err) + } + require.Equal(t, int32(0), dbDetails.Stats.NumIndexRecommendations) + + // Test table details endpoint. + tableDetails = serverpb.TableDetailsResponse{} + if err := srvtestutils.GetAdminJSONProto(s, "databases/test/tables/test_table", &tableDetails); err != nil { + t.Fatal(err) + } + require.Equal(t, false, tableDetails.HasIndexRecommendations) + + // Test when last read exists and the unused index duration has been + // exceeded. Expect an index recommendation. + currentTime = timeutil.Now() + lastRead := currentTime.Add(-stubDropUnusedDuration) + stubTime.setCurrent(currentTime) + stubTime.setLastRead(lastRead) + stubTime.setCreatedAt(nil) + + // Test database details endpoint. + dbDetails = serverpb.DatabaseDetailsResponse{} + if err := srvtestutils.GetAdminJSONProto( + s, + "databases/test?include_stats=true", + &dbDetails, + ); err != nil { + t.Fatal(err) + } + require.Equal(t, int32(1), dbDetails.Stats.NumIndexRecommendations) + + // Test table details endpoint. + tableDetails = serverpb.TableDetailsResponse{} + if err := srvtestutils.GetAdminJSONProto(s, "databases/test/tables/test_table", &tableDetails); err != nil { + t.Fatal(err) + } + require.Equal(t, true, tableDetails.HasIndexRecommendations) + + // Test when last read exists and the unused index duration has not been + // exceeded. Expect no index recommendation. + currentTime = timeutil.Now() + stubTime.setCurrent(currentTime) + stubTime.setLastRead(currentTime) + stubTime.setCreatedAt(nil) + + // Test database details endpoint. + dbDetails = serverpb.DatabaseDetailsResponse{} + if err := srvtestutils.GetAdminJSONProto( + s, + "databases/test?include_stats=true", + &dbDetails, + ); err != nil { + t.Fatal(err) + } + require.Equal(t, int32(0), dbDetails.Stats.NumIndexRecommendations) + + // Test table details endpoint. + tableDetails = serverpb.TableDetailsResponse{} + if err := srvtestutils.GetAdminJSONProto(s, "databases/test/tables/test_table", &tableDetails); err != nil { + t.Fatal(err) + } + require.Equal(t, false, tableDetails.HasIndexRecommendations) +} diff --git a/pkg/server/application_api/jobs_test.go b/pkg/server/application_api/jobs_test.go new file mode 100644 index 000000000000..d97c3ee8bb3e --- /dev/null +++ b/pkg/server/application_api/jobs_test.go @@ -0,0 +1,478 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package application_api_test + +import ( + "context" + "fmt" + "math" + "reflect" + "sort" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/jobs" + "github.com/cockroachdb/cockroach/pkg/jobs/jobspb" + "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/hlc" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/protoutil" + "github.com/cockroachdb/cockroach/pkg/util/safesql" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/require" +) + +// getSystemJobIDsForNonAutoJobs queries the jobs table for all job IDs that have +// the given status. Sorted by decreasing creation time. +func getSystemJobIDsForNonAutoJobs( + t testing.TB, db *sqlutils.SQLRunner, status jobs.Status, +) []int64 { + q := safesql.NewQuery() + q.Append(`SELECT job_id FROM crdb_internal.jobs WHERE status=$`, status) + q.Append(` AND (`) + for i, jobType := range jobspb.AutomaticJobTypes { + q.Append(`job_type != $`, jobType.String()) + if i < len(jobspb.AutomaticJobTypes)-1 { + q.Append(" AND ") + } + } + q.Append(` OR job_type IS NULL)`) + q.Append(` ORDER BY created DESC`) + rows := db.Query( + t, + q.String(), + q.QueryArguments()..., + ) + defer rows.Close() + + res := []int64{} + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + t.Fatal(err) + } + res = append(res, id) + } + return res +} + +func TestAdminAPIJobs(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + now := timeutil.Now() + retentionTime := 336 * time.Hour + s, conn, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails + // with it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + Knobs: base.TestingKnobs{ + JobsTestingKnobs: &jobs.TestingKnobs{ + IntervalOverrides: jobs.TestingIntervalOverrides{ + RetentionTime: &retentionTime, + }, + }, + Server: &server.TestingKnobs{ + StubTimeNow: func() time.Time { return now }, + }, + }, + }) + + defer s.Stopper().Stop(context.Background()) + sqlDB := sqlutils.MakeSQLRunner(conn) + + testutils.RunTrueAndFalse(t, "isAdmin", func(t *testing.T, isAdmin bool) { + // Creating this client causes a user to be created, which causes jobs + // to be created, so we do it up-front rather than inside the test. + _, err := s.GetAuthenticatedHTTPClient(isAdmin, serverutils.SingleTenantSession) + if err != nil { + t.Fatal(err) + } + }) + + existingSucceededIDs := getSystemJobIDsForNonAutoJobs(t, sqlDB, jobs.StatusSucceeded) + existingRunningIDs := getSystemJobIDsForNonAutoJobs(t, sqlDB, jobs.StatusRunning) + existingIDs := append(existingSucceededIDs, existingRunningIDs...) + + runningOnlyIds := []int64{1, 2, 4, 11, 12} + revertingOnlyIds := []int64{7, 8, 9} + retryRunningIds := []int64{6} + retryRevertingIds := []int64{10} + ef := &jobspb.RetriableExecutionFailure{ + TruncatedError: "foo", + } + // Add a regression test for #84139 where a string with a quote in it + // caused a failure in the admin API. + efQuote := &jobspb.RetriableExecutionFailure{ + TruncatedError: "foo\"abc\"", + } + + testJobs := []struct { + id int64 + status jobs.Status + details jobspb.Details + progress jobspb.ProgressDetails + username username.SQLUsername + numRuns int64 + lastRun time.Time + executionFailures []*jobspb.RetriableExecutionFailure + }{ + {1, jobs.StatusRunning, jobspb.RestoreDetails{}, jobspb.RestoreProgress{}, username.RootUserName(), 1, time.Time{}, nil}, + {2, jobs.StatusRunning, jobspb.BackupDetails{}, jobspb.BackupProgress{}, username.RootUserName(), 1, timeutil.Now().Add(10 * time.Minute), nil}, + {3, jobs.StatusSucceeded, jobspb.BackupDetails{}, jobspb.BackupProgress{}, username.RootUserName(), 1, time.Time{}, nil}, + {4, jobs.StatusRunning, jobspb.ChangefeedDetails{}, jobspb.ChangefeedProgress{}, username.RootUserName(), 2, time.Time{}, nil}, + {5, jobs.StatusSucceeded, jobspb.BackupDetails{}, jobspb.BackupProgress{}, apiconstants.TestingUserNameNoAdmin(), 1, time.Time{}, nil}, + {6, jobs.StatusRunning, jobspb.ImportDetails{}, jobspb.ImportProgress{}, username.RootUserName(), 2, timeutil.Now().Add(10 * time.Minute), nil}, + {7, jobs.StatusReverting, jobspb.ImportDetails{}, jobspb.ImportProgress{}, username.RootUserName(), 1, time.Time{}, nil}, + {8, jobs.StatusReverting, jobspb.ImportDetails{}, jobspb.ImportProgress{}, username.RootUserName(), 1, timeutil.Now().Add(10 * time.Minute), nil}, + {9, jobs.StatusReverting, jobspb.ImportDetails{}, jobspb.ImportProgress{}, username.RootUserName(), 2, time.Time{}, nil}, + {10, jobs.StatusReverting, jobspb.ImportDetails{}, jobspb.ImportProgress{}, username.RootUserName(), 2, timeutil.Now().Add(10 * time.Minute), nil}, + {11, jobs.StatusRunning, jobspb.RestoreDetails{}, jobspb.RestoreProgress{}, username.RootUserName(), 1, time.Time{}, []*jobspb.RetriableExecutionFailure{ef}}, + {12, jobs.StatusRunning, jobspb.RestoreDetails{}, jobspb.RestoreProgress{}, username.RootUserName(), 1, time.Time{}, []*jobspb.RetriableExecutionFailure{efQuote}}, + } + for _, job := range testJobs { + payload := jobspb.Payload{ + UsernameProto: job.username.EncodeProto(), + Details: jobspb.WrapPayloadDetails(job.details), + RetriableExecutionFailureLog: job.executionFailures, + } + payloadBytes, err := protoutil.Marshal(&payload) + if err != nil { + t.Fatal(err) + } + + progress := jobspb.Progress{Details: jobspb.WrapProgressDetails(job.progress)} + // Populate progress.Progress field with a specific progress type based on + // the job type. + if _, ok := job.progress.(jobspb.ChangefeedProgress); ok { + progress.Progress = &jobspb.Progress_HighWater{ + HighWater: &hlc.Timestamp{}, + } + } else { + progress.Progress = &jobspb.Progress_FractionCompleted{ + FractionCompleted: 1.0, + } + } + + progressBytes, err := protoutil.Marshal(&progress) + if err != nil { + t.Fatal(err) + } + sqlDB.Exec(t, + `INSERT INTO system.jobs (id, status, num_runs, last_run, job_type) VALUES ($1, $2, $3, $4, $5)`, + job.id, job.status, job.numRuns, job.lastRun, payload.Type().String(), + ) + sqlDB.Exec(t, + `INSERT INTO system.job_info (job_id, info_key, value) VALUES ($1, $2, $3)`, + job.id, jobs.GetLegacyPayloadKey(), payloadBytes, + ) + sqlDB.Exec(t, + `INSERT INTO system.job_info (job_id, info_key, value) VALUES ($1, $2, $3)`, + job.id, jobs.GetLegacyProgressKey(), progressBytes, + ) + } + + const invalidJobType = math.MaxInt32 + + testCases := []struct { + uri string + expectedIDsViaAdmin []int64 + expectedIDsViaNonAdmin []int64 + }{ + { + "jobs", + append([]int64{12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, existingIDs...), + []int64{5}, + }, + { + "jobs?limit=1", + []int64{12}, + []int64{5}, + }, + { + "jobs?status=succeeded", + append([]int64{5, 3}, existingSucceededIDs...), + []int64{5}, + }, + { + "jobs?status=running", + append(append(append([]int64{}, runningOnlyIds...), retryRunningIds...), existingRunningIDs...), + []int64{}, + }, + { + "jobs?status=reverting", + append(append([]int64{}, revertingOnlyIds...), retryRevertingIds...), + []int64{}, + }, + { + "jobs?status=pending", + []int64{}, + []int64{}, + }, + { + "jobs?status=garbage", + []int64{}, + []int64{}, + }, + { + fmt.Sprintf("jobs?type=%d", jobspb.TypeBackup), + []int64{5, 3, 2}, + []int64{5}, + }, + { + fmt.Sprintf("jobs?type=%d", jobspb.TypeRestore), + []int64{1, 11, 12}, + []int64{}, + }, + { + fmt.Sprintf("jobs?type=%d", invalidJobType), + []int64{}, + []int64{}, + }, + { + fmt.Sprintf("jobs?status=running&type=%d", jobspb.TypeBackup), + []int64{2}, + []int64{}, + }, + } + + testutils.RunTrueAndFalse(t, "isAdmin", func(t *testing.T, isAdmin bool) { + for i, testCase := range testCases { + var res serverpb.JobsResponse + if err := srvtestutils.GetAdminJSONProtoWithAdminOption(s, testCase.uri, &res, isAdmin); err != nil { + t.Fatal(err) + } + resIDs := []int64{} + for _, job := range res.Jobs { + resIDs = append(resIDs, job.ID) + } + + expected := testCase.expectedIDsViaAdmin + if !isAdmin { + expected = testCase.expectedIDsViaNonAdmin + } + + sort.Slice(expected, func(i, j int) bool { + return expected[i] < expected[j] + }) + + sort.Slice(resIDs, func(i, j int) bool { + return resIDs[i] < resIDs[j] + }) + if e, a := expected, resIDs; !reflect.DeepEqual(e, a) { + t.Errorf("%d - %v: expected job IDs %v, but got %v", i, testCase.uri, e, a) + } + // We don't use require.Equal() because timestamps don't necessarily + // compare == due to only one of them having a monotonic clock reading. + require.True(t, now.Add(-retentionTime).Equal(res.EarliestRetainedTime)) + } + }) +} + +func TestAdminAPIJobsDetails(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + s, conn, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails + // with it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + }) + defer s.Stopper().Stop(context.Background()) + sqlDB := sqlutils.MakeSQLRunner(conn) + + now := timeutil.Now() + + encodedError := func(err error) *errors.EncodedError { + ee := errors.EncodeError(context.Background(), err) + return &ee + } + testJobs := []struct { + id int64 + status jobs.Status + details jobspb.Details + progress jobspb.ProgressDetails + username username.SQLUsername + numRuns int64 + lastRun time.Time + executionLog []*jobspb.RetriableExecutionFailure + }{ + {1, jobs.StatusRunning, jobspb.RestoreDetails{}, jobspb.RestoreProgress{}, username.RootUserName(), 1, time.Time{}, nil}, + {2, jobs.StatusReverting, jobspb.BackupDetails{}, jobspb.BackupProgress{}, username.RootUserName(), 1, time.Time{}, nil}, + {3, jobs.StatusRunning, jobspb.BackupDetails{}, jobspb.BackupProgress{}, username.RootUserName(), 1, now.Add(10 * time.Minute), nil}, + {4, jobs.StatusReverting, jobspb.ChangefeedDetails{}, jobspb.ChangefeedProgress{}, username.RootUserName(), 1, now.Add(10 * time.Minute), nil}, + {5, jobs.StatusRunning, jobspb.BackupDetails{}, jobspb.BackupProgress{}, username.RootUserName(), 2, time.Time{}, nil}, + {6, jobs.StatusReverting, jobspb.ChangefeedDetails{}, jobspb.ChangefeedProgress{}, username.RootUserName(), 2, time.Time{}, nil}, + {7, jobs.StatusRunning, jobspb.BackupDetails{}, jobspb.BackupProgress{}, username.RootUserName(), 2, now.Add(10 * time.Minute), nil}, + {8, jobs.StatusReverting, jobspb.ChangefeedDetails{}, jobspb.ChangefeedProgress{}, username.RootUserName(), 2, now.Add(10 * time.Minute), []*jobspb.RetriableExecutionFailure{ + { + Status: string(jobs.StatusRunning), + ExecutionStartMicros: now.Add(-time.Minute).UnixMicro(), + ExecutionEndMicros: now.Add(-30 * time.Second).UnixMicro(), + InstanceID: 1, + Error: encodedError(errors.New("foo")), + }, + { + Status: string(jobs.StatusReverting), + ExecutionStartMicros: now.Add(-29 * time.Minute).UnixMicro(), + ExecutionEndMicros: now.Add(-time.Second).UnixMicro(), + InstanceID: 1, + TruncatedError: "bar", + }, + }}, + } + for _, job := range testJobs { + payload := jobspb.Payload{ + UsernameProto: job.username.EncodeProto(), + Details: jobspb.WrapPayloadDetails(job.details), + RetriableExecutionFailureLog: job.executionLog, + } + payloadBytes, err := protoutil.Marshal(&payload) + if err != nil { + t.Fatal(err) + } + + progress := jobspb.Progress{Details: jobspb.WrapProgressDetails(job.progress)} + // Populate progress.Progress field with a specific progress type based on + // the job type. + if _, ok := job.progress.(jobspb.ChangefeedProgress); ok { + progress.Progress = &jobspb.Progress_HighWater{ + HighWater: &hlc.Timestamp{}, + } + } else { + progress.Progress = &jobspb.Progress_FractionCompleted{ + FractionCompleted: 1.0, + } + } + + progressBytes, err := protoutil.Marshal(&progress) + if err != nil { + t.Fatal(err) + } + sqlDB.Exec(t, + `INSERT INTO system.jobs (id, status, num_runs, last_run) VALUES ($1, $2, $3, $4)`, + job.id, job.status, job.numRuns, job.lastRun, + ) + sqlDB.Exec(t, + `INSERT INTO system.job_info (job_id, info_key, value) VALUES ($1, $2, $3)`, + job.id, jobs.GetLegacyPayloadKey(), payloadBytes, + ) + sqlDB.Exec(t, + `INSERT INTO system.job_info (job_id, info_key, value) VALUES ($1, $2, $3)`, + job.id, jobs.GetLegacyProgressKey(), progressBytes, + ) + } + + var res serverpb.JobsResponse + if err := srvtestutils.GetAdminJSONProto(s, "jobs", &res); err != nil { + t.Fatal(err) + } + + // Trim down our result set to the jobs we injected. + resJobs := append([]serverpb.JobResponse(nil), res.Jobs...) + sort.Slice(resJobs, func(i, j int) bool { + return resJobs[i].ID < resJobs[j].ID + }) + resJobs = resJobs[:len(testJobs)] + + for i, job := range resJobs { + require.Equal(t, testJobs[i].id, job.ID) + require.Equal(t, len(testJobs[i].executionLog), len(job.ExecutionFailures)) + for j, f := range job.ExecutionFailures { + tf := testJobs[i].executionLog[j] + require.Equal(t, tf.Status, f.Status) + require.Equal(t, tf.ExecutionStartMicros, f.Start.UnixMicro()) + require.Equal(t, tf.ExecutionEndMicros, f.End.UnixMicro()) + var expErr string + if tf.Error != nil { + expErr = errors.DecodeError(context.Background(), *tf.Error).Error() + } else { + expErr = tf.TruncatedError + } + require.Equal(t, expErr, f.Error) + } + } +} + +func TestJobStatusResponse(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ts, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer ts.Stopper().Stop(context.Background()) + + rootConfig := testutils.NewTestBaseContext(username.RootUserName()) + rpcContext := srvtestutils.NewRPCTestContext(context.Background(), ts.(*server.TestServer), rootConfig) + + url := ts.ServingRPCAddr() + nodeID := ts.NodeID() + conn, err := rpcContext.GRPCDialNode(url, nodeID, rpc.DefaultClass).Connect(context.Background()) + if err != nil { + t.Fatal(err) + } + client := serverpb.NewStatusClient(conn) + + request := &serverpb.JobStatusRequest{JobId: -1} + response, err := client.JobStatus(context.Background(), request) + require.Regexp(t, `job with ID -1 does not exist`, err) + require.Nil(t, response) + + ctx := context.Background() + jr := ts.JobRegistry().(*jobs.Registry) + job, err := jr.CreateJobWithTxn( + ctx, + jobs.Record{ + Description: "testing", + Statements: []string{"SELECT 1"}, + Username: username.RootUserName(), + Details: jobspb.ImportDetails{ + Tables: []jobspb.ImportDetails_Table{ + { + Desc: &descpb.TableDescriptor{ + ID: 1, + }, + }, + { + Desc: &descpb.TableDescriptor{ + ID: 2, + }, + }, + }, + URIs: []string{"a", "b"}, + }, + Progress: jobspb.ImportProgress{}, + DescriptorIDs: []descpb.ID{1, 2, 3}, + }, + jr.MakeJobID(), + nil) + if err != nil { + t.Fatal(err) + } + request.JobId = int64(job.ID()) + response, err = client.JobStatus(context.Background(), request) + if err != nil { + t.Fatal(err) + } + require.Equal(t, job.ID(), response.Job.Id) + require.Equal(t, job.Payload(), *response.Job.Payload) + require.Equal(t, job.Progress(), *response.Job.Progress) +} diff --git a/pkg/server/application_api/main_test.go b/pkg/server/application_api/main_test.go new file mode 100644 index 000000000000..48112562f90f --- /dev/null +++ b/pkg/server/application_api/main_test.go @@ -0,0 +1,35 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package application_api_test + +import ( + "os" + "testing" + + "github.com/cockroachdb/cockroach/pkg/kv/kvclient/kvtenant" + "github.com/cockroachdb/cockroach/pkg/security/securityassets" + "github.com/cockroachdb/cockroach/pkg/security/securitytest" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/rangetestutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/testcluster" +) + +func TestMain(m *testing.M) { + securityassets.SetLoader(securitytest.EmbeddedAssets) + serverutils.InitTestServerFactory(server.TestServerFactory) + serverutils.InitTestClusterFactory(testcluster.TestClusterFactory) + rangetestutils.InitRangeTestServerFactory(server.TestServerFactory) + kvtenant.InitTestConnectorFactory() + os.Exit(m.Run()) +} + +//go:generate ../util/leaktest/add-leaktest.sh *_test.go diff --git a/pkg/server/application_api/metrics_test.go b/pkg/server/application_api/metrics_test.go new file mode 100644 index 000000000000..0f985f2a4a88 --- /dev/null +++ b/pkg/server/application_api/metrics_test.go @@ -0,0 +1,152 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package application_api_test + +import ( + "bytes" + "context" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/httputil" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" +) + +// TestMetricsMetadata ensures that the server's recorder return metrics and +// that each metric has a Name, Help, Unit, and DisplayUnit defined. +func TestMetricsMetadata(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + srv, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(context.Background()) + + s := srv.(*server.TestServer) + + metricsMetadata := s.MetricsRecorder().GetMetricsMetadata() + + if len(metricsMetadata) < 200 { + t.Fatal("s.recorder.GetMetricsMetadata() failed sanity check; didn't return enough metrics.") + } + + for _, v := range metricsMetadata { + if v.Name == "" { + t.Fatal("metric missing name.") + } + if v.Help == "" { + t.Fatalf("%s missing Help.", v.Name) + } + if v.Measurement == "" { + t.Fatalf("%s missing Measurement.", v.Name) + } + if v.Unit == 0 { + t.Fatalf("%s missing Unit.", v.Name) + } + } +} + +// TestStatusVars verifies that prometheus metrics are available via the +// /_status/vars and /_status/load endpoints. +func TestStatusVars(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + + if body, err := srvtestutils.GetText(s, s.AdminURL().WithPath(apiconstants.StatusPrefix+"vars").String()); err != nil { + t.Fatal(err) + } else if !bytes.Contains(body, []byte("# TYPE sql_bytesout counter\nsql_bytesout")) { + t.Errorf("expected sql_bytesout, got: %s", body) + } + if body, err := srvtestutils.GetText(s, s.AdminURL().WithPath(apiconstants.StatusPrefix+"load").String()); err != nil { + t.Fatal(err) + } else if !bytes.Contains(body, []byte("# TYPE sys_cpu_user_ns gauge\nsys_cpu_user_ns")) { + t.Errorf("expected sys_cpu_user_ns, got: %s", body) + } +} + +// TestStatusVarsTxnMetrics verifies that the metrics from the /_status/vars +// endpoint for txns and the special cockroach_restart savepoint are correct. +func TestStatusVarsTxnMetrics(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer db.Close() + defer s.Stopper().Stop(context.Background()) + + if _, err := db.Exec("BEGIN;" + + "SAVEPOINT cockroach_restart;" + + "SELECT 1;" + + "RELEASE SAVEPOINT cockroach_restart;" + + "ROLLBACK;"); err != nil { + t.Fatal(err) + } + + body, err := srvtestutils.GetText(s, s.AdminURL().WithPath(apiconstants.StatusPrefix+"vars").String()) + if err != nil { + t.Fatal(err) + } + if !bytes.Contains(body, []byte("sql_txn_begin_count{node_id=\"1\"} 1")) { + t.Errorf("expected `sql_txn_begin_count{node_id=\"1\"} 1`, got: %s", body) + } + if !bytes.Contains(body, []byte("sql_restart_savepoint_count{node_id=\"1\"} 1")) { + t.Errorf("expected `sql_restart_savepoint_count{node_id=\"1\"} 1`, got: %s", body) + } + if !bytes.Contains(body, []byte("sql_restart_savepoint_release_count{node_id=\"1\"} 1")) { + t.Errorf("expected `sql_restart_savepoint_release_count{node_id=\"1\"} 1`, got: %s", body) + } + if !bytes.Contains(body, []byte("sql_txn_commit_count{node_id=\"1\"} 1")) { + t.Errorf("expected `sql_txn_commit_count{node_id=\"1\"} 1`, got: %s", body) + } + if !bytes.Contains(body, []byte("sql_txn_rollback_count{node_id=\"1\"} 0")) { + t.Errorf("expected `sql_txn_rollback_count{node_id=\"1\"} 0`, got: %s", body) + } +} + +func TestSpanStatsResponse(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ts, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer ts.Stopper().Stop(context.Background()) + + httpClient, err := ts.GetAdminHTTPClient() + if err != nil { + t.Fatal(err) + } + + var response roachpb.SpanStatsResponse + span := roachpb.Span{ + Key: roachpb.RKeyMin.AsRawKey(), + EndKey: roachpb.RKeyMax.AsRawKey(), + } + request := roachpb.SpanStatsRequest{ + NodeID: "1", + Spans: []roachpb.Span{span}, + } + + url := ts.AdminURL().WithPath(apiconstants.StatusPrefix + "span").String() + if err := httputil.PostJSON(httpClient, url, &request, &response); err != nil { + t.Fatal(err) + } + initialRanges, err := ts.ExpectedInitialRangeCount() + if err != nil { + t.Fatal(err) + } + responseSpanStats := response.SpanToStats[span.String()] + if a, e := int(responseSpanStats.RangeCount), initialRanges; a != e { + t.Errorf("expected %d ranges, found %d", e, a) + } +} diff --git a/pkg/server/application_api/query_plan_test.go b/pkg/server/application_api/query_plan_test.go new file mode 100644 index 000000000000..b2459f65114f --- /dev/null +++ b/pkg/server/application_api/query_plan_test.go @@ -0,0 +1,66 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package application_api_test + +import ( + "context" + "fmt" + "net/url" + "strings" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" +) + +func TestAdminAPIQueryPlan(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + s, conn, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails + // with it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + }) + defer s.Stopper().Stop(context.Background()) + sqlDB := sqlutils.MakeSQLRunner(conn) + + sqlDB.Exec(t, `CREATE DATABASE api_test`) + sqlDB.Exec(t, `CREATE TABLE api_test.t1 (id int primary key, name string)`) + sqlDB.Exec(t, `CREATE TABLE api_test.t2 (id int primary key, name string)`) + + testCases := []struct { + query string + exp []string + }{ + {"SELECT sum(id) FROM api_test.t1", []string{"nodeNames\":[\"1\"]", "Columns: id"}}, + {"SELECT sum(1) FROM api_test.t1 JOIN api_test.t2 on t1.id = t2.id", []string{"nodeNames\":[\"1\"]", "Columns: id"}}, + } + for i, testCase := range testCases { + var res serverpb.QueryPlanResponse + queryParam := url.QueryEscape(testCase.query) + if err := srvtestutils.GetAdminJSONProto(s, fmt.Sprintf("queryplan?query=%s", queryParam), &res); err != nil { + t.Errorf("%d: got error %s", i, err) + } + + for _, exp := range testCase.exp { + if !strings.Contains(res.DistSQLPhysicalQueryPlan, exp) { + t.Errorf("%d: expected response %v to contain %s", i, res, exp) + } + } + } + +} diff --git a/pkg/server/application_api/schema_inspection_test.go b/pkg/server/application_api/schema_inspection_test.go new file mode 100644 index 000000000000..7a7401c011e4 --- /dev/null +++ b/pkg/server/application_api/schema_inspection_test.go @@ -0,0 +1,620 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package application_api_test + +import ( + "context" + gosql "database/sql" + "fmt" + "net/url" + "reflect" + "regexp" + "sort" + "strings" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/kv/kvserver" + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/skip" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/testutils/testcluster" + "github.com/cockroachdb/cockroach/pkg/util/httputil" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAdminAPIDatabases(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, db, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails with + // it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + }) + defer s.Stopper().Stop(context.Background()) + ts := s.(*server.TestServer) + + ac := ts.AmbientCtx() + ctx, span := ac.AnnotateCtxWithSpan(context.Background(), "test") + defer span.Finish() + + testDbName := generateRandomName() + testDbEscaped := tree.NameString(testDbName) + query := "CREATE DATABASE " + testDbEscaped + if _, err := db.Exec(query); err != nil { + t.Fatal(err) + } + // Test needs to revoke CONNECT on the public database to properly exercise + // fine-grained permissions logic. + if _, err := db.Exec(fmt.Sprintf("REVOKE CONNECT ON DATABASE %s FROM public", testDbEscaped)); err != nil { + t.Fatal(err) + } + if _, err := db.Exec("REVOKE CONNECT ON DATABASE defaultdb FROM public"); err != nil { + t.Fatal(err) + } + + // We have to create the non-admin user before calling + // "GRANT ... TO apiconstants.TestingUserNameNoAdmin". + // This is done in "GetAuthenticatedHTTPClient". + if _, err := ts.GetAuthenticatedHTTPClient(false, serverutils.SingleTenantSession); err != nil { + t.Fatal(err) + } + + // Grant permissions to view the tables for the given viewing user. + privileges := []string{"CONNECT"} + query = fmt.Sprintf( + "GRANT %s ON DATABASE %s TO %s", + strings.Join(privileges, ", "), + testDbEscaped, + apiconstants.TestingUserNameNoAdmin().SQLIdentifier(), + ) + if _, err := db.Exec(query); err != nil { + t.Fatal(err) + } + // Non admins now also require VIEWACTIVITY. + query = fmt.Sprintf( + "GRANT SYSTEM %s TO %s", + "VIEWACTIVITY", + apiconstants.TestingUserNameNoAdmin().SQLIdentifier(), + ) + if _, err := db.Exec(query); err != nil { + t.Fatal(err) + } + + for _, tc := range []struct { + expectedDBs []string + isAdmin bool + }{ + {[]string{"defaultdb", "postgres", "system", testDbName}, true}, + {[]string{"postgres", testDbName}, false}, + } { + t.Run(fmt.Sprintf("isAdmin:%t", tc.isAdmin), func(t *testing.T) { + // Test databases endpoint. + var resp serverpb.DatabasesResponse + if err := srvtestutils.GetAdminJSONProtoWithAdminOption( + s, + "databases", + &resp, + tc.isAdmin, + ); err != nil { + t.Fatal(err) + } + + if a, e := len(resp.Databases), len(tc.expectedDBs); a != e { + t.Fatalf("length of result %d != expected %d", a, e) + } + + sort.Strings(tc.expectedDBs) + sort.Strings(resp.Databases) + for i, e := range tc.expectedDBs { + if a := resp.Databases[i]; a != e { + t.Fatalf("database name %s != expected %s", a, e) + } + } + + // Test database details endpoint. + var details serverpb.DatabaseDetailsResponse + urlEscapeDbName := url.PathEscape(testDbName) + + if err := srvtestutils.GetAdminJSONProtoWithAdminOption( + s, + "databases/"+urlEscapeDbName, + &details, + tc.isAdmin, + ); err != nil { + t.Fatal(err) + } + + if a, e := len(details.Grants), 3; a != e { + t.Fatalf("# of grants %d != expected %d", a, e) + } + + userGrants := make(map[string][]string) + for _, grant := range details.Grants { + switch grant.User { + case username.AdminRole, username.RootUser, apiconstants.TestingUserNoAdmin: + userGrants[grant.User] = append(userGrants[grant.User], grant.Privileges...) + default: + t.Fatalf("unknown grant to user %s", grant.User) + } + } + for u, p := range userGrants { + switch u { + case username.AdminRole: + if !reflect.DeepEqual(p, []string{"ALL"}) { + t.Fatalf("privileges %v != expected %v", p, privileges) + } + case username.RootUser: + if !reflect.DeepEqual(p, []string{"ALL"}) { + t.Fatalf("privileges %v != expected %v", p, privileges) + } + case apiconstants.TestingUserNoAdmin: + sort.Strings(p) + if !reflect.DeepEqual(p, privileges) { + t.Fatalf("privileges %v != expected %v", p, privileges) + } + default: + t.Fatalf("unknown grant to user %s", u) + } + } + + // Verify Descriptor ID. + databaseID, err := ts.TestingQueryDatabaseID(ctx, username.RootUserName(), testDbName) + if err != nil { + t.Fatal(err) + } + if a, e := details.DescriptorID, int64(databaseID); a != e { + t.Fatalf("db had descriptorID %d, expected %d", a, e) + } + }) + } +} + +func TestAdminAPIDatabaseDoesNotExist(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails with + // it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + }) + defer s.Stopper().Stop(context.Background()) + + const errPattern = "database.+does not exist" + if err := srvtestutils.GetAdminJSONProto(s, "databases/i_do_not_exist", nil); !testutils.IsError(err, errPattern) { + t.Fatalf("unexpected error: %v\nexpected: %s", err, errPattern) + } +} + +func TestAdminAPIDatabaseSQLInjection(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails with + // it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + }) + defer s.Stopper().Stop(context.Background()) + + const fakedb = "system;DROP DATABASE system;" + const path = "databases/" + fakedb + const errPattern = `target database or schema does not exist` + if err := srvtestutils.GetAdminJSONProto(s, path, nil); !testutils.IsError(err, errPattern) { + t.Fatalf("unexpected error: %v\nexpected: %s", err, errPattern) + } +} + +func TestAdminAPITableDoesNotExist(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails with + // it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + }) + defer s.Stopper().Stop(context.Background()) + + const fakename = "i_do_not_exist" + const badDBPath = "databases/" + fakename + "/tables/foo" + const dbErrPattern = `relation \\"` + fakename + `.foo\\" does not exist` + if err := srvtestutils.GetAdminJSONProto(s, badDBPath, nil); !testutils.IsError(err, dbErrPattern) { + t.Fatalf("unexpected error: %v\nexpected: %s", err, dbErrPattern) + } + + const badTablePath = "databases/system/tables/" + fakename + const tableErrPattern = `relation \\"system.` + fakename + `\\" does not exist` + if err := srvtestutils.GetAdminJSONProto(s, badTablePath, nil); !testutils.IsError(err, tableErrPattern) { + t.Fatalf("unexpected error: %v\nexpected: %s", err, tableErrPattern) + } +} + +func TestAdminAPITableSQLInjection(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails with + // it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + }) + defer s.Stopper().Stop(context.Background()) + + const fakeTable = "users;DROP DATABASE system;" + const path = "databases/system/tables/" + fakeTable + const errPattern = `relation \"system.` + fakeTable + `\" does not exist` + if err := srvtestutils.GetAdminJSONProto(s, path, nil); !testutils.IsError(err, regexp.QuoteMeta(errPattern)) { + t.Fatalf("unexpected error: %v\nexpected: %s", err, errPattern) + } +} + +func TestAdminAPITableDetails(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + for _, tc := range []struct { + name, dbName, tblName, pkName string + }{ + {name: "lower", dbName: "test", tblName: "tbl", pkName: "tbl_pkey"}, + {name: "lower other schema", dbName: "test", tblName: `testschema.tbl`, pkName: "tbl_pkey"}, + {name: "lower with space", dbName: "test test", tblName: `"tbl tbl"`, pkName: "tbl tbl_pkey"}, + {name: "upper", dbName: "TEST", tblName: `"TBL"`, pkName: "TBL_pkey"}, // Regression test for issue #14056 + } { + t.Run(tc.name, func(t *testing.T) { + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails + // with it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + }) + defer s.Stopper().Stop(context.Background()) + ts := s.(*server.TestServer) + + escDBName := tree.NameStringP(&tc.dbName) + tblName := tc.tblName + schemaName := "testschema" + + ac := ts.AmbientCtx() + ctx, span := ac.AnnotateCtxWithSpan(context.Background(), "test") + defer span.Finish() + + tableSchema := `nulls_allowed INT8, + nulls_not_allowed INT8 NOT NULL DEFAULT 1000, + default2 INT8 DEFAULT 2, + string_default STRING DEFAULT 'default_string', + INDEX descidx (default2 DESC)` + + setupQueries := []string{ + fmt.Sprintf("CREATE DATABASE %s", escDBName), + fmt.Sprintf("CREATE SCHEMA %s", schemaName), + fmt.Sprintf(`CREATE TABLE %s.%s (%s)`, escDBName, tblName, tableSchema), + "CREATE USER readonly", + "CREATE USER app", + fmt.Sprintf("GRANT SELECT ON %s.%s TO readonly", escDBName, tblName), + fmt.Sprintf("GRANT SELECT,UPDATE,DELETE ON %s.%s TO app", escDBName, tblName), + fmt.Sprintf("CREATE STATISTICS test_stats FROM %s.%s", escDBName, tblName), + } + pgURL, cleanupGoDB := sqlutils.PGUrl( + t, s.ServingSQLAddr(), "StartServer" /* prefix */, url.User(username.RootUser)) + defer cleanupGoDB() + pgURL.Path = tc.dbName + db, err := gosql.Open("postgres", pgURL.String()) + if err != nil { + t.Fatal(err) + } + defer db.Close() + for _, q := range setupQueries { + if _, err := db.Exec(q); err != nil { + t.Fatal(err) + } + } + + // Perform API call. + var resp serverpb.TableDetailsResponse + url := fmt.Sprintf("databases/%s/tables/%s", tc.dbName, tblName) + if err := srvtestutils.GetAdminJSONProto(s, url, &resp); err != nil { + t.Fatal(err) + } + + // Verify columns. + expColumns := []serverpb.TableDetailsResponse_Column{ + {Name: "nulls_allowed", Type: "INT8", Nullable: true, DefaultValue: ""}, + {Name: "nulls_not_allowed", Type: "INT8", Nullable: false, DefaultValue: "1000"}, + {Name: "default2", Type: "INT8", Nullable: true, DefaultValue: "2"}, + {Name: "string_default", Type: "STRING", Nullable: true, DefaultValue: "'default_string'"}, + {Name: "rowid", Type: "INT8", Nullable: false, DefaultValue: "unique_rowid()", Hidden: true}, + } + testutils.SortStructs(expColumns, "Name") + testutils.SortStructs(resp.Columns, "Name") + if a, e := len(resp.Columns), len(expColumns); a != e { + t.Fatalf("# of result columns %d != expected %d (got: %#v)", a, e, resp.Columns) + } + for i, a := range resp.Columns { + e := expColumns[i] + if a.String() != e.String() { + t.Fatalf("mismatch at column %d: actual %#v != %#v", i, a, e) + } + } + + // Verify grants. + expGrants := []serverpb.TableDetailsResponse_Grant{ + {User: username.AdminRole, Privileges: []string{"ALL"}}, + {User: username.RootUser, Privileges: []string{"ALL"}}, + {User: "app", Privileges: []string{"DELETE"}}, + {User: "app", Privileges: []string{"SELECT"}}, + {User: "app", Privileges: []string{"UPDATE"}}, + {User: "readonly", Privileges: []string{"SELECT"}}, + } + testutils.SortStructs(expGrants, "User") + testutils.SortStructs(resp.Grants, "User") + if a, e := len(resp.Grants), len(expGrants); a != e { + t.Fatalf("# of grant columns %d != expected %d (got: %#v)", a, e, resp.Grants) + } + for i, a := range resp.Grants { + e := expGrants[i] + sort.Strings(a.Privileges) + sort.Strings(e.Privileges) + if a.String() != e.String() { + t.Fatalf("mismatch at index %d: actual %#v != %#v", i, a, e) + } + } + + // Verify indexes. + expIndexes := []serverpb.TableDetailsResponse_Index{ + {Name: tc.pkName, Column: "string_default", Direction: "N/A", Unique: true, Seq: 5, Storing: true}, + {Name: tc.pkName, Column: "default2", Direction: "N/A", Unique: true, Seq: 4, Storing: true}, + {Name: tc.pkName, Column: "nulls_not_allowed", Direction: "N/A", Unique: true, Seq: 3, Storing: true}, + {Name: tc.pkName, Column: "nulls_allowed", Direction: "N/A", Unique: true, Seq: 2, Storing: true}, + {Name: tc.pkName, Column: "rowid", Direction: "ASC", Unique: true, Seq: 1}, + {Name: "descidx", Column: "rowid", Direction: "ASC", Unique: false, Seq: 2, Implicit: true}, + {Name: "descidx", Column: "default2", Direction: "DESC", Unique: false, Seq: 1}, + } + testutils.SortStructs(expIndexes, "Name", "Seq") + testutils.SortStructs(resp.Indexes, "Name", "Seq") + for i, a := range resp.Indexes { + e := expIndexes[i] + if a.String() != e.String() { + t.Fatalf("mismatch at index %d: actual %#v != %#v", i, a, e) + } + } + + // Verify range count. + if a, e := resp.RangeCount, int64(1); a != e { + t.Fatalf("# of ranges %d != expected %d", a, e) + } + + // Verify Create Table Statement. + { + + showCreateTableQuery := fmt.Sprintf("SHOW CREATE TABLE %s.%s", escDBName, tblName) + + row := db.QueryRow(showCreateTableQuery) + var createStmt, tableName string + if err := row.Scan(&tableName, &createStmt); err != nil { + t.Fatal(err) + } + + if a, e := resp.CreateTableStatement, createStmt; a != e { + t.Fatalf("mismatched create table statement; expected %s, got %s", e, a) + } + } + + // Verify statistics last updated. + { + + showStatisticsForTableQuery := fmt.Sprintf("SELECT max(created) AS created FROM [SHOW STATISTICS FOR TABLE %s.%s]", escDBName, tblName) + + row := db.QueryRow(showStatisticsForTableQuery) + var createdTs time.Time + if err := row.Scan(&createdTs); err != nil { + t.Fatal(err) + } + + if a, e := resp.StatsLastCreatedAt, createdTs; reflect.DeepEqual(a, e) { + t.Fatalf("mismatched statistics creation timestamp; expected %s, got %s", e, a) + } + } + + // Verify Descriptor ID. + tableID, err := ts.TestingQueryTableID(ctx, username.RootUserName(), tc.dbName, tc.tblName) + if err != nil { + t.Fatal(err) + } + if a, e := resp.DescriptorID, int64(tableID); a != e { + t.Fatalf("table had descriptorID %d, expected %d", a, e) + } + }) + } +} + +func TestAdminAPIDatabaseDetails(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + const numServers = 3 + tc := testcluster.StartTestCluster(t, numServers, base.TestClusterArgs{}) + defer tc.Stopper().Stop(context.Background()) + + db := tc.ServerConn(0) + + _, err := db.Exec("CREATE DATABASE test") + require.NoError(t, err) + + _, err = db.Exec("CREATE TABLE test.foo (id INT PRIMARY KEY, val STRING)") + require.NoError(t, err) + + for i := 0; i < 10; i++ { + _, err := db.Exec("INSERT INTO test.foo VALUES($1, $2)", i, "test") + require.NoError(t, err) + } + + // Flush all stores here so that we can read the ApproximateDiskBytes field without waiting for a flush. + for i := 0; i < numServers; i++ { + s := tc.Server(i) + err = s.GetStores().(*kvserver.Stores).VisitStores(func(store *kvserver.Store) error { + return store.TODOEngine().Flush() + }) + require.NoError(t, err) + } + + s := tc.Server(0) + + var resp serverpb.DatabaseDetailsResponse + require.NoError(t, serverutils.GetJSONProto(s, "/_admin/v1/databases/test", &resp)) + assert.Nil(t, resp.Stats, "No Stats unless we ask for them explicitly.") + + nodeIDs := tc.NodeIDs() + testutils.SucceedsSoon(t, func() error { + var resp serverpb.DatabaseDetailsResponse + require.NoError(t, serverutils.GetJSONProto(s, "/_admin/v1/databases/test?include_stats=true", &resp)) + + if resp.Stats.RangeCount != int64(1) { + return errors.Newf("expected range-count=1, got %d", resp.Stats.RangeCount) + } + if len(resp.Stats.NodeIDs) != len(nodeIDs) { + return errors.Newf("expected node-ids=%s, got %s", nodeIDs, resp.Stats.NodeIDs) + } + assert.Equal(t, nodeIDs, resp.Stats.NodeIDs, "NodeIDs") + + // We've flushed data so this estimation should be non-zero. + assert.Positive(t, resp.Stats.ApproximateDiskBytes, "ApproximateDiskBytes") + + return nil + }) +} + +func TestAdminAPITableStats(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + skip.UnderStress(t, "flaky under stress #107156") + skip.UnderRace(t, "flaky under race #107156") + + const nodeCount = 3 + tc := testcluster.StartTestCluster(t, nodeCount, base.TestClusterArgs{ + ReplicationMode: base.ReplicationAuto, + ServerArgs: base.TestServerArgs{ + ScanInterval: time.Millisecond, + ScanMinIdleTime: time.Millisecond, + ScanMaxIdleTime: time.Millisecond, + }, + }) + defer tc.Stopper().Stop(context.Background()) + server0 := tc.Server(0) + + // Create clients (SQL, HTTP) connected to server 0. + db := tc.ServerConn(0) + + client, err := server0.GetAdminHTTPClient() + if err != nil { + t.Fatal(err) + } + + client.Timeout = time.Hour // basically no timeout + + // Make a single table and insert some data. The database and test have + // names which require escaping, in order to verify that database and + // table names are being handled correctly. + if _, err := db.Exec(`CREATE DATABASE "test test"`); err != nil { + t.Fatal(err) + } + if _, err := db.Exec(` + CREATE TABLE "test test"."foo foo" ( + id INT PRIMARY KEY, + val STRING + )`, + ); err != nil { + t.Fatal(err) + } + for i := 0; i < 10; i++ { + if _, err := db.Exec(` + INSERT INTO "test test"."foo foo" VALUES( + $1, $2 + )`, i, "test", + ); err != nil { + t.Fatal(err) + } + } + + url := server0.AdminURL().String() + "/_admin/v1/databases/test test/tables/foo foo/stats" + var tsResponse serverpb.TableStatsResponse + + // The new SQL table may not yet have split into its own range. Wait for + // this to occur, and for full replication. + testutils.SucceedsSoon(t, func() error { + if err := httputil.GetJSON(client, url, &tsResponse); err != nil { + t.Fatal(err) + } + if len(tsResponse.MissingNodes) != 0 { + return errors.Errorf("missing nodes: %+v", tsResponse.MissingNodes) + } + if tsResponse.RangeCount != 1 { + return errors.Errorf("Table range not yet separated.") + } + if tsResponse.NodeCount != nodeCount { + return errors.Errorf("Table range not yet replicated to %d nodes.", 3) + } + if a, e := tsResponse.ReplicaCount, int64(nodeCount); a != e { + return errors.Errorf("expected %d replicas, found %d", e, a) + } + if a, e := tsResponse.Stats.KeyCount, int64(30); a < e { + return errors.Errorf("expected at least %d total keys, found %d", e, a) + } + return nil + }) + + if len(tsResponse.MissingNodes) > 0 { + t.Fatalf("expected no missing nodes, found %v", tsResponse.MissingNodes) + } + + // Kill a node, ensure it shows up in MissingNodes and that ReplicaCount is + // lower. + tc.StopServer(1) + + if err := httputil.GetJSON(client, url, &tsResponse); err != nil { + t.Fatal(err) + } + if a, e := tsResponse.NodeCount, int64(nodeCount); a != e { + t.Errorf("expected %d nodes, found %d", e, a) + } + if a, e := tsResponse.RangeCount, int64(1); a != e { + t.Errorf("expected %d ranges, found %d", e, a) + } + if a, e := tsResponse.ReplicaCount, int64((nodeCount/2)+1); a != e { + t.Errorf("expected %d replicas, found %d", e, a) + } + if a, e := tsResponse.Stats.KeyCount, int64(10); a < e { + t.Errorf("expected at least 10 total keys, found %d", a) + } + if len(tsResponse.MissingNodes) != 1 { + t.Errorf("expected one missing node, found %v", tsResponse.MissingNodes) + } + if len(tsResponse.NodeIDs) == 0 { + t.Error("expected at least one node in NodeIds list") + } + + // Call TableStats with a very low timeout. This tests that fan-out queries + // do not leak goroutines if the calling context is abandoned. + // Interestingly, the call can actually sometimes succeed, despite the small + // timeout; however, in aggregate (or in stress tests) this will suffice for + // detecting leaks. + client.Timeout = 1 * time.Nanosecond + _ = httputil.GetJSON(client, url, &tsResponse) +} diff --git a/pkg/server/application_api/security_test.go b/pkg/server/application_api/security_test.go new file mode 100644 index 000000000000..66a6cbf6c9ac --- /dev/null +++ b/pkg/server/application_api/security_test.go @@ -0,0 +1,66 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package application_api_test + +import ( + "context" + "reflect" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" +) + +func TestAdminAPIUsers(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, db, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails + // with it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + }) + defer s.Stopper().Stop(context.Background()) + + // Create sample users. + query := ` +INSERT INTO system.users (username, "hashedPassword", user_id) +VALUES ('adminUser', 'abc', 200), ('bob', 'xyz', 201)` + if _, err := db.Exec(query); err != nil { + t.Fatal(err) + } + + // Query the API for users. + var resp serverpb.UsersResponse + if err := srvtestutils.GetAdminJSONProto(s, "users", &resp); err != nil { + t.Fatal(err) + } + expResult := serverpb.UsersResponse{ + Users: []serverpb.UsersResponse_User{ + {Username: "adminUser"}, + {Username: "authentic_user"}, + {Username: "bob"}, + {Username: "root"}, + }, + } + + // Verify results. + const sortKey = "Username" + testutils.SortStructs(resp.Users, sortKey) + testutils.SortStructs(expResult.Users, sortKey) + if !reflect.DeepEqual(resp, expResult) { + t.Fatalf("result %v != expected %v", resp, expResult) + } +} diff --git a/pkg/server/application_api/sessions_test.go b/pkg/server/application_api/sessions_test.go new file mode 100644 index 000000000000..615418d6f2fa --- /dev/null +++ b/pkg/server/application_api/sessions_test.go @@ -0,0 +1,338 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package application_api_test + +import ( + "context" + gosql "database/sql" + "encoding/hex" + "fmt" + "net/url" + "sort" + "strings" + "sync" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/sql/tests" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/skip" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/require" +) + +func TestListSessionsSecurity(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + ts := s.(*server.TestServer) + defer ts.Stopper().Stop(context.Background()) + + ctx := context.Background() + + for _, requestWithAdmin := range []bool{true, false} { + t.Run(fmt.Sprintf("admin=%v", requestWithAdmin), func(t *testing.T) { + myUser := apiconstants.TestingUserNameNoAdmin() + expectedErrOnListingRootSessions := "does not have permission to view sessions from user" + if requestWithAdmin { + myUser = apiconstants.TestingUserName() + expectedErrOnListingRootSessions = "" + } + + // HTTP requests respect the authenticated username from the HTTP session. + testCases := []struct { + endpoint string + expectedErr string + }{ + {"local_sessions", ""}, + {"sessions", ""}, + {fmt.Sprintf("local_sessions?username=%s", myUser.Normalized()), ""}, + {fmt.Sprintf("sessions?username=%s", myUser.Normalized()), ""}, + {"local_sessions?username=" + username.RootUser, expectedErrOnListingRootSessions}, + {"sessions?username=" + username.RootUser, expectedErrOnListingRootSessions}, + } + for _, tc := range testCases { + var response serverpb.ListSessionsResponse + err := srvtestutils.GetStatusJSONProtoWithAdminOption(ts, tc.endpoint, &response, requestWithAdmin) + if tc.expectedErr == "" { + if err != nil || len(response.Errors) > 0 { + t.Errorf("unexpected failure listing sessions from %s; error: %v; response errors: %v", + tc.endpoint, err, response.Errors) + } + } else { + respErr := "" + if len(response.Errors) > 0 { + respErr = response.Errors[0].Message + } + if !testutils.IsError(err, tc.expectedErr) && + !strings.Contains(respErr, tc.expectedErr) { + t.Errorf("did not get expected error %q when listing sessions from %s: %v", + tc.expectedErr, tc.endpoint, err) + } + } + } + }) + } + + // gRPC requests behave as root and thus are always allowed. + rootConfig := testutils.NewTestBaseContext(username.RootUserName()) + rpcContext := srvtestutils.NewRPCTestContext(ctx, ts, rootConfig) + url := ts.ServingRPCAddr() + nodeID := ts.NodeID() + conn, err := rpcContext.GRPCDialNode(url, nodeID, rpc.DefaultClass).Connect(context.Background()) + if err != nil { + t.Fatal(err) + } + client := serverpb.NewStatusClient(conn) + + for _, user := range []string{"", apiconstants.TestingUser, username.RootUser} { + request := &serverpb.ListSessionsRequest{Username: user} + if resp, err := client.ListLocalSessions(ctx, request); err != nil || len(resp.Errors) > 0 { + t.Errorf("unexpected failure listing local sessions for %q; error: %v; response errors: %v", + user, err, resp.Errors) + } + if resp, err := client.ListSessions(ctx, request); err != nil || len(resp.Errors) > 0 { + t.Errorf("unexpected failure listing sessions for %q; error: %v; response errors: %v", + user, err, resp.Errors) + } + } +} + +func TestStatusCancelSessionGatewayMetadataPropagation(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{}) + defer testCluster.Stopper().Stop(ctx) + + // Start a SQL session as admin on node 1. + sql0 := sqlutils.MakeSQLRunner(testCluster.ServerConn(0)) + results := sql0.QueryStr(t, "SELECT session_id FROM [SHOW SESSIONS] LIMIT 1") + sessionID, err := hex.DecodeString(results[0][0]) + require.NoError(t, err) + + // Attempt to cancel that SQL session as non-admin over HTTP on node 2. + req := &serverpb.CancelSessionRequest{ + SessionID: sessionID, + } + resp := &serverpb.CancelSessionResponse{} + err = srvtestutils.PostStatusJSONProtoWithAdminOption(testCluster.Server(1), "cancel_session/1", req, resp, false) + require.NotNil(t, err) + require.Contains(t, err.Error(), "status: 403 Forbidden") +} + +func TestStatusAPIListSessions(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + params, _ := tests.CreateTestServerParams() + ctx := context.Background() + testCluster := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{ + ServerArgs: params, + }) + defer testCluster.Stopper().Stop(ctx) + + serverProto := testCluster.Server(0) + serverSQL := sqlutils.MakeSQLRunner(testCluster.ServerConn(0)) + + appName := "test_sessions_api" + serverSQL.Exec(t, fmt.Sprintf(`SET application_name = "%s"`, appName)) + + getSessionWithTestAppName := func(response *serverpb.ListSessionsResponse) *serverpb.Session { + require.NotEmpty(t, response.Sessions) + for _, s := range response.Sessions { + if s.ApplicationName == appName { + return &s + } + } + t.Errorf("expected to find session with app name %s", appName) + return nil + } + + userNoAdmin := apiconstants.TestingUserNameNoAdmin() + var resp serverpb.ListSessionsResponse + // Non-admin without VIEWWACTIVITY or VIEWACTIVITYREDACTED should work and fetch user's own sessions. + err := srvtestutils.GetStatusJSONProtoWithAdminOption(serverProto, "sessions", &resp, false) + require.NoError(t, err) + + // Grant VIEWACTIVITYREDACTED. + serverSQL.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITYREDACTED", userNoAdmin.Normalized())) + serverSQL.Exec(t, "SELECT 1") + err = srvtestutils.GetStatusJSONProtoWithAdminOption(serverProto, "sessions", &resp, false) + require.NoError(t, err) + session := getSessionWithTestAppName(&resp) + require.Equal(t, session.LastActiveQuery, session.LastActiveQueryNoConstants) + require.Equal(t, "SELECT _", session.LastActiveQueryNoConstants) + + // Grant VIEWACTIVITY, VIEWACTIVITYREDACTED should take precedence. + serverSQL.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITY", userNoAdmin.Normalized())) + serverSQL.Exec(t, "SELECT 1, 1") + err = srvtestutils.GetStatusJSONProtoWithAdminOption(serverProto, "sessions", &resp, false) + require.NoError(t, err) + session = getSessionWithTestAppName(&resp) + require.Equal(t, appName, session.ApplicationName) + require.Equal(t, session.LastActiveQuery, session.LastActiveQueryNoConstants) + require.Equal(t, "SELECT _, _", session.LastActiveQueryNoConstants) + + // Remove VIEWACTIVITYREDCATED. User should now see full query. + serverSQL.Exec(t, fmt.Sprintf("ALTER USER %s NOVIEWACTIVITYREDACTED", userNoAdmin.Normalized())) + serverSQL.Exec(t, "SELECT 2") + err = srvtestutils.GetStatusJSONProtoWithAdminOption(serverProto, "sessions", &resp, false) + require.NoError(t, err) + session = getSessionWithTestAppName(&resp) + require.Equal(t, "SELECT _", session.LastActiveQueryNoConstants) + require.Equal(t, "SELECT 2", session.LastActiveQuery) +} + +func TestListClosedSessions(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // The active sessions might close before the stress race can finish. + skip.UnderStressRace(t, "active sessions") + + ctx := context.Background() + serverParams, _ := tests.CreateTestServerParams() + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{ + ServerArgs: serverParams, + }) + defer testCluster.Stopper().Stop(ctx) + + server := testCluster.Server(0) + + doSessionsRequest := func(username string) serverpb.ListSessionsResponse { + var resp serverpb.ListSessionsResponse + path := "/_status/sessions?username=" + username + err := serverutils.GetJSONProto(server, path, &resp) + require.NoError(t, err) + return resp + } + + getUserConn := func(t *testing.T, username string, server serverutils.TestServerInterface) *gosql.DB { + pgURL := url.URL{ + Scheme: "postgres", + User: url.UserPassword(username, "hunter2"), + Host: server.ServingSQLAddr(), + } + db, err := gosql.Open("postgres", pgURL.String()) + require.NoError(t, err) + return db + } + + // Create a test user. + users := []string{"test_user_a", "test_user_b", "test_user_c"} + conn := testCluster.ServerConn(0) + _, err := conn.Exec(fmt.Sprintf(` +CREATE USER %s with password 'hunter2'; +CREATE USER %s with password 'hunter2'; +CREATE USER %s with password 'hunter2'; +`, users[0], users[1], users[2])) + require.NoError(t, err) + + var dbs []*gosql.DB + + // Open 10 sessions for the user and then close them. + for _, user := range users { + for i := 0; i < 10; i++ { + targetDB := getUserConn(t, user, testCluster.Server(0)) + dbs = append(dbs, targetDB) + sqlutils.MakeSQLRunner(targetDB).Exec(t, `SELECT version()`) + } + } + + for _, db := range dbs { + err := db.Close() + require.NoError(t, err) + } + + var wg sync.WaitGroup + + // Open 5 sessions for the user and leave them open by running pg_sleep(30). + for _, user := range users { + for i := 0; i < 5; i++ { + wg.Add(1) + go func(user string) { + // Open a session for the target user. + targetDB := getUserConn(t, user, testCluster.Server(0)) + defer targetDB.Close() + defer wg.Done() + sqlutils.MakeSQLRunner(targetDB).Exec(t, `SELECT pg_sleep(30)`) + }(user) + } + } + + // Open 3 sessions for the user and leave them idle by running version(). + for _, user := range users { + for i := 0; i < 3; i++ { + targetDB := getUserConn(t, user, testCluster.Server(0)) + defer targetDB.Close() + sqlutils.MakeSQLRunner(targetDB).Exec(t, `SELECT version()`) + } + } + + countSessionStatus := func(allSessions []serverpb.Session) (int, int, int) { + var active, idle, closed int + for _, s := range allSessions { + if s.Status.String() == "ACTIVE" { + active++ + } + // IDLE sessions are open sessions with no active queries. + if s.Status.String() == "IDLE" { + idle++ + } + if s.Status.String() == "CLOSED" { + closed++ + } + } + return active, idle, closed + } + + expectedIdle := 3 + expectedActive := 5 + expectedClosed := 10 + + testutils.SucceedsSoon(t, func() error { + for _, user := range users { + sessionsResponse := doSessionsRequest(user) + allSessions := sessionsResponse.Sessions + sort.Slice(allSessions, func(i, j int) bool { + return allSessions[i].Start.Before(allSessions[j].Start) + }) + + active, idle, closed := countSessionStatus(allSessions) + if idle != expectedIdle { + return errors.Newf("User: %s: Expected %d idle sessions, got %d\n", user, expectedIdle, idle) + } + if active != expectedActive { + return errors.Newf("User: %s: Expected %d active sessions, got %d\n", user, expectedActive, active) + } + if closed != expectedClosed { + return errors.Newf("User: %s: Expected %d closed sessions, got %d\n", user, expectedClosed, closed) + } + } + return nil + }) + + // Wait for the goroutines from the pg_sleep() command to finish, so we can + // safely close their connections. + wg.Wait() +} diff --git a/pkg/server/application_api/sql_stats_test.go b/pkg/server/application_api/sql_stats_test.go new file mode 100644 index 000000000000..aacd34842451 --- /dev/null +++ b/pkg/server/application_api/sql_stats_test.go @@ -0,0 +1,941 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package application_api_test + +import ( + "context" + gosql "database/sql" + "fmt" + "net/url" + "reflect" + "sort" + "strings" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/spanconfig" + "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/sql/appstatspb" + "github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants" + "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" + "github.com/cockroachdb/cockroach/pkg/sql/sqlstats" + "github.com/cockroachdb/cockroach/pkg/sql/tests" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/skip" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/kr/pretty" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStatusAPICombinedTransactions(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + params, _ := tests.CreateTestServerParams() + params.Knobs.SpanConfig = &spanconfig.TestingKnobs{ManagerDisableJobCreation: true} // TODO(irfansharif): #74919. + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{ + ServerArgs: params, + }) + ctx := context.Background() + defer testCluster.Stopper().Stop(ctx) + + thirdServer := testCluster.Server(2) + pgURL, cleanupGoDB := sqlutils.PGUrl( + t, thirdServer.ServingSQLAddr(), "CreateConnections" /* prefix */, url.User(username.RootUser)) + defer cleanupGoDB() + firstServerProto := testCluster.Server(0) + + type testCase struct { + query string + fingerprinted string + count int + shouldRetry bool + numRows int + } + + testCases := []testCase{ + {query: `CREATE DATABASE roachblog`, count: 1, numRows: 0}, + {query: `SET database = roachblog`, count: 1, numRows: 0}, + {query: `CREATE TABLE posts (id INT8 PRIMARY KEY, body STRING)`, count: 1, numRows: 0}, + { + query: `INSERT INTO posts VALUES (1, 'foo')`, + fingerprinted: `INSERT INTO posts VALUES (_, '_')`, + count: 1, + numRows: 1, + }, + {query: `SELECT * FROM posts`, count: 2, numRows: 1}, + {query: `BEGIN; SELECT * FROM posts; SELECT * FROM posts; COMMIT`, count: 3, numRows: 2}, + { + query: `BEGIN; SELECT crdb_internal.force_retry('2s'); SELECT * FROM posts; COMMIT;`, + fingerprinted: `BEGIN; SELECT crdb_internal.force_retry(_); SELECT * FROM posts; COMMIT;`, + shouldRetry: true, + count: 1, + numRows: 2, + }, + { + query: `BEGIN; SELECT crdb_internal.force_retry('5s'); SELECT * FROM posts; COMMIT;`, + fingerprinted: `BEGIN; SELECT crdb_internal.force_retry(_); SELECT * FROM posts; COMMIT;`, + shouldRetry: true, + count: 1, + numRows: 2, + }, + } + + appNameToTestCase := make(map[string]testCase) + + for i, tc := range testCases { + appName := fmt.Sprintf("app%d", i) + appNameToTestCase[appName] = tc + + // Create a brand new connection for each app, so that we don't pollute + // transaction stats collection with `SET application_name` queries. + sqlDB, err := gosql.Open("postgres", pgURL.String()) + if err != nil { + t.Fatal(err) + } + if _, err := sqlDB.Exec(fmt.Sprintf(`SET application_name = "%s"`, appName)); err != nil { + t.Fatal(err) + } + for c := 0; c < tc.count; c++ { + if _, err := sqlDB.Exec(tc.query); err != nil { + t.Fatal(err) + } + } + if err := sqlDB.Close(); err != nil { + t.Fatal(err) + } + } + + // Hit query endpoint. + var resp serverpb.StatementsResponse + if err := srvtestutils.GetStatusJSONProto(firstServerProto, "combinedstmts", &resp); err != nil { + t.Fatal(err) + } + + // Construct a map of all the statement fingerprint IDs. + statementFingerprintIDs := make(map[appstatspb.StmtFingerprintID]bool, len(resp.Statements)) + for _, respStatement := range resp.Statements { + statementFingerprintIDs[respStatement.ID] = true + } + + respAppNames := make(map[string]bool) + for _, respTransaction := range resp.Transactions { + appName := respTransaction.StatsData.App + tc, found := appNameToTestCase[appName] + if !found { + // Ignore internal queries, they aren't relevant to this test. + continue + } + respAppNames[appName] = true + // Ensure all statementFingerprintIDs comprised by the Transaction Response can be + // linked to StatementFingerprintIDs for statements in the response. + for _, stmtFingerprintID := range respTransaction.StatsData.StatementFingerprintIDs { + if _, found := statementFingerprintIDs[stmtFingerprintID]; !found { + t.Fatalf("app: %s, expected stmtFingerprintID: %d not found in StatementResponse.", appName, stmtFingerprintID) + } + } + stats := respTransaction.StatsData.Stats + if tc.count != int(stats.Count) { + t.Fatalf("app: %s, expected count %d, got %d", appName, tc.count, stats.Count) + } + if tc.shouldRetry && respTransaction.StatsData.Stats.MaxRetries == 0 { + t.Fatalf("app: %s, expected retries, got none\n", appName) + } + + // Sanity check numeric stat values + if respTransaction.StatsData.Stats.CommitLat.Mean <= 0 { + t.Fatalf("app: %s, unexpected mean for commit latency\n", appName) + } + if respTransaction.StatsData.Stats.RetryLat.Mean <= 0 && tc.shouldRetry { + t.Fatalf("app: %s, expected retry latency mean to be non-zero as retries were involved\n", appName) + } + if respTransaction.StatsData.Stats.ServiceLat.Mean <= 0 { + t.Fatalf("app: %s, unexpected mean for service latency\n", appName) + } + if respTransaction.StatsData.Stats.NumRows.Mean != float64(tc.numRows) { + t.Fatalf("app: %s, unexpected number of rows observed. expected: %d, got %d\n", + appName, tc.numRows, int(respTransaction.StatsData.Stats.NumRows.Mean)) + } + } + + // Ensure we got transaction statistics for all the queries we sent. + for appName := range appNameToTestCase { + if _, found := respAppNames[appName]; !found { + t.Fatalf("app: %s did not appear in the response\n", appName) + } + } +} + +func TestStatusAPITransactions(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{}) + ctx := context.Background() + defer testCluster.Stopper().Stop(ctx) + + thirdServer := testCluster.Server(2) + pgURL, cleanupGoDB := sqlutils.PGUrl( + t, thirdServer.ServingSQLAddr(), "CreateConnections" /* prefix */, url.User(username.RootUser)) + defer cleanupGoDB() + firstServerProto := testCluster.Server(0) + + type testCase struct { + query string + fingerprinted string + count int + shouldRetry bool + numRows int + } + + testCases := []testCase{ + {query: `CREATE DATABASE roachblog`, count: 1, numRows: 0}, + {query: `SET database = roachblog`, count: 1, numRows: 0}, + {query: `CREATE TABLE posts (id INT8 PRIMARY KEY, body STRING)`, count: 1, numRows: 0}, + { + query: `INSERT INTO posts VALUES (1, 'foo')`, + fingerprinted: `INSERT INTO posts VALUES (_, _)`, + count: 1, + numRows: 1, + }, + {query: `SELECT * FROM posts`, count: 2, numRows: 1}, + {query: `BEGIN; SELECT * FROM posts; SELECT * FROM posts; COMMIT`, count: 3, numRows: 2}, + { + query: `BEGIN; SELECT crdb_internal.force_retry('2s'); SELECT * FROM posts; COMMIT;`, + fingerprinted: `BEGIN; SELECT crdb_internal.force_retry(_); SELECT * FROM posts; COMMIT;`, + shouldRetry: true, + count: 1, + numRows: 2, + }, + { + query: `BEGIN; SELECT crdb_internal.force_retry('5s'); SELECT * FROM posts; COMMIT;`, + fingerprinted: `BEGIN; SELECT crdb_internal.force_retry(_); SELECT * FROM posts; COMMIT;`, + shouldRetry: true, + count: 1, + numRows: 2, + }, + } + + appNameToTestCase := make(map[string]testCase) + + for i, tc := range testCases { + appName := fmt.Sprintf("app%d", i) + appNameToTestCase[appName] = tc + + // Create a brand new connection for each app, so that we don't pollute + // transaction stats collection with `SET application_name` queries. + sqlDB, err := gosql.Open("postgres", pgURL.String()) + if err != nil { + t.Fatal(err) + } + if _, err := sqlDB.Exec(fmt.Sprintf(`SET application_name = "%s"`, appName)); err != nil { + t.Fatal(err) + } + for c := 0; c < tc.count; c++ { + if _, err := sqlDB.Exec(tc.query); err != nil { + t.Fatal(err) + } + } + if err := sqlDB.Close(); err != nil { + t.Fatal(err) + } + } + + // Hit query endpoint. + var resp serverpb.StatementsResponse + if err := srvtestutils.GetStatusJSONProto(firstServerProto, "statements", &resp); err != nil { + t.Fatal(err) + } + + // Construct a map of all the statement fingerprint IDs. + statementFingerprintIDs := make(map[appstatspb.StmtFingerprintID]bool, len(resp.Statements)) + for _, respStatement := range resp.Statements { + statementFingerprintIDs[respStatement.ID] = true + } + + respAppNames := make(map[string]bool) + for _, respTransaction := range resp.Transactions { + appName := respTransaction.StatsData.App + tc, found := appNameToTestCase[appName] + if !found { + // Ignore internal queries, they aren't relevant to this test. + continue + } + respAppNames[appName] = true + // Ensure all statementFingerprintIDs comprised by the Transaction Response can be + // linked to StatementFingerprintIDs for statements in the response. + for _, stmtFingerprintID := range respTransaction.StatsData.StatementFingerprintIDs { + if _, found := statementFingerprintIDs[stmtFingerprintID]; !found { + t.Fatalf("app: %s, expected stmtFingerprintID: %d not found in StatementResponse.", appName, stmtFingerprintID) + } + } + stats := respTransaction.StatsData.Stats + if tc.count != int(stats.Count) { + t.Fatalf("app: %s, expected count %d, got %d", appName, tc.count, stats.Count) + } + if tc.shouldRetry && respTransaction.StatsData.Stats.MaxRetries == 0 { + t.Fatalf("app: %s, expected retries, got none\n", appName) + } + + // Sanity check numeric stat values + if respTransaction.StatsData.Stats.CommitLat.Mean <= 0 { + t.Fatalf("app: %s, unexpected mean for commit latency\n", appName) + } + if respTransaction.StatsData.Stats.RetryLat.Mean <= 0 && tc.shouldRetry { + t.Fatalf("app: %s, expected retry latency mean to be non-zero as retries were involved\n", appName) + } + if respTransaction.StatsData.Stats.ServiceLat.Mean <= 0 { + t.Fatalf("app: %s, unexpected mean for service latency\n", appName) + } + if respTransaction.StatsData.Stats.NumRows.Mean != float64(tc.numRows) { + t.Fatalf("app: %s, unexpected number of rows observed. expected: %d, got %d\n", + appName, tc.numRows, int(respTransaction.StatsData.Stats.NumRows.Mean)) + } + } + + // Ensure we got transaction statistics for all the queries we sent. + for appName := range appNameToTestCase { + if _, found := respAppNames[appName]; !found { + t.Fatalf("app: %s did not appear in the response\n", appName) + } + } +} + +func TestStatusAPITransactionStatementFingerprintIDsTruncation(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + params, _ := tests.CreateTestServerParams() + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{ + ServerArgs: params, + }) + defer testCluster.Stopper().Stop(context.Background()) + + firstServerProto := testCluster.Server(0) + thirdServerSQL := sqlutils.MakeSQLRunner(testCluster.ServerConn(2)) + testingApp := "testing" + + thirdServerSQL.Exec(t, `CREATE DATABASE db; CREATE TABLE db.t();`) + thirdServerSQL.Exec(t, fmt.Sprintf(`SET application_name = "%s"`, testingApp)) + + maxStmtFingerprintIDsLen := int(sqlstats.TxnStatsNumStmtFingerprintIDsToRecord.Get( + &firstServerProto.ExecutorConfig().(sql.ExecutorConfig).Settings.SV)) + + // Construct 2 transaction queries that include an absurd number of statements. + // These two queries have the same first 1000 statements, but should still have + // different fingerprints, as fingerprints take into account all + // statementFingerprintIDs (unlike the statementFingerprintIDs stored on the + // proto response, which are capped). + testQuery1 := "BEGIN;" + for i := 0; i < maxStmtFingerprintIDsLen+1; i++ { + testQuery1 += "SELECT * FROM db.t;" + } + testQuery2 := testQuery1 + "SELECT * FROM db.t; COMMIT;" + testQuery1 += "COMMIT;" + + thirdServerSQL.Exec(t, testQuery1) + thirdServerSQL.Exec(t, testQuery2) + + // Hit query endpoint. + var resp serverpb.StatementsResponse + if err := srvtestutils.GetStatusJSONProto(firstServerProto, "statements", &resp); err != nil { + t.Fatal(err) + } + + txnsFound := 0 + for _, respTransaction := range resp.Transactions { + appName := respTransaction.StatsData.App + if appName != testingApp { + // Only testQuery1 and testQuery2 are relevant to this test. + continue + } + + txnsFound++ + if len(respTransaction.StatsData.StatementFingerprintIDs) != maxStmtFingerprintIDsLen { + t.Fatalf("unexpected length of StatementFingerprintIDs. expected:%d, got:%d", + maxStmtFingerprintIDsLen, len(respTransaction.StatsData.StatementFingerprintIDs)) + } + } + if txnsFound != 2 { + t.Fatalf("transactions were not disambiguated as expected. expected %d txns, got: %d", + 2, txnsFound) + } +} + +func TestStatusAPIStatements(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // Aug 30 2021 19:50:00 GMT+0000 + aggregatedTs := int64(1630353000) + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{ + ServerArgs: base.TestServerArgs{ + Knobs: base.TestingKnobs{ + SQLStatsKnobs: &sqlstats.TestingKnobs{ + AOSTClause: "AS OF SYSTEM TIME '-1us'", + StubTimeNow: func() time.Time { return timeutil.Unix(aggregatedTs, 0) }, + }, + SpanConfig: &spanconfig.TestingKnobs{ + ManagerDisableJobCreation: true, // TODO(irfansharif): #74919. + }, + }, + }, + }) + defer testCluster.Stopper().Stop(context.Background()) + + firstServerProto := testCluster.Server(0) + thirdServerSQL := sqlutils.MakeSQLRunner(testCluster.ServerConn(2)) + + statements := []struct { + stmt string + fingerprinted string + }{ + {stmt: `CREATE DATABASE roachblog`}, + {stmt: `SET database = roachblog`}, + {stmt: `CREATE TABLE posts (id INT8 PRIMARY KEY, body STRING)`}, + { + stmt: `INSERT INTO posts VALUES (1, 'foo')`, + fingerprinted: `INSERT INTO posts VALUES (_, '_')`, + }, + {stmt: `SELECT * FROM posts`}, + } + + for _, stmt := range statements { + thirdServerSQL.Exec(t, stmt.stmt) + } + + // Test that non-admin without VIEWACTIVITY privileges cannot access. + var resp serverpb.StatementsResponse + err := srvtestutils.GetStatusJSONProtoWithAdminOption(firstServerProto, "statements", &resp, false) + if !testutils.IsError(err, "status: 403") { + t.Fatalf("expected privilege error, got %v", err) + } + + testPath := func(path string, expectedStmts []string) { + // Hit query endpoint. + if err := srvtestutils.GetStatusJSONProtoWithAdminOption(firstServerProto, path, &resp, false); err != nil { + t.Fatal(err) + } + + // See if the statements returned are what we executed. + var statementsInResponse []string + for _, respStatement := range resp.Statements { + if respStatement.Key.KeyData.Failed { + // We ignore failed statements here as the INSERT statement can fail and + // be automatically retried, confusing the test success check. + continue + } + if strings.HasPrefix(respStatement.Key.KeyData.App, catconstants.InternalAppNamePrefix) { + // We ignore internal queries, these are not relevant for the + // validity of this test. + continue + } + if strings.HasPrefix(respStatement.Key.KeyData.Query, "ALTER USER") { + // Ignore the ALTER USER ... VIEWACTIVITY statement. + continue + } + statementsInResponse = append(statementsInResponse, respStatement.Key.KeyData.Query) + } + + sort.Strings(expectedStmts) + sort.Strings(statementsInResponse) + + if !reflect.DeepEqual(expectedStmts, statementsInResponse) { + t.Fatalf("expected queries\n\n%v\n\ngot queries\n\n%v\n%s", + expectedStmts, statementsInResponse, pretty.Sprint(resp)) + } + } + + var expectedStatements []string + for _, stmt := range statements { + var expectedStmt = stmt.stmt + if stmt.fingerprinted != "" { + expectedStmt = stmt.fingerprinted + } + expectedStatements = append(expectedStatements, expectedStmt) + } + + // Grant VIEWACTIVITY. + thirdServerSQL.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITY", apiconstants.TestingUserNameNoAdmin().Normalized())) + + // Test no params. + testPath("statements", expectedStatements) + // Test combined=true forwards to CombinedStatements + testPath(fmt.Sprintf("statements?combined=true&start=%d", aggregatedTs+60), nil) + + // Remove VIEWACTIVITY so we can test with just the VIEWACTIVITYREDACTED role. + thirdServerSQL.Exec(t, fmt.Sprintf("ALTER USER %s NOVIEWACTIVITY", apiconstants.TestingUserNameNoAdmin().Normalized())) + // Grant VIEWACTIVITYREDACTED. + thirdServerSQL.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITYREDACTED", apiconstants.TestingUserNameNoAdmin().Normalized())) + + // Test no params. + testPath("statements", expectedStatements) + // Test combined=true forwards to CombinedStatements + testPath(fmt.Sprintf("statements?combined=true&start=%d", aggregatedTs+60), nil) +} + +func TestStatusAPICombinedStatements(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // Aug 30 2021 19:50:00 GMT+0000 + aggregatedTs := int64(1630353000) + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{ + ServerArgs: base.TestServerArgs{ + Knobs: base.TestingKnobs{ + SQLStatsKnobs: &sqlstats.TestingKnobs{ + AOSTClause: "AS OF SYSTEM TIME '-1us'", + StubTimeNow: func() time.Time { return timeutil.Unix(aggregatedTs, 0) }, + }, + SpanConfig: &spanconfig.TestingKnobs{ + ManagerDisableJobCreation: true, // TODO(irfansharif): #74919. + }, + }, + }, + }) + defer testCluster.Stopper().Stop(context.Background()) + + firstServerProto := testCluster.Server(0) + thirdServerSQL := sqlutils.MakeSQLRunner(testCluster.ServerConn(2)) + + statements := []struct { + stmt string + fingerprinted string + }{ + {stmt: `CREATE DATABASE roachblog`}, + {stmt: `SET database = roachblog`}, + {stmt: `CREATE TABLE posts (id INT8 PRIMARY KEY, body STRING)`}, + { + stmt: `INSERT INTO posts VALUES (1, 'foo')`, + fingerprinted: `INSERT INTO posts VALUES (_, '_')`, + }, + {stmt: `SELECT * FROM posts`}, + } + + for _, stmt := range statements { + thirdServerSQL.Exec(t, stmt.stmt) + } + + var resp serverpb.StatementsResponse + // Test that non-admin without VIEWACTIVITY privileges cannot access. + err := srvtestutils.GetStatusJSONProtoWithAdminOption(firstServerProto, "combinedstmts", &resp, false) + if !testutils.IsError(err, "status: 403") { + t.Fatalf("expected privilege error, got %v", err) + } + + verifyStmts := func(path string, expectedStmts []string, hasTxns bool, t *testing.T) { + // Hit query endpoint. + if err := srvtestutils.GetStatusJSONProtoWithAdminOption(firstServerProto, path, &resp, false); err != nil { + t.Fatal(err) + } + + // See if the statements returned are what we executed. + var statementsInResponse []string + expectedTxnFingerprints := map[appstatspb.TransactionFingerprintID]struct{}{} + for _, respStatement := range resp.Statements { + if respStatement.Key.KeyData.Failed { + // We ignore failed statements here as the INSERT statement can fail and + // be automatically retried, confusing the test success check. + continue + } + if strings.HasPrefix(respStatement.Key.KeyData.App, catconstants.InternalAppNamePrefix) { + // CombinedStatementStats should filter out internal queries. + t.Fatalf("unexpected internal query: %s", respStatement.Key.KeyData.Query) + } + if strings.HasPrefix(respStatement.Key.KeyData.Query, "ALTER USER") { + // Ignore the ALTER USER ... VIEWACTIVITY statement. + continue + } + + statementsInResponse = append(statementsInResponse, respStatement.Key.KeyData.Query) + for _, txnFingerprintID := range respStatement.TxnFingerprintIDs { + expectedTxnFingerprints[txnFingerprintID] = struct{}{} + } + } + + for _, respTxn := range resp.Transactions { + delete(expectedTxnFingerprints, respTxn.StatsData.TransactionFingerprintID) + } + + sort.Strings(expectedStmts) + sort.Strings(statementsInResponse) + + if !reflect.DeepEqual(expectedStmts, statementsInResponse) { + t.Fatalf("expected queries\n\n%v\n\ngot queries\n\n%v\n%s\n path: %s", + expectedStmts, statementsInResponse, pretty.Sprint(resp), path) + } + if hasTxns { + // We expect that expectedTxnFingerprints is now empty since + // we should have removed them all. + assert.Empty(t, expectedTxnFingerprints) + } else { + assert.Empty(t, resp.Transactions) + } + } + + var expectedStatements []string + for _, stmt := range statements { + var expectedStmt = stmt.stmt + if stmt.fingerprinted != "" { + expectedStmt = stmt.fingerprinted + } + expectedStatements = append(expectedStatements, expectedStmt) + } + + oneMinAfterAggregatedTs := aggregatedTs + 60 + + t.Run("fetch_mode=combined, VIEWACTIVITY", func(t *testing.T) { + // Grant VIEWACTIVITY. + thirdServerSQL.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITY", apiconstants.TestingUserNameNoAdmin().Normalized())) + + // Test with no query params. + verifyStmts("combinedstmts", expectedStatements, true, t) + // Test with end = 1 min after aggregatedTs; should give the same results as get all. + verifyStmts(fmt.Sprintf("combinedstmts?end=%d", oneMinAfterAggregatedTs), expectedStatements, true, t) + // Test with start = 1 hour before aggregatedTs end = 1 min after aggregatedTs; should give same results as get all. + verifyStmts(fmt.Sprintf("combinedstmts?start=%d&end=%d", aggregatedTs-3600, oneMinAfterAggregatedTs), + expectedStatements, true, t) + // Test with start = 1 min after aggregatedTs; should give no results + verifyStmts(fmt.Sprintf("combinedstmts?start=%d", oneMinAfterAggregatedTs), nil, true, t) + }) + + t.Run("fetch_mode=combined, VIEWACTIVITYREDACTED", func(t *testing.T) { + // Remove VIEWACTIVITY so we can test with just the VIEWACTIVITYREDACTED role. + thirdServerSQL.Exec(t, fmt.Sprintf("ALTER USER %s NOVIEWACTIVITY", apiconstants.TestingUserNameNoAdmin().Normalized())) + // Grant VIEWACTIVITYREDACTED. + thirdServerSQL.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITYREDACTED", apiconstants.TestingUserNameNoAdmin().Normalized())) + + // Test with no query params. + verifyStmts("combinedstmts", expectedStatements, true, t) + // Test with end = 1 min after aggregatedTs; should give the same results as get all. + verifyStmts(fmt.Sprintf("combinedstmts?end=%d", oneMinAfterAggregatedTs), expectedStatements, true, t) + // Test with start = 1 hour before aggregatedTs end = 1 min after aggregatedTs; should give same results as get all. + verifyStmts(fmt.Sprintf("combinedstmts?start=%d&end=%d", aggregatedTs-3600, oneMinAfterAggregatedTs), expectedStatements, true, t) + // Test with start = 1 min after aggregatedTs; should give no results + verifyStmts(fmt.Sprintf("combinedstmts?start=%d", oneMinAfterAggregatedTs), nil, true, t) + }) + + t.Run("fetch_mode=StmtsOnly", func(t *testing.T) { + verifyStmts("combinedstmts?fetch_mode.stats_type=0", expectedStatements, false, t) + }) + + t.Run("fetch_mode=TxnsOnly with limit", func(t *testing.T) { + // Verify that we only return stmts for the txns in the response. + // We'll add a limit in a later commit to help verify this behaviour. + if err := srvtestutils.GetStatusJSONProtoWithAdminOption(firstServerProto, "combinedstmts?fetch_mode.stats_type=1&limit=2", + &resp, false); err != nil { + t.Fatal(err) + } + + assert.Equal(t, 2, len(resp.Transactions)) + stmtFingerprintIDs := map[appstatspb.StmtFingerprintID]struct{}{} + for _, txn := range resp.Transactions { + for _, stmtFingerprint := range txn.StatsData.StatementFingerprintIDs { + stmtFingerprintIDs[stmtFingerprint] = struct{}{} + } + } + + for _, stmt := range resp.Statements { + if _, ok := stmtFingerprintIDs[stmt.ID]; !ok { + t.Fatalf("unexpected stmt; stmt unrelated to a txn int he response: %s", stmt.Key.KeyData.Query) + } + } + }) +} + +func TestStatusAPIStatementDetails(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + // The liveness session might expire before the stress race can finish. + skip.UnderStressRace(t, "expensive tests") + + // Aug 30 2021 19:50:00 GMT+0000 + aggregatedTs := int64(1630353000) + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{ + ServerArgs: base.TestServerArgs{ + Knobs: base.TestingKnobs{ + SQLStatsKnobs: &sqlstats.TestingKnobs{ + AOSTClause: "AS OF SYSTEM TIME '-1us'", + StubTimeNow: func() time.Time { return timeutil.Unix(aggregatedTs, 0) }, + }, + SpanConfig: &spanconfig.TestingKnobs{ + ManagerDisableJobCreation: true, + }, + }, + }, + }) + defer testCluster.Stopper().Stop(context.Background()) + + firstServerProto := testCluster.Server(0) + thirdServerSQL := sqlutils.MakeSQLRunner(testCluster.ServerConn(2)) + + statements := []string{ + `set application_name = 'first-app'`, + `CREATE DATABASE roachblog`, + `SET database = roachblog`, + `CREATE TABLE posts (id INT8 PRIMARY KEY, body STRING)`, + `INSERT INTO posts VALUES (1, 'foo')`, + `INSERT INTO posts VALUES (2, 'foo')`, + `INSERT INTO posts VALUES (3, 'foo')`, + `SELECT * FROM posts`, + } + + for _, stmt := range statements { + thirdServerSQL.Exec(t, stmt) + } + + query := `INSERT INTO posts VALUES (_, '_')` + fingerprintID := appstatspb.ConstructStatementFingerprintID(query, + false, true, `roachblog`) + path := fmt.Sprintf(`stmtdetails/%v`, fingerprintID) + + var resp serverpb.StatementDetailsResponse + // Test that non-admin without VIEWACTIVITY or VIEWACTIVITYREDACTED privileges cannot access. + err := srvtestutils.GetStatusJSONProtoWithAdminOption(firstServerProto, path, &resp, false) + if !testutils.IsError(err, "status: 403") { + t.Fatalf("expected privilege error, got %v", err) + } + + type resultValues struct { + query string + totalCount int + aggregatedTsCount int + planHashCount int + fullScanCount int + appNames []string + databases []string + } + + testPath := func(path string, expected resultValues) { + err := srvtestutils.GetStatusJSONProtoWithAdminOption(firstServerProto, path, &resp, false) + require.NoError(t, err) + require.Equal(t, int64(expected.totalCount), resp.Statement.Stats.Count) + require.Equal(t, expected.aggregatedTsCount, len(resp.StatementStatisticsPerAggregatedTs)) + require.Equal(t, expected.planHashCount, len(resp.StatementStatisticsPerPlanHash)) + require.Equal(t, expected.query, resp.Statement.Metadata.Query) + require.Equal(t, expected.appNames, resp.Statement.Metadata.AppNames) + require.Equal(t, int64(expected.totalCount), resp.Statement.Metadata.TotalCount) + require.Equal(t, expected.databases, resp.Statement.Metadata.Databases) + require.Equal(t, int64(expected.fullScanCount), resp.Statement.Metadata.FullScanCount) + } + + // Grant VIEWACTIVITY. + thirdServerSQL.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITY", apiconstants.TestingUserNameNoAdmin().Normalized())) + + // Test with no query params. + testPath( + path, + resultValues{ + query: query, + totalCount: 3, + aggregatedTsCount: 1, + planHashCount: 1, + appNames: []string{"first-app"}, + fullScanCount: 0, + databases: []string{"roachblog"}, + }) + + // Execute same fingerprint id statement on a different application + statements = []string{ + `set application_name = 'second-app'`, + `INSERT INTO posts VALUES (4, 'foo')`, + `INSERT INTO posts VALUES (5, 'foo')`, + } + for _, stmt := range statements { + thirdServerSQL.Exec(t, stmt) + } + + oneMinAfterAggregatedTs := aggregatedTs + 60 + + testData := []struct { + path string + expectedResult resultValues + }{ + { // Test with no query params. + path: path, + expectedResult: resultValues{ + query: query, + totalCount: 5, + aggregatedTsCount: 1, + planHashCount: 1, + appNames: []string{"first-app", "second-app"}, + fullScanCount: 0, + databases: []string{"roachblog"}}, + }, + { // Test with end = 1 min after aggregatedTs; should give the same results as get all. + path: fmt.Sprintf("%v?end=%d", path, oneMinAfterAggregatedTs), + expectedResult: resultValues{ + query: query, + totalCount: 5, + aggregatedTsCount: 1, + planHashCount: 1, + appNames: []string{"first-app", "second-app"}, + fullScanCount: 0, + databases: []string{"roachblog"}}, + }, + { // Test with start = 1 hour before aggregatedTs end = 1 min after aggregatedTs; should give same results as get all. + path: fmt.Sprintf("%v?start=%d&end=%d", path, aggregatedTs-3600, oneMinAfterAggregatedTs), + expectedResult: resultValues{ + query: query, + totalCount: 5, + aggregatedTsCount: 1, + planHashCount: 1, + appNames: []string{"first-app", "second-app"}, + fullScanCount: 0, + databases: []string{"roachblog"}}, + }, + { // Test with start = 1 min after aggregatedTs; should give no results. + path: fmt.Sprintf("%v?start=%d", path, oneMinAfterAggregatedTs), + expectedResult: resultValues{ + query: "", + totalCount: 0, + aggregatedTsCount: 0, + planHashCount: 0, + appNames: []string{}, + fullScanCount: 0, + databases: []string{}}, + }, + { // Test with one app_name. + path: fmt.Sprintf("%v?app_names=first-app", path), + expectedResult: resultValues{ + query: query, + totalCount: 3, + aggregatedTsCount: 1, + planHashCount: 1, + appNames: []string{"first-app"}, + fullScanCount: 0, + databases: []string{"roachblog"}}, + }, + { // Test with another app_name. + path: fmt.Sprintf("%v?app_names=second-app", path), + expectedResult: resultValues{ + query: query, + totalCount: 2, + aggregatedTsCount: 1, + planHashCount: 1, + appNames: []string{"second-app"}, + fullScanCount: 0, + databases: []string{"roachblog"}}, + }, + { // Test with both app_names. + path: fmt.Sprintf("%v?app_names=first-app&app_names=second-app", path), + expectedResult: resultValues{ + query: query, + totalCount: 5, + aggregatedTsCount: 1, + planHashCount: 1, + appNames: []string{"first-app", "second-app"}, + fullScanCount: 0, + databases: []string{"roachblog"}}, + }, + { // Test with non-existing app_name. + path: fmt.Sprintf("%v?app_names=non-existing", path), + expectedResult: resultValues{ + query: "", + totalCount: 0, + aggregatedTsCount: 0, + planHashCount: 0, + appNames: []string{}, + fullScanCount: 0, + databases: []string{}}, + }, + { // Test with app_name, start and end time. + path: fmt.Sprintf("%v?start=%d&end=%d&app_names=first-app&app_names=second-app", path, aggregatedTs-3600, oneMinAfterAggregatedTs), + expectedResult: resultValues{ + query: query, + totalCount: 5, + aggregatedTsCount: 1, + planHashCount: 1, + appNames: []string{"first-app", "second-app"}, + fullScanCount: 0, + databases: []string{"roachblog"}}, + }, + } + + for _, test := range testData { + testPath(test.path, test.expectedResult) + } + + // Remove VIEWACTIVITY so we can test with just the VIEWACTIVITYREDACTED role. + thirdServerSQL.Exec(t, fmt.Sprintf("ALTER USER %s NOVIEWACTIVITY", apiconstants.TestingUserNameNoAdmin().Normalized())) + // Grant VIEWACTIVITYREDACTED. + thirdServerSQL.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITYREDACTED", apiconstants.TestingUserNameNoAdmin().Normalized())) + + for _, test := range testData { + testPath(test.path, test.expectedResult) + } + + // Test fix for #83608. The stmt being below requested has a fingerprint id + // that is 15 chars in hexadecimal. We should be able to find this stmt now + // that we construct the filter using a bytes comparison instead of string. + + statements = []string{ + `set application_name = 'fix_83608'`, + `set database = defaultdb`, + `SELECT 1, 2, 3, 4`, + } + for _, stmt := range statements { + thirdServerSQL.Exec(t, stmt) + } + + selectQuery := "SELECT _, _, _, _" + fingerprintID = appstatspb.ConstructStatementFingerprintID(selectQuery, false, + true, "defaultdb") + + testPath( + fmt.Sprintf(`stmtdetails/%v`, fingerprintID), + resultValues{ + query: selectQuery, + totalCount: 1, + aggregatedTsCount: 1, + planHashCount: 1, + appNames: []string{"fix_83608"}, + fullScanCount: 0, + databases: []string{"defaultdb"}, + }) +} + +func TestUnprivilegedUserResetIndexUsageStats(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + + s, conn, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(ctx) + + sqlConn := sqlutils.MakeSQLRunner(conn) + sqlConn.Exec(t, "CREATE USER nonAdminUser") + + ie := s.InternalExecutor().(*sql.InternalExecutor) + + _, err := ie.ExecEx( + ctx, + "test-reset-index-usage-stats-as-non-admin-user", + nil, /* txn */ + sessiondata.InternalExecutorOverride{ + User: username.MakeSQLUsernameFromPreNormalizedString("nonAdminUser"), + }, + "SELECT crdb_internal.reset_index_usage_stats()", + ) + + require.Contains(t, err.Error(), "requires admin privilege") +} diff --git a/pkg/server/application_api/stmtdiag_test.go b/pkg/server/application_api/stmtdiag_test.go new file mode 100644 index 000000000000..dac6634898a4 --- /dev/null +++ b/pkg/server/application_api/stmtdiag_test.go @@ -0,0 +1,265 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package application_api_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/stretchr/testify/require" +) + +func TestAdminAPIStatementDiagnosticsBundle(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + ts := s.(*server.TestServer) + + query := "EXPLAIN ANALYZE (DEBUG) SELECT 'secret'" + _, err := db.Exec(query) + require.NoError(t, err) + + query = "SELECT id FROM system.statement_diagnostics LIMIT 1" + idRow, err := db.Query(query) + require.NoError(t, err) + var diagnosticRow string + if idRow.Next() { + err = idRow.Scan(&diagnosticRow) + require.NoError(t, err) + } else { + t.Fatal("no results") + } + + client, err := ts.GetAuthenticatedHTTPClient(false, serverutils.SingleTenantSession) + require.NoError(t, err) + resp, err := client.Get(ts.AdminURL().WithPath("/_admin/v1/stmtbundle/" + diagnosticRow).String()) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, 500, resp.StatusCode) + + adminClient, err := ts.GetAuthenticatedHTTPClient(true, serverutils.SingleTenantSession) + require.NoError(t, err) + adminResp, err := adminClient.Get(ts.AdminURL().WithPath("/_admin/v1/stmtbundle/" + diagnosticRow).String()) + require.NoError(t, err) + defer adminResp.Body.Close() + require.Equal(t, 200, adminResp.StatusCode) +} + +func TestCreateStatementDiagnosticsReport(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + + req := &serverpb.CreateStatementDiagnosticsReportRequest{ + StatementFingerprint: "INSERT INTO test VALUES (_)", + } + var resp serverpb.CreateStatementDiagnosticsReportResponse + if err := srvtestutils.PostStatusJSONProto(s, "stmtdiagreports", req, &resp); err != nil { + t.Fatal(err) + } + + var respGet serverpb.StatementDiagnosticsReportsResponse + if err := srvtestutils.GetStatusJSONProto(s, "stmtdiagreports", &respGet); err != nil { + t.Fatal(err) + } + + if respGet.Reports[0].StatementFingerprint != req.StatementFingerprint { + t.Fatal("statement diagnostics request was not persisted") + } +} + +func TestCreateStatementDiagnosticsReportWithViewActivityOptions(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + s, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + db := sqlutils.MakeSQLRunner(sqlDB) + + ctx := context.Background() + ie := s.InternalExecutor().(*sql.InternalExecutor) + + if err := srvtestutils.GetStatusJSONProtoWithAdminOption(s, "stmtdiagreports", &serverpb.CreateStatementDiagnosticsReportRequest{}, false); err != nil { + if !testutils.IsError(err, "status: 403") { + t.Fatalf("expected privilege error, got %v", err) + } + } + _, err := ie.ExecEx( + ctx, + "inserting-stmt-bundle-req", + nil, /* txn */ + sessiondata.InternalExecutorOverride{ + User: apiconstants.TestingUserNameNoAdmin(), + }, + "SELECT crdb_internal.request_statement_bundle('SELECT _', 0::FLOAT, 0::INTERVAL, 0::INTERVAL)", + ) + require.Contains(t, err.Error(), "requesting statement bundle requires VIEWACTIVITY or ADMIN role option") + + // Grant VIEWACTIVITY and all test should work. + db.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITY", apiconstants.TestingUserNameNoAdmin().Normalized())) + req := &serverpb.CreateStatementDiagnosticsReportRequest{ + StatementFingerprint: "INSERT INTO test VALUES (_)", + } + var resp serverpb.CreateStatementDiagnosticsReportResponse + if err := srvtestutils.PostStatusJSONProtoWithAdminOption(s, "stmtdiagreports", req, &resp, false); err != nil { + t.Fatal(err) + } + var respGet serverpb.StatementDiagnosticsReportsResponse + if err := srvtestutils.GetStatusJSONProtoWithAdminOption(s, "stmtdiagreports", &respGet, false); err != nil { + t.Fatal(err) + } + if respGet.Reports[0].StatementFingerprint != req.StatementFingerprint { + t.Fatal("statement diagnostics request was not persisted") + } + _, err = ie.ExecEx( + ctx, + "inserting-stmt-bundle-req", + nil, /* txn */ + sessiondata.InternalExecutorOverride{ + User: apiconstants.TestingUserNameNoAdmin(), + }, + "SELECT crdb_internal.request_statement_bundle('SELECT _', 0::FLOAT, 0::INTERVAL, 0::INTERVAL)", + ) + require.NoError(t, err) + + db.CheckQueryResults(t, ` + SELECT count(*) + FROM system.statement_diagnostics_requests + WHERE statement_fingerprint = 'SELECT _' +`, [][]string{{"1"}}) + + // Grant VIEWACTIVITYREDACTED and all test should get permission errors. + db.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITYREDACTED", apiconstants.TestingUserNameNoAdmin().Normalized())) + + if err := srvtestutils.PostStatusJSONProtoWithAdminOption(s, "stmtdiagreports", req, &resp, false); err != nil { + if !testutils.IsError(err, "status: 403") { + t.Fatalf("expected privilege error, got %v", err) + } + } + if err := srvtestutils.GetStatusJSONProtoWithAdminOption(s, "stmtdiagreports", &respGet, false); err != nil { + if !testutils.IsError(err, "status: 403") { + t.Fatalf("expected privilege error, got %v", err) + } + } + + _, err = ie.ExecEx( + ctx, + "inserting-stmt-bundle-req", + nil, /* txn */ + sessiondata.InternalExecutorOverride{ + User: apiconstants.TestingUserNameNoAdmin(), + }, + "SELECT crdb_internal.request_statement_bundle('SELECT _', 0::FLOAT, 0::INTERVAL, 0::INTERVAL)", + ) + require.Contains(t, err.Error(), "VIEWACTIVITYREDACTED role option cannot request statement bundle") +} + +func TestStatementDiagnosticsCompleted(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + + _, err := db.Exec("CREATE TABLE test (x int PRIMARY KEY)") + if err != nil { + t.Fatal(err) + } + + req := &serverpb.CreateStatementDiagnosticsReportRequest{ + StatementFingerprint: "INSERT INTO test VALUES (_)", + } + var resp serverpb.CreateStatementDiagnosticsReportResponse + if err := srvtestutils.PostStatusJSONProto(s, "stmtdiagreports", req, &resp); err != nil { + t.Fatal(err) + } + + _, err = db.Exec("INSERT INTO test VALUES (1)") + if err != nil { + t.Fatal(err) + } + + var respGet serverpb.StatementDiagnosticsReportsResponse + if err := srvtestutils.GetStatusJSONProto(s, "stmtdiagreports", &respGet); err != nil { + t.Fatal(err) + } + + if respGet.Reports[0].Completed != true { + t.Fatal("statement diagnostics was not captured") + } + + var diagRespGet serverpb.StatementDiagnosticsResponse + diagPath := fmt.Sprintf("stmtdiag/%d", respGet.Reports[0].StatementDiagnosticsId) + if err := srvtestutils.GetStatusJSONProto(s, diagPath, &diagRespGet); err != nil { + t.Fatal(err) + } +} + +func TestStatementDiagnosticsDoesNotReturnExpiredRequests(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + s, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + db := sqlutils.MakeSQLRunner(sqlDB) + + statementFingerprint := "INSERT INTO test VALUES (_)" + expiresAfter := 5 * time.Millisecond + + // Create statement diagnostics request with defined expiry time. + req := &serverpb.CreateStatementDiagnosticsReportRequest{ + StatementFingerprint: statementFingerprint, + MinExecutionLatency: 500 * time.Millisecond, + ExpiresAfter: expiresAfter, + } + var resp serverpb.CreateStatementDiagnosticsReportResponse + if err := srvtestutils.PostStatusJSONProto(s, "stmtdiagreports", req, &resp); err != nil { + t.Fatal(err) + } + + // Wait for request to expire. + time.Sleep(expiresAfter) + + // Check that created statement diagnostics report is incomplete. + report := db.QueryStr(t, ` +SELECT completed +FROM system.statement_diagnostics_requests +WHERE statement_fingerprint = $1`, statementFingerprint) + + require.Equal(t, report[0][0], "false") + + // Check that expired report is not returned in API response. + var respGet serverpb.StatementDiagnosticsReportsResponse + if err := srvtestutils.GetStatusJSONProto(s, "stmtdiagreports", &respGet); err != nil { + t.Fatal(err) + } + + for _, report := range respGet.Reports { + require.NotEqual(t, report.StatementFingerprint, statementFingerprint) + } +} diff --git a/pkg/server/application_api/storage_inspection_test.go b/pkg/server/application_api/storage_inspection_test.go new file mode 100644 index 000000000000..01e6235b1ba9 --- /dev/null +++ b/pkg/server/application_api/storage_inspection_test.go @@ -0,0 +1,494 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package application_api_test + +import ( + "context" + "fmt" + "math" + "reflect" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/keys" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" + "github.com/cockroachdb/cockroach/pkg/server/rangetestutils" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/skip" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/grunning" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAdminAPINonTableStats(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{}) + defer testCluster.Stopper().Stop(context.Background()) + s := testCluster.Server(0) + + // Skip TableStatsResponse.Stats comparison, since it includes data which + // aren't consistent (time, bytes). + expectedResponse := serverpb.NonTableStatsResponse{ + TimeSeriesStats: &serverpb.TableStatsResponse{ + RangeCount: 1, + ReplicaCount: 3, + NodeCount: 3, + }, + InternalUseStats: &serverpb.TableStatsResponse{ + RangeCount: 11, + ReplicaCount: 15, + NodeCount: 3, + }, + } + + var resp serverpb.NonTableStatsResponse + if err := srvtestutils.GetAdminJSONProto(s, "nontablestats", &resp); err != nil { + t.Fatal(err) + } + + assertExpectedStatsResponse := func(expected, actual *serverpb.TableStatsResponse) { + assert.Equal(t, expected.RangeCount, actual.RangeCount) + assert.Equal(t, expected.ReplicaCount, actual.ReplicaCount) + assert.Equal(t, expected.NodeCount, actual.NodeCount) + } + + assertExpectedStatsResponse(expectedResponse.TimeSeriesStats, resp.TimeSeriesStats) + assertExpectedStatsResponse(expectedResponse.InternalUseStats, resp.InternalUseStats) +} + +// Verify that for a cluster with no user data, all the ranges on the Databases +// page consist of: +// 1) the total ranges listed for the system database +// 2) the total ranges listed for the Non-Table data +func TestRangeCount(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{}) + require.NoError(t, testCluster.WaitForFullReplication()) + defer testCluster.Stopper().Stop(context.Background()) + s := testCluster.Server(0) + + // Sum up ranges for non-table parts of the system returned + // from the "nontablestats" enpoint. + getNonTableRangeCount := func() (ts, internal int64) { + var resp serverpb.NonTableStatsResponse + if err := srvtestutils.GetAdminJSONProto(s, "nontablestats", &resp); err != nil { + t.Fatal(err) + } + return resp.TimeSeriesStats.RangeCount, resp.InternalUseStats.RangeCount + } + + // Return map tablename=>count obtained from the + // "databases/system/tables/{table}" endpoints. + getSystemTableRangeCount := func() map[string]int64 { + m := map[string]int64{} + var dbResp serverpb.DatabaseDetailsResponse + if err := srvtestutils.GetAdminJSONProto(s, "databases/system", &dbResp); err != nil { + t.Fatal(err) + } + for _, tableName := range dbResp.TableNames { + var tblResp serverpb.TableStatsResponse + path := "databases/system/tables/" + tableName + "/stats" + if err := srvtestutils.GetAdminJSONProto(s, path, &tblResp); err != nil { + t.Fatal(err) + } + m[tableName] = tblResp.RangeCount + } + // Hardcode the single range used by each system sequence, the above + // request does not return sequences. + // TODO(richardjcai): Maybe update the request to return + // sequences as well? + m[fmt.Sprintf("public.%s", catconstants.DescIDSequenceTableName)] = 1 + m[fmt.Sprintf("public.%s", catconstants.RoleIDSequenceName)] = 1 + m[fmt.Sprintf("public.%s", catconstants.TenantIDSequenceTableName)] = 1 + return m + } + + getRangeCountFromFullSpan := func() int64 { + ts := s.(*server.TestServer) + stats, err := ts.TestingStatsForSpan(context.Background(), roachpb.Span{ + Key: keys.LocalMax, + EndKey: keys.MaxKey, + }) + if err != nil { + t.Fatal(err) + } + return stats.RangeCount + } + + exp := getRangeCountFromFullSpan() + + var systemTableRangeCount int64 + sysDBMap := getSystemTableRangeCount() + for _, n := range sysDBMap { + systemTableRangeCount += n + } + + tsCount, internalCount := getNonTableRangeCount() + + act := tsCount + internalCount + systemTableRangeCount + + if !assert.Equal(t, + exp, + act, + ) { + t.Log("did nonTableDescriptorRangeCount() change?") + t.Logf( + "claimed numbers:\ntime series = %d\ninternal = %d\nsystemdb = %d (%v)", + tsCount, internalCount, systemTableRangeCount, sysDBMap, + ) + db := testCluster.ServerConn(0) + defer db.Close() + + runner := sqlutils.MakeSQLRunner(db) + s := sqlutils.MatrixToStr(runner.QueryStr(t, `SHOW CLUSTER RANGES`)) + t.Logf("actual ranges:\n%s", s) + } +} + +func TestStatsforSpanOnLocalMax(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{}) + defer testCluster.Stopper().Stop(context.Background()) + firstServer := testCluster.Server(0) + ts := firstServer.(*server.TestServer) + + underTest := roachpb.Span{ + Key: keys.LocalMax, + EndKey: keys.SystemPrefix, + } + + _, err := ts.TestingStatsForSpan(context.Background(), underTest) + if err != nil { + t.Fatal(err) + } +} + +func TestAdminAPIDataDistribution(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{}) + defer testCluster.Stopper().Stop(context.Background()) + + firstServer := testCluster.Server(0) + sqlDB := sqlutils.MakeSQLRunner(testCluster.ServerConn(0)) + + // TODO(irfansharif): The data-distribution page and underyling APIs don't + // know how to deal with coalesced ranges. See #97942. + sqlDB.Exec(t, `SET CLUSTER SETTING spanconfig.storage_coalesce_adjacent.enabled = false`) + + // Create some tables. + sqlDB.Exec(t, `CREATE DATABASE roachblog`) + sqlDB.Exec(t, `CREATE TABLE roachblog.posts (id INT PRIMARY KEY, title text, body text)`) + sqlDB.Exec(t, `CREATE TABLE roachblog.comments ( + id INT PRIMARY KEY, + post_id INT REFERENCES roachblog.posts, + body text + )`) + sqlDB.Exec(t, `CREATE SCHEMA roachblog."foo bar"`) + sqlDB.Exec(t, `CREATE TABLE roachblog."foo bar".other_stuff(id INT PRIMARY KEY, body TEXT)`) + // Test special characters in DB and table names. + sqlDB.Exec(t, `CREATE DATABASE "sp'ec\ch""ars"`) + sqlDB.Exec(t, `CREATE TABLE "sp'ec\ch""ars"."more\spec'chars" (id INT PRIMARY KEY)`) + + // Make sure secondary tenants don't cause the endpoint to error. + sqlDB.Exec(t, "CREATE TENANT 'app'") + + // Verify that we see their replicas in the DataDistribution response, evenly spread + // across the test cluster's three nodes. + + expectedDatabaseInfo := map[string]serverpb.DataDistributionResponse_DatabaseInfo{ + "roachblog": { + TableInfo: map[string]serverpb.DataDistributionResponse_TableInfo{ + "public.posts": { + ReplicaCountByNodeId: map[roachpb.NodeID]int64{ + 1: 1, + 2: 1, + 3: 1, + }, + }, + "public.comments": { + ReplicaCountByNodeId: map[roachpb.NodeID]int64{ + 1: 1, + 2: 1, + 3: 1, + }, + }, + `"foo bar".other_stuff`: { + ReplicaCountByNodeId: map[roachpb.NodeID]int64{ + 1: 1, + 2: 1, + 3: 1, + }, + }, + }, + }, + `sp'ec\ch"ars`: { + TableInfo: map[string]serverpb.DataDistributionResponse_TableInfo{ + `public."more\spec'chars"`: { + ReplicaCountByNodeId: map[roachpb.NodeID]int64{ + 1: 1, + 2: 1, + 3: 1, + }, + }, + }, + }, + } + + // Wait for the new tables' ranges to be created and replicated. + testutils.SucceedsSoon(t, func() error { + var resp serverpb.DataDistributionResponse + if err := srvtestutils.GetAdminJSONProto(firstServer, "data_distribution", &resp); err != nil { + t.Fatal(err) + } + + delete(resp.DatabaseInfo, "system") // delete results for system database. + if !reflect.DeepEqual(resp.DatabaseInfo, expectedDatabaseInfo) { + return fmt.Errorf("expected %v; got %v", expectedDatabaseInfo, resp.DatabaseInfo) + } + + // Don't test anything about the zone configs for now; just verify that something is there. + if len(resp.ZoneConfigs) == 0 { + return fmt.Errorf("no zone configs returned") + } + + return nil + }) + + // Verify that the request still works after a table has been dropped, + // and that dropped_at is set on the dropped table. + sqlDB.Exec(t, `DROP TABLE roachblog.comments`) + + var resp serverpb.DataDistributionResponse + if err := srvtestutils.GetAdminJSONProto(firstServer, "data_distribution", &resp); err != nil { + t.Fatal(err) + } + + if resp.DatabaseInfo["roachblog"].TableInfo["public.comments"].DroppedAt == nil { + t.Fatal("expected roachblog.comments to have dropped_at set but it's nil") + } + + // Verify that the request still works after a database has been dropped. + sqlDB.Exec(t, `DROP DATABASE roachblog CASCADE`) + + if err := srvtestutils.GetAdminJSONProto(firstServer, "data_distribution", &resp); err != nil { + t.Fatal(err) + } +} + +func BenchmarkAdminAPIDataDistribution(b *testing.B) { + skip.UnderShort(b, "TODO: fix benchmark") + testCluster := serverutils.StartNewTestCluster(b, 3, base.TestClusterArgs{}) + defer testCluster.Stopper().Stop(context.Background()) + + firstServer := testCluster.Server(0) + sqlDB := sqlutils.MakeSQLRunner(testCluster.ServerConn(0)) + + sqlDB.Exec(b, `CREATE DATABASE roachblog`) + + // Create a bunch of tables. + for i := 0; i < 200; i++ { + sqlDB.Exec( + b, + fmt.Sprintf(`CREATE TABLE roachblog.t%d (id INT PRIMARY KEY, title text, body text)`, i), + ) + // TODO(vilterp): split to increase the number of ranges for each table + } + + b.ResetTimer() + for n := 0; n < b.N; n++ { + var resp serverpb.DataDistributionResponse + if err := srvtestutils.GetAdminJSONProto(firstServer, "data_distribution", &resp); err != nil { + b.Fatal(err) + } + } + b.StopTimer() +} + +func TestHotRangesResponse(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ts := rangetestutils.StartServer(t) + defer ts.Stopper().Stop(context.Background()) + + var hotRangesResp serverpb.HotRangesResponse + if err := srvtestutils.GetStatusJSONProto(ts, "hotranges", &hotRangesResp); err != nil { + t.Fatal(err) + } + if len(hotRangesResp.HotRangesByNodeID) == 0 { + t.Fatalf("didn't get hot range responses from any nodes") + } + + for nodeID, nodeResp := range hotRangesResp.HotRangesByNodeID { + if len(nodeResp.Stores) == 0 { + t.Errorf("didn't get any stores in hot range response from n%d: %v", + nodeID, nodeResp.ErrorMessage) + } + for _, storeResp := range nodeResp.Stores { + // Only the first store will actually have any ranges on it. + if storeResp.StoreID != roachpb.StoreID(1) { + continue + } + lastQPS := math.MaxFloat64 + if len(storeResp.HotRanges) == 0 { + t.Errorf("didn't get any hot ranges in response from n%d,s%d: %v", + nodeID, storeResp.StoreID, nodeResp.ErrorMessage) + } + for _, r := range storeResp.HotRanges { + if r.Desc.RangeID == 0 || (len(r.Desc.StartKey) == 0 && len(r.Desc.EndKey) == 0) { + t.Errorf("unexpected empty/unpopulated range descriptor: %+v", r.Desc) + } + if r.QueriesPerSecond > 0 { + if r.ReadsPerSecond == 0 && r.WritesPerSecond == 0 && r.ReadBytesPerSecond == 0 && r.WriteBytesPerSecond == 0 { + t.Errorf("qps %.2f > 0, expected either reads=%.2f, writes=%.2f, readBytes=%.2f or writeBytes=%.2f to be non-zero", + r.QueriesPerSecond, r.ReadsPerSecond, r.WritesPerSecond, r.ReadBytesPerSecond, r.WriteBytesPerSecond) + } + // If the architecture doesn't support sampling CPU, it + // will also be zero. + if grunning.Supported() && r.CPUTimePerSecond == 0 { + t.Errorf("qps %.2f > 0, expected cpu=%.2f to be non-zero", + r.QueriesPerSecond, r.CPUTimePerSecond) + } + } + if r.QueriesPerSecond > lastQPS { + t.Errorf("unexpected increase in qps between ranges; prev=%.2f, current=%.2f, desc=%v", + lastQPS, r.QueriesPerSecond, r.Desc) + } + lastQPS = r.QueriesPerSecond + } + } + + } +} + +func TestHotRanges2Response(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ts := rangetestutils.StartServer(t) + defer ts.Stopper().Stop(context.Background()) + + var hotRangesResp serverpb.HotRangesResponseV2 + if err := srvtestutils.PostStatusJSONProto(ts, "v2/hotranges", &serverpb.HotRangesRequest{}, &hotRangesResp); err != nil { + t.Fatal(err) + } + if len(hotRangesResp.Ranges) == 0 { + t.Fatalf("didn't get hot range responses from any nodes") + } + lastQPS := math.MaxFloat64 + for _, r := range hotRangesResp.Ranges { + if r.RangeID == 0 { + t.Errorf("unexpected empty range id: %d", r.RangeID) + } + if r.QPS > 0 { + if r.ReadsPerSecond == 0 && r.WritesPerSecond == 0 && r.ReadBytesPerSecond == 0 && r.WriteBytesPerSecond == 0 { + t.Errorf("qps %.2f > 0, expected either reads=%.2f, writes=%.2f, readBytes=%.2f or writeBytes=%.2f to be non-zero", + r.QPS, r.ReadsPerSecond, r.WritesPerSecond, r.ReadBytesPerSecond, r.WriteBytesPerSecond) + } + // If the architecture doesn't support sampling CPU, it + // will also be zero. + if grunning.Supported() && r.CPUTimePerSecond == 0 { + t.Errorf("qps %.2f > 0, expected cpu=%.2f to be non-zero", r.QPS, r.CPUTimePerSecond) + } + } + if r.QPS > lastQPS { + t.Errorf("unexpected increase in qps between ranges; prev=%.2f, current=%.2f", lastQPS, r.QPS) + } + lastQPS = r.QPS + } +} + +func TestHotRanges2ResponseWithViewActivityOptions(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + s, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + db := sqlutils.MakeSQLRunner(sqlDB) + + req := &serverpb.HotRangesRequest{} + var hotRangesResp serverpb.HotRangesResponseV2 + if err := srvtestutils.PostStatusJSONProtoWithAdminOption(s, "v2/hotranges", req, &hotRangesResp, false); err != nil { + if !testutils.IsError(err, "status: 403") { + t.Fatalf("expected privilege error, got %v", err) + } + } + + // Grant VIEWCLUSTERMETADATA and all test should work. + db.Exec(t, fmt.Sprintf("GRANT SYSTEM VIEWCLUSTERMETADATA TO %s", apiconstants.TestingUserNameNoAdmin().Normalized())) + if err := srvtestutils.PostStatusJSONProtoWithAdminOption(s, "v2/hotranges", req, &hotRangesResp, false); err != nil { + t.Fatal(err) + } + + // Grant VIEWACTIVITYREDACTED and all test should get permission errors. + db.Exec(t, fmt.Sprintf("REVOKE SYSTEM VIEWCLUSTERMETADATA FROM %s", apiconstants.TestingUserNameNoAdmin().Normalized())) + if err := srvtestutils.PostStatusJSONProtoWithAdminOption(s, "v2/hotranges", req, &hotRangesResp, false); err != nil { + if !testutils.IsError(err, "status: 403") { + t.Fatalf("expected privilege error, got %v", err) + } + } +} + +func TestSpanStatsGRPCResponse(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ctx := context.Background() + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(ctx) + ts := s.(*server.TestServer) + + rpcStopper := stop.NewStopper() + defer rpcStopper.Stop(ctx) + rpcContext := srvtestutils.NewRPCTestContext(ctx, ts, ts.RPCContext().Config) + span := roachpb.Span{ + Key: roachpb.RKeyMin.AsRawKey(), + EndKey: roachpb.RKeyMax.AsRawKey(), + } + request := roachpb.SpanStatsRequest{ + NodeID: "1", + Spans: []roachpb.Span{span}, + } + + url := ts.ServingRPCAddr() + nodeID := ts.NodeID() + conn, err := rpcContext.GRPCDialNode(url, nodeID, rpc.DefaultClass).Connect(ctx) + if err != nil { + t.Fatal(err) + } + client := serverpb.NewStatusClient(conn) + + response, err := client.SpanStats(ctx, &request) + if err != nil { + t.Fatal(err) + } + initialRanges, err := ts.ExpectedInitialRangeCount() + if err != nil { + t.Fatal(err) + } + responseSpanStats := response.SpanToStats[span.String()] + if a, e := int(responseSpanStats.RangeCount), initialRanges; a != e { + t.Fatalf("expected %d ranges, found %d", e, a) + } +} diff --git a/pkg/server/application_api/telemetry_test.go b/pkg/server/application_api/telemetry_test.go new file mode 100644 index 000000000000..84214b447d4a --- /dev/null +++ b/pkg/server/application_api/telemetry_test.go @@ -0,0 +1,119 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package application_api_test + +import ( + "context" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" + "github.com/cockroachdb/cockroach/pkg/server/diagnostics/diagnosticspb" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" +) + +// TestHealthTelemetry confirms that hits on some status endpoints increment +// feature telemetry counters. +func TestHealthTelemetry(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + + rows, err := db.Query("SELECT * FROM crdb_internal.feature_usage WHERE feature_name LIKE 'monitoring%' AND usage_count > 0;") + defer func() { + if err := rows.Close(); err != nil { + t.Fatal(err) + } + }() + if err != nil { + t.Fatal(err) + } + + initialCounts := make(map[string]int) + for rows.Next() { + var featureName string + var usageCount int + + if err := rows.Scan(&featureName, &usageCount); err != nil { + t.Fatal(err) + } + + initialCounts[featureName] = usageCount + } + + var details serverpb.DetailsResponse + if err := serverutils.GetJSONProto(s, "/health", &details); err != nil { + t.Fatal(err) + } + if _, err := srvtestutils.GetText(s, s.AdminURL().WithPath(apiconstants.StatusPrefix+"vars").String()); err != nil { + t.Fatal(err) + } + + expectedCounts := map[string]int{ + "monitoring.prometheus.vars": 1, + "monitoring.health.details": 1, + } + + rows2, err := db.Query("SELECT feature_name, usage_count FROM crdb_internal.feature_usage WHERE feature_name LIKE 'monitoring%' AND usage_count > 0;") + defer func() { + if err := rows2.Close(); err != nil { + t.Fatal(err) + } + }() + if err != nil { + t.Fatal(err) + } + + for rows2.Next() { + var featureName string + var usageCount int + + if err := rows2.Scan(&featureName, &usageCount); err != nil { + t.Fatal(err) + } + + usageCount -= initialCounts[featureName] + if count, ok := expectedCounts[featureName]; ok { + if count != usageCount { + t.Fatalf("expected %d count for feature %s, got %d", count, featureName, usageCount) + } + delete(expectedCounts, featureName) + } + } + + if len(expectedCounts) > 0 { + t.Fatalf("%d expected telemetry counters not emitted", len(expectedCounts)) + } +} + +func TestDiagnosticsResponse(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + + var resp diagnosticspb.DiagnosticReport + if err := srvtestutils.GetStatusJSONProto(s, "diagnostics/local", &resp); err != nil { + t.Fatal(err) + } + + // The endpoint just serializes result of getReportingInfo() which is already + // tested elsewhere, so simply verify that we have a non-empty reply. + if expected, actual := s.NodeID(), resp.Node.NodeID; expected != actual { + t.Fatalf("expected %v got %v", expected, actual) + } +} diff --git a/pkg/server/application_api/util_test.go b/pkg/server/application_api/util_test.go new file mode 100644 index 000000000000..3a839413486e --- /dev/null +++ b/pkg/server/application_api/util_test.go @@ -0,0 +1,32 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package application_api_test + +import ( + "github.com/cockroachdb/cockroach/pkg/util/randident" + "github.com/cockroachdb/cockroach/pkg/util/randutil" +) + +func generateRandomName() string { + rand, _ := randutil.NewTestRand() + cfg := randident.DefaultNameGeneratorConfig() + // REST api can not handle `/`. This is fixed in + // the UI by using sql-over-http endpoint instead. + cfg.Punctuate = -1 + cfg.Finalize() + + ng := randident.NewNameGenerator( + &cfg, + rand, + "a b%s-c.d", + ) + return ng.GenerateOne(42) +} diff --git a/pkg/server/application_api/zcfg_test.go b/pkg/server/application_api/zcfg_test.go new file mode 100644 index 000000000000..8979560facee --- /dev/null +++ b/pkg/server/application_api/zcfg_test.go @@ -0,0 +1,138 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package application_api_test + +import ( + "context" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/config/zonepb" + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/protoutil" + "github.com/gogo/protobuf/proto" +) + +// TestAdminAPIZoneDetails verifies the zone configuration information returned +// for both DatabaseDetailsResponse AND TableDetailsResponse. +func TestAdminAPIZoneDetails(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, db, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails + // with it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + }) + defer s.Stopper().Stop(context.Background()) + ts := s.(*server.TestServer) + + // Create database and table. + ac := ts.AmbientCtx() + ctx, span := ac.AnnotateCtxWithSpan(context.Background(), "test") + defer span.Finish() + setupQueries := []string{ + "CREATE DATABASE test", + "CREATE TABLE test.tbl (val STRING)", + } + for _, q := range setupQueries { + if _, err := db.Exec(q); err != nil { + t.Fatalf("error executing '%s': %s", q, err) + } + } + + // Function to verify the zone for table "test.tbl" as returned by the Admin + // API. + verifyTblZone := func( + expectedZone zonepb.ZoneConfig, expectedLevel serverpb.ZoneConfigurationLevel, + ) { + var resp serverpb.TableDetailsResponse + if err := srvtestutils.GetAdminJSONProto(s, "databases/test/tables/tbl", &resp); err != nil { + t.Fatal(err) + } + if a, e := &resp.ZoneConfig, &expectedZone; !a.Equal(e) { + t.Errorf("actual table zone config %v did not match expected value %v", a, e) + } + if a, e := resp.ZoneConfigLevel, expectedLevel; a != e { + t.Errorf("actual table ZoneConfigurationLevel %s did not match expected value %s", a, e) + } + if t.Failed() { + t.FailNow() + } + } + + // Function to verify the zone for database "test" as returned by the Admin + // API. + verifyDbZone := func( + expectedZone zonepb.ZoneConfig, expectedLevel serverpb.ZoneConfigurationLevel, + ) { + var resp serverpb.DatabaseDetailsResponse + if err := srvtestutils.GetAdminJSONProto(s, "databases/test", &resp); err != nil { + t.Fatal(err) + } + if a, e := &resp.ZoneConfig, &expectedZone; !a.Equal(e) { + t.Errorf("actual db zone config %v did not match expected value %v", a, e) + } + if a, e := resp.ZoneConfigLevel, expectedLevel; a != e { + t.Errorf("actual db ZoneConfigurationLevel %s did not match expected value %s", a, e) + } + if t.Failed() { + t.FailNow() + } + } + + // Function to store a zone config for a given object ID. + setZone := func(zoneCfg zonepb.ZoneConfig, id descpb.ID) { + zoneBytes, err := protoutil.Marshal(&zoneCfg) + if err != nil { + t.Fatal(err) + } + const query = `INSERT INTO system.zones VALUES($1, $2)` + if _, err := db.Exec(query, id, zoneBytes); err != nil { + t.Fatalf("error executing '%s': %s", query, err) + } + } + + // Verify zone matches cluster default. + verifyDbZone(s.(*server.TestServer).Cfg.DefaultZoneConfig, serverpb.ZoneConfigurationLevel_CLUSTER) + verifyTblZone(s.(*server.TestServer).Cfg.DefaultZoneConfig, serverpb.ZoneConfigurationLevel_CLUSTER) + + databaseID, err := ts.TestingQueryDatabaseID(ctx, username.RootUserName(), "test") + if err != nil { + t.Fatal(err) + } + tableID, err := ts.TestingQueryTableID(ctx, username.RootUserName(), "test", "tbl") + if err != nil { + t.Fatal(err) + } + + // Apply zone configuration to database and check again. + dbZone := zonepb.ZoneConfig{ + RangeMinBytes: proto.Int64(456), + } + setZone(dbZone, databaseID) + verifyDbZone(dbZone, serverpb.ZoneConfigurationLevel_DATABASE) + verifyTblZone(dbZone, serverpb.ZoneConfigurationLevel_DATABASE) + + // Apply zone configuration to table and check again. + tblZone := zonepb.ZoneConfig{ + RangeMinBytes: proto.Int64(789), + } + setZone(tblZone, tableID) + verifyDbZone(dbZone, serverpb.ZoneConfigurationLevel_DATABASE) + verifyTblZone(tblZone, serverpb.ZoneConfigurationLevel_TABLE) +} diff --git a/pkg/server/authserver/BUILD.bazel b/pkg/server/authserver/BUILD.bazel new file mode 100644 index 000000000000..a3d1079a4330 --- /dev/null +++ b/pkg/server/authserver/BUILD.bazel @@ -0,0 +1,97 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "authserver", + srcs = [ + "api.go", + "api_v2.go", + "api_v2_auth.go", + "authentication.go", + "context.go", + "cookie.go", + ], + importpath = "github.com/cockroachdb/cockroach/pkg/server/authserver", + visibility = ["//visibility:public"], + deps = [ + "//pkg/base", + "//pkg/clusterversion", + "//pkg/multitenant", + "//pkg/roachpb", + "//pkg/security", + "//pkg/security/password", + "//pkg/security/username", + "//pkg/server/apiutil", + "//pkg/server/serverpb", + "//pkg/server/srverrors", + "//pkg/settings", + "//pkg/settings/cluster", + "//pkg/sql", + "//pkg/sql/isql", + "//pkg/sql/roleoption", + "//pkg/sql/sem/tree", + "//pkg/sql/sessiondata", + "//pkg/sql/types", + "//pkg/ui", + "//pkg/util/grpcutil", + "//pkg/util/log", + "//pkg/util/protoutil", + "//pkg/util/uuid", + "@com_github_cockroachdb_errors//:errors", + "@com_github_cockroachdb_logtags//:logtags", + "@com_github_grpc_ecosystem_grpc_gateway//runtime:go_default_library", + "@org_golang_google_grpc//:go_default_library", + "@org_golang_google_grpc//codes", + "@org_golang_google_grpc//metadata", + "@org_golang_google_grpc//status", + ], +) + +go_test( + name = "authserver_test", + srcs = [ + "authentication_test.go", + "main_test.go", + ], + args = ["-test.timeout=295s"], + deps = [ + ":authserver", + "//pkg/base", + "//pkg/ccl", + "//pkg/gossip", + "//pkg/kv/kvclient/kvtenant", + "//pkg/kv/kvpb", + "//pkg/kv/kvserver", + "//pkg/kv/kvserver/closedts/ctpb", + "//pkg/kv/kvserver/kvserverpb", + "//pkg/roachpb", + "//pkg/rpc", + "//pkg/security", + "//pkg/security/securityassets", + "//pkg/security/securitytest", + "//pkg/security/username", + "//pkg/server", + "//pkg/server/apiconstants", + "//pkg/server/debug", + "//pkg/server/serverpb", + "//pkg/settings/cluster", + "//pkg/sql/execinfrapb", + "//pkg/testutils", + "//pkg/testutils/serverutils", + "//pkg/testutils/skip", + "//pkg/testutils/testcluster", + "//pkg/ts", + "//pkg/ts/tspb", + "//pkg/util", + "//pkg/util/httputil", + "//pkg/util/leaktest", + "//pkg/util/log", + "//pkg/util/timeutil", + "@com_github_cockroachdb_errors//:errors", + "@com_github_gogo_protobuf//jsonpb", + "@com_github_lib_pq//:pq", + "@com_github_stretchr_testify//require", + "@org_golang_google_grpc//:go_default_library", + "@org_golang_google_grpc//credentials", + "@org_golang_x_crypto//bcrypt", + ], +) diff --git a/pkg/server/authserver/api.go b/pkg/server/authserver/api.go new file mode 100644 index 000000000000..a87437c3d27f --- /dev/null +++ b/pkg/server/authserver/api.go @@ -0,0 +1,113 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package authserver + +import ( + "context" + "net/http" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/sql/isql" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "google.golang.org/grpc" +) + +type Server interface { + RegisterService(*grpc.Server) + RegisterGateway(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error + + // UserLogin verifies an incoming request by a user to create an web + // authentication session. It checks the provided credentials against + // the system.users table, and if successful creates a new + // authentication session. The session's ID and secret are returned to + // the caller as an HTTP cookie, added via a "Set-Cookie" header. + UserLogin(ctx context.Context, req *serverpb.UserLoginRequest) (*serverpb.UserLoginResponse, error) + + // UserLoginFromSSO checks for the existence of a given username and + // if it exists, creates a session for the username in the + // `web_sessions` table. The session's ID and secret are returned to + // the caller as an HTTP cookie, added via a "Set-Cookie" header. + UserLoginFromSSO(ctx context.Context, reqUsername string) (*http.Cookie, error) + + // UserLogout allows a user to terminate their currently active session. + UserLogout(ctx context.Context, req *serverpb.UserLogoutRequest) (*serverpb.UserLogoutResponse, error) + + // DemoLogin is the same as UserLogin but using the GET method. + // It is only available for 'cockroach demo' and test clusters. + DemoLogin(w http.ResponseWriter, req *http.Request) + + // NewAuthSession attempts to create a new authentication session for + // the given user. If successful, returns the ID and secret value for + // the new session. + // + // The caller is responsible to ensure the username has been + // normalized already. + // + // This is a low level API and is only exported for use in tests. + // Regular flows should use the login endpoints intead. + NewAuthSession(ctx context.Context, userName username.SQLUsername) (int64, []byte, error) + + // VerifySession verifies the existence and validity of the session + // claimed by the supplied SessionCookie. Returns three parameters: a + // boolean indicating if the session was valid, the username + // associated with the session (if validated), and an error for any + // internal errors which prevented validation. + // + // This is a low level API and is only exported for use in tests. + VerifySession( + ctx context.Context, cookie *serverpb.SessionCookie, + ) (bool, string, error) + + // VerifyPasswordDBConsole verifies the passed username/password + // pair against the system.users table. The returned boolean indicates + // whether or not the verification succeeded; an error is returned if + // the validation process could not be completed. + // + // This is a low level API and is only exported for use in tests. + // Regular flows should use the login endpoints intead. + // + // This function should *not* be used to validate logins into the SQL + // shell since it checks a separate authentication scheme. + // + // The caller is responsible for ensuring that the username is + // normalized. (CockroachDB has case-insensitive usernames, unlike + // PostgreSQL.) + VerifyPasswordDBConsole( + ctx context.Context, userName username.SQLUsername, passwordStr string, + ) (valid bool, expired bool, err error) +} + +type SQLServerInterface interface { + ExecutorConfig() *sql.ExecutorConfig + InternalExecutor() isql.Executor +} + +type AuthMux interface { + http.Handler +} + +func NewServer(cfg *base.Config, sqlServer SQLServerInterface) Server { + return &authenticationServer{ + cfg: cfg, + sqlServer: sqlServer, + } +} + +func NewMux(s Server, inner http.Handler, allowAnonymous bool) AuthMux { + return &authenticationMux{ + server: s.(*authenticationServer), + inner: inner, + allowAnonymous: allowAnonymous, + } +} diff --git a/pkg/server/authserver/api_v2.go b/pkg/server/authserver/api_v2.go new file mode 100644 index 000000000000..472e2c6d9821 --- /dev/null +++ b/pkg/server/authserver/api_v2.go @@ -0,0 +1,74 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package authserver + +import ( + "context" + "net/http" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/sql/isql" + "github.com/cockroachdb/cockroach/pkg/sql/roleoption" +) + +type ServerV2 interface { + http.Handler +} + +type AuthV2Mux interface { + http.Handler +} + +type RoleAuthzMux interface { + http.Handler +} + +// NewV2Server creates a new ServerV2 for the given outer Server, +// and base path. +func NewV2Server( + ctx context.Context, s SQLServerInterface, cfg *base.Config, basePath string, +) ServerV2 { + simpleMux := http.NewServeMux() + + innerServer := NewServer(cfg, s).(*authenticationServer) + authServer := &authenticationV2Server{ + sqlServer: s, + authServer: innerServer, + mux: simpleMux, + ctx: ctx, + basePath: basePath, + } + + authServer.registerRoutes() + return authServer +} + +// NewV2Mux creates a new AuthV2Mux for the given ServerV2. +func NewV2Mux(s ServerV2, inner http.Handler, allowAnonymous bool) AuthV2Mux { + as := s.(*authenticationV2Server) + return &authenticationV2Mux{ + s: as, + inner: inner, + allowAnonymous: allowAnonymous, + } +} + +// NewRoleAuthzMux creates a new RoleAuthzMux. +func NewRoleAuthzMux( + ie isql.Executor, role APIRole, option roleoption.Option, inner http.Handler, +) RoleAuthzMux { + return &roleAuthorizationMux{ + ie: ie, + role: role, + option: option, + inner: inner, + } +} diff --git a/pkg/server/api_v2_auth.go b/pkg/server/authserver/api_v2_auth.go similarity index 81% rename from pkg/server/api_v2_auth.go rename to pkg/server/authserver/api_v2_auth.go index 54d11ff9981e..3df7ec1bf91d 100644 --- a/pkg/server/api_v2_auth.go +++ b/pkg/server/authserver/api_v2_auth.go @@ -8,17 +8,18 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. -package server +package authserver import ( "context" "encoding/base64" "net/http" - "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server/apiutil" "github.com/cockroachdb/cockroach/pkg/server/serverpb" - "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/server/srverrors" + "github.com/cockroachdb/cockroach/pkg/sql/isql" "github.com/cockroachdb/cockroach/pkg/sql/roleoption" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" @@ -29,37 +30,22 @@ import ( "google.golang.org/grpc/status" ) +const ( + APIV2AuthHeader = "X-Cockroach-API-Session" +) + // authenticationV2Server is a sub-server under apiV2Server that handles // authentication-related endpoints, such as login and logout. The actual // verification of sessions for regular endpoints happens in authenticationV2Mux, // not here. type authenticationV2Server struct { ctx context.Context - sqlServer *SQLServer + sqlServer SQLServerInterface authServer *authenticationServer mux *http.ServeMux basePath string } -// newAuthenticationV2Server creates a new authenticationV2Server for the given -// outer Server, and base path. -func newAuthenticationV2Server( - ctx context.Context, s *SQLServer, cfg *base.Config, basePath string, -) *authenticationV2Server { - simpleMux := http.NewServeMux() - - authServer := &authenticationV2Server{ - sqlServer: s, - authServer: newAuthenticationServer(cfg, s), - mux: simpleMux, - ctx: ctx, - basePath: basePath, - } - - authServer.registerRoutes() - return authServer -} - func (a *authenticationV2Server) registerRoutes() { a.bindEndpoint("login/", a.login) a.bindEndpoint("logout/", a.logout) @@ -76,9 +62,9 @@ func (a *authenticationV2Server) createSessionFor( ctx context.Context, userName username.SQLUsername, ) (string, error) { // Create a new database session, generating an ID and secret key. - id, secret, err := a.authServer.newAuthSession(ctx, userName) + id, secret, err := a.authServer.NewAuthSession(ctx, userName) if err != nil { - return "", apiInternalError(ctx, err) + return "", srverrors.APIInternalError(ctx, err) } // Generate and set a session for the response. Because HTTP cookies @@ -150,7 +136,7 @@ func (a *authenticationV2Server) login(w http.ResponseWriter, r *http.Request) { http.Error(w, "not found", http.StatusNotFound) } if err := r.ParseForm(); err != nil { - apiV2InternalError(r.Context(), err, w) + srverrors.APIV2InternalError(r.Context(), err, w) return } if r.Form.Get("username") == "" { @@ -166,9 +152,9 @@ func (a *authenticationV2Server) login(w http.ResponseWriter, r *http.Request) { username, _ := username.MakeSQLUsernameFromUserInput(r.Form.Get("username"), username.PurposeValidation) // Verify the provided username/password pair. - verified, expired, err := a.authServer.verifyPasswordDBConsole(a.ctx, username, r.Form.Get("password")) + verified, expired, err := a.authServer.VerifyPasswordDBConsole(a.ctx, username, r.Form.Get("password")) if err != nil { - apiV2InternalError(r.Context(), err, w) + srverrors.APIV2InternalError(r.Context(), err, w) return } if expired { @@ -182,11 +168,11 @@ func (a *authenticationV2Server) login(w http.ResponseWriter, r *http.Request) { session, err := a.createSessionFor(a.ctx, username) if err != nil { - apiV2InternalError(r.Context(), err, w) + srverrors.APIV2InternalError(r.Context(), err, w) return } - writeJSONResponse(r.Context(), w, http.StatusOK, &loginResponse{Session: session}) + apiutil.WriteJSONResponse(r.Context(), w, http.StatusOK, &loginResponse{Session: session}) } // swagger:model logoutResponse @@ -221,7 +207,7 @@ func (a *authenticationV2Server) logout(w http.ResponseWriter, r *http.Request) if r.Method != "POST" { http.Error(w, "not found", http.StatusNotFound) } - session := r.Header.Get(apiV2AuthHeader) + session := r.Header.Get(APIV2AuthHeader) if session == "" { http.Error(w, "invalid or unspecified session", http.StatusBadRequest) return @@ -229,16 +215,16 @@ func (a *authenticationV2Server) logout(w http.ResponseWriter, r *http.Request) var sessionCookie serverpb.SessionCookie decoded, err := base64.StdEncoding.DecodeString(session) if err != nil { - apiV2InternalError(r.Context(), err, w) + srverrors.APIV2InternalError(r.Context(), err, w) return } if err := protoutil.Unmarshal(decoded, &sessionCookie); err != nil { - apiV2InternalError(r.Context(), err, w) + srverrors.APIV2InternalError(r.Context(), err, w) return } // Revoke the session. - if n, err := a.sqlServer.internalExecutor.ExecEx( + if n, err := a.sqlServer.InternalExecutor().ExecEx( a.ctx, "revoke-auth-session", nil, /* txn */ @@ -246,7 +232,7 @@ func (a *authenticationV2Server) logout(w http.ResponseWriter, r *http.Request) `UPDATE system.web_sessions SET "revokedAt" = now() WHERE id = $1`, sessionCookie.ID, ); err != nil { - apiV2InternalError(r.Context(), err, w) + srverrors.APIV2InternalError(r.Context(), err, w) return } else if n == 0 { err := status.Errorf( @@ -257,7 +243,7 @@ func (a *authenticationV2Server) logout(w http.ResponseWriter, r *http.Request) return } - writeJSONResponse(r.Context(), w, http.StatusOK, &logoutResponse{LoggedOut: true}) + apiutil.WriteJSONResponse(r.Context(), w, http.StatusOK, &logoutResponse{LoggedOut: true}) } func (a *authenticationV2Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -274,19 +260,11 @@ type authenticationV2Mux struct { allowAnonymous bool } -func newAuthenticationV2Mux(s *authenticationV2Server, inner http.Handler) *authenticationV2Mux { - return &authenticationV2Mux{ - s: s, - inner: inner, - allowAnonymous: s.sqlServer.cfg.Insecure, - } -} - -// apiV2UseCookieBasedAuth is a magic value of the auth header that +// APIV2UseCookieBasedAuth is a magic value of the auth header that // tells us to look for the session in the cookie. This can be used by // frontend code to maintain cookie-based auth while interacting with // the API. -const apiV2UseCookieBasedAuth = "cookie" +const APIV2UseCookieBasedAuth = "cookie" // getSession decodes the cookie from the request, looks up the corresponding // session, and returns the logged-in username. The session can be looked up @@ -301,7 +279,7 @@ func (a *authenticationV2Mux) getSession( ) (string, *serverpb.SessionCookie, int, error) { ctx := req.Context() // Validate the returned session header or cookie. - rawSession := req.Header.Get(apiV2AuthHeader) + rawSession := req.Header.Get(APIV2AuthHeader) if len(rawSession) == 0 { err := errors.New("invalid session header") return "", nil, http.StatusUnauthorized, err @@ -309,9 +287,9 @@ func (a *authenticationV2Mux) getSession( cookie := &serverpb.SessionCookie{} var err error - if rawSession == apiV2UseCookieBasedAuth { - st := a.s.sqlServer.cfg.Settings - cookie, err = findAndDecodeSessionCookie(req.Context(), st, req.Cookies()) + if rawSession == APIV2UseCookieBasedAuth { + st := a.s.sqlServer.ExecutorConfig().Settings + cookie, err = FindAndDecodeSessionCookie(req.Context(), st, req.Cookies()) } else { decoded, err := base64.StdEncoding.DecodeString(rawSession) if err != nil { @@ -327,9 +305,9 @@ func (a *authenticationV2Mux) getSession( err := errors.New("invalid session header") return "", nil, http.StatusBadRequest, err } - valid, username, err := a.s.authServer.verifySession(req.Context(), cookie) + valid, username, err := a.s.authServer.VerifySession(req.Context(), cookie) if err != nil { - apiV2InternalError(req.Context(), err, w) + srverrors.APIV2InternalError(req.Context(), err, w) return "", nil, http.StatusInternalServerError, err } if !valid { @@ -356,16 +334,23 @@ func (a *authenticationV2Mux) ServeHTTP(w http.ResponseWriter, req *http.Request if cookie != nil { sessionID = cookie.ID } - req = req.WithContext(contextWithHTTPAuthInfo(req.Context(), u, sessionID)) + req = req.WithContext(ContextWithHTTPAuthInfo(req.Context(), u, sessionID)) a.inner.ServeHTTP(w, req) } -type apiRole int +// APIRole is an enum representing the authorization level +// needed for an APIv2 endpoint. +type APIRole int const ( - regularRole apiRole = iota - adminRole - superUserRole + // RegularRole is the default role for an APIv2 endpoint. + RegularRole APIRole = iota + // AdminRole is the role for an APIv2 endpoint that requires + // admin privileges. + AdminRole + // SuperUserRole is the role for an APIv2 endpoint that requires + // superuser privileges. + SuperUserRole ) // roleAuthorizationMux enforces a role (eg. type of user, role option) @@ -374,40 +359,40 @@ const ( // the `option` roleoption, an HTTP 403 forbidden error is returned. Otherwise, // the request is passed onto the inner http.Handler. type roleAuthorizationMux struct { - ie *sql.InternalExecutor - role apiRole + ie isql.Executor + role APIRole option roleoption.Option inner http.Handler } func (r *roleAuthorizationMux) getRoleForUser( ctx context.Context, user username.SQLUsername, -) (apiRole, error) { +) (APIRole, error) { if user.IsRootUser() { // Shortcut. - return superUserRole, nil + return SuperUserRole, nil } row, err := r.ie.QueryRowEx( ctx, "check-is-admin", nil, /* txn */ sessiondata.InternalExecutorOverride{User: user}, "SELECT crdb_internal.is_admin()") if err != nil { - return regularRole, err + return RegularRole, err } if row == nil { - return regularRole, errors.AssertionFailedf("hasAdminRole: expected 1 row, got 0") + return RegularRole, errors.AssertionFailedf("hasAdminRole: expected 1 row, got 0") } if len(row) != 1 { - return regularRole, errors.AssertionFailedf("hasAdminRole: expected 1 column, got %d", len(row)) + return RegularRole, errors.AssertionFailedf("hasAdminRole: expected 1 column, got %d", len(row)) } dbDatum, ok := tree.AsDBool(row[0]) if !ok { - return regularRole, errors.AssertionFailedf("hasAdminRole: expected bool, got %T", row[0]) + return RegularRole, errors.AssertionFailedf("hasAdminRole: expected bool, got %T", row[0]) } if dbDatum { - return adminRole, nil + return AdminRole, nil } - return regularRole, nil + return RegularRole, nil } func (r *roleAuthorizationMux) hasRoleOption( @@ -440,10 +425,10 @@ func (r *roleAuthorizationMux) hasRoleOption( func (r *roleAuthorizationMux) ServeHTTP(w http.ResponseWriter, req *http.Request) { // The username is set in authenticationV2Mux, and must correspond with a // logged-in user. - username := userFromHTTPAuthInfoContext(req.Context()) + username := UserFromHTTPAuthInfoContext(req.Context()) if role, err := r.getRoleForUser(req.Context(), username); err != nil || role < r.role { if err != nil { - apiV2InternalError(req.Context(), err, w) + srverrors.APIV2InternalError(req.Context(), err, w) } else { http.Error(w, "user not allowed to access this endpoint", http.StatusForbidden) } @@ -452,7 +437,7 @@ func (r *roleAuthorizationMux) ServeHTTP(w http.ResponseWriter, req *http.Reques if r.option > 0 { ok, err := r.hasRoleOption(req.Context(), username, r.option) if err != nil { - apiV2InternalError(req.Context(), err, w) + srverrors.APIV2InternalError(req.Context(), err, w) return } else if !ok { http.Error(w, "user not allowed to access this endpoint", http.StatusForbidden) diff --git a/pkg/server/authentication.go b/pkg/server/authserver/authentication.go similarity index 55% rename from pkg/server/authentication.go rename to pkg/server/authserver/authentication.go index 3d8ce3c885d3..6a7f357eebcb 100644 --- a/pkg/server/authentication.go +++ b/pkg/server/authserver/authentication.go @@ -8,7 +8,7 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. -package server +package authserver import ( "bytes" @@ -19,17 +19,16 @@ import ( "fmt" "net/http" "strconv" - "strings" "time" "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/clusterversion" - "github.com/cockroachdb/cockroach/pkg/multitenant" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/security" "github.com/cockroachdb/cockroach/pkg/security/password" "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srverrors" "github.com/cockroachdb/cockroach/pkg/settings" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/sql" @@ -51,14 +50,15 @@ import ( ) const ( - // authPrefix is the prefix for RESTful endpoints used to provide - // authentication methods. - loginPath = "/login" - logoutPath = "/logout" + // LoginPath is the URL path to the login handler. + LoginPath = "/login" + + // LogoutPath is the URL path to the logout handler. + LogoutPath = "/logout" + // secretLength is the number of random bytes generated for session secrets. secretLength = 16 - // SessionCookieName is the name of the cookie used for HTTP auth. - SessionCookieName = "session" + // DemoLoginPath is the demo shell auto-login URL. DemoLoginPath = "/demologin" ) @@ -93,7 +93,8 @@ var ConfigureOIDC = func( return &noOIDCConfigured{}, nil } -var webSessionTimeout = settings.RegisterDurationSetting( +// WebSessionTimeout is the cluster setting for web session TTL. +var WebSessionTimeout = settings.RegisterDurationSetting( settings.TenantWritable, "server.web_session_timeout", "the duration that a newly created web session will be valid", @@ -103,16 +104,7 @@ var webSessionTimeout = settings.RegisterDurationSetting( type authenticationServer struct { cfg *base.Config - sqlServer *SQLServer -} - -// newAuthenticationServer allocates and returns a new REST server for -// authentication APIs. -func newAuthenticationServer(cfg *base.Config, s *SQLServer) *authenticationServer { - return &authenticationServer{ - cfg: cfg, - sqlServer: s, - } + sqlServer SQLServerInterface } // RegisterService registers the GRPC service. @@ -132,11 +124,7 @@ func (s *authenticationServer) RegisterGateway( return serverpb.RegisterLogOutHandler(ctx, mux, conn) } -// UserLogin verifies an incoming request by a user to create an web -// authentication session. It checks the provided credentials against the -// system.users table, and if successful creates a new authentication session. -// The session's ID and secret are returned to the caller as an HTTP cookie, -// added via a "Set-Cookie" header. +// UserLogin is part of the Server interface. func (s *authenticationServer) UserLogin( ctx context.Context, req *serverpb.UserLoginRequest, ) (*serverpb.UserLoginResponse, error) { @@ -155,9 +143,9 @@ func (s *authenticationServer) UserLogin( username, _ := username.MakeSQLUsernameFromUserInput(req.Username, username.PurposeValidation) // Verify the provided username/password pair. - verified, expired, err := s.verifyPasswordDBConsole(ctx, username, req.Password) + verified, expired, err := s.VerifyPasswordDBConsole(ctx, username, req.Password) if err != nil { - return nil, apiInternalError(ctx, err) + return nil, srverrors.APIInternalError(ctx, err) } if expired { return nil, status.Errorf( @@ -172,20 +160,20 @@ func (s *authenticationServer) UserLogin( cookie, err := s.createSessionFor(ctx, username) if err != nil { - return nil, apiInternalError(ctx, err) + return nil, srverrors.APIInternalError(ctx, err) } // Set the cookie header on the outgoing response. if err := grpc.SetHeader(ctx, metadata.Pairs("set-cookie", cookie.String())); err != nil { - return nil, apiInternalError(ctx, err) + return nil, srverrors.APIInternalError(ctx, err) } return &serverpb.UserLoginResponse{}, nil } -// demoLogin is the same as UserLogin but using the GET method. -// It is only available for demo and test clusters. -func (s *authenticationServer) demoLogin(w http.ResponseWriter, req *http.Request) { +// DemoLogin is the same as UserLogin but using the GET method. +// It is only available for 'cockroach demo' and test clusters. +func (s *authenticationServer) DemoLogin(w http.ResponseWriter, req *http.Request) { ctx := context.Background() ctx = logtags.AddTag(ctx, "client", log.SafeOperational(req.RemoteAddr)) ctx = logtags.AddTag(ctx, "demologin", nil) @@ -219,7 +207,7 @@ func (s *authenticationServer) demoLogin(w http.ResponseWriter, req *http.Reques // without further normalization. username, _ := username.MakeSQLUsernameFromUserInput(userInput, username.PurposeValidation) // Verify the provided username/password pair. - verified, expired, err := s.verifyPasswordDBConsole(ctx, username, password) + verified, expired, err := s.VerifyPasswordDBConsole(ctx, username, password) if err != nil { fail(err) return @@ -250,10 +238,7 @@ var errWebAuthenticationFailure = status.Errorf( "the provided credentials did not match any account on the server", ) -// UserLoginFromSSO checks for the existence of a given username and if it exists, -// creates a session for the username in the `web_sessions` table. -// The session's ID and secret are returned to the caller as an HTTP cookie, -// added via a "Set-Cookie" header. +// UserLoginFromSSO is part of the Server interface. func (s *authenticationServer) UserLoginFromSSO( ctx context.Context, reqUsername string, ) (*http.Cookie, error) { @@ -266,7 +251,7 @@ func (s *authenticationServer) UserLoginFromSSO( exists, _, canLoginDBConsole, _, _, _, _, err := sql.GetUserSessionInitInfo( ctx, - s.sqlServer.execCfg, + s.sqlServer.ExecutorConfig(), username, "", /* databaseName */ ) @@ -288,9 +273,9 @@ func (s *authenticationServer) createSessionFor( ctx context.Context, userName username.SQLUsername, ) (*http.Cookie, error) { // Create a new database session, generating an ID and secret key. - id, secret, err := s.newAuthSession(ctx, userName) + id, secret, err := s.NewAuthSession(ctx, userName) if err != nil { - return nil, apiInternalError(ctx, err) + return nil, srverrors.APIInternalError(ctx, err) } // Generate and set a session cookie for the response. Because HTTP cookies @@ -303,17 +288,17 @@ func (s *authenticationServer) createSessionFor( return EncodeSessionCookie(cookieValue, !s.cfg.DisableTLSForHTTP) } -// UserLogout allows a user to terminate their currently active session. +// UserLogout is part of the Server interface. func (s *authenticationServer) UserLogout( ctx context.Context, req *serverpb.UserLogoutRequest, ) (*serverpb.UserLogoutResponse, error) { md, ok := grpcutil.FastFromIncomingContext(ctx) if !ok { - return nil, apiInternalError(ctx, fmt.Errorf("couldn't get incoming context")) + return nil, srverrors.APIInternalError(ctx, fmt.Errorf("couldn't get incoming context")) } sessionIDs := md.Get(webSessionIDKeyStr) if len(sessionIDs) != 1 { - return nil, apiInternalError(ctx, fmt.Errorf("couldn't get incoming context")) + return nil, srverrors.APIInternalError(ctx, fmt.Errorf("couldn't get incoming context")) } sessionID, err := strconv.Atoi(sessionIDs[0]) @@ -324,7 +309,7 @@ func (s *authenticationServer) UserLogout( } // Revoke the session. - if n, err := s.sqlServer.internalExecutor.ExecEx( + if n, err := s.sqlServer.InternalExecutor().ExecEx( ctx, "revoke-auth-session", nil, /* txn */ @@ -332,7 +317,7 @@ func (s *authenticationServer) UserLogout( `UPDATE system.web_sessions SET "revokedAt" = now() WHERE id = $1`, sessionID, ); err != nil { - return nil, apiInternalError(ctx, err) + return nil, srverrors.APIInternalError(ctx, err) } else if n == 0 { err := status.Errorf( codes.InvalidArgument, @@ -348,17 +333,14 @@ func (s *authenticationServer) UserLogout( // Set the cookie header on the outgoing response. if err := grpc.SetHeader(ctx, metadata.Pairs("set-cookie", cookie.String())); err != nil { - return nil, apiInternalError(ctx, err) + return nil, srverrors.APIInternalError(ctx, err) } return &serverpb.UserLogoutResponse{}, nil } -// verifySession verifies the existence and validity of the session claimed by -// the supplied SessionCookie. Returns three parameters: a boolean indicating if -// the session was valid, the username associated with the session (if -// validated), and an error for any internal errors which prevented validation. -func (s *authenticationServer) verifySession( +// VerifySession is part of the Server interface. +func (s *authenticationServer) VerifySession( ctx context.Context, cookie *serverpb.SessionCookie, ) (bool, string, error) { // Look up session in database and verify hashed secret value. @@ -374,7 +356,7 @@ WHERE id = $1` isRevoked bool ) - row, err := s.sqlServer.internalExecutor.QueryRowEx( + row, err := s.sqlServer.InternalExecutor().QueryRowEx( ctx, "lookup-auth-session", nil, /* txn */ @@ -401,7 +383,7 @@ WHERE id = $1` return false, "", nil } - if now := s.sqlServer.execCfg.Clock.PhysicalTime(); !now.Before(expiresAt) { + if now := s.sqlServer.ExecutorConfig().Clock.PhysicalTime(); !now.Before(expiresAt) { return false, "", nil } @@ -415,22 +397,14 @@ WHERE id = $1` return true, userName, nil } -// verifyPasswordDBConsole verifies the passed username/password pair against the -// system.users table. The returned boolean indicates whether or not the -// verification succeeded; an error is returned if the validation process could -// not be completed. -// -// This function should *not* be used to validate logins into the SQL -// shell since it checks a separate authentication scheme. -// -// The caller is responsible for ensuring that the username is normalized. +// VerifyPasswordDBConsole is part of the Server interface. // (CockroachDB has case-insensitive usernames, unlike PostgreSQL.) -func (s *authenticationServer) verifyPasswordDBConsole( +func (s *authenticationServer) VerifyPasswordDBConsole( ctx context.Context, userName username.SQLUsername, passwordStr string, ) (valid bool, expired bool, err error) { exists, _, canLoginDBConsole, _, _, _, pwRetrieveFn, err := sql.GetUserSessionInitInfo( ctx, - s.sqlServer.execCfg, + s.sqlServer.ExecutorConfig(), userName, "", /* databaseName */ ) @@ -461,7 +435,7 @@ func (s *authenticationServer) verifyPasswordDBConsole( // pushes clusters upgraded from a previous version into using // SCRAM-SHA-256. sql.MaybeConvertStoredPasswordHash(ctx, - s.sqlServer.execCfg, + s.sqlServer.ExecutorConfig(), userName, passwordStr, hashedPassword) } @@ -481,14 +455,17 @@ func CreateAuthSecret() (secret, hashedSecret []byte, err error) { return secret, hashedSecret, nil } -// newAuthSession attempts to create a new authentication session for the given -// user. If successful, returns the ID and secret value for the new session. +// NewAuthSession attempts to create a new authentication session for +// the given user. If successful, returns the ID and secret value for +// the new session. // -// The caller is responsible to ensure the username has been normalized already. -func (s *authenticationServer) newAuthSession( +// The caller is responsible to ensure the username has been +// normalized already. +func (s *authenticationServer) NewAuthSession( ctx context.Context, userName username.SQLUsername, ) (int64, []byte, error) { - webSessionsTableHasUserIDCol := s.sqlServer.execCfg.Settings.Version.IsActive(ctx, + st := s.sqlServer.ExecutorConfig().Settings + webSessionsTableHasUserIDCol := st.Version.IsActive(ctx, clusterversion.V23_1WebSessionsTableHasUserIDColumn) secret, hashedSecret, err := CreateAuthSecret() @@ -496,7 +473,7 @@ func (s *authenticationServer) newAuthSession( return 0, nil, err } - expiration := s.sqlServer.execCfg.Clock.PhysicalTime().Add(webSessionTimeout.Get(&s.sqlServer.execCfg.Settings.SV)) + expiration := s.sqlServer.ExecutorConfig().Clock.PhysicalTime().Add(WebSessionTimeout.Get(&st.SV)) insertSessionStmt := ` INSERT INTO system.web_sessions ("hashedSecret", username, "expiresAt") @@ -512,7 +489,7 @@ RETURNING id } var id int64 - row, err := s.sqlServer.internalExecutor.QueryRowEx( + row, err := s.sqlServer.InternalExecutor().QueryRowEx( ctx, "create-auth-session", nil, /* txn */ @@ -554,35 +531,11 @@ type authenticationMux struct { allowAnonymous bool } -func newAuthenticationMuxAllowAnonymous( - s *authenticationServer, inner http.Handler, -) *authenticationMux { - return &authenticationMux{ - server: s, - inner: inner, - allowAnonymous: true, - } -} - -func newAuthenticationMux(s *authenticationServer, inner http.Handler) *authenticationMux { - return &authenticationMux{ - server: s, - inner: inner, - allowAnonymous: false, - } -} - -type webSessionUserKey struct{} -type webSessionIDKey struct{} - -const webSessionUserKeyStr = "websessionuser" -const webSessionIDKeyStr = "websessionid" - func (am *authenticationMux) ServeHTTP(w http.ResponseWriter, req *http.Request) { username, cookie, err := am.getSession(w, req) if err == nil { req = req.WithContext( - contextWithHTTPAuthInfo(req.Context(), username, cookie.ID)) + ContextWithHTTPAuthInfo(req.Context(), username, cookie.ID)) } else if !am.allowAnonymous { if log.V(1) { log.Infof(req.Context(), "web session error: %v", err) @@ -625,15 +578,15 @@ func makeCookieWithValue(value string, forHTTPSOnly bool) *http.Cookie { func (am *authenticationMux) getSession( w http.ResponseWriter, req *http.Request, ) (string, *serverpb.SessionCookie, error) { - st := am.server.sqlServer.cfg.Settings - cookie, err := findAndDecodeSessionCookie(req.Context(), st, req.Cookies()) + st := am.server.sqlServer.ExecutorConfig().Settings + cookie, err := FindAndDecodeSessionCookie(req.Context(), st, req.Cookies()) if err != nil { return "", nil, err } - valid, username, err := am.server.verifySession(req.Context(), cookie) + valid, username, err := am.server.VerifySession(req.Context(), cookie) if err != nil { - err := apiInternalError(req.Context(), err) + err := srverrors.APIInternalError(req.Context(), err) return "", nil, err } if !valid { @@ -657,11 +610,11 @@ func decodeSessionCookie(encodedCookie *http.Cookie) (*serverpb.SessionCookie, e return &sessionCookieValue, nil } -// authenticationHeaderMatcher is a GRPC header matcher function, which provides +// AuthenticationHeaderMatcher is a GRPC header matcher function, which provides // a conversion from GRPC headers to HTTP headers. This function is needed to // attach the "set-cookie" header to the response; by default, Grpc-Gateway // adds a prefix to all GRPC headers before adding them to the response. -func authenticationHeaderMatcher(key string) (string, bool) { +func AuthenticationHeaderMatcher(key string) (string, bool) { // GRPC converts all headers to lower case. if key == "set-cookie" { return key, true @@ -675,231 +628,3 @@ func authenticationHeaderMatcher(key string) (string, bool) { // duplicated here. return fmt.Sprintf("%s%s", gwruntime.MetadataHeaderPrefix, key), true } - -// contextWithHTTPAuthInfo embeds the HTTP authentication details into -// a go context. Meant for use with userFromHTTPAuthInfoContext(). -func contextWithHTTPAuthInfo( - ctx context.Context, username string, sessionID int64, -) context.Context { - ctx = context.WithValue(ctx, webSessionUserKey{}, username) - if sessionID != 0 { - ctx = context.WithValue(ctx, webSessionIDKey{}, sessionID) - } - return ctx -} - -// userFromHTTPAuthInfoContext returns a SQL username from the request -// context of a HTTP route requiring login. Only use in routes that require -// login (e.g. requiresAuth = true in the API v2 route definition). -// -// Do not use this function in _RPC_ API handlers. These access their -// SQL identity via the RPC incoming context. See -// userFromIncomingRPCContext(). -func userFromHTTPAuthInfoContext(ctx context.Context) username.SQLUsername { - return username.MakeSQLUsernameFromPreNormalizedString(ctx.Value(webSessionUserKey{}).(string)) -} - -// maybeUserFromHTTPAuthInfoContext is like userFromHTTPAuthInfoContext but -// it returns a boolean false if there is no user in the context. -func maybeUserFromHTTPAuthInfoContext(ctx context.Context) (username.SQLUsername, bool) { - if u := ctx.Value(webSessionUserKey{}); u != nil { - return username.MakeSQLUsernameFromPreNormalizedString(u.(string)), true - } - return username.SQLUsername{}, false -} - -// translateHTTPAuthInfoToGRPCMetadata translates the context.Value -// that results from HTTP authentication into gRPC metadata suitable -// for use by RPC API handlers. -func translateHTTPAuthInfoToGRPCMetadata(ctx context.Context, _ *http.Request) metadata.MD { - md := metadata.MD{} - if user := ctx.Value(webSessionUserKey{}); user != nil { - md.Set(webSessionUserKeyStr, user.(string)) - } - if sessionID := ctx.Value(webSessionIDKey{}); sessionID != nil { - md.Set(webSessionIDKeyStr, fmt.Sprintf("%v", sessionID)) - } - return md -} - -// forwardSQLIdentityThroughRPCCalls forwards the SQL identity of the -// original request (as populated by translateHTTPAuthInfoToGRPCMetadata in -// grpc-gateway) so it remains available to the remote node handling -// the request. -func forwardSQLIdentityThroughRPCCalls(ctx context.Context) context.Context { - if md, ok := grpcutil.FastFromIncomingContext(ctx); ok { - if u, ok := md[webSessionUserKeyStr]; ok { - return metadata.NewOutgoingContext(ctx, metadata.MD{webSessionUserKeyStr: u}) - } - } - return ctx -} - -// forwardHTTPAuthInfoToRPCCalls converts an HTTP API (v1 or v2) context, to one that -// can issue outgoing RPC requests under the same logged-in user. -func forwardHTTPAuthInfoToRPCCalls(ctx context.Context, r *http.Request) context.Context { - md := translateHTTPAuthInfoToGRPCMetadata(ctx, r) - return metadata.NewOutgoingContext(ctx, md) -} - -// userFromIncomingRPCContext is to be used in RPC API handlers. It -// assumes the SQL identity was populated in the context implicitly by -// gRPC via translateHTTPAuthInfoToGRPCMetadata(), or explicitly via -// forwardHTTPAuthInfoToRPCCalls() or -// forwardSQLIdentityThroughRPCCalls(). -// -// Do not use this function in _HTTP_ API handlers. Those access their -// SQL identity via a special context key. See -// userFromHTTPAuthInfoContext(). -func userFromIncomingRPCContext(ctx context.Context) (res username.SQLUsername, err error) { - md, ok := grpcutil.FastFromIncomingContext(ctx) - if !ok { - return username.RootUserName(), nil - } - usernames, ok := md[webSessionUserKeyStr] - if !ok { - // If the incoming context has metadata but no attached web session user, - // it's a gRPC / internal SQL connection which has root on the cluster. - // This assumption is a historical hiccup, and would be best described - // as a bug. See: https://github.com/cockroachdb/cockroach/issues/45018 - return username.RootUserName(), nil - } - if len(usernames) != 1 { - log.Warningf(ctx, "context's incoming metadata contains unexpected number of usernames: %+v ", md) - return res, fmt.Errorf( - "context's incoming metadata contains unexpected number of usernames: %+v ", md) - } - // At this point the user is already logged in, so we can assume - // the username has been normalized already. - username := username.MakeSQLUsernameFromPreNormalizedString(usernames[0]) - return username, nil -} - -// sessionCookieValue defines the data needed to construct the -// aggregate session cookie in the order provided. -type sessionCookieValue struct { - // The name of the tenant. - name string - // The value of set-cookie. - setCookie string -} - -// createAggregatedSessionCookieValue is used for multi-tenant login. -// It takes a slice of sessionCookieValue and converts it to a single -// string which is the aggregated session. Currently the format of the -// aggregated session is: `session,tenant_name,session2,tenant_name2` etc. -func createAggregatedSessionCookieValue(sessionCookieValue []sessionCookieValue) string { - var sessionsStr string - for _, val := range sessionCookieValue { - sessionCookieSlice := strings.Split(strings.ReplaceAll(val.setCookie, "session=", ""), ";") - sessionsStr += sessionCookieSlice[0] + "," + val.name + "," - } - if len(sessionsStr) > 0 { - sessionsStr = sessionsStr[:len(sessionsStr)-1] - } - return sessionsStr -} - -// findAndDecodeSessionCookie looks for multitenant-session and session cookies -// in the cookies slice. If they are found the value will need to be processed if -// it is a multitenant-session cookie (see findSessionCookieValueForTenant for details) -// and then decoded. If there is an error in decoding or processing, the function -// will return an error. -func findAndDecodeSessionCookie( - ctx context.Context, st *cluster.Settings, cookies []*http.Cookie, -) (*serverpb.SessionCookie, error) { - found := false - var sessionCookie *serverpb.SessionCookie - tenantSelectCookieVal := findTenantSelectCookieValue(cookies) - for _, cookie := range cookies { - if cookie.Name != SessionCookieName { - continue - } - found = true - mtSessionVal, err := findSessionCookieValueForTenant( - st, - cookie, - tenantSelectCookieVal) - if err != nil { - return sessionCookie, apiInternalError(ctx, err) - } - if mtSessionVal != "" { - cookie.Value = mtSessionVal - } - sessionCookie, err = decodeSessionCookie(cookie) - if err != nil { - // Multiple cookies with the same name may be included in the - // header. We continue searching even if we find a matching - // name with an invalid value. - log.Infof(ctx, "found a matching cookie that failed decoding: %v", err) - found = false - continue - } - break - } - if !found { - return nil, http.ErrNoCookie - } - return sessionCookie, nil -} - -// findSessionCookieValueForTenant finds the encoded session in the provided -// aggregated session cookie value established in multi-tenant clusters that's -// associated with the provided tenant name. If an empty tenant name is provided, -// we default to the DefaultTenantSelect cluster setting value. -// -// If the method cannot find a match between the tenant name and session, or -// if the provided session cookie is nil, it will return an empty string. -// -// e.g. tenant name is "system" and session cookie's value is -// "abcd1234,system,efgh5678,app" the output will be "abcd1234". -// -// In the case of legacy session cookies, where tenant names are not encoded -// into the cookie value, we assume that the session belongs to defaultTenantSelect. -// Note that these legacy session cookies only contained a single session string -// as the cookie's value. -func findSessionCookieValueForTenant( - st *cluster.Settings, sessionCookie *http.Cookie, tenantName string, -) (string, error) { - if sessionCookie == nil { - return "", nil - } - if mtSessionStr := sessionCookie.Value; sessionCookie.Value != "" { - sessionSlice := strings.Split(mtSessionStr, ",") - if len(sessionSlice) == 1 { - // If no separator was found in the cookie value, this is likely - // a cookie from a previous CRDB version where the cookie value - // contained a single session string without any tenant names encoded. - // To maintain backwards compatibility, assume this session belongs - // to the default tenant. In this case, the entire cookie value is - // the session string. - return mtSessionStr, nil - } - if tenantName == "" { - tenantName = multitenant.DefaultTenantSelect.Get(&st.SV) - } - var encodedSession string - for idx, val := range sessionSlice { - if val == tenantName && idx > 0 { - encodedSession = sessionSlice[idx-1] - } - } - if encodedSession == "" { - return "", errors.Newf("unable to find session cookie value that matches tenant %q", tenantName) - } - return encodedSession, nil - } - return "", nil -} - -// findTenantSelectCookieValue iterates through all request cookies in order -// to find the value of the tenant select cookie. If the tenant select cookie -// is not found, it returns the empty string. -func findTenantSelectCookieValue(cookies []*http.Cookie) string { - for _, c := range cookies { - if c.Name == TenantSelectCookieName { - return c.Value - } - } - return "" -} diff --git a/pkg/server/authentication_test.go b/pkg/server/authserver/authentication_test.go similarity index 84% rename from pkg/server/authentication_test.go rename to pkg/server/authserver/authentication_test.go index 912b8bb55c77..19b8602d96e6 100644 --- a/pkg/server/authentication_test.go +++ b/pkg/server/authserver/authentication_test.go @@ -8,7 +8,7 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. -package server +package authserver_test import ( "bytes" @@ -21,6 +21,7 @@ import ( "net/http" "net/http/cookiejar" "net/url" + "strings" "testing" "time" @@ -34,12 +35,16 @@ import ( "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/security" "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/debug" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/sql/execinfrapb" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/skip" "github.com/cockroachdb/cockroach/pkg/ts" "github.com/cockroachdb/cockroach/pkg/ts/tspb" "github.com/cockroachdb/cockroach/pkg/util" @@ -137,11 +142,11 @@ func TestSSLEnforcement(t *testing.T) { {"", insecureContext, http.StatusTemporaryRedirect}, // /_admin/: server.adminServer: no auth. - {adminPrefix + "health", rootCertsContext, http.StatusOK}, - {adminPrefix + "health", nodeCertsContext, http.StatusOK}, - {adminPrefix + "health", testCertsContext, http.StatusOK}, - {adminPrefix + "health", noCertsContext, http.StatusOK}, - {adminPrefix + "health", insecureContext, http.StatusTemporaryRedirect}, + {apiconstants.AdminPrefix + "health", rootCertsContext, http.StatusOK}, + {apiconstants.AdminPrefix + "health", nodeCertsContext, http.StatusOK}, + {apiconstants.AdminPrefix + "health", testCertsContext, http.StatusOK}, + {apiconstants.AdminPrefix + "health", noCertsContext, http.StatusOK}, + {apiconstants.AdminPrefix + "health", insecureContext, http.StatusTemporaryRedirect}, // /debug/: server.adminServer: no auth. {debug.Endpoint + "vars", rootCertsContext, http.StatusOK}, @@ -151,11 +156,11 @@ func TestSSLEnforcement(t *testing.T) { {debug.Endpoint + "vars", insecureContext, http.StatusTemporaryRedirect}, // /_status/nodes: server.statusServer: no auth. - {statusPrefix + "nodes", rootCertsContext, http.StatusOK}, - {statusPrefix + "nodes", nodeCertsContext, http.StatusOK}, - {statusPrefix + "nodes", testCertsContext, http.StatusOK}, - {statusPrefix + "nodes", noCertsContext, http.StatusOK}, - {statusPrefix + "nodes", insecureContext, http.StatusTemporaryRedirect}, + {apiconstants.StatusPrefix + "nodes", rootCertsContext, http.StatusOK}, + {apiconstants.StatusPrefix + "nodes", nodeCertsContext, http.StatusOK}, + {apiconstants.StatusPrefix + "nodes", testCertsContext, http.StatusOK}, + {apiconstants.StatusPrefix + "nodes", noCertsContext, http.StatusOK}, + {apiconstants.StatusPrefix + "nodes", insecureContext, http.StatusTemporaryRedirect}, // /ts/: ts.Server: no auth. {ts.URLPrefix, rootCertsContext, http.StatusNotFound}, @@ -175,7 +180,7 @@ func TestSSLEnforcement(t *testing.T) { } url := url.URL{ Scheme: tc.ctx.HTTPRequestScheme(), - Host: s.(*TestServer).Cfg.HTTPAddr, + Host: s.(*server.TestServer).Cfg.HTTPAddr, Path: tc.path, } resp, err := client.Get(url.String()) @@ -201,12 +206,12 @@ func TestVerifyPasswordDBConsole(t *testing.T) { s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) defer s.Stopper().Stop(ctx) - ts := s.(*TestServer) + ts := s.TenantOrServer() if util.RaceEnabled { // The default bcrypt cost makes this test approximately 30s slower when the // race detector is on. - security.BcryptCost.Override(ctx, &ts.Cfg.Settings.SV, int64(bcrypt.MinCost)) + security.BcryptCost.Override(ctx, &ts.ClusterSettings().SV, int64(bcrypt.MinCost)) } //location is used for timezone testing. @@ -290,7 +295,8 @@ func TestVerifyPasswordDBConsole(t *testing.T) { } { t.Run(tc.testName, func(t *testing.T) { username := username.MakeSQLUsernameFromPreNormalizedString(tc.username) - valid, expired, err := ts.authentication.verifyPasswordDBConsole(context.Background(), username, tc.password) + authServer := ts.HTTPAuthServer().(authserver.Server) + valid, expired, err := authServer.VerifyPasswordDBConsole(context.Background(), username, tc.password) if err != nil { t.Errorf( "credentials %s/%s failed with error %s, wanted no error", @@ -317,21 +323,21 @@ func TestCreateSession(t *testing.T) { defer log.Scope(t).Close(t) s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) defer s.Stopper().Stop(context.Background()) - ts := s.(*TestServer) + ts := s.TenantOrServer() username := username.TestUserName() - if err := ts.createAuthUser(username, false /* isAdmin */); err != nil { + if err := ts.CreateAuthUser(username, false /* isAdmin */); err != nil { t.Fatal(err) } // Create an authentication, noting the time before and after creation. This // lets us ensure that the timestamps created are accurate. - timeBoundBefore := ts.clock.PhysicalTime() - id, origSecret, err := ts.authentication.newAuthSession(context.Background(), username) + timeBoundBefore := ts.Clock().PhysicalTime() + id, origSecret, err := ts.HTTPAuthServer().(authserver.Server).NewAuthSession(context.Background(), username) if err != nil { t.Fatalf("error creating auth session: %s", err) } - timeBoundAfter := ts.clock.PhysicalTime() + timeBoundAfter := ts.Clock().PhysicalTime() // Query fields from created session. query := ` @@ -391,7 +397,7 @@ WHERE id = $1` if err := verifyTimestamp(sessLastUsed, timeBoundBefore, timeBoundAfter); err != nil { t.Fatalf("bad lastUsedAt timestamp: %s", err) } - timeout := webSessionTimeout.Get(&s.ClusterSettings().SV) + timeout := authserver.WebSessionTimeout.Get(&ts.ClusterSettings().SV) if err := verifyTimestamp( sessExpires, timeBoundBefore.Add(timeout), timeBoundAfter.Add(timeout), ); err != nil { @@ -412,13 +418,15 @@ func TestVerifySession(t *testing.T) { defer log.Scope(t).Close(t) s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) defer s.Stopper().Stop(context.Background()) - ts := s.(*TestServer) + ts := s.TenantOrServer() sessionUsername := username.TestUserName() - if err := ts.createAuthUser(sessionUsername, false /* isAdmin */); err != nil { + if err := ts.CreateAuthUser(sessionUsername, false /* isAdmin */); err != nil { t.Fatal(err) } - id, origSecret, err := ts.authentication.newAuthSession(context.Background(), sessionUsername) + + authServer := ts.HTTPAuthServer().(authserver.Server) + id, origSecret, err := authServer.NewAuthSession(context.Background(), sessionUsername) if err != nil { t.Fatal(err) } @@ -473,7 +481,7 @@ func TestVerifySession(t *testing.T) { }, } { t.Run(tc.testname, func(t *testing.T) { - valid, username, err := ts.authentication.verifySession(context.Background(), &tc.cookie) + valid, username, err := authServer.VerifySession(context.Background(), &tc.cookie) if err != nil { t.Fatalf("test got error %s, wanted no error", err) } @@ -492,7 +500,7 @@ func TestAuthenticationAPIUserLogin(t *testing.T) { defer log.Scope(t).Close(t) s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) defer s.Stopper().Stop(context.Background()) - ts := s.(*TestServer) + ts := s.TenantOrServer() const ( validUsername = "testuser" @@ -520,7 +528,7 @@ func TestAuthenticationAPIUserLogin(t *testing.T) { } var resp serverpb.UserLoginResponse return httputil.PostJSONWithRequest( - httpClient, ts.AdminURL().WithPath(loginPath).String(), &req, &resp, + httpClient, ts.AdminURL().WithPath(authserver.LoginPath).String(), &req, &resp, ) } @@ -544,7 +552,7 @@ func TestAuthenticationAPIUserLogin(t *testing.T) { if len(cookies) == 0 { t.Fatalf("good login got no cookies: %v", response) } - sessionCookie, err := findAndDecodeSessionCookie(context.Background(), ts.Cfg.Settings, cookies) + sessionCookie, err := authserver.FindAndDecodeSessionCookie(context.Background(), ts.ClusterSettings(), cookies) if err != nil { t.Fatalf("failed to decode session cookie: %s", err) } @@ -580,28 +588,47 @@ func TestAuthenticationAPIUserLogin(t *testing.T) { func TestLogoutClearsCookies(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ + DefaultTestTenant: base.TestControlsTenantsExplicitly, + }) defer s.Stopper().Stop(context.Background()) - ts := s.(*TestServer) - // Log in. - authHTTPClient, _, err := ts.getAuthenticatedHTTPClientAndCookie( - authenticatedUserName(), true, serverutils.SingleTenantSession, - ) - require.NoError(t, err) + testFunc := func(ts serverutils.TestTenantInterface, expectTenantCookieInClearList bool) { + // Log in. + authHTTPClient, _, err := ts.GetAuthenticatedHTTPClientAndCookie( + apiconstants.TestingUserName(), true, serverutils.SingleTenantSession, + ) + require.NoError(t, err) - // Log out. - resp, err := authHTTPClient.Get(ts.AdminURL().WithPath(logoutPath).String()) - require.NoError(t, err) - defer resp.Body.Close() + // Log out. + resp, err := authHTTPClient.Get(ts.AdminURL().WithPath(authserver.LogoutPath).String()) + require.NoError(t, err) + defer resp.Body.Close() - cookies := resp.Cookies() - cNames := make([]string, len(cookies)) - for i, c := range cookies { - require.Equal(t, "", c.Value) - cNames[i] = c.Name + cookies := resp.Cookies() + cNames := make([]string, len(cookies)) + for i, c := range cookies { + require.Equal(t, "", c.Value) + cNames[i] = c.Name + } + expected := []string{authserver.SessionCookieName} + if expectTenantCookieInClearList { + expected = append(expected, authserver.TenantSelectCookieName) + } + require.ElementsMatch(t, cNames, expected) } - require.ElementsMatch(t, cNames, []string{SessionCookieName, TenantSelectCookieName}) + + t.Run("system tenant", func(t *testing.T) { + testFunc(s, true) + }) + + t.Run("secondary tenant", func(t *testing.T) { + ts, err := s.StartTenant(context.Background(), base.TestTenantArgs{TenantID: roachpb.MustMakeTenantID(10)}) + if err != nil { + t.Fatal(err) + } + testFunc(ts, false) + }) } func TestLogout(t *testing.T) { @@ -609,11 +636,11 @@ func TestLogout(t *testing.T) { defer log.Scope(t).Close(t) s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) defer s.Stopper().Stop(context.Background()) - ts := s.(*TestServer) + ts := s.TenantOrServer() // Log in. - authHTTPClient, cookie, err := ts.getAuthenticatedHTTPClientAndCookie( - authenticatedUserName(), true, serverutils.SingleTenantSession, + authHTTPClient, cookie, err := ts.GetAuthenticatedHTTPClientAndCookie( + apiconstants.TestingUserName(), true, serverutils.SingleTenantSession, ) if err != nil { t.Fatal("error opening HTTP client", err) @@ -621,7 +648,7 @@ func TestLogout(t *testing.T) { // Log out. var resp serverpb.UserLogoutResponse - if err := httputil.GetJSON(authHTTPClient, ts.AdminURL().WithPath(logoutPath).String(), &resp); err != nil { + if err := httputil.GetJSON(authHTTPClient, ts.AdminURL().WithPath(authserver.LogoutPath).String(), &resp); err != nil { t.Fatal("logout request failed:", err) } @@ -651,7 +678,7 @@ func TestLogout(t *testing.T) { } // Try to use the revoked cookie; verify that it doesn't work. - encodedCookie, err := EncodeSessionCookie(cookie, false /* forHTTPSOnly */) + encodedCookie, err := authserver.EncodeSessionCookie(cookie, false /* forHTTPSOnly */) if err != nil { t.Fatal(err) } @@ -689,7 +716,7 @@ func TestAuthenticationMux(t *testing.T) { defer log.Scope(t).Close(t) s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) defer s.Stopper().Stop(context.Background()) - tsrv := s.(*TestServer) + tsrv := s.TenantOrServer() // Both the normal and authenticated client will be used for each test. normalClient, err := tsrv.GetUnauthenticatedHTTPClient() @@ -745,12 +772,16 @@ func TestAuthenticationMux(t *testing.T) { body []byte cookieHeader string }{ - {"GET", adminPrefix + "users", nil, ""}, - {"GET", adminPrefix + "users", nil, "session=badcookie"}, - {"GET", statusPrefix + "sessions", nil, ""}, + {"GET", apiconstants.AdminPrefix + "users", nil, ""}, + {"GET", apiconstants.AdminPrefix + "users", nil, "session=badcookie"}, + {"GET", apiconstants.StatusPrefix + "sessions", nil, ""}, {"POST", ts.URLPrefix + "query", tsReqBuffer.Bytes(), ""}, } { t.Run("path="+tc.path, func(t *testing.T) { + if strings.HasPrefix(tc.path, ts.URLPrefix) { + skip.WithIssue(t, 102378) + } + // Verify normal client returns 401 Unauthorized. if err := runRequest(normalClient, tc.method, tc.path, tc.body, tc.cookieHeader, http.StatusUnauthorized); err != nil { t.Fatalf("request %s failed when not authorized: %s", tc.path, err) @@ -772,6 +803,8 @@ func TestGRPCAuthentication(t *testing.T) { s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) defer s.Stopper().Stop(ctx) + ts := s.TenantOrServer() + // For each subsystem we pick a representative RPC. The idea is not to // exhaustively test each RPC but to prevent server startup from being // refactored in such a way that an entire subsystem becomes inadvertently @@ -779,6 +812,8 @@ func TestGRPCAuthentication(t *testing.T) { subsystems := []struct { name string sendRPC func(context.Context, *grpc.ClientConn) error + + storageOnly bool }{ {"gossip", func(ctx context.Context, conn *grpc.ClientConn) error { stream, err := gossip.NewGossipClient(conn).Gossip(ctx) @@ -788,15 +823,15 @@ func TestGRPCAuthentication(t *testing.T) { _ = stream.Send(&gossip.Request{}) _, err = stream.Recv() return err - }}, + }, true}, {"internal", func(ctx context.Context, conn *grpc.ClientConn) error { _, err := kvpb.NewInternalClient(conn).Batch(ctx, &kvpb.BatchRequest{}) return err - }}, + }, true}, {"perReplica", func(ctx context.Context, conn *grpc.ClientConn) error { _, err := kvserver.NewPerReplicaClient(conn).CollectChecksum(ctx, &kvserver.CollectChecksumRequest{}) return err - }}, + }, true}, {"raft", func(ctx context.Context, conn *grpc.ClientConn) error { stream, err := kvserver.NewMultiRaftClient(conn).RaftMessageBatch(ctx) if err != nil { @@ -805,7 +840,7 @@ func TestGRPCAuthentication(t *testing.T) { _ = stream.Send(&kvserverpb.RaftMessageRequestBatch{}) _, err = stream.Recv() return err - }}, + }, true}, {"closedTimestamp", func(ctx context.Context, conn *grpc.ClientConn) error { stream, err := ctpb.NewSideTransportClient(conn).PushUpdates(ctx) if err != nil { @@ -814,7 +849,7 @@ func TestGRPCAuthentication(t *testing.T) { _ = stream.Send(&ctpb.Update{}) _, err = stream.Recv() return err - }}, + }, true}, {"distSQL", func(ctx context.Context, conn *grpc.ClientConn) error { stream, err := execinfrapb.NewDistSQLClient(conn).FlowStream(ctx) if err != nil { @@ -823,22 +858,22 @@ func TestGRPCAuthentication(t *testing.T) { _ = stream.Send(&execinfrapb.ProducerMessage{}) _, err = stream.Recv() return err - }}, + }, false}, {"init", func(ctx context.Context, conn *grpc.ClientConn) error { _, err := serverpb.NewInitClient(conn).Bootstrap(ctx, &serverpb.BootstrapRequest{}) return err - }}, + }, true}, {"admin", func(ctx context.Context, conn *grpc.ClientConn) error { _, err := serverpb.NewAdminClient(conn).Databases(ctx, &serverpb.DatabasesRequest{}) return err - }}, + }, false}, {"status", func(ctx context.Context, conn *grpc.ClientConn) error { _, err := serverpb.NewStatusClient(conn).ListSessions(ctx, &serverpb.ListSessionsRequest{}) return err - }}, + }, false}, } - conn, err := grpc.DialContext(ctx, s.ServingRPCAddr(), + conn, err := grpc.DialContext(ctx, ts.RPCAddr(), grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ InsecureSkipVerify: true, }))) @@ -849,6 +884,10 @@ func TestGRPCAuthentication(t *testing.T) { _ = conn.Close() // nolint:grpcconnclose }(conn) for _, subsystem := range subsystems { + if subsystem.storageOnly && s.StartedDefaultTestTenant() { + // Subsystem only available on the system tenant. + continue + } t.Run(fmt.Sprintf("no-cert/%s", subsystem.name), func(t *testing.T) { err := subsystem.sendRPC(ctx, conn) if exp := "TLSInfo is not available in request context"; !testutils.IsError(err, exp) { @@ -857,7 +896,7 @@ func TestGRPCAuthentication(t *testing.T) { }) } - certManager, err := s.RPCContext().GetCertificateManager() + certManager, err := ts.RPCContext().GetCertificateManager() if err != nil { t.Fatal(err) } @@ -865,7 +904,7 @@ func TestGRPCAuthentication(t *testing.T) { if err != nil { t.Fatal(err) } - conn, err = grpc.DialContext(ctx, s.ServingRPCAddr(), + conn, err = grpc.DialContext(ctx, ts.RPCAddr(), grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) if err != nil { t.Fatal(err) @@ -874,6 +913,10 @@ func TestGRPCAuthentication(t *testing.T) { _ = conn.Close() // nolint:grpcconnclose }(conn) for _, subsystem := range subsystems { + if subsystem.storageOnly && s.StartedDefaultTestTenant() { + // Subsystem only available on the system tenant. + continue + } t.Run(fmt.Sprintf("bad-user/%s", subsystem.name), func(t *testing.T) { err := subsystem.sendRPC(ctx, conn) if exp := `need root or node client cert to perform RPCs on this server`; !testutils.IsError(err, exp) { @@ -889,19 +932,21 @@ func TestCreateAggregatedSessionCookieValue(t *testing.T) { tests := []struct { name string - mapArg []sessionCookieValue + mapArg []authserver.SessionCookieValue resExpected string }{ - {"standard arg", []sessionCookieValue{ - {name: "system", setCookie: "session=abcd1234"}, - {name: "app", setCookie: "session=efgh5678"}}, + {"standard arg", + []authserver.SessionCookieValue{ + authserver.MakeSessionCookieValue("system", "session=abcd1234"), + authserver.MakeSessionCookieValue("app", "session=efgh5678"), + }, "abcd1234,system,efgh5678,app", }, - {"empty arg", []sessionCookieValue{}, ""}, + {"empty arg", []authserver.SessionCookieValue{}, ""}, } for _, test := range tests { t.Run(fmt.Sprintf("create-session-cookie/%s", test.name), func(t *testing.T) { - res := createAggregatedSessionCookieValue(test.mapArg) + res := authserver.CreateAggregatedSessionCookieValue(test.mapArg) require.Equal(t, test.resExpected, res) }) } @@ -921,7 +966,7 @@ func TestFindSessionCookieValue(t *testing.T) { { name: "standard args", sessionCookie: &http.Cookie{ - Name: SessionCookieName, + Name: authserver.SessionCookieName, Value: normalSessionStr, Path: "/", }, @@ -939,7 +984,7 @@ func TestFindSessionCookieValue(t *testing.T) { { name: "no tenant cookie", sessionCookie: &http.Cookie{ - Name: SessionCookieName, + Name: authserver.SessionCookieName, Value: normalSessionStr, Path: "/", }, @@ -949,7 +994,7 @@ func TestFindSessionCookieValue(t *testing.T) { { name: "empty string tenant cookie", sessionCookie: &http.Cookie{ - Name: SessionCookieName, + Name: authserver.SessionCookieName, Value: normalSessionStr, Path: "/", }, @@ -960,7 +1005,7 @@ func TestFindSessionCookieValue(t *testing.T) { { name: "no tenant name match", sessionCookie: &http.Cookie{ - Name: SessionCookieName, + Name: authserver.SessionCookieName, Value: normalSessionStr, Path: "/", }, @@ -971,7 +1016,7 @@ func TestFindSessionCookieValue(t *testing.T) { { name: "legacy session cookie", sessionCookie: &http.Cookie{ - Name: SessionCookieName, + Name: authserver.SessionCookieName, Value: "aaskjhf218==", Path: "/", }, @@ -983,7 +1028,7 @@ func TestFindSessionCookieValue(t *testing.T) { for _, test := range tests { t.Run(fmt.Sprintf("find-session-cookie/%s", test.name), func(t *testing.T) { st := cluster.MakeClusterSettings() - res, err := findSessionCookieValueForTenant(st, test.sessionCookie, test.tenantSelectValue) + res, err := authserver.FindSessionCookieValueForTenant(st, test.sessionCookie, test.tenantSelectValue) require.Equal(t, test.resExpected, res) require.Equal(t, test.errorExpected, err != nil) }) diff --git a/pkg/server/authserver/context.go b/pkg/server/authserver/context.go new file mode 100644 index 000000000000..751334c0f777 --- /dev/null +++ b/pkg/server/authserver/context.go @@ -0,0 +1,127 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package authserver + +import ( + "context" + "fmt" + "net/http" + + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/util/grpcutil" + "github.com/cockroachdb/cockroach/pkg/util/log" + "google.golang.org/grpc/metadata" +) + +type webSessionUserKey struct{} +type webSessionIDKey struct{} + +const webSessionUserKeyStr = "websessionuser" +const webSessionIDKeyStr = "websessionid" + +// ContextWithHTTPAuthInfo embeds the HTTP authentication details into +// a go context. Meant for use with userFromHTTPAuthInfoContext(). +func ContextWithHTTPAuthInfo( + ctx context.Context, username string, sessionID int64, +) context.Context { + ctx = context.WithValue(ctx, webSessionUserKey{}, username) + if sessionID != 0 { + ctx = context.WithValue(ctx, webSessionIDKey{}, sessionID) + } + return ctx +} + +// UserFromHTTPAuthInfoContext returns a SQL username from the request +// context of a HTTP route requiring login. Only use in routes that require +// login (e.g. requiresAuth = true in the API v2 route definition). +// +// Do not use this function in _RPC_ API handlers. These access their +// SQL identity via the RPC incoming context. See +// userFromIncomingRPCContext(). +func UserFromHTTPAuthInfoContext(ctx context.Context) username.SQLUsername { + return username.MakeSQLUsernameFromPreNormalizedString(ctx.Value(webSessionUserKey{}).(string)) +} + +// MaybeUserFromHTTPAuthInfoContext is like userFromHTTPAuthInfoContext but +// it returns a boolean false if there is no user in the context. +func MaybeUserFromHTTPAuthInfoContext(ctx context.Context) (username.SQLUsername, bool) { + if u := ctx.Value(webSessionUserKey{}); u != nil { + return username.MakeSQLUsernameFromPreNormalizedString(u.(string)), true + } + return username.SQLUsername{}, false +} + +// TranslateHTTPAuthInfoToGRPCMetadata translates the context.Value +// that results from HTTP authentication into gRPC metadata suitable +// for use by RPC API handlers. +func TranslateHTTPAuthInfoToGRPCMetadata(ctx context.Context, _ *http.Request) metadata.MD { + md := metadata.MD{} + if user := ctx.Value(webSessionUserKey{}); user != nil { + md.Set(webSessionUserKeyStr, user.(string)) + } + if sessionID := ctx.Value(webSessionIDKey{}); sessionID != nil { + md.Set(webSessionIDKeyStr, fmt.Sprintf("%v", sessionID)) + } + return md +} + +// ForwardSQLIdentityThroughRPCCalls forwards the SQL identity of the +// original request (as populated by translateHTTPAuthInfoToGRPCMetadata in +// grpc-gateway) so it remains available to the remote node handling +// the request. +func ForwardSQLIdentityThroughRPCCalls(ctx context.Context) context.Context { + if md, ok := grpcutil.FastFromIncomingContext(ctx); ok { + if u, ok := md[webSessionUserKeyStr]; ok { + return metadata.NewOutgoingContext(ctx, metadata.MD{webSessionUserKeyStr: u}) + } + } + return ctx +} + +// ForwardHTTPAuthInfoToRPCCalls converts an HTTP API (v1 or v2) context, to one that +// can issue outgoing RPC requests under the same logged-in user. +func ForwardHTTPAuthInfoToRPCCalls(ctx context.Context, r *http.Request) context.Context { + md := TranslateHTTPAuthInfoToGRPCMetadata(ctx, r) + return metadata.NewOutgoingContext(ctx, md) +} + +// UserFromIncomingRPCContext is to be used in RPC API handlers. It +// assumes the SQL identity was populated in the context implicitly by +// gRPC via translateHTTPAuthInfoToGRPCMetadata(), or explicitly via +// forwardHTTPAuthInfoToRPCCalls() or +// forwardSQLIdentityThroughRPCCalls(). +// +// Do not use this function in _HTTP_ API handlers. Those access their +// SQL identity via a special context key. See +// userFromHTTPAuthInfoContext(). +func UserFromIncomingRPCContext(ctx context.Context) (res username.SQLUsername, err error) { + md, ok := grpcutil.FastFromIncomingContext(ctx) + if !ok { + return username.RootUserName(), nil + } + usernames, ok := md[webSessionUserKeyStr] + if !ok { + // If the incoming context has metadata but no attached web session user, + // it's a gRPC / internal SQL connection which has root on the cluster. + // This assumption is a historical hiccup, and would be best described + // as a bug. See: https://github.com/cockroachdb/cockroach/issues/45018 + return username.RootUserName(), nil + } + if len(usernames) != 1 { + log.Warningf(ctx, "context's incoming metadata contains unexpected number of usernames: %+v ", md) + return res, fmt.Errorf( + "context's incoming metadata contains unexpected number of usernames: %+v ", md) + } + // At this point the user is already logged in, so we can assume + // the username has been normalized already. + username := username.MakeSQLUsernameFromPreNormalizedString(usernames[0]) + return username, nil +} diff --git a/pkg/server/authserver/cookie.go b/pkg/server/authserver/cookie.go new file mode 100644 index 000000000000..d8145a5fea8c --- /dev/null +++ b/pkg/server/authserver/cookie.go @@ -0,0 +1,176 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package authserver + +import ( + "context" + "net/http" + "strings" + + "github.com/cockroachdb/cockroach/pkg/multitenant" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srverrors" + "github.com/cockroachdb/cockroach/pkg/settings/cluster" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/errors" +) + +const ( + // SessionCookieName is the name of the cookie used for HTTP auth. + SessionCookieName = "session" + + // TenantSelectCookieName is the name of the HTTP cookie used to select a particular tenant, + // if the custom header is not specified. + TenantSelectCookieName = `tenant` +) + +// SessionCookieValue defines the data needed to construct the +// aggregate session cookie in the order provided. +type SessionCookieValue struct { + // The name of the tenant. + name string + // The value of set-cookie. + setCookie string +} + +// MakeSessionCookieValue creates a SessionCookieValue from the provided +// tenant name and set-cookie value. +func MakeSessionCookieValue(name, setCookie string) SessionCookieValue { + return SessionCookieValue{ + name: name, + setCookie: setCookie, + } +} + +// Name returns the name in the tenant in the cookie value. +func (s SessionCookieValue) Name() string { + return s.name +} + +// CreateAggregatedSessionCookieValue is used for multi-tenant login. +// It takes a slice of SessionCookieValue and converts it to a single +// string which is the aggregated session. Currently the format of the +// aggregated session is: `session,tenant_name,session2,tenant_name2` etc. +func CreateAggregatedSessionCookieValue(sessionCookieValue []SessionCookieValue) string { + var sessionsStr string + for _, val := range sessionCookieValue { + sessionCookieSlice := strings.Split(strings.ReplaceAll(val.setCookie, "session=", ""), ";") + sessionsStr += sessionCookieSlice[0] + "," + val.name + "," + } + if len(sessionsStr) > 0 { + sessionsStr = sessionsStr[:len(sessionsStr)-1] + } + return sessionsStr +} + +// FindAndDecodeSessionCookie looks for multitenant-session and session cookies +// in the cookies slice. If they are found the value will need to be processed if +// it is a multitenant-session cookie (see findSessionCookieValueForTenant for details) +// and then decoded. If there is an error in decoding or processing, the function +// will return an error. +func FindAndDecodeSessionCookie( + ctx context.Context, st *cluster.Settings, cookies []*http.Cookie, +) (*serverpb.SessionCookie, error) { + found := false + var sessionCookie *serverpb.SessionCookie + tenantSelectCookieVal := findTenantSelectCookieValue(cookies) + for _, cookie := range cookies { + if cookie.Name != SessionCookieName { + continue + } + found = true + mtSessionVal, err := FindSessionCookieValueForTenant( + st, + cookie, + tenantSelectCookieVal) + if err != nil { + return sessionCookie, srverrors.APIInternalError(ctx, err) + } + if mtSessionVal != "" { + cookie.Value = mtSessionVal + } + sessionCookie, err = decodeSessionCookie(cookie) + if err != nil { + // Multiple cookies with the same name may be included in the + // header. We continue searching even if we find a matching + // name with an invalid value. + log.Infof(ctx, "found a matching cookie that failed decoding: %v", err) + found = false + continue + } + break + } + if !found { + return nil, http.ErrNoCookie + } + return sessionCookie, nil +} + +// FindSessionCookieValueForTenant finds the encoded session in the provided +// aggregated session cookie value established in multi-tenant clusters that's +// associated with the provided tenant name. If an empty tenant name is provided, +// we default to the DefaultTenantSelect cluster setting value. +// +// If the method cannot find a match between the tenant name and session, or +// if the provided session cookie is nil, it will return an empty string. +// +// e.g. tenant name is "system" and session cookie's value is +// "abcd1234,system,efgh5678,app" the output will be "abcd1234". +// +// In the case of legacy session cookies, where tenant names are not encoded +// into the cookie value, we assume that the session belongs to defaultTenantSelect. +// Note that these legacy session cookies only contained a single session string +// as the cookie's value. +func FindSessionCookieValueForTenant( + st *cluster.Settings, sessionCookie *http.Cookie, tenantName string, +) (string, error) { + if sessionCookie == nil { + return "", nil + } + if mtSessionStr := sessionCookie.Value; sessionCookie.Value != "" { + sessionSlice := strings.Split(mtSessionStr, ",") + if len(sessionSlice) == 1 { + // If no separator was found in the cookie value, this is likely + // a cookie from a previous CRDB version where the cookie value + // contained a single session string without any tenant names encoded. + // To maintain backwards compatibility, assume this session belongs + // to the default tenant. In this case, the entire cookie value is + // the session string. + return mtSessionStr, nil + } + if tenantName == "" { + tenantName = multitenant.DefaultTenantSelect.Get(&st.SV) + } + var encodedSession string + for idx, val := range sessionSlice { + if val == tenantName && idx > 0 { + encodedSession = sessionSlice[idx-1] + } + } + if encodedSession == "" { + return "", errors.Newf("unable to find session cookie value that matches tenant %q", tenantName) + } + return encodedSession, nil + } + return "", nil +} + +// findTenantSelectCookieValue iterates through all request cookies in order +// to find the value of the tenant select cookie. If the tenant select cookie +// is not found, it returns the empty string. +func findTenantSelectCookieValue(cookies []*http.Cookie) string { + for _, c := range cookies { + if c.Name == TenantSelectCookieName { + return c.Value + } + } + return "" +} diff --git a/pkg/server/authserver/main_test.go b/pkg/server/authserver/main_test.go new file mode 100644 index 000000000000..a293c92b404f --- /dev/null +++ b/pkg/server/authserver/main_test.go @@ -0,0 +1,35 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package authserver_test + +import ( + "os" + "testing" + + "github.com/cockroachdb/cockroach/pkg/ccl" + "github.com/cockroachdb/cockroach/pkg/kv/kvclient/kvtenant" + "github.com/cockroachdb/cockroach/pkg/security/securityassets" + "github.com/cockroachdb/cockroach/pkg/security/securitytest" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/testcluster" +) + +func TestMain(m *testing.M) { + securityassets.SetLoader(securitytest.EmbeddedAssets) + serverutils.InitTestServerFactory(server.TestServerFactory) + serverutils.InitTestClusterFactory(testcluster.TestClusterFactory) + kvtenant.InitTestConnectorFactory() + defer ccl.TestingEnableEnterprise()() + os.Exit(m.Run()) +} + +//go:generate ../util/leaktest/add-leaktest.sh *_test.go diff --git a/pkg/server/combined_statement_stats.go b/pkg/server/combined_statement_stats.go index 1b528692af0d..0d725e9dd87c 100644 --- a/pkg/server/combined_statement_stats.go +++ b/pkg/server/combined_statement_stats.go @@ -18,7 +18,9 @@ import ( "time" "github.com/cockroachdb/cockroach/pkg/clusterversion" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srverrors" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/sql" "github.com/cockroachdb/cockroach/pkg/sql/appstatspb" @@ -55,10 +57,10 @@ func closeIterator(it isql.Rows, err error) error { func (s *statusServer) CombinedStatementStats( ctx context.Context, req *serverpb.CombinedStatementsStatsRequest, ) (*serverpb.StatementsResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { + if err := s.privilegeChecker.RequireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { return nil, err } @@ -126,7 +128,7 @@ func getCombinedStatementStats( activityHasAllData, tableSuffix) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } } @@ -154,7 +156,7 @@ func getCombinedStatementStats( } if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } stmtsRunTime, txnsRunTime, err := getTotalRuntimeSecs( @@ -166,7 +168,7 @@ func getCombinedStatementStats( tableSuffix) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } response := &serverpb.StatementsResponse{ @@ -612,7 +614,7 @@ FROM (SELECT fingerprint_id, aostClause, orderAndLimit) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } } @@ -632,7 +634,7 @@ FROM (SELECT fingerprint_id, aostClause, orderAndLimit) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } } @@ -651,7 +653,7 @@ FROM (SELECT fingerprint_id, aostClause, orderAndLimit) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } } @@ -669,7 +671,7 @@ FROM (SELECT fingerprint_id, var statementFingerprintID uint64 if statementFingerprintID, err = sqlstatsutil.DatumToUint64(row[0]); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } var txnFingerprintID uint64 @@ -677,7 +679,7 @@ FROM (SELECT fingerprint_id, txnFingerprintIDs := make([]appstatspb.TransactionFingerprintID, 0, txnFingerprintDatums.Array.Len()) for _, idDatum := range txnFingerprintDatums.Array { if txnFingerprintID, err = sqlstatsutil.DatumToUint64(idDatum); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } txnFingerprintIDs = append(txnFingerprintIDs, appstatspb.TransactionFingerprintID(txnFingerprintID)) } @@ -689,14 +691,14 @@ FROM (SELECT fingerprint_id, var metadata appstatspb.CollectedStatementStatistics metadataJSON := tree.MustBeDJSON(row[4]).JSON if err = sqlstatsutil.DecodeStmtStatsMetadataJSON(metadataJSON, &metadata); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } metadata.Key.App = app statsJSON := tree.MustBeDJSON(row[5]).JSON if err = sqlstatsutil.DecodeStmtStatsStatisticsJSON(statsJSON, &metadata.Stats); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } stmt := serverpb.StatementsResponse_CollectedStatementStatistics{ @@ -714,7 +716,7 @@ FROM (SELECT fingerprint_id, } if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return statements, nil @@ -742,7 +744,7 @@ func getIterator( it, err := ie.QueryIteratorEx(ctx, queryInfo, nil, sessiondata.NodeUserSessionDataOverride, query, args...) if err != nil { - return it, serverError(ctx, err) + return it, srverrors.ServerError(ctx, err) } return it, nil @@ -787,7 +789,7 @@ FROM (SELECT app_name, aostClause, orderAndLimit) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } } @@ -811,7 +813,7 @@ FROM (SELECT app_name, aostClause, orderAndLimit) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } } @@ -830,7 +832,7 @@ FROM (SELECT app_name, aostClause, orderAndLimit) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } } @@ -851,18 +853,18 @@ FROM (SELECT app_name, aggregatedTs := tree.MustBeDTimestampTZ(row[1]).Time fingerprintID, err := sqlstatsutil.DatumToUint64(row[2]) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } var metadata appstatspb.CollectedTransactionStatistics metadataJSON := tree.MustBeDJSON(row[3]).JSON if err = sqlstatsutil.DecodeTxnStatsMetadataJSON(metadataJSON, &metadata); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } statsJSON := tree.MustBeDJSON(row[4]).JSON if err = sqlstatsutil.DecodeTxnStatsStatisticsJSON(statsJSON, &metadata.Stats); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } txnStats := serverpb.StatementsResponse_ExtendedCollectedTransactionStatistics{ @@ -879,7 +881,7 @@ FROM (SELECT app_name, } if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return transactions, nil @@ -929,7 +931,7 @@ GROUP BY app_name`, whereClause), args...) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } } @@ -939,7 +941,7 @@ GROUP BY if it != nil { err = closeIterator(it, err) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } } query = fmt.Sprintf( @@ -950,7 +952,7 @@ GROUP BY sessiondata.NodeUserSessionDataOverride, query, args...) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } } @@ -959,7 +961,7 @@ GROUP BY if !it.HasResults() { err = closeIterator(it, err) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } query = fmt.Sprintf(queryFormat, "crdb_internal.statement_statistics", whereClause) @@ -967,7 +969,7 @@ GROUP BY sessiondata.NodeUserSessionDataOverride, query, args...) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } } @@ -992,25 +994,25 @@ GROUP BY var statementFingerprintID uint64 if statementFingerprintID, err = sqlstatsutil.DatumToUint64(row[0]); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } var txnFingerprintID uint64 if txnFingerprintID, err = sqlstatsutil.DatumToUint64(row[1]); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } var metadata appstatspb.CollectedStatementStatistics metadataJSON := tree.MustBeDJSON(row[2]).JSON if err = sqlstatsutil.DecodeStmtStatsMetadataJSON(metadataJSON, &metadata); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } metadata.Key.TransactionFingerprintID = appstatspb.TransactionFingerprintID(txnFingerprintID) statsJSON := tree.MustBeDJSON(row[3]).JSON if err = sqlstatsutil.DecodeStmtStatsStatisticsJSON(statsJSON, &metadata.Stats); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } app := string(tree.MustBeDString(row[4])) @@ -1029,7 +1031,7 @@ GROUP BY } if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return statements, nil @@ -1038,10 +1040,10 @@ GROUP BY func (s *statusServer) StatementDetails( ctx context.Context, req *serverpb.StatementDetailsRequest, ) (*serverpb.StatementDetailsResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { + if err := s.privilegeChecker.RequireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { return nil, err } @@ -1064,7 +1066,7 @@ func getStatementDetails( showInternal := SQLStatsShowInternal.Get(&settings.SV) whereClause, args, err := getStatementDetailsQueryClausesAndArgs(req, testingKnobs, showInternal) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } // Used for mixed cluster version, where we need to use the persisted view with _v22_2. @@ -1092,7 +1094,7 @@ func getStatementDetails( statementTotal, err := getTotalStatementDetails(ctx, ie, whereClause, args, activityHasData, tableSuffix) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } statementStatisticsPerAggregatedTs, err := getStatementDetailsPerAggregatedTs( ctx, @@ -1103,7 +1105,7 @@ func getStatementDetails( activityHasData, tableSuffix) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } statementStatisticsPerPlanHash, err := getStatementDetailsPerPlanHash( ctx, @@ -1114,7 +1116,7 @@ func getStatementDetails( activityHasData, tableSuffix) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } // At this point the counts on statementTotal.metadata have the count for how many times we saw that value @@ -1246,7 +1248,7 @@ GROUP BY fingerprint_id LIMIT 1`, whereClause), args...) if err != nil { - return statement, serverError(ctx, err) + return statement, srverrors.ServerError(ctx, err) } } // If there are no results from the activity table, retrieve the data from the persisted table. @@ -1258,7 +1260,7 @@ LIMIT 1`, whereClause), args...) "crdb_internal.statement_statistics_persisted"+tableSuffix, whereClause), args...) if err != nil { - return statement, serverError(ctx, err) + return statement, srverrors.ServerError(ctx, err) } } @@ -1269,7 +1271,7 @@ LIMIT 1`, whereClause), args...) sessiondata.NodeUserSessionDataOverride, fmt.Sprintf(queryFormat, "crdb_internal.statement_statistics", whereClause), args...) if err != nil { - return statement, serverError(ctx, err) + return statement, srverrors.ServerError(ctx, err) } } @@ -1278,7 +1280,7 @@ LIMIT 1`, whereClause), args...) return statement, nil } if row.Len() != expectedNumDatums { - return statement, serverError(ctx, errors.Newf( + return statement, srverrors.ServerError(ctx, errors.Newf( "expected %d columns on getTotalStatementDetails, received %d", expectedNumDatums)) } @@ -1287,7 +1289,7 @@ LIMIT 1`, whereClause), args...) metadataJSON := tree.MustBeDJSON(row[0]).JSON if err = sqlstatsutil.DecodeAggregatedMetadataJSON(metadataJSON, &aggregatedMetadata); err != nil { - return statement, serverError(ctx, err) + return statement, srverrors.ServerError(ctx, err) } apps := tree.MustBeDArray(row[1]) @@ -1299,7 +1301,7 @@ LIMIT 1`, whereClause), args...) statsJSON := tree.MustBeDJSON(row[2]).JSON if err = sqlstatsutil.DecodeStmtStatsStatisticsJSON(statsJSON, &statistics.Stats); err != nil { - return statement, serverError(ctx, err) + return statement, srverrors.ServerError(ctx, err) } aggregatedMetadata.FormattedQuery = aggregatedMetadata.Query @@ -1357,7 +1359,7 @@ LIMIT $%d`, whereClause, len(args)), args...) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } } @@ -1377,7 +1379,7 @@ LIMIT $%d`, whereClause, len(args)), sessiondata.NodeUserSessionDataOverride, query, args...) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } } @@ -1389,7 +1391,7 @@ LIMIT $%d`, whereClause, len(args)), it, err = ie.QueryIteratorEx(ctx, "combined-stmts-details-by-aggregated-timestamp-with-memory", nil, sessiondata.NodeUserSessionDataOverride, query, args...) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } } @@ -1411,12 +1413,12 @@ LIMIT $%d`, whereClause, len(args)), var aggregatedMetadata appstatspb.AggregatedStatementMetadata metadataJSON := tree.MustBeDJSON(row[1]).JSON if err = sqlstatsutil.DecodeAggregatedMetadataJSON(metadataJSON, &aggregatedMetadata); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } statsJSON := tree.MustBeDJSON(row[2]).JSON if err = sqlstatsutil.DecodeStmtStatsStatisticsJSON(statsJSON, &metadata.Stats); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } stmt := serverpb.StatementDetailsResponse_CollectedStatementGroupedByAggregatedTs{ @@ -1428,7 +1430,7 @@ LIMIT $%d`, whereClause, len(args)), statements = append(statements, stmt) } if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return statements, nil @@ -1555,7 +1557,7 @@ GROUP BY index_recommendations LIMIT $%d`, whereClause, len(args)), args...) if iterErr != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } } @@ -1573,7 +1575,7 @@ LIMIT $%d`, whereClause, len(args)), args...) it, iterErr = ie.QueryIteratorEx(ctx, "combined-stmts-persisted-details-by-plan-hash", nil, sessiondata.NodeUserSessionDataOverride, query, args...) if iterErr != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } } @@ -1585,7 +1587,7 @@ LIMIT $%d`, whereClause, len(args)), args...) it, iterErr = ie.QueryIteratorEx(ctx, "combined-stmts-details-by-plan-hash-with-memory", nil, sessiondata.NodeUserSessionDataOverride, query, args...) if iterErr != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } } @@ -1603,7 +1605,7 @@ LIMIT $%d`, whereClause, len(args)), args...) var planHash uint64 if planHash, err = sqlstatsutil.DatumToUint64(row[0]); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } planGist := string(tree.MustBeDStringOrDNull(row[1])) var explainPlan string @@ -1615,12 +1617,12 @@ LIMIT $%d`, whereClause, len(args)), args...) var aggregatedMetadata appstatspb.AggregatedStatementMetadata metadataJSON := tree.MustBeDJSON(row[2]).JSON if err = sqlstatsutil.DecodeAggregatedMetadataJSON(metadataJSON, &aggregatedMetadata); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } statsJSON := tree.MustBeDJSON(row[3]).JSON if err = sqlstatsutil.DecodeStmtStatsStatisticsJSON(statsJSON, &metadata.Stats); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } recommendations := tree.MustBeDArray(row[4]) @@ -1663,7 +1665,7 @@ LIMIT $%d`, whereClause, len(args)), args...) statements = append(statements, stmt) } if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return statements, nil diff --git a/pkg/server/debug/BUILD.bazel b/pkg/server/debug/BUILD.bazel index 3c40fcc2a6ea..13074cc7b4dd 100644 --- a/pkg/server/debug/BUILD.bazel +++ b/pkg/server/debug/BUILD.bazel @@ -49,12 +49,25 @@ go_library( go_test( name = "debug_test", size = "small", - srcs = ["logspy_test.go"], + srcs = [ + "debug_test.go", + "logspy_test.go", + "main_test.go", + ], args = ["-test.timeout=55s"], embed = [":debug"], deps = [ + "//pkg/base", + "//pkg/ccl", + "//pkg/kv/kvclient/kvtenant", "//pkg/roachpb", + "//pkg/security/securityassets", + "//pkg/security/securitytest", + "//pkg/server", + "//pkg/server/srvtestutils", "//pkg/testutils", + "//pkg/testutils/serverutils", + "//pkg/testutils/testcluster", "//pkg/util/leaktest", "//pkg/util/log", "//pkg/util/log/logpb", diff --git a/pkg/server/debug/debug_test.go b/pkg/server/debug/debug_test.go new file mode 100644 index 000000000000..aecb2d256130 --- /dev/null +++ b/pkg/server/debug/debug_test.go @@ -0,0 +1,223 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package debug_test + +import ( + "bytes" + "context" + "net/http" + "net/url" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/server/debug" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/errors" +) + +// debugURL returns the root debug URL. +func debugURL(s serverutils.TestTenantInterface) string { + return s.AdminURL().WithPath(debug.Endpoint).String() +} + +// TestAdminDebugExpVar verifies that cmdline and memstats variables are +// available via the /debug/vars link. +func TestAdminDebugExpVar(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + + ts := s.TenantOrServer() + + jI, err := srvtestutils.GetJSON(ts, debugURL(ts)+"vars") + if err != nil { + t.Fatalf("failed to fetch JSON: %v", err) + } + j := jI.(map[string]interface{}) + if _, ok := j["cmdline"]; !ok { + t.Error("cmdline not found in JSON response") + } + if _, ok := j["memstats"]; !ok { + t.Error("memstats not found in JSON response") + } +} + +// TestAdminDebugMetrics verifies that cmdline and memstats variables are +// available via the /debug/metrics link. +func TestAdminDebugMetrics(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + + ts := s.TenantOrServer() + + jI, err := srvtestutils.GetJSON(ts, debugURL(ts)+"metrics") + if err != nil { + t.Fatalf("failed to fetch JSON: %v", err) + } + j := jI.(map[string]interface{}) + if _, ok := j["cmdline"]; !ok { + t.Error("cmdline not found in JSON response") + } + if _, ok := j["memstats"]; !ok { + t.Error("memstats not found in JSON response") + } +} + +// TestAdminDebugPprof verifies that pprof tools are available. +// via the /debug/pprof/* links. +func TestAdminDebugPprof(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + + ts := s.TenantOrServer() + + body, err := srvtestutils.GetText(ts, debugURL(ts)+"pprof/block?debug=1") + if err != nil { + t.Fatal(err) + } + if exp := "contention:\ncycles/second="; !bytes.Contains(body, []byte(exp)) { + t.Errorf("expected %s to contain %s", body, exp) + } +} + +// TestAdminDebugTrace verifies that the net/trace endpoints are available +// via /debug/{requests,events}. +func TestAdminDebugTrace(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + + ts := s.TenantOrServer() + + tc := []struct { + segment, search string + }{ + {"requests", "/debug/requests"}, + {"events", "events"}, + } + + for _, c := range tc { + body, err := srvtestutils.GetText(ts, debugURL(ts)+c.segment) + if err != nil { + t.Fatal(err) + } + if !bytes.Contains(body, []byte(c.search)) { + t.Errorf("expected %s to be contained in %s", c.search, body) + } + } +} + +func TestAdminDebugAuth(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + ts := s.TenantOrServer() + + url := debugURL(ts) + + // Unauthenticated. + client, err := ts.GetUnauthenticatedHTTPClient() + if err != nil { + t.Fatal(err) + } + resp, err := client.Get(url) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusUnauthorized { + t.Errorf("expected status code %d; got %d", http.StatusUnauthorized, resp.StatusCode) + } + + // Authenticated as non-admin. + client, err = ts.GetAuthenticatedHTTPClient(false, serverutils.SingleTenantSession) + if err != nil { + t.Fatal(err) + } + resp, err = client.Get(url) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusUnauthorized { + t.Errorf("expected status code %d; got %d", http.StatusUnauthorized, resp.StatusCode) + } + + // Authenticated as admin. + client, err = ts.GetAuthenticatedHTTPClient(true, serverutils.SingleTenantSession) + if err != nil { + t.Fatal(err) + } + resp, err = client.Get(url) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status code %d; got %d", http.StatusOK, resp.StatusCode) + } +} + +// TestAdminDebugRedirect verifies that the /debug/ endpoint is redirected to on +// incorrect /debug/ paths. +func TestAdminDebugRedirect(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + ts := s.TenantOrServer() + + expURL := debugURL(ts) + origURL := expURL + "incorrect" + + // Must be admin to access debug endpoints + client, err := ts.GetAdminHTTPClient() + if err != nil { + t.Fatal(err) + } + + // Don't follow redirects automatically. + redirectAttemptedError := errors.New("redirect") + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return redirectAttemptedError + } + + resp, err := client.Get(origURL) + if urlError := (*url.Error)(nil); errors.As(err, &urlError) && + errors.Is(urlError.Err, redirectAttemptedError) { + // Ignore the redirectAttemptedError. + err = nil + } + if err != nil { + t.Fatal(err) + } else { + resp.Body.Close() + if resp.StatusCode != http.StatusMovedPermanently { + t.Errorf("expected status code %d; got %d", http.StatusMovedPermanently, resp.StatusCode) + } + if redirectURL, err := resp.Location(); err != nil { + t.Error(err) + } else if foundURL := redirectURL.String(); foundURL != expURL { + t.Errorf("expected location %s; got %s", expURL, foundURL) + } + } +} diff --git a/pkg/server/debug/main_test.go b/pkg/server/debug/main_test.go new file mode 100644 index 000000000000..d3d09cb7e9cd --- /dev/null +++ b/pkg/server/debug/main_test.go @@ -0,0 +1,35 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package debug_test + +import ( + "os" + "testing" + + "github.com/cockroachdb/cockroach/pkg/ccl" + "github.com/cockroachdb/cockroach/pkg/kv/kvclient/kvtenant" + "github.com/cockroachdb/cockroach/pkg/security/securityassets" + "github.com/cockroachdb/cockroach/pkg/security/securitytest" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/testcluster" +) + +func TestMain(m *testing.M) { + securityassets.SetLoader(securitytest.EmbeddedAssets) + serverutils.InitTestServerFactory(server.TestServerFactory) + serverutils.InitTestClusterFactory(testcluster.TestClusterFactory) + kvtenant.InitTestConnectorFactory() + defer ccl.TestingEnableEnterprise()() + os.Exit(m.Run()) +} + +//go:generate ../util/leaktest/add-leaktest.sh *_test.go diff --git a/pkg/server/decommission.go b/pkg/server/decommission.go index 8bacba913532..f4ce6a4fecbc 100644 --- a/pkg/server/decommission.go +++ b/pkg/server/decommission.go @@ -36,37 +36,37 @@ import ( grpcstatus "google.golang.org/grpc/status" ) -// decommissioningNodeMap tracks the set of nodes that we know are +// DecommissioningNodeMap tracks the set of nodes that we know are // decommissioning. This map is used to inform whether we need to proactively // enqueue some decommissioning node's ranges for rebalancing. -type decommissioningNodeMap struct { +type DecommissioningNodeMap struct { syncutil.RWMutex nodes map[roachpb.NodeID]interface{} } -// decommissionRangeCheckResult is the result of evaluating the allocator action +// DecommissionRangeCheckResult is the result of evaluating the allocator action // and target for a single range that has an extant replica on a node targeted // for decommission. -type decommissionRangeCheckResult struct { - desc roachpb.RangeDescriptor - action string - tracingSpans tracingpb.Recording - err error +type DecommissionRangeCheckResult struct { + Desc roachpb.RangeDescriptor + Action string + TracingSpans tracingpb.Recording + Err error } -// decommissionPreCheckResult is the result of checking the readiness +// DecommissionPreCheckResult is the result of checking the readiness // of a node or set of nodes to be decommissioned. -type decommissionPreCheckResult struct { - rangesChecked int - replicasByNode map[roachpb.NodeID][]roachpb.ReplicaIdent - actionCounts map[string]int - rangesNotReady []decommissionRangeCheckResult +type DecommissionPreCheckResult struct { + RangesChecked int + ReplicasByNode map[roachpb.NodeID][]roachpb.ReplicaIdent + ActionCounts map[string]int + RangesNotReady []DecommissionRangeCheckResult } // makeOnNodeDecommissioningCallback returns a callback that enqueues the // decommissioning node's ranges into the `stores`' replicateQueues for // rebalancing. -func (t *decommissioningNodeMap) makeOnNodeDecommissioningCallback( +func (t *DecommissioningNodeMap) makeOnNodeDecommissioningCallback( stores *kvserver.Stores, ) func(id roachpb.NodeID) { return func(decommissioningNodeID roachpb.NodeID) { @@ -125,7 +125,7 @@ func (t *decommissioningNodeMap) makeOnNodeDecommissioningCallback( } } -func (t *decommissioningNodeMap) onNodeDecommissioned(nodeID roachpb.NodeID) { +func (t *DecommissioningNodeMap) onNodeDecommissioned(nodeID roachpb.NodeID) { t.Lock() defer t.Unlock() // NB: We may have already deleted this node, but that's ok. @@ -165,11 +165,11 @@ func (s *Server) DecommissionPreCheck( strictReadiness bool, collectTraces bool, maxErrors int, -) (decommissionPreCheckResult, error) { +) (DecommissionPreCheckResult, error) { // Ensure that if collectTraces is enabled, that a maxErrors >0 is set in // order to avoid unlimited memory usage. if collectTraces && maxErrors <= 0 { - return decommissionPreCheckResult{}, + return DecommissionPreCheckResult{}, grpcstatus.Error(codes.InvalidArgument, "MaxErrors must be set to collect traces.") } @@ -177,7 +177,7 @@ func (s *Server) DecommissionPreCheck( decommissionCheckNodeIDs := make(map[roachpb.NodeID]livenesspb.NodeLivenessStatus) replicasByNode := make(map[roachpb.NodeID][]roachpb.ReplicaIdent) actionCounts := make(map[string]int) - var rangeErrors []decommissionRangeCheckResult + var rangeErrors []DecommissionRangeCheckResult const pageSize = 10000 for _, nodeID := range nodeIDs { @@ -210,7 +210,7 @@ func (s *Server) DecommissionPreCheck( err = errors.Errorf("n%d has no initialized store", s.NodeID()) } if err != nil { - return decommissionPreCheckResult{}, grpcstatus.Error(codes.NotFound, err.Error()) + return DecommissionPreCheckResult{}, grpcstatus.Error(codes.NotFound, err.Error()) } // Define our node liveness overrides to simulate that the nodeIDs for which @@ -274,14 +274,14 @@ func (s *Server) DecommissionPreCheck( }) if err != nil { - return decommissionPreCheckResult{}, grpcstatus.Errorf(codes.Internal, err.Error()) + return DecommissionPreCheckResult{}, grpcstatus.Errorf(codes.Internal, err.Error()) } - return decommissionPreCheckResult{ - rangesChecked: rangesChecked, - replicasByNode: replicasByNode, - actionCounts: actionCounts, - rangesNotReady: rangeErrors, + return DecommissionPreCheckResult{ + RangesChecked: rangesChecked, + ReplicasByNode: replicasByNode, + ActionCounts: actionCounts, + RangesNotReady: rangeErrors, }, nil } @@ -295,15 +295,15 @@ func evaluateRangeCheckResult( action allocatorimpl.AllocatorAction, recording tracingpb.Recording, rErr error, -) (passed bool, _ decommissionRangeCheckResult) { - checkResult := decommissionRangeCheckResult{ - desc: *desc, - action: action.String(), - err: rErr, +) (passed bool, _ DecommissionRangeCheckResult) { + checkResult := DecommissionRangeCheckResult{ + Desc: *desc, + Action: action.String(), + Err: rErr, } if collectTraces { - checkResult.tracingSpans = recording + checkResult.TracingSpans = recording } if rErr != nil { @@ -313,14 +313,14 @@ func evaluateRangeCheckResult( if action == allocatorimpl.AllocatorRangeUnavailable || action == allocatorimpl.AllocatorNoop || action == allocatorimpl.AllocatorConsiderRebalance { - checkResult.err = errors.Errorf("range r%d requires unexpected allocation action: %s", + checkResult.Err = errors.Errorf("range r%d requires unexpected allocation action: %s", desc.RangeID, action, ) return false, checkResult } if strictReadiness && !(action.Replace() || action.Remove()) { - checkResult.err = errors.Errorf( + checkResult.Err = errors.Errorf( "range r%d needs repair beyond replacing/removing the decommissioning replica: %s", desc.RangeID, action, ) diff --git a/pkg/server/decommission_test.go b/pkg/server/decommission_test.go deleted file mode 100644 index c4bc293a58ee..000000000000 --- a/pkg/server/decommission_test.go +++ /dev/null @@ -1,319 +0,0 @@ -// Copyright 2023 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package server - -import ( - "context" - "testing" - - "github.com/cockroachdb/cockroach/pkg/base" - "github.com/cockroachdb/cockroach/pkg/kv/kvserver" - "github.com/cockroachdb/cockroach/pkg/kv/kvserver/allocator" - "github.com/cockroachdb/cockroach/pkg/kv/kvserver/allocator/allocatorimpl" - "github.com/cockroachdb/cockroach/pkg/roachpb" - "github.com/cockroachdb/cockroach/pkg/security/username" - "github.com/cockroachdb/cockroach/pkg/testutils" - "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" - "github.com/cockroachdb/cockroach/pkg/testutils/skip" - "github.com/cockroachdb/cockroach/pkg/util/leaktest" - "github.com/cockroachdb/cockroach/pkg/util/log" - "github.com/cockroachdb/errors" - "github.com/stretchr/testify/require" - "google.golang.org/grpc/codes" - grpcstatus "google.golang.org/grpc/status" -) - -// TestDecommissionPreCheckInvalid tests decommission pre check expected errors. -func TestDecommissionPreCheckInvalid(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - // Set up test cluster. - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 4, base.TestClusterArgs{ - ReplicationMode: base.ReplicationManual, - ServerArgsPerNode: map[int]base.TestServerArgs{ - 0: decommissionTsArgs("a", "n1"), - 1: decommissionTsArgs("b", "n2"), - 2: decommissionTsArgs("c", "n3"), - 3: decommissionTsArgs("a", "n4"), - }, - }) - defer tc.Stopper().Stop(ctx) - - firstSvr := tc.Server(0).(*TestServer) - - // Create database and tables. - ac := firstSvr.AmbientCtx() - ctx, span := ac.AnnotateCtxWithSpan(context.Background(), "test") - defer span.Finish() - - // Attempt to decommission check with unlimited traces. - decommissioningNodeIDs := []roachpb.NodeID{tc.Server(3).NodeID()} - result, err := firstSvr.DecommissionPreCheck(ctx, decommissioningNodeIDs, - true /* strictReadiness */, true /* collectTraces */, 0, /* maxErrors */ - ) - require.Error(t, err) - status, ok := grpcstatus.FromError(err) - require.True(t, ok, "expected grpc status error") - require.Equal(t, codes.InvalidArgument, status.Code()) - require.Equal(t, decommissionPreCheckResult{}, result) -} - -// TestDecommissionPreCheckEvaluation tests evaluation of decommission readiness -// of several nodes in a cluster given the replicas that exist on those nodes. -func TestDecommissionPreCheckEvaluation(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - skip.UnderRace(t) // can't handle 7-node clusters - - tsArgs := func(attrs ...string) base.TestServerArgs { - return decommissionTsArgs("a", attrs...) - } - - // Set up test cluster. - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 7, base.TestClusterArgs{ - ReplicationMode: base.ReplicationManual, - ServerArgsPerNode: map[int]base.TestServerArgs{ - 0: tsArgs("ns1", "origin"), - 1: tsArgs("ns2", "west"), - 2: tsArgs("ns3", "central"), - 3: tsArgs("ns4", "central"), - 4: tsArgs("ns5", "east"), - 5: tsArgs("ns6", "east"), - 6: tsArgs("ns7", "east"), - }, - }) - defer tc.Stopper().Stop(ctx) - - firstSvr := tc.Server(0).(*TestServer) - db := tc.ServerConn(0) - runQueries := func(queries ...string) { - for _, q := range queries { - if _, err := db.Exec(q); err != nil { - t.Fatalf("error executing '%s': %s", q, err) - } - } - } - - // Create database and tables. - ac := firstSvr.AmbientCtx() - ctx, span := ac.AnnotateCtxWithSpan(context.Background(), "test") - defer span.Finish() - setupQueries := []string{ - "CREATE DATABASE test", - "CREATE TABLE test.tblA (val STRING)", - "CREATE TABLE test.tblB (val STRING)", - "INSERT INTO test.tblA VALUES ('testvalA')", - "INSERT INTO test.tblB VALUES ('testvalB')", - } - runQueries(setupQueries...) - alterQueries := []string{ - "ALTER TABLE test.tblA CONFIGURE ZONE USING num_replicas = 3, constraints = '{+west: 1, +central: 1, +east: 1}', " + - "range_max_bytes = 500000000, range_min_bytes = 100", - "ALTER TABLE test.tblB CONFIGURE ZONE USING num_replicas = 3, constraints = '{+east}', " + - "range_max_bytes = 500000000, range_min_bytes = 100", - } - runQueries(alterQueries...) - tblAID, err := firstSvr.admin.queryTableID(ctx, username.RootUserName(), "test", "tblA") - require.NoError(t, err) - tblBID, err := firstSvr.admin.queryTableID(ctx, username.RootUserName(), "test", "tblB") - require.NoError(t, err) - startKeyTblA := firstSvr.Codec().TablePrefix(uint32(tblAID)) - startKeyTblB := firstSvr.Codec().TablePrefix(uint32(tblBID)) - - // Split off ranges for tblA and tblB. - _, rDescA, err := firstSvr.SplitRange(startKeyTblA) - require.NoError(t, err) - _, rDescB, err := firstSvr.SplitRange(startKeyTblB) - require.NoError(t, err) - - // Ensure all nodes have the correct span configs for tblA and tblB. - waitForSpanConfig(t, tc, rDescA.StartKey, 500000000) - waitForSpanConfig(t, tc, rDescB.StartKey, 500000000) - - // Transfer tblA to [west, central, east] and tblB to [east]. - tc.AddVotersOrFatal(t, startKeyTblA, tc.Target(1), tc.Target(2), tc.Target(4)) - tc.TransferRangeLeaseOrFatal(t, rDescA, tc.Target(1)) - tc.RemoveVotersOrFatal(t, startKeyTblA, tc.Target(0)) - tc.AddVotersOrFatal(t, startKeyTblB, tc.Target(4), tc.Target(5), tc.Target(6)) - tc.TransferRangeLeaseOrFatal(t, rDescB, tc.Target(4)) - tc.RemoveVotersOrFatal(t, startKeyTblB, tc.Target(0)) - - // Validate range distribution. - rDescA = tc.LookupRangeOrFatal(t, startKeyTblA) - rDescB = tc.LookupRangeOrFatal(t, startKeyTblB) - for _, desc := range []roachpb.RangeDescriptor{rDescA, rDescB} { - require.Lenf(t, desc.Replicas().VoterAndNonVoterDescriptors(), 3, "expected 3 replicas, have %v", desc) - } - - require.True(t, hasReplicaOnServers(tc, &rDescA, 1, 2, 4)) - require.True(t, hasReplicaOnServers(tc, &rDescB, 4, 5, 6)) - - // Evaluate n5 decommission check. - decommissioningNodeIDs := []roachpb.NodeID{tc.Server(4).NodeID()} - result, err := firstSvr.DecommissionPreCheck(ctx, decommissioningNodeIDs, - true /* strictReadiness */, true /* collectTraces */, 10000, /* maxErrors */ - ) - require.NoError(t, err) - require.Equal(t, 2, result.rangesChecked, "unexpected number of ranges checked") - require.Equalf(t, 2, result.actionCounts[allocatorimpl.AllocatorReplaceDecommissioningVoter.String()], - "unexpected allocator actions, got %v", result.actionCounts) - require.Lenf(t, result.rangesNotReady, 1, "unexpected number of unready ranges") - - // Validate error on tblB's range as it requires 3 replicas in "east". - unreadyResult := result.rangesNotReady[0] - require.Equalf(t, rDescB.StartKey, unreadyResult.desc.StartKey, - "expected tblB's range to be unready, got %s", unreadyResult.desc, - ) - require.Errorf(t, unreadyResult.err, "expected error on %s", unreadyResult.desc) - require.NotEmptyf(t, unreadyResult.tracingSpans, "expected tracing spans on %s", unreadyResult.desc) - var allocatorError allocator.AllocationError - require.ErrorAsf(t, unreadyResult.err, &allocatorError, "expected allocator error on %s", unreadyResult.desc) - - // Evaluate n3 decommission check (not required to satisfy constraints). - decommissioningNodeIDs = []roachpb.NodeID{tc.Server(2).NodeID()} - result, err = firstSvr.DecommissionPreCheck(ctx, decommissioningNodeIDs, - true /* strictReadiness */, false /* collectTraces */, 0, /* maxErrors */ - ) - require.NoError(t, err) - require.Equal(t, 1, result.rangesChecked, "unexpected number of ranges checked") - require.Equalf(t, 1, result.actionCounts[allocatorimpl.AllocatorReplaceDecommissioningVoter.String()], - "unexpected allocator actions, got %v", result.actionCounts) - require.Lenf(t, result.rangesNotReady, 0, "unexpected number of unready ranges") -} - -// TestDecommissionPreCheckOddToEven tests evaluation of decommission readiness -// when moving from 5 nodes to 3, in which case ranges with RF of 5 should have -// an effective RF of 3. -func TestDecommissionPreCheckOddToEven(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - // Set up test cluster. - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 5, base.TestClusterArgs{ - ReplicationMode: base.ReplicationManual, - }) - defer tc.Stopper().Stop(ctx) - - firstSvr := tc.Server(0).(*TestServer) - db := tc.ServerConn(0) - runQueries := func(queries ...string) { - for _, q := range queries { - if _, err := db.Exec(q); err != nil { - t.Fatalf("error executing '%s': %s", q, err) - } - } - } - - // Create database and tables. - ac := firstSvr.AmbientCtx() - ctx, span := ac.AnnotateCtxWithSpan(context.Background(), "test") - defer span.Finish() - setupQueries := []string{ - "CREATE DATABASE test", - "CREATE TABLE test.tblA (val STRING)", - "INSERT INTO test.tblA VALUES ('testvalA')", - } - runQueries(setupQueries...) - alterQueries := []string{ - "ALTER TABLE test.tblA CONFIGURE ZONE USING num_replicas = 5, " + - "range_max_bytes = 500000000, range_min_bytes = 100", - } - runQueries(alterQueries...) - tblAID, err := firstSvr.admin.queryTableID(ctx, username.RootUserName(), "test", "tblA") - require.NoError(t, err) - startKeyTblA := firstSvr.Codec().TablePrefix(uint32(tblAID)) - - // Split off range for tblA. - _, rDescA, err := firstSvr.SplitRange(startKeyTblA) - require.NoError(t, err) - - // Ensure all nodes have the correct span configs for tblA. - waitForSpanConfig(t, tc, rDescA.StartKey, 500000000) - - // Transfer tblA to all nodes. - tc.AddVotersOrFatal(t, startKeyTblA, tc.Target(1), tc.Target(2), tc.Target(3), tc.Target(4)) - tc.TransferRangeLeaseOrFatal(t, rDescA, tc.Target(1)) - - // Validate range distribution. - rDescA = tc.LookupRangeOrFatal(t, startKeyTblA) - require.Lenf(t, rDescA.Replicas().VoterAndNonVoterDescriptors(), 5, "expected 5 replicas, have %v", rDescA) - - require.True(t, hasReplicaOnServers(tc, &rDescA, 0, 1, 2, 3, 4)) - - // Evaluate n5 decommission check. - decommissioningNodeIDs := []roachpb.NodeID{tc.Server(4).NodeID()} - result, err := firstSvr.DecommissionPreCheck(ctx, decommissioningNodeIDs, - true /* strictReadiness */, true /* collectTraces */, 10000, /* maxErrors */ - ) - require.NoError(t, err) - require.Equal(t, 1, result.rangesChecked, "unexpected number of ranges checked") - require.Equalf(t, 1, result.actionCounts[allocatorimpl.AllocatorRemoveDecommissioningVoter.String()], - "unexpected allocator actions, got %v", result.actionCounts) - require.Lenf(t, result.rangesNotReady, 0, "unexpected number of unready ranges") -} - -// decommissionTsArgs returns a base.TestServerArgs for creating a test cluster -// with per-store attributes using a single, in-memory store for each node. -func decommissionTsArgs(region string, attrs ...string) base.TestServerArgs { - return base.TestServerArgs{ - Locality: roachpb.Locality{ - Tiers: []roachpb.Tier{ - { - Key: "region", - Value: region, - }, - }, - }, - StoreSpecs: []base.StoreSpec{ - {InMemory: true, Attributes: roachpb.Attributes{Attrs: attrs}}, - }, - } -} - -// hasReplicaOnServers returns true if the range has replicas on given servers. -func hasReplicaOnServers( - tc serverutils.TestClusterInterface, desc *roachpb.RangeDescriptor, serverIdxs ...int, -) bool { - for _, idx := range serverIdxs { - if !desc.Replicas().HasReplicaOnNode(tc.Server(idx).NodeID()) { - return false - } - } - return true -} - -// waitForSpanConfig waits until all servers in the test cluster have a span -// config for the key with the expected number of max bytes for the range. -func waitForSpanConfig( - t *testing.T, tc serverutils.TestClusterInterface, key roachpb.RKey, exp int64, -) { - testutils.SucceedsSoon(t, func() error { - for i := 0; i < tc.NumServers(); i++ { - s := tc.Server(i) - store, err := s.GetStores().(*kvserver.Stores).GetStore(s.GetFirstStoreID()) - if err != nil { - return errors.Wrapf(err, "missing store on server %d", i) - } - conf, err := store.GetStoreConfig().SpanConfigSubscriber.GetSpanConfigForKey(context.Background(), key) - if err != nil { - return errors.Wrapf(err, "missing span config for %s on server %d", key, i) - } - if conf.RangeMaxBytes != exp { - return errors.Errorf("expected %d max bytes, got %d", exp, conf.RangeMaxBytes) - } - } - return nil - }) -} diff --git a/pkg/server/distsql_flows.go b/pkg/server/distsql_flows.go new file mode 100644 index 000000000000..14459b80b684 --- /dev/null +++ b/pkg/server/distsql_flows.go @@ -0,0 +1,63 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package server + +import ( + "bytes" + "sort" + + "github.com/cockroachdb/cockroach/pkg/server/serverpb" +) + +// mergeDistSQLRemoteFlows takes in two slices of DistSQL remote flows (that +// satisfy the contract of serverpb.ListDistSQLFlowsResponse) and merges them +// together while adhering to the same contract. +// +// It is assumed that if serverpb.DistSQLRemoteFlows for a particular FlowID +// appear in both arguments - let's call them flowsA and flowsB for a and b, +// respectively - then there are no duplicate NodeIDs among flowsA and flowsB. +func mergeDistSQLRemoteFlows(a, b []serverpb.DistSQLRemoteFlows) []serverpb.DistSQLRemoteFlows { + maxLength := len(a) + if len(b) > len(a) { + maxLength = len(b) + } + result := make([]serverpb.DistSQLRemoteFlows, 0, maxLength) + aIter, bIter := 0, 0 + for aIter < len(a) && bIter < len(b) { + cmp := bytes.Compare(a[aIter].FlowID.GetBytes(), b[bIter].FlowID.GetBytes()) + if cmp < 0 { + result = append(result, a[aIter]) + aIter++ + } else if cmp > 0 { + result = append(result, b[bIter]) + bIter++ + } else { + r := a[aIter] + // No need to perform any kind of de-duplication because a + // particular flow will be reported at most once by each node in the + // cluster. + r.Infos = append(r.Infos, b[bIter].Infos...) + sort.Slice(r.Infos, func(i, j int) bool { + return r.Infos[i].NodeID < r.Infos[j].NodeID + }) + result = append(result, r) + aIter++ + bIter++ + } + } + if aIter < len(a) { + result = append(result, a[aIter:]...) + } + if bIter < len(b) { + result = append(result, b[bIter:]...) + } + return result +} diff --git a/pkg/server/distsql_flows_test.go b/pkg/server/distsql_flows_test.go new file mode 100644 index 000000000000..035aaf47b889 --- /dev/null +++ b/pkg/server/distsql_flows_test.go @@ -0,0 +1,205 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package server + +import ( + "bytes" + "sort" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/sql/execinfrapb" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/cockroach/pkg/util/uuid" + "github.com/stretchr/testify/require" +) + +func TestMergeDistSQLRemoteFlows(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + flowIDs := make([]execinfrapb.FlowID, 4) + for i := range flowIDs { + flowIDs[i].UUID = uuid.FastMakeV4() + } + sort.Slice(flowIDs, func(i, j int) bool { + return bytes.Compare(flowIDs[i].GetBytes(), flowIDs[j].GetBytes()) < 0 + }) + ts := make([]time.Time, 4) + for i := range ts { + ts[i] = timeutil.Now() + } + + for _, tc := range []struct { + a []serverpb.DistSQLRemoteFlows + b []serverpb.DistSQLRemoteFlows + expected []serverpb.DistSQLRemoteFlows + }{ + // a is empty + { + a: []serverpb.DistSQLRemoteFlows{}, + b: []serverpb.DistSQLRemoteFlows{ + { + FlowID: flowIDs[0], + Infos: []serverpb.DistSQLRemoteFlows_Info{ + {NodeID: 1, Timestamp: ts[1]}, + {NodeID: 2, Timestamp: ts[2]}, + {NodeID: 3, Timestamp: ts[3]}, + }, + }, + { + FlowID: flowIDs[1], + Infos: []serverpb.DistSQLRemoteFlows_Info{ + {NodeID: 1, Timestamp: ts[1]}, + }, + }, + }, + expected: []serverpb.DistSQLRemoteFlows{ + { + FlowID: flowIDs[0], + Infos: []serverpb.DistSQLRemoteFlows_Info{ + {NodeID: 1, Timestamp: ts[1]}, + {NodeID: 2, Timestamp: ts[2]}, + {NodeID: 3, Timestamp: ts[3]}, + }, + }, + { + FlowID: flowIDs[1], + Infos: []serverpb.DistSQLRemoteFlows_Info{ + {NodeID: 1, Timestamp: ts[1]}, + }, + }, + }, + }, + // b is empty + { + a: []serverpb.DistSQLRemoteFlows{ + { + FlowID: flowIDs[0], + Infos: []serverpb.DistSQLRemoteFlows_Info{ + {NodeID: 1, Timestamp: ts[1]}, + {NodeID: 2, Timestamp: ts[2]}, + {NodeID: 3, Timestamp: ts[3]}, + }, + }, + { + FlowID: flowIDs[1], + Infos: []serverpb.DistSQLRemoteFlows_Info{ + {NodeID: 1, Timestamp: ts[1]}, + }, + }, + }, + b: []serverpb.DistSQLRemoteFlows{}, + expected: []serverpb.DistSQLRemoteFlows{ + { + FlowID: flowIDs[0], + Infos: []serverpb.DistSQLRemoteFlows_Info{ + {NodeID: 1, Timestamp: ts[1]}, + {NodeID: 2, Timestamp: ts[2]}, + {NodeID: 3, Timestamp: ts[3]}, + }, + }, + { + FlowID: flowIDs[1], + Infos: []serverpb.DistSQLRemoteFlows_Info{ + {NodeID: 1, Timestamp: ts[1]}, + }, + }, + }, + }, + // both non-empty with some intersections + { + a: []serverpb.DistSQLRemoteFlows{ + { + FlowID: flowIDs[0], + Infos: []serverpb.DistSQLRemoteFlows_Info{ + {NodeID: 1, Timestamp: ts[1]}, + {NodeID: 2, Timestamp: ts[2]}, + {NodeID: 3, Timestamp: ts[3]}, + }, + }, + { + FlowID: flowIDs[2], + Infos: []serverpb.DistSQLRemoteFlows_Info{ + {NodeID: 3, Timestamp: ts[3]}, + }, + }, + { + FlowID: flowIDs[3], + Infos: []serverpb.DistSQLRemoteFlows_Info{ + {NodeID: 0, Timestamp: ts[0]}, + }, + }, + }, + b: []serverpb.DistSQLRemoteFlows{ + { + FlowID: flowIDs[0], + Infos: []serverpb.DistSQLRemoteFlows_Info{ + {NodeID: 0, Timestamp: ts[0]}, + }, + }, + { + FlowID: flowIDs[1], + Infos: []serverpb.DistSQLRemoteFlows_Info{ + {NodeID: 0, Timestamp: ts[0]}, + {NodeID: 1, Timestamp: ts[1]}, + {NodeID: 2, Timestamp: ts[2]}, + }, + }, + { + FlowID: flowIDs[3], + Infos: []serverpb.DistSQLRemoteFlows_Info{ + {NodeID: 1, Timestamp: ts[1]}, + {NodeID: 2, Timestamp: ts[2]}, + }, + }, + }, + expected: []serverpb.DistSQLRemoteFlows{ + { + FlowID: flowIDs[0], + Infos: []serverpb.DistSQLRemoteFlows_Info{ + {NodeID: 0, Timestamp: ts[0]}, + {NodeID: 1, Timestamp: ts[1]}, + {NodeID: 2, Timestamp: ts[2]}, + {NodeID: 3, Timestamp: ts[3]}, + }, + }, + { + FlowID: flowIDs[1], + Infos: []serverpb.DistSQLRemoteFlows_Info{ + {NodeID: 0, Timestamp: ts[0]}, + {NodeID: 1, Timestamp: ts[1]}, + {NodeID: 2, Timestamp: ts[2]}, + }, + }, + { + FlowID: flowIDs[2], + Infos: []serverpb.DistSQLRemoteFlows_Info{ + {NodeID: 3, Timestamp: ts[3]}, + }, + }, + { + FlowID: flowIDs[3], + Infos: []serverpb.DistSQLRemoteFlows_Info{ + {NodeID: 0, Timestamp: ts[0]}, + {NodeID: 1, Timestamp: ts[1]}, + {NodeID: 2, Timestamp: ts[2]}, + }, + }, + }, + }, + } { + require.Equal(t, tc.expected, mergeDistSQLRemoteFlows(tc.a, tc.b)) + } +} diff --git a/pkg/server/drain.go b/pkg/server/drain.go index 9f6df5aad884..061afbe2bcf7 100644 --- a/pkg/server/drain.go +++ b/pkg/server/drain.go @@ -20,6 +20,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/kv/kvserver/liveness" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srverrors" "github.com/cockroachdb/cockroach/pkg/settings" "github.com/cockroachdb/cockroach/pkg/sql/sqlstats/persistedsqlstats" "github.com/cockroachdb/cockroach/pkg/util/grpcutil" @@ -100,7 +101,7 @@ func (s *adminServer) Drain(req *serverpb.DrainRequest, stream serverpb.Admin_Dr // Connect to the target node. client, err := s.dialNode(ctx, roachpb.NodeID(nodeID)) if err != nil { - return serverError(ctx, err) + return srverrors.ServerError(ctx, err) } return delegateDrain(ctx, req, client, stream) } diff --git a/pkg/server/fanout_clients.go b/pkg/server/fanout_clients.go index 0451f539ec53..32dade21abbd 100644 --- a/pkg/server/fanout_clients.go +++ b/pkg/server/fanout_clients.go @@ -21,7 +21,9 @@ import ( "github.com/cockroachdb/cockroach/pkg/kv/kvserver/liveness/livenesspb" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srverrors" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/util" "github.com/cockroachdb/cockroach/pkg/util/hlc" @@ -187,7 +189,7 @@ type kvFanoutClient struct { func (k kvFanoutClient) nodesList(ctx context.Context) (*serverpb.NodesListResponse, error) { statuses, _, err := getNodeStatuses(ctx, k.db, 0 /* limit */, 0 /* offset */) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } resp := &serverpb.NodesListResponse{ Nodes: make([]serverpb.NodeDetails, len(statuses)), @@ -228,7 +230,7 @@ func (k kvFanoutClient) dialNode(ctx context.Context, serverID serverID) (*grpc. } func (k kvFanoutClient) listNodes(ctx context.Context) (*serverpb.NodesResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = k.ambientCtx.AnnotateCtx(ctx) statuses, _, err := getNodeStatuses(ctx, k.db, 0, 0) diff --git a/pkg/server/grpc_gateway.go b/pkg/server/grpc_gateway.go index 6eed6aa808d8..883f478c67aa 100644 --- a/pkg/server/grpc_gateway.go +++ b/pkg/server/grpc_gateway.go @@ -15,6 +15,7 @@ import ( "fmt" "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/telemetry" "github.com/cockroachdb/cockroach/pkg/ts" "github.com/cockroachdb/cockroach/pkg/util/httputil" @@ -38,7 +39,7 @@ type grpcGatewayServer interface { var _ grpcGatewayServer = (*adminServer)(nil) var _ grpcGatewayServer = (*statusServer)(nil) -var _ grpcGatewayServer = (*authenticationServer)(nil) +var _ grpcGatewayServer = authserver.Server(nil) var _ grpcGatewayServer = (*ts.Server)(nil) // configureGRPCGateway initializes services necessary for running the @@ -71,8 +72,8 @@ func configureGRPCGateway( gwruntime.WithMarshalerOption(httputil.AltJSONContentType, jsonpb), gwruntime.WithMarshalerOption(httputil.ProtoContentType, protopb), gwruntime.WithMarshalerOption(httputil.AltProtoContentType, protopb), - gwruntime.WithOutgoingHeaderMatcher(authenticationHeaderMatcher), - gwruntime.WithMetadata(translateHTTPAuthInfoToGRPCMetadata), + gwruntime.WithOutgoingHeaderMatcher(authserver.AuthenticationHeaderMatcher), + gwruntime.WithMetadata(authserver.TranslateHTTPAuthInfoToGRPCMetadata), ) gwCtx, gwCancel := context.WithCancel(ambientCtx.AnnotateCtx(context.Background())) stopper.AddCloser(stop.CloserFn(gwCancel)) diff --git a/pkg/server/grpc_gateway_test.go b/pkg/server/grpc_gateway_test.go new file mode 100644 index 000000000000..205c7cf33a62 --- /dev/null +++ b/pkg/server/grpc_gateway_test.go @@ -0,0 +1,57 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package server + +import ( + "context" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/server/telemetry" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/stretchr/testify/require" +) + +// TestEndpointTelemetryBasic tests that the telemetry collection on the usage of +// CRDB's endpoints works as expected by recording the call counts of `Admin` & +// `Status` requests. +func TestEndpointTelemetryBasic(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails + // with it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + }) + defer s.Stopper().Stop(context.Background()) + + // Check that calls over HTTP are recorded. + var details serverpb.LocationsResponse + if err := srvtestutils.GetAdminJSONProto(s, "locations", &details); err != nil { + t.Fatal(err) + } + require.GreaterOrEqual(t, telemetry.Read(getServerEndpointCounter( + "/cockroach.server.serverpb.Admin/Locations", + )), int32(1)) + + var resp serverpb.StatementsResponse + if err := srvtestutils.GetStatusJSONProto(s, "statements", &resp); err != nil { + t.Fatal(err) + } + require.Equal(t, int32(1), telemetry.Read(getServerEndpointCounter( + "/cockroach.server.serverpb.Status/Statements", + ))) +} diff --git a/pkg/server/index_usage_stats.go b/pkg/server/index_usage_stats.go index 6fcd754814af..81eeb28dc21c 100644 --- a/pkg/server/index_usage_stats.go +++ b/pkg/server/index_usage_stats.go @@ -17,6 +17,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/sql" @@ -25,6 +26,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" + "github.com/cockroachdb/cockroach/pkg/util/safesql" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" "google.golang.org/grpc/codes" @@ -40,10 +42,10 @@ import ( func (s *statusServer) IndexUsageStatistics( ctx context.Context, req *serverpb.IndexUsageStatisticsRequest, ) (*serverpb.IndexUsageStatisticsResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { + if err := s.privilegeChecker.RequireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { return nil, err } @@ -130,10 +132,10 @@ func indexUsageStatsLocal( func (s *statusServer) ResetIndexUsageStats( ctx context.Context, req *serverpb.ResetIndexUsageStatsRequest, ) (*serverpb.ResetIndexUsageStatsResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { return nil, err } @@ -207,10 +209,10 @@ func (s *statusServer) ResetIndexUsageStats( func (s *statusServer) TableIndexStats( ctx context.Context, req *serverpb.TableIndexStatsRequest, ) (*serverpb.TableIndexStatsResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { + if err := s.privilegeChecker.RequireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { return nil, err } return getTableIndexUsageStats(ctx, req, s.sqlServer.pgServer.SQLServer.GetLocalIndexStatistics(), @@ -228,7 +230,7 @@ func getTableIndexUsageStats( st *cluster.Settings, execConfig *sql.ExecutorConfig, ) (*serverpb.TableIndexStatsResponse, error) { - userName, err := userFromIncomingRPCContext(ctx) + userName, err := authserver.UserFromIncomingRPCContext(ctx) if err != nil { return nil, err } @@ -239,7 +241,7 @@ func getTableIndexUsageStats( return nil, err } - q := makeSQLQuery() + q := safesql.NewQuery() // TODO(#72930): Implement virtual indexes on index_usages_statistics and table_indexes q.Append(` SELECT @@ -387,7 +389,7 @@ func getDatabaseIndexRecommendations( return []*serverpb.IndexRecommendation{}, nil } - userName, err := userFromIncomingRPCContext(ctx) + userName, err := authserver.UserFromIncomingRPCContext(ctx) if err != nil { return []*serverpb.IndexRecommendation{}, err } diff --git a/pkg/server/index_usage_stats_test.go b/pkg/server/index_usage_stats_test.go index 1ce2b2c4ab21..9eabbffaab99 100644 --- a/pkg/server/index_usage_stats_test.go +++ b/pkg/server/index_usage_stats_test.go @@ -21,7 +21,9 @@ import ( "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" "github.com/cockroachdb/cockroach/pkg/sql" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" @@ -261,7 +263,7 @@ func TestStatusAPIIndexUsage(t *testing.T) { // Test cluster-wide RPC. var resp serverpb.IndexUsageStatisticsResponse - err = getStatusJSONProto(thirdServer, "indexusagestatistics", &resp) + err = srvtestutils.GetStatusJSONProto(thirdServer, "indexusagestatistics", &resp) require.NoError(t, err) statsEntries := 0 @@ -292,7 +294,7 @@ func TestStatusAPIIndexUsage(t *testing.T) { _, err = secondServerSQLConn.Exec("SELECT k, a FROM t WHERE a = 0") require.NoError(t, err) - err = getStatusJSONProto(thirdServer, "indexusagestatistics", &resp) + err = srvtestutils.GetStatusJSONProto(thirdServer, "indexusagestatistics", &resp) require.NoError(t, err) statsEntries = 0 @@ -360,7 +362,7 @@ CREATE TABLE schema.test_table ( `) // Get Table IDs. - userName, err := userFromIncomingRPCContext(ctx) + userName, err := authserver.UserFromIncomingRPCContext(ctx) require.NoError(t, err) testCases := []struct { diff --git a/pkg/server/init_handshake.go b/pkg/server/init_handshake.go index 85b29e6f4eb9..10c0666f3224 100644 --- a/pkg/server/init_handshake.go +++ b/pkg/server/init_handshake.go @@ -28,6 +28,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/security" "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server/srverrors" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" @@ -258,18 +259,18 @@ func (t *tlsInitHandshaker) onTrustInit( select { case t.trustedPeers <- challenge: case <-ctx.Done(): - apiV2InternalError(req.Context(), ctx.Err(), res) + srverrors.APIV2InternalError(req.Context(), ctx.Err(), res) return } // Acknowledge validation to the client. ack, err := createNodeHostnameAndCA(t.listenAddr, t.tempCerts.CACertificate, t.token) if err != nil { - apiV2InternalError(req.Context(), err, res) + srverrors.APIV2InternalError(req.Context(), err, res) return } if err := json.NewEncoder(res).Encode(ack); err != nil { - apiV2InternalError(req.Context(), err, res) + srverrors.APIV2InternalError(req.Context(), err, res) return } } diff --git a/pkg/server/main_test.go b/pkg/server/main_test.go index 81b35fa20e52..f20ce23797aa 100644 --- a/pkg/server/main_test.go +++ b/pkg/server/main_test.go @@ -18,6 +18,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/security/securityassets" "github.com/cockroachdb/cockroach/pkg/security/securitytest" "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/rangetestutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/testcluster" ) @@ -26,6 +27,7 @@ func TestMain(m *testing.M) { securityassets.SetLoader(securitytest.EmbeddedAssets) serverutils.InitTestServerFactory(server.TestServerFactory) serverutils.InitTestClusterFactory(testcluster.TestClusterFactory) + rangetestutils.InitRangeTestServerFactory(server.TestServerFactory) kvtenant.InitTestConnectorFactory() os.Exit(m.Run()) } diff --git a/pkg/server/multi_store_test.go b/pkg/server/multi_store_test.go index 8ffafd543f9f..3006504f5db5 100644 --- a/pkg/server/multi_store_test.go +++ b/pkg/server/multi_store_test.go @@ -65,7 +65,9 @@ func TestAddNewStoresToExistingNodes(t *testing.T) { ServerArgsPerNode: map[int]base.TestServerArgs{}, } for srvIdx := 0; srvIdx < numNodes; srvIdx++ { - var serverArgs base.TestServerArgs + serverArgs := base.TestServerArgs{ + DefaultTestTenant: base.TODOTestTenantDisabled, + } serverArgs.Knobs.Server = &server.TestingKnobs{StickyEngineRegistry: ser} for storeIdx := 0; storeIdx < numStoresPerNode; storeIdx++ { id := fmt.Sprintf("s%d.%d", srvIdx+1, storeIdx+1) diff --git a/pkg/server/nodes_response.go b/pkg/server/nodes_response.go new file mode 100644 index 000000000000..bac7d0cc6892 --- /dev/null +++ b/pkg/server/nodes_response.go @@ -0,0 +1,155 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package server + +import ( + "sort" + + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/status/statuspb" + "github.com/cockroachdb/cockroach/pkg/util" +) + +func nodeStatusToResp(n *statuspb.NodeStatus, hasViewClusterMetadata bool) serverpb.NodeResponse { + tiers := make([]serverpb.Tier, len(n.Desc.Locality.Tiers)) + for j, t := range n.Desc.Locality.Tiers { + tiers[j] = serverpb.Tier{ + Key: t.Key, + Value: t.Value, + } + } + + activity := make(map[roachpb.NodeID]serverpb.NodeResponse_NetworkActivity, len(n.Activity)) + for k, v := range n.Activity { + activity[k] = serverpb.NodeResponse_NetworkActivity{ + Latency: v.Latency, + } + } + + nodeDescriptor := serverpb.NodeDescriptor{ + NodeID: n.Desc.NodeID, + Address: util.UnresolvedAddr{}, + Attrs: roachpb.Attributes{}, + Locality: serverpb.Locality{ + Tiers: tiers, + }, + ServerVersion: serverpb.Version{ + Major: n.Desc.ServerVersion.Major, + Minor: n.Desc.ServerVersion.Minor, + Patch: n.Desc.ServerVersion.Patch, + Internal: n.Desc.ServerVersion.Internal, + }, + BuildTag: n.Desc.BuildTag, + StartedAt: n.Desc.StartedAt, + LocalityAddress: nil, + ClusterName: n.Desc.ClusterName, + SQLAddress: util.UnresolvedAddr{}, + } + + statuses := make([]serverpb.StoreStatus, len(n.StoreStatuses)) + for i, ss := range n.StoreStatuses { + statuses[i] = serverpb.StoreStatus{ + Desc: serverpb.StoreDescriptor{ + StoreID: ss.Desc.StoreID, + Attrs: ss.Desc.Attrs, + Node: nodeDescriptor, + Capacity: ss.Desc.Capacity, + + Properties: roachpb.StoreProperties{ + ReadOnly: ss.Desc.Properties.ReadOnly, + Encrypted: ss.Desc.Properties.Encrypted, + }, + }, + Metrics: ss.Metrics, + } + if fsprops := ss.Desc.Properties.FileStoreProperties; fsprops != nil { + sfsprops := &roachpb.FileStoreProperties{ + FsType: fsprops.FsType, + } + if hasViewClusterMetadata { + sfsprops.Path = fsprops.Path + sfsprops.BlockDevice = fsprops.BlockDevice + sfsprops.MountPoint = fsprops.MountPoint + sfsprops.MountOptions = fsprops.MountOptions + } + statuses[i].Desc.Properties.FileStoreProperties = sfsprops + } + } + + resp := serverpb.NodeResponse{ + Desc: nodeDescriptor, + BuildInfo: n.BuildInfo, + StartedAt: n.StartedAt, + UpdatedAt: n.UpdatedAt, + Metrics: n.Metrics, + StoreStatuses: statuses, + Args: nil, + Env: nil, + Latencies: n.Latencies, + Activity: activity, + TotalSystemMemory: n.TotalSystemMemory, + NumCpus: n.NumCpus, + } + + if hasViewClusterMetadata { + resp.Args = n.Args + resp.Env = n.Env + resp.Desc.Attrs = n.Desc.Attrs + resp.Desc.Address = n.Desc.Address + resp.Desc.LocalityAddress = n.Desc.LocalityAddress + resp.Desc.SQLAddress = n.Desc.SQLAddress + for _, n := range resp.StoreStatuses { + n.Desc.Node = resp.Desc + } + } + + return resp +} + +func regionsResponseFromNodesResponse(nr *serverpb.NodesResponse) *serverpb.RegionsResponse { + regionsToZones := make(map[string]map[string]struct{}) + for _, node := range nr.Nodes { + var region string + var zone string + for _, tier := range node.Desc.Locality.Tiers { + switch tier.Key { + case "region": + region = tier.Value + case "zone", "availability-zone", "az": + zone = tier.Value + } + } + if region == "" { + continue + } + if _, ok := regionsToZones[region]; !ok { + regionsToZones[region] = make(map[string]struct{}) + } + if zone != "" { + regionsToZones[region][zone] = struct{}{} + } + } + ret := &serverpb.RegionsResponse{ + Regions: make(map[string]*serverpb.RegionsResponse_Region, len(regionsToZones)), + } + for region, zones := range regionsToZones { + zonesArr := make([]string, 0, len(zones)) + for z := range zones { + zonesArr = append(zonesArr, z) + } + sort.Strings(zonesArr) + ret.Regions[region] = &serverpb.RegionsResponse_Region{ + Zones: zonesArr, + } + } + return ret +} diff --git a/pkg/server/nodes_response_test.go b/pkg/server/nodes_response_test.go new file mode 100644 index 000000000000..dbc1ce3c8c5b --- /dev/null +++ b/pkg/server/nodes_response_test.go @@ -0,0 +1,170 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package server + +import ( + "testing" + + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/status/statuspb" + "github.com/cockroachdb/cockroach/pkg/util" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/stretchr/testify/require" +) + +func TestNodeStatusToResp(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + var nodeStatus = &statuspb.NodeStatus{ + StoreStatuses: []statuspb.StoreStatus{ + {Desc: roachpb.StoreDescriptor{ + Properties: roachpb.StoreProperties{ + Encrypted: true, + FileStoreProperties: &roachpb.FileStoreProperties{ + Path: "/secret", + FsType: "ext4", + }, + }, + }}, + }, + Desc: roachpb.NodeDescriptor{ + Address: util.UnresolvedAddr{ + NetworkField: "network", + AddressField: "address", + }, + Attrs: roachpb.Attributes{ + Attrs: []string{"attr"}, + }, + LocalityAddress: []roachpb.LocalityAddress{{Address: util.UnresolvedAddr{ + NetworkField: "network", + AddressField: "address", + }, LocalityTier: roachpb.Tier{Value: "v", Key: "k"}}}, + SQLAddress: util.UnresolvedAddr{ + NetworkField: "network", + AddressField: "address", + }, + }, + Args: []string{"args"}, + Env: []string{"env"}, + } + resp := nodeStatusToResp(nodeStatus, false) + require.Empty(t, resp.Args) + require.Empty(t, resp.Env) + require.Empty(t, resp.Desc.Address) + require.Empty(t, resp.Desc.Attrs.Attrs) + require.Empty(t, resp.Desc.LocalityAddress) + require.Empty(t, resp.Desc.SQLAddress) + require.True(t, resp.StoreStatuses[0].Desc.Properties.Encrypted) + require.NotEmpty(t, resp.StoreStatuses[0].Desc.Properties.FileStoreProperties.FsType) + require.Empty(t, resp.StoreStatuses[0].Desc.Properties.FileStoreProperties.Path) + + // Now fetch all the node statuses as admin. + resp = nodeStatusToResp(nodeStatus, true) + require.NotEmpty(t, resp.Args) + require.NotEmpty(t, resp.Env) + require.NotEmpty(t, resp.Desc.Address) + require.NotEmpty(t, resp.Desc.Attrs.Attrs) + require.NotEmpty(t, resp.Desc.LocalityAddress) + require.NotEmpty(t, resp.Desc.SQLAddress) + require.True(t, resp.StoreStatuses[0].Desc.Properties.Encrypted) + require.NotEmpty(t, resp.StoreStatuses[0].Desc.Properties.FileStoreProperties.FsType) + require.NotEmpty(t, resp.StoreStatuses[0].Desc.Properties.FileStoreProperties.Path) +} + +func TestRegionsResponseFromNodesResponse(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + makeNodeResponseWithLocalities := func(tiers [][]roachpb.Tier) *serverpb.NodesResponse { + ret := &serverpb.NodesResponse{} + for _, l := range tiers { + ret.Nodes = append( + ret.Nodes, + statuspb.NodeStatus{ + Desc: roachpb.NodeDescriptor{ + Locality: roachpb.Locality{Tiers: l}, + }, + }, + ) + } + return ret + } + + makeTiers := func(region, zone string) []roachpb.Tier { + return []roachpb.Tier{ + {Key: "region", Value: region}, + {Key: "zone", Value: zone}, + } + } + + testCases := []struct { + desc string + resp *serverpb.NodesResponse + expected *serverpb.RegionsResponse + }{ + { + desc: "no nodes with regions", + resp: makeNodeResponseWithLocalities([][]roachpb.Tier{ + {{Key: "a", Value: "a"}}, + {}, + }), + expected: &serverpb.RegionsResponse{ + Regions: map[string]*serverpb.RegionsResponse_Region{}, + }, + }, + { + desc: "nodes, some with AZs", + resp: makeNodeResponseWithLocalities([][]roachpb.Tier{ + makeTiers("us-east1", "us-east1-a"), + makeTiers("us-east1", "us-east1-a"), + makeTiers("us-east1", "us-east1-a"), + makeTiers("us-east1", "us-east1-b"), + + makeTiers("us-east2", "us-east2-a"), + makeTiers("us-east2", "us-east2-a"), + makeTiers("us-east2", "us-east2-a"), + + makeTiers("us-east3", "us-east3-a"), + makeTiers("us-east3", "us-east3-b"), + makeTiers("us-east3", "us-east3-b"), + {{Key: "region", Value: "us-east3"}}, + + {{Key: "region", Value: "us-east4"}}, + }), + expected: &serverpb.RegionsResponse{ + Regions: map[string]*serverpb.RegionsResponse_Region{ + "us-east1": { + Zones: []string{"us-east1-a", "us-east1-b"}, + }, + "us-east2": { + Zones: []string{"us-east2-a"}, + }, + "us-east3": { + Zones: []string{"us-east3-a", "us-east3-b"}, + }, + "us-east4": { + Zones: []string{}, + }, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + ret := regionsResponseFromNodesResponse(tc.resp) + require.Equal(t, tc.expected, ret) + }) + } +} diff --git a/pkg/server/privchecker/BUILD.bazel b/pkg/server/privchecker/BUILD.bazel new file mode 100644 index 000000000000..56f6ac43ea52 --- /dev/null +++ b/pkg/server/privchecker/BUILD.bazel @@ -0,0 +1,55 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "privchecker", + srcs = [ + "api.go", + "privchecker.go", + ], + importpath = "github.com/cockroachdb/cockroach/pkg/server/privchecker", + visibility = ["//visibility:public"], + deps = [ + "//pkg/security/username", + "//pkg/server/authserver", + "//pkg/server/srverrors", + "//pkg/settings/cluster", + "//pkg/sql", + "//pkg/sql/isql", + "//pkg/sql/privilege", + "//pkg/sql/roleoption", + "//pkg/sql/sem/tree", + "//pkg/sql/sessiondata", + "//pkg/sql/syntheticprivilege", + "@com_github_cockroachdb_errors//:errors", + "@org_golang_google_grpc//codes", + "@org_golang_google_grpc//status", + ], +) + +go_test( + name = "privchecker_test", + srcs = [ + "main_test.go", + "privchecker_test.go", + ], + args = ["-test.timeout=295s"], + deps = [ + ":privchecker", + "//pkg/base", + "//pkg/ccl", + "//pkg/kv/kvclient/kvtenant", + "//pkg/security/securityassets", + "//pkg/security/securitytest", + "//pkg/security/username", + "//pkg/server", + "//pkg/sql", + "//pkg/sql/isql", + "//pkg/testutils/serverutils", + "//pkg/testutils/sqlutils", + "//pkg/testutils/testcluster", + "//pkg/util/leaktest", + "//pkg/util/log", + "@com_github_stretchr_testify//require", + "@org_golang_google_grpc//metadata", + ], +) diff --git a/pkg/server/privchecker/api.go b/pkg/server/privchecker/api.go new file mode 100644 index 000000000000..329396bfa272 --- /dev/null +++ b/pkg/server/privchecker/api.go @@ -0,0 +1,98 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package privchecker + +import ( + "context" + + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/settings/cluster" + "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/sql/isql" + "github.com/cockroachdb/cockroach/pkg/sql/privilege" + "github.com/cockroachdb/cockroach/pkg/sql/roleoption" + "google.golang.org/grpc/codes" + grpcstatus "google.golang.org/grpc/status" +) + +// CheckerForRPCHandlers describes a helper for checking privileges. +// +// Note: this interface is intended for use inside RPC handlers in the +// 'server' package, where the identity of the current user is carried +// by the context.Context as per the 'authserver' protocol. +type CheckerForRPCHandlers interface { + SQLPrivilegeChecker + + // GetUserAndRole returns the current user's name and whether the + // user is an admin. + // + // Note that the function returns plain errors, and it is the caller's + // responsibility to convert them through srverrors.ServerError. + GetUserAndRole(ctx context.Context) (userName username.SQLUsername, isAdmin bool, err error) + + // RequireAdminUser validates the current user is + // an admin user. It returns the current user's name. + // Its error return is a gRPC error. + RequireAdminUser(ctx context.Context) (userName username.SQLUsername, err error) + + // RequireAdminRole validates the current user has the VIEWACTIVITY + // privilege or role option. + // Its error return is a gRPC error. + RequireViewActivityPermission(ctx context.Context) error + + RequireViewActivityOrViewActivityRedactedPermission(ctx context.Context) error + RequireViewClusterSettingOrModifyClusterSettingPermission(ctx context.Context) error + RequireViewActivityAndNoViewActivityRedactedPermission(ctx context.Context) error + RequireViewClusterMetadataPermission(ctx context.Context) error + RequireViewDebugPermission(ctx context.Context) error +} + +// SQLPrivilegeChecker is the part of the privilege checker that can +// be used outside of RPC handlers, because it takes the identity as +// explicit argument. +type SQLPrivilegeChecker interface { + // HasAdminRole checks if the user has the admin role. + // Note that the function returns plain errors, and it is the + // caller's responsibility to convert them through + // srverrors.ServerError. + HasAdminRole(ctx context.Context, user username.SQLUsername) (bool, error) + + // HasRoleOptions checks if the user has the given role option. + // Note that the function returns plain errors, and it is the + // caller's responsibility to convert them through + // srverrors.ServerError. + HasRoleOption(ctx context.Context, user username.SQLUsername, roleOption roleoption.Option) (bool, error) + + // SetSQLAuthzAccessorFactory sets the accessor factory that can be + // used by HasGlobalPrivilege. + SetAuthzAccessorFactory(factory func(opName string) (sql.AuthorizationAccessor, func())) + + // HasGlobalPrivilege is a convenience wrapper + HasGlobalPrivilege(ctx context.Context, user username.SQLUsername, privilege privilege.Kind) (bool, error) +} + +// NewChecker constructs a new CheckerForRPCHandlers. +func NewChecker(ie isql.Executor, st *cluster.Settings) CheckerForRPCHandlers { + return &adminPrivilegeChecker{ + ie: ie, + st: st, + } +} + +// ErrRequiresAdmin is returned when the admin role is required by an API. +var ErrRequiresAdmin = grpcstatus.Error(codes.PermissionDenied, "this operation requires admin privilege") + +// ErrRequiresRoleOption can be used to construct an error that tells the user +// a given role option or privilege is required. +func ErrRequiresRoleOption(option roleoption.Option) error { + return grpcstatus.Errorf( + codes.PermissionDenied, "this operation requires %s privilege", option) +} diff --git a/pkg/server/privchecker/main_test.go b/pkg/server/privchecker/main_test.go new file mode 100644 index 000000000000..fe5714400354 --- /dev/null +++ b/pkg/server/privchecker/main_test.go @@ -0,0 +1,35 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package privchecker_test + +import ( + "os" + "testing" + + "github.com/cockroachdb/cockroach/pkg/ccl" + "github.com/cockroachdb/cockroach/pkg/kv/kvclient/kvtenant" + "github.com/cockroachdb/cockroach/pkg/security/securityassets" + "github.com/cockroachdb/cockroach/pkg/security/securitytest" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/testcluster" +) + +func TestMain(m *testing.M) { + securityassets.SetLoader(securitytest.EmbeddedAssets) + serverutils.InitTestServerFactory(server.TestServerFactory) + serverutils.InitTestClusterFactory(testcluster.TestClusterFactory) + kvtenant.InitTestConnectorFactory() + defer ccl.TestingEnableEnterprise()() + os.Exit(m.Run()) +} + +//go:generate ../util/leaktest/add-leaktest.sh *_test.go diff --git a/pkg/server/privchecker/privchecker.go b/pkg/server/privchecker/privchecker.go new file mode 100644 index 000000000000..50e4c19fb092 --- /dev/null +++ b/pkg/server/privchecker/privchecker.go @@ -0,0 +1,326 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package privchecker + +import ( + "context" + + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server/authserver" + "github.com/cockroachdb/cockroach/pkg/server/srverrors" + "github.com/cockroachdb/cockroach/pkg/settings/cluster" + "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/sql/isql" + "github.com/cockroachdb/cockroach/pkg/sql/privilege" + "github.com/cockroachdb/cockroach/pkg/sql/roleoption" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" + "github.com/cockroachdb/cockroach/pkg/sql/syntheticprivilege" + "github.com/cockroachdb/errors" + "google.golang.org/grpc/codes" + grpcstatus "google.golang.org/grpc/status" +) + +// adminPrivilegeChecker is a helper struct to check whether given usernames +// have admin privileges. +type adminPrivilegeChecker struct { + ie isql.Executor + st *cluster.Settings + + // makeAuthzAccessor is a function that calls NewInternalPlanner to + // make a sql.AuthorizationAccessor outside of the sql package. This + // is a hack to get around a Go package dependency cycle. See + // comment in pkg/scheduledjobs/env.go on planHookMaker. It should + // be cast to AuthorizationAccessor in order to use privilege + // checking functions. + makeAuthzAccessor func(opName string) (sql.AuthorizationAccessor, func()) +} + +// RequireAdminUser is part of the CheckerForRPCHandlers interface. +func (c *adminPrivilegeChecker) RequireAdminUser( + ctx context.Context, +) (userName username.SQLUsername, err error) { + userName, isAdmin, err := c.GetUserAndRole(ctx) + if err != nil { + return userName, srverrors.ServerError(ctx, err) + } + if !isAdmin { + return userName, ErrRequiresAdmin + } + return userName, nil +} + +// RequireViewActivityPermission is part of the CheckerForRPCHandlers interface. +func (c *adminPrivilegeChecker) RequireViewActivityPermission(ctx context.Context) (err error) { + userName, isAdmin, err := c.GetUserAndRole(ctx) + if err != nil { + return srverrors.ServerError(ctx, err) + } + if isAdmin { + return nil + } + if hasView, err := c.HasGlobalPrivilege(ctx, userName, privilege.VIEWACTIVITY); err != nil { + return srverrors.ServerError(ctx, err) + } else if hasView { + return nil + } + if hasView, err := c.HasRoleOption(ctx, userName, roleoption.VIEWACTIVITY); err != nil { + return srverrors.ServerError(ctx, err) + } else if hasView { + return nil + } + return grpcstatus.Errorf( + codes.PermissionDenied, "this operation requires the %s system privilege", + roleoption.VIEWACTIVITY) +} + +// RequireViewActivityOrViewActivityRedactedPermission's error return is a gRPC error. +func (c *adminPrivilegeChecker) RequireViewActivityOrViewActivityRedactedPermission( + ctx context.Context, +) (err error) { + userName, isAdmin, err := c.GetUserAndRole(ctx) + if err != nil { + return srverrors.ServerError(ctx, err) + } + if isAdmin { + return nil + } + if hasView, err := c.HasGlobalPrivilege(ctx, userName, privilege.VIEWACTIVITY); err != nil { + return srverrors.ServerError(ctx, err) + } else if hasView { + return nil + } + if hasViewRedacted, err := c.HasGlobalPrivilege(ctx, userName, privilege.VIEWACTIVITYREDACTED); err != nil { + return srverrors.ServerError(ctx, err) + } else if hasViewRedacted { + return nil + } + if hasView, err := c.HasRoleOption(ctx, userName, roleoption.VIEWACTIVITY); err != nil { + return srverrors.ServerError(ctx, err) + } else if hasView { + return nil + } + if hasViewRedacted, err := c.HasRoleOption(ctx, userName, roleoption.VIEWACTIVITYREDACTED); err != nil { + return srverrors.ServerError(ctx, err) + } else if hasViewRedacted { + return nil + } + return grpcstatus.Errorf( + codes.PermissionDenied, "this operation requires the %s or %s system privileges", + roleoption.VIEWACTIVITY, roleoption.VIEWACTIVITYREDACTED) +} + +// RequireViewClusterSettingOrModifyClusterSettingPermission's error return is a gRPC error. +func (c *adminPrivilegeChecker) RequireViewClusterSettingOrModifyClusterSettingPermission( + ctx context.Context, +) (err error) { + userName, isAdmin, err := c.GetUserAndRole(ctx) + if err != nil { + return srverrors.ServerError(ctx, err) + } + if isAdmin { + return nil + } + if hasView, err := c.HasGlobalPrivilege(ctx, userName, privilege.VIEWCLUSTERSETTING); err != nil { + return srverrors.ServerError(ctx, err) + } else if hasView { + return nil + } + if hasModify, err := c.HasGlobalPrivilege(ctx, userName, privilege.MODIFYCLUSTERSETTING); err != nil { + return srverrors.ServerError(ctx, err) + } else if hasModify { + return nil + } + if hasView, err := c.HasRoleOption(ctx, userName, roleoption.VIEWCLUSTERSETTING); err != nil { + return srverrors.ServerError(ctx, err) + } else if hasView { + return nil + } + if hasModify, err := c.HasRoleOption(ctx, userName, roleoption.MODIFYCLUSTERSETTING); err != nil { + return srverrors.ServerError(ctx, err) + } else if hasModify { + return nil + } + return grpcstatus.Errorf( + codes.PermissionDenied, "this operation requires the %s or %s system privileges", + privilege.VIEWCLUSTERSETTING, privilege.MODIFYCLUSTERSETTING) +} + +// RequireViewActivityAndNoViewActivityRedactedPermission requires +// that the user have the VIEWACTIVITY role, but does not have the +// VIEWACTIVITYREDACTED role. This function's error return is a gRPC +// error. +func (c *adminPrivilegeChecker) RequireViewActivityAndNoViewActivityRedactedPermission( + ctx context.Context, +) (err error) { + userName, isAdmin, err := c.GetUserAndRole(ctx) + if err != nil { + return srverrors.ServerError(ctx, err) + } + + if !isAdmin { + hasViewRedacted, err := c.HasGlobalPrivilege(ctx, userName, privilege.VIEWACTIVITYREDACTED) + if err != nil { + return srverrors.ServerError(ctx, err) + } + if !hasViewRedacted { + hasViewRedacted, err := c.HasRoleOption(ctx, userName, roleoption.VIEWACTIVITYREDACTED) + if err != nil { + return srverrors.ServerError(ctx, err) + } + if hasViewRedacted { + return grpcstatus.Errorf( + codes.PermissionDenied, "this operation requires %s role option and is not allowed for %s role option", + roleoption.VIEWACTIVITY, roleoption.VIEWACTIVITYREDACTED) + } + } else { + return grpcstatus.Errorf( + codes.PermissionDenied, "this operation requires %s system privilege and is not allowed for %s system privilege", + privilege.VIEWACTIVITY, privilege.VIEWACTIVITYREDACTED) + } + return c.RequireViewActivityPermission(ctx) + } + return nil +} + +// RequireViewClusterMetadataPermission requires the user have admin +// or the VIEWCLUSTERMETADATA system privilege and returns an error if +// the user does not have it. +func (c *adminPrivilegeChecker) RequireViewClusterMetadataPermission( + ctx context.Context, +) (err error) { + userName, isAdmin, err := c.GetUserAndRole(ctx) + if err != nil { + return srverrors.ServerError(ctx, err) + } + if isAdmin { + return nil + } + if hasViewClusterMetadata, err := c.HasGlobalPrivilege(ctx, userName, privilege.VIEWCLUSTERMETADATA); err != nil { + return srverrors.ServerError(ctx, err) + } else if hasViewClusterMetadata { + return nil + } + return grpcstatus.Errorf( + codes.PermissionDenied, "this operation requires the %s system privilege", + privilege.VIEWCLUSTERMETADATA) +} + +// RequireViewDebugPermission requires the user have admin or the +// VIEWDEBUG system privilege and returns an error if the user does +// not have it. +func (c *adminPrivilegeChecker) RequireViewDebugPermission(ctx context.Context) (err error) { + userName, isAdmin, err := c.GetUserAndRole(ctx) + if err != nil { + return srverrors.ServerError(ctx, err) + } + if isAdmin { + return nil + } + if hasViewDebug, err := c.HasGlobalPrivilege(ctx, userName, privilege.VIEWDEBUG); err != nil { + return srverrors.ServerError(ctx, err) + } else if hasViewDebug { + return nil + } + return grpcstatus.Errorf( + codes.PermissionDenied, "this operation requires the %s system privilege", + privilege.VIEWDEBUG) +} + +// GetUserAndRole is part of the CheckerForRPCHandlers interface. +func (c *adminPrivilegeChecker) GetUserAndRole( + ctx context.Context, +) (userName username.SQLUsername, isAdmin bool, err error) { + userName, err = authserver.UserFromIncomingRPCContext(ctx) + if err != nil { + return userName, false, err + } + isAdmin, err = c.HasAdminRole(ctx, userName) + return userName, isAdmin, err +} + +// HasAdminRole is part of the SQLPrivilegeChecker interface. +// Note that the function returns plain errors, and it is the caller's +// responsibility to convert them to serverErrors. +func (c *adminPrivilegeChecker) HasAdminRole( + ctx context.Context, user username.SQLUsername, +) (bool, error) { + if user.IsRootUser() { + // Shortcut. + return true, nil + } + row, err := c.ie.QueryRowEx( + ctx, "check-is-admin", nil, /* txn */ + sessiondata.InternalExecutorOverride{User: user}, + "SELECT crdb_internal.is_admin()") + if err != nil { + return false, err + } + if row == nil { + return false, errors.AssertionFailedf("hasAdminRole: expected 1 row, got 0") + } + if len(row) != 1 { + return false, errors.AssertionFailedf("hasAdminRole: expected 1 column, got %d", len(row)) + } + dbDatum, ok := tree.AsDBool(row[0]) + if !ok { + return false, errors.AssertionFailedf("hasAdminRole: expected bool, got %T", row[0]) + } + return bool(dbDatum), nil +} + +// HasRoleOptions is part of the SQLPrivilegeChecker interface. +// Note that the function returns plain errors, and it is the caller's +// responsibility to convert them to serverErrors. +func (c *adminPrivilegeChecker) HasRoleOption( + ctx context.Context, user username.SQLUsername, roleOption roleoption.Option, +) (bool, error) { + if user.IsRootUser() { + // Shortcut. + return true, nil + } + row, err := c.ie.QueryRowEx( + ctx, "check-role-option", nil, /* txn */ + sessiondata.InternalExecutorOverride{User: user}, + "SELECT crdb_internal.has_role_option($1)", roleOption.String()) + if err != nil { + return false, err + } + if row == nil { + return false, errors.AssertionFailedf("hasRoleOption: expected 1 row, got 0") + } + if len(row) != 1 { + return false, errors.AssertionFailedf("hasRoleOption: expected 1 column, got %d", len(row)) + } + dbDatum, ok := tree.AsDBool(row[0]) + if !ok { + return false, errors.AssertionFailedf("hasRoleOption: expected bool, got %T", row[0]) + } + return bool(dbDatum), nil +} + +// HasGlobalPrivilege is a helper function which calls +// CheckPrivilege and returns a true/false based on the returned +// result. +func (c *adminPrivilegeChecker) HasGlobalPrivilege( + ctx context.Context, user username.SQLUsername, privilege privilege.Kind, +) (bool, error) { + aa, cleanup := c.makeAuthzAccessor("check-system-privilege") + defer cleanup() + return aa.HasPrivilege(ctx, syntheticprivilege.GlobalPrivilegeObject, privilege, user) +} + +// TestingSetPlannerFn is used in tests only. +func (c *adminPrivilegeChecker) SetAuthzAccessorFactory( + fn func(opName string) (sql.AuthorizationAccessor, func()), +) { + c.makeAuthzAccessor = fn +} diff --git a/pkg/server/privchecker/privchecker_test.go b/pkg/server/privchecker/privchecker_test.go new file mode 100644 index 000000000000..266686018c02 --- /dev/null +++ b/pkg/server/privchecker/privchecker_test.go @@ -0,0 +1,163 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package privchecker_test + +import ( + "context" + "fmt" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server/privchecker" + "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/sql/isql" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" +) + +func TestAdminPrivilegeChecker(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(ctx) + + ts := s.TenantOrServer() + + sqlDB := sqlutils.MakeSQLRunner(db) + sqlDB.Exec(t, "CREATE USER withadmin") + sqlDB.Exec(t, "GRANT admin TO withadmin") + sqlDB.Exec(t, "CREATE USER withva") + sqlDB.Exec(t, "ALTER ROLE withva WITH VIEWACTIVITY") + sqlDB.Exec(t, "CREATE USER withvaredacted") + sqlDB.Exec(t, "ALTER ROLE withvaredacted WITH VIEWACTIVITYREDACTED") + sqlDB.Exec(t, "CREATE USER withvaandredacted") + sqlDB.Exec(t, "ALTER ROLE withvaandredacted WITH VIEWACTIVITY") + sqlDB.Exec(t, "ALTER ROLE withvaandredacted WITH VIEWACTIVITYREDACTED") + sqlDB.Exec(t, "CREATE USER withoutprivs") + sqlDB.Exec(t, "CREATE USER withvaglobalprivilege") + sqlDB.Exec(t, "GRANT SYSTEM VIEWACTIVITY TO withvaglobalprivilege") + sqlDB.Exec(t, "CREATE USER withvaredactedglobalprivilege") + sqlDB.Exec(t, "GRANT SYSTEM VIEWACTIVITYREDACTED TO withvaredactedglobalprivilege") + sqlDB.Exec(t, "CREATE USER withvaandredactedglobalprivilege") + sqlDB.Exec(t, "GRANT SYSTEM VIEWACTIVITY TO withvaandredactedglobalprivilege") + sqlDB.Exec(t, "GRANT SYSTEM VIEWACTIVITYREDACTED TO withvaandredactedglobalprivilege") + sqlDB.Exec(t, "CREATE USER withviewclustermetadata") + sqlDB.Exec(t, "GRANT SYSTEM VIEWCLUSTERMETADATA TO withviewclustermetadata") + sqlDB.Exec(t, "CREATE USER withviewdebug") + sqlDB.Exec(t, "GRANT SYSTEM VIEWDEBUG TO withviewdebug") + + execCfg := ts.ExecutorConfig().(sql.ExecutorConfig) + kvDB := ts.DB() + + plannerFn := func(opName string) (sql.AuthorizationAccessor, func()) { + // This is a hack to get around a Go package dependency cycle. See comment + // in sql/jobs/registry.go on planHookMaker. + txn := kvDB.NewTxn(ctx, "test") + p, cleanup := sql.NewInternalPlanner( + opName, + txn, + username.RootUserName(), + &sql.MemoryMetrics{}, + &execCfg, + sql.NewInternalSessionData(ctx, execCfg.Settings, opName), + ) + return p.(sql.AuthorizationAccessor), cleanup + } + + underTest := privchecker.NewChecker( + ts.InternalExecutor().(isql.Executor), + ts.ClusterSettings(), + ) + + underTest.SetAuthzAccessorFactory(plannerFn) + + withAdmin, err := username.MakeSQLUsernameFromPreNormalizedStringChecked("withadmin") + require.NoError(t, err) + withVa, err := username.MakeSQLUsernameFromPreNormalizedStringChecked("withva") + require.NoError(t, err) + withVaRedacted, err := username.MakeSQLUsernameFromPreNormalizedStringChecked("withvaredacted") + require.NoError(t, err) + withVaAndRedacted, err := username.MakeSQLUsernameFromPreNormalizedStringChecked("withvaandredacted") + require.NoError(t, err) + withoutPrivs, err := username.MakeSQLUsernameFromPreNormalizedStringChecked("withoutprivs") + require.NoError(t, err) + withVaGlobalPrivilege := username.MakeSQLUsernameFromPreNormalizedString("withvaglobalprivilege") + withVaRedactedGlobalPrivilege := username.MakeSQLUsernameFromPreNormalizedString("withvaredactedglobalprivilege") + withVaAndRedactedGlobalPrivilege := username.MakeSQLUsernameFromPreNormalizedString("withvaandredactedglobalprivilege") + withviewclustermetadata := username.MakeSQLUsernameFromPreNormalizedString("withviewclustermetadata") + withViewDebug := username.MakeSQLUsernameFromPreNormalizedString("withviewdebug") + + tests := []struct { + name string + checkerFun func(context.Context) error + usernameWantErr map[username.SQLUsername]bool + }{ + { + "requireViewActivityPermission", + underTest.RequireViewActivityPermission, + map[username.SQLUsername]bool{ + withAdmin: false, withVa: false, withVaRedacted: true, withVaAndRedacted: false, withoutPrivs: true, + withVaGlobalPrivilege: false, withVaRedactedGlobalPrivilege: true, withVaAndRedactedGlobalPrivilege: false, + }, + }, + { + "requireViewActivityOrViewActivityRedactedPermission", + underTest.RequireViewActivityOrViewActivityRedactedPermission, + map[username.SQLUsername]bool{ + withAdmin: false, withVa: false, withVaRedacted: false, withVaAndRedacted: false, withoutPrivs: true, + withVaGlobalPrivilege: false, withVaRedactedGlobalPrivilege: false, withVaAndRedactedGlobalPrivilege: false, + }, + }, + { + "requireViewActivityAndNoViewActivityRedactedPermission", + underTest.RequireViewActivityAndNoViewActivityRedactedPermission, + map[username.SQLUsername]bool{ + withAdmin: false, withVa: false, withVaRedacted: true, withVaAndRedacted: true, withoutPrivs: true, + withVaGlobalPrivilege: false, withVaRedactedGlobalPrivilege: true, withVaAndRedactedGlobalPrivilege: true, + }, + }, + { + "requireViewClusterMetadataPermission", + underTest.RequireViewClusterMetadataPermission, + map[username.SQLUsername]bool{ + withAdmin: false, withoutPrivs: true, withviewclustermetadata: false, + }, + }, + { + "requireViewDebugPermission", + underTest.RequireViewDebugPermission, + map[username.SQLUsername]bool{ + withAdmin: false, withoutPrivs: true, withViewDebug: false, + }, + }, + } + + for _, tt := range tests { + for userName, wantErr := range tt.usernameWantErr { + t.Run(fmt.Sprintf("%s-%s", tt.name, userName), func(t *testing.T) { + ctx := metadata.NewIncomingContext(ctx, metadata.New(map[string]string{"websessionuser": userName.SQLIdentifier()})) + err := tt.checkerFun(ctx) + if wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } + } +} diff --git a/pkg/server/problem_ranges.go b/pkg/server/problem_ranges.go index 441f74e7f172..5f2953820fe7 100644 --- a/pkg/server/problem_ranges.go +++ b/pkg/server/problem_ranges.go @@ -26,7 +26,7 @@ func (s *systemStatusServer) ProblemRanges( ) (*serverpb.ProblemRangesResponse, error) { ctx = s.AnnotateCtx(ctx) - if err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx); err != nil { + if err := s.privilegeChecker.RequireViewClusterMetadataPermission(ctx); err != nil { return nil, err } diff --git a/pkg/server/purge_auth_session_test.go b/pkg/server/purge_auth_session_test.go index dd2626e4ffd3..30a68e82ede4 100644 --- a/pkg/server/purge_auth_session_test.go +++ b/pkg/server/purge_auth_session_test.go @@ -17,6 +17,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" @@ -34,11 +35,11 @@ func TestPurgeSession(t *testing.T) { ts := s.(*TestServer) userName := username.TestUserName() - if err := ts.createAuthUser(userName, false /* isAdmin */); err != nil { + if err := ts.CreateAuthUser(userName, false /* isAdmin */); err != nil { t.Fatal(err) } - _, hashedSecret, err := CreateAuthSecret() + _, hashedSecret, err := authserver.CreateAuthSecret() if err != nil { t.Fatal(err) } diff --git a/pkg/server/rangetestutils/BUILD.bazel b/pkg/server/rangetestutils/BUILD.bazel new file mode 100644 index 000000000000..56f8d87af399 --- /dev/null +++ b/pkg/server/rangetestutils/BUILD.bazel @@ -0,0 +1,12 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "rangetestutils", + srcs = ["rangetestutils.go"], + importpath = "github.com/cockroachdb/cockroach/pkg/server/rangetestutils", + visibility = ["//visibility:public"], + deps = [ + "//pkg/base", + "//pkg/testutils/serverutils", + ], +) diff --git a/pkg/server/rangetestutils/rangetestutils.go b/pkg/server/rangetestutils/rangetestutils.go new file mode 100644 index 000000000000..5f5a4395dd48 --- /dev/null +++ b/pkg/server/rangetestutils/rangetestutils.go @@ -0,0 +1,49 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package rangetestutils + +import ( + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" +) + +// TestServerFactory creates test servers with an initial set of +// inspectable ranges. +type TestServerFactory interface { + // MakeRangeTestServerArgs instantiates TestServerArgs suitable + // to instantiate a range test server. + MakeRangeTestServerArgs() base.TestServerArgs + // PrepareRangeTestServer prepares a range test server for use. + PrepareRangeTestServer(testServer interface{}) error +} + +var srvFactoryImpl TestServerFactory + +// InitTestServerFactory should be called once to provide the +// implementation of the service. It will be called from a xx_test +// package that can import the server package. +func InitRangeTestServerFactory(impl TestServerFactory) { + srvFactoryImpl = impl +} + +// StartServer starts a server with multiple stores, a short scan +// interval, wait for the scan to complete, and return the server. The +// caller is responsible for stopping the server. +func StartServer(t testing.TB) serverutils.TestServerInterface { + params := srvFactoryImpl.MakeRangeTestServerArgs() + s, _, _ := serverutils.StartServer(t, params) + if err := srvFactoryImpl.PrepareRangeTestServer(s); err != nil { + t.Fatal(err) + } + return s +} diff --git a/pkg/server/server.go b/pkg/server/server.go index ab58665470d3..a12205befe86 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -69,8 +69,10 @@ import ( "github.com/cockroachdb/cockroach/pkg/rpc/nodedialer" "github.com/cockroachdb/cockroach/pkg/security/clientsecopts" "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/debug" "github.com/cockroachdb/cockroach/pkg/server/diagnostics" + "github.com/cockroachdb/cockroach/pkg/server/privchecker" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/server/serverrules" "github.com/cockroachdb/cockroach/pkg/server/status" @@ -157,12 +159,12 @@ type Server struct { ctSender *sidetransport.Sender http *httpServer - adminAuthzCheck *adminPrivilegeChecker + adminAuthzCheck privchecker.CheckerForRPCHandlers admin *systemAdminServer status *systemStatusServer drain *drainServer - decomNodeMap *decommissioningNodeMap - authentication *authenticationServer + decomNodeMap *DecommissioningNodeMap + authentication authserver.Server migrationServer *migrationServer tsDB *ts.DB tsServer *ts.Server @@ -484,7 +486,7 @@ func NewServer(cfg Config, stopper *stop.Stopper) (*Server, error) { stores := kvserver.NewStores(cfg.AmbientCtx, clock) - decomNodeMap := &decommissioningNodeMap{ + decomNodeMap := &DecommissioningNodeMap{ nodes: make(map[roachpb.NodeID]interface{}), } nodeLiveness := liveness.NewNodeLiveness(liveness.NodeLivenessOptions{ @@ -951,11 +953,7 @@ func NewServer(cfg Config, stopper *stop.Stopper) (*Server, error) { // Instantiate the API privilege checker. // // TODO(tbg): give adminServer only what it needs (and avoid circular deps). - adminAuthzCheck := &adminPrivilegeChecker{ - ie: internalExecutor, - st: st, - makePlanner: nil, - } + adminAuthzCheck := privchecker.NewChecker(internalExecutor, st) // Instantiate the HTTP server. // These callbacks help us avoid a dependency on gossip in httpServer. @@ -1159,11 +1157,11 @@ func NewServer(cfg Config, stopper *stop.Stopper) (*Server, error) { } // Tell the authz server how to connect to SQL. - adminAuthzCheck.makePlanner = func(opName string) (interface{}, func()) { + adminAuthzCheck.SetAuthzAccessorFactory(func(opName string) (sql.AuthorizationAccessor, func()) { // This is a hack to get around a Go package dependency cycle. See comment // in sql/jobs/registry.go on planHookMaker. txn := db.NewTxn(ctx, "check-system-privilege") - return sql.NewInternalPlanner( + p, cleanup := sql.NewInternalPlanner( opName, txn, username.RootUserName(), @@ -1171,10 +1169,11 @@ func NewServer(cfg Config, stopper *stop.Stopper) (*Server, error) { sqlServer.execCfg, sql.NewInternalSessionData(ctx, sqlServer.execCfg.Settings, opName), ) - } + return p.(sql.AuthorizationAccessor), cleanup + }) // Create the authentication RPC server (login/logout). - sAuth := newAuthenticationServer(cfg.Config, sqlServer) + sAuth := authserver.NewServer(cfg.Config, sqlServer) // Create a drain server. drain := newDrainServer(cfg.BaseConfig, stopper, stopTrigger, grpcServer, sqlServer) diff --git a/pkg/server/server_controller_http.go b/pkg/server/server_controller_http.go index 59a1f088d94d..bd5bd2aad243 100644 --- a/pkg/server/server_controller_http.go +++ b/pkg/server/server_controller_http.go @@ -18,6 +18,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/multitenant" "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/timeutil" @@ -32,10 +33,6 @@ const ( // to select a particular virtual cluster. ClusterNameParamInQueryURL = "cluster" - // TenantSelectCookieName is the name of the HTTP cookie used to select a particular tenant, - // if the custom header is not specified. - TenantSelectCookieName = `tenant` - // AcceptHeader is the canonical header name for accept. AcceptHeader = "Accept" @@ -59,10 +56,10 @@ func (c *serverController) httpMux(w http.ResponseWriter, r *http.Request) { // routed to a specific node and skip the fanout, creating inconsistent // outcomes that were path-dependent on the user's existing cookies. switch r.URL.Path { - case loginPath, DemoLoginPath: + case authserver.LoginPath, authserver.DemoLoginPath: c.attemptLoginToAllTenants().ServeHTTP(w, r) return - case logoutPath: + case authserver.LogoutPath: // Since we do not support per-tenant logout until // https://github.com/cockroachdb/cockroach/issues/92855 // is completed, we should always fanout a logout @@ -77,14 +74,14 @@ func (c *serverController) httpMux(w http.ResponseWriter, r *http.Request) { log.Warningf(ctx, "unable to find server for tenant %q: %v", tenantName, err) // Clear session and tenant cookies since it appears they reference invalid state. http.SetCookie(w, &http.Cookie{ - Name: SessionCookieName, + Name: authserver.SessionCookieName, Value: "", Path: "/", HttpOnly: true, Expires: timeutil.Unix(0, 0), }) http.SetCookie(w, &http.Cookie{ - Name: TenantSelectCookieName, + Name: authserver.TenantSelectCookieName, Value: "", Path: "/", HttpOnly: false, @@ -120,7 +117,7 @@ func getTenantNameFromHTTPRequest(st *cluster.Settings, r *http.Request) roachpb } // No parameter, no explicit header. Is there a cookie? - if c, _ := r.Cookie(TenantSelectCookieName); c != nil && c.Value != "" { + if c, _ := r.Cookie(authserver.TenantSelectCookieName); c != nil && c.Value != "" { return roachpb.TenantName(c.Value) } @@ -146,7 +143,7 @@ func (c *serverController) attemptLoginToAllTenants() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() tenantNames := c.getCurrentTenantNames() - var tenantNameToSetCookieSlice []sessionCookieValue + var tenantNameToSetCookieSlice []authserver.SessionCookieValue // The request body needs to be cloned since r.Clone() does not do it. clonedBody, err := io.ReadAll(r.Body) if err != nil { @@ -188,10 +185,10 @@ func (c *serverController) attemptLoginToAllTenants() http.Handler { collectedErrors[i] = sw.buf.String() log.Warningf(ctx, "unable to find session cookie for tenant %q: HTTP %d - %s", name, sw.code, &sw.buf) } else { - tenantNameToSetCookieSlice = append(tenantNameToSetCookieSlice, sessionCookieValue{ - name: string(name), - setCookie: setCookieHeader, - }) + tenantNameToSetCookieSlice = append(tenantNameToSetCookieSlice, authserver.MakeSessionCookieValue( + string(name), + setCookieHeader, + )) // In the case of /demologin, we want to redirect to the provided location // in the header. If we get back a cookie along with an // http.StatusTemporaryRedirect code, be sure to transfer the response code @@ -208,9 +205,9 @@ func (c *serverController) attemptLoginToAllTenants() http.Handler { // be called and cookies should be set. Otherwise, login was not successful // for any of the tenants. if len(tenantNameToSetCookieSlice) > 0 { - sessionsStr := createAggregatedSessionCookieValue(tenantNameToSetCookieSlice) + sessionsStr := authserver.CreateAggregatedSessionCookieValue(tenantNameToSetCookieSlice) cookie := http.Cookie{ - Name: SessionCookieName, + Name: authserver.SessionCookieName, Value: sessionsStr, Path: "/", HttpOnly: false, @@ -222,16 +219,16 @@ func (c *serverController) attemptLoginToAllTenants() http.Handler { // We only set the default selection from the cluster setting // if it's one of the valid logins. Otherwise, we just use the // first one in the list. - tenantSelection := tenantNameToSetCookieSlice[0].name + tenantSelection := tenantNameToSetCookieSlice[0].Name() defaultName := multitenant.DefaultTenantSelect.Get(&c.st.SV) for _, t := range tenantNameToSetCookieSlice { - if t.name == defaultName { - tenantSelection = t.name + if t.Name() == defaultName { + tenantSelection = t.Name() break } } cookie = http.Cookie{ - Name: TenantSelectCookieName, + Name: authserver.TenantSelectCookieName, Value: tenantSelection, Path: "/", HttpOnly: false, @@ -277,9 +274,9 @@ func (c *serverController) attemptLogoutFromAllTenants() http.Handler { w.WriteHeader(http.StatusInternalServerError) return } - sessionCookie, err := r.Cookie(SessionCookieName) + sessionCookie, err := r.Cookie(authserver.SessionCookieName) if errors.Is(err, http.ErrNoCookie) { - sessionCookie, err = r.Cookie(SessionCookieName) + sessionCookie, err = r.Cookie(authserver.SessionCookieName) if err != nil { log.Warningf(ctx, "unable to find session cookie: %v", err) w.WriteHeader(http.StatusInternalServerError) @@ -328,7 +325,7 @@ func (c *serverController) attemptLogoutFromAllTenants() http.Handler { } // Clear session and tenant cookies after all logouts have completed. cookie := http.Cookie{ - Name: SessionCookieName, + Name: authserver.SessionCookieName, Value: "", Path: "/", HttpOnly: false, @@ -336,7 +333,7 @@ func (c *serverController) attemptLogoutFromAllTenants() http.Handler { } http.SetCookie(w, &cookie) cookie = http.Cookie{ - Name: TenantSelectCookieName, + Name: authserver.TenantSelectCookieName, Value: "", Path: "/", HttpOnly: false, diff --git a/pkg/server/server_http.go b/pkg/server/server_http.go index 897bccfb02d5..e286f4960c3f 100644 --- a/pkg/server/server_http.go +++ b/pkg/server/server_http.go @@ -19,8 +19,12 @@ import ( "github.com/cockroachdb/cmux" "github.com/cockroachdb/cockroach/pkg/inspectz" "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/debug" + "github.com/cockroachdb/cockroach/pkg/server/privchecker" "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srverrors" "github.com/cockroachdb/cockroach/pkg/server/status" "github.com/cockroachdb/cockroach/pkg/settings" "github.com/cockroachdb/cockroach/pkg/ts" @@ -89,8 +93,8 @@ func (s *httpServer) handleHealth(healthHandler http.Handler) { func (s *httpServer) setupRoutes( ctx context.Context, - authnServer *authenticationServer, - adminAuthzCheck *adminPrivilegeChecker, + authnServer authserver.Server, + adminAuthzCheck privchecker.CheckerForRPCHandlers, metricSource metricMarshaler, runtimeStatSampler *status.RuntimeStatSampler, handleRequestsUnauthenticated http.Handler, @@ -101,7 +105,7 @@ func (s *httpServer) setupRoutes( ) error { // OIDC Configuration must happen prior to the UI Handler being defined below so that we have // the system settings initialized for it to pick up from the oidcAuthenticationServer. - oidc, err := ConfigureOIDC( + oidc, err := authserver.ConfigureOIDC( ctx, s.cfg.Settings, s.cfg.Locality, s.mux.Handle, authnServer.UserLoginFromSSO, s.cfg.AmbientCtx, s.cfg.ClusterIDContainer.Get(), ) @@ -115,7 +119,7 @@ func (s *httpServer) setupRoutes( NodeID: s.cfg.IDContainer, OIDC: oidc, GetUser: func(ctx context.Context) *string { - if user, ok := maybeUserFromHTTPAuthInfoContext(ctx); ok { + if user, ok := authserver.MaybeUserFromHTTPAuthInfoContext(ctx); ok { ustring := user.Normalized() return &ustring } @@ -128,49 +132,49 @@ func (s *httpServer) setupRoutes( // assets are served up whether or not there is a session. If there is a session, the mux // adds it to the context, and it is templated into index.html so that the UI can show // the username of the currently-logged-in user. - authenticatedUIHandler := newAuthenticationMuxAllowAnonymous( - authnServer, assetHandler) + authenticatedUIHandler := authserver.NewMux( + authnServer, assetHandler, true /* allowAnonymous */) s.mux.Handle("/", authenticatedUIHandler) // Add HTTP authentication to the gRPC-gateway endpoints used by the UI, // if not disabled by configuration. var authenticatedHandler = handleRequestsUnauthenticated if !s.cfg.InsecureWebAccess() { - authenticatedHandler = newAuthenticationMux(authnServer, authenticatedHandler) + authenticatedHandler = authserver.NewMux(authnServer, authenticatedHandler, false /* allowAnonymous */) } // Login and logout paths. // The /login endpoint is, by definition, available pre-authentication. - s.mux.Handle(loginPath, handleRequestsUnauthenticated) - s.mux.Handle(logoutPath, authenticatedHandler) + s.mux.Handle(authserver.LoginPath, handleRequestsUnauthenticated) + s.mux.Handle(authserver.LogoutPath, authenticatedHandler) // The login path for 'cockroach demo', if we're currently running // that. if s.cfg.EnableDemoLoginEndpoint { - s.mux.Handle(DemoLoginPath, http.HandlerFunc(authnServer.demoLogin)) + s.mux.Handle(authserver.DemoLoginPath, http.HandlerFunc(authnServer.DemoLogin)) } // Admin/Status servers. These are used by the UI via RPC-over-HTTP. - s.mux.Handle(statusPrefix, authenticatedHandler) - s.mux.Handle(adminPrefix, authenticatedHandler) + s.mux.Handle(apiconstants.StatusPrefix, authenticatedHandler) + s.mux.Handle(apiconstants.AdminPrefix, authenticatedHandler) // The timeseries endpoint, used to produce graphs. s.mux.Handle(ts.URLPrefix, authenticatedHandler) // Exempt the 2nd health check endpoint from authentication. // (This simply mirrors /health and exists for backward compatibility.) - s.mux.Handle(adminHealth, handleRequestsUnauthenticated) + s.mux.Handle(apiconstants.AdminHealth, handleRequestsUnauthenticated) // The /_status/vars endpoint is not authenticated either. Useful for monitoring. - s.mux.Handle(statusVars, http.HandlerFunc(varsHandler{metricSource, s.cfg.Settings}.handleVars)) + s.mux.Handle(apiconstants.StatusVars, http.HandlerFunc(varsHandler{metricSource, s.cfg.Settings}.handleVars)) // Same for /_status/load. le, err := newLoadEndpoint(runtimeStatSampler, metricSource) if err != nil { return err } - s.mux.Handle(loadStatusVars, le) + s.mux.Handle(apiconstants.LoadStatusVars, le) if apiServer != nil { // The new "v2" HTTP API tree. - s.mux.Handle(apiV2Path, apiServer) + s.mux.Handle(apiconstants.APIV2Path, apiServer) } // Register debugging endpoints. @@ -179,10 +183,10 @@ func (s *httpServer) setupRoutes( if !s.cfg.InsecureWebAccess() { // Mandate both authentication and admin authorization. handleDebugAuthenticated = makeAdminAuthzCheckHandler(adminAuthzCheck, handleDebugAuthenticated) - handleDebugAuthenticated = newAuthenticationMux(authnServer, handleDebugAuthenticated) + handleDebugAuthenticated = authserver.NewMux(authnServer, handleDebugAuthenticated, false /* allowAnonymous */) handleInspectzAuthenticated = makeAdminAuthzCheckHandler(adminAuthzCheck, handleInspectzAuthenticated) - handleInspectzAuthenticated = newAuthenticationMux(authnServer, handleInspectzAuthenticated) + handleInspectzAuthenticated = authserver.NewMux(authnServer, handleInspectzAuthenticated, false /* allowAnonymous */) } s.mux.Handle(debug.Endpoint, handleDebugAuthenticated) s.mux.Handle(inspectz.URLPrefix, handleInspectzAuthenticated) @@ -193,15 +197,15 @@ func (s *httpServer) setupRoutes( } func makeAdminAuthzCheckHandler( - adminAuthzCheck *adminPrivilegeChecker, handler http.Handler, + adminAuthzCheck privchecker.CheckerForRPCHandlers, handler http.Handler, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { // Retrieve the username embedded in the grpc metadata, if any. // This will be provided by the authenticationMux. - md := translateHTTPAuthInfoToGRPCMetadata(req.Context(), req) + md := authserver.TranslateHTTPAuthInfoToGRPCMetadata(req.Context(), req) authCtx := metadata.NewIncomingContext(req.Context(), md) // Check the privileges of the requester. - err := adminAuthzCheck.requireViewDebugPermission(authCtx) + err := adminAuthzCheck.RequireViewDebugPermission(authCtx) if err != nil { http.Error(w, "admin privilege or VIEWDEBUG global privilege required", http.StatusUnauthorized) return @@ -316,7 +320,7 @@ func (s *httpServer) baseHandler(w http.ResponseWriter, r *http.Request) { // Note: use of a background context here so we can log even with the absence of a client. // Assumes appropriate timeouts are used. logcrash.ReportPanic(context.Background(), &s.cfg.Settings.SV, p, 1 /* depth */) - http.Error(w, errAPIInternalErrorString, http.StatusInternalServerError) + http.Error(w, srverrors.ErrAPIInternalErrorString, http.StatusInternalServerError) } }() diff --git a/pkg/server/server_sql.go b/pkg/server/server_sql.go index 0d237cba3294..14dfc095b649 100644 --- a/pkg/server/server_sql.go +++ b/pkg/server/server_sql.go @@ -87,6 +87,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/flowinfra" "github.com/cockroachdb/cockroach/pkg/sql/gcjob/gcjobnotifier" "github.com/cockroachdb/cockroach/pkg/sql/idxusage" + "github.com/cockroachdb/cockroach/pkg/sql/isql" "github.com/cockroachdb/cockroach/pkg/sql/optionalnodeliveness" "github.com/cockroachdb/cockroach/pkg/sql/pgwire" "github.com/cockroachdb/cockroach/pkg/sql/querycache" @@ -1889,3 +1890,13 @@ func (s *SQLServer) LogicalClusterID() uuid.UUID { func (s *SQLServer) ShutdownRequested() <-chan ShutdownRequest { return s.stopTrigger.C() } + +// ExecutorConfig is an accessor for the executor config. +func (s *SQLServer) ExecutorConfig() *sql.ExecutorConfig { + return s.execCfg +} + +// InternalExecutor returns an executor for internal SQL queries. +func (s *SQLServer) InternalExecutor() isql.Executor { + return s.internalExecutor +} diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index d009095ad26e..c6b92b9ddc0e 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -40,7 +40,9 @@ import ( "github.com/cockroachdb/cockroach/pkg/kv/kvserver" "github.com/cockroachdb/cockroach/pkg/kv/kvserver/liveness/livenesspb" "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" "github.com/cockroachdb/cockroach/pkg/server/status/statuspb" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/sql" @@ -61,7 +63,7 @@ import ( "github.com/gogo/protobuf/jsonpb" "github.com/gogo/protobuf/proto" "github.com/grpc-ecosystem/grpc-gateway/runtime" - "github.com/jackc/pgx/v4" + pgx "github.com/jackc/pgx/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" @@ -254,7 +256,7 @@ func TestPlainHTTPServer(t *testing.T) { // still works in insecure mode. var data serverpb.JSONResponse testutils.SucceedsSoon(t, func() error { - return getStatusJSONProto(s, "metrics/local", &data) + return srvtestutils.GetStatusJSONProto(s, "metrics/local", &data) }) // Now make a couple of direct requests using both http and https. @@ -372,7 +374,7 @@ func TestAcceptEncoding(t *testing.T) { } for _, d := range testData { func() { - req, err := http.NewRequest("GET", s.AdminURL().WithPath(statusPrefix+"metrics/local").String(), nil) + req, err := http.NewRequest("GET", s.AdminURL().WithPath(apiconstants.StatusPrefix+"metrics/local").String(), nil) if err != nil { t.Fatalf("could not create request: %s", err) } diff --git a/pkg/server/sql_stats.go b/pkg/server/sql_stats.go index 411166d3bf9c..ac8404b8eb02 100644 --- a/pkg/server/sql_stats.go +++ b/pkg/server/sql_stats.go @@ -14,6 +14,7 @@ import ( "context" "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/errors" "google.golang.org/grpc/codes" @@ -23,10 +24,10 @@ import ( func (s *statusServer) ResetSQLStats( ctx context.Context, req *serverpb.ResetSQLStatsRequest, ) (*serverpb.ResetSQLStatsResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { return nil, err } diff --git a/pkg/server/srverrors/BUILD.bazel b/pkg/server/srverrors/BUILD.bazel new file mode 100644 index 000000000000..07a8ed61fd36 --- /dev/null +++ b/pkg/server/srverrors/BUILD.bazel @@ -0,0 +1,32 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "srverrors", + srcs = ["errors.go"], + importpath = "github.com/cockroachdb/cockroach/pkg/server/srverrors", + visibility = ["//visibility:public"], + deps = [ + "//pkg/sql/pgwire/pgcode", + "//pkg/sql/pgwire/pgerror", + "//pkg/util/log", + "@org_golang_google_grpc//codes", + "@org_golang_google_grpc//status", + ], +) + +go_test( + name = "srverrors_test", + srcs = [ + "errors_test.go", + "main_test.go", + ], + args = ["-test.timeout=295s"], + deps = [ + ":srverrors", + "//pkg/sql/pgwire/pgcode", + "//pkg/sql/pgwire/pgerror", + "//pkg/util/leaktest", + "//pkg/util/log", + "@com_github_stretchr_testify//require", + ], +) diff --git a/pkg/server/srverrors/errors.go b/pkg/server/srverrors/errors.go new file mode 100644 index 000000000000..df6a85d72fd1 --- /dev/null +++ b/pkg/server/srverrors/errors.go @@ -0,0 +1,80 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package srverrors + +import ( + "context" + "fmt" + "net/http" + + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" + "github.com/cockroachdb/cockroach/pkg/util/log" + "google.golang.org/grpc/codes" + grpcstatus "google.golang.org/grpc/status" +) + +// ServerError logs the provided error and returns an error that should be returned by +// the RPC endpoint method. +func ServerError(ctx context.Context, err error) error { + log.ErrorfDepth(ctx, 1, "%+v", err) + + // Include the PGCode in the message for easier troubleshooting + errCode := pgerror.GetPGCode(err).String() + if errCode != pgcode.Uncategorized.String() { + errMessage := fmt.Sprintf("%s Error Code: %s", ErrAPIInternalErrorString, errCode) + return grpcstatus.Errorf(codes.Internal, errMessage) + } + + // The error is already grpcstatus formatted error. + // Likely calling serverError multiple times on same error. + grpcCode := grpcstatus.Code(err) + if grpcCode != codes.Unknown { + return err + } + + // Fallback to generic message. + return ErrAPIInternalError +} + +// ServerErrorf logs the provided error and returns an error that should be returned by +// he RPC endpoint method. +func ServerErrorf(ctx context.Context, format string, args ...interface{}) error { + log.ErrorfDepth(ctx, 1, format, args...) + return ErrAPIInternalError +} + +// ErrAPIInternalErrorString is the string printed out in the UI when an internal error was encountered. +var ErrAPIInternalErrorString = "An internal server error has occurred. Please check your CockroachDB logs for more details." + +// ErrAPIInternalError is the gRPC status error returned when an internal error was encountered. +var ErrAPIInternalError = grpcstatus.Errorf( + codes.Internal, + ErrAPIInternalErrorString, +) + +// APIInternalError should be used to wrap server-side errors during API +// requests. This method records the contents of the error to the server log, +// and returns a standard GRPC error which is appropriate to return to the +// client. +func APIInternalError(ctx context.Context, err error) error { + log.ErrorfDepth(ctx, 1, "%s", err) + return ErrAPIInternalError +} + +// APIV2InternalError should be used to wrap server-side errors during API +// requests for V2 (non-GRPC) endpoints. This method records the contents +// of the error to the server log, and sends the standard internal error string +// over the http.ResponseWriter. +func APIV2InternalError(ctx context.Context, err error, w http.ResponseWriter) { + log.ErrorfDepth(ctx, 1, "%s", err) + http.Error(w, ErrAPIInternalErrorString, http.StatusInternalServerError) +} diff --git a/pkg/server/srverrors/errors_test.go b/pkg/server/srverrors/errors_test.go new file mode 100644 index 000000000000..4114a548344d --- /dev/null +++ b/pkg/server/srverrors/errors_test.go @@ -0,0 +1,40 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package srverrors_test + +import ( + "context" + "fmt" + "testing" + + "github.com/cockroachdb/cockroach/pkg/server/srverrors" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/stretchr/testify/require" +) + +func TestServerError(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ctx := context.Background() + pgError := pgerror.New(pgcode.OutOfMemory, "TestServerError.OutOfMemory") + err := srverrors.ServerError(ctx, pgError) + require.Equal(t, "rpc error: code = Internal desc = An internal server error has occurred. Please check your CockroachDB logs for more details. Error Code: 53200", err.Error()) + + err = srverrors.ServerError(ctx, err) + require.Equal(t, "rpc error: code = Internal desc = An internal server error has occurred. Please check your CockroachDB logs for more details. Error Code: 53200", err.Error()) + + err = fmt.Errorf("random error that is not pgerror or grpcstatus") + err = srverrors.ServerError(ctx, err) + require.Equal(t, "rpc error: code = Internal desc = An internal server error has occurred. Please check your CockroachDB logs for more details.", err.Error()) +} diff --git a/pkg/server/srverrors/main_test.go b/pkg/server/srverrors/main_test.go new file mode 100644 index 000000000000..c9291a69d0bc --- /dev/null +++ b/pkg/server/srverrors/main_test.go @@ -0,0 +1,22 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package srverrors_test + +import ( + "os" + "testing" +) + +func TestMain(m *testing.M) { + os.Exit(m.Run()) +} + +//go:generate ../util/leaktest/add-leaktest.sh *_test.go diff --git a/pkg/server/srvtestutils/BUILD.bazel b/pkg/server/srvtestutils/BUILD.bazel new file mode 100644 index 000000000000..126c3da4aec9 --- /dev/null +++ b/pkg/server/srvtestutils/BUILD.bazel @@ -0,0 +1,18 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "srvtestutils", + srcs = ["testutils.go"], + importpath = "github.com/cockroachdb/cockroach/pkg/server/srvtestutils", + visibility = ["//visibility:public"], + deps = [ + "//pkg/base", + "//pkg/roachpb", + "//pkg/rpc", + "//pkg/server/apiconstants", + "//pkg/testutils/serverutils", + "//pkg/util/protoutil", + "@com_github_cockroachdb_errors//:errors", + "@com_github_cockroachdb_logtags//:logtags", + ], +) diff --git a/pkg/server/srvtestutils/testutils.go b/pkg/server/srvtestutils/testutils.go new file mode 100644 index 000000000000..7afffc5d3d2f --- /dev/null +++ b/pkg/server/srvtestutils/testutils.go @@ -0,0 +1,157 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package srvtestutils + +import ( + "context" + "encoding/json" + "io" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/protoutil" + "github.com/cockroachdb/errors" + "github.com/cockroachdb/logtags" +) + +// GetAdminJSONProto performs a RPC-over-HTTP request to the admin endpoint +// and unmarshals the response into the specified proto message. +func GetAdminJSONProto( + ts serverutils.TestTenantInterface, path string, response protoutil.Message, +) error { + return GetAdminJSONProtoWithAdminOption(ts, path, response, true) +} + +// GetAdminJSONProtoWithAdminOption performs a RPC-over-HTTP request to +// the admin endpoint and unmarshals the response into the specified +// proto message. It allows the caller to control whether the request +// is made with the admin role. +func GetAdminJSONProtoWithAdminOption( + ts serverutils.TestTenantInterface, path string, response protoutil.Message, isAdmin bool, +) error { + return serverutils.GetJSONProtoWithAdminOption(ts, apiconstants.AdminPrefix+path, response, isAdmin) +} + +// PostAdminJSONProto performs a RPC-over-HTTP request to the admin endpoint +// and unmarshals the response into the specified proto message. +func PostAdminJSONProto( + ts serverutils.TestTenantInterface, path string, request, response protoutil.Message, +) error { + return PostAdminJSONProtoWithAdminOption(ts, path, request, response, true) +} + +// PostAdminJSONProtoWithAdminOption performs a RPC-over-HTTP request to +// the admin endpoint and unmarshals the response into the specified +// proto message. It allows the caller to control whether the request +// is made with the admin role. +func PostAdminJSONProtoWithAdminOption( + ts serverutils.TestTenantInterface, + path string, + request, response protoutil.Message, + isAdmin bool, +) error { + return serverutils.PostJSONProtoWithAdminOption(ts, apiconstants.AdminPrefix+path, request, response, isAdmin) +} + +// GetStatusJSONProto performs a RPC-over-HTTP request to the status endpoint +// and unmarshals the response into the specified proto message. +func GetStatusJSONProto( + ts serverutils.TestTenantInterface, path string, response protoutil.Message, +) error { + return serverutils.GetJSONProto(ts, apiconstants.StatusPrefix+path, response) +} + +// PostStatusJSONProto performs a RPC-over-HTTP request to the status endpoint +// and unmarshals the response into the specified proto message. +func PostStatusJSONProto( + ts serverutils.TestTenantInterface, path string, request, response protoutil.Message, +) error { + return serverutils.PostJSONProto(ts, apiconstants.StatusPrefix+path, request, response) +} + +// GetStatusJSONProtoWithAdminOption performs a RPC-over-HTTP request to +// the status endpoint and unmarshals the response into the specified +// proto message. It allows the caller to control whether the request +// is made with the admin role. +func GetStatusJSONProtoWithAdminOption( + ts serverutils.TestTenantInterface, path string, response protoutil.Message, isAdmin bool, +) error { + return serverutils.GetJSONProtoWithAdminOption(ts, apiconstants.StatusPrefix+path, response, isAdmin) +} + +// PostStatusJSONProtoWithAdminOption performs a RPC-over-HTTP request to +// the status endpoint and unmarshals the response into the specified +// proto message. It allows the caller to control whether the request +// is made with the admin role. +func PostStatusJSONProtoWithAdminOption( + ts serverutils.TestTenantInterface, + path string, + request, response protoutil.Message, + isAdmin bool, +) error { + return serverutils.PostJSONProtoWithAdminOption(ts, apiconstants.StatusPrefix+path, request, response, isAdmin) +} + +// GetText fetches the HTTP response body as text in the form of a +// byte slice from the specified URL. +func GetText(ts serverutils.TestTenantInterface, url string) ([]byte, error) { + httpClient, err := ts.GetAdminHTTPClient() + if err != nil { + return nil, err + } + resp, err := httpClient.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return io.ReadAll(resp.Body) +} + +// GetJSON fetches the JSON from the specified URL and returns +// it as unmarshaled JSON. Returns an error on any failure to fetch +// or unmarshal response body. +func GetJSON(ts serverutils.TestTenantInterface, url string) (interface{}, error) { + body, err := GetText(ts, url) + if err != nil { + return nil, err + } + var jI interface{} + if err := json.Unmarshal(body, &jI); err != nil { + return nil, errors.Wrapf(err, "body is:\n%s", body) + } + return jI, nil +} + +// NewRPCTestContext constructs a RPC context for use in API tests. +func NewRPCTestContext( + ctx context.Context, ts serverutils.TestServerInterface, cfg *base.Config, +) *rpc.Context { + var c base.NodeIDContainer + ctx = logtags.AddTag(ctx, "n", &c) + rpcContext := rpc.NewContext(ctx, rpc.ContextOptions{ + TenantID: roachpb.SystemTenantID, + NodeID: &c, + Config: cfg, + Clock: ts.Clock().WallClock(), + ToleratedOffset: ts.Clock().ToleratedOffset(), + Stopper: ts.Stopper(), + Settings: ts.ClusterSettings(), + Knobs: rpc.ContextTestingKnobs{NoLoopbackDialer: true}, + }) + // Ensure that the RPC client context validates the server cluster ID. + // This ensures that a test where the server is restarted will not let + // its test RPC client talk to a server started by an unrelated concurrent test. + rpcContext.StorageClusterID.Set(ctx, ts.StorageClusterID()) + return rpcContext +} diff --git a/pkg/server/statement_diagnostics_requests.go b/pkg/server/statement_diagnostics_requests.go index a9a04fbe0fc1..084695988812 100644 --- a/pkg/server/statement_diagnostics_requests.go +++ b/pkg/server/statement_diagnostics_requests.go @@ -16,6 +16,7 @@ import ( "time" "github.com/cockroachdb/cockroach/pkg/clusterversion" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" @@ -71,10 +72,10 @@ func (diagnostics *stmtDiagnostics) toProto() serverpb.StatementDiagnostics { func (s *statusServer) CreateStatementDiagnosticsReport( ctx context.Context, req *serverpb.CreateStatementDiagnosticsReportRequest, ) (*serverpb.CreateStatementDiagnosticsReportResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if err := s.privilegeChecker.requireViewActivityAndNoViewActivityRedactedPermission(ctx); err != nil { + if err := s.privilegeChecker.RequireViewActivityAndNoViewActivityRedactedPermission(ctx); err != nil { return nil, err } @@ -103,10 +104,10 @@ func (s *statusServer) CreateStatementDiagnosticsReport( func (s *statusServer) CancelStatementDiagnosticsReport( ctx context.Context, req *serverpb.CancelStatementDiagnosticsReportRequest, ) (*serverpb.CancelStatementDiagnosticsReportResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if err := s.privilegeChecker.requireViewActivityAndNoViewActivityRedactedPermission(ctx); err != nil { + if err := s.privilegeChecker.RequireViewActivityAndNoViewActivityRedactedPermission(ctx); err != nil { return nil, err } @@ -127,10 +128,10 @@ func (s *statusServer) CancelStatementDiagnosticsReport( func (s *statusServer) StatementDiagnosticsRequests( ctx context.Context, _ *serverpb.StatementDiagnosticsReportsRequest, ) (*serverpb.StatementDiagnosticsReportsResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if err := s.privilegeChecker.requireViewActivityAndNoViewActivityRedactedPermission(ctx); err != nil { + if err := s.privilegeChecker.RequireViewActivityAndNoViewActivityRedactedPermission(ctx); err != nil { return nil, err } @@ -221,10 +222,10 @@ func (s *statusServer) StatementDiagnosticsRequests( func (s *statusServer) StatementDiagnostics( ctx context.Context, req *serverpb.StatementDiagnosticsRequest, ) (*serverpb.StatementDiagnosticsResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if err := s.privilegeChecker.requireViewActivityAndNoViewActivityRedactedPermission(ctx); err != nil { + if err := s.privilegeChecker.RequireViewActivityAndNoViewActivityRedactedPermission(ctx); err != nil { return nil, err } diff --git a/pkg/server/statements.go b/pkg/server/statements.go index 41a8ff5fc7be..35c0d9b48206 100644 --- a/pkg/server/statements.go +++ b/pkg/server/statements.go @@ -14,6 +14,7 @@ import ( "context" "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/sql/appstatspb" "github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants" @@ -33,10 +34,10 @@ func (s *statusServer) Statements( return s.CombinedStatementStats(ctx, &combinedRequest) } - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { + if err := s.privilegeChecker.RequireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { return nil, err } diff --git a/pkg/server/status.go b/pkg/server/status.go index 8413533ac252..af46ff780eac 100644 --- a/pkg/server/status.go +++ b/pkg/server/status.go @@ -49,8 +49,12 @@ import ( "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/security" "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/diagnostics/diagnosticspb" + "github.com/cockroachdb/cockroach/pkg/server/privchecker" "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srverrors" "github.com/cockroachdb/cockroach/pkg/server/status/statuspb" "github.com/cockroachdb/cockroach/pkg/server/telemetry" "github.com/cockroachdb/cockroach/pkg/settings/cluster" @@ -68,7 +72,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants" "github.com/cockroachdb/cockroach/pkg/sql/sqlinstance" "github.com/cockroachdb/cockroach/pkg/sql/sqlstats/insights" - "github.com/cockroachdb/cockroach/pkg/util" "github.com/cockroachdb/cockroach/pkg/util/grpcutil" "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/cockroach/pkg/util/httputil" @@ -95,27 +98,8 @@ const ( // Default Maximum number of log entries returned. defaultMaxLogEntries = 1000 - // statusPrefix is the root of the cluster statistics and metrics API. - statusPrefix = "/_status/" - - // statusVars exposes prometheus metrics for monitoring consumption. - statusVars = statusPrefix + "vars" - - // loadStatusVars exposes prometheus metrics for instant monitoring of CPU load. - loadStatusVars = statusPrefix + "load" - - // raftStateDormant is used when there is no known raft state. - raftStateDormant = "StateDormant" - - // maxConcurrentRequests is the maximum number of RPC fan-out requests - // that will be made at any point of time. - maxConcurrentRequests = 100 - - // maxConcurrentPaginatedRequests is the maximum number of RPC fan-out - // requests that will be made at any point of time for a row-limited / - // paginated request. This should be much lower than maxConcurrentRequests - // as too much concurrency here can result in wasted results. - maxConcurrentPaginatedRequests = 4 + // RaftStateDormant is used when there is no known raft state. + RaftStateDormant = "StateDormant" ) var ( @@ -145,7 +129,7 @@ type baseStatusServer struct { serverpb.UnimplementedStatusServer log.AmbientContext - privilegeChecker *adminPrivilegeChecker + privilegeChecker privchecker.CheckerForRPCHandlers sessionRegistry *sql.SessionRegistry closedSessionCache *sql.ClosedSessionCache remoteFlowRunner *flowinfra.RemoteFlowRunner @@ -166,27 +150,27 @@ func isInternalAppName(app string) bool { func (b *baseStatusServer) getLocalSessions( ctx context.Context, req *serverpb.ListSessionsRequest, ) ([]serverpb.Session, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = b.AnnotateCtx(ctx) - sessionUser, isAdmin, err := b.privilegeChecker.getUserAndRole(ctx) + sessionUser, isAdmin, err := b.privilegeChecker.GetUserAndRole(ctx) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } - hasViewActivityRedacted, err := b.privilegeChecker.hasRoleOption(ctx, sessionUser, roleoption.VIEWACTIVITYREDACTED) + hasViewActivityRedacted, err := b.privilegeChecker.HasRoleOption(ctx, sessionUser, roleoption.VIEWACTIVITYREDACTED) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } - hasViewActivity, err := b.privilegeChecker.hasRoleOption(ctx, sessionUser, roleoption.VIEWACTIVITY) + hasViewActivity, err := b.privilegeChecker.HasRoleOption(ctx, sessionUser, roleoption.VIEWACTIVITY) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } reqUsername, err := username.MakeSQLUsernameFromPreNormalizedStringChecked(req.Username) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } if !isAdmin && !hasViewActivity && !hasViewActivityRedacted { @@ -270,19 +254,19 @@ func (b *baseStatusServer) getLocalSessions( func (b *baseStatusServer) checkCancelPrivilege( ctx context.Context, reqUsername username.SQLUsername, sessionUsername username.SQLUsername, ) error { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = b.AnnotateCtx(ctx) - ctxUsername, isCtxAdmin, err := b.privilegeChecker.getUserAndRole(ctx) + ctxUsername, isCtxAdmin, err := b.privilegeChecker.GetUserAndRole(ctx) if err != nil { - return serverError(ctx, err) + return srverrors.ServerError(ctx, err) } if reqUsername.Undefined() { reqUsername = ctxUsername } else if reqUsername != ctxUsername && !isCtxAdmin { // When CANCEL QUERY is run as a SQL statement, sessionUser is always root // and the user who ran the statement is passed as req.Username. - return errRequiresAdmin + return privchecker.ErrRequiresAdmin } // A user can always cancel their own sessions/queries. @@ -294,9 +278,9 @@ func (b *baseStatusServer) checkCancelPrivilege( // checked inside getUserAndRole above. isReqAdmin := isCtxAdmin if reqUsername != ctxUsername { - isReqAdmin, err = b.privilegeChecker.hasAdminRole(ctx, reqUsername) + isReqAdmin, err = b.privilegeChecker.HasAdminRole(ctx, reqUsername) if err != nil { - return serverError(ctx, err) + return srverrors.ServerError(ctx, err) } } @@ -307,24 +291,24 @@ func (b *baseStatusServer) checkCancelPrivilege( // Must have CANCELQUERY privilege to cancel other users' // sessions/queries. - hasGlobalCancelQuery, err := b.privilegeChecker.hasGlobalPrivilege(ctx, reqUsername, privilege.CANCELQUERY) + hasGlobalCancelQuery, err := b.privilegeChecker.HasGlobalPrivilege(ctx, reqUsername, privilege.CANCELQUERY) if err != nil { - return serverError(ctx, err) + return srverrors.ServerError(ctx, err) } if !hasGlobalCancelQuery { - hasRoleCancelQuery, err := b.privilegeChecker.hasRoleOption(ctx, reqUsername, roleoption.CANCELQUERY) + hasRoleCancelQuery, err := b.privilegeChecker.HasRoleOption(ctx, reqUsername, roleoption.CANCELQUERY) if err != nil { - return serverError(ctx, err) + return srverrors.ServerError(ctx, err) } if !hasRoleCancelQuery { - return errRequiresRoleOption(roleoption.CANCELQUERY) + return privchecker.ErrRequiresRoleOption(roleoption.CANCELQUERY) } } // Non-admins cannot cancel admins' sessions/queries. - isSessionAdmin, err := b.privilegeChecker.hasAdminRole(ctx, sessionUsername) + isSessionAdmin, err := b.privilegeChecker.HasAdminRole(ctx, sessionUsername) if err != nil { - return serverError(ctx, err) + return srverrors.ServerError(ctx, err) } if isSessionAdmin { return status.Error( @@ -338,11 +322,11 @@ func (b *baseStatusServer) checkCancelPrivilege( func (b *baseStatusServer) ListLocalContentionEvents( ctx context.Context, _ *serverpb.ListContentionEventsRequest, ) (*serverpb.ListContentionEventsResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = b.AnnotateCtx(ctx) - if err := b.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if err := b.privilegeChecker.RequireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -355,11 +339,11 @@ func (b *baseStatusServer) ListLocalContentionEvents( func (b *baseStatusServer) ListLocalDistSQLFlows( ctx context.Context, _ *serverpb.ListDistSQLFlowsRequest, ) (*serverpb.ListDistSQLFlowsResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = b.AnnotateCtx(ctx) - if err := b.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if err := b.privilegeChecker.RequireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -555,7 +539,7 @@ func newStatusServer( ambient log.AmbientContext, st *cluster.Settings, cfg *base.Config, - adminAuthzCheck *adminPrivilegeChecker, + adminAuthzCheck privchecker.CheckerForRPCHandlers, db *kv.DB, metricSource metricMarshaler, rpcCtx *rpc.Context, @@ -602,7 +586,7 @@ func newSystemStatusServer( ambient log.AmbientContext, st *cluster.Settings, cfg *base.Config, - adminAuthzCheck *adminPrivilegeChecker, + adminAuthzCheck privchecker.CheckerForRPCHandlers, db *kv.DB, gossip *gossip.Gossip, metricSource metricMarshaler, @@ -698,11 +682,11 @@ func (s *statusServer) dialNode( func (s *systemStatusServer) Gossip( ctx context.Context, req *serverpb.GossipRequest, ) (*gossip.InfoStatus, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -718,7 +702,7 @@ func (s *systemStatusServer) Gossip( } status, err := s.dialNode(ctx, nodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return status.Gossip(ctx, req) } @@ -726,11 +710,11 @@ func (s *systemStatusServer) Gossip( func (s *systemStatusServer) EngineStats( ctx context.Context, req *serverpb.EngineStatsRequest, ) (*serverpb.EngineStatsResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -743,7 +727,7 @@ func (s *systemStatusServer) EngineStats( if !local { status, err := s.dialNode(ctx, nodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return status.EngineStats(ctx, req) } @@ -760,7 +744,7 @@ func (s *systemStatusServer) EngineStats( return nil }) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return resp, nil } @@ -769,11 +753,11 @@ func (s *systemStatusServer) EngineStats( func (s *systemStatusServer) Allocator( ctx context.Context, req *serverpb.AllocatorRequest, ) (*serverpb.AllocatorResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if err := s.privilegeChecker.RequireViewClusterMetadataPermission(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -786,7 +770,7 @@ func (s *systemStatusServer) Allocator( if !local { status, err := s.dialNode(ctx, nodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return status.Allocator(ctx, req) } @@ -839,7 +823,7 @@ func (s *systemStatusServer) Allocator( return nil }) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return output, nil } @@ -868,7 +852,7 @@ func (s *systemStatusServer) CriticalNodes( ctx context.Context, req *serverpb.CriticalNodesRequest, ) (*serverpb.CriticalNodesResponse, error) { ctx = s.AnnotateCtx(ctx) - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { return nil, err } conformance, err := s.node.SpanConfigConformance( @@ -909,10 +893,10 @@ func (s *systemStatusServer) CriticalNodes( func (s *systemStatusServer) AllocatorRange( ctx context.Context, req *serverpb.AllocatorRangeRequest, ) (*serverpb.AllocatorRangeResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx) + err := s.privilegeChecker.RequireViewClusterMetadataPermission(ctx) if err != nil { return nil, err } @@ -955,7 +939,7 @@ func (s *systemStatusServer) AllocatorRange( return nil }) }); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } } @@ -989,7 +973,7 @@ func (s *systemStatusServer) AllocatorRange( } fmt.Fprintf(&buf, "n%d: %s", nodeID, err) } - return nil, serverErrorf(ctx, "%v", buf) + return nil, srverrors.ServerErrorf(ctx, "%v", buf) } return &serverpb.AllocatorRangeResponse{}, nil } @@ -998,11 +982,11 @@ func (s *systemStatusServer) AllocatorRange( func (s *statusServer) Certificates( ctx context.Context, req *serverpb.CertificatesRequest, ) (*serverpb.CertificatesResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -1019,21 +1003,21 @@ func (s *statusServer) Certificates( if !local { status, err := s.dialNode(ctx, nodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return status.Certificates(ctx, req) } cm, err := s.rpcCtx.GetCertificateManager() if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } // The certificate manager gives us a list of CertInfo objects to avoid // making security depend on serverpb. certs, err := cm.ListCertificates() if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } cr := &serverpb.CertificatesResponse{} @@ -1053,7 +1037,7 @@ func (s *statusServer) Certificates( case security.ClientPem: details.Type = serverpb.CertificateDetails_CLIENT default: - return nil, serverErrorf(ctx, "unknown certificate type %v for file %s", cert.FileUsage, cert.Filename) + return nil, srverrors.ServerErrorf(ctx, "unknown certificate type %v for file %s", cert.FileUsage, cert.Filename) } if cert.Error == nil { @@ -1119,11 +1103,11 @@ func extractCertFields(contents []byte, details *serverpb.CertificateDetails) er func (s *statusServer) Details( ctx context.Context, req *serverpb.DetailsRequest, ) (*serverpb.DetailsResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -1135,7 +1119,7 @@ func (s *statusServer) Details( if !local { status, err := s.dialNode(ctx, nodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return status.Details(ctx, req) } @@ -1160,11 +1144,11 @@ func (s *statusServer) Details( func (s *statusServer) GetFiles( ctx context.Context, req *serverpb.GetFilesRequest, ) (*serverpb.GetFilesResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -1176,7 +1160,7 @@ func (s *statusServer) GetFiles( if !local { status, err := s.dialNode(ctx, nodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return status.GetFiles(ctx, req) } @@ -1216,11 +1200,11 @@ func checkFilePattern(pattern string) error { func (s *statusServer) LogFilesList( ctx context.Context, req *serverpb.LogFilesListRequest, ) (*serverpb.LogFilesListResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -1232,14 +1216,14 @@ func (s *statusServer) LogFilesList( if !local { status, err := s.dialNode(ctx, nodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return status.LogFilesList(ctx, req) } log.Flush() logFiles, err := log.ListLogFiles() if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return &serverpb.LogFilesListResponse{Files: logFiles}, nil } @@ -1251,11 +1235,11 @@ func (s *statusServer) LogFilesList( func (s *statusServer) LogFile( ctx context.Context, req *serverpb.LogFileRequest, ) (*serverpb.LogEntriesResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -1267,7 +1251,7 @@ func (s *statusServer) LogFile( if !local { status, err := s.dialNode(ctx, nodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return status.LogFile(ctx, req) } @@ -1281,14 +1265,14 @@ func (s *statusServer) LogFile( // Read the logs. reader, err := log.GetLogReader(req.File) if err != nil { - return nil, serverError(ctx, errors.Wrapf(err, "log file %q could not be opened", req.File)) + return nil, srverrors.ServerError(ctx, errors.Wrapf(err, "log file %q could not be opened", req.File)) } defer reader.Close() var resp serverpb.LogEntriesResponse decoder, err := log.NewEntryDecoder(reader, inputEditMode) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } // Unless we're the system tenant, clients should only be able // to view logs that pertain to their own tenant. Set the filter @@ -1303,7 +1287,7 @@ func (s *statusServer) LogFile( if err == io.EOF { break } - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } if tenantIDFilter != "" && entry.TenantID != tenantIDFilter { continue @@ -1349,11 +1333,11 @@ func parseInt64WithDefault(s string, defaultValue int64) (int64, error) { func (s *statusServer) Logs( ctx context.Context, req *serverpb.LogsRequest, ) (*serverpb.LogEntriesResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -1365,7 +1349,7 @@ func (s *statusServer) Logs( if !local { status, err := s.dialNode(ctx, nodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return status.Logs(ctx, req) } @@ -1412,7 +1396,7 @@ func (s *statusServer) Logs( entries, err := log.FetchEntriesFromFiles( startTimestamp, endTimestamp, int(maxEntries), regex, inputEditMode) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } out := &serverpb.LogEntriesResponse{} @@ -1437,11 +1421,11 @@ func (s *statusServer) Logs( func (s *statusServer) Stacks( ctx context.Context, req *serverpb.StacksRequest, ) (*serverpb.JSONResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -1454,7 +1438,7 @@ func (s *statusServer) Stacks( if !local { status, err := s.dialNode(ctx, nodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return status.Stacks(ctx, req) } @@ -1638,7 +1622,7 @@ func (s *statusServer) fetchProfileFromAllNodes( response.profDataByNodeID[nodeID] = &profData{err: err} } if err := s.iterateNodes(ctx, opName, dialFn, nodeFn, responseFn, errorFn); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } var data []byte switch req.Type { @@ -1663,11 +1647,11 @@ func (s *statusServer) fetchProfileFromAllNodes( func (s *statusServer) Profile( ctx context.Context, req *serverpb.ProfileRequest, ) (*serverpb.JSONResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -1686,7 +1670,7 @@ func (s *statusServer) Profile( if !local { status, err := s.dialNode(ctx, nodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return status.Profile(ctx, req) } @@ -1709,61 +1693,22 @@ func (s *systemStatusServer) Regions( ) (*serverpb.RegionsResponse, error) { resp, _, err := s.nodesHelper(ctx, 0 /* limit */, 0 /* offset */) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return regionsResponseFromNodesResponse(resp), nil } -func regionsResponseFromNodesResponse(nr *serverpb.NodesResponse) *serverpb.RegionsResponse { - regionsToZones := make(map[string]map[string]struct{}) - for _, node := range nr.Nodes { - var region string - var zone string - for _, tier := range node.Desc.Locality.Tiers { - switch tier.Key { - case "region": - region = tier.Value - case "zone", "availability-zone", "az": - zone = tier.Value - } - } - if region == "" { - continue - } - if _, ok := regionsToZones[region]; !ok { - regionsToZones[region] = make(map[string]struct{}) - } - if zone != "" { - regionsToZones[region][zone] = struct{}{} - } - } - ret := &serverpb.RegionsResponse{ - Regions: make(map[string]*serverpb.RegionsResponse_Region, len(regionsToZones)), - } - for region, zones := range regionsToZones { - zonesArr := make([]string, 0, len(zones)) - for z := range zones { - zonesArr = append(zonesArr, z) - } - sort.Strings(zonesArr) - ret.Regions[region] = &serverpb.RegionsResponse_Region{ - Zones: zonesArr, - } - } - return ret -} - // NodesList returns a list of nodes with their corresponding addresses. func (s *statusServer) NodesList( ctx context.Context, _ *serverpb.NodesListRequest, ) (*serverpb.NodesListResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) // The node status contains details about the command line, network // addresses, env vars etc which are admin-only. - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -1784,17 +1729,17 @@ func (s *statusServer) NodesList( func (s *systemStatusServer) Nodes( ctx context.Context, req *serverpb.NodesRequest, ) (*serverpb.NodesResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx) + err := s.privilegeChecker.RequireViewClusterMetadataPermission(ctx) if err != nil { return nil, err } resp, _, err := s.nodesHelper(ctx, 0 /* limit */, 0 /* offset */) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return resp, nil } @@ -1810,7 +1755,7 @@ func (s *statusServer) NodesUI( ctx = s.AnnotateCtx(ctx) hasViewClusterMetadata := false - err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx) + err := s.privilegeChecker.RequireViewClusterMetadataPermission(ctx) if err != nil { if !grpcutil.IsAuthError(err) { return nil, err @@ -1821,7 +1766,7 @@ func (s *statusServer) NodesUI( internalResp, err := s.sqlServer.tenantConnect.Nodes(ctx, req) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } resp := &serverpb.NodesResponseExternal{ @@ -1838,11 +1783,11 @@ func (s *statusServer) NodesUI( func (s *systemStatusServer) NodesUI( ctx context.Context, req *serverpb.NodesRequest, ) (*serverpb.NodesResponseExternal, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) hasViewClusterMetadata := false - err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx) + err := s.privilegeChecker.RequireViewClusterMetadataPermission(ctx) if err != nil { if !grpcutil.IsAuthError(err) { return nil, err @@ -1853,7 +1798,7 @@ func (s *systemStatusServer) NodesUI( internalResp, _, err := s.nodesHelper(ctx, 0 /* limit */, 0 /* offset */) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } resp := &serverpb.NodesResponseExternal{ Nodes: make([]serverpb.NodeResponse, len(internalResp.Nodes)), @@ -1866,107 +1811,11 @@ func (s *systemStatusServer) NodesUI( return resp, nil } -func nodeStatusToResp(n *statuspb.NodeStatus, hasViewClusterMetadata bool) serverpb.NodeResponse { - tiers := make([]serverpb.Tier, len(n.Desc.Locality.Tiers)) - for j, t := range n.Desc.Locality.Tiers { - tiers[j] = serverpb.Tier{ - Key: t.Key, - Value: t.Value, - } - } - - activity := make(map[roachpb.NodeID]serverpb.NodeResponse_NetworkActivity, len(n.Activity)) - for k, v := range n.Activity { - activity[k] = serverpb.NodeResponse_NetworkActivity{ - Latency: v.Latency, - } - } - - nodeDescriptor := serverpb.NodeDescriptor{ - NodeID: n.Desc.NodeID, - Address: util.UnresolvedAddr{}, - Attrs: roachpb.Attributes{}, - Locality: serverpb.Locality{ - Tiers: tiers, - }, - ServerVersion: serverpb.Version{ - Major: n.Desc.ServerVersion.Major, - Minor: n.Desc.ServerVersion.Minor, - Patch: n.Desc.ServerVersion.Patch, - Internal: n.Desc.ServerVersion.Internal, - }, - BuildTag: n.Desc.BuildTag, - StartedAt: n.Desc.StartedAt, - LocalityAddress: nil, - ClusterName: n.Desc.ClusterName, - SQLAddress: util.UnresolvedAddr{}, - } - - statuses := make([]serverpb.StoreStatus, len(n.StoreStatuses)) - for i, ss := range n.StoreStatuses { - statuses[i] = serverpb.StoreStatus{ - Desc: serverpb.StoreDescriptor{ - StoreID: ss.Desc.StoreID, - Attrs: ss.Desc.Attrs, - Node: nodeDescriptor, - Capacity: ss.Desc.Capacity, - - Properties: roachpb.StoreProperties{ - ReadOnly: ss.Desc.Properties.ReadOnly, - Encrypted: ss.Desc.Properties.Encrypted, - }, - }, - Metrics: ss.Metrics, - } - if fsprops := ss.Desc.Properties.FileStoreProperties; fsprops != nil { - sfsprops := &roachpb.FileStoreProperties{ - FsType: fsprops.FsType, - } - if hasViewClusterMetadata { - sfsprops.Path = fsprops.Path - sfsprops.BlockDevice = fsprops.BlockDevice - sfsprops.MountPoint = fsprops.MountPoint - sfsprops.MountOptions = fsprops.MountOptions - } - statuses[i].Desc.Properties.FileStoreProperties = sfsprops - } - } - - resp := serverpb.NodeResponse{ - Desc: nodeDescriptor, - BuildInfo: n.BuildInfo, - StartedAt: n.StartedAt, - UpdatedAt: n.UpdatedAt, - Metrics: n.Metrics, - StoreStatuses: statuses, - Args: nil, - Env: nil, - Latencies: n.Latencies, - Activity: activity, - TotalSystemMemory: n.TotalSystemMemory, - NumCpus: n.NumCpus, - } - - if hasViewClusterMetadata { - resp.Args = n.Args - resp.Env = n.Env - resp.Desc.Attrs = n.Desc.Attrs - resp.Desc.Address = n.Desc.Address - resp.Desc.LocalityAddress = n.Desc.LocalityAddress - resp.Desc.SQLAddress = n.Desc.SQLAddress - for _, n := range resp.StoreStatuses { - n.Desc.Node = resp.Desc - } - } - - return resp -} - // ListNodesInternal is a helper function for the benefit of SQL exclusively. // It skips the privilege check, assuming that SQL is doing privilege checking already. // // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func (s *systemStatusServer) ListNodesInternal( ctx context.Context, req *serverpb.NodesRequest, ) (*serverpb.NodesResponse, error) { @@ -1975,7 +1824,7 @@ func (s *systemStatusServer) ListNodesInternal( } // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func getNodeStatuses( ctx context.Context, db *kv.DB, limit, offset int, ) (statuses []statuspb.NodeStatus, next int, _ error) { @@ -2005,11 +1854,11 @@ func getNodeStatuses( } // Note that the function returns plain errors, and it is the caller's -// responsibility to convert them to serverErrors. +// responsibility to convert them to srverrors.ServerErrors. func (s *systemStatusServer) nodesHelper( ctx context.Context, limit, offset int, ) (*serverpb.NodesResponse, int, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) statuses, next, err := getNodeStatuses(ctx, s.db, limit, offset) @@ -2037,18 +1886,18 @@ func (s *systemStatusServer) nodesHelper( func (s *statusServer) Node( ctx context.Context, req *serverpb.NodeRequest, ) (*statuspb.NodeStatus, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) // The node status contains details about the command line, network // addresses, env vars etc which are admin-only. - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } - // NB: not using serverError() here since nodeStatus + // NB: not using srverrors.ServerError() here since nodeStatus // already returns a proper gRPC error status. return s.nodeStatus(ctx, req) } @@ -2065,13 +1914,13 @@ func (s *statusServer) nodeStatus( b := &kv.Batch{} b.Get(key) if err := s.db.Run(ctx, b); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } var nodeStatus statuspb.NodeStatus if err := b.Results[0].Rows[0].ValueProto(&nodeStatus); err != nil { err = errors.Wrapf(err, "could not unmarshal NodeStatus from %s", key) - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return &nodeStatus, nil } @@ -2079,19 +1928,19 @@ func (s *statusServer) nodeStatus( func (s *statusServer) NodeUI( ctx context.Context, req *serverpb.NodeRequest, ) (*serverpb.NodeResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) // The node status contains details about the command line, network // addresses, env vars etc which are admin-only. - _, isAdmin, err := s.privilegeChecker.getUserAndRole(ctx) + _, isAdmin, err := s.privilegeChecker.GetUserAndRole(ctx) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } nodeStatus, err := s.nodeStatus(ctx, req) if err != nil { - // NB: not using serverError() here since nodeStatus + // NB: not using srverrors.ServerError() here since nodeStatus // already returns a proper gRPC error status. return nil, err } @@ -2103,7 +1952,7 @@ func (s *statusServer) NodeUI( func (s *systemStatusServer) NetworkConnectivity( ctx context.Context, req *serverpb.NetworkConnectivityRequest, ) (*serverpb.NetworkConnectivityResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) response := &serverpb.NetworkConnectivityResponse{ @@ -2114,13 +1963,13 @@ func (s *systemStatusServer) NetworkConnectivity( if len(req.NodeID) > 0 { sourceNodeID, local, err := s.parseNodeID(req.NodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } if !local { statusClient, err := s.dialNode(ctx, sourceNodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return statusClient.NetworkConnectivity(ctx, req) } @@ -2133,7 +1982,7 @@ func (s *systemStatusServer) NetworkConnectivity( return nil }) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } latencies := s.rpcCtx.RemoteClocks.AllLatencies() @@ -2191,7 +2040,7 @@ func (s *systemStatusServer) NetworkConnectivity( } if err := s.iterateNodes(ctx, "network connectivity", dialFn, nodeFn, responseFn, errorFn); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return response, nil @@ -2201,7 +2050,7 @@ func (s *systemStatusServer) NetworkConnectivity( func (s *statusServer) Metrics( ctx context.Context, req *serverpb.MetricsRequest, ) (*serverpb.JSONResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) nodeID, local, err := s.parseNodeID(req.NodeId) @@ -2212,13 +2061,13 @@ func (s *statusServer) Metrics( if !local { status, err := s.dialNode(ctx, nodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return status.Metrics(ctx, req) } j, err := marshalJSONResponse(s.metricSource) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return j, nil } @@ -2227,18 +2076,18 @@ func (s *statusServer) Metrics( func (s *systemStatusServer) RaftDebug( ctx context.Context, req *serverpb.RaftDebugRequest, ) (*serverpb.RaftDebugResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if err := s.privilegeChecker.RequireViewClusterMetadataPermission(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } nodes, err := s.ListNodesInternal(ctx, nil) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } mu := struct { @@ -2353,10 +2202,10 @@ func (s *systemStatusServer) Ranges( func (s *systemStatusServer) rangesHelper( ctx context.Context, req *serverpb.RangesRequest, limit, offset int, ) (*serverpb.RangesResponse, int, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx) + err := s.privilegeChecker.RequireViewClusterMetadataPermission(ctx) if err != nil { return nil, 0, err } @@ -2387,7 +2236,7 @@ func (s *systemStatusServer) rangesHelper( convertRaftStatus := func(raftStatus *raft.Status) serverpb.RaftState { if raftStatus == nil { return serverpb.RaftState{ - State: raftStateDormant, + State: RaftStateDormant, } } @@ -2556,11 +2405,11 @@ func (s *systemStatusServer) rangesHelper( func (t *statusServer) TenantRanges( ctx context.Context, req *serverpb.TenantRangesRequest, ) (*serverpb.TenantRangesResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = t.AnnotateCtx(ctx) // The tenant range report contains replica metadata which is admin-only. - if _, err := t.privilegeChecker.requireAdminUser(ctx); err != nil { + if _, err := t.privilegeChecker.RequireAdminUser(ctx); err != nil { return nil, err } @@ -2570,9 +2419,9 @@ func (t *statusServer) TenantRanges( func (s *systemStatusServer) TenantRanges( ctx context.Context, req *serverpb.TenantRangesRequest, ) (*serverpb.TenantRangesResponse, error) { - forwardSQLIdentityThroughRPCCalls(ctx) + authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { return nil, err } @@ -2644,7 +2493,7 @@ func (s *systemStatusServer) TenantRanges( } else { statusServer, err := s.dialNode(ctx, nodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } resp, err = statusServer.Ranges(ctx, nodeReq) @@ -2720,11 +2569,11 @@ func (s *systemStatusServer) TenantRanges( func (s *systemStatusServer) HotRanges( ctx context.Context, req *serverpb.HotRangesRequest, ) (*serverpb.HotRangesResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if err := s.privilegeChecker.RequireViewClusterMetadataPermission(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -2749,7 +2598,7 @@ func (s *systemStatusServer) HotRanges( // Only hot ranges from one non-local node. status, err := s.dialNode(ctx, requestedNodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return status.HotRanges(ctx, req) } @@ -2775,7 +2624,7 @@ func (s *systemStatusServer) HotRanges( } if err := s.iterateNodes(ctx, "hot ranges", dialFn, nodeFn, responseFn, errorFn); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return response, nil @@ -2793,7 +2642,7 @@ func (t *statusServer) HotRangesV2( ) (*serverpb.HotRangesResponseV2, error) { ctx = t.AnnotateCtx(ctx) - err := t.privilegeChecker.requireViewClusterMetadataPermission(ctx) + err := t.privilegeChecker.RequireViewClusterMetadataPermission(ctx) if err != nil { return nil, err } @@ -2806,9 +2655,9 @@ func (t *statusServer) HotRangesV2( func (s *systemStatusServer) HotRangesV2( ctx context.Context, req *serverpb.HotRangesRequest, ) (*serverpb.HotRangesResponseV2, error) { - ctx = s.AnnotateCtx(forwardSQLIdentityThroughRPCCalls(ctx)) + ctx = s.AnnotateCtx(authserver.ForwardSQLIdentityThroughRPCCalls(ctx)) - err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx) + err := s.privilegeChecker.RequireViewClusterMetadataPermission(ctx) if err != nil { return nil, err } @@ -3034,7 +2883,7 @@ func (s *statusServer) KeyVisSamples( ctx context.Context, req *serverpb.KeyVisSamplesRequest, ) (*serverpb.KeyVisSamplesResponse, error) { - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { return nil, err } @@ -3079,11 +2928,11 @@ func (s *statusServer) KeyVisSamples( func (s *statusServer) Range( ctx context.Context, req *serverpb.RangeRequest, ) (*serverpb.RangeResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if err := s.privilegeChecker.RequireViewClusterMetadataPermission(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -3128,7 +2977,7 @@ func (s *statusServer) Range( if err := s.iterateNodes( ctx, fmt.Sprintf("details about range %d", req.RangeId), dialFn, nodeFn, responseFn, errorFn, ); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return response, nil } @@ -3139,7 +2988,7 @@ func (s *statusServer) ListLocalSessions( ) (*serverpb.ListSessionsResponse, error) { sessions, err := s.getLocalSessions(ctx, req) if err != nil { - // NB: not using serverError() here since getLocalSessions + // NB: not using srverrors.ServerError() here since getLocalSessions // already returns a proper gRPC error status. return nil, err } @@ -3198,7 +3047,7 @@ func (s *statusServer) iterateNodes( } // Issue the requests concurrently. - sem := quotapool.NewIntPool("node status", maxConcurrentRequests) + sem := quotapool.NewIntPool("node status", apiconstants.MaxConcurrentRequests) ctx, cancel := s.stopper.WithCancelOnQuiesce(ctx) defer cancel() for nodeID := range nodeStatuses { @@ -3292,7 +3141,7 @@ func (s *statusServer) paginatedIterateNodes( paginator.init() // Issue the requests concurrently. - sem := quotapool.NewIntPool("node status", maxConcurrentPaginatedRequests) + sem := quotapool.NewIntPool("node status", apiconstants.MaxConcurrentPaginatedRequests) ctx, cancel := s.stopper.WithCancelOnQuiesce(ctx) defer cancel() for idx, nodeID := range nodeIDs { @@ -3370,18 +3219,18 @@ func (s *statusServer) listSessionsHelper( func (s *statusServer) ListSessions( ctx context.Context, req *serverpb.ListSessionsRequest, ) (*serverpb.ListSessionsResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if _, _, err := s.privilegeChecker.getUserAndRole(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if _, _, err := s.privilegeChecker.GetUserAndRole(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } resp, _, err := s.listSessionsHelper(ctx, req, 0 /* limit */, paginationState{}) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return resp, nil } @@ -3391,7 +3240,7 @@ func (s *statusServer) ListSessions( func (s *statusServer) CancelSession( ctx context.Context, req *serverpb.CancelSessionRequest, ) (*serverpb.CancelSessionResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) sessionIDBytes := req.SessionID @@ -3411,7 +3260,7 @@ func (s *statusServer) CancelSession( Error: fmt.Sprintf("session ID %s not found", sessionID), }, nil } - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return status.CancelSession(ctx, req) } @@ -3429,7 +3278,7 @@ func (s *statusServer) CancelSession( } if err := s.checkCancelPrivilege(ctx, reqUsername, session.SessionUser()); err != nil { - // NB: not using serverError() here since the priv checker + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -3443,7 +3292,7 @@ func (s *statusServer) CancelSession( func (s *statusServer) CancelQuery( ctx context.Context, req *serverpb.CancelQueryRequest, ) (*serverpb.CancelQueryResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) queryID, err := clusterunique.IDFromString(req.QueryID) @@ -3464,7 +3313,7 @@ func (s *statusServer) CancelQuery( Error: fmt.Sprintf("query ID %s not found", queryID), }, nil } - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return status.CancelQuery(ctx, req) } @@ -3482,7 +3331,7 @@ func (s *statusServer) CancelQuery( } if err := s.checkCancelPrivilege(ctx, reqUsername, session.SessionUser()); err != nil { - // NB: not using serverError() here since the priv checker + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -3542,7 +3391,7 @@ func (s *statusServer) CancelQueryByKey( } // This request needs to be forwarded to another node. - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) client, err := s.dialNode(ctx, roachpb.NodeID(req.SQLInstanceID)) if err != nil { @@ -3556,12 +3405,12 @@ func (s *statusServer) CancelQueryByKey( func (s *statusServer) ListContentionEvents( ctx context.Context, req *serverpb.ListContentionEventsRequest, ) (*serverpb.ListContentionEventsResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) // Check permissions early to avoid fan-out to all nodes. - if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if err := s.privilegeChecker.RequireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -3595,7 +3444,7 @@ func (s *statusServer) ListContentionEvents( } if err := s.iterateNodes(ctx, "contention events list", dialFn, nodeFn, responseFn, errorFn); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return &response, nil } @@ -3603,12 +3452,12 @@ func (s *statusServer) ListContentionEvents( func (s *statusServer) ListDistSQLFlows( ctx context.Context, request *serverpb.ListDistSQLFlowsRequest, ) (*serverpb.ListDistSQLFlowsResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) // Check permissions early to avoid fan-out to all nodes. - if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if err := s.privilegeChecker.RequireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -3642,65 +3491,20 @@ func (s *statusServer) ListDistSQLFlows( } if err := s.iterateNodes(ctx, "distsql flows list", dialFn, nodeFn, responseFn, errorFn); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return &response, nil } -// mergeDistSQLRemoteFlows takes in two slices of DistSQL remote flows (that -// satisfy the contract of serverpb.ListDistSQLFlowsResponse) and merges them -// together while adhering to the same contract. -// -// It is assumed that if serverpb.DistSQLRemoteFlows for a particular FlowID -// appear in both arguments - let's call them flowsA and flowsB for a and b, -// respectively - then there are no duplicate NodeIDs among flowsA and flowsB. -func mergeDistSQLRemoteFlows(a, b []serverpb.DistSQLRemoteFlows) []serverpb.DistSQLRemoteFlows { - maxLength := len(a) - if len(b) > len(a) { - maxLength = len(b) - } - result := make([]serverpb.DistSQLRemoteFlows, 0, maxLength) - aIter, bIter := 0, 0 - for aIter < len(a) && bIter < len(b) { - cmp := bytes.Compare(a[aIter].FlowID.GetBytes(), b[bIter].FlowID.GetBytes()) - if cmp < 0 { - result = append(result, a[aIter]) - aIter++ - } else if cmp > 0 { - result = append(result, b[bIter]) - bIter++ - } else { - r := a[aIter] - // No need to perform any kind of de-duplication because a - // particular flow will be reported at most once by each node in the - // cluster. - r.Infos = append(r.Infos, b[bIter].Infos...) - sort.Slice(r.Infos, func(i, j int) bool { - return r.Infos[i].NodeID < r.Infos[j].NodeID - }) - result = append(result, r) - aIter++ - bIter++ - } - } - if aIter < len(a) { - result = append(result, a[aIter:]...) - } - if bIter < len(b) { - result = append(result, b[bIter:]...) - } - return result -} - func (s *statusServer) ListExecutionInsights( ctx context.Context, req *serverpb.ListExecutionInsightsRequest, ) (*serverpb.ListExecutionInsightsResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) // Check permissions early to avoid fan-out to all nodes. - if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if err := s.privilegeChecker.RequireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -3717,7 +3521,7 @@ func (s *statusServer) ListExecutionInsights( } statusClient, err := s.dialNode(ctx, requestedNodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return statusClient.ListExecutionInsights(ctx, &localRequest) } @@ -3747,7 +3551,7 @@ func (s *statusServer) ListExecutionInsights( } if err := s.iterateNodes(ctx, "execution insights list", dialFn, nodeFn, responseFn, errorFn); err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return &response, nil } @@ -3758,8 +3562,8 @@ func (s *statusServer) SpanStats( ctx context.Context, req *roachpb.SpanStatsRequest, ) (*roachpb.SpanStatsResponse, error) { ctx = s.AnnotateCtx(ctx) - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -3769,10 +3573,10 @@ func (s *statusServer) SpanStats( func (s *systemStatusServer) SpanStats( ctx context.Context, req *roachpb.SpanStatsRequest, ) (*roachpb.SpanStatsResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -3803,7 +3607,7 @@ func (s *systemStatusServer) SpanStats( func (s *statusServer) Diagnostics( ctx context.Context, req *serverpb.DiagnosticsRequest, ) (*diagnosticspb.DiagnosticReport, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) nodeID, local, err := s.parseNodeID(req.NodeId) if err != nil { @@ -3813,7 +3617,7 @@ func (s *statusServer) Diagnostics( if !local { status, err := s.dialNode(ctx, nodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return status.Diagnostics(ctx, req) } @@ -3825,11 +3629,11 @@ func (s *statusServer) Diagnostics( func (s *systemStatusServer) Stores( ctx context.Context, req *serverpb.StoresRequest, ) (*serverpb.StoresResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if err := s.privilegeChecker.requireViewClusterMetadataPermission(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if err := s.privilegeChecker.RequireViewClusterMetadataPermission(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -3842,7 +3646,7 @@ func (s *systemStatusServer) Stores( if !local { status, err := s.dialNode(ctx, nodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return status.Stores(ctx, req) } @@ -3871,7 +3675,7 @@ func (s *systemStatusServer) Stores( return nil }) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return resp, nil } @@ -3953,11 +3757,11 @@ func (si *systemInfoOnce) systemInfo(ctx context.Context) serverpb.SystemInfo { func (s *statusServer) JobRegistryStatus( ctx context.Context, req *serverpb.JobRegistryStatusRequest, ) (*serverpb.JobRegistryStatusResponse, error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -3969,7 +3773,7 @@ func (s *statusServer) JobRegistryStatus( if !local { status, err := s.dialNode(ctx, nodeID) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } return status.JobRegistryStatus(ctx, req) } @@ -3992,10 +3796,10 @@ func (s *statusServer) JobRegistryStatus( func (s *statusServer) JobStatus( ctx context.Context, req *serverpb.JobStatusRequest, ) (*serverpb.JobStatusResponse, error) { - ctx = s.AnnotateCtx(forwardSQLIdentityThroughRPCCalls(ctx)) + ctx = s.AnnotateCtx(authserver.ForwardSQLIdentityThroughRPCCalls(ctx)) - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { - // NB: not using serverError() here since the priv checker + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { + // NB: not using srverrors.ServerError() here since the priv checker // already returns a proper gRPC error status. return nil, err } @@ -4005,7 +3809,7 @@ func (s *statusServer) JobStatus( if je := (*jobs.JobNotFoundError)(nil); errors.As(err, &je) { return nil, status.Errorf(codes.NotFound, "%v", err) } - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } res := &jobspb.Job{ Payload: &jobspb.Payload{}, @@ -4025,8 +3829,8 @@ func (s *statusServer) JobStatus( func (s *statusServer) TxnIDResolution( ctx context.Context, req *serverpb.TxnIDResolutionRequest, ) (*serverpb.TxnIDResolutionResponse, error) { - ctx = s.AnnotateCtx(forwardSQLIdentityThroughRPCCalls(ctx)) - if _, err := s.privilegeChecker.requireAdminUser(ctx); err != nil { + ctx = s.AnnotateCtx(authserver.ForwardSQLIdentityThroughRPCCalls(ctx)) + if _, err := s.privilegeChecker.RequireAdminUser(ctx); err != nil { return nil, err } @@ -4049,23 +3853,23 @@ func (s *statusServer) TxnIDResolution( func (s *statusServer) TransactionContentionEvents( ctx context.Context, req *serverpb.TransactionContentionEventsRequest, ) (*serverpb.TransactionContentionEventsResponse, error) { - ctx = s.AnnotateCtx(forwardSQLIdentityThroughRPCCalls(ctx)) + ctx = s.AnnotateCtx(authserver.ForwardSQLIdentityThroughRPCCalls(ctx)) - if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { + if err := s.privilegeChecker.RequireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { return nil, err } - user, isAdmin, err := s.privilegeChecker.getUserAndRole(ctx) + user, isAdmin, err := s.privilegeChecker.GetUserAndRole(ctx) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } shouldRedactContendingKey := false if !isAdmin { shouldRedactContendingKey, err = - s.privilegeChecker.hasRoleOption(ctx, user, roleoption.VIEWACTIVITYREDACTED) + s.privilegeChecker.HasRoleOption(ctx, user, roleoption.VIEWACTIVITYREDACTED) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } } @@ -4132,7 +3936,7 @@ func (s *statusServer) GetJobProfilerExecutionDetails( ) (*serverpb.GetJobProfilerExecutionDetailResponse, error) { ctx = s.AnnotateCtx(ctx) // TODO(adityamaru): Figure out the correct privileges required to get execution details. - _, err := s.privilegeChecker.requireAdminUser(ctx) + _, err := s.privilegeChecker.RequireAdminUser(ctx) if err != nil { return nil, err } @@ -4153,7 +3957,7 @@ func (s *statusServer) ListJobProfilerExecutionDetails( ) (*serverpb.ListJobProfilerExecutionDetailsResponse, error) { ctx = s.AnnotateCtx(ctx) // TODO(adityamaru): Figure out the correct privileges required to get execution details. - _, err := s.privilegeChecker.requireAdminUser(ctx) + _, err := s.privilegeChecker.RequireAdminUser(ctx) if err != nil { return nil, err } diff --git a/pkg/server/status_local_file_retrieval.go b/pkg/server/status_local_file_retrieval.go index 2d3c6363c68d..e390fd1d3d4e 100644 --- a/pkg/server/status_local_file_retrieval.go +++ b/pkg/server/status_local_file_retrieval.go @@ -24,6 +24,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/server/debug" "github.com/cockroachdb/cockroach/pkg/server/debug/pprofui" "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srverrors" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/util/allstacks" "google.golang.org/grpc/codes" @@ -49,7 +50,7 @@ func profileLocal( } if err := pprof.StartCPUProfile(&buf); err != nil { // Construct a gRPC error to return to the caller. - return serverError(ctx, err) + return srverrors.ServerError(ctx, err) } defer pprof.StopCPUProfile() select { diff --git a/pkg/server/status_test.go b/pkg/server/status_test.go deleted file mode 100644 index 1a3e23508612..000000000000 --- a/pkg/server/status_test.go +++ /dev/null @@ -1,3889 +0,0 @@ -// Copyright 2015 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package server - -import ( - "bytes" - "context" - gosql "database/sql" - "encoding/hex" - "fmt" - "math" - "net/url" - "os" - "path/filepath" - "reflect" - "sort" - "strconv" - "strings" - "sync" - "testing" - "time" - - "github.com/cockroachdb/cockroach/pkg/base" - "github.com/cockroachdb/cockroach/pkg/base/serverident" - "github.com/cockroachdb/cockroach/pkg/build" - "github.com/cockroachdb/cockroach/pkg/gossip" - "github.com/cockroachdb/cockroach/pkg/jobs" - "github.com/cockroachdb/cockroach/pkg/jobs/jobspb" - "github.com/cockroachdb/cockroach/pkg/keys" - "github.com/cockroachdb/cockroach/pkg/kv/kvserver" - "github.com/cockroachdb/cockroach/pkg/kv/kvserver/allocator/plan" - "github.com/cockroachdb/cockroach/pkg/roachpb" - "github.com/cockroachdb/cockroach/pkg/rpc" - "github.com/cockroachdb/cockroach/pkg/security/username" - "github.com/cockroachdb/cockroach/pkg/server/diagnostics/diagnosticspb" - "github.com/cockroachdb/cockroach/pkg/server/serverpb" - "github.com/cockroachdb/cockroach/pkg/server/status/statuspb" - "github.com/cockroachdb/cockroach/pkg/spanconfig" - "github.com/cockroachdb/cockroach/pkg/sql" - "github.com/cockroachdb/cockroach/pkg/sql/appstatspb" - "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" - "github.com/cockroachdb/cockroach/pkg/sql/clusterunique" - "github.com/cockroachdb/cockroach/pkg/sql/execinfrapb" - "github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants" - "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" - "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" - "github.com/cockroachdb/cockroach/pkg/sql/sqlstats" - "github.com/cockroachdb/cockroach/pkg/sql/tests" - "github.com/cockroachdb/cockroach/pkg/storage" - "github.com/cockroachdb/cockroach/pkg/storage/enginepb" - "github.com/cockroachdb/cockroach/pkg/testutils" - "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" - "github.com/cockroachdb/cockroach/pkg/testutils/skip" - "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" - "github.com/cockroachdb/cockroach/pkg/ts" - "github.com/cockroachdb/cockroach/pkg/util" - "github.com/cockroachdb/cockroach/pkg/util/grunning" - "github.com/cockroachdb/cockroach/pkg/util/httputil" - "github.com/cockroachdb/cockroach/pkg/util/leaktest" - "github.com/cockroachdb/cockroach/pkg/util/log" - "github.com/cockroachdb/cockroach/pkg/util/log/logpb" - "github.com/cockroachdb/cockroach/pkg/util/protoutil" - "github.com/cockroachdb/cockroach/pkg/util/stop" - "github.com/cockroachdb/cockroach/pkg/util/timeutil" - "github.com/cockroachdb/cockroach/pkg/util/uuid" - "github.com/cockroachdb/errors" - "github.com/cockroachdb/logtags" - "github.com/kr/pretty" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func getStatusJSONProto( - ts serverutils.TestServerInterface, path string, response protoutil.Message, -) error { - return serverutils.GetJSONProto(ts, statusPrefix+path, response) -} - -func postStatusJSONProto( - ts serverutils.TestServerInterface, path string, request, response protoutil.Message, -) error { - return serverutils.PostJSONProto(ts, statusPrefix+path, request, response) -} - -func getStatusJSONProtoWithAdminOption( - ts serverutils.TestServerInterface, path string, response protoutil.Message, isAdmin bool, -) error { - return serverutils.GetJSONProtoWithAdminOption(ts, statusPrefix+path, response, isAdmin) -} - -func postStatusJSONProtoWithAdminOption( - ts serverutils.TestServerInterface, - path string, - request, response protoutil.Message, - isAdmin bool, -) error { - return serverutils.PostJSONProtoWithAdminOption(ts, statusPrefix+path, request, response, isAdmin) -} - -// TestStatusJson verifies that status endpoints return expected Json results. -// The content type of the responses is always httputil.JSONContentType. -func TestStatusJson(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) - defer s.Stopper().Stop(context.Background()) - ts := s.(*TestServer) - - nodeID := ts.Gossip().NodeID.Get() - addr, err := ts.Gossip().GetNodeIDAddress(nodeID) - if err != nil { - t.Fatal(err) - } - sqlAddr, err := ts.Gossip().GetNodeIDSQLAddress(nodeID) - if err != nil { - t.Fatal(err) - } - - var nodes serverpb.NodesResponse - testutils.SucceedsSoon(t, func() error { - if err := getStatusJSONProto(s, "nodes", &nodes); err != nil { - t.Fatal(err) - } - - if len(nodes.Nodes) == 0 { - return errors.Errorf("expected non-empty node list, got: %v", nodes) - } - return nil - }) - - for _, path := range []string{ - statusPrefix + "details/local", - statusPrefix + "details/" + strconv.FormatUint(uint64(nodeID), 10), - } { - var details serverpb.DetailsResponse - if err := serverutils.GetJSONProto(s, path, &details); err != nil { - t.Fatal(err) - } - if a, e := details.NodeID, nodeID; a != e { - t.Errorf("expected: %d, got: %d", e, a) - } - if a, e := details.Address, *addr; a != e { - t.Errorf("expected: %v, got: %v", e, a) - } - if a, e := details.SQLAddress, *sqlAddr; a != e { - t.Errorf("expected: %v, got: %v", e, a) - } - if a, e := details.BuildInfo, build.GetInfo(); a != e { - t.Errorf("expected: %v, got: %v", e, a) - } - } -} - -// TestHealthTelemetry confirms that hits on some status endpoints increment -// feature telemetry counters. -func TestHealthTelemetry(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) - defer s.Stopper().Stop(context.Background()) - - rows, err := db.Query("SELECT * FROM crdb_internal.feature_usage WHERE feature_name LIKE 'monitoring%' AND usage_count > 0;") - defer func() { - if err := rows.Close(); err != nil { - t.Fatal(err) - } - }() - if err != nil { - t.Fatal(err) - } - - initialCounts := make(map[string]int) - for rows.Next() { - var featureName string - var usageCount int - - if err := rows.Scan(&featureName, &usageCount); err != nil { - t.Fatal(err) - } - - initialCounts[featureName] = usageCount - } - - var details serverpb.DetailsResponse - if err := serverutils.GetJSONProto(s, "/health", &details); err != nil { - t.Fatal(err) - } - if _, err := getText(s, s.AdminURL().WithPath(statusPrefix+"vars").String()); err != nil { - t.Fatal(err) - } - - expectedCounts := map[string]int{ - "monitoring.prometheus.vars": 1, - "monitoring.health.details": 1, - } - - rows2, err := db.Query("SELECT feature_name, usage_count FROM crdb_internal.feature_usage WHERE feature_name LIKE 'monitoring%' AND usage_count > 0;") - defer func() { - if err := rows2.Close(); err != nil { - t.Fatal(err) - } - }() - if err != nil { - t.Fatal(err) - } - - for rows2.Next() { - var featureName string - var usageCount int - - if err := rows2.Scan(&featureName, &usageCount); err != nil { - t.Fatal(err) - } - - usageCount -= initialCounts[featureName] - if count, ok := expectedCounts[featureName]; ok { - if count != usageCount { - t.Fatalf("expected %d count for feature %s, got %d", count, featureName, usageCount) - } - delete(expectedCounts, featureName) - } - } - - if len(expectedCounts) > 0 { - t.Fatalf("%d expected telemetry counters not emitted", len(expectedCounts)) - } -} - -// TestStatusGossipJson ensures that the output response for the full gossip -// info contains the required fields. -func TestStatusGossipJson(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) - defer s.Stopper().Stop(context.Background()) - - var data gossip.InfoStatus - if err := getStatusJSONProto(s, "gossip/local", &data); err != nil { - t.Fatal(err) - } - if _, ok := data.Infos["first-range"]; !ok { - t.Errorf("no first-range info returned: %v", data) - } - if _, ok := data.Infos["cluster-id"]; !ok { - t.Errorf("no clusterID info returned: %v", data) - } - if _, ok := data.Infos["node:1"]; !ok { - t.Errorf("no node 1 info returned: %v", data) - } -} - -// TestStatusEngineStatsJson ensures that the output response for the engine -// stats contains the required fields. -func TestStatusEngineStatsJson(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - dir, cleanupFn := testutils.TempDir(t) - defer cleanupFn() - - s, err := serverutils.StartServerRaw(t, base.TestServerArgs{ - StoreSpecs: []base.StoreSpec{{ - Path: dir, - }}, - }) - if err != nil { - t.Fatal(err) - } - defer s.Stopper().Stop(context.Background()) - - t.Logf("using admin URL %s", s.AdminURL()) - - var engineStats serverpb.EngineStatsResponse - // Using SucceedsSoon because we have seen in the wild that - // occasionally requests don't go through with error "transport: - // error while dialing: connection interrupted (did the remote node - // shut down or are there networking issues?)" - testutils.SucceedsSoon(t, func() error { - return getStatusJSONProto(s, "enginestats/local", &engineStats) - }) - - if len(engineStats.Stats) != 1 { - t.Fatal(errors.Errorf("expected one engine stats, got: %v", engineStats)) - } - - if engineStats.Stats[0].EngineType == enginepb.EngineTypePebble || - engineStats.Stats[0].EngineType == enginepb.EngineTypeDefault { - // Pebble does not have RocksDB style TickersAnd Histogram. - return - } - - tickers := engineStats.Stats[0].TickersAndHistograms.Tickers - if len(tickers) == 0 { - t.Fatal(errors.Errorf("expected non-empty tickers list, got: %v", tickers)) - } - allTickersZero := true - for _, ticker := range tickers { - if ticker != 0 { - allTickersZero = false - } - } - if allTickersZero { - t.Fatal(errors.Errorf("expected some tickers nonzero, got: %v", tickers)) - } - - histograms := engineStats.Stats[0].TickersAndHistograms.Histograms - if len(histograms) == 0 { - t.Fatal(errors.Errorf("expected non-empty histograms list, got: %v", histograms)) - } - allHistogramsZero := true - for _, histogram := range histograms { - if histogram.Max == 0 { - allHistogramsZero = false - } - } - if allHistogramsZero { - t.Fatal(errors.Errorf("expected some histograms nonzero, got: %v", histograms)) - } -} - -// startServer will start a server with a short scan interval, wait for -// the scan to complete, and return the server. The caller is -// responsible for stopping the server. -func startServer(t *testing.T) *TestServer { - tsI, _, kvDB := serverutils.StartServer(t, base.TestServerArgs{ - StoreSpecs: []base.StoreSpec{ - base.DefaultTestStoreSpec, - base.DefaultTestStoreSpec, - base.DefaultTestStoreSpec, - }, - Knobs: base.TestingKnobs{ - Store: &kvserver.StoreTestingKnobs{ - // Now that we allow same node rebalances, disable it in these tests, - // as they dont expect replicas to move. - ReplicaPlannerKnobs: plan.ReplicaPlannerTestingKnobs{ - DisableReplicaRebalancing: true, - }, - }, - }, - }) - - ts := tsI.(*TestServer) - - // Make sure the range is spun up with an arbitrary read command. We do not - // expect a specific response. - if _, err := kvDB.Get(context.Background(), "a"); err != nil { - t.Fatal(err) - } - - // Make sure the node status is available. This is done by forcing stores to - // publish their status, synchronizing to the event feed with a canary - // event, and then forcing the server to write summaries immediately. - if err := ts.node.computeMetricsPeriodically(context.Background(), map[*kvserver.Store]*storage.MetricsForInterval{}, 0); err != nil { - t.Fatalf("error publishing store statuses: %s", err) - } - - if err := ts.WriteSummaries(); err != nil { - t.Fatalf("error writing summaries: %s", err) - } - - return ts -} - -func newRPCTestContext(ctx context.Context, ts *TestServer, cfg *base.Config) *rpc.Context { - var c base.NodeIDContainer - ctx = logtags.AddTag(ctx, "n", &c) - rpcContext := rpc.NewContext(ctx, rpc.ContextOptions{ - TenantID: roachpb.SystemTenantID, - NodeID: &c, - Config: cfg, - Clock: ts.Clock().WallClock(), - ToleratedOffset: ts.Clock().ToleratedOffset(), - Stopper: ts.Stopper(), - Settings: ts.ClusterSettings(), - Knobs: rpc.ContextTestingKnobs{NoLoopbackDialer: true}, - }) - // Ensure that the RPC client context validates the server cluster ID. - // This ensures that a test where the server is restarted will not let - // its test RPC client talk to a server started by an unrelated concurrent test. - rpcContext.StorageClusterID.Set(context.Background(), ts.StorageClusterID()) - return rpcContext -} - -// TestStatusGetFiles tests the GetFiles endpoint. -func TestStatusGetFiles(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - tempDir, cleanupFn := testutils.TempDir(t) - defer cleanupFn() - - storeSpec := base.StoreSpec{Path: tempDir} - - tsI, _, _ := serverutils.StartServer(t, base.TestServerArgs{ - StoreSpecs: []base.StoreSpec{ - storeSpec, - }, - }) - ts := tsI.(*TestServer) - defer ts.Stopper().Stop(context.Background()) - - rootConfig := testutils.NewTestBaseContext(username.RootUserName()) - rpcContext := newRPCTestContext(context.Background(), ts, rootConfig) - - url := ts.ServingRPCAddr() - nodeID := ts.NodeID() - conn, err := rpcContext.GRPCDialNode(url, nodeID, rpc.DefaultClass).Connect(context.Background()) - if err != nil { - t.Fatal(err) - } - client := serverpb.NewStatusClient(conn) - - // Test fetching heap files. - t.Run("heap", func(t *testing.T) { - const testFilesNo = 3 - for i := 0; i < testFilesNo; i++ { - testHeapDir := filepath.Join(storeSpec.Path, "logs", base.HeapProfileDir) - testHeapFile := filepath.Join(testHeapDir, fmt.Sprintf("heap%d.pprof", i)) - if err := os.MkdirAll(testHeapDir, os.ModePerm); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(testHeapFile, []byte(fmt.Sprintf("I'm heap file %d", i)), 0644); err != nil { - t.Fatal(err) - } - } - - request := serverpb.GetFilesRequest{ - NodeId: "local", Type: serverpb.FileType_HEAP, Patterns: []string{"heap*"}} - response, err := client.GetFiles(context.Background(), &request) - if err != nil { - t.Fatal(err) - } - - if a, e := len(response.Files), testFilesNo; a != e { - t.Errorf("expected %d files(s), found %d", e, a) - } - - for i, file := range response.Files { - expectedFileName := fmt.Sprintf("heap%d.pprof", i) - if file.Name != expectedFileName { - t.Fatalf("expected file name %s, found %s", expectedFileName, file.Name) - } - expectedFileContents := []byte(fmt.Sprintf("I'm heap file %d", i)) - if !bytes.Equal(file.Contents, expectedFileContents) { - t.Fatalf("expected file contents %s, found %s", expectedFileContents, file.Contents) - } - } - }) - - // Test fetching goroutine files. - t.Run("goroutines", func(t *testing.T) { - const testFilesNo = 3 - for i := 0; i < testFilesNo; i++ { - testGoroutineDir := filepath.Join(storeSpec.Path, "logs", base.GoroutineDumpDir) - testGoroutineFile := filepath.Join(testGoroutineDir, fmt.Sprintf("goroutine_dump%d.txt.gz", i)) - if err := os.MkdirAll(testGoroutineDir, os.ModePerm); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(testGoroutineFile, []byte(fmt.Sprintf("Goroutine dump %d", i)), 0644); err != nil { - t.Fatal(err) - } - } - - request := serverpb.GetFilesRequest{ - NodeId: "local", Type: serverpb.FileType_GOROUTINES, Patterns: []string{"*"}} - response, err := client.GetFiles(context.Background(), &request) - if err != nil { - t.Fatal(err) - } - - if a, e := len(response.Files), testFilesNo; a != e { - t.Errorf("expected %d files(s), found %d", e, a) - } - - for i, file := range response.Files { - expectedFileName := fmt.Sprintf("goroutine_dump%d.txt.gz", i) - if file.Name != expectedFileName { - t.Fatalf("expected file name %s, found %s", expectedFileName, file.Name) - } - expectedFileContents := []byte(fmt.Sprintf("Goroutine dump %d", i)) - if !bytes.Equal(file.Contents, expectedFileContents) { - t.Fatalf("expected file contents %s, found %s", expectedFileContents, file.Contents) - } - } - }) - - // Testing path separators in pattern. - t.Run("path separators", func(t *testing.T) { - request := serverpb.GetFilesRequest{NodeId: "local", ListOnly: true, - Type: serverpb.FileType_HEAP, Patterns: []string{"pattern/with/separators"}} - _, err = client.GetFiles(context.Background(), &request) - if !testutils.IsError(err, "invalid pattern: cannot have path seperators") { - t.Errorf("GetFiles: path separators allowed in pattern") - } - }) - - // Testing invalid filetypes. - t.Run("filetypes", func(t *testing.T) { - request := serverpb.GetFilesRequest{NodeId: "local", ListOnly: true, - Type: -1, Patterns: []string{"*"}} - _, err = client.GetFiles(context.Background(), &request) - if !testutils.IsError(err, "unknown file type: -1") { - t.Errorf("GetFiles: invalid file type allowed") - } - }) -} - -// TestStatusLocalLogs checks to ensure that local/logfiles, -// local/logfiles/{filename} and local/log function -// correctly. -func TestStatusLocalLogs(t *testing.T) { - defer leaktest.AfterTest(t)() - if log.V(3) { - skip.IgnoreLint(t, "Test only works with low verbosity levels") - } - - s := log.ScopeWithoutShowLogs(t) - defer s.Close(t) - - // This test cares about the number of output files. Ensure - // there's just one. - defer s.SetupSingleFileLogging()() - - ts := startServer(t) - defer ts.Stopper().Stop(context.Background()) - - // Log an error of each main type which we expect to be able to retrieve. - // The resolution of our log timestamps is such that it's possible to get - // two subsequent log messages with the same timestamp. This test will fail - // when that occurs. By adding a small sleep in here after each timestamp to - // ensures this isn't the case and that the log filtering doesn't filter out - // the log entires we're looking for. The value of 20 μs was chosen because - // the log timestamps have a fidelity of 10 μs and thus doubling that should - // be a sufficient buffer. - // See util/log/clog.go formatHeader() for more details. - const sleepBuffer = time.Microsecond * 20 - timestamp := timeutil.Now().UnixNano() - time.Sleep(sleepBuffer) - log.Errorf(context.Background(), "TestStatusLocalLogFile test message-Error") - time.Sleep(sleepBuffer) - timestampE := timeutil.Now().UnixNano() - time.Sleep(sleepBuffer) - log.Warningf(context.Background(), "TestStatusLocalLogFile test message-Warning") - time.Sleep(sleepBuffer) - timestampEW := timeutil.Now().UnixNano() - time.Sleep(sleepBuffer) - log.Infof(context.Background(), "TestStatusLocalLogFile test message-Info") - time.Sleep(sleepBuffer) - timestampEWI := timeutil.Now().UnixNano() - - var wrapper serverpb.LogFilesListResponse - if err := getStatusJSONProto(ts, "logfiles/local", &wrapper); err != nil { - t.Fatal(err) - } - if a, e := len(wrapper.Files), 1; a != e { - t.Fatalf("expected %d log files; got %d", e, a) - } - - // Check each individual log can be fetched and is non-empty. - var foundInfo, foundWarning, foundError bool - for _, file := range wrapper.Files { - var wrapper serverpb.LogEntriesResponse - if err := getStatusJSONProto(ts, "logfiles/local/"+file.Name, &wrapper); err != nil { - t.Fatal(err) - } - for _, entry := range wrapper.Entries { - switch strings.TrimSpace(entry.Message) { - case "TestStatusLocalLogFile test message-Error": - foundError = true - case "TestStatusLocalLogFile test message-Warning": - foundWarning = true - case "TestStatusLocalLogFile test message-Info": - foundInfo = true - } - } - } - - if !(foundInfo && foundWarning && foundError) { - t.Errorf("expected to find test messages in %v", wrapper.Files) - } - - type levelPresence struct { - Error, Warning, Info bool - } - - testCases := []struct { - MaxEntities int - StartTimestamp int64 - EndTimestamp int64 - Pattern string - levelPresence - }{ - // Test filtering by log severity. - // // Test entry limit. Ignore Info/Warning/Error filters. - {1, timestamp, timestampEWI, "", levelPresence{false, false, false}}, - {2, timestamp, timestampEWI, "", levelPresence{false, false, false}}, - {3, timestamp, timestampEWI, "", levelPresence{false, false, false}}, - // Test filtering in different timestamp windows. - {0, timestamp, timestamp, "", levelPresence{false, false, false}}, - {0, timestamp, timestampE, "", levelPresence{true, false, false}}, - {0, timestampE, timestampEW, "", levelPresence{false, true, false}}, - {0, timestampEW, timestampEWI, "", levelPresence{false, false, true}}, - {0, timestamp, timestampEW, "", levelPresence{true, true, false}}, - {0, timestampE, timestampEWI, "", levelPresence{false, true, true}}, - {0, timestamp, timestampEWI, "", levelPresence{true, true, true}}, - // Test filtering by regexp pattern. - {0, 0, 0, "Info", levelPresence{false, false, true}}, - {0, 0, 0, "Warning", levelPresence{false, true, false}}, - {0, 0, 0, "Error", levelPresence{true, false, false}}, - {0, 0, 0, "Info|Error|Warning", levelPresence{true, true, true}}, - {0, 0, 0, "Nothing", levelPresence{false, false, false}}, - } - - for i, testCase := range testCases { - var url bytes.Buffer - fmt.Fprintf(&url, "logs/local?level=") - if testCase.MaxEntities > 0 { - fmt.Fprintf(&url, "&max=%d", testCase.MaxEntities) - } - if testCase.StartTimestamp > 0 { - fmt.Fprintf(&url, "&start_time=%d", testCase.StartTimestamp) - } - if testCase.StartTimestamp > 0 { - fmt.Fprintf(&url, "&end_time=%d", testCase.EndTimestamp) - } - if len(testCase.Pattern) > 0 { - fmt.Fprintf(&url, "&pattern=%s", testCase.Pattern) - } - - var wrapper serverpb.LogEntriesResponse - path := url.String() - if err := getStatusJSONProto(ts, path, &wrapper); err != nil { - t.Fatal(err) - } - - if testCase.MaxEntities > 0 { - if a, e := len(wrapper.Entries), testCase.MaxEntities; a != e { - t.Errorf("%d expected %d entries, got %d: \n%+v", i, e, a, wrapper.Entries) - } - } else { - var actual levelPresence - var logsBuf bytes.Buffer - for _, entry := range wrapper.Entries { - fmt.Fprintln(&logsBuf, entry.Message) - - switch strings.TrimSpace(entry.Message) { - case "TestStatusLocalLogFile test message-Error": - actual.Error = true - case "TestStatusLocalLogFile test message-Warning": - actual.Warning = true - case "TestStatusLocalLogFile test message-Info": - actual.Info = true - } - } - - if testCase.levelPresence != actual { - t.Errorf("%d: expected %+v at %s, got:\n%s", i, testCase, path, logsBuf.String()) - } - } - } -} - -// TestStatusLocalLogsTenantFilter checks to ensure that local/logfiles, -// local/logfiles/{filename} and local/log function correctly filter -// logs by tenant ID. -func TestStatusLocalLogsTenantFilter(t *testing.T) { - defer leaktest.AfterTest(t)() - if log.V(3) { - skip.IgnoreLint(t, "Test only works with low verbosity levels") - } - - s := log.ScopeWithoutShowLogs(t) - defer s.Close(t) - - // This test cares about the number of output files. Ensure - // there's just one. - defer s.SetupSingleFileLogging()() - - ts := startServer(t) - defer ts.Stopper().Stop(context.Background()) - - ctxSysTenant := context.Background() - ctxSysTenant = context.WithValue(ctxSysTenant, serverident.ServerIdentificationContextKey{}, &idProvider{ - tenantID: roachpb.SystemTenantID, - clusterID: &base.ClusterIDContainer{}, - serverID: &base.NodeIDContainer{}, - }) - appTenantID := roachpb.MustMakeTenantID(uint64(2)) - ctxAppTenant := context.Background() - ctxAppTenant = context.WithValue(ctxAppTenant, serverident.ServerIdentificationContextKey{}, &idProvider{ - tenantID: appTenantID, - clusterID: &base.ClusterIDContainer{}, - serverID: &base.NodeIDContainer{}, - }) - - // Log an error of each main type which we expect to be able to retrieve. - // The resolution of our log timestamps is such that it's possible to get - // two subsequent log messages with the same timestamp. This test will fail - // when that occurs. By adding a small sleep in here after each timestamp to - // ensures this isn't the case and that the log filtering doesn't filter out - // the log entires we're looking for. The value of 20 μs was chosen because - // the log timestamps have a fidelity of 10 μs and thus doubling that should - // be a sufficient buffer. - // See util/log/clog.go formatHeader() for more details. - const sleepBuffer = time.Microsecond * 20 - log.Errorf(ctxSysTenant, "system tenant msg 1") - time.Sleep(sleepBuffer) - log.Errorf(ctxAppTenant, "app tenant msg 1") - time.Sleep(sleepBuffer) - log.Warningf(ctxSysTenant, "system tenant msg 2") - time.Sleep(sleepBuffer) - log.Warningf(ctxAppTenant, "app tenant msg 2") - time.Sleep(sleepBuffer) - log.Infof(ctxSysTenant, "system tenant msg 3") - time.Sleep(sleepBuffer) - log.Infof(ctxAppTenant, "app tenant msg 3") - timestampEnd := timeutil.Now().UnixNano() - - var listFilesResp serverpb.LogFilesListResponse - if err := getStatusJSONProto(ts, "logfiles/local", &listFilesResp); err != nil { - t.Fatal(err) - } - require.Lenf(t, listFilesResp.Files, 1, "expected 1 log files; got %d", len(listFilesResp.Files)) - - testCases := []struct { - name string - tenantID roachpb.TenantID - }{ - { - name: "logs for system tenant does not apply filter", - tenantID: roachpb.SystemTenantID, - }, - { - name: "logs for app tenant applies tenant ID filter", - tenantID: appTenantID, - }, - } - - for _, testCase := range testCases { - // Non-system tenant servers filter to the tenant that they belong to. - // Set the server tenant ID for this test case. - ts.rpcContext.TenantID = testCase.tenantID - - var logfilesResp serverpb.LogEntriesResponse - if err := getStatusJSONProto(ts, "logfiles/local/"+listFilesResp.Files[0].Name, &logfilesResp); err != nil { - t.Fatal(err) - } - var logsResp serverpb.LogEntriesResponse - if err := getStatusJSONProto(ts, fmt.Sprintf("logs/local?end_time=%d", timestampEnd), &logsResp); err != nil { - t.Fatal(err) - } - - // Run the same set of assertions against both responses, as they are both expected - // to contain the log entries we're looking for. - for _, response := range []serverpb.LogEntriesResponse{logfilesResp, logsResp} { - sysTenantFound, appTenantFound := false, false - for _, logEntry := range response.Entries { - if !strings.HasSuffix(logEntry.File, "status_test.go") { - continue - } - - if testCase.tenantID != roachpb.SystemTenantID { - require.Equal(t, logEntry.TenantID, testCase.tenantID.String()) - } else { - // Logs use the literal system tenant ID when tagging. - if logEntry.TenantID == fmt.Sprintf("%d", roachpb.SystemTenantID.InternalValue) { - sysTenantFound = true - } else if logEntry.TenantID == appTenantID.String() { - appTenantFound = true - } - } - } - if testCase.tenantID == roachpb.SystemTenantID { - require.True(t, sysTenantFound) - require.True(t, appTenantFound) - } - } - } -} - -// TestStatusLogRedaction checks that the log file retrieval RPCs -// honor the redaction flags. -func TestStatusLogRedaction(t *testing.T) { - defer leaktest.AfterTest(t)() - - testData := []struct { - redactableLogs bool // logging flag - redact bool // RPC request flag - expectedMessage string - expectedRedactable bool // redactable bit in result entries - }{ - // Note: all combinations of (redactableLogs, redact) must be tested below. - - // If there were no markers to start with (redactableLogs=false), we - // introduce markers around the entire message to indicate it's not known to - // be safe. - {false, false, `‹THISISSAFE THISISUNSAFE›`, true}, - // redact=true must be conservative and redact everything out if - // there were no markers to start with (redactableLogs=false). - {false, true, `‹×›`, false}, - // redact=false keeps whatever was in the log file. - {true, false, `THISISSAFE ‹THISISUNSAFE›`, true}, - // Whether or not to keep the redactable markers has no influence - // on the output of redaction, just on the presence of the - // "redactable" marker. In any case no information is leaked. - {true, true, `THISISSAFE ‹×›`, true}, - } - - testutils.RunTrueAndFalse(t, "redactableLogs", - func(t *testing.T, redactableLogs bool) { - s := log.ScopeWithoutShowLogs(t) - defer s.Close(t) - - // This test cares about the number of output files. Ensure - // there's just one. - defer s.SetupSingleFileLogging()() - - // Apply the redactable log boolean for this test. - defer log.TestingSetRedactable(redactableLogs)() - - ts := startServer(t) - defer ts.Stopper().Stop(context.Background()) - - // Log something. - log.Infof(context.Background(), "THISISSAFE %s", "THISISUNSAFE") - - // Determine the log file name. - var wrapper serverpb.LogFilesListResponse - if err := getStatusJSONProto(ts, "logfiles/local", &wrapper); err != nil { - t.Fatal(err) - } - // We expect only the main log. - if a, e := len(wrapper.Files), 1; a != e { - t.Fatalf("expected %d log files; got %d: %+v", e, a, wrapper.Files) - } - file := wrapper.Files[0] - // Assert that the log that's present is not a stderr log. - if strings.Contains("stderr", file.Name) { - t.Fatalf("expected main log, found %v", file.Name) - } - - for _, tc := range testData { - if tc.redactableLogs != redactableLogs { - continue - } - t.Run(fmt.Sprintf("redact=%v", tc.redact), - func(t *testing.T) { - // checkEntries asserts that the redaction results are - // those expected in tc. - checkEntries := func(entries []logpb.Entry) { - foundMessage := false - for _, entry := range entries { - if !strings.HasSuffix(entry.File, "status_test.go") { - continue - } - foundMessage = true - - assert.Equal(t, tc.expectedMessage, entry.Message) - } - if !foundMessage { - t.Fatalf("did not find expected message from test in log") - } - } - - // Retrieve the log entries with the configured flags using - // the LogFiles() RPC. - logFilesURL := fmt.Sprintf("logfiles/local/%s?redact=%v", file.Name, tc.redact) - var wrapper serverpb.LogEntriesResponse - if err := getStatusJSONProto(ts, logFilesURL, &wrapper); err != nil { - t.Fatal(err) - } - checkEntries(wrapper.Entries) - - // If the test specifies redact=false, check that a non-admin - // user gets a privilege error. - if !tc.redact { - err := getStatusJSONProtoWithAdminOption(ts, logFilesURL, &wrapper, false /* isAdmin */) - if !testutils.IsError(err, "status: 403") { - t.Fatalf("expected privilege error, got %v", err) - } - } - - // Retrieve the log entries using the Logs() RPC. - // Set a high `max` value to ensure we get the log line we're searching for. - logsURL := fmt.Sprintf("logs/local?redact=%v&max=5000", tc.redact) - var wrapper2 serverpb.LogEntriesResponse - if err := getStatusJSONProto(ts, logsURL, &wrapper2); err != nil { - t.Fatal(err) - } - checkEntries(wrapper2.Entries) - - // If the test specifies redact=false, check that a non-admin - // user gets a privilege error. - if !tc.redact { - err := getStatusJSONProtoWithAdminOption(ts, logsURL, &wrapper2, false /* isAdmin */) - if !testutils.IsError(err, "status: 403") { - t.Fatalf("expected privilege error, got %v", err) - } - } - }) - } - }) -} - -// TestNodeStatusResponse verifies that node status returns the expected -// results. -func TestNodeStatusResponse(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s := startServer(t) - defer s.Stopper().Stop(context.Background()) - - wrapper := serverpb.NodesResponse{} - - // Check that the node statuses cannot be accessed via a non-admin account. - if err := getStatusJSONProtoWithAdminOption(s, "nodes", &wrapper, false /* isAdmin */); !testutils.IsError(err, "status: 403") { - t.Fatalf("expected privilege error, got %v", err) - } - - // Now fetch all the node statuses as admin. - if err := getStatusJSONProto(s, "nodes", &wrapper); err != nil { - t.Fatal(err) - } - nodeStatuses := wrapper.Nodes - - if len(nodeStatuses) != 1 { - t.Errorf("too many node statuses returned - expected:1 actual:%d", len(nodeStatuses)) - } - if !s.node.Descriptor.Equal(&nodeStatuses[0].Desc) { - t.Errorf("node status descriptors are not equal\nexpected:%+v\nactual:%+v\n", s.node.Descriptor, nodeStatuses[0].Desc) - } - - // Now fetch each one individually. Loop through the nodeStatuses to use the - // ids only. - for _, oldNodeStatus := range nodeStatuses { - nodeStatus := statuspb.NodeStatus{} - nodeURL := "nodes/" + oldNodeStatus.Desc.NodeID.String() - // Check that the node statuses cannot be accessed via a non-admin account. - if err := getStatusJSONProtoWithAdminOption(s, nodeURL, &nodeStatus, false /* isAdmin */); !testutils.IsError(err, "status: 403") { - t.Fatalf("expected privilege error, got %v", err) - } - - // Now access that node's status. - if err := getStatusJSONProto(s, nodeURL, &nodeStatus); err != nil { - t.Fatal(err) - } - if !s.node.Descriptor.Equal(&nodeStatus.Desc) { - t.Errorf("node status descriptors are not equal\nexpected:%+v\nactual:%+v\n", s.node.Descriptor, nodeStatus.Desc) - } - } -} - -// TestMetricsRecording verifies that Node statistics are periodically recorded -// as time series data. -func TestMetricsRecording(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - ctx := context.Background() - - s, _, kvDB := serverutils.StartServer(t, base.TestServerArgs{}) - defer s.Stopper().Stop(ctx) - - // Verify that metrics for the current timestamp are recorded. This should - // be true very quickly even though DefaultMetricsSampleInterval is large, - // because the server writes an entry eagerly on startup. - testutils.SucceedsSoon(t, func() error { - now := s.Clock().PhysicalNow() - - var data roachpb.InternalTimeSeriesData - for _, keyName := range []string{ - "cr.store.livebytes.1", - "cr.node.sys.go.allocbytes.1", - } { - key := ts.MakeDataKey(keyName, "", ts.Resolution10s, now) - if err := kvDB.GetProto(ctx, key, &data); err != nil { - return err - } - } - return nil - }) -} - -// TestMetricsEndpoint retrieves the metrics endpoint, which is currently only -// used for development purposes. The metrics within the response are verified -// in other tests. -func TestMetricsEndpoint(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s := startServer(t) - defer s.Stopper().Stop(context.Background()) - - if _, err := getText(s, s.AdminURL().WithPath(statusPrefix+"metrics/"+s.Gossip().NodeID.String()).String()); err != nil { - t.Fatal(err) - } -} - -// TestMetricsMetadata ensures that the server's recorder return metrics and -// that each metric has a Name, Help, Unit, and DisplayUnit defined. -func TestMetricsMetadata(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s := startServer(t) - defer s.Stopper().Stop(context.Background()) - - metricsMetadata := s.recorder.GetMetricsMetadata() - - if len(metricsMetadata) < 200 { - t.Fatal("s.recorder.GetMetricsMetadata() failed sanity check; didn't return enough metrics.") - } - - for _, v := range metricsMetadata { - if v.Name == "" { - t.Fatal("metric missing name.") - } - if v.Help == "" { - t.Fatalf("%s missing Help.", v.Name) - } - if v.Measurement == "" { - t.Fatalf("%s missing Measurement.", v.Name) - } - if v.Unit == 0 { - t.Fatalf("%s missing Unit.", v.Name) - } - } -} - -func TestHotRangesResponse(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - ts := startServer(t) - defer ts.Stopper().Stop(context.Background()) - - var hotRangesResp serverpb.HotRangesResponse - if err := getStatusJSONProto(ts, "hotranges", &hotRangesResp); err != nil { - t.Fatal(err) - } - if len(hotRangesResp.HotRangesByNodeID) == 0 { - t.Fatalf("didn't get hot range responses from any nodes") - } - - for nodeID, nodeResp := range hotRangesResp.HotRangesByNodeID { - if len(nodeResp.Stores) == 0 { - t.Errorf("didn't get any stores in hot range response from n%d: %v", - nodeID, nodeResp.ErrorMessage) - } - for _, storeResp := range nodeResp.Stores { - // Only the first store will actually have any ranges on it. - if storeResp.StoreID != roachpb.StoreID(1) { - continue - } - lastQPS := math.MaxFloat64 - if len(storeResp.HotRanges) == 0 { - t.Errorf("didn't get any hot ranges in response from n%d,s%d: %v", - nodeID, storeResp.StoreID, nodeResp.ErrorMessage) - } - for _, r := range storeResp.HotRanges { - if r.Desc.RangeID == 0 || (len(r.Desc.StartKey) == 0 && len(r.Desc.EndKey) == 0) { - t.Errorf("unexpected empty/unpopulated range descriptor: %+v", r.Desc) - } - if r.QueriesPerSecond > 0 { - if r.ReadsPerSecond == 0 && r.WritesPerSecond == 0 && r.ReadBytesPerSecond == 0 && r.WriteBytesPerSecond == 0 { - t.Errorf("qps %.2f > 0, expected either reads=%.2f, writes=%.2f, readBytes=%.2f or writeBytes=%.2f to be non-zero", - r.QueriesPerSecond, r.ReadsPerSecond, r.WritesPerSecond, r.ReadBytesPerSecond, r.WriteBytesPerSecond) - } - // If the architecture doesn't support sampling CPU, it - // will also be zero. - if grunning.Supported() && r.CPUTimePerSecond == 0 { - t.Errorf("qps %.2f > 0, expected cpu=%.2f to be non-zero", - r.QueriesPerSecond, r.CPUTimePerSecond) - } - } - if r.QueriesPerSecond > lastQPS { - t.Errorf("unexpected increase in qps between ranges; prev=%.2f, current=%.2f, desc=%v", - lastQPS, r.QueriesPerSecond, r.Desc) - } - lastQPS = r.QueriesPerSecond - } - } - - } -} - -func TestHotRanges2Response(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - ts := startServer(t) - defer ts.Stopper().Stop(context.Background()) - - var hotRangesResp serverpb.HotRangesResponseV2 - if err := postStatusJSONProto(ts, "v2/hotranges", &serverpb.HotRangesRequest{}, &hotRangesResp); err != nil { - t.Fatal(err) - } - if len(hotRangesResp.Ranges) == 0 { - t.Fatalf("didn't get hot range responses from any nodes") - } - lastQPS := math.MaxFloat64 - for _, r := range hotRangesResp.Ranges { - if r.RangeID == 0 { - t.Errorf("unexpected empty range id: %d", r.RangeID) - } - if r.QPS > 0 { - if r.ReadsPerSecond == 0 && r.WritesPerSecond == 0 && r.ReadBytesPerSecond == 0 && r.WriteBytesPerSecond == 0 { - t.Errorf("qps %.2f > 0, expected either reads=%.2f, writes=%.2f, readBytes=%.2f or writeBytes=%.2f to be non-zero", - r.QPS, r.ReadsPerSecond, r.WritesPerSecond, r.ReadBytesPerSecond, r.WriteBytesPerSecond) - } - // If the architecture doesn't support sampling CPU, it - // will also be zero. - if grunning.Supported() && r.CPUTimePerSecond == 0 { - t.Errorf("qps %.2f > 0, expected cpu=%.2f to be non-zero", r.QPS, r.CPUTimePerSecond) - } - } - if r.QPS > lastQPS { - t.Errorf("unexpected increase in qps between ranges; prev=%.2f, current=%.2f", lastQPS, r.QPS) - } - lastQPS = r.QPS - } -} - -func TestNetworkConnectivity(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - numNodes := 3 - testCluster := serverutils.StartNewTestCluster(t, numNodes, base.TestClusterArgs{ - ReplicationMode: base.ReplicationManual, - }) - ctx := context.Background() - defer testCluster.Stopper().Stop(ctx) - ts := testCluster.Server(0) - - var resp serverpb.NetworkConnectivityResponse - // Should wait because endpoint relies on Gossip. - testutils.SucceedsSoon(t, func() error { - if err := getStatusJSONProto(ts, "connectivity", &resp); err != nil { - return err - } - if len(resp.ErrorsByNodeID) > 0 { - return errors.Errorf("expected no errors but got: %d", len(resp.ErrorsByNodeID)) - } - if len(resp.Connections) < numNodes { - return errors.Errorf("expected results from %d nodes but got: %d", numNodes, len(resp.ErrorsByNodeID)) - } - return nil - }) - // Test when one node is stopped. - stoppedNodeID := testCluster.Server(1).NodeID() - testCluster.Server(1).Stopper().Stop(ctx) - - testutils.SucceedsSoon(t, func() error { - if err := getStatusJSONProto(ts, "connectivity", &resp); err != nil { - return err - } - require.Equal(t, len(resp.Connections), numNodes-1) - fmt.Printf("got status: %s", resp.Connections[ts.NodeID()].Peers[stoppedNodeID].Status.String()) - if resp.Connections[ts.NodeID()].Peers[stoppedNodeID].Status != serverpb.NetworkConnectivityResponse_ERROR { - return errors.New("waiting for connection state to be changed.") - } - if latency := resp.Connections[ts.NodeID()].Peers[stoppedNodeID].Latency; latency > 0 { - return errors.Errorf("expected latency to be 0 but got %s", latency.String()) - } - return nil - }) -} - -func TestHotRanges2ResponseWithViewActivityOptions(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - s, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{}) - defer s.Stopper().Stop(context.Background()) - db := sqlutils.MakeSQLRunner(sqlDB) - - req := &serverpb.HotRangesRequest{} - var hotRangesResp serverpb.HotRangesResponseV2 - if err := postStatusJSONProtoWithAdminOption(s, "v2/hotranges", req, &hotRangesResp, false); err != nil { - if !testutils.IsError(err, "status: 403") { - t.Fatalf("expected privilege error, got %v", err) - } - } - - // Grant VIEWCLUSTERMETADATA and all test should work. - db.Exec(t, fmt.Sprintf("GRANT SYSTEM VIEWCLUSTERMETADATA TO %s", authenticatedUserNameNoAdmin().Normalized())) - if err := postStatusJSONProtoWithAdminOption(s, "v2/hotranges", req, &hotRangesResp, false); err != nil { - t.Fatal(err) - } - - // Grant VIEWACTIVITYREDACTED and all test should get permission errors. - db.Exec(t, fmt.Sprintf("REVOKE SYSTEM VIEWCLUSTERMETADATA FROM %s", authenticatedUserNameNoAdmin().Normalized())) - if err := postStatusJSONProtoWithAdminOption(s, "v2/hotranges", req, &hotRangesResp, false); err != nil { - if !testutils.IsError(err, "status: 403") { - t.Fatalf("expected privilege error, got %v", err) - } - } -} - -func TestRangesResponse(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - defer kvserver.EnableLeaseHistoryForTesting(100)() - ts := startServer(t) - defer ts.Stopper().Stop(context.Background()) - - t.Run("test ranges response", func(t *testing.T) { - // Perform a scan to ensure that all the raft groups are initialized. - if _, err := ts.db.Scan(context.Background(), keys.LocalMax, roachpb.KeyMax, 0); err != nil { - t.Fatal(err) - } - - var response serverpb.RangesResponse - if err := getStatusJSONProto(ts, "ranges/local", &response); err != nil { - t.Fatal(err) - } - if len(response.Ranges) == 0 { - t.Errorf("didn't get any ranges") - } - for _, ri := range response.Ranges { - // Do some simple validation based on the fact that this is a - // single-node cluster. - if ri.RaftState.State != "StateLeader" && ri.RaftState.State != raftStateDormant { - t.Errorf("expected to be Raft leader or dormant, but was '%s'", ri.RaftState.State) - } - expReplica := roachpb.ReplicaDescriptor{ - NodeID: 1, - StoreID: 1, - ReplicaID: 1, - } - if len(ri.State.Desc.InternalReplicas) != 1 || ri.State.Desc.InternalReplicas[0] != expReplica { - t.Errorf("unexpected replica list %+v", ri.State.Desc.InternalReplicas) - } - if ri.State.Lease == nil || ri.State.Lease.Empty() { - t.Error("expected a nontrivial Lease") - } - if ri.State.LastIndex == 0 { - t.Error("expected positive LastIndex") - } - if len(ri.LeaseHistory) == 0 { - t.Error("expected at least one lease history entry") - } - } - }) - - t.Run("test ranges pagination", func(t *testing.T) { - ctx := context.Background() - rpcStopper := stop.NewStopper() - defer rpcStopper.Stop(ctx) - - conn, err := ts.rpcContext.GRPCDialNode(ts.ServingRPCAddr(), ts.NodeID(), rpc.DefaultClass).Connect(ctx) - if err != nil { - t.Fatal(err) - } - client := serverpb.NewStatusClient(conn) - resp1, err := client.Ranges(ctx, &serverpb.RangesRequest{ - Limit: 1, - }) - require.NoError(t, err) - require.Len(t, resp1.Ranges, 1) - require.Equal(t, int(resp1.Next), 1) - - resp2, err := client.Ranges(ctx, &serverpb.RangesRequest{ - Limit: 1, - Offset: resp1.Next, - }) - require.NoError(t, err) - require.Len(t, resp2.Ranges, 1) - require.Equal(t, int(resp2.Next), 2) - - // Verify pagination functions based on ascending RangeID order. - require.True(t, resp1.Ranges[0].State.Desc.RangeID < resp2.Ranges[0].State.Desc.RangeID) - }) -} - -func TestTenantRangesResponse(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - ctx := context.Background() - ts := startServer(t) - defer ts.Stopper().Stop(ctx) - - t.Run("returns error when TenantID not set in ctx", func(t *testing.T) { - rpcStopper := stop.NewStopper() - defer rpcStopper.Stop(ctx) - - conn, err := ts.rpcContext.GRPCDialNode(ts.ServingRPCAddr(), ts.NodeID(), rpc.DefaultClass).Connect(ctx) - if err != nil { - t.Fatal(err) - } - client := serverpb.NewStatusClient(conn) - _, err = client.TenantRanges(ctx, &serverpb.TenantRangesRequest{}) - require.Error(t, err) - require.Contains(t, err.Error(), "no tenant ID found in context") - }) -} - -func TestRaftDebug(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s := startServer(t) - defer s.Stopper().Stop(context.Background()) - - var resp serverpb.RaftDebugResponse - if err := getStatusJSONProto(s, "raft", &resp); err != nil { - t.Fatal(err) - } - if len(resp.Ranges) == 0 { - t.Errorf("didn't get any ranges") - } - - if len(resp.Ranges) < 3 { - t.Errorf("expected more than 2 ranges, got %d", len(resp.Ranges)) - } - - reqURI := "raft" - requestedIDs := []roachpb.RangeID{} - for id := range resp.Ranges { - if len(requestedIDs) == 0 { - reqURI += "?" - } else { - reqURI += "&" - } - reqURI += fmt.Sprintf("range_ids=%d", id) - requestedIDs = append(requestedIDs, id) - if len(requestedIDs) >= 2 { - break - } - } - - if err := getStatusJSONProto(s, reqURI, &resp); err != nil { - t.Fatal(err) - } - - // Make sure we get exactly two ranges back. - if len(resp.Ranges) != 2 { - t.Errorf("expected exactly two ranges in response, got %d", len(resp.Ranges)) - } - - // Make sure the ranges returned are those requested. - for _, reqID := range requestedIDs { - if _, ok := resp.Ranges[reqID]; !ok { - t.Errorf("request URI was %s, but range ID %d not returned: %+v", reqURI, reqID, resp.Ranges) - } - } -} - -// TestStatusVars verifies that prometheus metrics are available via the -// /_status/vars and /_status/load endpoints. -func TestStatusVars(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) - defer s.Stopper().Stop(context.Background()) - - if body, err := getText(s, s.AdminURL().WithPath(statusPrefix+"vars").String()); err != nil { - t.Fatal(err) - } else if !bytes.Contains(body, []byte("# TYPE sql_bytesout counter\nsql_bytesout")) { - t.Errorf("expected sql_bytesout, got: %s", body) - } - if body, err := getText(s, s.AdminURL().WithPath(statusPrefix+"load").String()); err != nil { - t.Fatal(err) - } else if !bytes.Contains(body, []byte("# TYPE sys_cpu_user_ns gauge\nsys_cpu_user_ns")) { - t.Errorf("expected sys_cpu_user_ns, got: %s", body) - } -} - -// TestStatusVarsTxnMetrics verifies that the metrics from the /_status/vars -// endpoint for txns and the special cockroach_restart savepoint are correct. -func TestStatusVarsTxnMetrics(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) - defer db.Close() - defer s.Stopper().Stop(context.Background()) - - if _, err := db.Exec("BEGIN;" + - "SAVEPOINT cockroach_restart;" + - "SELECT 1;" + - "RELEASE SAVEPOINT cockroach_restart;" + - "ROLLBACK;"); err != nil { - t.Fatal(err) - } - - body, err := getText(s, s.AdminURL().WithPath(statusPrefix+"vars").String()) - if err != nil { - t.Fatal(err) - } - if !bytes.Contains(body, []byte("sql_txn_begin_count{node_id=\"1\"} 1")) { - t.Errorf("expected `sql_txn_begin_count{node_id=\"1\"} 1`, got: %s", body) - } - if !bytes.Contains(body, []byte("sql_restart_savepoint_count{node_id=\"1\"} 1")) { - t.Errorf("expected `sql_restart_savepoint_count{node_id=\"1\"} 1`, got: %s", body) - } - if !bytes.Contains(body, []byte("sql_restart_savepoint_release_count{node_id=\"1\"} 1")) { - t.Errorf("expected `sql_restart_savepoint_release_count{node_id=\"1\"} 1`, got: %s", body) - } - if !bytes.Contains(body, []byte("sql_txn_commit_count{node_id=\"1\"} 1")) { - t.Errorf("expected `sql_txn_commit_count{node_id=\"1\"} 1`, got: %s", body) - } - if !bytes.Contains(body, []byte("sql_txn_rollback_count{node_id=\"1\"} 0")) { - t.Errorf("expected `sql_txn_rollback_count{node_id=\"1\"} 0`, got: %s", body) - } -} - -func TestSpanStatsResponse(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - ts := startServer(t) - defer ts.Stopper().Stop(context.Background()) - - httpClient, err := ts.GetAdminHTTPClient() - if err != nil { - t.Fatal(err) - } - - var response roachpb.SpanStatsResponse - span := roachpb.Span{ - Key: roachpb.RKeyMin.AsRawKey(), - EndKey: roachpb.RKeyMax.AsRawKey(), - } - request := roachpb.SpanStatsRequest{ - NodeID: "1", - Spans: []roachpb.Span{span}, - } - - url := ts.AdminURL().WithPath(statusPrefix + "span").String() - if err := httputil.PostJSON(httpClient, url, &request, &response); err != nil { - t.Fatal(err) - } - initialRanges, err := ts.ExpectedInitialRangeCount() - if err != nil { - t.Fatal(err) - } - responseSpanStats := response.SpanToStats[span.String()] - if a, e := int(responseSpanStats.RangeCount), initialRanges; a != e { - t.Errorf("expected %d ranges, found %d", e, a) - } -} - -func TestSpanStatsGRPCResponse(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - ctx := context.Background() - ts := startServer(t) - defer ts.Stopper().Stop(ctx) - - rpcStopper := stop.NewStopper() - defer rpcStopper.Stop(ctx) - rpcContext := newRPCTestContext(ctx, ts, ts.RPCContext().Config) - span := roachpb.Span{ - Key: roachpb.RKeyMin.AsRawKey(), - EndKey: roachpb.RKeyMax.AsRawKey(), - } - request := roachpb.SpanStatsRequest{ - NodeID: "1", - Spans: []roachpb.Span{span}, - } - - url := ts.ServingRPCAddr() - nodeID := ts.NodeID() - conn, err := rpcContext.GRPCDialNode(url, nodeID, rpc.DefaultClass).Connect(ctx) - if err != nil { - t.Fatal(err) - } - client := serverpb.NewStatusClient(conn) - - response, err := client.SpanStats(ctx, &request) - if err != nil { - t.Fatal(err) - } - initialRanges, err := ts.ExpectedInitialRangeCount() - if err != nil { - t.Fatal(err) - } - responseSpanStats := response.SpanToStats[span.String()] - if a, e := int(responseSpanStats.RangeCount), initialRanges; a != e { - t.Fatalf("expected %d ranges, found %d", e, a) - } -} - -func TestNodesGRPCResponse(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - ts := startServer(t) - defer ts.Stopper().Stop(context.Background()) - - rootConfig := testutils.NewTestBaseContext(username.RootUserName()) - rpcContext := newRPCTestContext(context.Background(), ts, rootConfig) - var request serverpb.NodesRequest - - url := ts.ServingRPCAddr() - nodeID := ts.NodeID() - conn, err := rpcContext.GRPCDialNode(url, nodeID, rpc.DefaultClass).Connect(context.Background()) - if err != nil { - t.Fatal(err) - } - client := serverpb.NewStatusClient(conn) - - response, err := client.Nodes(context.Background(), &request) - if err != nil { - t.Fatal(err) - } - - if a, e := len(response.Nodes), 1; a != e { - t.Errorf("expected %d node(s), found %d", e, a) - } -} - -func TestCertificatesResponse(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - ts := startServer(t) - defer ts.Stopper().Stop(context.Background()) - - var response serverpb.CertificatesResponse - if err := getStatusJSONProto(ts, "certificates/local", &response); err != nil { - t.Fatal(err) - } - - // We expect 5 certificates: CA, node, and client certs for root, testuser, testuser2. - if a, e := len(response.Certificates), 5; a != e { - t.Errorf("expected %d certificates, found %d", e, a) - } - - // The response is ordered: CA cert followed by node cert. - cert := response.Certificates[0] - if a, e := cert.Type, serverpb.CertificateDetails_CA; a != e { - t.Errorf("wrong type %s, expected %s", a, e) - } else if cert.ErrorMessage != "" { - t.Errorf("expected cert without error, got %v", cert.ErrorMessage) - } - - cert = response.Certificates[1] - if a, e := cert.Type, serverpb.CertificateDetails_NODE; a != e { - t.Errorf("wrong type %s, expected %s", a, e) - } else if cert.ErrorMessage != "" { - t.Errorf("expected cert without error, got %v", cert.ErrorMessage) - } -} - -func TestDiagnosticsResponse(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) - defer s.Stopper().Stop(context.Background()) - - var resp diagnosticspb.DiagnosticReport - if err := getStatusJSONProto(s, "diagnostics/local", &resp); err != nil { - t.Fatal(err) - } - - // The endpoint just serializes result of getReportingInfo() which is already - // tested elsewhere, so simply verify that we have a non-empty reply. - if expected, actual := s.NodeID(), resp.Node.NodeID; expected != actual { - t.Fatalf("expected %v got %v", expected, actual) - } -} - -func TestRangeResponse(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - defer kvserver.EnableLeaseHistoryForTesting(100)() - ts := startServer(t) - defer ts.Stopper().Stop(context.Background()) - - // Perform a scan to ensure that all the raft groups are initialized. - if _, err := ts.db.Scan(context.Background(), keys.LocalMax, roachpb.KeyMax, 0); err != nil { - t.Fatal(err) - } - - var response serverpb.RangeResponse - if err := getStatusJSONProto(ts, "range/1", &response); err != nil { - t.Fatal(err) - } - - // This is a single node cluster, so only expect a single response. - if e, a := 1, len(response.ResponsesByNodeID); e != a { - t.Errorf("got the wrong number of responses, expected %d, actual %d", e, a) - } - - node1Response := response.ResponsesByNodeID[response.NodeID] - - // The response should come back as valid. - if !node1Response.Response { - t.Errorf("node1's response returned as false, expected true") - } - - // The response should include just the one range. - if e, a := 1, len(node1Response.Infos); e != a { - t.Errorf("got the wrong number of ranges in the response, expected %d, actual %d", e, a) - } - - info := node1Response.Infos[0] - expReplica := roachpb.ReplicaDescriptor{ - NodeID: 1, - StoreID: 1, - ReplicaID: 1, - } - - // Check some other values. - if len(info.State.Desc.InternalReplicas) != 1 || info.State.Desc.InternalReplicas[0] != expReplica { - t.Errorf("unexpected replica list %+v", info.State.Desc.InternalReplicas) - } - - if info.State.Lease == nil || info.State.Lease.Empty() { - t.Error("expected a nontrivial Lease") - } - - if info.State.LastIndex == 0 { - t.Error("expected positive LastIndex") - } - - if len(info.LeaseHistory) == 0 { - t.Error("expected at least one lease history entry") - } -} - -func TestStatusAPICombinedTransactions(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - params, _ := tests.CreateTestServerParams() - params.Knobs.SpanConfig = &spanconfig.TestingKnobs{ManagerDisableJobCreation: true} // TODO(irfansharif): #74919. - testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{ - ServerArgs: params, - }) - ctx := context.Background() - defer testCluster.Stopper().Stop(ctx) - - thirdServer := testCluster.Server(2) - pgURL, cleanupGoDB := sqlutils.PGUrl( - t, thirdServer.ServingSQLAddr(), "CreateConnections" /* prefix */, url.User(username.RootUser)) - defer cleanupGoDB() - firstServerProto := testCluster.Server(0) - - type testCase struct { - query string - fingerprinted string - count int - shouldRetry bool - numRows int - } - - testCases := []testCase{ - {query: `CREATE DATABASE roachblog`, count: 1, numRows: 0}, - {query: `SET database = roachblog`, count: 1, numRows: 0}, - {query: `CREATE TABLE posts (id INT8 PRIMARY KEY, body STRING)`, count: 1, numRows: 0}, - { - query: `INSERT INTO posts VALUES (1, 'foo')`, - fingerprinted: `INSERT INTO posts VALUES (_, '_')`, - count: 1, - numRows: 1, - }, - {query: `SELECT * FROM posts`, count: 2, numRows: 1}, - {query: `BEGIN; SELECT * FROM posts; SELECT * FROM posts; COMMIT`, count: 3, numRows: 2}, - { - query: `BEGIN; SELECT crdb_internal.force_retry('2s'); SELECT * FROM posts; COMMIT;`, - fingerprinted: `BEGIN; SELECT crdb_internal.force_retry(_); SELECT * FROM posts; COMMIT;`, - shouldRetry: true, - count: 1, - numRows: 2, - }, - { - query: `BEGIN; SELECT crdb_internal.force_retry('5s'); SELECT * FROM posts; COMMIT;`, - fingerprinted: `BEGIN; SELECT crdb_internal.force_retry(_); SELECT * FROM posts; COMMIT;`, - shouldRetry: true, - count: 1, - numRows: 2, - }, - } - - appNameToTestCase := make(map[string]testCase) - - for i, tc := range testCases { - appName := fmt.Sprintf("app%d", i) - appNameToTestCase[appName] = tc - - // Create a brand new connection for each app, so that we don't pollute - // transaction stats collection with `SET application_name` queries. - sqlDB, err := gosql.Open("postgres", pgURL.String()) - if err != nil { - t.Fatal(err) - } - if _, err := sqlDB.Exec(fmt.Sprintf(`SET application_name = "%s"`, appName)); err != nil { - t.Fatal(err) - } - for c := 0; c < tc.count; c++ { - if _, err := sqlDB.Exec(tc.query); err != nil { - t.Fatal(err) - } - } - if err := sqlDB.Close(); err != nil { - t.Fatal(err) - } - } - - // Hit query endpoint. - var resp serverpb.StatementsResponse - if err := getStatusJSONProto(firstServerProto, "combinedstmts", &resp); err != nil { - t.Fatal(err) - } - - // Construct a map of all the statement fingerprint IDs. - statementFingerprintIDs := make(map[appstatspb.StmtFingerprintID]bool, len(resp.Statements)) - for _, respStatement := range resp.Statements { - statementFingerprintIDs[respStatement.ID] = true - } - - respAppNames := make(map[string]bool) - for _, respTransaction := range resp.Transactions { - appName := respTransaction.StatsData.App - tc, found := appNameToTestCase[appName] - if !found { - // Ignore internal queries, they aren't relevant to this test. - continue - } - respAppNames[appName] = true - // Ensure all statementFingerprintIDs comprised by the Transaction Response can be - // linked to StatementFingerprintIDs for statements in the response. - for _, stmtFingerprintID := range respTransaction.StatsData.StatementFingerprintIDs { - if _, found := statementFingerprintIDs[stmtFingerprintID]; !found { - t.Fatalf("app: %s, expected stmtFingerprintID: %d not found in StatementResponse.", appName, stmtFingerprintID) - } - } - stats := respTransaction.StatsData.Stats - if tc.count != int(stats.Count) { - t.Fatalf("app: %s, expected count %d, got %d", appName, tc.count, stats.Count) - } - if tc.shouldRetry && respTransaction.StatsData.Stats.MaxRetries == 0 { - t.Fatalf("app: %s, expected retries, got none\n", appName) - } - - // Sanity check numeric stat values - if respTransaction.StatsData.Stats.CommitLat.Mean <= 0 { - t.Fatalf("app: %s, unexpected mean for commit latency\n", appName) - } - if respTransaction.StatsData.Stats.RetryLat.Mean <= 0 && tc.shouldRetry { - t.Fatalf("app: %s, expected retry latency mean to be non-zero as retries were involved\n", appName) - } - if respTransaction.StatsData.Stats.ServiceLat.Mean <= 0 { - t.Fatalf("app: %s, unexpected mean for service latency\n", appName) - } - if respTransaction.StatsData.Stats.NumRows.Mean != float64(tc.numRows) { - t.Fatalf("app: %s, unexpected number of rows observed. expected: %d, got %d\n", - appName, tc.numRows, int(respTransaction.StatsData.Stats.NumRows.Mean)) - } - } - - // Ensure we got transaction statistics for all the queries we sent. - for appName := range appNameToTestCase { - if _, found := respAppNames[appName]; !found { - t.Fatalf("app: %s did not appear in the response\n", appName) - } - } -} - -func TestStatusAPITransactions(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{}) - ctx := context.Background() - defer testCluster.Stopper().Stop(ctx) - - thirdServer := testCluster.Server(2) - pgURL, cleanupGoDB := sqlutils.PGUrl( - t, thirdServer.ServingSQLAddr(), "CreateConnections" /* prefix */, url.User(username.RootUser)) - defer cleanupGoDB() - firstServerProto := testCluster.Server(0) - - type testCase struct { - query string - fingerprinted string - count int - shouldRetry bool - numRows int - } - - testCases := []testCase{ - {query: `CREATE DATABASE roachblog`, count: 1, numRows: 0}, - {query: `SET database = roachblog`, count: 1, numRows: 0}, - {query: `CREATE TABLE posts (id INT8 PRIMARY KEY, body STRING)`, count: 1, numRows: 0}, - { - query: `INSERT INTO posts VALUES (1, 'foo')`, - fingerprinted: `INSERT INTO posts VALUES (_, _)`, - count: 1, - numRows: 1, - }, - {query: `SELECT * FROM posts`, count: 2, numRows: 1}, - {query: `BEGIN; SELECT * FROM posts; SELECT * FROM posts; COMMIT`, count: 3, numRows: 2}, - { - query: `BEGIN; SELECT crdb_internal.force_retry('2s'); SELECT * FROM posts; COMMIT;`, - fingerprinted: `BEGIN; SELECT crdb_internal.force_retry(_); SELECT * FROM posts; COMMIT;`, - shouldRetry: true, - count: 1, - numRows: 2, - }, - { - query: `BEGIN; SELECT crdb_internal.force_retry('5s'); SELECT * FROM posts; COMMIT;`, - fingerprinted: `BEGIN; SELECT crdb_internal.force_retry(_); SELECT * FROM posts; COMMIT;`, - shouldRetry: true, - count: 1, - numRows: 2, - }, - } - - appNameToTestCase := make(map[string]testCase) - - for i, tc := range testCases { - appName := fmt.Sprintf("app%d", i) - appNameToTestCase[appName] = tc - - // Create a brand new connection for each app, so that we don't pollute - // transaction stats collection with `SET application_name` queries. - sqlDB, err := gosql.Open("postgres", pgURL.String()) - if err != nil { - t.Fatal(err) - } - if _, err := sqlDB.Exec(fmt.Sprintf(`SET application_name = "%s"`, appName)); err != nil { - t.Fatal(err) - } - for c := 0; c < tc.count; c++ { - if _, err := sqlDB.Exec(tc.query); err != nil { - t.Fatal(err) - } - } - if err := sqlDB.Close(); err != nil { - t.Fatal(err) - } - } - - // Hit query endpoint. - var resp serverpb.StatementsResponse - if err := getStatusJSONProto(firstServerProto, "statements", &resp); err != nil { - t.Fatal(err) - } - - // Construct a map of all the statement fingerprint IDs. - statementFingerprintIDs := make(map[appstatspb.StmtFingerprintID]bool, len(resp.Statements)) - for _, respStatement := range resp.Statements { - statementFingerprintIDs[respStatement.ID] = true - } - - respAppNames := make(map[string]bool) - for _, respTransaction := range resp.Transactions { - appName := respTransaction.StatsData.App - tc, found := appNameToTestCase[appName] - if !found { - // Ignore internal queries, they aren't relevant to this test. - continue - } - respAppNames[appName] = true - // Ensure all statementFingerprintIDs comprised by the Transaction Response can be - // linked to StatementFingerprintIDs for statements in the response. - for _, stmtFingerprintID := range respTransaction.StatsData.StatementFingerprintIDs { - if _, found := statementFingerprintIDs[stmtFingerprintID]; !found { - t.Fatalf("app: %s, expected stmtFingerprintID: %d not found in StatementResponse.", appName, stmtFingerprintID) - } - } - stats := respTransaction.StatsData.Stats - if tc.count != int(stats.Count) { - t.Fatalf("app: %s, expected count %d, got %d", appName, tc.count, stats.Count) - } - if tc.shouldRetry && respTransaction.StatsData.Stats.MaxRetries == 0 { - t.Fatalf("app: %s, expected retries, got none\n", appName) - } - - // Sanity check numeric stat values - if respTransaction.StatsData.Stats.CommitLat.Mean <= 0 { - t.Fatalf("app: %s, unexpected mean for commit latency\n", appName) - } - if respTransaction.StatsData.Stats.RetryLat.Mean <= 0 && tc.shouldRetry { - t.Fatalf("app: %s, expected retry latency mean to be non-zero as retries were involved\n", appName) - } - if respTransaction.StatsData.Stats.ServiceLat.Mean <= 0 { - t.Fatalf("app: %s, unexpected mean for service latency\n", appName) - } - if respTransaction.StatsData.Stats.NumRows.Mean != float64(tc.numRows) { - t.Fatalf("app: %s, unexpected number of rows observed. expected: %d, got %d\n", - appName, tc.numRows, int(respTransaction.StatsData.Stats.NumRows.Mean)) - } - } - - // Ensure we got transaction statistics for all the queries we sent. - for appName := range appNameToTestCase { - if _, found := respAppNames[appName]; !found { - t.Fatalf("app: %s did not appear in the response\n", appName) - } - } -} - -func TestStatusAPITransactionStatementFingerprintIDsTruncation(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - params, _ := tests.CreateTestServerParams() - testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{ - ServerArgs: params, - }) - defer testCluster.Stopper().Stop(context.Background()) - - firstServerProto := testCluster.Server(0) - thirdServerSQL := sqlutils.MakeSQLRunner(testCluster.ServerConn(2)) - testingApp := "testing" - - thirdServerSQL.Exec(t, `CREATE DATABASE db; CREATE TABLE db.t();`) - thirdServerSQL.Exec(t, fmt.Sprintf(`SET application_name = "%s"`, testingApp)) - - maxStmtFingerprintIDsLen := int(sqlstats.TxnStatsNumStmtFingerprintIDsToRecord.Get( - &firstServerProto.ExecutorConfig().(sql.ExecutorConfig).Settings.SV)) - - // Construct 2 transaction queries that include an absurd number of statements. - // These two queries have the same first 1000 statements, but should still have - // different fingerprints, as fingerprints take into account all - // statementFingerprintIDs (unlike the statementFingerprintIDs stored on the - // proto response, which are capped). - testQuery1 := "BEGIN;" - for i := 0; i < maxStmtFingerprintIDsLen+1; i++ { - testQuery1 += "SELECT * FROM db.t;" - } - testQuery2 := testQuery1 + "SELECT * FROM db.t; COMMIT;" - testQuery1 += "COMMIT;" - - thirdServerSQL.Exec(t, testQuery1) - thirdServerSQL.Exec(t, testQuery2) - - // Hit query endpoint. - var resp serverpb.StatementsResponse - if err := getStatusJSONProto(firstServerProto, "statements", &resp); err != nil { - t.Fatal(err) - } - - txnsFound := 0 - for _, respTransaction := range resp.Transactions { - appName := respTransaction.StatsData.App - if appName != testingApp { - // Only testQuery1 and testQuery2 are relevant to this test. - continue - } - - txnsFound++ - if len(respTransaction.StatsData.StatementFingerprintIDs) != maxStmtFingerprintIDsLen { - t.Fatalf("unexpected length of StatementFingerprintIDs. expected:%d, got:%d", - maxStmtFingerprintIDsLen, len(respTransaction.StatsData.StatementFingerprintIDs)) - } - } - if txnsFound != 2 { - t.Fatalf("transactions were not disambiguated as expected. expected %d txns, got: %d", - 2, txnsFound) - } -} - -func TestStatusAPIStatements(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - // Aug 30 2021 19:50:00 GMT+0000 - aggregatedTs := int64(1630353000) - testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{ - ServerArgs: base.TestServerArgs{ - Knobs: base.TestingKnobs{ - SQLStatsKnobs: &sqlstats.TestingKnobs{ - AOSTClause: "AS OF SYSTEM TIME '-1us'", - StubTimeNow: func() time.Time { return timeutil.Unix(aggregatedTs, 0) }, - }, - SpanConfig: &spanconfig.TestingKnobs{ - ManagerDisableJobCreation: true, // TODO(irfansharif): #74919. - }, - }, - }, - }) - defer testCluster.Stopper().Stop(context.Background()) - - firstServerProto := testCluster.Server(0) - thirdServerSQL := sqlutils.MakeSQLRunner(testCluster.ServerConn(2)) - - statements := []struct { - stmt string - fingerprinted string - }{ - {stmt: `CREATE DATABASE roachblog`}, - {stmt: `SET database = roachblog`}, - {stmt: `CREATE TABLE posts (id INT8 PRIMARY KEY, body STRING)`}, - { - stmt: `INSERT INTO posts VALUES (1, 'foo')`, - fingerprinted: `INSERT INTO posts VALUES (_, '_')`, - }, - {stmt: `SELECT * FROM posts`}, - } - - for _, stmt := range statements { - thirdServerSQL.Exec(t, stmt.stmt) - } - - // Test that non-admin without VIEWACTIVITY privileges cannot access. - var resp serverpb.StatementsResponse - err := getStatusJSONProtoWithAdminOption(firstServerProto, "statements", &resp, false) - if !testutils.IsError(err, "status: 403") { - t.Fatalf("expected privilege error, got %v", err) - } - - testPath := func(path string, expectedStmts []string) { - // Hit query endpoint. - if err := getStatusJSONProtoWithAdminOption(firstServerProto, path, &resp, false); err != nil { - t.Fatal(err) - } - - // See if the statements returned are what we executed. - var statementsInResponse []string - for _, respStatement := range resp.Statements { - if respStatement.Key.KeyData.Failed { - // We ignore failed statements here as the INSERT statement can fail and - // be automatically retried, confusing the test success check. - continue - } - if strings.HasPrefix(respStatement.Key.KeyData.App, catconstants.InternalAppNamePrefix) { - // We ignore internal queries, these are not relevant for the - // validity of this test. - continue - } - if strings.HasPrefix(respStatement.Key.KeyData.Query, "ALTER USER") { - // Ignore the ALTER USER ... VIEWACTIVITY statement. - continue - } - statementsInResponse = append(statementsInResponse, respStatement.Key.KeyData.Query) - } - - sort.Strings(expectedStmts) - sort.Strings(statementsInResponse) - - if !reflect.DeepEqual(expectedStmts, statementsInResponse) { - t.Fatalf("expected queries\n\n%v\n\ngot queries\n\n%v\n%s", - expectedStmts, statementsInResponse, pretty.Sprint(resp)) - } - } - - var expectedStatements []string - for _, stmt := range statements { - var expectedStmt = stmt.stmt - if stmt.fingerprinted != "" { - expectedStmt = stmt.fingerprinted - } - expectedStatements = append(expectedStatements, expectedStmt) - } - - // Grant VIEWACTIVITY. - thirdServerSQL.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITY", authenticatedUserNameNoAdmin().Normalized())) - - // Test no params. - testPath("statements", expectedStatements) - // Test combined=true forwards to CombinedStatements - testPath(fmt.Sprintf("statements?combined=true&start=%d", aggregatedTs+60), nil) - - // Remove VIEWACTIVITY so we can test with just the VIEWACTIVITYREDACTED role. - thirdServerSQL.Exec(t, fmt.Sprintf("ALTER USER %s NOVIEWACTIVITY", authenticatedUserNameNoAdmin().Normalized())) - // Grant VIEWACTIVITYREDACTED. - thirdServerSQL.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITYREDACTED", authenticatedUserNameNoAdmin().Normalized())) - - // Test no params. - testPath("statements", expectedStatements) - // Test combined=true forwards to CombinedStatements - testPath(fmt.Sprintf("statements?combined=true&start=%d", aggregatedTs+60), nil) -} - -func TestStatusAPICombinedStatements(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - // Aug 30 2021 19:50:00 GMT+0000 - aggregatedTs := int64(1630353000) - testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{ - ServerArgs: base.TestServerArgs{ - Knobs: base.TestingKnobs{ - SQLStatsKnobs: &sqlstats.TestingKnobs{ - AOSTClause: "AS OF SYSTEM TIME '-1us'", - StubTimeNow: func() time.Time { return timeutil.Unix(aggregatedTs, 0) }, - }, - SpanConfig: &spanconfig.TestingKnobs{ - ManagerDisableJobCreation: true, // TODO(irfansharif): #74919. - }, - }, - }, - }) - defer testCluster.Stopper().Stop(context.Background()) - - firstServerProto := testCluster.Server(0) - thirdServerSQL := sqlutils.MakeSQLRunner(testCluster.ServerConn(2)) - - statements := []struct { - stmt string - fingerprinted string - }{ - {stmt: `CREATE DATABASE roachblog`}, - {stmt: `SET database = roachblog`}, - {stmt: `CREATE TABLE posts (id INT8 PRIMARY KEY, body STRING)`}, - { - stmt: `INSERT INTO posts VALUES (1, 'foo')`, - fingerprinted: `INSERT INTO posts VALUES (_, '_')`, - }, - {stmt: `SELECT * FROM posts`}, - } - - for _, stmt := range statements { - thirdServerSQL.Exec(t, stmt.stmt) - } - - var resp serverpb.StatementsResponse - // Test that non-admin without VIEWACTIVITY privileges cannot access. - err := getStatusJSONProtoWithAdminOption(firstServerProto, "combinedstmts", &resp, false) - if !testutils.IsError(err, "status: 403") { - t.Fatalf("expected privilege error, got %v", err) - } - - verifyStmts := func(path string, expectedStmts []string, hasTxns bool, t *testing.T) { - // Hit query endpoint. - if err := getStatusJSONProtoWithAdminOption(firstServerProto, path, &resp, false); err != nil { - t.Fatal(err) - } - - // See if the statements returned are what we executed. - var statementsInResponse []string - expectedTxnFingerprints := map[appstatspb.TransactionFingerprintID]struct{}{} - for _, respStatement := range resp.Statements { - if respStatement.Key.KeyData.Failed { - // We ignore failed statements here as the INSERT statement can fail and - // be automatically retried, confusing the test success check. - continue - } - if strings.HasPrefix(respStatement.Key.KeyData.App, catconstants.InternalAppNamePrefix) { - // CombinedStatementStats should filter out internal queries. - t.Fatalf("unexpected internal query: %s", respStatement.Key.KeyData.Query) - } - if strings.HasPrefix(respStatement.Key.KeyData.Query, "ALTER USER") { - // Ignore the ALTER USER ... VIEWACTIVITY statement. - continue - } - - statementsInResponse = append(statementsInResponse, respStatement.Key.KeyData.Query) - for _, txnFingerprintID := range respStatement.TxnFingerprintIDs { - expectedTxnFingerprints[txnFingerprintID] = struct{}{} - } - } - - for _, respTxn := range resp.Transactions { - delete(expectedTxnFingerprints, respTxn.StatsData.TransactionFingerprintID) - } - - sort.Strings(expectedStmts) - sort.Strings(statementsInResponse) - - if !reflect.DeepEqual(expectedStmts, statementsInResponse) { - t.Fatalf("expected queries\n\n%v\n\ngot queries\n\n%v\n%s\n path: %s", - expectedStmts, statementsInResponse, pretty.Sprint(resp), path) - } - if hasTxns { - // We expect that expectedTxnFingerprints is now empty since - // we should have removed them all. - assert.Empty(t, expectedTxnFingerprints) - } else { - assert.Empty(t, resp.Transactions) - } - } - - var expectedStatements []string - for _, stmt := range statements { - var expectedStmt = stmt.stmt - if stmt.fingerprinted != "" { - expectedStmt = stmt.fingerprinted - } - expectedStatements = append(expectedStatements, expectedStmt) - } - - oneMinAfterAggregatedTs := aggregatedTs + 60 - - t.Run("fetch_mode=combined, VIEWACTIVITY", func(t *testing.T) { - // Grant VIEWACTIVITY. - thirdServerSQL.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITY", authenticatedUserNameNoAdmin().Normalized())) - - // Test with no query params. - verifyStmts("combinedstmts", expectedStatements, true, t) - // Test with end = 1 min after aggregatedTs; should give the same results as get all. - verifyStmts(fmt.Sprintf("combinedstmts?end=%d", oneMinAfterAggregatedTs), expectedStatements, true, t) - // Test with start = 1 hour before aggregatedTs end = 1 min after aggregatedTs; should give same results as get all. - verifyStmts(fmt.Sprintf("combinedstmts?start=%d&end=%d", aggregatedTs-3600, oneMinAfterAggregatedTs), - expectedStatements, true, t) - // Test with start = 1 min after aggregatedTs; should give no results - verifyStmts(fmt.Sprintf("combinedstmts?start=%d", oneMinAfterAggregatedTs), nil, true, t) - }) - - t.Run("fetch_mode=combined, VIEWACTIVITYREDACTED", func(t *testing.T) { - // Remove VIEWACTIVITY so we can test with just the VIEWACTIVITYREDACTED role. - thirdServerSQL.Exec(t, fmt.Sprintf("ALTER USER %s NOVIEWACTIVITY", authenticatedUserNameNoAdmin().Normalized())) - // Grant VIEWACTIVITYREDACTED. - thirdServerSQL.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITYREDACTED", authenticatedUserNameNoAdmin().Normalized())) - - // Test with no query params. - verifyStmts("combinedstmts", expectedStatements, true, t) - // Test with end = 1 min after aggregatedTs; should give the same results as get all. - verifyStmts(fmt.Sprintf("combinedstmts?end=%d", oneMinAfterAggregatedTs), expectedStatements, true, t) - // Test with start = 1 hour before aggregatedTs end = 1 min after aggregatedTs; should give same results as get all. - verifyStmts(fmt.Sprintf("combinedstmts?start=%d&end=%d", aggregatedTs-3600, oneMinAfterAggregatedTs), expectedStatements, true, t) - // Test with start = 1 min after aggregatedTs; should give no results - verifyStmts(fmt.Sprintf("combinedstmts?start=%d", oneMinAfterAggregatedTs), nil, true, t) - }) - - t.Run("fetch_mode=StmtsOnly", func(t *testing.T) { - verifyStmts("combinedstmts?fetch_mode.stats_type=0", expectedStatements, false, t) - }) - - t.Run("fetch_mode=TxnsOnly with limit", func(t *testing.T) { - // Verify that we only return stmts for the txns in the response. - // We'll add a limit in a later commit to help verify this behaviour. - if err := getStatusJSONProtoWithAdminOption(firstServerProto, "combinedstmts?fetch_mode.stats_type=1&limit=2", - &resp, false); err != nil { - t.Fatal(err) - } - - assert.Equal(t, 2, len(resp.Transactions)) - stmtFingerprintIDs := map[appstatspb.StmtFingerprintID]struct{}{} - for _, txn := range resp.Transactions { - for _, stmtFingerprint := range txn.StatsData.StatementFingerprintIDs { - stmtFingerprintIDs[stmtFingerprint] = struct{}{} - } - } - - for _, stmt := range resp.Statements { - if _, ok := stmtFingerprintIDs[stmt.ID]; !ok { - t.Fatalf("unexpected stmt; stmt unrelated to a txn int he response: %s", stmt.Key.KeyData.Query) - } - } - }) -} - -func TestStatusAPIStatementDetails(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - // The liveness session might expire before the stress race can finish. - skip.UnderStressRace(t, "expensive tests") - - // Aug 30 2021 19:50:00 GMT+0000 - aggregatedTs := int64(1630353000) - testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{ - ServerArgs: base.TestServerArgs{ - Knobs: base.TestingKnobs{ - SQLStatsKnobs: &sqlstats.TestingKnobs{ - AOSTClause: "AS OF SYSTEM TIME '-1us'", - StubTimeNow: func() time.Time { return timeutil.Unix(aggregatedTs, 0) }, - }, - SpanConfig: &spanconfig.TestingKnobs{ - ManagerDisableJobCreation: true, - }, - }, - }, - }) - defer testCluster.Stopper().Stop(context.Background()) - - firstServerProto := testCluster.Server(0) - thirdServerSQL := sqlutils.MakeSQLRunner(testCluster.ServerConn(2)) - - statements := []string{ - `set application_name = 'first-app'`, - `CREATE DATABASE roachblog`, - `SET database = roachblog`, - `CREATE TABLE posts (id INT8 PRIMARY KEY, body STRING)`, - `INSERT INTO posts VALUES (1, 'foo')`, - `INSERT INTO posts VALUES (2, 'foo')`, - `INSERT INTO posts VALUES (3, 'foo')`, - `SELECT * FROM posts`, - } - - for _, stmt := range statements { - thirdServerSQL.Exec(t, stmt) - } - - query := `INSERT INTO posts VALUES (_, '_')` - fingerprintID := appstatspb.ConstructStatementFingerprintID(query, - false, true, `roachblog`) - path := fmt.Sprintf(`stmtdetails/%v`, fingerprintID) - - var resp serverpb.StatementDetailsResponse - // Test that non-admin without VIEWACTIVITY or VIEWACTIVITYREDACTED privileges cannot access. - err := getStatusJSONProtoWithAdminOption(firstServerProto, path, &resp, false) - if !testutils.IsError(err, "status: 403") { - t.Fatalf("expected privilege error, got %v", err) - } - - type resultValues struct { - query string - totalCount int - aggregatedTsCount int - planHashCount int - fullScanCount int - appNames []string - databases []string - } - - testPath := func(path string, expected resultValues) { - err := getStatusJSONProtoWithAdminOption(firstServerProto, path, &resp, false) - require.NoError(t, err) - require.Equal(t, int64(expected.totalCount), resp.Statement.Stats.Count) - require.Equal(t, expected.aggregatedTsCount, len(resp.StatementStatisticsPerAggregatedTs)) - require.Equal(t, expected.planHashCount, len(resp.StatementStatisticsPerPlanHash)) - require.Equal(t, expected.query, resp.Statement.Metadata.Query) - require.Equal(t, expected.appNames, resp.Statement.Metadata.AppNames) - require.Equal(t, int64(expected.totalCount), resp.Statement.Metadata.TotalCount) - require.Equal(t, expected.databases, resp.Statement.Metadata.Databases) - require.Equal(t, int64(expected.fullScanCount), resp.Statement.Metadata.FullScanCount) - } - - // Grant VIEWACTIVITY. - thirdServerSQL.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITY", authenticatedUserNameNoAdmin().Normalized())) - - // Test with no query params. - testPath( - path, - resultValues{ - query: query, - totalCount: 3, - aggregatedTsCount: 1, - planHashCount: 1, - appNames: []string{"first-app"}, - fullScanCount: 0, - databases: []string{"roachblog"}, - }) - - // Execute same fingerprint id statement on a different application - statements = []string{ - `set application_name = 'second-app'`, - `INSERT INTO posts VALUES (4, 'foo')`, - `INSERT INTO posts VALUES (5, 'foo')`, - } - for _, stmt := range statements { - thirdServerSQL.Exec(t, stmt) - } - - oneMinAfterAggregatedTs := aggregatedTs + 60 - - testData := []struct { - path string - expectedResult resultValues - }{ - { // Test with no query params. - path: path, - expectedResult: resultValues{ - query: query, - totalCount: 5, - aggregatedTsCount: 1, - planHashCount: 1, - appNames: []string{"first-app", "second-app"}, - fullScanCount: 0, - databases: []string{"roachblog"}}, - }, - { // Test with end = 1 min after aggregatedTs; should give the same results as get all. - path: fmt.Sprintf("%v?end=%d", path, oneMinAfterAggregatedTs), - expectedResult: resultValues{ - query: query, - totalCount: 5, - aggregatedTsCount: 1, - planHashCount: 1, - appNames: []string{"first-app", "second-app"}, - fullScanCount: 0, - databases: []string{"roachblog"}}, - }, - { // Test with start = 1 hour before aggregatedTs end = 1 min after aggregatedTs; should give same results as get all. - path: fmt.Sprintf("%v?start=%d&end=%d", path, aggregatedTs-3600, oneMinAfterAggregatedTs), - expectedResult: resultValues{ - query: query, - totalCount: 5, - aggregatedTsCount: 1, - planHashCount: 1, - appNames: []string{"first-app", "second-app"}, - fullScanCount: 0, - databases: []string{"roachblog"}}, - }, - { // Test with start = 1 min after aggregatedTs; should give no results. - path: fmt.Sprintf("%v?start=%d", path, oneMinAfterAggregatedTs), - expectedResult: resultValues{ - query: "", - totalCount: 0, - aggregatedTsCount: 0, - planHashCount: 0, - appNames: []string{}, - fullScanCount: 0, - databases: []string{}}, - }, - { // Test with one app_name. - path: fmt.Sprintf("%v?app_names=first-app", path), - expectedResult: resultValues{ - query: query, - totalCount: 3, - aggregatedTsCount: 1, - planHashCount: 1, - appNames: []string{"first-app"}, - fullScanCount: 0, - databases: []string{"roachblog"}}, - }, - { // Test with another app_name. - path: fmt.Sprintf("%v?app_names=second-app", path), - expectedResult: resultValues{ - query: query, - totalCount: 2, - aggregatedTsCount: 1, - planHashCount: 1, - appNames: []string{"second-app"}, - fullScanCount: 0, - databases: []string{"roachblog"}}, - }, - { // Test with both app_names. - path: fmt.Sprintf("%v?app_names=first-app&app_names=second-app", path), - expectedResult: resultValues{ - query: query, - totalCount: 5, - aggregatedTsCount: 1, - planHashCount: 1, - appNames: []string{"first-app", "second-app"}, - fullScanCount: 0, - databases: []string{"roachblog"}}, - }, - { // Test with non-existing app_name. - path: fmt.Sprintf("%v?app_names=non-existing", path), - expectedResult: resultValues{ - query: "", - totalCount: 0, - aggregatedTsCount: 0, - planHashCount: 0, - appNames: []string{}, - fullScanCount: 0, - databases: []string{}}, - }, - { // Test with app_name, start and end time. - path: fmt.Sprintf("%v?start=%d&end=%d&app_names=first-app&app_names=second-app", path, aggregatedTs-3600, oneMinAfterAggregatedTs), - expectedResult: resultValues{ - query: query, - totalCount: 5, - aggregatedTsCount: 1, - planHashCount: 1, - appNames: []string{"first-app", "second-app"}, - fullScanCount: 0, - databases: []string{"roachblog"}}, - }, - } - - for _, test := range testData { - testPath(test.path, test.expectedResult) - } - - // Remove VIEWACTIVITY so we can test with just the VIEWACTIVITYREDACTED role. - thirdServerSQL.Exec(t, fmt.Sprintf("ALTER USER %s NOVIEWACTIVITY", authenticatedUserNameNoAdmin().Normalized())) - // Grant VIEWACTIVITYREDACTED. - thirdServerSQL.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITYREDACTED", authenticatedUserNameNoAdmin().Normalized())) - - for _, test := range testData { - testPath(test.path, test.expectedResult) - } - - // Test fix for #83608. The stmt being below requested has a fingerprint id - // that is 15 chars in hexadecimal. We should be able to find this stmt now - // that we construct the filter using a bytes comparison instead of string. - - statements = []string{ - `set application_name = 'fix_83608'`, - `set database = defaultdb`, - `SELECT 1, 2, 3, 4`, - } - for _, stmt := range statements { - thirdServerSQL.Exec(t, stmt) - } - - selectQuery := "SELECT _, _, _, _" - fingerprintID = appstatspb.ConstructStatementFingerprintID(selectQuery, false, - true, "defaultdb") - - testPath( - fmt.Sprintf(`stmtdetails/%v`, fingerprintID), - resultValues{ - query: selectQuery, - totalCount: 1, - aggregatedTsCount: 1, - planHashCount: 1, - appNames: []string{"fix_83608"}, - fullScanCount: 0, - databases: []string{"defaultdb"}, - }) -} - -func TestListSessionsSecurity(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) - ts := s.(*TestServer) - defer ts.Stopper().Stop(context.Background()) - - ctx := context.Background() - - for _, requestWithAdmin := range []bool{true, false} { - t.Run(fmt.Sprintf("admin=%v", requestWithAdmin), func(t *testing.T) { - myUser := authenticatedUserNameNoAdmin() - expectedErrOnListingRootSessions := "does not have permission to view sessions from user" - if requestWithAdmin { - myUser = authenticatedUserName() - expectedErrOnListingRootSessions = "" - } - - // HTTP requests respect the authenticated username from the HTTP session. - testCases := []struct { - endpoint string - expectedErr string - }{ - {"local_sessions", ""}, - {"sessions", ""}, - {fmt.Sprintf("local_sessions?username=%s", myUser.Normalized()), ""}, - {fmt.Sprintf("sessions?username=%s", myUser.Normalized()), ""}, - {"local_sessions?username=" + username.RootUser, expectedErrOnListingRootSessions}, - {"sessions?username=" + username.RootUser, expectedErrOnListingRootSessions}, - } - for _, tc := range testCases { - var response serverpb.ListSessionsResponse - err := getStatusJSONProtoWithAdminOption(ts, tc.endpoint, &response, requestWithAdmin) - if tc.expectedErr == "" { - if err != nil || len(response.Errors) > 0 { - t.Errorf("unexpected failure listing sessions from %s; error: %v; response errors: %v", - tc.endpoint, err, response.Errors) - } - } else { - respErr := "" - if len(response.Errors) > 0 { - respErr = response.Errors[0].Message - } - if !testutils.IsError(err, tc.expectedErr) && - !strings.Contains(respErr, tc.expectedErr) { - t.Errorf("did not get expected error %q when listing sessions from %s: %v", - tc.expectedErr, tc.endpoint, err) - } - } - } - }) - } - - // gRPC requests behave as root and thus are always allowed. - rootConfig := testutils.NewTestBaseContext(username.RootUserName()) - rpcContext := newRPCTestContext(ctx, ts, rootConfig) - url := ts.ServingRPCAddr() - nodeID := ts.NodeID() - conn, err := rpcContext.GRPCDialNode(url, nodeID, rpc.DefaultClass).Connect(context.Background()) - if err != nil { - t.Fatal(err) - } - client := serverpb.NewStatusClient(conn) - - for _, user := range []string{"", authenticatedUser, username.RootUser} { - request := &serverpb.ListSessionsRequest{Username: user} - if resp, err := client.ListLocalSessions(ctx, request); err != nil || len(resp.Errors) > 0 { - t.Errorf("unexpected failure listing local sessions for %q; error: %v; response errors: %v", - user, err, resp.Errors) - } - if resp, err := client.ListSessions(ctx, request); err != nil || len(resp.Errors) > 0 { - t.Errorf("unexpected failure listing sessions for %q; error: %v; response errors: %v", - user, err, resp.Errors) - } - } -} - -func TestListActivitySecurity(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - ctx := context.Background() - s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) - ts := s.(*TestServer) - defer ts.Stopper().Stop(ctx) - - expectedErrNoPermission := "this operation requires the VIEWACTIVITY or VIEWACTIVITYREDACTED system privilege" - contentionMsg := &serverpb.ListContentionEventsResponse{} - flowsMsg := &serverpb.ListDistSQLFlowsResponse{} - getErrors := func(msg protoutil.Message) []serverpb.ListActivityError { - switch r := msg.(type) { - case *serverpb.ListContentionEventsResponse: - return r.Errors - case *serverpb.ListDistSQLFlowsResponse: - return r.Errors - default: - t.Fatal("unexpected message type") - return nil - } - } - - // HTTP requests respect the authenticated username from the HTTP session. - testCases := []struct { - endpoint string - expectedErr string - requestWithAdmin bool - requestWithViewActivityGranted bool - response protoutil.Message - }{ - {"local_contention_events", expectedErrNoPermission, false, false, contentionMsg}, - {"contention_events", expectedErrNoPermission, false, false, contentionMsg}, - {"local_contention_events", "", true, false, contentionMsg}, - {"contention_events", "", true, false, contentionMsg}, - {"local_contention_events", "", false, true, contentionMsg}, - {"contention_events", "", false, true, contentionMsg}, - {"local_distsql_flows", expectedErrNoPermission, false, false, flowsMsg}, - {"distsql_flows", expectedErrNoPermission, false, false, flowsMsg}, - {"local_distsql_flows", "", true, false, flowsMsg}, - {"distsql_flows", "", true, false, flowsMsg}, - {"local_distsql_flows", "", false, true, flowsMsg}, - {"distsql_flows", "", false, true, flowsMsg}, - } - myUser := authenticatedUserNameNoAdmin().Normalized() - for _, tc := range testCases { - if tc.requestWithViewActivityGranted { - // Note that for this query to work, it is crucial that - // getStatusJSONProtoWithAdminOption below is called at least once, - // on the previous test case, so that the user exists. - _, err := db.Exec(fmt.Sprintf("ALTER USER %s VIEWACTIVITY", myUser)) - require.NoError(t, err) - } - err := getStatusJSONProtoWithAdminOption(s, tc.endpoint, tc.response, tc.requestWithAdmin) - responseErrors := getErrors(tc.response) - if tc.expectedErr == "" { - if err != nil || len(responseErrors) > 0 { - t.Errorf("unexpected failure listing the activity; error: %v; response errors: %v", - err, responseErrors) - } - } else { - respErr := "" - if len(responseErrors) > 0 { - respErr = responseErrors[0].Message - } - if !testutils.IsError(err, tc.expectedErr) && - !strings.Contains(respErr, tc.expectedErr) { - t.Errorf("did not get expected error %q when listing the activity from %s: %v", - tc.expectedErr, tc.endpoint, err) - } - } - if tc.requestWithViewActivityGranted { - _, err := db.Exec(fmt.Sprintf("ALTER USER %s NOVIEWACTIVITY", myUser)) - require.NoError(t, err) - } - } - - // gRPC requests behave as root and thus are always allowed. - rootConfig := testutils.NewTestBaseContext(username.RootUserName()) - rpcContext := newRPCTestContext(ctx, ts, rootConfig) - url := ts.ServingRPCAddr() - nodeID := ts.NodeID() - conn, err := rpcContext.GRPCDialNode(url, nodeID, rpc.DefaultClass).Connect(ctx) - if err != nil { - t.Fatal(err) - } - client := serverpb.NewStatusClient(conn) - { - request := &serverpb.ListContentionEventsRequest{} - if resp, err := client.ListLocalContentionEvents(ctx, request); err != nil || len(resp.Errors) > 0 { - t.Errorf("unexpected failure listing local contention events; error: %v; response errors: %v", - err, resp.Errors) - } - if resp, err := client.ListContentionEvents(ctx, request); err != nil || len(resp.Errors) > 0 { - t.Errorf("unexpected failure listing contention events; error: %v; response errors: %v", - err, resp.Errors) - } - } - { - request := &serverpb.ListDistSQLFlowsRequest{} - if resp, err := client.ListLocalDistSQLFlows(ctx, request); err != nil || len(resp.Errors) > 0 { - t.Errorf("unexpected failure listing local distsql flows; error: %v; response errors: %v", - err, resp.Errors) - } - if resp, err := client.ListDistSQLFlows(ctx, request); err != nil || len(resp.Errors) > 0 { - t.Errorf("unexpected failure listing distsql flows; error: %v; response errors: %v", - err, resp.Errors) - } - } -} - -func TestMergeDistSQLRemoteFlows(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - flowIDs := make([]execinfrapb.FlowID, 4) - for i := range flowIDs { - flowIDs[i].UUID = uuid.FastMakeV4() - } - sort.Slice(flowIDs, func(i, j int) bool { - return bytes.Compare(flowIDs[i].GetBytes(), flowIDs[j].GetBytes()) < 0 - }) - ts := make([]time.Time, 4) - for i := range ts { - ts[i] = timeutil.Now() - } - - for _, tc := range []struct { - a []serverpb.DistSQLRemoteFlows - b []serverpb.DistSQLRemoteFlows - expected []serverpb.DistSQLRemoteFlows - }{ - // a is empty - { - a: []serverpb.DistSQLRemoteFlows{}, - b: []serverpb.DistSQLRemoteFlows{ - { - FlowID: flowIDs[0], - Infos: []serverpb.DistSQLRemoteFlows_Info{ - {NodeID: 1, Timestamp: ts[1]}, - {NodeID: 2, Timestamp: ts[2]}, - {NodeID: 3, Timestamp: ts[3]}, - }, - }, - { - FlowID: flowIDs[1], - Infos: []serverpb.DistSQLRemoteFlows_Info{ - {NodeID: 1, Timestamp: ts[1]}, - }, - }, - }, - expected: []serverpb.DistSQLRemoteFlows{ - { - FlowID: flowIDs[0], - Infos: []serverpb.DistSQLRemoteFlows_Info{ - {NodeID: 1, Timestamp: ts[1]}, - {NodeID: 2, Timestamp: ts[2]}, - {NodeID: 3, Timestamp: ts[3]}, - }, - }, - { - FlowID: flowIDs[1], - Infos: []serverpb.DistSQLRemoteFlows_Info{ - {NodeID: 1, Timestamp: ts[1]}, - }, - }, - }, - }, - // b is empty - { - a: []serverpb.DistSQLRemoteFlows{ - { - FlowID: flowIDs[0], - Infos: []serverpb.DistSQLRemoteFlows_Info{ - {NodeID: 1, Timestamp: ts[1]}, - {NodeID: 2, Timestamp: ts[2]}, - {NodeID: 3, Timestamp: ts[3]}, - }, - }, - { - FlowID: flowIDs[1], - Infos: []serverpb.DistSQLRemoteFlows_Info{ - {NodeID: 1, Timestamp: ts[1]}, - }, - }, - }, - b: []serverpb.DistSQLRemoteFlows{}, - expected: []serverpb.DistSQLRemoteFlows{ - { - FlowID: flowIDs[0], - Infos: []serverpb.DistSQLRemoteFlows_Info{ - {NodeID: 1, Timestamp: ts[1]}, - {NodeID: 2, Timestamp: ts[2]}, - {NodeID: 3, Timestamp: ts[3]}, - }, - }, - { - FlowID: flowIDs[1], - Infos: []serverpb.DistSQLRemoteFlows_Info{ - {NodeID: 1, Timestamp: ts[1]}, - }, - }, - }, - }, - // both non-empty with some intersections - { - a: []serverpb.DistSQLRemoteFlows{ - { - FlowID: flowIDs[0], - Infos: []serverpb.DistSQLRemoteFlows_Info{ - {NodeID: 1, Timestamp: ts[1]}, - {NodeID: 2, Timestamp: ts[2]}, - {NodeID: 3, Timestamp: ts[3]}, - }, - }, - { - FlowID: flowIDs[2], - Infos: []serverpb.DistSQLRemoteFlows_Info{ - {NodeID: 3, Timestamp: ts[3]}, - }, - }, - { - FlowID: flowIDs[3], - Infos: []serverpb.DistSQLRemoteFlows_Info{ - {NodeID: 0, Timestamp: ts[0]}, - }, - }, - }, - b: []serverpb.DistSQLRemoteFlows{ - { - FlowID: flowIDs[0], - Infos: []serverpb.DistSQLRemoteFlows_Info{ - {NodeID: 0, Timestamp: ts[0]}, - }, - }, - { - FlowID: flowIDs[1], - Infos: []serverpb.DistSQLRemoteFlows_Info{ - {NodeID: 0, Timestamp: ts[0]}, - {NodeID: 1, Timestamp: ts[1]}, - {NodeID: 2, Timestamp: ts[2]}, - }, - }, - { - FlowID: flowIDs[3], - Infos: []serverpb.DistSQLRemoteFlows_Info{ - {NodeID: 1, Timestamp: ts[1]}, - {NodeID: 2, Timestamp: ts[2]}, - }, - }, - }, - expected: []serverpb.DistSQLRemoteFlows{ - { - FlowID: flowIDs[0], - Infos: []serverpb.DistSQLRemoteFlows_Info{ - {NodeID: 0, Timestamp: ts[0]}, - {NodeID: 1, Timestamp: ts[1]}, - {NodeID: 2, Timestamp: ts[2]}, - {NodeID: 3, Timestamp: ts[3]}, - }, - }, - { - FlowID: flowIDs[1], - Infos: []serverpb.DistSQLRemoteFlows_Info{ - {NodeID: 0, Timestamp: ts[0]}, - {NodeID: 1, Timestamp: ts[1]}, - {NodeID: 2, Timestamp: ts[2]}, - }, - }, - { - FlowID: flowIDs[2], - Infos: []serverpb.DistSQLRemoteFlows_Info{ - {NodeID: 3, Timestamp: ts[3]}, - }, - }, - { - FlowID: flowIDs[3], - Infos: []serverpb.DistSQLRemoteFlows_Info{ - {NodeID: 0, Timestamp: ts[0]}, - {NodeID: 1, Timestamp: ts[1]}, - {NodeID: 2, Timestamp: ts[2]}, - }, - }, - }, - }, - } { - require.Equal(t, tc.expected, mergeDistSQLRemoteFlows(tc.a, tc.b)) - } -} - -func TestCreateStatementDiagnosticsReport(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) - defer s.Stopper().Stop(context.Background()) - - req := &serverpb.CreateStatementDiagnosticsReportRequest{ - StatementFingerprint: "INSERT INTO test VALUES (_)", - } - var resp serverpb.CreateStatementDiagnosticsReportResponse - if err := postStatusJSONProto(s, "stmtdiagreports", req, &resp); err != nil { - t.Fatal(err) - } - - var respGet serverpb.StatementDiagnosticsReportsResponse - if err := getStatusJSONProto(s, "stmtdiagreports", &respGet); err != nil { - t.Fatal(err) - } - - if respGet.Reports[0].StatementFingerprint != req.StatementFingerprint { - t.Fatal("statement diagnostics request was not persisted") - } -} - -func TestCreateStatementDiagnosticsReportWithViewActivityOptions(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - s, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{}) - defer s.Stopper().Stop(context.Background()) - db := sqlutils.MakeSQLRunner(sqlDB) - - ctx := context.Background() - ie := s.InternalExecutor().(*sql.InternalExecutor) - - if err := getStatusJSONProtoWithAdminOption(s, "stmtdiagreports", &serverpb.CreateStatementDiagnosticsReportRequest{}, false); err != nil { - if !testutils.IsError(err, "status: 403") { - t.Fatalf("expected privilege error, got %v", err) - } - } - _, err := ie.ExecEx( - ctx, - "inserting-stmt-bundle-req", - nil, /* txn */ - sessiondata.InternalExecutorOverride{ - User: authenticatedUserNameNoAdmin(), - }, - "SELECT crdb_internal.request_statement_bundle('SELECT _', 0::FLOAT, 0::INTERVAL, 0::INTERVAL)", - ) - require.Contains(t, err.Error(), "requesting statement bundle requires VIEWACTIVITY or ADMIN role option") - - // Grant VIEWACTIVITY and all test should work. - db.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITY", authenticatedUserNameNoAdmin().Normalized())) - req := &serverpb.CreateStatementDiagnosticsReportRequest{ - StatementFingerprint: "INSERT INTO test VALUES (_)", - } - var resp serverpb.CreateStatementDiagnosticsReportResponse - if err := postStatusJSONProtoWithAdminOption(s, "stmtdiagreports", req, &resp, false); err != nil { - t.Fatal(err) - } - var respGet serverpb.StatementDiagnosticsReportsResponse - if err := getStatusJSONProtoWithAdminOption(s, "stmtdiagreports", &respGet, false); err != nil { - t.Fatal(err) - } - if respGet.Reports[0].StatementFingerprint != req.StatementFingerprint { - t.Fatal("statement diagnostics request was not persisted") - } - _, err = ie.ExecEx( - ctx, - "inserting-stmt-bundle-req", - nil, /* txn */ - sessiondata.InternalExecutorOverride{ - User: authenticatedUserNameNoAdmin(), - }, - "SELECT crdb_internal.request_statement_bundle('SELECT _', 0::FLOAT, 0::INTERVAL, 0::INTERVAL)", - ) - require.NoError(t, err) - - db.CheckQueryResults(t, ` - SELECT count(*) - FROM system.statement_diagnostics_requests - WHERE statement_fingerprint = 'SELECT _' -`, [][]string{{"1"}}) - - // Grant VIEWACTIVITYREDACTED and all test should get permission errors. - db.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITYREDACTED", authenticatedUserNameNoAdmin().Normalized())) - - if err := postStatusJSONProtoWithAdminOption(s, "stmtdiagreports", req, &resp, false); err != nil { - if !testutils.IsError(err, "status: 403") { - t.Fatalf("expected privilege error, got %v", err) - } - } - if err := getStatusJSONProtoWithAdminOption(s, "stmtdiagreports", &respGet, false); err != nil { - if !testutils.IsError(err, "status: 403") { - t.Fatalf("expected privilege error, got %v", err) - } - } - - _, err = ie.ExecEx( - ctx, - "inserting-stmt-bundle-req", - nil, /* txn */ - sessiondata.InternalExecutorOverride{ - User: authenticatedUserNameNoAdmin(), - }, - "SELECT crdb_internal.request_statement_bundle('SELECT _', 0::FLOAT, 0::INTERVAL, 0::INTERVAL)", - ) - require.Contains(t, err.Error(), "VIEWACTIVITYREDACTED role option cannot request statement bundle") -} - -func TestStatementDiagnosticsCompleted(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) - defer s.Stopper().Stop(context.Background()) - - _, err := db.Exec("CREATE TABLE test (x int PRIMARY KEY)") - if err != nil { - t.Fatal(err) - } - - req := &serverpb.CreateStatementDiagnosticsReportRequest{ - StatementFingerprint: "INSERT INTO test VALUES (_)", - } - var resp serverpb.CreateStatementDiagnosticsReportResponse - if err := postStatusJSONProto(s, "stmtdiagreports", req, &resp); err != nil { - t.Fatal(err) - } - - _, err = db.Exec("INSERT INTO test VALUES (1)") - if err != nil { - t.Fatal(err) - } - - var respGet serverpb.StatementDiagnosticsReportsResponse - if err := getStatusJSONProto(s, "stmtdiagreports", &respGet); err != nil { - t.Fatal(err) - } - - if respGet.Reports[0].Completed != true { - t.Fatal("statement diagnostics was not captured") - } - - var diagRespGet serverpb.StatementDiagnosticsResponse - diagPath := fmt.Sprintf("stmtdiag/%d", respGet.Reports[0].StatementDiagnosticsId) - if err := getStatusJSONProto(s, diagPath, &diagRespGet); err != nil { - t.Fatal(err) - } -} - -func TestStatementDiagnosticsDoesNotReturnExpiredRequests(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - s, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{}) - defer s.Stopper().Stop(context.Background()) - db := sqlutils.MakeSQLRunner(sqlDB) - - statementFingerprint := "INSERT INTO test VALUES (_)" - expiresAfter := 5 * time.Millisecond - - // Create statement diagnostics request with defined expiry time. - req := &serverpb.CreateStatementDiagnosticsReportRequest{ - StatementFingerprint: statementFingerprint, - MinExecutionLatency: 500 * time.Millisecond, - ExpiresAfter: expiresAfter, - } - var resp serverpb.CreateStatementDiagnosticsReportResponse - if err := postStatusJSONProto(s, "stmtdiagreports", req, &resp); err != nil { - t.Fatal(err) - } - - // Wait for request to expire. - time.Sleep(expiresAfter) - - // Check that created statement diagnostics report is incomplete. - report := db.QueryStr(t, ` -SELECT completed -FROM system.statement_diagnostics_requests -WHERE statement_fingerprint = $1`, statementFingerprint) - - require.Equal(t, report[0][0], "false") - - // Check that expired report is not returned in API response. - var respGet serverpb.StatementDiagnosticsReportsResponse - if err := getStatusJSONProto(s, "stmtdiagreports", &respGet); err != nil { - t.Fatal(err) - } - - for _, report := range respGet.Reports { - require.NotEqual(t, report.StatementFingerprint, statementFingerprint) - } -} - -func TestJobStatusResponse(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - ts := startServer(t) - defer ts.Stopper().Stop(context.Background()) - - rootConfig := testutils.NewTestBaseContext(username.RootUserName()) - rpcContext := newRPCTestContext(context.Background(), ts, rootConfig) - - url := ts.ServingRPCAddr() - nodeID := ts.NodeID() - conn, err := rpcContext.GRPCDialNode(url, nodeID, rpc.DefaultClass).Connect(context.Background()) - if err != nil { - t.Fatal(err) - } - client := serverpb.NewStatusClient(conn) - - request := &serverpb.JobStatusRequest{JobId: -1} - response, err := client.JobStatus(context.Background(), request) - require.Regexp(t, `job with ID -1 does not exist`, err) - require.Nil(t, response) - - ctx := context.Background() - jr := ts.JobRegistry().(*jobs.Registry) - job, err := jr.CreateJobWithTxn( - ctx, - jobs.Record{ - Description: "testing", - Statements: []string{"SELECT 1"}, - Username: username.RootUserName(), - Details: jobspb.ImportDetails{ - Tables: []jobspb.ImportDetails_Table{ - { - Desc: &descpb.TableDescriptor{ - ID: 1, - }, - }, - { - Desc: &descpb.TableDescriptor{ - ID: 2, - }, - }, - }, - URIs: []string{"a", "b"}, - }, - Progress: jobspb.ImportProgress{}, - DescriptorIDs: []descpb.ID{1, 2, 3}, - }, - jr.MakeJobID(), - nil) - if err != nil { - t.Fatal(err) - } - request.JobId = int64(job.ID()) - response, err = client.JobStatus(context.Background(), request) - if err != nil { - t.Fatal(err) - } - require.Equal(t, job.ID(), response.Job.Id) - require.Equal(t, job.Payload(), *response.Job.Payload) - require.Equal(t, job.Progress(), *response.Job.Progress) -} - -func TestRegionsResponseFromNodesResponse(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - makeNodeResponseWithLocalities := func(tiers [][]roachpb.Tier) *serverpb.NodesResponse { - ret := &serverpb.NodesResponse{} - for _, l := range tiers { - ret.Nodes = append( - ret.Nodes, - statuspb.NodeStatus{ - Desc: roachpb.NodeDescriptor{ - Locality: roachpb.Locality{Tiers: l}, - }, - }, - ) - } - return ret - } - - makeTiers := func(region, zone string) []roachpb.Tier { - return []roachpb.Tier{ - {Key: "region", Value: region}, - {Key: "zone", Value: zone}, - } - } - - testCases := []struct { - desc string - resp *serverpb.NodesResponse - expected *serverpb.RegionsResponse - }{ - { - desc: "no nodes with regions", - resp: makeNodeResponseWithLocalities([][]roachpb.Tier{ - {{Key: "a", Value: "a"}}, - {}, - }), - expected: &serverpb.RegionsResponse{ - Regions: map[string]*serverpb.RegionsResponse_Region{}, - }, - }, - { - desc: "nodes, some with AZs", - resp: makeNodeResponseWithLocalities([][]roachpb.Tier{ - makeTiers("us-east1", "us-east1-a"), - makeTiers("us-east1", "us-east1-a"), - makeTiers("us-east1", "us-east1-a"), - makeTiers("us-east1", "us-east1-b"), - - makeTiers("us-east2", "us-east2-a"), - makeTiers("us-east2", "us-east2-a"), - makeTiers("us-east2", "us-east2-a"), - - makeTiers("us-east3", "us-east3-a"), - makeTiers("us-east3", "us-east3-b"), - makeTiers("us-east3", "us-east3-b"), - {{Key: "region", Value: "us-east3"}}, - - {{Key: "region", Value: "us-east4"}}, - }), - expected: &serverpb.RegionsResponse{ - Regions: map[string]*serverpb.RegionsResponse_Region{ - "us-east1": { - Zones: []string{"us-east1-a", "us-east1-b"}, - }, - "us-east2": { - Zones: []string{"us-east2-a"}, - }, - "us-east3": { - Zones: []string{"us-east3-a", "us-east3-b"}, - }, - "us-east4": { - Zones: []string{}, - }, - }, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.desc, func(t *testing.T) { - ret := regionsResponseFromNodesResponse(tc.resp) - require.Equal(t, tc.expected, ret) - }) - } -} - -func TestStatusServer_nodeStatusToResp(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - var nodeStatus = &statuspb.NodeStatus{ - StoreStatuses: []statuspb.StoreStatus{ - {Desc: roachpb.StoreDescriptor{ - Properties: roachpb.StoreProperties{ - Encrypted: true, - FileStoreProperties: &roachpb.FileStoreProperties{ - Path: "/secret", - FsType: "ext4", - }, - }, - }}, - }, - Desc: roachpb.NodeDescriptor{ - Address: util.UnresolvedAddr{ - NetworkField: "network", - AddressField: "address", - }, - Attrs: roachpb.Attributes{ - Attrs: []string{"attr"}, - }, - LocalityAddress: []roachpb.LocalityAddress{{Address: util.UnresolvedAddr{ - NetworkField: "network", - AddressField: "address", - }, LocalityTier: roachpb.Tier{Value: "v", Key: "k"}}}, - SQLAddress: util.UnresolvedAddr{ - NetworkField: "network", - AddressField: "address", - }, - }, - Args: []string{"args"}, - Env: []string{"env"}, - } - resp := nodeStatusToResp(nodeStatus, false) - require.Empty(t, resp.Args) - require.Empty(t, resp.Env) - require.Empty(t, resp.Desc.Address) - require.Empty(t, resp.Desc.Attrs.Attrs) - require.Empty(t, resp.Desc.LocalityAddress) - require.Empty(t, resp.Desc.SQLAddress) - require.True(t, resp.StoreStatuses[0].Desc.Properties.Encrypted) - require.NotEmpty(t, resp.StoreStatuses[0].Desc.Properties.FileStoreProperties.FsType) - require.Empty(t, resp.StoreStatuses[0].Desc.Properties.FileStoreProperties.Path) - - // Now fetch all the node statuses as admin. - resp = nodeStatusToResp(nodeStatus, true) - require.NotEmpty(t, resp.Args) - require.NotEmpty(t, resp.Env) - require.NotEmpty(t, resp.Desc.Address) - require.NotEmpty(t, resp.Desc.Attrs.Attrs) - require.NotEmpty(t, resp.Desc.LocalityAddress) - require.NotEmpty(t, resp.Desc.SQLAddress) - require.True(t, resp.StoreStatuses[0].Desc.Properties.Encrypted) - require.NotEmpty(t, resp.StoreStatuses[0].Desc.Properties.FileStoreProperties.FsType) - require.NotEmpty(t, resp.StoreStatuses[0].Desc.Properties.FileStoreProperties.Path) -} - -func TestStatusAPIContentionEvents(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - params, _ := tests.CreateTestServerParams() - ctx := context.Background() - testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{ - ServerArgs: params, - }) - - defer testCluster.Stopper().Stop(ctx) - - server1Conn := sqlutils.MakeSQLRunner(testCluster.ServerConn(0)) - server2Conn := sqlutils.MakeSQLRunner(testCluster.ServerConn(1)) - - contentionCountBefore := testCluster.Server(1).SQLServer().(*sql.Server). - Metrics.EngineMetrics.SQLContendedTxns.Count() - - sqlutils.CreateTable( - t, - testCluster.ServerConn(0), - "test", - "x INT PRIMARY KEY", - 1, /* numRows */ - sqlutils.ToRowFn(sqlutils.RowIdxFn), - ) - - testTableID, err := - strconv.Atoi(server1Conn.QueryStr(t, "SELECT 'test.test'::regclass::oid")[0][0]) - require.NoError(t, err) - - server1Conn.Exec(t, "USE test") - server2Conn.Exec(t, "USE test") - server2Conn.Exec(t, "SET application_name = 'contentionTest'") - - server1Conn.Exec(t, ` -SET TRACING=on; -BEGIN; -UPDATE test SET x = 100 WHERE x = 1; -`) - server2Conn.Exec(t, ` -SET TRACING=on; -BEGIN PRIORITY HIGH; -UPDATE test SET x = 1000 WHERE x = 1; -COMMIT; -SET TRACING=off; -`) - server1Conn.ExpectErr( - t, - "^pq: restart transaction.+", - ` -COMMIT; -SET TRACING=off; -`, - ) - - var resp serverpb.ListContentionEventsResponse - require.NoError(t, - getStatusJSONProtoWithAdminOption( - testCluster.Server(2), - "contention_events", - &resp, - true /* isAdmin */), - ) - - require.GreaterOrEqualf(t, len(resp.Events.IndexContentionEvents), 1, - "expecting at least 1 contention event, but found none") - - found := false - for _, event := range resp.Events.IndexContentionEvents { - if event.TableID == descpb.ID(testTableID) && event.IndexID == descpb.IndexID(1) { - found = true - break - } - } - - require.True(t, found, - "expect to find contention event for table %d, but found %+v", testTableID, resp) - - server1Conn.CheckQueryResults(t, ` - SELECT count(*) - FROM crdb_internal.statement_statistics - WHERE - (statistics -> 'execution_statistics' -> 'contentionTime' ->> 'mean')::FLOAT > 0 - AND app_name = 'contentionTest' -`, [][]string{{"1"}}) - - server1Conn.CheckQueryResults(t, ` - SELECT count(*) - FROM crdb_internal.transaction_statistics - WHERE - (statistics -> 'execution_statistics' -> 'contentionTime' ->> 'mean')::FLOAT > 0 - AND app_name = 'contentionTest' -`, [][]string{{"1"}}) - - contentionCountNow := testCluster.Server(1).SQLServer().(*sql.Server). - Metrics.EngineMetrics.SQLContendedTxns.Count() - - require.Greaterf(t, contentionCountNow, contentionCountBefore, - "expected txn contention count to be more than %d, but it is %d", - contentionCountBefore, contentionCountNow) -} - -func TestStatusCancelSessionGatewayMetadataPropagation(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - ctx := context.Background() - testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{}) - defer testCluster.Stopper().Stop(ctx) - - // Start a SQL session as admin on node 1. - sql0 := sqlutils.MakeSQLRunner(testCluster.ServerConn(0)) - results := sql0.QueryStr(t, "SELECT session_id FROM [SHOW SESSIONS] LIMIT 1") - sessionID, err := hex.DecodeString(results[0][0]) - require.NoError(t, err) - - // Attempt to cancel that SQL session as non-admin over HTTP on node 2. - req := &serverpb.CancelSessionRequest{ - SessionID: sessionID, - } - resp := &serverpb.CancelSessionResponse{} - err = postStatusJSONProtoWithAdminOption(testCluster.Server(1), "cancel_session/1", req, resp, false) - require.NotNil(t, err) - require.Contains(t, err.Error(), "status: 403 Forbidden") -} - -func TestStatusAPIListSessions(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - params, _ := tests.CreateTestServerParams() - ctx := context.Background() - testCluster := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{ - ServerArgs: params, - }) - defer testCluster.Stopper().Stop(ctx) - - serverProto := testCluster.Server(0) - serverSQL := sqlutils.MakeSQLRunner(testCluster.ServerConn(0)) - - appName := "test_sessions_api" - serverSQL.Exec(t, fmt.Sprintf(`SET application_name = "%s"`, appName)) - - getSessionWithTestAppName := func(response *serverpb.ListSessionsResponse) *serverpb.Session { - require.NotEmpty(t, response.Sessions) - for _, s := range response.Sessions { - if s.ApplicationName == appName { - return &s - } - } - t.Errorf("expected to find session with app name %s", appName) - return nil - } - - userNoAdmin := authenticatedUserNameNoAdmin() - var resp serverpb.ListSessionsResponse - // Non-admin without VIEWWACTIVITY or VIEWACTIVITYREDACTED should work and fetch user's own sessions. - err := getStatusJSONProtoWithAdminOption(serverProto, "sessions", &resp, false) - require.NoError(t, err) - - // Grant VIEWACTIVITYREDACTED. - serverSQL.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITYREDACTED", userNoAdmin.Normalized())) - serverSQL.Exec(t, "SELECT 1") - err = getStatusJSONProtoWithAdminOption(serverProto, "sessions", &resp, false) - require.NoError(t, err) - session := getSessionWithTestAppName(&resp) - require.Equal(t, session.LastActiveQuery, session.LastActiveQueryNoConstants) - require.Equal(t, "SELECT _", session.LastActiveQueryNoConstants) - - // Grant VIEWACTIVITY, VIEWACTIVITYREDACTED should take precedence. - serverSQL.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITY", userNoAdmin.Normalized())) - serverSQL.Exec(t, "SELECT 1, 1") - err = getStatusJSONProtoWithAdminOption(serverProto, "sessions", &resp, false) - require.NoError(t, err) - session = getSessionWithTestAppName(&resp) - require.Equal(t, appName, session.ApplicationName) - require.Equal(t, session.LastActiveQuery, session.LastActiveQueryNoConstants) - require.Equal(t, "SELECT _, _", session.LastActiveQueryNoConstants) - - // Remove VIEWACTIVITYREDCATED. User should now see full query. - serverSQL.Exec(t, fmt.Sprintf("ALTER USER %s NOVIEWACTIVITYREDACTED", userNoAdmin.Normalized())) - serverSQL.Exec(t, "SELECT 2") - err = getStatusJSONProtoWithAdminOption(serverProto, "sessions", &resp, false) - require.NoError(t, err) - session = getSessionWithTestAppName(&resp) - require.Equal(t, "SELECT _", session.LastActiveQueryNoConstants) - require.Equal(t, "SELECT 2", session.LastActiveQuery) -} - -func TestListClosedSessions(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - // The active sessions might close before the stress race can finish. - skip.UnderStressRace(t, "active sessions") - - ctx := context.Background() - serverParams, _ := tests.CreateTestServerParams() - testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{ - ServerArgs: serverParams, - }) - defer testCluster.Stopper().Stop(ctx) - - server := testCluster.Server(0) - - doSessionsRequest := func(username string) serverpb.ListSessionsResponse { - var resp serverpb.ListSessionsResponse - path := "/_status/sessions?username=" + username - err := serverutils.GetJSONProto(server, path, &resp) - require.NoError(t, err) - return resp - } - - getUserConn := func(t *testing.T, username string, server serverutils.TestServerInterface) *gosql.DB { - pgURL := url.URL{ - Scheme: "postgres", - User: url.UserPassword(username, "hunter2"), - Host: server.ServingSQLAddr(), - } - db, err := gosql.Open("postgres", pgURL.String()) - require.NoError(t, err) - return db - } - - // Create a test user. - users := []string{"test_user_a", "test_user_b", "test_user_c"} - conn := testCluster.ServerConn(0) - _, err := conn.Exec(fmt.Sprintf(` -CREATE USER %s with password 'hunter2'; -CREATE USER %s with password 'hunter2'; -CREATE USER %s with password 'hunter2'; -`, users[0], users[1], users[2])) - require.NoError(t, err) - - var dbs []*gosql.DB - - // Open 10 sessions for the user and then close them. - for _, user := range users { - for i := 0; i < 10; i++ { - targetDB := getUserConn(t, user, testCluster.Server(0)) - dbs = append(dbs, targetDB) - sqlutils.MakeSQLRunner(targetDB).Exec(t, `SELECT version()`) - } - } - - for _, db := range dbs { - err := db.Close() - require.NoError(t, err) - } - - var wg sync.WaitGroup - - // Open 5 sessions for the user and leave them open by running pg_sleep(30). - for _, user := range users { - for i := 0; i < 5; i++ { - wg.Add(1) - go func(user string) { - // Open a session for the target user. - targetDB := getUserConn(t, user, testCluster.Server(0)) - defer targetDB.Close() - defer wg.Done() - sqlutils.MakeSQLRunner(targetDB).Exec(t, `SELECT pg_sleep(30)`) - }(user) - } - } - - // Open 3 sessions for the user and leave them idle by running version(). - for _, user := range users { - for i := 0; i < 3; i++ { - targetDB := getUserConn(t, user, testCluster.Server(0)) - defer targetDB.Close() - sqlutils.MakeSQLRunner(targetDB).Exec(t, `SELECT version()`) - } - } - - countSessionStatus := func(allSessions []serverpb.Session) (int, int, int) { - var active, idle, closed int - for _, s := range allSessions { - if s.Status.String() == "ACTIVE" { - active++ - } - // IDLE sessions are open sessions with no active queries. - if s.Status.String() == "IDLE" { - idle++ - } - if s.Status.String() == "CLOSED" { - closed++ - } - } - return active, idle, closed - } - - expectedIdle := 3 - expectedActive := 5 - expectedClosed := 10 - - testutils.SucceedsSoon(t, func() error { - for _, user := range users { - sessionsResponse := doSessionsRequest(user) - allSessions := sessionsResponse.Sessions - sort.Slice(allSessions, func(i, j int) bool { - return allSessions[i].Start.Before(allSessions[j].Start) - }) - - active, idle, closed := countSessionStatus(allSessions) - if idle != expectedIdle { - return errors.Newf("User: %s: Expected %d idle sessions, got %d\n", user, expectedIdle, idle) - } - if active != expectedActive { - return errors.Newf("User: %s: Expected %d active sessions, got %d\n", user, expectedActive, active) - } - if closed != expectedClosed { - return errors.Newf("User: %s: Expected %d closed sessions, got %d\n", user, expectedClosed, closed) - } - } - return nil - }) - - // Wait for the goroutines from the pg_sleep() command to finish, so we can - // safely close their connections. - wg.Wait() -} - -func TestTransactionContentionEvents(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - ctx := context.Background() - - s, conn1, _ := serverutils.StartServer(t, base.TestServerArgs{}) - defer s.Stopper().Stop(ctx) - - sqlutils.CreateTable( - t, - conn1, - "test", - "x INT PRIMARY KEY", - 1, /* numRows */ - sqlutils.ToRowFn(sqlutils.RowIdxFn), - ) - - conn2 := - serverutils.OpenDBConn(t, s.ServingSQLAddr(), "", false /* insecure */, s.Stopper()) - defer func() { - require.NoError(t, conn2.Close()) - }() - - sqlConn1 := sqlutils.MakeSQLRunner(conn1) - sqlConn1.Exec(t, "SET CLUSTER SETTING sql.contention.txn_id_cache.max_size = '1GB'") - sqlConn1.Exec(t, "USE test") - sqlConn1.Exec(t, "SET application_name='conn1'") - - sqlConn2 := sqlutils.MakeSQLRunner(conn2) - sqlConn2.Exec(t, "USE test") - sqlConn2.Exec(t, "SET application_name='conn2'") - - // Start the first transaction. - sqlConn1.Exec(t, ` - SET TRACING=on; - BEGIN; - `) - - txnID1 := sqlConn1.QueryStr(t, ` - SELECT txn_id - FROM [SHOW TRANSACTIONS] - WHERE application_name = 'conn1'`)[0][0] - - sqlConn1.Exec(t, "UPDATE test SET x = 100 WHERE x = 1") - - // Start the second transaction with higher priority. This will cause the - // first transaction to be aborted. - sqlConn2.Exec(t, ` - SET TRACING=on; - BEGIN PRIORITY HIGH; - `) - - txnID2 := sqlConn1.QueryStr(t, ` - SELECT txn_id - FROM [SHOW TRANSACTIONS] - WHERE application_name = 'conn2'`)[0][0] - - sqlConn2.Exec(t, ` - UPDATE test SET x = 1000 WHERE x = 1; - COMMIT;`) - - // Ensure that the first transaction is aborted. - sqlConn1.ExpectErr( - t, - "^pq: restart transaction.+", - ` - COMMIT; - SET TRACING=off;`, - ) - - // Sanity check to see the first transaction has been aborted. - sqlConn1.CheckQueryResults(t, "SELECT * FROM test", - [][]string{{"1000"}}) - - txnIDCache := s.SQLServer().(*sql.Server).GetTxnIDCache() - - // Since contention event store's resolver only retries once in the case of - // missing txn fingerprint ID for a given txnID, we ensure that the txnIDCache - // write buffer is properly drained before we go on to test the contention - // registry. - testutils.SucceedsSoon(t, func() error { - txnIDCache.DrainWriteBuffer() - - txnID, err := uuid.FromString(txnID1) - require.NoError(t, err) - - if _, found := txnIDCache.Lookup(txnID); !found { - return errors.Newf("expected the txn fingerprint ID for txn %s to be "+ - "stored in txnID cache, but it is not", txnID1) - } - - txnID, err = uuid.FromString(txnID2) - require.NoError(t, err) - - if _, found := txnIDCache.Lookup(txnID); !found { - return errors.Newf("expected the txn fingerprint ID for txn %s to be "+ - "stored in txnID cache, but it is not", txnID2) - } - - return nil - }) - - testutils.SucceedsWithin(t, func() error { - err := s.ExecutorConfig().(sql.ExecutorConfig).ContentionRegistry.FlushEventsForTest(ctx) - require.NoError(t, err) - - notEmpty := sqlConn1.QueryStr(t, ` - SELECT count(*) > 0 - FROM crdb_internal.transaction_contention_events - WHERE - blocking_txn_id = $1::UUID AND - waiting_txn_id = $2::UUID AND - encode(blocking_txn_fingerprint_id, 'hex') != '0000000000000000' AND - encode(waiting_txn_fingerprint_id, 'hex') != '0000000000000000' AND - length(contending_key) > 0`, txnID1, txnID2)[0][0] - - if notEmpty != "true" { - return errors.Newf("expected at least one contention events, but " + - "none was found") - } - - return nil - }, 10*time.Second) - - nonAdminUser := authenticatedUserNameNoAdmin().Normalized() - adminUser := authenticatedUserName().Normalized() - - // N.B. We need both test users to be created before establishing SQL - // connections with their usernames. We use - // getStatusJSONProtoWithAdminOption() to implicitly create those - // usernames instead of regular CREATE USER statements, since the helper - // getStatusJSONProtoWithAdminOption() couldn't handle the case where - // those two usernames already exist. - // This is the reason why we don't check for returning errors. - _ = getStatusJSONProtoWithAdminOption( - s, - "transactioncontentionevents", - &serverpb.TransactionContentionEventsResponse{}, - true, /* isAdmin */ - ) - _ = getStatusJSONProtoWithAdminOption( - s, - "transactioncontentionevents", - &serverpb.TransactionContentionEventsResponse{}, - false, /* isAdmin */ - ) - - type testCase struct { - testName string - userName string - canViewContendingKey bool - grantPerm string - revokePerm string - isAdmin bool - } - - tcs := []testCase{ - { - testName: "nopermission", - userName: nonAdminUser, - canViewContendingKey: false, - }, - { - testName: "viewactivityredacted", - userName: nonAdminUser, - canViewContendingKey: false, - grantPerm: fmt.Sprintf("ALTER USER %s VIEWACTIVITYREDACTED", nonAdminUser), - revokePerm: fmt.Sprintf("ALTER USER %s NOVIEWACTIVITYREDACTED", nonAdminUser), - }, - { - testName: "viewactivity", - userName: nonAdminUser, - canViewContendingKey: true, - grantPerm: fmt.Sprintf("ALTER USER %s VIEWACTIVITY", nonAdminUser), - revokePerm: fmt.Sprintf("ALTER USER %s NOVIEWACTIVITY", nonAdminUser), - }, - { - testName: "viewactivity_and_viewactivtyredacted", - userName: nonAdminUser, - canViewContendingKey: false, - grantPerm: fmt.Sprintf(`ALTER USER %s VIEWACTIVITY; - ALTER USER %s VIEWACTIVITYREDACTED;`, - nonAdminUser, nonAdminUser), - revokePerm: fmt.Sprintf(`ALTER USER %s NOVIEWACTIVITY; - ALTER USER %s NOVIEWACTIVITYREDACTED;`, - nonAdminUser, nonAdminUser), - }, - { - testName: "adminuser", - userName: adminUser, - canViewContendingKey: true, - isAdmin: true, - }, - } - - expectationStringHelper := func(canViewContendingKey bool) string { - if canViewContendingKey { - return "able to view contending keys" - } - return "not able to view contending keys" - } - - for _, tc := range tcs { - t.Run(tc.testName, func(t *testing.T) { - if tc.grantPerm != "" { - sqlConn1.Exec(t, tc.grantPerm) - } - if tc.revokePerm != "" { - defer sqlConn1.Exec(t, tc.revokePerm) - } - - expectationStr := expectationStringHelper(tc.canViewContendingKey) - t.Run("sql_cli", func(t *testing.T) { - // Check we have proper permission control in SQL CLI. We use internal - // executor here since we can easily override the username without opening - // new SQL sessions. - row, err := s.InternalExecutor().(*sql.InternalExecutor).QueryRowEx( - ctx, - "test-contending-key-redaction", - nil, /* txn */ - sessiondata.InternalExecutorOverride{ - User: username.MakeSQLUsernameFromPreNormalizedString(tc.userName), - }, - ` - SELECT count(*) - FROM crdb_internal.transaction_contention_events - WHERE length(contending_key) > 0`, - ) - if tc.testName == "nopermission" { - require.Contains(t, err.Error(), "does not have VIEWACTIVITY") - } else { - require.NoError(t, err) - visibleContendingKeysCount := tree.MustBeDInt(row[0]) - - require.Equal(t, tc.canViewContendingKey, visibleContendingKeysCount > 0, - "expected to %s, but %d keys have been retrieved", - expectationStr, visibleContendingKeysCount) - } - }) - - t.Run("http", func(t *testing.T) { - // Check we have proper permission control in RPC/HTTP endpoint. - resp := serverpb.TransactionContentionEventsResponse{} - err := getStatusJSONProtoWithAdminOption( - s, - "transactioncontentionevents", - &resp, - tc.isAdmin, - ) - - if tc.testName == "nopermission" { - require.Contains(t, err.Error(), "status: 403") - } else { - require.NoError(t, err) - } - - for _, event := range resp.Events { - require.NotEqual(t, event.WaitingStmtFingerprintID, 0) - require.NotEqual(t, event.WaitingStmtID.String(), clusterunique.ID{}.String()) - - require.Equal(t, tc.canViewContendingKey, len(event.BlockingEvent.Key) > 0, - "expected to %s, but the contending key has length of %d", - expectationStr, - len(event.BlockingEvent.Key), - ) - } - }) - - }) - } -} - -func TestUnprivilegedUserResetIndexUsageStats(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - ctx := context.Background() - - s, conn, _ := serverutils.StartServer(t, base.TestServerArgs{}) - defer s.Stopper().Stop(ctx) - - sqlConn := sqlutils.MakeSQLRunner(conn) - sqlConn.Exec(t, "CREATE USER nonAdminUser") - - ie := s.InternalExecutor().(*sql.InternalExecutor) - - _, err := ie.ExecEx( - ctx, - "test-reset-index-usage-stats-as-non-admin-user", - nil, /* txn */ - sessiondata.InternalExecutorOverride{ - User: username.MakeSQLUsernameFromPreNormalizedString("nonAdminUser"), - }, - "SELECT crdb_internal.reset_index_usage_stats()", - ) - - require.Contains(t, err.Error(), "requires admin privilege") -} diff --git a/pkg/server/storage_api/BUILD.bazel b/pkg/server/storage_api/BUILD.bazel new file mode 100644 index 000000000000..aa86913d2d6f --- /dev/null +++ b/pkg/server/storage_api/BUILD.bazel @@ -0,0 +1,70 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "storage_api", + srcs = ["doc.go"], + importpath = "github.com/cockroachdb/cockroach/pkg/server/storage_api", + visibility = ["//visibility:public"], +) + +go_test( + name = "storage_api_test", + srcs = [ + "certs_test.go", + "decommission_test.go", + "engine_test.go", + "enqueue_test.go", + "files_test.go", + "gossip_test.go", + "health_test.go", + "logfiles_test.go", + "main_test.go", + "network_test.go", + "nodes_test.go", + "raft_test.go", + "rangelog_test.go", + "ranges_test.go", + ], + args = ["-test.timeout=295s"], + deps = [ + "//pkg/base", + "//pkg/build", + "//pkg/gossip", + "//pkg/keys", + "//pkg/kv/kvclient/kvtenant", + "//pkg/kv/kvserver", + "//pkg/kv/kvserver/allocator", + "//pkg/kv/kvserver/allocator/allocatorimpl", + "//pkg/kv/kvserver/kvserverpb", + "//pkg/kv/kvserver/liveness", + "//pkg/kv/kvserver/liveness/livenesspb", + "//pkg/roachpb", + "//pkg/rpc", + "//pkg/security/securityassets", + "//pkg/security/securitytest", + "//pkg/security/username", + "//pkg/server", + "//pkg/server/apiconstants", + "//pkg/server/serverpb", + "//pkg/server/srvtestutils", + "//pkg/server/status/statuspb", + "//pkg/storage/enginepb", + "//pkg/testutils", + "//pkg/testutils/serverutils", + "//pkg/testutils/skip", + "//pkg/testutils/testcluster", + "//pkg/ts", + "//pkg/util/leaktest", + "//pkg/util/log", + "//pkg/util/log/logpb", + "//pkg/util/stop", + "//pkg/util/timeutil", + "@com_github_cockroachdb_errors//:errors", + "@com_github_pkg_errors//:errors", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + "@org_golang_google_grpc//:go_default_library", + "@org_golang_google_grpc//codes", + "@org_golang_google_grpc//status", + ], +) diff --git a/pkg/server/storage_api/certs_test.go b/pkg/server/storage_api/certs_test.go new file mode 100644 index 000000000000..c3c49a705639 --- /dev/null +++ b/pkg/server/storage_api/certs_test.go @@ -0,0 +1,55 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package storage_api_test + +import ( + "context" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" +) + +func TestCertificatesResponse(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ts, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer ts.Stopper().Stop(context.Background()) + + var response serverpb.CertificatesResponse + if err := srvtestutils.GetStatusJSONProto(ts, "certificates/local", &response); err != nil { + t.Fatal(err) + } + + // We expect 5 certificates: CA, node, and client certs for root, testuser, testuser2. + if a, e := len(response.Certificates), 5; a != e { + t.Errorf("expected %d certificates, found %d", e, a) + } + + // The response is ordered: CA cert followed by node cert. + cert := response.Certificates[0] + if a, e := cert.Type, serverpb.CertificateDetails_CA; a != e { + t.Errorf("wrong type %s, expected %s", a, e) + } else if cert.ErrorMessage != "" { + t.Errorf("expected cert without error, got %v", cert.ErrorMessage) + } + + cert = response.Certificates[1] + if a, e := cert.Type, serverpb.CertificateDetails_NODE; a != e { + t.Errorf("wrong type %s, expected %s", a, e) + } else if cert.ErrorMessage != "" { + t.Errorf("expected cert without error, got %v", cert.ErrorMessage) + } +} diff --git a/pkg/server/storage_api/decommission_test.go b/pkg/server/storage_api/decommission_test.go new file mode 100644 index 000000000000..ba9d6f75219b --- /dev/null +++ b/pkg/server/storage_api/decommission_test.go @@ -0,0 +1,1036 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package storage_api_test + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/keys" + "github.com/cockroachdb/cockroach/pkg/kv/kvserver" + "github.com/cockroachdb/cockroach/pkg/kv/kvserver/allocator" + "github.com/cockroachdb/cockroach/pkg/kv/kvserver/allocator/allocatorimpl" + "github.com/cockroachdb/cockroach/pkg/kv/kvserver/liveness" + "github.com/cockroachdb/cockroach/pkg/kv/kvserver/liveness/livenesspb" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/skip" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// TestDecommissionPreCheckInvalid tests decommission pre check expected errors. +func TestDecommissionPreCheckInvalid(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // Set up test cluster. + ctx := context.Background() + tc := serverutils.StartNewTestCluster(t, 4, base.TestClusterArgs{ + ReplicationMode: base.ReplicationManual, + ServerArgsPerNode: map[int]base.TestServerArgs{ + 0: decommissionTsArgs("a", "n1"), + 1: decommissionTsArgs("b", "n2"), + 2: decommissionTsArgs("c", "n3"), + 3: decommissionTsArgs("a", "n4"), + }, + }) + defer tc.Stopper().Stop(ctx) + + firstSvr := tc.Server(0).(*server.TestServer) + + // Create database and tables. + ac := firstSvr.AmbientCtx() + ctx, span := ac.AnnotateCtxWithSpan(context.Background(), "test") + defer span.Finish() + + // Attempt to decommission check with unlimited traces. + decommissioningNodeIDs := []roachpb.NodeID{tc.Server(3).NodeID()} + result, err := firstSvr.DecommissionPreCheck(ctx, decommissioningNodeIDs, + true /* strictReadiness */, true /* collectTraces */, 0, /* maxErrors */ + ) + require.Error(t, err) + status, ok := status.FromError(err) + require.True(t, ok, "expected grpc status error") + require.Equal(t, codes.InvalidArgument, status.Code()) + require.Equal(t, server.DecommissionPreCheckResult{}, result) +} + +// TestDecommissionPreCheckEvaluation tests evaluation of decommission readiness +// of several nodes in a cluster given the replicas that exist on those nodes. +func TestDecommissionPreCheckEvaluation(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + skip.UnderRace(t) // can't handle 7-node clusters + + tsArgs := func(attrs ...string) base.TestServerArgs { + return decommissionTsArgs("a", attrs...) + } + + // Set up test cluster. + ctx := context.Background() + tc := serverutils.StartNewTestCluster(t, 7, base.TestClusterArgs{ + ReplicationMode: base.ReplicationManual, + ServerArgsPerNode: map[int]base.TestServerArgs{ + 0: tsArgs("ns1", "origin"), + 1: tsArgs("ns2", "west"), + 2: tsArgs("ns3", "central"), + 3: tsArgs("ns4", "central"), + 4: tsArgs("ns5", "east"), + 5: tsArgs("ns6", "east"), + 6: tsArgs("ns7", "east"), + }, + }) + defer tc.Stopper().Stop(ctx) + + firstSvr := tc.Server(0).(*server.TestServer) + db := tc.ServerConn(0) + runQueries := func(queries ...string) { + for _, q := range queries { + if _, err := db.Exec(q); err != nil { + t.Fatalf("error executing '%s': %s", q, err) + } + } + } + + // Create database and tables. + ac := firstSvr.AmbientCtx() + ctx, span := ac.AnnotateCtxWithSpan(context.Background(), "test") + defer span.Finish() + setupQueries := []string{ + "CREATE DATABASE test", + "CREATE TABLE test.tblA (val STRING)", + "CREATE TABLE test.tblB (val STRING)", + "INSERT INTO test.tblA VALUES ('testvalA')", + "INSERT INTO test.tblB VALUES ('testvalB')", + } + runQueries(setupQueries...) + alterQueries := []string{ + "ALTER TABLE test.tblA CONFIGURE ZONE USING num_replicas = 3, constraints = '{+west: 1, +central: 1, +east: 1}', " + + "range_max_bytes = 500000000, range_min_bytes = 100", + "ALTER TABLE test.tblB CONFIGURE ZONE USING num_replicas = 3, constraints = '{+east}', " + + "range_max_bytes = 500000000, range_min_bytes = 100", + } + runQueries(alterQueries...) + tblAID, err := firstSvr.TestingQueryTableID(ctx, username.RootUserName(), "test", "tblA") + require.NoError(t, err) + tblBID, err := firstSvr.TestingQueryTableID(ctx, username.RootUserName(), "test", "tblB") + require.NoError(t, err) + startKeyTblA := firstSvr.Codec().TablePrefix(uint32(tblAID)) + startKeyTblB := firstSvr.Codec().TablePrefix(uint32(tblBID)) + + // Split off ranges for tblA and tblB. + _, rDescA, err := firstSvr.SplitRange(startKeyTblA) + require.NoError(t, err) + _, rDescB, err := firstSvr.SplitRange(startKeyTblB) + require.NoError(t, err) + + // Ensure all nodes have the correct span configs for tblA and tblB. + waitForSpanConfig(t, tc, rDescA.StartKey, 500000000) + waitForSpanConfig(t, tc, rDescB.StartKey, 500000000) + + // Transfer tblA to [west, central, east] and tblB to [east]. + tc.AddVotersOrFatal(t, startKeyTblA, tc.Target(1), tc.Target(2), tc.Target(4)) + tc.TransferRangeLeaseOrFatal(t, rDescA, tc.Target(1)) + tc.RemoveVotersOrFatal(t, startKeyTblA, tc.Target(0)) + tc.AddVotersOrFatal(t, startKeyTblB, tc.Target(4), tc.Target(5), tc.Target(6)) + tc.TransferRangeLeaseOrFatal(t, rDescB, tc.Target(4)) + tc.RemoveVotersOrFatal(t, startKeyTblB, tc.Target(0)) + + // Validate range distribution. + rDescA = tc.LookupRangeOrFatal(t, startKeyTblA) + rDescB = tc.LookupRangeOrFatal(t, startKeyTblB) + for _, desc := range []roachpb.RangeDescriptor{rDescA, rDescB} { + require.Lenf(t, desc.Replicas().VoterAndNonVoterDescriptors(), 3, "expected 3 replicas, have %v", desc) + } + + require.True(t, hasReplicaOnServers(tc, &rDescA, 1, 2, 4)) + require.True(t, hasReplicaOnServers(tc, &rDescB, 4, 5, 6)) + + // Evaluate n5 decommission check. + decommissioningNodeIDs := []roachpb.NodeID{tc.Server(4).NodeID()} + result, err := firstSvr.DecommissionPreCheck(ctx, decommissioningNodeIDs, + true /* strictReadiness */, true /* collectTraces */, 10000, /* maxErrors */ + ) + require.NoError(t, err) + require.Equal(t, 2, result.RangesChecked, "unexpected number of ranges checked") + require.Equalf(t, 2, result.ActionCounts[allocatorimpl.AllocatorReplaceDecommissioningVoter.String()], + "unexpected allocator actions, got %v", result.ActionCounts) + require.Lenf(t, result.RangesNotReady, 1, "unexpected number of unready ranges") + + // Validate error on tblB's range as it requires 3 replicas in "east". + unreadyResult := result.RangesNotReady[0] + require.Equalf(t, rDescB.StartKey, unreadyResult.Desc.StartKey, + "expected tblB's range to be unready, got %s", unreadyResult.Desc, + ) + require.Errorf(t, unreadyResult.Err, "expected error on %s", unreadyResult.Desc) + require.NotEmptyf(t, unreadyResult.TracingSpans, "expected tracing spans on %s", unreadyResult.Desc) + var allocatorError allocator.AllocationError + require.ErrorAsf(t, unreadyResult.Err, &allocatorError, "expected allocator error on %s", unreadyResult.Desc) + + // Evaluate n3 decommission check (not required to satisfy constraints). + decommissioningNodeIDs = []roachpb.NodeID{tc.Server(2).NodeID()} + result, err = firstSvr.DecommissionPreCheck(ctx, decommissioningNodeIDs, + true /* strictReadiness */, false /* collectTraces */, 0, /* maxErrors */ + ) + require.NoError(t, err) + require.Equal(t, 1, result.RangesChecked, "unexpected number of ranges checked") + require.Equalf(t, 1, result.ActionCounts[allocatorimpl.AllocatorReplaceDecommissioningVoter.String()], + "unexpected allocator actions, got %v", result.ActionCounts) + require.Lenf(t, result.RangesNotReady, 0, "unexpected number of unready ranges") +} + +// TestDecommissionPreCheckOddToEven tests evaluation of decommission readiness +// when moving from 5 nodes to 3, in which case ranges with RF of 5 should have +// an effective RF of 3. +func TestDecommissionPreCheckOddToEven(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // Set up test cluster. + ctx := context.Background() + tc := serverutils.StartNewTestCluster(t, 5, base.TestClusterArgs{ + ReplicationMode: base.ReplicationManual, + }) + defer tc.Stopper().Stop(ctx) + + firstSvr := tc.Server(0).(*server.TestServer) + db := tc.ServerConn(0) + runQueries := func(queries ...string) { + for _, q := range queries { + if _, err := db.Exec(q); err != nil { + t.Fatalf("error executing '%s': %s", q, err) + } + } + } + + // Create database and tables. + ac := firstSvr.AmbientCtx() + ctx, span := ac.AnnotateCtxWithSpan(context.Background(), "test") + defer span.Finish() + setupQueries := []string{ + "CREATE DATABASE test", + "CREATE TABLE test.tblA (val STRING)", + "INSERT INTO test.tblA VALUES ('testvalA')", + } + runQueries(setupQueries...) + alterQueries := []string{ + "ALTER TABLE test.tblA CONFIGURE ZONE USING num_replicas = 5, " + + "range_max_bytes = 500000000, range_min_bytes = 100", + } + runQueries(alterQueries...) + tblAID, err := firstSvr.TestingQueryTableID(ctx, username.RootUserName(), "test", "tblA") + require.NoError(t, err) + startKeyTblA := firstSvr.Codec().TablePrefix(uint32(tblAID)) + + // Split off range for tblA. + _, rDescA, err := firstSvr.SplitRange(startKeyTblA) + require.NoError(t, err) + + // Ensure all nodes have the correct span configs for tblA. + waitForSpanConfig(t, tc, rDescA.StartKey, 500000000) + + // Transfer tblA to all nodes. + tc.AddVotersOrFatal(t, startKeyTblA, tc.Target(1), tc.Target(2), tc.Target(3), tc.Target(4)) + tc.TransferRangeLeaseOrFatal(t, rDescA, tc.Target(1)) + + // Validate range distribution. + rDescA = tc.LookupRangeOrFatal(t, startKeyTblA) + require.Lenf(t, rDescA.Replicas().VoterAndNonVoterDescriptors(), 5, "expected 5 replicas, have %v", rDescA) + + require.True(t, hasReplicaOnServers(tc, &rDescA, 0, 1, 2, 3, 4)) + + // Evaluate n5 decommission check. + decommissioningNodeIDs := []roachpb.NodeID{tc.Server(4).NodeID()} + result, err := firstSvr.DecommissionPreCheck(ctx, decommissioningNodeIDs, + true /* strictReadiness */, true /* collectTraces */, 10000, /* maxErrors */ + ) + require.NoError(t, err) + require.Equal(t, 1, result.RangesChecked, "unexpected number of ranges checked") + require.Equalf(t, 1, result.ActionCounts[allocatorimpl.AllocatorRemoveDecommissioningVoter.String()], + "unexpected allocator actions, got %v", result.ActionCounts) + require.Lenf(t, result.RangesNotReady, 0, "unexpected number of unready ranges") +} + +// decommissionTsArgs returns a base.TestServerArgs for creating a test cluster +// with per-store attributes using a single, in-memory store for each node. +func decommissionTsArgs(region string, attrs ...string) base.TestServerArgs { + return base.TestServerArgs{ + Locality: roachpb.Locality{ + Tiers: []roachpb.Tier{ + { + Key: "region", + Value: region, + }, + }, + }, + StoreSpecs: []base.StoreSpec{ + {InMemory: true, Attributes: roachpb.Attributes{Attrs: attrs}}, + }, + } +} + +// hasReplicaOnServers returns true if the range has replicas on given servers. +func hasReplicaOnServers( + tc serverutils.TestClusterInterface, desc *roachpb.RangeDescriptor, serverIdxs ...int, +) bool { + for _, idx := range serverIdxs { + if !desc.Replicas().HasReplicaOnNode(tc.Server(idx).NodeID()) { + return false + } + } + return true +} + +// waitForSpanConfig waits until all servers in the test cluster have a span +// config for the key with the expected number of max bytes for the range. +func waitForSpanConfig( + t *testing.T, tc serverutils.TestClusterInterface, key roachpb.RKey, exp int64, +) { + testutils.SucceedsSoon(t, func() error { + for i := 0; i < tc.NumServers(); i++ { + s := tc.Server(i) + store, err := s.GetStores().(*kvserver.Stores).GetStore(s.GetFirstStoreID()) + if err != nil { + return errors.Wrapf(err, "missing store on server %d", i) + } + conf, err := store.GetStoreConfig().SpanConfigSubscriber.GetSpanConfigForKey(context.Background(), key) + if err != nil { + return errors.Wrapf(err, "missing span config for %s on server %d", key, i) + } + if conf.RangeMaxBytes != exp { + return errors.Errorf("expected %d max bytes, got %d", exp, conf.RangeMaxBytes) + } + } + return nil + }) +} + +// TestDecommissionPreCheckBasicReadiness tests the basic functionality of the +// DecommissionPreCheck endpoint. +func TestDecommissionPreCheckBasicReadiness(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + skip.UnderRace(t) // can't handle 7-node clusters + + ctx := context.Background() + tc := serverutils.StartNewTestCluster(t, 7, base.TestClusterArgs{ + ReplicationMode: base.ReplicationManual, // saves time + }) + defer tc.Stopper().Stop(ctx) + + adminSrv := tc.Server(4) + conn, err := adminSrv.RPCContext().GRPCDialNode( + adminSrv.RPCAddr(), adminSrv.NodeID(), rpc.DefaultClass).Connect(ctx) + require.NoError(t, err) + adminClient := serverpb.NewAdminClient(conn) + + resp, err := adminClient.DecommissionPreCheck(ctx, &serverpb.DecommissionPreCheckRequest{ + NodeIDs: []roachpb.NodeID{tc.Server(5).NodeID()}, + }) + require.NoError(t, err) + require.Len(t, resp.CheckedNodes, 1) + checkNodeCheckResultReady(t, tc.Server(5).NodeID(), 0, resp.CheckedNodes[0]) +} + +// TestDecommissionPreCheckUnready tests the functionality of the +// DecommissionPreCheck endpoint with some nodes not ready. +func TestDecommissionPreCheckUnready(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + skip.UnderRace(t) // can't handle 7-node clusters + + ctx := context.Background() + tc := serverutils.StartNewTestCluster(t, 7, base.TestClusterArgs{ + ReplicationMode: base.ReplicationManual, // saves time + }) + defer tc.Stopper().Stop(ctx) + + // Add replicas to a node we will check. + // Scratch range should have RF=3, liveness range should have RF=5. + adminSrvIdx := 3 + decommissioningSrvIdx := 5 + scratchKey := tc.ScratchRange(t) + scratchDesc := tc.AddVotersOrFatal(t, scratchKey, tc.Target(decommissioningSrvIdx)) + livenessDesc := tc.LookupRangeOrFatal(t, keys.NodeLivenessPrefix) + livenessDesc = tc.AddVotersOrFatal(t, livenessDesc.StartKey.AsRawKey(), tc.Target(decommissioningSrvIdx)) + + adminSrv := tc.Server(adminSrvIdx) + decommissioningSrv := tc.Server(decommissioningSrvIdx) + conn, err := adminSrv.RPCContext().GRPCDialNode( + adminSrv.RPCAddr(), adminSrv.NodeID(), rpc.DefaultClass).Connect(ctx) + require.NoError(t, err) + adminClient := serverpb.NewAdminClient(conn) + + checkNodeReady := func(nID roachpb.NodeID, replicaCount int64, strict bool) { + resp, err := adminClient.DecommissionPreCheck(ctx, &serverpb.DecommissionPreCheckRequest{ + NodeIDs: []roachpb.NodeID{nID}, + StrictReadiness: strict, + }) + require.NoError(t, err) + require.Len(t, resp.CheckedNodes, 1) + checkNodeCheckResultReady(t, nID, replicaCount, resp.CheckedNodes[0]) + } + + awaitDecommissioned := func(nID roachpb.NodeID) { + testutils.SucceedsSoon(t, func() error { + livenesses, err := adminSrv.NodeLiveness().(*liveness.NodeLiveness).ScanNodeVitalityFromKV(ctx) + if err != nil { + return err + } + for nodeID, nodeLiveness := range livenesses { + if nodeID == nID { + if nodeLiveness.IsDecommissioned() { + return nil + } else { + return errors.Errorf("n%d has membership: %s", nID, nodeLiveness.MembershipStatus()) + } + } + } + return errors.Errorf("n%d liveness not found", nID) + }) + } + + checkAndDecommission := func(srvIdx int, replicaCount int64, strict bool) { + nID := tc.Server(srvIdx).NodeID() + checkNodeReady(nID, replicaCount, strict) + require.NoError(t, adminSrv.Decommission( + ctx, livenesspb.MembershipStatus_DECOMMISSIONING, []roachpb.NodeID{nID})) + require.NoError(t, adminSrv.Decommission( + ctx, livenesspb.MembershipStatus_DECOMMISSIONED, []roachpb.NodeID{nID})) + awaitDecommissioned(nID) + } + + // In non-strict mode, this decommission appears "ready". This is because the + // ranges with replicas on decommissioningSrv have priority action "AddVoter", + // and they have valid targets. + checkNodeReady(decommissioningSrv.NodeID(), 2, false) + + // In strict mode, we would expect the readiness check to fail. + resp, err := adminClient.DecommissionPreCheck(ctx, &serverpb.DecommissionPreCheckRequest{ + NodeIDs: []roachpb.NodeID{decommissioningSrv.NodeID()}, + NumReplicaReport: 50, + StrictReadiness: true, + CollectTraces: true, + }) + require.NoError(t, err) + nodeCheckResult := resp.CheckedNodes[0] + require.Equalf(t, serverpb.DecommissionPreCheckResponse_ALLOCATION_ERRORS, nodeCheckResult.DecommissionReadiness, + "expected n%d to have allocation errors, got %s", nodeCheckResult.NodeID, nodeCheckResult.DecommissionReadiness) + require.Len(t, nodeCheckResult.CheckedRanges, 2) + checkRangeCheckResult(t, livenessDesc, nodeCheckResult.CheckedRanges[0], + "add voter", "needs repair beyond replacing/removing", true, + ) + checkRangeCheckResult(t, scratchDesc, nodeCheckResult.CheckedRanges[1], + "add voter", "needs repair beyond replacing/removing", true, + ) + + // Add replicas to ensure we have the correct number of replicas for each range. + scratchDesc = tc.AddVotersOrFatal(t, scratchKey, tc.Target(adminSrvIdx)) + livenessDesc = tc.AddVotersOrFatal(t, livenessDesc.StartKey.AsRawKey(), + tc.Target(adminSrvIdx), tc.Target(4), tc.Target(6), + ) + require.True(t, hasReplicaOnServers(tc, &scratchDesc, 0, adminSrvIdx, decommissioningSrvIdx)) + require.True(t, hasReplicaOnServers(tc, &livenessDesc, 0, adminSrvIdx, decommissioningSrvIdx, 4, 6)) + require.Len(t, scratchDesc.InternalReplicas, 3) + require.Len(t, livenessDesc.InternalReplicas, 5) + + // Decommissioning pre-check should pass on decommissioningSrv in both strict + // and non-strict modes, as each range can find valid upreplication targets. + checkNodeReady(decommissioningSrv.NodeID(), 2, true) + + // Check and decommission empty nodes, decreasing to a 5-node cluster. + checkAndDecommission(1, 0, true) + checkAndDecommission(2, 0, true) + + // Check that we can still decommission. + // Below 5 nodes, system ranges will have an effective RF=3. + checkNodeReady(decommissioningSrv.NodeID(), 2, true) + + // Check that we can decommission the nodes with liveness replicas only. + checkAndDecommission(4, 1, true) + checkAndDecommission(6, 1, true) + + // Check range descriptors are as expected. + scratchDesc = tc.LookupRangeOrFatal(t, scratchDesc.StartKey.AsRawKey()) + livenessDesc = tc.LookupRangeOrFatal(t, livenessDesc.StartKey.AsRawKey()) + require.True(t, hasReplicaOnServers(tc, &scratchDesc, 0, adminSrvIdx, decommissioningSrvIdx)) + require.True(t, hasReplicaOnServers(tc, &livenessDesc, 0, adminSrvIdx, decommissioningSrvIdx, 4, 6)) + require.Len(t, scratchDesc.InternalReplicas, 3) + require.Len(t, livenessDesc.InternalReplicas, 5) + + // Cleanup orphaned liveness replicas and check. + livenessDesc = tc.RemoveVotersOrFatal(t, livenessDesc.StartKey.AsRawKey(), tc.Target(4), tc.Target(6)) + require.True(t, hasReplicaOnServers(tc, &livenessDesc, 0, adminSrvIdx, decommissioningSrvIdx)) + require.Len(t, livenessDesc.InternalReplicas, 3) + + // Validate that the node is not ready to decommission. + resp, err = adminClient.DecommissionPreCheck(ctx, &serverpb.DecommissionPreCheckRequest{ + NodeIDs: []roachpb.NodeID{decommissioningSrv.NodeID()}, + NumReplicaReport: 1, // Test that we limit errors. + StrictReadiness: true, + }) + require.NoError(t, err) + nodeCheckResult = resp.CheckedNodes[0] + require.Equalf(t, serverpb.DecommissionPreCheckResponse_ALLOCATION_ERRORS, nodeCheckResult.DecommissionReadiness, + "expected n%d to have allocation errors, got %s", nodeCheckResult.NodeID, nodeCheckResult.DecommissionReadiness) + require.Equal(t, int64(2), nodeCheckResult.ReplicaCount) + require.Len(t, nodeCheckResult.CheckedRanges, 1) + checkRangeCheckResult(t, livenessDesc, nodeCheckResult.CheckedRanges[0], + "replace decommissioning voter", + "0 of 2 live stores are able to take a new replica for the range "+ + "(2 already have a voter, 0 already have a non-voter); "+ + "likely not enough nodes in cluster", + false, + ) +} + +// TestDecommissionPreCheckMultiple tests the functionality of the +// DecommissionPreCheck endpoint with multiple nodes. +func TestDecommissionPreCheckMultiple(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + tc := serverutils.StartNewTestCluster(t, 5, base.TestClusterArgs{ + ReplicationMode: base.ReplicationManual, // saves time + }) + defer tc.Stopper().Stop(ctx) + + // TODO(sarkesian): Once #95909 is merged, test checks on a 3-node decommission. + // e.g. Test both server idxs 3,4 and 2,3,4 (which should not pass checks). + adminSrvIdx := 1 + decommissioningSrvIdxs := []int{3, 4} + decommissioningSrvNodeIDs := make([]roachpb.NodeID, len(decommissioningSrvIdxs)) + for i, srvIdx := range decommissioningSrvIdxs { + decommissioningSrvNodeIDs[i] = tc.Server(srvIdx).NodeID() + } + + // Add replicas to nodes we will check. + // Scratch range should have RF=3, liveness range should have RF=5. + rangeDescs := []roachpb.RangeDescriptor{ + tc.LookupRangeOrFatal(t, keys.NodeLivenessPrefix), + tc.LookupRangeOrFatal(t, tc.ScratchRange(t)), + } + rangeDescSrvIdxs := [][]int{ + {0, 1, 2, 3, 4}, + {0, 3, 4}, + } + rangeDescSrvTargets := make([][]roachpb.ReplicationTarget, len(rangeDescs)) + for i, srvIdxs := range rangeDescSrvIdxs { + for _, srvIdx := range srvIdxs { + if srvIdx != 0 { + rangeDescSrvTargets[i] = append(rangeDescSrvTargets[i], tc.Target(srvIdx)) + } + } + } + + for i, rangeDesc := range rangeDescs { + rangeDescs[i] = tc.AddVotersOrFatal(t, rangeDesc.StartKey.AsRawKey(), rangeDescSrvTargets[i]...) + } + + for i, rangeDesc := range rangeDescs { + require.True(t, hasReplicaOnServers(tc, &rangeDesc, rangeDescSrvIdxs[i]...)) + require.Len(t, rangeDesc.InternalReplicas, len(rangeDescSrvIdxs[i])) + } + + adminSrv := tc.Server(adminSrvIdx) + conn, err := adminSrv.RPCContext().GRPCDialNode( + adminSrv.RPCAddr(), adminSrv.NodeID(), rpc.DefaultClass).Connect(ctx) + require.NoError(t, err) + adminClient := serverpb.NewAdminClient(conn) + + // We expect to be able to decommission the targeted nodes simultaneously. + resp, err := adminClient.DecommissionPreCheck(ctx, &serverpb.DecommissionPreCheckRequest{ + NodeIDs: decommissioningSrvNodeIDs, + NumReplicaReport: 50, + StrictReadiness: true, + CollectTraces: true, + }) + require.NoError(t, err) + require.Len(t, resp.CheckedNodes, len(decommissioningSrvIdxs)) + for i, nID := range decommissioningSrvNodeIDs { + checkNodeCheckResultReady(t, nID, int64(len(rangeDescs)), resp.CheckedNodes[i]) + } +} + +// TestDecommissionPreCheckInvalidNode tests the functionality of the +// DecommissionPreCheck endpoint where some nodes are invalid. +func TestDecommissionPreCheckInvalidNode(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + tc := serverutils.StartNewTestCluster(t, 5, base.TestClusterArgs{ + ReplicationMode: base.ReplicationManual, // saves time + }) + defer tc.Stopper().Stop(ctx) + + adminSrvIdx := 1 + validDecommissioningNodeID := roachpb.NodeID(5) + invalidDecommissioningNodeID := roachpb.NodeID(34) + decommissioningNodeIDs := []roachpb.NodeID{validDecommissioningNodeID, invalidDecommissioningNodeID} + + // Add replicas to nodes we will check. + // Scratch range should have RF=3, liveness range should have RF=5. + rangeDescs := []roachpb.RangeDescriptor{ + tc.LookupRangeOrFatal(t, keys.NodeLivenessPrefix), + tc.LookupRangeOrFatal(t, tc.ScratchRange(t)), + } + rangeDescSrvIdxs := [][]int{ + {0, 1, 2, 3, 4}, + {0, 3, 4}, + } + rangeDescSrvTargets := make([][]roachpb.ReplicationTarget, len(rangeDescs)) + for i, srvIdxs := range rangeDescSrvIdxs { + for _, srvIdx := range srvIdxs { + if srvIdx != 0 { + rangeDescSrvTargets[i] = append(rangeDescSrvTargets[i], tc.Target(srvIdx)) + } + } + } + + for i, rangeDesc := range rangeDescs { + rangeDescs[i] = tc.AddVotersOrFatal(t, rangeDesc.StartKey.AsRawKey(), rangeDescSrvTargets[i]...) + } + + for i, rangeDesc := range rangeDescs { + require.True(t, hasReplicaOnServers(tc, &rangeDesc, rangeDescSrvIdxs[i]...)) + require.Len(t, rangeDesc.InternalReplicas, len(rangeDescSrvIdxs[i])) + } + + adminSrv := tc.Server(adminSrvIdx) + conn, err := adminSrv.RPCContext().GRPCDialNode( + adminSrv.RPCAddr(), adminSrv.NodeID(), rpc.DefaultClass).Connect(ctx) + require.NoError(t, err) + adminClient := serverpb.NewAdminClient(conn) + + // We expect the pre-check to fail as some node IDs are invalid. + resp, err := adminClient.DecommissionPreCheck(ctx, &serverpb.DecommissionPreCheckRequest{ + NodeIDs: decommissioningNodeIDs, + NumReplicaReport: 50, + StrictReadiness: true, + CollectTraces: true, + }) + require.NoError(t, err) + require.Len(t, resp.CheckedNodes, len(decommissioningNodeIDs)) + checkNodeCheckResultReady(t, validDecommissioningNodeID, int64(len(rangeDescs)), resp.CheckedNodes[0]) + require.Equal(t, serverpb.DecommissionPreCheckResponse_NodeCheckResult{ + NodeID: invalidDecommissioningNodeID, + DecommissionReadiness: serverpb.DecommissionPreCheckResponse_UNKNOWN, + ReplicaCount: 0, + CheckedRanges: nil, + }, resp.CheckedNodes[1]) +} + +func TestDecommissionSelf(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + skip.UnderRace(t) // can't handle 7-node clusters + + // Set up test cluster. + ctx := context.Background() + tc := serverutils.StartNewTestCluster(t, 7, base.TestClusterArgs{ + ReplicationMode: base.ReplicationManual, // saves time + }) + defer tc.Stopper().Stop(ctx) + + // Decommission several nodes, including the node we're submitting the + // decommission request to. We use the admin client in order to test the + // admin server's logic, which involves a subsequent DecommissionStatus + // call which could fail if used from a node that's just decommissioned. + adminSrv := tc.Server(4) + conn, err := adminSrv.RPCContext().GRPCDialNode( + adminSrv.RPCAddr(), adminSrv.NodeID(), rpc.DefaultClass).Connect(ctx) + require.NoError(t, err) + adminClient := serverpb.NewAdminClient(conn) + decomNodeIDs := []roachpb.NodeID{ + tc.Server(4).NodeID(), + tc.Server(5).NodeID(), + tc.Server(6).NodeID(), + } + + // The DECOMMISSIONING call should return a full status response. + resp, err := adminClient.Decommission(ctx, &serverpb.DecommissionRequest{ + NodeIDs: decomNodeIDs, + TargetMembership: livenesspb.MembershipStatus_DECOMMISSIONING, + }) + require.NoError(t, err) + require.Len(t, resp.Status, len(decomNodeIDs)) + for i, nodeID := range decomNodeIDs { + status := resp.Status[i] + require.Equal(t, nodeID, status.NodeID) + // Liveness entries may not have been updated yet. + require.Contains(t, []livenesspb.MembershipStatus{ + livenesspb.MembershipStatus_ACTIVE, + livenesspb.MembershipStatus_DECOMMISSIONING, + }, status.Membership, "unexpected membership status %v for node %v", status, nodeID) + } + + // The DECOMMISSIONED call should return an empty response, to avoid + // erroring due to loss of cluster RPC access when decommissioning self. + resp, err = adminClient.Decommission(ctx, &serverpb.DecommissionRequest{ + NodeIDs: decomNodeIDs, + TargetMembership: livenesspb.MembershipStatus_DECOMMISSIONED, + }) + require.NoError(t, err) + require.Empty(t, resp.Status) + + // The nodes should now have been (or soon become) decommissioned. + for i := 0; i < tc.NumServers(); i++ { + srv := tc.Server(i) + expect := livenesspb.MembershipStatus_ACTIVE + for _, nodeID := range decomNodeIDs { + if srv.NodeID() == nodeID { + expect = livenesspb.MembershipStatus_DECOMMISSIONED + break + } + } + require.Eventually(t, func() bool { + liveness, ok := srv.NodeLiveness().(*liveness.NodeLiveness).GetLiveness(srv.NodeID()) + return ok && liveness.Membership == expect + }, 5*time.Second, 100*time.Millisecond, "timed out waiting for node %v status %v", i, expect) + } +} + +// TestDecommissionEnqueueReplicas tests that a decommissioning node's replicas +// are proactively enqueued into their replicateQueues by the other nodes in the +// system. +func TestDecommissionEnqueueReplicas(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + skip.UnderRace(t) // can't handle 7-node clusters + + ctx := context.Background() + enqueuedRangeIDs := make(chan roachpb.RangeID) + tc := serverutils.StartNewTestCluster(t, 7, base.TestClusterArgs{ + ReplicationMode: base.ReplicationManual, + ServerArgs: base.TestServerArgs{ + Insecure: true, // allows admin client without setting up certs + Knobs: base.TestingKnobs{ + Store: &kvserver.StoreTestingKnobs{ + EnqueueReplicaInterceptor: func( + queueName string, repl *kvserver.Replica, + ) { + require.Equal(t, queueName, "replicate") + enqueuedRangeIDs <- repl.RangeID + }, + }, + }, + }, + }) + defer tc.Stopper().Stop(ctx) + + decommissionAndCheck := func(decommissioningSrvIdx int) { + t.Logf("decommissioning n%d", tc.Target(decommissioningSrvIdx).NodeID) + // Add a scratch range's replica to a node we will decommission. + scratchKey := tc.ScratchRange(t) + decommissioningSrv := tc.Server(decommissioningSrvIdx) + tc.AddVotersOrFatal(t, scratchKey, tc.Target(decommissioningSrvIdx)) + + conn, err := decommissioningSrv.RPCContext().GRPCDialNode( + decommissioningSrv.RPCAddr(), decommissioningSrv.NodeID(), rpc.DefaultClass, + ).Connect(ctx) + require.NoError(t, err) + adminClient := serverpb.NewAdminClient(conn) + decomNodeIDs := []roachpb.NodeID{tc.Server(decommissioningSrvIdx).NodeID()} + _, err = adminClient.Decommission( + ctx, + &serverpb.DecommissionRequest{ + NodeIDs: decomNodeIDs, + TargetMembership: livenesspb.MembershipStatus_DECOMMISSIONING, + }, + ) + require.NoError(t, err) + + // Ensure that the scratch range's replica was proactively enqueued. + require.Equal(t, <-enqueuedRangeIDs, tc.LookupRangeOrFatal(t, scratchKey).RangeID) + + // Check that the node was marked as decommissioning in each of the nodes' + // decommissioningNodeMap. This needs to be wrapped in a SucceedsSoon to + // deal with gossip propagation delays. + testutils.SucceedsSoon(t, func() error { + for i := 0; i < tc.NumServers(); i++ { + srv := tc.Server(i) + if _, exists := srv.DecommissioningNodeMap()[decommissioningSrv.NodeID()]; !exists { + return errors.Newf("node %d not detected to be decommissioning", decommissioningSrv.NodeID()) + } + } + return nil + }) + } + + decommissionAndCheck(2 /* decommissioningSrvIdx */) + decommissionAndCheck(3 /* decommissioningSrvIdx */) + decommissionAndCheck(5 /* decommissioningSrvIdx */) +} + +func TestAdminDecommissionedOperations(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + skip.UnderRace(t, "test uses timeouts, and race builds cause the timeouts to be exceeded") + + ctx := context.Background() + tc := serverutils.StartNewTestCluster(t, 2, base.TestClusterArgs{ + ReplicationMode: base.ReplicationManual, // saves time + ServerArgs: base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails + // with it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + Insecure: true, // allows admin client without setting up certs + }, + }) + defer tc.Stopper().Stop(ctx) + + serverutils.SetClusterSetting(t, tc, "server.shutdown.jobs_wait", 0) + + scratchKey := tc.ScratchRange(t) + scratchRange := tc.LookupRangeOrFatal(t, scratchKey) + require.Len(t, scratchRange.InternalReplicas, 1) + require.Equal(t, tc.Server(0).NodeID(), scratchRange.InternalReplicas[0].NodeID) + + // Decommission server 1 and wait for it to lose cluster access. + srv := tc.Server(0) + decomSrv := tc.Server(1) + for _, status := range []livenesspb.MembershipStatus{ + livenesspb.MembershipStatus_DECOMMISSIONING, livenesspb.MembershipStatus_DECOMMISSIONED, + } { + require.NoError(t, srv.Decommission(ctx, status, []roachpb.NodeID{decomSrv.NodeID()})) + } + + testutils.SucceedsWithin(t, func() error { + timeoutCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + _, err := decomSrv.DB().Scan(timeoutCtx, keys.LocalMax, keys.MaxKey, 0) + if err == nil { + return errors.New("expected error") + } + s, ok := status.FromError(errors.UnwrapAll(err)) + if ok && s.Code() == codes.PermissionDenied { + return nil + } + return err + }, 10*time.Second) + + // Set up an admin client. + //lint:ignore SA1019 grpc.WithInsecure is deprecated + conn, err := grpc.Dial(decomSrv.ServingRPCAddr(), grpc.WithInsecure()) + require.NoError(t, err) + defer func() { + _ = conn.Close() // nolint:grpcconnclose + }() + adminClient := serverpb.NewAdminClient(conn) + + // Run some operations on the decommissioned node. The ones that require + // access to the cluster should fail, other should succeed. We're mostly + // concerned with making sure they return rather than hang due to internal + // retries. + testcases := []struct { + name string + expectCode codes.Code + op func(context.Context, serverpb.AdminClient) error + }{ + {"Cluster", codes.OK, func(ctx context.Context, c serverpb.AdminClient) error { + _, err := c.Cluster(ctx, &serverpb.ClusterRequest{}) + return err + }}, + {"Databases", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { + _, err := c.Databases(ctx, &serverpb.DatabasesRequest{}) + return err + }}, + {"DatabaseDetails", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { + _, err := c.DatabaseDetails(ctx, &serverpb.DatabaseDetailsRequest{Database: "foo"}) + return err + }}, + {"DataDistribution", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { + _, err := c.DataDistribution(ctx, &serverpb.DataDistributionRequest{}) + return err + }}, + {"Decommission", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { + _, err := c.Decommission(ctx, &serverpb.DecommissionRequest{ + NodeIDs: []roachpb.NodeID{srv.NodeID(), decomSrv.NodeID()}, + TargetMembership: livenesspb.MembershipStatus_DECOMMISSIONED, + }) + return err + }}, + {"DecommissionStatus", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { + _, err := c.DecommissionStatus(ctx, &serverpb.DecommissionStatusRequest{ + NodeIDs: []roachpb.NodeID{srv.NodeID(), decomSrv.NodeID()}, + }) + return err + }}, + {"EnqueueRange", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { + _, err := c.EnqueueRange(ctx, &serverpb.EnqueueRangeRequest{ + RangeID: scratchRange.RangeID, + Queue: "replicaGC", + }) + return err + }}, + {"Events", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { + _, err := c.Events(ctx, &serverpb.EventsRequest{}) + return err + }}, + {"Health", codes.OK, func(ctx context.Context, c serverpb.AdminClient) error { + _, err := c.Health(ctx, &serverpb.HealthRequest{}) + return err + }}, + {"Jobs", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { + _, err := c.Jobs(ctx, &serverpb.JobsRequest{}) + return err + }}, + {"Liveness", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { + _, err := c.Liveness(ctx, &serverpb.LivenessRequest{}) + return err + }}, + {"Locations", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { + _, err := c.Locations(ctx, &serverpb.LocationsRequest{}) + return err + }}, + {"NonTableStats", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { + _, err := c.NonTableStats(ctx, &serverpb.NonTableStatsRequest{}) + return err + }}, + {"QueryPlan", codes.OK, func(ctx context.Context, c serverpb.AdminClient) error { + _, err := c.QueryPlan(ctx, &serverpb.QueryPlanRequest{Query: "SELECT 1"}) + return err + }}, + {"RangeLog", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { + _, err := c.RangeLog(ctx, &serverpb.RangeLogRequest{}) + return err + }}, + {"Settings", codes.OK, func(ctx context.Context, c serverpb.AdminClient) error { + _, err := c.Settings(ctx, &serverpb.SettingsRequest{}) + return err + }}, + {"TableStats", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { + _, err := c.TableStats(ctx, &serverpb.TableStatsRequest{Database: "foo", Table: "bar"}) + return err + }}, + {"TableDetails", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { + _, err := c.TableDetails(ctx, &serverpb.TableDetailsRequest{Database: "foo", Table: "bar"}) + return err + }}, + {"Users", codes.Internal, func(ctx context.Context, c serverpb.AdminClient) error { + _, err := c.Users(ctx, &serverpb.UsersRequest{}) + return err + }}, + // We drain at the end, since it may evict us. + {"Drain", codes.Unknown, func(ctx context.Context, c serverpb.AdminClient) error { + stream, err := c.Drain(ctx, &serverpb.DrainRequest{DoDrain: true}) + if err != nil { + return err + } + _, err = stream.Recv() + return err + }}, + } + + for _, tc := range testcases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + testutils.SucceedsWithin(t, func() error { + timeoutCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + err := tc.op(timeoutCtx, adminClient) + if tc.expectCode == codes.OK { + require.NoError(t, err) + return nil + } + if err == nil { + // This will cause SuccessWithin to retry. + return errors.New("expected error, got no error") + } + s, ok := status.FromError(errors.UnwrapAll(err)) + if !ok { + // Not a gRPC error. + // This will cause SuccessWithin to retry. + return err + } + require.Equal(t, tc.expectCode, s.Code(), "%+v", err) + return nil + }, 10*time.Second) + }) + } +} + +// checkNodeCheckResultReady is a helper function for validating that the +// results of a decommission pre-check on a single node show it is ready. +func checkNodeCheckResultReady( + t *testing.T, + nID roachpb.NodeID, + replicaCount int64, + checkResult serverpb.DecommissionPreCheckResponse_NodeCheckResult, +) { + require.Equal(t, serverpb.DecommissionPreCheckResponse_NodeCheckResult{ + NodeID: nID, + DecommissionReadiness: serverpb.DecommissionPreCheckResponse_READY, + ReplicaCount: replicaCount, + CheckedRanges: nil, + }, checkResult) +} + +// checkRangeCheckResult is a helper function for validating a range error +// returned as part of a decommission pre-check. +func checkRangeCheckResult( + t *testing.T, + desc roachpb.RangeDescriptor, + checkResult serverpb.DecommissionPreCheckResponse_RangeCheckResult, + expectedAction string, + expectedErrSubstr string, + expectTraces bool, +) { + passed := false + defer func() { + if !passed { + t.Logf("failed checking %s", desc) + if expectTraces { + var traceBuilder strings.Builder + for _, event := range checkResult.Events { + fmt.Fprintf(&traceBuilder, "\n(%s) %s", event.Time, event.Message) + } + t.Logf("trace events: %s", traceBuilder.String()) + } + } + }() + require.Equalf(t, desc.RangeID, checkResult.RangeID, "expected r%d, got r%d with error: \"%s\"", + desc.RangeID, checkResult.RangeID, checkResult.Error) + require.Equalf(t, expectedAction, checkResult.Action, "r%d expected action %s, got action %s with error: \"%s\"", + desc.RangeID, expectedAction, checkResult.Action, checkResult.Error) + require.NotEmptyf(t, checkResult.Error, "r%d expected non-empty error", checkResult.RangeID) + if len(expectedErrSubstr) > 0 { + require.Containsf(t, checkResult.Error, expectedErrSubstr, "r%d expected error with \"%s\", got error: \"%s\"", + desc.RangeID, expectedErrSubstr, checkResult.Error) + } + if expectTraces { + require.NotEmptyf(t, checkResult.Events, "r%d expected traces, got none with error: \"%s\"", + checkResult.RangeID, checkResult.Error) + } else { + require.Emptyf(t, checkResult.Events, "r%d expected no traces with error: \"%s\"", + checkResult.RangeID, checkResult.Error) + } + passed = true +} diff --git a/pkg/server/storage_api/doc.go b/pkg/server/storage_api/doc.go new file mode 100644 index 000000000000..77bf675c8fcc --- /dev/null +++ b/pkg/server/storage_api/doc.go @@ -0,0 +1,15 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +// Package storage_api pertains to the RPC and HTTP APIs exposed by +// the storage and KV layers. +// Application-level APIs (e.g. SQL inspection) are in the +// application_api package. +package storage_api diff --git a/pkg/server/storage_api/engine_test.go b/pkg/server/storage_api/engine_test.go new file mode 100644 index 000000000000..af3b3b860c64 --- /dev/null +++ b/pkg/server/storage_api/engine_test.go @@ -0,0 +1,95 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package storage_api_test + +import ( + "context" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/storage/enginepb" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/pkg/errors" +) + +// TestStatusEngineStatsJson ensures that the output response for the engine +// stats contains the required fields. +func TestStatusEngineStatsJson(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + dir, cleanupFn := testutils.TempDir(t) + defer cleanupFn() + + s, err := serverutils.StartServerRaw(t, base.TestServerArgs{ + StoreSpecs: []base.StoreSpec{{ + Path: dir, + }}, + }) + if err != nil { + t.Fatal(err) + } + defer s.Stopper().Stop(context.Background()) + + t.Logf("using admin URL %s", s.AdminURL()) + + var engineStats serverpb.EngineStatsResponse + // Using SucceedsSoon because we have seen in the wild that + // occasionally requests don't go through with error "transport: + // error while dialing: connection interrupted (did the remote node + // shut down or are there networking issues?)" + testutils.SucceedsSoon(t, func() error { + return srvtestutils.GetStatusJSONProto(s, "enginestats/local", &engineStats) + }) + + if len(engineStats.Stats) != 1 { + t.Fatal(errors.Errorf("expected one engine stats, got: %v", engineStats)) + } + + if engineStats.Stats[0].EngineType == enginepb.EngineTypePebble || + engineStats.Stats[0].EngineType == enginepb.EngineTypeDefault { + // Pebble does not have RocksDB style TickersAnd Histogram. + return + } + + tickers := engineStats.Stats[0].TickersAndHistograms.Tickers + if len(tickers) == 0 { + t.Fatal(errors.Errorf("expected non-empty tickers list, got: %v", tickers)) + } + allTickersZero := true + for _, ticker := range tickers { + if ticker != 0 { + allTickersZero = false + } + } + if allTickersZero { + t.Fatal(errors.Errorf("expected some tickers nonzero, got: %v", tickers)) + } + + histograms := engineStats.Stats[0].TickersAndHistograms.Histograms + if len(histograms) == 0 { + t.Fatal(errors.Errorf("expected non-empty histograms list, got: %v", histograms)) + } + allHistogramsZero := true + for _, histogram := range histograms { + if histogram.Max == 0 { + allHistogramsZero = false + } + } + if allHistogramsZero { + t.Fatal(errors.Errorf("expected some histograms nonzero, got: %v", histograms)) + } +} diff --git a/pkg/server/storage_api/enqueue_test.go b/pkg/server/storage_api/enqueue_test.go new file mode 100644 index 000000000000..0de5d130ad14 --- /dev/null +++ b/pkg/server/storage_api/enqueue_test.go @@ -0,0 +1,124 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package storage_api_test + +import ( + "context" + "fmt" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" +) + +func TestEnqueueRange(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{ + ReplicationMode: base.ReplicationManual, + }) + defer testCluster.Stopper().Stop(context.Background()) + + // Up-replicate r1 to all 3 nodes. We use manual replication to avoid lease + // transfers causing temporary conditions in which no store is the + // leaseholder, which can break the tests below. + _, err := testCluster.AddVoters(roachpb.KeyMin, testCluster.Target(1), testCluster.Target(2)) + if err != nil { + t.Fatal(err) + } + + // RangeID being queued + const realRangeID = 1 + const fakeRangeID = 999 + + // Who we expect responses from. + const none = 0 + const leaseholder = 1 + const allReplicas = 3 + + testCases := []struct { + nodeID roachpb.NodeID + queue string + rangeID roachpb.RangeID + expectedDetails int + expectedNonErrors int + }{ + // Success cases + {0, "mvccGC", realRangeID, allReplicas, leaseholder}, + {0, "split", realRangeID, allReplicas, leaseholder}, + {0, "replicaGC", realRangeID, allReplicas, allReplicas}, + {0, "RaFtLoG", realRangeID, allReplicas, allReplicas}, + {0, "RAFTSNAPSHOT", realRangeID, allReplicas, allReplicas}, + {0, "consistencyChecker", realRangeID, allReplicas, leaseholder}, + {0, "TIMESERIESmaintenance", realRangeID, allReplicas, leaseholder}, + {1, "raftlog", realRangeID, leaseholder, leaseholder}, + {2, "raftlog", realRangeID, leaseholder, 1}, + {3, "raftlog", realRangeID, leaseholder, 1}, + // Compatibility cases. + // TODO(nvanbenschoten): remove this in v23.1. + {0, "gc", realRangeID, allReplicas, leaseholder}, + {0, "GC", realRangeID, allReplicas, leaseholder}, + // Error cases + {0, "gv", realRangeID, allReplicas, none}, + {0, "GC", fakeRangeID, allReplicas, none}, + } + + for _, tc := range testCases { + t.Run(tc.queue, func(t *testing.T) { + req := &serverpb.EnqueueRangeRequest{ + NodeID: tc.nodeID, + Queue: tc.queue, + RangeID: tc.rangeID, + } + var resp serverpb.EnqueueRangeResponse + if err := srvtestutils.PostAdminJSONProto(testCluster.Server(0), "enqueue_range", req, &resp); err != nil { + t.Fatal(err) + } + if e, a := tc.expectedDetails, len(resp.Details); e != a { + t.Errorf("expected %d details; got %d: %+v", e, a, resp) + } + var numNonErrors int + for _, details := range resp.Details { + if len(details.Events) > 0 && details.Error == "" { + numNonErrors++ + } + } + if tc.expectedNonErrors != numNonErrors { + t.Errorf("expected %d non-error details; got %d: %+v", tc.expectedNonErrors, numNonErrors, resp) + } + }) + } + + // Finally, test a few more basic error cases. + reqs := []*serverpb.EnqueueRangeRequest{ + {NodeID: -1, Queue: "mvccGC"}, + {Queue: ""}, + {RangeID: -1, Queue: "mvccGC"}, + } + for _, req := range reqs { + t.Run(fmt.Sprint(req), func(t *testing.T) { + var resp serverpb.EnqueueRangeResponse + err := srvtestutils.PostAdminJSONProto(testCluster.Server(0), "enqueue_range", req, &resp) + if err == nil { + t.Fatalf("unexpected success: %+v", resp) + } + if !testutils.IsError(err, "400 Bad Request") { + t.Fatalf("unexpected error type: %+v", err) + } + }) + } +} diff --git a/pkg/server/storage_api/files_test.go b/pkg/server/storage_api/files_test.go new file mode 100644 index 000000000000..f398dc5349d4 --- /dev/null +++ b/pkg/server/storage_api/files_test.go @@ -0,0 +1,155 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package storage_api_test + +import ( + "bytes" + "context" + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" +) + +// TestStatusGetFiles tests the GetFiles endpoint. +func TestStatusGetFiles(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + tempDir, cleanupFn := testutils.TempDir(t) + defer cleanupFn() + + storeSpec := base.StoreSpec{Path: tempDir} + + tsI, _, _ := serverutils.StartServer(t, base.TestServerArgs{ + StoreSpecs: []base.StoreSpec{ + storeSpec, + }, + }) + ts := tsI.(*server.TestServer) + defer ts.Stopper().Stop(context.Background()) + + rootConfig := testutils.NewTestBaseContext(username.RootUserName()) + rpcContext := srvtestutils.NewRPCTestContext(context.Background(), ts, rootConfig) + + url := ts.ServingRPCAddr() + nodeID := ts.NodeID() + conn, err := rpcContext.GRPCDialNode(url, nodeID, rpc.DefaultClass).Connect(context.Background()) + if err != nil { + t.Fatal(err) + } + client := serverpb.NewStatusClient(conn) + + // Test fetching heap files. + t.Run("heap", func(t *testing.T) { + const testFilesNo = 3 + for i := 0; i < testFilesNo; i++ { + testHeapDir := filepath.Join(storeSpec.Path, "logs", base.HeapProfileDir) + testHeapFile := filepath.Join(testHeapDir, fmt.Sprintf("heap%d.pprof", i)) + if err := os.MkdirAll(testHeapDir, os.ModePerm); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(testHeapFile, []byte(fmt.Sprintf("I'm heap file %d", i)), 0644); err != nil { + t.Fatal(err) + } + } + + request := serverpb.GetFilesRequest{ + NodeId: "local", Type: serverpb.FileType_HEAP, Patterns: []string{"heap*"}} + response, err := client.GetFiles(context.Background(), &request) + if err != nil { + t.Fatal(err) + } + + if a, e := len(response.Files), testFilesNo; a != e { + t.Errorf("expected %d files(s), found %d", e, a) + } + + for i, file := range response.Files { + expectedFileName := fmt.Sprintf("heap%d.pprof", i) + if file.Name != expectedFileName { + t.Fatalf("expected file name %s, found %s", expectedFileName, file.Name) + } + expectedFileContents := []byte(fmt.Sprintf("I'm heap file %d", i)) + if !bytes.Equal(file.Contents, expectedFileContents) { + t.Fatalf("expected file contents %s, found %s", expectedFileContents, file.Contents) + } + } + }) + + // Test fetching goroutine files. + t.Run("goroutines", func(t *testing.T) { + const testFilesNo = 3 + for i := 0; i < testFilesNo; i++ { + testGoroutineDir := filepath.Join(storeSpec.Path, "logs", base.GoroutineDumpDir) + testGoroutineFile := filepath.Join(testGoroutineDir, fmt.Sprintf("goroutine_dump%d.txt.gz", i)) + if err := os.MkdirAll(testGoroutineDir, os.ModePerm); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(testGoroutineFile, []byte(fmt.Sprintf("Goroutine dump %d", i)), 0644); err != nil { + t.Fatal(err) + } + } + + request := serverpb.GetFilesRequest{ + NodeId: "local", Type: serverpb.FileType_GOROUTINES, Patterns: []string{"*"}} + response, err := client.GetFiles(context.Background(), &request) + if err != nil { + t.Fatal(err) + } + + if a, e := len(response.Files), testFilesNo; a != e { + t.Errorf("expected %d files(s), found %d", e, a) + } + + for i, file := range response.Files { + expectedFileName := fmt.Sprintf("goroutine_dump%d.txt.gz", i) + if file.Name != expectedFileName { + t.Fatalf("expected file name %s, found %s", expectedFileName, file.Name) + } + expectedFileContents := []byte(fmt.Sprintf("Goroutine dump %d", i)) + if !bytes.Equal(file.Contents, expectedFileContents) { + t.Fatalf("expected file contents %s, found %s", expectedFileContents, file.Contents) + } + } + }) + + // Testing path separators in pattern. + t.Run("path separators", func(t *testing.T) { + request := serverpb.GetFilesRequest{NodeId: "local", ListOnly: true, + Type: serverpb.FileType_HEAP, Patterns: []string{"pattern/with/separators"}} + _, err = client.GetFiles(context.Background(), &request) + if !testutils.IsError(err, "invalid pattern: cannot have path seperators") { + t.Errorf("GetFiles: path separators allowed in pattern") + } + }) + + // Testing invalid filetypes. + t.Run("filetypes", func(t *testing.T) { + request := serverpb.GetFilesRequest{NodeId: "local", ListOnly: true, + Type: -1, Patterns: []string{"*"}} + _, err = client.GetFiles(context.Background(), &request) + if !testutils.IsError(err, "unknown file type: -1") { + t.Errorf("GetFiles: invalid file type allowed") + } + }) +} diff --git a/pkg/server/storage_api/gossip_test.go b/pkg/server/storage_api/gossip_test.go new file mode 100644 index 000000000000..96848b66d048 --- /dev/null +++ b/pkg/server/storage_api/gossip_test.go @@ -0,0 +1,47 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package storage_api_test + +import ( + "context" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/gossip" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" +) + +// TestStatusGossipJson ensures that the output response for the full gossip +// info contains the required fields. +func TestStatusGossipJson(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + + var data gossip.InfoStatus + if err := srvtestutils.GetStatusJSONProto(s, "gossip/local", &data); err != nil { + t.Fatal(err) + } + if _, ok := data.Infos["first-range"]; !ok { + t.Errorf("no first-range info returned: %v", data) + } + if _, ok := data.Infos["cluster-id"]; !ok { + t.Errorf("no clusterID info returned: %v", data) + } + if _, ok := data.Infos["node:1"]; !ok { + t.Errorf("no node 1 info returned: %v", data) + } +} diff --git a/pkg/server/storage_api/health_test.go b/pkg/server/storage_api/health_test.go new file mode 100644 index 000000000000..5bd7f92a70bf --- /dev/null +++ b/pkg/server/storage_api/health_test.go @@ -0,0 +1,143 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package storage_api_test + +import ( + "context" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/kv/kvserver/liveness" + "github.com/cockroachdb/cockroach/pkg/kv/kvserver/liveness/livenesspb" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/testcluster" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" +) + +func TestHealthAPI(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails + // with it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + }) + defer s.Stopper().Stop(ctx) + ts := s.(*server.TestServer) + + // We need to retry because the node ID isn't set until after + // bootstrapping. + testutils.SucceedsSoon(t, func() error { + var resp serverpb.HealthResponse + return srvtestutils.GetAdminJSONProto(s, "health", &resp) + }) + + // Make the SQL listener appear unavailable. Verify that health fails after that. + ts.TestingSetReady(false) + var resp serverpb.HealthResponse + err := srvtestutils.GetAdminJSONProto(s, "health?ready=1", &resp) + if err == nil { + t.Error("server appears ready even though SQL listener is not") + } + ts.TestingSetReady(true) + err = srvtestutils.GetAdminJSONProto(s, "health?ready=1", &resp) + if err != nil { + t.Errorf("server not ready after SQL listener is ready again: %v", err) + } + + // Expire this node's liveness record by pausing heartbeats and advancing the + // server's clock. + nl := ts.NodeLiveness().(*liveness.NodeLiveness) + defer nl.PauseAllHeartbeatsForTest()() + self, ok := nl.Self() + assert.True(t, ok) + s.Clock().Update(self.Expiration.ToTimestamp().Add(1, 0).UnsafeToClockTimestamp()) + + testutils.SucceedsSoon(t, func() error { + err := srvtestutils.GetAdminJSONProto(s, "health?ready=1", &resp) + if err == nil { + return errors.New("health OK, still waiting for unhealth") + } + + t.Logf("observed error: %v", err) + if !testutils.IsError(err, `(?s)503 Service Unavailable.*"error": "node is not healthy"`) { + return err + } + return nil + }) + + // After the node reports an error with `?ready=1`, the health + // endpoint must still succeed without error when `?ready=1` is not specified. + if err := srvtestutils.GetAdminJSONProto(s, "health", &resp); err != nil { + t.Fatal(err) + } +} + +func TestLivenessAPI(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + tc := testcluster.StartTestCluster(t, 3, base.TestClusterArgs{}) + defer tc.Stopper().Stop(context.Background()) + + startTime := tc.Server(0).Clock().PhysicalNow() + + // We need to retry because the gossiping of liveness status is an + // asynchronous process. + testutils.SucceedsSoon(t, func() error { + var resp serverpb.LivenessResponse + if err := serverutils.GetJSONProto(tc.Server(0), "/_admin/v1/liveness", &resp); err != nil { + return err + } + if a, e := len(resp.Livenesses), tc.NumServers(); a != e { + return errors.Errorf("found %d liveness records, wanted %d", a, e) + } + livenessMap := make(map[roachpb.NodeID]livenesspb.Liveness) + for _, l := range resp.Livenesses { + livenessMap[l.NodeID] = l + } + for i := 0; i < tc.NumServers(); i++ { + s := tc.Server(i) + sl, ok := livenessMap[s.NodeID()] + if !ok { + return errors.Errorf("found no liveness record for node %d", s.NodeID()) + } + if sl.Expiration.WallTime < startTime { + return errors.Errorf( + "expected node %d liveness to expire in future (after %d), expiration was %d", + s.NodeID(), + startTime, + sl.Expiration, + ) + } + status, ok := resp.Statuses[s.NodeID()] + if !ok { + return errors.Errorf("found no liveness status for node %d", s.NodeID()) + } + if a, e := status, livenesspb.NodeLivenessStatus_LIVE; a != e { + return errors.Errorf( + "liveness status for node %s was %s, wanted %s", s.NodeID(), a, e, + ) + } + } + return nil + }) +} diff --git a/pkg/server/storage_api/logfiles_test.go b/pkg/server/storage_api/logfiles_test.go new file mode 100644 index 000000000000..16850ac11cbd --- /dev/null +++ b/pkg/server/storage_api/logfiles_test.go @@ -0,0 +1,423 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package storage_api_test + +import ( + "bytes" + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/skip" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/log/logpb" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestStatusLocalLogs checks to ensure that local/logfiles, +// local/logfiles/{filename} and local/log function +// correctly. +func TestStatusLocalLogs(t *testing.T) { + defer leaktest.AfterTest(t)() + if log.V(3) { + skip.IgnoreLint(t, "Test only works with low verbosity levels") + } + + s := log.ScopeWithoutShowLogs(t) + defer s.Close(t) + + // This test cares about the number of output files. Ensure + // there's just one. + defer s.SetupSingleFileLogging()() + + ts, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer ts.Stopper().Stop(context.Background()) + + // Log an error of each main type which we expect to be able to retrieve. + // The resolution of our log timestamps is such that it's possible to get + // two subsequent log messages with the same timestamp. This test will fail + // when that occurs. By adding a small sleep in here after each timestamp to + // ensures this isn't the case and that the log filtering doesn't filter out + // the log entires we're looking for. The value of 20 μs was chosen because + // the log timestamps have a fidelity of 10 μs and thus doubling that should + // be a sufficient buffer. + // See util/log/clog.go formatHeader() for more details. + const sleepBuffer = time.Microsecond * 20 + timestamp := timeutil.Now().UnixNano() + time.Sleep(sleepBuffer) + log.Errorf(context.Background(), "TestStatusLocalLogFile test message-Error") + time.Sleep(sleepBuffer) + timestampE := timeutil.Now().UnixNano() + time.Sleep(sleepBuffer) + log.Warningf(context.Background(), "TestStatusLocalLogFile test message-Warning") + time.Sleep(sleepBuffer) + timestampEW := timeutil.Now().UnixNano() + time.Sleep(sleepBuffer) + log.Infof(context.Background(), "TestStatusLocalLogFile test message-Info") + time.Sleep(sleepBuffer) + timestampEWI := timeutil.Now().UnixNano() + + var wrapper serverpb.LogFilesListResponse + if err := srvtestutils.GetStatusJSONProto(ts, "logfiles/local", &wrapper); err != nil { + t.Fatal(err) + } + if a, e := len(wrapper.Files), 1; a != e { + t.Fatalf("expected %d log files; got %d", e, a) + } + + // Check each individual log can be fetched and is non-empty. + var foundInfo, foundWarning, foundError bool + for _, file := range wrapper.Files { + var wrapper serverpb.LogEntriesResponse + if err := srvtestutils.GetStatusJSONProto(ts, "logfiles/local/"+file.Name, &wrapper); err != nil { + t.Fatal(err) + } + for _, entry := range wrapper.Entries { + switch strings.TrimSpace(entry.Message) { + case "TestStatusLocalLogFile test message-Error": + foundError = true + case "TestStatusLocalLogFile test message-Warning": + foundWarning = true + case "TestStatusLocalLogFile test message-Info": + foundInfo = true + } + } + } + + if !(foundInfo && foundWarning && foundError) { + t.Errorf("expected to find test messages in %v", wrapper.Files) + } + + type levelPresence struct { + Error, Warning, Info bool + } + + testCases := []struct { + MaxEntities int + StartTimestamp int64 + EndTimestamp int64 + Pattern string + levelPresence + }{ + // Test filtering by log severity. + // // Test entry limit. Ignore Info/Warning/Error filters. + {1, timestamp, timestampEWI, "", levelPresence{false, false, false}}, + {2, timestamp, timestampEWI, "", levelPresence{false, false, false}}, + {3, timestamp, timestampEWI, "", levelPresence{false, false, false}}, + // Test filtering in different timestamp windows. + {0, timestamp, timestamp, "", levelPresence{false, false, false}}, + {0, timestamp, timestampE, "", levelPresence{true, false, false}}, + {0, timestampE, timestampEW, "", levelPresence{false, true, false}}, + {0, timestampEW, timestampEWI, "", levelPresence{false, false, true}}, + {0, timestamp, timestampEW, "", levelPresence{true, true, false}}, + {0, timestampE, timestampEWI, "", levelPresence{false, true, true}}, + {0, timestamp, timestampEWI, "", levelPresence{true, true, true}}, + // Test filtering by regexp pattern. + {0, 0, 0, "Info", levelPresence{false, false, true}}, + {0, 0, 0, "Warning", levelPresence{false, true, false}}, + {0, 0, 0, "Error", levelPresence{true, false, false}}, + {0, 0, 0, "Info|Error|Warning", levelPresence{true, true, true}}, + {0, 0, 0, "Nothing", levelPresence{false, false, false}}, + } + + for i, testCase := range testCases { + var url bytes.Buffer + fmt.Fprintf(&url, "logs/local?level=") + if testCase.MaxEntities > 0 { + fmt.Fprintf(&url, "&max=%d", testCase.MaxEntities) + } + if testCase.StartTimestamp > 0 { + fmt.Fprintf(&url, "&start_time=%d", testCase.StartTimestamp) + } + if testCase.StartTimestamp > 0 { + fmt.Fprintf(&url, "&end_time=%d", testCase.EndTimestamp) + } + if len(testCase.Pattern) > 0 { + fmt.Fprintf(&url, "&pattern=%s", testCase.Pattern) + } + + var wrapper serverpb.LogEntriesResponse + path := url.String() + if err := srvtestutils.GetStatusJSONProto(ts, path, &wrapper); err != nil { + t.Fatal(err) + } + + if testCase.MaxEntities > 0 { + if a, e := len(wrapper.Entries), testCase.MaxEntities; a != e { + t.Errorf("%d expected %d entries, got %d: \n%+v", i, e, a, wrapper.Entries) + } + } else { + var actual levelPresence + var logsBuf bytes.Buffer + for _, entry := range wrapper.Entries { + fmt.Fprintln(&logsBuf, entry.Message) + + switch strings.TrimSpace(entry.Message) { + case "TestStatusLocalLogFile test message-Error": + actual.Error = true + case "TestStatusLocalLogFile test message-Warning": + actual.Warning = true + case "TestStatusLocalLogFile test message-Info": + actual.Info = true + } + } + + if testCase.levelPresence != actual { + t.Errorf("%d: expected %+v at %s, got:\n%s", i, testCase, path, logsBuf.String()) + } + } + } +} + +// TestStatusLocalLogsTenantFilter checks to ensure that local/logfiles, +// local/logfiles/{filename} and local/log function correctly filter +// logs by tenant ID. +func TestStatusLocalLogsTenantFilter(t *testing.T) { + defer leaktest.AfterTest(t)() + if log.V(3) { + skip.IgnoreLint(t, "Test only works with low verbosity levels") + } + + s := log.ScopeWithoutShowLogs(t) + defer s.Close(t) + + // This test cares about the number of output files. Ensure + // there's just one. + defer s.SetupSingleFileLogging()() + + srv, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(context.Background()) + + ts := srv.(*server.TestServer) + + appTenantID := roachpb.MustMakeTenantID(uint64(2)) + ctxSysTenant, ctxAppTenant := server.TestingMakeLoggingContexts(appTenantID) + + // Log an error of each main type which we expect to be able to retrieve. + // The resolution of our log timestamps is such that it's possible to get + // two subsequent log messages with the same timestamp. This test will fail + // when that occurs. By adding a small sleep in here after each timestamp to + // ensures this isn't the case and that the log filtering doesn't filter out + // the log entires we're looking for. The value of 20 μs was chosen because + // the log timestamps have a fidelity of 10 μs and thus doubling that should + // be a sufficient buffer. + // See util/log/clog.go formatHeader() for more details. + const sleepBuffer = time.Microsecond * 20 + log.Errorf(ctxSysTenant, "system tenant msg 1") + time.Sleep(sleepBuffer) + log.Errorf(ctxAppTenant, "app tenant msg 1") + time.Sleep(sleepBuffer) + log.Warningf(ctxSysTenant, "system tenant msg 2") + time.Sleep(sleepBuffer) + log.Warningf(ctxAppTenant, "app tenant msg 2") + time.Sleep(sleepBuffer) + log.Infof(ctxSysTenant, "system tenant msg 3") + time.Sleep(sleepBuffer) + log.Infof(ctxAppTenant, "app tenant msg 3") + timestampEnd := timeutil.Now().UnixNano() + + var listFilesResp serverpb.LogFilesListResponse + if err := srvtestutils.GetStatusJSONProto(ts, "logfiles/local", &listFilesResp); err != nil { + t.Fatal(err) + } + require.Lenf(t, listFilesResp.Files, 1, "expected 1 log files; got %d", len(listFilesResp.Files)) + + testCases := []struct { + name string + tenantID roachpb.TenantID + }{ + { + name: "logs for system tenant does not apply filter", + tenantID: roachpb.SystemTenantID, + }, + { + name: "logs for app tenant applies tenant ID filter", + tenantID: appTenantID, + }, + } + + for _, testCase := range testCases { + // Non-system tenant servers filter to the tenant that they belong to. + // Set the server tenant ID for this test case. + ts.RPCContext().TenantID = testCase.tenantID + + var logfilesResp serverpb.LogEntriesResponse + if err := srvtestutils.GetStatusJSONProto(ts, "logfiles/local/"+listFilesResp.Files[0].Name, &logfilesResp); err != nil { + t.Fatal(err) + } + var logsResp serverpb.LogEntriesResponse + if err := srvtestutils.GetStatusJSONProto(ts, fmt.Sprintf("logs/local?end_time=%d", timestampEnd), &logsResp); err != nil { + t.Fatal(err) + } + + // Run the same set of assertions against both responses, as they are both expected + // to contain the log entries we're looking for. + for _, response := range []serverpb.LogEntriesResponse{logfilesResp, logsResp} { + sysTenantFound, appTenantFound := false, false + for _, logEntry := range response.Entries { + if !strings.HasSuffix(logEntry.File, "logfiles_test.go") { + continue + } + + if testCase.tenantID != roachpb.SystemTenantID { + require.Equal(t, logEntry.TenantID, testCase.tenantID.String()) + } else { + // Logs use the literal system tenant ID when tagging. + if logEntry.TenantID == fmt.Sprintf("%d", roachpb.SystemTenantID.InternalValue) { + sysTenantFound = true + } else if logEntry.TenantID == appTenantID.String() { + appTenantFound = true + } + } + } + if testCase.tenantID == roachpb.SystemTenantID { + require.True(t, sysTenantFound) + require.True(t, appTenantFound) + } + } + } +} + +// TestStatusLogRedaction checks that the log file retrieval RPCs +// honor the redaction flags. +func TestStatusLogRedaction(t *testing.T) { + defer leaktest.AfterTest(t)() + + testData := []struct { + redactableLogs bool // logging flag + redact bool // RPC request flag + expectedMessage string + expectedRedactable bool // redactable bit in result entries + }{ + // Note: all combinations of (redactableLogs, redact) must be tested below. + + // If there were no markers to start with (redactableLogs=false), we + // introduce markers around the entire message to indicate it's not known to + // be safe. + {false, false, `‹THISISSAFE THISISUNSAFE›`, true}, + // redact=true must be conservative and redact everything out if + // there were no markers to start with (redactableLogs=false). + {false, true, `‹×›`, false}, + // redact=false keeps whatever was in the log file. + {true, false, `THISISSAFE ‹THISISUNSAFE›`, true}, + // Whether or not to keep the redactable markers has no influence + // on the output of redaction, just on the presence of the + // "redactable" marker. In any case no information is leaked. + {true, true, `THISISSAFE ‹×›`, true}, + } + + testutils.RunTrueAndFalse(t, "redactableLogs", + func(t *testing.T, redactableLogs bool) { + s := log.ScopeWithoutShowLogs(t) + defer s.Close(t) + + // This test cares about the number of output files. Ensure + // there's just one. + defer s.SetupSingleFileLogging()() + + // Apply the redactable log boolean for this test. + defer log.TestingSetRedactable(redactableLogs)() + + ts, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer ts.Stopper().Stop(context.Background()) + + // Log something. + log.Infof(context.Background(), "THISISSAFE %s", "THISISUNSAFE") + + // Determine the log file name. + var wrapper serverpb.LogFilesListResponse + if err := srvtestutils.GetStatusJSONProto(ts, "logfiles/local", &wrapper); err != nil { + t.Fatal(err) + } + // We expect only the main log. + if a, e := len(wrapper.Files), 1; a != e { + t.Fatalf("expected %d log files; got %d: %+v", e, a, wrapper.Files) + } + file := wrapper.Files[0] + // Assert that the log that's present is not a stderr log. + if strings.Contains("stderr", file.Name) { + t.Fatalf("expected main log, found %v", file.Name) + } + + for _, tc := range testData { + if tc.redactableLogs != redactableLogs { + continue + } + t.Run(fmt.Sprintf("redact=%v", tc.redact), + func(t *testing.T) { + // checkEntries asserts that the redaction results are + // those expected in tc. + checkEntries := func(entries []logpb.Entry) { + foundMessage := false + for _, entry := range entries { + if !strings.HasSuffix(entry.File, "logfiles_test.go") { + continue + } + foundMessage = true + + assert.Equal(t, tc.expectedMessage, entry.Message) + } + if !foundMessage { + t.Fatalf("did not find expected message from test in log") + } + } + + // Retrieve the log entries with the configured flags using + // the LogFiles() RPC. + logFilesURL := fmt.Sprintf("logfiles/local/%s?redact=%v", file.Name, tc.redact) + var wrapper serverpb.LogEntriesResponse + if err := srvtestutils.GetStatusJSONProto(ts, logFilesURL, &wrapper); err != nil { + t.Fatal(err) + } + checkEntries(wrapper.Entries) + + // If the test specifies redact=false, check that a non-admin + // user gets a privilege error. + if !tc.redact { + err := srvtestutils.GetStatusJSONProtoWithAdminOption(ts, logFilesURL, &wrapper, false /* isAdmin */) + if !testutils.IsError(err, "status: 403") { + t.Fatalf("expected privilege error, got %v", err) + } + } + + // Retrieve the log entries using the Logs() RPC. + // Set a high `max` value to ensure we get the log line we're searching for. + logsURL := fmt.Sprintf("logs/local?redact=%v&max=5000", tc.redact) + var wrapper2 serverpb.LogEntriesResponse + if err := srvtestutils.GetStatusJSONProto(ts, logsURL, &wrapper2); err != nil { + t.Fatal(err) + } + checkEntries(wrapper2.Entries) + + // If the test specifies redact=false, check that a non-admin + // user gets a privilege error. + if !tc.redact { + err := srvtestutils.GetStatusJSONProtoWithAdminOption(ts, logsURL, &wrapper2, false /* isAdmin */) + if !testutils.IsError(err, "status: 403") { + t.Fatalf("expected privilege error, got %v", err) + } + } + }) + } + }) +} diff --git a/pkg/server/storage_api/main_test.go b/pkg/server/storage_api/main_test.go new file mode 100644 index 000000000000..6d5d1bf36fb9 --- /dev/null +++ b/pkg/server/storage_api/main_test.go @@ -0,0 +1,33 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package storage_api_test + +import ( + "os" + "testing" + + "github.com/cockroachdb/cockroach/pkg/kv/kvclient/kvtenant" + "github.com/cockroachdb/cockroach/pkg/security/securityassets" + "github.com/cockroachdb/cockroach/pkg/security/securitytest" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/testcluster" +) + +func TestMain(m *testing.M) { + securityassets.SetLoader(securitytest.EmbeddedAssets) + serverutils.InitTestServerFactory(server.TestServerFactory) + serverutils.InitTestClusterFactory(testcluster.TestClusterFactory) + kvtenant.InitTestConnectorFactory() + os.Exit(m.Run()) +} + +//go:generate ../util/leaktest/add-leaktest.sh *_test.go diff --git a/pkg/server/storage_api/network_test.go b/pkg/server/storage_api/network_test.go new file mode 100644 index 000000000000..72119b9d5fdb --- /dev/null +++ b/pkg/server/storage_api/network_test.go @@ -0,0 +1,72 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package storage_api_test + +import ( + "context" + "fmt" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestNetworkConnectivity(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + numNodes := 3 + testCluster := serverutils.StartNewTestCluster(t, numNodes, base.TestClusterArgs{ + ReplicationMode: base.ReplicationManual, + }) + ctx := context.Background() + defer testCluster.Stopper().Stop(ctx) + ts := testCluster.Server(0) + + var resp serverpb.NetworkConnectivityResponse + // Should wait because endpoint relies on Gossip. + testutils.SucceedsSoon(t, func() error { + if err := srvtestutils.GetStatusJSONProto(ts, "connectivity", &resp); err != nil { + return err + } + if len(resp.ErrorsByNodeID) > 0 { + return errors.Errorf("expected no errors but got: %d", len(resp.ErrorsByNodeID)) + } + if len(resp.Connections) < numNodes { + return errors.Errorf("expected results from %d nodes but got: %d", numNodes, len(resp.ErrorsByNodeID)) + } + return nil + }) + // Test when one node is stopped. + stoppedNodeID := testCluster.Server(1).NodeID() + testCluster.Server(1).Stopper().Stop(ctx) + + testutils.SucceedsSoon(t, func() error { + if err := srvtestutils.GetStatusJSONProto(ts, "connectivity", &resp); err != nil { + return err + } + require.Equal(t, len(resp.Connections), numNodes-1) + fmt.Printf("got status: %s", resp.Connections[ts.NodeID()].Peers[stoppedNodeID].Status.String()) + if resp.Connections[ts.NodeID()].Peers[stoppedNodeID].Status != serverpb.NetworkConnectivityResponse_ERROR { + return errors.New("waiting for connection state to be changed.") + } + if latency := resp.Connections[ts.NodeID()].Peers[stoppedNodeID].Latency; latency > 0 { + return errors.Errorf("expected latency to be 0 but got %s", latency.String()) + } + return nil + }) +} diff --git a/pkg/server/storage_api/nodes_test.go b/pkg/server/storage_api/nodes_test.go new file mode 100644 index 000000000000..fd30482cbc76 --- /dev/null +++ b/pkg/server/storage_api/nodes_test.go @@ -0,0 +1,214 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package storage_api_test + +import ( + "context" + "strconv" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/build" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/server/status/statuspb" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/ts" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/pkg/errors" +) + +// TestStatusJson verifies that status endpoints return expected Json results. +// The content type of the responses is always httputil.JSONContentType. +func TestStatusJson(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + ts := s.(*server.TestServer) + + nodeID := ts.Gossip().NodeID.Get() + addr, err := ts.Gossip().GetNodeIDAddress(nodeID) + if err != nil { + t.Fatal(err) + } + sqlAddr, err := ts.Gossip().GetNodeIDSQLAddress(nodeID) + if err != nil { + t.Fatal(err) + } + + var nodes serverpb.NodesResponse + testutils.SucceedsSoon(t, func() error { + if err := srvtestutils.GetStatusJSONProto(s, "nodes", &nodes); err != nil { + t.Fatal(err) + } + + if len(nodes.Nodes) == 0 { + return errors.Errorf("expected non-empty node list, got: %v", nodes) + } + return nil + }) + + for _, path := range []string{ + apiconstants.StatusPrefix + "details/local", + apiconstants.StatusPrefix + "details/" + strconv.FormatUint(uint64(nodeID), 10), + } { + var details serverpb.DetailsResponse + if err := serverutils.GetJSONProto(s, path, &details); err != nil { + t.Fatal(err) + } + if a, e := details.NodeID, nodeID; a != e { + t.Errorf("expected: %d, got: %d", e, a) + } + if a, e := details.Address, *addr; a != e { + t.Errorf("expected: %v, got: %v", e, a) + } + if a, e := details.SQLAddress, *sqlAddr; a != e { + t.Errorf("expected: %v, got: %v", e, a) + } + if a, e := details.BuildInfo, build.GetInfo(); a != e { + t.Errorf("expected: %v, got: %v", e, a) + } + } +} + +// TestNodeStatusResponse verifies that node status returns the expected +// results. +func TestNodeStatusResponse(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + srv, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(context.Background()) + s := srv.(*server.TestServer) + node := s.Node().(*server.Node) + + wrapper := serverpb.NodesResponse{} + + // Check that the node statuses cannot be accessed via a non-admin account. + if err := srvtestutils.GetStatusJSONProtoWithAdminOption(s, "nodes", &wrapper, false /* isAdmin */); !testutils.IsError(err, "status: 403") { + t.Fatalf("expected privilege error, got %v", err) + } + + // Now fetch all the node statuses as admin. + if err := srvtestutils.GetStatusJSONProto(s, "nodes", &wrapper); err != nil { + t.Fatal(err) + } + nodeStatuses := wrapper.Nodes + + if len(nodeStatuses) != 1 { + t.Errorf("too many node statuses returned - expected:1 actual:%d", len(nodeStatuses)) + } + if !node.Descriptor.Equal(&nodeStatuses[0].Desc) { + t.Errorf("node status descriptors are not equal\nexpected:%+v\nactual:%+v\n", node.Descriptor, nodeStatuses[0].Desc) + } + + // Now fetch each one individually. Loop through the nodeStatuses to use the + // ids only. + for _, oldNodeStatus := range nodeStatuses { + nodeStatus := statuspb.NodeStatus{} + nodeURL := "nodes/" + oldNodeStatus.Desc.NodeID.String() + // Check that the node statuses cannot be accessed via a non-admin account. + if err := srvtestutils.GetStatusJSONProtoWithAdminOption(s, nodeURL, &nodeStatus, false /* isAdmin */); !testutils.IsError(err, "status: 403") { + t.Fatalf("expected privilege error, got %v", err) + } + + // Now access that node's status. + if err := srvtestutils.GetStatusJSONProto(s, nodeURL, &nodeStatus); err != nil { + t.Fatal(err) + } + if !node.Descriptor.Equal(&nodeStatus.Desc) { + t.Errorf("node status descriptors are not equal\nexpected:%+v\nactual:%+v\n", node.Descriptor, nodeStatus.Desc) + } + } +} + +// TestMetricsRecording verifies that Node statistics are periodically recorded +// as time series data. +func TestMetricsRecording(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + + s, _, kvDB := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(ctx) + + // Verify that metrics for the current timestamp are recorded. This should + // be true very quickly even though DefaultMetricsSampleInterval is large, + // because the server writes an entry eagerly on startup. + testutils.SucceedsSoon(t, func() error { + now := s.Clock().PhysicalNow() + + var data roachpb.InternalTimeSeriesData + for _, keyName := range []string{ + "cr.store.livebytes.1", + "cr.node.sys.go.allocbytes.1", + } { + key := ts.MakeDataKey(keyName, "", ts.Resolution10s, now) + if err := kvDB.GetProto(ctx, key, &data); err != nil { + return err + } + } + return nil + }) +} + +// TestMetricsEndpoint retrieves the metrics endpoint, which is currently only +// used for development purposes. The metrics within the response are verified +// in other tests. +func TestMetricsEndpoint(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + srv, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(context.Background()) + + s := srv.(*server.TestServer) + + if _, err := srvtestutils.GetText(s, s.AdminURL().WithPath(apiconstants.StatusPrefix+"metrics/"+s.Gossip().NodeID.String()).String()); err != nil { + t.Fatal(err) + } +} + +func TestNodesGRPCResponse(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + ts := s.(*server.TestServer) + + rootConfig := testutils.NewTestBaseContext(username.RootUserName()) + rpcContext := srvtestutils.NewRPCTestContext(context.Background(), ts, rootConfig) + var request serverpb.NodesRequest + + url := ts.ServingRPCAddr() + nodeID := ts.NodeID() + conn, err := rpcContext.GRPCDialNode(url, nodeID, rpc.DefaultClass).Connect(context.Background()) + if err != nil { + t.Fatal(err) + } + client := serverpb.NewStatusClient(conn) + + response, err := client.Nodes(context.Background(), &request) + if err != nil { + t.Fatal(err) + } + + if a, e := len(response.Nodes), 1; a != e { + t.Errorf("expected %d node(s), found %d", e, a) + } +} diff --git a/pkg/server/storage_api/raft_test.go b/pkg/server/storage_api/raft_test.go new file mode 100644 index 000000000000..e55077af3ec7 --- /dev/null +++ b/pkg/server/storage_api/raft_test.go @@ -0,0 +1,75 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package storage_api_test + +import ( + "context" + "fmt" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" +) + +func TestRaftDebug(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + + var resp serverpb.RaftDebugResponse + if err := srvtestutils.GetStatusJSONProto(s, "raft", &resp); err != nil { + t.Fatal(err) + } + if len(resp.Ranges) == 0 { + t.Errorf("didn't get any ranges") + } + + if len(resp.Ranges) < 3 { + t.Errorf("expected more than 2 ranges, got %d", len(resp.Ranges)) + } + + reqURI := "raft" + requestedIDs := []roachpb.RangeID{} + for id := range resp.Ranges { + if len(requestedIDs) == 0 { + reqURI += "?" + } else { + reqURI += "&" + } + reqURI += fmt.Sprintf("range_ids=%d", id) + requestedIDs = append(requestedIDs, id) + if len(requestedIDs) >= 2 { + break + } + } + + if err := srvtestutils.GetStatusJSONProto(s, reqURI, &resp); err != nil { + t.Fatal(err) + } + + // Make sure we get exactly two ranges back. + if len(resp.Ranges) != 2 { + t.Errorf("expected exactly two ranges in response, got %d", len(resp.Ranges)) + } + + // Make sure the ranges returned are those requested. + for _, reqID := range requestedIDs { + if _, ok := resp.Ranges[reqID]; !ok { + t.Errorf("request URI was %s, but range ID %d not returned: %+v", reqURI, reqID, resp.Ranges) + } + } +} diff --git a/pkg/server/storage_api/rangelog_test.go b/pkg/server/storage_api/rangelog_test.go new file mode 100644 index 000000000000..edcff540742d --- /dev/null +++ b/pkg/server/storage_api/rangelog_test.go @@ -0,0 +1,178 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package storage_api_test + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/kv/kvserver" + "github.com/cockroachdb/cockroach/pkg/kv/kvserver/kvserverpb" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" +) + +func TestAdminAPIRangeLogByRangeID(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, db, _ := serverutils.StartServer(t, base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails + // with it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + }) + defer s.Stopper().Stop(context.Background()) + + rangeID := 654321 + testCases := []struct { + rangeID int + hasLimit bool + limit int + expected int + }{ + {rangeID, true, 0, 2}, + {rangeID, true, -1, 2}, + {rangeID, true, 1, 1}, + {rangeID, false, 0, 2}, + // We'll create one event that has rangeID+1 as the otherRangeID. + {rangeID + 1, false, 0, 1}, + } + + for _, otherRangeID := range []int{rangeID + 1, rangeID + 2} { + if _, err := db.Exec( + `INSERT INTO system.rangelog ( + timestamp, "rangeID", "otherRangeID", "storeID", "eventType" + ) VALUES ( + now(), $1, $2, $3, $4 + )`, + rangeID, otherRangeID, + 1, // storeID + kvserverpb.RangeLogEventType_add_voter.String(), + ); err != nil { + t.Fatal(err) + } + } + + for _, tc := range testCases { + url := fmt.Sprintf("rangelog/%d", tc.rangeID) + if tc.hasLimit { + url += fmt.Sprintf("?limit=%d", tc.limit) + } + t.Run(url, func(t *testing.T) { + var resp serverpb.RangeLogResponse + if err := srvtestutils.GetAdminJSONProto(s, url, &resp); err != nil { + t.Fatal(err) + } + + if e, a := tc.expected, len(resp.Events); e != a { + t.Fatalf("expected %d events, got %d", e, a) + } + + for _, event := range resp.Events { + expID := roachpb.RangeID(tc.rangeID) + if event.Event.RangeID != expID && event.Event.OtherRangeID != expID { + t.Errorf("expected rangeID or otherRangeID to be %d, got %d and r%d", + expID, event.Event.RangeID, event.Event.OtherRangeID) + } + } + }) + } +} + +// Test the range log API when queries are not filtered by a range ID (like in +// TestAdminAPIRangeLogByRangeID). +func TestAdminAPIFullRangeLog(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, db, _ := serverutils.StartServer(t, + base.TestServerArgs{ + // Disable the default test tenant for now as this tests fails + // with it enabled. Tracked with #81590. + DefaultTestTenant: base.TODOTestTenantDisabled, + Knobs: base.TestingKnobs{ + Store: &kvserver.StoreTestingKnobs{ + DisableSplitQueue: true, + }, + }, + }) + defer s.Stopper().Stop(context.Background()) + + // Insert something in the rangelog table, otherwise it's empty for new + // clusters. + rows, err := db.Query(`SELECT count(1) FROM system.rangelog`) + if err != nil { + t.Fatal(err) + } + if !rows.Next() { + t.Fatal("missing row") + } + var cnt int + if err := rows.Scan(&cnt); err != nil { + t.Fatal(err) + } + if err := rows.Close(); err != nil { + t.Fatal(err) + } + if cnt != 0 { + t.Fatalf("expected 0 rows in system.rangelog, found: %d", cnt) + } + const rangeID = 100 + for i := 0; i < 10; i++ { + if _, err := db.Exec( + `INSERT INTO system.rangelog ( + timestamp, "rangeID", "storeID", "eventType" + ) VALUES (now(), $1, 1, $2)`, + rangeID, + kvserverpb.RangeLogEventType_add_voter.String(), + ); err != nil { + t.Fatal(err) + } + } + expectedEvents := 10 + + testCases := []struct { + hasLimit bool + limit int + expected int + }{ + {false, 0, expectedEvents}, + {true, 0, expectedEvents}, + {true, -1, expectedEvents}, + {true, 1, 1}, + } + + for _, tc := range testCases { + url := "rangelog" + if tc.hasLimit { + url += fmt.Sprintf("?limit=%d", tc.limit) + } + t.Run(url, func(t *testing.T) { + var resp serverpb.RangeLogResponse + if err := srvtestutils.GetAdminJSONProto(s, url, &resp); err != nil { + t.Fatal(err) + } + events := resp.Events + if e, a := tc.expected, len(events); e != a { + var sb strings.Builder + for _, ev := range events { + sb.WriteString(ev.String() + "\n") + } + t.Fatalf("expected %d events, got %d:\n%s", e, a, sb.String()) + } + }) + } +} diff --git a/pkg/server/storage_api/ranges_test.go b/pkg/server/storage_api/ranges_test.go new file mode 100644 index 000000000000..1c7832e9a661 --- /dev/null +++ b/pkg/server/storage_api/ranges_test.go @@ -0,0 +1,190 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package storage_api_test + +import ( + "context" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/keys" + "github.com/cockroachdb/cockroach/pkg/kv/kvserver" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/stretchr/testify/require" +) + +func TestRangesResponse(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + defer kvserver.EnableLeaseHistoryForTesting(100)() + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + + ts := s.(*server.TestServer) + + t.Run("test ranges response", func(t *testing.T) { + // Perform a scan to ensure that all the raft groups are initialized. + if _, err := ts.DB().Scan(context.Background(), keys.LocalMax, roachpb.KeyMax, 0); err != nil { + t.Fatal(err) + } + + var response serverpb.RangesResponse + if err := srvtestutils.GetStatusJSONProto(ts, "ranges/local", &response); err != nil { + t.Fatal(err) + } + if len(response.Ranges) == 0 { + t.Errorf("didn't get any ranges") + } + for _, ri := range response.Ranges { + // Do some simple validation based on the fact that this is a + // single-node cluster. + if ri.RaftState.State != "StateLeader" && ri.RaftState.State != server.RaftStateDormant { + t.Errorf("expected to be Raft leader or dormant, but was '%s'", ri.RaftState.State) + } + expReplica := roachpb.ReplicaDescriptor{ + NodeID: 1, + StoreID: 1, + ReplicaID: 1, + } + if len(ri.State.Desc.InternalReplicas) != 1 || ri.State.Desc.InternalReplicas[0] != expReplica { + t.Errorf("unexpected replica list %+v", ri.State.Desc.InternalReplicas) + } + if ri.State.Lease == nil || ri.State.Lease.Empty() { + t.Error("expected a nontrivial Lease") + } + if ri.State.LastIndex == 0 { + t.Error("expected positive LastIndex") + } + if len(ri.LeaseHistory) == 0 { + t.Error("expected at least one lease history entry") + } + } + }) + + t.Run("test ranges pagination", func(t *testing.T) { + ctx := context.Background() + rpcStopper := stop.NewStopper() + defer rpcStopper.Stop(ctx) + + conn, err := ts.RPCContext().GRPCDialNode(ts.ServingRPCAddr(), ts.NodeID(), rpc.DefaultClass).Connect(ctx) + if err != nil { + t.Fatal(err) + } + client := serverpb.NewStatusClient(conn) + resp1, err := client.Ranges(ctx, &serverpb.RangesRequest{ + Limit: 1, + }) + require.NoError(t, err) + require.Len(t, resp1.Ranges, 1) + require.Equal(t, int(resp1.Next), 1) + + resp2, err := client.Ranges(ctx, &serverpb.RangesRequest{ + Limit: 1, + Offset: resp1.Next, + }) + require.NoError(t, err) + require.Len(t, resp2.Ranges, 1) + require.Equal(t, int(resp2.Next), 2) + + // Verify pagination functions based on ascending RangeID order. + require.True(t, resp1.Ranges[0].State.Desc.RangeID < resp2.Ranges[0].State.Desc.RangeID) + }) +} + +func TestTenantRangesResponse(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ctx := context.Background() + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(ctx) + ts := s.(*server.TestServer) + + t.Run("returns error when TenantID not set in ctx", func(t *testing.T) { + rpcStopper := stop.NewStopper() + defer rpcStopper.Stop(ctx) + + conn, err := ts.RPCContext().GRPCDialNode(ts.ServingRPCAddr(), ts.NodeID(), rpc.DefaultClass).Connect(ctx) + if err != nil { + t.Fatal(err) + } + client := serverpb.NewStatusClient(conn) + _, err = client.TenantRanges(ctx, &serverpb.TenantRangesRequest{}) + require.Error(t, err) + require.Contains(t, err.Error(), "no tenant ID found in context") + }) +} + +func TestRangeResponse(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + defer kvserver.EnableLeaseHistoryForTesting(100)() + ts, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer ts.Stopper().Stop(context.Background()) + + // Perform a scan to ensure that all the raft groups are initialized. + if _, err := ts.DB().Scan(context.Background(), keys.LocalMax, roachpb.KeyMax, 0); err != nil { + t.Fatal(err) + } + + var response serverpb.RangeResponse + if err := srvtestutils.GetStatusJSONProto(ts, "range/1", &response); err != nil { + t.Fatal(err) + } + + // This is a single node cluster, so only expect a single response. + if e, a := 1, len(response.ResponsesByNodeID); e != a { + t.Errorf("got the wrong number of responses, expected %d, actual %d", e, a) + } + + node1Response := response.ResponsesByNodeID[response.NodeID] + + // The response should come back as valid. + if !node1Response.Response { + t.Errorf("node1's response returned as false, expected true") + } + + // The response should include just the one range. + if e, a := 1, len(node1Response.Infos); e != a { + t.Errorf("got the wrong number of ranges in the response, expected %d, actual %d", e, a) + } + + info := node1Response.Infos[0] + expReplica := roachpb.ReplicaDescriptor{ + NodeID: 1, + StoreID: 1, + ReplicaID: 1, + } + + // Check some other values. + if len(info.State.Desc.InternalReplicas) != 1 || info.State.Desc.InternalReplicas[0] != expReplica { + t.Errorf("unexpected replica list %+v", info.State.Desc.InternalReplicas) + } + + if info.State.Lease == nil || info.State.Lease.Empty() { + t.Error("expected a nontrivial Lease") + } + + if info.State.LastIndex == 0 { + t.Error("expected positive LastIndex") + } + + if len(info.LeaseHistory) == 0 { + t.Error("expected at least one lease history entry") + } +} diff --git a/pkg/server/tenant.go b/pkg/server/tenant.go index 82bd026a9666..d136a7b2332a 100644 --- a/pkg/server/tenant.go +++ b/pkg/server/tenant.go @@ -46,7 +46,10 @@ import ( "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/rpc/nodedialer" "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server/apiutil" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/debug" + "github.com/cockroachdb/cockroach/pkg/server/privchecker" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/server/status" "github.com/cockroachdb/cockroach/pkg/server/structlogging" @@ -100,11 +103,11 @@ type SQLServerWrapper struct { runtime *status.RuntimeStatSampler http *httpServer - adminAuthzCheck *adminPrivilegeChecker + adminAuthzCheck privchecker.CheckerForRPCHandlers tenantAdmin *adminServer tenantStatus *statusServer drainServer *drainServer - authentication *authenticationServer + authentication authserver.Server // eventsExporter exports data to the Observability Service. eventsExporter obs.EventsExporterInterface stopper *stop.Stopper @@ -264,11 +267,7 @@ func newTenantServer( // Instantiate the API privilege checker. // // TODO(tbg): give adminServer only what it needs (and avoid circular deps). - adminAuthzCheck := &adminPrivilegeChecker{ - ie: args.circularInternalExecutor, - st: args.Settings, - makePlanner: nil, - } + adminAuthzCheck := privchecker.NewChecker(args.circularInternalExecutor, args.Settings) // Instantiate the HTTP server. // These callbacks help us avoid a dependency on gossip in httpServer. @@ -360,11 +359,11 @@ func newTenantServer( sqlServer.migrationServer = tms // only for testing via TestTenant // Tell the authz server how to connect to SQL. - adminAuthzCheck.makePlanner = func(opName string) (interface{}, func()) { + adminAuthzCheck.SetAuthzAccessorFactory(func(opName string) (sql.AuthorizationAccessor, func()) { // This is a hack to get around a Go package dependency cycle. See comment // in sql/jobs/registry.go on planHookMaker. txn := args.db.NewTxn(ctx, "check-system-privilege") - return sql.NewInternalPlanner( + p, cleanup := sql.NewInternalPlanner( opName, txn, username.RootUserName(), @@ -372,10 +371,11 @@ func newTenantServer( sqlServer.execCfg, sql.NewInternalSessionData(ctx, sqlServer.execCfg.Settings, opName), ) - } + return p.(sql.AuthorizationAccessor), cleanup + }) // Create the authentication RPC server (login/logout). - sAuth := newAuthenticationServer(baseCfg.Config, sqlServer) + sAuth := authserver.NewServer(baseCfg.Config, sqlServer) // Create a drain server. drainServer := newDrainServer(baseCfg, args.stopper, args.stopTrigger, args.grpc, sqlServer) @@ -725,7 +725,7 @@ func (s *SQLServerWrapper) PreStart(ctx context.Context) error { gwMux, /* handleRequestsUnauthenticated */ s.debug, /* handleDebugUnauthenticated */ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - writeJSONResponse(r.Context(), w, http.StatusNotImplemented, nil) + apiutil.WriteJSONResponse(r.Context(), w, http.StatusNotImplemented, nil) }), newAPIV2Server(workersCtx, &apiV2ServerOpts{ admin: s.tenantAdmin, diff --git a/pkg/server/testserver.go b/pkg/server/testserver.go index 847de6444dcb..e49643e29341 100644 --- a/pkg/server/testserver.go +++ b/pkg/server/testserver.go @@ -25,6 +25,7 @@ import ( "github.com/cenkalti/backoff" circuit "github.com/cockroachdb/circuitbreaker" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/base/serverident" "github.com/cockroachdb/cockroach/pkg/clusterversion" "github.com/cockroachdb/cockroach/pkg/config" "github.com/cockroachdb/cockroach/pkg/config/zonepb" @@ -37,16 +38,20 @@ import ( "github.com/cockroachdb/cockroach/pkg/kv/kvpb" "github.com/cockroachdb/cockroach/pkg/kv/kvprober" "github.com/cockroachdb/cockroach/pkg/kv/kvserver" + "github.com/cockroachdb/cockroach/pkg/kv/kvserver/allocator/plan" "github.com/cockroachdb/cockroach/pkg/kv/kvserver/liveness" "github.com/cockroachdb/cockroach/pkg/multitenant/tenantcapabilities" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/security/certnames" "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server/authserver" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/server/status" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/sql" "github.com/cockroachdb/cockroach/pkg/sql/catalog/bootstrap" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" "github.com/cockroachdb/cockroach/pkg/sql/deprecatedshowranges" "github.com/cockroachdb/cockroach/pkg/sql/pgwire" "github.com/cockroachdb/cockroach/pkg/sql/physicalplan" @@ -753,6 +758,11 @@ func (t *TestTenant) DistSenderI() interface{} { return t.sql.execCfg.DistSender } +// InternalExecutor is part of the serverutils.TestTenantInterface. +func (t *TestTenant) InternalExecutor() interface{} { + return t.sql.internalExecutor +} + // RPCContext is part of the serverutils.TestTenantInterface. func (t *TestTenant) RPCContext() *rpc.Context { return t.sql.execCfg.RPCContext @@ -978,6 +988,12 @@ func (t *TestTenant) MigrationServer() interface{} { return t.sql.migrationServer } +// HTTPAuthServer is part of the serverutils.TestServerInterface. +// HTTPAuthServer is part of the TestTenantInterface. +func (t *TestTenant) HTTPAuthServer() interface{} { + return t.t.authentication +} + // StartTenant is part of the serverutils.TestServerInterface. func (ts *TestServer) StartTenant( ctx context.Context, params base.TestTenantArgs, @@ -1352,18 +1368,6 @@ func (ts *TestServer) DiagnosticsReporter() interface{} { return ts.Server.sqlServer.diagnosticsReporter } -const authenticatedUser = "authentic_user" - -func authenticatedUserName() username.SQLUsername { - return username.MakeSQLUsernameFromPreNormalizedString(authenticatedUser) -} - -const authenticatedUserNoAdmin = "authentic_user_noadmin" - -func authenticatedUserNameNoAdmin() username.SQLUsername { - return username.MakeSQLUsernameFromPreNormalizedString(authenticatedUserNoAdmin) -} - type v2AuthDecorator struct { http.RoundTripper @@ -1371,7 +1375,7 @@ type v2AuthDecorator struct { } func (v *v2AuthDecorator) RoundTrip(r *http.Request) (*http.Response, error) { - r.Header.Add(apiV2AuthHeader, v.session) + r.Header.Add(authserver.APIV2AuthHeader, v.session) return v.RoundTripper.RoundTrip(r) } @@ -1867,11 +1871,91 @@ func (ts *TestServer) KvProber() *kvprober.Prober { return ts.Server.kvProber } +// TestingQueryDatabaseID provides access to the database name-to-ID conversion function +// for use in API tests. +func (ts *TestServer) TestingQueryDatabaseID( + ctx context.Context, userName username.SQLUsername, dbName string, +) (descpb.ID, error) { + return ts.admin.queryDatabaseID(ctx, userName, dbName) +} + +// TestingQueryTableID provides access to the table name-to-ID conversion function +// for use in API tests. +func (ts *TestServer) TestingQueryTableID( + ctx context.Context, userName username.SQLUsername, dbName, tbName string, +) (descpb.ID, error) { + return ts.admin.queryTableID(ctx, userName, dbName, tbName) +} + +// TestingStatsForSpans provides access to the span stats inspection function +// for use in API tests. +func (ts *TestServer) TestingStatsForSpan( + ctx context.Context, span roachpb.Span, +) (*serverpb.TableStatsResponse, error) { + return ts.admin.statsForSpan(ctx, span) +} + +// TestingSetReady is exposed for use in health tests. +func (ts *TestServer) TestingSetReady(ready bool) { + ts.sqlServer.isReady.Set(ready) +} + +// HTTPAuthServer is part of the TestTenantInterface. +func (ts *TestServer) HTTPAuthServer() interface{} { + return ts.t.authentication +} + type testServerFactoryImpl struct{} // TestServerFactory can be passed to serverutils.InitTestServerFactory +// and rangetestutils.InitTestServerFactory. var TestServerFactory = testServerFactoryImpl{} +// MakeRangeTestServerargs is part of the rangetestutils.TestServerFactory interface. +func (testServerFactoryImpl) MakeRangeTestServerArgs() base.TestServerArgs { + return base.TestServerArgs{ + StoreSpecs: []base.StoreSpec{ + base.DefaultTestStoreSpec, + base.DefaultTestStoreSpec, + base.DefaultTestStoreSpec, + }, + Knobs: base.TestingKnobs{ + Store: &kvserver.StoreTestingKnobs{ + // Now that we allow same node rebalances, disable it in these tests, + // as they dont expect replicas to move. + ReplicaPlannerKnobs: plan.ReplicaPlannerTestingKnobs{ + DisableReplicaRebalancing: true, + }, + }, + }, + } +} + +// PrepareRangeTestServer is part of the rangetestutils.TestServerFactory interface. +func (testServerFactoryImpl) PrepareRangeTestServer(srv interface{}) error { + ts := srv.(*TestServer) + kvDB := ts.TenantOrServer().DB() + + // Make sure the range is spun up with an arbitrary read command. We do not + // expect a specific response. + if _, err := kvDB.Get(context.Background(), "a"); err != nil { + return err + } + + // Make sure the node status is available. This is done by forcing stores to + // publish their status, synchronizing to the event feed with a canary + // event, and then forcing the server to write summaries immediately. + if err := ts.node.computeMetricsPeriodically(context.Background(), map[*kvserver.Store]*storage.MetricsForInterval{}, 0); err != nil { + return errors.Wrap(err, "error publishing store statuses") + } + + if err := ts.WriteSummaries(); err != nil { + return errors.Wrap(err, "error writing summaries") + } + + return nil +} + // New is part of TestServerFactory interface. func (testServerFactoryImpl) New(params base.TestServerArgs) (interface{}, error) { if params.Knobs.JobsTestingKnobs != nil { @@ -1959,3 +2043,22 @@ func mustGetSQLCounterForRegistry(registry *metric.Registry, name string) int64 } return c } + +// TestingMakeLoggingContexts is exposed for use in tests. +func TestingMakeLoggingContexts( + appTenantID roachpb.TenantID, +) (sysContext, appContext context.Context) { + ctxSysTenant := context.Background() + ctxSysTenant = context.WithValue(ctxSysTenant, serverident.ServerIdentificationContextKey{}, &idProvider{ + tenantID: roachpb.SystemTenantID, + clusterID: &base.ClusterIDContainer{}, + serverID: &base.NodeIDContainer{}, + }) + ctxAppTenant := context.Background() + ctxAppTenant = context.WithValue(ctxAppTenant, serverident.ServerIdentificationContextKey{}, &idProvider{ + tenantID: appTenantID, + clusterID: &base.ClusterIDContainer{}, + serverID: &base.NodeIDContainer{}, + }) + return ctxSysTenant, ctxAppTenant +} diff --git a/pkg/server/testserver_http.go b/pkg/server/testserver_http.go index 938f970e687e..54d3b1bf7e97 100644 --- a/pkg/server/testserver_http.go +++ b/pkg/server/testserver_http.go @@ -21,6 +21,8 @@ import ( "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" @@ -33,7 +35,7 @@ type httpTestServer struct { t struct { // We need a sub-struct to avoid ambiguous overlap with the fields // of *Server, which are also embedded in TestServer. - authentication *authenticationServer + authentication authserver.Server sqlServer *SQLServer tenantName roachpb.TenantName } @@ -89,8 +91,8 @@ func (ts *httpTestServer) GetUnauthenticatedHTTPClient() (http.Client, error) { // GetAdminHTTPClient implements the TestServerInterface. func (ts *httpTestServer) GetAdminHTTPClient() (http.Client, error) { - httpClient, _, err := ts.getAuthenticatedHTTPClientAndCookie( - authenticatedUserName(), true, serverutils.SingleTenantSession, + httpClient, _, err := ts.GetAuthenticatedHTTPClientAndCookie( + apiconstants.TestingUserName(), true, serverutils.SingleTenantSession, ) return httpClient, err } @@ -99,25 +101,27 @@ func (ts *httpTestServer) GetAdminHTTPClient() (http.Client, error) { func (ts *httpTestServer) GetAuthenticatedHTTPClient( isAdmin bool, session serverutils.SessionType, ) (http.Client, error) { - authUser := authenticatedUserName() + authUser := apiconstants.TestingUserName() if !isAdmin { - authUser = authenticatedUserNameNoAdmin() + authUser = apiconstants.TestingUserNameNoAdmin() } - httpClient, _, err := ts.getAuthenticatedHTTPClientAndCookie(authUser, isAdmin, session) + httpClient, _, err := ts.GetAuthenticatedHTTPClientAndCookie(authUser, isAdmin, session) return httpClient, err } // GetAuthenticatedHTTPClient implements the TestServerInterface. func (ts *httpTestServer) GetAuthSession(isAdmin bool) (*serverpb.SessionCookie, error) { - authUser := authenticatedUserName() + authUser := apiconstants.TestingUserName() if !isAdmin { - authUser = authenticatedUserNameNoAdmin() + authUser = apiconstants.TestingUserNameNoAdmin() } - _, cookie, err := ts.getAuthenticatedHTTPClientAndCookie(authUser, isAdmin, serverutils.SingleTenantSession) + _, cookie, err := ts.GetAuthenticatedHTTPClientAndCookie(authUser, isAdmin, serverutils.SingleTenantSession) return cookie, err } -func (ts *httpTestServer) getAuthenticatedHTTPClientAndCookie( +// GetAuthenticatedHTTPClientAndCookie returns an authenticated HTTP +// client and the session cookie for the client. +func (ts *httpTestServer) GetAuthenticatedHTTPClientAndCookie( authUser username.SQLUsername, isAdmin bool, session serverutils.SessionType, ) (http.Client, *serverpb.SessionCookie, error) { authIdx := 0 @@ -129,11 +133,11 @@ func (ts *httpTestServer) getAuthenticatedHTTPClientAndCookie( // Create an authentication session for an arbitrary admin user. authClient.err = func() error { // The user needs to exist as the admin endpoints will check its role. - if err := ts.createAuthUser(authUser, isAdmin); err != nil { + if err := ts.CreateAuthUser(authUser, isAdmin); err != nil { return err } - id, secret, err := ts.t.authentication.newAuthSession(context.TODO(), authUser) + id, secret, err := ts.t.authentication.NewAuthSession(context.TODO(), authUser) if err != nil { return err } @@ -142,7 +146,7 @@ func (ts *httpTestServer) getAuthenticatedHTTPClientAndCookie( Secret: secret, } // Encode a session cookie and store it in a cookie jar. - cookie, err := EncodeSessionCookie(rawCookie, false /* forHTTPSOnly */) + cookie, err := authserver.EncodeSessionCookie(rawCookie, false /* forHTTPSOnly */) if err != nil { return err } @@ -155,7 +159,7 @@ func (ts *httpTestServer) getAuthenticatedHTTPClientAndCookie( return err } if session == serverutils.MultiTenantSession { - cookie.Name = SessionCookieName + cookie.Name = authserver.SessionCookieName cookie.Value = fmt.Sprintf("%s,%s", cookie.Value, ts.t.tenantName) } cookieJar.SetCookies(url, []*http.Cookie{cookie}) @@ -184,7 +188,8 @@ func (ts *httpTestServer) getAuthenticatedHTTPClientAndCookie( return authClient.httpClient, authClient.cookie, authClient.err } -func (ts *httpTestServer) createAuthUser(userName username.SQLUsername, isAdmin bool) error { +// CreateAuthUser is exported for use in tests. +func (ts *httpTestServer) CreateAuthUser(userName username.SQLUsername, isAdmin bool) error { if _, err := ts.t.sqlServer.internalExecutor.ExecEx(context.TODO(), "create-auth-user", nil, sessiondata.RootUserSessionDataOverride, diff --git a/pkg/server/user.go b/pkg/server/user.go index 2cd2b1d8e454..10bc5ecc9d12 100644 --- a/pkg/server/user.go +++ b/pkg/server/user.go @@ -13,7 +13,9 @@ package server import ( "context" + "github.com/cockroachdb/cockroach/pkg/server/authserver" "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srverrors" "github.com/cockroachdb/cockroach/pkg/sql/privilege" "github.com/cockroachdb/cockroach/pkg/sql/roleoption" ) @@ -22,21 +24,21 @@ import ( func (s *baseStatusServer) UserSQLRoles( ctx context.Context, req *serverpb.UserSQLRolesRequest, ) (_ *serverpb.UserSQLRolesResponse, retErr error) { - ctx = forwardSQLIdentityThroughRPCCalls(ctx) + ctx = authserver.ForwardSQLIdentityThroughRPCCalls(ctx) ctx = s.AnnotateCtx(ctx) - username, isAdmin, err := s.privilegeChecker.getUserAndRole(ctx) + username, isAdmin, err := s.privilegeChecker.GetUserAndRole(ctx) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } var resp serverpb.UserSQLRolesResponse if !isAdmin { for _, privKind := range privilege.GlobalPrivileges { privName := privKind.String() - hasPriv, err := s.privilegeChecker.hasGlobalPrivilege(ctx, username, privKind) + hasPriv, err := s.privilegeChecker.HasGlobalPrivilege(ctx, username, privKind) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } if hasPriv { resp.Roles = append(resp.Roles, privName) @@ -46,9 +48,9 @@ func (s *baseStatusServer) UserSQLRoles( if !ok { continue } - hasRole, err := s.privilegeChecker.hasRoleOption(ctx, username, roleOpt) + hasRole, err := s.privilegeChecker.HasRoleOption(ctx, username, roleOpt) if err != nil { - return nil, serverError(ctx, err) + return nil, srverrors.ServerError(ctx, err) } if hasRole { resp.Roles = append(resp.Roles, privName) diff --git a/pkg/server/user_test.go b/pkg/server/user_test.go index 866b0457aec4..5676d082fedd 100644 --- a/pkg/server/user_test.go +++ b/pkg/server/user_test.go @@ -19,7 +19,9 @@ import ( "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server/apiconstants" "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" "github.com/cockroachdb/cockroach/pkg/sql/roleoption" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" @@ -41,7 +43,7 @@ func TestValidRoles(t *testing.T) { for name := range roleoption.ByName { // Test user without the role. - hasRole, err := s.(*TestServer).status.baseStatusServer.privilegeChecker.hasRoleOption(ctx, fooUser, roleoption.ByName[name]) + hasRole, err := s.(*TestServer).status.baseStatusServer.privilegeChecker.HasRoleOption(ctx, fooUser, roleoption.ByName[name]) require.NoError(t, err) require.Equal(t, false, hasRole) @@ -60,7 +62,7 @@ func TestValidRoles(t *testing.T) { _, err = sqlDB.Exec(fmt.Sprintf("ALTER USER %s %s%s", fooUser, name, extraInfo)) require.NoError(t, err) - hasRole, err = s.(*TestServer).status.baseStatusServer.privilegeChecker.hasRoleOption(ctx, fooUser, roleoption.ByName[name]) + hasRole, err = s.(*TestServer).status.baseStatusServer.privilegeChecker.HasRoleOption(ctx, fooUser, roleoption.ByName[name]) require.NoError(t, err) expectedHasRole := true @@ -85,38 +87,38 @@ func TestSQLRolesAPI(t *testing.T) { // Admin user. expRoles := []string{"ADMIN"} - err := getStatusJSONProtoWithAdminOption(s, "sqlroles", &res, true) + err := srvtestutils.GetStatusJSONProtoWithAdminOption(s, "sqlroles", &res, true) require.NoError(t, err) require.ElementsMatch(t, expRoles, res.Roles) // No roles added to a non-admin user. expRoles = []string{} - err = getStatusJSONProtoWithAdminOption(s, "sqlroles", &res, false) + err = srvtestutils.GetStatusJSONProtoWithAdminOption(s, "sqlroles", &res, false) require.NoError(t, err) require.ElementsMatch(t, expRoles, res.Roles) // Role option and global privilege added to the non-admin user. - db.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITY", authenticatedUserNameNoAdmin().Normalized())) - db.Exec(t, fmt.Sprintf("GRANT SYSTEM MODIFYCLUSTERSETTING TO %s", authenticatedUserNameNoAdmin().Normalized())) + db.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITY", apiconstants.TestingUserNameNoAdmin().Normalized())) + db.Exec(t, fmt.Sprintf("GRANT SYSTEM MODIFYCLUSTERSETTING TO %s", apiconstants.TestingUserNameNoAdmin().Normalized())) expRoles = []string{"MODIFYCLUSTERSETTING", "VIEWACTIVITY"} - err = getStatusJSONProtoWithAdminOption(s, "sqlroles", &res, false) + err = srvtestutils.GetStatusJSONProtoWithAdminOption(s, "sqlroles", &res, false) require.NoError(t, err) require.ElementsMatch(t, expRoles, res.Roles) // Two role options and two global privileges added to the non-admin user. - db.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITYREDACTED", authenticatedUserNameNoAdmin().Normalized())) - db.Exec(t, fmt.Sprintf("GRANT SYSTEM CANCELQUERY TO %s", authenticatedUserNameNoAdmin().Normalized())) + db.Exec(t, fmt.Sprintf("ALTER USER %s VIEWACTIVITYREDACTED", apiconstants.TestingUserNameNoAdmin().Normalized())) + db.Exec(t, fmt.Sprintf("GRANT SYSTEM CANCELQUERY TO %s", apiconstants.TestingUserNameNoAdmin().Normalized())) expRoles = []string{"CANCELQUERY", "MODIFYCLUSTERSETTING", "VIEWACTIVITY", "VIEWACTIVITYREDACTED"} - err = getStatusJSONProtoWithAdminOption(s, "sqlroles", &res, false) + err = srvtestutils.GetStatusJSONProtoWithAdminOption(s, "sqlroles", &res, false) sort.Strings(res.Roles) require.NoError(t, err) require.ElementsMatch(t, expRoles, res.Roles) // Remove one role option and one global privilege from non-admin user. - db.Exec(t, fmt.Sprintf("ALTER USER %s NOVIEWACTIVITY", authenticatedUserNameNoAdmin().Normalized())) - db.Exec(t, fmt.Sprintf("REVOKE SYSTEM MODIFYCLUSTERSETTING FROM %s", authenticatedUserNameNoAdmin().Normalized())) + db.Exec(t, fmt.Sprintf("ALTER USER %s NOVIEWACTIVITY", apiconstants.TestingUserNameNoAdmin().Normalized())) + db.Exec(t, fmt.Sprintf("REVOKE SYSTEM MODIFYCLUSTERSETTING FROM %s", apiconstants.TestingUserNameNoAdmin().Normalized())) expRoles = []string{"CANCELQUERY", "VIEWACTIVITYREDACTED"} - err = getStatusJSONProtoWithAdminOption(s, "sqlroles", &res, false) + err = srvtestutils.GetStatusJSONProtoWithAdminOption(s, "sqlroles", &res, false) require.NoError(t, err) require.ElementsMatch(t, expRoles, res.Roles) } diff --git a/pkg/testutils/lint/passes/fmtsafe/functions.go b/pkg/testutils/lint/passes/fmtsafe/functions.go index 0e4fa759e23a..ac4339ed4d8e 100644 --- a/pkg/testutils/lint/passes/fmtsafe/functions.go +++ b/pkg/testutils/lint/passes/fmtsafe/functions.go @@ -141,8 +141,8 @@ var requireConstFmt = map[string]bool{ "(*github.com/cockroachdb/cockroach/pkg/sql/logictest.logicTest).Errorf": true, "(*github.com/cockroachdb/cockroach/pkg/sql/logictest.logicTest).Fatalf": true, - "github.com/cockroachdb/cockroach/pkg/server.serverErrorf": true, - "github.com/cockroachdb/cockroach/pkg/server.guaranteedExitFatal": true, + "github.com/cockroachdb/cockroach/pkg/server/srverrors.ServerErrorf": true, + "github.com/cockroachdb/cockroach/pkg/server.guaranteedExitFatal": true, "(*github.com/cockroachdb/cockroach/pkg/ccl/changefeedccl.kafkaLogAdapter).Printf": true, diff --git a/pkg/testutils/serverutils/test_server_shim.go b/pkg/testutils/serverutils/test_server_shim.go index ada35820dcea..d413e72b4dea 100644 --- a/pkg/testutils/serverutils/test_server_shim.go +++ b/pkg/testutils/serverutils/test_server_shim.go @@ -141,10 +141,6 @@ type TestServerInterface interface { // LeaseManager() returns the *sql.LeaseManager as an interface{}. LeaseManager() interface{} - // InternalExecutor returns a *sql.InternalExecutor as an interface{} (which - // also implements insql.InternalExecutor if the test cannot depend on sql). - InternalExecutor() interface{} - // InternalExecutorInternalExecutorFactory returns a // insql.InternalDB as an interface{}. InternalDB() interface{} diff --git a/pkg/testutils/serverutils/test_tenant_shim.go b/pkg/testutils/serverutils/test_tenant_shim.go index 024fd30a77fa..c4a4afa31fac 100644 --- a/pkg/testutils/serverutils/test_tenant_shim.go +++ b/pkg/testutils/serverutils/test_tenant_shim.go @@ -23,6 +23,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/keys" "github.com/cockroachdb/cockroach/pkg/kv" "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/util/hlc" @@ -110,6 +111,9 @@ type TestTenantInterface interface { // interface{}. TenantStatusServer() interface{} + // HTTPAuthServer returns the authserver.Server as an interface{}. + HTTPAuthServer() interface{} + // SQLServer returns the *sql.Server as an interface{}. SQLServer() interface{} @@ -119,6 +123,11 @@ type TestTenantInterface interface { // DistSenderI returns the *kvcoord.DistSender as an interface{}. DistSenderI() interface{} + // InternalExecutor returns a *sql.InternalExecutor as an + // interface{} (which also implements insql.InternalExecutor if the + // test cannot depend on sql). + InternalExecutor() interface{} + // JobRegistry returns the *jobs.Registry as an interface{}. JobRegistry() interface{} @@ -181,22 +190,36 @@ type TestTenantInterface interface { // AdminURL returns the URL for the admin UI. AdminURL() *TestURL + // GetUnauthenticatedHTTPClient returns an http client configured with the client TLS // config required by the TestServer's configuration. // Discourages implementer from using unauthenticated http connections // with verbose method name. GetUnauthenticatedHTTPClient() (http.Client, error) + // GetAdminHTTPClient returns an http client which has been // authenticated to access Admin API methods (via a cookie). // The user has admin privileges. GetAdminHTTPClient() (http.Client, error) + // GetAuthenticatedHTTPClient returns an http client which has been // authenticated to access Admin API methods (via a cookie). GetAuthenticatedHTTPClient(isAdmin bool, sessionType SessionType) (http.Client, error) - // GetEncodedSession returns a byte array containing a valid auth + + // GetAuthenticatedHTTPClientAndCookie returns an http client which + // has been authenticated to access Admin API methods and + // the corresponding session cookie. + GetAuthenticatedHTTPClientAndCookie( + authUser username.SQLUsername, isAdmin bool, session SessionType, + ) (http.Client, *serverpb.SessionCookie, error) + + // GetAuthSession returns a byte array containing a valid auth // session. GetAuthSession(isAdmin bool) (*serverpb.SessionCookie, error) + // CreateAuthUser is exported for use in tests. + CreateAuthUser(userName username.SQLUsername, isAdmin bool) error + // DrainClients shuts down client connections. DrainClients(ctx context.Context) error diff --git a/pkg/util/safesql/BUILD.bazel b/pkg/util/safesql/BUILD.bazel new file mode 100644 index 000000000000..6bec3592163b --- /dev/null +++ b/pkg/util/safesql/BUILD.bazel @@ -0,0 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "safesql", + srcs = ["safesql.go"], + importpath = "github.com/cockroachdb/cockroach/pkg/util/safesql", + visibility = ["//visibility:public"], + deps = ["@com_github_cockroachdb_errors//:errors"], +) diff --git a/pkg/util/safesql/safesql.go b/pkg/util/safesql/safesql.go new file mode 100644 index 000000000000..c5bb2f8d970a --- /dev/null +++ b/pkg/util/safesql/safesql.go @@ -0,0 +1,88 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package safesql + +import ( + "bytes" + "strconv" + + "github.com/cockroachdb/errors" +) + +// Query allows you to incrementally build a SQL query that uses +// placeholders. Instead of specific placeholders like $1, you instead use the +// temporary placeholder $. +type Query struct { + buf bytes.Buffer + pidx int + qargs []interface{} + errs []error +} + +// NewQuery creates a new Query. +func NewQuery() *Query { + return &Query{} +} + +// String returns the full query. +func (q *Query) String() string { + if len(q.errs) > 0 { + return "couldn't generate query: please check Errors()" + } + return q.buf.String() +} + +// Errors returns a slice containing all errors that have happened during the +// construction of this query. +func (q *Query) Errors() []error { + return q.errs +} + +// QueryArguments returns a filled map of placeholders containing all arguments +// provided to this query through Append. +func (q *Query) QueryArguments() []interface{} { + return q.qargs +} + +// Append appends the provided string and any number of query parameters. +// Instead of using normal placeholders (e.g. $1, $2), use meta-placeholder $. +// This method rewrites the query so that it uses proper placeholders. +// +// For example, suppose we have the following calls: +// +// query.Append("SELECT * FROM foo WHERE a > $ AND a < $ ", arg1, arg2) +// query.Append("LIMIT $", limit) +// +// The query is rewritten into: +// +// SELECT * FROM foo WHERE a > $1 AND a < $2 LIMIT $3 +// /* $1 = arg1, $2 = arg2, $3 = limit */ +// +// Note that this method does NOT return any errors. Instead, we queue up +// errors, which can later be accessed. Returning an error here would make +// query construction code exceedingly tedious. +func (q *Query) Append(s string, params ...interface{}) { + var placeholders int + for _, r := range s { + q.buf.WriteRune(r) + if r == '$' { + q.pidx++ + placeholders++ + q.buf.WriteString(strconv.Itoa(q.pidx)) // SQL placeholders are 1-based + } + } + + if placeholders != len(params) { + q.errs = append(q.errs, + errors.Errorf("# of placeholders %d != # of params %d", placeholders, len(params))) + } + q.qargs = append(q.qargs, params...) +}