From 8305873cc92e38c4ef0e69d0013dddd557744ebf Mon Sep 17 00:00:00 2001 From: Raphael 'kena' Poss Date: Sun, 30 Jul 2023 16:16:22 +0200 Subject: [PATCH 1/5] cli: move `addrWithDefaultHost` to package `addr` Release note: None --- pkg/cli/rpc_client.go | 15 ++------------- pkg/cli/start_test.go | 3 ++- pkg/util/netutil/addr/addr.go | 13 +++++++++++++ 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/pkg/cli/rpc_client.go b/pkg/cli/rpc_client.go index aa6efda174e8..e28145bb1b84 100644 --- a/pkg/cli/rpc_client.go +++ b/pkg/cli/rpc_client.go @@ -12,13 +12,13 @@ package cli import ( "context" - "net" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/util/hlc" + "github.com/cockroachdb/cockroach/pkg/util/netutil/addr" "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/util/tracing" @@ -56,7 +56,7 @@ func getClientGRPCConn( if cfg.TestingKnobs.Server != nil { rpcContext.Knobs = cfg.TestingKnobs.Server.(*server.TestingKnobs).ContextTestingKnobs } - addr, err := addrWithDefaultHost(cfg.AdvertiseAddr) + addr, err := addr.AddrWithDefaultLocalhost(cfg.AdvertiseAddr) if err != nil { stopper.Stop(ctx) return nil, nil, nil, err @@ -79,17 +79,6 @@ func getClientGRPCConn( return conn, clock, closer, nil } -func addrWithDefaultHost(addr string) (string, error) { - host, port, err := net.SplitHostPort(addr) - if err != nil { - return "", err - } - if host == "" { - host = "localhost" - } - return net.JoinHostPort(host, port), nil -} - // getAdminClient returns an AdminClient and a closure that must be invoked // to free associated resources. func getAdminClient(ctx context.Context, cfg server.Config) (serverpb.AdminClient, func(), error) { diff --git a/pkg/cli/start_test.go b/pkg/cli/start_test.go index 94ab56d879eb..bd5ddd07a8b3 100644 --- a/pkg/cli/start_test.go +++ b/pkg/cli/start_test.go @@ -23,6 +23,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/netutil/addr" "github.com/cockroachdb/pebble/vfs" "github.com/stretchr/testify/require" ) @@ -206,7 +207,7 @@ func TestAddrWithDefaultHost(t *testing.T) { } for _, test := range testData { - addr, err := addrWithDefaultHost(test.inAddr) + addr, err := addr.AddrWithDefaultLocalhost(test.inAddr) if err != nil { t.Error(err) } else if addr != test.outAddr { diff --git a/pkg/util/netutil/addr/addr.go b/pkg/util/netutil/addr/addr.go index 0dba2a587906..8899456e8ba5 100644 --- a/pkg/util/netutil/addr/addr.go +++ b/pkg/util/netutil/addr/addr.go @@ -18,6 +18,19 @@ import ( "github.com/spf13/pflag" ) +// AddrWithDefaultLocalhost returns addr with the host set +// to localhost if it is empty. +func AddrWithDefaultLocalhost(addr string) (string, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return "", err + } + if host == "" { + host = "localhost" + } + return net.JoinHostPort(host, port), nil +} + // SplitHostPort is like net.SplitHostPort however it supports // addresses without a port number. In that case, the provided port // number is used. From bfed17654520577ed9d10d698a9ddf7e60d4f814 Mon Sep 17 00:00:00 2001 From: Raphael 'kena' Poss Date: Sun, 30 Jul 2023 16:16:59 +0200 Subject: [PATCH 2/5] cli: remove the unused return from `getClientGRPCConn` Release note: None --- pkg/cli/debug_reset_quorum.go | 2 +- pkg/cli/debug_send_kv_batch.go | 2 +- pkg/cli/init.go | 4 ++-- pkg/cli/node.go | 4 ++-- pkg/cli/rpc_client.go | 19 ++++++++----------- pkg/cli/tsdump.go | 2 +- pkg/cli/zip.go | 4 ++-- 7 files changed, 17 insertions(+), 20 deletions(-) diff --git a/pkg/cli/debug_reset_quorum.go b/pkg/cli/debug_reset_quorum.go index a15c353860dd..cad02c684d38 100644 --- a/pkg/cli/debug_reset_quorum.go +++ b/pkg/cli/debug_reset_quorum.go @@ -50,7 +50,7 @@ func runDebugResetQuorum(cmd *cobra.Command, args []string) error { } // Set up GRPC Connection for running ResetQuorum. - cc, _, finish, err := getClientGRPCConn(ctx, serverCfg) + cc, finish, err := getClientGRPCConn(ctx, serverCfg) if err != nil { log.Errorf(ctx, "connection to server failed: %v", err) return err diff --git a/pkg/cli/debug_send_kv_batch.go b/pkg/cli/debug_send_kv_batch.go index 1503587ef3de..7fddac11a0fe 100644 --- a/pkg/cli/debug_send_kv_batch.go +++ b/pkg/cli/debug_send_kv_batch.go @@ -171,7 +171,7 @@ func runSendKVBatch(cmd *cobra.Command, args []string) error { // Send BatchRequest. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - conn, _, finish, err := getClientGRPCConn(ctx, serverCfg) + conn, finish, err := getClientGRPCConn(ctx, serverCfg) if err != nil { return errors.Wrap(err, "failed to connect to the node") } diff --git a/pkg/cli/init.go b/pkg/cli/init.go index ef56f49ab781..f14a8f204264 100644 --- a/pkg/cli/init.go +++ b/pkg/cli/init.go @@ -46,7 +46,7 @@ func runInit(cmd *cobra.Command, args []string) error { if err := dialAndCheckHealth(ctx); err != nil { return err } - conn, _, finish, err := getClientGRPCConn(ctx, serverCfg) + conn, finish, err := getClientGRPCConn(ctx, serverCfg) if err != nil { return err } @@ -76,7 +76,7 @@ func dialAndCheckHealth(ctx context.Context) error { // (Attempt to) establish the gRPC connection. If that fails, // it may be that the server hasn't started to listen yet, in // which case we'll retry. - conn, _, finish, err := getClientGRPCConn(ctx, serverCfg) + conn, finish, err := getClientGRPCConn(ctx, serverCfg) if err != nil { return err } diff --git a/pkg/cli/node.go b/pkg/cli/node.go index f6fc2efa7321..48ef509ed997 100644 --- a/pkg/cli/node.go +++ b/pkg/cli/node.go @@ -349,7 +349,7 @@ func runDecommissionNode(cmd *cobra.Command, args []string) error { return err } - conn, _, finish, err := getClientGRPCConn(ctx, serverCfg) + conn, finish, err := getClientGRPCConn(ctx, serverCfg) if err != nil { return errors.Wrap(err, "failed to connect to the node") } @@ -814,7 +814,7 @@ func runRecommissionNode(cmd *cobra.Command, args []string) error { return err } - conn, _, finish, err := getClientGRPCConn(ctx, serverCfg) + conn, finish, err := getClientGRPCConn(ctx, serverCfg) if err != nil { return errors.Wrap(err, "failed to connect to the node") } diff --git a/pkg/cli/rpc_client.go b/pkg/cli/rpc_client.go index e28145bb1b84..fb58a6e2b08a 100644 --- a/pkg/cli/rpc_client.go +++ b/pkg/cli/rpc_client.go @@ -17,7 +17,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/server/serverpb" - "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/cockroach/pkg/util/netutil/addr" "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/timeutil" @@ -26,13 +25,11 @@ import ( "google.golang.org/grpc" ) -// getClientGRPCConn returns a ClientConn, a Clock and a method that blocks +// getClientGRPCConn returns a ClientConn and a method that blocks // until the connection (and its associated goroutines) have terminated. -func getClientGRPCConn( - ctx context.Context, cfg server.Config, -) (*grpc.ClientConn, hlc.WallClock, func(), error) { +func getClientGRPCConn(ctx context.Context, cfg server.Config) (*grpc.ClientConn, func(), error) { if ctx.Done() == nil { - return nil, nil, nil, errors.New("context must be cancellable") + return nil, nil, errors.New("context must be cancellable") } // 0 to disable max offset checks; this RPC context is not a member of the // cluster, so there's no need to enforce that its max offset is the same @@ -59,14 +56,14 @@ func getClientGRPCConn( addr, err := addr.AddrWithDefaultLocalhost(cfg.AdvertiseAddr) if err != nil { stopper.Stop(ctx) - return nil, nil, nil, err + return nil, nil, err } // We use GRPCUnvalidatedDial() here because it does not matter // to which node we're talking to. conn, err := rpcContext.GRPCUnvalidatedDial(addr).Connect(ctx) if err != nil { stopper.Stop(ctx) - return nil, nil, nil, err + return nil, nil, err } stopper.AddCloser(stop.CloserFn(func() { _ = conn.Close() // nolint:grpcconnclose @@ -76,13 +73,13 @@ func getClientGRPCConn( closer := func() { stopper.Stop(ctx) } - return conn, clock, closer, nil + return conn, closer, nil } // getAdminClient returns an AdminClient and a closure that must be invoked // to free associated resources. func getAdminClient(ctx context.Context, cfg server.Config) (serverpb.AdminClient, func(), error) { - conn, _, finish, err := getClientGRPCConn(ctx, cfg) + conn, finish, err := getClientGRPCConn(ctx, cfg) if err != nil { return nil, nil, errors.Wrap(err, "failed to connect to the node") } @@ -94,7 +91,7 @@ func getAdminClient(ctx context.Context, cfg server.Config) (serverpb.AdminClien func getStatusClient( ctx context.Context, cfg server.Config, ) (serverpb.StatusClient, func(), error) { - conn, _, finish, err := getClientGRPCConn(ctx, cfg) + conn, finish, err := getClientGRPCConn(ctx, cfg) if err != nil { return nil, nil, errors.Wrap(err, "failed to connect to the node") } diff --git a/pkg/cli/tsdump.go b/pkg/cli/tsdump.go index 3868ad7e66f9..aab88f9cad02 100644 --- a/pkg/cli/tsdump.go +++ b/pkg/cli/tsdump.go @@ -52,7 +52,7 @@ output. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - conn, _, finish, err := getClientGRPCConn(ctx, serverCfg) + conn, finish, err := getClientGRPCConn(ctx, serverCfg) if err != nil { return err } diff --git a/pkg/cli/zip.go b/pkg/cli/zip.go index 97d76dc48acc..e257e44792fd 100644 --- a/pkg/cli/zip.go +++ b/pkg/cli/zip.go @@ -161,7 +161,7 @@ func runDebugZip(cmd *cobra.Command, args []string) (retErr error) { var tenants []*serverpb.Tenant if err := func() error { s := zr.start("discovering virtual clusters") - conn, _, finish, err := getClientGRPCConn(ctx, serverCfg) + conn, finish, err := getClientGRPCConn(ctx, serverCfg) if err != nil { return s.fail(err) } @@ -215,7 +215,7 @@ func runDebugZip(cmd *cobra.Command, args []string) (retErr error) { sqlAddr := tenant.SqlAddr s := zr.start("establishing RPC connection to %s", cfg.AdvertiseAddr) - conn, _, finish, err := getClientGRPCConn(ctx, cfg) + conn, finish, err := getClientGRPCConn(ctx, cfg) if err != nil { return s.fail(err) } From 7380a91bb4a4a10d064d4d6f127cd34577b419cb Mon Sep 17 00:00:00 2001 From: Raphael 'kena' Poss Date: Tue, 1 Aug 2023 23:46:23 +0200 Subject: [PATCH 3/5] cli: fix a buglet We're noticing that commit da5dca5a528245a20f4255105cbb34c816876ba7 introduced a potential bug. The original code (introduced in commit 7a3cdaa01834915a05dbd027e7d131dddcca8ea5) was doing two separate things: 1. waiting for the context to be cancelled before calling `.Stop` 2. using a separate context for the call to `.Stop` to ensure the already-cancelled context does not get passed to the closers. Point (1) was considered to be a problem - instead we want the caller to call `finish()` explicitly. While that is true (and achieved by the later code), the change introduced a new problem: the possibly-cancelled ctx is passed directly to `.Stop` now. This defeats the protection of point (2). This commit restores it. Release note: None --- pkg/cli/rpc_client.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pkg/cli/rpc_client.go b/pkg/cli/rpc_client.go index fb58a6e2b08a..3b4f42893784 100644 --- a/pkg/cli/rpc_client.go +++ b/pkg/cli/rpc_client.go @@ -69,9 +69,12 @@ func getClientGRPCConn(ctx context.Context, cfg server.Config) (*grpc.ClientConn _ = conn.Close() // nolint:grpcconnclose })) - // Tie the lifetime of the stopper to that of the context. closer := func() { - stopper.Stop(ctx) + // We use context.Background() here and not ctx because we + // want to ensure that the closers always run to completion + // even if the context used to create the client conn is + // canceled. + stopper.Stop(context.Background()) } return conn, closer, nil } From 4ba00f1f870942451e799463125726364a06277a Mon Sep 17 00:00:00 2001 From: Raphael 'kena' Poss Date: Mon, 31 Jul 2023 22:28:18 +0200 Subject: [PATCH 4/5] democluster: rely less on firstServer Prior to this patch, we were over-relying on the "firstServer" in a demo cluster to do decommission/recommission operations. In a later version, we would like to introduce the option to stop/restart the first server. in that case, it would not necessarily be available. This commit makes one step in that direction. Release note: None --- pkg/cli/democluster/demo_cluster.go | 41 ++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/pkg/cli/democluster/demo_cluster.go b/pkg/cli/democluster/demo_cluster.go index d4bb253d5f05..74d3c5c49a99 100644 --- a/pkg/cli/democluster/demo_cluster.go +++ b/pkg/cli/democluster/demo_cluster.go @@ -1025,6 +1025,30 @@ func (c *transientCluster) DrainAndShutdown(ctx context.Context, nodeID int32) e return nil } +// findOtherServer returns an admin RPC client to another +// server than the one referred to by the node ID. +func (c *transientCluster) findOtherServer( + ctx context.Context, nodeID int32, op string, +) (serverpb.AdminClient, func(), error) { + // Find a node to use as the sender. + var adminClient serverpb.AdminClient + var finish func() + for _, s := range c.servers { + if s.TestServer != nil && s.nodeID != roachpb.NodeID(nodeID) { + var err error + adminClient, finish, err = c.getAdminClient(ctx, *(s.TestServer.Cfg)) + if err != nil { + return nil, nil, err + } + break + } + } + if adminClient == nil { + return nil, nil, errors.Newf("no other nodes available to send a %s request", op) + } + return adminClient, finish, nil +} + // Recommission recommissions a given node. func (c *transientCluster) Recommission(ctx context.Context, nodeID int32) error { nodeIndex := int(nodeID - 1) @@ -1041,14 +1065,14 @@ func (c *transientCluster) Recommission(ctx context.Context, nodeID int32) error ctx, cancel := context.WithCancel(ctx) defer cancel() - adminClient, finish, err := c.getAdminClient(ctx, *(c.firstServer.Cfg)) + // Find a node to use as the sender. + adminClient, finish, err := c.findOtherServer(ctx, nodeID, "recommission") if err != nil { return err } - defer finish() - _, err = adminClient.Decommission(ctx, req) - if err != nil { + + if _, err = adminClient.Decommission(ctx, req); err != nil { return errors.Wrap(err, "while trying to mark as decommissioning") } @@ -1066,7 +1090,8 @@ func (c *transientCluster) Decommission(ctx context.Context, nodeID int32) error ctx, cancel := context.WithCancel(ctx) defer cancel() - adminClient, finish, err := c.getAdminClient(ctx, *(c.firstServer.Cfg)) + // Find a node to use as the sender. + adminClient, finish, err := c.findOtherServer(ctx, nodeID, "decommission") if err != nil { return err } @@ -1080,8 +1105,7 @@ func (c *transientCluster) Decommission(ctx context.Context, nodeID int32) error NodeIDs: []roachpb.NodeID{roachpb.NodeID(nodeID)}, TargetMembership: livenesspb.MembershipStatus_DECOMMISSIONING, } - _, err = adminClient.Decommission(ctx, req) - if err != nil { + if _, err := adminClient.Decommission(ctx, req); err != nil { return errors.Wrap(err, "while trying to mark as decommissioning") } } @@ -1091,8 +1115,7 @@ func (c *transientCluster) Decommission(ctx context.Context, nodeID int32) error NodeIDs: []roachpb.NodeID{roachpb.NodeID(nodeID)}, TargetMembership: livenesspb.MembershipStatus_DECOMMISSIONED, } - _, err = adminClient.Decommission(ctx, req) - if err != nil { + if _, err := adminClient.Decommission(ctx, req); err != nil { return errors.Wrap(err, "while trying to mark as decommissioned") } } From 607c69795bea85de6289f0f826e9b738bdbe36c0 Mon Sep 17 00:00:00 2001 From: Raphael 'kena' Poss Date: Sun, 30 Jul 2023 19:51:19 +0200 Subject: [PATCH 5/5] rpc,*: sanitize the configuration of rpc.Context The main goal of this change is to offer a `.RPCClientConn()` method on test servers. To achieve this, it was necessary to lift the code previously in `cli.getClientGRPCConn` into a more reusable version of it, now hosted in `rpc.NewClientContext()`. I also took the opportunity to remove the dependency of `rpc.Context` on `base.Config`, by spelling out precisely which fields are necessary to RPC connections. Numerous tests could be simplified as a result. In particular: - Complex sequences of calls to open a client RPC connection to a test server can be simplified to a single call to `.RPCClientConn()`. - `testutils.NewNodeTestContext` and `NewTestBaseContext` are not necessary any more. They were previously used as input to build instances of `rpc.ContextOptions`. The latter can now be built using `rpc.DefaultContextOptions()` instead. - `(*localestcluster.LocalTestCluster).Start()` does not need a `*base.Config` any more and so its call sites can be simplified. Release note: None --- pkg/acceptance/localcluster/BUILD.bazel | 2 - pkg/acceptance/localcluster/cluster.go | 22 +-- pkg/base/addr_validation_test.go | 12 +- pkg/base/config.go | 46 +++-- pkg/ccl/oidcccl/BUILD.bazel | 3 - pkg/ccl/oidcccl/authentication_oidc_test.go | 36 +--- .../serverccl/statusccl/tenant_grpc_test.go | 10 +- pkg/cli/BUILD.bazel | 2 +- pkg/cli/debug_recover_loss_of_quorum_test.go | 16 +- pkg/cli/demo.go | 1 - pkg/cli/democluster/demo_cluster.go | 58 +++--- pkg/cli/rpc_client.go | 73 ++----- pkg/gossip/simulation/BUILD.bazel | 2 - pkg/gossip/simulation/network.go | 22 +-- pkg/kv/kvclient/kvcoord/split_test.go | 2 +- .../kvclient/kvcoord/txn_coord_sender_test.go | 4 +- .../kvclient/kvcoord/txn_correctness_test.go | 3 +- pkg/kv/kvclient/kvcoord/txn_test.go | 2 +- pkg/kv/kvserver/client_decommission_test.go | 5 +- pkg/kv/kvserver/idalloc/id_alloc_test.go | 2 +- pkg/kv/kvserver/liveness/client_test.go | 5 +- .../loqrecovery/collect_raft_log_test.go | 3 +- .../loqrecovery/server_integration_test.go | 51 ++--- pkg/kv/kvserver/raft_transport_test.go | 13 +- pkg/kv/kvserver/raft_transport_unit_test.go | 19 +- pkg/kv/kvserver/replica_learner_test.go | 8 +- pkg/kv/kvserver/replicate_queue_test.go | 13 +- pkg/kv/kvserver/scatter_test.go | 3 +- pkg/kv/kvserver/store_test.go | 18 +- pkg/obsservice/obslib/ingest/BUILD.bazel | 1 - .../obslib/ingest/ingest_integration_test.go | 19 +- pkg/rpc/BUILD.bazel | 2 + pkg/rpc/client.go | 186 ++++++++++++++++++ pkg/rpc/context.go | 107 ++++++++-- pkg/rpc/context_test.go | 146 +++++++------- pkg/rpc/context_testutils.go | 21 +- pkg/rpc/nodedialer/nodedialer_test.go | 23 +-- pkg/rpc/peer.go | 8 +- pkg/rpc/tls.go | 30 ++- pkg/rpc/tls_test.go | 52 ++--- pkg/security/certs_rotation_test.go | 5 +- pkg/security/certs_test.go | 25 +-- pkg/server/application_api/BUILD.bazel | 2 - pkg/server/application_api/activity_test.go | 12 +- pkg/server/application_api/jobs_test.go | 12 +- pkg/server/application_api/sessions_test.go | 11 +- .../storage_inspection_test.go | 13 +- pkg/server/authserver/authentication_test.go | 26 ++- pkg/server/drain_test.go | 23 +-- pkg/server/server.go | 40 ++-- pkg/server/server_import_ts_test.go | 7 +- pkg/server/srvtestutils/BUILD.bazel | 4 - pkg/server/srvtestutils/testutils.go | 28 --- pkg/server/statements_test.go | 19 +- pkg/server/status_ext_test.go | 11 +- pkg/server/storage_api/BUILD.bazel | 2 - pkg/server/storage_api/decommission_test.go | 43 +--- pkg/server/storage_api/files_test.go | 18 +- pkg/server/storage_api/nodes_test.go | 20 +- pkg/server/storage_api/ranges_test.go | 15 +- pkg/server/tenant.go | 28 ++- pkg/server/testserver.go | 131 ++++++++++++ pkg/server/testserver_http.go | 4 +- pkg/sql/create_role.go | 2 +- pkg/sql/execinfrapb/BUILD.bazel | 1 - pkg/sql/execinfrapb/testutils.go | 18 +- pkg/sql/flowinfra/cluster_test.go | 15 +- pkg/sql/flowinfra/server_test.go | 8 +- pkg/testutils/BUILD.bazel | 4 - pkg/testutils/base.go | 42 ---- .../localtestcluster/local_test_cluster.go | 25 ++- pkg/testutils/serverutils/BUILD.bazel | 1 + pkg/testutils/serverutils/api.go | 22 ++- pkg/testutils/testcluster/BUILD.bazel | 1 - pkg/testutils/testcluster/testcluster.go | 30 ++- pkg/testutils/testcluster/testcluster_test.go | 31 ++- pkg/ts/BUILD.bazel | 2 +- pkg/ts/db_test.go | 3 +- pkg/ts/server_test.go | 49 +---- 79 files changed, 902 insertions(+), 902 deletions(-) create mode 100644 pkg/rpc/client.go delete mode 100644 pkg/testutils/base.go diff --git a/pkg/acceptance/localcluster/BUILD.bazel b/pkg/acceptance/localcluster/BUILD.bazel index 60bed246c01c..71b3de130011 100644 --- a/pkg/acceptance/localcluster/BUILD.bazel +++ b/pkg/acceptance/localcluster/BUILD.bazel @@ -11,11 +11,9 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/acceptance/cluster", - "//pkg/base", "//pkg/config/zonepb", "//pkg/roachpb", "//pkg/rpc", - "//pkg/security/username", "//pkg/server/serverpb", "//pkg/settings/cluster", "//pkg/testutils", diff --git a/pkg/acceptance/localcluster/cluster.go b/pkg/acceptance/localcluster/cluster.go index f5f3003c6152..b535b3e5f341 100644 --- a/pkg/acceptance/localcluster/cluster.go +++ b/pkg/acceptance/localcluster/cluster.go @@ -29,11 +29,9 @@ import ( "text/tabwriter" "time" - "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/config/zonepb" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/rpc" - "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/testutils" @@ -265,20 +263,12 @@ func (c *Cluster) RPCPort(nodeIdx int) string { } func (c *Cluster) makeNode(ctx context.Context, nodeIdx int, cfg NodeConfig) (*Node, <-chan error) { - baseCtx := &base.Config{ - User: username.NodeUserName(), - Insecure: true, - } - rpcCtx := rpc.NewContext(ctx, rpc.ContextOptions{ - TenantID: roachpb.SystemTenantID, - Config: baseCtx, - Clock: &timeutil.DefaultTimeSource{}, - ToleratedOffset: 0, - Stopper: c.stopper, - Settings: cluster.MakeTestingClusterSettings(), - - ClientOnly: true, - }) + opts := rpc.DefaultContextOptions() + opts.Insecure = true + opts.Stopper = c.stopper + opts.Settings = cluster.MakeTestingClusterSettings() + opts.ClientOnly = true + rpcCtx := rpc.NewContext(ctx, opts) n := &Node{ Cfg: cfg, diff --git a/pkg/base/addr_validation_test.go b/pkg/base/addr_validation_test.go index 81dbafe03ef4..cd08f181d1a3 100644 --- a/pkg/base/addr_validation_test.go +++ b/pkg/base/addr_validation_test.go @@ -156,12 +156,12 @@ func TestValidateAddrs(t *testing.T) { for i, test := range testData { t.Run(fmt.Sprintf("%d/%s", i, test.in), func(t *testing.T) { cfg := base.Config{ - Addr: test.in.listen, - AdvertiseAddr: test.in.adv, - HTTPAddr: test.in.http, - HTTPAdvertiseAddr: test.in.advhttp, - SQLAddr: test.in.sql, - SQLAdvertiseAddr: test.in.advsql, + Addr: test.in.listen, + HTTPAddr: test.in.http, + SQLAddr: test.in.sql, + AdvertiseAddrH: base.AdvertiseAddrH{AdvertiseAddr: test.in.adv}, + HTTPAdvertiseAddrH: base.HTTPAdvertiseAddrH{HTTPAdvertiseAddr: test.in.advhttp}, + SQLAdvertiseAddrH: base.SQLAdvertiseAddrH{SQLAdvertiseAddr: test.in.advsql}, } if err := cfg.ValidateAddrs(context.Background()); err != nil { diff --git a/pkg/base/config.go b/pkg/base/config.go index ae0ed3b53040..c01bc6e44942 100644 --- a/pkg/base/config.go +++ b/pkg/base/config.go @@ -192,7 +192,7 @@ var ( defaultRangeLeaseDuration = envutil.EnvOrDefaultDuration( "COCKROACH_RANGE_LEASE_DURATION", 6*time.Second) - // defaultRPCHeartbeatTimeout is the default RPC heartbeat timeout. It is set + // DefaultRPCHeartbeatTimeout is the default RPC heartbeat timeout. It is set // very high at 3 * NetworkTimeout for several reasons: the gRPC transport may // need to complete a dial/handshake before sending the heartbeat, the // heartbeat has occasionally been seen to require 3 RTTs even post-dial (for @@ -212,7 +212,7 @@ var ( // heartbeats and reduce this to NetworkTimeout (plus DialTimeout for the // initial heartbeat), see: // https://github.com/cockroachdb/cockroach/issues/93397. - defaultRPCHeartbeatTimeout = 3 * NetworkTimeout + DefaultRPCHeartbeatTimeout = 3 * NetworkTimeout // defaultRaftTickInterval is the default resolution of the Raft timer. defaultRaftTickInterval = envutil.EnvOrDefaultDuration( @@ -341,10 +341,13 @@ type Config struct { // Addr is the address the server is listening on. Addr string - // AdvertiseAddr is the address advertised by the server to other nodes - // in the cluster. It should be reachable by all other nodes and should - // route to an interface that Addr is listening on. - AdvertiseAddr string + // AdvertiseAddrH contains the address advertised by the server to + // other nodes in the cluster. It should be reachable by all other + // nodes and should route to an interface that Addr is listening on. + // + // It is set after the server instance has been created, when the + // network listeners are being set up. + AdvertiseAddrH // ClusterName is the name used as a sanity check when a node joins // an uninitialized cluster, or when an uninitialized node joins an @@ -367,9 +370,12 @@ type Config struct { // This is used if SplitListenSQL is set to true. SQLAddr string - // SQLAdvertiseAddr is the advertised SQL address. + // SQLAdvertiseAddrH contains the advertised SQL address. // This is computed from SQLAddr if specified otherwise Addr. - SQLAdvertiseAddr string + // + // It is set after the server instance has been created, when the + // network listeners are being set up. + SQLAdvertiseAddrH // SocketFile, if non-empty, sets up a TLS-free local listener using // a unix datagram socket at the specified path for SQL clients. @@ -382,9 +388,12 @@ type Config struct { // DisableTLSForHTTP, if set, disables TLS for the HTTP listener. DisableTLSForHTTP bool - // HTTPAdvertiseAddr is the advertised HTTP address. + // HTTPAdvertiseAddrH contains the advertised HTTP address. // This is computed from HTTPAddr if specified otherwise Addr. - HTTPAdvertiseAddr string + // + // It is set after the server instance has been created, when the + // network listeners are being set up. + HTTPAdvertiseAddrH // RPCHeartbeatInterval controls how often Ping requests are sent on peer // connections to determine connection health and update the local view of @@ -421,6 +430,21 @@ type Config struct { LocalityAddresses []roachpb.LocalityAddress } +// AdvertiseAddr is the type of the AdvertiseAddr field in Config. +type AdvertiseAddrH struct { + AdvertiseAddr string +} + +// SQLAdvertiseAddr is the type of the SQLAdvertiseAddr field in Config. +type SQLAdvertiseAddrH struct { + SQLAdvertiseAddr string +} + +// HTTPAdvertiseAddr is the type of the HTTPAdvertiseAddr field in Config. +type HTTPAdvertiseAddrH struct { + HTTPAdvertiseAddr string +} + // HistogramWindowInterval is used to determine the approximate length of time // that individual samples are retained in in-memory histograms. Currently, // it is set to the arbitrary length of six times the Metrics sample interval. @@ -455,7 +479,7 @@ func (cfg *Config) InitDefaults() { cfg.SocketFile = "" cfg.SSLCertsDir = DefaultCertsDirectory cfg.RPCHeartbeatInterval = PingInterval - cfg.RPCHeartbeatTimeout = defaultRPCHeartbeatTimeout + cfg.RPCHeartbeatTimeout = DefaultRPCHeartbeatTimeout cfg.ClusterName = "" cfg.DisableClusterNameVerification = false cfg.ClockDevicePath = "" diff --git a/pkg/ccl/oidcccl/BUILD.bazel b/pkg/ccl/oidcccl/BUILD.bazel index ee25740fb320..55ee3487c7b2 100644 --- a/pkg/ccl/oidcccl/BUILD.bazel +++ b/pkg/ccl/oidcccl/BUILD.bazel @@ -48,20 +48,17 @@ go_test( "//pkg/base", "//pkg/ccl", "//pkg/roachpb", - "//pkg/rpc", "//pkg/security/securityassets", "//pkg/security/securitytest", "//pkg/security/username", "//pkg/server", "//pkg/server/serverpb", - "//pkg/testutils", "//pkg/testutils/serverutils", "//pkg/testutils/sqlutils", "//pkg/testutils/testcluster", "//pkg/util/leaktest", "//pkg/util/log", "//pkg/util/randutil", - "//pkg/util/timeutil", "@com_github_stretchr_testify//require", ], ) diff --git a/pkg/ccl/oidcccl/authentication_oidc_test.go b/pkg/ccl/oidcccl/authentication_oidc_test.go index d889bf3503cf..38c7a063d197 100644 --- a/pkg/ccl/oidcccl/authentication_oidc_test.go +++ b/pkg/ccl/oidcccl/authentication_oidc_test.go @@ -24,15 +24,12 @@ import ( "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/roachpb" - "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/server/serverpb" - "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" - "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/stretchr/testify/require" ) @@ -51,22 +48,7 @@ func TestOIDCBadRequestIfDisabled(t *testing.T) { s := serverutils.StartServerOnly(t, base.TestServerArgs{}) defer s.Stopper().Stop(ctx) - newRPCContext := func(cfg *base.Config) *rpc.Context { - return rpc.NewContext(ctx, - rpc.ContextOptions{ - TenantID: roachpb.SystemTenantID, - Config: cfg, - Clock: &timeutil.DefaultTimeSource{}, - ToleratedOffset: 1, - Stopper: s.Stopper(), - Settings: s.ClusterSettings(), - - ClientOnly: true, - }) - } - - plainHTTPCfg := testutils.NewTestBaseContext(username.TestUserName()) - testCertsContext := newRPCContext(plainHTTPCfg) + testCertsContext := s.NewClientRPCContext(ctx, username.TestUserName()) client, err := testCertsContext.GetHTTPClient() require.NoError(t, err) @@ -90,19 +72,6 @@ func TestOIDCEnabled(t *testing.T) { s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) defer s.Stopper().Stop(ctx) - newRPCContext := func(cfg *base.Config) *rpc.Context { - return rpc.NewContext(ctx, rpc.ContextOptions{ - TenantID: roachpb.SystemTenantID, - Config: cfg, - Clock: &timeutil.DefaultTimeSource{}, - ToleratedOffset: 1, - Stopper: s.Stopper(), - Settings: s.ClusterSettings(), - - ClientOnly: true, - }) - } - // Set up a test OIDC server that serves the JSON discovery document var issuer string oidcHandler := func(w http.ResponseWriter, r *http.Request) { @@ -184,8 +153,7 @@ func TestOIDCEnabled(t *testing.T) { sqlDB.Exec(t, `SET CLUSTER SETTING server.oidc_authentication.redirect_url = "https://cockroachlabs.com/oidc/v1/callback"`) sqlDB.Exec(t, `SET CLUSTER SETTING server.oidc_authentication.enabled = "true"`) - plainHTTPCfg := testutils.NewTestBaseContext(username.TestUserName()) - testCertsContext := newRPCContext(plainHTTPCfg) + testCertsContext := s.NewClientRPCContext(ctx, username.TestUserName()) client, err := testCertsContext.GetHTTPClient() require.NoError(t, err) diff --git a/pkg/ccl/serverccl/statusccl/tenant_grpc_test.go b/pkg/ccl/serverccl/statusccl/tenant_grpc_test.go index e1e415dae4ca..f3d517c28b90 100644 --- a/pkg/ccl/serverccl/statusccl/tenant_grpc_test.go +++ b/pkg/ccl/serverccl/statusccl/tenant_grpc_test.go @@ -59,15 +59,7 @@ func TestTenantGRPCServices(t *testing.T) { t.Logf("subtests starting") t.Run("gRPC is running", func(t *testing.T) { - grpcAddr := tenant.RPCAddr() - rpcCtx := tenant.RPCContext() - - nodeID := roachpb.NodeID(tenant.SQLInstanceID()) - conn, err := rpcCtx.GRPCDialNode(grpcAddr, nodeID, rpc.DefaultClass).Connect(ctx) - require.NoError(t, err) - - client := serverpb.NewStatusClient(conn) - + client := tenant.GetStatusClient(t) resp, err := client.Statements(ctx, &serverpb.StatementsRequest{NodeID: "local"}) require.NoError(t, err) require.NotEmpty(t, resp.Statements) diff --git a/pkg/cli/BUILD.bazel b/pkg/cli/BUILD.bazel index 8c3c358c01c6..7d5146d3ca8c 100644 --- a/pkg/cli/BUILD.bazel +++ b/pkg/cli/BUILD.bazel @@ -378,7 +378,6 @@ go_test( "//pkg/kv/kvserver/loqrecovery/loqrecoverypb", "//pkg/kv/kvserver/stateloader", "//pkg/roachpb", - "//pkg/rpc", "//pkg/security/securityassets", "//pkg/security/securitytest", "//pkg/security/username", @@ -405,6 +404,7 @@ go_test( "//pkg/util/log", "//pkg/util/log/logconfig", "//pkg/util/log/logpb", + "//pkg/util/netutil/addr", "//pkg/util/protoutil", "//pkg/util/randutil", "//pkg/util/stop", diff --git a/pkg/cli/debug_recover_loss_of_quorum_test.go b/pkg/cli/debug_recover_loss_of_quorum_test.go index aed90be50b5c..1d7a638a3e1f 100644 --- a/pkg/cli/debug_recover_loss_of_quorum_test.go +++ b/pkg/cli/debug_recover_loss_of_quorum_test.go @@ -26,7 +26,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/kv/kvserver/loqrecovery" "github.com/cockroachdb/cockroach/pkg/kv/kvserver/loqrecovery/loqrecoverypb" "github.com/cockroachdb/cockroach/pkg/roachpb" - "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/testutils" @@ -258,10 +257,7 @@ func TestLossOfQuorumRecovery(t *testing.T) { // attempt. That would increase number of replicas on system ranges to 5 and we // would not be able to upreplicate properly. So we need to decommission old nodes // first before proceeding. - grpcConn, err := tcAfter.Server(0).RPCContext().GRPCDialNode( - tcAfter.Server(0).AdvRPCAddr(), tcAfter.Server(0).NodeID(), rpc.DefaultClass).Connect(ctx) - require.NoError(t, err, "Failed to create test cluster after recovery") - adminClient := serverpb.NewAdminClient(grpcConn) + adminClient := tcAfter.Server(0).GetAdminClient(t) require.NoError(t, runDecommissionNodeImpl( ctx, adminClient, nodeDecommissionWaitNone, nodeDecommissionChecksSkip, false, @@ -353,10 +349,7 @@ func TestStageVersionCheck(t *testing.T) { defer tc.Stopper().Stop(ctx) tc.StopServer(3) - grpcConn, err := tc.Server(0).RPCContext().GRPCDialNode(tc.Server(0).AdvRPCAddr(), - tc.Server(0).NodeID(), rpc.DefaultClass).Connect(ctx) - require.NoError(t, err, "Failed to create test cluster after recovery") - adminClient := serverpb.NewAdminClient(grpcConn) + adminClient := tc.Server(0).GetAdminClient(t) v := clusterversion.ByKey(clusterversion.BinaryVersionKey) v.Internal++ // To avoid crafting real replicas we use StaleLeaseholderNodeIDs to force @@ -369,7 +362,7 @@ func TestStageVersionCheck(t *testing.T) { StaleLeaseholderNodeIDs: []roachpb.NodeID{1}, } // Attempts to stage plan with different internal version must fail. - _, err = adminClient.RecoveryStagePlan(ctx, &serverpb.RecoveryStagePlanRequest{ + _, err := adminClient.RecoveryStagePlan(ctx, &serverpb.RecoveryStagePlanRequest{ Plan: &p, AllNodes: true, ForcePlan: false, @@ -559,8 +552,7 @@ func TestHalfOnlineLossOfQuorumRecovery(t *testing.T) { // Verifying that post start cleanup performed node decommissioning that // prevents old nodes from rejoining. - ac, err := tc.GetAdminClient(ctx, t, 0) - require.NoError(t, err, "failed to get admin client") + ac := tc.GetAdminClient(t, 0) testutils.SucceedsSoon(t, func() error { dr, err := ac.DecommissionStatus(ctx, &serverpb.DecommissionStatusRequest{NodeIDs: []roachpb.NodeID{2, 3}}) diff --git a/pkg/cli/demo.go b/pkg/cli/demo.go index 6923a7b4e157..e4f023eaf814 100644 --- a/pkg/cli/demo.go +++ b/pkg/cli/demo.go @@ -233,7 +233,6 @@ func runDemoInternal( serverCfg.Stores.Specs = nil return setupAndInitializeLoggingAndProfiling(ctx, cmd, false /* isServerCmd */) }, - getAdminClient, func(ctx context.Context, ac serverpb.AdminClient) error { return drainAndShutdown(ctx, ac, "local" /* targetNode */) }, diff --git a/pkg/cli/democluster/demo_cluster.go b/pkg/cli/democluster/demo_cluster.go index 74d3c5c49a99..64f2a30227e0 100644 --- a/pkg/cli/democluster/demo_cluster.go +++ b/pkg/cli/democluster/demo_cluster.go @@ -68,7 +68,8 @@ import ( type serverEntry struct { *server.TestServer - nodeID roachpb.NodeID + adminClient serverpb.AdminClient + nodeID roachpb.NodeID } type transientCluster struct { @@ -88,7 +89,6 @@ type transientCluster struct { stickyEngineRegistry server.StickyInMemEnginesRegistry - getAdminClient func(ctx context.Context, cfg server.Config) (serverpb.AdminClient, func(), error) drainAndShutdown func(ctx context.Context, adminClient serverpb.AdminClient) error infoLog LoggerFn @@ -148,12 +148,10 @@ func NewDemoCluster( warnLog LoggerFn, shoutLog ShoutLoggerFn, startStopper func(ctx context.Context) (*stop.Stopper, error), - getAdminClient func(ctx context.Context, cfg server.Config) (serverpb.AdminClient, func(), error), drainAndShutdown func(ctx context.Context, s serverpb.AdminClient) error, ) (DemoCluster, error) { c := &transientCluster{ demoCtx: demoCtx, - getAdminClient: getAdminClient, drainAndShutdown: drainAndShutdown, infoLog: infoLog, warnLog: warnLog, @@ -764,6 +762,13 @@ func (c *transientCluster) waitForNodeIDReadiness( c.infoLog(ctx, "server %d: n%d", idx, nodeID) c.servers[idx].nodeID = nodeID } + // We can open a RPC admin connection now. + conn, err := c.servers[idx].RPCClientConnE(username.RootUserName()) + if err != nil { + return err + } + c.servers[idx].adminClient = serverpb.NewAdminClient(conn) + } break } @@ -1012,16 +1017,11 @@ func (c *transientCluster) DrainAndShutdown(ctx context.Context, nodeID int32) e ctx, cancel := context.WithCancel(ctx) defer cancel() - adminClient, finish, err := c.getAdminClient(ctx, *(c.servers[serverIdx].Cfg)) - if err != nil { - return err - } - defer finish() - - if err := c.drainAndShutdown(ctx, adminClient); err != nil { + if err := c.drainAndShutdown(ctx, c.servers[serverIdx].adminClient); err != nil { return err } c.servers[serverIdx].TestServer = nil + c.servers[serverIdx].adminClient = nil return nil } @@ -1029,24 +1029,19 @@ func (c *transientCluster) DrainAndShutdown(ctx context.Context, nodeID int32) e // server than the one referred to by the node ID. func (c *transientCluster) findOtherServer( ctx context.Context, nodeID int32, op string, -) (serverpb.AdminClient, func(), error) { +) (serverpb.AdminClient, error) { // Find a node to use as the sender. var adminClient serverpb.AdminClient - var finish func() for _, s := range c.servers { - if s.TestServer != nil && s.nodeID != roachpb.NodeID(nodeID) { - var err error - adminClient, finish, err = c.getAdminClient(ctx, *(s.TestServer.Cfg)) - if err != nil { - return nil, nil, err - } + if s.adminClient != nil && s.nodeID != roachpb.NodeID(nodeID) { + adminClient = s.adminClient break } } if adminClient == nil { - return nil, nil, errors.Newf("no other nodes available to send a %s request", op) + return nil, errors.Newf("no other nodes available to send a %s request", op) } - return adminClient, finish, nil + return adminClient, nil } // Recommission recommissions a given node. @@ -1066,13 +1061,12 @@ func (c *transientCluster) Recommission(ctx context.Context, nodeID int32) error defer cancel() // Find a node to use as the sender. - adminClient, finish, err := c.findOtherServer(ctx, nodeID, "recommission") + adminClient, err := c.findOtherServer(ctx, nodeID, "recommission") if err != nil { return err } - defer finish() - if _, err = adminClient.Decommission(ctx, req); err != nil { + if _, err := adminClient.Decommission(ctx, req); err != nil { return errors.Wrap(err, "while trying to mark as decommissioning") } @@ -1091,11 +1085,10 @@ func (c *transientCluster) Decommission(ctx context.Context, nodeID int32) error defer cancel() // Find a node to use as the sender. - adminClient, finish, err := c.findOtherServer(ctx, nodeID, "decommission") + adminClient, err := c.findOtherServer(ctx, nodeID, "decommission") if err != nil { return err } - defer finish() // This (cumbersome) two step process is due to the allowed state // transitions for membership status. To mark a node as fully @@ -1191,7 +1184,18 @@ func (c *transientCluster) startServerInternal( c.stopper.AddCloser(stop.CloserFn(serv.Stop)) nodeID := serv.NodeID() - c.servers[serverIdx] = serverEntry{TestServer: serv, nodeID: nodeID} + + conn, err := serv.RPCClientConnE(username.RootUserName()) + if err != nil { + serv.Stopper().Stop(ctx) + return 0, err + } + + c.servers[serverIdx] = serverEntry{ + TestServer: serv, + adminClient: serverpb.NewAdminClient(conn), + nodeID: nodeID, + } return int32(nodeID), nil } diff --git a/pkg/cli/rpc_client.go b/pkg/cli/rpc_client.go index 3b4f42893784..83ee7fb518c2 100644 --- a/pkg/cli/rpc_client.go +++ b/pkg/cli/rpc_client.go @@ -13,70 +13,31 @@ package cli import ( "context" - "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/server/serverpb" - "github.com/cockroachdb/cockroach/pkg/util/netutil/addr" - "github.com/cockroachdb/cockroach/pkg/util/stop" - "github.com/cockroachdb/cockroach/pkg/util/timeutil" - "github.com/cockroachdb/cockroach/pkg/util/tracing" "github.com/cockroachdb/errors" "google.golang.org/grpc" ) -// getClientGRPCConn returns a ClientConn and a method that blocks -// until the connection (and its associated goroutines) have terminated. -func getClientGRPCConn(ctx context.Context, cfg server.Config) (*grpc.ClientConn, func(), error) { - if ctx.Done() == nil { - return nil, nil, errors.New("context must be cancellable") - } - // 0 to disable max offset checks; this RPC context is not a member of the - // cluster, so there's no need to enforce that its max offset is the same - // as that of nodes in the cluster. - clock := &timeutil.DefaultTimeSource{} - tracer := cfg.Tracer - if tracer == nil { - tracer = tracing.NewTracer() - } - stopper := stop.NewStopper(stop.WithTracer(tracer)) - rpcContext := rpc.NewContext(ctx, - rpc.ContextOptions{ - TenantID: roachpb.SystemTenantID, - Config: cfg.Config, - Clock: clock, - Stopper: stopper, - Settings: cfg.Settings, - - ClientOnly: true, - }) - if cfg.TestingKnobs.Server != nil { - rpcContext.Knobs = cfg.TestingKnobs.Server.(*server.TestingKnobs).ContextTestingKnobs - } - addr, err := addr.AddrWithDefaultLocalhost(cfg.AdvertiseAddr) - if err != nil { - stopper.Stop(ctx) - return nil, nil, err - } - // We use GRPCUnvalidatedDial() here because it does not matter - // to which node we're talking to. - conn, err := rpcContext.GRPCUnvalidatedDial(addr).Connect(ctx) - if err != nil { - stopper.Stop(ctx) - return nil, nil, err - } - stopper.AddCloser(stop.CloserFn(func() { - _ = conn.Close() // nolint:grpcconnclose - })) +func makeRPCClientConfig(cfg server.Config) rpc.ClientConnConfig { + var knobs rpc.ContextTestingKnobs + if sknobs := cfg.TestingKnobs.Server; sknobs != nil { + knobs = sknobs.(*server.TestingKnobs).ContextTestingKnobs + } + return rpc.MakeClientConnConfigFromBaseConfig( + *cfg.Config, + cfg.Config.User, + cfg.BaseConfig.Tracer, + cfg.BaseConfig.Settings, + nil, /* clock */ + knobs, + ) +} - closer := func() { - // We use context.Background() here and not ctx because we - // want to ensure that the closers always run to completion - // even if the context used to create the client conn is - // canceled. - stopper.Stop(context.Background()) - } - return conn, closer, nil +func getClientGRPCConn(ctx context.Context, cfg server.Config) (*grpc.ClientConn, func(), error) { + ccfg := makeRPCClientConfig(cfg) + return rpc.NewClientConn(ctx, ccfg) } // getAdminClient returns an AdminClient and a closure that must be invoked diff --git a/pkg/gossip/simulation/BUILD.bazel b/pkg/gossip/simulation/BUILD.bazel index 1b818b913ac0..17b8e4921195 100644 --- a/pkg/gossip/simulation/BUILD.bazel +++ b/pkg/gossip/simulation/BUILD.bazel @@ -6,7 +6,6 @@ go_library( importpath = "github.com/cockroachdb/cockroach/pkg/gossip/simulation", visibility = ["//visibility:public"], deps = [ - "//pkg/base", "//pkg/config/zonepb", "//pkg/gossip", "//pkg/roachpb", @@ -18,7 +17,6 @@ go_library( "//pkg/util/metric", "//pkg/util/netutil", "//pkg/util/stop", - "//pkg/util/timeutil", "//pkg/util/uuid", "@org_golang_google_grpc//:go_default_library", ], diff --git a/pkg/gossip/simulation/network.go b/pkg/gossip/simulation/network.go index 5d64a4e83e7b..80333550a8a4 100644 --- a/pkg/gossip/simulation/network.go +++ b/pkg/gossip/simulation/network.go @@ -16,7 +16,6 @@ import ( "net" "time" - "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/config/zonepb" "github.com/cockroachdb/cockroach/pkg/gossip" "github.com/cockroachdb/cockroach/pkg/roachpb" @@ -28,7 +27,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/metric" "github.com/cockroachdb/cockroach/pkg/util/netutil" "github.com/cockroachdb/cockroach/pkg/util/stop" - "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/util/uuid" "google.golang.org/grpc" ) @@ -70,19 +68,15 @@ func NewNetwork( Nodes: []*Node{}, Stopper: stopper, } - n.RPCContext = rpc.NewContext(ctx, - rpc.ContextOptions{ - TenantID: roachpb.SystemTenantID, - Config: &base.Config{Insecure: true}, - Clock: &timeutil.DefaultTimeSource{}, - ToleratedOffset: 0, - Stopper: n.Stopper, - Settings: cluster.MakeTestingClusterSettings(), + opts := rpc.DefaultContextOptions() + opts.Insecure = true + opts.Stopper = n.Stopper + opts.Settings = cluster.MakeTestingClusterSettings() + opts.Knobs = rpc.ContextTestingKnobs{ + NoLoopbackDialer: true, + } + n.RPCContext = rpc.NewContext(ctx, opts) - Knobs: rpc.ContextTestingKnobs{ - NoLoopbackDialer: true, - }, - }) var err error n.tlsConfig, err = n.RPCContext.GetServerTLSConfig() if err != nil { diff --git a/pkg/kv/kvclient/kvcoord/split_test.go b/pkg/kv/kvclient/kvcoord/split_test.go index d9a394523bdc..1db7a155ef56 100644 --- a/pkg/kv/kvclient/kvcoord/split_test.go +++ b/pkg/kv/kvclient/kvcoord/split_test.go @@ -198,7 +198,7 @@ func TestRangeSplitsWithWritePressure(t *testing.T) { DisableScanner: true, }, } - s.Start(t, testutils.NewNodeTestBaseContext(), kvcoord.InitFactoryForLocalTestCluster) + s.Start(t, kvcoord.InitFactoryForLocalTestCluster) // This is purely to silence log spam. config.TestingSetupZoneConfigHook(s.Stopper()) diff --git a/pkg/kv/kvclient/kvcoord/txn_coord_sender_test.go b/pkg/kv/kvclient/kvcoord/txn_coord_sender_test.go index 088f521a5594..986753449995 100644 --- a/pkg/kv/kvclient/kvcoord/txn_coord_sender_test.go +++ b/pkg/kv/kvclient/kvcoord/txn_coord_sender_test.go @@ -65,7 +65,7 @@ func createTestDBWithKnobs( s := &localtestcluster.LocalTestCluster{ StoreTestingKnobs: knobs, } - s.Start(t, testutils.NewNodeTestBaseContext(), kvcoord.InitFactoryForLocalTestCluster) + s.Start(t, kvcoord.InitFactoryForLocalTestCluster) return s } @@ -1183,7 +1183,7 @@ func setupMetricsTest( DisableLivenessHeartbeat: true, DontCreateSystemRanges: true, } - s.Start(t, testutils.NewNodeTestBaseContext(), kvcoord.InitFactoryForLocalTestCluster) + s.Start(t, kvcoord.InitFactoryForLocalTestCluster) metrics := kvcoord.MakeTxnMetrics(metric.TestSampleInterval) s.DB.GetFactory().(*kvcoord.TxnCoordSenderFactory).TestingSetMetrics(metrics) diff --git a/pkg/kv/kvclient/kvcoord/txn_correctness_test.go b/pkg/kv/kvclient/kvcoord/txn_correctness_test.go index 5a8abb96cf15..a92497295e77 100644 --- a/pkg/kv/kvclient/kvcoord/txn_correctness_test.go +++ b/pkg/kv/kvclient/kvcoord/txn_correctness_test.go @@ -30,7 +30,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/kv/kvserver/concurrency/isolation" "github.com/cockroachdb/cockroach/pkg/kv/kvserver/kvserverbase" "github.com/cockroachdb/cockroach/pkg/storage/enginepb" - "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/localtestcluster" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" @@ -952,7 +951,7 @@ func checkConcurrency( }, }, } - s.Start(t, testutils.NewNodeTestBaseContext(), kvcoord.InitFactoryForLocalTestCluster) + s.Start(t, kvcoord.InitFactoryForLocalTestCluster) defer s.Stop() // Reduce the deadlock detection push delay so that txns that encounter locks // begin deadlock detection immediately. This speeds up tests significantly. diff --git a/pkg/kv/kvclient/kvcoord/txn_test.go b/pkg/kv/kvclient/kvcoord/txn_test.go index 7724b1f5d307..7f47c63101f8 100644 --- a/pkg/kv/kvclient/kvcoord/txn_test.go +++ b/pkg/kv/kvclient/kvcoord/txn_test.go @@ -111,7 +111,7 @@ func BenchmarkSingleRoundtripWithLatency(b *testing.B) { b.Run(fmt.Sprintf("latency=%s", latency), func(b *testing.B) { var s localtestcluster.LocalTestCluster s.Latency = latency - s.Start(b, testutils.NewNodeTestBaseContext(), kvcoord.InitFactoryForLocalTestCluster) + s.Start(b, kvcoord.InitFactoryForLocalTestCluster) defer s.Stop() defer b.StopTimer() key := roachpb.Key("key") diff --git a/pkg/kv/kvserver/client_decommission_test.go b/pkg/kv/kvserver/client_decommission_test.go index 95cf2f91e4dd..04fb4797773f 100644 --- a/pkg/kv/kvserver/client_decommission_test.go +++ b/pkg/kv/kvserver/client_decommission_test.go @@ -53,11 +53,10 @@ func TestDecommission(t *testing.T) { defer tc.Stopper().Stop(ctx) k := tc.ScratchRange(t) - admin, err := tc.GetAdminClient(ctx, t, 0) - require.NoError(t, err) + admin := tc.GetAdminClient(t, 0) // Decommission the first node, which holds most of the leases. - _, err = admin.Decommission( + _, err := admin.Decommission( ctx, &serverpb.DecommissionRequest{ NodeIDs: []roachpb.NodeID{1}, TargetMembership: livenesspb.MembershipStatus_DECOMMISSIONING, diff --git a/pkg/kv/kvserver/idalloc/id_alloc_test.go b/pkg/kv/kvserver/idalloc/id_alloc_test.go index b94751a1e5c5..d370ab69e977 100644 --- a/pkg/kv/kvserver/idalloc/id_alloc_test.go +++ b/pkg/kv/kvserver/idalloc/id_alloc_test.go @@ -40,7 +40,7 @@ func newTestAllocator(t testing.TB) (*localtestcluster.LocalTestCluster, *idallo DisableLivenessHeartbeat: true, DontCreateSystemRanges: true, } - s.Start(t, testutils.NewNodeTestBaseContext(), kvcoord.InitFactoryForLocalTestCluster) + s.Start(t, kvcoord.InitFactoryForLocalTestCluster) idAlloc, err := idalloc.NewAllocator(idalloc.Options{ AmbientCtx: s.Cfg.AmbientCtx, Key: keys.RangeIDGenerator, diff --git a/pkg/kv/kvserver/liveness/client_test.go b/pkg/kv/kvserver/liveness/client_test.go index af43fce99934..e933851c783e 100644 --- a/pkg/kv/kvserver/liveness/client_test.go +++ b/pkg/kv/kvserver/liveness/client_test.go @@ -203,10 +203,7 @@ func TestNodeLivenessStatusMap(t *testing.T) { log.Infof(ctx, "checking status map") // See what comes up in the status. - admin, err := tc.GetAdminClient(ctx, t, 0) - if err != nil { - t.Fatal(err) - } + admin := tc.GetAdminClient(t, 0) type testCase struct { nodeID roachpb.NodeID diff --git a/pkg/kv/kvserver/loqrecovery/collect_raft_log_test.go b/pkg/kv/kvserver/loqrecovery/collect_raft_log_test.go index 4fb4c2936548..113bc83e1522 100644 --- a/pkg/kv/kvserver/loqrecovery/collect_raft_log_test.go +++ b/pkg/kv/kvserver/loqrecovery/collect_raft_log_test.go @@ -248,8 +248,7 @@ func TestCollectLeaseholderStatus(t *testing.T) { defer tc.Stopper().Stop(ctx) require.NoError(t, tc.WaitForFullReplication()) - adm, err := tc.GetAdminClient(ctx, t, 0) - require.NoError(t, err, "failed to get admin client") + adm := tc.GetAdminClient(t, 0) // Note: we need to retry because replica collection is not atomic and // leaseholder could move around so we could see none or more than one. diff --git a/pkg/kv/kvserver/loqrecovery/server_integration_test.go b/pkg/kv/kvserver/loqrecovery/server_integration_test.go index 3326d6f1a500..d4025bf97b3f 100644 --- a/pkg/kv/kvserver/loqrecovery/server_integration_test.go +++ b/pkg/kv/kvserver/loqrecovery/server_integration_test.go @@ -70,8 +70,7 @@ func TestReplicaCollection(t *testing.T) { r := tc.ServerConn(0).QueryRow("select count(*) from crdb_internal.ranges_no_leases") var totalRanges int require.NoError(t, r.Scan(&totalRanges), "failed to query range count") - adm, err := tc.GetAdminClient(ctx, t, 2) - require.NoError(t, err, "failed to get admin client") + adm := tc.GetAdminClient(t, 2) // Collect and assert replica metadata. For expectMeta case we sometimes have // meta and sometimes doesn't depending on which node holds the lease. @@ -80,7 +79,7 @@ func TestReplicaCollection(t *testing.T) { var replicas loqrecoverypb.ClusterReplicaInfo var stats loqrecovery.CollectionStats - replicas, stats, err = loqrecovery.CollectRemoteReplicaInfo(ctx, adm) + replicas, stats, err := loqrecovery.CollectRemoteReplicaInfo(ctx, adm) require.NoError(t, err, "failed to retrieve replica info") // Check counters on retrieved replica info. @@ -146,14 +145,13 @@ func TestStreamRestart(t *testing.T) { r := tc.ServerConn(0).QueryRow("select count(*) from crdb_internal.ranges_no_leases") var totalRanges int require.NoError(t, r.Scan(&totalRanges), "failed to query range count") - adm, err := tc.GetAdminClient(ctx, t, 2) - require.NoError(t, err, "failed to get admin client") + adm := tc.GetAdminClient(t, 2) assertReplicas := func(liveNodes int) { var replicas loqrecoverypb.ClusterReplicaInfo var stats loqrecovery.CollectionStats - replicas, stats, err = loqrecovery.CollectRemoteReplicaInfo(ctx, adm) + replicas, stats, err := loqrecovery.CollectRemoteReplicaInfo(ctx, adm) require.NoError(t, err, "failed to retrieve replica info") // Check counters on retrieved replica info. @@ -202,8 +200,7 @@ func TestGetPlanStagingState(t *testing.T) { defer reg.CloseAllStickyInMemEngines() defer tc.Stopper().Stop(ctx) - adm, err := tc.GetAdminClient(ctx, t, 0) - require.NoError(t, err, "failed to get admin client") + adm := tc.GetAdminClient(t, 0) resp, err := adm.RecoveryVerify(ctx, &serverpb.RecoveryVerifyRequest{}) require.NoError(t, err) @@ -264,8 +261,7 @@ func TestStageRecoveryPlans(t *testing.T) { defer reg.CloseAllStickyInMemEngines() defer tc.Stopper().Stop(ctx) - adm, err := tc.GetAdminClient(ctx, t, 0) - require.NoError(t, err, "failed to get admin client") + adm := tc.GetAdminClient(t, 0) resp, err := adm.RecoveryVerify(ctx, &serverpb.RecoveryVerifyRequest{}) require.NoError(t, err) @@ -305,8 +301,7 @@ func TestStageBadVersions(t *testing.T) { defer reg.CloseAllStickyInMemEngines() defer tc.Stopper().Stop(ctx) - adm, err := tc.GetAdminClient(ctx, t, 0) - require.NoError(t, err, "failed to get admin client") + adm := tc.GetAdminClient(t, 0) sk := tc.ScratchRange(t) plan := makeTestRecoveryPlan(ctx, t, adm) @@ -316,7 +311,7 @@ func TestStageBadVersions(t *testing.T) { plan.Version = clusterversion.ByKey(clusterversion.BinaryMinSupportedVersionKey) plan.Version.Major -= 1 - _, err = adm.RecoveryStagePlan(ctx, &serverpb.RecoveryStagePlanRequest{Plan: &plan, AllNodes: true}) + _, err := adm.RecoveryStagePlan(ctx, &serverpb.RecoveryStagePlanRequest{Plan: &plan, AllNodes: true}) require.Error(t, err, "shouldn't stage plan with old version") plan.Version.Major += 2 @@ -334,8 +329,7 @@ func TestStageConflictingPlans(t *testing.T) { defer reg.CloseAllStickyInMemEngines() defer tc.Stopper().Stop(ctx) - adm, err := tc.GetAdminClient(ctx, t, 0) - require.NoError(t, err, "failed to get admin client") + adm := tc.GetAdminClient(t, 0) resp, err := adm.RecoveryVerify(ctx, &serverpb.RecoveryVerifyRequest{}) require.NoError(t, err) @@ -374,8 +368,7 @@ func TestForcePlanUpdate(t *testing.T) { defer reg.CloseAllStickyInMemEngines() defer tc.Stopper().Stop(ctx) - adm, err := tc.GetAdminClient(ctx, t, 0) - require.NoError(t, err, "failed to get admin client") + adm := tc.GetAdminClient(t, 0) resV, err := adm.RecoveryVerify(ctx, &serverpb.RecoveryVerifyRequest{}) require.NoError(t, err) @@ -416,8 +409,7 @@ func TestNodeDecommissioned(t *testing.T) { defer reg.CloseAllStickyInMemEngines() defer tc.Stopper().Stop(ctx) - adm, err := tc.GetAdminClient(ctx, t, 0) - require.NoError(t, err, "failed to get admin client") + adm := tc.GetAdminClient(t, 0) tc.StopServer(2) @@ -449,12 +441,11 @@ func TestRejectDecommissionReachableNode(t *testing.T) { defer reg.CloseAllStickyInMemEngines() defer tc.Stopper().Stop(ctx) - adm, err := tc.GetAdminClient(ctx, t, 0) - require.NoError(t, err, "failed to get admin client") + adm := tc.GetAdminClient(t, 0) plan := makeTestRecoveryPlan(ctx, t, adm) plan.DecommissionedNodeIDs = []roachpb.NodeID{roachpb.NodeID(3)} - _, err = adm.RecoveryStagePlan(ctx, + _, err := adm.RecoveryStagePlan(ctx, &serverpb.RecoveryStagePlanRequest{Plan: &plan, AllNodes: true}) require.ErrorContains(t, err, "was planned for decommission, but is present in cluster", "staging plan decommissioning live nodes must not be allowed") @@ -470,8 +461,7 @@ func TestStageRecoveryPlansToWrongCluster(t *testing.T) { defer reg.CloseAllStickyInMemEngines() defer tc.Stopper().Stop(ctx) - adm, err := tc.GetAdminClient(ctx, t, 0) - require.NoError(t, err, "failed to get admin client") + adm := tc.GetAdminClient(t, 0) resp, err := adm.RecoveryVerify(ctx, &serverpb.RecoveryVerifyRequest{}) require.NoError(t, err) @@ -518,8 +508,7 @@ func TestRetrieveRangeStatus(t *testing.T) { tc.StopServer(int(rs[1].NodeID - 1)) admServer := int(rs[2].NodeID - 1) - adm, err := tc.GetAdminClient(ctx, t, admServer) - require.NoError(t, err, "failed to get admin client") + adm := tc.GetAdminClient(t, admServer) r, err := adm.RecoveryVerify(ctx, &serverpb.RecoveryVerifyRequest{ DecommissionedNodeIDs: []roachpb.NodeID{rs[0].NodeID, rs[1].NodeID}, @@ -577,8 +566,7 @@ func TestRetrieveApplyStatus(t *testing.T) { tc.StopServer(int(rs[0].NodeID - 1)) tc.StopServer(int(rs[1].NodeID - 1)) - adm, err := tc.GetAdminClient(ctx, t, admServer) - require.NoError(t, err, "failed to get admin client") + adm := tc.GetAdminClient(t, admServer) var replicas loqrecoverypb.ClusterReplicaInfo testutils.SucceedsSoon(t, func() error { @@ -626,8 +614,8 @@ func TestRetrieveApplyStatus(t *testing.T) { // Need to get new client connection as we one we used belonged to killed // server. - adm, err = tc.GetAdminClient(ctx, t, admServer) - require.NoError(t, err, "failed to get admin client") + adm = tc.GetAdminClient(t, admServer) + r, err = adm.RecoveryVerify(ctx, &serverpb.RecoveryVerifyRequest{ PendingPlanID: &plan.PlanID, DecommissionedNodeIDs: plan.DecommissionedNodeIDs, @@ -655,8 +643,7 @@ func TestRejectBadVersionApplication(t *testing.T) { defer reg.CloseAllStickyInMemEngines() defer tc.Stopper().Stop(ctx) - adm, err := tc.GetAdminClient(ctx, t, 0) - require.NoError(t, err, "failed to get admin client") + adm := tc.GetAdminClient(t, 0) var replicas loqrecoverypb.ClusterReplicaInfo testutils.SucceedsSoon(t, func() error { diff --git a/pkg/kv/kvserver/raft_transport_test.go b/pkg/kv/kvserver/raft_transport_test.go index 0990099014b1..8241f2d2cd90 100644 --- a/pkg/kv/kvserver/raft_transport_test.go +++ b/pkg/kv/kvserver/raft_transport_test.go @@ -36,7 +36,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/metric" "github.com/cockroachdb/cockroach/pkg/util/netutil" "github.com/cockroachdb/cockroach/pkg/util/stop" - "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/util/tracing" "github.com/cockroachdb/cockroach/pkg/util/uuid" "github.com/cockroachdb/errors" @@ -126,14 +125,10 @@ func newRaftTransportTestContext(t testing.TB, st *cluster.Settings) *raftTransp transports: map[roachpb.NodeID]*kvserver.RaftTransport{}, st: st, } - rttc.nodeRPCContext = rpc.NewContext(ctx, rpc.ContextOptions{ - TenantID: roachpb.SystemTenantID, - Config: testutils.NewNodeTestBaseContext(), - Clock: &timeutil.DefaultTimeSource{}, - ToleratedOffset: time.Nanosecond, - Stopper: rttc.stopper, - Settings: st, - }) + opts := rpc.DefaultContextOptions() + opts.Stopper = rttc.stopper + opts.Settings = st + rttc.nodeRPCContext = rpc.NewContext(ctx, opts) // Ensure that tests using this test context and restart/shut down // their servers do not inadvertently start talking to servers from // unrelated concurrent tests. diff --git a/pkg/kv/kvserver/raft_transport_unit_test.go b/pkg/kv/kvserver/raft_transport_unit_test.go index 35907a3a7bd8..56f5722f73e6 100644 --- a/pkg/kv/kvserver/raft_transport_unit_test.go +++ b/pkg/kv/kvserver/raft_transport_unit_test.go @@ -18,7 +18,6 @@ import ( "testing" "time" - "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/kv/kvserver/kvflowcontrol/kvflowdispatch" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/rpc" @@ -29,7 +28,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/netutil" "github.com/cockroachdb/cockroach/pkg/util/stop" - "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/util/uuid" "github.com/cockroachdb/errors" "github.com/stretchr/testify/require" @@ -44,15 +42,14 @@ func TestRaftTransportStartNewQueue(t *testing.T) { defer stopper.Stop(ctx) st := cluster.MakeTestingClusterSettings() - rpcC := rpc.NewContext(ctx, - rpc.ContextOptions{ - TenantID: roachpb.SystemTenantID, - Config: &base.Config{Insecure: true}, - Clock: &timeutil.DefaultTimeSource{}, - ToleratedOffset: 500 * time.Millisecond, - Stopper: stopper, - Settings: st, - }) + opts := rpc.DefaultContextOptions() + opts.Insecure = true + opts.ToleratedOffset = 500 * time.Millisecond + opts.Stopper = stopper + opts.Settings = st + + rpcC := rpc.NewContext(ctx, opts) + rpcC.StorageClusterID.Set(context.Background(), uuid.MakeV4()) // mrs := &dummyMultiRaftServer{} diff --git a/pkg/kv/kvserver/replica_learner_test.go b/pkg/kv/kvserver/replica_learner_test.go index 83b85b9bdfcd..60e03bf6deba 100644 --- a/pkg/kv/kvserver/replica_learner_test.go +++ b/pkg/kv/kvserver/replica_learner_test.go @@ -985,8 +985,7 @@ func testRaftSnapshotsToNonVoters(t *testing.T, drainReceivingNode bool) { // effect on the outcome of this test. const drainingServerIdx = 1 const drainingNodeID = drainingServerIdx + 1 - client, err := tc.GetAdminClient(ctx, t, drainingServerIdx) - require.NoError(t, err) + client := tc.GetAdminClient(t, drainingServerIdx) drain(ctx, t, client, drainingNodeID) } @@ -1091,13 +1090,12 @@ func TestSnapshotsToDrainingNodes(t *testing.T) { }, ) defer tc.Stopper().Stop(ctx) - client, err := tc.GetAdminClient(ctx, t, drainingServerIdx) - require.NoError(t, err) + client := tc.GetAdminClient(t, drainingServerIdx) drain(ctx, t, client, drainingNodeID) // Now, we try to add a replica to it, we expect that to fail. scratchKey := tc.ScratchRange(t) - _, err = tc.AddVoters(scratchKey, makeReplicationTargets(drainingNodeID)...) + _, err := tc.AddVoters(scratchKey, makeReplicationTargets(drainingNodeID)...) require.Regexp(t, "store is draining", err) }) diff --git a/pkg/kv/kvserver/replicate_queue_test.go b/pkg/kv/kvserver/replicate_queue_test.go index 5be58113966b..97029880014b 100644 --- a/pkg/kv/kvserver/replicate_queue_test.go +++ b/pkg/kv/kvserver/replicate_queue_test.go @@ -38,7 +38,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/kv/kvserver/kvserverpb" "github.com/cockroachdb/cockroach/pkg/kv/kvserver/liveness/livenesspb" "github.com/cockroachdb/cockroach/pkg/roachpb" - "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/settings/cluster" @@ -860,10 +859,8 @@ func TestReplicateQueueTracingOnError(t *testing.T) { tc.AddVotersOrFatal(t, scratchKey, tc.Target(decomNodeIdx)) tc.AddVotersOrFatal(t, scratchKey, tc.Target(decomNodeIdx+1)) adminSrv := tc.Server(decomNodeIdx) - conn, err := adminSrv.RPCContext().GRPCDialNode(adminSrv.RPCAddr(), adminSrv.NodeID(), rpc.DefaultClass).Connect(ctx) - require.NoError(t, err) - adminClient := serverpb.NewAdminClient(conn) - _, err = adminClient.Decommission( + adminClient := adminSrv.GetAdminClient(t) + _, err := adminClient.Decommission( ctx, &serverpb.DecommissionRequest{ NodeIDs: []roachpb.NodeID{decomNodeID}, TargetMembership: livenesspb.MembershipStatus_DECOMMISSIONING, @@ -987,10 +984,8 @@ func TestReplicateQueueDecommissionPurgatoryError(t *testing.T) { tc.AddVotersOrFatal(t, scratchKey, tc.Target(decomNodeIdx)) tc.AddVotersOrFatal(t, scratchKey, tc.Target(decomNodeIdx+1)) adminSrv := tc.Server(decomNodeIdx) - conn, err := adminSrv.RPCContext().GRPCDialNode(adminSrv.RPCAddr(), adminSrv.NodeID(), rpc.DefaultClass).Connect(ctx) - require.NoError(t, err) - adminClient := serverpb.NewAdminClient(conn) - _, err = adminClient.Decommission( + adminClient := adminSrv.GetAdminClient(t) + _, err := adminClient.Decommission( ctx, &serverpb.DecommissionRequest{ NodeIDs: []roachpb.NodeID{decomNodeID}, TargetMembership: livenesspb.MembershipStatus_DECOMMISSIONING, diff --git a/pkg/kv/kvserver/scatter_test.go b/pkg/kv/kvserver/scatter_test.go index 67d83008d404..0e6c983b2533 100644 --- a/pkg/kv/kvserver/scatter_test.go +++ b/pkg/kv/kvserver/scatter_test.go @@ -57,8 +57,7 @@ func TestAdminScatterWithDrainingNodes(t *testing.T) { scratchKey := tc.ScratchRange(t) tc.AddVotersOrFatal(t, scratchKey, tc.Target(drainingServerIdx)) - client, err := tc.GetAdminClient(ctx, t, drainingServerIdx) - require.NoError(t, err) + client := tc.GetAdminClient(t, drainingServerIdx) drain(ctx, t, client, drainingNodeID) nonDrainingStore := tc.GetFirstStoreFromServer(t, 0) diff --git a/pkg/kv/kvserver/store_test.go b/pkg/kv/kvserver/store_test.go index ef582ce57fb8..df008f3caa73 100644 --- a/pkg/kv/kvserver/store_test.go +++ b/pkg/kv/kvserver/store_test.go @@ -179,16 +179,14 @@ func createTestStoreWithoutStart( // Setup fake zone config handler. config.TestingSetupZoneConfigHook(stopper) - rpcContext := rpc.NewContext(ctx, - rpc.ContextOptions{ - TenantID: roachpb.SystemTenantID, - Config: &base.Config{Insecure: true}, - Clock: cfg.Clock.WallClock(), - ToleratedOffset: cfg.Clock.ToleratedOffset(), - Stopper: stopper, - Settings: cfg.Settings, - TenantRPCAuthorizer: tenantcapabilitiesauthorizer.NewAllowEverythingAuthorizer(), - }) + rpcOpts := rpc.DefaultContextOptions() + rpcOpts.Insecure = true + rpcOpts.Clock = cfg.Clock.WallClock() + rpcOpts.ToleratedOffset = cfg.Clock.ToleratedOffset() + rpcOpts.Stopper = stopper + rpcOpts.Settings = cfg.Settings + rpcOpts.TenantRPCAuthorizer = tenantcapabilitiesauthorizer.NewAllowEverythingAuthorizer() + rpcContext := rpc.NewContext(ctx, rpcOpts) stopper.SetTracer(cfg.AmbientCtx.Tracer) server, err := rpc.NewServer(rpcContext) // never started require.NoError(t, err) diff --git a/pkg/obsservice/obslib/ingest/BUILD.bazel b/pkg/obsservice/obslib/ingest/BUILD.bazel index 4bbc33b8abbd..890a3adc293b 100644 --- a/pkg/obsservice/obslib/ingest/BUILD.bazel +++ b/pkg/obsservice/obslib/ingest/BUILD.bazel @@ -33,7 +33,6 @@ go_test( "//pkg/obsservice/obspb/opentelemetry-proto/common/v1:common", "//pkg/obsservice/obspb/opentelemetry-proto/logs/v1:logs", "//pkg/obsservice/obspb/opentelemetry-proto/resource/v1:resource", - "//pkg/roachpb", "//pkg/rpc", "//pkg/security/securityassets", "//pkg/security/securitytest", diff --git a/pkg/obsservice/obslib/ingest/ingest_integration_test.go b/pkg/obsservice/obslib/ingest/ingest_integration_test.go index 06b0932bdf04..90c4eaebcd1d 100644 --- a/pkg/obsservice/obslib/ingest/ingest_integration_test.go +++ b/pkg/obsservice/obslib/ingest/ingest_integration_test.go @@ -15,14 +15,12 @@ import ( "net" "strings" "testing" - "time" "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/obs" "github.com/cockroachdb/cockroach/pkg/obsservice/obslib/obsutil" "github.com/cockroachdb/cockroach/pkg/obsservice/obspb" logspb "github.com/cockroachdb/cockroach/pkg/obsservice/obspb/opentelemetry-proto/collector/logs/v1" - "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/testutils" @@ -30,7 +28,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/stop" - "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" "github.com/stretchr/testify/require" ) @@ -76,16 +73,12 @@ func TestEventIngestionIntegration(t *testing.T) { obsStop := stop.NewStopper() defer obsStop.Stop(ctx) e := MakeEventIngester(ctx, testConsumer, nil) - rpcContext := rpc.NewContext(ctx, - rpc.ContextOptions{ - TenantID: roachpb.SystemTenantID, - NodeID: &base.NodeIDContainer{}, - Config: &base.Config{Insecure: true}, - Clock: &timeutil.DefaultTimeSource{}, - ToleratedOffset: time.Nanosecond, - Stopper: obsStop, - Settings: cluster.MakeTestingClusterSettings(), - }) + opts := rpc.DefaultContextOptions() + opts.NodeID = &base.NodeIDContainer{} + opts.Insecure = true + opts.Stopper = obsStop + opts.Settings = cluster.MakeTestingClusterSettings() + rpcContext := rpc.NewContext(ctx, opts) grpcServer, err := rpc.NewServer(rpcContext) require.NoError(t, err) defer grpcServer.Stop() diff --git a/pkg/rpc/BUILD.bazel b/pkg/rpc/BUILD.bazel index 3bc169d8122d..951b4bcb366d 100644 --- a/pkg/rpc/BUILD.bazel +++ b/pkg/rpc/BUILD.bazel @@ -9,6 +9,7 @@ go_library( "addjoin.go", "auth.go", "auth_tenant.go", + "client.go", "clock_offset.go", "codec.go", "connection.go", @@ -54,6 +55,7 @@ go_library( "//pkg/util/metric", "//pkg/util/metric/aggmetric", "//pkg/util/netutil", + "//pkg/util/netutil/addr", "//pkg/util/stop", "//pkg/util/syncutil", "//pkg/util/timeutil", diff --git a/pkg/rpc/client.go b/pkg/rpc/client.go new file mode 100644 index 000000000000..c6dbc68ee514 --- /dev/null +++ b/pkg/rpc/client.go @@ -0,0 +1,186 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package rpc + +import ( + "context" + "time" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/settings/cluster" + "github.com/cockroachdb/cockroach/pkg/util/hlc" + "github.com/cockroachdb/cockroach/pkg/util/netutil/addr" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/cockroach/pkg/util/tracing" + "github.com/cockroachdb/errors" + "google.golang.org/grpc" +) + +// ClientConnConfig contains the configuration for a ClientConn. +type ClientConnConfig struct { + // Tracer is an optional tracer. + // May be nil. + Tracer *tracing.Tracer + + // Settings are the cluster settings to use. + Settings *cluster.Settings + + // ServerAddr is the server address to connect to. + ServerAddr string + + // Insecure specifies whether to use an insecure connection. + Insecure bool + + // SSLCertsDir is the directory containing the SSL certificates to use. + SSLCertsDir string + + // User is the username to use. + User username.SQLUsername + + // ClusterName is the name of the cluster to connect to. + // This is used to verify the server's identity. + ClusterName string + + // DisableClusterNameVerification can be set to avoid checking + // the cluser name. + DisableClusterNameVerification bool + + // TestingKnobs are optional testing knobs. + TestingKnobs ContextTestingKnobs + + // RPCHeartbeatInterval is the interval for sending a heartbeat. + RPCHeartbeatInterval time.Duration + // RPCHeartbeatTimeout is the timeout for a heartbeat. + RPCHeartbeatTimeout time.Duration + + // HistogramWindowInterval is the duration of the sliding window used to + // measure clock offsets. + HistogramWindowInterval time.Duration + + // Clock is the HLC clock to use. This can be nil to disable + // maxoffset verification. + Clock *hlc.Clock +} + +// MakeClientConnConfigFromBaseConfig creates a ClientConnConfig from a +// base.Config. +func MakeClientConnConfigFromBaseConfig( + cfg base.Config, + user username.SQLUsername, + tracer *tracing.Tracer, + settings *cluster.Settings, + clock *hlc.Clock, + knobs ContextTestingKnobs, +) ClientConnConfig { + return ClientConnConfig{ + Tracer: tracer, + Settings: settings, + User: user, + TestingKnobs: knobs, + Clock: clock, + ServerAddr: cfg.AdvertiseAddr, + Insecure: cfg.Insecure, + SSLCertsDir: cfg.SSLCertsDir, + ClusterName: cfg.ClusterName, + DisableClusterNameVerification: cfg.DisableClusterNameVerification, + RPCHeartbeatInterval: cfg.RPCHeartbeatInterval, + RPCHeartbeatTimeout: cfg.RPCHeartbeatTimeout, + HistogramWindowInterval: cfg.HistogramWindowInterval(), + } +} + +// NewClientContext creates a new client context. +func NewClientContext(ctx context.Context, cfg ClientConnConfig) (*Context, *stop.Stopper) { + tracer := cfg.Tracer + if tracer == nil { + tracer = tracing.NewTracer() + } + stopper := stop.NewStopper(stop.WithTracer(tracer)) + opts := ContextOptions{ + // Setting TenantID this way here is intentional. It means + // that the TLS cert selection will not attempt to use + // a tenant server cert to connect to the remote server, + // and instead pick up a regular client cert. + TenantID: roachpb.SystemTenantID, + + Stopper: stopper, + Settings: cfg.Settings, + + ClientOnly: true, + Insecure: cfg.Insecure, + SSLCertsDir: cfg.SSLCertsDir, + ClusterName: cfg.ClusterName, + DisableClusterNameVerification: cfg.DisableClusterNameVerification, + RPCHeartbeatInterval: cfg.RPCHeartbeatInterval, + RPCHeartbeatTimeout: cfg.RPCHeartbeatTimeout, + HistogramWindowInterval: cfg.HistogramWindowInterval, + User: cfg.User, + Knobs: cfg.TestingKnobs, + AdvertiseAddrH: &base.AdvertiseAddrH{}, + SQLAdvertiseAddrH: &base.SQLAdvertiseAddrH{}, + } + + if cfg.Clock == nil { + opts.Clock = &timeutil.DefaultTimeSource{} + opts.ToleratedOffset = time.Nanosecond + } else { + opts.Clock = cfg.Clock.WallClock() + opts.ToleratedOffset = cfg.Clock.MaxOffset() + } + + // Client connections never use loopback dialing. + opts.Knobs.NoLoopbackDialer = true + + return NewContext(ctx, opts), stopper +} + +// NewClientConn creates a new client connection. +// The caller is responsible for calling the returned function +// to release associated resources. +func NewClientConn( + ctx context.Context, cfg ClientConnConfig, +) (conn *grpc.ClientConn, cleanup func(), err error) { + if ctx.Done() == nil { + return nil, nil, errors.New("context must be cancellable") + } + rpcContext, stopper := NewClientContext(ctx, cfg) + closer := func() { + // We use context.Background() here and not ctx because we + // want to ensure that the closers always run to completion + // even if the context used to create the client conn is + // canceled. + stopper.Stop(context.Background()) + } + defer func() { + if err != nil { + closer() + } + }() + + addr, err := addr.AddrWithDefaultLocalhost(cfg.ServerAddr) + if err != nil { + return nil, nil, err + } + // We use GRPCUnvalidatedDial() here because it does not matter + // to which node we're talking to. + conn, err = rpcContext.GRPCUnvalidatedDial(addr).Connect(ctx) + if err != nil { + return nil, nil, err + } + stopper.AddCloser(stop.CloserFn(func() { + _ = conn.Close() // nolint:grpcconnclose + })) + + return conn, closer, nil +} diff --git a/pkg/rpc/context.go b/pkg/rpc/context.go index 76891b52b138..df0986cd54e2 100644 --- a/pkg/rpc/context.go +++ b/pkg/rpc/context.go @@ -27,6 +27,8 @@ import ( "github.com/cockroachdb/cockroach/pkg/multitenant/tenantcapabilities" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/security" + "github.com/cockroachdb/cockroach/pkg/security/certnames" + "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/settings" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/util" @@ -111,7 +113,7 @@ func NewServerEx( grpc.KeepaliveParams(serverKeepalive), grpc.KeepaliveEnforcementPolicy(serverEnforcement), } - if !rpcCtx.Config.Insecure { + if !rpcCtx.ContextOptions.Insecure { tlsConfig, err := rpcCtx.GetServerTLSConfig() if err != nil { return nil, sii, err @@ -144,7 +146,7 @@ func NewServerEx( }) }) - if !rpcCtx.Config.Insecure { + if !rpcCtx.ContextOptions.Insecure { a := kvAuth{ sv: &rpcCtx.Settings.SV, tenant: tenantAuthorizer{ @@ -288,13 +290,37 @@ func (c *Context) SetLoopbackDialer(loopbackDialFn func(context.Context) (net.Co // ContextOptions are passed to NewContext to set up a new *Context. // All pointer fields and TenantID are required. type ContextOptions struct { - TenantID roachpb.TenantID - Config *base.Config + TenantID roachpb.TenantID + Clock hlc.WallClock ToleratedOffset time.Duration FatalOnOffsetViolation bool Stopper *stop.Stopper Settings *cluster.Settings + + // Fields from base.Config used by a rpc.Context, either in clients + // or servers. + SSLCertsDir string + Insecure bool + ClusterName string + DisableClusterNameVerification bool + RPCHeartbeatInterval time.Duration + RPCHeartbeatTimeout time.Duration + HistogramWindowInterval time.Duration + + // Fields from base.Config used only in RPC servers. + LocalityAddresses []roachpb.LocalityAddress + // We include the advertised address by reference because it is set + // during server startup after the rpc.Context is created. + *base.AdvertiseAddrH + + // The following fields come from base.Config and only needed by + // SecurityContextOptions. They are not really used by the 'rpc' + // code. They really belong elsewhere - see comments in + // SecurityContextOptions. + *base.SQLAdvertiseAddrH // only for servers + DisableTLSForHTTP bool // only for servers + // OnIncomingPing is called when handling a PingRequest, after // preliminary checks but before recording clock offset information. // It can inject an error or modify the response. @@ -328,6 +354,9 @@ type ContextOptions struct { // cluster version, a node ID, etc. ClientOnly bool + // User is used in case ClientOnly is true. + User username.SQLUsername + // UseNodeAuth is only used when ClientOnly is not set. // When set, it indicates that this rpc.Context is running inside // the same process as a KV layer and thus should feel empowered @@ -350,13 +379,44 @@ type ContextOptions struct { NeedsDialback bool } +// DefaultContextOptions are mostly used in tests. +func DefaultContextOptions() ContextOptions { + return ContextOptions{ + SSLCertsDir: certnames.EmbeddedCertsDir, + TenantID: roachpb.SystemTenantID, + User: username.NodeUserName(), + HistogramWindowInterval: base.DefaultHistogramWindowInterval(), + RPCHeartbeatInterval: base.PingInterval, + RPCHeartbeatTimeout: base.DefaultRPCHeartbeatTimeout, + AdvertiseAddrH: &base.AdvertiseAddrH{}, + SQLAdvertiseAddrH: &base.SQLAdvertiseAddrH{}, + Clock: &timeutil.DefaultTimeSource{}, + ToleratedOffset: time.Nanosecond, + } +} + +// ServerContextOptionsFromBaseConfig initializes a ContextOptions struct from +// a base.Config struct. +func ServerContextOptionsFromBaseConfig(cfg *base.Config) ContextOptions { + return ContextOptions{ + SSLCertsDir: cfg.SSLCertsDir, + Insecure: cfg.Insecure, + ClusterName: cfg.ClusterName, + DisableClusterNameVerification: cfg.DisableClusterNameVerification, + RPCHeartbeatInterval: cfg.RPCHeartbeatInterval, + RPCHeartbeatTimeout: cfg.RPCHeartbeatTimeout, + HistogramWindowInterval: cfg.HistogramWindowInterval(), + LocalityAddresses: cfg.LocalityAddresses, + AdvertiseAddrH: &cfg.AdvertiseAddrH, + SQLAdvertiseAddrH: &cfg.SQLAdvertiseAddrH, + DisableTLSForHTTP: cfg.DisableTLSForHTTP, + } +} + func (c ContextOptions) validate() error { if c.TenantID == (roachpb.TenantID{}) { return errors.New("must specify TenantID") } - if c.Config == nil { - return errors.New("Config must be set") - } if c.Clock == nil { return errors.New("Clock must be set") } @@ -366,6 +426,9 @@ func (c ContextOptions) validate() error { if c.Settings == nil { return errors.New("Settings must be set") } + if c.RPCHeartbeatInterval == 0 { + return errors.New("RPCHeartbeatInterval must be set") + } // NB: OnOutgoingPing and OnIncomingPing default to noops. // This is used both for testing and the cli. @@ -454,7 +517,13 @@ func NewContext(ctx context.Context, opts ContextOptions) *Context { masterCtx, _ := opts.Stopper.WithCancelOnQuiesce(ctx) secCtx := NewSecurityContext( - opts.Config, + SecurityContextOptions{ + SSLCertsDir: opts.SSLCertsDir, + Insecure: opts.Insecure, + AdvertiseAddrH: opts.AdvertiseAddrH, + SQLAdvertiseAddrH: opts.SQLAdvertiseAddrH, + DisableTLSForHTTP: opts.DisableTLSForHTTP, + }, security.ClusterTLSSettings(opts.Settings), opts.TenantID, opts.TenantRPCAuthorizer, @@ -467,8 +536,8 @@ func NewContext(ctx context.Context, opts ContextOptions) *Context { rpcCompression: enableRPCCompression, MasterCtx: masterCtx, metrics: makeMetrics(), - heartbeatInterval: opts.Config.RPCHeartbeatInterval, - heartbeatTimeout: opts.Config.RPCHeartbeatTimeout, + heartbeatInterval: opts.RPCHeartbeatInterval, + heartbeatTimeout: opts.RPCHeartbeatTimeout, } rpcCtx.dialbackMu.Lock() @@ -479,7 +548,7 @@ func NewContext(ctx context.Context, opts ContextOptions) *Context { panic("tenant ID not set") } - if opts.ClientOnly && opts.Config.User.Undefined() { + if opts.ClientOnly && opts.User.Undefined() { panic("client username not set") } @@ -492,7 +561,7 @@ func NewContext(ctx context.Context, opts ContextOptions) *Context { // Ensure we still have a working dial function in that case. rpcCtx.loopbackDialFn = func(ctx context.Context) (net.Conn, error) { d := onlyOnceDialer{} - return d.dial(ctx, opts.Config.AdvertiseAddr) + return d.dial(ctx, opts.AdvertiseAddr) } } @@ -500,7 +569,7 @@ func NewContext(ctx context.Context, opts ContextOptions) *Context { // CLI commands are exempted. if !opts.ClientOnly { rpcCtx.RemoteClocks = newRemoteClockMonitor( - opts.Clock, opts.ToleratedOffset, 10*opts.Config.RPCHeartbeatTimeout, opts.Config.HistogramWindowInterval()) + opts.Clock, opts.ToleratedOffset, 10*opts.RPCHeartbeatTimeout, opts.HistogramWindowInterval) } if id := opts.Knobs.StorageClusterID; id != nil { @@ -546,7 +615,7 @@ func (rpcCtx *Context) ClusterName() string { // This is used in tests. return "" } - return rpcCtx.Config.ClusterName + return rpcCtx.ContextOptions.ClusterName } // Metrics returns the Context's Metrics struct. @@ -1444,7 +1513,7 @@ func (rpcCtx *Context) GRPCDialOptions( ctx context.Context, target string, class ConnectionClass, ) ([]grpc.DialOption, error) { transport := tcpTransport - if rpcCtx.Config.AdvertiseAddr == target && !rpcCtx.ClientOnly { + if rpcCtx.ContextOptions.AdvertiseAddr == target && !rpcCtx.ClientOnly { // See the explanation on loopbackDialFn for an explanation about this. transport = loopbackTransport } @@ -1543,7 +1612,7 @@ func (rpcCtx *Context) GetClientTLSConfig() (*tls.Config, error) { switch { case rpcCtx.ClientOnly: // A CLI command is performing a remote RPC. - tlsCfg, err := cm.GetClientTLSConfig(rpcCtx.config.User) + tlsCfg, err := cm.GetClientTLSConfig(rpcCtx.User) return tlsCfg, wrapError(err) case rpcCtx.UseNodeAuth || rpcCtx.tenID.IsSystem(): @@ -1568,7 +1637,7 @@ func (rpcCtx *Context) GetClientTLSConfig() (*tls.Config, error) { // RPC client authenticates itself to the remote server. func (rpcCtx *Context) dialOptsNetworkCredentials() ([]grpc.DialOption, error) { var dialOpts []grpc.DialOption - if rpcCtx.Config.Insecure { + if rpcCtx.ContextOptions.Insecure { dialOpts = append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) } else { tlsConfig, err := rpcCtx.GetClientTLSConfig() @@ -2010,7 +2079,7 @@ func (rpcCtx *Context) grpcDialRaw( ctx context.Context, target string, class ConnectionClass, additionalOpts ...grpc.DialOption, ) (*grpc.ClientConn, error) { transport := tcpTransport - if rpcCtx.Config.AdvertiseAddr == target && !rpcCtx.ClientOnly { + if rpcCtx.ContextOptions.AdvertiseAddr == target && !rpcCtx.ClientOnly { // See the explanation on loopbackDialFn for an explanation about this. transport = loopbackTransport } @@ -2116,7 +2185,7 @@ func (rpcCtx *Context) NewHeartbeatService() *HeartbeatService { clock: rpcCtx.Clock, remoteClockMonitor: rpcCtx.RemoteClocks, clusterName: rpcCtx.ClusterName(), - disableClusterNameVerification: rpcCtx.Config.DisableClusterNameVerification, + disableClusterNameVerification: rpcCtx.ContextOptions.DisableClusterNameVerification, clusterID: rpcCtx.StorageClusterID, nodeID: rpcCtx.NodeID, version: rpcCtx.Settings.Version, diff --git a/pkg/rpc/context_test.go b/pkg/rpc/context_test.go index 8ba6bc5f02b5..6bce3dbb7bd8 100644 --- a/pkg/rpc/context_test.go +++ b/pkg/rpc/context_test.go @@ -98,15 +98,18 @@ func newTestServer(t testing.TB, ctx *Context, extraOpts ...grpc.ServerOption) * func newTestContextWithKnobs( clock hlc.WallClock, maxOffset time.Duration, stopper *stop.Stopper, knobs ContextTestingKnobs, ) *Context { - return NewContext(context.Background(), ContextOptions{ - TenantID: roachpb.SystemTenantID, - Config: testutils.NewNodeTestBaseContext(), - Clock: clock, - ToleratedOffset: maxOffset, - Stopper: stopper, - Settings: cluster.MakeTestingClusterSettings(), - Knobs: knobs, - }) + opts := DefaultContextOptions() + opts.Clock = clock + opts.ToleratedOffset = maxOffset + opts.Stopper = stopper + opts.Knobs = knobs + opts.Settings = cluster.MakeTestingClusterSettings() + + // Make tests faster. + opts.RPCHeartbeatInterval = 10 * time.Millisecond + opts.RPCHeartbeatTimeout = 0 + + return NewContext(context.Background(), opts) } func newTestContext( @@ -131,26 +134,23 @@ func TestPingInterceptors(t *testing.T) { errBoomSend := errors.Handled(errors.New("boom due to onSendPing")) recvMsg := "boom due to onHandlePing" errBoomRecv := status.Error(codes.FailedPrecondition, recvMsg) - opts := ContextOptions{ - TenantID: roachpb.SystemTenantID, - Config: testutils.NewNodeTestBaseContext(), - Clock: &timeutil.DefaultTimeSource{}, - ToleratedOffset: 500 * time.Millisecond, - Stopper: stop.NewStopper(), - Settings: cluster.MakeTestingClusterSettings(), - OnOutgoingPing: func(ctx context.Context, req *PingRequest) error { - if req.TargetNodeID == blockedTargetNodeID { - return errBoomSend - } - return nil - }, - OnIncomingPing: func(ctx context.Context, req *PingRequest, resp *PingResponse) error { - if req.OriginNodeID == blockedOriginNodeID { - return errBoomRecv - } - return nil - }, + opts := DefaultContextOptions() + opts.ToleratedOffset = 500 * time.Millisecond + opts.Stopper = stop.NewStopper() + opts.Settings = cluster.MakeTestingClusterSettings() + opts.OnOutgoingPing = func(ctx context.Context, req *PingRequest) error { + if req.TargetNodeID == blockedTargetNodeID { + return errBoomSend + } + return nil + } + opts.OnIncomingPing = func(ctx context.Context, req *PingRequest, resp *PingResponse) error { + if req.OriginNodeID == blockedOriginNodeID { + return errBoomRecv + } + return nil } + defer opts.Stopper.Stop(ctx) rpcCtx := NewContext(ctx, opts) @@ -203,21 +203,16 @@ func testClockOffsetInPingRequestInternal(t *testing.T, clientOnly bool) { defer func() { close(done) }() // Build a minimal server. - opts := ContextOptions{ - TenantID: roachpb.SystemTenantID, - Config: testutils.NewNodeTestBaseContext(), - Clock: &timeutil.DefaultTimeSource{}, - ToleratedOffset: 500 * time.Millisecond, - Stopper: stopper, - Settings: cluster.MakeTestingClusterSettings(), - } + opts := DefaultContextOptions() + opts.ToleratedOffset = 500 * time.Millisecond + opts.Stopper = stopper + opts.Settings = cluster.MakeTestingClusterSettings() rpcCtxServer := NewContext(ctx, opts) clientOpts := opts - clientOpts.Config = testutils.NewNodeTestBaseContext() // Experimentally, values below 50ms seem to incur flakiness. - clientOpts.Config.RPCHeartbeatInterval = 100 * time.Millisecond - clientOpts.Config.RPCHeartbeatTimeout = 100 * time.Millisecond + clientOpts.RPCHeartbeatInterval = 100 * time.Millisecond + clientOpts.RPCHeartbeatTimeout = 100 * time.Millisecond clientOpts.ClientOnly = clientOnly clientOpts.OnOutgoingPing = func(ctx context.Context, req *PingRequest) error { select { @@ -418,8 +413,7 @@ func TestInternalServerAddress(t *testing.T) { maxOffset := time.Duration(0) serverCtx := newTestContext(uuid.MakeV4(), clock, maxOffset, stopper) - serverCtx.Config.Addr = "127.0.0.1:9999" - serverCtx.Config.AdvertiseAddr = "127.0.0.1:8888" + serverCtx.AdvertiseAddr = "127.0.0.1:8888" serverCtx.NodeID.Set(context.Background(), 1) internal := &internalServer{} @@ -446,8 +440,7 @@ func TestInternalClientAdapterRunsInterceptors(t *testing.T) { maxOffset := time.Duration(0) serverCtx := newTestContext(uuid.MakeV4(), clock, maxOffset, stopper) - serverCtx.Config.Addr = "127.0.0.1:9999" - serverCtx.Config.AdvertiseAddr = "127.0.0.1:8888" + serverCtx.AdvertiseAddr = "127.0.0.1:8888" serverCtx.NodeID.Set(context.Background(), 1) _ /* server */, serverInterceptors, err := NewServerEx(serverCtx) @@ -547,8 +540,7 @@ func TestInternalClientAdapterWithClientStreamInterceptors(t *testing.T) { maxOffset := time.Duration(0) serverCtx := newTestContext(uuid.MakeV4(), clock, maxOffset, stopper) - serverCtx.Config.Addr = "127.0.0.1:9999" - serverCtx.Config.AdvertiseAddr = "127.0.0.1:8888" + serverCtx.AdvertiseAddr = "127.0.0.1:8888" serverCtx.NodeID.Set(context.Background(), 1) _ /* server */, serverInterceptors, err := NewServerEx(serverCtx) @@ -623,8 +615,7 @@ func TestInternalClientAdapterWithServerStreamInterceptors(t *testing.T) { maxOffset := time.Duration(0) serverCtx := newTestContext(uuid.MakeV4(), clock, maxOffset, stopper) - serverCtx.Config.Addr = "127.0.0.1:9999" - serverCtx.Config.AdvertiseAddr = "127.0.0.1:8888" + serverCtx.AdvertiseAddr = "127.0.0.1:8888" serverCtx.NodeID.Set(context.Background(), 1) _ /* server */, serverInterceptors, err := NewServerEx(serverCtx) @@ -778,8 +769,7 @@ func BenchmarkInternalClientAdapter(b *testing.B) { clock := timeutil.NewManualTime(timeutil.Unix(0, 1)) maxOffset := time.Duration(0) serverCtx := newTestContext(uuid.MakeV4(), clock, maxOffset, stopper) - serverCtx.Config.Addr = "127.0.0.1:9999" - serverCtx.Config.AdvertiseAddr = "127.0.0.1:8888" + serverCtx.AdvertiseAddr = "127.0.0.1:8888" serverCtx.NodeID.Set(context.Background(), 1) _, interceptors, err := NewServerEx(serverCtx) @@ -844,7 +834,7 @@ func TestConnectLoopback(t *testing.T) { require.NoError(t, err) addr := ln.Addr().String() - clientCtx.Config.AdvertiseAddr = addr + clientCtx.AdvertiseAddr = addr // Connect and get the error that comes from loopbackLn, proving that we were routed there. _, err = clientCtx.GRPCDialNode(addr, nodeID, DefaultClass).Connect(ctx) @@ -887,8 +877,8 @@ func TestOffsetMeasurement(t *testing.T) { clientMaxOffset := time.Duration(0) clientCtx := newTestContext(clusterID, clientClock, clientMaxOffset, stopper) // Make the interval shorter to speed up the test. - clientCtx.Config.RPCHeartbeatInterval = 1 * time.Millisecond - clientCtx.Config.RPCHeartbeatTimeout = 1 * time.Millisecond + clientCtx.RPCHeartbeatInterval = 1 * time.Millisecond + clientCtx.RPCHeartbeatTimeout = 1 * time.Millisecond clientCtx.RemoteClocks.offsetTTL = 5 * clientClock.getAdvancementInterval() if _, err := clientCtx.GRPCDialNode(remoteAddr, serverNodeID, DefaultClass).Connect(ctx); err != nil { t.Fatal(err) @@ -1033,8 +1023,8 @@ func TestLatencyInfoCleanupOnClosedConnection(t *testing.T) { clientMaxOffset := time.Duration(0) clientCtx := newTestContext(clusterID, clientClock, clientMaxOffset, stopper) // Make the interval shorter to speed up the test. - clientCtx.Config.RPCHeartbeatInterval = 1 * time.Millisecond - clientCtx.Config.RPCHeartbeatTimeout = 1 * time.Millisecond + clientCtx.RPCHeartbeatInterval = 1 * time.Millisecond + clientCtx.RPCHeartbeatTimeout = 1 * time.Millisecond var hbDecommission atomic.Value hbDecommission.Store(false) @@ -1159,8 +1149,8 @@ func TestRemoteOffsetUnhealthy(t *testing.T) { clock := timeutil.NewManualTime(timeutil.Unix(0, start.Add(nodeCtxs[i].offset).UnixNano())) nodeCtxs[i].errChan = make(chan error, 1) nodeCtxs[i].ctx = newTestContext(clusterID, clock, maxOffset, stopper) - nodeCtxs[i].ctx.Config.RPCHeartbeatInterval = maxOffset - nodeCtxs[i].ctx.Config.RPCHeartbeatTimeout = maxOffset + nodeCtxs[i].ctx.RPCHeartbeatInterval = maxOffset + nodeCtxs[i].ctx.RPCHeartbeatTimeout = maxOffset nodeCtxs[i].ctx.NodeID.Set(context.Background(), roachpb.NodeID(i+1)) s := newTestServer(t, nodeCtxs[i].ctx) @@ -1175,7 +1165,7 @@ func TestRemoteOffsetUnhealthy(t *testing.T) { if err != nil { t.Fatal(err) } - nodeCtxs[i].ctx.Config.Addr = ln.Addr().String() + nodeCtxs[i].ctx.AdvertiseAddr = ln.Addr().String() } // Fully connect the nodes. @@ -1184,7 +1174,7 @@ func TestRemoteOffsetUnhealthy(t *testing.T) { if i == j { continue } - if _, err := clientNodeContext.ctx.GRPCDialNode(serverNodeContext.ctx.Config.Addr, serverNodeContext.ctx.NodeID.Get(), DefaultClass).Connect(ctx); err != nil { + if _, err := clientNodeContext.ctx.GRPCDialNode(serverNodeContext.ctx.AdvertiseAddr, serverNodeContext.ctx.NodeID.Get(), DefaultClass).Connect(ctx); err != nil { t.Fatal(err) } } @@ -1377,8 +1367,8 @@ func grpcRunKeepaliveTestCase(testCtx context.Context, c grpcKeepaliveTestCase) log.Infof(ctx, "setting up client") clientCtx := newTestContext(clusterID, clock, maxOffset, stopper) // Disable automatic heartbeats. We'll send them by hand. - clientCtx.Config.RPCHeartbeatInterval = math.MaxInt64 - clientCtx.Config.RPCHeartbeatTimeout = math.MaxInt64 + clientCtx.RPCHeartbeatInterval = math.MaxInt64 + clientCtx.RPCHeartbeatTimeout = math.MaxInt64 var firstConn int32 = 1 @@ -1709,8 +1699,8 @@ func TestClusterNameMismatch(t *testing.T) { defer stopper.Stop(context.Background()) serverCtx := newTestContext(uuid.MakeV4(), clock, maxOffset, stopper) - serverCtx.Config.ClusterName = c.serverName - serverCtx.Config.DisableClusterNameVerification = c.serverDisablePeerCheck + serverCtx.ContextOptions.ClusterName = c.serverName + serverCtx.ContextOptions.DisableClusterNameVerification = c.serverDisablePeerCheck s := newTestServer(t, serverCtx) RegisterHeartbeatServer(s, &HeartbeatService{ @@ -1719,8 +1709,8 @@ func TestClusterNameMismatch(t *testing.T) { clusterID: serverCtx.StorageClusterID, nodeID: serverCtx.NodeID, version: serverCtx.Settings.Version, - clusterName: serverCtx.Config.ClusterName, - disableClusterNameVerification: serverCtx.Config.DisableClusterNameVerification, + clusterName: serverCtx.ContextOptions.ClusterName, + disableClusterNameVerification: serverCtx.ContextOptions.DisableClusterNameVerification, }) ln, err := netutil.ListenAndServeGRPC(serverCtx.Stopper, s, util.TestAddr) @@ -1730,8 +1720,8 @@ func TestClusterNameMismatch(t *testing.T) { remoteAddr := ln.Addr().String() clientCtx := newTestContext(serverCtx.StorageClusterID.Get(), clock, maxOffset, stopper) - clientCtx.Config.ClusterName = c.clientName - clientCtx.Config.DisableClusterNameVerification = c.clientDisablePeerCheck + clientCtx.ContextOptions.ClusterName = c.clientName + clientCtx.ContextOptions.DisableClusterNameVerification = c.clientDisablePeerCheck var wg sync.WaitGroup for i := 0; i < 10; i++ { @@ -2188,19 +2178,17 @@ func newRegisteredServer( nodeID roachpb.NodeID, ) (*Context, string, *trackingListener) { clock := timeutil.NewManualTime(timeutil.Unix(0, 1)) - opts := ContextOptions{ - TenantID: roachpb.SystemTenantID, - Config: testutils.NewNodeTestBaseContext(), - Clock: clock, - ToleratedOffset: time.Duration(0), - Stopper: stopper, - Settings: cluster.MakeTestingClusterSettings(), - NeedsDialback: true, - Knobs: ContextTestingKnobs{NoLoopbackDialer: true}, - } + opts := DefaultContextOptions() + opts.Clock = clock + opts.ToleratedOffset = time.Duration(0) + opts.Stopper = stopper + opts.Settings = cluster.MakeTestingClusterSettings() + opts.NeedsDialback = true + opts.Knobs = ContextTestingKnobs{NoLoopbackDialer: true} + // Heartbeat faster so we don't have to wait as long. - opts.Config.RPCHeartbeatInterval = 10 * time.Millisecond - opts.Config.RPCHeartbeatTimeout = 1000 * time.Millisecond + opts.RPCHeartbeatInterval = 10 * time.Millisecond + opts.RPCHeartbeatTimeout = 1000 * time.Millisecond rpcCtx := NewContext(context.Background(), opts) // This is normally set up inside the server, we want to hold onto all PingRequests that come through. @@ -2237,7 +2225,7 @@ func newRegisteredServer( log.Infof(ctx, "Listening on %s", addr) // This needs to be set once we know our address so that ping requests have // the correct reverse addr in them. - rpcCtx.Config.AdvertiseAddr = addr + rpcCtx.AdvertiseAddr = addr return rpcCtx, addr, &tracker } diff --git a/pkg/rpc/context_testutils.go b/pkg/rpc/context_testutils.go index 723a5beb0366..b50d643f94a4 100644 --- a/pkg/rpc/context_testutils.go +++ b/pkg/rpc/context_testutils.go @@ -14,8 +14,6 @@ import ( "context" "time" - "github.com/cockroachdb/cockroach/pkg/base" - "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/cockroach/pkg/util/stop" @@ -92,14 +90,13 @@ func NewInsecureTestingContextWithClusterID( func NewInsecureTestingContextWithKnobs( ctx context.Context, clock *hlc.Clock, stopper *stop.Stopper, knobs ContextTestingKnobs, ) *Context { - return NewContext(ctx, - ContextOptions{ - TenantID: roachpb.SystemTenantID, - Config: &base.Config{Insecure: true}, - Clock: clock.WallClock(), - ToleratedOffset: clock.ToleratedOffset(), - Stopper: stopper, - Settings: cluster.MakeTestingClusterSettings(), - Knobs: knobs, - }) + opts := DefaultContextOptions() + opts.Insecure = true + opts.Clock = clock.WallClock() + opts.ToleratedOffset = clock.ToleratedOffset() + opts.Settings = cluster.MakeTestingClusterSettings() + opts.Knobs = knobs + opts.Stopper = stopper + + return NewContext(ctx, opts) } diff --git a/pkg/rpc/nodedialer/nodedialer_test.go b/pkg/rpc/nodedialer/nodedialer_test.go index a7a7fa25e2ae..1ef4ad3a48a1 100644 --- a/pkg/rpc/nodedialer/nodedialer_test.go +++ b/pkg/rpc/nodedialer/nodedialer_test.go @@ -233,7 +233,7 @@ func TestConnHealthInternal(t *testing.T) { &internalServer{}, rpc.ServerInterceptorInfo{}, rpc.ClientInterceptorInfo{}) rpcCtx.NodeID.Set(ctx, staticNodeID) - rpcCtx.Config.AdvertiseAddr = localAddr.String() + rpcCtx.AdvertiseAddr = localAddr.String() nd := New(rpcCtx, newSingleNodeResolver(staticNodeID, localAddr)) defer stopper.Stop(ctx) @@ -310,19 +310,16 @@ func newTestServer( func newTestContext( clock hlc.WallClock, maxOffset time.Duration, stopper *stop.Stopper, ) *rpc.Context { - cfg := testutils.NewNodeTestBaseContext() - cfg.Insecure = true - cfg.RPCHeartbeatInterval = 100 * time.Millisecond - cfg.RPCHeartbeatTimeout = 500 * time.Millisecond ctx := context.Background() - rctx := rpc.NewContext(ctx, rpc.ContextOptions{ - TenantID: roachpb.SystemTenantID, - Config: cfg, - Clock: clock, - ToleratedOffset: maxOffset, - Stopper: stopper, - Settings: cluster.MakeTestingClusterSettings(), - }) + opts := rpc.DefaultContextOptions() + opts.Insecure = true + opts.RPCHeartbeatInterval = 100 * time.Millisecond + opts.RPCHeartbeatTimeout = 500 * time.Millisecond + opts.Clock = clock + opts.ToleratedOffset = maxOffset + opts.Stopper = stopper + opts.Settings = cluster.MakeTestingClusterSettings() + rctx := rpc.NewContext(ctx, opts) // Ensure that tests using this test context and restart/shut down // their servers do not inadvertently start talking to servers from // unrelated concurrent tests. diff --git a/pkg/rpc/peer.go b/pkg/rpc/peer.go index d70eec0f0ed8..5ff1372af218 100644 --- a/pkg/rpc/peer.go +++ b/pkg/rpc/peer.go @@ -370,10 +370,10 @@ func runSingleHeartbeat( // heartbeat to heartbeat: we compute a new .Offset at the end of // the current heartbeat as input to the next one. request := &PingRequest{ - OriginAddr: opts.Config.AdvertiseAddr, + OriginAddr: opts.AdvertiseAddr, TargetNodeID: k.NodeID, ServerVersion: opts.Settings.Version.BinaryVersion(), - LocalityAddress: opts.Config.LocalityAddresses, + LocalityAddress: opts.LocalityAddresses, ClusterID: &clusterID, OriginNodeID: opts.NodeID.Get(), NeedsDialback: preferredDialback, @@ -414,9 +414,9 @@ func runSingleHeartbeat( // new node in a cluster and mistakenly joins the wrong // cluster gets a chance to see the error message on their // management console. - if !opts.Config.DisableClusterNameVerification && !response.DisableClusterNameVerification { + if !opts.DisableClusterNameVerification && !response.DisableClusterNameVerification { err = errors.Wrap( - checkClusterName(opts.Config.ClusterName, response.ClusterName), + checkClusterName(opts.ClusterName, response.ClusterName), "cluster name check failed on ping response") if err != nil { return err diff --git a/pkg/rpc/tls.go b/pkg/rpc/tls.go index 052bc31c8dc1..2ed75821a9f4 100644 --- a/pkg/rpc/tls.go +++ b/pkg/rpc/tls.go @@ -54,12 +54,33 @@ func wrapError(err error) error { return err } +// SecurityContextOptions contains the subset of base.Config +// useful to define a SecurityContext. +type SecurityContextOptions struct { + SSLCertsDir string + Insecure bool + + // DisableTLSForHTTP is only used by GetUIServerTLSConfig(). + // + // TODO(kv): it's a bit strange that the 'rpc' package + // is responsible for the HTTP TLS config. Maybe move it + // elsewhere? + DisableTLSForHTTP bool + // AdvertiseAddrH is only used by CheckCertificateAddrs(). + *base.AdvertiseAddrH + // SQLAdvertiseAddrH is only used by CheckCertificateAddrs(). + // + // TODO(kv): it's a bit strange that the 'rpc' package is + // responsible for the SQL TLS config. Maybe move it elsewhere? + *base.SQLAdvertiseAddrH +} + // SecurityContext is a wrapper providing transport security helpers such as // the certificate manager. type SecurityContext struct { certnames.Locator security.TLSSettings - config *base.Config + config SecurityContextOptions tenID roachpb.TenantID capabilitiesAuthorizer tenantcapabilities.Authorizer lazy struct { @@ -75,7 +96,7 @@ type SecurityContext struct { // // TODO(tbg): don't take a whole Config. This can be trimmed down significantly. func NewSecurityContext( - cfg *base.Config, + cfg SecurityContextOptions, tlsSettings security.TLSSettings, tenID roachpb.TenantID, capabilitiesAuthorizer tenantcapabilities.Authorizer, @@ -284,7 +305,10 @@ func (ctx *SecurityContext) CheckCertificateAddrs(cctx context.Context) { // HTTPRequestScheme returns "http" or "https" based on the value of // Insecure and DisableTLSForHTTP. func (ctx *SecurityContext) HTTPRequestScheme() string { - return ctx.config.HTTPRequestScheme() + if ctx.config.Insecure || ctx.config.DisableTLSForHTTP { + return "http" + } + return "https" } // certAddrs formats the list of addresses included in a certificate for diff --git a/pkg/rpc/tls_test.go b/pkg/rpc/tls_test.go index 0712ba28e6b9..6112aecc6ad9 100644 --- a/pkg/rpc/tls_test.go +++ b/pkg/rpc/tls_test.go @@ -13,17 +13,13 @@ package rpc import ( "context" "testing" - "time" - "github.com/cockroachdb/cockroach/pkg/base" - "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/security/certnames" "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/stop" - "github.com/cockroachdb/cockroach/pkg/util/timeutil" ) func TestClientSSLSettings(t *testing.T) { @@ -55,29 +51,26 @@ func TestClientSSLSettings(t *testing.T) { for _, tc := range testCases { t.Run("", func(t *testing.T) { - cfg := &base.Config{Insecure: tc.insecure, User: tc.user} + opts := DefaultContextOptions() + opts.Insecure = tc.insecure + opts.User = tc.user if tc.hasCerts { - cfg.SSLCertsDir = certnames.EmbeddedCertsDir + opts.SSLCertsDir = certnames.EmbeddedCertsDir } else { // We can't leave this empty because otherwise it refers to the cwd which // always exists. - cfg.SSLCertsDir = "i-do-not-exist" + opts.SSLCertsDir = "i-do-not-exist" } ctx := context.Background() stopper := stop.NewStopper() defer stopper.Stop(ctx) - rpcContext := NewContext(ctx, ContextOptions{ - TenantID: roachpb.SystemTenantID, - ClientOnly: true, - Clock: &timeutil.DefaultTimeSource{}, - ToleratedOffset: time.Nanosecond, - Stopper: stopper, - Settings: cluster.MakeTestingClusterSettings(), - Config: cfg, - }) + opts.ClientOnly = true + opts.Stopper = stopper + opts.Settings = cluster.MakeTestingClusterSettings() + rpcContext := NewContext(ctx, opts) - if cfg.HTTPRequestScheme() != tc.requestScheme { - t.Fatalf("expected HTTPRequestScheme=%s, got: %s", tc.requestScheme, cfg.HTTPRequestScheme()) + if expected, actual := tc.requestScheme, rpcContext.SecurityContext.HTTPRequestScheme(); expected != actual { + t.Fatalf("expected HTTPRequestScheme=%s, got: %s", expected, actual) } tlsConfig, err := rpcContext.GetClientTLSConfig() if !testutils.IsError(err, tc.configErr) { @@ -119,23 +112,20 @@ func TestServerSSLSettings(t *testing.T) { for tcNum, tc := range testCases { t.Run("", func(t *testing.T) { - cfg := &base.Config{Insecure: tc.insecure, User: username.NodeUserName()} - if tc.hasCerts { - cfg.SSLCertsDir = certnames.EmbeddedCertsDir + opts := DefaultContextOptions() + opts.Insecure = tc.insecure + if !tc.hasCerts { + opts.SSLCertsDir = "i-do-not-exist" } ctx := context.Background() stopper := stop.NewStopper() defer stopper.Stop(ctx) - rpcContext := NewContext(ctx, ContextOptions{ - TenantID: roachpb.SystemTenantID, - Clock: &timeutil.DefaultTimeSource{}, - ToleratedOffset: time.Nanosecond, - Stopper: stopper, - Settings: cluster.MakeTestingClusterSettings(), - Config: cfg, - }) - if cfg.HTTPRequestScheme() != tc.requestScheme { - t.Fatalf("#%d: expected HTTPRequestScheme=%s, got: %s", tcNum, tc.requestScheme, cfg.HTTPRequestScheme()) + + opts.Stopper = stopper + opts.Settings = cluster.MakeTestingClusterSettings() + rpcContext := NewContext(ctx, opts) + if actual, expected := rpcContext.HTTPRequestScheme(), tc.requestScheme; actual != expected { + t.Fatalf("#%d: expected HTTPRequestScheme=%s, got: %s", tcNum, expected, actual) } tlsConfig, err := rpcContext.GetServerTLSConfig() if (err == nil) != tc.configSuccess { diff --git a/pkg/security/certs_rotation_test.go b/pkg/security/certs_rotation_test.go index da777e396045..a9b939792033 100644 --- a/pkg/security/certs_rotation_test.go +++ b/pkg/security/certs_rotation_test.go @@ -104,8 +104,7 @@ func TestRotateCerts(t *testing.T) { const kBadCertificate = "tls: bad certificate" // Test client with the same certs. - clientContext := testutils.NewNodeTestBaseContext() - clientContext.SSLCertsDir = certsDir + clientContext := rpc.SecurityContextOptions{SSLCertsDir: certsDir} firstSCtx := rpc.NewSecurityContext( clientContext, security.CommandTLSSettings{}, @@ -140,7 +139,6 @@ func TestRotateCerts(t *testing.T) { // Setup a second http client. It will load the new certs. // We need to use a new context as it keeps the certificate manager around. // Fails on crypto errors. - clientContext = testutils.NewNodeTestBaseContext() clientContext.SSLCertsDir = certsDir secondSCtx := rpc.NewSecurityContext( @@ -255,7 +253,6 @@ func TestRotateCerts(t *testing.T) { // Setup a third http client. It will load the new certs. // We need to use a new context as it keeps the certificate manager around. // This is HTTP and succeeds because we do not ask for or verify client certificates. - clientContext = testutils.NewNodeTestBaseContext() clientContext.SSLCertsDir = certsDir thirdSCtx := rpc.NewSecurityContext( clientContext, diff --git a/pkg/security/certs_test.go b/pkg/security/certs_test.go index fb18c807fa33..7ab7a05bb17e 100644 --- a/pkg/security/certs_test.go +++ b/pkg/security/certs_test.go @@ -35,6 +35,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/errors" "github.com/stretchr/testify/require" ) @@ -376,6 +377,8 @@ func generateSplitCACerts(certsDir string) error { // We construct SSL server and clients and use the generated certs. func TestUseCerts(t *testing.T) { defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + // Do not mock cert access for this test. securityassets.ResetLoader() defer ResetTest() @@ -398,8 +401,7 @@ func TestUseCerts(t *testing.T) { defer s.Stopper().Stop(context.Background()) // Insecure mode. - clientContext := testutils.NewNodeTestBaseContext() - clientContext.Insecure = true + clientContext := rpc.SecurityContextOptions{Insecure: true} sCtx := rpc.NewSecurityContext( clientContext, security.CommandTLSSettings{}, @@ -422,8 +424,7 @@ func TestUseCerts(t *testing.T) { } // New client. With certs this time. - clientContext = testutils.NewNodeTestBaseContext() - clientContext.SSLCertsDir = certsDir + clientContext = rpc.SecurityContextOptions{SSLCertsDir: certsDir} { secondSCtx := rpc.NewSecurityContext( clientContext, @@ -468,6 +469,8 @@ func makeSecurePGUrl(addr, user, certsDir, caName, certName, keyName string) str // We construct SSL server and clients and use the generated certs. func TestUseSplitCACerts(t *testing.T) { defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + // Do not mock cert access for this test. securityassets.ResetLoader() defer ResetTest() @@ -490,8 +493,7 @@ func TestUseSplitCACerts(t *testing.T) { defer s.Stopper().Stop(context.Background()) // Insecure mode. - clientContext := testutils.NewNodeTestBaseContext() - clientContext.Insecure = true + clientContext := rpc.SecurityContextOptions{Insecure: true} sCtx := rpc.NewSecurityContext( clientContext, security.CommandTLSSettings{}, @@ -514,8 +516,7 @@ func TestUseSplitCACerts(t *testing.T) { } // New client. With certs this time. - clientContext = testutils.NewNodeTestBaseContext() - clientContext.SSLCertsDir = certsDir + clientContext = rpc.SecurityContextOptions{SSLCertsDir: certsDir} { secondSCtx := rpc.NewSecurityContext( clientContext, @@ -587,6 +588,8 @@ func TestUseSplitCACerts(t *testing.T) { // We construct SSL server and clients and use the generated certs. func TestUseWrongSplitCACerts(t *testing.T) { defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + // Do not mock cert access for this test. securityassets.ResetLoader() defer ResetTest() @@ -618,8 +621,7 @@ func TestUseWrongSplitCACerts(t *testing.T) { defer s.Stopper().Stop(context.Background()) // Insecure mode. - clientContext := testutils.NewNodeTestBaseContext() - clientContext.Insecure = true + clientContext := rpc.SecurityContextOptions{Insecure: true} sCtx := rpc.NewSecurityContext( clientContext, security.CommandTLSSettings{}, @@ -642,8 +644,7 @@ func TestUseWrongSplitCACerts(t *testing.T) { } // New client with certs, but the UI CA is gone, we have no way to verify the Admin UI cert. - clientContext = testutils.NewNodeTestBaseContext() - clientContext.SSLCertsDir = certsDir + clientContext = rpc.SecurityContextOptions{SSLCertsDir: certsDir} { secondCtx := rpc.NewSecurityContext( clientContext, diff --git a/pkg/server/application_api/BUILD.bazel b/pkg/server/application_api/BUILD.bazel index 3ecf46413d2b..873a529e897f 100644 --- a/pkg/server/application_api/BUILD.bazel +++ b/pkg/server/application_api/BUILD.bazel @@ -42,7 +42,6 @@ go_test( "//pkg/kv/kvclient/kvtenant", "//pkg/kv/kvserver", "//pkg/roachpb", - "//pkg/rpc", "//pkg/security/securityassets", "//pkg/security/securitytest", "//pkg/security/username", @@ -79,7 +78,6 @@ go_test( "//pkg/util/randident", "//pkg/util/randutil", "//pkg/util/safesql", - "//pkg/util/stop", "//pkg/util/syncutil", "//pkg/util/timeutil", "//pkg/util/uuid", diff --git a/pkg/server/application_api/activity_test.go b/pkg/server/application_api/activity_test.go index c38ebd97c9eb..6f4519f05119 100644 --- a/pkg/server/application_api/activity_test.go +++ b/pkg/server/application_api/activity_test.go @@ -17,8 +17,6 @@ import ( "testing" "github.com/cockroachdb/cockroach/pkg/base" - "github.com/cockroachdb/cockroach/pkg/rpc" - "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/server/apiconstants" "github.com/cockroachdb/cockroach/pkg/server/serverpb" @@ -110,15 +108,7 @@ func TestListActivitySecurity(t *testing.T) { } // gRPC requests behave as root and thus are always allowed. - rootConfig := testutils.NewTestBaseContext(username.RootUserName()) - rpcContext := srvtestutils.NewRPCTestContext(ctx, ts, rootConfig) - url := ts.AdvRPCAddr() - nodeID := ts.NodeID() - conn, err := rpcContext.GRPCDialNode(url, nodeID, rpc.DefaultClass).Connect(ctx) - if err != nil { - t.Fatal(err) - } - client := serverpb.NewStatusClient(conn) + client := ts.GetStatusClient(t) { request := &serverpb.ListContentionEventsRequest{} if resp, err := client.ListLocalContentionEvents(ctx, request); err != nil || len(resp.Errors) > 0 { diff --git a/pkg/server/application_api/jobs_test.go b/pkg/server/application_api/jobs_test.go index ae3e0256af8d..e9338dc3c078 100644 --- a/pkg/server/application_api/jobs_test.go +++ b/pkg/server/application_api/jobs_test.go @@ -22,7 +22,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/jobs" "github.com/cockroachdb/cockroach/pkg/jobs/jobspb" - "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/server/apiconstants" @@ -420,16 +419,7 @@ func TestJobStatusResponse(t *testing.T) { ts := serverutils.StartServerOnly(t, base.TestServerArgs{}) defer ts.Stopper().Stop(context.Background()) - rootConfig := testutils.NewTestBaseContext(username.RootUserName()) - rpcContext := srvtestutils.NewRPCTestContext(context.Background(), ts.(*server.TestServer), rootConfig) - - url := ts.AdvRPCAddr() - nodeID := ts.NodeID() - conn, err := rpcContext.GRPCDialNode(url, nodeID, rpc.DefaultClass).Connect(context.Background()) - if err != nil { - t.Fatal(err) - } - client := serverpb.NewStatusClient(conn) + client := ts.GetStatusClient(t) request := &serverpb.JobStatusRequest{JobId: -1} response, err := client.JobStatus(context.Background(), request) diff --git a/pkg/server/application_api/sessions_test.go b/pkg/server/application_api/sessions_test.go index 6c7d00ff9682..171b04fe0827 100644 --- a/pkg/server/application_api/sessions_test.go +++ b/pkg/server/application_api/sessions_test.go @@ -22,7 +22,6 @@ import ( "testing" "github.com/cockroachdb/cockroach/pkg/base" - "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/server/apiconstants" @@ -93,15 +92,7 @@ func TestListSessionsSecurity(t *testing.T) { } // gRPC requests behave as root and thus are always allowed. - rootConfig := testutils.NewTestBaseContext(username.RootUserName()) - rpcContext := srvtestutils.NewRPCTestContext(ctx, ts, rootConfig) - url := ts.AdvRPCAddr() - nodeID := ts.NodeID() - conn, err := rpcContext.GRPCDialNode(url, nodeID, rpc.DefaultClass).Connect(context.Background()) - if err != nil { - t.Fatal(err) - } - client := serverpb.NewStatusClient(conn) + client := ts.GetStatusClient(t) for _, user := range []string{"", apiconstants.TestingUser, username.RootUser} { request := &serverpb.ListSessionsRequest{Username: user} diff --git a/pkg/server/application_api/storage_inspection_test.go b/pkg/server/application_api/storage_inspection_test.go index bea7b266bea5..1ec62453d090 100644 --- a/pkg/server/application_api/storage_inspection_test.go +++ b/pkg/server/application_api/storage_inspection_test.go @@ -20,7 +20,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/keys" "github.com/cockroachdb/cockroach/pkg/roachpb" - "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/server/apiconstants" "github.com/cockroachdb/cockroach/pkg/server/rangetestutils" @@ -34,7 +33,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/grunning" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" - "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -457,9 +455,6 @@ func TestSpanStatsGRPCResponse(t *testing.T) { defer s.Stopper().Stop(ctx) ts := s.(*server.TestServer) - rpcStopper := stop.NewStopper() - defer rpcStopper.Stop(ctx) - rpcContext := srvtestutils.NewRPCTestContext(ctx, ts, ts.RPCContext().Config) span := roachpb.Span{ Key: roachpb.RKeyMin.AsRawKey(), EndKey: roachpb.RKeyMax.AsRawKey(), @@ -469,13 +464,7 @@ func TestSpanStatsGRPCResponse(t *testing.T) { Spans: []roachpb.Span{span}, } - url := ts.AdvRPCAddr() - nodeID := ts.NodeID() - conn, err := rpcContext.GRPCDialNode(url, nodeID, rpc.DefaultClass).Connect(ctx) - if err != nil { - t.Fatal(err) - } - client := serverpb.NewStatusClient(conn) + client := ts.GetStatusClient(t) response, err := client.SpanStats(ctx, &request) if err != nil { diff --git a/pkg/server/authserver/authentication_test.go b/pkg/server/authserver/authentication_test.go index 6707dcd26921..f01ed95de099 100644 --- a/pkg/server/authserver/authentication_test.go +++ b/pkg/server/authserver/authentication_test.go @@ -99,29 +99,25 @@ func TestSSLEnforcement(t *testing.T) { }) defer s.Stopper().Stop(ctx) - newRPCContext := func(cfg *base.Config) *rpc.Context { - return rpc.NewContext(ctx, rpc.ContextOptions{ - TenantID: roachpb.SystemTenantID, - Config: cfg, - Clock: &timeutil.DefaultTimeSource{}, - ToleratedOffset: time.Nanosecond, - Stopper: s.Stopper(), - Settings: s.ClusterSettings(), - }) + newRPCContext := func(insecure bool, user username.SQLUsername) *rpc.Context { + opts := rpc.DefaultContextOptions() + opts.Insecure = insecure + opts.User = user + opts.Stopper = s.Stopper() + opts.Settings = s.ClusterSettings() + return rpc.NewContext(ctx, opts) } // HTTPS with client certs for security.RootUser. - rootCertsContext := newRPCContext(testutils.NewTestBaseContext(username.RootUserName())) + rootCertsContext := newRPCContext(false, username.RootUserName()) // HTTPS with client certs for security.NodeUser. - nodeCertsContext := newRPCContext(testutils.NewNodeTestBaseContext()) + nodeCertsContext := newRPCContext(false, username.NodeUserName()) // HTTPS with client certs for TestUser. - testCertsContext := newRPCContext(testutils.NewTestBaseContext(username.TestUserName())) + testCertsContext := newRPCContext(false, username.TestUserName()) // HTTPS without client certs. The user does not matter. noCertsContext := insecureCtx{} // Plain http. - plainHTTPCfg := testutils.NewTestBaseContext(username.TestUserName()) - plainHTTPCfg.Insecure = true - insecureContext := newRPCContext(plainHTTPCfg) + insecureContext := newRPCContext(true, username.TestUserName()) kvGet := &kvpb.GetRequest{} kvGet.Key = roachpb.Key("/") diff --git a/pkg/server/drain_test.go b/pkg/server/drain_test.go index 0dfecab0dfa2..64d234592583 100644 --- a/pkg/server/drain_test.go +++ b/pkg/server/drain_test.go @@ -31,7 +31,6 @@ import ( "github.com/cockroachdb/errors" "github.com/kr/pretty" "github.com/stretchr/testify/require" - "google.golang.org/grpc" ) // TestDrain tests the Drain RPC. @@ -192,12 +191,8 @@ func newTestDrainContext(t *testing.T, drainSleepCallCount *int) *testDrainConte } // We'll have the RPC talk to the first node. - var err error - tc.c, tc.connCloser, err = getAdminClientForServer(tc.tc.Server(0)) - if err != nil { - tc.Close() - t.Fatal(err) - } + tc.c = tc.tc.Server(0).GetAdminClient(t) + tc.connCloser = func() {} return tc } @@ -294,20 +289,6 @@ func (t *testDrainContext) getDrainResponse( return resp, nil } -func getAdminClientForServer( - s serverutils.TestServerInterface, -) (c serverpb.AdminClient, closer func(), err error) { - //lint:ignore SA1019 grpc.WithInsecure is deprecated - conn, err := grpc.Dial(s.AdvRPCAddr(), grpc.WithInsecure()) - if err != nil { - return nil, nil, err - } - client := serverpb.NewAdminClient(conn) - return client, func() { - _ = conn.Close() // nolint:grpcconnclose - }, nil -} - func TestServerShutdownReleasesSession(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) diff --git a/pkg/server/server.go b/pkg/server/server.go index 7f1a23854288..7147690251ec 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -311,26 +311,26 @@ func NewServer(cfg Config, stopper *stop.Stopper) (*Server, error) { tenantCapabilitiesTestingKnobs, _ := cfg.TestingKnobs.TenantCapabilitiesTestingKnobs.(*tenantcapabilities.TestingKnobs) authorizer := tenantcapabilitiesauthorizer.New(cfg.Settings, tenantCapabilitiesTestingKnobs) - rpcCtxOpts := rpc.ContextOptions{ - TenantID: roachpb.SystemTenantID, - UseNodeAuth: true, - NodeID: nodeIDContainer, - StorageClusterID: cfg.ClusterIDContainer, - Config: cfg.Config, - Clock: clock.WallClock(), - ToleratedOffset: clock.ToleratedOffset(), - FatalOnOffsetViolation: true, - Stopper: stopper, - Settings: cfg.Settings, - OnOutgoingPing: func(ctx context.Context, req *rpc.PingRequest) error { - // Outgoing ping will block requests with codes.FailedPrecondition to - // notify caller that this replica is decommissioned but others could - // still be tried as caller node is valid, but not the destination. - return decommissionCheck(ctx, req.TargetNodeID, codes.FailedPrecondition) - }, - TenantRPCAuthorizer: authorizer, - NeedsDialback: true, - } + rpcCtxOpts := rpc.ServerContextOptionsFromBaseConfig(cfg.BaseConfig.Config) + + rpcCtxOpts.TenantID = roachpb.SystemTenantID + rpcCtxOpts.UseNodeAuth = true + rpcCtxOpts.NodeID = nodeIDContainer + rpcCtxOpts.StorageClusterID = cfg.ClusterIDContainer + rpcCtxOpts.Clock = clock.WallClock() + rpcCtxOpts.ToleratedOffset = clock.ToleratedOffset() + rpcCtxOpts.FatalOnOffsetViolation = true + rpcCtxOpts.Stopper = stopper + rpcCtxOpts.Settings = cfg.Settings + rpcCtxOpts.OnOutgoingPing = func(ctx context.Context, req *rpc.PingRequest) error { + // Outgoing ping will block requests with codes.FailedPrecondition to + // notify caller that this replica is decommissioned but others could + // still be tried as caller node is valid, but not the destination. + return decommissionCheck(ctx, req.TargetNodeID, codes.FailedPrecondition) + } + rpcCtxOpts.TenantRPCAuthorizer = authorizer + rpcCtxOpts.NeedsDialback = true + if knobs := cfg.TestingKnobs.Server; knobs != nil { serverKnobs := knobs.(*TestingKnobs) rpcCtxOpts.Knobs = serverKnobs.ContextTestingKnobs diff --git a/pkg/server/server_import_ts_test.go b/pkg/server/server_import_ts_test.go index b44734b92537..b3cfd6a90e7f 100644 --- a/pkg/server/server_import_ts_test.go +++ b/pkg/server/server_import_ts_test.go @@ -17,6 +17,7 @@ import ( "testing" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/settings/cluster" @@ -51,8 +52,7 @@ func TestServerWithTimeseriesImport(t *testing.T) { srv := testcluster.StartTestCluster(t, 1, base.TestClusterArgs{}) defer srv.Stopper().Stop(ctx) - cc, err := srv.Servers[0].RPCContext().GRPCUnvalidatedDial(srv.Servers[0].RPCAddr()).Connect(ctx) - require.NoError(t, err) + cc := srv.Server(0).RPCClientConn(t, username.RootUserName()) bytesDumped = dumpTSNonempty(t, cc, path) t.Logf("dumped %s bytes", humanizeutil.IBytes(bytesDumped)) }() @@ -68,8 +68,7 @@ func TestServerWithTimeseriesImport(t *testing.T) { } srv := testcluster.StartTestCluster(t, 1, args) defer srv.Stopper().Stop(ctx) - cc, err := srv.Servers[0].RPCContext().GRPCUnvalidatedDial(srv.Servers[0].RPCAddr()).Connect(ctx) - require.NoError(t, err) + cc := srv.Server(0).RPCClientConn(t, username.RootUserName()) // This would fail if we didn't supply a dump. Just the fact that it returns // successfully proves that we ingested at least some time series (or that we // failed to disable time series). diff --git a/pkg/server/srvtestutils/BUILD.bazel b/pkg/server/srvtestutils/BUILD.bazel index 126c3da4aec9..8ce0c31e8ff9 100644 --- a/pkg/server/srvtestutils/BUILD.bazel +++ b/pkg/server/srvtestutils/BUILD.bazel @@ -6,13 +6,9 @@ go_library( importpath = "github.com/cockroachdb/cockroach/pkg/server/srvtestutils", visibility = ["//visibility:public"], deps = [ - "//pkg/base", - "//pkg/roachpb", - "//pkg/rpc", "//pkg/server/apiconstants", "//pkg/testutils/serverutils", "//pkg/util/protoutil", "@com_github_cockroachdb_errors//:errors", - "@com_github_cockroachdb_logtags//:logtags", ], ) diff --git a/pkg/server/srvtestutils/testutils.go b/pkg/server/srvtestutils/testutils.go index 2c295d3bf14a..5beab3e79a39 100644 --- a/pkg/server/srvtestutils/testutils.go +++ b/pkg/server/srvtestutils/testutils.go @@ -11,18 +11,13 @@ package srvtestutils import ( - "context" "encoding/json" "io" - "github.com/cockroachdb/cockroach/pkg/base" - "github.com/cockroachdb/cockroach/pkg/roachpb" - "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/server/apiconstants" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/util/protoutil" "github.com/cockroachdb/errors" - "github.com/cockroachdb/logtags" ) // GetAdminJSONProto performs a RPC-over-HTTP request to the admin endpoint @@ -132,26 +127,3 @@ func GetJSON(ts serverutils.ApplicationLayerInterface, url string) (interface{}, } return jI, nil } - -// NewRPCTestContext constructs a RPC context for use in API tests. -func NewRPCTestContext( - ctx context.Context, ts serverutils.TestServerInterface, cfg *base.Config, -) *rpc.Context { - var c base.NodeIDContainer - ctx = logtags.AddTag(ctx, "n", &c) - rpcContext := rpc.NewContext(ctx, rpc.ContextOptions{ - TenantID: roachpb.SystemTenantID, - NodeID: &c, - Config: cfg, - Clock: ts.Clock().WallClock(), - ToleratedOffset: ts.Clock().ToleratedOffset(), - Stopper: ts.Stopper(), - Settings: ts.ClusterSettings(), - Knobs: rpc.ContextTestingKnobs{NoLoopbackDialer: true}, - }) - // Ensure that the RPC client context validates the server cluster ID. - // This ensures that a test where the server is restarted will not let - // its test RPC client talk to a server started by an unrelated concurrent test. - rpcContext.StorageClusterID.Set(ctx, ts.StorageClusterID()) - return rpcContext -} diff --git a/pkg/server/statements_test.go b/pkg/server/statements_test.go index 92010328b51f..4890592d5709 100644 --- a/pkg/server/statements_test.go +++ b/pkg/server/statements_test.go @@ -14,7 +14,6 @@ import ( "context" "testing" - "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/sql/tests" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" @@ -36,15 +35,10 @@ func TestStatements(t *testing.T) { testServer, db, _ := serverutils.StartServer(t, params) defer testServer.Stopper().Stop(ctx) - conn, err := testServer.RPCContext().GRPCDialNode( - testServer.RPCAddr(), testServer.NodeID(), rpc.DefaultClass, - ).Connect(ctx) - require.NoError(t, err) - - client := serverpb.NewStatusClient(conn) + client := testServer.GetStatusClient(t) testQuery := "CREATE TABLE foo (id INT8)" - _, err = db.Exec(testQuery) + _, err := db.Exec(testQuery) require.NoError(t, err) resp, err := client.Statements(ctx, &serverpb.StatementsRequest{NodeID: "local"}) @@ -69,15 +63,10 @@ func TestStatementsExcludeStats(t *testing.T) { testServer, db, _ := serverutils.StartServer(t, params) defer testServer.Stopper().Stop(ctx) - conn, err := testServer.RPCContext().GRPCDialNode( - testServer.RPCAddr(), testServer.NodeID(), rpc.DefaultClass, - ).Connect(ctx) - require.NoError(t, err) - - client := serverpb.NewStatusClient(conn) + client := testServer.GetStatusClient(t) testQuery := "CREATE TABLE foo (id INT8)" - _, err = db.Exec(testQuery) + _, err := db.Exec(testQuery) require.NoError(t, err) t.Run("exclude-statements", func(t *testing.T) { diff --git a/pkg/server/status_ext_test.go b/pkg/server/status_ext_test.go index d3dd04f9ba94..9b0e636add3d 100644 --- a/pkg/server/status_ext_test.go +++ b/pkg/server/status_ext_test.go @@ -18,11 +18,10 @@ import ( "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/skip" - "github.com/cockroachdb/cockroach/pkg/testutils/testcluster" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" - "github.com/stretchr/testify/require" ) // TestStatusLocalStacks verifies that goroutine stack traces are available @@ -33,12 +32,10 @@ func TestStatusLocalStacks(t *testing.T) { skip.UnderRaceWithIssue(t, 74133) ctx := context.Background() - var args base.TestClusterArgs - tc := testcluster.StartTestCluster(t, 1 /* nodes */, args) - defer tc.Stopper().Stop(ctx) + s := serverutils.StartServerOnly(t, base.TestServerArgs{}) + defer s.Stopper().Stop(ctx) - cc, err := tc.GetStatusClient(ctx, t, 0 /* idx */) - require.NoError(t, err) + cc := s.GetStatusClient(t) testCases := []struct { stackType serverpb.StacksType diff --git a/pkg/server/storage_api/BUILD.bazel b/pkg/server/storage_api/BUILD.bazel index aa86913d2d6f..2733923e2212 100644 --- a/pkg/server/storage_api/BUILD.bazel +++ b/pkg/server/storage_api/BUILD.bazel @@ -39,7 +39,6 @@ go_test( "//pkg/kv/kvserver/liveness", "//pkg/kv/kvserver/liveness/livenesspb", "//pkg/roachpb", - "//pkg/rpc", "//pkg/security/securityassets", "//pkg/security/securitytest", "//pkg/security/username", @@ -63,7 +62,6 @@ go_test( "@com_github_pkg_errors//:errors", "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", - "@org_golang_google_grpc//:go_default_library", "@org_golang_google_grpc//codes", "@org_golang_google_grpc//status", ], diff --git a/pkg/server/storage_api/decommission_test.go b/pkg/server/storage_api/decommission_test.go index 8266f0b34544..9cb09ee2a7a4 100644 --- a/pkg/server/storage_api/decommission_test.go +++ b/pkg/server/storage_api/decommission_test.go @@ -25,7 +25,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/kv/kvserver/liveness" "github.com/cockroachdb/cockroach/pkg/kv/kvserver/liveness/livenesspb" "github.com/cockroachdb/cockroach/pkg/roachpb" - "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/server/serverpb" @@ -36,7 +35,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/errors" "github.com/stretchr/testify/require" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -342,10 +340,7 @@ func TestDecommissionPreCheckBasicReadiness(t *testing.T) { defer tc.Stopper().Stop(ctx) adminSrv := tc.Server(4) - conn, err := adminSrv.RPCContext().GRPCDialNode( - adminSrv.RPCAddr(), adminSrv.NodeID(), rpc.DefaultClass).Connect(ctx) - require.NoError(t, err) - adminClient := serverpb.NewAdminClient(conn) + adminClient := adminSrv.GetAdminClient(t) resp, err := adminClient.DecommissionPreCheck(ctx, &serverpb.DecommissionPreCheckRequest{ NodeIDs: []roachpb.NodeID{tc.Server(5).NodeID()}, @@ -379,10 +374,7 @@ func TestDecommissionPreCheckUnready(t *testing.T) { adminSrv := tc.Server(adminSrvIdx) decommissioningSrv := tc.Server(decommissioningSrvIdx) - conn, err := adminSrv.RPCContext().GRPCDialNode( - adminSrv.RPCAddr(), adminSrv.NodeID(), rpc.DefaultClass).Connect(ctx) - require.NoError(t, err) - adminClient := serverpb.NewAdminClient(conn) + adminClient := adminSrv.GetAdminClient(t) checkNodeReady := func(nID roachpb.NodeID, replicaCount int64, strict bool) { resp, err := adminClient.DecommissionPreCheck(ctx, &serverpb.DecommissionPreCheckRequest{ @@ -557,10 +549,7 @@ func TestDecommissionPreCheckMultiple(t *testing.T) { } adminSrv := tc.Server(adminSrvIdx) - conn, err := adminSrv.RPCContext().GRPCDialNode( - adminSrv.RPCAddr(), adminSrv.NodeID(), rpc.DefaultClass).Connect(ctx) - require.NoError(t, err) - adminClient := serverpb.NewAdminClient(conn) + adminClient := adminSrv.GetAdminClient(t) // We expect to be able to decommission the targeted nodes simultaneously. resp, err := adminClient.DecommissionPreCheck(ctx, &serverpb.DecommissionPreCheckRequest{ @@ -622,10 +611,7 @@ func TestDecommissionPreCheckInvalidNode(t *testing.T) { } adminSrv := tc.Server(adminSrvIdx) - conn, err := adminSrv.RPCContext().GRPCDialNode( - adminSrv.RPCAddr(), adminSrv.NodeID(), rpc.DefaultClass).Connect(ctx) - require.NoError(t, err) - adminClient := serverpb.NewAdminClient(conn) + adminClient := adminSrv.GetAdminClient(t) // We expect the pre-check to fail as some node IDs are invalid. resp, err := adminClient.DecommissionPreCheck(ctx, &serverpb.DecommissionPreCheckRequest{ @@ -662,10 +648,7 @@ func TestDecommissionSelf(t *testing.T) { // admin server's logic, which involves a subsequent DecommissionStatus // call which could fail if used from a node that's just decommissioned. adminSrv := tc.Server(4) - conn, err := adminSrv.RPCContext().GRPCDialNode( - adminSrv.RPCAddr(), adminSrv.NodeID(), rpc.DefaultClass).Connect(ctx) - require.NoError(t, err) - adminClient := serverpb.NewAdminClient(conn) + adminClient := adminSrv.GetAdminClient(t) decomNodeIDs := []roachpb.NodeID{ tc.Server(4).NodeID(), tc.Server(5).NodeID(), @@ -751,13 +734,9 @@ func TestDecommissionEnqueueReplicas(t *testing.T) { decommissioningSrv := tc.Server(decommissioningSrvIdx) tc.AddVotersOrFatal(t, scratchKey, tc.Target(decommissioningSrvIdx)) - conn, err := decommissioningSrv.RPCContext().GRPCDialNode( - decommissioningSrv.RPCAddr(), decommissioningSrv.NodeID(), rpc.DefaultClass, - ).Connect(ctx) - require.NoError(t, err) - adminClient := serverpb.NewAdminClient(conn) + adminClient := decommissioningSrv.GetAdminClient(t) decomNodeIDs := []roachpb.NodeID{tc.Server(decommissioningSrvIdx).NodeID()} - _, err = adminClient.Decommission( + _, err := adminClient.Decommission( ctx, &serverpb.DecommissionRequest{ NodeIDs: decomNodeIDs, @@ -837,13 +816,7 @@ func TestAdminDecommissionedOperations(t *testing.T) { }, 10*time.Second) // Set up an admin client. - //lint:ignore SA1019 grpc.WithInsecure is deprecated - conn, err := grpc.Dial(decomSrv.AdvRPCAddr(), grpc.WithInsecure()) - require.NoError(t, err) - defer func() { - _ = conn.Close() // nolint:grpcconnclose - }() - adminClient := serverpb.NewAdminClient(conn) + adminClient := decomSrv.GetAdminClient(t) // Run some operations on the decommissioned node. The ones that require // access to the cluster should fail, other should succeed. We're mostly diff --git a/pkg/server/storage_api/files_test.go b/pkg/server/storage_api/files_test.go index e248de2d5d35..0b1ba76302cb 100644 --- a/pkg/server/storage_api/files_test.go +++ b/pkg/server/storage_api/files_test.go @@ -19,11 +19,8 @@ import ( "testing" "github.com/cockroachdb/cockroach/pkg/base" - "github.com/cockroachdb/cockroach/pkg/rpc" - "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/server/serverpb" - "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" @@ -48,16 +45,7 @@ func TestStatusGetFiles(t *testing.T) { ts := tsI.(*server.TestServer) defer ts.Stopper().Stop(context.Background()) - rootConfig := testutils.NewTestBaseContext(username.RootUserName()) - rpcContext := srvtestutils.NewRPCTestContext(context.Background(), ts, rootConfig) - - url := ts.AdvRPCAddr() - nodeID := ts.NodeID() - conn, err := rpcContext.GRPCDialNode(url, nodeID, rpc.DefaultClass).Connect(context.Background()) - if err != nil { - t.Fatal(err) - } - client := serverpb.NewStatusClient(conn) + client := ts.GetStatusClient(t) // Test fetching heap files. t.Run("heap", func(t *testing.T) { @@ -137,7 +125,7 @@ func TestStatusGetFiles(t *testing.T) { t.Run("path separators", func(t *testing.T) { request := serverpb.GetFilesRequest{NodeId: "local", ListOnly: true, Type: serverpb.FileType_HEAP, Patterns: []string{"pattern/with/separators"}} - _, err = client.GetFiles(context.Background(), &request) + _, err := client.GetFiles(context.Background(), &request) if !testutils.IsError(err, "invalid pattern: cannot have path seperators") { t.Errorf("GetFiles: path separators allowed in pattern") } @@ -147,7 +135,7 @@ func TestStatusGetFiles(t *testing.T) { t.Run("filetypes", func(t *testing.T) { request := serverpb.GetFilesRequest{NodeId: "local", ListOnly: true, Type: -1, Patterns: []string{"*"}} - _, err = client.GetFiles(context.Background(), &request) + _, err := client.GetFiles(context.Background(), &request) if !testutils.IsError(err, "unknown file type: -1") { t.Errorf("GetFiles: invalid file type allowed") } diff --git a/pkg/server/storage_api/nodes_test.go b/pkg/server/storage_api/nodes_test.go index 6ccaee414b51..afd2267dea52 100644 --- a/pkg/server/storage_api/nodes_test.go +++ b/pkg/server/storage_api/nodes_test.go @@ -18,8 +18,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/build" "github.com/cockroachdb/cockroach/pkg/roachpb" - "github.com/cockroachdb/cockroach/pkg/rpc" - "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/server/apiconstants" "github.com/cockroachdb/cockroach/pkg/server/serverpb" @@ -187,23 +185,17 @@ func TestMetricsEndpoint(t *testing.T) { func TestNodesGRPCResponse(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) + + ctx := context.Background() + s := serverutils.StartServerOnly(t, base.TestServerArgs{}) - defer s.Stopper().Stop(context.Background()) - ts := s.(*server.TestServer) + defer s.Stopper().Stop(ctx) - rootConfig := testutils.NewTestBaseContext(username.RootUserName()) - rpcContext := srvtestutils.NewRPCTestContext(context.Background(), ts, rootConfig) var request serverpb.NodesRequest - url := ts.AdvRPCAddr() - nodeID := ts.NodeID() - conn, err := rpcContext.GRPCDialNode(url, nodeID, rpc.DefaultClass).Connect(context.Background()) - if err != nil { - t.Fatal(err) - } - client := serverpb.NewStatusClient(conn) + client := s.GetStatusClient(t) - response, err := client.Nodes(context.Background(), &request) + response, err := client.Nodes(ctx, &request) if err != nil { t.Fatal(err) } diff --git a/pkg/server/storage_api/ranges_test.go b/pkg/server/storage_api/ranges_test.go index c27c9ad03d3f..e37934df98b0 100644 --- a/pkg/server/storage_api/ranges_test.go +++ b/pkg/server/storage_api/ranges_test.go @@ -18,7 +18,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/keys" "github.com/cockroachdb/cockroach/pkg/kv/kvserver" "github.com/cockroachdb/cockroach/pkg/roachpb" - "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/server/srvtestutils" @@ -82,11 +81,7 @@ func TestRangesResponse(t *testing.T) { rpcStopper := stop.NewStopper() defer rpcStopper.Stop(ctx) - conn, err := ts.RPCContext().GRPCDialNode(ts.AdvRPCAddr(), ts.NodeID(), rpc.DefaultClass).Connect(ctx) - if err != nil { - t.Fatal(err) - } - client := serverpb.NewStatusClient(conn) + client := ts.GetStatusClient(t) resp1, err := client.Ranges(ctx, &serverpb.RangesRequest{ Limit: 1, }) @@ -119,12 +114,8 @@ func TestTenantRangesResponse(t *testing.T) { rpcStopper := stop.NewStopper() defer rpcStopper.Stop(ctx) - conn, err := ts.RPCContext().GRPCDialNode(ts.AdvRPCAddr(), ts.NodeID(), rpc.DefaultClass).Connect(ctx) - if err != nil { - t.Fatal(err) - } - client := serverpb.NewStatusClient(conn) - _, err = client.TenantRanges(ctx, &serverpb.TenantRangesRequest{}) + client := ts.GetStatusClient(t) + _, err := client.TenantRanges(ctx, &serverpb.TenantRangesRequest{}) require.Error(t, err) require.Contains(t, err.Error(), "no tenant ID found in context") }) diff --git a/pkg/server/tenant.go b/pkg/server/tenant.go index c42fca7c77c9..33cdb605ef1f 100644 --- a/pkg/server/tenant.go +++ b/pkg/server/tenant.go @@ -992,23 +992,21 @@ func makeTenantSQLServerArgs( testingKnobShutdownTenantConnectorEarlyIfNoRecordPresent = p.ShutdownTenantConnectorEarlyIfNoRecordPresent } + rpcCtxOpts := rpc.ServerContextOptionsFromBaseConfig(baseCfg.Config) + rpcCtxOpts.TenantID = sqlCfg.TenantID + rpcCtxOpts.UseNodeAuth = sqlCfg.LocalKVServerInfo != nil + rpcCtxOpts.NodeID = baseCfg.IDContainer + rpcCtxOpts.StorageClusterID = baseCfg.ClusterIDContainer + rpcCtxOpts.Clock = clock.WallClock() + rpcCtxOpts.ToleratedOffset = clock.ToleratedOffset() + rpcCtxOpts.Stopper = stopper + rpcCtxOpts.Settings = st + rpcCtxOpts.Knobs = rpcTestingKnobs // This tenant's SQL server only serves SQL connections and SQL-to-SQL // RPCs; so it should refuse to serve SQL-to-KV RPCs completely. - authorizer := tenantcapabilitiesauthorizer.NewAllowNothingAuthorizer() - - rpcContext := rpc.NewContext(startupCtx, rpc.ContextOptions{ - TenantID: sqlCfg.TenantID, - UseNodeAuth: sqlCfg.LocalKVServerInfo != nil, - NodeID: baseCfg.IDContainer, - StorageClusterID: baseCfg.ClusterIDContainer, - Config: baseCfg.Config, - Clock: clock.WallClock(), - ToleratedOffset: clock.ToleratedOffset(), - Stopper: stopper, - Settings: st, - Knobs: rpcTestingKnobs, - TenantRPCAuthorizer: authorizer, - }) + rpcCtxOpts.TenantRPCAuthorizer = tenantcapabilitiesauthorizer.NewAllowNothingAuthorizer() + + rpcContext := rpc.NewContext(startupCtx, rpcCtxOpts) if !baseCfg.Insecure { // This check mirrors that done in NewServer(). diff --git a/pkg/server/testserver.go b/pkg/server/testserver.go index 052f7f5e47e7..376507b2d8b9 100644 --- a/pkg/server/testserver.go +++ b/pkg/server/testserver.go @@ -76,7 +76,9 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/util/tracing" "github.com/cockroachdb/errors" + "github.com/cockroachdb/logtags" "github.com/gogo/protobuf/proto" + "google.golang.org/grpc" ) // makeTestConfig returns a config for testing. It overrides the @@ -2272,3 +2274,132 @@ func TestingMakeLoggingContexts( }) return ctxSysTenant, ctxAppTenant } + +// NewClientRPCContext is part of the serverutils.ApplicationLayerInterface. +func (ts *TestServer) NewClientRPCContext( + ctx context.Context, user username.SQLUsername, +) *rpc.Context { + return newClientRPCContext(ctx, user, + ts.Server.cfg.Config, + ts.Server.cfg.TestingKnobs.Server, + ts.Server.cfg.ClusterIDContainer, + ts) +} + +// RPCClientConn is part of the serverutils.ApplicationLayerInterface. +func (ts *TestServer) RPCClientConn( + test serverutils.TestFataler, user username.SQLUsername, +) *grpc.ClientConn { + conn, err := ts.RPCClientConnE(user) + if err != nil { + test.Fatal(err) + } + return conn +} + +// RPCClientConnE is part of the serverutils.ApplicationLayerInterface. +func (ts *TestServer) RPCClientConnE(user username.SQLUsername) (*grpc.ClientConn, error) { + ctx := context.Background() + rpcCtx := ts.NewClientRPCContext(ctx, user) + return rpcCtx.GRPCDialNode(ts.AdvRPCAddr(), ts.NodeID(), rpc.DefaultClass).Connect(ctx) +} + +// GetAdminClient is part of the serverutils.ApplicationLayerInterface. +func (ts *TestServer) GetAdminClient(test serverutils.TestFataler) serverpb.AdminClient { + conn := ts.RPCClientConn(test, username.RootUserName()) + return serverpb.NewAdminClient(conn) +} + +// GetStatusClient is part of the serverutils.ApplicationLayerInterface. +func (ts *TestServer) GetStatusClient(test serverutils.TestFataler) serverpb.StatusClient { + conn := ts.RPCClientConn(test, username.RootUserName()) + return serverpb.NewStatusClient(conn) +} + +// NewClientRPCContext is part of the serverutils.ApplicationLayerInterface. +func (t *TestTenant) NewClientRPCContext( + ctx context.Context, user username.SQLUsername, +) *rpc.Context { + return newClientRPCContext(ctx, user, + t.Cfg.Config, + t.Cfg.TestingKnobs.Server, + t.Cfg.ClusterIDContainer, + t) +} + +// RPCClientConn is part of the serverutils.ApplicationLayerInterface. +func (t *TestTenant) RPCClientConn( + test serverutils.TestFataler, user username.SQLUsername, +) *grpc.ClientConn { + conn, err := t.RPCClientConnE(user) + if err != nil { + test.Fatal(err) + } + return conn +} + +// RPCClientConnE is part of the serverutils.ApplicationLayerInterface. +func (t *TestTenant) RPCClientConnE(user username.SQLUsername) (*grpc.ClientConn, error) { + ctx := context.Background() + rpcCtx := t.NewClientRPCContext(ctx, user) + return rpcCtx.GRPCDialPod(t.AdvRPCAddr(), t.SQLInstanceID(), rpc.DefaultClass).Connect(ctx) +} + +// GetAdminClient is part of the serverutils.ApplicationLayerInterface. +func (t *TestTenant) GetAdminClient(test serverutils.TestFataler) serverpb.AdminClient { + conn := t.RPCClientConn(test, username.RootUserName()) + return serverpb.NewAdminClient(conn) +} + +// GetStatusClient is part of the serverutils.ApplicationLayerInterface. +func (t *TestTenant) GetStatusClient(test serverutils.TestFataler) serverpb.StatusClient { + conn := t.RPCClientConn(test, username.RootUserName()) + return serverpb.NewStatusClient(conn) +} + +func newClientRPCContext( + ctx context.Context, + user username.SQLUsername, + cfg *base.Config, + sknobs base.ModuleTestingKnobs, + cid *base.ClusterIDContainer, + s serverutils.ApplicationLayerInterface, +) *rpc.Context { + ctx = logtags.AddTag(ctx, "testclient", nil) + ctx = logtags.AddTag(ctx, "user", user) + ctx = logtags.AddTag(ctx, "nsql", s.SQLInstanceID()) + + stopper := s.Stopper() + if ctx.Done() == nil { + // The RPCContext initialization wants a cancellable context, + // since that will be used to stop async goroutines. Help + // the test by making one. + cctx, cancel := context.WithCancel(ctx) + stopper.AddCloser(stop.CloserFn(cancel)) + ctx = cctx + } + var knobs rpc.ContextTestingKnobs + if sknobs != nil { + knobs = sknobs.(*TestingKnobs).ContextTestingKnobs + } + ccfg := rpc.MakeClientConnConfigFromBaseConfig(*cfg, + user, + // We pass nil as tracer parameter below, instead of the server's + // tracer, because the tracer used for the incoming context and + // passed to NewClientContext may not be using the same tracer and + // we cannot mix and match tracers. + nil, /* tracer */ + s.ClusterSettings(), + s.Clock(), + knobs, + ) + + rpcCtx, clientStopper := rpc.NewClientContext(ctx, ccfg) + // Ensure that the RPC client context validates the server cluster ID. + // This ensures that a test where the server is restarted will not let + // its test RPC client talk to a server started by an unrelated concurrent test. + rpcCtx.StorageClusterID.Set(ctx, cid.Get()) + + stopper.AddCloser(stop.CloserFn(func() { clientStopper.Stop(ctx) })) + return rpcCtx +} diff --git a/pkg/server/testserver_http.go b/pkg/server/testserver_http.go index 13fe171260cb..597ae58a820b 100644 --- a/pkg/server/testserver_http.go +++ b/pkg/server/testserver_http.go @@ -69,7 +69,7 @@ var _ http.RoundTripper = &tenantHeaderDecorator{} // AdminURL implements TestServerInterface. func (ts *httpTestServer) AdminURL() *serverutils.TestURL { - u := ts.t.sqlServer.execCfg.RPCContext.Config.AdminURL() + u := ts.t.sqlServer.cfg.Config.AdminURL() if ts.t.tenantName != "" { q := u.Query() q.Add(ClusterNameParamInQueryURL, string(ts.t.tenantName)) @@ -156,7 +156,7 @@ func (ts *httpTestServer) GetAuthenticatedHTTPClientAndCookie( if err != nil { return err } - url, err := url.Parse(ts.t.sqlServer.execCfg.RPCContext.Config.AdminURL().String()) + url, err := url.Parse(ts.t.sqlServer.cfg.Config.AdminURL().String()) if err != nil { return err } diff --git a/pkg/sql/create_role.go b/pkg/sql/create_role.go index 2db390628fc1..1f4dfe6d4190 100644 --- a/pkg/sql/create_role.go +++ b/pkg/sql/create_role.go @@ -274,7 +274,7 @@ func retrievePasswordFromRoleOptions( if err != nil { return true, nil, err } - if !isNull && params.extendedEvalCtx.ExecCfg.RPCContext.Config.Insecure { + if !isNull && params.extendedEvalCtx.ExecCfg.RPCContext.Insecure { // We disallow setting a non-empty password in insecure mode // because insecure means an observer may have MITM'ed the change // and learned the password. diff --git a/pkg/sql/execinfrapb/BUILD.bazel b/pkg/sql/execinfrapb/BUILD.bazel index bf86cf0e0ba9..9878204bd4e7 100644 --- a/pkg/sql/execinfrapb/BUILD.bazel +++ b/pkg/sql/execinfrapb/BUILD.bazel @@ -54,7 +54,6 @@ go_library( "//pkg/util/protoutil", "//pkg/util/stop", "//pkg/util/syncutil", - "//pkg/util/timeutil", "//pkg/util/tracing/tracingpb", "//pkg/util/uuid", "@com_github_cockroachdb_errors//:errors", diff --git a/pkg/sql/execinfrapb/testutils.go b/pkg/sql/execinfrapb/testutils.go index d71b65296950..8ee29d5c58ca 100644 --- a/pkg/sql/execinfrapb/testutils.go +++ b/pkg/sql/execinfrapb/testutils.go @@ -13,7 +13,6 @@ package execinfrapb import ( "context" "net" - "time" "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/roachpb" @@ -24,7 +23,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/netutil" "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/syncutil" - "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/util/uuid" "github.com/cockroachdb/logtags" "google.golang.org/grpc" @@ -33,16 +31,12 @@ import ( func newInsecureRPCContext(ctx context.Context, stopper *stop.Stopper) *rpc.Context { nc := &base.NodeIDContainer{} ctx = logtags.AddTag(ctx, "n", nc) - return rpc.NewContext(ctx, - rpc.ContextOptions{ - TenantID: roachpb.SystemTenantID, - NodeID: nc, - Config: &base.Config{Insecure: true}, - Clock: &timeutil.DefaultTimeSource{}, - ToleratedOffset: time.Nanosecond, - Stopper: stopper, - Settings: cluster.MakeTestingClusterSettings(), - }) + opts := rpc.DefaultContextOptions() + opts.Insecure = true + opts.Stopper = stopper + opts.Settings = cluster.MakeTestingClusterSettings() + opts.NodeID = nc + return rpc.NewContext(ctx, opts) } // StartMockDistSQLServer starts a MockDistSQLServer and returns the address on diff --git a/pkg/sql/flowinfra/cluster_test.go b/pkg/sql/flowinfra/cluster_test.go index dc11b919b66f..2355c56bd7fa 100644 --- a/pkg/sql/flowinfra/cluster_test.go +++ b/pkg/sql/flowinfra/cluster_test.go @@ -28,6 +28,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/kv/kvserver/kvserverbase" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" "github.com/cockroachdb/cockroach/pkg/sql/catalog/desctestutils" "github.com/cockroachdb/cockroach/pkg/sql/catalog/fetchpb" @@ -258,7 +259,7 @@ func runTestClusterFlow( func TestClusterFlow(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) - ctx := context.Background() + const numNodes = 3 args := base.TestClusterArgs{ReplicationMode: base.ReplicationManual} @@ -273,10 +274,7 @@ func TestClusterFlow(t *testing.T) { s := tc.Server(i) servers[i] = s conns[i] = tc.ServerConn(i) - conn, err := s.RPCContext().GRPCDialNode(s.AdvRPCAddr(), s.NodeID(), rpc.DefaultClass).Connect(ctx) - if err != nil { - t.Fatal(err) - } + conn := s.RPCClientConn(t, username.RootUserName()) clients[i] = execinfrapb.NewDistSQLClient(conn) } @@ -772,13 +770,10 @@ func BenchmarkInfrastructure(b *testing.B) { reqs[0].Flow.Processors = append(reqs[0].Flow.Processors, lastProc) var clients []execinfrapb.DistSQLClient - ctx := context.Background() + for i := 0; i < numNodes; i++ { s := tc.Server(i) - conn, err := s.RPCContext().GRPCDialNode(s.AdvRPCAddr(), s.NodeID(), rpc.DefaultClass).Connect(ctx) - if err != nil { - b.Fatal(err) - } + conn := s.RPCClientConn(b, username.RootUserName()) clients = append(clients, execinfrapb.NewDistSQLClient(conn)) } diff --git a/pkg/sql/flowinfra/server_test.go b/pkg/sql/flowinfra/server_test.go index dfcd4904a3e6..9cf3eebb030d 100644 --- a/pkg/sql/flowinfra/server_test.go +++ b/pkg/sql/flowinfra/server_test.go @@ -20,7 +20,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/keys" "github.com/cockroachdb/cockroach/pkg/kv" "github.com/cockroachdb/cockroach/pkg/roachpb" - "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" "github.com/cockroachdb/cockroach/pkg/sql/catalog/desctestutils" "github.com/cockroachdb/cockroach/pkg/sql/distsql" @@ -45,11 +45,7 @@ func TestServer(t *testing.T) { ctx := context.Background() s, sqlDB, kvDB := serverutils.StartServer(t, base.TestServerArgs{}) defer s.Stopper().Stop(ctx) - conn, err := s.RPCContext().GRPCDialNode(s.AdvRPCAddr(), s.NodeID(), - rpc.DefaultClass).Connect(ctx) - if err != nil { - t.Fatal(err) - } + conn := s.RPCClientConn(t, username.RootUserName()) r := sqlutils.MakeSQLRunner(sqlDB) diff --git a/pkg/testutils/BUILD.bazel b/pkg/testutils/BUILD.bazel index 91349456cd68..417016c1bd63 100644 --- a/pkg/testutils/BUILD.bazel +++ b/pkg/testutils/BUILD.bazel @@ -4,7 +4,6 @@ go_library( name = "testutils", srcs = [ "backup.go", - "base.go", "dir.go", "error.go", "files.go", @@ -22,11 +21,8 @@ go_library( importpath = "github.com/cockroachdb/cockroach/pkg/testutils", visibility = ["//visibility:public"], deps = [ - "//pkg/base", "//pkg/build/bazel", "//pkg/kv/kvpb", - "//pkg/security/certnames", - "//pkg/security/username", "//pkg/sql/pgwire/pgerror", "//pkg/util", "//pkg/util/fileutil", diff --git a/pkg/testutils/base.go b/pkg/testutils/base.go deleted file mode 100644 index efb9a50ebbfb..000000000000 --- a/pkg/testutils/base.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2015 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package testutils - -import ( - "time" - - "github.com/cockroachdb/cockroach/pkg/base" - "github.com/cockroachdb/cockroach/pkg/security/certnames" - "github.com/cockroachdb/cockroach/pkg/security/username" -) - -// NewNodeTestBaseContext creates a base context for testing. This uses -// embedded certs and the default node user. The default node user has both -// server and client certificates. -func NewNodeTestBaseContext() *base.Config { - return NewTestBaseContext(username.NodeUserName()) -} - -// NewTestBaseContext creates a secure base context for user. -func NewTestBaseContext(user username.SQLUsername) *base.Config { - cfg := &base.Config{ - Insecure: false, - User: user, - RPCHeartbeatInterval: time.Millisecond, - } - FillCerts(cfg) - return cfg -} - -// FillCerts sets the certs on a base.Config. -func FillCerts(cfg *base.Config) { - cfg.SSLCertsDir = certnames.EmbeddedCertsDir -} diff --git a/pkg/testutils/localtestcluster/local_test_cluster.go b/pkg/testutils/localtestcluster/local_test_cluster.go index 4024308fb481..c03619538a2f 100644 --- a/pkg/testutils/localtestcluster/local_test_cluster.go +++ b/pkg/testutils/localtestcluster/local_test_cluster.go @@ -52,8 +52,7 @@ import ( // usage of a LocalTestCluster follows: // // s := &LocalTestCluster{} -// s.Start(t, testutils.NewNodeTestBaseContext(), -// kv.InitFactoryForLocalTestCluster) +// s.Start(t, kv.InitFactoryForLocalTestCluster) // defer s.Stop() // // Note that the LocalTestCluster is different from server.TestCluster @@ -114,7 +113,7 @@ func (ltc *LocalTestCluster) Stopper() *stop.Stopper { // node RPC server and all HTTP endpoints. Use the value of // TestServer.Addr after Start() for client connections. Use Stop() // to shutdown the server after the test completes. -func (ltc *LocalTestCluster) Start(t testing.TB, baseCtx *base.Config, initFactory InitFactoryFn) { +func (ltc *LocalTestCluster) Start(t testing.TB, initFactory InitFactoryFn) { manualClock := timeutil.NewManualTime(timeutil.Unix(0, 123)) clock := hlc.NewClock(manualClock, 50*time.Millisecond /* maxOffset */, 50*time.Millisecond /* toleratedOffset */) @@ -137,17 +136,17 @@ func (ltc *LocalTestCluster) Start(t testing.TB, baseCtx *base.Config, initFacto } ltc.tester = t - cfg.RPCContext = rpc.NewContext(ctx, rpc.ContextOptions{ - TenantID: roachpb.SystemTenantID, - Config: baseCtx, - Clock: ltc.Clock.WallClock(), - ToleratedOffset: ltc.Clock.ToleratedOffset(), - Stopper: ltc.stopper, - Settings: cfg.Settings, - NodeID: nc, - TenantRPCAuthorizer: tenantcapabilitiesauthorizer.NewAllowEverythingAuthorizer(), - }) + opts := rpc.DefaultContextOptions() + opts.Clock = ltc.Clock.WallClock() + opts.ToleratedOffset = ltc.Clock.ToleratedOffset() + opts.Stopper = ltc.stopper + opts.Settings = cfg.Settings + opts.NodeID = nc + opts.TenantRPCAuthorizer = tenantcapabilitiesauthorizer.NewAllowEverythingAuthorizer() + + cfg.RPCContext = rpc.NewContext(ctx, opts) + cfg.RPCContext.NodeID.Set(ctx, nodeID) clusterID := cfg.RPCContext.StorageClusterID ltc.Gossip = gossip.New(ambient, clusterID, nc, ltc.stopper, metric.NewRegistry(), roachpb.Locality{}, zonepb.DefaultZoneConfigRef()) diff --git a/pkg/testutils/serverutils/BUILD.bazel b/pkg/testutils/serverutils/BUILD.bazel index 1a373e456e41..5213f3eb30f5 100644 --- a/pkg/testutils/serverutils/BUILD.bazel +++ b/pkg/testutils/serverutils/BUILD.bazel @@ -43,5 +43,6 @@ go_library( "//pkg/util/tracing", "//pkg/util/uuid", "@com_github_cockroachdb_errors//:errors", + "@org_golang_google_grpc//:go_default_library", ], ) diff --git a/pkg/testutils/serverutils/api.go b/pkg/testutils/serverutils/api.go index 3864c5c1763e..ff1377695b59 100644 --- a/pkg/testutils/serverutils/api.go +++ b/pkg/testutils/serverutils/api.go @@ -38,6 +38,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/tracing" "github.com/cockroachdb/cockroach/pkg/util/uuid" + "google.golang.org/grpc" ) // TestServerInterface defines test server functionality that tests need; it is @@ -177,9 +178,28 @@ type ApplicationLayerInterface interface { // JobRegistry returns the *jobs.Registry as an interface{}. JobRegistry() interface{} - // RPCContext returns the *rpc.Context used by the test tenant. + // RPCContext returns the *rpc.Context used by the server. RPCContext() *rpc.Context + // NewClientRPCContext creates a new rpc.Context suitable to open + // client RPC connections to the server. + NewClientRPCContext(ctx context.Context, userName username.SQLUsername) *rpc.Context + + // RPCClientConn opens a RPC client connection to the server. + RPCClientConn(t TestFataler, userName username.SQLUsername) *grpc.ClientConn + + // RPCClientConnE is like RPCClientConn but it allows the test to check the + // error. + RPCClientConnE(userName username.SQLUsername) (*grpc.ClientConn, error) + + // GetAdminClient creates a serverpb.AdminClient connection to the server. + // Shorthand for serverpb.AdminClient(.RPCClientConn(t, "root")) + GetAdminClient(t TestFataler) serverpb.AdminClient + + // GetStatusClient creates a serverpb.StatusClient connection to the server. + // Shorthand for serverpb.StatusClient(.RPCClientConn(t, "root")) + GetStatusClient(t TestFataler) serverpb.StatusClient + // AnnotateCtx annotates a context. AnnotateCtx(context.Context) context.Context diff --git a/pkg/testutils/testcluster/BUILD.bazel b/pkg/testutils/testcluster/BUILD.bazel index 2fc609dfba15..20daadef51c8 100644 --- a/pkg/testutils/testcluster/BUILD.bazel +++ b/pkg/testutils/testcluster/BUILD.bazel @@ -58,7 +58,6 @@ go_test( "//pkg/kv", "//pkg/kv/kvpb", "//pkg/roachpb", - "//pkg/rpc", "//pkg/security/securityassets", "//pkg/security/securitytest", "//pkg/server", diff --git a/pkg/testutils/testcluster/testcluster.go b/pkg/testutils/testcluster/testcluster.go index f34a7dbaefac..3420360ac542 100644 --- a/pkg/testutils/testcluster/testcluster.go +++ b/pkg/testutils/testcluster/testcluster.go @@ -488,6 +488,8 @@ func (tc *TestCluster) Start(t serverutils.TestFataler) { for _, ssrv := range tc.Servers { for _, dsrv := range tc.Servers { stl := dsrv.StorageLayer() + // Note: we avoid using .RPCClientConn() here to avoid accumulating + // stopper closures in RAM during the SucceedsSoon iterations. _, e := ssrv.RPCContext().GRPCDialNode(dsrv.SystemLayer().AdvRPCAddr(), stl.NodeID(), rpc.DefaultClass).Connect(context.TODO()) err = errors.CombineErrors(err, e) @@ -1527,10 +1529,14 @@ func (tc *TestCluster) WaitForZoneConfigPropagation() error { // store in the cluster. func (tc *TestCluster) WaitForNodeStatuses(t serverutils.TestFataler) { testutils.SucceedsSoon(t, func() error { - client, err := tc.GetStatusClient(context.Background(), t, 0) + // Note: we avoid using .RPCClientConn() here to avoid accumulating + // stopper closures in RAM during the SucceedsSoon iterations. + srv := tc.Server(0).SystemLayer() + conn, err := srv.RPCContext().GRPCDialNode(srv.AdvRPCAddr(), tc.Server(0).NodeID(), rpc.DefaultClass).Connect(context.TODO()) if err != nil { return err } + client := serverpb.NewStatusClient(conn) response, err := client.Nodes(context.Background(), &serverpb.NodesRequest{}) if err != nil { return err @@ -1847,26 +1853,16 @@ func (tc *TestCluster) GetRaftLeader( // GetAdminClient gets the severpb.AdminClient for the specified server. func (tc *TestCluster) GetAdminClient( - ctx context.Context, t serverutils.TestFataler, serverIdx int, -) (serverpb.AdminClient, error) { - srv := tc.Server(serverIdx) - cc, err := srv.RPCContext().GRPCDialNode(srv.RPCAddr(), srv.NodeID(), rpc.DefaultClass).Connect(ctx) - if err != nil { - return nil, errors.Wrap(err, "failed to create an admin client") - } - return serverpb.NewAdminClient(cc), nil + t serverutils.TestFataler, serverIdx int, +) serverpb.AdminClient { + return tc.Server(serverIdx).GetAdminClient(t) } // GetStatusClient gets the severpb.StatusClient for the specified server. func (tc *TestCluster) GetStatusClient( - ctx context.Context, t serverutils.TestFataler, serverIdx int, -) (serverpb.StatusClient, error) { - srv := tc.Server(serverIdx) - cc, err := srv.RPCContext().GRPCDialNode(srv.RPCAddr(), srv.NodeID(), rpc.DefaultClass).Connect(ctx) - if err != nil { - return nil, errors.Wrap(err, "failed to create a status client") - } - return serverpb.NewStatusClient(cc), nil + t serverutils.TestFataler, serverIdx int, +) serverpb.StatusClient { + return tc.Server(serverIdx).GetStatusClient(t) } // SplitTable implements TestClusterInterface. diff --git a/pkg/testutils/testcluster/testcluster_test.go b/pkg/testutils/testcluster/testcluster_test.go index ecffba83bc42..6e9820dbcfb1 100644 --- a/pkg/testutils/testcluster/testcluster_test.go +++ b/pkg/testutils/testcluster/testcluster_test.go @@ -21,7 +21,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/kv" "github.com/cockroachdb/cockroach/pkg/kv/kvpb" "github.com/cockroachdb/cockroach/pkg/roachpb" - "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/sql/catalog/desctestutils" @@ -34,8 +33,17 @@ import ( "github.com/stretchr/testify/require" ) +func TestClusterStart(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + tc := StartTestCluster(t, 3, base.TestClusterArgs{}) + defer tc.Stopper().Stop(context.Background()) +} + func TestManualReplication(t *testing.T) { defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) tc := StartTestCluster(t, 3, base.TestClusterArgs{ @@ -143,6 +151,7 @@ func TestManualReplication(t *testing.T) { // waiting for all of the stores to initialize. func TestBasicManualReplication(t *testing.T) { defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) tc := StartTestCluster(t, 3, base.TestClusterArgs{ReplicationMode: base.ReplicationManual}) defer tc.Stopper().Stop(context.Background()) @@ -175,6 +184,7 @@ func TestBasicManualReplication(t *testing.T) { func TestBasicAutoReplication(t *testing.T) { defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) tc := StartTestCluster(t, 3, base.TestClusterArgs{ReplicationMode: base.ReplicationAuto}) defer tc.Stopper().Stop(context.Background()) @@ -183,6 +193,7 @@ func TestBasicAutoReplication(t *testing.T) { func TestStopServer(t *testing.T) { defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) // Use insecure mode so our servers listen on util.IsolatedTestAddr // and they fail cleanly instead of interfering with other tests. @@ -209,22 +220,7 @@ func TestStopServer(t *testing.T) { } ctx := context.Background() - rpcContext := rpc.NewContext(ctx, rpc.ContextOptions{ - TenantID: roachpb.SystemTenantID, - Config: server1.RPCContext().Config, - Clock: server1.Clock().WallClock(), - ToleratedOffset: server1.Clock().ToleratedOffset(), - Stopper: tc.Stopper(), - Settings: server1.ClusterSettings(), - - ClientOnly: true, - }) - conn, err := rpcContext.GRPCDialNode(server1.AdvRPCAddr(), server1.NodeID(), - rpc.DefaultClass).Connect(ctx) - if err != nil { - t.Fatal(err) - } - statusClient1 := serverpb.NewStatusClient(conn) + statusClient1 := server1.GetStatusClient(t) var cancel func() ctx, cancel = context.WithTimeout(ctx, 2*time.Second) defer cancel() @@ -273,6 +269,7 @@ func TestStopServer(t *testing.T) { func TestRestart(t *testing.T) { defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) stickyEngineRegistry := server.NewStickyInMemEnginesRegistry() defer stickyEngineRegistry.CloseAllStickyInMemEngines() diff --git a/pkg/ts/BUILD.bazel b/pkg/ts/BUILD.bazel index 24404ac19284..294f081df13d 100644 --- a/pkg/ts/BUILD.bazel +++ b/pkg/ts/BUILD.bazel @@ -78,9 +78,9 @@ go_test( "//pkg/kv/kvserver", "//pkg/multitenant/tenantcapabilities", "//pkg/roachpb", - "//pkg/rpc", "//pkg/security/securityassets", "//pkg/security/securitytest", + "//pkg/security/username", "//pkg/server", "//pkg/settings/cluster", "//pkg/storage", diff --git a/pkg/ts/db_test.go b/pkg/ts/db_test.go index a1d3601ecebe..c968bda3c5d4 100644 --- a/pkg/ts/db_test.go +++ b/pkg/ts/db_test.go @@ -110,8 +110,7 @@ func newTestModelRunner(t *testing.T) testModelRunner { // Start constructs and starts the local test server and creates a // time series DB. func (tm *testModelRunner) Start() { - tm.LocalTestCluster.Start(tm.t, testutils.NewNodeTestBaseContext(), - kvcoord.InitFactoryForLocalTestCluster) + tm.LocalTestCluster.Start(tm.t, kvcoord.InitFactoryForLocalTestCluster) tm.DB = NewDB(tm.LocalTestCluster.DB, tm.Cfg.Settings) } diff --git a/pkg/ts/server_test.go b/pkg/ts/server_test.go index 824a1bc622ed..2a18ea93cd58 100644 --- a/pkg/ts/server_test.go +++ b/pkg/ts/server_test.go @@ -25,7 +25,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/kv/kvserver" "github.com/cockroachdb/cockroach/pkg/multitenant/tenantcapabilities" "github.com/cockroachdb/cockroach/pkg/roachpb" - "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/ts" @@ -179,11 +179,7 @@ func TestServerQuery(t *testing.T) { }, } - conn, err := tsrv.RPCContext().GRPCDialNode(tsrv.Cfg.Addr, tsrv.NodeID(), - rpc.DefaultClass).Connect(context.Background()) - if err != nil { - t.Fatal(err) - } + conn := tsrv.RPCClientConn(t, username.RootUserName()) client := tspb.NewTimeSeriesClient(conn) response, err := client.Query(context.Background(), &tspb.TimeSeriesQueryRequest{ StartNanos: 500 * 1e9, @@ -277,11 +273,7 @@ func TestServerQueryStarvation(t *testing.T) { t.Fatal(err) } - conn, err := tsrv.RPCContext().GRPCDialNode(tsrv.Cfg.Addr, tsrv.NodeID(), - rpc.DefaultClass).Connect(context.Background()) - if err != nil { - t.Fatal(err) - } + conn := tsrv.RPCClientConn(t, username.RootUserName()) client := tspb.NewTimeSeriesClient(conn) queries := make([]tspb.Query, 0, seriesCount) @@ -417,11 +409,7 @@ func TestServerQueryTenant(t *testing.T) { }, } - conn, err := tsrv.RPCContext().GRPCDialNode(tsrv.Cfg.Addr, tsrv.NodeID(), - rpc.DefaultClass).Connect(context.Background()) - if err != nil { - t.Fatal(err) - } + conn := tsrv.RPCClientConn(t, username.RootUserName()) client := tspb.NewTimeSeriesClient(conn) systemResponse, err := client.Query(context.Background(), &tspb.TimeSeriesQueryRequest{ StartNanos: 400 * 1e9, @@ -492,10 +480,7 @@ func TestServerQueryTenant(t *testing.T) { } capability := map[tenantcapabilities.ID]string{tenantcapabilities.CanViewTSDBMetrics: "true"} serverutils.WaitForTenantCapabilities(t, s, tenantID, capability, "") - tenantConn, err := tenant.(*server.TestTenant).RPCContext().GRPCDialNode(tenant.(*server.TestTenant).Cfg.AdvertiseAddr, tsrv.NodeID(), rpc.DefaultClass).Connect(context.Background()) - if err != nil { - t.Fatal(err) - } + tenantConn := tenant.RPCClientConn(t, username.RootUserName()) tenantClient := tspb.NewTimeSeriesClient(tenantConn) tenantResponse, err := tenantClient.Query(context.Background(), &tspb.TimeSeriesQueryRequest{ @@ -554,11 +539,7 @@ func TestServerQueryMemoryManagement(t *testing.T) { t.Fatal(err) } - conn, err := tsrv.RPCContext().GRPCDialNode(tsrv.Cfg.Addr, tsrv.NodeID(), - rpc.DefaultClass).Connect(context.Background()) - if err != nil { - t.Fatal(err) - } + conn := tsrv.RPCClientConn(t, username.RootUserName()) client := tspb.NewTimeSeriesClient(conn) queries := make([]tspb.Query, 0, seriesCount) @@ -633,11 +614,7 @@ func TestServerDump(t *testing.T) { names = append(names, seriesName(series)) } - conn, err := tsrv.RPCContext().GRPCDialNode(tsrv.Cfg.Addr, tsrv.NodeID(), - rpc.DefaultClass).Connect(ctx) - if err != nil { - t.Fatal(err) - } + conn := tsrv.RPCClientConn(t, username.RootUserName()) client := tspb.NewTimeSeriesClient(conn) dumpClient, err := client.Dump(ctx, &tspb.DumpRequest{ @@ -730,11 +707,7 @@ func TestServerDump(t *testing.T) { for i := 0; i < 3; i++ { require.NoError(t, s.DB().Run(ctx, &b)) - conn, err := s.RPCContext().GRPCDialNode(s.AdvRPCAddr(), s.NodeID(), - rpc.DefaultClass).Connect(ctx) - if err != nil { - t.Fatal(err) - } + conn := s.RPCClientConn(t, username.RootUserName()) client := tspb.NewTimeSeriesClient(conn) dumpClient, err := client.Dump(ctx, &tspb.DumpRequest{ @@ -765,11 +738,7 @@ func BenchmarkServerQuery(b *testing.B) { b.Fatal(err) } - conn, err := tsrv.RPCContext().GRPCDialNode(tsrv.Cfg.Addr, tsrv.NodeID(), - rpc.DefaultClass).Connect(context.Background()) - if err != nil { - b.Fatal(err) - } + conn := tsrv.RPCClientConn(b, username.RootUserName()) client := tspb.NewTimeSeriesClient(conn) queries := make([]tspb.Query, 0, seriesCount)