diff --git a/pkg/cassette/server_replay.go b/pkg/cassette/server_replay.go index 5853a83..a18908d 100644 --- a/pkg/cassette/server_replay.go +++ b/pkg/cassette/server_replay.go @@ -11,6 +11,26 @@ import ( "testing" ) +// ReplayAssertFunc is used to assert the results of replaying a recorded request against a handler. +// It receives the current Interaction and the httptest.ResponseRecorder. +type ReplayAssertFunc func(t *testing.T, expected *Interaction, actual *httptest.ResponseRecorder) + +// DefaultReplayAssertFunc compares the response status code, body, and headers. +// It can be overridden for more specific tests or to use your preferred assertion libraries +var DefaultReplayAssertFunc ReplayAssertFunc = func(t *testing.T, expected *Interaction, actual *httptest.ResponseRecorder) { + if expected.Response.Code != actual.Result().StatusCode { + t.Errorf("status code does not match: expected=%d actual=%d", expected.Response.Code, actual.Result().StatusCode) + } + + if expected.Response.Body != actual.Body.String() { + t.Errorf("body does not match: expected=%s actual=%s", expected.Response.Body, actual.Body.String()) + } + + if !headersEqual(expected.Response.Headers, actual.Header()) { + t.Errorf("header values do not match. expected=%v actual=%v", expected.Response.Headers, actual.Header()) + } +} + // TestServerReplay loads a Cassette and replays each Interaction with the provided Handler, then compares the response func TestServerReplay(t *testing.T, cassetteName string, handler http.Handler) { t.Helper() @@ -47,27 +67,7 @@ func TestInteractionReplay(t *testing.T, handler http.Handler, interaction *Inte w := httptest.NewRecorder() handler.ServeHTTP(w, req) - expectedResp, err := interaction.GetHTTPResponse() - if err != nil { - t.Errorf("unexpected error getting interaction response: %v", err) - } - - if expectedResp.StatusCode != w.Result().StatusCode { - t.Errorf("status code does not match: expected=%d actual=%d", expectedResp.StatusCode, w.Result().StatusCode) - } - - expectedBody, err := io.ReadAll(expectedResp.Body) - if err != nil { - t.Errorf("unexpected reading response body: %v", err) - } - - if string(expectedBody) != w.Body.String() { - t.Errorf("body does not match: expected=%s actual=%s", expectedBody, w.Body.String()) - } - - if !headersEqual(expectedResp.Header, w.Header()) { - t.Errorf("header values do not match. expected=%v actual=%v", expectedResp.Header, w.Header()) - } + DefaultReplayAssertFunc(t, interaction, w) } func headersEqual(expected, actual http.Header) bool {