Skip to content

Commit

Permalink
updated code for comments
Browse files Browse the repository at this point in the history
  • Loading branch information
4gust committed Nov 21, 2024
1 parent abf8b86 commit 73fda09
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
30 changes: 14 additions & 16 deletions apps/managedidentity/managedidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ var retryCodesForIMDS = []int{

// retry on these codes
var retryStatusCodes = []int{
http.StatusNotFound, // 404
http.StatusRequestTimeout, // 408
http.StatusTooManyRequests, // 429
http.StatusInternalServerError, // 500
Expand Down Expand Up @@ -376,28 +375,25 @@ func contains[T comparable](list []T, element T) bool {
}

// retry performs an HTTP request with retries based on the provided options.
func retry(maxRetries int, c ops.HTTPClient, req *http.Request, s Source) (*http.Response, error) {
func (c Client) retry(maxRetries int, req *http.Request) (*http.Response, error) {
var resp *http.Response
var err error
for attempt := 0; attempt < maxRetries; attempt++ {
tryCtx, tryCancel := context.WithTimeout(req.Context(), time.Second*15)
defer tryCancel()
if resp != nil && resp.Body != nil {
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
}
cloneReq := req.Clone(tryCtx)
resp, err = c.Do(cloneReq)
resp, err = c.httpClient.Do(cloneReq)
retrylist := retryStatusCodes
if s == DefaultToIMDS {
if c.source == DefaultToIMDS {
retrylist = retryCodesForIMDS
}
if err == nil && !contains(retrylist, resp.StatusCode) {
return resp, nil
}
if attempt == maxRetries-1 {
return resp, err
}
if resp != nil && resp.Body != nil {
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
}
select {
case <-time.After(time.Second):
case <-req.Context().Done():
Expand All @@ -408,13 +404,15 @@ func retry(maxRetries int, c ops.HTTPClient, req *http.Request, s Source) (*http
return resp, err
}

func (client Client) getTokenForRequest(req *http.Request) (accesstokens.TokenResponse, error) {
func (c Client) getTokenForRequest(req *http.Request) (accesstokens.TokenResponse, error) {
r := accesstokens.TokenResponse{}
retryCount := defaultRetryCount
if !client.retryPolicyEnabled {
retryCount = 1
var resp *http.Response
var err error
if c.retryPolicyEnabled {
resp, err = c.retry(defaultRetryCount, req)
} else {
resp, err = c.httpClient.Do(req)
}
resp, err := retry(retryCount, client.httpClient, req, client.source)
if err != nil {
return r, err
}
Expand Down
6 changes: 5 additions & 1 deletion apps/managedidentity/managedidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,13 @@ func TestRetryFunction(t *testing.T) {
body := bytes.NewBufferString(resp.body)
mockClient.AppendResponse(mock.WithBody(body.Bytes()), mock.WithHTTPStatusCode(resp.statusCode))
}
client, err := New(SystemAssigned(), WithHTTPClient(mockClient), WithRetryPolicyDisabled())
if err != nil {
t.Fatal(err)
}
reqBody := bytes.NewBufferString(tt.requestBody)
req, _ := http.NewRequest("POST", "https://example.com", reqBody)
finalResp, err := retry(tt.maxRetries, mockClient, req, tt.source)
finalResp, err := client.retry(tt.maxRetries, req)
if err != nil {
t.Fatalf("error was not expected %s", err)
}
Expand Down

0 comments on commit 73fda09

Please sign in to comment.