diff --git a/orca/producer_test.go b/orca/producer_test.go index ecaf57e0631e..5d29e1a0d322 100644 --- a/orca/producer_test.go +++ b/orca/producer_test.go @@ -228,7 +228,14 @@ func (f *fakeORCAService) close() { func (f *fakeORCAService) StreamCoreMetrics(req *v3orcaservicepb.OrcaLoadReportRequest, stream v3orcaservicegrpc.OpenRcaService_StreamCoreMetricsServer) error { f.reqCh <- req - for resp := range f.respCh { + for { + var resp any + select { + case resp = <-f.respCh: + case <-stream.Context().Done(): + return stream.Context().Err() + } + if err, ok := resp.(error); ok { return err } @@ -245,7 +252,6 @@ func (f *fakeORCAService) StreamCoreMetrics(req *v3orcaservicepb.OrcaLoadReportR return err } } - return nil } // TestProducerBackoff verifies that the ORCA producer applies the proper diff --git a/server.go b/server.go index 682fa1831ec8..e22c22122fc0 100644 --- a/server.go +++ b/server.go @@ -1009,10 +1009,13 @@ func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, }() streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams) + wg := &sync.WaitGroup{} st.HandleStreams(ctx, func(stream *transport.Stream) { + wg.Add(1) streamQuota.acquire() f := func() { defer streamQuota.release() + defer wg.Done() s.handleStream(st, stream) } @@ -1026,6 +1029,7 @@ func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, } go f() }) + wg.Wait() } var _ http.Handler = (*Server)(nil) diff --git a/server_ext_test.go b/server_ext_test.go index c065e4ad42a8..d61d431ba4e2 100644 --- a/server_ext_test.go +++ b/server_ext_test.go @@ -185,3 +185,73 @@ func (s) TestStreamWorkers_GracefulStopAndStop(t *testing.T) { ss.S.GracefulStop() } + +func (s) TestHandlersReturnBeforeStop(t *testing.T) { + started := grpcsync.NewEvent() + blockCalls := grpcsync.NewEvent() + + // This stub server does not properly respect the stream context, so it will + // not exit when the context is canceled. + ss := stubserver.StubServer{ + FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { + started.Fire() + <-blockCalls.Done() + return nil + }, + } + if err := ss.Start([]grpc.ServerOption{grpc.MaxConcurrentStreams(1)}); err != nil { + t.Fatal("Error starting server:", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Start one RPC to the server. + ctx1, cancel1 := context.WithCancel(ctx) + _, err := ss.Client.FullDuplexCall(ctx1) + if err != nil { + t.Fatal("Error staring call:", err) + } + + // Wait for the handler to be invoked. + select { + case <-started.Done(): + case <-ctx.Done(): + t.Fatalf("Timed out waiting for RPC to start on server.") + } + + // Cancel it on the client. The server handler will still be running. + cancel1() + + // Close the connection. This might be sufficient to allow the server to + // return if it doesn't properly wait for outstanding method handlers to + // return. + ss.CC.Close() + + // Try to Stop() the server, which should block indefinitely (until + // blockCalls is fired). + stopped := grpcsync.NewEvent() + go func() { + ss.S.Stop() + stopped.Fire() + }() + + // Wait 100ms and ensure stopped does not fire. + select { + case <-stopped.Done(): + trace := make([]byte, 4096) + trace = trace[0:runtime.Stack(trace, true)] + blockCalls.Fire() + t.Fatalf("Server returned from Stop() illegally. Stack trace:\n%v", string(trace)) + case <-time.After(100 * time.Millisecond): + // Success; unblock the call and wait for stopped. + blockCalls.Fire() + } + + select { + case <-stopped.Done(): + case <-ctx.Done(): + t.Fatalf("Timed out waiting for second RPC to start on server.") + } +} diff --git a/test/end2end_test.go b/test/end2end_test.go index 97a7f1812553..93a519146443 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -1035,7 +1035,7 @@ func (s) TestDetailedConnectionCloseErrorPropagatesToRpcError(t *testing.T) { // connection for the RPC to go out on initially, and that the TCP connection will shut down strictly after // the RPC has been started on it. <-rpcStartedOnServer - ss.S.Stop() + go ss.S.Stop() // The precise behavior of this test is subject to raceyness around the timing // of when TCP packets are sent from client to server, and when we tell the // server to stop, so we need to account for both possible error messages. diff --git a/test/gracefulstop_test.go b/test/gracefulstop_test.go index ecf07d984359..3a767e300b97 100644 --- a/test/gracefulstop_test.go +++ b/test/gracefulstop_test.go @@ -30,6 +30,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/status" @@ -269,7 +270,7 @@ func (s) TestGracefulStopBlocksUntilGRPCConnectionsTerminate(t *testing.T) { // TestStopAbortsBlockingGRPCCall ensures that when Stop() is called while an ongoing RPC // is blocking that: // - Stop() returns -// - and the RPC fails with an connection closed error on the client-side +// - and the RPC fails with an connection closed error on the client-side func (s) TestStopAbortsBlockingGRPCCall(t *testing.T) { unblockGRPCCall := make(chan struct{}) grpcCallExecuting := make(chan struct{}) @@ -298,8 +299,13 @@ func (s) TestStopAbortsBlockingGRPCCall(t *testing.T) { }() <-grpcCallExecuting - ss.S.Stop() + stopReturned := grpcsync.NewEvent() + go func() { + ss.S.Stop() + stopReturned.Fire() + }() - unblockGRPCCall <- struct{}{} <-grpcClientCallReturned + unblockGRPCCall <- struct{}{} + <-stopReturned.Done() }