From 73fda09b4f88a6584058c2a0e69c47292ec47f47 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Thu, 21 Nov 2024 17:57:20 +0000 Subject: [PATCH] updated code for comments --- apps/managedidentity/managedidentity.go | 30 +++++++++----------- apps/managedidentity/managedidentity_test.go | 6 +++- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 76f9306c..a68ef0c0 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -98,7 +98,6 @@ var retryCodesForIMDS = []int{ // retry on these codes var retryStatusCodes = []int{ - http.StatusNotFound, // 404 http.StatusRequestTimeout, // 408 http.StatusTooManyRequests, // 429 http.StatusInternalServerError, // 500 @@ -376,28 +375,25 @@ func contains[T comparable](list []T, element T) bool { } // retry performs an HTTP request with retries based on the provided options. -func retry(maxRetries int, c ops.HTTPClient, req *http.Request, s Source) (*http.Response, error) { +func (c Client) retry(maxRetries int, req *http.Request) (*http.Response, error) { var resp *http.Response var err error for attempt := 0; attempt < maxRetries; attempt++ { tryCtx, tryCancel := context.WithTimeout(req.Context(), time.Second*15) defer tryCancel() + if resp != nil && resp.Body != nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } cloneReq := req.Clone(tryCtx) - resp, err = c.Do(cloneReq) + resp, err = c.httpClient.Do(cloneReq) retrylist := retryStatusCodes - if s == DefaultToIMDS { + if c.source == DefaultToIMDS { retrylist = retryCodesForIMDS } if err == nil && !contains(retrylist, resp.StatusCode) { return resp, nil } - if attempt == maxRetries-1 { - return resp, err - } - if resp != nil && resp.Body != nil { - io.Copy(io.Discard, resp.Body) - resp.Body.Close() - } select { case <-time.After(time.Second): case <-req.Context().Done(): @@ -408,13 +404,15 @@ func retry(maxRetries int, c ops.HTTPClient, req *http.Request, s Source) (*http return resp, err } -func (client Client) getTokenForRequest(req *http.Request) (accesstokens.TokenResponse, error) { +func (c Client) getTokenForRequest(req *http.Request) (accesstokens.TokenResponse, error) { r := accesstokens.TokenResponse{} - retryCount := defaultRetryCount - if !client.retryPolicyEnabled { - retryCount = 1 + var resp *http.Response + var err error + if c.retryPolicyEnabled { + resp, err = c.retry(defaultRetryCount, req) + } else { + resp, err = c.httpClient.Do(req) } - resp, err := retry(retryCount, client.httpClient, req, client.source) if err != nil { return r, err } diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index adff7c10..6150071d 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -228,9 +228,13 @@ func TestRetryFunction(t *testing.T) { body := bytes.NewBufferString(resp.body) mockClient.AppendResponse(mock.WithBody(body.Bytes()), mock.WithHTTPStatusCode(resp.statusCode)) } + client, err := New(SystemAssigned(), WithHTTPClient(mockClient), WithRetryPolicyDisabled()) + if err != nil { + t.Fatal(err) + } reqBody := bytes.NewBufferString(tt.requestBody) req, _ := http.NewRequest("POST", "https://example.com", reqBody) - finalResp, err := retry(tt.maxRetries, mockClient, req, tt.source) + finalResp, err := client.retry(tt.maxRetries, req) if err != nil { t.Fatalf("error was not expected %s", err) }