Skip to content

Commit

Permalink
bigquery: support array query parameters
Browse files Browse the repository at this point in the history
Change-Id: I396aeed685b421d3c3ab07a552e28e4eded8f064
Reviewed-on: https://code-review.googlesource.com/9558
Reviewed-by: Ross Light <[email protected]>
  • Loading branch information
jba committed Nov 30, 2016
1 parent 2861f2e commit a64eb5d
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 33 deletions.
49 changes: 36 additions & 13 deletions bigquery/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package bigquery
import (
"encoding/base64"
"fmt"
"reflect"
"time"

bq "google.golang.org/api/bigquery/v2"
Expand All @@ -34,23 +35,32 @@ var (
timestampParamType = &bq.QueryParameterType{Type: "TIMESTAMP"}
)

func paramType(x interface{}) (*bq.QueryParameterType, error) {
switch x.(type) {
case int, int8, int16, int32, int64, uint8, uint16, uint32:
var timeType = reflect.TypeOf(time.Time{})

func paramType(t reflect.Type) (*bq.QueryParameterType, error) {
switch t.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32:
return int64ParamType, nil
case float32, float64:
case reflect.Float32, reflect.Float64:
return float64ParamType, nil
case bool:
case reflect.Bool:
return boolParamType, nil
case string:
case reflect.String:
return stringParamType, nil
case time.Time:
case reflect.Slice, reflect.Array:
if t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8 {
return bytesParamType, nil
}
et, err := paramType(t.Elem())
if err != nil {
return nil, err
}
return &bq.QueryParameterType{Type: "ARRAY", ArrayType: et}, nil
}
if t == timeType {
return timestampParamType, nil
case []byte:
return bytesParamType, nil
default:
return nil, fmt.Errorf("Go type %T cannot be represented as a parameter type", x)
}
return nil, fmt.Errorf("Go type %s cannot be represented as a parameter type", t)
}

func paramValue(x interface{}) (bq.QueryParameterValue, error) {
Expand All @@ -63,7 +73,20 @@ func paramValue(x interface{}) (bq.QueryParameterValue, error) {
return sval(base64.StdEncoding.EncodeToString(x)), nil
case time.Time:
return sval(x.Format(timestampFormat)), nil
default:
return sval(fmt.Sprint(x)), nil
}
t := reflect.TypeOf(x)
switch t.Kind() {
case reflect.Slice, reflect.Array:
var vals []*bq.QueryParameterValue
v := reflect.ValueOf(x)
for i := 0; i < v.Len(); i++ {
val, err := paramValue(v.Index(i).Interface())
if err != nil {
return bq.QueryParameterValue{}, err
}
vals = append(vals, &val)
}
return bq.QueryParameterValue{ArrayValues: vals}, nil
}
return sval(fmt.Sprint(x)), nil
}
89 changes: 70 additions & 19 deletions bigquery/params_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
package bigquery

import (
"bytes"
"context"
"errors"
"math"
"reflect"
"testing"
Expand Down Expand Up @@ -60,7 +60,31 @@ func TestParamValueScalar(t *testing.T) {
}
}

func TestParamTypeScalar(t *testing.T) {
func TestParamValueArray(t *testing.T) {
for _, test := range []struct {
val interface{}
want []string
}{
{[]int(nil), []string{}},
{[]int{}, []string{}},
{[]int{1, 2}, []string{"1", "2"}},
{[3]int{1, 2, 3}, []string{"1", "2", "3"}},
} {
got, err := paramValue(test.val)
if err != nil {
t.Fatal(err)
}
var want bq.QueryParameterValue
for _, s := range test.want {
want.ArrayValues = append(want.ArrayValues, &bq.QueryParameterValue{Value: s})
}
if !reflect.DeepEqual(got, want) {
t.Errorf("%#v:\ngot %+v\nwant %+v", test.val, got, want)
}
}
}

func TestParamType(t *testing.T) {
for _, test := range []struct {
val interface{}
want *bq.QueryParameterType
Expand All @@ -75,42 +99,71 @@ func TestParamTypeScalar(t *testing.T) {
{"string", stringParamType},
{time.Now(), timestampParamType},
{[]byte("foo"), bytesParamType},
{[]int{}, &bq.QueryParameterType{Type: "ARRAY", ArrayType: int64ParamType}},
{[3]bool{}, &bq.QueryParameterType{Type: "ARRAY", ArrayType: boolParamType}},
} {
got, err := paramType(test.val)
got, err := paramType(reflect.TypeOf(test.val))
if err != nil {
t.Fatal(err)
}
if got != test.want {
if !reflect.DeepEqual(got, test.want) {
t.Errorf("%v (%T): got %v, want %v", test.val, test.val, got, test.want)
}
}
}

func TestIntegration_ScalarParam(t *testing.T) {
ctx := context.Background()
c := getClient(t)
for _, test := range scalarTests {
q := c.Query("select ?")
q.Parameters = []QueryParameter{{Value: test.val}}
it, err := q.Read(ctx)
got, err := paramRoundTrip(c, test.val)
if err != nil {
t.Fatal(err)
}
var val []Value
err = it.Next(&val)
if !equal(got, test.val) {
t.Errorf("\ngot %#v (%T)\nwant %#v (%T)", got, got, test.val, test.val)
}
}
}

func TestIntegration_ArrayParam(t *testing.T) {
c := getClient(t)
for _, test := range []struct {
val interface{}
want interface{}
}{
{[]int(nil), []Value(nil)},
{[]int{}, []Value(nil)},
{[]int{1, 2}, []Value{int64(1), int64(2)}},
{[3]int{1, 2, 3}, []Value{int64(1), int64(2), int64(3)}},
} {
got, err := paramRoundTrip(c, test.val)
if err != nil {
t.Fatal(err)
}
if len(val) != 1 {
t.Fatalf("got %d values, want 1", len(val))
}
got := val[0]
if !equal(got, test.val) {
t.Errorf("\ngot %#v (%T)\nwant %#v (%T)", got, got, test.val, test.val)
if !equal(got, test.want) {
t.Errorf("\ngot %#v (%T)\nwant %#v (%T)", got, got, test.want, test.want)
}
}
}

func paramRoundTrip(c *Client, x interface{}) (Value, error) {
q := c.Query("select ?")
q.Parameters = []QueryParameter{{Value: x}}
it, err := q.Read(context.Background())
if err != nil {
return nil, err
}
var val []Value
err = it.Next(&val)
if err != nil {
return nil, err
}
if len(val) != 1 {
return nil, errors.New("wrong number of values")
}
return val[0], nil
}

func equal(x1, x2 interface{}) bool {
if reflect.TypeOf(x1) != reflect.TypeOf(x2) {
return false
Expand All @@ -124,9 +177,7 @@ func equal(x1, x2 interface{}) bool {
case time.Time:
// BigQuery is only accurate to the microsecond.
return x1.Round(time.Microsecond).Equal(x2.(time.Time).Round(time.Microsecond))
case []byte:
return bytes.Equal(x1, x2.([]byte))
default:
return x1 == x2
return reflect.DeepEqual(x1, x2)
}
}
4 changes: 3 additions & 1 deletion bigquery/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package bigquery

import (
"reflect"

"golang.org/x/net/context"
bq "google.golang.org/api/bigquery/v2"
)
Expand Down Expand Up @@ -209,7 +211,7 @@ func (q *QueryConfig) populateJobQueryConfig(conf *bq.JobConfigurationQuery) err
if err != nil {
return err
}
pt, err := paramType(p.Value)
pt, err := paramType(reflect.TypeOf(p.Value))
if err != nil {
return err
}
Expand Down

0 comments on commit a64eb5d

Please sign in to comment.