diff --git a/cassette/cassette.go b/cassette/cassette.go index 7cbe7a8..8d4d279 100644 --- a/cassette/cassette.go +++ b/cassette/cassette.go @@ -35,6 +35,8 @@ import ( "strings" "sync" "time" + "reflect" + "bytes" "gopkg.in/yaml.v3" ) @@ -207,10 +209,98 @@ func (i *Interaction) GetHTTPResponse() (*http.Response, error) { // criteria. type MatcherFunc func(*http.Request, Request) bool -// DefaultMatcher is used when a custom matcher is not defined and -// compares only the method and of the HTTP request. +// Similar to reflect.DeepEqual, but considers the contents of collections, so {} and nil would be +// considered equal. works with Array, Map, Slice, or pointer to Array. +func deepEqualContents(x, y any) bool { + if reflect.ValueOf(x).IsNil() { + if reflect.ValueOf(y).IsNil() { + return true + } else { + return reflect.ValueOf(y).Len() == 0 + } + } else { + if reflect.ValueOf(y).IsNil() { + return reflect.ValueOf(x).Len() == 0 + } else { + return reflect.DeepEqual(x, y) + } + } +} + +func bodyMatches(r *http.Request, i Request) bool { + if r.Body != nil { + var buffer bytes.Buffer + + if _, err := buffer.ReadFrom(r.Body); err != nil { + panic(fmt.Sprintf("failed to read %s %s body: %s", r.Method, r.URL.String(), err)) + } + + r.Body = io.NopCloser(bytes.NewBuffer(buffer.Bytes())) + + if buffer.String() != i.Body { + return false + } + } else { + if len(i.Body) != 0 { + return false + } + } + + return true +} + +// DefaultMatcher is used when a custom matcher is not defined. It compares the whole HTTP request +// and only matches, if everything (eg: method, url, headers, body, ) matches. func DefaultMatcher(r *http.Request, i Request) bool { - return r.Method == i.Method && r.URL.String() == i.URL + + if r.Method != i.Method { + return false + } + + if r.URL.String() != i.URL { + return false + } + + if r.ProtoMajor != i.ProtoMajor { + return false + } + if r.ProtoMinor != i.ProtoMinor { + return false + } + + if !deepEqualContents(r.Header, i.Headers) { + return false + } + + if !bodyMatches(r, i) { + return false + } + + if r.ContentLength != i.ContentLength { + return false + } + + if !deepEqualContents(r.TransferEncoding, i.TransferEncoding) { + return false + } + + if r.Host != i.Host { + return false + } + + if !deepEqualContents(r.Trailer, i.Trailer) { + return false + } + + if r.RemoteAddr != i.RemoteAddr { + return false + } + + if r.RequestURI != i.RequestURI { + return false + } + + return true } // OnRequestReplayFunc function is called when a request is being replayed. diff --git a/recorder/recorder_test.go b/recorder/recorder_test.go index b916f02..1a6fb65 100644 --- a/recorder/recorder_test.go +++ b/recorder/recorder_test.go @@ -1111,10 +1111,10 @@ func TestWithCustomMatcher(t *testing.T) { rec.SetReplayableInteractions(true) // All requests which hit the same URL and use the same method - // will match against the first recorded interaction. + // will match against the recorded interaction. client = rec.GetDefaultClient() url := fmt.Sprintf("%s%s", serverUrl, "/api/v1/foo") // Same URL as the test cases - req, err := http.NewRequest(http.MethodPost, url, strings.NewReader("any body will match")) + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader("foo")) if err != nil { t.Fatal(err) }