diff --git a/pkg/sql/distsqlrun/flow.go b/pkg/sql/distsqlrun/flow.go index c0cf89e6b123..494bb6f9d342 100644 --- a/pkg/sql/distsqlrun/flow.go +++ b/pkg/sql/distsqlrun/flow.go @@ -159,6 +159,11 @@ type Flow struct { localProcessors []LocalProcessor + // startedGoroutines specifies whether this flow started any goroutines. This + // is used in Wait() to avoid the overhead of waiting for non-existent + // goroutines. + startedGoroutines bool + localStreams map[StreamID]RowReceiver // inboundStreams are streams that receive data from other hosts; this map @@ -535,6 +540,7 @@ func (f *Flow) startInternal(ctx context.Context, doneFn func()) error { f.waitGroup.Add(1) go f.processors[i].Run(ctx, &f.waitGroup) } + f.startedGoroutines = len(f.startables) > 0 || len(f.processors) > 1 || !f.isLocal() return nil } @@ -562,6 +568,7 @@ func (f *Flow) StartAsync(ctx context.Context, doneFn func()) error { if len(f.processors) > 0 { f.waitGroup.Add(1) go f.processors[len(f.processors)-1].Run(ctx, &f.waitGroup) + f.startedGoroutines = true } return nil } @@ -588,6 +595,9 @@ func (f *Flow) StartSync(ctx context.Context, doneFn func()) error { // Wait waits for all the goroutines for this flow to exit. If the context gets // canceled before all goroutines exit, it calls f.cancel(). func (f *Flow) Wait() { + if !f.startedGoroutines { + return + } waitChan := make(chan struct{}) go func() { diff --git a/pkg/sql/distsqlrun/server.go b/pkg/sql/distsqlrun/server.go index 80ba73a09175..5082ed0a11df 100644 --- a/pkg/sql/distsqlrun/server.go +++ b/pkg/sql/distsqlrun/server.go @@ -501,7 +501,7 @@ func (ds *ServerImpl) RunSyncFlow(stream DistSQL_RunSyncFlowServer) error { if err := ds.Stopper.RunTask(ctx, "distsqlrun.ServerImpl: sync flow", func(ctx context.Context) { ctx, ctxCancel := contextutil.WithCancel(ctx) defer ctxCancel() - mbox.start(ctx, &f.waitGroup, ctxCancel) + f.startables = append(f.startables, mbox) ds.Metrics.FlowStart() if err := f.StartSync(ctx, func() {}); err != nil { log.Fatalf(ctx, "unexpected error from syncFlow.Start(): %s "+