From f4d41a915e3569db347e61e69aa2152ea8300c6a Mon Sep 17 00:00:00 2001 From: Andrew O Hart Date: Wed, 16 Oct 2024 23:19:05 +0100 Subject: [PATCH] Fixes for windows tests --- apps/managedidentity/managedidentity.go | 38 ++++++-------------- apps/managedidentity/managedidentity_test.go | 18 ++++++---- 2 files changed, 22 insertions(+), 34 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index fa503369..fc7fceea 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -70,14 +70,10 @@ const ( identityServerThumbprintEnvVar = "IDENTITY_SERVER_THUMBPRINT" ) -var getAzureArcPlatformPath = func() string { - switch runtime.GOOS { +var getAzureArcPlatformPath = func(platform string) string { + switch platform { case "windows": - programData := os.Getenv("ProgramData") - if programData == "" { - return "" - } - return fmt.Sprintf("%s%s", programData, windowsTokenPath) + return fmt.Sprintf("%s%s", os.Getenv("ProgramData"), windowsTokenPath) case "linux": return linuxTokenPath default: @@ -85,14 +81,10 @@ var getAzureArcPlatformPath = func() string { } } -var getAzureArcFilePath = func() string { - switch runtime.GOOS { +var getAzureArcFilePath = func(platform string) string { + switch platform { case "windows": - programFiles := os.Getenv("ProgramFiles") - if programFiles == "" { - return "" - } - return fmt.Sprintf("%s%s", programFiles, windowsHimdsPath) + return fmt.Sprintf("%s%s", os.Getenv("ProgramFiles"), windowsHimdsPath) case "linux": return linuxHimdsPath default: @@ -100,16 +92,6 @@ var getAzureArcFilePath = func() string { } } -// var supportedAzureArcPlatforms = map[string]string{ -// "windows": fmt.Sprintf("%s%s", os.Getenv("ProgramData"), windowsTokenPath), -// "linux": linuxTokenPath, -// } - -// var azureArcOsToFileMap = map[string]string{ -// "windows": fmt.Sprintf("%s%s", os.Getenv("ProgramFiles"), windowsHimdsPath), -// "linux": linuxHimdsPath, -// } - type Source string type ID interface { @@ -255,7 +237,7 @@ func (client Client) AcquireToken(ctx context.Context, resource string, options case errors.CallErr: switch callErr.Resp.StatusCode { case http.StatusUnauthorized: - response, err := client.handleAzureArcResponse(ctx, callErr.Resp, resource) + response, err := client.handleAzureArcResponse(ctx, callErr.Resp, resource, runtime.GOOS) if err != nil { return base.AuthResult{}, err } @@ -397,7 +379,7 @@ func validateAzureArcEnvironment(identityEndpoint, imdsEndpoint string, platform return true } - himdsFilePath := getAzureArcFilePath() + himdsFilePath := getAzureArcFilePath(platform) if himdsFilePath != "" && fileExists(himdsFilePath) { return true @@ -406,7 +388,7 @@ func validateAzureArcEnvironment(identityEndpoint, imdsEndpoint string, platform return false } -func (c *Client) handleAzureArcResponse(ctx context.Context, response *http.Response, resource string) (accesstokens.TokenResponse, error) { +func (c *Client) handleAzureArcResponse(ctx context.Context, response *http.Response, resource string, platform string) (accesstokens.TokenResponse, error) { if response.StatusCode == http.StatusUnauthorized { wwwAuthenticateHeader := response.Header.Get(wwwAuthenticateHeaderName) @@ -415,7 +397,7 @@ func (c *Client) handleAzureArcResponse(ctx context.Context, response *http.Resp } // check if the platform is supported - expectedSecretFilePath := getAzureArcPlatformPath() + expectedSecretFilePath := getAzureArcPlatformPath(platform) if expectedSecretFilePath == "" { return accesstokens.TokenResponse{}, errors.New("platform not supported") } diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index c20f5deb..596f8003 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -88,12 +88,10 @@ func createMockFile(t *testing.T, path string, size int64) { if err := os.MkdirAll(dir, 0755); err != nil { t.Fatalf("failed to create directory: %v", err) } - f, err := os.Create(path) if err != nil { t.Fatalf("failed to create file: %v", err) } - if size > 0 { if err := f.Truncate(size); err != nil { t.Fatalf("failed to truncate file: %v", err) @@ -105,7 +103,6 @@ func createMockFile(t *testing.T, path string, size int64) { func getMockFilePath(t *testing.T) (string, error) { tempDir := t.TempDir() mockFilePath := filepath.Join(tempDir, "AzureConnectedMachineAgent") - return mockFilePath, nil } @@ -134,11 +131,20 @@ func unsetEnvVars(t *testing.T) { t.Setenv(msiEndpointEnvVar, "") } +func setCustomAzureArcPlatformPath(path string) { + originalFunc := getAzureArcFilePath + defer func() { getAzureArcFilePath = originalFunc }() + + getAzureArcPlatformPath = func(platform string) string { + return path + } +} + func setCustomAzureArcFilePath(path string) { originalFunc := getAzureArcFilePath defer func() { getAzureArcFilePath = originalFunc }() - getAzureArcFilePath = func() string { + getAzureArcFilePath = func(platform string) string { return path } } @@ -570,7 +576,7 @@ func Test_handleAzureArcResponse(t *testing.T) { if tc.createMockFile { expectedFilePath := filepath.Join(testCaseFilePath) mockFilePath := filepath.Join(expectedFilePath, "secret.key") - setCustomAzureArcFilePath(mockFilePath) + setCustomAzureArcPlatformPath(expectedFilePath) if tc.name == "Invalid secret file size" { createMockFile(t, mockFilePath, 5000) @@ -588,7 +594,7 @@ func Test_handleAzureArcResponse(t *testing.T) { contextToUse = nil } - _, err := client.handleAzureArcResponse(contextToUse, response, "") + _, err := client.handleAzureArcResponse(contextToUse, response, "", tc.platform) if err == nil || err.Error() != tc.expectedError { t.Fatalf("expected error %v, got %v", tc.expectedError, err)