Skip to content

Commit

Permalink
refactor: set instead of map for mergeGateways (#2803)
Browse files Browse the repository at this point in the history
* refactor:set[T] instead of map[T]bool

Signed-off-by: Dennis Zhou <[email protected]>

* fix lint

Signed-off-by: Dennis Zhou <[email protected]>

---------

Signed-off-by: Dennis Zhou <[email protected]>
Co-authored-by: Xunzhuo <[email protected]>
  • Loading branch information
deszhou and Xunzhuo authored Mar 7, 2024
1 parent bdda774 commit 44ede66
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 75 deletions.
17 changes: 8 additions & 9 deletions internal/cmd/egctl/translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/util/sets"
gwapiv1 "sigs.k8s.io/gateway-api/apis/v1"
gwapiv1a2 "sigs.k8s.io/gateway-api/apis/v1alpha2"

Expand Down Expand Up @@ -849,21 +850,19 @@ func kubernetesYAMLToResources(str string, addMissingResources bool) (*gatewayap
if provided, found := providedServiceMap[key]; !found {
resources.Services = append(resources.Services, service)
} else {
providedPorts := map[string]bool{}
providedPorts := sets.NewString()
for _, port := range provided.Spec.Ports {
providedPorts[fmt.Sprintf("%s-%d", port.Protocol, port.Port)] = true
portKey := fmt.Sprintf("%s-%d", port.Protocol, port.Port)
providedPorts.Insert(portKey)
}

for _, port := range service.Spec.Ports {
protocol := port.Protocol
port := port.Port
name := fmt.Sprintf("%s-%d", protocol, port)

if _, found := providedPorts[name]; !found {
name := fmt.Sprintf("%s-%d", port.Protocol, port.Port)
if !providedPorts.Has(name) {
servicePort := v1.ServicePort{
Name: name,
Protocol: protocol,
Port: port,
Protocol: port.Protocol,
Port: port.Port,
}
provided.Spec.Ports = append(provided.Spec.Ports, servicePort)
}
Expand Down
41 changes: 21 additions & 20 deletions internal/gatewayapi/backendtrafficpolicy.go
Original file line number Diff line number Diff line change
Expand Up @@ -825,14 +825,15 @@ func (t *Translator) buildHTTPActiveHealthChecker(h *egv1a1.HTTPActiveHealthChec
*irHTTP.Method = strings.ToUpper(*irHTTP.Method)
}

var irStatuses []ir.HTTPStatus
// deduplicate http statuses
statusSet := make(map[egv1a1.HTTPStatus]bool, len(h.ExpectedStatuses))
statusSet := sets.NewInt()
for _, r := range h.ExpectedStatuses {
if _, ok := statusSet[r]; !ok {
statusSet[r] = true
irStatuses = append(irStatuses, ir.HTTPStatus(r))
}
statusSet.Insert(int(r))
}
irStatuses := make([]ir.HTTPStatus, 0, statusSet.Len())

for _, r := range statusSet.List() {
irStatuses = append(irStatuses, ir.HTTPStatus(r))
}
irHTTP.ExpectedStatuses = irStatuses

Expand Down Expand Up @@ -1158,27 +1159,27 @@ func (t *Translator) buildRetry(policy *egv1a1.BackendTrafficPolicy) *ir.Retry {
}

func makeIrStatusSet(in []egv1a1.HTTPStatus) []ir.HTTPStatus {
var irStatuses []ir.HTTPStatus
// deduplicate http statuses
statusSet := make(map[egv1a1.HTTPStatus]bool, len(in))
statusSet := sets.NewInt()
for _, r := range in {
if _, ok := statusSet[r]; !ok {
statusSet[r] = true
irStatuses = append(irStatuses, ir.HTTPStatus(r))
}
statusSet.Insert(int(r))
}
irStatuses := make([]ir.HTTPStatus, 0, statusSet.Len())

for _, r := range statusSet.List() {
irStatuses = append(irStatuses, ir.HTTPStatus(r))
}
return irStatuses
}

func makeIrTriggerSet(in []egv1a1.TriggerEnum) []ir.TriggerEnum {
var irTriggers []ir.TriggerEnum
// deduplicate http statuses
triggerSet := make(map[egv1a1.TriggerEnum]bool, len(in))
triggerSet := sets.NewString()
for _, r := range in {
if _, ok := triggerSet[r]; !ok {
triggerSet[r] = true
irTriggers = append(irTriggers, ir.TriggerEnum(r))
}
triggerSet.Insert(string(r))
}
irTriggers := make([]ir.TriggerEnum, 0, triggerSet.Len())

for _, r := range triggerSet.List() {
irTriggers = append(irTriggers, ir.TriggerEnum(r))
}
return irTriggers
}
58 changes: 58 additions & 0 deletions internal/gatewayapi/backendtrafficpolicy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@ package gatewayapi

import (
"math"
"reflect"
"testing"

"github.com/stretchr/testify/require"

egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1"
"github.com/envoyproxy/gateway/internal/ir"
)

func TestInt64ToUint32(t *testing.T) {
Expand Down Expand Up @@ -50,3 +54,57 @@ func TestInt64ToUint32(t *testing.T) {
})
}
}

func TestMakeIrStatusSet(t *testing.T) {
tests := []struct {
name string
in []egv1a1.HTTPStatus
want []ir.HTTPStatus
}{
{
name: "no duplicates",
in: []egv1a1.HTTPStatus{200, 404},
want: []ir.HTTPStatus{200, 404},
},
{
name: "with duplicates",
in: []egv1a1.HTTPStatus{200, 404, 200},
want: []ir.HTTPStatus{200, 404},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := makeIrStatusSet(tt.in); !reflect.DeepEqual(got, tt.want) {
t.Errorf("makeIrStatusSet() = %v, want %v", got, tt.want)
}
})
}
}

func TestMakeIrTriggerSet(t *testing.T) {
tests := []struct {
name string
in []egv1a1.TriggerEnum
want []ir.TriggerEnum
}{
{
name: "no duplicates",
in: []egv1a1.TriggerEnum{"5xx", "reset"},
want: []ir.TriggerEnum{"5xx", "reset"},
},
{
name: "with duplicates",
in: []egv1a1.TriggerEnum{"5xx", "reset", "5xx"},
want: []ir.TriggerEnum{"5xx", "reset"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := makeIrTriggerSet(tt.in); !reflect.DeepEqual(got, tt.want) {
t.Errorf("makeIrTriggerSet() = %v, want %v", got, tt.want)
}
})
}
}
22 changes: 5 additions & 17 deletions internal/gatewayapi/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/sets"
v1 "sigs.k8s.io/gateway-api/apis/v1"

"github.com/envoyproxy/gateway/api/v1alpha1"
Expand Down Expand Up @@ -356,23 +357,10 @@ func (r *Runner) deleteAllStatusKeys() {
// based on the difference between the current keys and the
// new keys parameters passed to the function.
func getIRKeysToDelete(curKeys, newKeys []string) []string {
var delKeys []string
remaining := make(map[string]bool)
curSet := sets.NewString(curKeys...)
newSet := sets.NewString(newKeys...)

// Add all current keys to the remaining map
for _, key := range curKeys {
remaining[key] = true
}

// Delete newKeys from the remaining map
// to get keys that need to be deleted
for _, key := range newKeys {
delete(remaining, key)
}

for key := range remaining {
delKeys = append(delKeys, key)
}
delSet := curSet.Difference(newSet)

return delKeys
return delSet.List()
}
29 changes: 13 additions & 16 deletions internal/ir/xds.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/validation"
"sigs.k8s.io/yaml"

Expand Down Expand Up @@ -714,53 +715,49 @@ func (h HTTPRoute) Validate() error {
}
}
if len(h.AddRequestHeaders) > 0 {
occurred := map[string]bool{}
occurred := sets.NewString()
for _, header := range h.AddRequestHeaders {
if err := header.Validate(); err != nil {
errs = errors.Join(errs, err)
}
if !occurred[header.Name] {
occurred[header.Name] = true
} else {
if occurred.Has(header.Name) {
errs = errors.Join(errs, ErrAddHeaderDuplicate)
break
}
occurred.Insert(header.Name)
}
}
if len(h.RemoveRequestHeaders) > 0 {
occurred := map[string]bool{}
occurred := sets.NewString()
for _, header := range h.RemoveRequestHeaders {
if !occurred[header] {
occurred[header] = true
} else {
if occurred.Has(header) {
errs = errors.Join(errs, ErrRemoveHeaderDuplicate)
break
}
occurred.Insert(header)
}
}
if len(h.AddResponseHeaders) > 0 {
occurred := map[string]bool{}
occurred := sets.NewString()
for _, header := range h.AddResponseHeaders {
if err := header.Validate(); err != nil {
errs = errors.Join(errs, err)
}
if !occurred[header.Name] {
occurred[header.Name] = true
} else {
if occurred.Has(header.Name) {
errs = errors.Join(errs, ErrAddHeaderDuplicate)
break
}
occurred.Insert(header.Name)
}
}
if len(h.RemoveResponseHeaders) > 0 {
occurred := map[string]bool{}
occurred := sets.NewString()
for _, header := range h.RemoveResponseHeaders {
if !occurred[header] {
occurred[header] = true
} else {
if occurred.Has(header) {
errs = errors.Join(errs, ErrRemoveHeaderDuplicate)
break
}
occurred.Insert(header)
}
}
if h.LoadBalancer != nil {
Expand Down
11 changes: 8 additions & 3 deletions internal/provider/kubernetes/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"k8s.io/apimachinery/pkg/fields"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/client-go/discovery"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/controller"
Expand Down Expand Up @@ -51,7 +52,7 @@ type gatewayAPIReconciler struct {
namespace string
namespaceLabel *metav1.LabelSelector
envoyGateway *egv1a1.EnvoyGateway
mergeGateways map[string]bool
mergeGateways sets.Set[string]
resources *message.ProviderResources
extGVKs []schema.GroupVersionKind
}
Expand Down Expand Up @@ -87,7 +88,7 @@ func newGatewayAPIController(mgr manager.Manager, cfg *config.Server, su status.
extGVKs: extGVKs,
store: newProviderStore(),
envoyGateway: cfg.EnvoyGateway,
mergeGateways: map[string]bool{},
mergeGateways: sets.New[string](),
}

if byNamespaceSelector {
Expand Down Expand Up @@ -356,7 +357,11 @@ func (r *gatewayAPIReconciler) Reconcile(ctx context.Context, _ reconcile.Reques
}

if gwcResource.EnvoyProxy != nil && gwcResource.EnvoyProxy.Spec.MergeGateways != nil {
r.mergeGateways[acceptedGC.Name] = *gwcResource.EnvoyProxy.Spec.MergeGateways
if *gwcResource.EnvoyProxy.Spec.MergeGateways {
r.mergeGateways.Insert(acceptedGC.Name)
} else {
r.mergeGateways.Delete(acceptedGC.Name)
}
}

if err := r.updateStatusForGatewayClass(ctx, acceptedGC, true, string(gwapiv1.GatewayClassReasonAccepted), status.MsgValidGatewayClass); err != nil {
Expand Down
8 changes: 4 additions & 4 deletions internal/provider/kubernetes/predicates.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func (r *gatewayAPIReconciler) validateServiceForReconcile(obj client.Object) bo

// Merged gateways will have only this label, update status of all Gateways under found GatewayClass.
gcName, ok := labels[gatewayapi.OwningGatewayClassLabel]
if ok && r.mergeGateways[gcName] {
if ok && r.mergeGateways.Has(gcName) {
if err := r.updateStatusForGatewaysUnderGatewayClass(ctx, gcName); err != nil {
r.log.Info("no Gateways found under GatewayClass", "name", gcName)
return false
Expand Down Expand Up @@ -390,7 +390,7 @@ func (r *gatewayAPIReconciler) validateDeploymentForReconcile(obj client.Object)

// Merged gateways will have only this label, update status of all Gateways under found GatewayClass.
gcName, ok := labels[gatewayapi.OwningGatewayClassLabel]
if ok && r.mergeGateways[gcName] {
if ok && r.mergeGateways.Has(gcName) {
if err := r.updateStatusForGatewaysUnderGatewayClass(ctx, gcName); err != nil {
r.log.Info("no Gateways found under GatewayClass", "name", gcName)
return false
Expand All @@ -406,7 +406,7 @@ func (r *gatewayAPIReconciler) validateDeploymentForReconcile(obj client.Object)
func (r *gatewayAPIReconciler) envoyDeploymentForGateway(ctx context.Context, gateway *gwapiv1.Gateway) (*appsv1.Deployment, error) {
key := types.NamespacedName{
Namespace: r.namespace,
Name: infraName(gateway, r.mergeGateways[string(gateway.Spec.GatewayClassName)]),
Name: infraName(gateway, r.mergeGateways.Has(string(gateway.Spec.GatewayClassName))),
}
deployment := new(appsv1.Deployment)
if err := r.client.Get(ctx, key, deployment); err != nil {
Expand All @@ -422,7 +422,7 @@ func (r *gatewayAPIReconciler) envoyDeploymentForGateway(ctx context.Context, ga
func (r *gatewayAPIReconciler) envoyServiceForGateway(ctx context.Context, gateway *gwapiv1.Gateway) (*corev1.Service, error) {
key := types.NamespacedName{
Namespace: r.namespace,
Name: infraName(gateway, r.mergeGateways[string(gateway.Spec.GatewayClassName)]),
Name: infraName(gateway, r.mergeGateways.Has(string(gateway.Spec.GatewayClassName))),
}
svc := new(corev1.Service)
if err := r.client.Get(ctx, key, svc); err != nil {
Expand Down
9 changes: 3 additions & 6 deletions internal/provider/kubernetes/predicates_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"
fakeclient "sigs.k8s.io/controller-runtime/pkg/client/fake"
Expand Down Expand Up @@ -556,9 +557,7 @@ func TestValidateServiceForReconcile(t *testing.T) {
r := gatewayAPIReconciler{
classController: v1alpha1.GatewayControllerName,
log: logger,
mergeGateways: map[string]bool{
"test-mg": true,
},
mergeGateways: sets.New[string]("test-mg"),
}

for _, tc := range testCases {
Expand Down Expand Up @@ -653,9 +652,7 @@ func TestValidateDeploymentForReconcile(t *testing.T) {
r := gatewayAPIReconciler{
classController: v1alpha1.GatewayControllerName,
log: logger,
mergeGateways: map[string]bool{
"test-mg": true,
},
mergeGateways: sets.New[string]("test-mg"),
}

for _, tc := range testCases {
Expand Down

0 comments on commit 44ede66

Please sign in to comment.