Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Protocol Support for NaN/Infinity/-Infinity float values #4592

Merged
merged 1 commit into from
Oct 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions private/protocol/json/jsonutil/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package jsonutil
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"math"
"reflect"
Expand All @@ -16,6 +15,12 @@ import (
"github.com/aws/aws-sdk-go/private/protocol"
)

const (
floatNaN = "NaN"
floatInf = "Infinity"
floatNegInf = "-Infinity"
)

var timeType = reflect.ValueOf(time.Time{}).Type()
var byteSliceType = reflect.ValueOf([]byte{}).Type()

Expand Down Expand Up @@ -211,10 +216,16 @@ func buildScalar(v reflect.Value, buf *bytes.Buffer, tag reflect.StructTag) erro
buf.Write(strconv.AppendInt(scratch[:0], value.Int(), 10))
case reflect.Float64:
f := value.Float()
if math.IsInf(f, 0) || math.IsNaN(f) {
return &json.UnsupportedValueError{Value: v, Str: strconv.FormatFloat(f, 'f', -1, 64)}
switch {
case math.IsNaN(f):
writeString(floatNaN, buf)
case math.IsInf(f, 1):
writeString(floatInf, buf)
case math.IsInf(f, -1):
writeString(floatNegInf, buf)
default:
buf.Write(strconv.AppendFloat(scratch[:0], f, 'f', -1, 64))
}
buf.Write(strconv.AppendFloat(scratch[:0], f, 'f', -1, 64))
default:
switch converted := value.Interface().(type) {
case time.Time:
Expand Down
40 changes: 24 additions & 16 deletions private/protocol/json/jsonutil/build_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package jsonutil_test

import (
"encoding/json"
"math"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -41,41 +42,48 @@ var jsonTests = []struct {
err string
}{
{
J{},
`{}`,
``,
in: J{},
out: `{}`,
},
{
J{
in: J{
S: S("str"),
SS: []string{"A", "B", "C"},
D: D(123),
F: F(4.56),
T: T(time.Unix(987, 0)),
},
`{"S":"str","SS":["A","B","C"],"D":123,"F":4.56,"T":987}`,
``,
out: `{"S":"str","SS":["A","B","C"],"D":123,"F":4.56,"T":987}`,
},
{
J{
in: J{
S: S(`"''"`),
},
`{"S":"\"''\""}`,
``,
out: `{"S":"\"''\""}`,
},
{
J{
in: J{
S: S("\x00føø\u00FF\n\\\"\r\t\b\f"),
},
`{"S":"\u0000føøÿ\n\\\"\r\t\b\f"}`,
``,
out: `{"S":"\u0000føøÿ\n\\\"\r\t\b\f"}`,
},
{
J{
F: F(4.56 / zero),
in: J{
F: F(math.NaN()),
},
"",
`json: unsupported value: +Inf`,
out: `{"F":"NaN"}`,
},
{
in: J{
F: F(math.Inf(1)),
},
out: `{"F":"Infinity"}`,
},
{
in: J{
F: F(math.Inf(-1)),
},
out: `{"F":"-Infinity"}`,
},
}

Expand Down
13 changes: 13 additions & 0 deletions private/protocol/json/jsonutil/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"io"
"math"
"math/big"
"reflect"
"strings"
Expand Down Expand Up @@ -258,6 +259,18 @@ func (u unmarshaler) unmarshalScalar(value reflect.Value, data interface{}, tag
return err
}
value.Set(reflect.ValueOf(v))
case *float64:
// These are regular strings when parsed by encoding/json's unmarshaler.
switch {
case strings.EqualFold(d, floatNaN):
value.Set(reflect.ValueOf(aws.Float64(math.NaN())))
case strings.EqualFold(d, floatInf):
value.Set(reflect.ValueOf(aws.Float64(math.Inf(1))))
case strings.EqualFold(d, floatNegInf):
value.Set(reflect.ValueOf(aws.Float64(math.Inf(-1))))
default:
return fmt.Errorf("unknown JSON number value: %s", d)
}
default:
return fmt.Errorf("unsupported value: %v (%s)", value.Interface(), value.Type())
}
Expand Down
35 changes: 32 additions & 3 deletions private/protocol/json/jsonutil/unmarshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package jsonutil_test

import (
"bytes"
"math"
"reflect"
"testing"
"time"
Expand All @@ -21,9 +22,10 @@ func TestUnmarshalJSON_JSONNumber(t *testing.T) {
}

cases := map[string]struct {
JSON string
Value input
Expected input
JSON string
Value input
Expected input
ExpectedFn func(*testing.T, input)
}{
"seconds precision": {
JSON: `{"timeField":1597094942}`,
Expand Down Expand Up @@ -106,6 +108,29 @@ func TestUnmarshalJSON_JSONNumber(t *testing.T) {
FloatField: aws.Float64(123456789.123),
},
},
"float64 field NaN": {
JSON: `{"floatField":"NaN"}`,
ExpectedFn: func(t *testing.T, input input) {
if input.FloatField == nil {
t.Fatal("expect non nil float64")
}
if e, a := true, math.IsNaN(*input.FloatField); e != a {
t.Errorf("expect %v, got %v", e, a)
}
},
},
"float64 field Infinity": {
JSON: `{"floatField":"Infinity"}`,
Expected: input{
FloatField: aws.Float64(math.Inf(1)),
},
},
"float64 field -Infinity": {
JSON: `{"floatField":"-Infinity"}`,
Expected: input{
FloatField: aws.Float64(math.Inf(-1)),
},
},
}

for name, tt := range cases {
Expand All @@ -114,6 +139,10 @@ func TestUnmarshalJSON_JSONNumber(t *testing.T) {
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if tt.ExpectedFn != nil {
tt.ExpectedFn(t, tt.Value)
return
}
if e, a := tt.Expected, tt.Value; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v, got %v", e, a)
}
Expand Down
34 changes: 32 additions & 2 deletions private/protocol/query/queryutil/queryutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package queryutil
import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason that the query implementation doesnt need any accompanied unit tests?

"encoding/base64"
"fmt"
"math"
"net/url"
"reflect"
"sort"
Expand All @@ -13,6 +14,12 @@ import (
"github.com/aws/aws-sdk-go/private/protocol"
)

const (
floatNaN = "NaN"
floatInf = "Infinity"
floatNegInf = "-Infinity"
)

// Parse parses an object i and fills a url.Values object. The isEC2 flag
// indicates if this is the EC2 Query sub-protocol.
func Parse(body url.Values, i interface{}, isEC2 bool) error {
Expand Down Expand Up @@ -228,9 +235,32 @@ func (q *queryParser) parseScalar(v url.Values, r reflect.Value, name string, ta
case int:
v.Set(name, strconv.Itoa(value))
case float64:
v.Set(name, strconv.FormatFloat(value, 'f', -1, 64))
var str string
switch {
case math.IsNaN(value):
str = floatNaN
case math.IsInf(value, 1):
str = floatInf
case math.IsInf(value, -1):
str = floatNegInf
default:
str = strconv.FormatFloat(value, 'f', -1, 64)
}
v.Set(name, str)
case float32:
v.Set(name, strconv.FormatFloat(float64(value), 'f', -1, 32))
asFloat64 := float64(value)
var str string
switch {
case math.IsNaN(asFloat64):
str = floatNaN
case math.IsInf(asFloat64, 1):
str = floatInf
case math.IsInf(asFloat64, -1):
str = floatNegInf
default:
str = strconv.FormatFloat(asFloat64, 'f', -1, 32)
}
v.Set(name, str)
case time.Time:
const ISO8601UTC = "2006-01-02T15:04:05Z"
format := tag.Get("timestampFormat")
Expand Down
18 changes: 17 additions & 1 deletion private/protocol/rest/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/base64"
"fmt"
"io"
"math"
"net/http"
"net/url"
"path"
Expand All @@ -20,6 +21,12 @@ import (
"github.com/aws/aws-sdk-go/private/protocol"
)

const (
floatNaN = "NaN"
floatInf = "Infinity"
floatNegInf = "-Infinity"
)

// Whether the byte value can be sent without escaping in AWS URLs
var noEscape [256]bool

Expand Down Expand Up @@ -302,7 +309,16 @@ func convertType(v reflect.Value, tag reflect.StructTag) (str string, err error)
case int64:
str = strconv.FormatInt(value, 10)
case float64:
str = strconv.FormatFloat(value, 'f', -1, 64)
switch {
case math.IsNaN(value):
str = floatNaN
case math.IsInf(value, 1):
str = floatInf
case math.IsInf(value, -1):
str = floatNegInf
default:
str = strconv.FormatFloat(value, 'f', -1, 64)
}
case time.Time:
format := tag.Get("timestampFormat")
if len(format) == 0 {
Expand Down
Loading