Skip to content

Commit

Permalink
Allow custom assertions with ReplayAssertFunc
Browse files Browse the repository at this point in the history
  • Loading branch information
calvinmclean authored and dnaeon committed Aug 19, 2024
1 parent 98b2115 commit ab3bb92
Showing 1 changed file with 21 additions and 21 deletions.
42 changes: 21 additions & 21 deletions pkg/cassette/server_replay.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,26 @@ import (
"testing"
)

// ReplayAssertFunc is used to assert the results of replaying a recorded request against a handler.
// It receives the current Interaction and the httptest.ResponseRecorder.
type ReplayAssertFunc func(t *testing.T, expected *Interaction, actual *httptest.ResponseRecorder)

// DefaultReplayAssertFunc compares the response status code, body, and headers.
// It can be overridden for more specific tests or to use your preferred assertion libraries
var DefaultReplayAssertFunc ReplayAssertFunc = func(t *testing.T, expected *Interaction, actual *httptest.ResponseRecorder) {
if expected.Response.Code != actual.Result().StatusCode {
t.Errorf("status code does not match: expected=%d actual=%d", expected.Response.Code, actual.Result().StatusCode)
}

if expected.Response.Body != actual.Body.String() {
t.Errorf("body does not match: expected=%s actual=%s", expected.Response.Body, actual.Body.String())
}

if !headersEqual(expected.Response.Headers, actual.Header()) {
t.Errorf("header values do not match. expected=%v actual=%v", expected.Response.Headers, actual.Header())
}
}

// TestServerReplay loads a Cassette and replays each Interaction with the provided Handler, then compares the response
func TestServerReplay(t *testing.T, cassetteName string, handler http.Handler) {
t.Helper()
Expand Down Expand Up @@ -47,27 +67,7 @@ func TestInteractionReplay(t *testing.T, handler http.Handler, interaction *Inte
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)

expectedResp, err := interaction.GetHTTPResponse()
if err != nil {
t.Errorf("unexpected error getting interaction response: %v", err)
}

if expectedResp.StatusCode != w.Result().StatusCode {
t.Errorf("status code does not match: expected=%d actual=%d", expectedResp.StatusCode, w.Result().StatusCode)
}

expectedBody, err := io.ReadAll(expectedResp.Body)
if err != nil {
t.Errorf("unexpected reading response body: %v", err)
}

if string(expectedBody) != w.Body.String() {
t.Errorf("body does not match: expected=%s actual=%s", expectedBody, w.Body.String())
}

if !headersEqual(expectedResp.Header, w.Header()) {
t.Errorf("header values do not match. expected=%v actual=%v", expectedResp.Header, w.Header())
}
DefaultReplayAssertFunc(t, interaction, w)
}

func headersEqual(expected, actual http.Header) bool {
Expand Down

0 comments on commit ab3bb92

Please sign in to comment.