From 84fff6bd307a7d9171efc75967420e7806e358ed Mon Sep 17 00:00:00 2001 From: Andrew Baptist Date: Thu, 5 Jan 2023 14:12:18 -0500 Subject: [PATCH] kv: convert uni-directional network partitions to bi-directional Previously one-way partitions where a node could initiate a successful TCP connection in one direction, but the reverse connection fails which causes problems. The node that initiates outgoing connections can acquire leases and cause failures for reads and writes to those ranges. This is particularly a problem if it acquires the liveness range leases, but is a problem even for other ranges. This commit adds an additional check during server-to-server communication where the recipient of a new PingRequest first validates that it is able to open a reverse connection to the initiator before responding. Additionally, it will monitor whether it has a successful reverse connection over time and asynchronously validate reverse connections to the sender. The ongoing validation is asynchronous to avoid adding delays to PingResponses as they are used for measuring clock offsets. Release note (bug fix): RPC connections between nodes now require RPC connections to be established in both directions, otherwise the connection will be closed. This is done to prevent asymmetric network partitions where nodes are able to send outbound messages but not receive inbound messages, which could result in persistent unavailability. This behavior can be disabled by the cluster setting rpc.dialback.enabled. Also the onlyOnceDialer prevents retrying after a dial error, however this can get into a state where it continually retries for certain network connections. This is not easy to reproduce in a unit tests as it requires killing the connection using iptables (normal closes don't cauuse this). After this change the onlyOnceDialer will no longer repeatedly retry to reconnect after a broken connection during setup. Epic: none Release note: None --- pkg/acceptance/localcluster/cluster.go | 4 +- pkg/base/config.go | 4 + pkg/cmd/allocsim/main.go | 2 +- pkg/roachpb/metadata.go | 31 ++-- pkg/rpc/BUILD.bazel | 1 + pkg/rpc/context.go | 233 ++++++++++++++++++++----- pkg/rpc/context_test.go | 188 +++++++++++++++++++- pkg/rpc/heartbeat.go | 42 ++--- pkg/rpc/heartbeat.proto | 35 +++- pkg/rpc/nodedialer/nodedialer.go | 7 + pkg/rpc/nodedialer/nodedialer_test.go | 8 +- pkg/server/config.go | 4 - pkg/server/server.go | 27 +-- 13 files changed, 485 insertions(+), 101 deletions(-) diff --git a/pkg/acceptance/localcluster/cluster.go b/pkg/acceptance/localcluster/cluster.go index 94ffd1b51418..f5f3003c6152 100644 --- a/pkg/acceptance/localcluster/cluster.go +++ b/pkg/acceptance/localcluster/cluster.go @@ -482,7 +482,7 @@ func (n *Node) Alive() bool { } // StatusClient returns a StatusClient set up to talk to this node. -func (n *Node) StatusClient() serverpb.StatusClient { +func (n *Node) StatusClient(ctx context.Context) serverpb.StatusClient { n.Lock() existingClient := n.statusClient n.Unlock() @@ -491,7 +491,7 @@ func (n *Node) StatusClient() serverpb.StatusClient { return existingClient } - conn, err := n.rpcCtx.GRPCDialRaw(n.RPCAddr()) + conn, err := n.rpcCtx.GRPCUnvalidatedDial(n.RPCAddr()).Connect(ctx) if err != nil { log.Fatalf(context.Background(), "failed to initialize status client: %s", err) } diff --git a/pkg/base/config.go b/pkg/base/config.go index dec64dd60ad8..0f4a79fe0e49 100644 --- a/pkg/base/config.go +++ b/pkg/base/config.go @@ -376,6 +376,10 @@ type Config struct { // The flag exists mostly for the benefit of tests, and for // `cockroach start-single-node`. AutoInitializeCluster bool + + // LocalityAddresses contains private IP addresses that can only be accessed + // in the corresponding locality. + LocalityAddresses []roachpb.LocalityAddress } // HistogramWindowInterval is used to determine the approximate length of time diff --git a/pkg/cmd/allocsim/main.go b/pkg/cmd/allocsim/main.go index 03d5fb92c241..fb76f79809c4 100644 --- a/pkg/cmd/allocsim/main.go +++ b/pkg/cmd/allocsim/main.go @@ -243,7 +243,7 @@ func (a *allocSim) rangeInfo() allocStats { for i := 0; i < len(a.Nodes); i++ { go func(i int) { defer wg.Done() - status := a.Nodes[i].StatusClient() + status := a.Nodes[i].StatusClient(context.Background()) if status == nil { // Cluster is shutting down. return diff --git a/pkg/roachpb/metadata.go b/pkg/roachpb/metadata.go index b2064c85dc9c..6c58f709819f 100644 --- a/pkg/roachpb/metadata.go +++ b/pkg/roachpb/metadata.go @@ -552,21 +552,11 @@ func (sc StoreCapacity) Load() load.Load { // AddressForLocality returns the network address that nodes in the specified // locality should use when connecting to the node described by the descriptor. func (n *NodeDescriptor) AddressForLocality(loc Locality) *util.UnresolvedAddr { - // If the provided locality has any tiers that are an exact exact match (key + // If the provided locality has any tiers that are an exact match (key // and value) with a tier in the node descriptor's custom LocalityAddress // list, return the corresponding address. Otherwise, return the default // address. - // - // O(n^2), but we expect very few locality tiers in practice. - for i := range n.LocalityAddress { - nLoc := &n.LocalityAddress[i] - for _, loc := range loc.Tiers { - if loc == nLoc.LocalityTier { - return &nLoc.Address - } - } - } - return &n.Address + return loc.LookupAddress(n.LocalityAddress, &n.Address) } // CheckedSQLAddress returns the value of SQLAddress if set. If not, either @@ -644,6 +634,23 @@ func (l Locality) Equals(r Locality) bool { return true } +// LookupAddress is given a set of LocalityAddresses and finds the one that +// exactly matches my Locality. O(n^2), but we expect very few locality tiers in +// practice. +func (l Locality) LookupAddress( + address []LocalityAddress, base *util.UnresolvedAddr, +) *util.UnresolvedAddr { + for i := range address { + nLoc := &address[i] + for _, loc := range l.Tiers { + if loc == nLoc.LocalityTier { + return &nLoc.Address + } + } + } + return base +} + // MaxDiversityScore is the largest possible diversity score, indicating that // two localities are as different from each other as possible. const MaxDiversityScore = 1.0 diff --git a/pkg/rpc/BUILD.bazel b/pkg/rpc/BUILD.bazel index 884006042a71..4fa9f8d335d4 100644 --- a/pkg/rpc/BUILD.bazel +++ b/pkg/rpc/BUILD.bazel @@ -36,6 +36,7 @@ go_library( "//pkg/security/username", "//pkg/settings", "//pkg/settings/cluster", + "//pkg/util", "//pkg/util/buildutil", "//pkg/util/contextutil", "//pkg/util/envutil", diff --git a/pkg/rpc/context.go b/pkg/rpc/context.go index 44df5e9dba75..e68f4126e44e 100644 --- a/pkg/rpc/context.go +++ b/pkg/rpc/context.go @@ -28,7 +28,9 @@ import ( "github.com/cockroachdb/cockroach/pkg/multitenant/tenantcapabilities" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/security" + "github.com/cockroachdb/cockroach/pkg/settings" "github.com/cockroachdb/cockroach/pkg/settings/cluster" + "github.com/cockroachdb/cockroach/pkg/util" "github.com/cockroachdb/cockroach/pkg/util/buildutil" "github.com/cockroachdb/cockroach/pkg/util/contextutil" "github.com/cockroachdb/cockroach/pkg/util/envutil" @@ -303,9 +305,9 @@ type Connection struct { // err is nil initially; eventually set to the dial or heartbeat error that // tore down the connection. err atomic.Value - // initialHeartbeatDone is closed in `runHeartbeat` once grpcConn is - // populated. This means that access to that field must read this channel - // first. + // initialHeartbeatDone is closed in `runHeartbeat` once grpcConn is populated + // and a heartbeat is successfully returned. This means that access to that + // field must read this channel first. initialHeartbeatDone chan struct{} // closed after first heartbeat grpcConn *grpc.ClientConn // present when initialHeartbeatDone is closed; must read that channel first } @@ -350,10 +352,10 @@ func (c *Connection) Health() error { err, _ := c.err.Load().(error) return err default: - // TODO(tbg): would be better if this returned ErrNoConnection, as this - // is what's happening here. There might be a connection attempt going - // on, but not one that has proven conclusively that the peer is even - // reachable. + // There might be a connection attempt going on, but not one that has proven + // conclusively that the peer is reachable and able to connect back to us. + // Ideally we could return ErrNoConnection, but it is hard to separate out + // these cases. return ErrNotHeartbeated } } @@ -362,6 +364,8 @@ func (c *Connection) Health() error { // // TODO(tbg): rename at the very least the `ctx` receiver, but possibly the whole // thing. +// TODO(baptist): Remove the inheritance on ContextOptions directly construct +// the object with what it needs. type Context struct { ContextOptions *SecurityContext @@ -380,6 +384,14 @@ type Context struct { m connMap + // dialbackMap is a map of currently executing dialback connections. This map + // is typically empty or close to empty. It only holds entries that are being + // verified for dialback due to failing a health check. + dialbackMu struct { + syncutil.Mutex + m map[roachpb.NodeID]*Connection + } + metrics Metrics // For unittesting. @@ -470,9 +482,8 @@ type ContextOptions struct { Settings *cluster.Settings // OnIncomingPing is called when handling a PingRequest, after // preliminary checks but before recording clock offset information. - // - // It can inject an error. - OnIncomingPing func(context.Context, *PingRequest) error + // It can inject an error or modify the response. + OnIncomingPing func(context.Context, *PingRequest, *PingResponse) error // OnOutgoingPing intercepts outgoing PingRequests. It may inject an // error. OnOutgoingPing func(context.Context, *PingRequest) error @@ -512,6 +523,16 @@ type ContextOptions struct { // subsystem. It allows KV nodes to perform capability checks for incoming // tenant requests. TenantRPCAuthorizer tenantcapabilities.Authorizer + + // NeedsDialback indicates that connections created with this RPC context + // should be verified after they are established by the recipient having a + // backwards connection to us. This is used for KV server to KV server + // communication. If there is already a healthy connection, then the + // PingResponse is sent like normal, however if there is no connection then a + // throwaway reverse TCP connection is made. This is set to true on + // node-to-node connections and prevents one-way partitions from occurring by + // turing them into two-way partitions. + NeedsDialback bool } func (c ContextOptions) validate() error { @@ -639,6 +660,10 @@ func NewContext(ctx context.Context, opts ContextOptions) *Context { logClosingConnEvery: log.Every(time.Second), } + rpcCtx.dialbackMu.Lock() + rpcCtx.dialbackMu.m = map[roachpb.NodeID]*Connection{} + rpcCtx.dialbackMu.Unlock() + if !opts.TenantID.IsSet() { panic("tenant ID not set") } @@ -1763,6 +1788,10 @@ func (rpcCtx *Context) dialOptsNetwork( } dialOpts = append(dialOpts, grpc.WithContextDialer(dialerFunc)) + // Don't retry on dial errors either, otherwise the onlyOnceDialer will get + // into a bad state for connection errors. + dialOpts = append(dialOpts, grpc.FailOnNonTempDialError(true)) + return dialOpts, nil } @@ -1879,7 +1908,7 @@ func (ood *onlyOnceDialer) dial(ctx context.Context, addr string) (net.Conn, err // We set up onlyOnceDialer to avoid returning any errors that could look // temporary to gRPC, and so we don't expect it to re-dial a connection // twice (the first re-dial is supposed to surface the permanent error). - return nil, errors.NewAssertionErrorWithWrappedErrf(err, "gRPC connection unexpectedly re-dialed") + return nil, ¬TemporaryError{errors.NewAssertionErrorWithWrappedErrf(err, "gRPC connection unexpectedly re-dialed")} } ood.mu.redialed = true return nil, err @@ -2061,22 +2090,11 @@ func (rpcCtx *Context) makeDialCtx( return dialCtx } -// GRPCDialRaw calls grpc.Dial with options appropriate for the context. -// Unlike GRPCDialNode, it does not start an RPC heartbeat to validate the -// connection. This connection will not be reconnected automatically; -// the returned channel is closed when a reconnection is attempted. -// This method implies a DefaultClass ConnectionClass for the returned -// ClientConn. -func (rpcCtx *Context) GRPCDialRaw(target string) (*grpc.ClientConn, error) { - ctx := rpcCtx.makeDialCtx(target, 0, DefaultClass) - return rpcCtx.grpcDialRaw(ctx, target, DefaultClass) -} - // grpcDialRaw connects to the remote node. // The ctx passed as argument must be derived from rpcCtx.masterCtx, so // that it respects the same cancellation policy. func (rpcCtx *Context) grpcDialRaw( - ctx context.Context, target string, class ConnectionClass, + ctx context.Context, target string, class ConnectionClass, additionalOpts ...grpc.DialOption, ) (*grpc.ClientConn, error) { transport := tcpTransport if rpcCtx.Config.AdvertiseAddr == target && !rpcCtx.ClientOnly { @@ -2088,10 +2106,7 @@ func (rpcCtx *Context) grpcDialRaw( return nil, err } - // Add testingDialOpts at the end because one of our tests - // uses a custom dialer (this disables the only-one-connection - // behavior and redialChan will never be closed). - dialOpts = append(dialOpts, rpcCtx.testingDialOpts...) + dialOpts = append(dialOpts, additionalOpts...) return grpc.DialContext(ctx, target, dialOpts...) } @@ -2229,7 +2244,8 @@ func (rpcCtx *Context) grpcDialNodeInternal( // Run the heartbeat; this will block until the connection breaks for // whatever reason. We don't actually have to do anything with the error, // so we ignore it. - _ = rpcCtx.runHeartbeat(ctx, conn, target) + err := rpcCtx.runHeartbeat(ctx, conn, target) + log.Health.Infof(ctx, "connection heartbeat loop ended with err: %v", err) maybeFatal(ctx, rpcCtx.m.Remove(k, conn)) // Context gets canceled on server shutdown, and if that's likely why @@ -2272,6 +2288,14 @@ var ErrNotHeartbeated = errors.New("not yet heartbeated") // the node. var ErrNoConnection = errors.New("no connection found") +// TODO(baptist): Remove in 23.2 (or 24.1) once validating dialback works for all scenarios. +var useDialback = settings.RegisterBoolSetting( + settings.SystemOnly, + "rpc.dialback.enabled", + "if true, require bidirectional RPC connections between nodes to prevent one-way network unavailability", + true, +) + // runHeartbeat synchronously runs the heartbeat loop for the given RPC // connection. The ctx passed as argument must be derived from rpcCtx.masterCtx, // so that it respects the same cancellation policy. @@ -2326,7 +2350,7 @@ func (rpcCtx *Context) runHeartbeat( { var err error - conn.grpcConn, err = rpcCtx.grpcDialRaw(ctx, target, conn.class) + conn.grpcConn, err = rpcCtx.grpcDialRaw(ctx, target, conn.class, rpcCtx.testingDialOpts...) if err != nil { // Note that grpcConn will actually connect in the background, so it's // unusual to hit this case. @@ -2343,9 +2367,10 @@ func (rpcCtx *Context) runHeartbeat( // heartbeat to heartbeat: we compute a new .Offset at the end of // the current heartbeat as input to the next one. request := &PingRequest{ - DeprecatedOriginAddr: rpcCtx.Config.Addr, - TargetNodeID: conn.remoteNodeID, - ServerVersion: rpcCtx.Settings.Version.BinaryVersion(), + OriginAddr: rpcCtx.Config.AdvertiseAddr, + TargetNodeID: conn.remoteNodeID, + ServerVersion: rpcCtx.Settings.Version.BinaryVersion(), + LocalityAddress: rpcCtx.Config.LocalityAddresses, } heartbeatClient := NewHeartbeatClient(conn.grpcConn) @@ -2363,7 +2388,8 @@ func (rpcCtx *Context) runHeartbeat( // This simple model should work well in practice and it avoids serious // problems that could arise from keeping unhealthy connections in the pool. connFailedCh := make(chan connectivity.State, 1) - for i := 0; ; i++ { + first := true + for { select { case <-ctx.Done(): return nil // server shutting down @@ -2395,6 +2421,16 @@ func (rpcCtx *Context) runHeartbeat( return err } var err error + // Check the setting lazily to allow toggling on/off without a restart. + if rpcCtx.NeedsDialback && useDialback.Get(&rpcCtx.Settings.SV) { + if first { + request.NeedsDialback = PingRequest_BLOCKING + } else { + request.NeedsDialback = PingRequest_NON_BLOCKING + } + } else { + request.NeedsDialback = PingRequest_NONE + } response, err = heartbeatClient.Ping(ctx, request) return err } @@ -2430,12 +2466,13 @@ func (rpcCtx *Context) runHeartbeat( return err } - // Only a server connecting to another server needs to check - // clock offsets. A CLI command does not need to update its - // local HLC, nor does it care that strictly about - // client-server latency, nor does it need to track the - // offsets. - if rpcCtx.RemoteClocks != nil { + // Only a server connecting to another server needs to check clock + // offsets. A CLI command does not need to update its local HLC, nor does + // it care that strictly about client-server latency, nor does it need to + // track the offsets. A response.ServerTime of 0 means we can not use this + // response for updating our clocks. This can occur if the server added + // delays before sending a response. + if rpcCtx.RemoteClocks != nil && response.ServerTime != 0 { receiveTime := rpcCtx.Clock.Now() // Only update the clock offset measurement if we actually got a @@ -2466,11 +2503,11 @@ func (rpcCtx *Context) runHeartbeat( return err } - if i == 0 { + if first { // First heartbeat succeeded. rpcCtx.metrics.HeartbeatsNominal.Inc(1) - log.Health.Infof(ctx, "connection is now ready") close(conn.initialHeartbeatDone) + log.Health.Infof(ctx, "connection is now ready") // The connection should be `Ready` now since we just used it for a // heartbeat RPC. Any additional state transition indicates that we need // to remove it, and we want to do so reactively. Unfortunately, gRPC @@ -2494,6 +2531,7 @@ func (rpcCtx *Context) runHeartbeat( } heartbeatTimer.Reset(rpcCtx.heartbeatInterval) + first = false } } @@ -2511,3 +2549,116 @@ func (rpcCtx *Context) NewHeartbeatService() *HeartbeatService { testingAllowNamedRPCToAnonymousServer: rpcCtx.TestingAllowNamedRPCToAnonymousServer, } } + +// VerifyDialback verifies connectivity from the recipient of a PingRequest +// back to the sender. If there is already a connection in place, it will return +// immediately without error. If there is no connection in place and the +// NeedsDialback on the PingRequest is not set to NONE, then it will establish a +// connection in either blocking or non-blocking mode. +// BLOCKING mode delays sending a PingResponse until the connection is +// validated, and is only used on the first PingRequest after a connection is +// established. +// NON_BLOCKING mode will attempt to establish a reverse connection and send the +// result on the next PingRequest that is sent on this connection. +func (rpcCtx *Context) VerifyDialback( + ctx context.Context, request *PingRequest, response *PingResponse, locality roachpb.Locality, +) error { + if request.NeedsDialback == PingRequest_NONE { + return nil + } + + baseAddr := util.UnresolvedAddr{NetworkField: "tcp", AddressField: request.OriginAddr} + target := locality.LookupAddress(request.LocalityAddress, &baseAddr).AddressField + nodeID := request.OriginNodeID + + // Initially the nodeID might not be set since it is assigned by the cluster + // not the node. In that case, we can't look up if we have a connection to the + // node and instead need to always try dialback. + var err error + if nodeID != 0 { + prevErr, found := rpcCtx.previousAttempt(nodeID) + if found { + return prevErr + } + + // Check in our regular connection map to see if we are healthy. We only + // care about the SystemClass because that is what is important from a Raft + // liveness perspective. + err = rpcCtx.ConnHealth(target, nodeID, SystemClass) + // We have a successful connection, nothing else to do. + if err == nil { + return nil + } + } + + log.VEventf(ctx, 2, "unable to verify health on open conn, trying dialback conn to %s, n%d, %v", target, nodeID, err) + if nodeID == 0 || request.NeedsDialback == PingRequest_BLOCKING { + // We want this connection to block while connecting to verify it succeeds. + // Clear out the ServerTime on our response so the receiver does not use + // this response in its latency calculations. + response.ServerTime = 0 + + // Since we don't have a successful reverse connection, try and dial back + // manually. We don't use the regular dialer pool since we don't want to wait + // for heartbeats on this connection. + // TODO(baptist): Consider using GRPCUnvalidatedDial and use the + // WaitForStateChange to detect when the TCP connection is established. This + // will keep this connection in the pool after establishment. Note the class + // here matter since this connection is not added to a pool and immediately + // closed. + ctx := rpcCtx.makeDialCtx(target, 0, SystemClass) + conn, err := rpcCtx.grpcDialRaw(ctx, target, SystemClass, grpc.WithBlock()) + if err != nil { + log.Warningf(ctx, "dialback connection failed to %s, n%d, %v", target, nodeID, err) + return err + } + _ = conn.Close() // nolint:grpcconnclose + return nil + } else { + // We don't have a previous attempt and the current health was not healthy, + // but we can't just trust that because there might not be new connection + // attempts. Instead, establish a new connection using the standard + // GRPCDialNode, this connection is added to the connection pool. Always + // return success on this ping, but check this connection attempt on future + // pings. Use the SystemClass to ensure that Raft traffic is not + // interrupted. It is unusual for some classes to be affected and not others + // but the SystemClass is the one we really care about. + rpcCtx.dialbackMu.Lock() + defer rpcCtx.dialbackMu.Unlock() + rpcCtx.dialbackMu.m[nodeID] = rpcCtx.GRPCDialNode(target, nodeID, SystemClass) + return nil + } +} + +// previousAttempt checks if any prior attempt that started but hadn't complete +// has now completed. Until this attempt completes, the Pings will continue to +// return success. Once this completes, we remove this from our map and return +// whatever error this attempt returned. +func (rpcCtx *Context) previousAttempt(nodeID roachpb.NodeID) (error, bool) { + // Check if there was a previous attempt and if so use that and clear out + // the previous attempt. + rpcCtx.dialbackMu.Lock() + defer rpcCtx.dialbackMu.Unlock() + previousAttempt := rpcCtx.dialbackMu.m[nodeID] + + // Block here for the previous connection to be completed (successfully or + // not). This happens only on a second ping after the first didn't detect + // a reverse connection. The connection setup can take longer than a ping + // interval. This ensures the reverse connection is good before allowing the + // sender to continue. + if previousAttempt != nil { + select { + case <-previousAttempt.initialHeartbeatDone: + // The connection attempt was completed, return the outcome of it. + err, _ := previousAttempt.err.Load().(error) + rpcCtx.dialbackMu.m[nodeID] = nil + return err, true + default: + // We still don't know the outcome of the previous attempt. For now + // allow this Ping to continue and check on the following attempt. + return nil, true + } + } + // There is no previous attempt in place. + return nil, false +} diff --git a/pkg/rpc/context_test.go b/pkg/rpc/context_test.go index 20070aed63b9..705f4a5f21e9 100644 --- a/pkg/rpc/context_test.go +++ b/pkg/rpc/context_test.go @@ -81,9 +81,7 @@ func (rpcCtx *Context) AddTestingDialOpts(opts ...grpc.DialOption) { func newTestServer(t testing.TB, ctx *Context, extraOpts ...grpc.ServerOption) *grpc.Server { tlsConfig, err := ctx.GetServerTLSConfig() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) opts := []grpc.ServerOption{ grpc.Creds(credentials.NewTLS(tlsConfig)), } @@ -193,7 +191,7 @@ func TestPingInterceptors(t *testing.T) { } return nil }, - OnIncomingPing: func(ctx context.Context, req *PingRequest) error { + OnIncomingPing: func(ctx context.Context, req *PingRequest, resp *PingResponse) error { if req.OriginNodeID == blockedOriginNodeID { return errBoomRecv } @@ -2467,7 +2465,7 @@ func BenchmarkGRPCDial(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - _, err := rpcCtx.GRPCDialNode(remoteAddr, serverNodeID, DefaultClass).Connect(context.Background()) + _, err := rpcCtx.grpcDialRaw(ctx, remoteAddr, DefaultClass) if err != nil { b.Fatal(err) } @@ -2557,3 +2555,183 @@ func TestOnlyOnceDialer(t *testing.T) { } } } + +type trackingListener struct { + net.Listener + mu syncutil.Mutex + connections []net.Conn + closed bool +} + +func (d *trackingListener) Accept() (net.Conn, error) { + c, err := d.Listener.Accept() + + d.mu.Lock() + defer d.mu.Unlock() + // If we get any trailing accepts after we close, just close the connection immediately. + if err == nil { + if d.closed { + _ = c.Close() + } else { + d.connections = append(d.connections, c) + } + } + return c, err +} + +func (d *trackingListener) Close() error { + d.mu.Lock() + defer d.mu.Unlock() + d.closed = true + for _, c := range d.connections { + _ = c.Close() + } + err := d.Listener.Close() + if err != nil { + return err + } + return nil +} + +func newRegisteredServer( + t testing.TB, stopper *stop.Stopper, clusterID uuid.UUID, nodeID roachpb.NodeID, +) (*Context, string, chan *PingRequest, *trackingListener) { + clock := timeutil.NewManualTime(timeutil.Unix(0, 1)) + // We don't want to stall sending to this channel. + pingChan := make(chan *PingRequest, 5) + + opts := ContextOptions{ + TenantID: roachpb.SystemTenantID, + Config: testutils.NewNodeTestBaseContext(), + Clock: clock, + ToleratedOffset: time.Duration(0), + Stopper: stopper, + Settings: cluster.MakeTestingClusterSettings(), + NeedsDialback: true, + Knobs: ContextTestingKnobs{NoLoopbackDialer: true}, + } + // Heartbeat faster so we don't have to wait as long. + opts.Config.RPCHeartbeatInterval = 10 * time.Millisecond + opts.Config.RPCHeartbeatTimeout = 100 * time.Millisecond + + rpcCtx := NewContext(context.Background(), opts) + // This is normally set up inside the server, we want to hold onto all PingRequests that come through. + rpcCtx.OnIncomingPing = func(ctx context.Context, req *PingRequest, resp *PingResponse) error { + pingChan <- req + err := rpcCtx.VerifyDialback(ctx, req, resp, roachpb.Locality{}) + // On success store the ping to the channel for test analysis. + return err + } + + rpcCtx.NodeID.Set(context.Background(), nodeID) + rpcCtx.StorageClusterID.Set(context.Background(), clusterID) + s := newTestServer(t, rpcCtx) + + RegisterHeartbeatServer(s, rpcCtx.NewHeartbeatService()) + + ln, err := net.Listen("tcp", util.TestAddr.String()) + require.Nil(t, err) + tracker := trackingListener{Listener: ln} + _ = stopper.RunAsyncTask(context.Background(), "serve", func(context.Context) { + closeReason := s.Serve(&tracker) + log.Infof(context.Background(), "Closed listener with reason %v", closeReason) + }) + + addr := ln.Addr().String() + log.Infof(context.Background(), "Listening on %s", addr) + // This needs to be set once we know our address so that ping requests have + // the correct reverse addr in them. + rpcCtx.Config.AdvertiseAddr = addr + return rpcCtx, addr, pingChan, &tracker +} + +func TestHeartbeatDialback(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + clusterID := uuid.MakeV4() + + ctx1, remoteAddr1, pingChan1, ln1 := newRegisteredServer(t, stopper, clusterID, 1) + ctx2, remoteAddr2, pingChan2, ln2 := newRegisteredServer(t, stopper, clusterID, 2) + defer func() { netutil.FatalIfUnexpected(ln1.Close()) }() + defer func() { netutil.FatalIfUnexpected(ln2.Close()) }() + + // Test an incorrect remoteNodeID, this should fail with a heartbeat error. + // This invariant is important to make sure we don't try and connect to the + // wrong node. + { + _, err := ctx1.GRPCDialNode(remoteAddr2, 3, DefaultClass).Connect(ctx) + var respErr *netutil.InitialHeartbeatFailedError + require.ErrorAs(t, err, &respErr) + // Verify no heartbeat received in either direction. + require.Equal(t, 0, len(pingChan1)) + require.Equal(t, 0, len(pingChan2)) + } + + // Initiate connection from node 1 to node 2 which will create a dialback + // connection back to 1. This will be a blocking connection since there is no + // reverse connection. + { + conn, err := ctx1.GRPCDialNode(remoteAddr2, 2, DefaultClass).Connect(ctx) + defer func() { + _ = conn.Close() // nolint:grpcconnclose + }() + require.NoError(t, err) + require.NotNil(t, conn) + require.Equal(t, 1, len(pingChan2)) + pingReq := <-pingChan2 + require.Equal(t, PingRequest_BLOCKING, pingReq.NeedsDialback) + require.Equal(t, 0, len(pingChan1)) + } + + //Now connect back in the opposite direction. This should not initiate any + //dialback since we are already connected. + { + conn, err := ctx1.GRPCDialNode(remoteAddr2, 2, DefaultClass).Connect(ctx) + defer func() { + _ = conn.Close() // nolint:grpcconnclose + }() + require.NoError(t, err) + require.NotNil(t, conn) + // The reverse connection was already set up, but we are still blocking. + pingReq := <-pingChan1 + require.Equal(t, PingRequest_BLOCKING, pingReq.NeedsDialback) + + // At this point, node 1 has a fully established connection to node 2, however node 2 has not yet finished connecting back. + require.Equal(t, nil, ctx1.ConnHealth(remoteAddr2, 2, DefaultClass)) + } + + // Verify we get non-blocking requests in both directions now. + require.Equal(t, PingRequest_NON_BLOCKING, (<-pingChan2).NeedsDialback) + require.Equal(t, PingRequest_NON_BLOCKING, (<-pingChan1).NeedsDialback) + + // Verify we are fully healthy in both directions (note the dialback is on the + // system class). + require.Equal(t, nil, ctx1.ConnHealth(remoteAddr2, 2, DefaultClass)) + require.Equal(t, nil, ctx2.ConnHealth(remoteAddr1, 1, SystemClass)) + + // Forcibly shut down listener 2 and the connection node1 -> node2. + // Verify the reverse connection will also close within a DialTimeout. + log.Info(ctx, "Closing node 2 listener") + _ = ln2.Close() + + // Wait for a few more pings to go through to make sure it has a chance to + // shut down the reverse connection. Normally the connect attempt times out + // immediately and returns an error, but occasionally it needs to wait for the + // RPCHeartbeatTimeout (100 ms). Wait until pings have stopped in both + // directions for at least 1 second before checking health. + for { + select { + case ping := <-pingChan1: + log.Infof(ctx, "Received %+v", ping) + case ping := <-pingChan2: + log.Infof(ctx, "Received %+v", ping) + case <-time.After(1 * time.Second): + require.ErrorAs(t, ctx1.ConnHealth(remoteAddr2, 2, DefaultClass), &ErrNoConnection) + require.ErrorAs(t, ctx2.ConnHealth(remoteAddr1, 1, SystemClass), &ErrNoConnection) + return + } + } +} diff --git a/pkg/rpc/heartbeat.go b/pkg/rpc/heartbeat.go index cde8390a8c0c..9653dbe9f9dc 100644 --- a/pkg/rpc/heartbeat.go +++ b/pkg/rpc/heartbeat.go @@ -52,7 +52,7 @@ type HeartbeatService struct { clusterName string disableClusterNameVerification bool - onHandlePing func(context.Context, *PingRequest) error // see ContextOptions.OnIncomingPing + onHandlePing func(context.Context, *PingRequest, *PingResponse) error // see ContextOptions.OnIncomingPing // TestingAllowNamedRPCToAnonymousServer, when defined (in tests), // disables errors in case a heartbeat requests a specific node ID but @@ -115,13 +115,13 @@ func checkVersion( // server's current clock value, allowing the requester to measure its clock. // The requester should also estimate its offset from this server along // with the requester's address. -func (hs *HeartbeatService) Ping(ctx context.Context, args *PingRequest) (*PingResponse, error) { +func (hs *HeartbeatService) Ping(ctx context.Context, request *PingRequest) (*PingResponse, error) { if log.ExpensiveLogEnabled(ctx, 2) { - log.Dev.Infof(ctx, "received heartbeat: %+v vs local cluster %+v node %+v", args, hs.clusterID, hs.nodeID) + log.Dev.Infof(ctx, "received heartbeat: %+v vs local cluster %+v node %+v", request, hs.clusterID, hs.nodeID) } // Check that cluster IDs match. clusterID := hs.clusterID.Get() - if args.ClusterID != nil && *args.ClusterID != uuid.Nil && clusterID != uuid.Nil { + if request.ClusterID != nil && *request.ClusterID != uuid.Nil && clusterID != uuid.Nil { // There is a cluster ID on both sides. Use that to verify the connection. // // Note: we could be checking the cluster name here too, however @@ -129,9 +129,9 @@ func (hs *HeartbeatService) Ping(ctx context.Context, args *PingRequest) (*PingR // initiating the connection), so that the user of a newly started // node gets a chance to see a cluster name mismatch as an error message // on their side. - if *args.ClusterID != clusterID { + if *request.ClusterID != clusterID { return nil, errors.Errorf( - "client cluster ID %q doesn't match server cluster ID %q", args.ClusterID, clusterID) + "client cluster ID %q doesn't match server cluster ID %q", request.ClusterID, clusterID) } } // Check that node IDs match. @@ -139,7 +139,7 @@ func (hs *HeartbeatService) Ping(ctx context.Context, args *PingRequest) (*PingR if hs.nodeID != nil { nodeID = hs.nodeID.Get() } - if args.TargetNodeID != 0 && (!hs.testingAllowNamedRPCToAnonymousServer || nodeID != 0) && args.TargetNodeID != nodeID { + if request.TargetNodeID != 0 && (!hs.testingAllowNamedRPCToAnonymousServer || nodeID != 0) && request.TargetNodeID != nodeID { // If nodeID != 0, the situation is clear (we are checking that // the other side is talking to the right node). // @@ -149,29 +149,31 @@ func (hs *HeartbeatService) Ping(ctx context.Context, args *PingRequest) (*PingR // however we can still serve connections that don't need a node // ID, e.g. during initial gossip. return nil, errors.Errorf( - "client requested node ID %d doesn't match server node ID %d", args.TargetNodeID, nodeID) + "client requested node ID %d doesn't match server node ID %d", request.TargetNodeID, nodeID) } // Check version compatibility. - if err := checkVersion(ctx, hs.version, args.ServerVersion); err != nil { + if err := checkVersion(ctx, hs.version, request.ServerVersion); err != nil { return nil, errors.Wrap(err, "version compatibility check failed on ping request") } - if fn := hs.onHandlePing; fn != nil { - if err := fn(ctx, args); err != nil { - return nil, err - } - } - - serverOffset := args.Offset + serverOffset := request.Offset // The server offset should be the opposite of the client offset. serverOffset.Offset = -serverOffset.Offset - hs.remoteClockMonitor.UpdateOffset(ctx, args.OriginNodeID, serverOffset, 0 /* roundTripLatency */) - return &PingResponse{ - Pong: args.Ping, + hs.remoteClockMonitor.UpdateOffset(ctx, request.OriginNodeID, serverOffset, 0 /* roundTripLatency */) + response := PingResponse{ + Pong: request.Ping, ServerTime: hs.clock.Now().UnixNano(), ServerVersion: hs.version.BinaryVersion(), ClusterName: hs.clusterName, DisableClusterNameVerification: hs.disableClusterNameVerification, - }, nil + } + + if fn := hs.onHandlePing; fn != nil { + if err := fn(ctx, request, &response); err != nil { + return nil, err + } + } + + return &response, nil } diff --git a/pkg/rpc/heartbeat.proto b/pkg/rpc/heartbeat.proto index be21a62564da..d561d8e8f3ee 100644 --- a/pkg/rpc/heartbeat.proto +++ b/pkg/rpc/heartbeat.proto @@ -41,9 +41,8 @@ message PingRequest { optional string ping = 1 [(gogoproto.nullable) = false]; // The last offset the client measured with the server. optional RemoteOffset offset = 2 [(gogoproto.nullable) = false]; - // The address of the client. - // TODO(baptist): Remove this field in v23.2. It is no longer read. - optional string deprecated_origin_addr = 3 [(gogoproto.nullable) = false]; + // The advertised address of the client. + optional string origin_addr = 3 [(gogoproto.nullable) = false]; // Cluster ID to prevent connections between nodes in different clusters. optional bytes origin_cluster_id = 5 [ (gogoproto.customname) = "ClusterID", @@ -61,6 +60,36 @@ message PingRequest { (gogoproto.customname) = "OriginNodeID", (gogoproto.customtype) = "github.com/cockroachdb/cockroach/pkg/roachpb.NodeID"]; + // The mapping of locality addresses for this node. These are used by the + // receiver of the node to initiate a dialback connection. This same + // information is also sent over gossip, but the first ping is sent prior to + // gossip being available, so it is included here also. + repeated cockroach.roachpb.LocalityAddress locality_address = 9 [(gogoproto.nullable) = false]; + + + enum DialbackType { + // The recipient should send a PingResponse without checking if there is a + // reverse connection. + NONE = 0; + // If there is already an established reverse connection, respond + // immediately, otherwise create a reverse connection and wait until this + // connection is successfully established before responding. + BLOCKING = 1; + // If there is already an established reverse connection, respond + // immediately, otherwise if there is no reverse connection, attempt to + // create one asynchronously. If that fails to connect, then respond with + // failure to the next PingRequest on this connection. + NON_BLOCKING = 2; + } + // What type of dialback is requested for the recipient of this PingRequest + // NB: A node that receives this request without this field (from a + // pre-dialback client) set will treat it as NONE mode and not attempt + // dialback. + // A node that sets this field to a server that doesn't understand it will + // result in it being ignored. + // As this is just an additional validation, both these behaviors are OK. + optional DialbackType needs_dialback = 10 [(gogoproto.nullable) = false]; + reserved 4; } diff --git a/pkg/rpc/nodedialer/nodedialer.go b/pkg/rpc/nodedialer/nodedialer.go index d8a409d77662..9fb39f38b0d1 100644 --- a/pkg/rpc/nodedialer/nodedialer.go +++ b/pkg/rpc/nodedialer/nodedialer.go @@ -250,6 +250,11 @@ func (n *Dialer) ConnHealth(nodeID roachpb.NodeID, class rpc.ConnectionClass) er // down), and should be avoided in latency-sensitive code paths. Preferably, // this should be replaced by some other mechanism to maintain RPC connections. // See also: https://github.com/cockroachdb/cockroach/issues/70111 +// TODO(baptist): This method is poorly named and confusing. It is used as a +// "hint" to use a connection if it already exists, but simultaneously kick off +// a connection attempt in the background if it doesn't and always return +// immediately. It is only used today by DistSQL and it should probably be +// removed and moved into that code. func (n *Dialer) ConnHealthTryDial(nodeID roachpb.NodeID, class rpc.ConnectionClass) error { err := n.ConnHealth(nodeID, class) if err == nil || !n.getBreaker(nodeID, class).Ready() { @@ -259,6 +264,8 @@ func (n *Dialer) ConnHealthTryDial(nodeID roachpb.NodeID, class rpc.ConnectionCl if err != nil { return err } + // NB: This will always return `ErrNotHeartbeated` since the heartbeat will + // not be done by the time `Health` is called since GRPCDialNode is async. return n.rpcContext.GRPCDialNode(addr.String(), nodeID, class).Health() } diff --git a/pkg/rpc/nodedialer/nodedialer_test.go b/pkg/rpc/nodedialer/nodedialer_test.go index f33bbe6f8809..c0a13d37b1fd 100644 --- a/pkg/rpc/nodedialer/nodedialer_test.go +++ b/pkg/rpc/nodedialer/nodedialer_test.go @@ -140,7 +140,9 @@ func TestConnHealth(t *testing.T) { // After dialing the node, ConnHealth should return nil. _, err := nd.Dial(ctx, staticNodeID, rpc.DefaultClass) require.NoError(t, err) - require.NoError(t, nd.ConnHealth(staticNodeID, rpc.DefaultClass)) + require.Eventually(t, func() bool { + return nd.ConnHealth(staticNodeID, rpc.DefaultClass) == nil + }, time.Second, 10*time.Millisecond) // ConnHealth should still error for other node ID and class. require.Error(t, nd.ConnHealth(9, rpc.DefaultClass)) @@ -166,7 +168,9 @@ func TestConnHealth(t *testing.T) { _, err := nd.DialNoBreaker(ctx, staticNodeID, rpc.DefaultClass) return err == nil }, 10*time.Second, time.Millisecond) - require.NoError(t, nd.ConnHealth(staticNodeID, rpc.DefaultClass)) + require.Eventually(t, func() bool { + return nd.ConnHealth(staticNodeID, rpc.DefaultClass) == nil + }, time.Second, 10*time.Millisecond) } // Tripping the breaker should return ErrBreakerOpen. diff --git a/pkg/server/config.go b/pkg/server/config.go index 0390fb1fc6a7..da2ca00d8b4a 100644 --- a/pkg/server/config.go +++ b/pkg/server/config.go @@ -422,10 +422,6 @@ type KVConfig struct { // DefaultSystemZoneConfigOverride server testing knob. DefaultSystemZoneConfig zonepb.ZoneConfig - // LocalityAddresses contains private IP addresses the can only be accessed - // in the corresponding locality. - LocalityAddresses []roachpb.LocalityAddress - // EventLogEnabled is a switch which enables recording into cockroach's SQL // event log tables. These tables record transactional events about changes // to cluster metadata, such as DDL statements and range rebalancing diff --git a/pkg/server/server.go b/pkg/server/server.go index 6fe4a0aa3e96..30cc3db2c325 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -313,18 +313,8 @@ func NewServer(cfg Config, stopper *stop.Stopper) (*Server, error) { // still be tried as caller node is valid, but not the destination. return checkPingFor(ctx, req.TargetNodeID, codes.FailedPrecondition) }, - OnIncomingPing: func(ctx context.Context, req *rpc.PingRequest) error { - // Decommission state is only tracked for the system tenant. - if tenantID, isTenant := roachpb.ClientTenantFromContext(ctx); isTenant && - !roachpb.IsSystemTenantID(tenantID.ToUint64()) { - return nil - } - // Incoming ping will reject requests with codes.PermissionDenied to - // signal remote node that it is not considered valid anymore and - // operations should fail immediately. - return checkPingFor(ctx, req.OriginNodeID, codes.PermissionDenied) - }, TenantRPCAuthorizer: authorizer, + NeedsDialback: true, } if knobs := cfg.TestingKnobs.Server; knobs != nil { serverKnobs := knobs.(*TestingKnobs) @@ -332,6 +322,21 @@ func NewServer(cfg Config, stopper *stop.Stopper) (*Server, error) { } rpcContext := rpc.NewContext(ctx, rpcCtxOpts) + rpcContext.OnIncomingPing = func(ctx context.Context, req *rpc.PingRequest, resp *rpc.PingResponse) error { + // Decommission state is only tracked for the system tenant. + if tenantID, isTenant := roachpb.ClientTenantFromContext(ctx); isTenant && + !roachpb.IsSystemTenantID(tenantID.ToUint64()) { + return nil + } + if err := rpcContext.VerifyDialback(ctx, req, resp, cfg.Locality); err != nil { + return err + } + // Incoming ping will reject requests with codes.PermissionDenied to + // signal remote node that it is not considered valid anymore and + // operations should fail immediately. + return checkPingFor(ctx, req.OriginNodeID, codes.PermissionDenied) + } + rpcContext.HeartbeatCB = func() { if err := rpcContext.RemoteClocks.VerifyClockOffset(ctx); err != nil { log.Ops.Fatalf(ctx, "%v", err)