diff --git a/webhooks/core/service_mutator.go b/webhooks/core/service_mutator.go index d3c73f9fe..26f57edbc 100644 --- a/webhooks/core/service_mutator.go +++ b/webhooks/core/service_mutator.go @@ -52,7 +52,41 @@ func (m *serviceMutator) MutateCreate(ctx context.Context, obj runtime.Object) ( } func (m *serviceMutator) MutateUpdate(ctx context.Context, obj runtime.Object, oldObj runtime.Object) (runtime.Object, error) { - return obj, nil + // this mutator only cares about Service objects + newSvc, ok := obj.(*corev1.Service) + if !ok { + return obj, nil + } + + oldSvc, ok := oldObj.(*corev1.Service) + if !ok { + return obj, nil + } + + if newSvc.Spec.Type != corev1.ServiceTypeLoadBalancer { + return obj, nil + } + + // does the old Service object have spec.loadBalancerClass? + if oldSvc.Spec.LoadBalancerClass != nil && *oldSvc.Spec.LoadBalancerClass != "" { + // if so, let's inspect the incoming object for the same field + + // does the new Service object lack spec.loadBalancerClass? + // if so, set it to the old value + // if yes, then leave it be because someone wanted it that way, let the user deal with the error + if newSvc.Spec.LoadBalancerClass == nil || *newSvc.Spec.LoadBalancerClass == "" { + newSvc.Spec.LoadBalancerClass = oldSvc.Spec.LoadBalancerClass + + m.logger.Info("preserved loadBalancerClass", "service", newSvc.Name, "loadBalancerClass", *newSvc.Spec.LoadBalancerClass) + return newSvc, nil + } + + m.logger.Info("service already has loadBalancerClass, skipping", "service", newSvc.Name, "loadBalancerClass", *newSvc.Spec.LoadBalancerClass) + return newSvc, nil + } + + m.logger.Info("service did not originally have a loadBalancerClass, skipping", "service", newSvc.Name) + return newSvc, nil } // +kubebuilder:webhook:path=/mutate-v1-service,mutating=true,failurePolicy=fail,groups="",resources=services,verbs=create,versions=v1,name=mservice.elbv2.k8s.aws,sideEffects=None,webhookVersions=v1,admissionReviewVersions=v1beta1 diff --git a/webhooks/core/service_mutator_test.go b/webhooks/core/service_mutator_test.go new file mode 100644 index 000000000..478e616a6 --- /dev/null +++ b/webhooks/core/service_mutator_test.go @@ -0,0 +1,50 @@ +package core + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" +) + +func TestMutateUpdate_WhenServiceIsNotLoadBalancer(t *testing.T) { + m := &serviceMutator{} + svc := &corev1.Service{Spec: corev1.ServiceSpec{Type: corev1.ServiceTypeClusterIP}} + _, err := m.MutateUpdate(context.Background(), svc, svc) + assert.NoError(t, err) + + assert.Nil(t, svc.Spec.LoadBalancerClass) +} + +func TestMutateUpdate_WhenOldServiceHasLoadBalancerClassAndNewServiceDoesNot(t *testing.T) { + m := &serviceMutator{} + oldSvc := &corev1.Service{Spec: corev1.ServiceSpec{Type: corev1.ServiceTypeLoadBalancer, LoadBalancerClass: stringPtr("old-class")}} + newSvc := &corev1.Service{Spec: corev1.ServiceSpec{Type: corev1.ServiceTypeLoadBalancer}} + _, err := m.MutateUpdate(context.Background(), newSvc, oldSvc) + assert.NoError(t, err) + assert.Equal(t, "old-class", *newSvc.Spec.LoadBalancerClass) +} + +func TestMutateUpdate_WhenOldServiceHasLoadBalancerClassAndNewServiceHasDifferent(t *testing.T) { + m := &serviceMutator{} + oldSvc := &corev1.Service{Spec: corev1.ServiceSpec{Type: corev1.ServiceTypeLoadBalancer, LoadBalancerClass: stringPtr("old-class")}} + newSvc := &corev1.Service{Spec: corev1.ServiceSpec{Type: corev1.ServiceTypeLoadBalancer, LoadBalancerClass: stringPtr("new-class")}} + _, err := m.MutateUpdate(context.Background(), newSvc, oldSvc) + assert.NoError(t, err) + assert.Equal(t, "new-class", *newSvc.Spec.LoadBalancerClass) +} + +func TestMutateUpdate_WhenOldServiceDoesNotHaveLoadBalancerClass(t *testing.T) { + m := &serviceMutator{} + oldSvc := &corev1.Service{Spec: corev1.ServiceSpec{Type: corev1.ServiceTypeLoadBalancer}} + newSvc := &corev1.Service{Spec: corev1.ServiceSpec{Type: corev1.ServiceTypeLoadBalancer}} + _, err := m.MutateUpdate(context.Background(), newSvc, oldSvc) + assert.NoError(t, err) + + assert.Nil(t, newSvc.Spec.LoadBalancerClass) +} + +func stringPtr(s string) *string { + return &s +}