Skip to content

Commit

Permalink
net/http/httptest: add support for 1XX responses
Browse files Browse the repository at this point in the history
The existing implementation doesn't allow tracing
1xx responses.
This patch allows using net/http/httptrace to inspect
1XX responses.

Updates golang#26089.
  • Loading branch information
dunglas committed Nov 22, 2022
1 parent 9160e15 commit 6b9da27
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/net/http/httptest/recorder.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"io"
"net/http"
"net/http/httptrace"
"net/textproto"
"strconv"
"strings"
Expand Down Expand Up @@ -42,6 +43,9 @@ type ResponseRecorder struct {
// Flushed is whether the Handler called Flush.
Flushed bool

// ClientTrace is used to trace 1XX responses
ClientTrace *httptrace.ClientTrace

result *http.Response // cache of Result's return value
snapHeader http.Header // snapshot of HeaderMap at first Write
wroteHeader bool
Expand Down Expand Up @@ -146,6 +150,20 @@ func (rw *ResponseRecorder) WriteHeader(code int) {
}

checkWriteHeaderCode(code)

if rw.ClientTrace != nil && code >= 100 && code < 200 {
if code == 100 {
rw.ClientTrace.Got100Continue()
}
// treat 101 as a terminal status, see issue 26161
if code != http.StatusSwitchingProtocols {
if err := rw.ClientTrace.Got1xxResponse(code, textproto.MIMEHeader(rw.HeaderMap)); err != nil {
panic(err)
}
return
}
}

rw.Code = code
rw.wroteHeader = true
if rw.HeaderMap == nil {
Expand Down
57 changes: 57 additions & 0 deletions src/net/http/httptest/recorder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"fmt"
"io"
"net/http"
"net/http/httptrace"
"net/textproto"
"testing"
)

Expand Down Expand Up @@ -369,3 +371,58 @@ func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) {
})
}
}

func TestRecorderClientTrace(t *testing.T) {
handler := func(rw http.ResponseWriter, _ *http.Request) {
rw.WriteHeader(http.StatusContinue)

rw.Header().Add("Foo", "bar")
rw.WriteHeader(http.StatusEarlyHints)

rw.Header().Add("Baz", "bat")
}

var received100, received103 bool
trace := &httptrace.ClientTrace{
Got100Continue: func() {
received100 = true
},
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
switch code {
case http.StatusContinue:
case http.StatusEarlyHints:
received103 = true
if header.Get("Foo") != "bar" {
t.Errorf(`Expected Foo=bar, got %s`, header.Get("Foo"))
}
if header.Get("Bar") != "" {
t.Error("Unexpected Bar header")
}
default:
t.Errorf("Unexpected status code %d", code)
}

return nil
},
}

r, _ := http.NewRequest("GET", "http://example.org/", nil)
rw := NewRecorder()
rw.ClientTrace = trace
handler(rw, r)

if !received100 {
t.Error("Got100Continue not called")
}
if !received103 {
t.Error("103 request not received")
}

header := rw.Result().Header
if header.Get("Foo") != "bar" {
t.Errorf("Expected Foo=bar, got %s", header.Get("Foo"))
}
if header.Get("Baz") != "bat" {
t.Errorf("Expected Baz=bat, got %s", header.Get("Baz"))
}
}

0 comments on commit 6b9da27

Please sign in to comment.