diff --git a/azure/services/identities/client.go b/azure/services/identities/client.go index ca406d2914f4..d8958c161af9 100644 --- a/azure/services/identities/client.go +++ b/azure/services/identities/client.go @@ -51,6 +51,19 @@ func NewClient(auth azure.Authorizer) (Client, error) { return &AzureClient{factory.NewUserAssignedIdentitiesClient()}, nil } +// NewClientBySub creates a new MSI client with a given subscriptionID +func NewClientBySub(auth azure.Authorizer, subscriptionID string) (Client, error) { + opts, err := azure.ARMClientOptions(auth.CloudEnvironment()) + if err != nil { + return nil, errors.Wrap(err, "failed to create identities client options") + } + factory, err := armmsi.NewClientFactory(subscriptionID, auth.Token(), opts) + if err != nil { + return nil, errors.Wrap(err, "failed to create armmsi client factory") + } + return &AzureClient{factory.NewUserAssignedIdentitiesClient()}, nil +} + // Get returns a managed service identity. func (ac *AzureClient) Get(ctx context.Context, resourceGroupName, name string) (armmsi.Identity, error) { ctx, _, done := tele.StartSpanWithLogger(ctx, "identities.AzureClient.Get") diff --git a/azure/services/virtualmachines/virtualmachines.go b/azure/services/virtualmachines/virtualmachines.go index 9179a11dedda..3c62482e6d48 100644 --- a/azure/services/virtualmachines/virtualmachines.go +++ b/azure/services/virtualmachines/virtualmachines.go @@ -176,7 +176,15 @@ func (s *Service) checkUserAssignedIdentities(ctx context.Context, specIdentitie // Create a map of the expected identities. The ProviderID is converted to match the format of the VM identity. for _, expectedIdentity := range specIdentities { - expectedClientID, err := s.identitiesGetter.GetClientID(ctx, expectedIdentity.ProviderID) + var identitiesClient identities.Client = s.identitiesGetter + parsed, err := azureutil.ParseResourceID(expectedIdentity.ProviderID) + if err != nil { + return err + } + if parsed.SubscriptionID != s.Scope.SubscriptionID() { + identitiesClient, err = identities.NewClientBySub(s.Scope, parsed.SubscriptionID) + } + expectedClientID, err := identitiesClient.GetClientID(ctx, expectedIdentity.ProviderID) if err != nil { return errors.Wrap(err, "failed to get client ID") }