diff --git a/balancer_wrapper.go b/balancer_wrapper.go index 554fd3c64afe..5895732225d2 100644 --- a/balancer_wrapper.go +++ b/balancer_wrapper.go @@ -95,7 +95,7 @@ func newCCBalancerWrapper(cc *ClientConn) *ccBalancerWrapper { // it is safe to call into the balancer here. func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnState) error { errCh := make(chan error) - ok := ccb.serializer.Schedule(func(ctx context.Context) { + uccs := func(ctx context.Context) { defer close(errCh) if ctx.Err() != nil || ccb.balancer == nil { return @@ -110,17 +110,23 @@ func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnStat logger.Infof("error from balancer.UpdateClientConnState: %v", err) } errCh <- err - }) - if !ok { - return nil } + onFailure := func() { close(errCh) } + + // UpdateClientConnState can race with Close, and when the latter wins, the + // serializer is closed, and the attempt to schedule the callback will fail. + // It is acceptable to ignore this failure. But since we want to handle the + // state update in a blocking fashion (when we successfully schedule the + // callback), we have to use the ScheduleOr method and not the MaybeSchedule + // method on the serializer. + ccb.serializer.ScheduleOr(uccs, onFailure) return <-errCh } // resolverError is invoked by grpc to push a resolver error to the underlying // balancer. The call to the balancer is executed from the serializer. func (ccb *ccBalancerWrapper) resolverError(err error) { - ccb.serializer.Schedule(func(ctx context.Context) { + ccb.serializer.TrySchedule(func(ctx context.Context) { if ctx.Err() != nil || ccb.balancer == nil { return } @@ -136,7 +142,7 @@ func (ccb *ccBalancerWrapper) close() { ccb.closed = true ccb.mu.Unlock() channelz.Info(logger, ccb.cc.channelz, "ccBalancerWrapper: closing") - ccb.serializer.Schedule(func(context.Context) { + ccb.serializer.TrySchedule(func(context.Context) { if ccb.balancer == nil { return } @@ -148,7 +154,7 @@ func (ccb *ccBalancerWrapper) close() { // exitIdle invokes the balancer's exitIdle method in the serializer. func (ccb *ccBalancerWrapper) exitIdle() { - ccb.serializer.Schedule(func(ctx context.Context) { + ccb.serializer.TrySchedule(func(ctx context.Context) { if ctx.Err() != nil || ccb.balancer == nil { return } @@ -256,7 +262,7 @@ type acBalancerWrapper struct { // updateState is invoked by grpc to push a subConn state update to the // underlying balancer. func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolver.Address, err error) { - acbw.ccb.serializer.Schedule(func(ctx context.Context) { + acbw.ccb.serializer.TrySchedule(func(ctx context.Context) { if ctx.Err() != nil || acbw.ccb.balancer == nil { return } diff --git a/internal/grpcsync/callback_serializer.go b/internal/grpcsync/callback_serializer.go index f7f40a16acee..19b9d639275a 100644 --- a/internal/grpcsync/callback_serializer.go +++ b/internal/grpcsync/callback_serializer.go @@ -53,16 +53,28 @@ func NewCallbackSerializer(ctx context.Context) *CallbackSerializer { return cs } -// Schedule adds a callback to be scheduled after existing callbacks are run. +// TrySchedule tries to schedules the provided callback function f to be +// executed in the order it was added. This is a best-effort operation. If the +// context passed to NewCallbackSerializer was canceled before this method is +// called, the callback will not be scheduled. // // Callbacks are expected to honor the context when performing any blocking // operations, and should return early when the context is canceled. +func (cs *CallbackSerializer) TrySchedule(f func(ctx context.Context)) { + cs.callbacks.Put(f) +} + +// ScheduleOr schedules the provided callback function f to be executed in the +// order it was added. If the context passed to NewCallbackSerializer has been +// canceled before this method is called, the onFailure callback will be +// executed inline instead. // -// Return value indicates if the callback was successfully added to the list of -// callbacks to be executed by the serializer. It is not possible to add -// callbacks once the context passed to NewCallbackSerializer is cancelled. -func (cs *CallbackSerializer) Schedule(f func(ctx context.Context)) bool { - return cs.callbacks.Put(f) == nil +// Callbacks are expected to honor the context when performing any blocking +// operations, and should return early when the context is canceled. +func (cs *CallbackSerializer) ScheduleOr(f func(ctx context.Context), onFailure func()) { + if cs.callbacks.Put(f) != nil { + onFailure() + } } func (cs *CallbackSerializer) run(ctx context.Context) { diff --git a/internal/grpcsync/callback_serializer_test.go b/internal/grpcsync/callback_serializer_test.go index 7e2e50b1f04e..34aa5191612e 100644 --- a/internal/grpcsync/callback_serializer_test.go +++ b/internal/grpcsync/callback_serializer_test.go @@ -55,7 +55,7 @@ func (s) TestCallbackSerializer_Schedule_FIFO(t *testing.T) { mu.Lock() defer mu.Unlock() scheduleOrderCh <- id - cs.Schedule(func(ctx context.Context) { + cs.TrySchedule(func(ctx context.Context) { select { case <-ctx.Done(): return @@ -115,7 +115,7 @@ func (s) TestCallbackSerializer_Schedule_Concurrent(t *testing.T) { wg.Add(numCallbacks) for i := 0; i < numCallbacks; i++ { go func() { - cs.Schedule(func(context.Context) { + cs.TrySchedule(func(context.Context) { wg.Done() }) }() @@ -148,7 +148,7 @@ func (s) TestCallbackSerializer_Schedule_Close(t *testing.T) { // Schedule a callback which blocks until the context passed to it is // canceled. It also closes a channel to signal that it has started. firstCallbackStartedCh := make(chan struct{}) - cs.Schedule(func(ctx context.Context) { + cs.TrySchedule(func(ctx context.Context) { close(firstCallbackStartedCh) <-ctx.Done() }) @@ -159,9 +159,9 @@ func (s) TestCallbackSerializer_Schedule_Close(t *testing.T) { callbackCh := make(chan int, numCallbacks) for i := 0; i < numCallbacks; i++ { num := i - if !cs.Schedule(func(context.Context) { callbackCh <- num }) { - t.Fatal("Schedule failed to accept a callback when the serializer is yet to be closed") - } + callback := func(context.Context) { callbackCh <- num } + onFailure := func() { t.Fatal("Schedule failed to accept a callback when the serializer is yet to be closed") } + cs.ScheduleOr(callback, onFailure) } // Ensure that none of the newer callbacks are executed at this point. @@ -192,15 +192,15 @@ func (s) TestCallbackSerializer_Schedule_Close(t *testing.T) { } <-cs.Done() + // Ensure that a callback cannot be scheduled after the serializer is + // closed. done := make(chan struct{}) - if cs.Schedule(func(context.Context) { close(done) }) { - t.Fatal("Scheduled a callback after closing the serializer") - } - - // Ensure that the latest callback is executed at this point. + callback := func(context.Context) { t.Fatal("Scheduled a callback after closing the serializer") } + onFailure := func() { close(done) } + cs.ScheduleOr(callback, onFailure) select { - case <-time.After(defaultTestShortTimeout): + case <-time.After(defaultTestTimeout): + t.Fatal("Successfully scheduled callback after serializer is closed") case <-done: - t.Fatal("Newer callback executed when scheduled after closing serializer") } } diff --git a/internal/grpcsync/pubsub.go b/internal/grpcsync/pubsub.go index aef8cec1ab0c..6d8c2f518dff 100644 --- a/internal/grpcsync/pubsub.go +++ b/internal/grpcsync/pubsub.go @@ -77,7 +77,7 @@ func (ps *PubSub) Subscribe(sub Subscriber) (cancel func()) { if ps.msg != nil { msg := ps.msg - ps.cs.Schedule(func(context.Context) { + ps.cs.TrySchedule(func(context.Context) { ps.mu.Lock() defer ps.mu.Unlock() if !ps.subscribers[sub] { @@ -103,7 +103,7 @@ func (ps *PubSub) Publish(msg any) { ps.msg = msg for sub := range ps.subscribers { s := sub - ps.cs.Schedule(func(context.Context) { + ps.cs.TrySchedule(func(context.Context) { ps.mu.Lock() defer ps.mu.Unlock() if !ps.subscribers[s] { diff --git a/resolver_wrapper.go b/resolver_wrapper.go index c5fb45236faf..d25052457c1b 100644 --- a/resolver_wrapper.go +++ b/resolver_wrapper.go @@ -66,7 +66,7 @@ func newCCResolverWrapper(cc *ClientConn) *ccResolverWrapper { // any newly created ccResolverWrapper, except that close may be called instead. func (ccr *ccResolverWrapper) start() error { errCh := make(chan error) - ccr.serializer.Schedule(func(ctx context.Context) { + ccr.serializer.TrySchedule(func(ctx context.Context) { if ctx.Err() != nil { return } @@ -85,7 +85,7 @@ func (ccr *ccResolverWrapper) start() error { } func (ccr *ccResolverWrapper) resolveNow(o resolver.ResolveNowOptions) { - ccr.serializer.Schedule(func(ctx context.Context) { + ccr.serializer.TrySchedule(func(ctx context.Context) { if ctx.Err() != nil || ccr.resolver == nil { return } @@ -102,7 +102,7 @@ func (ccr *ccResolverWrapper) close() { ccr.closed = true ccr.mu.Unlock() - ccr.serializer.Schedule(func(context.Context) { + ccr.serializer.TrySchedule(func(context.Context) { if ccr.resolver == nil { return } diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer.go b/xds/internal/balancer/cdsbalancer/cdsbalancer.go index 798d1c09e381..438d08c4e190 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer.go @@ -309,8 +309,8 @@ func (b *cdsBalancer) UpdateClientConnState(state balancer.ClientConnState) erro b.lbCfg = lbCfg // Handle the update in a blocking fashion. - done := make(chan struct{}) - ok = b.serializer.Schedule(func(context.Context) { + errCh := make(chan error, 1) + callback := func(context.Context) { // A config update with a changed top-level cluster name means that none // of our old watchers make any sense any more. b.closeAllWatchers() @@ -319,20 +319,20 @@ func (b *cdsBalancer) UpdateClientConnState(state balancer.ClientConnState) erro // could end up creating more watchers if turns out to be an aggregate // cluster. b.createAndAddWatcherForCluster(lbCfg.ClusterName) - close(done) - }) - if !ok { + errCh <- nil + } + onFailure := func() { // The call to Schedule returns false *only* if the serializer has been // closed, which happens only when we receive an update after close. - return errBalancerClosed + errCh <- errBalancerClosed } - <-done - return nil + b.serializer.ScheduleOr(callback, onFailure) + return <-errCh } // ResolverError handles errors reported by the xdsResolver. func (b *cdsBalancer) ResolverError(err error) { - b.serializer.Schedule(func(context.Context) { + b.serializer.TrySchedule(func(context.Context) { // Resource not found error is reported by the resolver when the // top-level cluster resource is removed by the management server. if xdsresource.ErrType(err) == xdsresource.ErrorTypeResourceNotFound { @@ -364,7 +364,7 @@ func (b *cdsBalancer) closeAllWatchers() { // Close cancels the CDS watch, closes the child policy and closes the // cdsBalancer. func (b *cdsBalancer) Close() { - b.serializer.Schedule(func(ctx context.Context) { + b.serializer.TrySchedule(func(ctx context.Context) { b.closeAllWatchers() if b.childLB != nil { @@ -384,7 +384,7 @@ func (b *cdsBalancer) Close() { } func (b *cdsBalancer) ExitIdle() { - b.serializer.Schedule(func(context.Context) { + b.serializer.TrySchedule(func(context.Context) { if b.childLB == nil { b.logger.Warningf("Received ExitIdle with no child policy") return diff --git a/xds/internal/balancer/cdsbalancer/cluster_watcher.go b/xds/internal/balancer/cdsbalancer/cluster_watcher.go index 0b0d168376d7..be01e8f07a07 100644 --- a/xds/internal/balancer/cdsbalancer/cluster_watcher.go +++ b/xds/internal/balancer/cdsbalancer/cluster_watcher.go @@ -33,19 +33,19 @@ type clusterWatcher struct { } func (cw *clusterWatcher) OnUpdate(u *xdsresource.ClusterResourceData) { - cw.parent.serializer.Schedule(func(context.Context) { + cw.parent.serializer.TrySchedule(func(context.Context) { cw.parent.onClusterUpdate(cw.name, u.Resource) }) } func (cw *clusterWatcher) OnError(err error) { - cw.parent.serializer.Schedule(func(context.Context) { + cw.parent.serializer.TrySchedule(func(context.Context) { cw.parent.onClusterError(cw.name, err) }) } func (cw *clusterWatcher) OnResourceDoesNotExist() { - cw.parent.serializer.Schedule(func(context.Context) { + cw.parent.serializer.TrySchedule(func(context.Context) { cw.parent.onClusterResourceNotFound(cw.name) }) } diff --git a/xds/internal/balancer/clusterresolver/resource_resolver.go b/xds/internal/balancer/clusterresolver/resource_resolver.go index 151c54dae6d0..ab93a1fa5e6b 100644 --- a/xds/internal/balancer/clusterresolver/resource_resolver.go +++ b/xds/internal/balancer/clusterresolver/resource_resolver.go @@ -287,7 +287,7 @@ func (rr *resourceResolver) generateLocked() { } func (rr *resourceResolver) onUpdate() { - rr.serializer.Schedule(func(context.Context) { + rr.serializer.TrySchedule(func(context.Context) { rr.mu.Lock() rr.generateLocked() rr.mu.Unlock() diff --git a/xds/internal/resolver/serviceconfig.go b/xds/internal/resolver/serviceconfig.go index f5bfc500c11a..aec81489e5fc 100644 --- a/xds/internal/resolver/serviceconfig.go +++ b/xds/internal/resolver/serviceconfig.go @@ -182,7 +182,7 @@ func (cs *configSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*iresolver.RP if v := atomic.AddInt32(ref, -1); v == 0 { // This entry will be removed from activeClusters when // producing the service config for the empty update. - cs.r.serializer.Schedule(func(context.Context) { + cs.r.serializer.TrySchedule(func(context.Context) { cs.r.onClusterRefDownToZero() }) } @@ -326,7 +326,7 @@ func (cs *configSelector) stop() { // selector; we need another update to delete clusters from the config (if // we don't have another update pending already). if needUpdate { - cs.r.serializer.Schedule(func(context.Context) { + cs.r.serializer.TrySchedule(func(context.Context) { cs.r.onClusterRefDownToZero() }) } diff --git a/xds/internal/resolver/watch_service.go b/xds/internal/resolver/watch_service.go index abb3c2c5acf1..a4532eccc3ec 100644 --- a/xds/internal/resolver/watch_service.go +++ b/xds/internal/resolver/watch_service.go @@ -37,19 +37,19 @@ func newListenerWatcher(resourceName string, parent *xdsResolver) *listenerWatch } func (l *listenerWatcher) OnUpdate(update *xdsresource.ListenerResourceData) { - l.parent.serializer.Schedule(func(context.Context) { + l.parent.serializer.TrySchedule(func(context.Context) { l.parent.onListenerResourceUpdate(update.Resource) }) } func (l *listenerWatcher) OnError(err error) { - l.parent.serializer.Schedule(func(context.Context) { + l.parent.serializer.TrySchedule(func(context.Context) { l.parent.onListenerResourceError(err) }) } func (l *listenerWatcher) OnResourceDoesNotExist() { - l.parent.serializer.Schedule(func(context.Context) { + l.parent.serializer.TrySchedule(func(context.Context) { l.parent.onListenerResourceNotFound() }) } @@ -72,19 +72,19 @@ func newRouteConfigWatcher(resourceName string, parent *xdsResolver) *routeConfi } func (r *routeConfigWatcher) OnUpdate(update *xdsresource.RouteConfigResourceData) { - r.parent.serializer.Schedule(func(context.Context) { + r.parent.serializer.TrySchedule(func(context.Context) { r.parent.onRouteConfigResourceUpdate(r.resourceName, update.Resource) }) } func (r *routeConfigWatcher) OnError(err error) { - r.parent.serializer.Schedule(func(context.Context) { + r.parent.serializer.TrySchedule(func(context.Context) { r.parent.onRouteConfigResourceError(r.resourceName, err) }) } func (r *routeConfigWatcher) OnResourceDoesNotExist() { - r.parent.serializer.Schedule(func(context.Context) { + r.parent.serializer.TrySchedule(func(context.Context) { r.parent.onRouteConfigResourceNotFound(r.resourceName) }) } diff --git a/xds/internal/xdsclient/authority.go b/xds/internal/xdsclient/authority.go index 75500b102466..7214eb03c9b1 100644 --- a/xds/internal/xdsclient/authority.go +++ b/xds/internal/xdsclient/authority.go @@ -210,7 +210,7 @@ func (a *authority) updateResourceStateAndScheduleCallbacks(rType xdsresource.Ty for watcher := range state.watchers { watcher := watcher err := uErr.err - a.serializer.Schedule(func(context.Context) { watcher.OnError(err) }) + a.serializer.TrySchedule(func(context.Context) { watcher.OnError(err) }) } continue } @@ -225,7 +225,7 @@ func (a *authority) updateResourceStateAndScheduleCallbacks(rType xdsresource.Ty for watcher := range state.watchers { watcher := watcher resource := uErr.resource - a.serializer.Schedule(func(context.Context) { watcher.OnUpdate(resource) }) + a.serializer.TrySchedule(func(context.Context) { watcher.OnUpdate(resource) }) } } // Sync cache. @@ -300,7 +300,7 @@ func (a *authority) updateResourceStateAndScheduleCallbacks(rType xdsresource.Ty state.md = xdsresource.UpdateMetadata{Status: xdsresource.ServiceStatusNotExist} for watcher := range state.watchers { watcher := watcher - a.serializer.Schedule(func(context.Context) { watcher.OnResourceDoesNotExist() }) + a.serializer.TrySchedule(func(context.Context) { watcher.OnResourceDoesNotExist() }) } } } @@ -428,7 +428,7 @@ func (a *authority) newConnectionError(err error) { // Propagate the connection error from the transport layer to all watchers. for watcher := range state.watchers { watcher := watcher - a.serializer.Schedule(func(context.Context) { + a.serializer.TrySchedule(func(context.Context) { watcher.OnError(xdsresource.NewErrorf(xdsresource.ErrorTypeConnection, "xds: error received from xDS stream: %v", err)) }) } @@ -495,7 +495,7 @@ func (a *authority) watchResource(rType xdsresource.Type, resourceName string, w a.logger.Infof("Resource type %q with resource name %q found in cache: %s", rType.TypeName(), resourceName, state.cache.ToJSON()) } resource := state.cache - a.serializer.Schedule(func(context.Context) { watcher.OnUpdate(resource) }) + a.serializer.TrySchedule(func(context.Context) { watcher.OnUpdate(resource) }) } return func() { @@ -548,7 +548,7 @@ func (a *authority) handleWatchTimerExpiryLocked(rType xdsresource.Type, resourc state.md = xdsresource.UpdateMetadata{Status: xdsresource.ServiceStatusNotExist} for watcher := range state.watchers { watcher := watcher - a.serializer.Schedule(func(context.Context) { watcher.OnResourceDoesNotExist() }) + a.serializer.TrySchedule(func(context.Context) { watcher.OnResourceDoesNotExist() }) } } @@ -574,7 +574,7 @@ func (a *authority) triggerResourceNotFoundForTesting(rType xdsresource.Type, re state.md = xdsresource.UpdateMetadata{Status: xdsresource.ServiceStatusNotExist} for watcher := range state.watchers { watcher := watcher - a.serializer.Schedule(func(context.Context) { watcher.OnResourceDoesNotExist() }) + a.serializer.TrySchedule(func(context.Context) { watcher.OnResourceDoesNotExist() }) } } diff --git a/xds/internal/xdsclient/clientimpl_watchers.go b/xds/internal/xdsclient/clientimpl_watchers.go index 22b8eb0107c9..f064394aa41c 100644 --- a/xds/internal/xdsclient/clientimpl_watchers.go +++ b/xds/internal/xdsclient/clientimpl_watchers.go @@ -44,7 +44,7 @@ func (c *clientImpl) WatchResource(rType xdsresource.Type, resourceName string, if err := c.resourceTypes.maybeRegister(rType); err != nil { logger.Warningf("Watch registered for name %q of type %q which is already registered", rType.TypeName(), resourceName) - c.serializer.Schedule(func(context.Context) { watcher.OnError(err) }) + c.serializer.TrySchedule(func(context.Context) { watcher.OnError(err) }) return func() {} } @@ -54,7 +54,7 @@ func (c *clientImpl) WatchResource(rType xdsresource.Type, resourceName string, a, unref, err := c.findAuthority(n) if err != nil { logger.Warningf("Watch registered for name %q of type %q, authority %q is not found", rType.TypeName(), resourceName, n.Authority) - c.serializer.Schedule(func(context.Context) { watcher.OnError(err) }) + c.serializer.TrySchedule(func(context.Context) { watcher.OnError(err) }) return func() {} } cancelF := a.watchResource(rType, n.String(), watcher)