Skip to content

Commit

Permalink
Fix SetCustomHeaders / WithCustomHeaders (#237)
Browse files Browse the repository at this point in the history
  • Loading branch information
averche authored Sep 27, 2023
1 parent c59c70c commit b367e47
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 2 deletions.
6 changes: 5 additions & 1 deletion client_requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,12 @@ func (c *Client) newRequest(
}

// populate request headers
if headers.customHeaders != nil {
req.Header = headers.customHeaders
}

if headers.userAgent != "" {
req.Header.Set("User-Agent", headers.userAgent)
req.Header.Add("User-Agent", headers.userAgent)
}

if headers.token != "" {
Expand Down
96 changes: 96 additions & 0 deletions client_requests_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
package vault

import (
"context"
"io"
"net/http"
"net/url"
"strings"
"testing"

Expand Down Expand Up @@ -65,3 +69,95 @@ func Fuzz_v1Path(f *testing.F) {
}
})
}

func Test_Client_newRequest(t *testing.T) {
// helper to read and close the request body
readClose := func(body io.ReadCloser) string {
b, err := io.ReadAll(body)
if err != nil {
return ""
}
body.Close()

return string(b)
}

cases := map[string]struct {
method string
path string
body io.Reader
parameters url.Values
headers requestHeaders
expect func(t *testing.T, request *http.Request)
}{
"simple": {
method: http.MethodGet,
path: "/some/path",

expect: func(t *testing.T, request *http.Request) {
assert.Equal(t, http.MethodGet, request.Method)
assert.Equal(t, "/some/path", request.URL.Path)
},
},

"with-body": {
method: http.MethodPatch,
path: "/some/path",
body: strings.NewReader("{some body}"),

expect: func(t *testing.T, request *http.Request) {
assert.Equal(t, http.MethodPatch, request.Method)
assert.Equal(t, "/some/path", request.URL.Path)
assert.Equal(t, "{some body}", readClose(request.Body))
},
},

"with-parameters": {
method: http.MethodPost,
path: "/some/path",
parameters: url.Values{"foo": {"bar"}},

expect: func(t *testing.T, request *http.Request) {
assert.Equal(t, http.MethodPost, request.Method)
assert.Equal(t, "/some/path", request.URL.Path)
assert.Equal(t, url.Values{"foo": {"bar"}}, request.URL.Query())
},
},

"with-custom-headers": {
method: http.MethodPut,
path: "/some/path",
headers: requestHeaders{
customHeaders: http.Header{
"Content-Type": {"text/html"},
},
},

expect: func(t *testing.T, request *http.Request) {
assert.Equal(t, http.MethodPut, request.Method)
assert.Equal(t, "/some/path", request.URL.Path)
assert.Equal(t, []string{"text/html"}, request.Header.Values("Content-Type"))
},
},
}

client, err := newClient(DefaultConfiguration())
require.NoError(t, err)

for name, tc := range cases {
t.Run(name, func(t *testing.T) {
request, err := client.newRequest(
context.Background(),
tc.method,
tc.path,
tc.body,
tc.parameters,
tc.headers,
)

require.NoError(t, err)

tc.expect(t, request)
})
}
}
6 changes: 5 additions & 1 deletion request_modifiers.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,11 @@ func mergeRequestModifiers(lhs, rhs *requestModifiers) {
}

// in case of key collisions, the rhs keys will take precedence
maps.Copy(lhs.headers.customHeaders, rhs.headers.customHeaders)
if lhs.headers.customHeaders != nil {
maps.Copy(lhs.headers.customHeaders, rhs.headers.customHeaders)
} else {
lhs.headers.customHeaders = rhs.headers.customHeaders
}

lhs.requestCallbacks = append(
lhs.requestCallbacks,
Expand Down
15 changes: 15 additions & 0 deletions request_modifiers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,21 @@ func Test_mergeRequestModifiers_overwrite(t *testing.T) {
rhs: requestModifiers{headers: requestHeaders{namespace: "namespace-rhs"}},
expected: requestModifiers{headers: requestHeaders{token: "token-lhs", namespace: "namespace-rhs"}},
},
"custom-headers-in-lhs": {
lhs: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"Content-Type": []string{"image/png"}}}},
rhs: requestModifiers{},
expected: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"Content-Type": []string{"image/png"}}}},
},
"custom-headers-in-rhs": {
lhs: requestModifiers{},
rhs: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"Content-Length": []string{"123"}}}},
expected: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"Content-Length": []string{"123"}}}},
},
"custom-headers-in-both": {
lhs: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"Content-Type": []string{"image/png"}}}},
rhs: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"Content-Length": []string{"123"}}}},
expected: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"Content-Type": []string{"image/png"}, "Content-Length": []string{"123"}}}},
},
}

for name, tc := range cases {
Expand Down

0 comments on commit b367e47

Please sign in to comment.