diff --git a/.golangci.yml b/.golangci.yml index a2fd85d04c6..b2b5ba8d552 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -94,6 +94,8 @@ linters-settings: alias: infrav1alpha4exp - pkg: sigs.k8s.io/cluster-api-provider-azure/exp/api/v1beta1 alias: infrav1exp + - pkg: sigs.k8s.io/cluster-api-provider-azure/util/webhook + alias: webhookutils gocritic: enabled-tags: - "experimental" diff --git a/api/v1beta1/azurecluster_webhook.go b/api/v1beta1/azurecluster_webhook.go index 9743f4f6b57..94b7b15a9c8 100644 --- a/api/v1beta1/azurecluster_webhook.go +++ b/api/v1beta1/azurecluster_webhook.go @@ -22,6 +22,7 @@ import ( apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/validation/field" + webhookutils "sigs.k8s.io/cluster-api-provider-azure/util/webhook" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/webhook" ) @@ -54,25 +55,25 @@ func (c *AzureCluster) ValidateUpdate(oldRaw runtime.Object) error { var allErrs field.ErrorList old := oldRaw.(*AzureCluster) - if !reflect.DeepEqual(c.Spec.ResourceGroup, old.Spec.ResourceGroup) { - allErrs = append(allErrs, - field.Invalid(field.NewPath("spec", "ResourceGroup"), - c.Spec.ResourceGroup, "field is immutable"), - ) + if err := webhookutils.ValidateStringImmutable( + field.NewPath("Spec", "ResourceGroup"), + old.Spec.ResourceGroup, + c.Spec.ResourceGroup); err != nil { + allErrs = append(allErrs, err) } - if !reflect.DeepEqual(c.Spec.SubscriptionID, old.Spec.SubscriptionID) { - allErrs = append(allErrs, - field.Invalid(field.NewPath("spec", "SubscriptionID"), - c.Spec.SubscriptionID, "field is immutable"), - ) + if err := webhookutils.ValidateStringImmutable( + field.NewPath("Spec", "SubscriptionID"), + old.Spec.SubscriptionID, + c.Spec.SubscriptionID); err != nil { + allErrs = append(allErrs, err) } - if !reflect.DeepEqual(c.Spec.Location, old.Spec.Location) { - allErrs = append(allErrs, - field.Invalid(field.NewPath("spec", "Location"), - c.Spec.Location, "field is immutable"), - ) + if err := webhookutils.ValidateStringImmutable( + field.NewPath("Spec", "Location"), + old.Spec.Location, + c.Spec.Location); err != nil { + allErrs = append(allErrs, err) } if old.Spec.ControlPlaneEndpoint.Host != "" && c.Spec.ControlPlaneEndpoint.Host != old.Spec.ControlPlaneEndpoint.Host { diff --git a/api/v1beta1/azuremachine_webhook.go b/api/v1beta1/azuremachine_webhook.go index bfc31d968a8..3bcf6690c42 100644 --- a/api/v1beta1/azuremachine_webhook.go +++ b/api/v1beta1/azuremachine_webhook.go @@ -22,6 +22,7 @@ import ( apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/validation/field" + webhookutils "sigs.k8s.io/cluster-api-provider-azure/util/webhook" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/webhook" ) @@ -81,11 +82,11 @@ func (m *AzureMachine) ValidateUpdate(oldRaw runtime.Object) error { ) } - if !reflect.DeepEqual(m.Spec.RoleAssignmentName, old.Spec.RoleAssignmentName) { - allErrs = append(allErrs, - field.Invalid(field.NewPath("spec", "roleAssignmentName"), - m.Spec.RoleAssignmentName, "field is immutable"), - ) + if err := webhookutils.ValidateStringImmutable( + field.NewPath("Spec", "RoleAssignmentName"), + old.Spec.RoleAssignmentName, + m.Spec.RoleAssignmentName); err != nil { + allErrs = append(allErrs, err) } if !reflect.DeepEqual(m.Spec.OSDisk, old.Spec.OSDisk) { @@ -102,11 +103,11 @@ func (m *AzureMachine) ValidateUpdate(oldRaw runtime.Object) error { ) } - if !reflect.DeepEqual(m.Spec.SSHPublicKey, old.Spec.SSHPublicKey) { - allErrs = append(allErrs, - field.Invalid(field.NewPath("spec", "sshPublicKey"), - m.Spec.SSHPublicKey, "field is immutable"), - ) + if err := webhookutils.ValidateStringImmutable( + field.NewPath("Spec", "SSHPublicKey"), + old.Spec.SSHPublicKey, + m.Spec.SSHPublicKey); err != nil { + allErrs = append(allErrs, err) } if !reflect.DeepEqual(m.Spec.AllocatePublicIP, old.Spec.AllocatePublicIP) { @@ -123,11 +124,11 @@ func (m *AzureMachine) ValidateUpdate(oldRaw runtime.Object) error { ) } - if !reflect.DeepEqual(m.Spec.AcceleratedNetworking, old.Spec.AcceleratedNetworking) { - allErrs = append(allErrs, - field.Invalid(field.NewPath("spec", "acceleratedNetworking"), - m.Spec.AcceleratedNetworking, "field is immutable"), - ) + if err := webhookutils.ValidateBoolPtrImmutable( + field.NewPath("Spec", "AcceleratedNetworking"), + old.Spec.AcceleratedNetworking, + m.Spec.AcceleratedNetworking); err != nil { + allErrs = append(allErrs, err) } if !reflect.DeepEqual(m.Spec.SpotVMOptions, old.Spec.SpotVMOptions) { diff --git a/exp/api/v1beta1/azuremanagedcontrolplane_webhook.go b/exp/api/v1beta1/azuremanagedcontrolplane_webhook.go index e30c19efc4a..d2ddd55a2b5 100644 --- a/exp/api/v1beta1/azuremanagedcontrolplane_webhook.go +++ b/exp/api/v1beta1/azuremanagedcontrolplane_webhook.go @@ -30,6 +30,7 @@ import ( kerrors "k8s.io/apimachinery/pkg/util/errors" "k8s.io/apimachinery/pkg/util/validation/field" infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" + webhookutils "sigs.k8s.io/cluster-api-provider-azure/util/webhook" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" @@ -88,131 +89,74 @@ func (m *AzureManagedControlPlane) ValidateUpdate(oldRaw runtime.Object, client var allErrs field.ErrorList old := oldRaw.(*AzureManagedControlPlane) - if m.Name != old.Name { - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Name"), - m.Name, - "field is immutable")) + if err := webhookutils.ValidateStringImmutable( + field.NewPath("Name"), + old.Name, + m.Name); err != nil { + allErrs = append(allErrs, err) } - if m.Spec.SubscriptionID != old.Spec.SubscriptionID { - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Spec", "SubscriptionID"), - m.Spec.SubscriptionID, - "field is immutable")) + if err := webhookutils.ValidateStringImmutable( + field.NewPath("Spec", "SubscriptionID"), + old.Spec.SubscriptionID, + m.Spec.SubscriptionID); err != nil { + allErrs = append(allErrs, err) } - if m.Spec.ResourceGroupName != old.Spec.ResourceGroupName { - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Spec", "ResourceGroupName"), - m.Spec.ResourceGroupName, - "field is immutable")) + if err := webhookutils.ValidateStringImmutable( + field.NewPath("Spec", "ResourceGroupName"), + old.Spec.ResourceGroupName, + m.Spec.ResourceGroupName); err != nil { + allErrs = append(allErrs, err) } - if m.Spec.NodeResourceGroupName != old.Spec.NodeResourceGroupName { - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Spec", "NodeResourceGroupName"), - m.Spec.NodeResourceGroupName, - "field is immutable")) + if err := webhookutils.ValidateStringImmutable( + field.NewPath("Spec", "NodeResourceGroupName"), + old.Spec.NodeResourceGroupName, + m.Spec.NodeResourceGroupName); err != nil { + allErrs = append(allErrs, err) } - if m.Spec.Location != old.Spec.Location { - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Spec", "Location"), - m.Spec.Location, - "field is immutable")) + if err := webhookutils.ValidateStringImmutable( + field.NewPath("Spec", "Location"), + old.Spec.Location, + m.Spec.Location); err != nil { + allErrs = append(allErrs, err) } - if old.Spec.SSHPublicKey != "" { - // Prevent SSH key modification if it was already set to some value - if m.Spec.SSHPublicKey != old.Spec.SSHPublicKey { - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Spec", "SSHPublicKey"), - m.Spec.SSHPublicKey, - "field is immutable")) - } + if err := webhookutils.ValidateStringImmutable( + field.NewPath("Spec", "SSHPublicKey"), + old.Spec.SSHPublicKey, + m.Spec.SSHPublicKey); err != nil { + allErrs = append(allErrs, err) } - if old.Spec.DNSServiceIP != nil { - // Prevent DNSServiceIP modification if it was already set to some value - if m.Spec.DNSServiceIP == nil { - // unsetting the field is not allowed - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Spec", "DNSServiceIP"), - m.Spec.DNSServiceIP, - "field is immutable, unsetting is not allowed")) - } else if *m.Spec.DNSServiceIP != *old.Spec.DNSServiceIP { - // changing the field is not allowed - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Spec", "DNSServiceIP"), - *m.Spec.DNSServiceIP, - "field is immutable")) - } + if err := webhookutils.ValidateStringPtrImmutable( + field.NewPath("Spec", "DNSServiceIP"), + old.Spec.DNSServiceIP, + m.Spec.DNSServiceIP); err != nil { + allErrs = append(allErrs, err) } - if old.Spec.NetworkPlugin != nil { - // Prevent NetworkPlugin modification if it was already set to some value - if m.Spec.NetworkPlugin == nil { - // unsetting the field is not allowed - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Spec", "NetworkPlugin"), - m.Spec.NetworkPlugin, - "field is immutable, unsetting is not allowed")) - } else if *m.Spec.NetworkPlugin != *old.Spec.NetworkPlugin { - // changing the field is not allowed - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Spec", "NetworkPlugin"), - *m.Spec.NetworkPlugin, - "field is immutable")) - } + if err := webhookutils.ValidateStringPtrImmutable( + field.NewPath("Spec", "NetworkPlugin"), + old.Spec.NetworkPlugin, + m.Spec.NetworkPlugin); err != nil { + allErrs = append(allErrs, err) } - if old.Spec.NetworkPolicy != nil { - // Prevent NetworkPolicy modification if it was already set to some value - if m.Spec.NetworkPolicy == nil { - // unsetting the field is not allowed - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Spec", "NetworkPolicy"), - m.Spec.NetworkPolicy, - "field is immutable, unsetting is not allowed")) - } else if *m.Spec.NetworkPolicy != *old.Spec.NetworkPolicy { - // changing the field is not allowed - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Spec", "NetworkPolicy"), - *m.Spec.NetworkPolicy, - "field is immutable")) - } + if err := webhookutils.ValidateStringPtrImmutable( + field.NewPath("Spec", "NetworkPolicy"), + old.Spec.NetworkPolicy, + m.Spec.NetworkPolicy); err != nil { + allErrs = append(allErrs, err) } - if old.Spec.LoadBalancerSKU != nil { - // Prevent LoadBalancerSKU modification if it was already set to some value - if m.Spec.LoadBalancerSKU == nil { - // unsetting the field is not allowed - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Spec", "LoadBalancerSKU"), - m.Spec.LoadBalancerSKU, - "field is immutable, unsetting is not allowed")) - } else if *m.Spec.LoadBalancerSKU != *old.Spec.LoadBalancerSKU { - // changing the field is not allowed - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Spec", "LoadBalancerSKU"), - *m.Spec.LoadBalancerSKU, - "field is immutable")) - } + if err := webhookutils.ValidateStringPtrImmutable( + field.NewPath("Spec", "LoadBalancerSKU"), + old.Spec.LoadBalancerSKU, + m.Spec.LoadBalancerSKU); err != nil { + allErrs = append(allErrs, err) } if old.Spec.AADProfile != nil { diff --git a/exp/api/v1beta1/azuremanagedcontrolplane_webhook_test.go b/exp/api/v1beta1/azuremanagedcontrolplane_webhook_test.go index b78ca9ec9f5..f341dbe7f74 100644 --- a/exp/api/v1beta1/azuremanagedcontrolplane_webhook_test.go +++ b/exp/api/v1beta1/azuremanagedcontrolplane_webhook_test.go @@ -353,18 +353,6 @@ func TestAzureManagedControlPlane_ValidateUpdate(t *testing.T) { amcp *AzureManagedControlPlane wantErr bool }{ - { - name: "AzureManagedControlPlane with valid SSHPublicKey", - oldAMCP: createAzureManagedControlPlane("192.168.0.0", "v1.18.0", ""), - amcp: createAzureManagedControlPlane("192.168.0.0", "v1.18.0", generateSSHPublicKey(true)), - wantErr: false, - }, - { - name: "AzureManagedControlPlane with invalid SSHPublicKey", - oldAMCP: createAzureManagedControlPlane("192.168.0.0", "v1.18.0", ""), - amcp: createAzureManagedControlPlane("192.168.0.0", "v1.18.0", generateSSHPublicKey(false)), - wantErr: true, - }, { name: "AzureManagedControlPlane with invalid serviceIP", oldAMCP: createAzureManagedControlPlane("", "v1.18.0", ""), diff --git a/exp/api/v1beta1/azuremanagedmachinepool_webhook.go b/exp/api/v1beta1/azuremanagedmachinepool_webhook.go index 7186f13a231..fb44c5e7cef 100644 --- a/exp/api/v1beta1/azuremanagedmachinepool_webhook.go +++ b/exp/api/v1beta1/azuremanagedmachinepool_webhook.go @@ -31,6 +31,7 @@ import ( "sigs.k8s.io/cluster-api-provider-azure/azure" azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" "sigs.k8s.io/cluster-api-provider-azure/util/maps" + webhookutils "sigs.k8s.io/cluster-api-provider-azure/util/webhook" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -91,38 +92,25 @@ func (m *AzureManagedMachinePool) ValidateUpdate(oldRaw runtime.Object, client c err.Error())) } - if err := validateStringPtrImmutable( + if err := webhookutils.ValidateStringPtrImmutable( field.NewPath("Spec", "OSType"), old.Spec.OSType, m.Spec.OSType); err != nil { allErrs = append(allErrs, err) } - if m.Spec.SKU != old.Spec.SKU { - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Spec", "SKU"), - m.Spec.SKU, - "field is immutable")) + if err := webhookutils.ValidateStringImmutable( + field.NewPath("Spec", "SKU"), + old.Spec.SKU, + m.Spec.SKU); err != nil { + allErrs = append(allErrs, err) } - if old.Spec.OSDiskSizeGB != nil { - // Prevent OSDiskSizeGB modification if it was already set to some value - if m.Spec.OSDiskSizeGB == nil { - // unsetting the field is not allowed - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Spec", "OSDiskSizeGB"), - m.Spec.OSDiskSizeGB, - "field is immutable, unsetting is not allowed")) - } else if *m.Spec.OSDiskSizeGB != *old.Spec.OSDiskSizeGB { - // changing the field is not allowed - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Spec", "OSDiskSizeGB"), - *m.Spec.OSDiskSizeGB, - "field is immutable")) - } + if err := webhookutils.ValidateInt32PtrImmutable( + field.NewPath("Spec", "OSDiskSizeGB"), + old.Spec.OSDiskSizeGB, + m.Spec.OSDiskSizeGB); err != nil { + allErrs = append(allErrs, err) } // custom headers are immutable @@ -136,7 +124,7 @@ func (m *AzureManagedMachinePool) ValidateUpdate(oldRaw runtime.Object, client c fmt.Sprintf("annotations with '%s' prefix are immutable", azure.CustomHeaderPrefix))) } - if !ensureStringSlicesAreEqual(m.Spec.AvailabilityZones, old.Spec.AvailabilityZones) { + if !webhookutils.EnsureStringSlicesAreEquivalent(m.Spec.AvailabilityZones, old.Spec.AvailabilityZones) { allErrs = append(allErrs, field.Invalid( field.NewPath("Spec", "AvailabilityZones"), @@ -153,64 +141,41 @@ func (m *AzureManagedMachinePool) ValidateUpdate(oldRaw runtime.Object, client c } } - if old.Spec.MaxPods != nil { - // Prevent MaxPods modification if it was already set to some value - if m.Spec.MaxPods == nil { - // unsetting the field is not allowed - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Spec", "MaxPods"), - m.Spec.MaxPods, - "field is immutable, unsetting is not allowed")) - } else if *m.Spec.MaxPods != *old.Spec.MaxPods { - // changing the field is not allowed - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Spec", "MaxPods"), - *m.Spec.MaxPods, - "field is immutable")) - } + if err := webhookutils.ValidateInt32PtrImmutable( + field.NewPath("Spec", "MaxPods"), + old.Spec.MaxPods, + m.Spec.MaxPods); err != nil { + allErrs = append(allErrs, err) } - if old.Spec.OsDiskType != nil { - // Prevent OSDiskType modification if it was already set to some value - if m.Spec.OsDiskType == nil || to.String(m.Spec.OsDiskType) == "" { - // unsetting the field is not allowed - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Spec", "OsDiskType"), - m.Spec.OsDiskType, - "field is immutable, unsetting is not allowed")) - } else if *m.Spec.OsDiskType != *old.Spec.OsDiskType { - // changing the field is not allowed - allErrs = append(allErrs, - field.Invalid( - field.NewPath("Spec", "OsDiskType"), - m.Spec.OsDiskType, - "field is immutable")) - } + if err := webhookutils.ValidateStringPtrImmutable( + field.NewPath("Spec", "OsDiskType"), + old.Spec.OsDiskType, + m.Spec.OsDiskType); err != nil { + allErrs = append(allErrs, err) } - if !reflect.DeepEqual(m.Spec.ScaleSetPriority, old.Spec.ScaleSetPriority) { - allErrs = append(allErrs, - field.Invalid(field.NewPath("Spec", "ScaleSetPriority"), - m.Spec.ScaleSetPriority, "field is immutable"), - ) + if err := webhookutils.ValidateStringPtrImmutable( + field.NewPath("Spec", "ScaleSetPriority"), + old.Spec.ScaleSetPriority, + m.Spec.ScaleSetPriority); err != nil { + allErrs = append(allErrs, err) } - if err := validateBoolPtrImmutable( + if err := webhookutils.ValidateBoolPtrImmutable( field.NewPath("Spec", "EnableUltraSSD"), old.Spec.EnableUltraSSD, m.Spec.EnableUltraSSD); err != nil { allErrs = append(allErrs, err) } - if err := validateBoolPtrImmutable( + + if err := webhookutils.ValidateBoolPtrImmutable( field.NewPath("Spec", "EnableNodePublicIP"), old.Spec.EnableNodePublicIP, m.Spec.EnableNodePublicIP); err != nil { allErrs = append(allErrs, err) } - if err := validateStringPtrImmutable( + if err := webhookutils.ValidateStringPtrImmutable( field.NewPath("Spec", "NodePublicIPPrefixID"), old.Spec.NodePublicIPPrefixID, m.Spec.NodePublicIPPrefixID); err != nil { @@ -346,57 +311,3 @@ func (m *AzureManagedMachinePool) validateEnableNodePublicIP() error { } return nil } - -func ensureStringSlicesAreEqual(a []string, b []string) bool { - if len(a) != len(b) { - return false - } - - m := map[string]bool{} - for _, v := range a { - m[v] = true - } - - for _, v := range b { - if _, ok := m[v]; !ok { - return false - } - } - return true -} - -func validateBoolPtrImmutable(path *field.Path, oldVal, newVal *bool) *field.Error { - if oldVal != nil { - // Prevent modification if it was already set to some value - if newVal == nil { - // unsetting the field is not allowed - return field.Invalid(path, newVal, "field is immutable, unsetting is not allowed") - } - if *newVal != *oldVal { - // changing the field is not allowed - return field.Invalid(path, newVal, "field is immutable") - } - } else if newVal != nil { - return field.Invalid(path, newVal, "field is immutable, setting is not allowed") - } - - return nil -} - -func validateStringPtrImmutable(path *field.Path, oldVal, newVal *string) *field.Error { - if oldVal != nil { - // Prevent modification if it was already set to some value - if newVal == nil { - // unsetting the field is not allowed - return field.Invalid(path, newVal, "field is immutable, unsetting is not allowed") - } - if *newVal != *oldVal { - // changing the field is not allowed - return field.Invalid(path, newVal, "field is immutable") - } - } else if newVal != nil { - return field.Invalid(path, newVal, "field is immutable, setting is not allowed") - } - - return nil -} diff --git a/exp/api/v1beta1/azuremanagedmachinepool_webhook_test.go b/exp/api/v1beta1/azuremanagedmachinepool_webhook_test.go index 17951fb969f..5fedc1b2e71 100644 --- a/exp/api/v1beta1/azuremanagedmachinepool_webhook_test.go +++ b/exp/api/v1beta1/azuremanagedmachinepool_webhook_test.go @@ -25,6 +25,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/cluster-api-provider-azure/azure" + webhookutils "sigs.k8s.io/cluster-api-provider-azure/util/webhook" "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -557,7 +558,7 @@ func TestValidateBoolPtrImmutable(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { g := NewWithT(t) - err := validateBoolPtrImmutable(field.NewPath("test"), test.oldVal, test.newVal) + err := webhookutils.ValidateBoolPtrImmutable(field.NewPath("test"), test.oldVal, test.newVal) if test.wantErr { g.Expect(err).To(HaveOccurred()) } else { diff --git a/main.go b/main.go index 494c86fdb28..d00b4a0f519 100644 --- a/main.go +++ b/main.go @@ -48,7 +48,7 @@ import ( "sigs.k8s.io/cluster-api-provider-azure/pkg/coalescing" "sigs.k8s.io/cluster-api-provider-azure/pkg/ot" "sigs.k8s.io/cluster-api-provider-azure/util/reconciler" - "sigs.k8s.io/cluster-api-provider-azure/util/webhook" + webhookutils "sigs.k8s.io/cluster-api-provider-azure/util/webhook" "sigs.k8s.io/cluster-api-provider-azure/version" clusterv1alpha4 "sigs.k8s.io/cluster-api/api/v1alpha4" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" @@ -525,16 +525,16 @@ func registerWebhooks(mgr manager.Manager) { if feature.Gates.Enabled(feature.AKS) { hookServer := mgr.GetWebhookServer() - hookServer.Register("/mutate-infrastructure-cluster-x-k8s-io-v1beta1-azuremanagedmachinepool", webhook.NewMutatingWebhook( + hookServer.Register("/mutate-infrastructure-cluster-x-k8s-io-v1beta1-azuremanagedmachinepool", webhookutils.NewMutatingWebhook( &infrav1exp.AzureManagedMachinePool{}, mgr.GetClient(), )) - hookServer.Register("/validate-infrastructure-cluster-x-k8s-io-v1beta1-azuremanagedmachinepool", webhook.NewValidatingWebhook( + hookServer.Register("/validate-infrastructure-cluster-x-k8s-io-v1beta1-azuremanagedmachinepool", webhookutils.NewValidatingWebhook( &infrav1exp.AzureManagedMachinePool{}, mgr.GetClient(), )) - hookServer.Register("/mutate-infrastructure-cluster-x-k8s-io-v1beta1-azuremanagedcontrolplane", webhook.NewMutatingWebhook( + hookServer.Register("/mutate-infrastructure-cluster-x-k8s-io-v1beta1-azuremanagedcontrolplane", webhookutils.NewMutatingWebhook( &infrav1exp.AzureManagedControlPlane{}, mgr.GetClient(), )) - hookServer.Register("/validate-infrastructure-cluster-x-k8s-io-v1beta1-azuremanagedcontrolplane", webhook.NewValidatingWebhook( + hookServer.Register("/validate-infrastructure-cluster-x-k8s-io-v1beta1-azuremanagedcontrolplane", webhookutils.NewValidatingWebhook( &infrav1exp.AzureManagedControlPlane{}, mgr.GetClient(), )) } diff --git a/util/webhook/validator.go b/util/webhook/validator.go index 4d9297e04ea..fe651f52afc 100644 --- a/util/webhook/validator.go +++ b/util/webhook/validator.go @@ -20,15 +20,23 @@ import ( "context" "errors" "net/http" + "sort" admissionv1 "k8s.io/api/admission/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" ) +const ( + unsetMessage = "field is immutable, unable to set an empty value if it was already set" + setMessage = "field is immutable, unable to assign a value if it was already empty" + immutableMessage = "field is immutable" +) + // Validator defines functions for validating an operation. type Validator interface { runtime.Object @@ -136,3 +144,102 @@ func validationResponseFromStatus(allowed bool, status metav1.Status) admission. }, } } + +// EnsureStringSlicesAreEquivalent returns if two string slices have equal lengths, +// and that they have the exact same items; it does not enforce strict ordering of items. +func EnsureStringSlicesAreEquivalent(a []string, b []string) bool { + if len(a) != len(b) { + return false + } + + sort.Strings(a) + sort.Strings(b) + + for i := range a { + if a[i] != b[i] { + return false + } + } + + return true +} + +// ValidateBoolPtrImmutable validates equality across two *bools, +// and returns a meaningful error to indicate a changed value, a newly set value, or a newly unset value. +func ValidateBoolPtrImmutable(path *field.Path, oldVal, newVal *bool) *field.Error { + if oldVal != nil { + // Prevent modification if it was already set to some value + if newVal == nil { + // unsetting the field is not allowed + return field.Invalid(path, newVal, unsetMessage) + } + if *newVal != *oldVal { + // changing the field is not allowed + return field.Invalid(path, newVal, immutableMessage) + } + } else if newVal != nil { + return field.Invalid(path, newVal, setMessage) + } + + return nil +} + +// ValidateStringImmutable validates equality across two strings, +// and returns a meaningful error to indicate a changed value, a newly set value, or a newly unset value. +func ValidateStringImmutable(path *field.Path, oldVal, newVal string) *field.Error { + if oldVal != "" { + // Prevent modification if it was already set to some value + if newVal == "" { + // unsetting the field is not allowed + return field.Invalid(path, newVal, unsetMessage) + } + if newVal != oldVal { + // changing the field is not allowed + return field.Invalid(path, newVal, immutableMessage) + } + } else if newVal != "" { + return field.Invalid(path, newVal, setMessage) + } + + return nil +} + +// ValidateStringPtrImmutable validates equality across two *strings, +// and returns a meaningful error to indicate a changed value, a newly set value, or a newly unset value. +func ValidateStringPtrImmutable(path *field.Path, oldVal, newVal *string) *field.Error { + if oldVal != nil { + // Prevent modification if it was already set to some value + if newVal == nil { + // unsetting the field is not allowed + return field.Invalid(path, newVal, unsetMessage) + } + if *newVal != *oldVal { + // changing the field is not allowed + return field.Invalid(path, newVal, immutableMessage) + } + } else if newVal != nil { + return field.Invalid(path, newVal, setMessage) + } + + return nil +} + +// ValidateInt32PtrImmutable validates equality across two *int32s, +// and returns a meaningful error to indicate a changed value, a newly set value, or a newly unset value. +func ValidateInt32PtrImmutable(path *field.Path, oldVal, newVal *int32) *field.Error { + if oldVal != nil { + // Prevent modification if it was already set to some value + if newVal == nil { + // unsetting the field is not allowed + return field.Invalid(path, newVal, unsetMessage) + } + if *newVal != *oldVal { + // changing the field is not allowed + return field.Invalid(path, newVal, immutableMessage) + } + } else if newVal != nil { + return field.Invalid(path, newVal, setMessage) + } + + return nil +} diff --git a/util/webhook/validator_test.go b/util/webhook/validator_test.go new file mode 100644 index 00000000000..7cb9c55aeda --- /dev/null +++ b/util/webhook/validator_test.go @@ -0,0 +1,325 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package webhook + +import ( + "testing" + + "github.com/Azure/go-autorest/autorest/to" + . "github.com/onsi/gomega" + "k8s.io/apimachinery/pkg/util/validation/field" +) + +func TestValidateBoolPtrImmutable(t *testing.T) { + g := NewWithT(t) + testPath := field.NewPath("Spec", "Foo") + + tests := []struct { + name string + path *field.Path + input1 *bool + input2 *bool + expectedOutput *field.Error + }{ + { + name: "nil", + path: testPath, + input1: nil, + input2: nil, + }, + { + name: "no change", + path: testPath, + input1: nil, + input2: nil, + }, + { + name: "can't unset", + path: testPath, + input1: to.BoolPtr(true), + input2: nil, + expectedOutput: field.Invalid(testPath, nil, unsetMessage), + }, + { + name: "can't set from empty", + path: testPath, + input1: nil, + input2: to.BoolPtr(true), + expectedOutput: field.Invalid(testPath, nil, setMessage), + }, + { + name: "can't change", + path: testPath, + input1: to.BoolPtr(true), + input2: to.BoolPtr(false), + expectedOutput: field.Invalid(testPath, nil, immutableMessage), + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := ValidateBoolPtrImmutable(tc.path, tc.input1, tc.input2) + if tc.expectedOutput != nil { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Detail).To(Equal(tc.expectedOutput.Detail)) + g.Expect(err.Type).To(Equal(tc.expectedOutput.Type)) + g.Expect(err.Field).To(Equal(tc.expectedOutput.Field)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + } + }) + } +} + +func TestValidateStringImmutable(t *testing.T) { + g := NewWithT(t) + testPath := field.NewPath("Spec", "Foo") + + tests := []struct { + name string + path *field.Path + input1 string + input2 string + expectedOutput *field.Error + }{ + { + name: "empty string", + path: testPath, + input1: "", + input2: "", + }, + { + name: "no change", + path: testPath, + input1: "", + input2: "", + }, + { + name: "can't unset", + path: testPath, + input1: "foo", + input2: "", + expectedOutput: field.Invalid(testPath, nil, unsetMessage), + }, + { + name: "can't set from empty", + path: testPath, + input1: "", + input2: "foo", + expectedOutput: field.Invalid(testPath, nil, setMessage), + }, + { + name: "can't change", + path: testPath, + input1: "foo", + input2: "bar", + expectedOutput: field.Invalid(testPath, nil, immutableMessage), + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := ValidateStringImmutable(tc.path, tc.input1, tc.input2) + if tc.expectedOutput != nil { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Detail).To(Equal(tc.expectedOutput.Detail)) + g.Expect(err.Type).To(Equal(tc.expectedOutput.Type)) + g.Expect(err.Field).To(Equal(tc.expectedOutput.Field)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + } + }) + } +} + +func TestValidateStringPtrImmutable(t *testing.T) { + g := NewWithT(t) + testPath := field.NewPath("Spec", "Foo") + + tests := []struct { + name string + path *field.Path + input1 *string + input2 *string + expectedOutput *field.Error + }{ + { + name: "nil", + path: testPath, + input1: nil, + input2: nil, + }, + { + name: "no change", + path: testPath, + input1: to.StringPtr("foo"), + input2: to.StringPtr("foo"), + }, + { + name: "can't unset", + path: testPath, + input1: to.StringPtr("foo"), + input2: nil, + expectedOutput: field.Invalid(testPath, nil, unsetMessage), + }, + { + name: "can't set from empty", + path: testPath, + input1: nil, + input2: to.StringPtr("foo"), + expectedOutput: field.Invalid(testPath, nil, setMessage), + }, + { + name: "can't change", + path: testPath, + input1: to.StringPtr("foo"), + input2: to.StringPtr("bar"), + expectedOutput: field.Invalid(testPath, nil, immutableMessage), + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := ValidateStringPtrImmutable(tc.path, tc.input1, tc.input2) + if tc.expectedOutput != nil { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Detail).To(Equal(tc.expectedOutput.Detail)) + g.Expect(err.Type).To(Equal(tc.expectedOutput.Type)) + g.Expect(err.Field).To(Equal(tc.expectedOutput.Field)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + } + }) + } +} + +func TestValidateInt32PtrImmutable(t *testing.T) { + g := NewWithT(t) + testPath := field.NewPath("Spec", "Foo") + + tests := []struct { + name string + path *field.Path + input1 *int32 + input2 *int32 + expectedOutput *field.Error + }{ + { + name: "nil", + path: testPath, + input1: nil, + input2: nil, + }, + { + name: "no change", + path: testPath, + input1: to.Int32Ptr(5), + input2: to.Int32Ptr(5), + }, + { + name: "can't unset", + path: testPath, + input1: to.Int32Ptr(5), + input2: nil, + expectedOutput: field.Invalid(testPath, nil, unsetMessage), + }, + { + name: "can't set from empty", + path: testPath, + input1: nil, + input2: to.Int32Ptr(5), + expectedOutput: field.Invalid(testPath, nil, setMessage), + }, + { + name: "can't change", + path: testPath, + input1: to.Int32Ptr(5), + input2: to.Int32Ptr(6), + expectedOutput: field.Invalid(testPath, nil, immutableMessage), + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := ValidateInt32PtrImmutable(tc.path, tc.input1, tc.input2) + if tc.expectedOutput != nil { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Detail).To(Equal(tc.expectedOutput.Detail)) + g.Expect(err.Type).To(Equal(tc.expectedOutput.Type)) + g.Expect(err.Field).To(Equal(tc.expectedOutput.Field)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + } + }) + } +} + +func TestEnsureStringSlicesAreEquivalent(t *testing.T) { + g := NewWithT(t) + + tests := []struct { + name string + input1 []string + input2 []string + expectedOutput bool + }{ + { + name: "nil", + input1: nil, + input2: nil, + expectedOutput: true, + }, + { + name: "no change", + input1: []string{"foo", "bar"}, + input2: []string{"foo", "bar"}, + expectedOutput: true, + }, + { + name: "different", + input1: []string{"foo", "bar"}, + input2: []string{"foo", "foo"}, + expectedOutput: false, + }, + { + name: "different order, but equal", + input1: []string{"1", "2"}, + input2: []string{"2", "1"}, + expectedOutput: true, + }, + { + name: "different lengths", + input1: []string{"foo"}, + input2: []string{"foo", "foo"}, + expectedOutput: false, + }, + { + name: "different", + input1: []string{"1", "2", "3", "4", "5", "6", "7", "8", "9"}, + input2: []string{"1", "2", "3", "4", "5", "7", "8", "9"}, + expectedOutput: false, + }, + { + name: "another different variant", + input1: []string{"a", "a", "b"}, + input2: []string{"a", "b", "b"}, + expectedOutput: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ret := EnsureStringSlicesAreEquivalent(tc.input1, tc.input2) + g.Expect(ret).To(Equal(tc.expectedOutput)) + }) + } +}