From b367e47fc88e23a785f16e07ec6fae0c87b370f1 Mon Sep 17 00:00:00 2001 From: Anton Averchenkov <84287187+averche@users.noreply.github.com> Date: Wed, 27 Sep 2023 12:10:09 -0400 Subject: [PATCH] Fix SetCustomHeaders / WithCustomHeaders (#237) --- client_requests.go | 6 ++- client_requests_test.go | 96 +++++++++++++++++++++++++++++++++++++++ request_modifiers.go | 6 ++- request_modifiers_test.go | 15 ++++++ 4 files changed, 121 insertions(+), 2 deletions(-) diff --git a/client_requests.go b/client_requests.go index 96ad7f5a..f27374ef 100644 --- a/client_requests.go +++ b/client_requests.go @@ -251,8 +251,12 @@ func (c *Client) newRequest( } // populate request headers + if headers.customHeaders != nil { + req.Header = headers.customHeaders + } + if headers.userAgent != "" { - req.Header.Set("User-Agent", headers.userAgent) + req.Header.Add("User-Agent", headers.userAgent) } if headers.token != "" { diff --git a/client_requests_test.go b/client_requests_test.go index a5400935..7c73dccd 100644 --- a/client_requests_test.go +++ b/client_requests_test.go @@ -4,6 +4,10 @@ package vault import ( + "context" + "io" + "net/http" + "net/url" "strings" "testing" @@ -65,3 +69,95 @@ func Fuzz_v1Path(f *testing.F) { } }) } + +func Test_Client_newRequest(t *testing.T) { + // helper to read and close the request body + readClose := func(body io.ReadCloser) string { + b, err := io.ReadAll(body) + if err != nil { + return "" + } + body.Close() + + return string(b) + } + + cases := map[string]struct { + method string + path string + body io.Reader + parameters url.Values + headers requestHeaders + expect func(t *testing.T, request *http.Request) + }{ + "simple": { + method: http.MethodGet, + path: "/some/path", + + expect: func(t *testing.T, request *http.Request) { + assert.Equal(t, http.MethodGet, request.Method) + assert.Equal(t, "/some/path", request.URL.Path) + }, + }, + + "with-body": { + method: http.MethodPatch, + path: "/some/path", + body: strings.NewReader("{some body}"), + + expect: func(t *testing.T, request *http.Request) { + assert.Equal(t, http.MethodPatch, request.Method) + assert.Equal(t, "/some/path", request.URL.Path) + assert.Equal(t, "{some body}", readClose(request.Body)) + }, + }, + + "with-parameters": { + method: http.MethodPost, + path: "/some/path", + parameters: url.Values{"foo": {"bar"}}, + + expect: func(t *testing.T, request *http.Request) { + assert.Equal(t, http.MethodPost, request.Method) + assert.Equal(t, "/some/path", request.URL.Path) + assert.Equal(t, url.Values{"foo": {"bar"}}, request.URL.Query()) + }, + }, + + "with-custom-headers": { + method: http.MethodPut, + path: "/some/path", + headers: requestHeaders{ + customHeaders: http.Header{ + "Content-Type": {"text/html"}, + }, + }, + + expect: func(t *testing.T, request *http.Request) { + assert.Equal(t, http.MethodPut, request.Method) + assert.Equal(t, "/some/path", request.URL.Path) + assert.Equal(t, []string{"text/html"}, request.Header.Values("Content-Type")) + }, + }, + } + + client, err := newClient(DefaultConfiguration()) + require.NoError(t, err) + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + request, err := client.newRequest( + context.Background(), + tc.method, + tc.path, + tc.body, + tc.parameters, + tc.headers, + ) + + require.NoError(t, err) + + tc.expect(t, request) + }) + } +} diff --git a/request_modifiers.go b/request_modifiers.go index ec7f7b72..4278ebf7 100644 --- a/request_modifiers.go +++ b/request_modifiers.go @@ -279,7 +279,11 @@ func mergeRequestModifiers(lhs, rhs *requestModifiers) { } // in case of key collisions, the rhs keys will take precedence - maps.Copy(lhs.headers.customHeaders, rhs.headers.customHeaders) + if lhs.headers.customHeaders != nil { + maps.Copy(lhs.headers.customHeaders, rhs.headers.customHeaders) + } else { + lhs.headers.customHeaders = rhs.headers.customHeaders + } lhs.requestCallbacks = append( lhs.requestCallbacks, diff --git a/request_modifiers_test.go b/request_modifiers_test.go index 1d6453c7..e4f6ad5f 100644 --- a/request_modifiers_test.go +++ b/request_modifiers_test.go @@ -121,6 +121,21 @@ func Test_mergeRequestModifiers_overwrite(t *testing.T) { rhs: requestModifiers{headers: requestHeaders{namespace: "namespace-rhs"}}, expected: requestModifiers{headers: requestHeaders{token: "token-lhs", namespace: "namespace-rhs"}}, }, + "custom-headers-in-lhs": { + lhs: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"Content-Type": []string{"image/png"}}}}, + rhs: requestModifiers{}, + expected: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"Content-Type": []string{"image/png"}}}}, + }, + "custom-headers-in-rhs": { + lhs: requestModifiers{}, + rhs: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"Content-Length": []string{"123"}}}}, + expected: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"Content-Length": []string{"123"}}}}, + }, + "custom-headers-in-both": { + lhs: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"Content-Type": []string{"image/png"}}}}, + rhs: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"Content-Length": []string{"123"}}}}, + expected: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"Content-Type": []string{"image/png"}, "Content-Length": []string{"123"}}}}, + }, } for name, tc := range cases {