diff --git a/api/v1beta1/azuremachine_webhook.go b/api/v1beta1/azuremachine_webhook.go index 4aee5ddf41b..8082ddec352 100644 --- a/api/v1beta1/azuremachine_webhook.go +++ b/api/v1beta1/azuremachine_webhook.go @@ -154,7 +154,10 @@ func (mw *azureMachineWebhook) ValidateUpdate(ctx context.Context, oldObj, newOb allErrs = append(allErrs, err) } - if err := webhookutils.ValidateImmutable( + // Spec.AcceleratedNetworking can only be reset to nil and no other changes apart from that + // is accepted if the field is set. + // Ref issue #3518 + if err := webhookutils.ValidateZeroTransition( field.NewPath("Spec", "AcceleratedNetworking"), old.Spec.AcceleratedNetworking, m.Spec.AcceleratedNetworking); err != nil { diff --git a/api/v1beta1/azuremachine_webhook_test.go b/api/v1beta1/azuremachine_webhook_test.go index 510d3f59332..bf675fb24bb 100644 --- a/api/v1beta1/azuremachine_webhook_test.go +++ b/api/v1beta1/azuremachine_webhook_test.go @@ -536,6 +536,34 @@ func TestAzureMachine_ValidateUpdate(t *testing.T) { }, wantErr: false, }, + { + name: "validTest: azuremachine.spec.AcceleratedNetworking transition(from true) to nil is acceptable", + oldMachine: &AzureMachine{ + Spec: AzureMachineSpec{ + AcceleratedNetworking: pointer.Bool(true), + }, + }, + newMachine: &AzureMachine{ + Spec: AzureMachineSpec{ + AcceleratedNetworking: nil, + }, + }, + wantErr: false, + }, + { + name: "validTest: azuremachine.spec.AcceleratedNetworking transition(from false) to nil is acceptable", + oldMachine: &AzureMachine{ + Spec: AzureMachineSpec{ + AcceleratedNetworking: pointer.Bool(false), + }, + }, + newMachine: &AzureMachine{ + Spec: AzureMachineSpec{ + AcceleratedNetworking: nil, + }, + }, + wantErr: false, + }, { name: "invalidTest: azuremachine.spec.SpotVMOptions is immutable", oldMachine: &AzureMachine{ diff --git a/util/webhook/validator.go b/util/webhook/validator.go index 00b099cc1c4..cfb3228e792 100644 --- a/util/webhook/validator.go +++ b/util/webhook/validator.go @@ -52,6 +52,16 @@ func ValidateImmutable(path *field.Path, oldVal, newVal any) *field.Error { return nil } +// ValidateZeroTransition validates equality across two values, with only exception to allow +// the value to transition of a zero value. +func ValidateZeroTransition(path *field.Path, oldVal, newVal any) *field.Error { + if reflect.ValueOf(newVal).IsZero() { + // unsetting the field is allowed + return nil + } + return ValidateImmutable(path, oldVal, newVal) +} + // EnsureStringSlicesAreEquivalent returns if two string slices have equal lengths, // and that they have the exact same items; it does not enforce strict ordering of items. func EnsureStringSlicesAreEquivalent(a []string, b []string) bool { diff --git a/util/webhook/validator_test.go b/util/webhook/validator_test.go index 7657fbbbd27..39f77dab520 100644 --- a/util/webhook/validator_test.go +++ b/util/webhook/validator_test.go @@ -299,3 +299,215 @@ func TestEnsureStringSlicesAreEquivalent(t *testing.T) { }) } } + +func TestValidateZeroTransitionPtr(t *testing.T) { + testPath := field.NewPath("Spec", "Foo") + + tests := []struct { + name string + input1 *bool + input2 *bool + expectedOutput *field.Error + }{ + { + name: "nil", + input1: nil, + input2: nil, + }, + { + name: "no change", + input1: pointer.Bool(true), + input2: pointer.Bool(true), + }, + { + name: "can unset", + input1: pointer.Bool(true), + input2: nil, + }, + { + name: "can't set from empty", + input1: nil, + input2: pointer.Bool(true), + expectedOutput: field.Invalid(testPath, nil, setMessage), + }, + { + name: "can't change", + input1: pointer.Bool(true), + input2: pointer.Bool(false), + expectedOutput: field.Invalid(testPath, nil, immutableMessage), + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + err := ValidateZeroTransition(testPath, tc.input1, tc.input2) + if tc.expectedOutput != nil { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Detail).To(Equal(tc.expectedOutput.Detail)) + g.Expect(err.Type).To(Equal(tc.expectedOutput.Type)) + g.Expect(err.Field).To(Equal(tc.expectedOutput.Field)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + } + }) + } +} + +func TestValidateZeroTransitionString(t *testing.T) { + testPath := field.NewPath("Spec", "Foo") + + tests := []struct { + name string + input1 string + input2 string + expectedOutput *field.Error + }{ + { + name: "empty string", + input1: "", + input2: "", + }, + { + name: "no change", + input1: "foo", + input2: "foo", + }, + { + name: "can unset", + input1: "foo", + input2: "", + }, + { + name: "can't set from empty", + input1: "", + input2: "foo", + expectedOutput: field.Invalid(testPath, nil, setMessage), + }, + { + name: "can't change", + input1: "foo", + input2: "bar", + expectedOutput: field.Invalid(testPath, nil, immutableMessage), + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + err := ValidateZeroTransition(testPath, tc.input1, tc.input2) + if tc.expectedOutput != nil { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Detail).To(Equal(tc.expectedOutput.Detail)) + g.Expect(err.Type).To(Equal(tc.expectedOutput.Type)) + g.Expect(err.Field).To(Equal(tc.expectedOutput.Field)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + } + }) + } +} + +func TestValidateZeroTransitionStringPtr(t *testing.T) { + testPath := field.NewPath("Spec", "Foo") + + tests := []struct { + name string + input1 *string + input2 *string + expectedOutput *field.Error + }{ + { + name: "nil", + input1: nil, + input2: nil, + }, + { + name: "no change", + input1: pointer.String("foo"), + input2: pointer.String("foo"), + }, + { + name: "can unset", + input1: pointer.String("foo"), + input2: nil, + }, + { + name: "can't set from empty", + input1: nil, + input2: pointer.String("foo"), + expectedOutput: field.Invalid(testPath, nil, setMessage), + }, + { + name: "can't change", + input1: pointer.String("foo"), + input2: pointer.String("bar"), + expectedOutput: field.Invalid(testPath, nil, immutableMessage), + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + err := ValidateZeroTransition(testPath, tc.input1, tc.input2) + if tc.expectedOutput != nil { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Detail).To(Equal(tc.expectedOutput.Detail)) + g.Expect(err.Type).To(Equal(tc.expectedOutput.Type)) + g.Expect(err.Field).To(Equal(tc.expectedOutput.Field)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + } + }) + } +} + +func TestValidateZeroTransitionInt32(t *testing.T) { + testPath := field.NewPath("Spec", "Foo") + + tests := []struct { + name string + input1 int32 + input2 int32 + expectedOutput *field.Error + }{ + { + name: "unset", + input1: 0, + input2: 0, + }, + { + name: "no change", + input1: 5, + input2: 5, + }, + { + name: "can unset", + input1: 5, + input2: 0, + }, + { + name: "can't set from empty", + input1: 0, + input2: 5, + expectedOutput: field.Invalid(testPath, nil, setMessage), + }, + { + name: "can't change", + input1: 5, + input2: 6, + expectedOutput: field.Invalid(testPath, nil, immutableMessage), + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + err := ValidateZeroTransition(testPath, tc.input1, tc.input2) + if tc.expectedOutput != nil { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Detail).To(Equal(tc.expectedOutput.Detail)) + g.Expect(err.Type).To(Equal(tc.expectedOutput.Type)) + g.Expect(err.Field).To(Equal(tc.expectedOutput.Field)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + } + }) + } +}