diff --git a/azure/services/agentpools/agentpools.go b/azure/services/agentpools/agentpools.go index cf1e0158fc8..604acbc3904 100644 --- a/azure/services/agentpools/agentpools.go +++ b/azure/services/agentpools/agentpools.go @@ -19,6 +19,7 @@ package agentpools import ( "context" "fmt" + "strings" "time" "github.com/Azure/azure-sdk-for-go/services/containerservice/mgmt/2021-05-01/containerservice" @@ -28,6 +29,7 @@ import ( infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" "sigs.k8s.io/cluster-api-provider-azure/azure" "sigs.k8s.io/cluster-api-provider-azure/azure/converters" + azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" "sigs.k8s.io/cluster-api-provider-azure/util/maps" "sigs.k8s.io/cluster-api-provider-azure/util/tele" ) @@ -155,6 +157,7 @@ func (s *Service) Reconcile(ctx context.Context) error { diff := cmp.Diff(normalizedProfile, existingProfile) if diff != "" { log.V(2).Info(fmt.Sprintf("Update required (+new -old):\n%s", diff)) + profile.NodeLabels = mergeSystemNodeLabels(profile.NodeLabels, existingPool.NodeLabels) err = s.Client.CreateOrUpdate(ctx, agentPoolSpec.ResourceGroup, agentPoolSpec.Cluster, agentPoolSpec.Name, profile, customHeaders) if err != nil { @@ -191,3 +194,15 @@ func (s *Service) Delete(ctx context.Context) error { log.V(2).Info(fmt.Sprintf("Successfully deleted agent pool %s ", agentPoolSpec.Name)) return nil } + +// mergeSystemNodeLabels +func mergeSystemNodeLabels(capz, aks map[string]*string) map[string]*string { + ret := capz + // Look for labels returned from the AKS node pool API that begin with kubernetes.azure.com + for aksNodeLabelKey := range aks { + if strings.HasPrefix(aksNodeLabelKey, azureutil.AzureSystemNodeLabelPrefix) { + ret[aksNodeLabelKey] = aks[aksNodeLabelKey] + } + } + return ret +} diff --git a/azure/services/agentpools/agentpools_test.go b/azure/services/agentpools/agentpools_test.go index ce85c3fcb44..261f146b2ae 100644 --- a/azure/services/agentpools/agentpools_test.go +++ b/azure/services/agentpools/agentpools_test.go @@ -565,3 +565,82 @@ func TestDeleteAgentPools(t *testing.T) { }) } } + +func TestMergeSystemNodeLabels(t *testing.T) { + testcases := []struct { + name string + capzLabels map[string]*string + aksLabels map[string]*string + expected map[string]*string + }{ + { + name: "update an existing label", + capzLabels: map[string]*string{ + "foo": to.StringPtr("bar"), + }, + aksLabels: map[string]*string{ + "foo": to.StringPtr("baz"), + }, + expected: map[string]*string{ + "foo": to.StringPtr("bar"), + }, + }, + { + name: "delete labels", + capzLabels: map[string]*string{}, + aksLabels: map[string]*string{ + "foo": to.StringPtr("bar"), + "hello": to.StringPtr("world"), + }, + expected: map[string]*string{}, + }, + { + name: "delete one label", + capzLabels: map[string]*string{ + "foo": to.StringPtr("bar"), + }, + aksLabels: map[string]*string{ + "foo": to.StringPtr("bar"), + "hello": to.StringPtr("world"), + }, + expected: map[string]*string{ + "foo": to.StringPtr("bar"), + }, + }, + { + name: "retain system label during update", + capzLabels: map[string]*string{ + "foo": to.StringPtr("bar"), + }, + aksLabels: map[string]*string{ + "kubernetes.azure.com/scalesetpriority": to.StringPtr("spot"), + }, + expected: map[string]*string{ + "foo": to.StringPtr("bar"), + "kubernetes.azure.com/scalesetpriority": to.StringPtr("spot"), + }, + }, + { + name: "retain system label during delete", + capzLabels: map[string]*string{}, + aksLabels: map[string]*string{ + "kubernetes.azure.com/scalesetpriority": to.StringPtr("spot"), + }, + expected: map[string]*string{ + "kubernetes.azure.com/scalesetpriority": to.StringPtr("spot"), + }, + }, + } + + for _, tc := range testcases { + t.Logf("Testing " + tc.name) + tc := tc + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + t.Parallel() + + ret := mergeSystemNodeLabels(tc.capzLabels, tc.aksLabels) + g.Expect(ret).To(Equal(tc.expected)) + }) + } +} diff --git a/exp/api/v1beta1/azuremanagedmachinepool_webhook.go b/exp/api/v1beta1/azuremanagedmachinepool_webhook.go index a08ca79c88b..3fd398a31c7 100644 --- a/exp/api/v1beta1/azuremanagedmachinepool_webhook.go +++ b/exp/api/v1beta1/azuremanagedmachinepool_webhook.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "reflect" + "strings" "github.com/Azure/go-autorest/autorest/to" "github.com/pkg/errors" @@ -28,6 +29,7 @@ import ( kerrors "k8s.io/apimachinery/pkg/util/errors" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/cluster-api-provider-azure/azure" + azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" "sigs.k8s.io/cluster-api-provider-azure/util/maps" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" "sigs.k8s.io/controller-runtime/pkg/client" @@ -59,6 +61,7 @@ func (m *AzureManagedMachinePool) ValidateCreate(client client.Client) error { m.validateMaxPods, m.validateOSType, m.validateName, + m.validateNodeLabels, } var errs []error @@ -223,6 +226,21 @@ func (m *AzureManagedMachinePool) ValidateUpdate(oldRaw runtime.Object, client c "field is immutable, unsetting is not allowed")) } } + + updateValidators := []func() error{ + m.validateNodeLabels, + } + + for _, validator := range updateValidators { + if err := validator(); err != nil { + allErrs = append(allErrs, + field.Invalid( + field.NewPath("Spec", "NodeLabels"), + m.Spec.NodeLabels, + fmt.Sprintf("Node pool labels must not start with %s", azureutil.AzureSystemNodeLabelPrefix))) + } + } + if len(allErrs) != 0 { return apierrors.NewInvalid(GroupVersion.WithKind("AzureManagedMachinePool").GroupKind(), m.Name, allErrs) } @@ -319,6 +337,21 @@ func (m *AzureManagedMachinePool) validateName() error { return nil } +func (m *AzureManagedMachinePool) validateNodeLabels() error { + if m.Spec.NodeLabels != nil { + for key := range m.Spec.NodeLabels { + if strings.HasPrefix(key, azureutil.AzureSystemNodeLabelPrefix) { + return field.Invalid( + field.NewPath("Spec", "NodeLabels"), + m.Spec.NodeLabels[key], + fmt.Sprintf("Node pool label must not start with %s", azureutil.AzureSystemNodeLabelPrefix)) + } + } + } + + return nil +} + func ensureStringSlicesAreEqual(a []string, b []string) bool { if len(a) != len(b) { return false diff --git a/exp/api/v1beta1/azuremanagedmachinepool_webhook_test.go b/exp/api/v1beta1/azuremanagedmachinepool_webhook_test.go index 8b708d55997..bcf4082fed6 100644 --- a/exp/api/v1beta1/azuremanagedmachinepool_webhook_test.go +++ b/exp/api/v1beta1/azuremanagedmachinepool_webhook_test.go @@ -463,6 +463,25 @@ func TestAzureManagedMachinePoolUpdatingWebhook(t *testing.T) { }, wantErr: true, }, + { + name: "Can't add a node label that begins with kubernetes.azure.com", + new: &AzureManagedMachinePool{ + Spec: AzureManagedMachinePoolSpec{ + NodeLabels: map[string]string{ + "foo": "bar", + "kubernetes.azure.com/scalesetpriority": "spot", + }, + }, + }, + old: &AzureManagedMachinePool{ + Spec: AzureManagedMachinePoolSpec{ + NodeLabels: map[string]string{ + "foo": "bar", + }, + }, + }, + wantErr: true, + }, } var client client.Client for _, tc := range tests { @@ -581,6 +600,33 @@ func TestAzureManagedMachinePool_ValidateCreate(t *testing.T) { wantErr: true, errorLen: 1, }, + { + name: "valid label", + ammp: &AzureManagedMachinePool{ + Spec: AzureManagedMachinePoolSpec{ + Mode: "User", + OSType: to.StringPtr(azure.LinuxOS), + NodeLabels: map[string]string{ + "foo": "bar", + }, + }, + }, + wantErr: false, + }, + { + name: "kubernetes.azure.com label", + ammp: &AzureManagedMachinePool{ + Spec: AzureManagedMachinePoolSpec{ + Mode: "User", + OSType: to.StringPtr(azure.LinuxOS), + NodeLabels: map[string]string{ + "kubernetes.azure.com/scalesetpriority": "spot", + }, + }, + }, + wantErr: true, + errorLen: 1, + }, } var client client.Client for _, tc := range tests { diff --git a/util/azure/azure.go b/util/azure/azure.go index 79b9f5eedd2..bc4f0de6724 100644 --- a/util/azure/azure.go +++ b/util/azure/azure.go @@ -24,6 +24,8 @@ import ( var azureResourceGroupNameRE = regexp.MustCompile(`.*/subscriptions/(?:.*)/resourceGroups/(.+)/providers/(?:.*)`) +const AzureSystemNodeLabelPrefix = "kubernetes.azure.com" + // ConvertResourceGroupNameToLower converts the resource group name in the resource ID to be lowered. // Inspired by https://github.com/kubernetes-sigs/cloud-provider-azure/blob/88c9b89611e7c1fcbd39266928cce8406eb0e728/pkg/provider/azure_wrap.go#L409 func ConvertResourceGroupNameToLower(resourceID string) (string, error) {