forked from gocraft/dbr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
load.go
144 lines (126 loc) · 3.1 KB
/
load.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
package dbr
import (
"database/sql"
"reflect"
)
type interfaceLoader struct {
v interface{}
typ reflect.Type
}
func InterfaceLoader(value interface{}, concreteType interface{}) interface{} {
return interfaceLoader{value, reflect.TypeOf(concreteType)}
}
// Load loads any value from sql.Rows.
//
// value can be:
//
// 1. simple type like int64, string, etc.
//
// 2. sql.Scanner, which allows loading with custom types.
//
// 3. map; the first column from SQL result loaded to the key,
// and the rest of columns will be loaded into the value.
// This is useful to dedup SQL result with first column.
//
// 4. map of slice; like map, values with the same key are
// collected with a slice.
func Load(rows *sql.Rows, value interface{}) (int, error) {
defer rows.Close()
column, err := rows.Columns()
if err != nil {
return 0, err
}
ptr := make([]interface{}, len(column))
var v reflect.Value
var elemType reflect.Type
if il, ok := value.(interfaceLoader); ok {
v = reflect.ValueOf(il.v)
elemType = il.typ
} else {
v = reflect.ValueOf(value)
}
if v.Kind() != reflect.Ptr || v.IsNil() {
return 0, ErrInvalidPointer
}
v = v.Elem()
isScanner := v.Addr().Type().Implements(typeScanner)
isSlice := v.Kind() == reflect.Slice && v.Type().Elem().Kind() != reflect.Uint8 && !isScanner
isMap := v.Kind() == reflect.Map && !isScanner
isMapOfSlices := isMap && v.Type().Elem().Kind() == reflect.Slice && v.Type().Elem().Elem().Kind() != reflect.Uint8
if isMap {
v.Set(reflect.MakeMap(v.Type()))
}
s := newTagStore()
count := 0
for rows.Next() {
var elem, keyElem reflect.Value
if elemType != nil {
elem = reflectAlloc(elemType)
} else if isMapOfSlices {
elem = reflectAlloc(v.Type().Elem().Elem())
} else if isSlice || isMap {
elem = reflectAlloc(v.Type().Elem())
} else {
elem = v
}
if isMap {
err := s.findPtr(elem, column[1:], ptr[1:])
if err != nil {
return 0, err
}
keyElem = reflectAlloc(v.Type().Key())
err = s.findPtr(keyElem, column[:1], ptr[:1])
if err != nil {
return 0, err
}
} else {
err := s.findPtr(elem, column, ptr)
if err != nil {
return 0, err
}
}
// Before scanning, set nil pointer to dummy dest.
// After that, reset pointers to nil for the next batch.
for i := range ptr {
if ptr[i] == nil {
ptr[i] = dummyDest
}
}
err = rows.Scan(ptr...)
if err != nil {
return 0, err
}
for i := range ptr {
ptr[i] = nil
}
count++
if isSlice {
v.Set(reflect.Append(v, elem))
} else if isMapOfSlices {
s := v.MapIndex(keyElem)
if !s.IsValid() {
s = reflect.Zero(v.Type().Elem())
}
v.SetMapIndex(keyElem, reflect.Append(s, elem))
} else if isMap {
v.SetMapIndex(keyElem, elem)
} else {
break
}
}
return count, rows.Err()
}
func reflectAlloc(typ reflect.Type) reflect.Value {
if typ.Kind() == reflect.Ptr {
return reflect.New(typ.Elem())
}
return reflect.New(typ).Elem()
}
type dummyScanner struct{}
func (dummyScanner) Scan(interface{}) error {
return nil
}
var (
dummyDest sql.Scanner = dummyScanner{}
typeScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
)