From f8480ad5412fede962a958f957038f288cc7f8c5 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Mon, 21 Oct 2024 21:16:52 +0100 Subject: [PATCH 01/19] Implemented Retry Policy --- apps/managedidentity/managedidentity.go | 67 ++++++++++++++++--- apps/managedidentity/managedidentity_test.go | 70 +++++++++++++++++++- 2 files changed, 126 insertions(+), 11 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index afa028db..6633cc90 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -16,7 +16,9 @@ import ( "io" "net/http" "net/url" + "slices" "strings" + "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" @@ -50,8 +52,28 @@ const ( imdsAPIVersion = "2018-02-01" systemAssignedManagedIdentity = "system_assigned_managed_identity" + defaultRetryCount = 3 ) +// IMDS docs recommend retrying 404, 410, 429 and 5xx +// https://learn.microsoft.com/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#error-handling +var retryStatusCodes = []int{ + http.StatusNotFound, // 404 + http.StatusGone, // 410 + http.StatusTooManyRequests, // 429 // retry after. + http.StatusInternalServerError, // 500 + http.StatusNotImplemented, // 501 + http.StatusBadGateway, // 502 + http.StatusServiceUnavailable, // 503 + http.StatusGatewayTimeout, // 504 + http.StatusHTTPVersionNotSupported, // 505 + http.StatusVariantAlsoNegotiates, // 506 + http.StatusInsufficientStorage, // 507 + http.StatusLoopDetected, // 508 + http.StatusNotExtended, // 510 + http.StatusNetworkAuthenticationRequired, // 511 +} + type Source string type ID interface { @@ -75,13 +97,14 @@ func SystemAssigned() ID { var cacheManager *storage.Manager = storage.New(nil) type Client struct { - httpClient ops.HTTPClient - miType ID - // source Source reenable when required in future sources + httpClient ops.HTTPClient + miType ID + retryPolicyDisabled bool } type ClientOptions struct { - httpClient ops.HTTPClient + httpClient ops.HTTPClient + retryPolicyDiabled bool } type AcquireTokenOptions struct { @@ -107,13 +130,20 @@ func WithHTTPClient(httpClient ops.HTTPClient) ClientOption { } } +func WithRetryPolicyDisabled() ClientOption { + return func(o *ClientOptions) { + o.retryPolicyDiabled = true + } +} + // Client to be used to acquire tokens for managed identity. // ID: [SystemAssigned], [UserAssignedClientID], [UserAssignedResourceID], [UserAssignedObjectID] // // Options: [WithHTTPClient] func New(id ID, options ...ClientOption) (Client, error) { opts := ClientOptions{ - httpClient: shared.DefaultClient, + httpClient: shared.DefaultClient, + retryPolicyDiabled: false, } for _, option := range options { option(&opts) @@ -136,8 +166,9 @@ func New(id ID, options ...ClientOption) (Client, error) { return Client{}, fmt.Errorf("unsupported type %T", id) } client := Client{ - miType: id, - httpClient: opts.httpClient, + miType: id, + httpClient: opts.httpClient, + retryPolicyDisabled: opts.retryPolicyDiabled, } return client, nil } @@ -178,8 +209,24 @@ func createIMDSAuthRequest(ctx context.Context, id ID, resource string, claims s return req, nil } +func retry(attempts int, delay time.Duration, c ops.HTTPClient, req *http.Request) (*http.Response, error) { + resp, err := c.Do(req) + if err == nil && !slices.Contains(retryStatusCodes, resp.StatusCode) { + return resp, nil // Success + } + if attempts-1 < 1 { + return resp, nil + } + time.Sleep(delay) + return retry(attempts-1, delay, c, req) +} + func (client Client) getTokenForRequest(req *http.Request) (accesstokens.TokenResponse, error) { - resp, err := client.httpClient.Do(req) + retryCount := 1 // defaul Count + if !client.retryPolicyDisabled { + retryCount = defaultRetryCount + } + resp, err := retry(retryCount, time.Second, client.httpClient, req) //client.httpClient.Do(req) if err != nil { return accesstokens.TokenResponse{}, err } @@ -254,10 +301,10 @@ func (client Client) AcquireToken(ctx context.Context, resource string, options if err != nil { return base.AuthResult{}, err } - return authResultFromToken(fakeAuthParams, tokenResponse) + return client.authResultFromToken(fakeAuthParams, tokenResponse) } -func authResultFromToken(authParams authority.AuthParams, token accesstokens.TokenResponse) (base.AuthResult, error) { +func (c Client) authResultFromToken(authParams authority.AuthParams, token accesstokens.TokenResponse) (base.AuthResult, error) { if cacheManager == nil { return base.AuthResult{}, fmt.Errorf("cache instance is nil") } diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index e19a66c0..2ed08071 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -102,7 +102,7 @@ func Test_SystemAssigned_Returns_AcquireToken_Failure(t *testing.T) { } fakeErrorClient.AppendResponse(mock.WithHTTPStatusCode(testCase.code), mock.WithBody(responseBody)) - client, err := New(SystemAssigned(), WithHTTPClient(&fakeErrorClient)) + client, err := New(SystemAssigned(), WithHTTPClient(&fakeErrorClient), WithRetryPolicyDisabled()) if err != nil { t.Fatal(err) } @@ -128,6 +128,74 @@ func Test_SystemAssigned_Returns_AcquireToken_Failure(t *testing.T) { } } +func Test_RetryPolicy_For_AcquireToken_Request(t *testing.T) { + t.Run("Testing retry policy with With One failure", func(t *testing.T) { + fakeClient := mock.Client{} + responseBody, err := makeResponseWithErrorData("sample error", "sample error desc") + if err != nil { + t.Fatalf("error while forming json response : %s", err.Error()) + } + errorRetry := 2 + errorRetryCounter := 0 + fakeClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusInternalServerError), + mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) { + errorRetryCounter += 1 + })) + fakeClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusInternalServerError), + mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) { + errorRetryCounter += 1 + })) + fakeClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusUnauthorized), + mock.WithBody(responseBody)) + client, err := New(SystemAssigned(), WithHTTPClient(&fakeClient)) + if err != nil { + t.Fatal(err) + } + resp, err := client.AcquireToken(context.Background(), resource) + if err == nil { + t.Fatalf("should have encountered the error") + } + if resp.AccessToken != "" { + t.Fatalf("wanted %q, got %q", "", resp.AccessToken) + } + if errorRetryCounter != errorRetry { + t.Fatalf("expected Number of retry of %d, got %d", errorRetry, errorRetryCounter) + } + }) +} + +func Test_RetryPolicy_For_AcquireToken_Request_MaxTries(t *testing.T) { + t.Run("Testing retry policy with With One failure", func(t *testing.T) { + fakeErrorClient := mock.Client{} + responseBody, err := makeResponseWithErrorData("sample error", "sample error desc") + if err != nil { + t.Fatalf("error while forming json response : %s", err.Error()) + } + errorRetry := 4 + errorRetryCounter := 0 + for i := 0; i < errorRetry; i++ { + fakeErrorClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusInternalServerError), + mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) { + errorRetryCounter += 1 + })) + } + client, err := New(SystemAssigned(), WithHTTPClient(&fakeErrorClient)) + if err != nil { + t.Fatal(err) + } + resp, err := client.AcquireToken(context.Background(), resource) + if err == nil { + t.Fatalf("should have encountered the error") + } + if resp.AccessToken != "" { + t.Fatalf("wanted %q, got %q", "", resp.AccessToken) + } + if errorRetryCounter != defaultRetryCount { + t.Fatalf("expected Number of retry of %d, got %d", errorRetry, errorRetryCounter) + } + }) +} + func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { testCases := []resourceTestData{ {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resource, miType: SystemAssigned()}, From 10df4c9fa25dfad85f57c699ff8957bd4309f995 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 22 Oct 2024 23:16:37 +0100 Subject: [PATCH 02/19] Fixed the tests --- apps/managedidentity/managedidentity.go | 5 +- apps/managedidentity/managedidentity_test.go | 145 ++++++++++--------- 2 files changed, 82 insertions(+), 68 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 6633cc90..3d75b98a 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -268,7 +268,6 @@ func (client Client) getTokenForRequest(req *http.Request) (accesstokens.TokenRe // Options: [WithClaims] func (client Client) AcquireToken(ctx context.Context, resource string, options ...AcquireTokenOption) (base.AuthResult, error) { o := AcquireTokenOptions{} - for _, option := range options { option(&o) } @@ -301,10 +300,10 @@ func (client Client) AcquireToken(ctx context.Context, resource string, options if err != nil { return base.AuthResult{}, err } - return client.authResultFromToken(fakeAuthParams, tokenResponse) + return authResultFromToken(fakeAuthParams, tokenResponse) } -func (c Client) authResultFromToken(authParams authority.AuthParams, token accesstokens.TokenResponse) (base.AuthResult, error) { +func authResultFromToken(authParams authority.AuthParams, token accesstokens.TokenResponse) (base.AuthResult, error) { if cacheManager == nil { return base.AuthResult{}, fmt.Errorf("cache instance is nil") } diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 2ed08071..c645dd7d 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -5,9 +5,11 @@ package managedidentity import ( "context" "encoding/json" + "fmt" "net/http" "net/url" "strings" + "sync" "testing" "time" @@ -95,6 +97,7 @@ func Test_SystemAssigned_Returns_AcquireToken_Failure(t *testing.T) { for _, testCase := range testCases { t.Run(http.StatusText(testCase.code), func(t *testing.T) { + fakeErrorClient := mock.Client{} responseBody, err := makeResponseWithErrorData(testCase.err, testCase.desc) if err != nil { @@ -128,72 +131,84 @@ func Test_SystemAssigned_Returns_AcquireToken_Failure(t *testing.T) { } } -func Test_RetryPolicy_For_AcquireToken_Request(t *testing.T) { - t.Run("Testing retry policy with With One failure", func(t *testing.T) { - fakeClient := mock.Client{} - responseBody, err := makeResponseWithErrorData("sample error", "sample error desc") - if err != nil { - t.Fatalf("error while forming json response : %s", err.Error()) - } - errorRetry := 2 - errorRetryCounter := 0 - fakeClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusInternalServerError), - mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) { - errorRetryCounter += 1 - })) - fakeClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusInternalServerError), - mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) { - errorRetryCounter += 1 - })) - fakeClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusUnauthorized), - mock.WithBody(responseBody)) - client, err := New(SystemAssigned(), WithHTTPClient(&fakeClient)) - if err != nil { - t.Fatal(err) - } - resp, err := client.AcquireToken(context.Background(), resource) - if err == nil { - t.Fatalf("should have encountered the error") - } - if resp.AccessToken != "" { - t.Fatalf("wanted %q, got %q", "", resp.AccessToken) - } - if errorRetryCounter != errorRetry { - t.Fatalf("expected Number of retry of %d, got %d", errorRetry, errorRetryCounter) - } - }) -} +func Test_RetryPolicy_For_AcquireToken_Failure(t *testing.T) { + testCases := []struct { + numberOfFails int + expectedFail bool + disableRetry bool + }{ + {numberOfFails: 1, expectedFail: false, disableRetry: false}, + {numberOfFails: 1, expectedFail: true, disableRetry: true}, + {numberOfFails: 1, expectedFail: true, disableRetry: true}, + {numberOfFails: 2, expectedFail: false, disableRetry: false}, + {numberOfFails: 3, expectedFail: true, disableRetry: false}, + } + for _, testCase := range testCases { + t.Run(fmt.Sprintf("Testing retry policy with %d ", testCase.numberOfFails), func(t *testing.T) { + var wg sync.WaitGroup + fakeErrorClient := mock.Client{} + responseBody, err := makeResponseWithErrorData("sample error", "sample error desc") + if err != nil { + t.Fatalf("error while forming json response : %s", err.Error()) + } + errorRetryCounter := 0 + wg.Add(testCase.numberOfFails) + for i := 0; i < testCase.numberOfFails; i++ { + fakeErrorClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusInternalServerError), + mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) { + errorRetryCounter++ + wg.Done() + })) + } + if !testCase.expectedFail { + wg.Add(1) + successRespBody, err := getSuccessfulResponse(resource) + if err != nil { + t.Fatalf("error while forming json response : %s", err.Error()) + } + fakeErrorClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusAccepted), + mock.WithBody(successRespBody), mock.WithCallback(func(r *http.Request) { + wg.Done() + })) + } + var client Client + if testCase.disableRetry { + client, err = New(SystemAssigned(), WithHTTPClient(&fakeErrorClient), WithRetryPolicyDisabled()) + } else { + client, err = New(SystemAssigned(), WithHTTPClient(&fakeErrorClient)) -func Test_RetryPolicy_For_AcquireToken_Request_MaxTries(t *testing.T) { - t.Run("Testing retry policy with With One failure", func(t *testing.T) { - fakeErrorClient := mock.Client{} - responseBody, err := makeResponseWithErrorData("sample error", "sample error desc") - if err != nil { - t.Fatalf("error while forming json response : %s", err.Error()) - } - errorRetry := 4 - errorRetryCounter := 0 - for i := 0; i < errorRetry; i++ { - fakeErrorClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusInternalServerError), - mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) { - errorRetryCounter += 1 - })) - } - client, err := New(SystemAssigned(), WithHTTPClient(&fakeErrorClient)) - if err != nil { - t.Fatal(err) - } - resp, err := client.AcquireToken(context.Background(), resource) - if err == nil { - t.Fatalf("should have encountered the error") - } - if resp.AccessToken != "" { - t.Fatalf("wanted %q, got %q", "", resp.AccessToken) - } - if errorRetryCounter != defaultRetryCount { - t.Fatalf("expected Number of retry of %d, got %d", errorRetry, errorRetryCounter) - } - }) + } + if err != nil { + t.Fatal(err) + } + resp, err := client.AcquireToken(context.Background(), resource, WithClaims("noCache")) + wg.Wait() + if testCase.expectedFail { + if err == nil { + t.Fatalf("should have encountered the error") + } + if resp.AccessToken != "" { + t.Fatalf("accesstoken should be empty") + } + } else { + if err != nil { + t.Fatalf("should have encountered the error") + } + if resp.AccessToken != token { + t.Fatalf("wanted %q, got %q", token, resp.AccessToken) + } + } + if testCase.disableRetry { + if errorRetryCounter != 1 { + t.Fatalf("expected Number of retry of 1, got %d", errorRetryCounter) + } + } else { + if errorRetryCounter != testCase.numberOfFails && testCase.numberOfFails < 3 { + t.Fatalf("expected Number of retry of %d, got %d", testCase.numberOfFails, errorRetryCounter) + } + } + }) + } } func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { From 1782b07b14d3e8ca19025e473c1355f9b32d2518 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Wed, 23 Oct 2024 15:46:57 +0100 Subject: [PATCH 03/19] Updated the retry policy to respect context Added test to check request body is same for each request. --- apps/managedidentity/managedidentity.go | 47 +++++++--- apps/managedidentity/managedidentity_test.go | 90 ++++++++++++++++++-- 2 files changed, 115 insertions(+), 22 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 3d75b98a..00a33a6d 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -16,7 +16,6 @@ import ( "io" "net/http" "net/url" - "slices" "strings" "time" @@ -142,8 +141,7 @@ func WithRetryPolicyDisabled() ClientOption { // Options: [WithHTTPClient] func New(id ID, options ...ClientOption) (Client, error) { opts := ClientOptions{ - httpClient: shared.DefaultClient, - retryPolicyDiabled: false, + httpClient: shared.DefaultClient, } for _, option := range options { option(&opts) @@ -209,24 +207,47 @@ func createIMDSAuthRequest(ctx context.Context, id ID, resource string, claims s return req, nil } -func retry(attempts int, delay time.Duration, c ops.HTTPClient, req *http.Request) (*http.Response, error) { - resp, err := c.Do(req) - if err == nil && !slices.Contains(retryStatusCodes, resp.StatusCode) { - return resp, nil // Success +// Contains checks if the element is present in the list. +func contains[T comparable](list []T, element T) bool { + for _, v := range list { + if v == element { + return true + } } - if attempts-1 < 1 { - return resp, nil + return false +} + +// retry performs an HTTP request with retries based on the provided options. +func retry(maxRetries int, c ops.HTTPClient, req *http.Request) (*http.Response, error) { + var resp *http.Response + var err error + for attempt := 0; attempt < maxRetries; attempt++ { + tryCtx, tryCancel := context.WithTimeout(req.Context(), time.Second*15) + defer tryCancel() + cloneReq := req.Clone(tryCtx) + resp, err = c.Do(cloneReq) + if err == nil && !contains(retryStatusCodes, resp.StatusCode) { + return resp, nil + } + if attempt == maxRetries-1 { + return resp, err + } + if resp != nil && resp.Body != nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + delay := time.Duration(time.Second) + time.Sleep(delay) } - time.Sleep(delay) - return retry(attempts-1, delay, c, req) + return resp, err } func (client Client) getTokenForRequest(req *http.Request) (accesstokens.TokenResponse, error) { - retryCount := 1 // defaul Count + retryCount := 1 if !client.retryPolicyDisabled { retryCount = defaultRetryCount } - resp, err := retry(retryCount, time.Second, client.httpClient, req) //client.httpClient.Do(req) + resp, err := retry(retryCount, client.httpClient, req) if err != nil { return accesstokens.TokenResponse{}, err } diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index c645dd7d..01df31a0 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -3,13 +3,14 @@ package managedidentity import ( + "bytes" "context" "encoding/json" "fmt" + "io" "net/http" "net/url" "strings" - "sync" "testing" "time" @@ -131,6 +132,84 @@ func Test_SystemAssigned_Returns_AcquireToken_Failure(t *testing.T) { } } +func TestRetryFunction(t *testing.T) { + tests := []struct { + name string + mockResponses []struct { + body string + statusCode int + } + expectedStatus int + expectedBody string + maxRetries int + requestBody string + }{ + { + name: "Successful Request", + mockResponses: []struct { + body string + statusCode int + }{ + {"Failed", http.StatusInternalServerError}, + {"Success", http.StatusOK}, + }, + expectedStatus: http.StatusOK, + expectedBody: "Success", + maxRetries: 3, + requestBody: "Test Body", + }, + { + name: "Max Retries Reached", + mockResponses: []struct { + body string + statusCode int + }{ + {"Error", http.StatusInternalServerError}, + {"Error", http.StatusInternalServerError}, + }, + expectedStatus: http.StatusInternalServerError, + expectedBody: "Error", + maxRetries: 2, + requestBody: "Test Body", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockClient := &mock.Client{} + for _, resp := range tt.mockResponses { + body := bytes.NewBufferString(resp.body) + mockClient.AppendResponse(mock.WithBody(body.Bytes()), mock.WithHTTPStatusCode(resp.statusCode)) + } + reqBody := bytes.NewBufferString(tt.requestBody) + req, _ := http.NewRequest("POST", "https://example.com", reqBody) + finalResp, err := retry(tt.maxRetries, mockClient, req) + if finalResp.StatusCode != tt.expectedStatus { + t.Fatalf("Expected status code %d, got %d", tt.expectedStatus, finalResp.StatusCode) + } + bodyBytes, err := io.ReadAll(finalResp.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + finalResp.Body.Close() // Close the body after reading + if string(bodyBytes) != tt.expectedBody { + t.Fatalf("Expected body %q, got %q", tt.expectedBody, bodyBytes) + } + if req.Body != nil { + reqBodyBytes, err := io.ReadAll(req.Body) + if err != nil { + t.Fatalf("Failed to read request body: %v", err) + } + req.Body.Close() + + if string(reqBodyBytes) != tt.requestBody { + t.Fatalf("Expected request body %q, got %q", tt.requestBody, reqBodyBytes) + } + } + }) + } +} + func Test_RetryPolicy_For_AcquireToken_Failure(t *testing.T) { testCases := []struct { numberOfFails int @@ -145,31 +224,25 @@ func Test_RetryPolicy_For_AcquireToken_Failure(t *testing.T) { } for _, testCase := range testCases { t.Run(fmt.Sprintf("Testing retry policy with %d ", testCase.numberOfFails), func(t *testing.T) { - var wg sync.WaitGroup fakeErrorClient := mock.Client{} responseBody, err := makeResponseWithErrorData("sample error", "sample error desc") if err != nil { t.Fatalf("error while forming json response : %s", err.Error()) } errorRetryCounter := 0 - wg.Add(testCase.numberOfFails) for i := 0; i < testCase.numberOfFails; i++ { fakeErrorClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusInternalServerError), mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) { errorRetryCounter++ - wg.Done() })) } if !testCase.expectedFail { - wg.Add(1) successRespBody, err := getSuccessfulResponse(resource) if err != nil { t.Fatalf("error while forming json response : %s", err.Error()) } fakeErrorClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusAccepted), - mock.WithBody(successRespBody), mock.WithCallback(func(r *http.Request) { - wg.Done() - })) + mock.WithBody(successRespBody)) } var client Client if testCase.disableRetry { @@ -182,7 +255,6 @@ func Test_RetryPolicy_For_AcquireToken_Failure(t *testing.T) { t.Fatal(err) } resp, err := client.AcquireToken(context.Background(), resource, WithClaims("noCache")) - wg.Wait() if testCase.expectedFail { if err == nil { t.Fatalf("should have encountered the error") From a60fd6d76012a1774e1d28515444cb64c9cc1499 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Thu, 24 Oct 2024 13:19:26 +0100 Subject: [PATCH 04/19] Updated the variable name to remove negation --- apps/managedidentity/managedidentity.go | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 00a33a6d..2670802a 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -96,14 +96,14 @@ func SystemAssigned() ID { var cacheManager *storage.Manager = storage.New(nil) type Client struct { - httpClient ops.HTTPClient - miType ID - retryPolicyDisabled bool + httpClient ops.HTTPClient + miType ID + retryPolicyEnabled bool } type ClientOptions struct { httpClient ops.HTTPClient - retryPolicyDiabled bool + retryPolicyEnabled bool } type AcquireTokenOptions struct { @@ -131,7 +131,7 @@ func WithHTTPClient(httpClient ops.HTTPClient) ClientOption { func WithRetryPolicyDisabled() ClientOption { return func(o *ClientOptions) { - o.retryPolicyDiabled = true + o.retryPolicyEnabled = false } } @@ -141,7 +141,8 @@ func WithRetryPolicyDisabled() ClientOption { // Options: [WithHTTPClient] func New(id ID, options ...ClientOption) (Client, error) { opts := ClientOptions{ - httpClient: shared.DefaultClient, + httpClient: shared.DefaultClient, + retryPolicyEnabled: true, } for _, option := range options { option(&opts) @@ -164,9 +165,9 @@ func New(id ID, options ...ClientOption) (Client, error) { return Client{}, fmt.Errorf("unsupported type %T", id) } client := Client{ - miType: id, - httpClient: opts.httpClient, - retryPolicyDisabled: opts.retryPolicyDiabled, + miType: id, + httpClient: opts.httpClient, + retryPolicyEnabled: opts.retryPolicyEnabled, } return client, nil } @@ -243,9 +244,9 @@ func retry(maxRetries int, c ops.HTTPClient, req *http.Request) (*http.Response, } func (client Client) getTokenForRequest(req *http.Request) (accesstokens.TokenResponse, error) { - retryCount := 1 - if !client.retryPolicyDisabled { - retryCount = defaultRetryCount + retryCount := 3 + if !client.retryPolicyEnabled { + retryCount = 1 } resp, err := retry(retryCount, client.httpClient, req) if err != nil { From 716cbce6fc84a99fa6a7ff7e89264d47b1804b78 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 29 Oct 2024 11:49:30 +0000 Subject: [PATCH 05/19] Update managedidentity.go Updating the error code for retry --- apps/managedidentity/managedidentity.go | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 2670802a..22a84c72 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -54,23 +54,15 @@ const ( defaultRetryCount = 3 ) -// IMDS docs recommend retrying 404, 410, 429 and 5xx +// IMDS docs recommend retrying 408, 429, 500, 502, 503, 504 // https://learn.microsoft.com/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#error-handling var retryStatusCodes = []int{ - http.StatusNotFound, // 404 - http.StatusGone, // 410 - http.StatusTooManyRequests, // 429 // retry after. - http.StatusInternalServerError, // 500 - http.StatusNotImplemented, // 501 - http.StatusBadGateway, // 502 - http.StatusServiceUnavailable, // 503 - http.StatusGatewayTimeout, // 504 - http.StatusHTTPVersionNotSupported, // 505 - http.StatusVariantAlsoNegotiates, // 506 - http.StatusInsufficientStorage, // 507 - http.StatusLoopDetected, // 508 - http.StatusNotExtended, // 510 - http.StatusNetworkAuthenticationRequired, // 511 + http.StatusRequestTimeout, // 408 + http.StatusTooManyRequests, // 429 + http.StatusInternalServerError, // 500 + http.StatusBadGateway, // 502 + http.StatusServiceUnavailable, // 503 + http.StatusGatewayTimeout, // 504 } type Source string From 8319e22623f7c50ecae1b122fe156191159df836 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 29 Oct 2024 11:59:50 +0000 Subject: [PATCH 06/19] Update managedidentity_test.go --- apps/managedidentity/managedidentity_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 01df31a0..a4271dc5 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -184,6 +184,9 @@ func TestRetryFunction(t *testing.T) { reqBody := bytes.NewBufferString(tt.requestBody) req, _ := http.NewRequest("POST", "https://example.com", reqBody) finalResp, err := retry(tt.maxRetries, mockClient, req) + if err != nil { + t.Fatalf("error was not expected %s", err) + } if finalResp.StatusCode != tt.expectedStatus { t.Fatalf("Expected status code %d, got %d", tt.expectedStatus, finalResp.StatusCode) } From 31fc7a074d23c4e67093810b27dad7b085403585 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Fri, 1 Nov 2024 00:48:04 +0000 Subject: [PATCH 07/19] Update managedidentity.go Updated the comment --- apps/managedidentity/managedidentity.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 22a84c72..1adf6e36 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -54,8 +54,7 @@ const ( defaultRetryCount = 3 ) -// IMDS docs recommend retrying 408, 429, 500, 502, 503, 504 -// https://learn.microsoft.com/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#error-handling +// retry on these codes var retryStatusCodes = []int{ http.StatusRequestTimeout, // 408 http.StatusTooManyRequests, // 429 From 54ff161eb6e996ec05893a76dd49d52116622424 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Wed, 6 Nov 2024 19:02:39 +0000 Subject: [PATCH 08/19] Update managedidentity.go Removed the condition and copied no matter the body size. --- apps/managedidentity/managedidentity.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 1adf6e36..e4a75572 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -224,10 +224,9 @@ func retry(maxRetries int, c ops.HTTPClient, req *http.Request) (*http.Response, if attempt == maxRetries-1 { return resp, err } - if resp != nil && resp.Body != nil { - io.Copy(io.Discard, resp.Body) - resp.Body.Close() - } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + delay := time.Duration(time.Second) time.Sleep(delay) } From 4e2ed034d1675a7847c1b18729ab006699746dd6 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Thu, 7 Nov 2024 17:52:01 +0000 Subject: [PATCH 09/19] Added a context exit for request. Added a context cancel --- apps/managedidentity/managedidentity.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index e4a75572..0c213af1 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -224,11 +224,16 @@ func retry(maxRetries int, c ops.HTTPClient, req *http.Request) (*http.Response, if attempt == maxRetries-1 { return resp, err } - io.Copy(io.Discard, resp.Body) - resp.Body.Close() - - delay := time.Duration(time.Second) - time.Sleep(delay) + if resp != nil && resp.Body != nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + select { + case <-time.After(time.Second): + case <-req.Context().Done(): + err = req.Context().Err() + return resp, err + } } return resp, err } From c2252212c67b40150526506dea6796030fad2901 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Fri, 15 Nov 2024 11:06:15 +0000 Subject: [PATCH 10/19] Added a source based retry --- apps/managedidentity/managedidentity.go | 24 +++++++++++-- apps/managedidentity/managedidentity_test.go | 36 ++++++++++++++++++-- 2 files changed, 55 insertions(+), 5 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index f434e427..6deff040 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -54,6 +54,18 @@ const ( defaultRetryCount = 3 ) +var retryCodesForIMDS = []int{ + http.StatusNotFound, // 404 + http.StatusGone, // 410 + http.StatusNotImplemented, // 501 + http.StatusHTTPVersionNotSupported, // 505 + http.StatusVariantAlsoNegotiates, // 506 + http.StatusInsufficientStorage, // 507 + http.StatusLoopDetected, // 508 + http.StatusNotExtended, // 510 + http.StatusNetworkAuthenticationRequired, // 511 +} + // retry on these codes var retryStatusCodes = []int{ http.StatusRequestTimeout, // 408 @@ -89,6 +101,7 @@ var cacheManager *storage.Manager = storage.New(nil) type Client struct { httpClient ops.HTTPClient miType ID + source Source retryPolicyEnabled bool } @@ -159,6 +172,7 @@ func New(id ID, options ...ClientOption) (Client, error) { miType: id, httpClient: opts.httpClient, retryPolicyEnabled: opts.retryPolicyEnabled, + source: DefaultToIMDS, } return client, nil } @@ -209,7 +223,7 @@ func contains[T comparable](list []T, element T) bool { } // retry performs an HTTP request with retries based on the provided options. -func retry(maxRetries int, c ops.HTTPClient, req *http.Request) (*http.Response, error) { +func retry(maxRetries int, c ops.HTTPClient, req *http.Request, s Source) (*http.Response, error) { var resp *http.Response var err error for attempt := 0; attempt < maxRetries; attempt++ { @@ -217,7 +231,11 @@ func retry(maxRetries int, c ops.HTTPClient, req *http.Request) (*http.Response, defer tryCancel() cloneReq := req.Clone(tryCtx) resp, err = c.Do(cloneReq) - if err == nil && !contains(retryStatusCodes, resp.StatusCode) { + retrylist := retryStatusCodes + if s == DefaultToIMDS { + retrylist = append(retrylist, retryCodesForIMDS...) + } + if err == nil && !contains(retrylist, resp.StatusCode) { return resp, nil } if attempt == maxRetries-1 { @@ -242,7 +260,7 @@ func (client Client) getTokenForRequest(req *http.Request) (accesstokens.TokenRe if !client.retryPolicyEnabled { retryCount = 1 } - resp, err := retry(retryCount, client.httpClient, req) + resp, err := retry(retryCount, client.httpClient, req, client.source) if err != nil { return accesstokens.TokenResponse{}, err } diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 7b9bdc8d..1e8f1609 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -143,6 +143,7 @@ func TestRetryFunction(t *testing.T) { expectedBody string maxRetries int requestBody string + source Source }{ { name: "Successful Request", @@ -157,6 +158,22 @@ func TestRetryFunction(t *testing.T) { expectedBody: "Success", maxRetries: 3, requestBody: "Test Body", + source: AzureArc, + }, + { + name: "Successful Request", + mockResponses: []struct { + body string + statusCode int + }{ + {"Failed", http.StatusNotFound}, + {"Success", http.StatusOK}, + }, + expectedStatus: http.StatusOK, + expectedBody: "Success", + maxRetries: 3, + requestBody: "Test Body", + source: DefaultToIMDS, }, { name: "Max Retries Reached", @@ -171,6 +188,22 @@ func TestRetryFunction(t *testing.T) { expectedBody: "Error", maxRetries: 2, requestBody: "Test Body", + source: AzureArc, + }, + { + name: "Max Retries Reached", + mockResponses: []struct { + body string + statusCode int + }{ + {"Error", http.StatusNotFound}, + {"Error", http.StatusInternalServerError}, + }, + expectedStatus: http.StatusInternalServerError, + expectedBody: "Error", + maxRetries: 2, + requestBody: "Test Body", + source: DefaultToIMDS, }, } @@ -183,7 +216,7 @@ func TestRetryFunction(t *testing.T) { } reqBody := bytes.NewBufferString(tt.requestBody) req, _ := http.NewRequest("POST", "https://example.com", reqBody) - finalResp, err := retry(tt.maxRetries, mockClient, req) + finalResp, err := retry(tt.maxRetries, mockClient, req, tt.source) if err != nil { t.Fatalf("error was not expected %s", err) } @@ -252,7 +285,6 @@ func Test_RetryPolicy_For_AcquireToken_Failure(t *testing.T) { client, err = New(SystemAssigned(), WithHTTPClient(&fakeErrorClient), WithRetryPolicyDisabled()) } else { client, err = New(SystemAssigned(), WithHTTPClient(&fakeErrorClient)) - } if err != nil { t.Fatal(err) From e5c8bc7354bb3d81e7e613ab4f5cebad3ee416d6 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Fri, 15 Nov 2024 11:36:30 +0000 Subject: [PATCH 11/19] Updated status code list. --- apps/managedidentity/managedidentity.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 6deff040..5e5b357a 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -55,9 +55,9 @@ const ( ) var retryCodesForIMDS = []int{ - http.StatusNotFound, // 404 http.StatusGone, // 410 http.StatusNotImplemented, // 501 + http.StatusBadGateway, // 502 http.StatusHTTPVersionNotSupported, // 505 http.StatusVariantAlsoNegotiates, // 506 http.StatusInsufficientStorage, // 507 @@ -68,10 +68,10 @@ var retryCodesForIMDS = []int{ // retry on these codes var retryStatusCodes = []int{ + http.StatusNotFound, // 404 http.StatusRequestTimeout, // 408 http.StatusTooManyRequests, // 429 http.StatusInternalServerError, // 500 - http.StatusBadGateway, // 502 http.StatusServiceUnavailable, // 503 http.StatusGatewayTimeout, // 504 } From 4f3b41453ebcadd9c0a76c3e7097a971796fd20c Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Fri, 15 Nov 2024 11:38:43 +0000 Subject: [PATCH 12/19] Updated comment. --- apps/managedidentity/managedidentity.go | 1 + 1 file changed, 1 insertion(+) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 5e5b357a..4b14ca48 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -54,6 +54,7 @@ const ( defaultRetryCount = 3 ) +// retry codes for IMDS var retryCodesForIMDS = []int{ http.StatusGone, // 410 http.StatusNotImplemented, // 501 From abf8b864e6238e5787a6df40a4371b4d9a26347f Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 19 Nov 2024 10:09:50 +0000 Subject: [PATCH 13/19] Updated retry code logic. --- apps/managedidentity/managedidentity.go | 9 +++++++-- apps/managedidentity/managedidentity_test.go | 1 - 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index d68eb1d4..76f9306c 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -79,9 +79,15 @@ const ( // retry codes for IMDS var retryCodesForIMDS = []int{ + http.StatusNotFound, // 404 + http.StatusRequestTimeout, // 408 http.StatusGone, // 410 + http.StatusTooManyRequests, // 429 + http.StatusInternalServerError, // 500 http.StatusNotImplemented, // 501 http.StatusBadGateway, // 502 + http.StatusServiceUnavailable, // 503 + http.StatusGatewayTimeout, // 504 http.StatusHTTPVersionNotSupported, // 505 http.StatusVariantAlsoNegotiates, // 506 http.StatusInsufficientStorage, // 507 @@ -98,7 +104,6 @@ var retryStatusCodes = []int{ http.StatusInternalServerError, // 500 http.StatusServiceUnavailable, // 503 http.StatusGatewayTimeout, // 504 - } var getAzureArcPlatformPath = func(platform string) string { @@ -381,7 +386,7 @@ func retry(maxRetries int, c ops.HTTPClient, req *http.Request, s Source) (*http resp, err = c.Do(cloneReq) retrylist := retryStatusCodes if s == DefaultToIMDS { - retrylist = append(retrylist, retryCodesForIMDS...) + retrylist = retryCodesForIMDS } if err == nil && !contains(retrylist, resp.StatusCode) { return resp, nil diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index efc24cf6..adff7c10 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -593,7 +593,6 @@ func TestAzureArcPlatformSupported(t *testing.T) { } result, err := client.AcquireToken(context.Background(), resource) if err == nil || !strings.Contains(err.Error(), "platform not supported") { - println(result.AccessToken) t.Fatalf(`expected error: "%v" got error: "%v"`, "platform not supported", err) } From 1cd9908f2027a20fa10408243fb233ae41fb619e Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary <107404295+4gust@users.noreply.github.com> Date: Thu, 21 Nov 2024 17:56:17 +0000 Subject: [PATCH 14/19] Update apps/managedidentity/managedidentity.go Co-authored-by: Charles Lowell <10964656+chlowell@users.noreply.github.com> --- apps/managedidentity/managedidentity.go | 1 + 1 file changed, 1 insertion(+) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 76f9306c..4787399f 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -102,6 +102,7 @@ var retryStatusCodes = []int{ http.StatusRequestTimeout, // 408 http.StatusTooManyRequests, // 429 http.StatusInternalServerError, // 500 + http.StatusBadGateway, // 502 http.StatusServiceUnavailable, // 503 http.StatusGatewayTimeout, // 504 } From 73fda09b4f88a6584058c2a0e69c47292ec47f47 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Thu, 21 Nov 2024 17:57:20 +0000 Subject: [PATCH 15/19] updated code for comments --- apps/managedidentity/managedidentity.go | 30 +++++++++----------- apps/managedidentity/managedidentity_test.go | 6 +++- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 76f9306c..a68ef0c0 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -98,7 +98,6 @@ var retryCodesForIMDS = []int{ // retry on these codes var retryStatusCodes = []int{ - http.StatusNotFound, // 404 http.StatusRequestTimeout, // 408 http.StatusTooManyRequests, // 429 http.StatusInternalServerError, // 500 @@ -376,28 +375,25 @@ func contains[T comparable](list []T, element T) bool { } // retry performs an HTTP request with retries based on the provided options. -func retry(maxRetries int, c ops.HTTPClient, req *http.Request, s Source) (*http.Response, error) { +func (c Client) retry(maxRetries int, req *http.Request) (*http.Response, error) { var resp *http.Response var err error for attempt := 0; attempt < maxRetries; attempt++ { tryCtx, tryCancel := context.WithTimeout(req.Context(), time.Second*15) defer tryCancel() + if resp != nil && resp.Body != nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } cloneReq := req.Clone(tryCtx) - resp, err = c.Do(cloneReq) + resp, err = c.httpClient.Do(cloneReq) retrylist := retryStatusCodes - if s == DefaultToIMDS { + if c.source == DefaultToIMDS { retrylist = retryCodesForIMDS } if err == nil && !contains(retrylist, resp.StatusCode) { return resp, nil } - if attempt == maxRetries-1 { - return resp, err - } - if resp != nil && resp.Body != nil { - io.Copy(io.Discard, resp.Body) - resp.Body.Close() - } select { case <-time.After(time.Second): case <-req.Context().Done(): @@ -408,13 +404,15 @@ func retry(maxRetries int, c ops.HTTPClient, req *http.Request, s Source) (*http return resp, err } -func (client Client) getTokenForRequest(req *http.Request) (accesstokens.TokenResponse, error) { +func (c Client) getTokenForRequest(req *http.Request) (accesstokens.TokenResponse, error) { r := accesstokens.TokenResponse{} - retryCount := defaultRetryCount - if !client.retryPolicyEnabled { - retryCount = 1 + var resp *http.Response + var err error + if c.retryPolicyEnabled { + resp, err = c.retry(defaultRetryCount, req) + } else { + resp, err = c.httpClient.Do(req) } - resp, err := retry(retryCount, client.httpClient, req, client.source) if err != nil { return r, err } diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index adff7c10..6150071d 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -228,9 +228,13 @@ func TestRetryFunction(t *testing.T) { body := bytes.NewBufferString(resp.body) mockClient.AppendResponse(mock.WithBody(body.Bytes()), mock.WithHTTPStatusCode(resp.statusCode)) } + client, err := New(SystemAssigned(), WithHTTPClient(mockClient), WithRetryPolicyDisabled()) + if err != nil { + t.Fatal(err) + } reqBody := bytes.NewBufferString(tt.requestBody) req, _ := http.NewRequest("POST", "https://example.com", reqBody) - finalResp, err := retry(tt.maxRetries, mockClient, req, tt.source) + finalResp, err := client.retry(tt.maxRetries, req) if err != nil { t.Fatalf("error was not expected %s", err) } From 23290c4c40b5441eb7d704212844e23da7c9ac87 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Mon, 25 Nov 2024 17:25:50 +0000 Subject: [PATCH 16/19] Updated tests --- apps/managedidentity/managedidentity.go | 1 - apps/managedidentity/managedidentity_test.go | 24 +++++++------------- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 39b38a23..235c4b67 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -80,7 +80,6 @@ const ( // retry codes for IMDS var retryCodesForIMDS = []int{ http.StatusNotFound, // 404 - http.StatusRequestTimeout, // 408 http.StatusGone, // 410 http.StatusTooManyRequests, // 429 http.StatusInternalServerError, // 500 diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 6150071d..04534676 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -233,10 +233,13 @@ func TestRetryFunction(t *testing.T) { t.Fatal(err) } reqBody := bytes.NewBufferString(tt.requestBody) - req, _ := http.NewRequest("POST", "https://example.com", reqBody) + req, err := http.NewRequest("POST", "https://example.com", reqBody) + if err != nil { + t.Fatal(err) + } finalResp, err := client.retry(tt.maxRetries, req) if err != nil { - t.Fatalf("error was not expected %s", err) + t.Fatal(err) } if finalResp.StatusCode != tt.expectedStatus { t.Fatalf("Expected status code %d, got %d", tt.expectedStatus, finalResp.StatusCode) @@ -245,26 +248,15 @@ func TestRetryFunction(t *testing.T) { if err != nil { t.Fatalf("Failed to read response body: %v", err) } - finalResp.Body.Close() // Close the body after reading + finalResp.Body.Close() if string(bodyBytes) != tt.expectedBody { t.Fatalf("Expected body %q, got %q", tt.expectedBody, bodyBytes) } - if req.Body != nil { - reqBodyBytes, err := io.ReadAll(req.Body) - if err != nil { - t.Fatalf("Failed to read request body: %v", err) - } - req.Body.Close() - - if string(reqBodyBytes) != tt.requestBody { - t.Fatalf("Expected request body %q, got %q", tt.requestBody, reqBodyBytes) - } - } }) } } -func Test_RetryPolicy_For_AcquireToken_Failure(t *testing.T) { +func Test_RetryPolicy_For_AcquireToken(t *testing.T) { testCases := []struct { numberOfFails int expectedFail bool @@ -317,7 +309,7 @@ func Test_RetryPolicy_For_AcquireToken_Failure(t *testing.T) { } } else { if err != nil { - t.Fatalf("should have encountered the error") + t.Fatal(err) } if resp.AccessToken != token { t.Fatalf("wanted %q, got %q", token, resp.AccessToken) From a81a5beb27ca68be58324a841bcfbc10216344e6 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary <107404295+4gust@users.noreply.github.com> Date: Tue, 26 Nov 2024 12:49:59 +0000 Subject: [PATCH 17/19] Update apps/managedidentity/managedidentity_test.go Co-authored-by: Charles Lowell <10964656+chlowell@users.noreply.github.com> --- apps/managedidentity/managedidentity_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 04534676..e9f7d88c 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -319,8 +319,7 @@ func Test_RetryPolicy_For_AcquireToken(t *testing.T) { if errorRetryCounter != 1 { t.Fatalf("expected Number of retry of 1, got %d", errorRetryCounter) } - } else { - if errorRetryCounter != testCase.numberOfFails && testCase.numberOfFails < 3 { + } else if errorRetryCounter != testCase.numberOfFails && testCase.numberOfFails < 3 { t.Fatalf("expected Number of retry of %d, got %d", testCase.numberOfFails, errorRetryCounter) } } From 86d023d2def13188360d565866202556b135251d Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 26 Nov 2024 12:55:18 +0000 Subject: [PATCH 18/19] updated tests and comments --- apps/managedidentity/managedidentity.go | 6 ++---- apps/managedidentity/managedidentity_test.go | 7 +------ 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 235c4b67..601ddcf2 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -77,7 +77,6 @@ const ( defaultRetryCount = 3 ) -// retry codes for IMDS var retryCodesForIMDS = []int{ http.StatusNotFound, // 404 http.StatusGone, // 410 @@ -95,7 +94,6 @@ var retryCodesForIMDS = []int{ http.StatusNetworkAuthenticationRequired, // 511 } -// retry on these codes var retryStatusCodes = []int{ http.StatusRequestTimeout, // 408 http.StatusTooManyRequests, // 429 @@ -364,7 +362,7 @@ func authResultFromToken(authParams authority.AuthParams, token accesstokens.Tok return ar, err } -// Contains checks if the element is present in the list. +// contains checks if the element is present in the list. func contains[T comparable](list []T, element T) bool { for _, v := range list { if v == element { @@ -382,7 +380,7 @@ func (c Client) retry(maxRetries int, req *http.Request) (*http.Response, error) tryCtx, tryCancel := context.WithTimeout(req.Context(), time.Second*15) defer tryCancel() if resp != nil && resp.Body != nil { - io.Copy(io.Discard, resp.Body) + _, _ = io.Copy(io.Discard, resp.Body) resp.Body.Close() } cloneReq := req.Clone(tryCtx) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 04534676..d769ed75 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -156,7 +156,6 @@ func TestRetryFunction(t *testing.T) { expectedStatus int expectedBody string maxRetries int - requestBody string source Source }{ { @@ -171,7 +170,6 @@ func TestRetryFunction(t *testing.T) { expectedStatus: http.StatusOK, expectedBody: "Success", maxRetries: 3, - requestBody: "Test Body", source: AzureArc, }, { @@ -186,7 +184,6 @@ func TestRetryFunction(t *testing.T) { expectedStatus: http.StatusOK, expectedBody: "Success", maxRetries: 3, - requestBody: "Test Body", source: DefaultToIMDS, }, { @@ -201,7 +198,6 @@ func TestRetryFunction(t *testing.T) { expectedStatus: http.StatusInternalServerError, expectedBody: "Error", maxRetries: 2, - requestBody: "Test Body", source: AzureArc, }, { @@ -216,7 +212,6 @@ func TestRetryFunction(t *testing.T) { expectedStatus: http.StatusInternalServerError, expectedBody: "Error", maxRetries: 2, - requestBody: "Test Body", source: DefaultToIMDS, }, } @@ -232,7 +227,7 @@ func TestRetryFunction(t *testing.T) { if err != nil { t.Fatal(err) } - reqBody := bytes.NewBufferString(tt.requestBody) + reqBody := bytes.NewBufferString("Test Body") req, err := http.NewRequest("POST", "https://example.com", reqBody) if err != nil { t.Fatal(err) From 3fd8ad544c5b25cf772b6c3284d6a59d8b014762 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 26 Nov 2024 12:56:20 +0000 Subject: [PATCH 19/19] Update managedidentity_test.go --- apps/managedidentity/managedidentity_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 23c94d64..a13edebc 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -315,8 +315,7 @@ func Test_RetryPolicy_For_AcquireToken(t *testing.T) { t.Fatalf("expected Number of retry of 1, got %d", errorRetryCounter) } } else if errorRetryCounter != testCase.numberOfFails && testCase.numberOfFails < 3 { - t.Fatalf("expected Number of retry of %d, got %d", testCase.numberOfFails, errorRetryCounter) - } + t.Fatalf("expected Number of retry of %d, got %d", testCase.numberOfFails, errorRetryCounter) } }) }