diff --git a/spanner/key.go b/spanner/key.go index 1e689d49dc05..e6ecfe444417 100644 --- a/spanner/key.go +++ b/spanner/key.go @@ -83,9 +83,7 @@ func keyPartValue(part interface{}) (pb *proto3.Value, err error) { pb, _, err = encodeValue(int64(v)) case uint32: pb, _, err = encodeValue(int64(v)) - case float32: - pb, _, err = encodeValue(float64(v)) - case int64, float64, NullInt64, NullFloat64, bool, NullBool, []byte, string, NullString, time.Time, civil.Date, NullTime, NullDate, big.Rat, NullNumeric: + case int64, float64, float32, NullInt64, NullFloat64, NullFloat32, bool, NullBool, []byte, string, NullString, time.Time, civil.Date, NullTime, NullDate, big.Rat, NullNumeric: pb, _, err = encodeValue(v) case Encoder: part, err = v.EncodeSpanner() diff --git a/spanner/key_test.go b/spanner/key_test.go index bac0fc13f5d7..4e8adb135d76 100644 --- a/spanner/key_test.go +++ b/spanner/key_test.go @@ -163,6 +163,16 @@ func TestKey(t *testing.T) { wantProto: listValueProto(nullProto()), wantStr: "()", }, + { + k: Key{NullFloat32{3.14, true}}, + wantProto: listValueProto(floatProto(float64(float32(3.14)))), + wantStr: "(3.14)", + }, + { + k: Key{NullFloat32{2.0, false}}, + wantProto: listValueProto(nullProto()), + wantStr: "()", + }, { k: Key{NullBool{true, true}}, wantProto: listValueProto(boolProto(true)), diff --git a/spanner/protoutils.go b/spanner/protoutils.go index c15a6eaf5b1d..83af997ca841 100644 --- a/spanner/protoutils.go +++ b/spanner/protoutils.go @@ -57,6 +57,14 @@ func intType() *sppb.Type { return &sppb.Type{Code: sppb.TypeCode_INT64} } +func float32Proto(n float32) *proto3.Value { + return &proto3.Value{Kind: &proto3.Value_NumberValue{NumberValue: float64(n)}} +} + +func float32Type() *sppb.Type { + return &sppb.Type{Code: sppb.TypeCode_FLOAT32} +} + func floatProto(n float64) *proto3.Value { return &proto3.Value{Kind: &proto3.Value_NumberValue{NumberValue: n}} } diff --git a/spanner/row.go b/spanner/row.go index d546f240724e..5905bc9b2299 100644 --- a/spanner/row.go +++ b/spanner/row.go @@ -57,6 +57,8 @@ import ( // *[]int64, *[]NullInt64 - INT64 ARRAY // *bool(not NULL), *NullBool - BOOL // *[]bool, *[]NullBool - BOOL ARRAY +// *float32(not NULL), *NullFloat32 - FLOAT32 +// *[]float32, *[]NullFloat32 - FLOAT32 ARRAY // *float64(not NULL), *NullFloat64 - FLOAT64 // *[]float64, *[]NullFloat64 - FLOAT64 ARRAY // *big.Rat(not NULL), *NullNumeric - NUMERIC diff --git a/spanner/row_test.go b/spanner/row_test.go index 466c9d5626b1..221757bed0eb 100644 --- a/spanner/row_test.go +++ b/spanner/row_test.go @@ -66,6 +66,11 @@ var ( {Name: "NULL_FLOAT64", Type: floatType()}, {Name: "FLOAT64_ARRAY", Type: listType(floatType())}, {Name: "NULL_FLOAT64_ARRAY", Type: listType(floatType())}, + // FLOAT32 / FLOAT32 ARRAY + {Name: "FLOAT32", Type: float32Type()}, + {Name: "NULL_FLOAT32", Type: float32Type()}, + {Name: "FLOAT32_ARRAY", Type: listType(float32Type())}, + {Name: "NULL_FLOAT32_ARRAY", Type: listType(float32Type())}, // TIMESTAMP / TIMESTAMP ARRAY {Name: "TIMESTAMP", Type: timeType()}, {Name: "NULL_TIMESTAMP", Type: timeType()}, @@ -84,7 +89,8 @@ var ( structType( mkField("Col1", intType()), mkField("Col2", floatType()), - mkField("Col3", stringType()), + mkField("Col3", float32Type()), + mkField("Col4", stringType()), ), ), }, @@ -94,7 +100,8 @@ var ( structType( mkField("Col1", intType()), mkField("Col2", floatType()), - mkField("Col3", stringType()), + mkField("Col3", float32Type()), + mkField("Col4", stringType()), ), ), }, @@ -125,6 +132,11 @@ var ( nullProto(), listProto(nullProto(), nullProto(), floatProto(1.7)), nullProto(), + // FLOAT32 / FLOAT32 ARRAY + float32Proto(0.3), + nullProto(), + listProto(nullProto(), nullProto(), float32Proto(0.3)), + nullProto(), // TIMESTAMP / TIMESTAMP ARRAY timeProto(tm), nullProto(), @@ -138,7 +150,7 @@ var ( // STRUCT ARRAY listProto( nullProto(), - listProto(intProto(3), floatProto(33.3), stringProto("three")), + listProto(intProto(3), floatProto(33.3), float32Proto(0.3), stringProto("three")), nullProto(), ), nullProto(), @@ -177,6 +189,11 @@ func TestColumnValues(t *testing.T) { {NullFloat64{}}, {[]NullFloat64{{}, {}, {1.7, true}}}, {[]NullFloat64(nil)}, + // FLOAT32 / FLOAT64 ARRAY + {float32(0.3), NullFloat32{0.3, true}}, + {NullFloat32{}}, + {[]NullFloat32{{}, {}, {float32(0.3), true}}}, + {[]NullFloat32(nil)}, // TIMESTAMP / TIMESTAMP ARRAY {tm, NullTime{tm, true}}, {NullTime{}}, @@ -192,13 +209,15 @@ func TestColumnValues(t *testing.T) { []*struct { Col1 NullInt64 Col2 NullFloat64 - Col3 string + Col3 NullFloat32 + Col4 string }{ nil, { NullInt64{3, true}, NullFloat64{33.3, true}, + NullFloat32{0.3, true}, "three", }, nil, @@ -210,11 +229,13 @@ func TestColumnValues(t *testing.T) { fields: []*sppb.StructType_Field{ mkField("Col1", intType()), mkField("Col2", floatType()), - mkField("Col3", stringType()), + mkField("Col3", float32Type()), + mkField("Col4", stringType()), }, vals: []*proto3.Value{ intProto(3), floatProto(33.3), + float32Proto(0.3), stringProto("three"), }, }, @@ -227,7 +248,8 @@ func TestColumnValues(t *testing.T) { []*struct { Col1 NullInt64 Col2 NullFloat64 - Col3 string + Col3 NullFloat32 + Col4 string }(nil), []NullRow(nil), }, @@ -311,32 +333,37 @@ func TestNilDst(t *testing.T) { structType( mkField("Col1", intType()), mkField("Col2", floatType()), + mkField("Col3", float32Type()), ), ), }, }, []*proto3.Value{listProto( - listProto(intProto(3), floatProto(33.3)), + listProto(intProto(3), floatProto(33.3), float32Proto(0.3)), )}, }, (*[]*struct { Col1 int Col2 float64 + Col3 float32 })(nil), errDecodeColumn(0, errNilDst((*[]*struct { Col1 int Col2 float64 + Col3 float32 })(nil))), (*struct { StructArray []*struct { Col1 int Col2 float64 + Col3 float32 } `spanner:"STRUCT_ARRAY"` })(nil), errNilDst((*struct { StructArray []*struct { Col1 int Col2 float64 + Col3 float32 } `spanner:"STRUCT_ARRAY"` })(nil)), }, @@ -399,6 +426,10 @@ func TestNullTypeErr(t *testing.T) { "NULL_FLOAT64", proto.Float64(0.0), }, + { + "NULL_FLOAT32", + proto.Float32(0.0), + }, { "NULL_TIMESTAMP", &tm, @@ -857,6 +888,50 @@ func TestBrokenRow(t *testing.T) { proto.Float64(0), errDecodeColumn(0, errUnexpectedFloat64Str("nan")), }, + { + // Field specifies FLOAT32 type, value is having a nil Kind. + &Row{ + []*sppb.StructType_Field{ + {Name: "Col0", Type: float32Type()}, + }, + []*proto3.Value{{Kind: (*proto3.Value_NumberValue)(nil)}}, + }, + &NullFloat32{1.0, true}, + errDecodeColumn(0, errSrcVal(&proto3.Value{Kind: (*proto3.Value_NumberValue)(nil)}, "Number")), + }, + { + // Field specifies FLOAT32 type, but value is for BOOL type. + &Row{ + []*sppb.StructType_Field{ + {Name: "Col0", Type: float32Type()}, + }, + []*proto3.Value{boolProto(true)}, + }, + &NullFloat32{1.0, true}, + errDecodeColumn(0, errSrcVal(boolProto(true), "Number")), + }, + { + // Field specifies FLOAT32 type, but value is wrongly encoded. + &Row{ + []*sppb.StructType_Field{ + {Name: "Col0", Type: float32Type()}, + }, + []*proto3.Value{stringProto("nan")}, + }, + &NullFloat32{}, + errDecodeColumn(0, errUnexpectedFloat32Str("nan")), + }, + { + // Field specifies FLOAT32 type, but value is wrongly encoded. + &Row{ + []*sppb.StructType_Field{ + {Name: "Col0", Type: float32Type()}, + }, + []*proto3.Value{stringProto("nan")}, + }, + proto.Float32(0), + errDecodeColumn(0, errUnexpectedFloat32Str("nan")), + }, { // Field specifies BYTES type, value is having a nil Kind. &Row{ @@ -1531,6 +1606,11 @@ func TestToStruct(t *testing.T) { NullFloat64 NullFloat64 `spanner:"NULL_FLOAT64"` Float64Array []NullFloat64 `spanner:"FLOAT64_ARRAY"` NullFloat64Array []NullFloat64 `spanner:"NULL_FLOAT64_ARRAY"` + // FLOAT32 / FLOAT32 ARRAY + Float32 float32 `spanner:"FLOAT32"` + NullFloat32 NullFloat32 `spanner:"NULL_FLOAT32"` + Float32Array []NullFloat32 `spanner:"FLOAT32_ARRAY"` + NullFloat32Array []NullFloat32 `spanner:"NULL_FLOAT32_ARRAY"` // TIMESTAMP / TIMESTAMP ARRAY Timestamp time.Time `spanner:"TIMESTAMP"` NullTimestamp NullTime `spanner:"NULL_TIMESTAMP"` @@ -1546,12 +1626,14 @@ func TestToStruct(t *testing.T) { StructArray []*struct { Col1 int64 Col2 float64 - Col3 string + Col3 float32 + Col4 string } `spanner:"STRUCT_ARRAY"` NullStructArray []*struct { Col1 int64 Col2 float64 - Col3 string + Col3 float32 + Col4 string } `spanner:"NULL_STRUCT_ARRAY"` }{ {}, // got @@ -1581,6 +1663,11 @@ func TestToStruct(t *testing.T) { NullFloat64{}, []NullFloat64{{}, {}, {1.7, true}}, []NullFloat64(nil), + // FLOAT32 / FLOAT32 ARRAY + float32(0.3), + NullFloat32{}, + []NullFloat32{{}, {}, {float32(0.3), true}}, + []NullFloat32(nil), // TIMESTAMP / TIMESTAMP ARRAY tm, NullTime{}, @@ -1595,17 +1682,19 @@ func TestToStruct(t *testing.T) { []*struct { Col1 int64 Col2 float64 - Col3 string + Col3 float32 + Col4 string }{ nil, - {3, 33.3, "three"}, + {3, 33.3, float32(0.3), "three"}, nil, }, []*struct { Col1 int64 Col2 float64 - Col3 string + Col3 float32 + Col4 string }(nil), }, // want } @@ -1632,7 +1721,9 @@ func TestToStructWithCustomTypes(t *testing.T) { type CustomBool bool type CustomNullBool NullBool type CustomFloat64 float64 + type CustomFloat32 float32 type CustomNullFloat64 NullFloat64 + type CustomNullFloat32 NullFloat32 type CustomTime time.Time type CustomNullTime NullTime type CustomDate civil.Date @@ -1665,6 +1756,11 @@ func TestToStructWithCustomTypes(t *testing.T) { NullFloat64 CustomNullFloat64 `spanner:"NULL_FLOAT64"` Float64Array []CustomNullFloat64 `spanner:"FLOAT64_ARRAY"` NullFloat64Array []CustomNullFloat64 `spanner:"NULL_FLOAT64_ARRAY"` + // FLOAT32 / FLOAT32 ARRAY + Float32 CustomFloat32 `spanner:"FLOAT32"` + NullFloat32 CustomNullFloat32 `spanner:"NULL_FLOAT32"` + Float32Array []CustomNullFloat32 `spanner:"FLOAT32_ARRAY"` + NullFloat32Array []CustomNullFloat32 `spanner:"NULL_FLOAT32_ARRAY"` // TIMESTAMP / TIMESTAMP ARRAY Timestamp CustomTime `spanner:"TIMESTAMP"` NullTimestamp CustomNullTime `spanner:"NULL_TIMESTAMP"` @@ -1680,12 +1776,14 @@ func TestToStructWithCustomTypes(t *testing.T) { StructArray []*struct { Col1 CustomInt64 Col2 CustomFloat64 - Col3 CustomString + Col3 CustomFloat32 + Col4 CustomString } `spanner:"STRUCT_ARRAY"` NullStructArray []*struct { Col1 CustomInt64 Col2 CustomFloat64 - Col3 CustomString + Col3 CustomFloat32 + Col4 CustomString } `spanner:"NULL_STRUCT_ARRAY"` }{ {}, // got @@ -1715,6 +1813,11 @@ func TestToStructWithCustomTypes(t *testing.T) { CustomNullFloat64{}, []CustomNullFloat64{{}, {}, {1.7, true}}, []CustomNullFloat64(nil), + // FLOAT32 / FLOAT32 ARRAY + 0.3, + CustomNullFloat32{}, + []CustomNullFloat32{{}, {}, {0.3, true}}, + []CustomNullFloat32(nil), // TIMESTAMP / TIMESTAMP ARRAY CustomTime(tm), CustomNullTime{}, @@ -1729,17 +1832,19 @@ func TestToStructWithCustomTypes(t *testing.T) { []*struct { Col1 CustomInt64 Col2 CustomFloat64 - Col3 CustomString + Col3 CustomFloat32 + Col4 CustomString }{ nil, - {3, 33.3, "three"}, + {3, 33.3, 0.3, "three"}, nil, }, []*struct { Col1 CustomInt64 Col2 CustomFloat64 - Col3 CustomString + Col3 CustomFloat32 + Col4 CustomString }(nil), }, // want } diff --git a/spanner/statement_test.go b/spanner/statement_test.go index 4ef7ea10c037..20f92c33e2d0 100644 --- a/spanner/statement_test.go +++ b/spanner/statement_test.go @@ -100,6 +100,19 @@ func TestConvertParams(t *testing.T) { {[]NullFloat64(nil), nullProto(), listType(floatType())}, {[]NullFloat64{}, listProto(), listType(floatType())}, {[]NullFloat64{{2.72, true}, {}}, listProto(floatProto(2.72), nullProto()), listType(floatType())}, + // float32 + {float32(0.0), float32Proto(0.0), float32Type()}, + {float32(math.Inf(1)), float32Proto(float32(math.Inf(1))), float32Type()}, + {float32(math.Inf(-1)), float32Proto(float32(math.Inf(-1))), float32Type()}, + {float32(math.NaN()), float32Proto(float32(math.NaN())), float32Type()}, + {NullFloat32{3.14, true}, float32Proto(3.14), float32Type()}, + {NullFloat32{-99.99, false}, nullProto(), float32Type()}, + {[]float32(nil), nullProto(), listType(float32Type())}, + {[]float32{}, listProto(), listType(float32Type())}, + {[]float32{3.14, float32(math.Inf(1))}, listProto(float32Proto(3.14), float32Proto(float32(math.Inf(1)))), listType(float32Type())}, + {[]NullFloat32(nil), nullProto(), listType(float32Type())}, + {[]NullFloat32{}, listProto(), listType(float32Type())}, + {[]NullFloat32{{3.14, true}, {}}, listProto(float32Proto(3.14), nullProto()), listType(float32Type())}, // string {"", stringProto(""), stringType()}, {"foo", stringProto("foo"), stringType()}, @@ -178,11 +191,14 @@ func TestConvertParams(t *testing.T) { if test.wantType.Code == floatType().Code && proto.MarshalTextString(gotParamField) == proto.MarshalTextString(test.wantField) { continue } - t.Errorf("%#v: got %v, want %v\n", test.val, gotParamField, test.wantField) + if test.wantType.Code == float32Type().Code && proto.MarshalTextString(gotParamField) == proto.MarshalTextString(test.wantField) { + continue + } + t.Errorf("%#v:\n got: %v\nwant: %v\n", test.val, gotParamField, test.wantField) } gotParamType := gotParamTypes["var"] if !proto.Equal(gotParamType, test.wantType) { - t.Errorf("%#v: got %v, want %v\n", test.val, gotParamType, test.wantField) + t.Errorf("%#v:\n got: %v\nwant: %v\n", test.val, gotParamType, test.wantField) } } } diff --git a/spanner/value.go b/spanner/value.go index 83a8132f0592..1d4e4ea67a3a 100644 --- a/spanner/value.go +++ b/spanner/value.go @@ -427,6 +427,86 @@ func (n NullFloat64) GormDataType() string { return "FLOAT64" } +// NullFloat32 represents a Cloud Spanner FLOAT32 that may be NULL. +type NullFloat32 struct { + Float32 float32 // Float32 contains the value when it is non-NULL, and zero when NULL. + Valid bool // Valid is true if FLOAT32 is not NULL. +} + +// IsNull implements NullableValue.IsNull for NullFloat32. +func (n NullFloat32) IsNull() bool { + return !n.Valid +} + +// String implements Stringer.String for NullFloat32 +func (n NullFloat32) String() string { + if !n.Valid { + return nullString + } + return fmt.Sprintf("%v", n.Float32) +} + +// MarshalJSON implements json.Marshaler.MarshalJSON for NullFloat32. +func (n NullFloat32) MarshalJSON() ([]byte, error) { + return nulljson(n.Valid, n.Float32) +} + +// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullFloat32. +func (n *NullFloat32) UnmarshalJSON(payload []byte) error { + if payload == nil { + return fmt.Errorf("payload should not be nil") + } + if bytes.Equal(payload, jsonNullBytes) { + n.Float32 = float32(0) + n.Valid = false + return nil + } + num, err := strconv.ParseFloat(string(payload), 32) + if err != nil { + return fmt.Errorf("payload cannot be converted to float32: got %v", string(payload)) + } + n.Float32 = float32(num) + n.Valid = true + return nil +} + +// Value implements the driver.Valuer interface. +func (n NullFloat32) Value() (driver.Value, error) { + if n.IsNull() { + return nil, nil + } + return n.Float32, nil +} + +// Scan implements the sql.Scanner interface. +func (n *NullFloat32) Scan(value interface{}) error { + if value == nil { + n.Float32, n.Valid = 0, false + return nil + } + n.Valid = true + switch p := value.(type) { + default: + return spannerErrorf(codes.InvalidArgument, "invalid type for NullFloat32: %v", p) + case *float32: + n.Float32 = *p + case float32: + n.Float32 = p + case *NullFloat32: + n.Float32 = p.Float32 + n.Valid = p.Valid + case NullFloat32: + n.Float32 = p.Float32 + n.Valid = p.Valid + } + return nil +} + +// GormDataType is used by gorm to determine the default data type for fields with this type. +func (n NullFloat32) GormDataType() string { + return "FLOAT32" +} + // NullBool represents a Cloud Spanner BOOL that may be NULL. type NullBool struct { Bool bool // Bool contains the value when it is non-NULL, and false when NULL. @@ -1498,6 +1578,102 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}, opts ...DecodeO return err } *p = y + case *float32: + if p == nil { + return errNilDst(p) + } + if code != sppb.TypeCode_FLOAT32 { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + return errDstNotForNull(ptr) + } + x, err := getFloat32Value(v) + if err != nil { + return err + } + *p = x + case *NullFloat32, **float32: + if p == nil { + return errNilDst(p) + } + if code != sppb.TypeCode_FLOAT32 { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + switch sp := ptr.(type) { + case *NullFloat32: + *sp = NullFloat32{} + case **float32: + *sp = nil + } + break + } + x, err := getFloat32Value(v) + if err != nil { + return err + } + switch sp := ptr.(type) { + case *NullFloat32: + sp.Valid = true + sp.Float32 = x + case **float32: + *sp = &x + } + case *[]NullFloat32, *[]*float32: + if p == nil { + return errNilDst(p) + } + if acode != sppb.TypeCode_FLOAT32 { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + switch sp := ptr.(type) { + case *[]NullFloat32: + *sp = nil + case *[]*float32: + *sp = nil + } + break + } + x, err := getListValue(v) + if err != nil { + return err + } + switch sp := ptr.(type) { + case *[]NullFloat32: + y, err := decodeNullFloat32Array(x) + if err != nil { + return err + } + *sp = y + case *[]*float32: + y, err := decodeFloat32PointerArray(x) + if err != nil { + return err + } + *sp = y + } + case *[]float32: + if p == nil { + return errNilDst(p) + } + if acode != sppb.TypeCode_FLOAT32 { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + *p = nil + break + } + x, err := getListValue(v) + if err != nil { + return err + } + y, err := decodeFloat32Array(x) + if err != nil { + return err + } + *p = y case *big.Rat: if code != sppb.TypeCode_NUMERIC { return errTypeMismatch(code, acode, ptr) @@ -2000,6 +2176,7 @@ const ( spannerTypeNonNullInt64 spannerTypeNonNullBool spannerTypeNonNullFloat64 + spannerTypeNonNullFloat32 spannerTypeNonNullNumeric spannerTypeNonNullTime spannerTypeNonNullDate @@ -2007,6 +2184,7 @@ const ( spannerTypeNullInt64 spannerTypeNullBool spannerTypeNullFloat64 + spannerTypeNullFloat32 spannerTypeNullTime spannerTypeNullDate spannerTypeNullNumeric @@ -2018,6 +2196,7 @@ const ( spannerTypeArrayOfNonNullInt64 spannerTypeArrayOfNonNullBool spannerTypeArrayOfNonNullFloat64 + spannerTypeArrayOfNonNullFloat32 spannerTypeArrayOfNonNullNumeric spannerTypeArrayOfNonNullTime spannerTypeArrayOfNonNullDate @@ -2025,6 +2204,7 @@ const ( spannerTypeArrayOfNullInt64 spannerTypeArrayOfNullBool spannerTypeArrayOfNullFloat64 + spannerTypeArrayOfNullFloat32 spannerTypeArrayOfNullNumeric spannerTypeArrayOfNullJSON spannerTypeArrayOfNullTime @@ -2037,7 +2217,7 @@ const ( // Spanner. func (d decodableSpannerType) supportsNull() bool { switch d { - case spannerTypeNonNullString, spannerTypeNonNullInt64, spannerTypeNonNullBool, spannerTypeNonNullFloat64, spannerTypeNonNullTime, spannerTypeNonNullDate, spannerTypeNonNullNumeric: + case spannerTypeNonNullString, spannerTypeNonNullInt64, spannerTypeNonNullBool, spannerTypeNonNullFloat64, spannerTypeNonNullFloat32, spannerTypeNonNullTime, spannerTypeNonNullDate, spannerTypeNonNullNumeric: return false default: return true @@ -2058,6 +2238,7 @@ var typeOfNullString = reflect.TypeOf(NullString{}) var typeOfNullInt64 = reflect.TypeOf(NullInt64{}) var typeOfNullBool = reflect.TypeOf(NullBool{}) var typeOfNullFloat64 = reflect.TypeOf(NullFloat64{}) +var typeOfNullFloat32 = reflect.TypeOf(NullFloat32{}) var typeOfNullTime = reflect.TypeOf(NullTime{}) var typeOfNullDate = reflect.TypeOf(NullDate{}) var typeOfNullNumeric = reflect.TypeOf(NullNumeric{}) @@ -2088,6 +2269,8 @@ func getDecodableSpannerType(ptr interface{}, isPtr bool) decodableSpannerType { return spannerTypeNonNullInt64 case reflect.Bool: return spannerTypeNonNullBool + case reflect.Float32: + return spannerTypeNonNullFloat32 case reflect.Float64: return spannerTypeNonNullFloat64 case reflect.Ptr: @@ -2124,6 +2307,9 @@ func getDecodableSpannerType(ptr interface{}, isPtr bool) decodableSpannerType { if t.ConvertibleTo(typeOfNullFloat64) { return spannerTypeNullFloat64 } + if t.ConvertibleTo(typeOfNullFloat32) { + return spannerTypeNullFloat32 + } if t.ConvertibleTo(typeOfNullTime) { return spannerTypeNullTime } @@ -2157,6 +2343,8 @@ func getDecodableSpannerType(ptr interface{}, isPtr bool) decodableSpannerType { return spannerTypeArrayOfNonNullBool case reflect.Float64: return spannerTypeArrayOfNonNullFloat64 + case reflect.Float32: + return spannerTypeArrayOfNonNullFloat32 case reflect.Ptr: t := val.Type().Elem() if t.ConvertibleTo(typeOfNullNumeric) { @@ -2185,6 +2373,9 @@ func getDecodableSpannerType(ptr interface{}, isPtr bool) decodableSpannerType { if t.ConvertibleTo(typeOfNullFloat64) { return spannerTypeArrayOfNullFloat64 } + if t.ConvertibleTo(typeOfNullFloat32) { + return spannerTypeArrayOfNullFloat32 + } if t.ConvertibleTo(typeOfNullTime) { return spannerTypeArrayOfNullTime } @@ -2321,6 +2512,23 @@ func (dsc decodableSpannerType) decodeValueToCustomType(v *proto3.Value, t *sppb } else { result = &NullFloat64{x, !isNull} } + case spannerTypeNonNullFloat32, spannerTypeNullFloat32: + if code != sppb.TypeCode_FLOAT32 { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + result = &NullFloat32{} + break + } + x, err := getFloat32Value(v) + if err != nil { + return err + } + if dsc == spannerTypeNonNullFloat32 { + result = &x + } else { + result = &NullFloat32{x, !isNull} + } case spannerTypeNonNullNumeric, spannerTypeNullNumeric: if code != sppb.TypeCode_NUMERIC { return errTypeMismatch(code, acode, ptr) @@ -2478,6 +2686,23 @@ func (dsc decodableSpannerType) decodeValueToCustomType(v *proto3.Value, t *sppb return err } result = y + case spannerTypeArrayOfNonNullFloat32, spannerTypeArrayOfNullFloat32: + if acode != sppb.TypeCode_FLOAT32 { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + ptr = nil + return nil + } + x, err := getListValue(v) + if err != nil { + return err + } + y, err := decodeGenericArray(reflect.TypeOf(ptr).Elem(), x, float32Type(), "FLOAT32") + if err != nil { + return err + } + result = y case spannerTypeArrayOfNonNullFloat64, spannerTypeArrayOfNullFloat64: if acode != sppb.TypeCode_FLOAT64 { return errTypeMismatch(code, acode, ptr) @@ -2714,6 +2939,40 @@ func getFloat64Value(v *proto3.Value) (float64, error) { return 0, errSrcVal(v, "Number") } +// errUnexpectedFloat32Str returns error for decoder getting an unexpected +// string for representing special float values. +func errUnexpectedFloat32Str(s string) error { + return spannerErrorf(codes.FailedPrecondition, "unexpected string value %q for float32 number", s) +} + +// getFloat32Value returns the float32 value encoded in proto3.Value v whose +// kind is proto3.Value_NumberValue / proto3.Value_StringValue. +// Cloud Spanner uses string to encode NaN, Infinity and -Infinity. +func getFloat32Value(v *proto3.Value) (float32, error) { + switch x := v.GetKind().(type) { + case *proto3.Value_NumberValue: + if x == nil { + break + } + return float32(x.NumberValue), nil + case *proto3.Value_StringValue: + if x == nil { + break + } + switch x.StringValue { + case "NaN": + return float32(math.NaN()), nil + case "Infinity": + return float32(math.Inf(1)), nil + case "-Infinity": + return float32(math.Inf(-1)), nil + default: + return 0, errUnexpectedFloat32Str(x.StringValue) + } + } + return 0, errSrcVal(v, "Number") +} + // errNilListValue returns error for unexpected nil ListValue in decoding Cloud Spanner ARRAYs. func errNilListValue(sqlType string) error { return spannerErrorf(codes.FailedPrecondition, "unexpected nil ListValue in decoding %v array", sqlType) @@ -2914,6 +3173,48 @@ func decodeFloat64Array(pb *proto3.ListValue) ([]float64, error) { return a, nil } +// decodeNullFloat32Array decodes proto3.ListValue pb into a NullFloat32 slice. +func decodeNullFloat32Array(pb *proto3.ListValue) ([]NullFloat32, error) { + if pb == nil { + return nil, errNilListValue("FLOAT32") + } + a := make([]NullFloat32, len(pb.Values)) + for i, v := range pb.Values { + if err := decodeValue(v, float32Type(), &a[i]); err != nil { + return nil, errDecodeArrayElement(i, v, "FLOAT32", err) + } + } + return a, nil +} + +// decodeFloat32PointerArray decodes proto3.ListValue pb into a *float32 slice. +func decodeFloat32PointerArray(pb *proto3.ListValue) ([]*float32, error) { + if pb == nil { + return nil, errNilListValue("FLOAT32") + } + a := make([]*float32, len(pb.Values)) + for i, v := range pb.Values { + if err := decodeValue(v, float32Type(), &a[i]); err != nil { + return nil, errDecodeArrayElement(i, v, "FLOAT32", err) + } + } + return a, nil +} + +// decodeFloat32Array decodes proto3.ListValue pb into a float32 slice. +func decodeFloat32Array(pb *proto3.ListValue) ([]float32, error) { + if pb == nil { + return nil, errNilListValue("FLOAT32") + } + a := make([]float32, len(pb.Values)) + for i, v := range pb.Values { + if err := decodeValue(v, float32Type(), &a[i]); err != nil { + return nil, errDecodeArrayElement(i, v, "FLOAT32", err) + } + } + return a, nil +} + // decodeNullNumericArray decodes proto3.ListValue pb into a NullNumeric slice. func decodeNullNumericArray(pb *proto3.ListValue) ([]NullNumeric, error) { if pb == nil { @@ -3543,6 +3844,43 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) { } } pt = listType(floatType()) + case float32: + pb.Kind = &proto3.Value_NumberValue{NumberValue: float64(v)} + pt = float32Type() + case []float32: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + } + pt = listType(float32Type()) + case NullFloat32: + if v.Valid { + return encodeValue(v.Float32) + } + pt = float32Type() + case []NullFloat32: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + } + pt = listType(float32Type()) + case *float32: + if v != nil { + return encodeValue(*v) + } + pt = float32Type() + case []*float32: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + } + pt = listType(float32Type()) case big.Rat: switch LossOfPrecisionHandling { case NumericError: @@ -3800,6 +4138,10 @@ func convertCustomTypeValue(sourceType decodableSpannerType, v interface{}) (int destination = reflect.Indirect(reflect.New(reflect.TypeOf(float64(0.0)))) case spannerTypeNullFloat64: destination = reflect.Indirect(reflect.New(reflect.TypeOf(NullFloat64{}))) + case spannerTypeNonNullFloat32: + destination = reflect.Indirect(reflect.New(reflect.TypeOf(float32(0.0)))) + case spannerTypeNullFloat32: + destination = reflect.Indirect(reflect.New(reflect.TypeOf(NullFloat32{}))) case spannerTypeNonNullTime: destination = reflect.Indirect(reflect.New(reflect.TypeOf(time.Time{}))) case spannerTypeNullTime: @@ -3863,6 +4205,16 @@ func convertCustomTypeValue(sourceType decodableSpannerType, v interface{}) (int return []NullFloat64(nil), nil } destination = reflect.MakeSlice(reflect.TypeOf([]NullFloat64{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap()) + case spannerTypeArrayOfNonNullFloat32: + if reflect.ValueOf(v).IsNil() { + return []float32(nil), nil + } + destination = reflect.MakeSlice(reflect.TypeOf([]float32{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap()) + case spannerTypeArrayOfNullFloat32: + if reflect.ValueOf(v).IsNil() { + return []NullFloat32(nil), nil + } + destination = reflect.MakeSlice(reflect.TypeOf([]NullFloat32{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap()) case spannerTypeArrayOfNonNullTime: if reflect.ValueOf(v).IsNil() { return []time.Time(nil), nil @@ -4042,6 +4394,7 @@ func isSupportedMutationType(v interface{}) bool { int, []int, int64, *int64, []int64, []*int64, NullInt64, []NullInt64, bool, *bool, []bool, []*bool, NullBool, []NullBool, float64, *float64, []float64, []*float64, NullFloat64, []NullFloat64, + float32, *float32, []float32, []*float32, NullFloat32, []NullFloat32, time.Time, *time.Time, []time.Time, []*time.Time, NullTime, []NullTime, civil.Date, *civil.Date, []civil.Date, []*civil.Date, NullDate, []NullDate, big.Rat, *big.Rat, []big.Rat, []*big.Rat, NullNumeric, []NullNumeric, diff --git a/spanner/value_test.go b/spanner/value_test.go index ca85544b94a3..71208e1f4b98 100644 --- a/spanner/value_test.go +++ b/spanner/value_test.go @@ -264,6 +264,8 @@ func TestEncodeValue(t *testing.T) { var bNilPtr *bool fValue := 3.14 var fNilPtr *float64 + f32Value := float32(3.14) + var f32NilPtr *float32 tValue := t1 var tNilPtr *time.Time dValue := d1 @@ -279,6 +281,7 @@ func TestEncodeValue(t *testing.T) { tInt = intType() tBool = boolType() tFloat = floatType() + tFloat32 = float32Type() tBytes = bytesType() tTime = timeType() tDate = dateType() @@ -341,6 +344,17 @@ func TestEncodeValue(t *testing.T) { {[]float64{3.141, 0.618, math.Inf(-1)}, listProto(floatProto(3.141), floatProto(0.618), floatProto(math.Inf(-1))), listType(tFloat), "[]float64"}, {[]NullFloat64{{3.141, true}, {0.618, false}}, listProto(floatProto(3.141), nullProto()), listType(tFloat), "[]NullFloat64"}, {[]*float64{&fValue, fNilPtr}, listProto(floatProto(3.14), nullProto()), listType(tFloat), "[]NullFloat64"}, + // FLOAT32 / FLOAT32 ARRAY + {float32(3.14), float32Proto(3.14), tFloat32, "float32"}, + {NullFloat32{3.14, true}, float32Proto(3.14), tFloat32, "NullFloat32 with value"}, + {NullFloat32{float32(math.Inf(1)), true}, float32Proto(float32(math.Inf(1))), tFloat32, "NullFloat32 with infinity"}, + {NullFloat32{3.14, false}, nullProto(), tFloat32, "NullFloat32 with null"}, + {&f32Value, float32Proto(3.14), tFloat32, "*float32 with value"}, + {f32NilPtr, nullProto(), tFloat32, "*float32 with null"}, + {[]float32(nil), nullProto(), listType(tFloat32), "null []float32"}, + {[]float32{3.14, 0.618, float32(math.Inf(-1))}, listProto(float32Proto(3.14), float32Proto(0.618), float32Proto(float32(math.Inf(-1)))), listType(tFloat32), "[]float32"}, + {[]NullFloat32{{3.14, true}, {0.618, false}}, listProto(float32Proto(3.14), nullProto()), listType(tFloat32), "[]NullFloat"}, + {[]*float32{&f32Value, f32NilPtr}, listProto(float32Proto(3.14), nullProto()), listType(tFloat32), "[]NullFloat32"}, // NUMERIC / NUMERIC ARRAY {*numValuePtr, numericProto(numValuePtr), tNumeric, "big.Rat"}, {numValuePtr, numericProto(numValuePtr), tNumeric, "*big.Rat"}, @@ -925,6 +939,7 @@ func TestEncodeStructValueBasicFields(t *testing.T) { type CustomInt64 int64 type CustomBool bool type CustomFloat64 float64 + type CustomFloat32 float32 type CustomTime time.Time type CustomDate civil.Date @@ -932,6 +947,7 @@ func TestEncodeStructValueBasicFields(t *testing.T) { type CustomNullInt64 NullInt64 type CustomNullBool NullBool type CustomNullFloat64 NullFloat64 + type CustomNullFloat32 NullFloat32 type CustomNullTime NullTime type CustomNullDate NullDate @@ -939,6 +955,7 @@ func TestEncodeStructValueBasicFields(t *testing.T) { iValue := int64(300) bValue := false fValue := 3.45 + f32Value := float32(3.14) tValue := t1 dValue := d1 @@ -947,6 +964,7 @@ func TestEncodeStructValueBasicFields(t *testing.T) { mkField("Intf", intType()), mkField("Boolf", boolType()), mkField("Floatf", floatType()), + mkField("Float32f", float32Type()), mkField("Bytef", bytesType()), mkField("Timef", timeType()), mkField("Datef", dateType())) @@ -955,19 +973,21 @@ func TestEncodeStructValueBasicFields(t *testing.T) { { "Basic types.", struct { - Stringf string - Intf int - Boolf bool - Floatf float64 - Bytef []byte - Timef time.Time - Datef civil.Date - }{"abc", 300, false, 3.45, []byte("foo"), t1, d1}, + Stringf string + Intf int + Boolf bool + Floatf float64 + Float32f float32 + Bytef []byte + Timef time.Time + Datef civil.Date + }{"abc", 300, false, 3.45, float32(3.14), []byte("foo"), t1, d1}, listProto( stringProto("abc"), intProto(300), boolProto(false), floatProto(3.45), + float32Proto(3.14), bytesProto([]byte("foo")), timeProto(t1), dateProto(d1)), @@ -976,19 +996,21 @@ func TestEncodeStructValueBasicFields(t *testing.T) { { "Pointers to basic types.", struct { - Stringf *string - Intf *int64 - Boolf *bool - Floatf *float64 - Bytef []byte - Timef *time.Time - Datef *civil.Date - }{&sValue, &iValue, &bValue, &fValue, []byte("foo"), &tValue, &dValue}, + Stringf *string + Intf *int64 + Boolf *bool + Floatf *float64 + Float32f *float32 + Bytef []byte + Timef *time.Time + Datef *civil.Date + }{&sValue, &iValue, &bValue, &fValue, &f32Value, []byte("foo"), &tValue, &dValue}, listProto( stringProto("abc"), intProto(300), boolProto(false), floatProto(3.45), + float32Proto(3.14), bytesProto([]byte("foo")), timeProto(t1), dateProto(d1)), @@ -997,14 +1019,15 @@ func TestEncodeStructValueBasicFields(t *testing.T) { { "Pointers to basic types with null values.", struct { - Stringf *string - Intf *int64 - Boolf *bool - Floatf *float64 - Bytef []byte - Timef *time.Time - Datef *civil.Date - }{nil, nil, nil, nil, nil, nil, nil}, + Stringf *string + Intf *int64 + Boolf *bool + Floatf *float64 + Float32f *float32 + Bytef []byte + Timef *time.Time + Datef *civil.Date + }{nil, nil, nil, nil, nil, nil, nil, nil}, listProto( nullProto(), nullProto(), @@ -1012,25 +1035,28 @@ func TestEncodeStructValueBasicFields(t *testing.T) { nullProto(), nullProto(), nullProto(), + nullProto(), nullProto()), StructTypeProto, }, { "Basic custom types.", struct { - Stringf CustomString - Intf CustomInt64 - Boolf CustomBool - Floatf CustomFloat64 - Bytef CustomBytes - Timef CustomTime - Datef CustomDate - }{"abc", 300, false, 3.45, []byte("foo"), CustomTime(t1), CustomDate(d1)}, + Stringf CustomString + Intf CustomInt64 + Boolf CustomBool + Floatf CustomFloat64 + Float32f CustomFloat32 + Bytef CustomBytes + Timef CustomTime + Datef CustomDate + }{"abc", 300, false, 3.45, CustomFloat32(3.14), []byte("foo"), CustomTime(t1), CustomDate(d1)}, listProto( stringProto("abc"), intProto(300), boolProto(false), floatProto(3.45), + float32Proto(3.14), bytesProto([]byte("foo")), timeProto(t1), dateProto(d1)), @@ -1039,18 +1065,20 @@ func TestEncodeStructValueBasicFields(t *testing.T) { { "Basic types null values.", struct { - Stringf NullString - Intf NullInt64 - Boolf NullBool - Floatf NullFloat64 - Bytef []byte - Timef NullTime - Datef NullDate + Stringf NullString + Intf NullInt64 + Boolf NullBool + Floatf NullFloat64 + Float32f NullFloat32 + Bytef []byte + Timef NullTime + Datef NullDate }{ NullString{"abc", false}, NullInt64{4, false}, NullBool{false, false}, NullFloat64{5.6, false}, + NullFloat32{3.14, false}, nil, NullTime{t1, false}, NullDate{d1, false}, @@ -1062,24 +1090,27 @@ func TestEncodeStructValueBasicFields(t *testing.T) { nullProto(), nullProto(), nullProto(), + nullProto(), nullProto()), StructTypeProto, }, { "Basic custom types null values.", struct { - Stringf CustomNullString - Intf CustomNullInt64 - Boolf CustomNullBool - Floatf CustomNullFloat64 - Bytef CustomBytes - Timef CustomNullTime - Datef CustomNullDate + Stringf CustomNullString + Intf CustomNullInt64 + Boolf CustomNullBool + Floatf CustomNullFloat64 + Float32f CustomNullFloat32 + Bytef CustomBytes + Timef CustomNullTime + Datef CustomNullDate }{ CustomNullString{"abc", false}, CustomNullInt64{4, false}, CustomNullBool{false, false}, CustomNullFloat64{5.6, false}, + CustomNullFloat32{3.14, false}, nil, CustomNullTime{t1, false}, CustomNullDate{d1, false}, @@ -1091,6 +1122,7 @@ func TestEncodeStructValueBasicFields(t *testing.T) { nullProto(), nullProto(), nullProto(), + nullProto(), nullProto()), StructTypeProto, }, @@ -1421,6 +1453,10 @@ func TestDecodeValue(t *testing.T) { var fNilPtr *float64 f2Value := 6.626 + f32Value := float32(3.14) + var f32NilPtr *float32 + f32Value2 := float32(6.626) + numValuePtr := big.NewRat(12345, 1e3) var numNilPtr *big.Rat num2ValuePtr := big.NewRat(12345, 1e4) @@ -1506,6 +1542,21 @@ func TestDecodeValue(t *testing.T) { // FLOAT64 ARRAY with []*float64 {desc: "decode ARRAY to []*float64", proto: listProto(floatProto(fValue), nullProto(), floatProto(f2Value)), protoType: listType(floatType()), want: []*float64{&fValue, nil, &f2Value}}, {desc: "decode NULL to []*float64", proto: nullProto(), protoType: listType(floatType()), want: []*float64(nil)}, + // FLOAT32 + {desc: "decode FLOAT32 to float32", proto: float32Proto(3.14), protoType: float32Type(), want: float32(3.14)}, + {desc: "decode NULL to float32", proto: nullProto(), protoType: float32Type(), want: 0.00, wantErr: true}, + {desc: "decode FLOAT32 to *float32", proto: float32Proto(3.14), protoType: float32Type(), want: &f32Value}, + {desc: "decode NULL to *float32", proto: nullProto(), protoType: float32Type(), want: f32NilPtr}, + {desc: "decode FLOAT32 to NullFloat32", proto: float32Proto(3.14), protoType: float32Type(), want: NullFloat32{3.14, true}}, + {desc: "decode NULL to NullFloat32", proto: nullProto(), protoType: float32Type(), want: NullFloat32{}}, + // FLOAT64 ARRAY with []NullFloat32 + {desc: "decode ARRAY to []NullFloat32", proto: listProto(float32Proto(float32(math.Inf(1))), float32Proto(float32(math.Inf(-1))), nullProto(), float32Proto(3.1)), protoType: listType(float32Type()), want: []NullFloat32{{float32(math.Inf(1)), true}, {float32(math.Inf(-1)), true}, {}, {3.1, true}}}, + {desc: "decode NULL to []NullFloat32", proto: nullProto(), protoType: listType(float32Type()), want: []NullFloat32(nil)}, + // FLOAT32 ARRAY with []float32 + {desc: "decode ARRAY to []float32", proto: listProto(float32Proto(float32(math.Inf(1))), float32Proto(float32(math.Inf(-1))), float32Proto(3.1)), protoType: listType(float32Type()), want: []float32{float32(math.Inf(1)), float32(math.Inf(-1)), 3.1}}, + // FLOAT64 ARRAY with []*float32 + {desc: "decode ARRAY to []*float32", proto: listProto(float32Proto(f32Value), nullProto(), float32Proto(f32Value2)), protoType: listType(float32Type()), want: []*float32{&f32Value, nil, &f32Value2}}, + {desc: "decode NULL to []*float32", proto: nullProto(), protoType: listType(float32Type()), want: []*float32(nil)}, // NUMERIC {desc: "decode NUMERIC to big.Rat", proto: numericProto(numValuePtr), protoType: numericType(), want: *numValuePtr}, {desc: "decode NUMERIC to NullNumeric", proto: numericProto(numValuePtr), protoType: numericType(), want: NullNumeric{*numValuePtr, true}}, @@ -1892,6 +1943,7 @@ func TestGetDecodableSpannerType(t *testing.T) { type CustomInt64 int64 type CustomBool bool type CustomFloat64 float64 + type CustomFloat32 float32 type CustomTime time.Time type CustomDate civil.Date type CustomNumeric big.Rat @@ -1900,6 +1952,7 @@ func TestGetDecodableSpannerType(t *testing.T) { type CustomNullInt64 NullInt64 type CustomNullBool NullBool type CustomNullFloat64 NullFloat64 + type CustomNullFloat32 NullFloat32 type CustomNullTime NullTime type CustomNullDate NullDate type CustomNullNumeric NullNumeric @@ -1921,12 +1974,14 @@ func TestGetDecodableSpannerType(t *testing.T) { {int64(123), spannerTypeNonNullInt64}, {true, spannerTypeNonNullBool}, {3.14, spannerTypeNonNullFloat64}, + {float32(3.14), spannerTypeNonNullFloat32}, {time.Now(), spannerTypeNonNullTime}, {civil.DateOf(time.Now()), spannerTypeNonNullDate}, {NullString{}, spannerTypeNullString}, {NullInt64{}, spannerTypeNullInt64}, {NullBool{}, spannerTypeNullBool}, {NullFloat64{}, spannerTypeNullFloat64}, + {NullFloat32{}, spannerTypeNullFloat32}, {NullTime{}, spannerTypeNullTime}, {NullDate{}, spannerTypeNullDate}, {*big.NewRat(1234, 1000), spannerTypeNonNullNumeric}, @@ -1939,12 +1994,14 @@ func TestGetDecodableSpannerType(t *testing.T) { {[]int64{int64(123)}, spannerTypeArrayOfNonNullInt64}, {[]bool{true}, spannerTypeArrayOfNonNullBool}, {[]float64{3.14}, spannerTypeArrayOfNonNullFloat64}, + {[]float32{3.14}, spannerTypeArrayOfNonNullFloat32}, {[]time.Time{time.Now()}, spannerTypeArrayOfNonNullTime}, {[]civil.Date{civil.DateOf(time.Now())}, spannerTypeArrayOfNonNullDate}, {[]NullString{}, spannerTypeArrayOfNullString}, {[]NullInt64{}, spannerTypeArrayOfNullInt64}, {[]NullBool{}, spannerTypeArrayOfNullBool}, {[]NullFloat64{}, spannerTypeArrayOfNullFloat64}, + {[]NullFloat32{}, spannerTypeArrayOfNullFloat32}, {[]NullTime{}, spannerTypeArrayOfNullTime}, {[]NullDate{}, spannerTypeArrayOfNullDate}, {[]big.Rat{}, spannerTypeArrayOfNonNullNumeric}, @@ -1955,6 +2012,7 @@ func TestGetDecodableSpannerType(t *testing.T) { {CustomInt64(-100), spannerTypeNonNullInt64}, {CustomBool(true), spannerTypeNonNullBool}, {CustomFloat64(3.141592), spannerTypeNonNullFloat64}, + {CustomFloat32(3.141592), spannerTypeNonNullFloat32}, {CustomTime(time.Now()), spannerTypeNonNullTime}, {CustomDate(civil.DateOf(time.Now())), spannerTypeNonNullDate}, {CustomNumeric(*big.NewRat(1234, 1000)), spannerTypeNonNullNumeric}, @@ -1963,6 +2021,7 @@ func TestGetDecodableSpannerType(t *testing.T) { {[]CustomInt64{}, spannerTypeArrayOfNonNullInt64}, {[]CustomBool{}, spannerTypeArrayOfNonNullBool}, {[]CustomFloat64{}, spannerTypeArrayOfNonNullFloat64}, + {[]CustomFloat32{}, spannerTypeArrayOfNonNullFloat32}, {[]CustomTime{}, spannerTypeArrayOfNonNullTime}, {[]CustomDate{}, spannerTypeArrayOfNonNullDate}, {[]CustomNumeric{}, spannerTypeArrayOfNonNullNumeric}, @@ -1971,6 +2030,7 @@ func TestGetDecodableSpannerType(t *testing.T) { {CustomNullInt64{}, spannerTypeNullInt64}, {CustomNullBool{}, spannerTypeNullBool}, {CustomNullFloat64{}, spannerTypeNullFloat64}, + {CustomNullFloat32{}, spannerTypeNullFloat32}, {CustomNullTime{}, spannerTypeNullTime}, {CustomNullDate{}, spannerTypeNullDate}, {CustomNullNumeric{}, spannerTypeNullNumeric}, @@ -1979,6 +2039,7 @@ func TestGetDecodableSpannerType(t *testing.T) { {[]CustomNullInt64{}, spannerTypeArrayOfNullInt64}, {[]CustomNullBool{}, spannerTypeArrayOfNullBool}, {[]CustomNullFloat64{}, spannerTypeArrayOfNullFloat64}, + {[]CustomNullFloat32{}, spannerTypeArrayOfNullFloat32}, {[]CustomNullTime{}, spannerTypeArrayOfNullTime}, {[]CustomNullDate{}, spannerTypeArrayOfNullDate}, {[]CustomNullNumeric{}, spannerTypeArrayOfNullNumeric}, @@ -2047,6 +2108,52 @@ func TestNaN(t *testing.T) { } } +// Test Float32 NaN encoding/decoding. +func TestFloat32NaN(t *testing.T) { + // Decode NaN value. + f := float32(0.0) + nf := NullFloat32{} + // To float32 + if err := decodeValue(float32Proto(float32(math.NaN())), float32Type(), &f); err != nil { + t.Errorf("decodeValue returns %q for %v, want nil", err, float32Proto(float32(math.NaN()))) + } + if !math.IsNaN(float64(f)) { + t.Errorf("f = %v, want %v", f, math.NaN()) + } + // To NullFloat32 + if err := decodeValue(float32Proto(float32(math.NaN())), float32Type(), &nf); err != nil { + t.Errorf("decodeValue returns %q for %v, want nil", err, float32Proto(float32(math.NaN()))) + } + if !math.IsNaN(float64(nf.Float32)) || !nf.Valid { + t.Errorf("f = %v, want %v", f, NullFloat32{float32(math.NaN()), true}) + } + // Encode NaN value + // From float32 + v, _, err := encodeValue(float32(math.NaN())) + if err != nil { + t.Errorf("encodeValue returns %q for NaN, want nil", err) + } + x, ok := v.GetKind().(*proto3.Value_NumberValue) + if !ok { + t.Errorf("incorrect type for v.GetKind(): %T, want *proto3.Value_NumberValue", v.GetKind()) + } + if !math.IsNaN(x.NumberValue) { + t.Errorf("x.NumberValue = %v, want %v", x.NumberValue, math.NaN()) + } + // From NullFloat32 + v, _, err = encodeValue(NullFloat32{float32(math.NaN()), true}) + if err != nil { + t.Errorf("encodeValue returns %q for NaN, want nil", err) + } + x, ok = v.GetKind().(*proto3.Value_NumberValue) + if !ok { + t.Errorf("incorrect type for v.GetKind(): %T, want *proto3.Value_NumberValue", v.GetKind()) + } + if !math.IsNaN(x.NumberValue) { + t.Errorf("x.NumberValue = %v, want %v", x.NumberValue, math.NaN()) + } +} + func TestGenericColumnValue(t *testing.T) { for _, test := range []struct { in GenericColumnValue @@ -2703,6 +2810,15 @@ func TestJSONMarshal_NullTypes(t *testing.T) { {input: NullFloat64{}, expect: "null"}, }, }, + { + "NullFloat32", + []testcase{ + {input: NullFloat32{float32(3.14), true}, expect: "3.14"}, + {input: &NullFloat32{float32(123.123), true}, expect: "123.123"}, + {input: &NullFloat32{float32(123.123), false}, expect: "null"}, + {input: NullFloat32{}, expect: "null"}, + }, + }, { "NullBool", []testcase{ @@ -2819,6 +2935,16 @@ func TestJSONUnmarshal_NullTypes(t *testing.T) { {input: []byte(`"hello`), got: NullFloat64{}, isNull: true, expect: nullString, expectError: true}, }, }, + { + "NullFloat32", + []testcase{ + {input: []byte("3.14"), got: NullFloat32{}, isNull: false, expect: "3.14", expectError: false}, + {input: []byte("null"), got: NullFloat32{}, isNull: true, expect: nullString, expectError: false}, + {input: nil, got: NullFloat32{}, isNull: true, expect: nullString, expectError: true}, + {input: []byte(""), got: NullFloat32{}, isNull: true, expect: nullString, expectError: true}, + {input: []byte(`"hello`), got: NullFloat32{}, isNull: true, expect: nullString, expectError: true}, + }, + }, { "NullBool", []testcase{ @@ -2893,6 +3019,9 @@ func TestJSONUnmarshal_NullTypes(t *testing.T) { case NullFloat64: err := json.Unmarshal(tc.input, &v) expectUnmarshalNullableTypes(t, err, v, tc.isNull, tc.expect, tc.expectError) + case NullFloat32: + err := json.Unmarshal(tc.input, &v) + expectUnmarshalNullableTypes(t, err, v, tc.isNull, tc.expect, tc.expectError) case NullBool: err := json.Unmarshal(tc.input, &v) expectUnmarshalNullableTypes(t, err, v, tc.isNull, tc.expect, tc.expectError)