Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor reporter implementation #112

Merged
merged 1 commit into from
Feb 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 112 additions & 64 deletions cmp/compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,32 @@ var nothing = reflect.Value{}
// Map keys are equal according to the == operator.
// To use custom comparisons for map keys, consider using cmpopts.SortMaps.
func Equal(x, y interface{}, opts ...Option) bool {
vx := reflect.ValueOf(x)
vy := reflect.ValueOf(y)

// If the inputs are different types, auto-wrap them in an empty interface
// so that they have the same parent type.
var t reflect.Type
if !vx.IsValid() || !vy.IsValid() || vx.Type() != vy.Type() {
t = reflect.TypeOf((*interface{})(nil)).Elem()
if vx.IsValid() {
vvx := reflect.New(t).Elem()
vvx.Set(vx)
vx = vvx
}
if vy.IsValid() {
vvy := reflect.New(t).Elem()
vvy.Set(vy)
vy = vvy
}
} else {
t = vx.Type()
}

s := newState(opts)
s.compareAny(reflect.ValueOf(x), reflect.ValueOf(y))
s.pushStep(&pathStep{typ: t}, vx, vy)
s.compareAny(vx, vy)
s.popStep()
return s.result.Equal()
}

Expand All @@ -91,7 +115,7 @@ func Equal(x, y interface{}, opts ...Option) bool {
// Do not depend on this output being stable.
func Diff(x, y interface{}, opts ...Option) string {
r := new(defaultReporter)
opts = Options{Options(opts), r}
opts = Options{Options(opts), reporter(r)}
eq := Equal(x, y, opts...)
d := r.String()
if (d == "") != eq {
Expand All @@ -103,9 +127,9 @@ func Diff(x, y interface{}, opts ...Option) string {
type state struct {
// These fields represent the "comparison state".
// Calling statelessCompare must not result in observable changes to these.
result diff.Result // The current result of comparison
curPath Path // The current path in the value tree
reporter reporter // Optional reporter used for difference formatting
result diff.Result // The current result of comparison
curPath Path // The current path in the value tree
reporters []reporterOption // Optional reporters

// recChecker checks for infinite cycles applying the same set of
// transformers upon the output of itself.
Expand Down Expand Up @@ -150,11 +174,8 @@ func (s *state) processOption(opt Option) {
for t := range opt {
s.exporters[t] = true
}
case reporter:
if s.reporter != nil {
panic("difference reporter already registered")
}
s.reporter = opt
case reporterOption:
s.reporters = append(s.reporters, opt)
default:
panic(fmt.Sprintf("unknown option %T", opt))
}
Expand All @@ -169,12 +190,12 @@ func (s *state) statelessCompare(vx, vy reflect.Value) diff.Result {
// It is an implementation bug if the contents of curPath differs from
// when calling this function to when returning from it.

oldResult, oldReporter := s.result, s.reporter
oldResult, oldReporters := s.result, s.reporters
s.result = diff.Result{} // Reset result
s.reporter = nil // Remove reporter to avoid spurious printouts
s.reporters = nil // Remove reporters to avoid spurious printouts
s.compareAny(vx, vy)
res := s.result
s.result, s.reporter = oldResult, oldReporter
s.result, s.reporters = oldResult, oldReporters
return res
}

Expand All @@ -184,18 +205,14 @@ func (s *state) compareAny(vx, vy reflect.Value) {

// Rule 0: Differing types are never equal.
if !vx.IsValid() || !vy.IsValid() {
s.report(vx.IsValid() == vy.IsValid(), vx, vy)
s.report(vx.IsValid() == vy.IsValid())
return
}
if vx.Type() != vy.Type() {
s.report(false, vx, vy) // Possible for path to be empty
s.report(false)
return
}
t := vx.Type()
if len(s.curPath) == 0 {
s.curPath.push(&pathStep{typ: t})
defer s.curPath.pop()
}
vx, vy = s.tryExporting(vx, vy)

// Rule 1: Check whether an option applies on this node in the value tree.
Expand All @@ -211,35 +228,35 @@ func (s *state) compareAny(vx, vy reflect.Value) {
// Rule 3: Recursively descend into each value's underlying kind.
switch t.Kind() {
case reflect.Bool:
s.report(vx.Bool() == vy.Bool(), vx, vy)
s.report(vx.Bool() == vy.Bool())
return
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
s.report(vx.Int() == vy.Int(), vx, vy)
s.report(vx.Int() == vy.Int())
return
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
s.report(vx.Uint() == vy.Uint(), vx, vy)
s.report(vx.Uint() == vy.Uint())
return
case reflect.Float32, reflect.Float64:
s.report(vx.Float() == vy.Float(), vx, vy)
s.report(vx.Float() == vy.Float())
return
case reflect.Complex64, reflect.Complex128:
s.report(vx.Complex() == vy.Complex(), vx, vy)
s.report(vx.Complex() == vy.Complex())
return
case reflect.String:
s.report(vx.String() == vy.String(), vx, vy)
s.report(vx.String() == vy.String())
return
case reflect.Chan, reflect.UnsafePointer:
s.report(vx.Pointer() == vy.Pointer(), vx, vy)
s.report(vx.Pointer() == vy.Pointer())
return
case reflect.Func:
s.report(vx.IsNil() && vy.IsNil(), vx, vy)
s.report(vx.IsNil() && vy.IsNil())
return
case reflect.Struct:
s.compareStruct(vx, vy, t)
return
case reflect.Slice:
if vx.IsNil() || vy.IsNil() {
s.report(vx.IsNil() && vy.IsNil(), vx, vy)
s.report(vx.IsNil() && vy.IsNil())
return
}
fallthrough
Expand All @@ -251,25 +268,27 @@ func (s *state) compareAny(vx, vy reflect.Value) {
return
case reflect.Ptr:
if vx.IsNil() || vy.IsNil() {
s.report(vx.IsNil() && vy.IsNil(), vx, vy)
s.report(vx.IsNil() && vy.IsNil())
return
}
s.curPath.push(&indirect{pathStep{t.Elem()}})
defer s.curPath.pop()
s.compareAny(vx.Elem(), vy.Elem())
vx, vy = vx.Elem(), vy.Elem()
s.pushStep(&indirect{pathStep{t.Elem()}}, vx, vy)
s.compareAny(vx, vy)
s.popStep()
return
case reflect.Interface:
if vx.IsNil() || vy.IsNil() {
s.report(vx.IsNil() && vy.IsNil(), vx, vy)
s.report(vx.IsNil() && vy.IsNil())
return
}
if vx.Elem().Type() != vy.Elem().Type() {
s.report(false, vx.Elem(), vy.Elem())
vx, vy = vx.Elem(), vy.Elem()
if vx.Type() != vy.Type() {
s.report(false)
return
}
s.curPath.push(&typeAssertion{pathStep{vx.Elem().Type()}})
defer s.curPath.pop()
s.compareAny(vx.Elem(), vy.Elem())
s.pushStep(&typeAssertion{pathStep{vx.Type()}}, vx, vy)
s.compareAny(vx, vy)
s.popStep()
return
default:
panic(fmt.Sprintf("%v kind not handled", t.Kind()))
Expand Down Expand Up @@ -318,7 +337,7 @@ func (s *state) tryMethod(vx, vy reflect.Value, t reflect.Type) bool {
}

eq := s.callTTBFunc(m.Func, vx, vy)
s.report(eq, vx, vy)
s.report(eq)
return true
}

Expand Down Expand Up @@ -391,11 +410,8 @@ func (s *state) compareStruct(vx, vy reflect.Value, t reflect.Type) {
var vax, vay reflect.Value // Addressable versions of vx and vy

step := &structField{}
s.curPath.push(step)
defer s.curPath.pop()
for i := 0; i < t.NumField(); i++ {
vvx := vx.Field(i)
vvy := vy.Field(i)
vvx, vvy := vx.Field(i), vy.Field(i)
step.typ = t.Field(i).Type
step.name = t.Field(i).Name
step.idx = i
Expand All @@ -418,18 +434,22 @@ func (s *state) compareStruct(vx, vy reflect.Value, t reflect.Type) {
step.pvy = vay
step.field = t.Field(i)
}
s.pushStep(step, vvx, vvy)
s.compareAny(vvx, vvy)
s.popStep()
}
}

func (s *state) compareSlice(vx, vy reflect.Value, t reflect.Type) {
step := &sliceIndex{pathStep{t.Elem()}, 0, 0}
s.curPath.push(step)

// Compute an edit-script for slices vx and vy.
es := diff.Difference(vx.Len(), vy.Len(), func(ix, iy int) diff.Result {
step.xkey, step.ykey = ix, iy
return s.statelessCompare(vx.Index(ix), vy.Index(iy))
s.curPath.push(step)
ret := s.statelessCompare(vx.Index(ix), vy.Index(iy))
s.curPath.pop()
return ret
})

// Report the entire slice as is if the arrays are of primitive kind,
Expand All @@ -442,8 +462,7 @@ func (s *state) compareSlice(vx, vy reflect.Value, t reflect.Type) {
isPrimitive = true
}
if isPrimitive && es.Dist() > (vx.Len()+vy.Len())/4 {
s.curPath.pop() // Pop first since we are reporting the whole slice
s.report(false, vx, vy)
s.report(false)
return
}

Expand All @@ -453,49 +472,55 @@ func (s *state) compareSlice(vx, vy reflect.Value, t reflect.Type) {
switch e {
case diff.UniqueX:
step.xkey, step.ykey = ix, -1
s.report(false, vx.Index(ix), nothing)
vvx := vx.Index(ix)
s.pushStep(step, vvx, nothing)
s.report(false)
s.popStep()
ix++
case diff.UniqueY:
step.xkey, step.ykey = -1, iy
s.report(false, nothing, vy.Index(iy))
vvy := vy.Index(iy)
s.pushStep(step, nothing, vvy)
s.report(false)
s.popStep()
iy++
default:
step.xkey, step.ykey = ix, iy
vvx, vvy := vx.Index(ix), vy.Index(iy)
s.pushStep(step, vvx, vvy)
if e == diff.Identity {
s.report(true, vx.Index(ix), vy.Index(iy))
s.report(true)
} else {
s.compareAny(vx.Index(ix), vy.Index(iy))
s.compareAny(vvx, vvy)
}
s.popStep()
ix++
iy++
}
}
s.curPath.pop()
return
}

func (s *state) compareMap(vx, vy reflect.Value, t reflect.Type) {
if vx.IsNil() || vy.IsNil() {
s.report(vx.IsNil() && vy.IsNil(), vx, vy)
s.report(vx.IsNil() && vy.IsNil())
return
}

// We combine and sort the two map keys so that we can perform the
// comparisons in a deterministic order.
step := &mapIndex{pathStep: pathStep{t.Elem()}}
s.curPath.push(step)
defer s.curPath.pop()
for _, k := range value.SortKeys(append(vx.MapKeys(), vy.MapKeys()...)) {
step.key = k
vvx := vx.MapIndex(k)
vvy := vy.MapIndex(k)
vvx, vvy := vx.MapIndex(k), vy.MapIndex(k)
s.pushStep(step, vvx, vvy)
switch {
case vvx.IsValid() && vvy.IsValid():
s.compareAny(vvx, vvy)
case vvx.IsValid() && !vvy.IsValid():
s.report(false, vvx, nothing)
s.report(false)
case !vvx.IsValid() && vvy.IsValid():
s.report(false, nothing, vvy)
s.report(false)
default:
// It is possible for both vvx and vvy to be invalid if the
// key contained a NaN value in it.
Expand All @@ -514,19 +539,42 @@ func (s *state) compareMap(vx, vy reflect.Value, t reflect.Type) {
const help = "consider providing a Comparer to compare the map"
panic(fmt.Sprintf("%#v has map key with NaNs\n%s", s.curPath, help))
}
s.popStep()
}
}

// report records the result of a single comparison.
// It also calls Report if any reporter is registered.
func (s *state) report(eq bool, vx, vy reflect.Value) {
func (s *state) pushStep(ps PathStep, x, y reflect.Value) {
s.curPath.push(ps)
for _, r := range s.reporters {
r.PushStep(ps, x, y)
}
}

func (s *state) popStep() {
s.curPath.pop()
for _, r := range s.reporters {
r.PopStep()
}
}

func (s *state) report(eq bool) {
if eq {
s.result.NSame++
} else {
s.result.NDiff++
}
if s.reporter != nil {
s.reporter.Report(vx, vy, eq, s.curPath)
for _, r := range s.reporters {
if eq {
r.Report(reportEqual)
} else {
r.Report(reportUnequal)
}
}
}

func (s *state) reportIgnore() {
for _, r := range s.reporters {
r.Report(reportIgnore)
}
}

Expand Down
4 changes: 2 additions & 2 deletions cmp/compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ root:
x: new(fmt.Stringer),
y: nil,
wantDiff: `
:
root:
-: &<nil>
+: <non-existent>`,
}, {
Expand Down Expand Up @@ -426,7 +426,7 @@ root:
// Ensure Stringer avoids double-quote escaping if possible.
label: label,
x: []*pb.Stringer{{`multi\nline\nline\nline`}},
wantDiff: ":\n\t-: []*testprotos.Stringer{s`multi\\nline\\nline\\nline`}\n\t+: <non-existent>",
wantDiff: "root:\n\t-: []*testprotos.Stringer{s`multi\\nline\\nline\\nline`}\n\t+: <non-existent>",
}, {
label: label,
x: struct{ I Iface2 }{},
Expand Down
Loading