diff --git a/private/protocol/json/jsonutil/build.go b/private/protocol/json/jsonutil/build.go index 2aec80661a4..12e814ddf25 100644 --- a/private/protocol/json/jsonutil/build.go +++ b/private/protocol/json/jsonutil/build.go @@ -4,7 +4,6 @@ package jsonutil import ( "bytes" "encoding/base64" - "encoding/json" "fmt" "math" "reflect" @@ -16,6 +15,12 @@ import ( "github.com/aws/aws-sdk-go/private/protocol" ) +const ( + floatNaN = "NaN" + floatInf = "Infinity" + floatNegInf = "-Infinity" +) + var timeType = reflect.ValueOf(time.Time{}).Type() var byteSliceType = reflect.ValueOf([]byte{}).Type() @@ -211,10 +216,16 @@ func buildScalar(v reflect.Value, buf *bytes.Buffer, tag reflect.StructTag) erro buf.Write(strconv.AppendInt(scratch[:0], value.Int(), 10)) case reflect.Float64: f := value.Float() - if math.IsInf(f, 0) || math.IsNaN(f) { - return &json.UnsupportedValueError{Value: v, Str: strconv.FormatFloat(f, 'f', -1, 64)} + switch { + case math.IsNaN(f): + writeString(floatNaN, buf) + case math.IsInf(f, 1): + writeString(floatInf, buf) + case math.IsInf(f, -1): + writeString(floatNegInf, buf) + default: + buf.Write(strconv.AppendFloat(scratch[:0], f, 'f', -1, 64)) } - buf.Write(strconv.AppendFloat(scratch[:0], f, 'f', -1, 64)) default: switch converted := value.Interface().(type) { case time.Time: diff --git a/private/protocol/json/jsonutil/build_test.go b/private/protocol/json/jsonutil/build_test.go index 66392b658e0..353228cf7d4 100644 --- a/private/protocol/json/jsonutil/build_test.go +++ b/private/protocol/json/jsonutil/build_test.go @@ -2,6 +2,7 @@ package jsonutil_test import ( "encoding/json" + "math" "strings" "testing" "time" @@ -41,41 +42,48 @@ var jsonTests = []struct { err string }{ { - J{}, - `{}`, - ``, + in: J{}, + out: `{}`, }, { - J{ + in: J{ S: S("str"), SS: []string{"A", "B", "C"}, D: D(123), F: F(4.56), T: T(time.Unix(987, 0)), }, - `{"S":"str","SS":["A","B","C"],"D":123,"F":4.56,"T":987}`, - ``, + out: `{"S":"str","SS":["A","B","C"],"D":123,"F":4.56,"T":987}`, }, { - J{ + in: J{ S: S(`"''"`), }, - `{"S":"\"''\""}`, - ``, + out: `{"S":"\"''\""}`, }, { - J{ + in: J{ S: S("\x00føø\u00FF\n\\\"\r\t\b\f"), }, - `{"S":"\u0000føøÿ\n\\\"\r\t\b\f"}`, - ``, + out: `{"S":"\u0000føøÿ\n\\\"\r\t\b\f"}`, }, { - J{ - F: F(4.56 / zero), + in: J{ + F: F(math.NaN()), }, - "", - `json: unsupported value: +Inf`, + out: `{"F":"NaN"}`, + }, + { + in: J{ + F: F(math.Inf(1)), + }, + out: `{"F":"Infinity"}`, + }, + { + in: J{ + F: F(math.Inf(-1)), + }, + out: `{"F":"-Infinity"}`, }, } diff --git a/private/protocol/json/jsonutil/unmarshal.go b/private/protocol/json/jsonutil/unmarshal.go index 8b2c9bbeba0..f9334879b80 100644 --- a/private/protocol/json/jsonutil/unmarshal.go +++ b/private/protocol/json/jsonutil/unmarshal.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "math" "math/big" "reflect" "strings" @@ -258,6 +259,18 @@ func (u unmarshaler) unmarshalScalar(value reflect.Value, data interface{}, tag return err } value.Set(reflect.ValueOf(v)) + case *float64: + // These are regular strings when parsed by encoding/json's unmarshaler. + switch { + case strings.EqualFold(d, floatNaN): + value.Set(reflect.ValueOf(aws.Float64(math.NaN()))) + case strings.EqualFold(d, floatInf): + value.Set(reflect.ValueOf(aws.Float64(math.Inf(1)))) + case strings.EqualFold(d, floatNegInf): + value.Set(reflect.ValueOf(aws.Float64(math.Inf(-1)))) + default: + return fmt.Errorf("unknown JSON number value: %s", d) + } default: return fmt.Errorf("unsupported value: %v (%s)", value.Interface(), value.Type()) } diff --git a/private/protocol/json/jsonutil/unmarshal_test.go b/private/protocol/json/jsonutil/unmarshal_test.go index a91b46c34fa..1e7ec3615ef 100644 --- a/private/protocol/json/jsonutil/unmarshal_test.go +++ b/private/protocol/json/jsonutil/unmarshal_test.go @@ -5,6 +5,7 @@ package jsonutil_test import ( "bytes" + "math" "reflect" "testing" "time" @@ -21,9 +22,10 @@ func TestUnmarshalJSON_JSONNumber(t *testing.T) { } cases := map[string]struct { - JSON string - Value input - Expected input + JSON string + Value input + Expected input + ExpectedFn func(*testing.T, input) }{ "seconds precision": { JSON: `{"timeField":1597094942}`, @@ -106,6 +108,29 @@ func TestUnmarshalJSON_JSONNumber(t *testing.T) { FloatField: aws.Float64(123456789.123), }, }, + "float64 field NaN": { + JSON: `{"floatField":"NaN"}`, + ExpectedFn: func(t *testing.T, input input) { + if input.FloatField == nil { + t.Fatal("expect non nil float64") + } + if e, a := true, math.IsNaN(*input.FloatField); e != a { + t.Errorf("expect %v, got %v", e, a) + } + }, + }, + "float64 field Infinity": { + JSON: `{"floatField":"Infinity"}`, + Expected: input{ + FloatField: aws.Float64(math.Inf(1)), + }, + }, + "float64 field -Infinity": { + JSON: `{"floatField":"-Infinity"}`, + Expected: input{ + FloatField: aws.Float64(math.Inf(-1)), + }, + }, } for name, tt := range cases { @@ -114,6 +139,10 @@ func TestUnmarshalJSON_JSONNumber(t *testing.T) { if err != nil { t.Errorf("expect no error, got %v", err) } + if tt.ExpectedFn != nil { + tt.ExpectedFn(t, tt.Value) + return + } if e, a := tt.Expected, tt.Value; !reflect.DeepEqual(e, a) { t.Errorf("expect %v, got %v", e, a) } diff --git a/private/protocol/query/queryutil/queryutil.go b/private/protocol/query/queryutil/queryutil.go index 75866d01218..058334053c2 100644 --- a/private/protocol/query/queryutil/queryutil.go +++ b/private/protocol/query/queryutil/queryutil.go @@ -3,6 +3,7 @@ package queryutil import ( "encoding/base64" "fmt" + "math" "net/url" "reflect" "sort" @@ -13,6 +14,12 @@ import ( "github.com/aws/aws-sdk-go/private/protocol" ) +const ( + floatNaN = "NaN" + floatInf = "Infinity" + floatNegInf = "-Infinity" +) + // Parse parses an object i and fills a url.Values object. The isEC2 flag // indicates if this is the EC2 Query sub-protocol. func Parse(body url.Values, i interface{}, isEC2 bool) error { @@ -228,9 +235,32 @@ func (q *queryParser) parseScalar(v url.Values, r reflect.Value, name string, ta case int: v.Set(name, strconv.Itoa(value)) case float64: - v.Set(name, strconv.FormatFloat(value, 'f', -1, 64)) + var str string + switch { + case math.IsNaN(value): + str = floatNaN + case math.IsInf(value, 1): + str = floatInf + case math.IsInf(value, -1): + str = floatNegInf + default: + str = strconv.FormatFloat(value, 'f', -1, 64) + } + v.Set(name, str) case float32: - v.Set(name, strconv.FormatFloat(float64(value), 'f', -1, 32)) + asFloat64 := float64(value) + var str string + switch { + case math.IsNaN(asFloat64): + str = floatNaN + case math.IsInf(asFloat64, 1): + str = floatInf + case math.IsInf(asFloat64, -1): + str = floatNegInf + default: + str = strconv.FormatFloat(asFloat64, 'f', -1, 32) + } + v.Set(name, str) case time.Time: const ISO8601UTC = "2006-01-02T15:04:05Z" format := tag.Get("timestampFormat") diff --git a/private/protocol/rest/build.go b/private/protocol/rest/build.go index 63f66af2c62..1d273ff0ec6 100644 --- a/private/protocol/rest/build.go +++ b/private/protocol/rest/build.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "fmt" "io" + "math" "net/http" "net/url" "path" @@ -20,6 +21,12 @@ import ( "github.com/aws/aws-sdk-go/private/protocol" ) +const ( + floatNaN = "NaN" + floatInf = "Infinity" + floatNegInf = "-Infinity" +) + // Whether the byte value can be sent without escaping in AWS URLs var noEscape [256]bool @@ -302,7 +309,16 @@ func convertType(v reflect.Value, tag reflect.StructTag) (str string, err error) case int64: str = strconv.FormatInt(value, 10) case float64: - str = strconv.FormatFloat(value, 'f', -1, 64) + switch { + case math.IsNaN(value): + str = floatNaN + case math.IsInf(value, 1): + str = floatInf + case math.IsInf(value, -1): + str = floatNegInf + default: + str = strconv.FormatFloat(value, 'f', -1, 64) + } case time.Time: format := tag.Get("timestampFormat") if len(format) == 0 { diff --git a/private/protocol/rest/build_test.go b/private/protocol/rest/build_test.go index ca06460384f..5b48482bc7d 100644 --- a/private/protocol/rest/build_test.go +++ b/private/protocol/rest/build_test.go @@ -1,6 +1,10 @@ +//go:build go1.7 +// +build go1.7 + package rest import ( + "math" "net/http" "net/url" "reflect" @@ -175,3 +179,103 @@ func TestListOfEnums(t *testing.T) { } } } + +func TestMarshalFloat64(t *testing.T) { + cases := map[string]struct { + Input interface{} + URL string + ExpectedHeader http.Header + ExpectedURL string + WantErr bool + }{ + "header float values": { + Input: &struct { + Float *float64 `location:"header" locationName:"x-amz-float"` + FloatInf *float64 `location:"header" locationName:"x-amz-float-inf"` + FloatNegInf *float64 `location:"header" locationName:"x-amz-float-neg-inf"` + FloatNaN *float64 `location:"header" locationName:"x-amz-float-nan"` + }{ + Float: aws.Float64(123456789.123), + FloatInf: aws.Float64(math.Inf(1)), + FloatNegInf: aws.Float64(math.Inf(-1)), + FloatNaN: aws.Float64(math.NaN()), + }, + URL: "https://example.com/", + ExpectedHeader: map[string][]string{ + "X-Amz-Float": {"123456789.123"}, + "X-Amz-Float-Inf": {"Infinity"}, + "X-Amz-Float-Neg-Inf": {"-Infinity"}, + "X-Amz-Float-Nan": {"NaN"}, + }, + ExpectedURL: "https://example.com/", + }, + "path float values": { + Input: &struct { + Float *float64 `location:"uri" locationName:"float"` + FloatInf *float64 `location:"uri" locationName:"floatInf"` + FloatNegInf *float64 `location:"uri" locationName:"floatNegInf"` + FloatNaN *float64 `location:"uri" locationName:"floatNaN"` + }{ + Float: aws.Float64(123456789.123), + FloatInf: aws.Float64(math.Inf(1)), + FloatNegInf: aws.Float64(math.Inf(-1)), + FloatNaN: aws.Float64(math.NaN()), + }, + URL: "https://example.com/{float}/{floatInf}/{floatNegInf}/{floatNaN}", + ExpectedHeader: map[string][]string{}, + ExpectedURL: "https://example.com/123456789.123/Infinity/-Infinity/NaN", + }, + "query float values": { + Input: &struct { + Float *float64 `location:"querystring" locationName:"x-amz-float"` + FloatInf *float64 `location:"querystring" locationName:"x-amz-float-inf"` + FloatNegInf *float64 `location:"querystring" locationName:"x-amz-float-neg-inf"` + FloatNaN *float64 `location:"querystring" locationName:"x-amz-float-nan"` + }{ + Float: aws.Float64(123456789.123), + FloatInf: aws.Float64(math.Inf(1)), + FloatNegInf: aws.Float64(math.Inf(-1)), + FloatNaN: aws.Float64(math.NaN()), + }, + URL: "https://example.com/", + ExpectedHeader: map[string][]string{}, + ExpectedURL: "https://example.com/?x-amz-float=123456789.123&x-amz-float-inf=Infinity&x-amz-float-nan=NaN&x-amz-float-neg-inf=-Infinity", + }, + } + + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + req := &request.Request{ + HTTPRequest: &http.Request{ + URL: func() *url.URL { + u, err := url.Parse(tt.URL) + if err != nil { + panic(err) + } + return u + }(), + Header: map[string][]string{}, + }, + Params: tt.Input, + } + + Build(req) + + if (req.Error != nil) != (tt.WantErr) { + t.Fatalf("WantErr(%t) got %v", tt.WantErr, req.Error) + } + + if tt.WantErr { + return + } + + if e, a := tt.ExpectedHeader, req.HTTPRequest.Header; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v, got %v", e, a) + } + + if e, a := tt.ExpectedURL, req.HTTPRequest.URL.String(); e != a { + t.Errorf("expect %v, got %v", e, a) + } + }) + } +} diff --git a/private/protocol/rest/unmarshal.go b/private/protocol/rest/unmarshal.go index cdef403e219..79fcf1699b7 100644 --- a/private/protocol/rest/unmarshal.go +++ b/private/protocol/rest/unmarshal.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "io/ioutil" + "math" "net/http" "reflect" "strconv" @@ -231,9 +232,20 @@ func unmarshalHeader(v reflect.Value, header string, tag reflect.StructTag) erro } v.Set(reflect.ValueOf(&i)) case *float64: - f, err := strconv.ParseFloat(header, 64) - if err != nil { - return err + var f float64 + switch { + case strings.EqualFold(header, floatNaN): + f = math.NaN() + case strings.EqualFold(header, floatInf): + f = math.Inf(1) + case strings.EqualFold(header, floatNegInf): + f = math.Inf(-1) + default: + var err error + f, err = strconv.ParseFloat(header, 64) + if err != nil { + return err + } } v.Set(reflect.ValueOf(&f)) case *time.Time: diff --git a/private/protocol/rest/unmarshal_test.go b/private/protocol/rest/unmarshal_test.go new file mode 100644 index 00000000000..a3ad179b383 --- /dev/null +++ b/private/protocol/rest/unmarshal_test.go @@ -0,0 +1,86 @@ +//go:build go1.7 +// +build go1.7 + +package rest + +import ( + "bytes" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "io/ioutil" + "math" + "net/http" + "reflect" + "testing" +) + +func TestUnmarshalFloat64(t *testing.T) { + cases := map[string]struct { + Headers http.Header + OutputFn func() (interface{}, func(*testing.T, interface{})) + WantErr bool + }{ + "header float values": { + OutputFn: func() (interface{}, func(*testing.T, interface{})) { + type output struct { + Float *float64 `location:"header" locationName:"x-amz-float"` + FloatInf *float64 `location:"header" locationName:"x-amz-float-inf"` + FloatNegInf *float64 `location:"header" locationName:"x-amz-float-neg-inf"` + FloatNaN *float64 `location:"header" locationName:"x-amz-float-nan"` + } + + return &output{}, func(t *testing.T, out interface{}) { + o, ok := out.(*output) + if !ok { + t.Errorf("expect %T, got %T", (*output)(nil), out) + } + if e, a := aws.Float64(123456789.123), o.Float; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v, got %v", e, a) + } + if a := aws.Float64Value(o.FloatInf); !math.IsInf(a, 1) { + t.Errorf("expect infinity, got %v", a) + } + if a := aws.Float64Value(o.FloatNegInf); !math.IsInf(a, -1) { + t.Errorf("expect infinity, got %v", a) + } + if a := aws.Float64Value(o.FloatNaN); !math.IsNaN(a) { + t.Errorf("expect infinity, got %v", a) + } + } + }, + Headers: map[string][]string{ + "X-Amz-Float": {"123456789.123"}, + "X-Amz-Float-Inf": {"Infinity"}, + "X-Amz-Float-Neg-Inf": {"-Infinity"}, + "X-Amz-Float-Nan": {"NaN"}, + }, + }, + } + + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + output, expectFn := tt.OutputFn() + + req := &request.Request{ + Data: output, + HTTPResponse: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewReader(nil)), + Header: tt.Headers, + }, + } + + if (req.Error != nil) != tt.WantErr { + t.Fatalf("WantErr(%v) != %v", tt.WantErr, req.Error) + } + + if tt.WantErr { + return + } + + UnmarshalMeta(req) + + expectFn(t, output) + }) + } +} diff --git a/private/protocol/xml/xmlutil/build.go b/private/protocol/xml/xmlutil/build.go index 2fbb93ae76a..58c12bd8ccb 100644 --- a/private/protocol/xml/xmlutil/build.go +++ b/private/protocol/xml/xmlutil/build.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "encoding/xml" "fmt" + "math" "reflect" "sort" "strconv" @@ -14,6 +15,12 @@ import ( "github.com/aws/aws-sdk-go/private/protocol" ) +const ( + floatNaN = "NaN" + floatInf = "Infinity" + floatNegInf = "-Infinity" +) + // BuildXML will serialize params into an xml.Encoder. Error will be returned // if the serialization of any of the params or nested values fails. func BuildXML(params interface{}, e *xml.Encoder) error { @@ -275,6 +282,7 @@ func (b *xmlBuilder) buildMap(value reflect.Value, current *XMLNode, tag reflect // Error will be returned if the value type is unsupported. func (b *xmlBuilder) buildScalar(value reflect.Value, current *XMLNode, tag reflect.StructTag) error { var str string + switch converted := value.Interface().(type) { case string: str = converted @@ -289,9 +297,29 @@ func (b *xmlBuilder) buildScalar(value reflect.Value, current *XMLNode, tag refl case int: str = strconv.Itoa(converted) case float64: - str = strconv.FormatFloat(converted, 'f', -1, 64) + switch { + case math.IsNaN(converted): + str = floatNaN + case math.IsInf(converted, 1): + str = floatInf + case math.IsInf(converted, -1): + str = floatNegInf + default: + str = strconv.FormatFloat(converted, 'f', -1, 64) + } case float32: - str = strconv.FormatFloat(float64(converted), 'f', -1, 32) + // The SDK doesn't render float32 values in types, only float64. This case would never be hit currently. + asFloat64 := float64(converted) + switch { + case math.IsNaN(asFloat64): + str = floatNaN + case math.IsInf(asFloat64, 1): + str = floatInf + case math.IsInf(asFloat64, -1): + str = floatNegInf + default: + str = strconv.FormatFloat(asFloat64, 'f', -1, 32) + } case time.Time: format := tag.Get("timestampFormat") if len(format) == 0 { diff --git a/private/protocol/xml/xmlutil/build_test.go b/private/protocol/xml/xmlutil/build_test.go index f1b305368c3..de4b927242e 100644 --- a/private/protocol/xml/xmlutil/build_test.go +++ b/private/protocol/xml/xmlutil/build_test.go @@ -6,6 +6,7 @@ package xmlutil import ( "bytes" "encoding/xml" + "math" "testing" "github.com/aws/aws-sdk-go/aws" @@ -14,9 +15,10 @@ import ( type implicitPayload struct { _ struct{} `type:"structure"` - StrVal *string `type:"string"` - Second *nestedType `type:"structure"` - Third *nestedType `type:"structure"` + StrVal *string `type:"string"` + FloatVal *float64 `type:"double"` + Second *nestedType `type:"structure"` + Third *nestedType `type:"structure"` } type namedImplicitPayload struct { @@ -160,6 +162,30 @@ func TestBuildXML(t *testing.T) { }, Expect: "this string has escapable characters", }, + "float value": { + Input: &implicitPayload{ + FloatVal: aws.Float64(123456789.123), + }, + Expect: "123456789.123", + }, + "infinity float value": { + Input: &implicitPayload{ + FloatVal: aws.Float64(math.Inf(1)), + }, + Expect: "Infinity", + }, + "negative infinity float value": { + Input: &implicitPayload{ + FloatVal: aws.Float64(math.Inf(-1)), + }, + Expect: "-Infinity", + }, + "NaN float value": { + Input: &implicitPayload{ + FloatVal: aws.Float64(math.NaN()), + }, + Expect: "NaN", + }, } for name, c := range cases { diff --git a/private/protocol/xml/xmlutil/unmarshal.go b/private/protocol/xml/xmlutil/unmarshal.go index 107c053f8ac..44a580a940b 100644 --- a/private/protocol/xml/xmlutil/unmarshal.go +++ b/private/protocol/xml/xmlutil/unmarshal.go @@ -6,6 +6,7 @@ import ( "encoding/xml" "fmt" "io" + "math" "reflect" "strconv" "strings" @@ -276,9 +277,20 @@ func parseScalar(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { } r.Set(reflect.ValueOf(&v)) case *float64: - v, err := strconv.ParseFloat(node.Text, 64) - if err != nil { - return err + var v float64 + switch { + case strings.EqualFold(node.Text, floatNaN): + v = math.NaN() + case strings.EqualFold(node.Text, floatInf): + v = math.Inf(1) + case strings.EqualFold(node.Text, floatNegInf): + v = math.Inf(-1) + default: + var err error + v, err = strconv.ParseFloat(node.Text, 64) + if err != nil { + return err + } } r.Set(reflect.ValueOf(&v)) case *time.Time: diff --git a/private/protocol/xml/xmlutil/unmarshal_test.go b/private/protocol/xml/xmlutil/unmarshal_test.go index 1185d232855..9c6ef9611b4 100644 --- a/private/protocol/xml/xmlutil/unmarshal_test.go +++ b/private/protocol/xml/xmlutil/unmarshal_test.go @@ -1,10 +1,15 @@ +//go:build go1.7 +// +build go1.7 + package xmlutil import ( "encoding/xml" "fmt" "io" + "math" "reflect" + "strconv" "strings" "testing" @@ -30,6 +35,7 @@ type mockOutput struct { _ struct{} `type:"structure"` String *string `type:"string"` Integer *int64 `type:"integer"` + Float *float64 `type:"double"` Nested *mockNestedStruct `type:"structure"` List []*mockListElem `locationName:"List" locationNameList:"Elem" type:"list"` Closed *mockClosedTags `type:"structure"` @@ -56,7 +62,14 @@ type mockNestedListElem struct { } func TestUnmarshal(t *testing.T) { - const xmlBodyStr = ` + + cases := []struct { + Body string + Expect mockOutput + ExpectFn func(t *testing.T, actual mockOutput) + }{ + { + Body: ` string value 123 @@ -73,39 +86,77 @@ func TestUnmarshal(t *testing.T) { elem string value -` - - expect := mockOutput{ - String: aws.String("string value"), - Integer: aws.Int64(123), - Closed: &mockClosedTags{ - Attr: aws.String("attr value"), +`, + Expect: mockOutput{ + String: aws.String("string value"), + Integer: aws.Int64(123), + Closed: &mockClosedTags{ + Attr: aws.String("attr value"), + }, + Nested: &mockNestedStruct{ + NestedString: aws.String("nested string value"), + NestedInt: aws.Int64(321), + }, + List: []*mockListElem{ + { + String: aws.String("elem string value"), + NestedElem: &mockNestedListElem{ + String: aws.String("nested elem string value"), + Type: aws.String("type"), + }, + }, + }, + }, }, - Nested: &mockNestedStruct{ - NestedString: aws.String("nested string value"), - NestedInt: aws.Int64(321), + { + Body: `123456789.123`, + Expect: mockOutput{Float: aws.Float64(123456789.123)}, }, - List: []*mockListElem{ - { - String: aws.String("elem string value"), - NestedElem: &mockNestedListElem{ - String: aws.String("nested elem string value"), - Type: aws.String("type"), - }, + { + Body: `Infinity`, + ExpectFn: func(t *testing.T, actual mockOutput) { + if a := aws.Float64Value(actual.Float); !math.IsInf(a, 1) { + t.Errorf("expect infinity, got %v", a) + } + }, + }, + { + Body: `-Infinity`, + ExpectFn: func(t *testing.T, actual mockOutput) { + if a := aws.Float64Value(actual.Float); !math.IsInf(a, -1) { + t.Errorf("expect -infinity, got %v", a) + } + }, + }, + { + Body: `NaN`, + ExpectFn: func(t *testing.T, actual mockOutput) { + if a := aws.Float64Value(actual.Float); !math.IsNaN(a) { + t.Errorf("expect NaN, got %v", a) + } }, }, } - actual := mockOutput{} - decoder := xml.NewDecoder(strings.NewReader(xmlBodyStr)) - err := UnmarshalXML(&actual, decoder, "") - if err != nil { - t.Fatalf("expect no error, got %v", err) - } - - if !reflect.DeepEqual(expect, actual) { - t.Errorf("expect unmarshal to match\nExpect: %s\nActual: %s", - awsutil.Prettify(expect), awsutil.Prettify(actual)) + for i, tt := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + actual := mockOutput{} + decoder := xml.NewDecoder(strings.NewReader(tt.Body)) + err := UnmarshalXML(&actual, decoder, "") + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if tt.ExpectFn != nil { + tt.ExpectFn(t, actual) + return + } + + if !reflect.DeepEqual(tt.Expect, actual) { + t.Errorf("expect unmarshal to match\nExpect: %s\nActual: %s", + awsutil.Prettify(tt.Expect), awsutil.Prettify(actual)) + } + }) } }