diff --git a/spanner/go.mod b/spanner/go.mod index ae0aab442e6e..dd6489a4e874 100644 --- a/spanner/go.mod +++ b/spanner/go.mod @@ -41,6 +41,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.46.1 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 // indirect go.opentelemetry.io/otel v1.21.0 // indirect diff --git a/spanner/go.sum b/spanner/go.sum index f214e3f4d209..fd6ebf95e9bd 100644 --- a/spanner/go.sum +++ b/spanner/go.sum @@ -87,6 +87,7 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/spanner/mocks.go b/spanner/mocks.go new file mode 100644 index 000000000000..f7553f217e8c --- /dev/null +++ b/spanner/mocks.go @@ -0,0 +1,95 @@ +/* +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Code generated by mockery v2.40.1. DO NOT EDIT. + +package spanner + +import ( + mock "github.com/stretchr/testify/mock" + "google.golang.org/api/iterator" +) + +// mockRowIterator is an autogenerated mock type for the mockRowIterator type +type mockRowIterator struct { + mock.Mock +} + +// Next provides a mock function with given fields: +func (_m *mockRowIterator) Next() (*Row, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Next") + } + + var r0 *Row + var r1 error + if rf, ok := ret.Get(0).(func() (*Row, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() *Row); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*Row) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +func (_m *mockRowIterator) Do(f func(r *Row) error) error { + defer _m.Stop() + for { + row, err := _m.Next() + switch err { + case iterator.Done: + return nil + case nil: + if err = f(row); err != nil { + return err + } + default: + return err + } + } +} + +// Stop provides a mock function with given fields: +func (_m *mockRowIterator) Stop() { + _m.Called() +} + +// newRowIterator creates a new instance of mockRowIterator. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newRowIterator(t interface { + mock.TestingT + Cleanup(func()) +}) *mockRowIterator { + mock := &mockRowIterator{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/spanner/read.go b/spanner/read.go index 2651ac9535de..50578b740fa3 100644 --- a/spanner/read.go +++ b/spanner/read.go @@ -90,6 +90,13 @@ func streamWithReplaceSessionFunc( } } +// rowIterator is an interface for iterating over Rows. +type rowIterator interface { + Next() (*Row, error) + Do(f func(r *Row) error) error + Stop() +} + // RowIterator is an iterator over Rows. type RowIterator struct { // The plan for the query. Available after RowIterator.Next returns @@ -121,6 +128,9 @@ type RowIterator struct { sawStats bool } +// this is for safety from future changes to RowIterator making sure that it implements rowIterator interface. +var _ rowIterator = (*RowIterator)(nil) + // Next returns the next result. Its second return value is iterator.Done if // there are no more results. Once Next returns Done, all subsequent calls // will return Done. diff --git a/spanner/row.go b/spanner/row.go index 4dca9be91f0e..83a3628de4e8 100644 --- a/spanner/row.go +++ b/spanner/row.go @@ -249,6 +249,18 @@ func errColNotFound(n string) error { return spannerErrorf(codes.NotFound, "column %q not found", n) } +func errNotASlicePointer() error { + return spannerErrorf(codes.InvalidArgument, "destination must be a pointer to a slice") +} + +func errNilSlicePointer() error { + return spannerErrorf(codes.InvalidArgument, "destination must be a non nil pointer") +} + +func errTooManyColumns() error { + return spannerErrorf(codes.InvalidArgument, "too many columns returned for primitive slice") +} + // ColumnByName fetches the value from the named column, decoding it into ptr. // See the Row documentation for the list of acceptable argument types. func (r *Row) ColumnByName(name string, ptr interface{}) error { @@ -378,3 +390,175 @@ func (r *Row) ToStructLenient(p interface{}) error { true, ) } + +// SelectAll iterates all rows to the end. After iterating it closes the rows +// and propagates any errors that could pop up with destination slice partially filled. +// It expects that destination should be a slice. For each row, it scans data and appends it to the destination slice. +// SelectAll supports both types of slices: slice of pointers and slice of structs or primitives by value, +// for example: +// +// type Singer struct { +// ID string +// Name string +// } +// +// var singersByPtr []*Singer +// var singersByValue []Singer +// +// Both singersByPtr and singersByValue are valid destinations for SelectAll function. +// +// Add the option `spanner.WithLenient()` to instruct SelectAll to ignore additional columns in the rows that are not present in the destination struct. +// example: +// +// var singersByPtr []*Singer +// err := spanner.SelectAll(row, &singersByPtr, spanner.WithLenient()) +func SelectAll(rows rowIterator, destination interface{}, options ...DecodeOptions) error { + if rows == nil { + return fmt.Errorf("rows is nil") + } + if destination == nil { + return fmt.Errorf("destination is nil") + } + dstVal := reflect.ValueOf(destination) + if !dstVal.IsValid() || (dstVal.Kind() == reflect.Ptr && dstVal.IsNil()) { + return errNilSlicePointer() + } + if dstVal.Kind() != reflect.Ptr { + return errNotASlicePointer() + } + dstVal = dstVal.Elem() + dstType := dstVal.Type() + if k := dstType.Kind(); k != reflect.Slice { + return errNotASlicePointer() + } + + itemType := dstType.Elem() + var itemByPtr bool + // If it's a slice of pointers to structs, + // we handle it the same way as it would be slice of struct by value + // and dereference pointers to values, + // because eventually we work with fields. + // But if it's a slice of primitive type e.g. or []string or []*string, + // we must leave and pass elements as is. + if itemType.Kind() == reflect.Ptr { + elementBaseTypeElem := itemType.Elem() + if elementBaseTypeElem.Kind() == reflect.Struct { + itemType = elementBaseTypeElem + itemByPtr = true + } + } + s := &decodeSetting{} + for _, opt := range options { + opt.Apply(s) + } + + isPrimitive := itemType.Kind() != reflect.Struct + var pointers []interface{} + isFirstRow := true + var err error + return rows.Do(func(row *Row) error { + sliceItem := reflect.New(itemType) + if isFirstRow && !isPrimitive { + defer func() { + isFirstRow = false + }() + if pointers, err = structPointers(sliceItem.Elem(), row.fields, s.Lenient); err != nil { + return err + } + } else if isPrimitive { + if len(row.fields) > 1 && !s.Lenient { + return errTooManyColumns() + } + pointers = []interface{}{sliceItem.Interface()} + } + if len(pointers) == 0 { + return nil + } + err = row.Columns(pointers...) + if err != nil { + return err + } + if !isPrimitive { + e := sliceItem.Elem() + for i, p := range pointers { + if p == nil { + continue + } + e.Field(i).Set(reflect.ValueOf(p).Elem()) + } + } + var elemVal reflect.Value + if itemByPtr { + if isFirstRow { + // create a new pointer to the struct with all the values copied from sliceItem + // because same underlying pointers array will be used for next rows + elemVal = reflect.New(itemType) + elemVal.Elem().Set(sliceItem.Elem()) + } else { + elemVal = sliceItem + } + } else { + elemVal = sliceItem.Elem() + } + dstVal.Set(reflect.Append(dstVal, elemVal)) + return nil + }) +} + +func structPointers(sliceItem reflect.Value, cols []*sppb.StructType_Field, lenient bool) ([]interface{}, error) { + pointers := make([]interface{}, 0, len(cols)) + fieldTag := make(map[string]reflect.Value, len(cols)) + initFieldTag(sliceItem, &fieldTag) + + for _, colName := range cols { + var fieldVal reflect.Value + if v, ok := fieldTag[colName.GetName()]; ok { + fieldVal = v + } else { + if !lenient { + return nil, errNoOrDupGoField(sliceItem, colName.GetName()) + } + fieldVal = sliceItem.FieldByName(colName.GetName()) + } + if !fieldVal.IsValid() || !fieldVal.CanSet() { + // have to add if we found a column because Columns() requires + // len(cols) arguments or it will error. This way we can scan to + // a useless pointer + pointers = append(pointers, nil) + continue + } + + pointers = append(pointers, fieldVal.Addr().Interface()) + } + return pointers, nil +} + +// Initialization the tags from struct. +func initFieldTag(sliceItem reflect.Value, fieldTagMap *map[string]reflect.Value) { + typ := sliceItem.Type() + + for i := 0; i < sliceItem.NumField(); i++ { + fieldType := typ.Field(i) + exported := (fieldType.PkgPath == "") + // If a named field is unexported, ignore it. An anonymous + // unexported field is processed, because it may contain + // exported fields, which are visible. + if !exported && !fieldType.Anonymous { + continue + } + if fieldType.Type.Kind() == reflect.Struct { + // found an embedded struct + sliceItemOfAnonymous := sliceItem.Field(i) + initFieldTag(sliceItemOfAnonymous, fieldTagMap) + continue + } + name, keep, _, _ := spannerTagParser(fieldType.Tag) + if !keep { + continue + } + if name == "" { + name = fieldType.Name + } + (*fieldTagMap)[name] = sliceItem.Field(i) + } +} diff --git a/spanner/row_test.go b/spanner/row_test.go index b4e7a2395635..0f03b5de94c7 100644 --- a/spanner/row_test.go +++ b/spanner/row_test.go @@ -18,6 +18,7 @@ package spanner import ( "encoding/base64" + "errors" "fmt" "reflect" "strconv" @@ -31,6 +32,7 @@ import ( proto "github.com/golang/protobuf/proto" proto3 "github.com/golang/protobuf/ptypes/struct" "github.com/google/go-cmp/cmp" + "google.golang.org/api/iterator" ) var ( @@ -1991,3 +1993,294 @@ func BenchmarkColumn(b *testing.B) { } } } + +func TestSelectAll(t *testing.T) { + skipForPGTest(t) + type args struct { + destination interface{} + options []DecodeOptions + mock func(mockIterator *mockRowIterator) + } + type testStruct struct { + Col1 int64 + Col2 float64 + Col3 string + } + tests := []struct { + name string + args args + wantErr bool + want interface{} + }{ + { + name: "success: using slice of primitives", + args: args{ + destination: &[]string{}, + mock: func(mockIterator *mockRowIterator) { + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col0", Type: stringType()}, + }, + []*proto3.Value{stringProto("value")}, + }, nil) + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col0", Type: stringType()}, + }, + []*proto3.Value{stringProto("value2")}, + }, nil) + mockIterator.On("Next").Once().Return(nil, iterator.Done) + mockIterator.On("Stop").Once().Return(nil) + }, + }, + want: &[]string{"value", "value2"}, + }, + { + name: "success: using slice of pointer to primitives", + args: args{ + destination: &[]*string{}, + mock: func(mockIterator *mockRowIterator) { + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col0", Type: stringType()}, + }, + []*proto3.Value{stringProto("value")}, + }, nil) + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col0", Type: stringType()}, + }, + []*proto3.Value{stringProto("value2")}, + }, nil) + mockIterator.On("Next").Once().Return(nil, iterator.Done) + mockIterator.On("Stop").Once().Return(nil) + }, + }, + want: &[]*string{stringPointer("value"), stringPointer("value2")}, + }, + { + name: "success: using slice of structs", + args: args{ + destination: &[]testStruct{}, + mock: func(mockIterator *mockRowIterator) { + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + }, + []*proto3.Value{intProto(1), floatProto(1.1), stringProto("value")}, + }, nil) + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + }, + []*proto3.Value{intProto(2), floatProto(2.2), stringProto("value2")}, + }, nil) + mockIterator.On("Next").Once().Return(nil, iterator.Done) + mockIterator.On("Stop").Once().Return(nil) + }, + }, + want: &[]testStruct{ + {Col1: 1, Col2: 1.1, Col3: "value"}, + {Col1: 2, Col2: 2.2, Col3: "value2"}, + }, + }, + { + name: "success: using slice of pointer to structs", + args: args{ + destination: &[]*testStruct{}, + mock: func(mockIterator *mockRowIterator) { + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + }, + []*proto3.Value{intProto(1), floatProto(1.1), stringProto("value")}, + }, nil) + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + }, + []*proto3.Value{intProto(2), floatProto(2.2), stringProto("value2")}, + }, nil) + mockIterator.On("Next").Once().Return(nil, iterator.Done) + mockIterator.On("Stop").Once().Return(nil) + }, + }, + want: &[]*testStruct{ + {Col1: 1, Col2: 1.1, Col3: "value"}, + {Col1: 2, Col2: 2.2, Col3: "value2"}, + }}, + { + name: "success: when spanner row contains more columns than declared in Go struct but called WithLenient", + args: args{ + destination: &[]*testStruct{}, + mock: func(mockIterator *mockRowIterator) { + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + {Name: "Col4", Type: stringType()}, + }, + []*proto3.Value{intProto(1), floatProto(1.1), stringProto("value"), stringProto("value4")}, + }, nil) + // failure case + mockIterator.On("Next").Once().Return(nil, iterator.Done) + mockIterator.On("Stop").Once().Return(nil) + }, + options: []DecodeOptions{WithLenient()}, + }, + want: &[]*testStruct{ + {Col1: 1, Col2: 1.1, Col3: "value"}, + }, + }, + { + name: "success: using prefilled destination should append to the destination", + args: args{ + destination: &[]*testStruct{{Col1: 3, Col2: 3.3, Col3: "value3"}}, + mock: func(mockIterator *mockRowIterator) { + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + }, + []*proto3.Value{intProto(1), floatProto(1.1), stringProto("value")}, + }, nil) + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + }, + []*proto3.Value{intProto(2), floatProto(2.2), stringProto("value2")}, + }, nil) + mockIterator.On("Next").Once().Return(nil, iterator.Done) + mockIterator.On("Stop").Once().Return(nil) + }, + }, + want: &[]*testStruct{ + {Col1: 3, Col2: 3.3, Col3: "value3"}, + {Col1: 1, Col2: 1.1, Col3: "value"}, + {Col1: 2, Col2: 2.2, Col3: "value2"}, + }}, + { + name: "failure: in case of error destination will have the partial result", + args: args{ + destination: &[]*testStruct{{Col1: 3, Col2: 3.3, Col3: "value3"}}, + mock: func(mockIterator *mockRowIterator) { + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + }, + []*proto3.Value{intProto(1), floatProto(1.1), stringProto("value")}, + }, nil) + // failure case + mockIterator.On("Next").Once().Return(nil, errors.New("some error")) + mockIterator.On("Stop").Once().Return(nil) + }, + }, + want: &[]*testStruct{ + {Col1: 3, Col2: 3.3, Col3: "value3"}, + {Col1: 1, Col2: 1.1, Col3: "value"}, + }, + wantErr: true, + }, + { + name: "failure: when spanner row contains more columns than declared in Go struct", + args: args{ + destination: &[]*testStruct{{Col1: 3, Col2: 3.3, Col3: "value3"}}, + mock: func(mockIterator *mockRowIterator) { + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + {Name: "Col4", Type: stringType()}, + }, + []*proto3.Value{intProto(1), floatProto(1.1), stringProto("value")}, + }, nil) + // failure case + mockIterator.On("Next").Once().Return(nil, iterator.Done) + mockIterator.On("Stop").Once().Return(nil) + }, + }, + want: &[]*testStruct{ + {Col1: 3, Col2: 3.3, Col3: "value3"}, + }, + wantErr: true, + }, + { + name: "failure: when spanner row contains more columns and destination is primitive slice", + args: args{ + destination: &[]int64{}, + mock: func(mockIterator *mockRowIterator) { + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + {Name: "Col4", Type: stringType()}, + }, + []*proto3.Value{intProto(1), floatProto(1.1), stringProto("value")}, + }, nil) + // failure case + mockIterator.On("Next").Once().Return(nil, iterator.Done) + mockIterator.On("Stop").Once().Return(nil) + }, + }, + want: &[]int64{}, + wantErr: true, + }, + { + name: "failure: when spanner row contains more columns and destination is primitive slice using WithLenient", + args: args{ + destination: &[]int64{}, + mock: func(mockIterator *mockRowIterator) { + mockIterator.On("Next").Once().Return(&Row{ + []*sppb.StructType_Field{ + {Name: "Col1", Type: intType()}, + {Name: "Col2", Type: floatType()}, + {Name: "Col3", Type: stringType()}, + {Name: "Col4", Type: stringType()}, + }, + []*proto3.Value{intProto(1), floatProto(1.1), stringProto("value")}, + }, nil) + // failure case + mockIterator.On("Next").Once().Return(nil, iterator.Done) + mockIterator.On("Stop").Once().Return(nil) + }, + options: []DecodeOptions{WithLenient()}, + }, + want: &[]int64{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockIterator := &mockRowIterator{} + if tt.args.mock != nil { + tt.args.mock(mockIterator) + } + if err := SelectAll(mockIterator, tt.args.destination, tt.args.options...); (err != nil) != tt.wantErr { + t.Errorf("SelectAll() error = %v, wantErr %v", err, tt.wantErr) + } + if !testEqual(tt.args.destination, tt.want) { + t.Errorf("SelectAll() = %v, want %v", tt.args.destination, tt.want) + } + }) + } +} + +func stringPointer(s string) *string { + return &s +} diff --git a/spanner/value.go b/spanner/value.go index d0360c5c16c0..83a8132f0592 100644 --- a/spanner/value.go +++ b/spanner/value.go @@ -1032,7 +1032,7 @@ func parseNullTime(v *proto3.Value, p *NullTime, code sppb.TypeCode, isNull bool // decodeValue decodes a protobuf Value into a pointer to a Go value, as // specified by sppb.Type. -func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}, opts ...decodeOptions) error { +func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}, opts ...DecodeOptions) error { if v == nil { return errNilSrc() } @@ -3198,8 +3198,8 @@ type decodeSetting struct { Lenient bool } -// decodeOptions is the interface to change decode struct settings -type decodeOptions interface { +// DecodeOptions is the interface to change decode struct settings +type DecodeOptions interface { Apply(s *decodeSetting) } @@ -3209,6 +3209,11 @@ func (w withLenient) Apply(s *decodeSetting) { s.Lenient = w.lenient } +// WithLenient returns a DecodeOptions that allows decoding into a struct with missing fields in database. +func WithLenient() DecodeOptions { + return withLenient{lenient: true} +} + // decodeStruct decodes proto3.ListValue pb into struct referenced by pointer // ptr, according to // the structural information given in sppb.StructType ty. @@ -3253,7 +3258,7 @@ func decodeStruct(ty *sppb.StructType, pb *proto3.ListValue, ptr interface{}, le // We don't allow duplicated field name. return errDupSpannerField(f.Name, ty) } - opts := []decodeOptions{withLenient{lenient: lenient}} + opts := []DecodeOptions{withLenient{lenient: lenient}} // Try to decode a single field. if err := decodeValue(pb.Values[i], f.Type, v.FieldByIndex(sf.Index).Addr().Interface(), opts...); err != nil { return errDecodeStructField(ty, f.Name, err) diff --git a/spanner/value_benchmarks_test.go b/spanner/value_benchmarks_test.go index 7a9334b6a797..3a9687199f76 100644 --- a/spanner/value_benchmarks_test.go +++ b/spanner/value_benchmarks_test.go @@ -15,6 +15,8 @@ package spanner import ( + "fmt" + "math" "reflect" "strconv" "testing" @@ -22,6 +24,7 @@ import ( "cloud.google.com/go/civil" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" proto3 "github.com/golang/protobuf/ptypes/struct" + "google.golang.org/api/iterator" ) func BenchmarkEncodeIntArray(b *testing.B) { @@ -230,3 +233,242 @@ func decodeArrayReflect(pb *proto3.ListValue, name string, typ *sppb.Type, aptr } return nil } + +func BenchmarkScan(b *testing.B) { + scanMethods := []string{"row.Column()", "row.ToStruct()", "row.SelectAll()"} + for _, method := range scanMethods { + for k := 0.; k <= 20; k++ { + n := int(math.Pow(2, k)) + b.Run(fmt.Sprintf("%s/%d", method, n), func(b *testing.B) { + b.StopTimer() + var rows []struct { + ID int64 + Name string + Active bool + City string + State string + } + for i := 0; i < n; i++ { + rows = append(rows, struct { + ID int64 + Name string + Active bool + City string + State string + }{int64(i), fmt.Sprintf("name-%d", i), true, "city", "state"}) + } + src := mockBenchmarkIterator(b, rows) + for i := 0; i < b.N; i++ { + it := *src + var res []struct { + ID int64 + Name string + Active bool + City string + State string + } + b.StartTimer() + switch method { + case "row.SelectAll()": + if err := SelectAll(&it, &res); err != nil { + b.Fatal(err) + } + _ = res + break + default: + for { + row, err := it.Next() + if err == iterator.Done { + break + } else if err != nil { + b.Fatal(err) + } + var r struct { + ID int64 + Name string + Active bool + City string + State string + } + if method == "row.Column()" { + err = row.Columns(&r.ID, &r.Name, &r.Active, &r.City, &r.State) + if err != nil { + b.Fatal(err) + } + } else { + err = row.ToStruct(&r) + if err != nil { + b.Fatal(err) + } + } + res = append(res, r) + } + it.Stop() + _ = res + } + + } + }) + } + } +} + +func BenchmarkScan100RowsUsingSelectAll(b *testing.B) { + var rows []struct { + ID int64 + Name string + } + for i := 0; i < 100; i++ { + rows = append(rows, struct { + ID int64 + Name string + }{int64(i), fmt.Sprintf("name-%d", i)}) + } + src := mockBenchmarkIterator(b, rows) + b.ResetTimer() + for n := 0; n < b.N; n++ { + it := *src + var res []struct { + ID int64 + Name string + } + if err := SelectAll(&it, &res); err != nil { + b.Fatal(err) + } + _ = res + } +} + +func BenchmarkScan100RowsUsingToStruct(b *testing.B) { + var rows []struct { + ID int64 + Name string + } + for i := 0; i < 100; i++ { + rows = append(rows, struct { + ID int64 + Name string + }{int64(i), fmt.Sprintf("name-%d", i)}) + } + src := mockBenchmarkIterator(b, rows) + b.ResetTimer() + for n := 0; n < b.N; n++ { + it := *src + var res []struct { + ID int64 + Name string + } + for { + row, err := it.Next() + if err == iterator.Done { + break + } else if err != nil { + b.Fatal(err) + } + var r struct { + ID int64 + Name string + } + err = row.ToStruct(&r) + if err != nil { + b.Fatal(err) + } + res = append(res, r) + } + it.Stop() + _ = res + } +} + +func BenchmarkScan100RowsUsingColumns(b *testing.B) { + var rows []struct { + ID int64 + Name string + } + for i := 0; i < 100; i++ { + rows = append(rows, struct { + ID int64 + Name string + }{int64(i), fmt.Sprintf("name-%d", i)}) + } + src := mockBenchmarkIterator(b, rows) + b.ResetTimer() + for n := 0; n < b.N; n++ { + it := *src + var res []struct { + ID int64 + Name string + } + for { + row, err := it.Next() + if err == iterator.Done { + break + } else if err != nil { + b.Fatal(err) + } + var r struct { + ID int64 + Name string + } + err = row.Columns(&r.ID, &r.Name) + if err != nil { + b.Fatal(err) + } + res = append(res, r) + } + it.Stop() + _ = res + } +} + +func mockBenchmarkIterator[T any](t testing.TB, rows []T) *mockIteratorImpl { + var v T + var colNames []string + numCols := reflect.TypeOf(v).NumField() + for i := 0; i < numCols; i++ { + f := reflect.TypeOf(v).Field(i) + colNames = append(colNames, f.Name) + } + var srows []*Row + for _, e := range rows { + var vs []any + for f := 0; f < numCols; f++ { + v := reflect.ValueOf(e).Field(f).Interface() + vs = append(vs, v) + } + row, err := NewRow(colNames, vs) + if err != nil { + t.Fatal(err) + } + srows = append(srows, row) + } + return &mockIteratorImpl{rows: srows} +} + +type mockIteratorImpl struct { + rows []*Row +} + +func (i *mockIteratorImpl) Next() (*Row, error) { + if len(i.rows) == 0 { + return nil, iterator.Done + } + row := i.rows[0] + i.rows = i.rows[1:] + return row, nil +} + +func (i *mockIteratorImpl) Stop() { + i.rows = nil +} + +func (i *mockIteratorImpl) Do(f func(*Row) error) error { + defer i.Stop() + for _, row := range i.rows { + err := f(row) + if err != nil { + return err + } + } + return nil +}