From 621b4394d19da3998eaddf15d16a007c5b920494 Mon Sep 17 00:00:00 2001 From: Sean McGrail Date: Mon, 2 Dec 2019 16:06:18 -0800 Subject: [PATCH] REST Encoder Implementation V2 --- aws/protocol/rest/encode.go | 79 ++++++++++++ aws/protocol/rest/encode_test.go | 56 ++++++++ aws/protocol/rest/header.go | 103 +++++++++++++++ aws/protocol/rest/header_test.go | 213 +++++++++++++++++++++++++++++++ aws/protocol/rest/query.go | 80 ++++++++++++ aws/protocol/rest/query_test.go | 153 ++++++++++++++++++++++ aws/protocol/rest/shared_test.go | 36 ++++++ aws/protocol/rest/uri.go | 71 +++++++++++ aws/protocol/rest/uri_test.go | 119 +++++++++++++++++ private/protocol/path_replace.go | 8 +- 10 files changed, 915 insertions(+), 3 deletions(-) create mode 100644 aws/protocol/rest/encode.go create mode 100644 aws/protocol/rest/encode_test.go create mode 100644 aws/protocol/rest/header.go create mode 100644 aws/protocol/rest/header_test.go create mode 100644 aws/protocol/rest/query.go create mode 100644 aws/protocol/rest/query_test.go create mode 100644 aws/protocol/rest/shared_test.go create mode 100644 aws/protocol/rest/uri.go create mode 100644 aws/protocol/rest/uri_test.go diff --git a/aws/protocol/rest/encode.go b/aws/protocol/rest/encode.go new file mode 100644 index 00000000000..06e6ff576c4 --- /dev/null +++ b/aws/protocol/rest/encode.go @@ -0,0 +1,79 @@ +package rest + +import ( + "net/http" + "net/url" + "strings" +) + +// An Encoder provides encoding of REST URI path, query, and header components +// of an HTTP request. Can also encode a stream as the payload. +// +// Does not support SetFields. +type Encoder struct { + req *http.Request + + path, rawPath, pathBuffer []byte + + query url.Values + header http.Header +} + +// NewEncoder creates a new encoder from the passed in request. All query and +// header values will be added on top of the request's existing values. Overwriting +// duplicate values. +func NewEncoder(req *http.Request) *Encoder { + e := &Encoder{ + req: req, + + path: []byte(req.URL.Path), + rawPath: []byte(req.URL.Path), + query: req.URL.Query(), + header: req.Header, + } + + return e +} + +// Encode returns a REST protocol encoder for encoding HTTP bindings +// Returns any error if one occurred during encoding. +func (e *Encoder) Encode() error { + e.req.URL.Path, e.req.URL.RawPath = string(e.path), string(e.rawPath) + e.req.URL.RawQuery = e.query.Encode() + e.req.Header = e.header + + return nil +} + +// AddHeader returns a HeaderValue for appending to the given header name +func (e *Encoder) AddHeader(key string) HeaderValue { + return newHeaderValue(e.header, key, true) +} + +// SetHeader returns a HeaderValue for setting the given header name +func (e *Encoder) SetHeader(key string) HeaderValue { + return newHeaderValue(e.header, key, false) +} + +// Headers returns a Header used encoding headers with the given prefix +func (e *Encoder) Headers(prefix string) Headers { + return Headers{ + header: e.header, + prefix: strings.TrimSpace(prefix), + } +} + +// SetURI returns a URIValue used for setting the given path key +func (e *Encoder) SetURI(key string) URIValue { + return newURIValue(&e.path, &e.rawPath, &e.pathBuffer, key) +} + +// SetQuery returns a QueryValue used for setting the given query key +func (e *Encoder) SetQuery(key string) QueryValue { + return newQueryValue(e.query, key, false) +} + +// AddQuery returns a QueryValue used for appending the given query key +func (e *Encoder) AddQuery(key string) QueryValue { + return newQueryValue(e.query, key, true) +} diff --git a/aws/protocol/rest/encode_test.go b/aws/protocol/rest/encode_test.go new file mode 100644 index 00000000000..625d7962fba --- /dev/null +++ b/aws/protocol/rest/encode_test.go @@ -0,0 +1,56 @@ +package rest + +import ( + "net/http" + "net/url" + "reflect" + "testing" +) + +func TestEncoder(t *testing.T) { + actual := http.Request{ + Header: http.Header{ + "custom-user-header": {"someValue"}, + }, + URL: &url.URL{ + Path: "/some/{pathKey}/path", + RawQuery: "someExistingKeys=foobar", + }, + } + + expected := http.Request{ + Header: map[string][]string{ + "custom-user-header": {"someValue"}, + "x-amzn-header-foo": {"someValue"}, + "x-amzn-meta-foo": {"someValue"}, + }, + URL: &url.URL{ + Path: "/some/someValue/path", + RawPath: "/some/someValue/path", + RawQuery: "someExistingKeys=foobar&someKey=someValue&someKey=otherValue", + }, + } + + encoder := NewEncoder(&actual) + + // Headers + encoder.AddHeader("x-amzn-header-foo").String("someValue") + encoder.Headers("x-amzn-meta-").AddHeader("foo").String("someValue") + + // Query + encoder.SetQuery("someKey").String("someValue") + encoder.AddQuery("someKey").String("otherValue") + + // URI + if err := encoder.SetURI("pathKey").String("someValue"); err != nil { + t.Errorf("expected no err, but got %v", err) + } + + if err := encoder.Encode(); err != nil { + t.Errorf("expected no err, but got %v", err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("expected %v, but got %v", expected, actual) + } +} diff --git a/aws/protocol/rest/header.go b/aws/protocol/rest/header.go new file mode 100644 index 00000000000..94917e969c0 --- /dev/null +++ b/aws/protocol/rest/header.go @@ -0,0 +1,103 @@ +package rest + +import ( + "encoding/base64" + "net/http" + "strconv" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/private/protocol" +) + +// Headers is used to encode header keys using a provided prefix +type Headers struct { + header http.Header + prefix string +} + +// AddHeader returns a HeaderValue used to append values to prefix+key +func (h Headers) AddHeader(key string) HeaderValue { + return h.newHeaderValue(key, true) +} + +// SetHeader returns a HeaderValue used to set the value of prefix+key +func (h Headers) SetHeader(key string) HeaderValue { + return h.newHeaderValue(key, false) +} + +func (h Headers) newHeaderValue(key string, append bool) HeaderValue { + return newHeaderValue(h.header, h.prefix+strings.TrimSpace(key), append) +} + +// HeaderValue is used to encode values to an HTTP header +type HeaderValue struct { + header http.Header + key string + append bool +} + +func newHeaderValue(header http.Header, key string, append bool) HeaderValue { + return HeaderValue{header: header, key: strings.TrimSpace(key), append: append} +} + +func (h HeaderValue) modifyHeader(value string) { + lk := strings.ToLower(h.key) + + val := h.header[lk] + + if h.append { + val = append(val, value) + } else { + val = append(val[:0], value) + } + + h.header[lk] = val +} + +// String encodes the value v as the header string value +func (h HeaderValue) String(v string) { + h.modifyHeader(v) +} + +// Integer encodes the value v as the header string value +func (h HeaderValue) Integer(v int64) { + h.modifyHeader(strconv.FormatInt(v, 10)) +} + +// Boolean encodes the value v as a header string value +func (h HeaderValue) Boolean(v bool) { + h.modifyHeader(strconv.FormatBool(v)) +} + +// Float encodes the value v as a header string value +func (h HeaderValue) Float(v float64) { + h.modifyHeader(strconv.FormatFloat(v, 'f', -1, 64)) +} + +// Time encodes the value v using the format name as a header string value +func (h HeaderValue) Time(t time.Time, format string) error { + value, err := protocol.FormatTime(format, t) + if err != nil { + return err + } + h.modifyHeader(value) + return nil +} + +// ByteSlice encodes the value v as a base64 header string value +func (h HeaderValue) ByteSlice(v []byte) { + encodeToString := base64.StdEncoding.EncodeToString(v) + h.modifyHeader(encodeToString) +} + +// JSONValue encodes the value v as a base64 header string value +func (h HeaderValue) JSONValue(v aws.JSONValue) error { + encodedValue, err := protocol.EncodeJSONValue(v, protocol.Base64Escape) + if err != nil { + return err + } + h.modifyHeader(encodedValue) + return nil +} diff --git a/aws/protocol/rest/header_test.go b/aws/protocol/rest/header_test.go new file mode 100644 index 00000000000..db847eb2a99 --- /dev/null +++ b/aws/protocol/rest/header_test.go @@ -0,0 +1,213 @@ +package rest + +import ( + "fmt" + "net/http" + "reflect" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/private/protocol" +) + +func TestHeaderValue(t *testing.T) { + const keyName = "test-key" + const expectedKeyName = keyName + + cases := map[string]struct { + header http.Header + args []interface{} + append bool + expected http.Header + }{ + "set string": { + header: http.Header{expectedKeyName: []string{"foobar"}}, + args: []interface{}{"string value"}, + expected: map[string][]string{ + expectedKeyName: {"string value"}, + }, + }, + "set float64": { + header: http.Header{expectedKeyName: []string{"foobar"}}, + args: []interface{}{3.14159}, + expected: map[string][]string{ + expectedKeyName: {"3.14159"}, + }, + }, + "set bool": { + header: http.Header{expectedKeyName: []string{"foobar"}}, + args: []interface{}{true}, + expected: map[string][]string{ + expectedKeyName: {"true"}, + }, + }, + "set json": { + header: http.Header{expectedKeyName: []string{"foobar"}}, + args: []interface{}{aws.JSONValue{"jsonKey": "jsonValue"}}, + expected: map[string][]string{ + expectedKeyName: {"eyJqc29uS2V5IjoianNvblZhbHVlIn0="}, + }, + }, + "set time": { + header: http.Header{expectedKeyName: []string{"foobar"}}, + args: []interface{}{time.Unix(0, 0), protocol.ISO8601TimeFormatName}, + expected: map[string][]string{ + expectedKeyName: {"1970-01-01T00:00:00Z"}, + }, + }, + "set byte slice": { + header: http.Header{expectedKeyName: []string{"foobar"}}, + args: []interface{}{[]byte("baz")}, + expected: map[string][]string{ + expectedKeyName: {"YmF6"}, + }, + }, + "add string": { + header: http.Header{expectedKeyName: []string{"other string"}}, + args: []interface{}{"string value"}, + append: true, + expected: map[string][]string{ + expectedKeyName: {"other string", "string value"}, + }, + }, + "add float64": { + header: http.Header{expectedKeyName: []string{"1.61803"}}, + args: []interface{}{3.14159}, + append: true, + expected: map[string][]string{ + expectedKeyName: {"1.61803", "3.14159"}, + }, + }, + "add bool": { + header: http.Header{expectedKeyName: []string{"false"}}, + args: []interface{}{true}, + append: true, + expected: map[string][]string{ + expectedKeyName: {"false", "true"}, + }, + }, + "add json": { + header: http.Header{expectedKeyName: []string{`eyJzb21lS2V5Ijoic29tZVZhbHVlIn0=`}}, + args: []interface{}{aws.JSONValue{"jsonKey": "jsonValue"}}, + append: true, + expected: map[string][]string{ + expectedKeyName: {"eyJzb21lS2V5Ijoic29tZVZhbHVlIn0=", "eyJqc29uS2V5IjoianNvblZhbHVlIn0="}, + }, + }, + "add time": { + header: http.Header{expectedKeyName: []string{"1991-09-17T00:00:00Z"}}, + args: []interface{}{time.Unix(0, 0), protocol.ISO8601TimeFormatName}, + append: true, + expected: map[string][]string{ + expectedKeyName: {"1991-09-17T00:00:00Z", "1970-01-01T00:00:00Z"}, + }, + }, + "add byte slice": { + header: http.Header{expectedKeyName: []string{"YmFy"}}, + args: []interface{}{[]byte("baz")}, + append: true, + expected: map[string][]string{ + expectedKeyName: {"YmFy", "YmF6"}, + }, + }, + } + + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + if tt.header == nil { + tt.header = http.Header{} + } + + hv := newHeaderValue(tt.header, keyName, tt.append) + + if err := setHeader(hv, tt.args); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if e, a := tt.expected, hv.header; !reflect.DeepEqual(e, a) { + t.Errorf("expected %v, got %v", e, a) + } + }) + } +} + +func TestHeaders(t *testing.T) { + const prefix = "x-amzn-meta-" + cases := map[string]struct { + headers http.Header + values map[string]string + append bool + expected http.Header + }{ + "set": { + headers: http.Header{ + "x-amzn-meta-foo": {"bazValue"}, + }, + values: map[string]string{ + "foo": "fooValue", + " bar ": "barValue", + }, + expected: http.Header{ + "x-amzn-meta-foo": {"fooValue"}, + "x-amzn-meta-bar": {"barValue"}, + }, + }, + "add": { + headers: http.Header{ + "x-amzn-meta-foo": {"bazValue"}, + }, + values: map[string]string{ + "foo": "fooValue", + " bar ": "barValue", + }, + append: true, + expected: http.Header{ + "x-amzn-meta-foo": {"bazValue", "fooValue"}, + "x-amzn-meta-bar": {"barValue"}, + }, + }, + } + + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + headers := Headers{header: tt.headers, prefix: prefix} + + var f func(key string) HeaderValue + if tt.append { + f = headers.AddHeader + } else { + f = headers.SetHeader + } + + for key, value := range tt.values { + f(key).String(value) + } + + if e, a := tt.expected, tt.headers; !reflect.DeepEqual(e, a) { + t.Errorf("expected %v, but got %v", e, a) + } + }) + } +} + +func setHeader(hv HeaderValue, args []interface{}) error { + value := args[0] + + switch value.(type) { + case string: + return reflectCall(reflect.ValueOf(hv.String), args) + case float64: + return reflectCall(reflect.ValueOf(hv.Float), args) + case bool: + return reflectCall(reflect.ValueOf(hv.Boolean), args) + case aws.JSONValue: + return reflectCall(reflect.ValueOf(hv.JSONValue), args) + case time.Time: + return reflectCall(reflect.ValueOf(hv.Time), args) + case []byte: + return reflectCall(reflect.ValueOf(hv.ByteSlice), args) + default: + return fmt.Errorf("unhandled header value type") + } +} diff --git a/aws/protocol/rest/query.go b/aws/protocol/rest/query.go new file mode 100644 index 00000000000..8fcf6a862a6 --- /dev/null +++ b/aws/protocol/rest/query.go @@ -0,0 +1,80 @@ +package rest + +import ( + "encoding/base64" + "net/url" + "strconv" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/private/protocol" +) + +// QueryValue is used to encode query key values +type QueryValue struct { + query url.Values + key string + append bool +} + +func newQueryValue(query url.Values, key string, append bool) QueryValue { + return QueryValue{ + query: query, + key: key, + append: append, + } +} + +func (qv QueryValue) updateKey(value string) { + if qv.append { + qv.query.Add(qv.key, value) + } else { + qv.query.Set(qv.key, value) + } +} + +// String encodes the value v as a query string value +func (qv QueryValue) String(v string) { + qv.updateKey(v) +} + +// Integer encodes the value v as a query string value +func (qv QueryValue) Integer(v int64) { + qv.updateKey(strconv.FormatInt(v, 10)) +} + +// Boolean encodes the value v as a query string value +func (qv QueryValue) Boolean(v bool) { + qv.updateKey(strconv.FormatBool(v)) +} + +// Float encodes the value v as a query string value +func (qv QueryValue) Float(v float64) { + qv.updateKey(strconv.FormatFloat(v, 'f', -1, 64)) +} + +// Time encodes the value v using the format name as a query string value +func (qv QueryValue) Time(v time.Time, format string) error { + value, err := protocol.FormatTime(format, v) + if err != nil { + return err + } + qv.updateKey(value) + return nil +} + +// ByteSlice encodes the value v as a base64 query string value +func (qv QueryValue) ByteSlice(v []byte) { + encodeToString := base64.StdEncoding.EncodeToString(v) + qv.updateKey(encodeToString) +} + +// JSONValue encodes the value v using the format name as a query string value +func (qv QueryValue) JSONValue(v aws.JSONValue) error { + encodeJSONValue, err := protocol.EncodeJSONValue(v, protocol.NoEscape) + if err != nil { + return err + } + qv.updateKey(encodeJSONValue) + return nil +} diff --git a/aws/protocol/rest/query_test.go b/aws/protocol/rest/query_test.go new file mode 100644 index 00000000000..8b61dd7bfd0 --- /dev/null +++ b/aws/protocol/rest/query_test.go @@ -0,0 +1,153 @@ +package rest + +import ( + "fmt" + "net/url" + "reflect" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/private/protocol" +) + +func TestQueryValue(t *testing.T) { + const queryKey = "someKey" + + cases := map[string]struct { + values url.Values + args []interface{} + append bool + expected url.Values + }{ + "set string": { + values: url.Values{queryKey: []string{"foobar"}}, + args: []interface{}{"string value"}, + expected: map[string][]string{ + queryKey: {"string value"}, + }, + }, + "set float64": { + values: url.Values{queryKey: []string{"foobar"}}, + args: []interface{}{3.14159}, + expected: map[string][]string{ + queryKey: {"3.14159"}, + }, + }, + "set bool": { + values: url.Values{queryKey: []string{"foobar"}}, + args: []interface{}{true}, + expected: map[string][]string{ + queryKey: {"true"}, + }, + }, + "set json": { + values: url.Values{queryKey: []string{"foobar"}}, + args: []interface{}{aws.JSONValue{"jsonKey": "jsonValue"}}, + expected: map[string][]string{ + queryKey: {`{"jsonKey":"jsonValue"}`}, + }, + }, + "set time": { + values: url.Values{queryKey: []string{"foobar"}}, + args: []interface{}{time.Unix(0, 0), protocol.ISO8601TimeFormatName}, + expected: map[string][]string{ + queryKey: {"1970-01-01T00:00:00Z"}, + }, + }, + "set byte slice": { + values: url.Values{queryKey: []string{"foobar"}}, + args: []interface{}{[]byte("baz")}, + expected: map[string][]string{ + queryKey: {"YmF6"}, + }, + }, + "add string": { + values: url.Values{queryKey: []string{"other string"}}, + args: []interface{}{"string value"}, + append: true, + expected: map[string][]string{ + queryKey: {"other string", "string value"}, + }, + }, + "add float64": { + values: url.Values{queryKey: []string{"1.61803"}}, + args: []interface{}{3.14159}, + append: true, + expected: map[string][]string{ + queryKey: {"1.61803", "3.14159"}, + }, + }, + "add bool": { + values: url.Values{queryKey: []string{"false"}}, + args: []interface{}{true}, + append: true, + expected: map[string][]string{ + queryKey: {"false", "true"}, + }, + }, + "add json": { + values: url.Values{queryKey: []string{`{"someKey":"someValue"}`}}, + args: []interface{}{aws.JSONValue{"jsonKey": "jsonValue"}}, + append: true, + expected: map[string][]string{ + queryKey: {`{"someKey":"someValue"}`, `{"jsonKey":"jsonValue"}`}, + }, + }, + "add time": { + values: url.Values{queryKey: []string{"1991-09-17T00:00:00Z"}}, + args: []interface{}{time.Unix(0, 0), protocol.ISO8601TimeFormatName}, + append: true, + expected: map[string][]string{ + queryKey: {"1991-09-17T00:00:00Z", "1970-01-01T00:00:00Z"}, + }, + }, + "add byte slice": { + values: url.Values{queryKey: []string{"YmFy"}}, + args: []interface{}{[]byte("baz")}, + append: true, + expected: map[string][]string{ + queryKey: {"YmFy", "YmF6"}, + }, + }, + } + + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + if tt.values == nil { + tt.values = url.Values{} + } + + qv := newQueryValue(tt.values, queryKey, tt.append) + + if err := setQueryValue(qv, tt.args); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if e, a := tt.expected, qv.query; !reflect.DeepEqual(e, a) { + t.Errorf("expected %v, got %v", e, a) + } + }) + } +} + +func setQueryValue(qv QueryValue, args []interface{}) error { + value := args[0] + + switch value.(type) { + case string: + return reflectCall(reflect.ValueOf(qv.String), args) + case float64: + return reflectCall(reflect.ValueOf(qv.Float), args) + case bool: + return reflectCall(reflect.ValueOf(qv.Boolean), args) + case aws.JSONValue: + return reflectCall(reflect.ValueOf(qv.JSONValue), args) + case time.Time: + return reflectCall(reflect.ValueOf(qv.Time), args) + case []byte: + return reflectCall(reflect.ValueOf(qv.ByteSlice), args) + default: + return fmt.Errorf("unhandled query value type") + } +} diff --git a/aws/protocol/rest/shared_test.go b/aws/protocol/rest/shared_test.go new file mode 100644 index 00000000000..7b9eed5cfc1 --- /dev/null +++ b/aws/protocol/rest/shared_test.go @@ -0,0 +1,36 @@ +package rest + +import ( + "fmt" + "reflect" +) + +func reflectCall(funcValue reflect.Value, args []interface{}) error { + argValues := make([]reflect.Value, len(args)) + + for i, v := range args { + value := reflect.ValueOf(v) + argValues[i] = value + } + + retValues := funcValue.Call(argValues) + if len(retValues) > 0 { + errValue := retValues[0] + + if typeName := errValue.Type().Name(); typeName != "error" { + panic(fmt.Sprintf("expected first return argument to be error but got %v", typeName)) + } + + if errValue.IsNil() { + return nil + } + + if err, ok := errValue.Interface().(error); ok { + return err + } + + panic(fmt.Sprintf("expected %v to return error type, but got %v", funcValue.Type().String(), retValues[0].Type().String())) + } + + return nil +} diff --git a/aws/protocol/rest/uri.go b/aws/protocol/rest/uri.go new file mode 100644 index 00000000000..46add5ce540 --- /dev/null +++ b/aws/protocol/rest/uri.go @@ -0,0 +1,71 @@ +package rest + +import ( + "strconv" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/private/protocol" +) + +// URIValue is used to encode named URI parameters +type URIValue struct { + path, rawPath, buffer *[]byte + + key string +} + +func newURIValue(path *[]byte, rawPath *[]byte, buffer *[]byte, key string) URIValue { + return URIValue{path: path, rawPath: rawPath, buffer: buffer, key: key} +} + +func (u URIValue) modifyURI(value string) (err error) { + *u.path, *u.buffer, err = protocol.ReplacePathElement(*u.path, *u.buffer, u.key, value, false) + *u.rawPath, *u.buffer, err = protocol.ReplacePathElement(*u.rawPath, *u.buffer, u.key, value, true) + return err +} + +// String encodes the value v as a URI string value +func (u URIValue) String(v string) error { + return u.modifyURI(v) +} + +// Integer encodes the value v as a URI string value +func (u URIValue) Integer(v int64) error { + return u.modifyURI(strconv.FormatInt(v, 10)) +} + +// Boolean encodes the value v as a URI string value +func (u URIValue) Boolean(v bool) error { + return u.modifyURI(strconv.FormatBool(v)) +} + +// Float encodes the value v as a URI string value +func (u URIValue) Float(v float64) error { + return u.modifyURI(strconv.FormatFloat(v, 'f', -1, 64)) +} + +// Time encodes the value v using the format name as a URI string value +func (u URIValue) Time(v time.Time, format string) error { + value, err := protocol.FormatTime(format, v) + if err != nil { + return err + } + + return u.modifyURI(value) +} + +// ByteSlice encodes the value v as a base64 URI string value +func (u URIValue) ByteSlice(v []byte) error { + return u.modifyURI(string(v)) +} + +// JSONValue encodes the value v as a URI string value +func (u URIValue) JSONValue(v aws.JSONValue) error { + encodeJSONValue, err := protocol.EncodeJSONValue(v, protocol.NoEscape) + if err != nil { + return err + } + + return u.modifyURI(encodeJSONValue) +} diff --git a/aws/protocol/rest/uri_test.go b/aws/protocol/rest/uri_test.go new file mode 100644 index 00000000000..9e2494f6d22 --- /dev/null +++ b/aws/protocol/rest/uri_test.go @@ -0,0 +1,119 @@ +package rest + +import ( + "fmt" + "reflect" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/private/protocol" +) + +func TestURIValue(t *testing.T) { + const uriKey = "someKey" + const path = "/some/{someKey}/{path+}" + + type expected struct { + path string + raw string + } + + cases := map[string]struct { + path string + args []interface{} + expected expected + }{ + "string": { + path: path, + args: []interface{}{"someValue"}, + expected: expected{ + path: "/some/someValue/{path+}", + raw: "/some/someValue/{path+}", + }, + }, + "float64": { + path: path, + args: []interface{}{3.14159}, + expected: expected{ + path: "/some/3.14159/{path+}", + raw: "/some/3.14159/{path+}", + }, + }, + "bool": { + path: path, + args: []interface{}{true}, + expected: expected{ + path: "/some/true/{path+}", + raw: "/some/true/{path+}", + }, + }, + "json": { + path: path, + args: []interface{}{aws.JSONValue{"jsonKey": "jsonValue"}}, + expected: expected{ + path: `/some/{"jsonKey":"jsonValue"}/{path+}`, + raw: "/some/%7B%22jsonKey%22%3A%22jsonValue%22%7D/{path+}", + }, + }, + "time": { + path: path, + args: []interface{}{time.Unix(0, 0), protocol.ISO8601TimeFormatName}, + expected: expected{ + path: "/some/1970-01-01T00:00:00Z/{path+}", + raw: "/some/1970-01-01T00%3A00%3A00Z/{path+}", + }, + }, + "byte slice": { + path: path, + args: []interface{}{[]byte("baz")}, + expected: expected{ + path: "/some/baz/{path+}", + raw: "/some/baz/{path+}", + }, + }, + } + + buffer := make([]byte, 1024) + + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + pBytes, rBytes := []byte(tt.path), []byte(tt.path) + + uv := newURIValue(&pBytes, &rBytes, &buffer, uriKey) + + if err := setURI(uv, tt.args); err != nil { + t.Fatalf("expected no error, %v", err) + } + + if e, a := tt.expected.path, string(pBytes); e != a { + t.Errorf("expected %v, got %v", e, a) + } + + if e, a := tt.expected.raw, string(rBytes); e != a { + t.Errorf("expected %v, got %v", e, a) + } + }) + } +} + +func setURI(uv URIValue, args []interface{}) error { + value := args[0] + + switch value.(type) { + case string: + return reflectCall(reflect.ValueOf(uv.String), args) + case float64: + return reflectCall(reflect.ValueOf(uv.Float), args) + case bool: + return reflectCall(reflect.ValueOf(uv.Boolean), args) + case aws.JSONValue: + return reflectCall(reflect.ValueOf(uv.JSONValue), args) + case time.Time: + return reflectCall(reflect.ValueOf(uv.Time), args) + case []byte: + return reflectCall(reflect.ValueOf(uv.ByteSlice), args) + default: + return fmt.Errorf("unhandled value type") + } +} diff --git a/private/protocol/path_replace.go b/private/protocol/path_replace.go index ae64a17dbf8..d9fcb34d291 100644 --- a/private/protocol/path_replace.go +++ b/private/protocol/path_replace.go @@ -43,12 +43,14 @@ func (r *PathReplace) Encode() (path string, rawPath string) { // ReplaceElement replaces a single element in the path string. func (r *PathReplace) ReplaceElement(key, val string) (err error) { - r.path, r.fieldBuf, err = replacePathElement(r.path, r.fieldBuf, key, val, false) - r.rawPath, r.fieldBuf, err = replacePathElement(r.rawPath, r.fieldBuf, key, val, true) + r.path, r.fieldBuf, err = ReplacePathElement(r.path, r.fieldBuf, key, val, false) + r.rawPath, r.fieldBuf, err = ReplacePathElement(r.rawPath, r.fieldBuf, key, val, true) return err } -func replacePathElement(path, fieldBuf []byte, key, val string, escape bool) ([]byte, []byte, error) { +// ReplacePathElement replaces a single element in the path []byte. +// Escape is used to control whether the value will be escaped using Amazon path escape style. +func ReplacePathElement(path, fieldBuf []byte, key, val string, escape bool) ([]byte, []byte, error) { fieldBuf = bufCap(fieldBuf, len(key)+3) // { [+] } fieldBuf = append(fieldBuf, uriTokenStart) fieldBuf = append(fieldBuf, key...)