From e8ccd212952cebf768171982b2d0cbaf7c59fe6b Mon Sep 17 00:00:00 2001 From: Matthew Booth Date: Thu, 22 Feb 2024 02:18:12 +0000 Subject: [PATCH] Reduce cyclomatic complexity of ReconcileLoadBalancer This function had become genuinely too complex over time, to the point that even the linter was starting to complain about it when making almost any change. This change refactors ReconcileLoadBalancer into several smaller logical functions which are much easier to read and reason about. It also revealed some trivial optimisations: * Only fetch Octavia providers if we need them to create a new loadbalancer * Only calculate allowed CIDRs once * Don't re-fetch a loadbalancer to check it's active if it's already active Co-Authored-By: Emilien Macchi --- .../services/loadbalancer/loadbalancer.go | 358 ++++++++++------- .../loadbalancer/loadbalancer_test.go | 360 +++++++++++++++++- 2 files changed, 568 insertions(+), 150 deletions(-) diff --git a/pkg/cloud/services/loadbalancer/loadbalancer.go b/pkg/cloud/services/loadbalancer/loadbalancer.go index 1d0812c353..38dda5ecb6 100644 --- a/pkg/cloud/services/loadbalancer/loadbalancer.go +++ b/pkg/cloud/services/loadbalancer/loadbalancer.go @@ -76,42 +76,7 @@ func (s *Service) ReconcileLoadBalancer(openStackCluster *infrav1.OpenStackClust openStackCluster.Status.APIServerLoadBalancer = lbStatus } - var fixedIPAddress string - var err error - - switch { - case lbStatus.InternalIP != "": - fixedIPAddress = lbStatus.InternalIP - case openStackCluster.Spec.APIServerFixedIP != "": - fixedIPAddress = openStackCluster.Spec.APIServerFixedIP - case openStackCluster.Spec.DisableAPIServerFloatingIP && openStackCluster.Spec.ControlPlaneEndpoint.IsValid(): - fixedIPAddress, err = lookupHost(openStackCluster.Spec.ControlPlaneEndpoint.Host) - if err != nil { - return false, fmt.Errorf("lookup host: %w", err) - } - } - - providers, err := s.loadbalancerClient.ListLoadBalancerProviders() - if err != nil { - return false, err - } - - // Choose the selected provider if it is set in cluster spec, if not, omit the field and Octavia will use the default provider. - lbProvider := "" - if openStackCluster.Spec.APIServerLoadBalancer.Provider != "" { - for _, v := range providers { - if v.Name == openStackCluster.Spec.APIServerLoadBalancer.Provider { - lbProvider = v.Name - break - } - } - if lbProvider == "" { - record.Warnf(openStackCluster, "OctaviaProviderNotFound", "Provider specified for Octavia not found.") - record.Eventf(openStackCluster, "OctaviaProviderNotFound", "Provider %s specified for Octavia not found, using the default provider.", openStackCluster.Spec.APIServerLoadBalancer.Provider) - } - } - - lb, err := s.getOrCreateLoadBalancer(openStackCluster, loadBalancerName, openStackCluster.Status.Network.Subnets[0].ID, clusterName, fixedIPAddress, lbProvider) + lb, err := s.getOrCreateAPILoadBalancer(openStackCluster, clusterName) if err != nil { return false, err } @@ -121,23 +86,20 @@ func (s *Service) ReconcileLoadBalancer(openStackCluster *infrav1.OpenStackClust lbStatus.InternalIP = lb.VipAddress lbStatus.Tags = lb.Tags - if err := s.waitForLoadBalancerActive(lb.ID); err != nil { - return false, fmt.Errorf("load balancer %q with id %s is not active after timeout: %v", loadBalancerName, lb.ID, err) + if lb.ProvisioningStatus != loadBalancerProvisioningStatusActive { + var err error + lb, err = s.waitForLoadBalancerActive(lb.ID) + if err != nil { + return false, fmt.Errorf("load balancer %q with id %s is not active after timeout: %v", loadBalancerName, lb.ID, err) + } } if !openStackCluster.Spec.DisableAPIServerFloatingIP { - var floatingIPAddress string - switch { - case lbStatus.IP != "": - floatingIPAddress = lbStatus.IP - case openStackCluster.Spec.APIServerFloatingIP != "": - floatingIPAddress = openStackCluster.Spec.APIServerFloatingIP - case openStackCluster.Spec.ControlPlaneEndpoint.IsValid(): - floatingIPAddress, err = lookupHost(openStackCluster.Spec.ControlPlaneEndpoint.Host) - if err != nil { - return false, fmt.Errorf("lookup host: %w", err) - } + floatingIPAddress, err := getAPIServerFloatingIP(openStackCluster) + if err != nil { + return false, err } + fp, err := s.networkingService.GetOrCreateFloatingIP(openStackCluster, openStackCluster, clusterName, floatingIPAddress) if err != nil { return false, err @@ -153,71 +115,190 @@ func (s *Service) ReconcileLoadBalancer(openStackCluster *infrav1.OpenStackClust } } - allowedCIDRs := []string{} - // To reduce API calls towards OpenStack API, let's handle the CIDR support verification for all Ports only once. - allowedCIDRsSupported := false - octaviaVersions, err := s.loadbalancerClient.ListOctaviaVersions() + allowedCIDRsSupported, err := s.isAllowsCIDRSSupported(lb) if err != nil { return false, err } - // The current version is always the last one in the list. - octaviaVersion := octaviaVersions[len(octaviaVersions)-1].ID - if openstackutil.IsOctaviaFeatureSupported(octaviaVersion, openstackutil.OctaviaFeatureVIPACL, lbProvider) { - allowedCIDRsSupported = true + + // AllowedCIDRs will be nil if allowed CIDRs is not supported by the Octavia provider + if allowedCIDRsSupported { + lbStatus.AllowedCIDRs = getCanonicalAllowedCIDRs(openStackCluster) + } else { + lbStatus.AllowedCIDRs = nil } portList := []int{apiServerPort} portList = append(portList, openStackCluster.Spec.APIServerLoadBalancer.AdditionalPorts...) for _, port := range portList { - lbPortObjectsName := fmt.Sprintf("%s-%d", loadBalancerName, port) + if err := s.reconcileAPILoadBalancerListener(lb, openStackCluster, clusterName, port); err != nil { + return false, err + } + } + + return false, nil +} + +// getAPIServerVIPAddress gets the VIP address for the API server from wherever it is specified. +// Returns an empty string if the VIP address is not specified and it should be allocated automatically. +func getAPIServerVIPAddress(openStackCluster *infrav1.OpenStackCluster) (string, error) { + switch { + // We only use call this function when creating the loadbalancer, so this case should never be used + case openStackCluster.Status.APIServerLoadBalancer != nil && openStackCluster.Status.APIServerLoadBalancer.InternalIP != "": + return openStackCluster.Status.APIServerLoadBalancer.InternalIP, nil + + // Explicit fixed IP in the cluster spec + case openStackCluster.Spec.APIServerFixedIP != "": + return openStackCluster.Spec.APIServerFixedIP, nil - listener, err := s.getOrCreateListener(openStackCluster, lbPortObjectsName, lb.ID, port) + // If we are using the VIP as the control plane endpoint, use any value explicitly set on the control plane endpoint + case openStackCluster.Spec.DisableAPIServerFloatingIP && openStackCluster.Spec.ControlPlaneEndpoint.IsValid(): + fixedIPAddress, err := lookupHost(openStackCluster.Spec.ControlPlaneEndpoint.Host) if err != nil { - return false, err + return "", fmt.Errorf("lookup host: %w", err) } + return fixedIPAddress, nil + } + + return "", nil +} + +// getAPIServerFloatingIP gets the floating IP from wherever it is specified. +// Returns an empty string if the floating IP is not specified and it should be allocated automatically. +func getAPIServerFloatingIP(openStackCluster *infrav1.OpenStackCluster) (string, error) { + switch { + // The floating IP was created previously + case openStackCluster.Status.APIServerLoadBalancer != nil && openStackCluster.Status.APIServerLoadBalancer.IP != "": + return openStackCluster.Status.APIServerLoadBalancer.IP, nil - pool, err := s.getOrCreatePool(openStackCluster, lbPortObjectsName, listener.ID, lb.ID, lb.Provider) + // Explicit floating IP in the cluster spec + case openStackCluster.Spec.APIServerFloatingIP != "": + return openStackCluster.Spec.APIServerFloatingIP, nil + + // An IP address is specified explicitly in the control plane endpoint + case openStackCluster.Spec.ControlPlaneEndpoint.IsValid(): + floatingIPAddress, err := lookupHost(openStackCluster.Spec.ControlPlaneEndpoint.Host) if err != nil { - return false, err + return "", fmt.Errorf("lookup host: %w", err) } + return floatingIPAddress, nil + } - if err := s.getOrCreateMonitor(openStackCluster, lbPortObjectsName, pool.ID, lb.ID); err != nil { - return false, err + return "", nil +} + +// getCanonicalAllowedCIDRs gets a filtered list of CIDRs which should be allowed to access the API server loadbalancer. +// Invalid CIDRs are filtered from the list and emil a warning event. +// It returns a canonical representation that can be directly compared with other canonicalized lists. +func getCanonicalAllowedCIDRs(openStackCluster *infrav1.OpenStackCluster) []string { + allowedCIDRs := []string{} + + if len(openStackCluster.Spec.APIServerLoadBalancer.AllowedCIDRs) > 0 { + allowedCIDRs = append(allowedCIDRs, openStackCluster.Spec.APIServerLoadBalancer.AllowedCIDRs...) + + // In the first reconciliation loop, only the Ready field is set in openStackCluster.Status + // All other fields are empty/nil + if openStackCluster.Status.Bastion != nil { + if openStackCluster.Status.Bastion.FloatingIP != "" { + allowedCIDRs = append(allowedCIDRs, openStackCluster.Status.Bastion.FloatingIP) + } + + if openStackCluster.Status.Bastion.IP != "" { + allowedCIDRs = append(allowedCIDRs, openStackCluster.Status.Bastion.IP) + } } - if allowedCIDRsSupported { - // Skip reconciliation if network status is nil (e.g. during clusterctl move) - if openStackCluster.Status.Network != nil { - if err := s.getOrUpdateAllowedCIDRS(openStackCluster, listener); err != nil { - return false, err + if openStackCluster.Status.Network != nil { + for _, subnet := range openStackCluster.Status.Network.Subnets { + if subnet.CIDR != "" { + allowedCIDRs = append(allowedCIDRs, subnet.CIDR) } - allowedCIDRs = listener.AllowedCIDRs + } + + if openStackCluster.Status.Router != nil && len(openStackCluster.Status.Router.IPs) > 0 { + allowedCIDRs = append(allowedCIDRs, openStackCluster.Status.Router.IPs...) } } } - lbStatus.AllowedCIDRs = allowedCIDRs - return false, nil + // Filter invalid CIDRs and convert any IPs into CIDRs. + validCIDRs := []string{} + for _, v := range allowedCIDRs { + switch { + case utilsnet.IsIPv4String(v): + validCIDRs = append(validCIDRs, v+"/32") + case utilsnet.IsIPv4CIDRString(v): + validCIDRs = append(validCIDRs, v) + default: + record.Warnf(openStackCluster, "FailedIPAddressValidation", "%s is not a valid IPv4 nor CIDR address and will not get applied to allowed_cidrs", v) + } + } + + // Sort and remove duplicates + return capostrings.Canonicalize(validCIDRs) +} + +// isAllowsCIDRSSupported returns true if Octavia supports allowed CIDRs for the loadbalancer provider in use. +func (s *Service) isAllowsCIDRSSupported(lb *loadbalancers.LoadBalancer) (bool, error) { + octaviaVersions, err := s.loadbalancerClient.ListOctaviaVersions() + if err != nil { + return false, err + } + // The current version is always the last one in the list. + octaviaVersion := octaviaVersions[len(octaviaVersions)-1].ID + + return openstackutil.IsOctaviaFeatureSupported(octaviaVersion, openstackutil.OctaviaFeatureVIPACL, lb.Provider), nil } -func (s *Service) getOrCreateLoadBalancer(openStackCluster *infrav1.OpenStackCluster, loadBalancerName, subnetID, clusterName, vipAddress, provider string) (*loadbalancers.LoadBalancer, error) { +// getOrCreateAPILoadBalancer returns an existing API loadbalancer if it already exists, or creates a new one if it does not. +func (s *Service) getOrCreateAPILoadBalancer(openStackCluster *infrav1.OpenStackCluster, clusterName string) (*loadbalancers.LoadBalancer, error) { + loadBalancerName := getLoadBalancerName(clusterName) lb, err := s.checkIfLbExists(loadBalancerName) if err != nil { return nil, err } - if lb != nil { return lb, nil } + if openStackCluster.Status.Network == nil { + return nil, fmt.Errorf("network is not yet available in OpenStackCluster.Status") + } + + // Create the VIP on the first cluster subnet + subnetID := openStackCluster.Status.Network.Subnets[0].ID s.scope.Logger().Info("Creating load balancer in subnet", "subnetID", subnetID, "name", loadBalancerName) + providers, err := s.loadbalancerClient.ListLoadBalancerProviders() + if err != nil { + return nil, err + } + + // Choose the selected provider if it is set in cluster spec, if not, omit the field and Octavia will use the default provider. + lbProvider := "" + if openStackCluster.Spec.APIServerLoadBalancer.Provider != "" { + for _, v := range providers { + if v.Name == openStackCluster.Spec.APIServerLoadBalancer.Provider { + lbProvider = v.Name + break + } + } + if lbProvider == "" { + record.Warnf(openStackCluster, "OctaviaProviderNotFound", "Provider specified for Octavia not found.") + record.Eventf(openStackCluster, "OctaviaProviderNotFound", "Provider %s specified for Octavia not found, using the default provider.", openStackCluster.Spec.APIServerLoadBalancer.Provider) + } + } + + vipAddress, err := getAPIServerVIPAddress(openStackCluster) + if err != nil { + return nil, err + } + lbCreateOpts := loadbalancers.CreateOpts{ Name: loadBalancerName, VipSubnetID: subnetID, VipAddress: vipAddress, Description: names.GetDescription(clusterName), - Provider: provider, + Provider: lbProvider, Tags: openStackCluster.Spec.Tags, } lb, err = s.loadbalancerClient.CreateLoadBalancer(lbCreateOpts) @@ -230,7 +311,45 @@ func (s *Service) getOrCreateLoadBalancer(openStackCluster *infrav1.OpenStackClu return lb, nil } -func (s *Service) getOrCreateListener(openStackCluster *infrav1.OpenStackCluster, listenerName, lbID string, port int) (*listeners.Listener, error) { +// reconcileAPILoadBalancerListener ensures that the listener on the given port exists and is configured correctly. +func (s *Service) reconcileAPILoadBalancerListener(lb *loadbalancers.LoadBalancer, openStackCluster *infrav1.OpenStackCluster, clusterName string, port int) error { + loadBalancerName := getLoadBalancerName(clusterName) + lbPortObjectsName := fmt.Sprintf("%s-%d", loadBalancerName, port) + + if openStackCluster.Status.APIServerLoadBalancer == nil { + return fmt.Errorf("APIServerLoadBalancer is not yet available in OpenStackCluster.Status") + } + + allowedCIDRs := openStackCluster.Status.APIServerLoadBalancer.AllowedCIDRs + + listener, err := s.getOrCreateListener(openStackCluster, lbPortObjectsName, lb.ID, allowedCIDRs, port) + if err != nil { + return err + } + + pool, err := s.getOrCreatePool(openStackCluster, lbPortObjectsName, listener.ID, lb.ID, lb.Provider) + if err != nil { + return err + } + + if err := s.getOrCreateMonitor(openStackCluster, lbPortObjectsName, pool.ID, lb.ID); err != nil { + return err + } + + // allowedCIDRs is nil if allowedCIDRs is not supported by the Octavia provider + // A non-nil empty slice is an explicitly empty list + if allowedCIDRs != nil { + if err := s.getOrUpdateAllowedCIDRs(openStackCluster, listener, allowedCIDRs); err != nil { + return err + } + } + + return nil +} + +// getOrCreateListener returns an existing listener for the given loadbalancer +// and port if it already exists, or creates a new one if it does not. +func (s *Service) getOrCreateListener(openStackCluster *infrav1.OpenStackCluster, listenerName, lbID string, allowedCIDRs []string, port int) (*listeners.Listener, error) { listener, err := s.checkIfListenerExists(listenerName) if err != nil { return nil, err @@ -248,6 +367,7 @@ func (s *Service) getOrCreateListener(openStackCluster *infrav1.OpenStackCluster ProtocolPort: port, LoadbalancerID: lbID, Tags: openStackCluster.Spec.Tags, + AllowedCIDRs: allowedCIDRs, } listener, err = s.loadbalancerClient.CreateListener(listenerCreateOpts) if err != nil { @@ -255,7 +375,7 @@ func (s *Service) getOrCreateListener(openStackCluster *infrav1.OpenStackCluster return nil, err } - if err := s.waitForLoadBalancerActive(lbID); err != nil { + if _, err := s.waitForLoadBalancerActive(lbID); err != nil { record.Warnf(openStackCluster, "FailedCreateListener", "Failed to create listener %s with id %s: wait for load balancer active %s: %v", listenerName, listener.ID, lbID, err) return nil, err } @@ -269,42 +389,9 @@ func (s *Service) getOrCreateListener(openStackCluster *infrav1.OpenStackCluster return listener, nil } -func (s *Service) getOrUpdateAllowedCIDRS(openStackCluster *infrav1.OpenStackCluster, listener *listeners.Listener) error { - allowedCIDRs := []string{} - - if len(openStackCluster.Spec.APIServerLoadBalancer.AllowedCIDRs) > 0 { - allowedCIDRs = append(allowedCIDRs, openStackCluster.Spec.APIServerLoadBalancer.AllowedCIDRs...) - - // In the first reconciliation loop, only the Ready field is set in openStackCluster.Status - // All other fields are empty/nil - if openStackCluster.Status.Bastion != nil { - if openStackCluster.Status.Bastion.FloatingIP != "" { - allowedCIDRs = append(allowedCIDRs, openStackCluster.Status.Bastion.FloatingIP) - } - - if openStackCluster.Status.Bastion.IP != "" { - allowedCIDRs = append(allowedCIDRs, openStackCluster.Status.Bastion.IP) - } - } - - if openStackCluster.Status.Network != nil { - for _, subnet := range openStackCluster.Status.Network.Subnets { - if subnet.CIDR != "" { - allowedCIDRs = append(allowedCIDRs, subnet.CIDR) - } - } - - if len(openStackCluster.Status.Router.IPs) > 0 { - allowedCIDRs = append(allowedCIDRs, openStackCluster.Status.Router.IPs...) - } - } - } - - // Validate CIDRs and convert any given IP into a CIDR. - allowedCIDRs = validateIPs(openStackCluster, allowedCIDRs) - +// getOrUpdateAllowedCIDRs ensures that the allowed CIDRs configured on a listener correspond to the expected list. +func (s *Service) getOrUpdateAllowedCIDRs(openStackCluster *infrav1.OpenStackCluster, listener *listeners.Listener, allowedCIDRs []string) error { // Sort and remove duplicates - allowedCIDRs = capostrings.Canonicalize(allowedCIDRs) listener.AllowedCIDRs = capostrings.Canonicalize(listener.AllowedCIDRs) if !slices.Equal(allowedCIDRs, listener.AllowedCIDRs) { @@ -330,24 +417,6 @@ func (s *Service) getOrUpdateAllowedCIDRS(openStackCluster *infrav1.OpenStackClu return nil } -// validateIPs validates given IPs/CIDRs and removes non valid network objects. -func validateIPs(openStackCluster *infrav1.OpenStackCluster, definedCIDRs []string) []string { - marshaledCIDRs := []string{} - - for _, v := range definedCIDRs { - switch { - case utilsnet.IsIPv4String(v): - marshaledCIDRs = append(marshaledCIDRs, v+"/32") - case utilsnet.IsIPv4CIDRString(v): - marshaledCIDRs = append(marshaledCIDRs, v) - default: - record.Warnf(openStackCluster, "FailedIPAddressValidation", "%s is not a valid IPv4 nor CIDR address and will not get applied to allowed_cidrs", v) - } - } - - return marshaledCIDRs -} - func (s *Service) getOrCreatePool(openStackCluster *infrav1.OpenStackCluster, poolName, listenerID, lbID string, lbProvider string) (*pools.Pool, error) { pool, err := s.checkIfPoolExists(poolName) if err != nil { @@ -379,7 +448,7 @@ func (s *Service) getOrCreatePool(openStackCluster *infrav1.OpenStackCluster, po return nil, err } - if err := s.waitForLoadBalancerActive(lbID); err != nil { + if _, err := s.waitForLoadBalancerActive(lbID); err != nil { record.Warnf(openStackCluster, "FailedCreatePool", "Failed to create pool %s with id %s: wait for load balancer active %s: %v", poolName, pool.ID, lbID, err) return nil, err } @@ -421,7 +490,7 @@ func (s *Service) getOrCreateMonitor(openStackCluster *infrav1.OpenStackCluster, return err } - if err = s.waitForLoadBalancerActive(lbID); err != nil { + if _, err = s.waitForLoadBalancerActive(lbID); err != nil { record.Warnf(openStackCluster, "FailedCreateMonitor", "Failed to create monitor %s with id %s: wait for load balancer active %s: %v", monitorName, monitor.ID, lbID, err) return err } @@ -474,14 +543,14 @@ func (s *Service) ReconcileLoadBalancerMember(openStackCluster *infrav1.OpenStac s.scope.Logger().Info("Deleting load balancer member because the IP of the machine changed", "name", name) // lb member changed so let's delete it so we can create it again with the correct IP - err = s.waitForLoadBalancerActive(lbID) + _, err = s.waitForLoadBalancerActive(lbID) if err != nil { return err } if err := s.loadbalancerClient.DeletePoolMember(pool.ID, lbMember.ID); err != nil { return err } - err = s.waitForLoadBalancerActive(lbID) + _, err = s.waitForLoadBalancerActive(lbID) if err != nil { return err } @@ -497,7 +566,7 @@ func (s *Service) ReconcileLoadBalancerMember(openStackCluster *infrav1.OpenStac Tags: openStackCluster.Spec.Tags, } - if err := s.waitForLoadBalancerActive(lbID); err != nil { + if _, err := s.waitForLoadBalancerActive(lbID); err != nil { return err } @@ -505,7 +574,7 @@ func (s *Service) ReconcileLoadBalancerMember(openStackCluster *infrav1.OpenStac return err } - if err := s.waitForLoadBalancerActive(lbID); err != nil { + if _, err := s.waitForLoadBalancerActive(lbID); err != nil { return err } } @@ -598,14 +667,14 @@ func (s *Service) DeleteLoadBalancerMember(openStackCluster *infrav1.OpenStackCl if lbMember != nil { // lb member changed so let's delete it so we can create it again with the correct IP - err = s.waitForLoadBalancerActive(lbID) + _, err = s.waitForLoadBalancerActive(lbID) if err != nil { return err } if err := s.loadbalancerClient.DeletePoolMember(pool.ID, lbMember.ID); err != nil { return err } - err = s.waitForLoadBalancerActive(lbID) + _, err = s.waitForLoadBalancerActive(lbID) if err != nil { return err } @@ -681,15 +750,22 @@ var backoff = wait.Backoff{ } // Possible LoadBalancer states are documented here: https://docs.openstack.org/api-ref/load-balancer/v2/index.html#prov-status -func (s *Service) waitForLoadBalancerActive(id string) error { +func (s *Service) waitForLoadBalancerActive(id string) (*loadbalancers.LoadBalancer, error) { + var lb *loadbalancers.LoadBalancer + s.scope.Logger().Info("Waiting for load balancer", "id", id, "targetStatus", "ACTIVE") - return wait.ExponentialBackoff(backoff, func() (bool, error) { - lb, err := s.loadbalancerClient.GetLoadBalancer(id) + err := wait.ExponentialBackoff(backoff, func() (bool, error) { + var err error + lb, err = s.loadbalancerClient.GetLoadBalancer(id) if err != nil { return false, err } return lb.ProvisioningStatus == loadBalancerProvisioningStatusActive, nil }) + if err != nil { + return nil, err + } + return lb, nil } func (s *Service) waitForListener(id, target string) error { diff --git a/pkg/cloud/services/loadbalancer/loadbalancer_test.go b/pkg/cloud/services/loadbalancer/loadbalancer_test.go index d4a6b69a09..0c9996ca0f 100644 --- a/pkg/cloud/services/loadbalancer/loadbalancer_test.go +++ b/pkg/cloud/services/loadbalancer/loadbalancer_test.go @@ -18,6 +18,7 @@ package loadbalancer import ( "errors" + "fmt" "net" "testing" @@ -37,6 +38,8 @@ import ( "sigs.k8s.io/cluster-api-provider-openstack/pkg/scope" ) +const apiHostname = "api.test-cluster.test" + func Test_ReconcileLoadBalancer(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() @@ -45,7 +48,7 @@ func Test_ReconcileLoadBalancer(t *testing.T) { lookupHost = func(host string) (addrs string, err error) { if net.ParseIP(host) != nil { return host, nil - } else if host == "api.test-cluster.test" { + } else if host == apiHostname { ips := []string{"192.168.100.10"} return ips[0], nil } @@ -56,7 +59,7 @@ func Test_ReconcileLoadBalancer(t *testing.T) { Spec: infrav1.OpenStackClusterSpec{ DisableAPIServerFloatingIP: true, ControlPlaneEndpoint: clusterv1.APIEndpoint{ - Host: "api.test-cluster.test", + Host: apiHostname, Port: 6443, }, }, @@ -83,13 +86,6 @@ func Test_ReconcileLoadBalancer(t *testing.T) { // add network api call results here }, expectLoadBalancer: func(m *mock.MockLbClientMockRecorder) { - // return loadbalancer providers - providers := []providers.Provider{ - {Name: "amphora", Description: "The Octavia Amphora driver."}, - {Name: "octavia", Description: "Deprecated alias of the Octavia Amphora driver."}, - } - m.ListLoadBalancerProviders().Return(providers, nil) - pendingLB := loadbalancers.LoadBalancer{ ID: "aaaaaaaa-bbbb-cccc-dddd-333333333333", Name: "k8s-clusterapi-cluster-AAAAA-kubeapi", @@ -159,3 +155,349 @@ func Test_ReconcileLoadBalancer(t *testing.T) { }) } } + +func Test_getAPIServerVIPAddress(t *testing.T) { + // Stub the call to net.LookupHost + lookupHost = func(host string) (addrs string, err error) { + if net.ParseIP(host) != nil { + return host, nil + } else if host == apiHostname { + ips := []string{"192.168.100.10"} + return ips[0], nil + } + return "", errors.New("Unknown Host " + host) + } + tests := []struct { + name string + openStackCluster *infrav1.OpenStackCluster + want string + wantError bool + }{ + { + name: "empty cluster returns empty VIP", + openStackCluster: &infrav1.OpenStackCluster{}, + want: "", + wantError: false, + }, + { + name: "API server VIP is InternalIP", + openStackCluster: &infrav1.OpenStackCluster{ + Status: infrav1.OpenStackClusterStatus{ + APIServerLoadBalancer: &infrav1.LoadBalancer{ + InternalIP: "1.2.3.4", + }, + }, + }, + want: "1.2.3.4", + wantError: false, + }, + { + name: "API server VIP is API Server Fixed IP", + openStackCluster: &infrav1.OpenStackCluster{ + Spec: infrav1.OpenStackClusterSpec{ + APIServerFixedIP: "1.2.3.4", + }, + }, + want: "1.2.3.4", + wantError: false, + }, + { + name: "API server VIP with valid control plane endpoint", + openStackCluster: &infrav1.OpenStackCluster{ + Spec: infrav1.OpenStackClusterSpec{ + DisableAPIServerFloatingIP: true, + ControlPlaneEndpoint: clusterv1.APIEndpoint{ + Host: apiHostname, + Port: 6443, + }, + }, + }, + want: "192.168.100.10", + wantError: false, + }, + { + name: "API server VIP with invalid control plane endpoint", + openStackCluster: &infrav1.OpenStackCluster{ + Spec: infrav1.OpenStackClusterSpec{ + DisableAPIServerFloatingIP: true, + ControlPlaneEndpoint: clusterv1.APIEndpoint{ + Host: "invalid-api.test-cluster.test", + Port: 6443, + }, + }, + }, + wantError: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewWithT(t) + + got, err := getAPIServerVIPAddress(tt.openStackCluster) + if tt.wantError { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(got).To(Equal(tt.want)) + } + }) + } +} + +func Test_getAPIServerFloatingIP(t *testing.T) { + // Stub the call to net.LookupHost + lookupHost = func(host string) (addrs string, err error) { + if net.ParseIP(host) != nil { + return host, nil + } else if host == apiHostname { + ips := []string{"192.168.100.10"} + return ips[0], nil + } + return "", errors.New("Unknown Host " + host) + } + tests := []struct { + name string + openStackCluster *infrav1.OpenStackCluster + want string + wantError bool + }{ + { + name: "empty cluster returns empty FIP", + openStackCluster: &infrav1.OpenStackCluster{}, + want: "", + wantError: false, + }, + { + name: "API server FIP is API Server LB IP", + openStackCluster: &infrav1.OpenStackCluster{ + Status: infrav1.OpenStackClusterStatus{ + APIServerLoadBalancer: &infrav1.LoadBalancer{ + IP: "1.2.3.4", + }, + }, + }, + want: "1.2.3.4", + wantError: false, + }, + { + name: "API server FIP is API Server Floating IP", + openStackCluster: &infrav1.OpenStackCluster{ + Spec: infrav1.OpenStackClusterSpec{ + APIServerFloatingIP: "1.2.3.4", + }, + }, + want: "1.2.3.4", + wantError: false, + }, + { + name: "API server FIP with valid control plane endpoint", + openStackCluster: &infrav1.OpenStackCluster{ + Spec: infrav1.OpenStackClusterSpec{ + ControlPlaneEndpoint: clusterv1.APIEndpoint{ + Host: apiHostname, + Port: 6443, + }, + }, + }, + want: "192.168.100.10", + wantError: false, + }, + { + name: "API server FIP with invalid control plane endpoint", + openStackCluster: &infrav1.OpenStackCluster{ + Spec: infrav1.OpenStackClusterSpec{ + ControlPlaneEndpoint: clusterv1.APIEndpoint{ + Host: "invalid-api.test-cluster.test", + Port: 6443, + }, + }, + }, + wantError: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewWithT(t) + + got, err := getAPIServerFloatingIP(tt.openStackCluster) + if tt.wantError { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(got).To(Equal(tt.want)) + } + }) + } +} + +func Test_getCanonicalAllowedCIDRs(t *testing.T) { + tests := []struct { + name string + openStackCluster *infrav1.OpenStackCluster + want []string + }{ + { + name: "allowed CIDRs are empty", + openStackCluster: &infrav1.OpenStackCluster{}, + want: []string{}, + }, + { + name: "allowed CIDRs are set", + openStackCluster: &infrav1.OpenStackCluster{ + Spec: infrav1.OpenStackClusterSpec{ + APIServerLoadBalancer: infrav1.APIServerLoadBalancer{ + AllowedCIDRs: []string{"1.2.3.4/32"}, + }, + }, + }, + want: []string{"1.2.3.4/32"}, + }, + { + name: "allowed CIDRs are set with bastion", + openStackCluster: &infrav1.OpenStackCluster{ + Spec: infrav1.OpenStackClusterSpec{ + APIServerLoadBalancer: infrav1.APIServerLoadBalancer{ + AllowedCIDRs: []string{"1.2.3.4/32"}, + }, + }, + Status: infrav1.OpenStackClusterStatus{ + Bastion: &infrav1.BastionStatus{ + FloatingIP: "1.2.3.5", + IP: "192.168.0.1", + }, + }, + }, + want: []string{"1.2.3.4/32", "1.2.3.5/32", "192.168.0.1/32"}, + }, + { + name: "allowed CIDRs are set with network status", + openStackCluster: &infrav1.OpenStackCluster{ + Spec: infrav1.OpenStackClusterSpec{ + APIServerLoadBalancer: infrav1.APIServerLoadBalancer{ + AllowedCIDRs: []string{"1.2.3.4/32"}, + }, + }, + Status: infrav1.OpenStackClusterStatus{ + Network: &infrav1.NetworkStatusWithSubnets{ + Subnets: []infrav1.Subnet{ + { + CIDR: "192.168.0.0/24", + }, + }, + }, + }, + }, + want: []string{"1.2.3.4/32", "192.168.0.0/24"}, + }, + { + name: "allowed CIDRs are set with network status and router IP", + openStackCluster: &infrav1.OpenStackCluster{ + Spec: infrav1.OpenStackClusterSpec{ + APIServerLoadBalancer: infrav1.APIServerLoadBalancer{ + AllowedCIDRs: []string{"1.2.3.4/32"}, + }, + }, + Status: infrav1.OpenStackClusterStatus{ + Network: &infrav1.NetworkStatusWithSubnets{ + Subnets: []infrav1.Subnet{ + { + CIDR: "192.168.0.0/24", + }, + }, + }, + Router: &infrav1.Router{ + IPs: []string{"1.2.3.5"}, + }, + }, + }, + want: []string{"1.2.3.4/32", "1.2.3.5/32", "192.168.0.0/24"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewWithT(t) + + got := getCanonicalAllowedCIDRs(tt.openStackCluster) + g.Expect(got).To(Equal(tt.want)) + }) + } +} + +func Test_getOrCreateAPILoadBalancer(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + octaviaProviders := []providers.Provider{ + { + Name: "ovn", + }, + } + lbtests := []struct { + name string + openStackCluster *infrav1.OpenStackCluster + expectLoadBalancer func(m *mock.MockLbClientMockRecorder) + want *loadbalancers.LoadBalancer + wantError error + }{ + { + name: "nothing exists", + openStackCluster: &infrav1.OpenStackCluster{}, + expectLoadBalancer: func(m *mock.MockLbClientMockRecorder) { + m.ListLoadBalancers(gomock.Any()).Return([]loadbalancers.LoadBalancer{}, nil) + }, + want: &loadbalancers.LoadBalancer{}, + wantError: fmt.Errorf("network is not yet available in OpenStackCluster.Status"), + }, + { + name: "loadbalancer already exists", + openStackCluster: &infrav1.OpenStackCluster{}, + expectLoadBalancer: func(m *mock.MockLbClientMockRecorder) { + m.ListLoadBalancers(gomock.Any()).Return([]loadbalancers.LoadBalancer{{ID: "AAAAA"}}, nil) + }, + want: &loadbalancers.LoadBalancer{ + ID: "AAAAA", + }, + }, + { + name: "loadbalancer created", + openStackCluster: &infrav1.OpenStackCluster{ + Status: infrav1.OpenStackClusterStatus{ + Network: &infrav1.NetworkStatusWithSubnets{ + Subnets: []infrav1.Subnet{ + {ID: "aaaaaaaa-bbbb-cccc-dddd-222222222222"}, + {ID: "aaaaaaaa-bbbb-cccc-dddd-333333333333"}, + }, + }, + }, + }, + expectLoadBalancer: func(m *mock.MockLbClientMockRecorder) { + m.ListLoadBalancers(gomock.Any()).Return([]loadbalancers.LoadBalancer{}, nil) + m.ListLoadBalancerProviders().Return(octaviaProviders, nil) + m.CreateLoadBalancer(gomock.Any()).Return(&loadbalancers.LoadBalancer{ + ID: "AAAAA", + }, nil) + }, + want: &loadbalancers.LoadBalancer{ + ID: "AAAAA", + }, + }, + } + for _, tt := range lbtests { + t.Run(tt.name, func(t *testing.T) { + g := NewWithT(t) + + mockScopeFactory := scope.NewMockScopeFactory(mockCtrl, "", logr.Discard()) + lbs, err := NewService(mockScopeFactory) + g.Expect(err).NotTo(HaveOccurred()) + + tt.expectLoadBalancer(mockScopeFactory.LbClient.EXPECT()) + lb, err := lbs.getOrCreateAPILoadBalancer(tt.openStackCluster, "AAAAA") + if tt.wantError != nil { + g.Expect(err).To(MatchError(tt.wantError)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(lb).To(Equal(tt.want)) + } + }) + } +}