From a9d16ef5a6b59325e1ee73e5700cedbe007dd2f8 Mon Sep 17 00:00:00 2001
From: Jack Francis <jackfrancis@gmail.com>
Date: Fri, 14 Oct 2022 11:48:52 -0700
Subject: [PATCH] standardize AzureManagedCluster webhooks

---
 exp/api/v1beta1/azuremachinepool_webhook.go   | 17 ++---
 .../v1beta1/azuremachinepool_webhook_test.go  | 73 +++++++++++++++++--
 .../v1beta1/azuremanagedcluster_webhook.go    | 11 ++-
 .../azuremanagedcluster_webhook_test.go       | 34 +++++++++
 .../azuremanagedcontrolplane_webhook.go       |  9 +++
 .../azuremanagedcontrolplane_webhook_test.go  | 73 +++++++++++++++----
 .../azuremanagedmachinepool_webhook.go        |  9 +++
 .../azuremanagedmachinepool_webhook_test.go   | 53 ++++++++++++--
 main.go                                       | 28 ++++---
 9 files changed, 249 insertions(+), 58 deletions(-)

diff --git a/exp/api/v1beta1/azuremachinepool_webhook.go b/exp/api/v1beta1/azuremachinepool_webhook.go
index 368012a1c755..1dcb692432ce 100644
--- a/exp/api/v1beta1/azuremachinepool_webhook.go
+++ b/exp/api/v1beta1/azuremachinepool_webhook.go
@@ -57,6 +57,14 @@ var _ webhook.Validator = &AzureMachinePool{}
 
 // ValidateCreate implements webhook.Validator so a webhook will be registered for the type.
 func (amp *AzureMachinePool) ValidateCreate() error {
+	// NOTE: AzureMachinePool is behind MachinePool feature gate flag; the web hook
+	// must prevent creating new objects in case the feature flag is disabled.
+	if !feature.Gates.Enabled(capifeature.MachinePool) {
+		return field.Forbidden(
+			field.NewPath("spec"),
+			"can be set only if the MachinePool feature flag is enabled",
+		)
+	}
 	return amp.Validate(nil)
 }
 
