From 78b95994ca9ba11d6a7d659a1d24bcb4bb4bff0a Mon Sep 17 00:00:00 2001 From: Dustin Zeisler Date: Wed, 29 May 2024 16:24:02 -0700 Subject: [PATCH] OnRequestReplay callback to allow mutating request like force reading the body --- cassette/cassette.go | 22 ++++++++++ recorder/recorder.go | 5 +++ recorder/recorder_test.go | 87 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 114 insertions(+) diff --git a/cassette/cassette.go b/cassette/cassette.go index 73af81a..7cbe7a8 100644 --- a/cassette/cassette.go +++ b/cassette/cassette.go @@ -213,6 +213,15 @@ func DefaultMatcher(r *http.Request, i Request) bool { return r.Method == i.Method && r.URL.String() == i.URL } +// OnRequestReplayFunc function is called when a request is being replayed. +// This is helpful when you want to modify the request like forcing the body to be read. +type OnRequestReplayFunc func(*http.Request) error + +// DefaultOnRequestReplayFunc is used when a custom function is not defined +func DefaultOnRequestReplayFunc(*http.Request) error { + return nil +} + // Cassette type type Cassette struct { // Name of the cassette @@ -238,6 +247,9 @@ type Cassette struct { // Matches actual request with interaction requests. Matcher MatcherFunc `yaml:"-"` + // OnRequestReplay is called when a request is being replayed. + OnRequestReplay OnRequestReplayFunc `yaml:"-"` + // IsNew specifies whether this is a newly created cassette. // Returns false, when the cassette was loaded from an // existing source, e.g. a file. @@ -254,6 +266,7 @@ func New(name string) *Cassette { Version: CassetteFormatV2, Interactions: make([]*Interaction, 0), Matcher: DefaultMatcher, + OnRequestReplay: DefaultOnRequestReplayFunc, ReplayableInteractions: false, IsNew: true, nextInteractionId: 0, @@ -294,6 +307,15 @@ func (c *Cassette) AddInteraction(i *Interaction) { // GetInteraction retrieves a recorded request/response interaction func (c *Cassette) GetInteraction(r *http.Request) (*Interaction, error) { + i, err := c.getInteraction(r) + if err != nil { + return nil, err + } + // Ensure OnRequestReplay is not wrapped with a lock. + return i, c.OnRequestReplay(r) +} + +func (c *Cassette) getInteraction(r *http.Request) (*Interaction, error) { c.Mu.Lock() defer c.Mu.Unlock() for _, i := range c.Interactions { diff --git a/recorder/recorder.go b/recorder/recorder.go index f95784a..9dc9a87 100644 --- a/recorder/recorder.go +++ b/recorder/recorder.go @@ -502,6 +502,11 @@ func (rec *Recorder) SetMatcher(matcher cassette.MatcherFunc) { rec.cassette.Matcher = matcher } +// OnRequestReplay sets a function to be called when replaying a request. +func (rec *Recorder) OnRequestReplay(onRequestReplay cassette.OnRequestReplayFunc) { + rec.cassette.OnRequestReplay = onRequestReplay +} + // SetReplayableInteractions defines whether to allow interactions to // be replayed or not. This is useful in cases when you need to hit // the same endpoint multiple times and want to replay the interaction diff --git a/recorder/recorder_test.go b/recorder/recorder_test.go index 0340c91..b916f02 100644 --- a/recorder/recorder_test.go +++ b/recorder/recorder_test.go @@ -1317,3 +1317,90 @@ func TestDiscardInteractionsOnSave(t *testing.T) { t.Fatalf("expected %d interactions, got %d", wantInteractions, gotInteractions) } } + +func TestOnRequestReplay(t *testing.T) { + tests := []testCase{ + { + method: http.MethodPost, + body: "foo", + wantBody: "POST go-vcr\nfoo", + wantStatus: http.StatusOK, + wantContentLength: 15, + path: "/api/v1/foo", + }, + } + + server := newEchoHttpServer() + serverUrl := server.URL + + cassPath, err := newCassettePath("test_on_request_replay") + if err != nil { + t.Fatal(err) + } + + // Create recorder + rec, err := recorder.New(cassPath) + if err != nil { + t.Fatal(err) + } + + if rec.Mode() != recorder.ModeRecordOnce { + t.Fatal("recorder is not in the correct mode") + } + + if rec.IsRecording() != true { + t.Fatal("recorder is not recording") + } + + // populate the cassette + ctx := context.Background() + client := rec.GetDefaultClient() + for _, test := range tests { + if err := test.run(client, ctx, serverUrl); err != nil { + t.Fatal(err) + } + } + + server.Close() + rec.Stop() + + // Re-run the tests with the recorder in replay mode + opts := &recorder.Options{ + CassetteName: cassPath, + Mode: recorder.ModeReplayOnly, + } + rec, err = recorder.NewWithOptions(opts) + if err != nil { + t.Fatal(err) + } + defer rec.Stop() + + if rec.Mode() != recorder.ModeReplayOnly { + t.Fatal("recorder is not in the correct mode") + } + + if rec.IsRecording() != false { + t.Fatal("recorder should not be recording") + } + + var onReplayRequest *http.Request + // Add a hook to capture the request being replayed + rec.OnRequestReplay(func(r *http.Request) error { + onReplayRequest = r + return nil + }) + + // Set replayable interactions to true, so that we can match + // against the already recorded interactions. + rec.SetReplayableInteractions(true) + + for _, test := range tests { + if err := test.run(rec.GetDefaultClient(), ctx, serverUrl); err != nil { + t.Fatal(err) + } + } + + if onReplayRequest == nil { + t.Fatal("expected replaced request not be nil") + } +}