diff --git a/internal/cmd/egctl/translate.go b/internal/cmd/egctl/translate.go index a97c83261b7..c8c3d51d24a 100644 --- a/internal/cmd/egctl/translate.go +++ b/internal/cmd/egctl/translate.go @@ -27,6 +27,7 @@ import ( v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/util/sets" gwapiv1 "sigs.k8s.io/gateway-api/apis/v1" gwapiv1a2 "sigs.k8s.io/gateway-api/apis/v1alpha2" @@ -849,21 +850,19 @@ func kubernetesYAMLToResources(str string, addMissingResources bool) (*gatewayap if provided, found := providedServiceMap[key]; !found { resources.Services = append(resources.Services, service) } else { - providedPorts := map[string]bool{} + providedPorts := sets.NewString() for _, port := range provided.Spec.Ports { - providedPorts[fmt.Sprintf("%s-%d", port.Protocol, port.Port)] = true + portKey := fmt.Sprintf("%s-%d", port.Protocol, port.Port) + providedPorts.Insert(portKey) } for _, port := range service.Spec.Ports { - protocol := port.Protocol - port := port.Port - name := fmt.Sprintf("%s-%d", protocol, port) - - if _, found := providedPorts[name]; !found { + name := fmt.Sprintf("%s-%d", port.Protocol, port.Port) + if !providedPorts.Has(name) { servicePort := v1.ServicePort{ Name: name, - Protocol: protocol, - Port: port, + Protocol: port.Protocol, + Port: port.Port, } provided.Spec.Ports = append(provided.Spec.Ports, servicePort) } diff --git a/internal/gatewayapi/backendtrafficpolicy.go b/internal/gatewayapi/backendtrafficpolicy.go index ecabbfef29b..f57f17e5726 100644 --- a/internal/gatewayapi/backendtrafficpolicy.go +++ b/internal/gatewayapi/backendtrafficpolicy.go @@ -825,14 +825,15 @@ func (t *Translator) buildHTTPActiveHealthChecker(h *egv1a1.HTTPActiveHealthChec *irHTTP.Method = strings.ToUpper(*irHTTP.Method) } - var irStatuses []ir.HTTPStatus // deduplicate http statuses - statusSet := make(map[egv1a1.HTTPStatus]bool, len(h.ExpectedStatuses)) + statusSet := sets.NewInt() for _, r := range h.ExpectedStatuses { - if _, ok := statusSet[r]; !ok { - statusSet[r] = true - irStatuses = append(irStatuses, ir.HTTPStatus(r)) - } + statusSet.Insert(int(r)) + } + irStatuses := make([]ir.HTTPStatus, 0, statusSet.Len()) + + for _, r := range statusSet.List() { + irStatuses = append(irStatuses, ir.HTTPStatus(r)) } irHTTP.ExpectedStatuses = irStatuses @@ -1158,27 +1159,27 @@ func (t *Translator) buildRetry(policy *egv1a1.BackendTrafficPolicy) *ir.Retry { } func makeIrStatusSet(in []egv1a1.HTTPStatus) []ir.HTTPStatus { - var irStatuses []ir.HTTPStatus - // deduplicate http statuses - statusSet := make(map[egv1a1.HTTPStatus]bool, len(in)) + statusSet := sets.NewInt() for _, r := range in { - if _, ok := statusSet[r]; !ok { - statusSet[r] = true - irStatuses = append(irStatuses, ir.HTTPStatus(r)) - } + statusSet.Insert(int(r)) + } + irStatuses := make([]ir.HTTPStatus, 0, statusSet.Len()) + + for _, r := range statusSet.List() { + irStatuses = append(irStatuses, ir.HTTPStatus(r)) } return irStatuses } func makeIrTriggerSet(in []egv1a1.TriggerEnum) []ir.TriggerEnum { - var irTriggers []ir.TriggerEnum - // deduplicate http statuses - triggerSet := make(map[egv1a1.TriggerEnum]bool, len(in)) + triggerSet := sets.NewString() for _, r := range in { - if _, ok := triggerSet[r]; !ok { - triggerSet[r] = true - irTriggers = append(irTriggers, ir.TriggerEnum(r)) - } + triggerSet.Insert(string(r)) + } + irTriggers := make([]ir.TriggerEnum, 0, triggerSet.Len()) + + for _, r := range triggerSet.List() { + irTriggers = append(irTriggers, ir.TriggerEnum(r)) } return irTriggers } diff --git a/internal/gatewayapi/backendtrafficpolicy_test.go b/internal/gatewayapi/backendtrafficpolicy_test.go index df943a2032b..d40d1e68c76 100644 --- a/internal/gatewayapi/backendtrafficpolicy_test.go +++ b/internal/gatewayapi/backendtrafficpolicy_test.go @@ -7,9 +7,13 @@ package gatewayapi import ( "math" + "reflect" "testing" "github.com/stretchr/testify/require" + + egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" + "github.com/envoyproxy/gateway/internal/ir" ) func TestInt64ToUint32(t *testing.T) { @@ -50,3 +54,57 @@ func TestInt64ToUint32(t *testing.T) { }) } } + +func TestMakeIrStatusSet(t *testing.T) { + tests := []struct { + name string + in []egv1a1.HTTPStatus + want []ir.HTTPStatus + }{ + { + name: "no duplicates", + in: []egv1a1.HTTPStatus{200, 404}, + want: []ir.HTTPStatus{200, 404}, + }, + { + name: "with duplicates", + in: []egv1a1.HTTPStatus{200, 404, 200}, + want: []ir.HTTPStatus{200, 404}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := makeIrStatusSet(tt.in); !reflect.DeepEqual(got, tt.want) { + t.Errorf("makeIrStatusSet() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMakeIrTriggerSet(t *testing.T) { + tests := []struct { + name string + in []egv1a1.TriggerEnum + want []ir.TriggerEnum + }{ + { + name: "no duplicates", + in: []egv1a1.TriggerEnum{"5xx", "reset"}, + want: []ir.TriggerEnum{"5xx", "reset"}, + }, + { + name: "with duplicates", + in: []egv1a1.TriggerEnum{"5xx", "reset", "5xx"}, + want: []ir.TriggerEnum{"5xx", "reset"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := makeIrTriggerSet(tt.in); !reflect.DeepEqual(got, tt.want) { + t.Errorf("makeIrTriggerSet() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/gatewayapi/runner/runner.go b/internal/gatewayapi/runner/runner.go index 2b34b8ad33f..13f2c6b9d08 100644 --- a/internal/gatewayapi/runner/runner.go +++ b/internal/gatewayapi/runner/runner.go @@ -10,6 +10,7 @@ import ( "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/sets" v1 "sigs.k8s.io/gateway-api/apis/v1" "github.com/envoyproxy/gateway/api/v1alpha1" @@ -356,23 +357,10 @@ func (r *Runner) deleteAllStatusKeys() { // based on the difference between the current keys and the // new keys parameters passed to the function. func getIRKeysToDelete(curKeys, newKeys []string) []string { - var delKeys []string - remaining := make(map[string]bool) + curSet := sets.NewString(curKeys...) + newSet := sets.NewString(newKeys...) - // Add all current keys to the remaining map - for _, key := range curKeys { - remaining[key] = true - } - - // Delete newKeys from the remaining map - // to get keys that need to be deleted - for _, key := range newKeys { - delete(remaining, key) - } - - for key := range remaining { - delKeys = append(delKeys, key) - } + delSet := curSet.Difference(newSet) - return delKeys + return delSet.List() } diff --git a/internal/ir/xds.go b/internal/ir/xds.go index 4bd51d37c0a..7d1af7c0602 100644 --- a/internal/ir/xds.go +++ b/internal/ir/xds.go @@ -16,6 +16,7 @@ import ( apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/validation" "sigs.k8s.io/yaml" @@ -714,53 +715,49 @@ func (h HTTPRoute) Validate() error { } } if len(h.AddRequestHeaders) > 0 { - occurred := map[string]bool{} + occurred := sets.NewString() for _, header := range h.AddRequestHeaders { if err := header.Validate(); err != nil { errs = errors.Join(errs, err) } - if !occurred[header.Name] { - occurred[header.Name] = true - } else { + if occurred.Has(header.Name) { errs = errors.Join(errs, ErrAddHeaderDuplicate) break } + occurred.Insert(header.Name) } } if len(h.RemoveRequestHeaders) > 0 { - occurred := map[string]bool{} + occurred := sets.NewString() for _, header := range h.RemoveRequestHeaders { - if !occurred[header] { - occurred[header] = true - } else { + if occurred.Has(header) { errs = errors.Join(errs, ErrRemoveHeaderDuplicate) break } + occurred.Insert(header) } } if len(h.AddResponseHeaders) > 0 { - occurred := map[string]bool{} + occurred := sets.NewString() for _, header := range h.AddResponseHeaders { if err := header.Validate(); err != nil { errs = errors.Join(errs, err) } - if !occurred[header.Name] { - occurred[header.Name] = true - } else { + if occurred.Has(header.Name) { errs = errors.Join(errs, ErrAddHeaderDuplicate) break } + occurred.Insert(header.Name) } } if len(h.RemoveResponseHeaders) > 0 { - occurred := map[string]bool{} + occurred := sets.NewString() for _, header := range h.RemoveResponseHeaders { - if !occurred[header] { - occurred[header] = true - } else { + if occurred.Has(header) { errs = errors.Join(errs, ErrRemoveHeaderDuplicate) break } + occurred.Insert(header) } } if h.LoadBalancer != nil { diff --git a/internal/provider/kubernetes/controller.go b/internal/provider/kubernetes/controller.go index 5e66e0e6a7b..fb3ade2bcff 100644 --- a/internal/provider/kubernetes/controller.go +++ b/internal/provider/kubernetes/controller.go @@ -18,6 +18,7 @@ import ( "k8s.io/apimachinery/pkg/fields" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/client-go/discovery" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller" @@ -51,7 +52,7 @@ type gatewayAPIReconciler struct { namespace string namespaceLabel *metav1.LabelSelector envoyGateway *egv1a1.EnvoyGateway - mergeGateways map[string]bool + mergeGateways sets.Set[string] resources *message.ProviderResources extGVKs []schema.GroupVersionKind } @@ -87,7 +88,7 @@ func newGatewayAPIController(mgr manager.Manager, cfg *config.Server, su status. extGVKs: extGVKs, store: newProviderStore(), envoyGateway: cfg.EnvoyGateway, - mergeGateways: map[string]bool{}, + mergeGateways: sets.New[string](), } if byNamespaceSelector { @@ -356,7 +357,11 @@ func (r *gatewayAPIReconciler) Reconcile(ctx context.Context, _ reconcile.Reques } if gwcResource.EnvoyProxy != nil && gwcResource.EnvoyProxy.Spec.MergeGateways != nil { - r.mergeGateways[acceptedGC.Name] = *gwcResource.EnvoyProxy.Spec.MergeGateways + if *gwcResource.EnvoyProxy.Spec.MergeGateways { + r.mergeGateways.Insert(acceptedGC.Name) + } else { + r.mergeGateways.Delete(acceptedGC.Name) + } } if err := r.updateStatusForGatewayClass(ctx, acceptedGC, true, string(gwapiv1.GatewayClassReasonAccepted), status.MsgValidGatewayClass); err != nil { diff --git a/internal/provider/kubernetes/predicates.go b/internal/provider/kubernetes/predicates.go index f77b46ea99a..3585a2913ae 100644 --- a/internal/provider/kubernetes/predicates.go +++ b/internal/provider/kubernetes/predicates.go @@ -238,7 +238,7 @@ func (r *gatewayAPIReconciler) validateServiceForReconcile(obj client.Object) bo // Merged gateways will have only this label, update status of all Gateways under found GatewayClass. gcName, ok := labels[gatewayapi.OwningGatewayClassLabel] - if ok && r.mergeGateways[gcName] { + if ok && r.mergeGateways.Has(gcName) { if err := r.updateStatusForGatewaysUnderGatewayClass(ctx, gcName); err != nil { r.log.Info("no Gateways found under GatewayClass", "name", gcName) return false @@ -390,7 +390,7 @@ func (r *gatewayAPIReconciler) validateDeploymentForReconcile(obj client.Object) // Merged gateways will have only this label, update status of all Gateways under found GatewayClass. gcName, ok := labels[gatewayapi.OwningGatewayClassLabel] - if ok && r.mergeGateways[gcName] { + if ok && r.mergeGateways.Has(gcName) { if err := r.updateStatusForGatewaysUnderGatewayClass(ctx, gcName); err != nil { r.log.Info("no Gateways found under GatewayClass", "name", gcName) return false @@ -406,7 +406,7 @@ func (r *gatewayAPIReconciler) validateDeploymentForReconcile(obj client.Object) func (r *gatewayAPIReconciler) envoyDeploymentForGateway(ctx context.Context, gateway *gwapiv1.Gateway) (*appsv1.Deployment, error) { key := types.NamespacedName{ Namespace: r.namespace, - Name: infraName(gateway, r.mergeGateways[string(gateway.Spec.GatewayClassName)]), + Name: infraName(gateway, r.mergeGateways.Has(string(gateway.Spec.GatewayClassName))), } deployment := new(appsv1.Deployment) if err := r.client.Get(ctx, key, deployment); err != nil { @@ -422,7 +422,7 @@ func (r *gatewayAPIReconciler) envoyDeploymentForGateway(ctx context.Context, ga func (r *gatewayAPIReconciler) envoyServiceForGateway(ctx context.Context, gateway *gwapiv1.Gateway) (*corev1.Service, error) { key := types.NamespacedName{ Namespace: r.namespace, - Name: infraName(gateway, r.mergeGateways[string(gateway.Spec.GatewayClassName)]), + Name: infraName(gateway, r.mergeGateways.Has(string(gateway.Spec.GatewayClassName))), } svc := new(corev1.Service) if err := r.client.Get(ctx, key, svc); err != nil { diff --git a/internal/provider/kubernetes/predicates_test.go b/internal/provider/kubernetes/predicates_test.go index cd88d7b7100..f923eef8a26 100644 --- a/internal/provider/kubernetes/predicates_test.go +++ b/internal/provider/kubernetes/predicates_test.go @@ -13,6 +13,7 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" fakeclient "sigs.k8s.io/controller-runtime/pkg/client/fake" @@ -556,9 +557,7 @@ func TestValidateServiceForReconcile(t *testing.T) { r := gatewayAPIReconciler{ classController: v1alpha1.GatewayControllerName, log: logger, - mergeGateways: map[string]bool{ - "test-mg": true, - }, + mergeGateways: sets.New[string]("test-mg"), } for _, tc := range testCases { @@ -653,9 +652,7 @@ func TestValidateDeploymentForReconcile(t *testing.T) { r := gatewayAPIReconciler{ classController: v1alpha1.GatewayControllerName, log: logger, - mergeGateways: map[string]bool{ - "test-mg": true, - }, + mergeGateways: sets.New[string]("test-mg"), } for _, tc := range testCases {