From 6b9da27e58a4ed030a734e667c52cea9a212df6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Dunglas?= Date: Tue, 11 Oct 2022 17:30:00 +0200 Subject: [PATCH 1/2] net/http/httptest: add support for 1XX responses The existing implementation doesn't allow tracing 1xx responses. This patch allows using net/http/httptrace to inspect 1XX responses. Updates #26089. --- src/net/http/httptest/recorder.go | 18 ++++++++ src/net/http/httptest/recorder_test.go | 57 ++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go index 1c1d8801558ed..3ad206feb88a1 100644 --- a/src/net/http/httptest/recorder.go +++ b/src/net/http/httptest/recorder.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net/http" + "net/http/httptrace" "net/textproto" "strconv" "strings" @@ -42,6 +43,9 @@ type ResponseRecorder struct { // Flushed is whether the Handler called Flush. Flushed bool + // ClientTrace is used to trace 1XX responses + ClientTrace *httptrace.ClientTrace + result *http.Response // cache of Result's return value snapHeader http.Header // snapshot of HeaderMap at first Write wroteHeader bool @@ -146,6 +150,20 @@ func (rw *ResponseRecorder) WriteHeader(code int) { } checkWriteHeaderCode(code) + + if rw.ClientTrace != nil && code >= 100 && code < 200 { + if code == 100 { + rw.ClientTrace.Got100Continue() + } + // treat 101 as a terminal status, see issue 26161 + if code != http.StatusSwitchingProtocols { + if err := rw.ClientTrace.Got1xxResponse(code, textproto.MIMEHeader(rw.HeaderMap)); err != nil { + panic(err) + } + return + } + } + rw.Code = code rw.wroteHeader = true if rw.HeaderMap == nil { diff --git a/src/net/http/httptest/recorder_test.go b/src/net/http/httptest/recorder_test.go index 4782eced43e6c..c3b5cf7aa5bba 100644 --- a/src/net/http/httptest/recorder_test.go +++ b/src/net/http/httptest/recorder_test.go @@ -8,6 +8,8 @@ import ( "fmt" "io" "net/http" + "net/http/httptrace" + "net/textproto" "testing" ) @@ -369,3 +371,58 @@ func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) { }) } } + +func TestRecorderClientTrace(t *testing.T) { + handler := func(rw http.ResponseWriter, _ *http.Request) { + rw.WriteHeader(http.StatusContinue) + + rw.Header().Add("Foo", "bar") + rw.WriteHeader(http.StatusEarlyHints) + + rw.Header().Add("Baz", "bat") + } + + var received100, received103 bool + trace := &httptrace.ClientTrace{ + Got100Continue: func() { + received100 = true + }, + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + switch code { + case http.StatusContinue: + case http.StatusEarlyHints: + received103 = true + if header.Get("Foo") != "bar" { + t.Errorf(`Expected Foo=bar, got %s`, header.Get("Foo")) + } + if header.Get("Bar") != "" { + t.Error("Unexpected Bar header") + } + default: + t.Errorf("Unexpected status code %d", code) + } + + return nil + }, + } + + r, _ := http.NewRequest("GET", "http://example.org/", nil) + rw := NewRecorder() + rw.ClientTrace = trace + handler(rw, r) + + if !received100 { + t.Error("Got100Continue not called") + } + if !received103 { + t.Error("103 request not received") + } + + header := rw.Result().Header + if header.Get("Foo") != "bar" { + t.Errorf("Expected Foo=bar, got %s", header.Get("Foo")) + } + if header.Get("Baz") != "bat" { + t.Errorf("Expected Baz=bat, got %s", header.Get("Baz")) + } +} From f6e891ed395680c26530065fa5a87a95915bdc36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Dunglas?= Date: Wed, 23 Nov 2022 15:40:10 +0100 Subject: [PATCH 2/2] new API --- src/net/http/httptest/recorder.go | 41 ++++++------ src/net/http/httptest/recorder_test.go | 87 +++++++++----------------- 2 files changed, 53 insertions(+), 75 deletions(-) diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go index 3ad206feb88a1..4e38f1f449644 100644 --- a/src/net/http/httptest/recorder.go +++ b/src/net/http/httptest/recorder.go @@ -9,7 +9,6 @@ import ( "fmt" "io" "net/http" - "net/http/httptrace" "net/textproto" "strconv" "strings" @@ -17,6 +16,17 @@ import ( "golang.org/x/net/http/httpguts" ) +// InformationalResponse is an HTTP response sent with a [1xx status code]. +// +// [1xx status code]: https://httpwg.org/specs/rfc9110.html#status.1xx +type InformationalResponse struct { + // Code is the 1xx HTTP response code of this informational response. + Code int + + // Header contains the headers of this informational response. + Header http.Header +} + // ResponseRecorder is an implementation of http.ResponseWriter that // records its mutations for later inspection in tests. type ResponseRecorder struct { @@ -28,6 +38,9 @@ type ResponseRecorder struct { // method. Code int + // Informational HTTP responses (1xx status code) sent before the main response. + InformationalResponses []InformationalResponse + // HeaderMap contains the headers explicitly set by the Handler. // It is an internal detail. // @@ -43,9 +56,6 @@ type ResponseRecorder struct { // Flushed is whether the Handler called Flush. Flushed bool - // ClientTrace is used to trace 1XX responses - ClientTrace *httptrace.ClientTrace - result *http.Response // cache of Result's return value snapHeader http.Header // snapshot of HeaderMap at first Write wroteHeader bool @@ -151,24 +161,19 @@ func (rw *ResponseRecorder) WriteHeader(code int) { checkWriteHeaderCode(code) - if rw.ClientTrace != nil && code >= 100 && code < 200 { - if code == 100 { - rw.ClientTrace.Got100Continue() - } - // treat 101 as a terminal status, see issue 26161 - if code != http.StatusSwitchingProtocols { - if err := rw.ClientTrace.Got1xxResponse(code, textproto.MIMEHeader(rw.HeaderMap)); err != nil { - panic(err) - } - return - } + if rw.HeaderMap == nil { + rw.HeaderMap = make(http.Header) + } + + if code >= 100 && code < 200 { + ir := InformationalResponse{code, rw.HeaderMap.Clone()} + rw.InformationalResponses = append(rw.InformationalResponses, ir) + + return } rw.Code = code rw.wroteHeader = true - if rw.HeaderMap == nil { - rw.HeaderMap = make(http.Header) - } rw.snapHeader = rw.HeaderMap.Clone() } diff --git a/src/net/http/httptest/recorder_test.go b/src/net/http/httptest/recorder_test.go index c3b5cf7aa5bba..5fd48be78c17a 100644 --- a/src/net/http/httptest/recorder_test.go +++ b/src/net/http/httptest/recorder_test.go @@ -8,8 +8,7 @@ import ( "fmt" "io" "net/http" - "net/http/httptrace" - "net/textproto" + "reflect" "testing" ) @@ -125,6 +124,15 @@ func TestRecorder(t *testing.T) { return nil } } + hasInformationalResponses := func(ir []InformationalResponse) checkFunc { + return func(rec *ResponseRecorder) error { + if !reflect.DeepEqual(ir, rec.InformationalResponses) { + return fmt.Errorf("InformationalResponses = %v; want %v", rec.InformationalResponses, ir) + } + + return nil + } + } for _, tt := range [...]struct { name string @@ -296,6 +304,26 @@ func TestRecorder(t *testing.T) { check(hasResultContents("")), // check we don't crash reading the body }, + { + "1xx status code", + func(rw http.ResponseWriter, _ *http.Request) { + rw.WriteHeader(http.StatusContinue) + rw.Header().Add("Foo", "bar") + + rw.WriteHeader(http.StatusEarlyHints) + rw.Header().Add("Baz", "bat") + + rw.Header().Del("Foo") + }, + check( + hasInformationalResponses([]InformationalResponse{ + InformationalResponse{100, http.Header{}}, + InformationalResponse{103, http.Header{"Foo": []string{"bar"}}}, + }), + hasHeader("Baz", "bat"), + hasNotHeaders("Foo"), + ), + }, } { t.Run(tt.name, func(t *testing.T) { r, _ := http.NewRequest("GET", "http://foo.com/", nil) @@ -371,58 +399,3 @@ func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) { }) } } - -func TestRecorderClientTrace(t *testing.T) { - handler := func(rw http.ResponseWriter, _ *http.Request) { - rw.WriteHeader(http.StatusContinue) - - rw.Header().Add("Foo", "bar") - rw.WriteHeader(http.StatusEarlyHints) - - rw.Header().Add("Baz", "bat") - } - - var received100, received103 bool - trace := &httptrace.ClientTrace{ - Got100Continue: func() { - received100 = true - }, - Got1xxResponse: func(code int, header textproto.MIMEHeader) error { - switch code { - case http.StatusContinue: - case http.StatusEarlyHints: - received103 = true - if header.Get("Foo") != "bar" { - t.Errorf(`Expected Foo=bar, got %s`, header.Get("Foo")) - } - if header.Get("Bar") != "" { - t.Error("Unexpected Bar header") - } - default: - t.Errorf("Unexpected status code %d", code) - } - - return nil - }, - } - - r, _ := http.NewRequest("GET", "http://example.org/", nil) - rw := NewRecorder() - rw.ClientTrace = trace - handler(rw, r) - - if !received100 { - t.Error("Got100Continue not called") - } - if !received103 { - t.Error("103 request not received") - } - - header := rw.Result().Header - if header.Get("Foo") != "bar" { - t.Errorf("Expected Foo=bar, got %s", header.Get("Foo")) - } - if header.Get("Baz") != "bat" { - t.Errorf("Expected Baz=bat, got %s", header.Get("Baz")) - } -}