diff --git a/pkg/rpc/nodedialer/nodedialer.go b/pkg/rpc/nodedialer/nodedialer.go index 6c9092e36078..0b6a3e44853c 100644 --- a/pkg/rpc/nodedialer/nodedialer.go +++ b/pkg/rpc/nodedialer/nodedialer.go @@ -83,33 +83,13 @@ func (n *Dialer) Dial(ctx context.Context, nodeID roachpb.NodeID) (_ *grpc.Clien return nil, ctxErr } breaker := n.getBreaker(nodeID) - - if !breaker.Ready() { - err := errors.Wrapf(circuit.ErrBreakerOpen, "unable to dial n%d", nodeID) - return nil, err - } - - defer func() { - // Enforce a minimum interval between warnings for failed connections. - if err != nil && breaker.ShouldLog() { - log.Infof(ctx, "unable to connect to n%d: %s", nodeID, err) - } - }() - addr, err := n.resolver(nodeID) if err != nil { err = errors.Wrapf(err, "failed to resolve n%d", nodeID) breaker.Fail(err) return nil, err } - conn, err := n.rpcContext.GRPCDial(addr.String()).Connect(ctx) - if err != nil { - err = errors.Wrapf(err, "failed to grpc dial n%d at %v", nodeID, addr) - breaker.Fail(err) - return nil, err - } - breaker.Success() - return conn, nil + return n.dial(ctx, nodeID, addr, breaker) } // DialInternalClient is a specialization of Dial for callers that @@ -124,10 +104,6 @@ func (n *Dialer) DialInternalClient( if n == nil || n.resolver == nil { return nil, nil, errors.New("no node dialer configured") } - // Don't trip the breaker if we're already canceled. - if ctxErr := ctx.Err(); ctxErr != nil { - return nil, nil, ctxErr - } addr, err := n.resolver(nodeID) if err != nil { return nil, nil, err @@ -141,24 +117,51 @@ func (n *Dialer) DialInternalClient( return localCtx, localClient, nil } - - breaker := n.getBreaker(nodeID) - log.VEventf(ctx, 2, "sending request to %s", addr) + conn, err := n.dial(ctx, nodeID, addr, n.getBreaker(nodeID)) + if err != nil { + return nil, nil, err + } + return ctx, roachpb.NewInternalClient(conn), err +} + +// dial performs the dialing of the remove connection. +func (n *Dialer) dial( + ctx context.Context, nodeID roachpb.NodeID, addr net.Addr, breaker *wrappedBreaker, +) (_ *grpc.ClientConn, err error) { + // Don't trip the breaker if we're already canceled. + if ctxErr := ctx.Err(); ctxErr != nil { + return nil, ctxErr + } + if !breaker.Ready() { + err = errors.Wrapf(circuit.ErrBreakerOpen, "unable to dial n%d", nodeID) + return nil, err + } + defer func() { + // Enforce a minimum interval between warnings for failed connections. + if err != nil && err != ctx.Err() && breaker.ShouldLog() { + log.Infof(ctx, "unable to connect to n%d: %s", nodeID, err) + } + }() conn, err := n.rpcContext.GRPCDial(addr.String()).Connect(ctx) if err != nil { + // If we were canceled during the dial, don't trip the breaker. + if ctxErr := ctx.Err(); ctxErr != nil { + return nil, ctxErr + } err = errors.Wrapf(err, "failed to connect to n%d at %v", nodeID, addr) breaker.Fail(err) - return nil, nil, err + return nil, err } // Check to see if the connection is in the transient failure state. This can // happen if the connection already existed, but a recent heartbeat has // failed and we haven't yet torn down the connection. if err := grpcutil.ConnectionReady(conn); err != nil { - err = errors.Wrapf(err, "failed to check for connection ready to n%d at %v", nodeID, addr) + err = errors.Wrapf(err, "failed to check for ready connection to n%d at %v", nodeID, addr) breaker.Fail(err) - return nil, nil, err + return nil, err } + // TODO(bdarnell): Reconcile the different health checks and circuit breaker // behavior in this file. Note that this different behavior causes problems // for higher-levels in the system. For example, DistSQL checks for @@ -166,7 +169,7 @@ func (n *Dialer) DialInternalClient( // RPCs fail when dial fails due to an open breaker. Reset the breaker here // as a stop-gap before the reconciliation occurs. breaker.Success() - return ctx, roachpb.NewInternalClient(conn), nil + return conn, nil } // ConnHealth returns nil if we have an open connection to the given node diff --git a/pkg/rpc/nodedialer/nodedialer_test.go b/pkg/rpc/nodedialer/nodedialer_test.go new file mode 100644 index 000000000000..8f9ed2a2b7bb --- /dev/null +++ b/pkg/rpc/nodedialer/nodedialer_test.go @@ -0,0 +1,317 @@ +// Copyright 2019 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. + +package nodedialer + +import ( + "context" + "fmt" + "math/rand" + "net" + "sync" + "testing" + "time" + + circuit "github.com/cockroachdb/circuitbreaker" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/settings/cluster" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/util/hlc" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/cockroach/pkg/util/tracing" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" +) + +func TestNodedialerPositive(t *testing.T) { + defer leaktest.AfterTest(t)() + stopper, rpcCtx, ln, _ := setUpNodedialerTest(t) + defer stopper.Stop(context.TODO()) + nd := New(rpcCtx, newSingleNodeResolver(1, ln.Addr())) + // Ensure that dialing works. + breaker := nd.GetCircuitBreaker(1) + assert.True(t, breaker.Ready()) + ctx := context.Background() + _, err := nd.Dial(ctx, 1) + assert.Nil(t, err, "failed to dial") + assert.True(t, breaker.Ready()) +} + +func TestConcurrentCancellationAndTimeout(t *testing.T) { + defer leaktest.AfterTest(t)() + stopper, rpcCtx, ln, _ := setUpNodedialerTest(t) + defer stopper.Stop(context.TODO()) + nd := New(rpcCtx, newSingleNodeResolver(1, ln.Addr())) + ctx := context.Background() + breaker := nd.GetCircuitBreaker(1) + // Test that when a context is canceled during dialing we always return that + // error but we never trip the breaker. + const N = 1000 + var wg sync.WaitGroup + for i := 0; i < N; i++ { + wg.Add(2) + // Jiggle when we cancel relative to when we dial to try to hit cases where + // cancellation happens during the call to GRPCDial. + iCtx, cancel := context.WithTimeout(ctx, randDuration(time.Millisecond)) + go func() { + time.Sleep(randDuration(time.Millisecond)) + cancel() + wg.Done() + }() + go func() { + time.Sleep(randDuration(time.Millisecond)) + _, err := nd.Dial(iCtx, 1) + if err != nil && + err != context.Canceled && + err != context.DeadlineExceeded { + t.Errorf("got an unexpected error from Dial: %v", err) + } + wg.Done() + }() + } + wg.Wait() + assert.Equal(t, breaker.Failures(), int64(0)) +} + +func TestResolverErrorsTrip(t *testing.T) { + defer leaktest.AfterTest(t)() + stopper, rpcCtx, _, _ := setUpNodedialerTest(t) + defer stopper.Stop(context.TODO()) + boom := fmt.Errorf("boom") + nd := New(rpcCtx, func(id roachpb.NodeID) (net.Addr, error) { + return nil, boom + }) + _, err := nd.Dial(context.Background(), 1) + assert.Equal(t, errors.Cause(err), boom) + breaker := nd.GetCircuitBreaker(1) + assert.False(t, breaker.Ready()) +} + +func TestDisconnectsTrip(t *testing.T) { + defer leaktest.AfterTest(t)() + stopper, rpcCtx, ln, hb := setUpNodedialerTest(t) + defer stopper.Stop(context.TODO()) + nd := New(rpcCtx, newSingleNodeResolver(1, ln.Addr())) + ctx := context.Background() + breaker := nd.GetCircuitBreaker(1) + + // Now close the underlying connection from the server side and set the + // heartbeat service to return errors. This will eventually lead to the client + // connection being removed and Dial attempts to return an error. + // While this is going on there will be many clients attempting to + // connect. These connecting clients will send interesting errors they observe + // on the errChan. Once an error from Dial is observed the test re-enables the + // heartbeat service. The test will confirm that the only errors they record + // in to the breaker are interesting ones as determined by shouldTrip. + hb.setErr(fmt.Errorf("boom")) + underlyingNetConn := ln.popConn() + assert.Nil(t, underlyingNetConn.Close()) + const N = 1000 + breakerEventChan := make(chan circuit.ListenerEvent, N) + breaker.AddListener(breakerEventChan) + errChan := make(chan error, N) + shouldTrip := func(err error) bool { + return err != nil && + err != context.DeadlineExceeded && + err != context.Canceled && + errors.Cause(err) != circuit.ErrBreakerOpen + } + var wg sync.WaitGroup + for i := 0; i < N; i++ { + wg.Add(2) + iCtx, cancel := context.WithTimeout(ctx, randDuration(time.Millisecond)) + go func() { + time.Sleep(randDuration(time.Millisecond)) + cancel() + wg.Done() + }() + go func() { + time.Sleep(randDuration(time.Millisecond)) + _, err := nd.Dial(iCtx, 1) + if shouldTrip(err) { + errChan <- err + } + wg.Done() + }() + } + go func() { wg.Wait(); close(errChan) }() + var errorsSeen int + for range errChan { + if errorsSeen == 0 { + hb.setErr(nil) + } + errorsSeen++ + } + breaker.RemoveListener(breakerEventChan) + close(breakerEventChan) + var failsSeen int + for ev := range breakerEventChan { + if ev.Event == circuit.BreakerFail { + failsSeen++ + } + } + // Ensure that all of the interesting errors were seen by the breaker. + assert.Equal(t, errorsSeen, failsSeen) + + // Ensure that the connection become healthy soon now that the heartbeat + // service is not returning errors. + hb.setErr(nil) // reset in case there were no errors + testutils.SucceedsSoon(t, func() error { + return rpcCtx.ConnHealth(ln.Addr().String()) + }) +} + +func setUpNodedialerTest( + t *testing.T, +) (stopper *stop.Stopper, rpcCtx *rpc.Context, ln *interceptingListener, hb *heartbeatService) { + stopper = stop.NewStopper() + clock := hlc.NewClock(hlc.UnixNano, time.Nanosecond) + // Create an rpc Context and then + rpcCtx = newTestContext(clock, stopper) + _, ln, hb = newTestServer(t, clock, stopper) + testutils.SucceedsSoon(t, func() error { + return rpcCtx.ConnHealth(ln.Addr().String()) + }) + return stopper, rpcCtx, ln, hb +} + +// randDuration returns a uniform random duration between 0 and max. +func randDuration(max time.Duration) time.Duration { + return time.Duration(rand.Intn(int(max))) +} + +func newTestServer( + t testing.TB, clock *hlc.Clock, stopper *stop.Stopper, +) (*grpc.Server, *interceptingListener, *heartbeatService) { + ctx := context.Background() + localAddr := "127.0.0.1:0" + ln, err := net.Listen("tcp", localAddr) + if err != nil { + t.Fatalf("failed to listed on %v: %v", localAddr, err) + } + il := &interceptingListener{Listener: ln} + s := grpc.NewServer() + serverVersion := cluster.MakeTestingClusterSettings().Version.ServerVersion + hb := &heartbeatService{ + clock: clock, + serverVersion: serverVersion, + } + rpc.RegisterHeartbeatServer(s, hb) + if err := stopper.RunAsyncTask(ctx, "localServer", func(ctx context.Context) { + if err := s.Serve(il); err != nil { + log.Infof(ctx, "server stopped: %v", err) + } + }); err != nil { + t.Fatalf("failed to run test server: %v", err) + } + go func() { <-stopper.ShouldQuiesce(); s.Stop() }() + return s, il, hb +} + +func newTestContext(clock *hlc.Clock, stopper *stop.Stopper) *rpc.Context { + cfg := testutils.NewNodeTestBaseContext() + cfg.Insecure = true + return rpc.NewContext( + log.AmbientContext{Tracer: tracing.NewTracer()}, + cfg, + clock, + stopper, + &cluster.MakeTestingClusterSettings().Version, + ) +} + +// interceptingListener wraps a net.Listener and provides access to the +// underlying net.Conn objects which that listener Accepts. +type interceptingListener struct { + net.Listener + mu struct { + syncutil.Mutex + conns []net.Conn + } +} + +// newSingleNodeResolver returns a Resolver that resolve a single node id +func newSingleNodeResolver(id roachpb.NodeID, addr net.Addr) AddressResolver { + return func(toResolve roachpb.NodeID) (net.Addr, error) { + if id == toResolve { + return addr, nil + } + return nil, fmt.Errorf("unknown node id %d", toResolve) + } +} + +func (il *interceptingListener) Accept() (c net.Conn, err error) { + defer func() { + if err == nil { + il.mu.Lock() + il.mu.conns = append(il.mu.conns, c) + il.mu.Unlock() + } + }() + return il.Listener.Accept() +} + +func (il *interceptingListener) popConn() net.Conn { + il.mu.Lock() + defer il.mu.Unlock() + if len(il.mu.conns) == 0 { + return nil + } + c := il.mu.conns[0] + il.mu.conns = il.mu.conns[1:] + return c +} + +type errContainer struct { + syncutil.RWMutex + err error +} + +func (ec *errContainer) getErr() error { + ec.RLock() + defer ec.RUnlock() + return ec.err +} + +func (ec *errContainer) setErr(err error) { + ec.Lock() + defer ec.Unlock() + ec.err = err +} + +// heartbeatService is a dummy rpc.HeartbeatService which provides a mechanism +// to inject errors. +type heartbeatService struct { + errContainer + clock *hlc.Clock + serverVersion roachpb.Version +} + +func (hb *heartbeatService) Ping( + ctx context.Context, args *rpc.PingRequest, +) (*rpc.PingResponse, error) { + if err := hb.getErr(); err != nil { + return nil, err + } + return &rpc.PingResponse{ + Pong: args.Ping, + ServerTime: hb.clock.PhysicalNow(), + ServerVersion: hb.serverVersion, + }, nil +} diff --git a/pkg/testutils/lint/lint_test.go b/pkg/testutils/lint/lint_test.go index e4c8e73399a0..065a27469b89 100644 --- a/pkg/testutils/lint/lint_test.go +++ b/pkg/testutils/lint/lint_test.go @@ -480,6 +480,7 @@ func TestLint(t *testing.T) { "*.go", ":!rpc/context_test.go", ":!rpc/context.go", + ":!rpc/nodedialer/nodedialer_test.go", ":!util/grpcutil/grpc_util_test.go", ":!cli/systembench/network_test_server.go", )