@@ -72,15 +80,6 @@ func (amp *AzureMachinePool) ValidateDelete() error {
 
 // Validate the Azure Machine Pool and return an aggregate error.
 func (amp *AzureMachinePool) Validate(old runtime.Object) error {
-	// NOTE: AzureMachinePool is behind MachinePool feature gate flag; the web hook
-	// must prevent creating new objects new case the feature flag is disabled.
-	if !feature.Gates.Enabled(capifeature.MachinePool) {
-		return field.Forbidden(
-			field.NewPath("spec"),
-			"can be set only if the MachinePool feature flag is enabled",
-		)
-	}
-
 	validators := []func() error{
 		amp.ValidateImage,
 		amp.ValidateTerminateNotificationTimeout,
diff --git a/exp/api/v1beta1/azuremachinepool_webhook_test.go b/exp/api/v1beta1/azuremachinepool_webhook_test.go
index 77e5bc1bf251..171ed0ea2a7f 100644
--- a/exp/api/v1beta1/azuremachinepool_webhook_test.go
+++ b/exp/api/v1beta1/azuremachinepool_webhook_test.go
@@ -36,6 +36,8 @@ import (
 
 var (
 	validSSHPublicKey = generateSSHPublicKey(true)
+	zero              = intstr.FromInt(0)
+	one               = intstr.FromInt(1)
 )
 
 func TestAzureMachinePool_ValidateCreate(t *testing.T) {
@@ -45,16 +47,16 @@ func TestAzureMachinePool_ValidateCreate(t *testing.T) {
 
 	g := NewWithT(t)
 
-	var (
-		zero = intstr.FromInt(0)
-		one  = intstr.FromInt(1)
-	)
-
 	tests := []struct {
 		name    string
 		amp     *AzureMachinePool
 		wantErr bool
 	}{
+		{
+			name:    "valid",
+			amp:     getKnownValidAzureMachinePool(),
+			wantErr: false,
+		},
 		{
 			name:    "azuremachinepool with marketplace image - full",
 			amp:     createMachinePoolWithMarketPlaceImage("PUB1234", "OFFER1234", "SKU1234", "1.0.0", to.IntPtr(10)),
@@ -378,3 +380,64 @@ func createMachinePoolWithStrategy(strategy AzureMachinePoolDeploymentStrategy)
 		},
 	}
 }
+
+func TestAzureMachinePool_ValidateCreateFailure(t *testing.T) {
+	g := NewWithT(t)
+
+	tests := []struct {
+		name      string
+		amp       *AzureMachinePool
+		deferFunc func()
+	}{
+		{
+			name:      "feature gate explicitly disabled",
+			amp:       getKnownValidAzureMachinePool(),
+			deferFunc: utilfeature.SetFeatureGateDuringTest(t, feature.Gates, capifeature.MachinePool, false),
+		},
+		{
+			name: "feature gate implicitly disabled",
+			amp:  getKnownValidAzureMachinePool(),
+			deferFunc: func() {
+				return
+			},
+		},
+	}
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			defer tc.deferFunc()
+			err := tc.amp.ValidateCreate()
+			g.Expect(err).To(HaveOccurred())
+		})
+	}
+}
+
+func getKnownValidAzureMachinePool() *AzureMachinePool {
+	image := infrav1.Image{
+		Marketplace: &infrav1.AzureMarketplaceImage{
+			ImagePlan: infrav1.ImagePlan{
+				Publisher: "PUB1234",
+				Offer:     "OFFER1234",
+				SKU:       "SKU1234",
+			},
+			Version: "1.0.0",
+		},
+	}
+	return &AzureMachinePool{
+		Spec: AzureMachinePoolSpec{
+			Template: AzureMachinePoolMachineTemplate{
+				Image:                        &image,
+				SSHPublicKey:                 validSSHPublicKey,
+				TerminateNotificationTimeout: to.IntPtr(10),
+			},
+			Identity:           infrav1.VMIdentitySystemAssigned,
+			RoleAssignmentName: string(uuid.NewUUID()),
+			Strategy: AzureMachinePoolDeploymentStrategy{
+				Type: RollingUpdateAzureMachinePoolDeploymentStrategyType,
+				RollingUpdate: &MachineRollingUpdateDeployment{
+					MaxSurge:       &zero,
+					MaxUnavailable: &one,
+				},
+			},
+		},
+	}
+}
diff --git a/exp/api/v1beta1/azuremanagedcluster_webhook.go b/exp/api/v1beta1/azuremanagedcluster_webhook.go
index e8e8cd8a826d..40e901468572 100644
--- a/exp/api/v1beta1/azuremanagedcluster_webhook.go
+++ b/exp/api/v1beta1/azuremanagedcluster_webhook.go
@@ -43,20 +43,19 @@ var _ webhook.Validator = &AzureManagedCluster{}
 
 // ValidateCreate implements webhook.Validator so a webhook will be registered for the type.
 func (r *AzureManagedCluster) ValidateCreate() error {
-	return nil
-}
-
-// ValidateUpdate implements webhook.Validator so a webhook will be registered for the type.
-func (r *AzureManagedCluster) ValidateUpdate(oldRaw runtime.Object) error {
 	// NOTE: AzureManagedCluster is behind AKS feature gate flag; the web hook
-	// must prevent creating new objects new case the feature flag is disabled.
+	// must prevent creating new objects in case the feature flag is disabled.
 	if !feature.Gates.Enabled(feature.AKS) {
 		return field.Forbidden(
 			field.NewPath("spec"),
 			"can be set only if the AKS feature flag is enabled",
 		)
 	}
+	return nil
+}
 
+// ValidateUpdate implements webhook.Validator so a webhook will be registered for the type.
+func (r *AzureManagedCluster) ValidateUpdate(oldRaw runtime.Object) error {
 	old := oldRaw.(*AzureManagedCluster)
 	var allErrs field.ErrorList
 
diff --git a/exp/api/v1beta1/azuremanagedcluster_webhook_test.go b/exp/api/v1beta1/azuremanagedcluster_webhook_test.go
index 73a8e66b8829..c4e3ba67af0e 100644
--- a/exp/api/v1beta1/azuremanagedcluster_webhook_test.go
+++ b/exp/api/v1beta1/azuremanagedcluster_webhook_test.go
@@ -131,3 +131,37 @@ func TestAzureManagedCluster_ValidateUpdate(t *testing.T) {
 		})
 	}
 }
