diff --git a/clientserver_test.go b/clientserver_test.go index 71b2a32c..3fc9fcf1 100644 --- a/clientserver_test.go +++ b/clientserver_test.go @@ -9,6 +9,7 @@ package http_test import ( "bytes" "compress/gzip" + "context" "crypto/rand" "crypto/sha1" "crypto/tls" @@ -19,7 +20,9 @@ import ( "net" . "net/http" "net/http/httptest" + "net/http/httptrace" "net/http/httputil" + "net/textproto" "net/url" "os" "reflect" @@ -1616,3 +1619,95 @@ func testIdentityTransferEncoding(t *testing.T, h2 bool) { t.Errorf("got response body = %q; want %q", got, want) } } + +func TestEarlyHintsRequest_h1(t *testing.T) { testEarlyHintsRequest(t, h1Mode) } +func TestEarlyHintsRequest_h2(t *testing.T) { testEarlyHintsRequest(t, h2Mode) } +func testEarlyHintsRequest(t *testing.T, h2 bool) { + defer afterTest(t) + if h2 { + t.Skip("Waiting for H2 support to be merged: https://go-review.googlesource.com/c/net/+/406494") + } + + var wg sync.WaitGroup + wg.Add(1) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + h := w.Header() + + h.Add("Content-Length", "123") // must be ignored + h.Add("Link", "; rel=preload; as=style") + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(StatusEarlyHints) + + wg.Wait() + + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(StatusEarlyHints) + + w.Write([]byte("Hello")) + })) + defer cst.close() + + checkLinkHeaders := func(t *testing.T, expected, got []string) { + t.Helper() + + if len(expected) != len(got) { + t.Errorf("got %d expected %d", len(got), len(expected)) + } + + for i := range expected { + if expected[i] != got[i] { + t.Errorf("got %q expected %q", got[i], expected[i]) + } + } + } + + checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) { + t.Helper() + + for _, h := range []string{"Content-Length", "Transfer-Encoding"} { + if v, ok := header[h]; ok { + t.Errorf("%s is %q; must not be sent", h, v) + } + } + } + + var respCounter uint8 + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + switch respCounter { + case 0: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script"}, header["Link"]) + checkExcludedHeaders(t, header) + + wg.Done() + case 1: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, header["Link"]) + checkExcludedHeaders(t, header) + + default: + t.Error("Unexpected 1xx response") + } + + respCounter++ + + return nil + }, + } + req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil) + + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, res.Header["Link"]) + if cl := res.Header.Get("Content-Length"); cl != "123" { + t.Errorf("Content-Length is %q; want 123", cl) + } + + body, _ := io.ReadAll(res.Body) + if string(body) != "Hello" { + t.Errorf("Read body %q; want Hello", body) + } +} diff --git a/serve_test.go b/serve_test.go index 404cca08..464e0f73 100644 --- a/serve_test.go +++ b/serve_test.go @@ -3873,7 +3873,7 @@ func testServerReaderFromOrder(t *testing.T, h2 bool) { // Issue 6157, Issue 6685 func TestCodesPreventingContentTypeAndBody(t *testing.T) { - for _, code := range []int{StatusNotModified, StatusNoContent, StatusContinue} { + for _, code := range []int{StatusNotModified, StatusNoContent} { ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) { if r.URL.Path == "/header" { w.Header().Set("Content-Length", "123") @@ -6725,3 +6725,35 @@ func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) { t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len()) } } + +func TestEarlyHints(t *testing.T) { + ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) { + h := w.Header() + h.Add("Link", "; rel=preload; as=style") + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(StatusEarlyHints) + + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(StatusEarlyHints) + + w.Write([]byte("stuff")) + })) + + got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org") + expected := "HTTP/1.1 103 Early Hints\r\nLink: ; rel=preload; as=style\r\nLink: ; rel=preload; as=script\r\n\r\nHTTP/1.1 103 Early Hints\r\nLink: ; rel=preload; as=style\r\nLink: ; rel=preload; as=script\r\nLink: ; rel=preload; as=script\r\n\r\nHTTP/1.1 200 OK\r\nLink: ; rel=preload; as=style\r\nLink: ; rel=preload; as=script\r\nLink: ; rel=preload; as=script\r\nDate: " // dynamic content expected + if !strings.Contains(got, expected) { + t.Errorf("unexpected response; got %q; should start by %q", got, expected) + } +} +func TestProcessing(t *testing.T) { + ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(StatusProcessing) + w.Write([]byte("stuff")) + })) + + got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org") + expected := "HTTP/1.1 102 Processing\r\n\r\nHTTP/1.1 200 OK\r\nDate: " // dynamic content expected + if !strings.Contains(got, expected) { + t.Errorf("unexpected response; got %q; should start by %q", got, expected) + } +} diff --git a/server.go b/server.go index d44b0fb2..bc3a4633 100644 --- a/server.go +++ b/server.go @@ -98,8 +98,8 @@ type ResponseWriter interface { // Handlers can set HTTP trailers. // // Changing the header map after a call to WriteHeader (or - // Write) has no effect unless the modified headers are - // trailers. + // Write) has no effect unless the HTTP status code was of the + // 1xx class or the modified headers are trailers. // // There are two ways to set Trailers. The preferred way is to // predeclare in the headers which trailers you will later @@ -144,13 +144,18 @@ type ResponseWriter interface { // If WriteHeader is not called explicitly, the first call to Write // will trigger an implicit WriteHeader(http.StatusOK). // Thus explicit calls to WriteHeader are mainly used to - // send error codes. + // send error codes or 1xx informational responses. // // The provided code must be a valid HTTP 1xx-5xx status code. - // Only one header may be written. Go does not currently - // support sending user-defined 1xx informational headers, - // with the exception of 100-continue response header that the - // Server sends automatically when the Request.Body is read. + // Any number of 1xx headers may be written, followed by at most + // one 2xx-5xx header. 1xx headers are sent immediately, but 2xx-5xx + // headers may be buffered. Use the Flusher interface to send + // buffered data. The header map is cleared when 2xx-5xx headers are + // sent, but not with 1xx headers. + // + // The server will automatically send a 100 (Continue) header + // on the first read from the request body if the request has + // an "Expect: 100-continue" header. WriteHeader(statusCode int) } @@ -420,7 +425,7 @@ type response struct { req *Request // request for this response reqBody io.ReadCloser cancelCtx context.CancelFunc // when ServeHTTP exits - wroteHeader bool // reply header has been (logically) written + wroteHeader bool // a non-1xx header has been (logically) written wroteContinue bool // 100 Continue response was written wants10KeepAlive bool // HTTP/1.0 w/ Connection "keep-alive" wantsClose bool // HTTP request has Connection "close" @@ -1100,8 +1105,7 @@ func checkWriteHeaderCode(code int) { // Issue 22880: require valid WriteHeader status codes. // For now we only enforce that it's three digits. // In the future we might block things over 599 (600 and above aren't defined - // at https://httpwg.org/specs/rfc7231.html#status.codes) - // and we might block under 200 (once we have more mature 1xx support). + // at https://httpwg.org/specs/rfc7231.html#status.codes). // But for now any three digits. // // We used to send "HTTP/1.1 000 0" on the wire in responses but there's @@ -1144,6 +1148,26 @@ func (w *response) WriteHeader(code int) { return } checkWriteHeaderCode(code) + + // Handle informational headers + if code >= 100 && code <= 199 { + // Prevent a potential race with an automatically-sent 100 Continue triggered by Request.Body.Read() + if code == 100 && w.canWriteContinue.isSet() { + w.writeContinueMu.Lock() + w.canWriteContinue.setFalse() + w.writeContinueMu.Unlock() + } + + writeStatusLine(w.conn.bufw, w.req.ProtoAtLeast(1, 1), code, w.statusBuf[:]) + + // Per RFC 8297 we must not clear the current header map + w.handlerHeader.WriteSubset(w.conn.bufw, excludedHeadersNoBody) + w.conn.bufw.Write(crlf) + w.conn.bufw.Flush() + + return + } + w.wroteHeader = true w.status = code diff --git a/transfer.go b/transfer.go index 7bea5866..6957b246 100644 --- a/transfer.go +++ b/transfer.go @@ -468,6 +468,7 @@ func bodyAllowedForStatus(status int) bool { var ( suppressedHeaders304 = []string{"Content-Type", "Content-Length", "Transfer-Encoding"} suppressedHeadersNoBody = []string{"Content-Length", "Transfer-Encoding"} + excludedHeadersNoBody = map[string]bool{"Content-Length": true, "Transfer-Encoding": true} ) func suppressedHeaders(status int) []string {