diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index be371c6e0f73..28c77af70aba 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -78,6 +78,7 @@ type http2Client struct { framer *framer // controlBuf delivers all the control related tasks (e.g., window // updates, reset streams, and various settings) to the controller. + // Do not access controlBuf with mu held. controlBuf *controlBuffer fc *trInFlow // The scheme used: https if TLS is on, http otherwise. @@ -109,6 +110,7 @@ type http2Client struct { waitingStreams uint32 nextID uint32 + // Do not access controlBuf with mu held. mu sync.Mutex // guard the following variables state transportState activeStreams map[uint32]*Stream @@ -685,7 +687,6 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, cleanup(err) return err } - t.activeStreams[id] = s if channelz.IsOn() { atomic.AddInt64(&t.czData.streamsStarted, 1) atomic.StoreInt64(&t.czData.lastStreamCreatedTime, time.Now().UnixNano()) @@ -719,6 +720,13 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, t.nextID += 2 s.id = h.streamID s.fc = &inFlow{limit: uint32(t.initialWindowSize)} + t.mu.Lock() + if t.activeStreams == nil { // Can be niled from Close(). + t.mu.Unlock() + return false // Don't create a stream if the transport is already closed. + } + t.activeStreams[s.id] = s + t.mu.Unlock() if t.streamQuota > 0 && t.waitingStreams > 0 { select { case t.streamsQuotaAvailable <- struct{}{}: @@ -744,13 +752,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, } for { success, err := t.controlBuf.executeAndPut(func(it interface{}) bool { - if !checkForStreamQuota(it) { - return false - } - if !checkForHeaderListSize(it) { - return false - } - return true + return checkForHeaderListSize(it) && checkForStreamQuota(it) }, hdr) if err != nil { // Connection closed. @@ -1003,13 +1005,13 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) { // for the transport and the stream based on the current bdp // estimation. func (t *http2Client) updateFlowControl(n uint32) { - t.mu.Lock() - for _, s := range t.activeStreams { - s.fc.newLimit(n) - } - t.mu.Unlock() updateIWS := func(interface{}) bool { t.initialWindowSize = int32(n) + t.mu.Lock() + for _, s := range t.activeStreams { + s.fc.newLimit(n) + } + t.mu.Unlock() return true } t.controlBuf.executeAndPut(updateIWS, &outgoingWindowUpdate{streamID: 0, increment: t.fc.newLimit(n)}) @@ -1215,7 +1217,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { default: t.setGoAwayReason(f) close(t.goAway) - t.controlBuf.put(&incomingGoAway{}) + defer t.controlBuf.put(&incomingGoAway{}) // Defer as t.mu is currently held. // Notify the clientconn about the GOAWAY before we set the state to // draining, to allow the client to stop attempting to create streams // before disallowing new streams on this connection. diff --git a/test/end2end_test.go b/test/end2end_test.go index da0acbf3d75d..c44925f96a62 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -8041,3 +8041,39 @@ func (s) TestServerClosesConn(t *testing.T) { } t.Fatalf("timed out waiting for conns to be closed by server; still open: %v", atomic.LoadInt32(&wrapLis.connsOpen)) } + +// TestUnexpectedEOF tests a scenario where a client invokes two unary RPC +// calls. The first call receives a payload which exceeds max grpc receive +// message length, and the second gets a large response. This second RPC should +// not fail with unexpected.EOF. +func (s) TestUnexpectedEOF(t *testing.T) { + ss := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{ + Payload: &testpb.Payload{ + Body: bytes.Repeat([]byte("a"), int(in.ResponseSize)), + }, + }, nil + }, + } + if err := ss.Start([]grpc.ServerOption{}); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + for i := 0; i < 10; i++ { + // exceeds grpc.DefaultMaxRecvMessageSize, this should error with + // RESOURCE_EXHAUSTED error. + _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{ResponseSize: 4194304}) + if code := status.Code(err); code != codes.ResourceExhausted { + t.Fatalf("UnaryCall RPC returned error: %v, want status code %v", err, codes.ResourceExhausted) + } + // Larger response that doesn't exceed DefaultMaxRecvMessageSize, this + // should work normally. + if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{ResponseSize: 275075}); err != nil { + t.Fatalf("UnaryCall RPC failed: %v", err) + } + } +}