Skip to content

Commit

Permalink
Fixes for windows tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AndyOHart committed Oct 16, 2024
1 parent f8561a4 commit f4d41a9
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 34 deletions.
38 changes: 10 additions & 28 deletions apps/managedidentity/managedidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,46 +70,28 @@ const (
identityServerThumbprintEnvVar = "IDENTITY_SERVER_THUMBPRINT"
)

var getAzureArcPlatformPath = func() string {
switch runtime.GOOS {
var getAzureArcPlatformPath = func(platform string) string {
switch platform {
case "windows":
programData := os.Getenv("ProgramData")
if programData == "" {
return ""
}
return fmt.Sprintf("%s%s", programData, windowsTokenPath)
return fmt.Sprintf("%s%s", os.Getenv("ProgramData"), windowsTokenPath)
case "linux":
return linuxTokenPath
default:
return ""
}
}

var getAzureArcFilePath = func() string {
switch runtime.GOOS {
var getAzureArcFilePath = func(platform string) string {
switch platform {
case "windows":
programFiles := os.Getenv("ProgramFiles")
if programFiles == "" {
return ""
}
return fmt.Sprintf("%s%s", programFiles, windowsHimdsPath)
return fmt.Sprintf("%s%s", os.Getenv("ProgramFiles"), windowsHimdsPath)
case "linux":
return linuxHimdsPath
default:
return ""
}
}

// var supportedAzureArcPlatforms = map[string]string{
// "windows": fmt.Sprintf("%s%s", os.Getenv("ProgramData"), windowsTokenPath),
// "linux": linuxTokenPath,
// }

// var azureArcOsToFileMap = map[string]string{
// "windows": fmt.Sprintf("%s%s", os.Getenv("ProgramFiles"), windowsHimdsPath),
// "linux": linuxHimdsPath,
// }

type Source string

type ID interface {
Expand Down Expand Up @@ -255,7 +237,7 @@ func (client Client) AcquireToken(ctx context.Context, resource string, options
case errors.CallErr:
switch callErr.Resp.StatusCode {
case http.StatusUnauthorized:
response, err := client.handleAzureArcResponse(ctx, callErr.Resp, resource)
response, err := client.handleAzureArcResponse(ctx, callErr.Resp, resource, runtime.GOOS)
if err != nil {
return base.AuthResult{}, err
}
Expand Down Expand Up @@ -397,7 +379,7 @@ func validateAzureArcEnvironment(identityEndpoint, imdsEndpoint string, platform
return true
}

himdsFilePath := getAzureArcFilePath()
himdsFilePath := getAzureArcFilePath(platform)

if himdsFilePath != "" && fileExists(himdsFilePath) {
return true
Expand All @@ -406,7 +388,7 @@ func validateAzureArcEnvironment(identityEndpoint, imdsEndpoint string, platform
return false
}

func (c *Client) handleAzureArcResponse(ctx context.Context, response *http.Response, resource string) (accesstokens.TokenResponse, error) {
func (c *Client) handleAzureArcResponse(ctx context.Context, response *http.Response, resource string, platform string) (accesstokens.TokenResponse, error) {
if response.StatusCode == http.StatusUnauthorized {
wwwAuthenticateHeader := response.Header.Get(wwwAuthenticateHeaderName)

Expand All @@ -415,7 +397,7 @@ func (c *Client) handleAzureArcResponse(ctx context.Context, response *http.Resp
}

// check if the platform is supported
expectedSecretFilePath := getAzureArcPlatformPath()
expectedSecretFilePath := getAzureArcPlatformPath(platform)
if expectedSecretFilePath == "" {
return accesstokens.TokenResponse{}, errors.New("platform not supported")
}
Expand Down
18 changes: 12 additions & 6 deletions apps/managedidentity/managedidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,10 @@ func createMockFile(t *testing.T, path string, size int64) {
if err := os.MkdirAll(dir, 0755); err != nil {
t.Fatalf("failed to create directory: %v", err)
}

f, err := os.Create(path)
if err != nil {
t.Fatalf("failed to create file: %v", err)
}

if size > 0 {
if err := f.Truncate(size); err != nil {
t.Fatalf("failed to truncate file: %v", err)
Expand All @@ -105,7 +103,6 @@ func createMockFile(t *testing.T, path string, size int64) {
func getMockFilePath(t *testing.T) (string, error) {
tempDir := t.TempDir()
mockFilePath := filepath.Join(tempDir, "AzureConnectedMachineAgent")

return mockFilePath, nil
}

Expand Down Expand Up @@ -134,11 +131,20 @@ func unsetEnvVars(t *testing.T) {
t.Setenv(msiEndpointEnvVar, "")
}

func setCustomAzureArcPlatformPath(path string) {
originalFunc := getAzureArcFilePath
defer func() { getAzureArcFilePath = originalFunc }()

getAzureArcPlatformPath = func(platform string) string {
return path
}
}

func setCustomAzureArcFilePath(path string) {
originalFunc := getAzureArcFilePath
defer func() { getAzureArcFilePath = originalFunc }()

getAzureArcFilePath = func() string {
getAzureArcFilePath = func(platform string) string {
return path
}
}
Expand Down Expand Up @@ -570,7 +576,7 @@ func Test_handleAzureArcResponse(t *testing.T) {
if tc.createMockFile {
expectedFilePath := filepath.Join(testCaseFilePath)
mockFilePath := filepath.Join(expectedFilePath, "secret.key")
setCustomAzureArcFilePath(mockFilePath)
setCustomAzureArcPlatformPath(expectedFilePath)

if tc.name == "Invalid secret file size" {
createMockFile(t, mockFilePath, 5000)
Expand All @@ -588,7 +594,7 @@ func Test_handleAzureArcResponse(t *testing.T) {
contextToUse = nil
}

_, err := client.handleAzureArcResponse(contextToUse, response, "")
_, err := client.handleAzureArcResponse(contextToUse, response, "", tc.platform)

if err == nil || err.Error() != tc.expectedError {
t.Fatalf("expected error %v, got %v", tc.expectedError, err)
Expand Down

0 comments on commit f4d41a9

Please sign in to comment.