diff --git a/format/format.go b/format/format.go index 4f7462ab1..c6e4f2783 100644 --- a/format/format.go +++ b/format/format.go @@ -53,6 +53,13 @@ var Indent = " " var longFormThreshold = 20 +// GomegaStringer allows for custom formating of objects for gomega. +type GomegaStringer interface { + // GomegaString will be used to custom format an object. + // It does not follow UseStringerRepresentation value and will always be called regardless. + GomegaString() string +} + /* Generates a formatted matcher success/failure message of the form: @@ -219,9 +226,14 @@ func formatValue(value reflect.Value, indentation uint) string { return "nil" } - if UseStringerRepresentation { - if value.CanInterface() { - obj := value.Interface() + // GomegaStringer will take precedence to other representations and disregards UseStringerRepresentation + if value.CanInterface() { + obj := value.Interface() + if x, ok := obj.(GomegaStringer); ok { + return x.GomegaString() + } + + if UseStringerRepresentation { switch x := obj.(type) { case fmt.GoStringer: return x.GoString() diff --git a/format/format_test.go b/format/format_test.go index 8f20767e8..c7e67cec2 100644 --- a/format/format_test.go +++ b/format/format_test.go @@ -74,6 +74,13 @@ func (g Stringer) String() string { return "string" } +type gomegaStringer struct { +} + +func (g gomegaStringer) GomegaString() string { + return "gomegastring" +} + var _ = Describe("Format", func() { match := func(typeRepresentation string, valueRepresentation string, args ...interface{}) types.GomegaMatcher { if len(args) > 0 { @@ -654,6 +661,14 @@ var _ = Describe("Format", func() { Expect(Object(Stringer{}, 1)).Should(ContainSubstring(": string")) }) }) + + When("passed a GomegaStringer", func() { + It("should use what GomegaString() returns", func() { + Expect(Object(gomegaStringer{}, 1)).Should(ContainSubstring(": gomegastring")) + UseStringerRepresentation = false + Expect(Object(gomegaStringer{}, 1)).Should(ContainSubstring(": gomegastring")) + }) + }) }) Describe("Printing a context.Context field", func() {