diff --git a/pkg/agent/nodeportlocal/k8s/npl_controller.go b/pkg/agent/nodeportlocal/k8s/npl_controller.go index 0bdb187f690..0515fd6abaf 100644 --- a/pkg/agent/nodeportlocal/k8s/npl_controller.go +++ b/pkg/agent/nodeportlocal/k8s/npl_controller.go @@ -530,11 +530,10 @@ func (c *NPLController) handleAddUpdatePod(key string, obj interface{}) error { entries := c.portTable.GetDataForPodIP(podIP) if nplExists { for _, data := range entries { - for _, proto := range data.Protocols { - if _, exists := podPorts[util.BuildPortProto(fmt.Sprint(data.PodPort), proto.Protocol)]; !exists { - if err := c.portTable.DeleteRule(podIP, int(data.PodPort), proto.Protocol); err != nil { - return fmt.Errorf("failed to delete rule for Pod IP %s, Pod Port %d, Protocol %s: %v", podIP, data.PodPort, proto.Protocol, err) - } + proto := data.Protocol + if _, exists := podPorts[util.BuildPortProto(fmt.Sprint(data.PodPort), proto.Protocol)]; !exists { + if err := c.portTable.DeleteRule(podIP, int(data.PodPort), proto.Protocol); err != nil { + return fmt.Errorf("failed to delete rule for Pod IP %s, Pod Port %d, Protocol %s: %v", podIP, data.PodPort, proto.Protocol, err) } } } diff --git a/pkg/agent/nodeportlocal/npl_agent_test.go b/pkg/agent/nodeportlocal/npl_agent_test.go index 95a09547d09..f546ebb3451 100644 --- a/pkg/agent/nodeportlocal/npl_agent_test.go +++ b/pkg/agent/nodeportlocal/npl_agent_test.go @@ -67,8 +67,11 @@ const ( func newPortTable(mockIPTables rules.PodPortRules, mockPortOpener portcache.LocalPortOpener) *portcache.PortTable { return &portcache.PortTable{ - NodePortTable: make(map[string]*portcache.NodePortData), - PodEndpointTable: make(map[string]*portcache.NodePortData), + PortTableCache: cache.NewIndexer(portcache.GetPortTableKey, cache.Indexers{ + portcache.nodePortIndex: portcache.NodePortIndexFunc, + portcache.podEndpointIndex: portcache.PodEndpointIndexFunc, + portcache.podIPIndex: portcache.PodIPIndexFunc, + }), StartPort: defaultStartPort, EndPort: defaultEndPort, PortSearchStart: defaultStartPort, @@ -658,7 +661,8 @@ func TestMultiplePods(t *testing.T) { // TestMultipleProtocols creates multiple Pods with multiple protocols and verifies that // NPL annotations and iptable rules for both Pods and Protocols are updated correctly. // In particular we make sure that a given NodePort is never used by more than one Pod, -// irrespective of which protocol is in use. +// One Pod could use multiple Nodeports for different protocol with the same Pod port +// because of the new NPL unification implementation. func TestMultipleProtocols(t *testing.T) { tcpUdpSvcLabel := map[string]string{"tcp": "true", "udp": "true"} udpSvcLabel := map[string]string{"tcp": "false", "udp": "true"} @@ -702,8 +706,7 @@ func TestMultipleProtocols(t *testing.T) { assert.True(t, testData.portTable.RuleExists(testPod2.Status.PodIP, defaultPort, protocolUDP)) // Update testSvc2 to serve TCP/80 and UDP/81 both, so pod2 is - // exposed on both TCP and UDP, with the same NodePort. - + // exposed on both TCP and UDP, with different NodePorts. testSvc2.Spec.Ports = append(testSvc2.Spec.Ports, corev1.ServicePort{ Port: 80, Protocol: corev1.ProtocolTCP, @@ -716,7 +719,16 @@ func TestMultipleProtocols(t *testing.T) { pod2ValueUpdate, err := testData.pollForPodAnnotation(testPod2.Name, true) require.NoError(t, err, "Poll for annotation check failed") - expectedAnnotationsPod2.Add(&pod2Value[0].NodePort, defaultPort, protocolTCP) + + // The new nodeport should be the next of the last used port because of + // the implementation of the nodeport allocation. + var pod2nodeport int + if pod1Value[0].NodePort > pod2Value[0].NodePort { + pod2nodeport = pod1Value[0].NodePort + 1 + } else { + pod2nodeport = pod2Value[0].NodePort + 1 + } + expectedAnnotationsPod2.Add(&pod2nodeport, defaultPort, protocolTCP) expectedAnnotationsPod2.Check(t, pod2ValueUpdate) } @@ -761,22 +773,18 @@ var ( portTakenError = fmt.Errorf("Port taken") ) -// TestNodePortAlreadyBoundTo validates that when a port is already bound to, a different port will -// be selected for NPL. +// TestNodePortAlreadyBoundTo validates that when a port with TCP protocol is already bound to, +// the same port should be selected for NPL if any other protocol is available. func TestNodePortAlreadyBoundTo(t *testing.T) { nodePort1 := defaultStartPort nodePort2 := nodePort1 + 1 testConfig := newTestConfig().withCustomPortOpenerExpectations(func(mockPortOpener *portcachetesting.MockLocalPortOpener) { gomock.InOrder( - // Based on the implementation, we know that TCP is checked first... - // 1. port1 is checked for TCP availability -> success - mockPortOpener.EXPECT().OpenLocalPort(nodePort1, protocolTCP).Return(&fakeSocket{}, nil), - // 2. port1 is checked for UDP availability (even if the Service uses TCP only) -> error - mockPortOpener.EXPECT().OpenLocalPort(nodePort1, protocolUDP).Return(nil, portTakenError), - // 3. port2 is checked for TCP availability -> success + // Based on the implementation, we know only the TCP protocol used in rule is checked... + // 1. port1 is checked for TCP availability -> error + mockPortOpener.EXPECT().OpenLocalPort(nodePort1, protocolTCP).Return(nil, portTakenError), + // 2. port2 is checked for TCP availability -> success mockPortOpener.EXPECT().OpenLocalPort(nodePort2, protocolTCP).Return(&fakeSocket{}, nil), - // 4. port2 is checked for UDP availability -> success - mockPortOpener.EXPECT().OpenLocalPort(nodePort2, protocolUDP).Return(&fakeSocket{}, nil), ) }) customNodePort := defaultStartPort + 1 diff --git a/pkg/agent/nodeportlocal/portcache/port_table.go b/pkg/agent/nodeportlocal/portcache/port_table.go index a2923f74689..4852c242f5a 100644 --- a/pkg/agent/nodeportlocal/portcache/port_table.go +++ b/pkg/agent/nodeportlocal/portcache/port_table.go @@ -22,9 +22,17 @@ import ( "k8s.io/klog/v2" + "k8s.io/client-go/tools/cache" "antrea.io/antrea/pkg/agent/nodeportlocal/rules" ) +const ( + NPLPortTableIndex = "NPLPortTableIndex" + nodePortIndex = "nodePortIndex" + podEndpointIndex = "podEndpointIndex" + podIPIndex = "podIPIndex" +) + // protocolSocketState represents the state of the socket corresponding to a // given (Node port, protocol) tuple. type protocolSocketState int @@ -39,23 +47,18 @@ type NodePortData struct { NodePort int PodPort int PodIP string - Protocols []ProtocolSocketData + Protocol ProtocolSocketData } -func (d *NodePortData) FindProtocol(protocol string) *ProtocolSocketData { - for idx, protocolSocketData := range d.Protocols { - if protocolSocketData.Protocol == protocol { - return &d.Protocols[idx] - } - } - return nil +type CacheNpData struct { + Type string + Data *NodePortData } func (d *NodePortData) ProtocolInUse(protocol string) bool { - for _, protocolSocketData := range d.Protocols { - if protocolSocketData.Protocol == protocol { - return protocolSocketData.State == stateInUse - } + protocolSocketData := d.Protocol + if protocolSocketData.Protocol == protocol { + return protocolSocketData.State == stateInUse } return false } @@ -67,8 +70,7 @@ type LocalPortOpener interface { type localPortOpener struct{} type PortTable struct { - NodePortTable map[string]*NodePortData - PodEndpointTable map[string]*NodePortData + PortTableCache cache.Indexer StartPort int EndPort int PortSearchStart int @@ -77,10 +79,99 @@ type PortTable struct { tableLock sync.RWMutex } +func GetPortTableKey(obj interface{}) (string, error) { + npData := obj.(*NodePortData) + key := fmt.Sprintf("%d:%s:%d:%s", npData.NodePort, npData.PodIP, npData.PodPort, npData.Protocol.Protocol) + return key, nil +} + +func (pt *PortTable) addPortTableCache(npData *NodePortData) error { + if err := pt.PortTableCache.Add(npData); err != nil { + return err + } + return nil +} + +func (pt *PortTable) delPortTableCache(npData *NodePortData) error { + if err := pt.PortTableCache.Delete(npData); err != nil { + return err + } + return nil +} + +func (pt *PortTable) getPortTableCacheFromNodePortIndex(index string) (*NodePortData, bool) { + objs, _ := pt.PortTableCache.ByIndex(nodePortIndex, index) + if len(objs) == 0 { + return nil, false + } + return objs[0].(*NodePortData), true +} + +func (pt *PortTable) getPortTableCacheFromPodEndpointIndex(index string) (*NodePortData, bool) { + objs, _ := pt.PortTableCache.ByIndex(podEndpointIndex, index) + if len(objs) == 0 { + return nil, false + } + return objs[0].(*NodePortData), true +} + +func (pt *PortTable) getPortTableCacheFromPodIPIndex(index string) ([]NodePortData, bool) { + var npData []NodePortData + objs, _ := pt.PortTableCache.ByIndex(podIPIndex, index) + if len(objs) == 0 { + return nil, false + } + for _, obj := range objs { + npData = append(npData, *(obj.(*NodePortData))) + } + return npData, true +} + +func (pt *PortTable) delPortTableCacheFromNodePortIndex(index string) error { + data, exists := pt.getPortTableCacheFromNodePortIndex(index) + if exists == false { + return nil + } + if err := pt.delPortTableCache(data); err != nil { + return err + } + return nil +} + +func (pt *PortTable) releaseDataFromPortTableCache() error { + for _, obj := range pt.PortTableCache.List() { + data := obj.(*NodePortData) + if err := pt.delPortTableCache(data); err != nil { + return err + } + } + return nil +} + +func NodePortIndexFunc(obj interface{}) ([]string, error) { + npData := obj.(*NodePortData) + nodePortTuple := NodePortProtoFormat(npData.NodePort, npData.Protocol.Protocol) + return []string{nodePortTuple}, nil +} + +func PodEndpointIndexFunc(obj interface{}) ([]string, error) { + npData := obj.(*NodePortData) + podEndpointTuple := podIPPortProtoFormat(npData.PodIP, npData.PodPort, npData.Protocol.Protocol) + return []string{podEndpointTuple}, nil +} + +func PodIPIndexFunc(obj interface{}) ([]string, error) { + npData := obj.(*NodePortData) + return []string{npData.PodIP}, nil +} + func NewPortTable(start, end int) (*PortTable, error) { ptable := PortTable{ - NodePortTable: make(map[string]*NodePortData), - PodEndpointTable: make(map[string]*NodePortData), + PortTableCache: cache.NewIndexer(GetPortTableKey, cache.Indexers{ + nodePortIndex: NodePortIndexFunc, + podEndpointIndex: PodEndpointIndexFunc, + podIPIndex: PodIPIndexFunc, + }), StartPort: start, EndPort: end, PortSearchStart: start, @@ -96,8 +187,7 @@ func NewPortTable(start, end int) (*PortTable, error) { func (pt *PortTable) CleanupAllEntries() { pt.tableLock.Lock() defer pt.tableLock.Unlock() - pt.NodePortTable = make(map[string]*NodePortData) - pt.PodEndpointTable = make(map[string]*NodePortData) + pt.releaseDataFromPortTableCache() } func (pt *PortTable) GetDataForPodIP(ip string) []NodePortData { @@ -107,23 +197,30 @@ func (pt *PortTable) GetDataForPodIP(ip string) []NodePortData { } func (pt *PortTable) getDataForPodIP(ip string) []NodePortData { - var allData []NodePortData - for i := range pt.NodePortTable { - if pt.NodePortTable[i].PodIP == ip { - allData = append(allData, *pt.NodePortTable[i]) - } + allData, exist := pt.getPortTableCacheFromPodIPIndex(ip) + if exist == false { + return nil } return allData } -func (pt *PortTable) getEntryByPodIPPort(ip string, port int) *NodePortData { - return pt.PodEndpointTable[podIPPortFormat(ip, port)] +// podIPPortFormat formats the ip, port to string ip:port. +func podIPPortProtoFormat(ip string, port int, protocol string) string { + return fmt.Sprintf("%s:%d:%s", ip, port, protocol) +} + +func (pt *PortTable) getEntryByPodIPPortProto(ip string, port int, protocol string) *NodePortData { + data, exists := pt.getPortTableCacheFromPodEndpointIndex(podIPPortProtoFormat(ip, port, protocol)) + if exists == false { + return nil + } + return data } func (pt *PortTable) RuleExists(podIP string, podPort int, protocol string) bool { pt.tableLock.RLock() defer pt.tableLock.RUnlock() - if data := pt.getEntryByPodIPPort(podIP, podPort); data != nil { + if data := pt.getEntryByPodIPPortProto(podIP, podPort, protocol); data != nil { return data.ProtocolInUse(protocol) } return false diff --git a/pkg/agent/nodeportlocal/portcache/port_table_linux.go b/pkg/agent/nodeportlocal/portcache/port_table_linux.go index dbaa5a32b86..1c35fe27463 100644 --- a/pkg/agent/nodeportlocal/portcache/port_table_linux.go +++ b/pkg/agent/nodeportlocal/portcache/port_table_linux.go @@ -19,7 +19,6 @@ package portcache import ( "fmt" - "strconv" "time" "k8s.io/klog/v2" @@ -44,38 +43,33 @@ var ( ) func (pt *PortTable) GetEntry(ip string, port int, protocol string) *NodePortData { - var _ = protocol pt.tableLock.RLock() defer pt.tableLock.RUnlock() // Return pointer to copy of data from the PodEndpointTable. - if data := pt.getEntryByPodIPPort(ip, port); data != nil { + if data := pt.getEntryByPodIPPortProto(ip, port, protocol); data != nil { dataCopy := *data return &dataCopy } return nil } -func openSocketsForPort(localPortOpener LocalPortOpener, port int) ([]ProtocolSocketData, error) { - // Port needs to be available for all supported protocols: we want to use the same port - // number for all protocols and we don't know at this point which protocols are needed. - // This is to preserve the legacy behavior of allocating the same nodePort for all protocols. - protocols := make([]ProtocolSocketData, 0, len(supportedProtocols)) - for _, protocol := range supportedProtocols { - socket, err := localPortOpener.OpenLocalPort(port, protocol) - if err != nil { - klog.V(4).InfoS("Local port cannot be opened", "port", port, "protocol", protocol) - return protocols, err - } - protocols = append(protocols, ProtocolSocketData{ - Protocol: protocol, - State: stateOpen, - socket: socket, - }) +func openSocketsForPort(localPortOpener LocalPortOpener, port int, protocol string) (ProtocolSocketData, error) { + // Port needs to be only available for the protocol used by NPL rule. + // We don't need to allocate the same nodePort for all protocols anymore. + socket, err := localPortOpener.OpenLocalPort(port, protocol) + if err != nil { + klog.V(4).InfoS("Local port cannot be opened", "port", port, "protocol", protocol) + return ProtocolSocketData{}, err + } + protocolData := ProtocolSocketData{ + Protocol: protocol, + State: stateInUse, + socket: socket, } - return protocols, nil + return protocolData, nil } -func (pt *PortTable) getFreePort(podIP string, podPort int) (int, []ProtocolSocketData, error) { +func (pt *PortTable) getFreePort(podIP string, podPort int, protocol string) (int, ProtocolSocketData, error) { klog.V(2).InfoS("Looking for free Node port", "podIP", podIP, "podPort", podPort) numPorts := pt.EndPort - pt.StartPort + 1 for i := 0; i < numPorts; i++ { @@ -84,15 +78,14 @@ func (pt *PortTable) getFreePort(podIP string, podPort int) (int, []ProtocolSock // handle wrap around port = port - numPorts } - if _, ok := pt.NodePortTable[strconv.Itoa(port)]; ok { + if _, ok := pt.getPortTableCacheFromNodePortIndex(NodePortProtoFormat(port, protocol)); ok { // port is already taken continue } - protocols, err := openSocketsForPort(pt.LocalPortOpener, port) + protocolData, err := openSocketsForPort(pt.LocalPortOpener, port, protocol) if err != nil { klog.V(4).InfoS("Port cannot be reserved, moving on to the next one", "port", port) - closeSocketsOrRetry(protocols) continue } @@ -100,58 +93,18 @@ func (pt *PortTable) getFreePort(podIP string, podPort int) (int, []ProtocolSock if pt.PortSearchStart > pt.EndPort { pt.PortSearchStart = pt.StartPort } - return port, protocols, nil - } - return 0, nil, fmt.Errorf("no free port found") -} - -func closeSockets(protocols []ProtocolSocketData) error { - for idx := range protocols { - protocolSocketData := &protocols[idx] - if protocolSocketData.State != stateOpen { - continue - } - if err := protocolSocketData.socket.Close(); err != nil { - return err - } - protocolSocketData.State = stateClosed - + return port, protocolData, nil } - return nil -} - -// closeSocketsOrRetry closes all provided sockets. In case of an error, it -// creates a goroutine to retry asynchronously. -func closeSocketsOrRetry(protocols []ProtocolSocketData) { - var err error - if err = closeSockets(protocols); err == nil { - return - } - // Unlikely that there could be transient errors when closing a socket, - // but just in case, we create a goroutine to retry. We make a copy of - // the protocols slice, since the calling goroutine may modify the - // original one. - protocolsCopy := make([]ProtocolSocketData, len(protocols)) - copy(protocolsCopy, protocols) - go func() { - const delay = 5 * time.Second - for { - klog.ErrorS(err, "Unexpected error when closing socket(s), will retry", "retryDelay", delay) - time.Sleep(delay) - if err = closeSockets(protocolsCopy); err == nil { - return - } - } - }() + return 0, ProtocolSocketData{}, fmt.Errorf("no free port found") } func (d *NodePortData) CloseSockets() error { - for idx := range d.Protocols { - protocolSocketData := &d.Protocols[idx] + if d.Protocol.Protocol != "" { + protocolSocketData := &d.Protocol switch protocolSocketData.State { case stateClosed: // already closed - continue + return nil case stateInUse: // should not happen return fmt.Errorf("protocol %s is still in use, cannot release socket", protocolSocketData.Protocol) @@ -170,10 +123,10 @@ func (d *NodePortData) CloseSockets() error { func (pt *PortTable) AddRule(podIP string, podPort int, protocol string) (int, error) { pt.tableLock.Lock() defer pt.tableLock.Unlock() - npData := pt.getEntryByPodIPPort(podIP, podPort) + npData := pt.getEntryByPodIPPortProto(podIP, podPort, protocol) exists := (npData != nil) if !exists { - nodePort, protocols, err := pt.getFreePort(podIP, podPort) + nodePort, protocolData, err := pt.getFreePort(podIP, podPort, protocol) if err != nil { return 0, err } @@ -181,29 +134,16 @@ func (pt *PortTable) AddRule(podIP string, podPort int, protocol string) (int, e NodePort: nodePort, PodIP: podIP, PodPort: podPort, - Protocols: protocols, + Protocol: protocolData, } - } - protocolSocketData := npData.FindProtocol(protocol) - if protocolSocketData == nil { - return 0, fmt.Errorf("unknown protocol %s", protocol) - } - if protocolSocketData.State == stateInUse { - return 0, fmt.Errorf("rule for %s:%d:%s already exists", podIP, podPort, protocol) - } - if protocolSocketData.State == stateClosed { - return 0, fmt.Errorf("invalid socket state for %s:%d:%s", podIP, podPort, protocol) - } - - nodePort := npData.NodePort - if err := pt.PodPortRules.AddRule(nodePort, podIP, podPort, protocol); err != nil { - return 0, err - } - - protocolSocketData.State = stateInUse - if !exists { - pt.NodePortTable[strconv.Itoa(nodePort)] = npData - pt.PodEndpointTable[podIPPortFormat(podIP, podPort)] = npData + nodePort = npData.NodePort + if err := pt.PodPortRules.AddRule(nodePort, podIP, podPort, protocol); err != nil { + return 0, err + } + pt.addPortTableCache(npData) + } else { + // Only add rules for if the entry does not exist. + return 0, fmt.Errorf("existed windows nodeport entry for %s:%d:%s", podIP, podPort, protocol) } return npData.NodePort, nil } @@ -211,38 +151,19 @@ func (pt *PortTable) AddRule(podIP string, podPort int, protocol string) (int, e func (pt *PortTable) DeleteRule(podIP string, podPort int, protocol string) error { pt.tableLock.Lock() defer pt.tableLock.Unlock() - data := pt.getEntryByPodIPPort(podIP, podPort) + data := pt.getEntryByPodIPPortProto(podIP, podPort, protocol) if data == nil { // Delete not required when the PortTable entry does not exist return nil } - numProtocolsInUse := 0 - var protocolSocketData *ProtocolSocketData - for idx, pData := range data.Protocols { - if pData.State != stateInUse { - continue - } - numProtocolsInUse++ - if pData.Protocol == protocol { - protocolSocketData = &data.Protocols[idx] - } - } - if protocolSocketData != nil { - if err := pt.PodPortRules.DeleteRule(data.NodePort, podIP, podPort, protocol); err != nil { - return err - } - protocolSocketData.State = stateOpen - numProtocolsInUse-- + if err := pt.PodPortRules.DeleteRule(data.NodePort, podIP, podPort, protocol); err != nil { + return err } - if numProtocolsInUse == 0 { - // Node port is not needed anymore: close all sockets and delete - // table entries. - if err := data.CloseSockets(); err != nil { - return err - } - delete(pt.NodePortTable, strconv.Itoa(data.NodePort)) - delete(pt.PodEndpointTable, podIPPortFormat(podIP, podPort)) + if err := data.CloseSockets(); err != nil { + return err } + // We don't need to delete cache from different indexes repeatedly because they map to the same entry. + pt.delPortTableCacheFromNodePortIndex(NodePortProtoFormat(data.NodePort, protocol)) return nil } @@ -251,18 +172,14 @@ func (pt *PortTable) DeleteRulesForPod(podIP string) error { defer pt.tableLock.Unlock() podEntries := pt.getDataForPodIP(podIP) for _, podEntry := range podEntries { - for len(podEntry.Protocols) > 0 { - protocolSocketData := podEntry.Protocols[0] - if err := pt.PodPortRules.DeleteRule(podEntry.NodePort, podIP, podEntry.PodPort, protocolSocketData.Protocol); err != nil { - return err - } - if err := protocolSocketData.socket.Close(); err != nil { - return fmt.Errorf("error when releasing local port %d with protocol %s: %v", podEntry.NodePort, protocolSocketData.Protocol, err) - } - podEntry.Protocols = podEntry.Protocols[1:] + protocolSocketData := podEntry.Protocol + if err := pt.PodPortRules.DeleteRule(podEntry.NodePort, podIP, podEntry.PodPort, protocolSocketData.Protocol); err != nil { + return err } - delete(pt.NodePortTable, strconv.Itoa(podEntry.NodePort)) - delete(pt.PodEndpointTable, podIPPortFormat(podIP, podEntry.PodPort)) + if err := protocolSocketData.socket.Close(); err != nil { + return fmt.Errorf("error when releasing local port %d with protocol %s: %v", podEntry.NodePort, protocolSocketData.Protocol, err) + } + pt.delPortTableCacheFromNodePortIndex(NodePortProtoFormat(podEntry.NodePort, protocolSocketData.Protocol)) } return nil } @@ -271,18 +188,19 @@ func (pt *PortTable) DeleteRulesForPod(podIP string) error { func (pt *PortTable) syncRules() error { pt.tableLock.Lock() defer pt.tableLock.Unlock() - nplPorts := make([]rules.PodNodePort, 0, len(pt.NodePortTable)) - for _, npData := range pt.NodePortTable { - protocols := make([]string, 0, len(supportedProtocols)) - for _, protocol := range npData.Protocols { - if protocol.State == stateInUse { - protocols = append(protocols, protocol.Protocol) - } + nplPorts := make([]rules.PodNodePort, 0, 1) + for _, obj := range pt.PortTableCache.List() { + npData := obj.(*NodePortData) + protocols := make([]string, 0, 1) + protocol := npData.Protocol + if protocol.State == stateInUse { + protocols = append(protocols, protocol.Protocol) } nplPorts = append(nplPorts, rules.PodNodePort{ NodePort: npData.NodePort, PodPort: npData.PodPort, PodIP: npData.PodIP, + Protocol: protocols[0], Protocols: protocols, }) } @@ -300,13 +218,12 @@ func (pt *PortTable) RestoreRules(allNPLPorts []rules.PodNodePort, synced chan<- pt.tableLock.Lock() defer pt.tableLock.Unlock() for _, nplPort := range allNPLPorts { - protocols, err := openSocketsForPort(pt.LocalPortOpener, nplPort.NodePort) + protocolData, err := openSocketsForPort(pt.LocalPortOpener, nplPort.NodePort, nplPort.Protocol) if err != nil { // This will be handled gracefully by the NPL controller: if there is an // annotation using this port, it will be removed and replaced with a new // one with a valid port mapping. klog.ErrorS(err, "Cannot bind to local port, skipping it", "port", nplPort.NodePort) - closeSocketsOrRetry(protocols) continue } @@ -314,17 +231,9 @@ func (pt *PortTable) RestoreRules(allNPLPorts []rules.PodNodePort, synced chan<- NodePort: nplPort.NodePort, PodPort: nplPort.PodPort, PodIP: nplPort.PodIP, - Protocols: protocols, - } - for _, protocol := range nplPort.Protocols { - protocolSocketData := npData.FindProtocol(protocol) - if protocolSocketData == nil { - return fmt.Errorf("unknown protocol %s", protocol) - } - protocolSocketData.State = stateInUse + Protocol: protocolData, } - pt.NodePortTable[strconv.Itoa(nplPort.NodePort)] = npData - pt.PodEndpointTable[podIPPortFormat(nplPort.PodIP, nplPort.PodPort)] = pt.NodePortTable[strconv.Itoa(nplPort.NodePort)] + pt.addPortTableCache(npData) } // retry mechanism as iptables-restore can fail if other components (in Antrea or other // software) are accessing iptables. diff --git a/pkg/agent/nodeportlocal/portcache/port_table_test.go b/pkg/agent/nodeportlocal/portcache/port_table_test.go index 271ffcc2853..4ea467a27b4 100644 --- a/pkg/agent/nodeportlocal/portcache/port_table_test.go +++ b/pkg/agent/nodeportlocal/portcache/port_table_test.go @@ -26,6 +26,8 @@ import ( portcachetesting "antrea.io/antrea/pkg/agent/nodeportlocal/portcache/testing" "antrea.io/antrea/pkg/agent/nodeportlocal/rules" rulestesting "antrea.io/antrea/pkg/agent/nodeportlocal/rules/testing" + + "k8s.io/client-go/tools/cache" ) const ( @@ -35,8 +37,11 @@ const ( func newPortTable(mockIPTables rules.PodPortRules, mockPortOpener LocalPortOpener) *PortTable { return &PortTable{ - NodePortTable: make(map[string]*NodePortData), - PodEndpointTable: make(map[string]*NodePortData), + PortTableCache: cache.NewIndexer(GetPortTableKey, cache.Indexers{ + nodePortIndex: NodePortIndexFunc, + podEndpointIndex: PodEndpointIndexFunc, + podIPIndex: PodIPIndexFunc, + }), StartPort: startPort, EndPort: endPort, PortSearchStart: startPort, @@ -59,12 +64,21 @@ func TestRestoreRules(t *testing.T) { NodePort: nodePort1, PodPort: 1001, PodIP: podIP, - Protocols: []string{"tcp", "udp"}, + Protocol: "tcp", + Protocols: []string{"tcp"}, + }, + { + NodePort: nodePort1, + PodPort: 1001, + PodIP: podIP, + Protocol: "udp", + Protocols: []string{"udp"}, }, { NodePort: nodePort2, PodPort: 1002, PodIP: podIP, + Protocol: "udp", Protocols: []string{"udp"}, }, } @@ -73,7 +87,6 @@ func TestRestoreRules(t *testing.T) { gomock.InOrder( mockPortOpener.EXPECT().OpenLocalPort(nodePort1, "tcp"), mockPortOpener.EXPECT().OpenLocalPort(nodePort1, "udp"), - mockPortOpener.EXPECT().OpenLocalPort(nodePort2, "tcp"), mockPortOpener.EXPECT().OpenLocalPort(nodePort2, "udp"), ) diff --git a/pkg/agent/nodeportlocal/portcache/port_table_windows.go b/pkg/agent/nodeportlocal/portcache/port_table_windows.go index ef5d6ee539a..b9a29447800 100644 --- a/pkg/agent/nodeportlocal/portcache/port_table_windows.go +++ b/pkg/agent/nodeportlocal/portcache/port_table_windows.go @@ -30,15 +30,6 @@ const ( stateInUse protocolSocketState = 1 ) -// podIPPortFormat formats the ip, port to string ip:port. -func podIPPortProtoFormat(ip string, port int, protocol string) string { - return fmt.Sprintf("%s:%d:%s", ip, port, protocol) -} - -func (pt *PortTable) getEntryByPodIPPortProto(ip string, port int, protocol string) *NodePortData { - return pt.PodEndpointTable[podIPPortProtoFormat(ip, port, protocol)] -} - func (pt *PortTable) GetEntry(ip string, port int, protocol string) *NodePortData { pt.tableLock.RLock() defer pt.tableLock.RUnlock() @@ -50,24 +41,23 @@ func (pt *PortTable) GetEntry(ip string, port int, protocol string) *NodePortDat return nil } -func addRuleForPort(podPortRules rules.PodPortRules, port int, podIP string, podPort int, protocol string) ([]ProtocolSocketData, error) { +func addRuleForPort(podPortRules rules.PodPortRules, port int, podIP string, podPort int, protocol string) (ProtocolSocketData, error) { // Only the protocol used here should be returned if NetNatStaticMapping rule // can be inserted to an unused protocol port. - protocols := make([]ProtocolSocketData, 0, 1) err := podPortRules.AddRule(port, podIP, podPort, protocol) if err != nil { klog.ErrorS(err, "Local port cannot be opened", "port", port, "protocol", protocol) return nil, err } - protocols = append(protocols, ProtocolSocketData{ + protocolData := ProtocolSocketData{ Protocol: protocol, State: stateInUse, socket: nil, - }) - return protocols, nil + } + return protocolData, nil } -func (pt *PortTable) addRuleforFreePort(podIP string, podPort int, protocol string) (int, []ProtocolSocketData, error) { +func (pt *PortTable) addRuleforFreePort(podIP string, podPort int, protocol string) (int, ProtocolSocketData, error) { klog.V(2).InfoS("Looking for free Node port on Windows", "podIP", podIP, "podPort", podPort, "protocol", protocol) numPorts := pt.EndPort - pt.StartPort + 1 for i := 0; i < numPorts; i++ { @@ -76,12 +66,12 @@ func (pt *PortTable) addRuleforFreePort(podIP string, podPort int, protocol stri // handle wrap around port = port - numPorts } - if _, ok := pt.NodePortTable[NodePortProtoFormat(port, protocol)]; ok { + if _, ok := pt.getPortTableCacheFromNodePortIndex(NodePortProtoFormat(port, protocol)); ok { // protocol port is already taken continue } - protocols, err := addRuleForPort(pt.PodPortRules, port, podIP, podPort, protocol) + protocolData, err := addRuleForPort(pt.PodPortRules, port, podIP, podPort, protocol) if err != nil { klog.ErrorS(err, "Port cannot be reserved, moving on to the next one", "port", port) continue @@ -91,7 +81,7 @@ func (pt *PortTable) addRuleforFreePort(podIP string, podPort int, protocol stri if pt.PortSearchStart > pt.EndPort { pt.PortSearchStart = pt.StartPort } - return port, protocols, nil + return port, protocolData, nil } return 0, nil, fmt.Errorf("no free port found") } @@ -114,8 +104,7 @@ func (pt *PortTable) AddRule(podIP string, podPort int, protocol string) (int, e Protocols: protocols, } - pt.NodePortTable[NodePortProtoFormat(nodePort, protocol)] = npData - pt.PodEndpointTable[podIPPortProtoFormat(podIP, podPort, protocol)] = npData + pt.addPortTableCache(npData) } else { // Only add rules for if the entry does not exist. return 0, fmt.Errorf("existed windows nodeport entry for %s:%d:%s", podIP, podPort, protocol) @@ -143,8 +132,7 @@ func (pt *PortTable) RestoreRules(allNPLPorts []rules.PodNodePort, synced chan<- PodIP: nplPort.PodIP, Protocols: protocols, } - pt.PodEndpointTable[podIPPortProtoFormat(nplPort.PodIP, nplPort.PodPort, nplPort.Protocol)] = pt.NodePortTable[NodePortProtoFormat(nplPort.NodePort, nplPort.Protocol)] - pt.NodePortTable[NodePortProtoFormat(nplPort.NodePort, nplPort.Protocol)] = npData + pt.addPortTableCache(npData) } // No need to sync up again because addRuleForPort has updated all rules on Windows close(synced) @@ -160,14 +148,13 @@ func (pt *PortTable) DeleteRule(podIP string, podPort int, protocol string) erro return nil } var protocolSocketData *ProtocolSocketData - protocolSocketData = &data.Protocols[0] + protocolSocketData = &data.Protocol if protocolSocketData != nil { if err := pt.PodPortRules.DeleteRule(data.NodePort, podIP, podPort, protocol); err != nil { return err } } - delete(pt.NodePortTable, NodePortProtoFormat(data.NodePort, protocol)) - delete(pt.PodEndpointTable, podIPPortProtoFormat(podIP, podPort, protocol)) + pt.delPortTableCacheFromNodePortIndex(NodePortProtoFormat(data.NodePort, protocol)) return nil } @@ -176,14 +163,11 @@ func (pt *PortTable) DeleteRulesForPod(podIP string) error { defer pt.tableLock.Unlock() podEntries := pt.getDataForPodIP(podIP) for _, podEntry := range podEntries { - if len(podEntry.Protocols) > 0 { - protocolSocketData := podEntry.Protocols[0] - if err := pt.PodPortRules.DeleteRule(podEntry.NodePort, podIP, podEntry.PodPort, protocolSocketData.Protocol); err != nil { - return err - } - delete(pt.PodEndpointTable, podIPPortProtoFormat(podIP, podEntry.PodPort, protocolSocketData.Protocol)) - delete(pt.NodePortTable, NodePortProtoFormat(podEntry.NodePort, protocolSocketData.Protocol)) + protocolSocketData := podEntry.Protocol + if err := pt.PodPortRules.DeleteRule(podEntry.NodePort, podIP, podEntry.PodPort, protocolSocketData.Protocol); err != nil { + return err } + pt.delPortTableCacheFromNodePortIndex(NodePortProtoFormat(podEntry.NodePort, protocolSocketData.Protocol)) } return nil } diff --git a/test/e2e/nodeportlocal_test.go b/test/e2e/nodeportlocal_test.go index 075a73becf7..6d2d09f209f 100644 --- a/test/e2e/nodeportlocal_test.go +++ b/test/e2e/nodeportlocal_test.go @@ -77,7 +77,6 @@ func configureNPLForAgent(t *testing.T, data *TestData, startPort, endPort int) // NodePortLocal related test cases so they can share setup, teardown. func TestNodePortLocal(t *testing.T) { skipIfNotIPv4Cluster(t) - skipIfHasWindowsNodes(t) skipIfNodePortLocalDisabled(t) data, err := setupTest(t)