diff --git a/merge.go b/merge.go index f8de6c5..9c1ba3b 100644 --- a/merge.go +++ b/merge.go @@ -28,6 +28,7 @@ func hasExportedField(dst reflect.Value) (exported bool) { type Config struct { Overwrite bool AppendSlice bool + TypeCheck bool Transformers Transformers overwriteWithEmptyValue bool } @@ -41,6 +42,7 @@ type Transformers interface { // short circuiting on recursive types. func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, config *Config) (err error) { overwrite := config.Overwrite + typeCheck := config.TypeCheck overwriteWithEmptySrc := config.overwriteWithEmptyValue config.overwriteWithEmptyValue = false @@ -129,10 +131,13 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co } if (!isEmptyValue(src) || overwriteWithEmptySrc) && (overwrite || isEmptyValue(dst)) && !config.AppendSlice { + if typeCheck && srcSlice.Type() != dstSlice.Type() { + return fmt.Errorf("cannot override two slices with different type (%s, %s)", srcSlice.Type(), dstSlice.Type()) + } dstSlice = srcSlice } else if config.AppendSlice { if srcSlice.Type() != dstSlice.Type() { - return fmt.Errorf("cannot append two slice with different type (%s, %s)", srcSlice.Type(), dstSlice.Type()) + return fmt.Errorf("cannot append two slices with different type (%s, %s)", srcSlice.Type(), dstSlice.Type()) } dstSlice = reflect.AppendSlice(dstSlice, srcSlice) } @@ -228,11 +233,16 @@ func WithOverride(config *Config) { config.Overwrite = true } -// WithAppendSlice will make merge append slices instead of overwriting it +// WithAppendSlice will make merge append slices instead of overwriting it. func WithAppendSlice(config *Config) { config.AppendSlice = true } +// WithTypeCheck will make merge check types while overwriting it (must be used with WithOverride). +func WithTypeCheck(config *Config) { + config.TypeCheck = true +} + func merge(dst, src interface{}, opts ...func(*Config)) error { var ( vDst, vSrc reflect.Value diff --git a/mergo_test.go b/mergo_test.go index 1f2ab0b..2b7ba17 100644 --- a/mergo_test.go +++ b/mergo_test.go @@ -8,6 +8,7 @@ package mergo import ( "io/ioutil" "reflect" + "strings" "testing" "time" @@ -734,23 +735,43 @@ func TestBooleanPointer(t *testing.T) { } func TestMergeMapWithInnerSliceOfDifferentType(t *testing.T) { - src := map[string]interface{}{ - "foo": []string{"a", "b"}, - } - dst := map[string]interface{}{ - "foo": []int{1, 2}, + testCases := []struct { + name string + options []func(*Config) + err string + }{ + { + "With override and append slice", + []func(*Config){WithOverride, WithAppendSlice}, + "cannot append two slices with different type", + }, + { + "With override and type check", + []func(*Config){WithOverride, WithTypeCheck}, + "cannot override two slices with different type", + }, } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + src := map[string]interface{}{ + "foo": []string{"a", "b"}, + } + dst := map[string]interface{}{ + "foo": []int{1, 2}, + } - if err := Merge(&src, &dst, WithOverride, WithAppendSlice); err == nil { - t.Fatal("expected an error, got nothing") + if err := Merge(&src, &dst, tc.options...); err == nil || !strings.Contains(err.Error(), tc.err) { + t.Fatalf("expected %q, got %q", tc.err, err) + } + }) } } -func TestMergeSliceDifferentType(t *testing.T) { +func TestMergeSlicesIsNotSupported(t *testing.T) { src := []string{"a", "b"} dst := []int{1, 2} - if err := Merge(&src, &dst, WithOverride, WithAppendSlice); err == nil { - t.Fatal("expected an error, got nothing") + if err := Merge(&src, &dst, WithOverride, WithAppendSlice); err != ErrNotSupported { + t.Fatalf("expected %q, got %q", ErrNotSupported, err) } }