Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate providerID for user-assigned IDs in webhook #3618

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions api/v1beta1/azuremachine_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ import (
"github.com/google/uuid"
"golang.org/x/crypto/ssh"
"k8s.io/apimachinery/pkg/util/validation/field"
azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure"
)

// ValidateAzureMachineSpec check for validation errors of azuremachine.spec.
// ValidateAzureMachineSpec checks an AzureMachineSpec and returns any validation errors.
func ValidateAzureMachineSpec(spec AzureMachineSpec) field.ErrorList {
var allErrs field.ErrorList

Expand Down Expand Up @@ -128,9 +129,19 @@ func ValidateSystemAssignedIdentity(identityType VMIdentity, oldIdentity, newIde
func ValidateUserAssignedIdentity(identityType VMIdentity, userAssignedIdentities []UserAssignedIdentity, fldPath *field.Path) field.ErrorList {
allErrs := field.ErrorList{}

if identityType == VMIdentityUserAssigned && len(userAssignedIdentities) == 0 {
allErrs = append(allErrs, field.Required(fldPath, "must be specified for the 'UserAssigned' identity type"))
if identityType == VMIdentityUserAssigned {
if len(userAssignedIdentities) == 0 {
allErrs = append(allErrs, field.Required(fldPath, "must be specified for the 'UserAssigned' identity type"))
}
for _, identity := range userAssignedIdentities {
if identity.ProviderID != "" {
if _, err := azureutil.ParseResourceID(identity.ProviderID); err != nil {
allErrs = append(allErrs, field.Invalid(fldPath, identity.ProviderID, "must be a valid Azure resource ID"))
}
}
}
}

return allErrs
}

Expand Down
69 changes: 69 additions & 0 deletions api/v1beta1/azuremachine_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,75 @@ func TestAzureMachine_ValidateSystemAssignedIdentityRole(t *testing.T) {
}
}

func TestAzureMachine_ValidateUserAssignedIdentity(t *testing.T) {
g := NewWithT(t)

tests := []struct {
name string
idType VMIdentity
identities []UserAssignedIdentity
wantErr bool
}{
{
name: "empty identity list",
idType: VMIdentityUserAssigned,
identities: []UserAssignedIdentity{},
wantErr: true,
},
{
name: "invalid: providerID must start with slash",
idType: VMIdentityUserAssigned,
identities: []UserAssignedIdentity{
{
ProviderID: "subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/my-resource-group/providers/Microsoft.Compute/virtualMachines/default-20202-control-plane-7w265",
},
},
wantErr: true,
},
{
name: "invalid: providerID must start with subscriptions or providers",
idType: VMIdentityUserAssigned,
identities: []UserAssignedIdentity{
{
ProviderID: "azure:///prescriptions/00000000-0000-0000-0000-000000000000/resourceGroups/my-resource-group/providers/Microsoft.Compute/virtualMachines/default-20202-control-plane-7w265",
},
},
wantErr: true,
},
{
name: "valid",
idType: VMIdentityUserAssigned,
identities: []UserAssignedIdentity{
{
ProviderID: "/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/my-resource-group/providers/Microsoft.Compute/virtualMachines/default-20202-control-plane-7w265",
},
},
wantErr: false,
},
{
name: "valid with provider prefix",
idType: VMIdentityUserAssigned,
identities: []UserAssignedIdentity{
{
ProviderID: "azure:///subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/my-resource-group/providers/Microsoft.Compute/virtualMachines/default-20202-control-plane-7w265",
},
},
wantErr: false,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
errs := ValidateUserAssignedIdentity(tc.idType, tc.identities, field.NewPath("userAssignedIdentities"))
if tc.wantErr {
g.Expect(errs).NotTo(BeEmpty())
} else {
g.Expect(errs).To(BeEmpty())
}
})
}
}

func TestAzureMachine_ValidateDataDisksUpdate(t *testing.T) {
g := NewWithT(t)

Expand Down
7 changes: 5 additions & 2 deletions api/v1beta1/azuremachine_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,11 @@ func TestAzureMachine_ValidateCreate(t *testing.T) {
wantErr: true,
},
{
name: "azuremachine with list of user-assigned identities",
machine: createMachineWithUserAssignedIdentities([]UserAssignedIdentity{{ProviderID: "azure:///123"}, {ProviderID: "azure:///456"}}),
name: "azuremachine with list of user-assigned identities",
machine: createMachineWithUserAssignedIdentities([]UserAssignedIdentity{
{ProviderID: "azure:///subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/my-resource-group/providers/Microsoft.Compute/virtualMachines/default-12345-control-plane-9d5x5"},
{ProviderID: "azure:///subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/my-resource-group/providers/Microsoft.Compute/virtualMachines/default-12345-control-plane-a1b2c"},
}),
wantErr: false,
},
{
Expand Down
5 changes: 4 additions & 1 deletion api/v1beta1/azuremachinetemplate_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ func TestAzureMachineTemplate_ValidateCreate(t *testing.T) {
{
name: "azuremachinetemplate with list of user-assigned identities",
machineTemplate: createAzureMachineTemplateFromMachine(
createMachineWithUserAssignedIdentities([]UserAssignedIdentity{{ProviderID: "azure:///123"}, {ProviderID: "azure:///456"}}),
createMachineWithUserAssignedIdentities([]UserAssignedIdentity{
{ProviderID: "azure:///subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/my-resource-group/providers/Microsoft.Compute/virtualMachines/default-09091-control-plane-f1b2c"},
{ProviderID: "azure:///subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/my-resource-group/providers/Microsoft.Compute/virtualMachines/default-09091-control-plane-9a8b7"},
}),
),
wantErr: false,
},
Expand Down
4 changes: 2 additions & 2 deletions azure/converters/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2021-11-01/compute"
"github.com/pkg/errors"
infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1"
"sigs.k8s.io/cluster-api-provider-azure/azure"
azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure"
)

// ErrUserAssignedIdentitiesNotFound is the error thrown when user assigned identities is not passed with the identity type being UserAssigned.
Expand Down Expand Up @@ -84,5 +84,5 @@ func UserAssignedIdentitiesToVMSSSDK(identities []infrav1.UserAssignedIdentity)

// sanitized removes "azure://" prefix from the given id.
func sanitized(id string) string {
return strings.TrimPrefix(id, azure.ProviderIDPrefix)
return strings.TrimPrefix(id, azureutil.ProviderIDPrefix)
}
13 changes: 0 additions & 13 deletions azure/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ package azure
import (
"fmt"
"net/http"
"strings"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/Azure/go-autorest/autorest"
"sigs.k8s.io/cluster-api-provider-azure/util/tele"
"sigs.k8s.io/cluster-api-provider-azure/version"
Expand Down Expand Up @@ -93,12 +91,6 @@ const (
bootstrapSentinelFile = "/run/cluster-api/bootstrap-success.complete"
)

const (
// ProviderIDPrefix will be appended to the beginning of Azure resource IDs to form the Kubernetes Provider ID.
// NOTE: this format matches the 2 slashes format used in cloud-provider and cluster-autoscaler.
ProviderIDPrefix = "azure://"
)

const (
// CustomHeaderPrefix is the prefix of annotations that enable additional cluster / node pool features.
// Whatever follows the prefix will be passed as a header to cluster/node pool creation/update requests.
Expand Down Expand Up @@ -363,8 +355,3 @@ func msCorrelationIDSendDecorator(snd autorest.Sender) autorest.Sender {
return snd.Do(r)
})
}

// ParseResourceID parses a string to an *arm.ResourceID, first removing any "azure://" prefix.
func ParseResourceID(id string) (*arm.ResourceID, error) {
return arm.ParseResourceID(strings.TrimPrefix(id, ProviderIDPrefix))
}
52 changes: 0 additions & 52 deletions azure/defaults_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,55 +117,3 @@ func TestMSCorrelationIDSendDecorator(t *testing.T) {
receivedReq.Header.Get(string(tele.CorrIDKeyVal)),
).To(Equal(string(corrID)))
}

func TestParseResourceID(t *testing.T) {
g := NewWithT(t)

tests := []struct {
name string
id string
expectedName string
errExpected bool
}{
{
name: "invalid",
id: "invalid",
expectedName: "",
errExpected: true,
},
{
name: "invalid: must start with slash",
id: "subscriptions/123/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm",
expectedName: "",
errExpected: true,
},
{
name: "invalid: must start with subscriptions or providers",
id: "/prescriptions/123/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm",
expectedName: "",
errExpected: true,
},
{
name: "valid",
id: "/subscriptions/123/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm",
expectedName: "vm",
},
{
name: "valid with provider prefix",
id: "azure:///subscriptions/123/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm",
expectedName: "vm",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resourceID, err := ParseResourceID(tt.id)
if tt.errExpected {
g.Expect(err).To(HaveOccurred())
} else {
g.Expect(err).NotTo(HaveOccurred())
g.Expect(resourceID.Name).To(Equal(tt.expectedName))
}
})
}
}
3 changes: 2 additions & 1 deletion azure/scope/machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"sigs.k8s.io/cluster-api-provider-azure/azure/services/virtualmachineimages"
"sigs.k8s.io/cluster-api-provider-azure/azure/services/virtualmachines"
"sigs.k8s.io/cluster-api-provider-azure/azure/services/vmextensions"
azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure"
"sigs.k8s.io/cluster-api-provider-azure/util/futures"
"sigs.k8s.io/cluster-api-provider-azure/util/tele"
clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1"
Expand Down Expand Up @@ -452,7 +453,7 @@ func (m *MachineScope) Role() string {

// GetVMID returns the AzureMachine instance id by parsing the scope's providerID.
func (m *MachineScope) GetVMID() string {
resourceID, err := azure.ParseResourceID(m.ProviderID())
resourceID, err := azureutil.ParseResourceID(m.ProviderID())
if err != nil {
return ""
}
Expand Down
5 changes: 3 additions & 2 deletions azure/scope/machinepool.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"sigs.k8s.io/cluster-api-provider-azure/azure/services/scalesets"
"sigs.k8s.io/cluster-api-provider-azure/azure/services/virtualmachineimages"
infrav1exp "sigs.k8s.io/cluster-api-provider-azure/exp/api/v1beta1"
azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure"
"sigs.k8s.io/cluster-api-provider-azure/util/futures"
"sigs.k8s.io/cluster-api-provider-azure/util/tele"
clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1"
Expand Down Expand Up @@ -154,7 +155,7 @@ func (m *MachinePoolScope) Name() string {

// ProviderID returns the AzureMachinePool ID by parsing Spec.ProviderID.
func (m *MachinePoolScope) ProviderID() string {
resourceID, err := azure.ParseResourceID(m.AzureMachinePool.Spec.ProviderID)
resourceID, err := azureutil.ParseResourceID(m.AzureMachinePool.Spec.ProviderID)
if err != nil {
return ""
}
Expand Down Expand Up @@ -376,7 +377,7 @@ func (m *MachinePoolScope) createMachine(ctx context.Context, machine azure.VMSS
ctx, _, done := tele.StartSpanWithLogger(ctx, "scope.MachinePoolScope.createMachine")
defer done()

parsed, err := azure.ParseResourceID(machine.ID)
parsed, err := azureutil.ParseResourceID(machine.ID)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("failed to parse resource id %q", machine.ID))
}
Expand Down
3 changes: 2 additions & 1 deletion azure/services/identities/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/Azure/azure-sdk-for-go/services/msi/mgmt/2018-11-30/msi"
"github.com/Azure/go-autorest/autorest"
"sigs.k8s.io/cluster-api-provider-azure/azure"
azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure"
"sigs.k8s.io/cluster-api-provider-azure/util/tele"
)

Expand Down Expand Up @@ -62,7 +63,7 @@ func (ac *AzureClient) GetClientID(ctx context.Context, providerID string) (stri
ctx, _, done := tele.StartSpanWithLogger(ctx, "identities.GetClientID")
defer done()

parsed, err := azure.ParseResourceID(providerID)
parsed, err := azureutil.ParseResourceID(providerID)
if err != nil {
return "", err
}
Expand Down
3 changes: 2 additions & 1 deletion azure/services/natgateways/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1"
"sigs.k8s.io/cluster-api-provider-azure/azure"
"sigs.k8s.io/cluster-api-provider-azure/azure/converters"
azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure"
)

// NatGatewaySpec defines the specification for a NAT gateway.
Expand Down Expand Up @@ -96,7 +97,7 @@ func hasPublicIP(natGateway network.NatGateway, publicIPName string) bool {
}

for _, publicIP := range *natGateway.PublicIPAddresses {
resource, err := azure.ParseResourceID(*publicIP.ID)
resource, err := azureutil.ParseResourceID(*publicIP.ID)
if err != nil {
continue
}
Expand Down
3 changes: 2 additions & 1 deletion azure/services/scalesets/scalesets.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"sigs.k8s.io/cluster-api-provider-azure/azure"
"sigs.k8s.io/cluster-api-provider-azure/azure/converters"
"sigs.k8s.io/cluster-api-provider-azure/azure/services/resourceskus"
azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure"
"sigs.k8s.io/cluster-api-provider-azure/util/generators"
"sigs.k8s.io/cluster-api-provider-azure/util/slice"
"sigs.k8s.io/cluster-api-provider-azure/util/tele"
Expand Down Expand Up @@ -111,7 +112,7 @@ func (s *Service) Reconcile(ctx context.Context) (retErr error) {

if fetchedVMSS != nil {
// Transform the VMSS resource representation to conform to the cloud-provider-azure representation
providerID, err := azprovider.ConvertResourceGroupNameToLower(azure.ProviderIDPrefix + fetchedVMSS.ID)
providerID, err := azprovider.ConvertResourceGroupNameToLower(azureutil.ProviderIDPrefix + fetchedVMSS.ID)
if err != nil {
log.Error(err, "failed to parse VMSS ID", "ID", fetchedVMSS.ID)
}
Expand Down
7 changes: 4 additions & 3 deletions azure/services/scalesets/scalesets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"sigs.k8s.io/cluster-api-provider-azure/azure/services/resourceskus"
"sigs.k8s.io/cluster-api-provider-azure/azure/services/scalesets/mock_scalesets"
gomockinternal "sigs.k8s.io/cluster-api-provider-azure/internal/test/matchers/gomock"
azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure"
clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1"
)

Expand Down Expand Up @@ -1560,7 +1561,7 @@ func setupDefaultVMSSInProgressOperationDoneExpectations(s *mock_scalesets.MockS
m.ListInstances(gomockinternal.AContext(), defaultResourceGroup, defaultVMSSName).Return(instances, nil).AnyTimes()
s.MaxSurge().Return(1, nil)
s.SetVMSSState(gomock.Any())
s.SetProviderID(azure.ProviderIDPrefix + *createdVMSS.ID)
s.SetProviderID(azureutil.ProviderIDPrefix + *createdVMSS.ID)
}

func setupDefaultVMSSStartCreatingExpectations(s *mock_scalesets.MockScaleSetScopeMockRecorder, m *mock_scalesets.MockClientMockRecorder) {
Expand All @@ -1577,7 +1578,7 @@ func setupCreatingSucceededExpectations(s *mock_scalesets.MockScaleSetScopeMockR
m.Get(gomockinternal.AContext(), defaultResourceGroup, defaultVMSSName).Return(vmss, nil)
m.ListInstances(gomockinternal.AContext(), defaultResourceGroup, defaultVMSSName).Return(newDefaultInstances(), nil).AnyTimes()
s.SetVMSSState(gomock.Any())
s.SetProviderID(azure.ProviderIDPrefix + *vmss.ID)
s.SetProviderID(azureutil.ProviderIDPrefix + *vmss.ID)
}

func setupDefaultVMSSExpectations(s *mock_scalesets.MockScaleSetScopeMockRecorder) {
Expand Down Expand Up @@ -1641,7 +1642,7 @@ func setupVMSSExpectationsWithoutVMImage(s *mock_scalesets.MockScaleSetScopeMock

func setupDefaultVMSSUpdateExpectations(s *mock_scalesets.MockScaleSetScopeMockRecorder) {
setupUpdateVMSSExpectations(s)
s.SetProviderID(azure.ProviderIDPrefix + "subscriptions/1234/resourceGroups/my_resource_group/providers/Microsoft.Compute/virtualMachines/my-vm")
s.SetProviderID(azureutil.ProviderIDPrefix + "subscriptions/1234/resourceGroups/my_resource_group/providers/Microsoft.Compute/virtualMachines/my-vm")
s.GetLongRunningOperationState(defaultVMSSName, serviceName, infrav1.PutFuture).Return(nil)
s.GetLongRunningOperationState(defaultVMSSName, serviceName, infrav1.PatchFuture).Return(nil)
s.MaxSurge().Return(1, nil)
Expand Down
Loading