Skip to content

Commit

Permalink
Merge pull request #113 from joshhardy/fix-deepcopy-unexported-fields
Browse files Browse the repository at this point in the history
Copy un-exported struct fields in DeepCopy
  • Loading branch information
jinzhu authored Dec 13, 2021
2 parents 5de5170 + b369e8a commit 633a171
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 2 deletions.
24 changes: 22 additions & 2 deletions copier.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)

// check source
if source.IsValid() {
copyUnexportedStructFields(dest, source)

// Copy from source field to dest field or method
fromTypeFields := deepFields(fromType)
for _, field := range fromTypeFields {
Expand Down Expand Up @@ -334,6 +336,24 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
return
}

func copyUnexportedStructFields(to, from reflect.Value) {
if from.Kind() != reflect.Struct || to.Kind() != reflect.Struct || !from.Type().AssignableTo(to.Type()) {
return
}

// create a shallow copy of 'to' to get all fields
tmp := indirect(reflect.New(to.Type()))
tmp.Set(from)

// revert exported fields
for i := 0; i < to.NumField(); i++ {
if tmp.Field(i).CanSet() {
tmp.Field(i).Set(to.Field(i))
}
}
to.Set(tmp)
}

func shouldIgnore(v reflect.Value, ignoreEmpty bool) bool {
if !ignoreEmpty {
return false
Expand All @@ -352,10 +372,10 @@ func deepFields(reflectType reflect.Type) []reflect.StructField {
// field name. It is empty for upper case (exported) field names.
// See https://golang.org/ref/spec#Uniqueness_of_identifiers
if v.PkgPath == "" {
fields = append(fields, v)
if v.Anonymous {
// also consider fields of anonymous fields as fields of the root
fields = append(fields, deepFields(v.Type)...)
} else {
fields = append(fields, v)
}
}
}
Expand Down
178 changes: 178 additions & 0 deletions copier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql"
"errors"
"fmt"
"reflect"
"testing"
"time"

Expand Down Expand Up @@ -1292,3 +1293,180 @@ func TestDeepCopyInterface(t *testing.T) {
t.Errorf("to value failed to be deep copied")
}
}

func TestDeepCopyTime(t *testing.T) {
type embedT1 struct {
T5 time.Time
}

type embedT2 struct {
T6 *time.Time
}

var (
from struct {
T1 time.Time
T2 *time.Time

T3 *time.Time
T4 time.Time
T5 time.Time
T6 time.Time
}

to struct {
T1 time.Time
T2 *time.Time

T3 time.Time
T4 *time.Time
embedT1
embedT2
}
)

t1 := time.Now()
from.T1 = t1
t2 := t1.Add(time.Second)
from.T2 = &t2
t3 := t2.Add(time.Second)
from.T3 = &t3
t4 := t3.Add(time.Second)
from.T4 = t4
t5 := t4.Add(time.Second)
from.T5 = t5
t6 := t5.Add(time.Second)
from.T6 = t6

err := copier.CopyWithOption(&to, from, copier.Option{DeepCopy: true})
if err != nil {
t.Error("Should not raise error")
}

if !to.T1.Equal(from.T1) {
t.Errorf("Field T1 should be copied")
}
if !to.T2.Equal(*from.T2) {
t.Errorf("Field T2 should be copied")
}
if !to.T3.Equal(*from.T3) {
t.Errorf("Field T3 should be copied")
}
if !to.T4.Equal(from.T4) {
t.Errorf("Field T4 should be copied")
}
if !to.T5.Equal(from.T5) {
t.Errorf("Field T5 should be copied")
}
if !to.T6.Equal(from.T6) {
t.Errorf("Field T6 should be copied")
}
}


func TestNestedPrivateData(t *testing.T) {
type hasPrivate struct {
data int
}

type hasMembers struct {
Member hasPrivate
}

src := hasMembers{
Member: hasPrivate{
data: 42,
},
}
var shallow hasMembers
err := copier.Copy(&shallow, &src)
if err != nil {
t.Errorf("could not complete shallow copy")
}
if !reflect.DeepEqual(&src, &shallow) {
t.Errorf("shallow copy faild")
}

var deep hasMembers
err = copier.CopyWithOption(&deep, &src, copier.Option{DeepCopy: true})
if err != nil {
t.Errorf("could not complete deep copy")
}
if !reflect.DeepEqual(&src, &deep) {
t.Errorf("deep copy faild")
}

if !reflect.DeepEqual(&shallow, &deep) {
t.Errorf("unexpected difference between shallow and deep copy")
}
}


func TestDeepMapCopyTime(t *testing.T) {
t1 := time.Now()
t2 := t1.Add(time.Second)
from := []map[string]interface{}{
{
"t1": t1,
"t2": &t2,
},
}
to := make([]map[string]interface{}, len(from))

err := copier.CopyWithOption(&to, from, copier.Option{DeepCopy: true})
if err != nil {
t.Error("should not error")
}
if len(to) != len(from) {
t.Errorf("slice should be copied")
}
if !to[0]["t1"].(time.Time).Equal(from[0]["t1"].(time.Time)) {
t.Errorf("nested time ptr should be copied")
}
if !to[0]["t2"].(*time.Time).Equal(*from[0]["t2"].(*time.Time)) {
t.Errorf("nested time ptr should be copied")
}
}

func TestCopySimpleTime(t *testing.T) {
from := time.Now()
to := time.Time{}

err := copier.Copy(&to, from)
if err != nil {
t.Error("should not error")
}
if !from.Equal(to) {
t.Errorf("to (%v) value should equal from (%v) value", to, from)
}
}

func TestDeepCopySimpleTime(t *testing.T) {
from := time.Now()
to := time.Time{}

err := copier.CopyWithOption(&to, from, copier.Option{DeepCopy: true})
if err != nil {
t.Error("should not error")
}
if !from.Equal(to) {
t.Errorf("to (%v) value should equal from (%v) value", to, from)
}
}

type TimeWrapper struct{
time.Time
}

func TestDeepCopyAnonymousFieldTime(t *testing.T) {
from := TimeWrapper{time.Now()}
to := TimeWrapper{}

err := copier.CopyWithOption(&to, from, copier.Option{DeepCopy: true})
if err != nil {
t.Error("should not error")
}
if !from.Time.Equal(to.Time) {
t.Errorf("to (%v) value should equal from (%v) value", to.Time, from.Time)
}
}

0 comments on commit 633a171

Please sign in to comment.