Skip to content

Commit

Permalink
Merge pull request #74 from timsolov/master
Browse files Browse the repository at this point in the history
Full support sql.Null* structs on both sides. Copy to Scanner interface and from Valuer interface.
  • Loading branch information
jinzhu authored Jan 15, 2021
2 parents e9a219c + 916e662 commit d7c71e4
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 3 deletions.
59 changes: 57 additions & 2 deletions copier.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package copier

import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"strings"
Expand Down Expand Up @@ -333,8 +334,22 @@ func set(to, from reflect.Value, deepCopy bool) bool {
to.Set(reflect.Zero(to.Type()))
return true
} else if to.IsNil() {
// `from` -> `to`
// sql.NullString -> *string
if fromValuer, ok := driverValuer(from); ok {
v, err := fromValuer.Value()
if err != nil {
return false
}
// if `from` is not valid do nothing with `to`
if v == nil {
return true
}
}
// allocate new `to` variable with default value (eg. *string -> new(string))
to.Set(reflect.New(to.Type().Elem()))
}
// depointer `to`
to = to.Elem()
}

Expand All @@ -351,10 +366,39 @@ func set(to, from reflect.Value, deepCopy bool) bool {

if from.Type().ConvertibleTo(to.Type()) {
to.Set(from.Convert(to.Type()))
} else if scanner, ok := to.Addr().Interface().(sql.Scanner); ok {
if err := scanner.Scan(from.Interface()); err != nil {
} else if toScanner, ok := to.Addr().Interface().(sql.Scanner); ok {
// `from` -> `to`
// *string -> sql.NullString
if from.Kind() == reflect.Ptr {
// if `from` is nil do nothing with `to`
if from.IsNil() {
return true
}
// depointer `from`
from = indirect(from)
}
// `from` -> `to`
// string -> sql.NullString
// set `to` by invoking method Scan(`from`)
err := toScanner.Scan(from.Interface())
if err != nil {
return false
}
} else if fromValuer, ok := driverValuer(from); ok {
// `from` -> `to`
// sql.NullString -> string
v, err := fromValuer.Value()
if err != nil {
return false
}
// if `from` is not valid do nothing with `to`
if v == nil {
return true
}
rv := reflect.ValueOf(v)
if rv.Type().AssignableTo(to.Type()) {
to.Set(rv)
}
} else if from.Kind() == reflect.Ptr {
return set(to, from.Elem(), deepCopy)
} else {
Expand Down Expand Up @@ -412,3 +456,14 @@ func checkBitFlags(flagsList map[string]uint8) (err error) {
}
return
}

func driverValuer(v reflect.Value) (i driver.Valuer, ok bool) {

if !v.CanAddr() {
i, ok = v.Interface().(driver.Valuer)
return
}

i, ok = v.Addr().Interface().(driver.Valuer)
return
}
75 changes: 74 additions & 1 deletion copier_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package copier_test

import (
"database/sql"
"errors"
"testing"
"time"
Expand Down Expand Up @@ -171,7 +172,7 @@ func TestCopyFromStructToSlice(t *testing.T) {
}

func TestCopyFromSliceToSlice(t *testing.T) {
users := []User{User{Name: "Jinzhu", Age: 18, Role: "Admin", Notes: []string{"hello world"}}, User{Name: "Jinzhu2", Age: 22, Role: "Dev", Notes: []string{"hello world", "hello"}}}
users := []User{{Name: "Jinzhu", Age: 18, Role: "Admin", Notes: []string{"hello world"}}, {Name: "Jinzhu2", Age: 22, Role: "Dev", Notes: []string{"hello world", "hello"}}}
employees := []Employee{}

if copier.Copy(&employees, users); len(employees) != 2 {
Expand Down Expand Up @@ -1111,3 +1112,75 @@ func TestScanner(t *testing.T) {
t.Errorf("Field V should be copied")
}
}

func TestScanFromPtrToSqlNullable(t *testing.T) {

var (
from struct {
S string
Sptr *string
T1 sql.NullTime
T2 sql.NullTime
T3 *time.Time
}

to struct {
S sql.NullString
Sptr sql.NullString
T1 time.Time
T2 *time.Time
T3 sql.NullTime
}

s string

err error
)

s = "test"
from.S = s
from.Sptr = &s

if from.T1.Valid || from.T2.Valid {
t.Errorf("Must be not valid")
}

err = copier.Copy(&to, from)
if err != nil {
t.Error("Should not raise error")
}

if !to.T1.IsZero() {
t.Errorf("to.T1 should be Zero but %v", to.T1)
}

if to.T2 != nil && !to.T2.IsZero() {
t.Errorf("to.T2 should be Zero but %v", to.T2)
}

now := time.Now()

from.T1.Scan(now)
from.T2.Scan(now)

err = copier.Copy(&to, from)
if err != nil {
t.Error("Should not raise error")
}

if to.S.String != from.S {
t.Errorf("Field S should be copied")
}

if to.Sptr.String != *from.Sptr {
t.Errorf("Field Sptr should be copied")
}

if from.T1.Time != to.T1 {
t.Errorf("Fields T1 fields should be equal")
}

if from.T2.Time != *to.T2 {
t.Errorf("Fields T2 fields should be equal")
}
}

0 comments on commit d7c71e4

Please sign in to comment.