diff --git a/src/net/http/readrequest_test.go b/src/net/http/readrequest_test.go index 1225d97edbe63c..4bf646b0a63ce0 100644 --- a/src/net/http/readrequest_test.go +++ b/src/net/http/readrequest_test.go @@ -380,6 +380,27 @@ var reqTests = []reqTest{ noTrailer, noError, }, + + // http2 client preface: + { + "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", + &Request{ + Method: "PRI", + URL: &url.URL{ + Path: "*", + }, + Header: Header{}, + Proto: "HTTP/2.0", + ProtoMajor: 2, + ProtoMinor: 0, + RequestURI: "*", + ContentLength: -1, + Close: true, + }, + noBody, + noTrailer, + noError, + }, } func TestReadRequest(t *testing.T) { diff --git a/src/net/http/request.go b/src/net/http/request.go index 9cf2d2576fbe3e..d9ebb26dfc6123 100644 --- a/src/net/http/request.go +++ b/src/net/http/request.go @@ -343,6 +343,12 @@ func (r *Request) multipartReader() (*multipart.Reader, error) { return multipart.NewReader(r.Body, boundary), nil } +// isH2Upgrade reports whether r represents the http2 "client preface" +// magic string. +func (r *Request) isH2Upgrade() bool { + return r.Method == "PRI" && len(r.Header) == 0 && r.URL.Path == "*" && r.Proto == "HTTP/2.0" +} + // Return value if nonempty, def otherwise. func valueOrDefault(value, def string) string { if value != "" { @@ -794,6 +800,16 @@ func readRequest(b *bufio.Reader, deleteHostHeader bool) (req *Request, err erro return nil, err } + if req.isH2Upgrade() { + // Because it's neither chunked, nor declared: + req.ContentLength = -1 + + // We want to give handlers a chance to hijack the + // connection, but we need to prevent the Server from + // dealing with the connection further if it's not + // hijacked. Set Close to ensure that: + req.Close = true + } return req, nil } diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index c49262201a872f..638ba5f48f6373 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -741,6 +741,13 @@ func TestHandlersCanSetConnectionClose10(t *testing.T) { })) } +func TestHTTP2UpgradeClosesConnection(t *testing.T) { + testTCPConnectionCloses(t, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + // Nothing. (if not hijacked, the server should close the connection + // afterwards) + })) +} + func TestSetsRemoteAddr_h1(t *testing.T) { testSetsRemoteAddr(t, h1Mode) } func TestSetsRemoteAddr_h2(t *testing.T) { testSetsRemoteAddr(t, h2Mode) } @@ -3877,10 +3884,17 @@ func TestServerValidatesHostHeader(t *testing.T) { {"HTTP/1.0", "", 200}, {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400}, {"HTTP/1.0", "Host: \xff\r\n", 400}, + + // Make an exception for HTTP upgrade requests: + {"PRI * HTTP/2.0", "", 200}, } for _, tt := range tests { conn := &testConn{closec: make(chan bool, 1)} - io.WriteString(&conn.readBuf, "GET / "+tt.proto+"\r\n"+tt.host+"\r\n") + methodTarget := "GET / " + if !strings.HasPrefix(tt.proto, "HTTP/") { + methodTarget = "" + } + io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n") ln := &oneConnListener{conn} go Serve(ln, HandlerFunc(func(ResponseWriter, *Request) {})) @@ -3896,6 +3910,45 @@ func TestServerValidatesHostHeader(t *testing.T) { } } +func TestServerHandlersCanHandleH2PRI(t *testing.T) { + const upgradeResponse = "upgrade here" + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + conn, br, err := w.(Hijacker).Hijack() + defer conn.Close() + if r.Method != "PRI" || r.RequestURI != "*" { + t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI) + return + } + if !r.Close { + t.Errorf("Request.Close = true; want false") + } + const want = "SM\r\n\r\n" + buf := make([]byte, len(want)) + n, err := io.ReadFull(br, buf) + if err != nil || string(buf[:n]) != want { + t.Errorf("Read = %v, %v (%q), want %q", n, err, buf[:n], want) + return + } + io.WriteString(conn, upgradeResponse) + })) + defer ts.Close() + + c, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer c.Close() + io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") + slurp, err := ioutil.ReadAll(c) + if err != nil { + t.Fatal(err) + } + if string(slurp) != upgradeResponse { + t.Errorf("Handler response = %q; want %q", slurp, upgradeResponse) + } +} + // Test that we validate the valid bytes in HTTP/1 headers. // Issue 11207. func TestServerValidatesHeaders(t *testing.T) { diff --git a/src/net/http/server.go b/src/net/http/server.go index 17c2890aa7b351..5718cafbc3d728 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -714,7 +714,8 @@ func (c *conn) readRequest() (w *response, err error) { c.r.setInfiniteReadLimit() hosts, haveHost := req.Header["Host"] - if req.ProtoAtLeast(1, 1) && (!haveHost || len(hosts) == 0) { + isH2Upgrade := req.isH2Upgrade() + if req.ProtoAtLeast(1, 1) && (!haveHost || len(hosts) == 0) && !isH2Upgrade { return nil, badRequestError("missing required Host header") } if len(hosts) > 1 { @@ -748,6 +749,9 @@ func (c *conn) readRequest() (w *response, err error) { handlerHeader: make(Header), contentLength: -1, } + if isH2Upgrade { + w.closeAfterReply = true + } w.cw.res = w w.w = newBufioWriterSize(&w.cw, bufferBeforeChunkingSize) return w, nil