diff --git a/server.go b/server.go index e4cc87534..01f8f7fb8 100644 --- a/server.go +++ b/server.go @@ -201,8 +201,7 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) { bw: newBufferedWriter(c), handler: h, streams: make(map[uint32]*stream), - readFrameCh: make(chan frameAndGate), - readFrameErrCh: make(chan error, 1), // must be buffered for 1 + readFrameCh: make(chan readFrameResult), wantWriteFrameCh: make(chan frameWriteMsg, 8), wroteFrameCh: make(chan struct{}, 1), // buffered; one send in reading goroutine bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way @@ -309,16 +308,6 @@ func (sc *serverConn) rejectConn(err ErrCode, debug string) { sc.conn.Close() } -// frameAndGates coordinates the readFrames and serve -// goroutines. Because the Framer interface only permits the most -// recently-read Frame from being accessed, the readFrames goroutine -// blocks until it has a frame, passes it to serve, and then waits for -// serve to be done with it before reading the next one. -type frameAndGate struct { - f Frame - g gate -} - type serverConn struct { // Immutable: srv *Server @@ -328,9 +317,8 @@ type serverConn struct { handler http.Handler framer *Framer hpackDecoder *hpack.Decoder - doneServing chan struct{} // closed when serverConn.serve ends - readFrameCh chan frameAndGate // written by serverConn.readFrames - readFrameErrCh chan error + doneServing chan struct{} // closed when serverConn.serve ends + readFrameCh chan readFrameResult // written by serverConn.readFrames wantWriteFrameCh chan frameWriteMsg // from handlers -> serve wroteFrameCh chan struct{} // from writeFrameAsync -> serve, tickles more frame writes bodyReadCh chan bodyReadMsg // from handlers -> serve @@ -541,24 +529,34 @@ func (sc *serverConn) canonicalHeader(v string) string { return cv } +type readFrameResult struct { + f Frame // valid until readMore is called + err error + + // readMore should be called once the consumer no longer needs or + // retains f. After readMore, f is invalid and more frames can be + // read. + readMore func() +} + // readFrames is the loop that reads incoming frames. +// It takes care to only read one frame at a time, blocking until the +// consumer is done with the frame. // It's run on its own goroutine. func (sc *serverConn) readFrames() { - g := make(gate, 1) + gate := make(gate) for { f, err := sc.framer.ReadFrame() - if err != nil { - sc.readFrameErrCh <- err - close(sc.readFrameCh) + select { + case sc.readFrameCh <- readFrameResult{f, err, gate.Done}: + case <-sc.doneServing: + return + } + select { + case <-gate: + case <-sc.doneServing: return } - sc.readFrameCh <- frameAndGate{f, g} - // We can't read another frame until this one is - // processed, as the ReadFrame interface doesn't copy - // memory. The Frame accessor methods access the last - // frame's (shared) buffer. So we wait for the - // serve goroutine to tell us it's done: - g.Wait() } } @@ -648,13 +646,11 @@ func (sc *serverConn) serve() { } sc.writingFrame = false sc.scheduleFrameWrite() - case fg, ok := <-sc.readFrameCh: - if !ok { - sc.readFrameCh = nil - } - if !sc.processFrameFromReader(fg, ok) { + case res := <-sc.readFrameCh: + if !sc.processFrameFromReader(res) { return } + res.readMore() if settingsTimer.C != nil { settingsTimer.Stop() settingsTimer.C = nil @@ -901,17 +897,15 @@ func (sc *serverConn) curHeaderStreamID() uint32 { // processFrameFromReader processes the serve loop's read from readFrameCh from the // frame-reading goroutine. // processFrameFromReader returns whether the connection should be kept open. -func (sc *serverConn) processFrameFromReader(fg frameAndGate, fgValid bool) bool { +func (sc *serverConn) processFrameFromReader(res readFrameResult) bool { sc.serveG.check() - var clientGone bool - var err error - if !fgValid { - err = <-sc.readFrameErrCh + err := res.err + if err != nil { if err == ErrFrameTooLarge { sc.goAway(ErrCodeFrameSize) return true // goAway will close the loop } - clientGone = err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") + clientGone := err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") if clientGone { // TODO: could we also get into this state if // the peer does a half close @@ -923,13 +917,10 @@ func (sc *serverConn) processFrameFromReader(fg frameAndGate, fgValid bool) bool // just for testing we could have a non-TLS mode. return false } - } - - if fgValid { - f := fg.f + } else { + f := res.f sc.vlogf("got %v: %#v", f.Header(), f) err = sc.processFrame(f) - fg.g.Done() // unblock the readFrames goroutine if err == nil { return true } @@ -947,13 +938,13 @@ func (sc *serverConn) processFrameFromReader(fg frameAndGate, fgValid bool) bool sc.goAway(ErrCode(ev)) return true // goAway will handle shutdown default: - if !fgValid { + if res.err != nil { sc.logf("disconnecting; error reading frame from client %s: %v", sc.conn.RemoteAddr(), err) } else { sc.logf("disconnection due to other error: %v", err) } + return false } - return false } func (sc *serverConn) processFrame(f Frame) error { diff --git a/server_test.go b/server_test.go index 5413f5569..6021435f9 100644 --- a/server_test.go +++ b/server_test.go @@ -34,6 +34,14 @@ import ( var stderrVerbose = flag.Bool("stderr_verbose", false, "Mirror verbosity to stderr, unbuffered") +func stderrv() io.Writer { + if *stderrVerbose { + return os.Stderr + } + + return ioutil.Discard +} + type serverTester struct { cc net.Conn // client conn t testing.TB @@ -106,13 +114,8 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} } st.hpackEnc = hpack.NewEncoder(&st.headerBuf) - var stderrv io.Writer = ioutil.Discard - if *stderrVerbose { - stderrv = os.Stderr - } - ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config - ts.Config.ErrorLog = log.New(io.MultiWriter(stderrv, twriter{t: t, st: st}, logBuf), "", log.LstdFlags) + ts.Config.ErrorLog = log.New(io.MultiWriter(stderrv(), twriter{t: t, st: st}, logBuf), "", log.LstdFlags) ts.StartTLS() if VerboseLogs { @@ -124,7 +127,7 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} st.sc = v st.sc.testHookCh = make(chan func()) } - log.SetOutput(io.MultiWriter(stderrv, twriter{t: t, st: st})) + log.SetOutput(io.MultiWriter(stderrv(), twriter{t: t, st: st})) if !onlyServer { cc, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig) if err != nil { @@ -2328,3 +2331,52 @@ func BenchmarkServerPosts(b *testing.B) { } } } + +// go-fuzz bug, originally reported at https://github.com/bradfitz/http2/issues/53 +// Verify we don't hang. +func TestIssue53(t *testing.T) { + const data = "PRI * HTTP/2.0\r\n\r\nSM" + + "\r\n\r\n\x00\x00\x00\x01\ainfinfin\ad" + s := &http.Server{ + ErrorLog: log.New(io.MultiWriter(stderrv(), twriter{t: t}), "", log.LstdFlags), + } + s2 := &Server{MaxReadFrameSize: 1 << 16, PermitProhibitedCipherSuites: true} + c := &issue53Conn{[]byte(data), false, false} + s2.handleConn(s, c, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Write([]byte("hello")) + })) + if !c.closed { + t.Fatal("connection is not closed") + } +} + +type issue53Conn struct { + data []byte + closed bool + written bool +} + +func (c *issue53Conn) Read(b []byte) (n int, err error) { + if len(c.data) == 0 { + return 0, io.EOF + } + n = copy(b, c.data) + c.data = c.data[n:] + return +} + +func (c *issue53Conn) Write(b []byte) (n int, err error) { + c.written = true + return len(b), nil +} + +func (c *issue53Conn) Close() error { + c.closed = true + return nil +} + +func (c *issue53Conn) LocalAddr() net.Addr { return &net.TCPAddr{net.IP{127, 0, 0, 1}, 49706, ""} } +func (c *issue53Conn) RemoteAddr() net.Addr { return &net.TCPAddr{net.IP{127, 0, 0, 1}, 49706, ""} } +func (c *issue53Conn) SetDeadline(t time.Time) error { return nil } +func (c *issue53Conn) SetReadDeadline(t time.Time) error { return nil } +func (c *issue53Conn) SetWriteDeadline(t time.Time) error { return nil }