From a4630153038d3cb8c57f83d95200aea356145cf5 Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Sat, 1 Sep 2018 02:43:59 +0000 Subject: [PATCH] [release-branch.go1.11] http2: don't leak streams on broken body Updates golang/go#28673 Change-Id: I5d9a643f33d27d33b24f670c98f5a51aa6000967 GitHub-Last-Rev: 3ac4a573b62846ef4944599085218e119819383c GitHub-Pull-Request: golang/net#18 Reviewed-on: https://go-review.googlesource.com/c/132715 Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot Reviewed-by: Brad Fitzpatrick (cherry picked from commit 1c5f79cfb1642860bbe00b6cfce66700c01e04f6) Reviewed-on: https://go-review.googlesource.com/c/154237 --- http2/transport.go | 2 + http2/transport_test.go | 96 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/http2/transport.go b/http2/transport.go index 9d1f2fadd..ef356d6d9 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -1060,6 +1060,7 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf default: } if err != nil { + cc.forgetStreamID(cs.ID) return nil, cs.getStartedWrite(), err } bodyWritten = true @@ -1181,6 +1182,7 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) ( sawEOF = true err = nil } else if err != nil { + cc.writeStreamReset(cs.ID, ErrCodeCancel, err) return err } diff --git a/http2/transport_test.go b/http2/transport_test.go index 5b5c0768f..2c0f53e5c 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -4183,3 +4183,99 @@ func TestNoDialH2RoundTripperType(t *testing.T) { t.Fatalf("wrong kind %T; want *Transport", v.Interface()) } } + +type errReader struct { + body []byte + err error +} + +func (r *errReader) Read(p []byte) (int, error) { + if len(r.body) > 0 { + n := copy(p, r.body) + r.body = r.body[n:] + return n, nil + } + return 0, r.err +} + +func testTransportBodyReadError(t *testing.T, body []byte) { + clientDone := make(chan struct{}) + ct := newClientTester(t) + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + defer close(clientDone) + + checkNoStreams := func() error { + cp, ok := ct.tr.connPool().(*clientConnPool) + if !ok { + return fmt.Errorf("conn pool is %T; want *clientConnPool", ct.tr.connPool()) + } + cp.mu.Lock() + defer cp.mu.Unlock() + conns, ok := cp.conns["dummy.tld:443"] + if !ok { + return fmt.Errorf("missing connection") + } + if len(conns) != 1 { + return fmt.Errorf("conn pool size: %v; expect 1", len(conns)) + } + if activeStreams(conns[0]) != 0 { + return fmt.Errorf("active streams count: %v; want 0", activeStreams(conns[0])) + } + return nil + } + bodyReadError := errors.New("body read error") + body := &errReader{body, bodyReadError} + req, err := http.NewRequest("PUT", "https://dummy.tld/", body) + if err != nil { + return err + } + _, err = ct.tr.RoundTrip(req) + if err != bodyReadError { + return fmt.Errorf("err = %v; want %v", err, bodyReadError) + } + if err = checkNoStreams(); err != nil { + return err + } + return nil + } + ct.server = func() error { + ct.greet() + var receivedBody []byte + var resetCount int + for { + f, err := ct.fr.ReadFrame() + if err != nil { + select { + case <-clientDone: + // If the client's done, it + // will have reported any + // errors on its side. + if bytes.Compare(receivedBody, body) != 0 { + return fmt.Errorf("body: %v; expected %v", receivedBody, body) + } + if resetCount != 1 { + return fmt.Errorf("stream reset count: %v; expected: 1", resetCount) + } + return nil + default: + return err + } + } + switch f := f.(type) { + case *WindowUpdateFrame, *SettingsFrame: + case *HeadersFrame: + case *DataFrame: + receivedBody = append(receivedBody, f.Data()...) + case *RSTStreamFrame: + resetCount++ + default: + return fmt.Errorf("Unexpected client frame %v", f) + } + } + } + ct.run() +} + +func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) } +func TestTransportBodyReadError_Some(t *testing.T) { testTransportBodyReadError(t, []byte("123")) }