diff --git a/orca/producer_test.go b/orca/producer_test.go index ecaf57e0631e..5d29e1a0d322 100644 --- a/orca/producer_test.go +++ b/orca/producer_test.go @@ -228,7 +228,14 @@ func (f *fakeORCAService) close() { func (f *fakeORCAService) StreamCoreMetrics(req *v3orcaservicepb.OrcaLoadReportRequest, stream v3orcaservicegrpc.OpenRcaService_StreamCoreMetricsServer) error { f.reqCh <- req - for resp := range f.respCh { + for { + var resp any + select { + case resp = <-f.respCh: + case <-stream.Context().Done(): + return stream.Context().Err() + } + if err, ok := resp.(error); ok { return err } @@ -245,7 +252,6 @@ func (f *fakeORCAService) StreamCoreMetrics(req *v3orcaservicepb.OrcaLoadReportR return err } } - return nil } // TestProducerBackoff verifies that the ORCA producer applies the proper diff --git a/server.go b/server.go index a9e104adc0d9..e89c5ac6136c 100644 --- a/server.go +++ b/server.go @@ -136,7 +136,8 @@ type Server struct { quit *grpcsync.Event done *grpcsync.Event channelzRemoveOnce sync.Once - serveWG sync.WaitGroup // counts active Serve goroutines for GracefulStop + serveWG sync.WaitGroup // counts active Serve goroutines for Stop/GracefulStop + handlersWG sync.WaitGroup // counts active method handler goroutines channelzID *channelz.Identifier czData *channelzData @@ -173,6 +174,7 @@ type serverOptions struct { headerTableSize *uint32 numServerWorkers uint32 recvBufferPool SharedBufferPool + waitForHandlers bool } var defaultServerOptions = serverOptions{ @@ -570,6 +572,21 @@ func NumStreamWorkers(numServerWorkers uint32) ServerOption { }) } +// WaitForHandlers cause Stop to wait until all outstanding method handlers have +// exited before returning. If false, Stop will return as soon as all +// connections have closed, but method handlers may still be running. By +// default, Stop does not wait for method handlers to return. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func WaitForHandlers(w bool) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.waitForHandlers = w + }) +} + // RecvBufferPool returns a ServerOption that configures the server // to use the provided shared buffer pool for parsing incoming messages. Depending // on the application's workload, this could result in reduced memory allocation. @@ -1004,9 +1021,11 @@ func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams) st.HandleStreams(ctx, func(stream *transport.Stream) { + s.handlersWG.Add(1) streamQuota.acquire() f := func() { defer streamQuota.release() + defer s.handlersWG.Done() s.handleStream(st, stream) } @@ -1905,6 +1924,10 @@ func (s *Server) stop(graceful bool) { s.serverWorkerChannelClose() } + if graceful || s.opts.waitForHandlers { + s.handlersWG.Wait() + } + if s.events != nil { s.events.Finish() s.events = nil diff --git a/server_ext_test.go b/server_ext_test.go index c065e4ad42a8..7d9f1f5560a8 100644 --- a/server_ext_test.go +++ b/server_ext_test.go @@ -185,3 +185,148 @@ func (s) TestStreamWorkers_GracefulStopAndStop(t *testing.T) { ss.S.GracefulStop() } + +// Tests the WaitForHandlers ServerOption by leaving an RPC running while Stop +// is called, and ensures Stop doesn't return until the handler returns. +func (s) TestServer_WaitForHandlers(t *testing.T) { + started := grpcsync.NewEvent() + blockCalls := grpcsync.NewEvent() + + // This stub server does not properly respect the stream context, so it will + // not exit when the context is canceled. + ss := stubserver.StubServer{ + FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { + started.Fire() + <-blockCalls.Done() + return nil + }, + } + if err := ss.Start([]grpc.ServerOption{grpc.WaitForHandlers(true)}); err != nil { + t.Fatal("Error starting server:", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Start one RPC to the server. + ctx1, cancel1 := context.WithCancel(ctx) + _, err := ss.Client.FullDuplexCall(ctx1) + if err != nil { + t.Fatal("Error staring call:", err) + } + + // Wait for the handler to be invoked. + select { + case <-started.Done(): + case <-ctx.Done(): + t.Fatalf("Timed out waiting for RPC to start on server.") + } + + // Cancel it on the client. The server handler will still be running. + cancel1() + + // Close the connection. This might be sufficient to allow the server to + // return if it doesn't properly wait for outstanding method handlers to + // return. + ss.CC.Close() + + // Try to Stop() the server, which should block indefinitely (until + // blockCalls is fired). + stopped := grpcsync.NewEvent() + go func() { + ss.S.Stop() + stopped.Fire() + }() + + // Wait 100ms and ensure stopped does not fire. + select { + case <-stopped.Done(): + trace := make([]byte, 4096) + trace = trace[0:runtime.Stack(trace, true)] + blockCalls.Fire() + t.Fatalf("Server returned from Stop() illegally. Stack trace:\n%v", string(trace)) + case <-time.After(100 * time.Millisecond): + // Success; unblock the call and wait for stopped. + blockCalls.Fire() + } + + select { + case <-stopped.Done(): + case <-ctx.Done(): + t.Fatalf("Timed out waiting for second RPC to start on server.") + } +} + +// Tests that GracefulStop will wait for all method handlers to return by +// blocking a handler and ensuring GracefulStop doesn't return until after it is +// unblocked. +func (s) TestServer_GracefulStopWaits(t *testing.T) { + started := grpcsync.NewEvent() + blockCalls := grpcsync.NewEvent() + + // This stub server does not properly respect the stream context, so it will + // not exit when the context is canceled. + ss := stubserver.StubServer{ + FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { + started.Fire() + <-blockCalls.Done() + return nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatal("Error starting server:", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Start one RPC to the server. + ctx1, cancel1 := context.WithCancel(ctx) + _, err := ss.Client.FullDuplexCall(ctx1) + if err != nil { + t.Fatal("Error staring call:", err) + } + + // Wait for the handler to be invoked. + select { + case <-started.Done(): + case <-ctx.Done(): + t.Fatalf("Timed out waiting for RPC to start on server.") + } + + // Cancel it on the client. The server handler will still be running. + cancel1() + + // Close the connection. This might be sufficient to allow the server to + // return if it doesn't properly wait for outstanding method handlers to + // return. + ss.CC.Close() + + // Try to Stop() the server, which should block indefinitely (until + // blockCalls is fired). + stopped := grpcsync.NewEvent() + go func() { + ss.S.GracefulStop() + stopped.Fire() + }() + + // Wait 100ms and ensure stopped does not fire. + select { + case <-stopped.Done(): + trace := make([]byte, 4096) + trace = trace[0:runtime.Stack(trace, true)] + blockCalls.Fire() + t.Fatalf("Server returned from Stop() illegally. Stack trace:\n%v", string(trace)) + case <-time.After(100 * time.Millisecond): + // Success; unblock the call and wait for stopped. + blockCalls.Fire() + } + + select { + case <-stopped.Done(): + case <-ctx.Done(): + t.Fatalf("Timed out waiting for second RPC to start on server.") + } +}