diff --git a/private/protocol/json/jsonutil/build.go b/private/protocol/json/jsonutil/build.go
index 2aec80661a4..12e814ddf25 100644
--- a/private/protocol/json/jsonutil/build.go
+++ b/private/protocol/json/jsonutil/build.go
@@ -4,7 +4,6 @@ package jsonutil
import (
"bytes"
"encoding/base64"
- "encoding/json"
"fmt"
"math"
"reflect"
@@ -16,6 +15,12 @@ import (
"github.com/aws/aws-sdk-go/private/protocol"
)
+const (
+ floatNaN = "NaN"
+ floatInf = "Infinity"
+ floatNegInf = "-Infinity"
+)
+
var timeType = reflect.ValueOf(time.Time{}).Type()
var byteSliceType = reflect.ValueOf([]byte{}).Type()
@@ -211,10 +216,16 @@ func buildScalar(v reflect.Value, buf *bytes.Buffer, tag reflect.StructTag) erro
buf.Write(strconv.AppendInt(scratch[:0], value.Int(), 10))
case reflect.Float64:
f := value.Float()
- if math.IsInf(f, 0) || math.IsNaN(f) {
- return &json.UnsupportedValueError{Value: v, Str: strconv.FormatFloat(f, 'f', -1, 64)}
+ switch {
+ case math.IsNaN(f):
+ writeString(floatNaN, buf)
+ case math.IsInf(f, 1):
+ writeString(floatInf, buf)
+ case math.IsInf(f, -1):
+ writeString(floatNegInf, buf)
+ default:
+ buf.Write(strconv.AppendFloat(scratch[:0], f, 'f', -1, 64))
}
- buf.Write(strconv.AppendFloat(scratch[:0], f, 'f', -1, 64))
default:
switch converted := value.Interface().(type) {
case time.Time:
diff --git a/private/protocol/json/jsonutil/build_test.go b/private/protocol/json/jsonutil/build_test.go
index 66392b658e0..353228cf7d4 100644
--- a/private/protocol/json/jsonutil/build_test.go
+++ b/private/protocol/json/jsonutil/build_test.go
@@ -2,6 +2,7 @@ package jsonutil_test
import (
"encoding/json"
+ "math"
"strings"
"testing"
"time"
@@ -41,41 +42,48 @@ var jsonTests = []struct {
err string
}{
{
- J{},
- `{}`,
- ``,
+ in: J{},
+ out: `{}`,
},
{
- J{
+ in: J{
S: S("str"),
SS: []string{"A", "B", "C"},
D: D(123),
F: F(4.56),
T: T(time.Unix(987, 0)),
},
- `{"S":"str","SS":["A","B","C"],"D":123,"F":4.56,"T":987}`,
- ``,
+ out: `{"S":"str","SS":["A","B","C"],"D":123,"F":4.56,"T":987}`,
},
{
- J{
+ in: J{
S: S(`"''"`),
},
- `{"S":"\"''\""}`,
- ``,
+ out: `{"S":"\"''\""}`,
},
{
- J{
+ in: J{
S: S("\x00føø\u00FF\n\\\"\r\t\b\f"),
},
- `{"S":"\u0000føøÿ\n\\\"\r\t\b\f"}`,
- ``,
+ out: `{"S":"\u0000føøÿ\n\\\"\r\t\b\f"}`,
},
{
- J{
- F: F(4.56 / zero),
+ in: J{
+ F: F(math.NaN()),
},
- "",
- `json: unsupported value: +Inf`,
+ out: `{"F":"NaN"}`,
+ },
+ {
+ in: J{
+ F: F(math.Inf(1)),
+ },
+ out: `{"F":"Infinity"}`,
+ },
+ {
+ in: J{
+ F: F(math.Inf(-1)),
+ },
+ out: `{"F":"-Infinity"}`,
},
}
diff --git a/private/protocol/json/jsonutil/unmarshal.go b/private/protocol/json/jsonutil/unmarshal.go
index 8b2c9bbeba0..f9334879b80 100644
--- a/private/protocol/json/jsonutil/unmarshal.go
+++ b/private/protocol/json/jsonutil/unmarshal.go
@@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"io"
+ "math"
"math/big"
"reflect"
"strings"
@@ -258,6 +259,18 @@ func (u unmarshaler) unmarshalScalar(value reflect.Value, data interface{}, tag
return err
}
value.Set(reflect.ValueOf(v))
+ case *float64:
+ // These are regular strings when parsed by encoding/json's unmarshaler.
+ switch {
+ case strings.EqualFold(d, floatNaN):
+ value.Set(reflect.ValueOf(aws.Float64(math.NaN())))
+ case strings.EqualFold(d, floatInf):
+ value.Set(reflect.ValueOf(aws.Float64(math.Inf(1))))
+ case strings.EqualFold(d, floatNegInf):
+ value.Set(reflect.ValueOf(aws.Float64(math.Inf(-1))))
+ default:
+ return fmt.Errorf("unknown JSON number value: %s", d)
+ }
default:
return fmt.Errorf("unsupported value: %v (%s)", value.Interface(), value.Type())
}
diff --git a/private/protocol/json/jsonutil/unmarshal_test.go b/private/protocol/json/jsonutil/unmarshal_test.go
index a91b46c34fa..1e7ec3615ef 100644
--- a/private/protocol/json/jsonutil/unmarshal_test.go
+++ b/private/protocol/json/jsonutil/unmarshal_test.go
@@ -5,6 +5,7 @@ package jsonutil_test
import (
"bytes"
+ "math"
"reflect"
"testing"
"time"
@@ -21,9 +22,10 @@ func TestUnmarshalJSON_JSONNumber(t *testing.T) {
}
cases := map[string]struct {
- JSON string
- Value input
- Expected input
+ JSON string
+ Value input
+ Expected input
+ ExpectedFn func(*testing.T, input)
}{
"seconds precision": {
JSON: `{"timeField":1597094942}`,
@@ -106,6 +108,29 @@ func TestUnmarshalJSON_JSONNumber(t *testing.T) {
FloatField: aws.Float64(123456789.123),
},
},
+ "float64 field NaN": {
+ JSON: `{"floatField":"NaN"}`,
+ ExpectedFn: func(t *testing.T, input input) {
+ if input.FloatField == nil {
+ t.Fatal("expect non nil float64")
+ }
+ if e, a := true, math.IsNaN(*input.FloatField); e != a {
+ t.Errorf("expect %v, got %v", e, a)
+ }
+ },
+ },
+ "float64 field Infinity": {
+ JSON: `{"floatField":"Infinity"}`,
+ Expected: input{
+ FloatField: aws.Float64(math.Inf(1)),
+ },
+ },
+ "float64 field -Infinity": {
+ JSON: `{"floatField":"-Infinity"}`,
+ Expected: input{
+ FloatField: aws.Float64(math.Inf(-1)),
+ },
+ },
}
for name, tt := range cases {
@@ -114,6 +139,10 @@ func TestUnmarshalJSON_JSONNumber(t *testing.T) {
if err != nil {
t.Errorf("expect no error, got %v", err)
}
+ if tt.ExpectedFn != nil {
+ tt.ExpectedFn(t, tt.Value)
+ return
+ }
if e, a := tt.Expected, tt.Value; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v, got %v", e, a)
}
diff --git a/private/protocol/query/queryutil/queryutil.go b/private/protocol/query/queryutil/queryutil.go
index 75866d01218..058334053c2 100644
--- a/private/protocol/query/queryutil/queryutil.go
+++ b/private/protocol/query/queryutil/queryutil.go
@@ -3,6 +3,7 @@ package queryutil
import (
"encoding/base64"
"fmt"
+ "math"
"net/url"
"reflect"
"sort"
@@ -13,6 +14,12 @@ import (
"github.com/aws/aws-sdk-go/private/protocol"
)
+const (
+ floatNaN = "NaN"
+ floatInf = "Infinity"
+ floatNegInf = "-Infinity"
+)
+
// Parse parses an object i and fills a url.Values object. The isEC2 flag
// indicates if this is the EC2 Query sub-protocol.
func Parse(body url.Values, i interface{}, isEC2 bool) error {
@@ -228,9 +235,32 @@ func (q *queryParser) parseScalar(v url.Values, r reflect.Value, name string, ta
case int:
v.Set(name, strconv.Itoa(value))
case float64:
- v.Set(name, strconv.FormatFloat(value, 'f', -1, 64))
+ var str string
+ switch {
+ case math.IsNaN(value):
+ str = floatNaN
+ case math.IsInf(value, 1):
+ str = floatInf
+ case math.IsInf(value, -1):
+ str = floatNegInf
+ default:
+ str = strconv.FormatFloat(value, 'f', -1, 64)
+ }
+ v.Set(name, str)
case float32:
- v.Set(name, strconv.FormatFloat(float64(value), 'f', -1, 32))
+ asFloat64 := float64(value)
+ var str string
+ switch {
+ case math.IsNaN(asFloat64):
+ str = floatNaN
+ case math.IsInf(asFloat64, 1):
+ str = floatInf
+ case math.IsInf(asFloat64, -1):
+ str = floatNegInf
+ default:
+ str = strconv.FormatFloat(asFloat64, 'f', -1, 32)
+ }
+ v.Set(name, str)
case time.Time:
const ISO8601UTC = "2006-01-02T15:04:05Z"
format := tag.Get("timestampFormat")
diff --git a/private/protocol/rest/build.go b/private/protocol/rest/build.go
index 63f66af2c62..1d273ff0ec6 100644
--- a/private/protocol/rest/build.go
+++ b/private/protocol/rest/build.go
@@ -6,6 +6,7 @@ import (
"encoding/base64"
"fmt"
"io"
+ "math"
"net/http"
"net/url"
"path"
@@ -20,6 +21,12 @@ import (
"github.com/aws/aws-sdk-go/private/protocol"
)
+const (
+ floatNaN = "NaN"
+ floatInf = "Infinity"
+ floatNegInf = "-Infinity"
+)
+
// Whether the byte value can be sent without escaping in AWS URLs
var noEscape [256]bool
@@ -302,7 +309,16 @@ func convertType(v reflect.Value, tag reflect.StructTag) (str string, err error)
case int64:
str = strconv.FormatInt(value, 10)
case float64:
- str = strconv.FormatFloat(value, 'f', -1, 64)
+ switch {
+ case math.IsNaN(value):
+ str = floatNaN
+ case math.IsInf(value, 1):
+ str = floatInf
+ case math.IsInf(value, -1):
+ str = floatNegInf
+ default:
+ str = strconv.FormatFloat(value, 'f', -1, 64)
+ }
case time.Time:
format := tag.Get("timestampFormat")
if len(format) == 0 {
diff --git a/private/protocol/rest/build_test.go b/private/protocol/rest/build_test.go
index ca06460384f..5b48482bc7d 100644
--- a/private/protocol/rest/build_test.go
+++ b/private/protocol/rest/build_test.go
@@ -1,6 +1,10 @@
+//go:build go1.7
+// +build go1.7
+
package rest
import (
+ "math"
"net/http"
"net/url"
"reflect"
@@ -175,3 +179,103 @@ func TestListOfEnums(t *testing.T) {
}
}
}
+
+func TestMarshalFloat64(t *testing.T) {
+ cases := map[string]struct {
+ Input interface{}
+ URL string
+ ExpectedHeader http.Header
+ ExpectedURL string
+ WantErr bool
+ }{
+ "header float values": {
+ Input: &struct {
+ Float *float64 `location:"header" locationName:"x-amz-float"`
+ FloatInf *float64 `location:"header" locationName:"x-amz-float-inf"`
+ FloatNegInf *float64 `location:"header" locationName:"x-amz-float-neg-inf"`
+ FloatNaN *float64 `location:"header" locationName:"x-amz-float-nan"`
+ }{
+ Float: aws.Float64(123456789.123),
+ FloatInf: aws.Float64(math.Inf(1)),
+ FloatNegInf: aws.Float64(math.Inf(-1)),
+ FloatNaN: aws.Float64(math.NaN()),
+ },
+ URL: "https://example.com/",
+ ExpectedHeader: map[string][]string{
+ "X-Amz-Float": {"123456789.123"},
+ "X-Amz-Float-Inf": {"Infinity"},
+ "X-Amz-Float-Neg-Inf": {"-Infinity"},
+ "X-Amz-Float-Nan": {"NaN"},
+ },
+ ExpectedURL: "https://example.com/",
+ },
+ "path float values": {
+ Input: &struct {
+ Float *float64 `location:"uri" locationName:"float"`
+ FloatInf *float64 `location:"uri" locationName:"floatInf"`
+ FloatNegInf *float64 `location:"uri" locationName:"floatNegInf"`
+ FloatNaN *float64 `location:"uri" locationName:"floatNaN"`
+ }{
+ Float: aws.Float64(123456789.123),
+ FloatInf: aws.Float64(math.Inf(1)),
+ FloatNegInf: aws.Float64(math.Inf(-1)),
+ FloatNaN: aws.Float64(math.NaN()),
+ },
+ URL: "https://example.com/{float}/{floatInf}/{floatNegInf}/{floatNaN}",
+ ExpectedHeader: map[string][]string{},
+ ExpectedURL: "https://example.com/123456789.123/Infinity/-Infinity/NaN",
+ },
+ "query float values": {
+ Input: &struct {
+ Float *float64 `location:"querystring" locationName:"x-amz-float"`
+ FloatInf *float64 `location:"querystring" locationName:"x-amz-float-inf"`
+ FloatNegInf *float64 `location:"querystring" locationName:"x-amz-float-neg-inf"`
+ FloatNaN *float64 `location:"querystring" locationName:"x-amz-float-nan"`
+ }{
+ Float: aws.Float64(123456789.123),
+ FloatInf: aws.Float64(math.Inf(1)),
+ FloatNegInf: aws.Float64(math.Inf(-1)),
+ FloatNaN: aws.Float64(math.NaN()),
+ },
+ URL: "https://example.com/",
+ ExpectedHeader: map[string][]string{},
+ ExpectedURL: "https://example.com/?x-amz-float=123456789.123&x-amz-float-inf=Infinity&x-amz-float-nan=NaN&x-amz-float-neg-inf=-Infinity",
+ },
+ }
+
+ for name, tt := range cases {
+ t.Run(name, func(t *testing.T) {
+ req := &request.Request{
+ HTTPRequest: &http.Request{
+ URL: func() *url.URL {
+ u, err := url.Parse(tt.URL)
+ if err != nil {
+ panic(err)
+ }
+ return u
+ }(),
+ Header: map[string][]string{},
+ },
+ Params: tt.Input,
+ }
+
+ Build(req)
+
+ if (req.Error != nil) != (tt.WantErr) {
+ t.Fatalf("WantErr(%t) got %v", tt.WantErr, req.Error)
+ }
+
+ if tt.WantErr {
+ return
+ }
+
+ if e, a := tt.ExpectedHeader, req.HTTPRequest.Header; !reflect.DeepEqual(e, a) {
+ t.Errorf("expect %v, got %v", e, a)
+ }
+
+ if e, a := tt.ExpectedURL, req.HTTPRequest.URL.String(); e != a {
+ t.Errorf("expect %v, got %v", e, a)
+ }
+ })
+ }
+}
diff --git a/private/protocol/rest/unmarshal.go b/private/protocol/rest/unmarshal.go
index cdef403e219..79fcf1699b7 100644
--- a/private/protocol/rest/unmarshal.go
+++ b/private/protocol/rest/unmarshal.go
@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"io/ioutil"
+ "math"
"net/http"
"reflect"
"strconv"
@@ -231,9 +232,20 @@ func unmarshalHeader(v reflect.Value, header string, tag reflect.StructTag) erro
}
v.Set(reflect.ValueOf(&i))
case *float64:
- f, err := strconv.ParseFloat(header, 64)
- if err != nil {
- return err
+ var f float64
+ switch {
+ case strings.EqualFold(header, floatNaN):
+ f = math.NaN()
+ case strings.EqualFold(header, floatInf):
+ f = math.Inf(1)
+ case strings.EqualFold(header, floatNegInf):
+ f = math.Inf(-1)
+ default:
+ var err error
+ f, err = strconv.ParseFloat(header, 64)
+ if err != nil {
+ return err
+ }
}
v.Set(reflect.ValueOf(&f))
case *time.Time:
diff --git a/private/protocol/rest/unmarshal_test.go b/private/protocol/rest/unmarshal_test.go
new file mode 100644
index 00000000000..a3ad179b383
--- /dev/null
+++ b/private/protocol/rest/unmarshal_test.go
@@ -0,0 +1,86 @@
+//go:build go1.7
+// +build go1.7
+
+package rest
+
+import (
+ "bytes"
+ "github.com/aws/aws-sdk-go/aws"
+ "github.com/aws/aws-sdk-go/aws/request"
+ "io/ioutil"
+ "math"
+ "net/http"
+ "reflect"
+ "testing"
+)
+
+func TestUnmarshalFloat64(t *testing.T) {
+ cases := map[string]struct {
+ Headers http.Header
+ OutputFn func() (interface{}, func(*testing.T, interface{}))
+ WantErr bool
+ }{
+ "header float values": {
+ OutputFn: func() (interface{}, func(*testing.T, interface{})) {
+ type output struct {
+ Float *float64 `location:"header" locationName:"x-amz-float"`
+ FloatInf *float64 `location:"header" locationName:"x-amz-float-inf"`
+ FloatNegInf *float64 `location:"header" locationName:"x-amz-float-neg-inf"`
+ FloatNaN *float64 `location:"header" locationName:"x-amz-float-nan"`
+ }
+
+ return &output{}, func(t *testing.T, out interface{}) {
+ o, ok := out.(*output)
+ if !ok {
+ t.Errorf("expect %T, got %T", (*output)(nil), out)
+ }
+ if e, a := aws.Float64(123456789.123), o.Float; !reflect.DeepEqual(e, a) {
+ t.Errorf("expect %v, got %v", e, a)
+ }
+ if a := aws.Float64Value(o.FloatInf); !math.IsInf(a, 1) {
+ t.Errorf("expect infinity, got %v", a)
+ }
+ if a := aws.Float64Value(o.FloatNegInf); !math.IsInf(a, -1) {
+ t.Errorf("expect infinity, got %v", a)
+ }
+ if a := aws.Float64Value(o.FloatNaN); !math.IsNaN(a) {
+ t.Errorf("expect infinity, got %v", a)
+ }
+ }
+ },
+ Headers: map[string][]string{
+ "X-Amz-Float": {"123456789.123"},
+ "X-Amz-Float-Inf": {"Infinity"},
+ "X-Amz-Float-Neg-Inf": {"-Infinity"},
+ "X-Amz-Float-Nan": {"NaN"},
+ },
+ },
+ }
+
+ for name, tt := range cases {
+ t.Run(name, func(t *testing.T) {
+ output, expectFn := tt.OutputFn()
+
+ req := &request.Request{
+ Data: output,
+ HTTPResponse: &http.Response{
+ StatusCode: 200,
+ Body: ioutil.NopCloser(bytes.NewReader(nil)),
+ Header: tt.Headers,
+ },
+ }
+
+ if (req.Error != nil) != tt.WantErr {
+ t.Fatalf("WantErr(%v) != %v", tt.WantErr, req.Error)
+ }
+
+ if tt.WantErr {
+ return
+ }
+
+ UnmarshalMeta(req)
+
+ expectFn(t, output)
+ })
+ }
+}
diff --git a/private/protocol/xml/xmlutil/build.go b/private/protocol/xml/xmlutil/build.go
index 2fbb93ae76a..58c12bd8ccb 100644
--- a/private/protocol/xml/xmlutil/build.go
+++ b/private/protocol/xml/xmlutil/build.go
@@ -5,6 +5,7 @@ import (
"encoding/base64"
"encoding/xml"
"fmt"
+ "math"
"reflect"
"sort"
"strconv"
@@ -14,6 +15,12 @@ import (
"github.com/aws/aws-sdk-go/private/protocol"
)
+const (
+ floatNaN = "NaN"
+ floatInf = "Infinity"
+ floatNegInf = "-Infinity"
+)
+
// BuildXML will serialize params into an xml.Encoder. Error will be returned
// if the serialization of any of the params or nested values fails.
func BuildXML(params interface{}, e *xml.Encoder) error {
@@ -275,6 +282,7 @@ func (b *xmlBuilder) buildMap(value reflect.Value, current *XMLNode, tag reflect
// Error will be returned if the value type is unsupported.
func (b *xmlBuilder) buildScalar(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
var str string
+
switch converted := value.Interface().(type) {
case string:
str = converted
@@ -289,9 +297,29 @@ func (b *xmlBuilder) buildScalar(value reflect.Value, current *XMLNode, tag refl
case int:
str = strconv.Itoa(converted)
case float64:
- str = strconv.FormatFloat(converted, 'f', -1, 64)
+ switch {
+ case math.IsNaN(converted):
+ str = floatNaN
+ case math.IsInf(converted, 1):
+ str = floatInf
+ case math.IsInf(converted, -1):
+ str = floatNegInf
+ default:
+ str = strconv.FormatFloat(converted, 'f', -1, 64)
+ }
case float32:
- str = strconv.FormatFloat(float64(converted), 'f', -1, 32)
+ // The SDK doesn't render float32 values in types, only float64. This case would never be hit currently.
+ asFloat64 := float64(converted)
+ switch {
+ case math.IsNaN(asFloat64):
+ str = floatNaN
+ case math.IsInf(asFloat64, 1):
+ str = floatInf
+ case math.IsInf(asFloat64, -1):
+ str = floatNegInf
+ default:
+ str = strconv.FormatFloat(asFloat64, 'f', -1, 32)
+ }
case time.Time:
format := tag.Get("timestampFormat")
if len(format) == 0 {
diff --git a/private/protocol/xml/xmlutil/build_test.go b/private/protocol/xml/xmlutil/build_test.go
index f1b305368c3..de4b927242e 100644
--- a/private/protocol/xml/xmlutil/build_test.go
+++ b/private/protocol/xml/xmlutil/build_test.go
@@ -6,6 +6,7 @@ package xmlutil
import (
"bytes"
"encoding/xml"
+ "math"
"testing"
"github.com/aws/aws-sdk-go/aws"
@@ -14,9 +15,10 @@ import (
type implicitPayload struct {
_ struct{} `type:"structure"`
- StrVal *string `type:"string"`
- Second *nestedType `type:"structure"`
- Third *nestedType `type:"structure"`
+ StrVal *string `type:"string"`
+ FloatVal *float64 `type:"double"`
+ Second *nestedType `type:"structure"`
+ Third *nestedType `type:"structure"`
}
type namedImplicitPayload struct {
@@ -160,6 +162,30 @@ func TestBuildXML(t *testing.T) {
},
Expect: "this
string
has
escapable
characters",
},
+ "float value": {
+ Input: &implicitPayload{
+ FloatVal: aws.Float64(123456789.123),
+ },
+ Expect: "123456789.123",
+ },
+ "infinity float value": {
+ Input: &implicitPayload{
+ FloatVal: aws.Float64(math.Inf(1)),
+ },
+ Expect: "Infinity",
+ },
+ "negative infinity float value": {
+ Input: &implicitPayload{
+ FloatVal: aws.Float64(math.Inf(-1)),
+ },
+ Expect: "-Infinity",
+ },
+ "NaN float value": {
+ Input: &implicitPayload{
+ FloatVal: aws.Float64(math.NaN()),
+ },
+ Expect: "NaN",
+ },
}
for name, c := range cases {
diff --git a/private/protocol/xml/xmlutil/unmarshal.go b/private/protocol/xml/xmlutil/unmarshal.go
index 107c053f8ac..44a580a940b 100644
--- a/private/protocol/xml/xmlutil/unmarshal.go
+++ b/private/protocol/xml/xmlutil/unmarshal.go
@@ -6,6 +6,7 @@ import (
"encoding/xml"
"fmt"
"io"
+ "math"
"reflect"
"strconv"
"strings"
@@ -276,9 +277,20 @@ func parseScalar(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
}
r.Set(reflect.ValueOf(&v))
case *float64:
- v, err := strconv.ParseFloat(node.Text, 64)
- if err != nil {
- return err
+ var v float64
+ switch {
+ case strings.EqualFold(node.Text, floatNaN):
+ v = math.NaN()
+ case strings.EqualFold(node.Text, floatInf):
+ v = math.Inf(1)
+ case strings.EqualFold(node.Text, floatNegInf):
+ v = math.Inf(-1)
+ default:
+ var err error
+ v, err = strconv.ParseFloat(node.Text, 64)
+ if err != nil {
+ return err
+ }
}
r.Set(reflect.ValueOf(&v))
case *time.Time:
diff --git a/private/protocol/xml/xmlutil/unmarshal_test.go b/private/protocol/xml/xmlutil/unmarshal_test.go
index 1185d232855..9c6ef9611b4 100644
--- a/private/protocol/xml/xmlutil/unmarshal_test.go
+++ b/private/protocol/xml/xmlutil/unmarshal_test.go
@@ -1,10 +1,15 @@
+//go:build go1.7
+// +build go1.7
+
package xmlutil
import (
"encoding/xml"
"fmt"
"io"
+ "math"
"reflect"
+ "strconv"
"strings"
"testing"
@@ -30,6 +35,7 @@ type mockOutput struct {
_ struct{} `type:"structure"`
String *string `type:"string"`
Integer *int64 `type:"integer"`
+ Float *float64 `type:"double"`
Nested *mockNestedStruct `type:"structure"`
List []*mockListElem `locationName:"List" locationNameList:"Elem" type:"list"`
Closed *mockClosedTags `type:"structure"`
@@ -56,7 +62,14 @@ type mockNestedListElem struct {
}
func TestUnmarshal(t *testing.T) {
- const xmlBodyStr = `
+
+ cases := []struct {
+ Body string
+ Expect mockOutput
+ ExpectFn func(t *testing.T, actual mockOutput)
+ }{
+ {
+ Body: `
string value
123
@@ -73,39 +86,77 @@ func TestUnmarshal(t *testing.T) {
elem string value
-`
-
- expect := mockOutput{
- String: aws.String("string value"),
- Integer: aws.Int64(123),
- Closed: &mockClosedTags{
- Attr: aws.String("attr value"),
+`,
+ Expect: mockOutput{
+ String: aws.String("string value"),
+ Integer: aws.Int64(123),
+ Closed: &mockClosedTags{
+ Attr: aws.String("attr value"),
+ },
+ Nested: &mockNestedStruct{
+ NestedString: aws.String("nested string value"),
+ NestedInt: aws.Int64(321),
+ },
+ List: []*mockListElem{
+ {
+ String: aws.String("elem string value"),
+ NestedElem: &mockNestedListElem{
+ String: aws.String("nested elem string value"),
+ Type: aws.String("type"),
+ },
+ },
+ },
+ },
},
- Nested: &mockNestedStruct{
- NestedString: aws.String("nested string value"),
- NestedInt: aws.Int64(321),
+ {
+ Body: `123456789.123`,
+ Expect: mockOutput{Float: aws.Float64(123456789.123)},
},
- List: []*mockListElem{
- {
- String: aws.String("elem string value"),
- NestedElem: &mockNestedListElem{
- String: aws.String("nested elem string value"),
- Type: aws.String("type"),
- },
+ {
+ Body: `Infinity`,
+ ExpectFn: func(t *testing.T, actual mockOutput) {
+ if a := aws.Float64Value(actual.Float); !math.IsInf(a, 1) {
+ t.Errorf("expect infinity, got %v", a)
+ }
+ },
+ },
+ {
+ Body: `-Infinity`,
+ ExpectFn: func(t *testing.T, actual mockOutput) {
+ if a := aws.Float64Value(actual.Float); !math.IsInf(a, -1) {
+ t.Errorf("expect -infinity, got %v", a)
+ }
+ },
+ },
+ {
+ Body: `NaN`,
+ ExpectFn: func(t *testing.T, actual mockOutput) {
+ if a := aws.Float64Value(actual.Float); !math.IsNaN(a) {
+ t.Errorf("expect NaN, got %v", a)
+ }
},
},
}
- actual := mockOutput{}
- decoder := xml.NewDecoder(strings.NewReader(xmlBodyStr))
- err := UnmarshalXML(&actual, decoder, "")
- if err != nil {
- t.Fatalf("expect no error, got %v", err)
- }
-
- if !reflect.DeepEqual(expect, actual) {
- t.Errorf("expect unmarshal to match\nExpect: %s\nActual: %s",
- awsutil.Prettify(expect), awsutil.Prettify(actual))
+ for i, tt := range cases {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ actual := mockOutput{}
+ decoder := xml.NewDecoder(strings.NewReader(tt.Body))
+ err := UnmarshalXML(&actual, decoder, "")
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+
+ if tt.ExpectFn != nil {
+ tt.ExpectFn(t, actual)
+ return
+ }
+
+ if !reflect.DeepEqual(tt.Expect, actual) {
+ t.Errorf("expect unmarshal to match\nExpect: %s\nActual: %s",
+ awsutil.Prettify(tt.Expect), awsutil.Prettify(actual))
+ }
+ })
}
}