diff --git a/pkg/auth/azure_auth.go b/pkg/auth/azure_auth.go index 730a0a2c5b..c641f5cd03 100644 --- a/pkg/auth/azure_auth.go +++ b/pkg/auth/azure_auth.go @@ -100,6 +100,17 @@ func GetServicePrincipalToken(config *AzureAuthConfig, env *azure.Environment, r } if len(config.UserAssignedIdentityID) > 0 { klog.V(4).Info("azure: using User Assigned MSI ID to retrieve access token") + resourceID, err := azure.ParseResourceID(config.UserAssignedIdentityID) + if err == nil && + strings.EqualFold(resourceID.Provider, "Microsoft.ManagedIdentity") && + strings.EqualFold(resourceID.ResourceType, "userAssignedIdentities") { + klog.V(4).Info("azure: User Assigned MSI ID is resource ID") + return adal.NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, + resource, + config.UserAssignedIdentityID) + } + + klog.V(4).Info("azure: User Assigned MSI ID is client ID. Resource ID parsing error: %+v", err) return adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource, config.UserAssignedIdentityID) diff --git a/pkg/auth/azure_auth_test.go b/pkg/auth/azure_auth_test.go index 89aebdb47a..435b4956b1 100644 --- a/pkg/auth/azure_auth_test.go +++ b/pkg/auth/azure_auth_test.go @@ -65,13 +65,13 @@ func TestGetServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) { configs := []*AzureAuthConfig{ { UseManagedIdentityExtension: true, - UserAssignedIdentityID: "UserAssignedIdentityID", + UserAssignedIdentityID: "00000000-0000-0000-0000-000000000000", }, // The Azure service principal is ignored when // UseManagedIdentityExtension is set to true { UseManagedIdentityExtension: true, - UserAssignedIdentityID: "UserAssignedIdentityID", + UserAssignedIdentityID: "00000000-0000-0000-0000-000000000000", TenantID: "TenantID", AADClientID: "AADClientID", AADClientSecret: "AADClientSecret", @@ -110,6 +110,55 @@ func TestGetServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) { } } +func TestGetServicePrincipalTokenFromMSIWithIdentityResourceID(t *testing.T) { + configs := []*AzureAuthConfig{ + { + UseManagedIdentityExtension: true, + UserAssignedIdentityID: "/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/ua", + }, + // The Azure service principal is ignored when + // UseManagedIdentityExtension is set to true + { + UseManagedIdentityExtension: true, + UserAssignedIdentityID: "/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/ua", + TenantID: "TenantID", + AADClientID: "AADClientID", + AADClientSecret: "AADClientSecret", + }, + } + env := &azure.PublicCloud + + // msiEndpointEnv and msiSecretEnv are required because autorest/adal library requires IMDS endpoint to be available. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("{}")) + assert.NoError(t, err) + })) + originalEnv := os.Getenv(msiEndpointEnv) + originalSecret := os.Getenv(msiSecretEnv) + os.Setenv(msiEndpointEnv, server.URL) + os.Setenv(msiSecretEnv, "secret") + defer func() { + server.Close() + os.Setenv(msiEndpointEnv, originalEnv) + os.Setenv(msiSecretEnv, originalSecret) + }() + + for _, config := range configs { + token, err := GetServicePrincipalToken(config, env, "") + assert.NoError(t, err) + + msiEndpoint, err := adal.GetMSIVMEndpoint() + assert.NoError(t, err) + + spt, err := adal.NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, + env.ServiceManagementEndpoint, config.UserAssignedIdentityID) + assert.NoError(t, err) + assert.Equal(t, token, spt) + } +} + func TestGetServicePrincipalTokenFromMSI(t *testing.T) { configs := []*AzureAuthConfig{ {