diff --git a/http2/server.go b/http2/server.go index de31d72b2..293324574 100644 --- a/http2/server.go +++ b/http2/server.go @@ -1696,6 +1696,7 @@ func (sc *serverConn) processData(f *DataFrame) error { if len(data) > 0 { wrote, err := st.body.Write(data) if err != nil { + sc.sendWindowUpdate(nil, int(f.Length)-wrote) return streamError(id, ErrCodeStreamClosed) } if wrote != len(data) { diff --git a/http2/server_test.go b/http2/server_test.go index c4f1b1a81..3593f9b5d 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -4098,3 +4098,62 @@ func TestContentEncodingNoSniffing(t *testing.T) { }) } } + +func TestServerWindowUpdateOnBodyClose(t *testing.T) { + const content = "12345678" + blockCh := make(chan bool) + errc := make(chan error, 1) + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + buf := make([]byte, 4) + n, err := io.ReadFull(r.Body, buf) + if err != nil { + errc <- err + return + } + if n != len(buf) { + errc <- fmt.Errorf("too few bytes read: %d", n) + return + } + blockCh <- true + <-blockCh + errc <- nil + }) + defer st.Close() + + st.greet() + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: st.encodeHeader( + ":method", "POST", + "content-length", strconv.Itoa(len(content)), + ), + EndStream: false, // to say DATA frames are coming + EndHeaders: true, + }) + st.writeData(1, false, []byte(content[:5])) + <-blockCh + st.stream(1).body.CloseWithError(io.EOF) + st.writeData(1, false, []byte(content[5:])) + blockCh <- true + + increments := len(content) + for { + f, err := st.readFrame() + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + if wu, ok := f.(*WindowUpdateFrame); ok && wu.StreamID == 0 { + increments -= int(wu.Increment) + if increments == 0 { + break + } + } + } + + if err := <-errc; err != nil { + t.Error(err) + } +}