diff --git a/recorder/recorder.go b/recorder/recorder.go index 9dc9a87..0f0cace 100644 --- a/recorder/recorder.go +++ b/recorder/recorder.go @@ -145,6 +145,27 @@ func NewHook(handler HookFunc, kind HookKind) *Hook { // otherwise. type PassthroughFunc func(req *http.Request) bool +// ErrUnsafeRequestMethod is returned when Options.BlockRealTransportUnsafeMethods is true, and +// an request with an unsafe request is made. +var ErrUnsafeRequestMethod = errors.New("request has unsafe method") + +type blockUnsafeMethodsRoundTripper struct { + RoundTripper http.RoundTripper +} + +func (r *blockUnsafeMethodsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + safeMethods := map[string]bool{ + http.MethodGet: true, + http.MethodHead: true, + http.MethodOptions: true, + http.MethodTrace: true, + } + if _, ok := safeMethods[req.Method]; !ok { + return nil, ErrUnsafeRequestMethod + } + return r.RoundTripper.RoundTrip(req) +} + // Option represents the Recorder options type Options struct { // CassetteName is the name of the cassette @@ -157,6 +178,14 @@ type Options struct { // the real requests RealTransport http.RoundTripper + // Block unsafe methods from ever being called with RealTransport. + // The definition of "Safe Methods" comes from + // https://datatracker.ietf.org/doc/html/rfc9110#name-safe-methods + // and means that Safe Methods SHOULD NOT have side effects on the server. + // The use case for this flag is to prevent unsafe methods being used when executing tests + // thare are known to be "read-only". + BlockRealTransportUnsafeMethods bool + // SkipRequestLatency, if set to true will not simulate the // latency of the recorded interaction. When set to false // (default) it will block for the period of time taken by the @@ -255,6 +284,15 @@ func NewWithOptions(opts *Options) (*Recorder, error) { } } +func (rec *Recorder) getRoundTripper() http.RoundTripper { + if rec.options.BlockRealTransportUnsafeMethods { + return &blockUnsafeMethodsRoundTripper{ + RoundTripper: rec.options.RealTransport, + } + } + return rec.options.RealTransport +} + // Proxies client requests to their original destination func (rec *Recorder) requestHandler(r *http.Request) (*cassette.Interaction, error) { if err := r.Context().Err(); err != nil { @@ -328,7 +366,7 @@ func (rec *Recorder) requestHandler(r *http.Request) (*cassette.Interaction, err // Perform request to it's original destination and record the interactions var start time.Time start = time.Now() - resp, err := rec.options.RealTransport.RoundTrip(r) + resp, err := rec.getRoundTripper().RoundTrip(r) if err != nil { return nil, err } @@ -448,17 +486,27 @@ func (rec *Recorder) SetRealTransport(t http.RoundTripper) { rec.options.RealTransport = t } +// Block unsafe methods from ever being called with RealTransport. +// The definition of "Safe Methods" comes from +// https://datatracker.ietf.org/doc/html/rfc9110#name-safe-methods +// and means that Safe Methods SHOULD NOT have side effects on the server. +// The use case for this flag is to prevent unsafe methods being used when executing tests +// thare are known to be "read-only". +func (rec *Recorder) SetBlockRealTransportUnsafeMethods(value bool) { + rec.options.BlockRealTransportUnsafeMethods = value +} + // RoundTrip implements the http.RoundTripper interface func (rec *Recorder) RoundTrip(req *http.Request) (*http.Response, error) { // Passthrough mode, use real transport if rec.options.Mode == ModePassthrough { - return rec.options.RealTransport.RoundTrip(req) + return rec.getRoundTripper().RoundTrip(req) } // Apply passthrough handler functions for _, passthroughFunc := range rec.passthroughs { if passthroughFunc(req) { - return rec.options.RealTransport.RoundTrip(req) + return rec.getRoundTripper().RoundTrip(req) } } @@ -491,7 +539,7 @@ func (rec *Recorder) CancelRequest(req *http.Request) { type cancelableTransport interface { CancelRequest(req *http.Request) } - if ct, ok := rec.options.RealTransport.(cancelableTransport); ok { + if ct, ok := rec.getRoundTripper().(cancelableTransport); ok { ct.CancelRequest(req) } } diff --git a/recorder/recorder_test.go b/recorder/recorder_test.go index 1a6fb65..1de2956 100644 --- a/recorder/recorder_test.go +++ b/recorder/recorder_test.go @@ -49,6 +49,7 @@ type testCase struct { wantBody string wantStatus int wantContentLength int + wantError error path string } @@ -61,6 +62,9 @@ func (tc testCase) run(client *http.Client, ctx context.Context, serverUrl strin resp, err := client.Do(req.WithContext(ctx)) if err != nil { + if tc.wantError != nil && errors.Is(err, tc.wantError) { + return nil + } return err } defer resp.Body.Close() @@ -1227,6 +1231,97 @@ func TestRecordOnlyMode(t *testing.T) { } } +func TestBlockRealTransportUnsafeMethods(t *testing.T) { + // Set things up + tests := []testCase{ + { + method: http.MethodGet, + wantBody: "GET go-vcr\n", + wantStatus: http.StatusOK, + wantContentLength: 11, + path: "/api/v1/foo", + }, + { + method: http.MethodHead, + wantStatus: http.StatusOK, + wantContentLength: 12, + path: "/api/v1/bar", + }, + { + method: http.MethodOptions, + wantBody: "OPTIONS go-vcr\n", + wantStatus: http.StatusOK, + wantContentLength: 15, + path: "/api/v1/foo", + }, + { + method: http.MethodTrace, + wantBody: "TRACE go-vcr\n", + wantStatus: http.StatusOK, + wantContentLength: 13, + path: "/api/v1/foo", + }, + { + method: http.MethodPost, + body: "foo", + wantError: recorder.ErrUnsafeRequestMethod, + path: "/api/v1/baz", + }, + { + method: http.MethodPut, + body: "foo", + wantError: recorder.ErrUnsafeRequestMethod, + path: "/api/v1/baz", + }, + { + method: http.MethodDelete, + wantError: recorder.ErrUnsafeRequestMethod, + path: "/api/v1/baz", + }, + { + method: http.MethodConnect, + wantError: recorder.ErrUnsafeRequestMethod, + path: "/api/v1/baz", + }, + { + method: http.MethodPatch, + body: "foo", + wantError: recorder.ErrUnsafeRequestMethod, + path: "/api/v1/baz", + }, + } + + server := newEchoHttpServer() + serverUrl := server.URL + defer server.Close() + + cassPath, err := newCassettePath("test_record_only") + if err != nil { + t.Fatal(err) + } + + // Create recorder + opts := &recorder.Options{ + CassetteName: cassPath, + Mode: recorder.ModeRecordOnly, + BlockRealTransportUnsafeMethods: true, + } + rec, err := recorder.NewWithOptions(opts) + if err != nil { + t.Fatal(err) + } + defer rec.Stop() + + // Run tests + ctx := context.Background() + client := rec.GetDefaultClient() + for _, test := range tests { + if err := test.run(client, ctx, serverUrl); err != nil { + t.Fatal(err) + } + } +} + func TestInvalidRecorderMode(t *testing.T) { // Create recorder opts := &recorder.Options{