+
+func TestAzureManagedCluster_ValidateCreateFailure(t *testing.T) {
+	g := NewWithT(t)
+
+	tests := []struct {
+		name      string
+		amc       *AzureManagedCluster
+		deferFunc func()
+	}{
+		{
+			name:      "feature gate explicitly disabled",
+			amc:       getKnownValidAzureManagedCluster(),
+			deferFunc: utilfeature.SetFeatureGateDuringTest(t, feature.Gates, feature.AKS, false),
+		},
+		{
+			name: "feature gate implicitly disabled",
+			amc:  getKnownValidAzureManagedCluster(),
+			deferFunc: func() {
+				return
+			},
+		},
+	}
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			defer tc.deferFunc()
+			err := tc.amc.ValidateCreate()
+			g.Expect(err).To(HaveOccurred())
+		})
+	}
+}
+
+func getKnownValidAzureManagedCluster() *AzureManagedCluster {
+	return &AzureManagedCluster{}
+}
diff --git a/exp/api/v1beta1/azuremanagedcontrolplane_webhook.go b/exp/api/v1beta1/azuremanagedcontrolplane_webhook.go
index e30c19efc4ae..29e8a52957a6 100644
--- a/exp/api/v1beta1/azuremanagedcontrolplane_webhook.go
+++ b/exp/api/v1beta1/azuremanagedcontrolplane_webhook.go
@@ -30,6 +30,7 @@ import (
 	kerrors "k8s.io/apimachinery/pkg/util/errors"
 	"k8s.io/apimachinery/pkg/util/validation/field"
 	infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1"
+	"sigs.k8s.io/cluster-api-provider-azure/feature"
 	clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1"
 	ctrl "sigs.k8s.io/controller-runtime"
 	"sigs.k8s.io/controller-runtime/pkg/client"
@@ -80,6 +81,14 @@ func (m *AzureManagedControlPlane) Default(_ client.Client) {
 
 // ValidateCreate implements webhook.Validator so a webhook will be registered for the type.
 func (m *AzureManagedControlPlane) ValidateCreate(client client.Client) error {
+	// NOTE: AzureManagedControlPlane is behind AKS feature gate flag; the web hook
+	// must prevent creating new objects in case the feature flag is disabled.
+	if !feature.Gates.Enabled(feature.AKS) {
+		return field.Forbidden(
+			field.NewPath("spec"),
+			"can be set only if the AKS feature flag is enabled",
+		)
+	}
 	return m.Validate(client)
 }
 
diff --git a/exp/api/v1beta1/azuremanagedcontrolplane_webhook_test.go b/exp/api/v1beta1/azuremanagedcontrolplane_webhook_test.go
index b78ca9ec9f5c..c317cdaac9a4 100644
--- a/exp/api/v1beta1/azuremanagedcontrolplane_webhook_test.go
+++ b/exp/api/v1beta1/azuremanagedcontrolplane_webhook_test.go
@@ -22,7 +22,9 @@ import (
 	"github.com/Azure/go-autorest/autorest/to"
 	. "github.com/onsi/gomega"
 	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+	utilfeature "k8s.io/component-base/featuregate/testing"
 	"k8s.io/utils/pointer"
+	"sigs.k8s.io/cluster-api-provider-azure/feature"
 )
 
 func TestDefaultingWebhook(t *testing.T) {
@@ -77,6 +79,10 @@ func TestDefaultingWebhook(t *testing.T) {
 }
 
 func TestValidatingWebhook(t *testing.T) {
+	// NOTE: AzureManageControlPlane is behind AKS feature gate flag; the web hook
+	// must prevent creating new objects in case the feature flag is disabled.
+	defer utilfeature.SetFeatureGateDuringTest(t, feature.Gates, feature.AKS, true)()
+	g := NewWithT(t)
 	tests := []struct {
 		name      string
 		amcp      AzureManagedControlPlane
@@ -239,8 +245,6 @@ func TestValidatingWebhook(t *testing.T) {
 	for _, tt := range tests {
 		tt := tt
 		t.Run(tt.name, func(t *testing.T) {
-			g := NewWithT(t)
-			t.Parallel()
 			if tt.expectErr {
 				g.Expect(tt.amcp.ValidateCreate(nil)).NotTo(Succeed())
 			} else {
@@ -251,6 +255,9 @@ func TestValidatingWebhook(t *testing.T) {
 }
 
 func TestAzureManagedControlPlane_ValidateCreate(t *testing.T) {
+	// NOTE: AzureManageControlPlane is behind AKS feature gate flag; the web hook
+	// must prevent creating new objects in case the feature flag is disabled.
+	defer utilfeature.SetFeatureGateDuringTest(t, feature.Gates, feature.AKS, true)()
 	g := NewWithT(t)
 
 	tests := []struct {
@@ -260,20 +267,8 @@ func TestAzureManagedControlPlane_ValidateCreate(t *testing.T) {
 		errorLen int
 	}{
 		{
-			name: "all valid",
-			amcp: &AzureManagedControlPlane{
-				Spec: AzureManagedControlPlaneSpec{
-					DNSServiceIP: to.StringPtr("192.168.0.0"),
-					Version:      "v1.18.0",
-					SSHPublicKey: generateSSHPublicKey(true),
-					AADProfile: &AADProfile{
-						Managed: true,
-						AdminGroupObjectIDs: []string{
-							"616077a8-5db7-4c98-b856-b34619afg75h",
-						},
-					},
-				},
-			},
+			name:    "all valid",
+			amcp:    getKnownValidAzureManagedControlPlane(),
 			wantErr: false,
 		},
 		{
@@ -344,6 +339,36 @@ func TestAzureManagedControlPlane_ValidateCreate(t *testing.T) {
 	}
 }
 
+func TestAzureManagedControlPlane_ValidateCreateFailure(t *testing.T) {
+	g := NewWithT(t)
+
+	tests := []struct {
+		name      string
+		amcp      *AzureManagedControlPlane
+		deferFunc func()
+	}{
+		{
+			name:      "feature gate explicitly disabled",
+			amcp:      getKnownValidAzureManagedControlPlane(),
+			deferFunc: utilfeature.SetFeatureGateDuringTest(t, feature.Gates, feature.AKS, false),
+		},
+		{
+			name: "feature gate implicitly disabled",
+			amcp: getKnownValidAzureManagedControlPlane(),
+			deferFunc: func() {
+				return
+			},
+		},
+	}
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			defer tc.deferFunc()
+			err := tc.amcp.ValidateCreate(nil)
+			g.Expect(err).To(HaveOccurred())
+		})
+	}
+}
+
 func TestAzureManagedControlPlane_ValidateUpdate(t *testing.T) {
 	g := NewWithT(t)
 
@@ -901,3 +926,19 @@ func createAzureManagedControlPlane(serviceIP, version, sshKey string) *AzureMan
 		},
 	}
 }
+
+func getKnownValidAzureManagedControlPlane() *AzureManagedControlPlane {
+	return &AzureManagedControlPlane{
+		Spec: AzureManagedControlPlaneSpec{
+			DNSServiceIP: to.StringPtr("192.168.0.0"),
+			Version:      "v1.18.0",
+			SSHPublicKey: generateSSHPublicKey(true),
+			AADProfile: &AADProfile{
+				Managed: true,
+				AdminGroupObjectIDs: []string{
+					"616077a8-5db7-4c98-b856-b34619afg75h",
+				},
+			},
+		},
+	}
+}
diff --git a/exp/api/v1beta1/azuremanagedmachinepool_webhook.go b/exp/api/v1beta1/azuremanagedmachinepool_webhook.go
index 7186f13a2310..318b603a3c9c 100644
--- a/exp/api/v1beta1/azuremanagedmachinepool_webhook.go
+++ b/exp/api/v1beta1/azuremanagedmachinepool_webhook.go
@@ -29,6 +29,7 @@ import (
 	kerrors "k8s.io/apimachinery/pkg/util/errors"
 	"k8s.io/apimachinery/pkg/util/validation/field"
 	"sigs.k8s.io/cluster-api-provider-azure/azure"
+	"sigs.k8s.io/cluster-api-provider-azure/feature"
 	azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure"
 	"sigs.k8s.io/cluster-api-provider-azure/util/maps"
 	clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1"
@@ -59,6 +60,14 @@ func (m *AzureManagedMachinePool) Default(client client.Client) {
 
 // ValidateCreate implements webhook.Validator so a webhook will be registered for the type.
 func (m *AzureManagedMachinePool) ValidateCreate(client client.Client) error {
+	// NOTE: AzureManagedMachinePool is behind AKS feature gate flag; the web hook
+	// must prevent creating new objects in case the feature flag is disabled.
+	if !feature.Gates.Enabled(feature.AKS) {
+		return field.Forbidden(
+			field.NewPath("spec"),
+			"can be set only if the AKS feature flag is enabled",
+		)
+	}
 	validators := []func() error{
 		m.validateMaxPods,
 		m.validateOSType,
diff --git a/exp/api/v1beta1/azuremanagedmachinepool_webhook_test.go b/exp/api/v1beta1/azuremanagedmachinepool_webhook_test.go
index 17951fb969fb..82278469ae98 100644
--- a/exp/api/v1beta1/azuremanagedmachinepool_webhook_test.go
+++ b/exp/api/v1beta1/azuremanagedmachinepool_webhook_test.go
@@ -24,7 +24,9 @@ import (
 	. "github.com/onsi/gomega"
 	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
 	"k8s.io/apimachinery/pkg/util/validation/field"
+	utilfeature "k8s.io/component-base/featuregate/testing"
 	"sigs.k8s.io/cluster-api-provider-azure/azure"
+	"sigs.k8s.io/cluster-api-provider-azure/feature"
 	"sigs.k8s.io/controller-runtime/pkg/client"
 )
 
@@ -568,6 +570,9 @@ func TestValidateBoolPtrImmutable(t *testing.T) {
 }
 
 func TestAzureManagedMachinePool_ValidateCreate(t *testing.T) {
+	// NOTE: AzureManagedMachinePool is behind AKS feature gate flag; the web hook
+	// must prevent creating new objects in case the feature flag is disabled.
+	defer utilfeature.SetFeatureGateDuringTest(t, feature.Gates, feature.AKS, true)()
 	tests := []struct {
 		name     string
 		ammp     *AzureManagedMachinePool
@@ -575,13 +580,8 @@ func TestAzureManagedMachinePool_ValidateCreate(t *testing.T) {
 		errorLen int
 	}{
 		{
-			name: "valid",
-			ammp: &AzureManagedMachinePool{
-				Spec: AzureManagedMachinePoolSpec{
-					MaxPods:    to.Int32Ptr(30),
-					OsDiskType: to.StringPtr(string(containerservice.OSDiskTypeEphemeral)),
-				},
-			},
+			name:    "valid",
+			ammp:    getKnownValidAzureManagedMachinePool(),
 			wantErr: false,
 		},
 		{
@@ -791,3 +791,42 @@ func TestAzureManagedMachinePool_ValidateCreate(t *testing.T) {
 		})
 	}
 }
+
+func TestAzureManagedMachinePool_ValidateCreateFailure(t *testing.T) {
+	g := NewWithT(t)
+
+	tests := []struct {
+		name      string
+		ammp      *AzureManagedMachinePool
+		deferFunc func()
+	}{
+		{
+			name:      "feature gate explicitly disabled",
+			ammp:      getKnownValidAzureManagedMachinePool(),
+			deferFunc: utilfeature.SetFeatureGateDuringTest(t, feature.Gates, feature.AKS, false),
+		},
+		{
+			name: "feature gate implicitly disabled",
+			ammp: getKnownValidAzureManagedMachinePool(),
+			deferFunc: func() {
+				return
+			},
+		},
+	}
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			defer tc.deferFunc()
+			err := tc.ammp.ValidateCreate(nil)
+			g.Expect(err).To(HaveOccurred())
+		})
+	}
+}
+
+func getKnownValidAzureManagedMachinePool() *AzureManagedMachinePool {
+	return &AzureManagedMachinePool{
+		Spec: AzureManagedMachinePoolSpec{
+			MaxPods:    to.Int32Ptr(30),
+			OsDiskType: to.StringPtr(string(containerservice.OSDiskTypeEphemeral)),
+		},
+	}
+}
diff --git a/main.go b/main.go
index 494c86fdb28d..ab35e46925bf 100644
--- a/main.go
+++ b/main.go
@@ -523,21 +523,19 @@ func registerWebhooks(mgr manager.Manager) {
 		os.Exit(1)
 	}
 
-	if feature.Gates.Enabled(feature.AKS) {
-		hookServer := mgr.GetWebhookServer()
-		hookServer.Register("/mutate-infrastructure-cluster-x-k8s-io-v1beta1-azuremanagedmachinepool", webhook.NewMutatingWebhook(
-			&infrav1exp.AzureManagedMachinePool{}, mgr.GetClient(),
-		))
-		hookServer.Register("/validate-infrastructure-cluster-x-k8s-io-v1beta1-azuremanagedmachinepool", webhook.NewValidatingWebhook(
-			&infrav1exp.AzureManagedMachinePool{}, mgr.GetClient(),
-		))
-		hookServer.Register("/mutate-infrastructure-cluster-x-k8s-io-v1beta1-azuremanagedcontrolplane", webhook.NewMutatingWebhook(
-			&infrav1exp.AzureManagedControlPlane{}, mgr.GetClient(),
-		))
-		hookServer.Register("/validate-infrastructure-cluster-x-k8s-io-v1beta1-azuremanagedcontrolplane", webhook.NewValidatingWebhook(
-			&infrav1exp.AzureManagedControlPlane{}, mgr.GetClient(),
-		))
-	}
+	hookServer := mgr.GetWebhookServer()
+	hookServer.Register("/mutate-infrastructure-cluster-x-k8s-io-v1beta1-azuremanagedmachinepool", webhook.NewMutatingWebhook(
+		&infrav1exp.AzureManagedMachinePool{}, mgr.GetClient(),
+	))
+	hookServer.Register("/validate-infrastructure-cluster-x-k8s-io-v1beta1-azuremanagedmachinepool", webhook.NewValidatingWebhook(
+		&infrav1exp.AzureManagedMachinePool{}, mgr.GetClient(),
+	))
+	hookServer.Register("/mutate-infrastructure-cluster-x-k8s-io-v1beta1-azuremanagedcontrolplane", webhook.NewMutatingWebhook(
+		&infrav1exp.AzureManagedControlPlane{}, mgr.GetClient(),
+	))
+	hookServer.Register("/validate-infrastructure-cluster-x-k8s-io-v1beta1-azuremanagedcontrolplane", webhook.NewValidatingWebhook(
+		&infrav1exp.AzureManagedControlPlane{}, mgr.GetClient(),
+	))
 
 	if err := mgr.AddReadyzCheck("webhook", mgr.GetWebhookServer().StartedChecker()); err != nil {
 		setupLog.Error(err, "unable to create ready check")