From 902f8a7377ffe4dfd144f9854b07732df900a97d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Dunglas?= Date: Tue, 4 Jan 2022 21:30:37 +0100 Subject: [PATCH] net/http: reverseproxy: forward 1xx responses Support for 1xx responses has recently been merged in net/http (#42597). As discussed in this CL (https://go-review.googlesource.com/c/go/+/269997/comments/1ff70bef_c25a829a), support for forwarding 1xx responses in ReverseProxy has been extracted in this separate patch. According to RFC 7231, "a proxy MUST forward 1xx responses unless the proxy itself requested the generation of the 1xx response". Consequently, all received 1xx responses are automatically forwarded as long as the underlying transport supports ClientTrace.Got1xxResponse. Fixes #26088 Fixes #36734 --- src/net/http/httputil/reverseproxy.go | 30 ++++++++- src/net/http/httputil/reverseproxy_test.go | 76 ++++++++++++++++++++++ 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go index b5d3ce711097b1..9e952891841015 100644 --- a/src/net/http/httputil/reverseproxy.go +++ b/src/net/http/httputil/reverseproxy.go @@ -14,6 +14,7 @@ import ( "mime" "net" "net/http" + "net/http/httptrace" "net/http/internal/ascii" "net/textproto" "net/url" @@ -40,6 +41,9 @@ import ( // To prevent IP spoofing, be sure to delete any pre-existing // X-Forwarded-For header coming from the client or // an untrusted proxy. +// +// 1xx responses are forwarded to the client if the underlying +// transport supports ClientTrace.Got1xxResponse. type ReverseProxy struct { // Director must be a function which modifies // the request into a new request to be sent @@ -307,6 +311,23 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } } + var headerSet bool + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + h := rw.Header() + copyHeader(h, http.Header(header)) + rw.WriteHeader(code) + + // Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses + for k, _ := range h { + h.Del(k) + } + + return nil + }, + } + outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) + res, err := transport.RoundTrip(outreq) if err != nil { p.getErrorHandler()(rw, outreq, err) @@ -332,7 +353,14 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } - copyHeader(rw.Header(), res.Header) + h := rw.Header() + if headerSet { + for k, _ := range h { + h.Del(k) + } + } + + copyHeader(h, res.Header) // The "Trailer" header isn't included in the Transport's response, // at least for *http.Transport. Build it up from Trailer. diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go index 90e8903e9cfd5b..e33f3154c67cf1 100644 --- a/src/net/http/httputil/reverseproxy_test.go +++ b/src/net/http/httputil/reverseproxy_test.go @@ -16,7 +16,9 @@ import ( "log" "net/http" "net/http/httptest" + "net/http/httptrace" "net/http/internal/ascii" + "net/textproto" "net/url" "os" "reflect" @@ -1537,3 +1539,77 @@ func TestJoinURLPath(t *testing.T) { } } } + +func Test1xxResponses(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h.Add("Link", "; rel=preload; as=style") + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusEarlyHints) + + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusProcessing) + + w.Write([]byte("Hello")) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + checkLinkHeaders := func(t *testing.T, expected, got []string) { + t.Helper() + + if len(expected) != len(got) { + t.Errorf("Expected %d link headers; got %d", len(expected), len(got)) + } + + for i := range expected { + if expected[i] != got[i] { + t.Errorf("Expected %q link header; got %q", expected[i], got[i]) + } + } + } + + var respCounter uint8 + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + switch code { + case http.StatusEarlyHints: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script"}, header["Link"]) + case http.StatusProcessing: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, header["Link"]) + default: + t.Error("Unexpected 1xx response") + } + + respCounter++ + + return nil + }, + } + req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil) + + res, err := frontendClient.Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + + defer res.Body.Close() + + if respCounter != 2 { + t.Errorf("Excpected 2 1xx responses; got %d", respCounter) + } + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, res.Header["Link"]) + + body, _ := io.ReadAll(res.Body) + if string(body) != "Hello" { + t.Errorf("Read body %q; want Hello", body) + } +}