diff --git a/xds/internal/balancer/clusterimpl/balancer_test.go b/xds/internal/balancer/clusterimpl/balancer_test.go index 5a4bb0f270b2..6a3b251b89e6 100644 --- a/xds/internal/balancer/clusterimpl/balancer_test.go +++ b/xds/internal/balancer/clusterimpl/balancer_test.go @@ -820,6 +820,104 @@ func (s) TestUpdateLRSServer(t *testing.T) { } } +type myPicker struct { + // You can add any necessary fields here to store + // state or configuration for your picker's behavior +} + +func (p *myPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + // Implement the picking logic here based on your test requirements + // For this example, let's assume a simple round-robin approach + + // Replace this with your actual subConn selection logic + var selectedSubConn balancer.SubConn + + return balancer.PickResult{ + SubConn: selectedSubConn, + Done: func(info balancer.DoneInfo) { + // Handle any post-pick actions if necessary + }, + }, nil +} + +// TestPickerUpdatedSynchronouslyOnConfigUpdate covers the case picker is updated +// synchronous on reciept of configuration update. +func (s) TestPickerUpdatedSynchronouslyOnConfigUpdate(t *testing.T) { + // Override the newConfigHook to ensure picker was updated. + clientConnUpdateDone := make(chan struct{}, 1) + origClientConnUpdateHook := clientConnUpdateHook + clientConnUpdateHook = func() { clientConnUpdateDone <- struct{}{} } + defer func() { clientConnUpdateHook = origClientConnUpdateHook }() + + defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName) + xdsC := fakeclient.NewClient() + + builder := balancer.Get(Name) + cc := testutils.NewBalancerClientConn(t) + b := builder.Build(cc, balancer.BuildOptions{}) + defer b.Close() + + // Create a stub balancer which waits for the cluster_impl policy to be + // closed before sending a picker update (upon receipt of a subConn state + // change). + const childPolicyName = "stubBalancer-TestPickerUpdateAfterClose" + stub.Register(childPolicyName, stub.BalancerFuncs{ + UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { + // Create a subConn which will be used later on to test the race + // between StateListener() and Close(). + bd.ClientConn.UpdateState(balancer.State{ + Picker: &myPicker{}, + }) + t.Logf("Picker sent from child policy.") + return nil + }, + }) + + const ( + dropReason = "test-dropping-category" + dropNumerator = 1 + dropDenominator = 2 + ) + testLRSServerConfig, err := bootstrap.ServerConfigForTesting(bootstrap.ServerConfigTestingOptions{ + URI: "trafficdirector.googleapis.com:443", + ChannelCreds: []bootstrap.ChannelCreds{{Type: "google_default"}}, + }) + if err != nil { + t.Fatalf("Failed to create LRS server config for testing: %v", err) + } + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + LoadReportingServer: testLRSServerConfig, + DropCategories: []DropConfig{{ + Category: dropReason, + RequestsPerMillion: million * dropNumerator / dropDenominator, + }}, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: childPolicyName, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultShortTestTimeout) + defer cancel() + select { + case <-cc.NewPickerCh: + case <-ctx.Done(): + t.Fatalf("Timed out waiting for the picker update on receipt of configuration update.") + } + + select { + case <-clientConnUpdateDone: + case <-ctx.Done(): + t.Fatal("Timed out waiting for client conn update to be completed.") + } +} + func assertString(f func() (string, error)) string { s, err := f() if err != nil { diff --git a/xds/internal/balancer/clusterimpl/clusterimpl.go b/xds/internal/balancer/clusterimpl/clusterimpl.go index b2ea2815e30b..d90c3785f6ae 100644 --- a/xds/internal/balancer/clusterimpl/clusterimpl.go +++ b/xds/internal/balancer/clusterimpl/clusterimpl.go @@ -56,6 +56,9 @@ const ( var ( connectedAddress = internal.ConnectedAddress.(func(balancer.SubConnState) resolver.Address) errBalancerClosed = fmt.Errorf("%s LB policy is closed", Name) + // Below function is no-op in actual code, but can be overridden in + // tests to give tests visibility into exactly when certain events happen. + clientConnUpdateHook = func() {} ) func init() { @@ -102,6 +105,12 @@ type clusterImplBalancer struct { lrsServer *bootstrap.ServerConfig loadWrapper *loadstore.Wrapper + // Set during UpdateClientConnState when pushing updates to child policies. + // Prevents state updates from child policies causing new pickers to be sent + // up the channel. Cleared after all child policies have processed the + // updates sent to them, after which a new picker is sent up the channel. + inhibitPickerUpdates bool + clusterNameMu sync.Mutex clusterName string @@ -231,16 +240,17 @@ func (b *clusterImplBalancer) updateClientConnState(s balancer.ClientConnState) return err } + b.inhibitPickerUpdates = true if b.config == nil || b.config.ChildPolicy.Name != newConfig.ChildPolicy.Name { if err := b.child.SwitchTo(bb); err != nil { return fmt.Errorf("error switching to child of type %q: %v", newConfig.ChildPolicy.Name, err) } } b.config = newConfig - + b.inhibitPickerUpdates = false b.telemetryLabels = newConfig.TelemetryLabels dc := b.handleDropAndRequestCount(newConfig) - if dc != nil && b.childState.Picker != nil { + if dc != nil && b.childState.Picker != nil && !b.inhibitPickerUpdates { b.ClientConn.UpdateState(balancer.State{ ConnectivityState: b.childState.ConnectivityState, Picker: b.newPicker(dc), @@ -259,6 +269,7 @@ func (b *clusterImplBalancer) UpdateClientConnState(s balancer.ClientConnState) errCh := make(chan error, 1) callback := func(context.Context) { errCh <- b.updateClientConnState(s) + clientConnUpdateHook() } onFailure := func() { // The call to Schedule returns false *only* if the serializer has been @@ -322,14 +333,16 @@ func (b *clusterImplBalancer) ExitIdle() { func (b *clusterImplBalancer) UpdateState(state balancer.State) { b.serializer.TrySchedule(func(context.Context) { b.childState = state - b.ClientConn.UpdateState(balancer.State{ - ConnectivityState: b.childState.ConnectivityState, - Picker: b.newPicker(&dropConfigs{ - drops: b.drops, - requestCounter: b.requestCounter, - requestCountMax: b.requestCountMax, - }), - }) + if !b.inhibitPickerUpdates { + b.ClientConn.UpdateState(balancer.State{ + ConnectivityState: b.childState.ConnectivityState, + Picker: b.newPicker(&dropConfigs{ + drops: b.drops, + requestCounter: b.requestCounter, + requestCountMax: b.requestCountMax, + }), + }) + } }) }