From e2b98f96c919c49d1f9dd731c1a03fb49a6281a7 Mon Sep 17 00:00:00 2001 From: Arjan Singh Bal <46515553+arjan-bal@users.noreply.github.com> Date: Tue, 12 Nov 2024 14:34:17 +0530 Subject: [PATCH] pickfirst: Implement Happy Eyeballs (#7725) --- balancer/pickfirst/internal/internal.go | 17 +- .../pickfirst/pickfirstleaf/pickfirstleaf.go | 181 +++++++++---- .../pickfirstleaf/pickfirstleaf_ext_test.go | 243 ++++++++++++++++-- .../pickfirstleaf/pickfirstleaf_test.go | 13 - 4 files changed, 369 insertions(+), 85 deletions(-) diff --git a/balancer/pickfirst/internal/internal.go b/balancer/pickfirst/internal/internal.go index c51978945844..7d66cb491c40 100644 --- a/balancer/pickfirst/internal/internal.go +++ b/balancer/pickfirst/internal/internal.go @@ -18,7 +18,18 @@ // Package internal contains code internal to the pickfirst package. package internal -import "math/rand" +import ( + rand "math/rand/v2" + "time" +) -// RandShuffle pseudo-randomizes the order of addresses. -var RandShuffle = rand.Shuffle +var ( + // RandShuffle pseudo-randomizes the order of addresses. + RandShuffle = rand.Shuffle + // TimeAfterFunc allows mocking the timer for testing connection delay + // related functionality. + TimeAfterFunc = func(d time.Duration, f func()) func() { + timer := time.AfterFunc(d, f) + return func() { timer.Stop() } + } +) diff --git a/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go b/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go index 4b54866058d5..aaec87497fd4 100644 --- a/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go +++ b/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go @@ -31,6 +31,7 @@ import ( "fmt" "net" "sync" + "time" "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/pickfirst/internal" @@ -59,8 +60,13 @@ var ( Name = "pick_first_leaf" ) -// TODO: change to pick-first when this becomes the default pick_first policy. -const logPrefix = "[pick-first-leaf-lb %p] " +const ( + // TODO: change to pick-first when this becomes the default pick_first policy. + logPrefix = "[pick-first-leaf-lb %p] " + // connectionDelayInterval is the time to wait for during the happy eyeballs + // pass before starting the next connection attempt. + connectionDelayInterval = 250 * time.Millisecond +) type ipAddrFamily int @@ -76,11 +82,12 @@ type pickfirstBuilder struct{} func (pickfirstBuilder) Build(cc balancer.ClientConn, _ balancer.BuildOptions) balancer.Balancer { b := &pickfirstBalancer{ - cc: cc, - addressList: addressList{}, - subConns: resolver.NewAddressMap(), - state: connectivity.Connecting, - mu: sync.Mutex{}, + cc: cc, + addressList: addressList{}, + subConns: resolver.NewAddressMap(), + state: connectivity.Connecting, + mu: sync.Mutex{}, + cancelConnectionTimer: func() {}, } b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(logPrefix, b)) return b @@ -115,8 +122,9 @@ type scData struct { subConn balancer.SubConn addr resolver.Address - state connectivity.State - lastErr error + state connectivity.State + lastErr error + connectionFailedInFirstPass bool } func (b *pickfirstBalancer) newSCData(addr resolver.Address) (*scData, error) { @@ -148,10 +156,11 @@ type pickfirstBalancer struct { mu sync.Mutex state connectivity.State // scData for active subonns mapped by address. - subConns *resolver.AddressMap - addressList addressList - firstPass bool - numTF int + subConns *resolver.AddressMap + addressList addressList + firstPass bool + numTF int + cancelConnectionTimer func() } // ResolverError is called by the ClientConn when the name resolver produces @@ -186,6 +195,7 @@ func (b *pickfirstBalancer) resolverErrorLocked(err error) { func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState) error { b.mu.Lock() defer b.mu.Unlock() + b.cancelConnectionTimer() if len(state.ResolverState.Addresses) == 0 && len(state.ResolverState.Endpoints) == 0 { // Cleanup state pertaining to the previous resolver state. // Treat an empty address list like an error by calling b.ResolverError. @@ -239,12 +249,8 @@ func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState // Not de-duplicating would result in attempting to connect to the same // SubConn multiple times in the same pass. We don't want this. newAddrs = deDupAddresses(newAddrs) - newAddrs = interleaveAddresses(newAddrs) - // Since we have a new set of addresses, we are again at first pass. - b.firstPass = true - // If the previous ready SubConn exists in new address list, // keep this connection and don't create new SubConns. prevAddr := b.addressList.currentAddress() @@ -269,11 +275,11 @@ func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState ConnectivityState: connectivity.Connecting, Picker: &picker{err: balancer.ErrNoSubConnAvailable}, }) - b.requestConnectionLocked() + b.startFirstPassLocked() } else if b.state == connectivity.TransientFailure { // If we're in TRANSIENT_FAILURE, we stay in TRANSIENT_FAILURE until // we're READY. See A62. - b.requestConnectionLocked() + b.startFirstPassLocked() } return nil } @@ -288,6 +294,7 @@ func (b *pickfirstBalancer) Close() { b.mu.Lock() defer b.mu.Unlock() b.closeSubConnsLocked() + b.cancelConnectionTimer() b.state = connectivity.Shutdown } @@ -297,12 +304,21 @@ func (b *pickfirstBalancer) Close() { func (b *pickfirstBalancer) ExitIdle() { b.mu.Lock() defer b.mu.Unlock() - if b.state == connectivity.Idle && b.addressList.currentAddress() == b.addressList.first() { - b.firstPass = true - b.requestConnectionLocked() + if b.state == connectivity.Idle { + b.startFirstPassLocked() } } +func (b *pickfirstBalancer) startFirstPassLocked() { + b.firstPass = true + b.numTF = 0 + // Reset the connection attempt record for existing SubConns. + for _, sd := range b.subConns.Values() { + sd.(*scData).connectionFailedInFirstPass = false + } + b.requestConnectionLocked() +} + func (b *pickfirstBalancer) closeSubConnsLocked() { for _, sd := range b.subConns.Values() { sd.(*scData).subConn.Shutdown() @@ -413,6 +429,7 @@ func (b *pickfirstBalancer) reconcileSubConnsLocked(newAddrs []resolver.Address) // shutdownRemainingLocked shuts down remaining subConns. Called when a subConn // becomes ready, which means that all other subConn must be shutdown. func (b *pickfirstBalancer) shutdownRemainingLocked(selected *scData) { + b.cancelConnectionTimer() for _, v := range b.subConns.Values() { sd := v.(*scData) if sd.subConn != selected.subConn { @@ -456,30 +473,69 @@ func (b *pickfirstBalancer) requestConnectionLocked() { switch scd.state { case connectivity.Idle: scd.subConn.Connect() + b.scheduleNextConnectionLocked() + return case connectivity.TransientFailure: - // Try the next address. + // The SubConn is being re-used and failed during a previous pass + // over the addressList. It has not completed backoff yet. + // Mark it as having failed and try the next address. + scd.connectionFailedInFirstPass = true lastErr = scd.lastErr continue - case connectivity.Ready: - // Should never happen. - b.logger.Errorf("Requesting a connection even though we have a READY SubConn") - case connectivity.Shutdown: - // Should never happen. - b.logger.Errorf("SubConn with state SHUTDOWN present in SubConns map") case connectivity.Connecting: - // Wait for the SubConn to report success or failure. + // Wait for the connection attempt to complete or the timer to fire + // before attempting the next address. + b.scheduleNextConnectionLocked() + return + default: + b.logger.Errorf("SubConn with unexpected state %v present in SubConns map.", scd.state) + return + } - return } + // All the remaining addresses in the list are in TRANSIENT_FAILURE, end the - // first pass. - b.endFirstPassLocked(lastErr) + // first pass if possible. + b.endFirstPassIfPossibleLocked(lastErr) +} + +func (b *pickfirstBalancer) scheduleNextConnectionLocked() { + b.cancelConnectionTimer() + if !b.addressList.hasNext() { + return + } + curAddr := b.addressList.currentAddress() + cancelled := false // Access to this is protected by the balancer's mutex. + closeFn := internal.TimeAfterFunc(connectionDelayInterval, func() { + b.mu.Lock() + defer b.mu.Unlock() + // If the scheduled task is cancelled while acquiring the mutex, return. + if cancelled { + return + } + if b.logger.V(2) { + b.logger.Infof("Happy Eyeballs timer expired while waiting for connection to %q.", curAddr.Addr) + } + if b.addressList.increment() { + b.requestConnectionLocked() + } + }) + // Access to the cancellation callback held by the balancer is guarded by + // the balancer's mutex, so it's safe to set the boolean from the callback. + b.cancelConnectionTimer = sync.OnceFunc(func() { + cancelled = true + closeFn() + }) } func (b *pickfirstBalancer) updateSubConnState(sd *scData, newState balancer.SubConnState) { b.mu.Lock() defer b.mu.Unlock() oldState := sd.state + // Record a connection attempt when exiting CONNECTING. + if newState.ConnectivityState == connectivity.TransientFailure { + sd.connectionFailedInFirstPass = true + } sd.state = newState.ConnectivityState // Previously relevant SubConns can still callback with state updates. // To prevent pickers from returning these obsolete SubConns, this logic @@ -545,17 +601,20 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, newState balancer.Sub sd.lastErr = newState.ConnectionError // Since we're re-using common SubConns while handling resolver // updates, we could receive an out of turn TRANSIENT_FAILURE from - // a pass over the previous address list. We ignore such updates. - - if curAddr := b.addressList.currentAddress(); !equalAddressIgnoringBalAttributes(&curAddr, &sd.addr) { - return - } - if b.addressList.increment() { - b.requestConnectionLocked() - return + // a pass over the previous address list. Happy Eyeballs will also + // cause out of order updates to arrive. + + if curAddr := b.addressList.currentAddress(); equalAddressIgnoringBalAttributes(&curAddr, &sd.addr) { + b.cancelConnectionTimer() + if b.addressList.increment() { + b.requestConnectionLocked() + return + } } - // End of the first pass. - b.endFirstPassLocked(newState.ConnectionError) + + // End the first pass if we've seen a TRANSIENT_FAILURE from all + // SubConns once. + b.endFirstPassIfPossibleLocked(newState.ConnectionError) } return } @@ -580,9 +639,22 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, newState balancer.Sub } } -func (b *pickfirstBalancer) endFirstPassLocked(lastErr error) { +// endFirstPassIfPossibleLocked ends the first happy-eyeballs pass if all the +// addresses are tried and their SubConns have reported a failure. +func (b *pickfirstBalancer) endFirstPassIfPossibleLocked(lastErr error) { + // An optimization to avoid iterating over the entire SubConn map. + if b.addressList.isValid() { + return + } + // Connect() has been called on all the SubConns. The first pass can be + // ended if all the SubConns have reported a failure. + for _, v := range b.subConns.Values() { + sd := v.(*scData) + if !sd.connectionFailedInFirstPass { + return + } + } b.firstPass = false - b.numTF = 0 b.state = connectivity.TransientFailure b.cc.UpdateState(balancer.State{ @@ -654,15 +726,6 @@ func (al *addressList) currentAddress() resolver.Address { return al.addresses[al.idx] } -// first returns the first address in the list. If the list is empty, it returns -// an empty address instead. -func (al *addressList) first() resolver.Address { - if len(al.addresses) == 0 { - return resolver.Address{} - } - return al.addresses[0] -} - func (al *addressList) reset() { al.idx = 0 } @@ -685,6 +748,16 @@ func (al *addressList) seekTo(needle resolver.Address) bool { return false } +// hasNext returns whether incrementing the addressList will result in moving +// past the end of the list. If the list has already moved past the end, it +// returns false. +func (al *addressList) hasNext() bool { + if !al.isValid() { + return false + } + return al.idx+1 < len(al.addresses) +} + // equalAddressIgnoringBalAttributes returns true is a and b are considered // equal. This is different from the Equal method on the resolver.Address type // which considers all fields to determine equality. Here, we only consider diff --git a/balancer/pickfirst/pickfirstleaf/pickfirstleaf_ext_test.go b/balancer/pickfirst/pickfirstleaf/pickfirstleaf_ext_test.go index 46e47be43ffa..bf957f98b119 100644 --- a/balancer/pickfirst/pickfirstleaf/pickfirstleaf_ext_test.go +++ b/balancer/pickfirst/pickfirstleaf/pickfirstleaf_ext_test.go @@ -28,6 +28,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/balancer" + pfinternal "google.golang.org/grpc/balancer/pickfirst/internal" "google.golang.org/grpc/balancer/pickfirst/pickfirstleaf" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" @@ -66,8 +67,7 @@ func Test(t *testing.T) { } // setupPickFirstLeaf performs steps required for pick_first tests. It starts a -// bunch of backends exporting the TestService, creates a ClientConn to them -// with service config specifying the use of the state_storing LB policy. +// bunch of backends exporting the TestService, and creates a ClientConn to them. func setupPickFirstLeaf(t *testing.T, backendCount int, opts ...grpc.DialOption) (*grpc.ClientConn, *manual.Resolver, *backendManager) { t.Helper() r := manual.NewBuilderWithScheme("whatever") @@ -86,7 +86,6 @@ func setupPickFirstLeaf(t *testing.T, backendCount int, opts ...grpc.DialOption) dopts := []grpc.DialOption{ grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r), - grpc.WithDefaultServiceConfig(stateStoringServiceConfig), } dopts = append(dopts, opts...) cc, err := grpc.NewClient(r.Scheme()+":///test.server", dopts...) @@ -121,7 +120,7 @@ func (s) TestPickFirstLeaf_SimpleResolverUpdate_FirstServerReady(t *testing.T) { balCh := make(chan *stateStoringBalancer, 1) balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) - cc, r, bm := setupPickFirstLeaf(t, 2) + cc, r, bm := setupPickFirstLeaf(t, 2, grpc.WithDefaultServiceConfig(stateStoringServiceConfig)) addrs := bm.resolverAddrs() stateSubscriber := &ccStateSubscriber{} internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) @@ -161,7 +160,7 @@ func (s) TestPickFirstLeaf_SimpleResolverUpdate_FirstServerUnReady(t *testing.T) balCh := make(chan *stateStoringBalancer, 1) balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) - cc, r, bm := setupPickFirstLeaf(t, 2) + cc, r, bm := setupPickFirstLeaf(t, 2, grpc.WithDefaultServiceConfig(stateStoringServiceConfig)) addrs := bm.resolverAddrs() stateSubscriber := &ccStateSubscriber{} internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) @@ -203,7 +202,7 @@ func (s) TestPickFirstLeaf_SimpleResolverUpdate_DuplicateAddrs(t *testing.T) { balCh := make(chan *stateStoringBalancer, 1) balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) - cc, r, bm := setupPickFirstLeaf(t, 2) + cc, r, bm := setupPickFirstLeaf(t, 2, grpc.WithDefaultServiceConfig(stateStoringServiceConfig)) addrs := bm.resolverAddrs() stateSubscriber := &ccStateSubscriber{} internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) @@ -259,7 +258,7 @@ func (s) TestPickFirstLeaf_ResolverUpdates_DisjointLists(t *testing.T) { balCh := make(chan *stateStoringBalancer, 1) balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) - cc, r, bm := setupPickFirstLeaf(t, 4) + cc, r, bm := setupPickFirstLeaf(t, 4, grpc.WithDefaultServiceConfig(stateStoringServiceConfig)) addrs := bm.resolverAddrs() stateSubscriber := &ccStateSubscriber{} internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) @@ -322,7 +321,7 @@ func (s) TestPickFirstLeaf_ResolverUpdates_ActiveBackendInUpdatedList(t *testing balCh := make(chan *stateStoringBalancer, 1) balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) - cc, r, bm := setupPickFirstLeaf(t, 3) + cc, r, bm := setupPickFirstLeaf(t, 3, grpc.WithDefaultServiceConfig(stateStoringServiceConfig)) addrs := bm.resolverAddrs() stateSubscriber := &ccStateSubscriber{} internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) @@ -386,7 +385,7 @@ func (s) TestPickFirstLeaf_ResolverUpdates_InActiveBackendInUpdatedList(t *testi balCh := make(chan *stateStoringBalancer, 1) balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) - cc, r, bm := setupPickFirstLeaf(t, 3) + cc, r, bm := setupPickFirstLeaf(t, 3, grpc.WithDefaultServiceConfig(stateStoringServiceConfig)) addrs := bm.resolverAddrs() stateSubscriber := &ccStateSubscriber{} internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) @@ -451,7 +450,7 @@ func (s) TestPickFirstLeaf_ResolverUpdates_IdenticalLists(t *testing.T) { balCh := make(chan *stateStoringBalancer, 1) balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) - cc, r, bm := setupPickFirstLeaf(t, 2) + cc, r, bm := setupPickFirstLeaf(t, 2, grpc.WithDefaultServiceConfig(stateStoringServiceConfig)) addrs := bm.resolverAddrs() stateSubscriber := &ccStateSubscriber{} internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) @@ -524,7 +523,7 @@ func (s) TestPickFirstLeaf_StopConnectedServer_FirstServerRestart(t *testing.T) balCh := make(chan *stateStoringBalancer, 1) balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) - cc, r, bm := setupPickFirstLeaf(t, 2) + cc, r, bm := setupPickFirstLeaf(t, 2, grpc.WithDefaultServiceConfig(stateStoringServiceConfig)) addrs := bm.resolverAddrs() stateSubscriber := &ccStateSubscriber{} internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) @@ -589,7 +588,7 @@ func (s) TestPickFirstLeaf_StopConnectedServer_SecondServerRestart(t *testing.T) balCh := make(chan *stateStoringBalancer, 1) balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) - cc, r, bm := setupPickFirstLeaf(t, 2) + cc, r, bm := setupPickFirstLeaf(t, 2, grpc.WithDefaultServiceConfig(stateStoringServiceConfig)) addrs := bm.resolverAddrs() stateSubscriber := &ccStateSubscriber{} internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) @@ -661,7 +660,7 @@ func (s) TestPickFirstLeaf_StopConnectedServer_SecondServerToFirst(t *testing.T) balCh := make(chan *stateStoringBalancer, 1) balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) - cc, r, bm := setupPickFirstLeaf(t, 2) + cc, r, bm := setupPickFirstLeaf(t, 2, grpc.WithDefaultServiceConfig(stateStoringServiceConfig)) addrs := bm.resolverAddrs() stateSubscriber := &ccStateSubscriber{} internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) @@ -733,7 +732,7 @@ func (s) TestPickFirstLeaf_StopConnectedServer_FirstServerToSecond(t *testing.T) balCh := make(chan *stateStoringBalancer, 1) balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) - cc, r, bm := setupPickFirstLeaf(t, 2) + cc, r, bm := setupPickFirstLeaf(t, 2, grpc.WithDefaultServiceConfig(stateStoringServiceConfig)) addrs := bm.resolverAddrs() stateSubscriber := &ccStateSubscriber{} internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) @@ -807,7 +806,7 @@ func (s) TestPickFirstLeaf_EmptyAddressList(t *testing.T) { defer cancel() balChan := make(chan *stateStoringBalancer, 1) balancer.Register(&stateStoringBalancerBuilder{balancer: balChan}) - cc, r, bm := setupPickFirstLeaf(t, 1) + cc, r, bm := setupPickFirstLeaf(t, 1, grpc.WithDefaultServiceConfig(stateStoringServiceConfig)) addrs := bm.resolverAddrs() stateSubscriber := &ccStateSubscriber{} @@ -850,6 +849,189 @@ func (s) TestPickFirstLeaf_EmptyAddressList(t *testing.T) { } } +// Test verifies that pickfirst correctly detects the end of the first happy +// eyeballs pass when the timer causes pickfirst to reach the end of the address +// list and failures are reported out of order. +func (s) TestPickFirstLeaf_HappyEyeballs_TF_AfterEndOfList(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + originalTimer := pfinternal.TimeAfterFunc + defer func() { + pfinternal.TimeAfterFunc = originalTimer + }() + triggerTimer, timeAfter := mockTimer() + pfinternal.TimeAfterFunc = timeAfter + + dialer := testutils.NewBlockingDialer() + opts := []grpc.DialOption{ + grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, pickfirstleaf.Name)), + grpc.WithContextDialer(dialer.DialContext), + } + cc, rb, bm := setupPickFirstLeaf(t, 3, opts...) + addrs := bm.resolverAddrs() + holds := bm.holds(dialer) + rb.UpdateState(resolver.State{Addresses: addrs}) + cc.Connect() + + testutils.AwaitState(ctx, t, cc, connectivity.Connecting) + + // Verify that only the first server is contacted. + if holds[0].Wait(ctx) != true { + t.Fatalf("Timeout waiting for server %d with address %q to be contacted", 0, addrs[0]) + } + if holds[1].IsStarted() != false { + t.Fatalf("Server %d with address %q contacted unexpectedly", 1, addrs[1]) + } + if holds[2].IsStarted() != false { + t.Fatalf("Server %d with address %q contacted unexpectedly", 2, addrs[2]) + } + + // Make the happy eyeballs timer fire once and verify that the + // second server is contacted, but the third isn't. + triggerTimer() + if holds[1].Wait(ctx) != true { + t.Fatalf("Timeout waiting for server %d with address %q to be contacted", 1, addrs[1]) + } + if holds[2].IsStarted() != false { + t.Fatalf("Server %d with address %q contacted unexpectedly", 2, addrs[2]) + } + + // Make the happy eyeballs timer fire once more and verify that the + // third server is contacted. + triggerTimer() + if holds[2].Wait(ctx) != true { + t.Fatalf("Timeout waiting for server %d with address %q to be contacted", 2, addrs[2]) + } + + // First SubConn Fails. + holds[0].Fail(fmt.Errorf("test error")) + + // No TF should be reported until the first pass is complete. + shortCtx, shortCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer shortCancel() + testutils.AwaitNotState(shortCtx, t, cc, connectivity.TransientFailure) + + // Third SubConn fails. + shortCtx, shortCancel = context.WithTimeout(ctx, defaultTestShortTimeout) + defer shortCancel() + holds[2].Fail(fmt.Errorf("test error")) + testutils.AwaitNotState(shortCtx, t, cc, connectivity.TransientFailure) + + // Last SubConn fails, this should result in a TF update. + holds[1].Fail(fmt.Errorf("test error")) + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) +} + +// Test verifies that pickfirst attempts to connect to the second backend once +// the happy eyeballs timer expires. +func (s) TestPickFirstLeaf_HappyEyeballs_TriggerConnectionDelay(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + originalTimer := pfinternal.TimeAfterFunc + defer func() { + pfinternal.TimeAfterFunc = originalTimer + }() + triggerTimer, timeAfter := mockTimer() + pfinternal.TimeAfterFunc = timeAfter + + dialer := testutils.NewBlockingDialer() + opts := []grpc.DialOption{ + grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, pickfirstleaf.Name)), + grpc.WithContextDialer(dialer.DialContext), + } + cc, rb, bm := setupPickFirstLeaf(t, 2, opts...) + addrs := bm.resolverAddrs() + holds := bm.holds(dialer) + rb.UpdateState(resolver.State{Addresses: addrs}) + cc.Connect() + + testutils.AwaitState(ctx, t, cc, connectivity.Connecting) + + // Verify that only the first server is contacted. + if holds[0].Wait(ctx) != true { + t.Fatalf("Timeout waiting for server %d with address %q to be contacted", 0, addrs[0]) + } + if holds[1].IsStarted() != false { + t.Fatalf("Server %d with address %q contacted unexpectedly", 1, addrs[1]) + } + + // Make the happy eyeballs timer fire once and verify that the + // second server is contacted. + triggerTimer() + if holds[1].Wait(ctx) != true { + t.Fatalf("Timeout waiting for server %d with address %q to be contacted", 1, addrs[1]) + } + + // Get the connection attempt to the second server to succeed and verify + // that the channel becomes READY. + holds[1].Resume() + testutils.AwaitState(ctx, t, cc, connectivity.Ready) +} + +// Test tests the pickfirst balancer by causing a SubConn to fail and then +// jumping to the 3rd SubConn after the happy eyeballs timer expires. +func (s) TestPickFirstLeaf_HappyEyeballs_TF_ThenTimerFires(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + originalTimer := pfinternal.TimeAfterFunc + defer func() { + pfinternal.TimeAfterFunc = originalTimer + }() + triggerTimer, timeAfter := mockTimer() + pfinternal.TimeAfterFunc = timeAfter + + dialer := testutils.NewBlockingDialer() + opts := []grpc.DialOption{ + grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, pickfirstleaf.Name)), + grpc.WithContextDialer(dialer.DialContext), + } + cc, rb, bm := setupPickFirstLeaf(t, 3, opts...) + addrs := bm.resolverAddrs() + holds := bm.holds(dialer) + rb.UpdateState(resolver.State{Addresses: addrs}) + cc.Connect() + + testutils.AwaitState(ctx, t, cc, connectivity.Connecting) + + // Verify that only the first server is contacted. + if holds[0].Wait(ctx) != true { + t.Fatalf("Timeout waiting for server %d with address %q to be contacted", 0, addrs[0]) + } + if holds[1].IsStarted() != false { + t.Fatalf("Server %d with address %q contacted unexpectedly", 1, addrs[1]) + } + if holds[2].IsStarted() != false { + t.Fatalf("Server %d with address %q contacted unexpectedly", 2, addrs[2]) + } + + // First SubConn Fails. + holds[0].Fail(fmt.Errorf("test error")) + + // Verify that only the second server is contacted. + if holds[1].Wait(ctx) != true { + t.Fatalf("Timeout waiting for server %d with address %q to be contacted", 1, addrs[1]) + } + if holds[2].IsStarted() != false { + t.Fatalf("Server %d with address %q contacted unexpectedly", 2, addrs[2]) + } + + // The happy eyeballs timer expires, pickfirst should stop waiting for + // server[1] to report a failure/success and request the creation of a third + // SubConn. + triggerTimer() + if holds[2].Wait(ctx) != true { + t.Fatalf("Timeout waiting for server %d with address %q to be contacted", 2, addrs[2]) + } + + // Get the connection attempt to the second server to succeed and verify + // that the channel becomes READY. + holds[1].Resume() + testutils.AwaitState(ctx, t, cc, connectivity.Ready) +} + func (s) TestPickFirstLeaf_InterleavingIPV4Preffered(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -1106,6 +1288,14 @@ func (b *backendManager) resolverAddrs() []resolver.Address { return addrs } +func (b *backendManager) holds(dialer *testutils.BlockingDialer) []*testutils.Hold { + holds := []*testutils.Hold{} + for _, addr := range b.resolverAddrs() { + holds = append(holds, dialer.Hold(addr.Addr)) + } + return holds +} + type ccStateSubscriber struct { transitions []connectivity.State } @@ -1113,3 +1303,26 @@ type ccStateSubscriber struct { func (c *ccStateSubscriber) OnMessage(msg any) { c.transitions = append(c.transitions, msg.(connectivity.State)) } + +// mockTimer returns a fake timeAfterFunc that will not trigger automatically. +// It returns a function that can be called to manually trigger the execution +// of the scheduled callback. +func mockTimer() (triggerFunc func(), timerFunc func(_ time.Duration, f func()) func()) { + timerCh := make(chan struct{}) + triggerFunc = func() { + timerCh <- struct{}{} + } + return triggerFunc, func(_ time.Duration, f func()) func() { + stopCh := make(chan struct{}) + go func() { + select { + case <-timerCh: + f() + case <-stopCh: + } + }() + return sync.OnceFunc(func() { + close(stopCh) + }) + } +} diff --git a/balancer/pickfirst/pickfirstleaf/pickfirstleaf_test.go b/balancer/pickfirst/pickfirstleaf/pickfirstleaf_test.go index 84b3cb65bed4..71984a238cd5 100644 --- a/balancer/pickfirst/pickfirstleaf/pickfirstleaf_test.go +++ b/balancer/pickfirst/pickfirstleaf/pickfirstleaf_test.go @@ -73,21 +73,8 @@ func (s) TestAddressList_Iteration(t *testing.T) { } addressList := addressList{} - emptyAddress := resolver.Address{} - if got, want := addressList.first(), emptyAddress; got != want { - t.Fatalf("addressList.first() = %v, want %v", got, want) - } - addressList.updateAddrs(addrs) - if got, want := addressList.first(), addressList.currentAddress(); got != want { - t.Fatalf("addressList.first() = %v, want %v", got, want) - } - - if got, want := addressList.first(), addrs[0]; got != want { - t.Fatalf("addressList.first() = %v, want %v", got, want) - } - for i := 0; i < len(addrs); i++ { if got, want := addressList.isValid(), true; got != want { t.Fatalf("addressList.isValid() = %t, want %t", got, want)