diff --git a/copier.go b/copier.go index 6d21da8..f1afc36 100644 --- a/copier.go +++ b/copier.go @@ -203,6 +203,8 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) // check source if source.IsValid() { + copyUnexportedStructFields(dest, source) + // Copy from source field to dest field or method fromTypeFields := deepFields(fromType) for _, field := range fromTypeFields { @@ -334,6 +336,24 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) return } +func copyUnexportedStructFields(to, from reflect.Value) { + if from.Kind() != reflect.Struct || to.Kind() != reflect.Struct || !from.Type().AssignableTo(to.Type()) { + return + } + + // create a shallow copy of 'to' to get all fields + tmp := indirect(reflect.New(to.Type())) + tmp.Set(from) + + // revert exported fields + for i := 0; i < to.NumField(); i++ { + if tmp.Field(i).CanSet() { + tmp.Field(i).Set(to.Field(i)) + } + } + to.Set(tmp) +} + func shouldIgnore(v reflect.Value, ignoreEmpty bool) bool { if !ignoreEmpty { return false @@ -352,10 +372,10 @@ func deepFields(reflectType reflect.Type) []reflect.StructField { // field name. It is empty for upper case (exported) field names. // See https://golang.org/ref/spec#Uniqueness_of_identifiers if v.PkgPath == "" { + fields = append(fields, v) if v.Anonymous { + // also consider fields of anonymous fields as fields of the root fields = append(fields, deepFields(v.Type)...) - } else { - fields = append(fields, v) } } } diff --git a/copier_test.go b/copier_test.go index a3293e7..2799c85 100644 --- a/copier_test.go +++ b/copier_test.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "fmt" + "reflect" "testing" "time" @@ -1292,3 +1293,180 @@ func TestDeepCopyInterface(t *testing.T) { t.Errorf("to value failed to be deep copied") } } + +func TestDeepCopyTime(t *testing.T) { + type embedT1 struct { + T5 time.Time + } + + type embedT2 struct { + T6 *time.Time + } + + var ( + from struct { + T1 time.Time + T2 *time.Time + + T3 *time.Time + T4 time.Time + T5 time.Time + T6 time.Time + } + + to struct { + T1 time.Time + T2 *time.Time + + T3 time.Time + T4 *time.Time + embedT1 + embedT2 + } + ) + + t1 := time.Now() + from.T1 = t1 + t2 := t1.Add(time.Second) + from.T2 = &t2 + t3 := t2.Add(time.Second) + from.T3 = &t3 + t4 := t3.Add(time.Second) + from.T4 = t4 + t5 := t4.Add(time.Second) + from.T5 = t5 + t6 := t5.Add(time.Second) + from.T6 = t6 + + err := copier.CopyWithOption(&to, from, copier.Option{DeepCopy: true}) + if err != nil { + t.Error("Should not raise error") + } + + if !to.T1.Equal(from.T1) { + t.Errorf("Field T1 should be copied") + } + if !to.T2.Equal(*from.T2) { + t.Errorf("Field T2 should be copied") + } + if !to.T3.Equal(*from.T3) { + t.Errorf("Field T3 should be copied") + } + if !to.T4.Equal(from.T4) { + t.Errorf("Field T4 should be copied") + } + if !to.T5.Equal(from.T5) { + t.Errorf("Field T5 should be copied") + } + if !to.T6.Equal(from.T6) { + t.Errorf("Field T6 should be copied") + } +} + + +func TestNestedPrivateData(t *testing.T) { + type hasPrivate struct { + data int + } + + type hasMembers struct { + Member hasPrivate + } + + src := hasMembers{ + Member: hasPrivate{ + data: 42, + }, + } + var shallow hasMembers + err := copier.Copy(&shallow, &src) + if err != nil { + t.Errorf("could not complete shallow copy") + } + if !reflect.DeepEqual(&src, &shallow) { + t.Errorf("shallow copy faild") + } + + var deep hasMembers + err = copier.CopyWithOption(&deep, &src, copier.Option{DeepCopy: true}) + if err != nil { + t.Errorf("could not complete deep copy") + } + if !reflect.DeepEqual(&src, &deep) { + t.Errorf("deep copy faild") + } + + if !reflect.DeepEqual(&shallow, &deep) { + t.Errorf("unexpected difference between shallow and deep copy") + } +} + + +func TestDeepMapCopyTime(t *testing.T) { + t1 := time.Now() + t2 := t1.Add(time.Second) + from := []map[string]interface{}{ + { + "t1": t1, + "t2": &t2, + }, + } + to := make([]map[string]interface{}, len(from)) + + err := copier.CopyWithOption(&to, from, copier.Option{DeepCopy: true}) + if err != nil { + t.Error("should not error") + } + if len(to) != len(from) { + t.Errorf("slice should be copied") + } + if !to[0]["t1"].(time.Time).Equal(from[0]["t1"].(time.Time)) { + t.Errorf("nested time ptr should be copied") + } + if !to[0]["t2"].(*time.Time).Equal(*from[0]["t2"].(*time.Time)) { + t.Errorf("nested time ptr should be copied") + } +} + +func TestCopySimpleTime(t *testing.T) { + from := time.Now() + to := time.Time{} + + err := copier.Copy(&to, from) + if err != nil { + t.Error("should not error") + } + if !from.Equal(to) { + t.Errorf("to (%v) value should equal from (%v) value", to, from) + } +} + +func TestDeepCopySimpleTime(t *testing.T) { + from := time.Now() + to := time.Time{} + + err := copier.CopyWithOption(&to, from, copier.Option{DeepCopy: true}) + if err != nil { + t.Error("should not error") + } + if !from.Equal(to) { + t.Errorf("to (%v) value should equal from (%v) value", to, from) + } +} + +type TimeWrapper struct{ + time.Time +} + +func TestDeepCopyAnonymousFieldTime(t *testing.T) { + from := TimeWrapper{time.Now()} + to := TimeWrapper{} + + err := copier.CopyWithOption(&to, from, copier.Option{DeepCopy: true}) + if err != nil { + t.Error("should not error") + } + if !from.Time.Equal(to.Time) { + t.Errorf("to (%v) value should equal from (%v) value", to.Time, from.Time) + } +}