From 8e39922546de92aaee560449a2772f0c7a8b8c0e Mon Sep 17 00:00:00 2001 From: Matt Boersma Date: Thu, 8 Jun 2023 14:32:55 -0600 Subject: [PATCH] Validate providerID for user-assigned IDs in webhook --- api/v1beta1/azuremachine_validation.go | 17 ++++- api/v1beta1/azuremachine_validation_test.go | 69 +++++++++++++++++++ api/v1beta1/azuremachine_webhook_test.go | 7 +- .../azuremachinetemplate_webhook_test.go | 5 +- azure/converters/identity.go | 4 +- azure/defaults.go | 13 ---- azure/defaults_test.go | 52 -------------- azure/scope/machine.go | 3 +- azure/scope/machinepool.go | 5 +- azure/services/identities/client.go | 3 +- azure/services/natgateways/spec.go | 3 +- azure/services/scalesets/scalesets.go | 3 +- azure/services/scalesets/scalesets_test.go | 7 +- azure/services/scalesetvms/scalesetvms.go | 7 +- azure/services/virtualmachines/client.go | 3 +- .../virtualmachines/virtualmachines.go | 3 +- azure/types.go | 5 +- .../azuremanagedmachinepool_reconciler.go | 3 +- .../v1beta1/azuremachinepool_webhook_test.go | 7 +- test/e2e/azure_edgezone.go | 4 +- test/e2e/azure_logcollector.go | 6 +- test/e2e/azure_privatecluster.go | 4 +- test/e2e/azure_vmextensions.go | 5 +- util/azure/azure.go | 12 ++++ util/azure/azure_test.go | 52 ++++++++++++++ 25 files changed, 200 insertions(+), 102 deletions(-) diff --git a/api/v1beta1/azuremachine_validation.go b/api/v1beta1/azuremachine_validation.go index e741838714d..df0b2974c21 100644 --- a/api/v1beta1/azuremachine_validation.go +++ b/api/v1beta1/azuremachine_validation.go @@ -24,9 +24,10 @@ import ( "github.com/google/uuid" "golang.org/x/crypto/ssh" "k8s.io/apimachinery/pkg/util/validation/field" + azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" ) -// ValidateAzureMachineSpec check for validation errors of azuremachine.spec. +// ValidateAzureMachineSpec checks an AzureMachineSpec and returns any validation errors. func ValidateAzureMachineSpec(spec AzureMachineSpec) field.ErrorList { var allErrs field.ErrorList @@ -128,9 +129,19 @@ func ValidateSystemAssignedIdentity(identityType VMIdentity, oldIdentity, newIde func ValidateUserAssignedIdentity(identityType VMIdentity, userAssignedIdentities []UserAssignedIdentity, fldPath *field.Path) field.ErrorList { allErrs := field.ErrorList{} - if identityType == VMIdentityUserAssigned && len(userAssignedIdentities) == 0 { - allErrs = append(allErrs, field.Required(fldPath, "must be specified for the 'UserAssigned' identity type")) + if identityType == VMIdentityUserAssigned { + if len(userAssignedIdentities) == 0 { + allErrs = append(allErrs, field.Required(fldPath, "must be specified for the 'UserAssigned' identity type")) + } + for _, identity := range userAssignedIdentities { + if identity.ProviderID != "" { + if _, err := azureutil.ParseResourceID(identity.ProviderID); err != nil { + allErrs = append(allErrs, field.Invalid(fldPath, identity.ProviderID, "must be a valid Azure resource ID")) + } + } + } } + return allErrs } diff --git a/api/v1beta1/azuremachine_validation_test.go b/api/v1beta1/azuremachine_validation_test.go index 4443e6b6e31..ae4fa9a9213 100644 --- a/api/v1beta1/azuremachine_validation_test.go +++ b/api/v1beta1/azuremachine_validation_test.go @@ -580,6 +580,75 @@ func TestAzureMachine_ValidateSystemAssignedIdentityRole(t *testing.T) { } } +func TestAzureMachine_ValidateUserAssignedIdentity(t *testing.T) { + g := NewWithT(t) + + tests := []struct { + name string + idType VMIdentity + identities []UserAssignedIdentity + wantErr bool + }{ + { + name: "empty identity list", + idType: VMIdentityUserAssigned, + identities: []UserAssignedIdentity{}, + wantErr: true, + }, + { + name: "invalid: providerID must start with slash", + idType: VMIdentityUserAssigned, + identities: []UserAssignedIdentity{ + { + ProviderID: "subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/my-resource-group/providers/Microsoft.Compute/virtualMachines/default-20202-control-plane-7w265", + }, + }, + wantErr: true, + }, + { + name: "invalid: providerID must start with subscriptions or providers", + idType: VMIdentityUserAssigned, + identities: []UserAssignedIdentity{ + { + ProviderID: "azure:///prescriptions/00000000-0000-0000-0000-000000000000/resourceGroups/my-resource-group/providers/Microsoft.Compute/virtualMachines/default-20202-control-plane-7w265", + }, + }, + wantErr: true, + }, + { + name: "valid", + idType: VMIdentityUserAssigned, + identities: []UserAssignedIdentity{ + { + ProviderID: "/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/my-resource-group/providers/Microsoft.Compute/virtualMachines/default-20202-control-plane-7w265", + }, + }, + wantErr: false, + }, + { + name: "valid with provider prefix", + idType: VMIdentityUserAssigned, + identities: []UserAssignedIdentity{ + { + ProviderID: "azure:///subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/my-resource-group/providers/Microsoft.Compute/virtualMachines/default-20202-control-plane-7w265", + }, + }, + wantErr: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + errs := ValidateUserAssignedIdentity(tc.idType, tc.identities, field.NewPath("userAssignedIdentities")) + if tc.wantErr { + g.Expect(errs).NotTo(BeEmpty()) + } else { + g.Expect(errs).To(BeEmpty()) + } + }) + } +} + func TestAzureMachine_ValidateDataDisksUpdate(t *testing.T) { g := NewWithT(t) diff --git a/api/v1beta1/azuremachine_webhook_test.go b/api/v1beta1/azuremachine_webhook_test.go index bb0541ba70a..2bc90a231de 100644 --- a/api/v1beta1/azuremachine_webhook_test.go +++ b/api/v1beta1/azuremachine_webhook_test.go @@ -90,8 +90,11 @@ func TestAzureMachine_ValidateCreate(t *testing.T) { wantErr: true, }, { - name: "azuremachine with list of user-assigned identities", - machine: createMachineWithUserAssignedIdentities([]UserAssignedIdentity{{ProviderID: "azure:///123"}, {ProviderID: "azure:///456"}}), + name: "azuremachine with list of user-assigned identities", + machine: createMachineWithUserAssignedIdentities([]UserAssignedIdentity{ + {ProviderID: "azure:///subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/my-resource-group/providers/Microsoft.Compute/virtualMachines/default-12345-control-plane-9d5x5"}, + {ProviderID: "azure:///subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/my-resource-group/providers/Microsoft.Compute/virtualMachines/default-12345-control-plane-a1b2c"}, + }), wantErr: false, }, { diff --git a/api/v1beta1/azuremachinetemplate_webhook_test.go b/api/v1beta1/azuremachinetemplate_webhook_test.go index 8554e35ee48..18b165316f1 100644 --- a/api/v1beta1/azuremachinetemplate_webhook_test.go +++ b/api/v1beta1/azuremachinetemplate_webhook_test.go @@ -102,7 +102,10 @@ func TestAzureMachineTemplate_ValidateCreate(t *testing.T) { { name: "azuremachinetemplate with list of user-assigned identities", machineTemplate: createAzureMachineTemplateFromMachine( - createMachineWithUserAssignedIdentities([]UserAssignedIdentity{{ProviderID: "azure:///123"}, {ProviderID: "azure:///456"}}), + createMachineWithUserAssignedIdentities([]UserAssignedIdentity{ + {ProviderID: "azure:///subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/my-resource-group/providers/Microsoft.Compute/virtualMachines/default-09091-control-plane-f1b2c"}, + {ProviderID: "azure:///subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/my-resource-group/providers/Microsoft.Compute/virtualMachines/default-09091-control-plane-9a8b7"}, + }), ), wantErr: false, }, diff --git a/azure/converters/identity.go b/azure/converters/identity.go index 6993797aea2..1f06673c63e 100644 --- a/azure/converters/identity.go +++ b/azure/converters/identity.go @@ -22,7 +22,7 @@ import ( "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2021-11-01/compute" "github.com/pkg/errors" infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" - "sigs.k8s.io/cluster-api-provider-azure/azure" + azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" ) // ErrUserAssignedIdentitiesNotFound is the error thrown when user assigned identities is not passed with the identity type being UserAssigned. @@ -84,5 +84,5 @@ func UserAssignedIdentitiesToVMSSSDK(identities []infrav1.UserAssignedIdentity) // sanitized removes "azure://" prefix from the given id. func sanitized(id string) string { - return strings.TrimPrefix(id, azure.ProviderIDPrefix) + return strings.TrimPrefix(id, azureutil.ProviderIDPrefix) } diff --git a/azure/defaults.go b/azure/defaults.go index 12fb5395600..37b2af97ed5 100644 --- a/azure/defaults.go +++ b/azure/defaults.go @@ -19,9 +19,7 @@ package azure import ( "fmt" "net/http" - "strings" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" "github.com/Azure/go-autorest/autorest" "sigs.k8s.io/cluster-api-provider-azure/util/tele" "sigs.k8s.io/cluster-api-provider-azure/version" @@ -93,12 +91,6 @@ const ( bootstrapSentinelFile = "/run/cluster-api/bootstrap-success.complete" ) -const ( - // ProviderIDPrefix will be appended to the beginning of Azure resource IDs to form the Kubernetes Provider ID. - // NOTE: this format matches the 2 slashes format used in cloud-provider and cluster-autoscaler. - ProviderIDPrefix = "azure://" -) - const ( // CustomHeaderPrefix is the prefix of annotations that enable additional cluster / node pool features. // Whatever follows the prefix will be passed as a header to cluster/node pool creation/update requests. @@ -363,8 +355,3 @@ func msCorrelationIDSendDecorator(snd autorest.Sender) autorest.Sender { return snd.Do(r) }) } - -// ParseResourceID parses a string to an *arm.ResourceID, first removing any "azure://" prefix. -func ParseResourceID(id string) (*arm.ResourceID, error) { - return arm.ParseResourceID(strings.TrimPrefix(id, ProviderIDPrefix)) -} diff --git a/azure/defaults_test.go b/azure/defaults_test.go index 4a9ddc5a5cb..38288f64d4f 100644 --- a/azure/defaults_test.go +++ b/azure/defaults_test.go @@ -117,55 +117,3 @@ func TestMSCorrelationIDSendDecorator(t *testing.T) { receivedReq.Header.Get(string(tele.CorrIDKeyVal)), ).To(Equal(string(corrID))) } - -func TestParseResourceID(t *testing.T) { - g := NewWithT(t) - - tests := []struct { - name string - id string - expectedName string - errExpected bool - }{ - { - name: "invalid", - id: "invalid", - expectedName: "", - errExpected: true, - }, - { - name: "invalid: must start with slash", - id: "subscriptions/123/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm", - expectedName: "", - errExpected: true, - }, - { - name: "invalid: must start with subscriptions or providers", - id: "/prescriptions/123/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm", - expectedName: "", - errExpected: true, - }, - { - name: "valid", - id: "/subscriptions/123/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm", - expectedName: "vm", - }, - { - name: "valid with provider prefix", - id: "azure:///subscriptions/123/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm", - expectedName: "vm", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - resourceID, err := ParseResourceID(tt.id) - if tt.errExpected { - g.Expect(err).To(HaveOccurred()) - } else { - g.Expect(err).NotTo(HaveOccurred()) - g.Expect(resourceID.Name).To(Equal(tt.expectedName)) - } - }) - } -} diff --git a/azure/scope/machine.go b/azure/scope/machine.go index 39052502831..3f296839ee9 100644 --- a/azure/scope/machine.go +++ b/azure/scope/machine.go @@ -39,6 +39,7 @@ import ( "sigs.k8s.io/cluster-api-provider-azure/azure/services/virtualmachineimages" "sigs.k8s.io/cluster-api-provider-azure/azure/services/virtualmachines" "sigs.k8s.io/cluster-api-provider-azure/azure/services/vmextensions" + azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" "sigs.k8s.io/cluster-api-provider-azure/util/futures" "sigs.k8s.io/cluster-api-provider-azure/util/tele" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" @@ -452,7 +453,7 @@ func (m *MachineScope) Role() string { // GetVMID returns the AzureMachine instance id by parsing the scope's providerID. func (m *MachineScope) GetVMID() string { - resourceID, err := azure.ParseResourceID(m.ProviderID()) + resourceID, err := azureutil.ParseResourceID(m.ProviderID()) if err != nil { return "" } diff --git a/azure/scope/machinepool.go b/azure/scope/machinepool.go index 5b3ee9c5dd0..2070dcaa67e 100644 --- a/azure/scope/machinepool.go +++ b/azure/scope/machinepool.go @@ -36,6 +36,7 @@ import ( "sigs.k8s.io/cluster-api-provider-azure/azure/services/scalesets" "sigs.k8s.io/cluster-api-provider-azure/azure/services/virtualmachineimages" infrav1exp "sigs.k8s.io/cluster-api-provider-azure/exp/api/v1beta1" + azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" "sigs.k8s.io/cluster-api-provider-azure/util/futures" "sigs.k8s.io/cluster-api-provider-azure/util/tele" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" @@ -154,7 +155,7 @@ func (m *MachinePoolScope) Name() string { // ProviderID returns the AzureMachinePool ID by parsing Spec.ProviderID. func (m *MachinePoolScope) ProviderID() string { - resourceID, err := azure.ParseResourceID(m.AzureMachinePool.Spec.ProviderID) + resourceID, err := azureutil.ParseResourceID(m.AzureMachinePool.Spec.ProviderID) if err != nil { return "" } @@ -376,7 +377,7 @@ func (m *MachinePoolScope) createMachine(ctx context.Context, machine azure.VMSS ctx, _, done := tele.StartSpanWithLogger(ctx, "scope.MachinePoolScope.createMachine") defer done() - parsed, err := azure.ParseResourceID(machine.ID) + parsed, err := azureutil.ParseResourceID(machine.ID) if err != nil { return errors.Wrap(err, fmt.Sprintf("failed to parse resource id %q", machine.ID)) } diff --git a/azure/services/identities/client.go b/azure/services/identities/client.go index d9a9e8a0b5c..172852a105e 100644 --- a/azure/services/identities/client.go +++ b/azure/services/identities/client.go @@ -22,6 +22,7 @@ import ( "github.com/Azure/azure-sdk-for-go/services/msi/mgmt/2018-11-30/msi" "github.com/Azure/go-autorest/autorest" "sigs.k8s.io/cluster-api-provider-azure/azure" + azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" "sigs.k8s.io/cluster-api-provider-azure/util/tele" ) @@ -62,7 +63,7 @@ func (ac *AzureClient) GetClientID(ctx context.Context, providerID string) (stri ctx, _, done := tele.StartSpanWithLogger(ctx, "identities.GetClientID") defer done() - parsed, err := azure.ParseResourceID(providerID) + parsed, err := azureutil.ParseResourceID(providerID) if err != nil { return "", err } diff --git a/azure/services/natgateways/spec.go b/azure/services/natgateways/spec.go index 7b078d9e104..102574d7984 100644 --- a/azure/services/natgateways/spec.go +++ b/azure/services/natgateways/spec.go @@ -25,6 +25,7 @@ import ( infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" "sigs.k8s.io/cluster-api-provider-azure/azure" "sigs.k8s.io/cluster-api-provider-azure/azure/converters" + azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" ) // NatGatewaySpec defines the specification for a NAT gateway. @@ -96,7 +97,7 @@ func hasPublicIP(natGateway network.NatGateway, publicIPName string) bool { } for _, publicIP := range *natGateway.PublicIPAddresses { - resource, err := azure.ParseResourceID(*publicIP.ID) + resource, err := azureutil.ParseResourceID(*publicIP.ID) if err != nil { continue } diff --git a/azure/services/scalesets/scalesets.go b/azure/services/scalesets/scalesets.go index 7aa0e90a22e..21de9c74319 100644 --- a/azure/services/scalesets/scalesets.go +++ b/azure/services/scalesets/scalesets.go @@ -31,6 +31,7 @@ import ( "sigs.k8s.io/cluster-api-provider-azure/azure" "sigs.k8s.io/cluster-api-provider-azure/azure/converters" "sigs.k8s.io/cluster-api-provider-azure/azure/services/resourceskus" + azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" "sigs.k8s.io/cluster-api-provider-azure/util/generators" "sigs.k8s.io/cluster-api-provider-azure/util/slice" "sigs.k8s.io/cluster-api-provider-azure/util/tele" @@ -111,7 +112,7 @@ func (s *Service) Reconcile(ctx context.Context) (retErr error) { if fetchedVMSS != nil { // Transform the VMSS resource representation to conform to the cloud-provider-azure representation - providerID, err := azprovider.ConvertResourceGroupNameToLower(azure.ProviderIDPrefix + fetchedVMSS.ID) + providerID, err := azprovider.ConvertResourceGroupNameToLower(azureutil.ProviderIDPrefix + fetchedVMSS.ID) if err != nil { log.Error(err, "failed to parse VMSS ID", "ID", fetchedVMSS.ID) } diff --git a/azure/services/scalesets/scalesets_test.go b/azure/services/scalesets/scalesets_test.go index d966573cb53..c79c5f5098c 100644 --- a/azure/services/scalesets/scalesets_test.go +++ b/azure/services/scalesets/scalesets_test.go @@ -33,6 +33,7 @@ import ( "sigs.k8s.io/cluster-api-provider-azure/azure/services/resourceskus" "sigs.k8s.io/cluster-api-provider-azure/azure/services/scalesets/mock_scalesets" gomockinternal "sigs.k8s.io/cluster-api-provider-azure/internal/test/matchers/gomock" + azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" ) @@ -1560,7 +1561,7 @@ func setupDefaultVMSSInProgressOperationDoneExpectations(s *mock_scalesets.MockS m.ListInstances(gomockinternal.AContext(), defaultResourceGroup, defaultVMSSName).Return(instances, nil).AnyTimes() s.MaxSurge().Return(1, nil) s.SetVMSSState(gomock.Any()) - s.SetProviderID(azure.ProviderIDPrefix + *createdVMSS.ID) + s.SetProviderID(azureutil.ProviderIDPrefix + *createdVMSS.ID) } func setupDefaultVMSSStartCreatingExpectations(s *mock_scalesets.MockScaleSetScopeMockRecorder, m *mock_scalesets.MockClientMockRecorder) { @@ -1577,7 +1578,7 @@ func setupCreatingSucceededExpectations(s *mock_scalesets.MockScaleSetScopeMockR m.Get(gomockinternal.AContext(), defaultResourceGroup, defaultVMSSName).Return(vmss, nil) m.ListInstances(gomockinternal.AContext(), defaultResourceGroup, defaultVMSSName).Return(newDefaultInstances(), nil).AnyTimes() s.SetVMSSState(gomock.Any()) - s.SetProviderID(azure.ProviderIDPrefix + *vmss.ID) + s.SetProviderID(azureutil.ProviderIDPrefix + *vmss.ID) } func setupDefaultVMSSExpectations(s *mock_scalesets.MockScaleSetScopeMockRecorder) { @@ -1641,7 +1642,7 @@ func setupVMSSExpectationsWithoutVMImage(s *mock_scalesets.MockScaleSetScopeMock func setupDefaultVMSSUpdateExpectations(s *mock_scalesets.MockScaleSetScopeMockRecorder) { setupUpdateVMSSExpectations(s) - s.SetProviderID(azure.ProviderIDPrefix + "subscriptions/1234/resourceGroups/my_resource_group/providers/Microsoft.Compute/virtualMachines/my-vm") + s.SetProviderID(azureutil.ProviderIDPrefix + "subscriptions/1234/resourceGroups/my_resource_group/providers/Microsoft.Compute/virtualMachines/my-vm") s.GetLongRunningOperationState(defaultVMSSName, serviceName, infrav1.PutFuture).Return(nil) s.GetLongRunningOperationState(defaultVMSSName, serviceName, infrav1.PatchFuture).Return(nil) s.MaxSurge().Return(1, nil) diff --git a/azure/services/scalesetvms/scalesetvms.go b/azure/services/scalesetvms/scalesetvms.go index 0a9183d5b89..4ca7ffab084 100644 --- a/azure/services/scalesetvms/scalesetvms.go +++ b/azure/services/scalesetvms/scalesetvms.go @@ -28,6 +28,7 @@ import ( "sigs.k8s.io/cluster-api-provider-azure/azure" "sigs.k8s.io/cluster-api-provider-azure/azure/converters" "sigs.k8s.io/cluster-api-provider-azure/azure/services/virtualmachines" + azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" "sigs.k8s.io/cluster-api-provider-azure/util/tele" ) @@ -82,7 +83,7 @@ func (s *Service) Reconcile(ctx context.Context) error { // Fetch the latest instance or VM data. AzureMachinePoolReconciler handles model mutations. if isFlex { - resourceID := strings.TrimPrefix(providerID, azure.ProviderIDPrefix) + resourceID := strings.TrimPrefix(providerID, azureutil.ProviderIDPrefix) log.V(4).Info("VMSS is flex", "vmssName", vmssName, "providerID", providerID, "resourceID", resourceID) // Using VMSS Flex, so fetch by resource ID. vm, err := s.VMClient.GetByID(ctx, resourceID) @@ -130,7 +131,7 @@ func (s *Service) Delete(ctx context.Context) error { defer done() if isFlex { - return s.deleteVMSSFlexVM(ctx, strings.TrimPrefix(providerID, azure.ProviderIDPrefix)) + return s.deleteVMSSFlexVM(ctx, strings.TrimPrefix(providerID, azureutil.ProviderIDPrefix)) } return s.deleteVMSSUniformInstance(ctx, resourceGroup, vmssName, instanceID, log) } @@ -146,7 +147,7 @@ func (s *Service) deleteVMSSFlexVM(ctx context.Context, resourceID string) error } }() - parsed, err := azure.ParseResourceID(resourceID) + parsed, err := azureutil.ParseResourceID(resourceID) if err != nil { return errors.Wrap(err, fmt.Sprintf("failed to parse resource id %q", resourceID)) } diff --git a/azure/services/virtualmachines/client.go b/azure/services/virtualmachines/client.go index 06cf8bf9614..1f2e0d321a2 100644 --- a/azure/services/virtualmachines/client.go +++ b/azure/services/virtualmachines/client.go @@ -30,6 +30,7 @@ import ( "k8s.io/utils/pointer" infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" "sigs.k8s.io/cluster-api-provider-azure/azure" + azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" "sigs.k8s.io/cluster-api-provider-azure/util/reconciler" "sigs.k8s.io/cluster-api-provider-azure/util/tele" ) @@ -89,7 +90,7 @@ func (ac *AzureClient) GetByID(ctx context.Context, resourceID string) (compute. ctx, log, done := tele.StartSpanWithLogger(ctx, "virtualmachines.AzureClient.GetByID") defer done() - parsed, err := azure.ParseResourceID(resourceID) + parsed, err := azureutil.ParseResourceID(resourceID) if err != nil { return compute.VirtualMachine{}, errors.Wrap(err, fmt.Sprintf("failed parsing the VM resource id %q", resourceID)) } diff --git a/azure/services/virtualmachines/virtualmachines.go b/azure/services/virtualmachines/virtualmachines.go index 32ad5679871..e98bbf262ed 100644 --- a/azure/services/virtualmachines/virtualmachines.go +++ b/azure/services/virtualmachines/virtualmachines.go @@ -33,6 +33,7 @@ import ( "sigs.k8s.io/cluster-api-provider-azure/azure/services/identities" "sigs.k8s.io/cluster-api-provider-azure/azure/services/networkinterfaces" "sigs.k8s.io/cluster-api-provider-azure/azure/services/publicips" + azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" "sigs.k8s.io/cluster-api-provider-azure/util/reconciler" "sigs.k8s.io/cluster-api-provider-azure/util/tele" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" @@ -102,7 +103,7 @@ func (s *Service) Reconcile(ctx context.Context) error { } infraVM := converters.SDKToVM(vm) // Transform the VM resource representation to conform to the cloud-provider-azure representation - providerID, err := azprovider.ConvertResourceGroupNameToLower(azure.ProviderIDPrefix + infraVM.ID) + providerID, err := azprovider.ConvertResourceGroupNameToLower(azureutil.ProviderIDPrefix + infraVM.ID) if err != nil { return errors.Wrapf(err, "failed to parse VM ID %s", infraVM.ID) } diff --git a/azure/types.go b/azure/types.go index 873cf1cced4..4583b57981c 100644 --- a/azure/types.go +++ b/azure/types.go @@ -22,6 +22,7 @@ import ( "github.com/google/go-cmp/cmp" infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" + azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" ) // RoleAssignmentSpec defines the specification for a Role Assignment. @@ -149,11 +150,11 @@ func (vm VMSSVM) ProviderID() string { splitOnSlash := strings.Split(vm.ID, "/") elems := splitOnSlash[:len(splitOnSlash)-4] elems = append(elems, splitOnSlash[len(splitOnSlash)-2:]...) - return ProviderIDPrefix + strings.Join(elems, "/") + return azureutil.ProviderIDPrefix + strings.Join(elems, "/") } // ProviderID for Uniform scaleset VMs looks like this: // azure:///subscriptions//resourceGroups/my-cluster/providers/Microsoft.Compute/virtualMachineScaleSets/my-cluster-mp-0/virtualMachines/0 - return ProviderIDPrefix + vm.ID + return azureutil.ProviderIDPrefix + vm.ID } // HasLatestModelAppliedToAll returns true if all VMSS instance have the latest model applied. diff --git a/controllers/azuremanagedmachinepool_reconciler.go b/controllers/azuremanagedmachinepool_reconciler.go index 3b670fd43b5..5b660014b2a 100644 --- a/controllers/azuremanagedmachinepool_reconciler.go +++ b/controllers/azuremanagedmachinepool_reconciler.go @@ -28,6 +28,7 @@ import ( "sigs.k8s.io/cluster-api-provider-azure/azure/scope" "sigs.k8s.io/cluster-api-provider-azure/azure/services/agentpools" "sigs.k8s.io/cluster-api-provider-azure/azure/services/scalesets" + azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" "sigs.k8s.io/cluster-api-provider-azure/util/tele" ) @@ -142,7 +143,7 @@ func (s *azureManagedMachinePoolService) Reconcile(ctx context.Context) error { var providerIDs = make([]string, len(instances)) for i := 0; i < len(instances); i++ { // Transform the VMSS instance resource representation to conform to the cloud-provider-azure representation - providerID, err := azprovider.ConvertResourceGroupNameToLower(azure.ProviderIDPrefix + *instances[i].ID) + providerID, err := azprovider.ConvertResourceGroupNameToLower(azureutil.ProviderIDPrefix + *instances[i].ID) if err != nil { return errors.Wrapf(err, "failed to parse instance ID %s", *instances[i].ID) } diff --git a/exp/api/v1beta1/azuremachinepool_webhook_test.go b/exp/api/v1beta1/azuremachinepool_webhook_test.go index 31661d07cc0..5170c78be2d 100644 --- a/exp/api/v1beta1/azuremachinepool_webhook_test.go +++ b/exp/api/v1beta1/azuremachinepool_webhook_test.go @@ -146,8 +146,11 @@ func TestAzureMachinePool_ValidateCreate(t *testing.T) { wantErr: true, }, { - name: "azuremachinepool with user assigned identity", - amp: createMachinePoolWithUserAssignedIdentity([]string{"azure:://id1", "azure:://id2"}), + name: "azuremachinepool with user assigned identity", + amp: createMachinePoolWithUserAssignedIdentity([]string{ + "azure:///subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/my-resource-group/providers/Microsoft.Compute/virtualMachines/default-20202-control-plane-7w265", + "azure:///subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/my-resource-group/providers/Microsoft.Compute/virtualMachines/default-20202-control-plane-a6b7d", + }), wantErr: false, }, { diff --git a/test/e2e/azure_edgezone.go b/test/e2e/azure_edgezone.go index b46409e29b5..3b8d1dba61e 100644 --- a/test/e2e/azure_edgezone.go +++ b/test/e2e/azure_edgezone.go @@ -28,7 +28,7 @@ import ( . "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" - "sigs.k8s.io/cluster-api-provider-azure/azure" + azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" "sigs.k8s.io/cluster-api/test/framework" "sigs.k8s.io/cluster-api/test/framework/clusterctl" @@ -85,7 +85,7 @@ func AzureEdgeZoneClusterSpec(ctx context.Context, inputGetter func() AzureEdgeZ vmClient.Authorizer = auth // get the resource group name - resource, err := azure.ParseResourceID(*machineList.Items[0].Spec.ProviderID) + resource, err := azureutil.ParseResourceID(*machineList.Items[0].Spec.ProviderID) Expect(err).NotTo(HaveOccurred()) vmListResults, err := vmClient.List(ctx, resource.ResourceGroupName, "") diff --git a/test/e2e/azure_logcollector.go b/test/e2e/azure_logcollector.go index 76260a89c0b..4d3eb51a115 100644 --- a/test/e2e/azure_logcollector.go +++ b/test/e2e/azure_logcollector.go @@ -400,7 +400,7 @@ func collectVMBootLog(ctx context.Context, am *infrav1.AzureMachine, outputPath return errors.New("AzureMachine provider ID is nil") } - resource, err := azure.ParseResourceID(*am.Spec.ProviderID) + resource, err := azureutil.ParseResourceID(*am.Spec.ProviderID) if err != nil { return errors.Wrap(err, "failed to parse resource id") } @@ -426,11 +426,11 @@ func collectVMBootLog(ctx context.Context, am *infrav1.AzureMachine, outputPath // collectVMSSBootLog collects boot logs of the scale set by using azure boot diagnostics. func collectVMSSBootLog(ctx context.Context, providerID string, outputPath string) error { - resourceID := strings.TrimPrefix(providerID, azure.ProviderIDPrefix) + resourceID := strings.TrimPrefix(providerID, azureutil.ProviderIDPrefix) v := strings.Split(resourceID, "/") instanceID := v[len(v)-1] resourceID = strings.TrimSuffix(resourceID, "/virtualMachines/"+instanceID) - resource, err := azure.ParseResourceID(resourceID) + resource, err := azureutil.ParseResourceID(resourceID) if err != nil { return errors.Wrap(err, "failed to parse resource id") } diff --git a/test/e2e/azure_privatecluster.go b/test/e2e/azure_privatecluster.go index ead2a558f16..a5009b563c6 100644 --- a/test/e2e/azure_privatecluster.go +++ b/test/e2e/azure_privatecluster.go @@ -420,7 +420,7 @@ func SetupExistingVNet(ctx context.Context, vnetCidr string, cpSubnetCidrs, node } func getAPIVersion(resourceID string) (string, error) { - parsed, err := azure.ParseResourceID(resourceID) + parsed, err := azureutil.ParseResourceID(resourceID) if err != nil { return "", errors.Wrap(err, fmt.Sprintf("unable to parse resource ID %q", resourceID)) } @@ -454,7 +454,7 @@ func getClientIDforMSI(resourceID string) string { msiClient := msi.NewUserAssignedIdentitiesClient(subscriptionID) msiClient.Authorizer = authorizer - parsed, err := azure.ParseResourceID(resourceID) + parsed, err := azureutil.ParseResourceID(resourceID) Expect(err).NotTo(HaveOccurred()) id, err := msiClient.Get(context.TODO(), parsed.ResourceGroupName, parsed.Name) diff --git a/test/e2e/azure_vmextensions.go b/test/e2e/azure_vmextensions.go index 7ed870bc38c..2b677ddf258 100644 --- a/test/e2e/azure_vmextensions.go +++ b/test/e2e/azure_vmextensions.go @@ -28,7 +28,6 @@ import ( . "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" - "sigs.k8s.io/cluster-api-provider-azure/azure" infrav1exp "sigs.k8s.io/cluster-api-provider-azure/exp/api/v1beta1" azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" @@ -96,7 +95,7 @@ func AzureVMExtensionsSpec(ctx context.Context, inputGetter func() AzureVMExtens vmExtensionsClient.Authorizer = auth // get the resource group name - resource, err := azure.ParseResourceID(*machineList.Items[0].Spec.ProviderID) + resource, err := azureutil.ParseResourceID(*machineList.Items[0].Spec.ProviderID) Expect(err).NotTo(HaveOccurred()) vmListResults, err := vmClient.List(ctx, resource.ResourceGroupName, "") @@ -143,7 +142,7 @@ func AzureVMExtensionsSpec(ctx context.Context, inputGetter func() AzureVMExtens vmssExtensionsClient.Authorizer = auth // get the resource group name - resource, err := azure.ParseResourceID(machinePoolList.Items[0].Spec.ProviderID) + resource, err := azureutil.ParseResourceID(machinePoolList.Items[0].Spec.ProviderID) Expect(err).NotTo(HaveOccurred()) vmssListResults, err := vmssClient.List(ctx, resource.ResourceGroupName) diff --git a/util/azure/azure.go b/util/azure/azure.go index 33d8d9f5238..757fc51a361 100644 --- a/util/azure/azure.go +++ b/util/azure/azure.go @@ -23,6 +23,7 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/go-autorest/autorest" @@ -37,6 +38,12 @@ import ( // AzureSystemNodeLabelPrefix is a standard node label prefix for Azure features, e.g., kubernetes.azure.com/scalesetpriority. const AzureSystemNodeLabelPrefix = "kubernetes.azure.com" +const ( + // ProviderIDPrefix will be appended to the beginning of Azure resource IDs to form the Kubernetes Provider ID. + // NOTE: this format matches the 2 slashes format used in cloud-provider and cluster-autoscaler. + ProviderIDPrefix = "azure://" +) + // IsAzureSystemNodeLabelKey is a helper function that determines whether a node label key is an Azure "system" label. func IsAzureSystemNodeLabelKey(labelKey string) bool { return strings.HasPrefix(labelKey, AzureSystemNodeLabelPrefix) @@ -126,3 +133,8 @@ func FindParentMachinePoolWithRetry(ampName string, cli client.Client, maxAttemp return p, nil } } + +// ParseResourceID parses a string to an *arm.ResourceID, first removing any "azure://" prefix. +func ParseResourceID(id string) (*arm.ResourceID, error) { + return arm.ParseResourceID(strings.TrimPrefix(id, ProviderIDPrefix)) +} diff --git a/util/azure/azure_test.go b/util/azure/azure_test.go index ffef1c6361a..fd52dc6d0d3 100644 --- a/util/azure/azure_test.go +++ b/util/azure/azure_test.go @@ -124,3 +124,55 @@ func (m mockClient) List(ctx context.Context, list client.ObjectList, opts ...cl return nil } + +func TestParseResourceID(t *testing.T) { + g := NewWithT(t) + + tests := []struct { + name string + id string + expectedName string + errExpected bool + }{ + { + name: "invalid", + id: "invalid", + expectedName: "", + errExpected: true, + }, + { + name: "invalid: must start with slash", + id: "subscriptions/123/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm", + expectedName: "", + errExpected: true, + }, + { + name: "invalid: must start with subscriptions or providers", + id: "/prescriptions/123/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm", + expectedName: "", + errExpected: true, + }, + { + name: "valid", + id: "/subscriptions/123/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm", + expectedName: "vm", + }, + { + name: "valid with provider prefix", + id: "azure:///subscriptions/123/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm", + expectedName: "vm", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resourceID, err := ParseResourceID(tt.id) + if tt.errExpected { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(resourceID.Name).To(Equal(tt.expectedName)) + } + }) + } +}