diff --git a/issue131_test.go b/issue131_test.go index 5ca24cf..47e8fc6 100644 --- a/issue131_test.go +++ b/issue131_test.go @@ -9,6 +9,9 @@ import ( type foz struct { A *bool B string + C *bool + D *bool + E *bool } func TestIssue131MergeWithOverwriteWithEmptyValue(t *testing.T) { @@ -30,3 +33,38 @@ func TestIssue131MergeWithOverwriteWithEmptyValue(t *testing.T) { t.Errorf("dest.B not merged in properly: %v != %v", src.B, dest.B) } } + +func TestIssue131MergeWithoutDereference(t *testing.T) { + src := foz{ + A: func(v bool) *bool { return &v }(false), + B: "src", + C: nil, + D: func(v bool) *bool { return &v }(false), + E: func(v bool) *bool { return &v }(true), + } + dest := foz{ + A: func(v bool) *bool { return &v }(true), + B: "dest", + C: func(v bool) *bool { return &v }(false), + D: nil, + E: func(v bool) *bool { return &v }(false), + } + if err := mergo.Merge(&dest, src, mergo.WithoutDereference, mergo.WithOverride); err != nil { + t.Error(err) + } + if *src.A != *dest.A { + t.Errorf("dest.A not merged in properly: %v != %v", *src.A, *dest.A) + } + if src.B != dest.B { + t.Errorf("dest.B not merged in properly: %v != %v", src.B, dest.B) + } + if *dest.C != false { + t.Errorf("dest.C not merged in properly: %v != %v", src.C, *dest.C) + } + if *dest.D != false { + t.Errorf("dest.D not merged in properly: %v != %v", src.D, *dest.D) + } + if *dest.E != true { + t.Errorf("dest.E not merged in properly: %v != %v", src.E, *dest.E) + } +} diff --git a/map.go b/map.go index f739255..ce5b0f0 100644 --- a/map.go +++ b/map.go @@ -58,7 +58,7 @@ func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int, conf } fieldName := field.Name fieldName = changeInitialCase(fieldName, unicode.ToLower) - if v, ok := dstMap[fieldName]; !ok || (isEmptyValue(reflect.ValueOf(v)) || overwrite) { + if v, ok := dstMap[fieldName]; !ok || (isEmptyValue(reflect.ValueOf(v), !config.ShouldNotDereference) || overwrite) { dstMap[fieldName] = src.Field(i).Interface() } } diff --git a/merge.go b/merge.go index 4b47d0b..a4aa281 100644 --- a/merge.go +++ b/merge.go @@ -40,6 +40,7 @@ func isExportedComponent(field *reflect.StructField) bool { type Config struct { Transformers Transformers Overwrite bool + ShouldNotDereference bool AppendSlice bool TypeCheck bool overwriteWithEmptyValue bool @@ -95,7 +96,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co } } } else { - if dst.CanSet() && (isReflectNil(dst) || overwrite) && (!isEmptyValue(src) || overwriteWithEmptySrc) { + if dst.CanSet() && (isReflectNil(dst) || overwrite) && (!isEmptyValue(src, !config.ShouldNotDereference) || overwriteWithEmptySrc) { dst.Set(src) } } @@ -162,7 +163,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co dstSlice = reflect.ValueOf(dstElement.Interface()) } - if (!isEmptyValue(src) || overwriteWithEmptySrc || overwriteSliceWithEmptySrc) && (overwrite || isEmptyValue(dst)) && !config.AppendSlice && !sliceDeepCopy { + if (!isEmptyValue(src, !config.ShouldNotDereference) || overwriteWithEmptySrc || overwriteSliceWithEmptySrc) && (overwrite || isEmptyValue(dst, !config.ShouldNotDereference)) && !config.AppendSlice && !sliceDeepCopy { if typeCheck && srcSlice.Type() != dstSlice.Type() { return fmt.Errorf("cannot override two slices with different type (%s, %s)", srcSlice.Type(), dstSlice.Type()) } @@ -194,11 +195,11 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co dst.SetMapIndex(key, dstSlice) } } - if dstElement.IsValid() && !isEmptyValue(dstElement) && (reflect.TypeOf(srcElement.Interface()).Kind() == reflect.Map || reflect.TypeOf(srcElement.Interface()).Kind() == reflect.Slice) { + if dstElement.IsValid() && !isEmptyValue(dstElement, !config.ShouldNotDereference) && (reflect.TypeOf(srcElement.Interface()).Kind() == reflect.Map || reflect.TypeOf(srcElement.Interface()).Kind() == reflect.Slice) { continue } - if srcElement.IsValid() && ((srcElement.Kind() != reflect.Ptr && overwrite) || !dstElement.IsValid() || isEmptyValue(dstElement)) { + if srcElement.IsValid() && ((srcElement.Kind() != reflect.Ptr && overwrite) || !dstElement.IsValid() || isEmptyValue(dstElement, !config.ShouldNotDereference)) { if dst.IsNil() { dst.Set(reflect.MakeMap(dst.Type())) } @@ -209,7 +210,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co if !dst.CanSet() { break } - if (!isEmptyValue(src) || overwriteWithEmptySrc || overwriteSliceWithEmptySrc) && (overwrite || isEmptyValue(dst)) && !config.AppendSlice && !sliceDeepCopy { + if (!isEmptyValue(src, !config.ShouldNotDereference) || overwriteWithEmptySrc || overwriteSliceWithEmptySrc) && (overwrite || isEmptyValue(dst, !config.ShouldNotDereference)) && !config.AppendSlice && !sliceDeepCopy { dst.Set(src) } else if config.AppendSlice { if src.Type() != dst.Type() { @@ -244,12 +245,18 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co if src.Kind() != reflect.Interface { if dst.IsNil() || (src.Kind() != reflect.Ptr && overwrite) { - if dst.CanSet() && (overwrite || isEmptyValue(dst)) { + if dst.CanSet() && (overwrite || isEmptyValue(dst, !config.ShouldNotDereference)) { dst.Set(src) } } else if src.Kind() == reflect.Ptr { - if err = deepMerge(dst.Elem(), src.Elem(), visited, depth+1, config); err != nil { - return + if !config.ShouldNotDereference { + if err = deepMerge(dst.Elem(), src.Elem(), visited, depth+1, config); err != nil { + return + } + } else { + if overwriteWithEmptySrc || !src.IsNil() { + dst.Set(src) + } } } else if dst.Elem().Type() == src.Type() { if err = deepMerge(dst.Elem(), src, visited, depth+1, config); err != nil { @@ -262,7 +269,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co } if dst.IsNil() || overwrite { - if dst.CanSet() && (overwrite || isEmptyValue(dst)) { + if dst.CanSet() && (overwrite || isEmptyValue(dst, !config.ShouldNotDereference)) { dst.Set(src) } break @@ -275,7 +282,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co break } default: - mustSet := (isEmptyValue(dst) || overwrite) && (!isEmptyValue(src) || overwriteWithEmptySrc) + mustSet := (isEmptyValue(dst, !config.ShouldNotDereference) || overwrite) && (!isEmptyValue(src, !config.ShouldNotDereference) || overwriteWithEmptySrc) if mustSet { if dst.CanSet() { dst.Set(src) @@ -326,6 +333,12 @@ func WithOverrideEmptySlice(config *Config) { config.overwriteSliceWithEmptyValue = true } +// WithoutDereference prevents dereferencing pointers when evaluating whether they are empty +// (i.e. a non-nil pointer is never considered empty). +func WithoutDereference(config *Config) { + config.ShouldNotDereference = true +} + // WithAppendSlice will make merge append slices instead of overwriting it. func WithAppendSlice(config *Config) { config.AppendSlice = true diff --git a/mergo.go b/mergo.go index d2af2a9..dd61685 100644 --- a/mergo.go +++ b/mergo.go @@ -34,7 +34,7 @@ type visit struct { } // From src/pkg/encoding/json/encode.go. -func isEmptyValue(v reflect.Value) bool { +func isEmptyValue(v reflect.Value, shouldDereference bool) bool { switch v.Kind() { case reflect.Array, reflect.Map, reflect.Slice, reflect.String: return v.Len() == 0 @@ -50,7 +50,10 @@ func isEmptyValue(v reflect.Value) bool { if v.IsNil() { return true } - return isEmptyValue(v.Elem()) + if shouldDereference { + return isEmptyValue(v.Elem(), shouldDereference) + } + return false case reflect.Func: return v.IsNil() case reflect.Invalid: