Skip to content

Commit

Permalink
add BlockRealTransportUnsafeMethods
Browse files Browse the repository at this point in the history
  • Loading branch information
fornellas-udemy authored and dnaeon committed Aug 8, 2024
1 parent 4bc3b10 commit 90b25a8
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 4 deletions.
56 changes: 52 additions & 4 deletions recorder/recorder.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,27 @@ func NewHook(handler HookFunc, kind HookKind) *Hook {
// otherwise.
type PassthroughFunc func(req *http.Request) bool

// ErrUnsafeRequestMethod is returned when Options.BlockRealTransportUnsafeMethods is true, and
// an request with an unsafe request is made.
var ErrUnsafeRequestMethod = errors.New("request has unsafe method")

type blockUnsafeMethodsRoundTripper struct {
RoundTripper http.RoundTripper
}

func (r *blockUnsafeMethodsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
safeMethods := map[string]bool{
http.MethodGet: true,
http.MethodHead: true,
http.MethodOptions: true,
http.MethodTrace: true,
}
if _, ok := safeMethods[req.Method]; !ok {
return nil, ErrUnsafeRequestMethod
}
return r.RoundTripper.RoundTrip(req)
}

// Option represents the Recorder options
type Options struct {
// CassetteName is the name of the cassette
Expand All @@ -157,6 +178,14 @@ type Options struct {
// the real requests
RealTransport http.RoundTripper

// Block unsafe methods from ever being called with RealTransport.
// The definition of "Safe Methods" comes from
// https://datatracker.ietf.org/doc/html/rfc9110#name-safe-methods
// and means that Safe Methods SHOULD NOT have side effects on the server.
// The use case for this flag is to prevent unsafe methods being used when executing tests
// thare are known to be "read-only".
BlockRealTransportUnsafeMethods bool

// SkipRequestLatency, if set to true will not simulate the
// latency of the recorded interaction. When set to false
// (default) it will block for the period of time taken by the
Expand Down Expand Up @@ -255,6 +284,15 @@ func NewWithOptions(opts *Options) (*Recorder, error) {
}
}

func (rec *Recorder) getRoundTripper() http.RoundTripper {
if rec.options.BlockRealTransportUnsafeMethods {
return &blockUnsafeMethodsRoundTripper{
RoundTripper: rec.options.RealTransport,
}
}
return rec.options.RealTransport
}

// Proxies client requests to their original destination
func (rec *Recorder) requestHandler(r *http.Request) (*cassette.Interaction, error) {
if err := r.Context().Err(); err != nil {
Expand Down Expand Up @@ -328,7 +366,7 @@ func (rec *Recorder) requestHandler(r *http.Request) (*cassette.Interaction, err
// Perform request to it's original destination and record the interactions
var start time.Time
start = time.Now()
resp, err := rec.options.RealTransport.RoundTrip(r)
resp, err := rec.getRoundTripper().RoundTrip(r)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -448,17 +486,27 @@ func (rec *Recorder) SetRealTransport(t http.RoundTripper) {
rec.options.RealTransport = t
}

// Block unsafe methods from ever being called with RealTransport.
// The definition of "Safe Methods" comes from
// https://datatracker.ietf.org/doc/html/rfc9110#name-safe-methods
// and means that Safe Methods SHOULD NOT have side effects on the server.
// The use case for this flag is to prevent unsafe methods being used when executing tests
// thare are known to be "read-only".
func (rec *Recorder) SetBlockRealTransportUnsafeMethods(value bool) {
rec.options.BlockRealTransportUnsafeMethods = value
}

// RoundTrip implements the http.RoundTripper interface
func (rec *Recorder) RoundTrip(req *http.Request) (*http.Response, error) {
// Passthrough mode, use real transport
if rec.options.Mode == ModePassthrough {
return rec.options.RealTransport.RoundTrip(req)
return rec.getRoundTripper().RoundTrip(req)
}

// Apply passthrough handler functions
for _, passthroughFunc := range rec.passthroughs {
if passthroughFunc(req) {
return rec.options.RealTransport.RoundTrip(req)
return rec.getRoundTripper().RoundTrip(req)
}
}

Expand Down Expand Up @@ -491,7 +539,7 @@ func (rec *Recorder) CancelRequest(req *http.Request) {
type cancelableTransport interface {
CancelRequest(req *http.Request)
}
if ct, ok := rec.options.RealTransport.(cancelableTransport); ok {
if ct, ok := rec.getRoundTripper().(cancelableTransport); ok {
ct.CancelRequest(req)
}
}
Expand Down
95 changes: 95 additions & 0 deletions recorder/recorder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type testCase struct {
wantBody string
wantStatus int
wantContentLength int
wantError error
path string
}

Expand All @@ -61,6 +62,9 @@ func (tc testCase) run(client *http.Client, ctx context.Context, serverUrl strin

resp, err := client.Do(req.WithContext(ctx))
if err != nil {
if tc.wantError != nil && errors.Is(err, tc.wantError) {
return nil
}
return err
}
defer resp.Body.Close()
Expand Down Expand Up @@ -1227,6 +1231,97 @@ func TestRecordOnlyMode(t *testing.T) {
}
}

func TestBlockRealTransportUnsafeMethods(t *testing.T) {
// Set things up
tests := []testCase{
{
method: http.MethodGet,
wantBody: "GET go-vcr\n",
wantStatus: http.StatusOK,
wantContentLength: 11,
path: "/api/v1/foo",
},
{
method: http.MethodHead,
wantStatus: http.StatusOK,
wantContentLength: 12,
path: "/api/v1/bar",
},
{
method: http.MethodOptions,
wantBody: "OPTIONS go-vcr\n",
wantStatus: http.StatusOK,
wantContentLength: 15,
path: "/api/v1/foo",
},
{
method: http.MethodTrace,
wantBody: "TRACE go-vcr\n",
wantStatus: http.StatusOK,
wantContentLength: 13,
path: "/api/v1/foo",
},
{
method: http.MethodPost,
body: "foo",
wantError: recorder.ErrUnsafeRequestMethod,
path: "/api/v1/baz",
},
{
method: http.MethodPut,
body: "foo",
wantError: recorder.ErrUnsafeRequestMethod,
path: "/api/v1/baz",
},
{
method: http.MethodDelete,
wantError: recorder.ErrUnsafeRequestMethod,
path: "/api/v1/baz",
},
{
method: http.MethodConnect,
wantError: recorder.ErrUnsafeRequestMethod,
path: "/api/v1/baz",
},
{
method: http.MethodPatch,
body: "foo",
wantError: recorder.ErrUnsafeRequestMethod,
path: "/api/v1/baz",
},
}

server := newEchoHttpServer()
serverUrl := server.URL
defer server.Close()

cassPath, err := newCassettePath("test_record_only")
if err != nil {
t.Fatal(err)
}

// Create recorder
opts := &recorder.Options{
CassetteName: cassPath,
Mode: recorder.ModeRecordOnly,
BlockRealTransportUnsafeMethods: true,
}
rec, err := recorder.NewWithOptions(opts)
if err != nil {
t.Fatal(err)
}
defer rec.Stop()

// Run tests
ctx := context.Background()
client := rec.GetDefaultClient()
for _, test := range tests {
if err := test.run(client, ctx, serverUrl); err != nil {
t.Fatal(err)
}
}
}

func TestInvalidRecorderMode(t *testing.T) {
// Create recorder
opts := &recorder.Options{
Expand Down

0 comments on commit 90b25a8

Please sign in to comment.