From 4f1c6c181a2b1424905975e006977bf8b1e957c8 Mon Sep 17 00:00:00 2001 From: rahul yadav Date: Fri, 29 Dec 2023 12:50:26 +0530 Subject: [PATCH 1/9] feat(spanner): add SelectAll method to decode from Spanner iterator.Rows to golang struct --- spanner/read.go | 6 ++ spanner/row.go | 129 +++++++++++++++++++++++++ spanner/value.go | 12 ++- spanner/value_benchmarks_test.go | 159 +++++++++++++++++++++++++++++++ 4 files changed, 302 insertions(+), 4 deletions(-) diff --git a/spanner/read.go b/spanner/read.go index 2651ac9535de..7ea0837add06 100644 --- a/spanner/read.go +++ b/spanner/read.go @@ -90,6 +90,12 @@ func streamWithReplaceSessionFunc( } } +type Iterator interface { + Next() (*Row, error) + Do(f func(r *Row) error) error + Stop() +} + // RowIterator is an iterator over Rows. type RowIterator struct { // The plan for the query. Available after RowIterator.Next returns diff --git a/spanner/row.go b/spanner/row.go index 4dca9be91f0e..37232f0eff75 100644 --- a/spanner/row.go +++ b/spanner/row.go @@ -249,6 +249,14 @@ func errColNotFound(n string) error { return spannerErrorf(codes.NotFound, "column %q not found", n) } +func errNotASlicePointer() error { + return spannerErrorf(codes.InvalidArgument, "destination must be a pointer to a slice") +} + +func errTooManyColumns() error { + return spannerErrorf(codes.InvalidArgument, "too many columns returned for primitive slice") +} + // ColumnByName fetches the value from the named column, decoding it into ptr. // See the Row documentation for the list of acceptable argument types. func (r *Row) ColumnByName(name string, ptr interface{}) error { @@ -378,3 +386,124 @@ func (r *Row) ToStructLenient(p interface{}) error { true, ) } + +// SelectAll scans rows into a slice (v) +func SelectAll(rows Iterator, v interface{}, options ...DecodeOptions) error { + if rows == nil { + return fmt.Errorf("rows is nil") + } + if v == nil { + return fmt.Errorf("p is nil") + } + vType := reflect.TypeOf(v) + if k := vType.Kind(); k != reflect.Ptr { + return errToStructArgType(v) + } + sliceType := vType.Elem() + if reflect.Slice != sliceType.Kind() { + return errNotASlicePointer() + } + sliceVal := reflect.Indirect(reflect.ValueOf(v)) + itemType := sliceType.Elem() + s := &decodeSetting{} + for _, opt := range options { + opt.Apply(s) + } + + isPrimitive := itemType.Kind() != reflect.Struct + var pointers []interface{} + var err error + if err := rows.Do(func(row *Row) error { + sliceItem := reflect.New(itemType).Elem() + if len(pointers) == 0 { + if isPrimitive { + if len(row.fields) > 1 { + return errTooManyColumns() + } + pointers = []interface{}{sliceItem.Addr().Interface()} + } else { + if pointers, err = structPointers(sliceItem, row.fields, s.Lenient); err != nil { + return err + } + } + } + if len(pointers) == 0 { + return nil + } + err := row.Columns(pointers...) + if err != nil { + return err + } + if len(pointers) > 0 { + dst := sliceItem.Addr().Interface() + for i, p := range pointers { + reflect.ValueOf(dst).Elem().Field(i).Set(reflect.ValueOf(p).Elem()) + } + } + sliceVal.Set(reflect.Append(sliceVal, sliceItem)) + return nil + }); err != nil { + return err + } + return nil +} + +func structPointers(sliceItem reflect.Value, cols []*sppb.StructType_Field, strict bool) ([]interface{}, error) { + pointers := make([]interface{}, 0, len(cols)) + fieldTag := make(map[string]reflect.Value, len(cols)) + initFieldTag(sliceItem, &fieldTag) + + for _, colName := range cols { + var fieldVal reflect.Value + if v, ok := fieldTag[colName.GetName()]; ok { + fieldVal = v + } else { + if strict { + return nil, errNoOrDupGoField(sliceItem, colName.GetName()) + } else { + fieldVal = sliceItem.FieldByName(colName.GetName()) + } + } + if !fieldVal.IsValid() || !fieldVal.CanSet() { + // have to add if we found a column because Scan() requires + // len(cols) arguments or it will error. This way we can scan to + // a useless pointer + var nothing interface{} + pointers = append(pointers, ¬hing) + continue + } + + pointers = append(pointers, fieldVal.Addr().Interface()) + } + return pointers, nil +} + +// Initialization the tags from struct. +func initFieldTag(sliceItem reflect.Value, fieldTagMap *map[string]reflect.Value) { + typ := sliceItem.Type() + + for i := 0; i < sliceItem.NumField(); i++ { + fieldType := typ.Field(i) + exported := (fieldType.PkgPath == "") + // If a named field is unexported, ignore it. An anonymous + // unexported field is processed, because it may contain + // exported fields, which are visible. + if !exported && !fieldType.Anonymous { + continue + } + if fieldType.Type.Kind() == reflect.Struct { + // found an embedded struct + sliceItemOfAnonymous := sliceItem.Field(i) + initFieldTag(sliceItemOfAnonymous, fieldTagMap) + continue + } + name, keep, _, _ := spannerTagParser(fieldType.Tag) + if !keep { + continue + } + if name == "" { + name = fieldType.Name + } + (*fieldTagMap)[name] = sliceItem.Field(i) + } +} diff --git a/spanner/value.go b/spanner/value.go index d0360c5c16c0..919c9f396b40 100644 --- a/spanner/value.go +++ b/spanner/value.go @@ -1032,7 +1032,7 @@ func parseNullTime(v *proto3.Value, p *NullTime, code sppb.TypeCode, isNull bool // decodeValue decodes a protobuf Value into a pointer to a Go value, as // specified by sppb.Type. -func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}, opts ...decodeOptions) error { +func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}, opts ...DecodeOptions) error { if v == nil { return errNilSrc() } @@ -3198,8 +3198,8 @@ type decodeSetting struct { Lenient bool } -// decodeOptions is the interface to change decode struct settings -type decodeOptions interface { +// DecodeOptions is the interface to change decode struct settings +type DecodeOptions interface { Apply(s *decodeSetting) } @@ -3209,6 +3209,10 @@ func (w withLenient) Apply(s *decodeSetting) { s.Lenient = w.lenient } +func WithLenient() DecodeOptions { + return withLenient{lenient: true} +} + // decodeStruct decodes proto3.ListValue pb into struct referenced by pointer // ptr, according to // the structural information given in sppb.StructType ty. @@ -3253,7 +3257,7 @@ func decodeStruct(ty *sppb.StructType, pb *proto3.ListValue, ptr interface{}, le // We don't allow duplicated field name. return errDupSpannerField(f.Name, ty) } - opts := []decodeOptions{withLenient{lenient: lenient}} + opts := []DecodeOptions{withLenient{lenient: lenient}} // Try to decode a single field. if err := decodeValue(pb.Values[i], f.Type, v.FieldByIndex(sf.Index).Addr().Interface(), opts...); err != nil { return errDecodeStructField(ty, f.Name, err) diff --git a/spanner/value_benchmarks_test.go b/spanner/value_benchmarks_test.go index 7a9334b6a797..ce1eceff6dd6 100644 --- a/spanner/value_benchmarks_test.go +++ b/spanner/value_benchmarks_test.go @@ -15,6 +15,7 @@ package spanner import ( + "fmt" "reflect" "strconv" "testing" @@ -22,6 +23,7 @@ import ( "cloud.google.com/go/civil" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" proto3 "github.com/golang/protobuf/ptypes/struct" + "google.golang.org/api/iterator" ) func BenchmarkEncodeIntArray(b *testing.B) { @@ -230,3 +232,160 @@ func decodeArrayReflect(pb *proto3.ListValue, name string, typ *sppb.Type, aptr } return nil } + +func BenchmarkScan100RowsUsingSelectAll(b *testing.B) { + var rows []struct { + ID int64 + Name string + } + for i := 0; i < 100; i++ { + rows = append(rows, struct { + ID int64 + Name string + }{int64(i), fmt.Sprintf("name-%d", i)}) + } + src := mockIterator(b, rows) + for n := 0; n < b.N; n++ { + it := *src + var res []struct { + ID int64 + Name string + } + if err := SelectAll(&it, &res); err != nil { + b.Fatal(err) + } + _ = res + } +} + +func BenchmarkScan100RowsUsingToStruct(b *testing.B) { + var rows []struct { + ID int64 + Name string + } + for i := 0; i < 100; i++ { + rows = append(rows, struct { + ID int64 + Name string + }{int64(i), fmt.Sprintf("name-%d", i)}) + } + src := mockIterator(b, rows) + for n := 0; n < b.N; n++ { + it := *src + var res []struct { + ID int64 + Name string + } + for { + row, err := it.Next() + if err == iterator.Done { + break + } else if err != nil { + b.Fatal(err) + } + var r struct { + ID int64 + Name string + } + err = row.ToStruct(&r) + if err != nil { + b.Fatal(err) + } + res = append(res, r) + } + it.Stop() + _ = res + } +} + +func BenchmarkScan100RowsUsingColumns(b *testing.B) { + var rows []struct { + ID int64 + Name string + } + for i := 0; i < 100; i++ { + rows = append(rows, struct { + ID int64 + Name string + }{int64(i), fmt.Sprintf("name-%d", i)}) + } + src := mockIterator(b, rows) + for n := 0; n < b.N; n++ { + it := *src + var res []struct { + ID int64 + Name string + } + for { + row, err := it.Next() + if err == iterator.Done { + break + } else if err != nil { + b.Fatal(err) + } + var r struct { + ID int64 + Name string + } + err = row.Columns(&r.ID, &r.Name) + if err != nil { + b.Fatal(err) + } + res = append(res, r) + } + it.Stop() + _ = res + } +} + +func mockIterator[T any](t testing.TB, rows []T) *mockIteratorImpl { + var v T + var colNames []string + numCols := reflect.TypeOf(v).NumField() + for i := 0; i < numCols; i++ { + f := reflect.TypeOf(v).Field(i) + colNames = append(colNames, f.Name) + } + var srows []*Row + for _, e := range rows { + var vs []any + for f := 0; f < numCols; f++ { + v := reflect.ValueOf(e).Field(f).Interface() + vs = append(vs, v) + } + row, err := NewRow(colNames, vs) + if err != nil { + t.Fatal(err) + } + srows = append(srows, row) + } + return &mockIteratorImpl{rows: srows} +} + +type mockIteratorImpl struct { + rows []*Row +} + +func (i *mockIteratorImpl) Next() (*Row, error) { + if len(i.rows) == 0 { + return nil, iterator.Done + } + row := i.rows[0] + i.rows = i.rows[1:] + return row, nil +} + +func (i *mockIteratorImpl) Stop() { + i.rows = nil +} + +func (i *mockIteratorImpl) Do(f func(*Row) error) error { + defer i.Stop() + for _, row := range i.rows { + err := f(row) + if err != nil { + return err + } + } + return nil +} From 627af8204168cb3a9de046bd60d7c08211ea4f97 Mon Sep 17 00:00:00 2001 From: rahul yadav Date: Fri, 29 Dec 2023 14:36:41 +0530 Subject: [PATCH 2/9] fix go vet --- spanner/read.go | 1 + spanner/row.go | 3 +-- spanner/value.go | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/spanner/read.go b/spanner/read.go index 7ea0837add06..e172f0811953 100644 --- a/spanner/read.go +++ b/spanner/read.go @@ -90,6 +90,7 @@ func streamWithReplaceSessionFunc( } } +// Iterator is an interface for iterating over Rows. type Iterator interface { Next() (*Row, error) Do(f func(r *Row) error) error diff --git a/spanner/row.go b/spanner/row.go index 37232f0eff75..33860cb147b7 100644 --- a/spanner/row.go +++ b/spanner/row.go @@ -460,9 +460,8 @@ func structPointers(sliceItem reflect.Value, cols []*sppb.StructType_Field, stri } else { if strict { return nil, errNoOrDupGoField(sliceItem, colName.GetName()) - } else { - fieldVal = sliceItem.FieldByName(colName.GetName()) } + fieldVal = sliceItem.FieldByName(colName.GetName()) } if !fieldVal.IsValid() || !fieldVal.CanSet() { // have to add if we found a column because Scan() requires diff --git a/spanner/value.go b/spanner/value.go index 919c9f396b40..83a8132f0592 100644 --- a/spanner/value.go +++ b/spanner/value.go @@ -3209,6 +3209,7 @@ func (w withLenient) Apply(s *decodeSetting) { s.Lenient = w.lenient } +// WithLenient returns a DecodeOptions that allows decoding into a struct with missing fields in database. func WithLenient() DecodeOptions { return withLenient{lenient: true} } From 9613853ccf45bd77beb9148fb6c0b7ba88d5f7db Mon Sep 17 00:00:00 2001 From: rahul yadav Date: Tue, 2 Jan 2024 19:08:51 +0530 Subject: [PATCH 3/9] incorporate suggestions --- spanner/row.go | 26 +++++++++++--------------- spanner/value_benchmarks_test.go | 3 +++ 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/spanner/row.go b/spanner/row.go index 33860cb147b7..936876af2f13 100644 --- a/spanner/row.go +++ b/spanner/row.go @@ -412,20 +412,22 @@ func SelectAll(rows Iterator, v interface{}, options ...DecodeOptions) error { isPrimitive := itemType.Kind() != reflect.Struct var pointers []interface{} - var err error - if err := rows.Do(func(row *Row) error { + isFistRow := true + return rows.Do(func(row *Row) error { sliceItem := reflect.New(itemType).Elem() - if len(pointers) == 0 { + if isFistRow { if isPrimitive { if len(row.fields) > 1 { return errTooManyColumns() } pointers = []interface{}{sliceItem.Addr().Interface()} } else { + var err error if pointers, err = structPointers(sliceItem, row.fields, s.Lenient); err != nil { return err } } + isFistRow = false } if len(pointers) == 0 { return nil @@ -434,18 +436,13 @@ func SelectAll(rows Iterator, v interface{}, options ...DecodeOptions) error { if err != nil { return err } - if len(pointers) > 0 { - dst := sliceItem.Addr().Interface() - for i, p := range pointers { - reflect.ValueOf(dst).Elem().Field(i).Set(reflect.ValueOf(p).Elem()) - } + dst := reflect.ValueOf(sliceItem.Addr().Interface()).Elem() + for i, p := range pointers { + dst.Field(i).Set(reflect.ValueOf(p).Elem()) } sliceVal.Set(reflect.Append(sliceVal, sliceItem)) return nil - }); err != nil { - return err - } - return nil + }) } func structPointers(sliceItem reflect.Value, cols []*sppb.StructType_Field, strict bool) ([]interface{}, error) { @@ -464,11 +461,10 @@ func structPointers(sliceItem reflect.Value, cols []*sppb.StructType_Field, stri fieldVal = sliceItem.FieldByName(colName.GetName()) } if !fieldVal.IsValid() || !fieldVal.CanSet() { - // have to add if we found a column because Scan() requires + // have to add if we found a column because Columns() requires // len(cols) arguments or it will error. This way we can scan to // a useless pointer - var nothing interface{} - pointers = append(pointers, ¬hing) + pointers = append(pointers, nil) continue } diff --git a/spanner/value_benchmarks_test.go b/spanner/value_benchmarks_test.go index ce1eceff6dd6..4e49350e9e3e 100644 --- a/spanner/value_benchmarks_test.go +++ b/spanner/value_benchmarks_test.go @@ -245,6 +245,7 @@ func BenchmarkScan100RowsUsingSelectAll(b *testing.B) { }{int64(i), fmt.Sprintf("name-%d", i)}) } src := mockIterator(b, rows) + b.ResetTimer() for n := 0; n < b.N; n++ { it := *src var res []struct { @@ -270,6 +271,7 @@ func BenchmarkScan100RowsUsingToStruct(b *testing.B) { }{int64(i), fmt.Sprintf("name-%d", i)}) } src := mockIterator(b, rows) + b.ResetTimer() for n := 0; n < b.N; n++ { it := *src var res []struct { @@ -310,6 +312,7 @@ func BenchmarkScan100RowsUsingColumns(b *testing.B) { }{int64(i), fmt.Sprintf("name-%d", i)}) } src := mockIterator(b, rows) + b.ResetTimer() for n := 0; n < b.N; n++ { it := *src var res []struct { From 7750ed0d81fe0dd72d4b19154691cacdaa55d0f3 Mon Sep 17 00:00:00 2001 From: rahul yadav Date: Tue, 2 Jan 2024 22:33:20 +0530 Subject: [PATCH 4/9] preallocate if returned rows count is known --- spanner/read.go | 20 ++++++++++++++++++++ spanner/row.go | 20 +++++++++++++++++--- spanner/value_benchmarks_test.go | 4 ++++ 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/spanner/read.go b/spanner/read.go index e172f0811953..815865d7c2b6 100644 --- a/spanner/read.go +++ b/spanner/read.go @@ -21,6 +21,7 @@ import ( "context" "io" "log" + "strconv" "sync/atomic" "time" @@ -95,6 +96,7 @@ type Iterator interface { Next() (*Row, error) Do(f func(r *Row) error) error Stop() + RowsReturned() int64 } // RowIterator is an iterator over Rows. @@ -128,6 +130,24 @@ type RowIterator struct { sawStats bool } +func (r *RowIterator) RowsReturned() int64 { + if r.sawStats && r.QueryStats != nil && r.QueryStats["rows_returned"] != nil { + switch r.QueryStats["rows_returned"].(type) { + case float64: + return r.QueryStats["rows_returned"].(int64) + case string: + v, err := strconv.Atoi(r.QueryStats["rows_returned"].(string)) + if err != nil { + return -1 + } + return int64(v) + default: + return -1 + } + } + return -1 +} + // Next returns the next result. Its second return value is iterator.Done if // there are no more results. Once Next returns Done, all subsequent calls // will return Done. diff --git a/spanner/row.go b/spanner/row.go index 936876af2f13..46662cfe9d15 100644 --- a/spanner/row.go +++ b/spanner/row.go @@ -413,9 +413,16 @@ func SelectAll(rows Iterator, v interface{}, options ...DecodeOptions) error { isPrimitive := itemType.Kind() != reflect.Struct var pointers []interface{} isFistRow := true + rowIndex := -1 return rows.Do(func(row *Row) error { sliceItem := reflect.New(itemType).Elem() if isFistRow { + nRows := rows.RowsReturned() + if nRows != -1 { + sliceVal = reflect.MakeSlice(sliceType, int(nRows), int(nRows)) + reflect.ValueOf(v).Elem().Set(sliceVal) + rowIndex++ + } if isPrimitive { if len(row.fields) > 1 { return errTooManyColumns() @@ -436,11 +443,18 @@ func SelectAll(rows Iterator, v interface{}, options ...DecodeOptions) error { if err != nil { return err } - dst := reflect.ValueOf(sliceItem.Addr().Interface()).Elem() for i, p := range pointers { - dst.Field(i).Set(reflect.ValueOf(p).Elem()) + if p == nil { + continue + } + sliceItem.Field(i).Set(reflect.ValueOf(p).Elem()) + } + if rowIndex >= 0 { + sliceVal.Index(rowIndex).Set(sliceItem) + rowIndex++ + } else { + sliceVal.Set(reflect.Append(sliceVal, sliceItem)) } - sliceVal.Set(reflect.Append(sliceVal, sliceItem)) return nil }) } diff --git a/spanner/value_benchmarks_test.go b/spanner/value_benchmarks_test.go index 4e49350e9e3e..ce68f0e88cc0 100644 --- a/spanner/value_benchmarks_test.go +++ b/spanner/value_benchmarks_test.go @@ -378,6 +378,10 @@ func (i *mockIteratorImpl) Next() (*Row, error) { return row, nil } +func (i *mockIteratorImpl) RowsReturned() int64 { + return int64(len(i.rows)) +} + func (i *mockIteratorImpl) Stop() { i.rows = nil } From 6d7cca8184bf3a09ee41384973d6392894cdd135 Mon Sep 17 00:00:00 2001 From: rahul yadav Date: Tue, 2 Jan 2024 22:43:23 +0530 Subject: [PATCH 5/9] fix go vet --- spanner/read.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/spanner/read.go b/spanner/read.go index 815865d7c2b6..8e419dc9d792 100644 --- a/spanner/read.go +++ b/spanner/read.go @@ -130,6 +130,9 @@ type RowIterator struct { sawStats bool } +// RowsReturned returns the number of rows returned by the query. If the query +// was a DML statement, the number of rows affected is returned. If the query +// was a PDML statement, the number of rows affected is a lower bound. func (r *RowIterator) RowsReturned() int64 { if r.sawStats && r.QueryStats != nil && r.QueryStats["rows_returned"] != nil { switch r.QueryStats["rows_returned"].(type) { From 4997f5ef89f08d39965b4504247e38c260738775 Mon Sep 17 00:00:00 2001 From: rahul yadav Date: Fri, 5 Jan 2024 23:03:08 +0530 Subject: [PATCH 6/9] incorporate suggestions --- spanner/read.go | 22 ++++++------ spanner/row.go | 92 +++++++++++++++++++++++++++++++++++++------------ 2 files changed, 82 insertions(+), 32 deletions(-) diff --git a/spanner/read.go b/spanner/read.go index 8e419dc9d792..65948f6c08a6 100644 --- a/spanner/read.go +++ b/spanner/read.go @@ -130,22 +130,24 @@ type RowIterator struct { sawStats bool } -// RowsReturned returns the number of rows returned by the query. If the query -// was a DML statement, the number of rows affected is returned. If the query -// was a PDML statement, the number of rows affected is a lower bound. +// RowsReturned returns, a lower bound on the number of rows returned by the query. +// Currently, this requires the query to be executed with query stats enabled. +// +// If the query was a DML statement, the number of rows affected is returned. +// If the query was a PDML statement, the number of rows affected is a lower bound. +// If the query was executed without query stats enabled, or if it is otherwise +// impossible to determine the number of rows in the resultset, -1 is returned. func (r *RowIterator) RowsReturned() int64 { if r.sawStats && r.QueryStats != nil && r.QueryStats["rows_returned"] != nil { - switch r.QueryStats["rows_returned"].(type) { + switch rowsReturned := r.QueryStats["rows_returned"].(type) { case float64: - return r.QueryStats["rows_returned"].(int64) + return int64(rowsReturned) case string: - v, err := strconv.Atoi(r.QueryStats["rows_returned"].(string)) + v, err := strconv.ParseInt(rowsReturned, 10, 64) if err != nil { - return -1 + v = -1 } - return int64(v) - default: - return -1 + return v } } return -1 diff --git a/spanner/row.go b/spanner/row.go index 46662cfe9d15..58bb69159c2e 100644 --- a/spanner/row.go +++ b/spanner/row.go @@ -252,6 +252,9 @@ func errColNotFound(n string) error { func errNotASlicePointer() error { return spannerErrorf(codes.InvalidArgument, "destination must be a pointer to a slice") } +func errNilSlicePointer() error { + return spannerErrorf(codes.InvalidArgument, "destination must be a non nil pointer") +} func errTooManyColumns() error { return spannerErrorf(codes.InvalidArgument, "too many columns returned for primitive slice") @@ -387,7 +390,24 @@ func (r *Row) ToStructLenient(p interface{}) error { ) } -// SelectAll scans rows into a slice (v) +// SelectAll iterates all rows to the end. After iterating it closes the rows, +// and propagates any errors that could pop up. +// It expects that destination should be a slice. For each row it scans data and appends it to the destination slice. +// SelectAll supports both types of slices: slice of structs by a pointer and slice of structs by value, +// for example: +// +// type Singer struct { +// ID string +// Name string +// } +// +// var singersByPtr []*Singer +// var singersByValue []Singer +// +// Both singersByPtr and singersByValue are valid destinations for SelectAll function. +// +// Before starting, SelectAll resets the destination slice, +// so if it's not empty it will overwrite all existing elements. func SelectAll(rows Iterator, v interface{}, options ...DecodeOptions) error { if rows == nil { return fmt.Errorf("rows is nil") @@ -395,16 +415,36 @@ func SelectAll(rows Iterator, v interface{}, options ...DecodeOptions) error { if v == nil { return fmt.Errorf("p is nil") } - vType := reflect.TypeOf(v) - if k := vType.Kind(); k != reflect.Ptr { - return errToStructArgType(v) + dstVal := reflect.ValueOf(v) + if !dstVal.IsValid() || (dstVal.Kind() == reflect.Ptr && dstVal.IsNil()) { + return errNilSlicePointer() + } + if dstVal.Kind() != reflect.Ptr { + return errNotASlicePointer() } - sliceType := vType.Elem() - if reflect.Slice != sliceType.Kind() { + dstVal = dstVal.Elem() + dstType := dstVal.Type() + if k := dstType.Kind(); k != reflect.Slice { return errNotASlicePointer() } - sliceVal := reflect.Indirect(reflect.ValueOf(v)) - itemType := sliceType.Elem() + + itemType := dstType.Elem() + var itemByPtr bool + // If it's a slice of pointers to structs, + // we handle it the same way as it would be slice of struct by value + // and dereference pointers to values, + // because eventually we work with fields. + // But if it's a slice of primitive type e.g. or []string or []*string, + // we must leave and pass elements as is. + if itemType.Kind() == reflect.Ptr { + elementBaseTypeElem := itemType.Elem() + if elementBaseTypeElem.Kind() == reflect.Struct { + itemType = elementBaseTypeElem + itemByPtr = true + } + } + // Make sure slice is empty. + dstVal.Set(dstVal.Slice(0, 0)) s := &decodeSetting{} for _, opt := range options { opt.Apply(s) @@ -412,16 +452,17 @@ func SelectAll(rows Iterator, v interface{}, options ...DecodeOptions) error { isPrimitive := itemType.Kind() != reflect.Struct var pointers []interface{} - isFistRow := true - rowIndex := -1 + isFirstRow := true return rows.Do(func(row *Row) error { - sliceItem := reflect.New(itemType).Elem() - if isFistRow { + sliceItem := reflect.New(itemType) + if isFirstRow { + defer func() { + isFirstRow = false + }() nRows := rows.RowsReturned() if nRows != -1 { - sliceVal = reflect.MakeSlice(sliceType, int(nRows), int(nRows)) - reflect.ValueOf(v).Elem().Set(sliceVal) - rowIndex++ + // nRows is lower bound of the number of rows returned by the query. + dstVal.Set(reflect.MakeSlice(dstType, 0, int(nRows))) } if isPrimitive { if len(row.fields) > 1 { @@ -430,11 +471,10 @@ func SelectAll(rows Iterator, v interface{}, options ...DecodeOptions) error { pointers = []interface{}{sliceItem.Addr().Interface()} } else { var err error - if pointers, err = structPointers(sliceItem, row.fields, s.Lenient); err != nil { + if pointers, err = structPointers(sliceItem.Elem(), row.fields, s.Lenient); err != nil { return err } } - isFistRow = false } if len(pointers) == 0 { return nil @@ -447,14 +487,22 @@ func SelectAll(rows Iterator, v interface{}, options ...DecodeOptions) error { if p == nil { continue } - sliceItem.Field(i).Set(reflect.ValueOf(p).Elem()) + sliceItem.Elem().Field(i).Set(reflect.ValueOf(p).Elem()) } - if rowIndex >= 0 { - sliceVal.Index(rowIndex).Set(sliceItem) - rowIndex++ + var elemVal reflect.Value + if itemByPtr { + if isFirstRow { + // create a new pointer to the struct with all the values copied from sliceIte + // because same underlying pointers array will be used for next rows + elemVal = reflect.New(itemType) + elemVal.Elem().Set(sliceItem.Elem()) + } else { + elemVal = sliceItem + } } else { - sliceVal.Set(reflect.Append(sliceVal, sliceItem)) + elemVal = sliceItem.Elem() } + dstVal.Set(reflect.Append(dstVal, elemVal)) return nil }) } From 9fa72ab619557ff44289e4e397209e39671f0474 Mon Sep 17 00:00:00 2001 From: rahul yadav Date: Fri, 5 Jan 2024 23:22:09 +0530 Subject: [PATCH 7/9] allocate when rowsReturned is lowerbound --- spanner/row.go | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/spanner/row.go b/spanner/row.go index 58bb69159c2e..9f609c0da3c8 100644 --- a/spanner/row.go +++ b/spanner/row.go @@ -453,16 +453,19 @@ func SelectAll(rows Iterator, v interface{}, options ...DecodeOptions) error { isPrimitive := itemType.Kind() != reflect.Struct var pointers []interface{} isFirstRow := true + rowIndex := int64(-1) + var rowsReturned int64 return rows.Do(func(row *Row) error { sliceItem := reflect.New(itemType) if isFirstRow { defer func() { isFirstRow = false }() - nRows := rows.RowsReturned() - if nRows != -1 { + rowsReturned = rows.RowsReturned() + if rowsReturned != -1 { // nRows is lower bound of the number of rows returned by the query. - dstVal.Set(reflect.MakeSlice(dstType, 0, int(nRows))) + dstVal.Set(reflect.MakeSlice(dstType, int(rowsReturned), int(rowsReturned))) + rowIndex++ } if isPrimitive { if len(row.fields) > 1 { @@ -502,7 +505,12 @@ func SelectAll(rows Iterator, v interface{}, options ...DecodeOptions) error { } else { elemVal = sliceItem.Elem() } - dstVal.Set(reflect.Append(dstVal, elemVal)) + if rowIndex >= 0 && rowsReturned > rowIndex { + dstVal.Index(int(rowIndex)).Set(elemVal) + rowIndex++ + } else { + dstVal.Set(reflect.Append(dstVal, elemVal)) + } return nil }) } From 8702fc361a7740f3d00a0a3186cf1b588f8d9018 Mon Sep 17 00:00:00 2001 From: rahul yadav Date: Mon, 15 Jan 2024 20:19:57 +0530 Subject: [PATCH 8/9] incorporate changes and add benchmark to compare test runs for 5 fields struct --- spanner/go.mod | 1 + spanner/go.sum | 1 + spanner/mocks.go | 95 ++++++++++ spanner/read.go | 29 +-- spanner/row.go | 75 ++++---- spanner/row_test.go | 293 +++++++++++++++++++++++++++++++ spanner/value_benchmarks_test.go | 92 +++++++++- 7 files changed, 509 insertions(+), 77 deletions(-) create mode 100644 spanner/mocks.go diff --git a/spanner/go.mod b/spanner/go.mod index 3ce2be8f4a25..e644dd057fb5 100644 --- a/spanner/go.mod +++ b/spanner/go.mod @@ -41,6 +41,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.46.1 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 // indirect go.opentelemetry.io/otel v1.21.0 // indirect diff --git a/spanner/go.sum b/spanner/go.sum index 2a6a530af4ca..6de839bc4b32 100644 --- a/spanner/go.sum +++ b/spanner/go.sum @@ -87,6 +87,7 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/spanner/mocks.go b/spanner/mocks.go new file mode 100644 index 000000000000..f7553f217e8c --- /dev/null +++ b/spanner/mocks.go @@ -0,0 +1,95 @@ +/* +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Code generated by mockery v2.40.1. DO NOT EDIT. + +package spanner + +import ( + mock "github.com/stretchr/testify/mock" + "google.golang.org/api/iterator" +) + +// mockRowIterator is an autogenerated mock type for the mockRowIterator type +type mockRowIterator struct { + mock.Mock +} + +// Next provides a mock function with given fields: +func (_m *mockRowIterator) Next() (*Row, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Next") + } + + var r0 *Row + var r1 error + if rf, ok := ret.Get(0).(func() (*Row, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() *Row); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*Row) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +func (_m *mockRowIterator) Do(f func(r *Row) error) error { + defer _m.Stop() + for { + row, err := _m.Next() + switch err { + case iterator.Done: + return nil + case nil: + if err = f(row); err != nil { + return err + } + default: + return err + } + } +} + +// Stop provides a mock function with given fields: +func (_m *mockRowIterator) Stop() { + _m.Called() +} + +// newRowIterator creates a new instance of mockRowIterator. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newRowIterator(t interface { + mock.TestingT + Cleanup(func()) +}) *mockRowIterator { + mock := &mockRowIterator{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/spanner/read.go b/spanner/read.go index 65948f6c08a6..93895c91d50c 100644 --- a/spanner/read.go +++ b/spanner/read.go @@ -21,7 +21,6 @@ import ( "context" "io" "log" - "strconv" "sync/atomic" "time" @@ -91,12 +90,11 @@ func streamWithReplaceSessionFunc( } } -// Iterator is an interface for iterating over Rows. -type Iterator interface { +// rowIterator is an interface for iterating over Rows. +type rowIterator interface { Next() (*Row, error) Do(f func(r *Row) error) error Stop() - RowsReturned() int64 } // RowIterator is an iterator over Rows. @@ -130,28 +128,7 @@ type RowIterator struct { sawStats bool } -// RowsReturned returns, a lower bound on the number of rows returned by the query. -// Currently, this requires the query to be executed with query stats enabled. -// -// If the query was a DML statement, the number of rows affected is returned. -// If the query was a PDML statement, the number of rows affected is a lower bound. -// If the query was executed without query stats enabled, or if it is otherwise -// impossible to determine the number of rows in the resultset, -1 is returned. -func (r *RowIterator) RowsReturned() int64 { - if r.sawStats && r.QueryStats != nil && r.QueryStats["rows_returned"] != nil { - switch rowsReturned := r.QueryStats["rows_returned"].(type) { - case float64: - return int64(rowsReturned) - case string: - v, err := strconv.ParseInt(rowsReturned, 10, 64) - if err != nil { - v = -1 - } - return v - } - } - return -1 -} +var _ rowIterator = (*RowIterator)(nil) // Next returns the next result. Its second return value is iterator.Done if // there are no more results. Once Next returns Done, all subsequent calls diff --git a/spanner/row.go b/spanner/row.go index 9f609c0da3c8..fa86b9cd97b5 100644 --- a/spanner/row.go +++ b/spanner/row.go @@ -252,6 +252,7 @@ func errColNotFound(n string) error { func errNotASlicePointer() error { return spannerErrorf(codes.InvalidArgument, "destination must be a pointer to a slice") } + func errNilSlicePointer() error { return spannerErrorf(codes.InvalidArgument, "destination must be a non nil pointer") } @@ -390,10 +391,10 @@ func (r *Row) ToStructLenient(p interface{}) error { ) } -// SelectAll iterates all rows to the end. After iterating it closes the rows, -// and propagates any errors that could pop up. -// It expects that destination should be a slice. For each row it scans data and appends it to the destination slice. -// SelectAll supports both types of slices: slice of structs by a pointer and slice of structs by value, +// SelectAll iterates all rows to the end. After iterating it closes the rows +// and propagates any errors that could pop up with destination slice partially filled. +// It expects that destination should be a slice. For each row, it scans data and appends it to the destination slice. +// SelectAll supports both types of slices: slice of pointers and slice of structs or primitives by value, // for example: // // type Singer struct { @@ -406,16 +407,19 @@ func (r *Row) ToStructLenient(p interface{}) error { // // Both singersByPtr and singersByValue are valid destinations for SelectAll function. // -// Before starting, SelectAll resets the destination slice, -// so if it's not empty it will overwrite all existing elements. -func SelectAll(rows Iterator, v interface{}, options ...DecodeOptions) error { +// custom setting such as lenient can be passed as an option using DecodeOptions +// example: to ignore extra columns in the row +// +// var singersByPtr []*Singer +// err := spanner.SelectAll(row, &singersByPtr, spanner.WithLenient()) +func SelectAll(rows rowIterator, destination interface{}, options ...DecodeOptions) error { if rows == nil { return fmt.Errorf("rows is nil") } - if v == nil { - return fmt.Errorf("p is nil") + if destination == nil { + return fmt.Errorf("destination is nil") } - dstVal := reflect.ValueOf(v) + dstVal := reflect.ValueOf(destination) if !dstVal.IsValid() || (dstVal.Kind() == reflect.Ptr && dstVal.IsNil()) { return errNilSlicePointer() } @@ -443,8 +447,6 @@ func SelectAll(rows Iterator, v interface{}, options ...DecodeOptions) error { itemByPtr = true } } - // Make sure slice is empty. - dstVal.Set(dstVal.Slice(0, 0)) s := &decodeSetting{} for _, opt := range options { opt.Apply(s) @@ -453,44 +455,36 @@ func SelectAll(rows Iterator, v interface{}, options ...DecodeOptions) error { isPrimitive := itemType.Kind() != reflect.Struct var pointers []interface{} isFirstRow := true - rowIndex := int64(-1) - var rowsReturned int64 + var err error return rows.Do(func(row *Row) error { sliceItem := reflect.New(itemType) - if isFirstRow { + if isFirstRow && !isPrimitive { defer func() { isFirstRow = false }() - rowsReturned = rows.RowsReturned() - if rowsReturned != -1 { - // nRows is lower bound of the number of rows returned by the query. - dstVal.Set(reflect.MakeSlice(dstType, int(rowsReturned), int(rowsReturned))) - rowIndex++ + if pointers, err = structPointers(sliceItem.Elem(), row.fields, s.Lenient); err != nil { + return err } - if isPrimitive { - if len(row.fields) > 1 { - return errTooManyColumns() - } - pointers = []interface{}{sliceItem.Addr().Interface()} - } else { - var err error - if pointers, err = structPointers(sliceItem.Elem(), row.fields, s.Lenient); err != nil { - return err - } + } else if isPrimitive { + if len(row.fields) > 1 && !s.Lenient { + return errTooManyColumns() } + pointers = []interface{}{sliceItem.Interface()} } if len(pointers) == 0 { return nil } - err := row.Columns(pointers...) + err = row.Columns(pointers...) if err != nil { return err } - for i, p := range pointers { - if p == nil { - continue + if !isPrimitive { + for i, p := range pointers { + if p == nil { + continue + } + sliceItem.Elem().Field(i).Set(reflect.ValueOf(p).Elem()) } - sliceItem.Elem().Field(i).Set(reflect.ValueOf(p).Elem()) } var elemVal reflect.Value if itemByPtr { @@ -505,17 +499,12 @@ func SelectAll(rows Iterator, v interface{}, options ...DecodeOptions) error { } else { elemVal = sliceItem.Elem() } - if rowIndex >= 0 && rowsReturned > rowIndex { - dstVal.Index(int(rowIndex)).Set(elemVal) - rowIndex++ - } else { - dstVal.Set(reflect.Append(dstVal, elemVal)) - } + dstVal.Set(reflect.Append(dstVal, elemVal)) return nil }) } -func structPointers(sliceItem reflect.Value, cols []*sppb.StructType_Field, strict bool) ([]interface{}, error) { +func structPointers(sliceItem reflect.Value, cols []*sppb.StructType_Field, lenient bool) ([]interface{}, error) { pointers := make([]interface{}, 0, len(cols)) fieldTag := make(map[string]reflect.Value, len(cols)) initFieldTag(sliceItem, &fieldTag) @@ -525,7 +514,7 @@ func structPointers(sliceItem reflect.Value, cols []*sppb.StructType_Field, stri if v, ok := fieldTag[colName.GetName()]; ok { fieldVal = v } else { - if strict { + if !lenient { return nil, errNoOrDupGoField(sliceItem, colName.GetName()) } fieldVal = sliceItem.FieldByName(colName.GetName()) diff --git a/spanner/row_test.go b/spanner/row_test.go index b4e7a2395635..0f03b5de94c7 100644 --- a/spanner/row_test.go +++ b/spanner/row_test.go @@ -18,6 +18,7 @@ package spanner import ( "encoding/base64" + "errors" "fmt" "reflect" "strconv" @@ -31,6 +32,7 @@ import ( proto "github.com/golang/protobuf/proto" proto3 "github.com/golang/protobuf/ptypes/struct" "github.com/google/go-cmp/cmp" + "google.golang.org/api/iterator" ) var ( @@ -1991,3 +1993,294 @@ func BenchmarkColumn(b *testing.B) { } } } + +func TestSelectAll(t *testing.T) { + skipForPGTest(t) + type args struct { + destination interface{} + options []DecodeOptions + mock func(mockIterator *mockRowIterator) + } + type testStruct struct { + Col1 int64 + Col2 float64 + Col3 string + } + tests := []struct { + name string + args args + wantErr bool + want interface{} + }{ + { + name: "success: using slice of primitives", + args: args{ + destination: &[]string{}, + mock: func(mockIterator *mockRowIterator) { + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col0", Type: stringType()}, + }, + []*proto3.Value{stringProto("value")}, + }, nil) + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col0", Type: stringType()}, + }, + []*proto3.Value{stringProto("value2")}, + }, nil) + mockIterator.On("Next").Once().Return(nil, iterator.Done) + mockIterator.On("Stop").Once().Return(nil) + }, + }, + want: &[]string{"value", "value2"}, + }, + { + name: "success: using slice of pointer to primitives", + args: args{ + destination: &[]*string{}, + mock: func(mockIterator *mockRowIterator) { + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col0", Type: stringType()}, + }, + []*proto3.Value{stringProto("value")}, + }, nil) + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col0", Type: stringType()}, + }, + []*proto3.Value{stringProto("value2")}, + }, nil) + mockIterator.On("Next").Once().Return(nil, iterator.Done) + mockIterator.On("Stop").Once().Return(nil) + }, + }, + want: &[]*string{stringPointer("value"), stringPointer("value2")}, + }, + { + name: "success: using slice of structs", + args: args{ + destination: &[]testStruct{}, + mock: func(mockIterator *mockRowIterator) { + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + }, + []*proto3.Value{intProto(1), floatProto(1.1), stringProto("value")}, + }, nil) + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + }, + []*proto3.Value{intProto(2), floatProto(2.2), stringProto("value2")}, + }, nil) + mockIterator.On("Next").Once().Return(nil, iterator.Done) + mockIterator.On("Stop").Once().Return(nil) + }, + }, + want: &[]testStruct{ + {Col1: 1, Col2: 1.1, Col3: "value"}, + {Col1: 2, Col2: 2.2, Col3: "value2"}, + }, + }, + { + name: "success: using slice of pointer to structs", + args: args{ + destination: &[]*testStruct{}, + mock: func(mockIterator *mockRowIterator) { + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + }, + []*proto3.Value{intProto(1), floatProto(1.1), stringProto("value")}, + }, nil) + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + }, + []*proto3.Value{intProto(2), floatProto(2.2), stringProto("value2")}, + }, nil) + mockIterator.On("Next").Once().Return(nil, iterator.Done) + mockIterator.On("Stop").Once().Return(nil) + }, + }, + want: &[]*testStruct{ + {Col1: 1, Col2: 1.1, Col3: "value"}, + {Col1: 2, Col2: 2.2, Col3: "value2"}, + }}, + { + name: "success: when spanner row contains more columns than declared in Go struct but called WithLenient", + args: args{ + destination: &[]*testStruct{}, + mock: func(mockIterator *mockRowIterator) { + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + {Name: "Col4", Type: stringType()}, + }, + []*proto3.Value{intProto(1), floatProto(1.1), stringProto("value"), stringProto("value4")}, + }, nil) + // failure case + mockIterator.On("Next").Once().Return(nil, iterator.Done) + mockIterator.On("Stop").Once().Return(nil) + }, + options: []DecodeOptions{WithLenient()}, + }, + want: &[]*testStruct{ + {Col1: 1, Col2: 1.1, Col3: "value"}, + }, + }, + { + name: "success: using prefilled destination should append to the destination", + args: args{ + destination: &[]*testStruct{{Col1: 3, Col2: 3.3, Col3: "value3"}}, + mock: func(mockIterator *mockRowIterator) { + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + }, + []*proto3.Value{intProto(1), floatProto(1.1), stringProto("value")}, + }, nil) + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + }, + []*proto3.Value{intProto(2), floatProto(2.2), stringProto("value2")}, + }, nil) + mockIterator.On("Next").Once().Return(nil, iterator.Done) + mockIterator.On("Stop").Once().Return(nil) + }, + }, + want: &[]*testStruct{ + {Col1: 3, Col2: 3.3, Col3: "value3"}, + {Col1: 1, Col2: 1.1, Col3: "value"}, + {Col1: 2, Col2: 2.2, Col3: "value2"}, + }}, + { + name: "failure: in case of error destination will have the partial result", + args: args{ + destination: &[]*testStruct{{Col1: 3, Col2: 3.3, Col3: "value3"}}, + mock: func(mockIterator *mockRowIterator) { + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + }, + []*proto3.Value{intProto(1), floatProto(1.1), stringProto("value")}, + }, nil) + // failure case + mockIterator.On("Next").Once().Return(nil, errors.New("some error")) + mockIterator.On("Stop").Once().Return(nil) + }, + }, + want: &[]*testStruct{ + {Col1: 3, Col2: 3.3, Col3: "value3"}, + {Col1: 1, Col2: 1.1, Col3: "value"}, + }, + wantErr: true, + }, + { + name: "failure: when spanner row contains more columns than declared in Go struct", + args: args{ + destination: &[]*testStruct{{Col1: 3, Col2: 3.3, Col3: "value3"}}, + mock: func(mockIterator *mockRowIterator) { + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + {Name: "Col4", Type: stringType()}, + }, + []*proto3.Value{intProto(1), floatProto(1.1), stringProto("value")}, + }, nil) + // failure case + mockIterator.On("Next").Once().Return(nil, iterator.Done) + mockIterator.On("Stop").Once().Return(nil) + }, + }, + want: &[]*testStruct{ + {Col1: 3, Col2: 3.3, Col3: "value3"}, + }, + wantErr: true, + }, + { + name: "failure: when spanner row contains more columns and destination is primitive slice", + args: args{ + destination: &[]int64{}, + mock: func(mockIterator *mockRowIterator) { + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + {Name: "Col4", Type: stringType()}, + }, + []*proto3.Value{intProto(1), floatProto(1.1), stringProto("value")}, + }, nil) + // failure case + mockIterator.On("Next").Once().Return(nil, iterator.Done) + mockIterator.On("Stop").Once().Return(nil) + }, + }, + want: &[]int64{}, + wantErr: true, + }, + { + name: "failure: when spanner row contains more columns and destination is primitive slice using WithLenient", + args: args{ + destination: &[]int64{}, + mock: func(mockIterator *mockRowIterator) { + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + {Name: "Col4", Type: stringType()}, + }, + []*proto3.Value{intProto(1), floatProto(1.1), stringProto("value")}, + }, nil) + // failure case + mockIterator.On("Next").Once().Return(nil, iterator.Done) + mockIterator.On("Stop").Once().Return(nil) + }, + options: []DecodeOptions{WithLenient()}, + }, + want: &[]int64{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockIterator := &mockRowIterator{} + if tt.args.mock != nil { + tt.args.mock(mockIterator) + } + if err := SelectAll(mockIterator, tt.args.destination, tt.args.options...); (err != nil) != tt.wantErr { + t.Errorf("SelectAll() error = %v, wantErr %v", err, tt.wantErr) + } + if !testEqual(tt.args.destination, tt.want) { + t.Errorf("SelectAll() = %v, want %v", tt.args.destination, tt.want) + } + }) + } +} + +func stringPointer(s string) *string { + return &s +} diff --git a/spanner/value_benchmarks_test.go b/spanner/value_benchmarks_test.go index ce68f0e88cc0..3a9687199f76 100644 --- a/spanner/value_benchmarks_test.go +++ b/spanner/value_benchmarks_test.go @@ -16,6 +16,7 @@ package spanner import ( "fmt" + "math" "reflect" "strconv" "testing" @@ -233,6 +234,85 @@ func decodeArrayReflect(pb *proto3.ListValue, name string, typ *sppb.Type, aptr return nil } +func BenchmarkScan(b *testing.B) { + scanMethods := []string{"row.Column()", "row.ToStruct()", "row.SelectAll()"} + for _, method := range scanMethods { + for k := 0.; k <= 20; k++ { + n := int(math.Pow(2, k)) + b.Run(fmt.Sprintf("%s/%d", method, n), func(b *testing.B) { + b.StopTimer() + var rows []struct { + ID int64 + Name string + Active bool + City string + State string + } + for i := 0; i < n; i++ { + rows = append(rows, struct { + ID int64 + Name string + Active bool + City string + State string + }{int64(i), fmt.Sprintf("name-%d", i), true, "city", "state"}) + } + src := mockBenchmarkIterator(b, rows) + for i := 0; i < b.N; i++ { + it := *src + var res []struct { + ID int64 + Name string + Active bool + City string + State string + } + b.StartTimer() + switch method { + case "row.SelectAll()": + if err := SelectAll(&it, &res); err != nil { + b.Fatal(err) + } + _ = res + break + default: + for { + row, err := it.Next() + if err == iterator.Done { + break + } else if err != nil { + b.Fatal(err) + } + var r struct { + ID int64 + Name string + Active bool + City string + State string + } + if method == "row.Column()" { + err = row.Columns(&r.ID, &r.Name, &r.Active, &r.City, &r.State) + if err != nil { + b.Fatal(err) + } + } else { + err = row.ToStruct(&r) + if err != nil { + b.Fatal(err) + } + } + res = append(res, r) + } + it.Stop() + _ = res + } + + } + }) + } + } +} + func BenchmarkScan100RowsUsingSelectAll(b *testing.B) { var rows []struct { ID int64 @@ -244,7 +324,7 @@ func BenchmarkScan100RowsUsingSelectAll(b *testing.B) { Name string }{int64(i), fmt.Sprintf("name-%d", i)}) } - src := mockIterator(b, rows) + src := mockBenchmarkIterator(b, rows) b.ResetTimer() for n := 0; n < b.N; n++ { it := *src @@ -270,7 +350,7 @@ func BenchmarkScan100RowsUsingToStruct(b *testing.B) { Name string }{int64(i), fmt.Sprintf("name-%d", i)}) } - src := mockIterator(b, rows) + src := mockBenchmarkIterator(b, rows) b.ResetTimer() for n := 0; n < b.N; n++ { it := *src @@ -311,7 +391,7 @@ func BenchmarkScan100RowsUsingColumns(b *testing.B) { Name string }{int64(i), fmt.Sprintf("name-%d", i)}) } - src := mockIterator(b, rows) + src := mockBenchmarkIterator(b, rows) b.ResetTimer() for n := 0; n < b.N; n++ { it := *src @@ -341,7 +421,7 @@ func BenchmarkScan100RowsUsingColumns(b *testing.B) { } } -func mockIterator[T any](t testing.TB, rows []T) *mockIteratorImpl { +func mockBenchmarkIterator[T any](t testing.TB, rows []T) *mockIteratorImpl { var v T var colNames []string numCols := reflect.TypeOf(v).NumField() @@ -378,10 +458,6 @@ func (i *mockIteratorImpl) Next() (*Row, error) { return row, nil } -func (i *mockIteratorImpl) RowsReturned() int64 { - return int64(len(i.rows)) -} - func (i *mockIteratorImpl) Stop() { i.rows = nil } From 50d1c37cc241119f31926910fba20ed0762b81fb Mon Sep 17 00:00:00 2001 From: rahul yadav Date: Thu, 18 Jan 2024 10:16:02 +0530 Subject: [PATCH 9/9] incorporate suggestions --- spanner/read.go | 1 + spanner/row.go | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/spanner/read.go b/spanner/read.go index 93895c91d50c..50578b740fa3 100644 --- a/spanner/read.go +++ b/spanner/read.go @@ -128,6 +128,7 @@ type RowIterator struct { sawStats bool } +// this is for safety from future changes to RowIterator making sure that it implements rowIterator interface. var _ rowIterator = (*RowIterator)(nil) // Next returns the next result. Its second return value is iterator.Done if diff --git a/spanner/row.go b/spanner/row.go index fa86b9cd97b5..83a3628de4e8 100644 --- a/spanner/row.go +++ b/spanner/row.go @@ -407,8 +407,8 @@ func (r *Row) ToStructLenient(p interface{}) error { // // Both singersByPtr and singersByValue are valid destinations for SelectAll function. // -// custom setting such as lenient can be passed as an option using DecodeOptions -// example: to ignore extra columns in the row +// Add the option `spanner.WithLenient()` to instruct SelectAll to ignore additional columns in the rows that are not present in the destination struct. +// example: // // var singersByPtr []*Singer // err := spanner.SelectAll(row, &singersByPtr, spanner.WithLenient()) @@ -479,17 +479,18 @@ func SelectAll(rows rowIterator, destination interface{}, options ...DecodeOptio return err } if !isPrimitive { + e := sliceItem.Elem() for i, p := range pointers { if p == nil { continue } - sliceItem.Elem().Field(i).Set(reflect.ValueOf(p).Elem()) + e.Field(i).Set(reflect.ValueOf(p).Elem()) } } var elemVal reflect.Value if itemByPtr { if isFirstRow { - // create a new pointer to the struct with all the values copied from sliceIte + // create a new pointer to the struct with all the values copied from sliceItem // because same underlying pointers array will be used for next rows elemVal = reflect.New(itemType) elemVal.Elem().Set(sliceItem.Elem())