Skip to content

Commit

Permalink
OnRequestReplay callback to allow mutating request like force reading…
Browse files Browse the repository at this point in the history
… the body
  • Loading branch information
Dustin Zeisler authored and dnaeon committed May 30, 2024
1 parent 1c1cba9 commit 78b9599
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 0 deletions.
22 changes: 22 additions & 0 deletions cassette/cassette.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,15 @@ func DefaultMatcher(r *http.Request, i Request) bool {
return r.Method == i.Method && r.URL.String() == i.URL
}

// OnRequestReplayFunc function is called when a request is being replayed.
// This is helpful when you want to modify the request like forcing the body to be read.
type OnRequestReplayFunc func(*http.Request) error

// DefaultOnRequestReplayFunc is used when a custom function is not defined
func DefaultOnRequestReplayFunc(*http.Request) error {
return nil
}

// Cassette type
type Cassette struct {
// Name of the cassette
Expand All @@ -238,6 +247,9 @@ type Cassette struct {
// Matches actual request with interaction requests.
Matcher MatcherFunc `yaml:"-"`

// OnRequestReplay is called when a request is being replayed.
OnRequestReplay OnRequestReplayFunc `yaml:"-"`

// IsNew specifies whether this is a newly created cassette.
// Returns false, when the cassette was loaded from an
// existing source, e.g. a file.
Expand All @@ -254,6 +266,7 @@ func New(name string) *Cassette {
Version: CassetteFormatV2,
Interactions: make([]*Interaction, 0),
Matcher: DefaultMatcher,
OnRequestReplay: DefaultOnRequestReplayFunc,
ReplayableInteractions: false,
IsNew: true,
nextInteractionId: 0,
Expand Down Expand Up @@ -294,6 +307,15 @@ func (c *Cassette) AddInteraction(i *Interaction) {

// GetInteraction retrieves a recorded request/response interaction
func (c *Cassette) GetInteraction(r *http.Request) (*Interaction, error) {
i, err := c.getInteraction(r)
if err != nil {
return nil, err
}
// Ensure OnRequestReplay is not wrapped with a lock.
return i, c.OnRequestReplay(r)
}

func (c *Cassette) getInteraction(r *http.Request) (*Interaction, error) {
c.Mu.Lock()
defer c.Mu.Unlock()
for _, i := range c.Interactions {
Expand Down
5 changes: 5 additions & 0 deletions recorder/recorder.go
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,11 @@ func (rec *Recorder) SetMatcher(matcher cassette.MatcherFunc) {
rec.cassette.Matcher = matcher
}

// OnRequestReplay sets a function to be called when replaying a request.
func (rec *Recorder) OnRequestReplay(onRequestReplay cassette.OnRequestReplayFunc) {
rec.cassette.OnRequestReplay = onRequestReplay
}

// SetReplayableInteractions defines whether to allow interactions to
// be replayed or not. This is useful in cases when you need to hit
// the same endpoint multiple times and want to replay the interaction
Expand Down
87 changes: 87 additions & 0 deletions recorder/recorder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1317,3 +1317,90 @@ func TestDiscardInteractionsOnSave(t *testing.T) {
t.Fatalf("expected %d interactions, got %d", wantInteractions, gotInteractions)
}
}

func TestOnRequestReplay(t *testing.T) {
tests := []testCase{
{
method: http.MethodPost,
body: "foo",
wantBody: "POST go-vcr\nfoo",
wantStatus: http.StatusOK,
wantContentLength: 15,
path: "/api/v1/foo",
},
}

server := newEchoHttpServer()
serverUrl := server.URL

cassPath, err := newCassettePath("test_on_request_replay")
if err != nil {
t.Fatal(err)
}

// Create recorder
rec, err := recorder.New(cassPath)
if err != nil {
t.Fatal(err)
}

if rec.Mode() != recorder.ModeRecordOnce {
t.Fatal("recorder is not in the correct mode")
}

if rec.IsRecording() != true {
t.Fatal("recorder is not recording")
}

// populate the cassette
ctx := context.Background()
client := rec.GetDefaultClient()
for _, test := range tests {
if err := test.run(client, ctx, serverUrl); err != nil {
t.Fatal(err)
}
}

server.Close()
rec.Stop()

// Re-run the tests with the recorder in replay mode
opts := &recorder.Options{
CassetteName: cassPath,
Mode: recorder.ModeReplayOnly,
}
rec, err = recorder.NewWithOptions(opts)
if err != nil {
t.Fatal(err)
}
defer rec.Stop()

if rec.Mode() != recorder.ModeReplayOnly {
t.Fatal("recorder is not in the correct mode")
}

if rec.IsRecording() != false {
t.Fatal("recorder should not be recording")
}

var onReplayRequest *http.Request
// Add a hook to capture the request being replayed
rec.OnRequestReplay(func(r *http.Request) error {
onReplayRequest = r
return nil
})

// Set replayable interactions to true, so that we can match
// against the already recorded interactions.
rec.SetReplayableInteractions(true)

for _, test := range tests {
if err := test.run(rec.GetDefaultClient(), ctx, serverUrl); err != nil {
t.Fatal(err)
}
}

if onReplayRequest == nil {
t.Fatal("expected replaced request not be nil")
}
}

0 comments on commit 78b9599

Please sign in to comment.