diff --git a/pkg/agent/nodeportlocal/npl_agent_test.go b/pkg/agent/nodeportlocal/npl_agent_test.go index 3cc5957e4d2..83d5fea033b 100644 --- a/pkg/agent/nodeportlocal/npl_agent_test.go +++ b/pkg/agent/nodeportlocal/npl_agent_test.go @@ -657,7 +657,8 @@ func TestMultiplePods(t *testing.T) { // TestMultipleProtocols creates multiple Pods with multiple protocols and verifies that // NPL annotations and iptable rules for both Pods and Protocols are updated correctly. // In particular we make sure that a given NodePort is never used by more than one Pod, -// irrespective of which protocol is in use. +// One Pod could use multiple Nodeports for different protocol with the same podPort +// because of the new NPL unification implementation. func TestMultipleProtocols(t *testing.T) { tcpUdpSvcLabel := map[string]string{"tcp": "true", "udp": "true"} udpSvcLabel := map[string]string{"tcp": "false", "udp": "true"} @@ -701,7 +702,7 @@ func TestMultipleProtocols(t *testing.T) { assert.True(t, testData.portTable.RuleExists(testPod2.Status.PodIP, defaultPort, protocolUDP)) // Update testSvc2 to serve TCP/80 and UDP/81 both, so pod2 is - // exposed on both TCP and UDP, with the same NodePort. + // exposed on both TCP and UDP, with different NodePorts. testSvc2.Spec.Ports = append(testSvc2.Spec.Ports, corev1.ServicePort{ Port: 80, @@ -715,7 +716,11 @@ func TestMultipleProtocols(t *testing.T) { pod2ValueUpdate, err := testData.pollForPodAnnotation(testPod2.Name, true) require.NoError(t, err, "Poll for annotation check failed") - expectedAnnotationsPod2.Add(&pod2Value[0].NodePort, defaultPort, protocolTCP) + + // The new nodeport should be the next of the previous port because of + // the implementation of the nodeport allocation. + pod2nodeport := pod2Value[0].NodePort + 1 + expectedAnnotationsPod2.Add(&pod2nodeport, defaultPort, protocolTCP) expectedAnnotationsPod2.Check(t, pod2ValueUpdate) } @@ -760,22 +765,18 @@ var ( portTakenError = fmt.Errorf("Port taken") ) -// TestNodePortAlreadyBoundTo validates that when a port is already bound to, a different port will -// be selected for NPL. +// TestNodePortAlreadyBoundTo validates that when a port with TCP protocol is already bound to, +// the same port should be selected for NPL if any other protocol is available. func TestNodePortAlreadyBoundTo(t *testing.T) { nodePort1 := defaultStartPort nodePort2 := nodePort1 + 1 testConfig := newTestConfig().withCustomPortOpenerExpectations(func(mockPortOpener *portcachetesting.MockLocalPortOpener) { gomock.InOrder( - // Based on the implementation, we know that TCP is checked first... - // 1. port1 is checked for TCP availability -> success - mockPortOpener.EXPECT().OpenLocalPort(nodePort1, protocolTCP).Return(&fakeSocket{}, nil), - // 2. port1 is checked for UDP availability (even if the Service uses TCP only) -> error - mockPortOpener.EXPECT().OpenLocalPort(nodePort1, protocolUDP).Return(nil, portTakenError), - // 3. port2 is checked for TCP availability -> success + // Based on the implementation, we know only the TCP protocol used in rule is checked... + // 1. port1 is checked for TCP availability -> error + mockPortOpener.EXPECT().OpenLocalPort(nodePort1, protocolTCP).Return(nil, portTakenError), + // 2. port2 is checked for TCP availability -> success mockPortOpener.EXPECT().OpenLocalPort(nodePort2, protocolTCP).Return(&fakeSocket{}, nil), - // 4. port2 is checked for UDP availability -> success - mockPortOpener.EXPECT().OpenLocalPort(nodePort2, protocolUDP).Return(&fakeSocket{}, nil), ) }) customNodePort := defaultStartPort + 1 diff --git a/pkg/agent/nodeportlocal/portcache/port_table.go b/pkg/agent/nodeportlocal/portcache/port_table.go index a2923f74689..cbae1aa725b 100644 --- a/pkg/agent/nodeportlocal/portcache/port_table.go +++ b/pkg/agent/nodeportlocal/portcache/port_table.go @@ -120,10 +120,19 @@ func (pt *PortTable) getEntryByPodIPPort(ip string, port int) *NodePortData { return pt.PodEndpointTable[podIPPortFormat(ip, port)] } +// podIPPortFormat formats the ip, port to string ip:port. +func podIPPortProtoFormat(ip string, port int, protocol string) string { + return fmt.Sprintf("%s:%d:%s", ip, port, protocol) +} + +func (pt *PortTable) getEntryByPodIPPortProto(ip string, port int, protocol string) *NodePortData { + return pt.PodEndpointTable[podIPPortProtoFormat(ip, port, protocol)] +} + func (pt *PortTable) RuleExists(podIP string, podPort int, protocol string) bool { pt.tableLock.RLock() defer pt.tableLock.RUnlock() - if data := pt.getEntryByPodIPPort(podIP, podPort); data != nil { + if data := pt.getEntryByPodIPPortProto(podIP, podPort, protocol); data != nil { return data.ProtocolInUse(protocol) } return false diff --git a/pkg/agent/nodeportlocal/portcache/port_table_linux.go b/pkg/agent/nodeportlocal/portcache/port_table_linux.go index dbaa5a32b86..c079101c21e 100644 --- a/pkg/agent/nodeportlocal/portcache/port_table_linux.go +++ b/pkg/agent/nodeportlocal/portcache/port_table_linux.go @@ -44,38 +44,34 @@ var ( ) func (pt *PortTable) GetEntry(ip string, port int, protocol string) *NodePortData { - var _ = protocol pt.tableLock.RLock() defer pt.tableLock.RUnlock() // Return pointer to copy of data from the PodEndpointTable. - if data := pt.getEntryByPodIPPort(ip, port); data != nil { + if data := pt.getEntryByPodIPPortProto(ip, port, protocol); data != nil { dataCopy := *data return &dataCopy } return nil } -func openSocketsForPort(localPortOpener LocalPortOpener, port int) ([]ProtocolSocketData, error) { - // Port needs to be available for all supported protocols: we want to use the same port - // number for all protocols and we don't know at this point which protocols are needed. - // This is to preserve the legacy behavior of allocating the same nodePort for all protocols. - protocols := make([]ProtocolSocketData, 0, len(supportedProtocols)) - for _, protocol := range supportedProtocols { - socket, err := localPortOpener.OpenLocalPort(port, protocol) - if err != nil { - klog.V(4).InfoS("Local port cannot be opened", "port", port, "protocol", protocol) - return protocols, err - } - protocols = append(protocols, ProtocolSocketData{ - Protocol: protocol, - State: stateOpen, - socket: socket, - }) +func openSocketsForPort(localPortOpener LocalPortOpener, port int, protocol string) ([]ProtocolSocketData, error) { + // Port needs to be only available for the protocol used by NPL rule. + // We don't need to allocate the same nodePort for all protocols anymore. + protocols := make([]ProtocolSocketData, 0, 1) + socket, err := localPortOpener.OpenLocalPort(port, protocol) + if err != nil { + klog.V(4).InfoS("Local port cannot be opened", "port", port, "protocol", protocol) + return protocols, err } + protocols = append(protocols, ProtocolSocketData{ + Protocol: protocol, + State: stateInUse, + socket: socket, + }) return protocols, nil } -func (pt *PortTable) getFreePort(podIP string, podPort int) (int, []ProtocolSocketData, error) { +func (pt *PortTable) getFreePort(podIP string, podPort int, protocol string) (int, []ProtocolSocketData, error) { klog.V(2).InfoS("Looking for free Node port", "podIP", podIP, "podPort", podPort) numPorts := pt.EndPort - pt.StartPort + 1 for i := 0; i < numPorts; i++ { @@ -84,15 +80,14 @@ func (pt *PortTable) getFreePort(podIP string, podPort int) (int, []ProtocolSock // handle wrap around port = port - numPorts } - if _, ok := pt.NodePortTable[strconv.Itoa(port)]; ok { + if _, ok := pt.NodePortTable[NodePortProtoFormat(port, protocol)]; ok { // port is already taken continue } - protocols, err := openSocketsForPort(pt.LocalPortOpener, port) + protocols, err := openSocketsForPort(pt.LocalPortOpener, port, protocol) if err != nil { klog.V(4).InfoS("Port cannot be reserved, moving on to the next one", "port", port) - closeSocketsOrRetry(protocols) continue } @@ -105,46 +100,6 @@ func (pt *PortTable) getFreePort(podIP string, podPort int) (int, []ProtocolSock return 0, nil, fmt.Errorf("no free port found") } -func closeSockets(protocols []ProtocolSocketData) error { - for idx := range protocols { - protocolSocketData := &protocols[idx] - if protocolSocketData.State != stateOpen { - continue - } - if err := protocolSocketData.socket.Close(); err != nil { - return err - } - protocolSocketData.State = stateClosed - - } - return nil -} - -// closeSocketsOrRetry closes all provided sockets. In case of an error, it -// creates a goroutine to retry asynchronously. -func closeSocketsOrRetry(protocols []ProtocolSocketData) { - var err error - if err = closeSockets(protocols); err == nil { - return - } - // Unlikely that there could be transient errors when closing a socket, - // but just in case, we create a goroutine to retry. We make a copy of - // the protocols slice, since the calling goroutine may modify the - // original one. - protocolsCopy := make([]ProtocolSocketData, len(protocols)) - copy(protocolsCopy, protocols) - go func() { - const delay = 5 * time.Second - for { - klog.ErrorS(err, "Unexpected error when closing socket(s), will retry", "retryDelay", delay) - time.Sleep(delay) - if err = closeSockets(protocolsCopy); err == nil { - return - } - } - }() -} - func (d *NodePortData) CloseSockets() error { for idx := range d.Protocols { protocolSocketData := &d.Protocols[idx] @@ -170,10 +125,10 @@ func (d *NodePortData) CloseSockets() error { func (pt *PortTable) AddRule(podIP string, podPort int, protocol string) (int, error) { pt.tableLock.Lock() defer pt.tableLock.Unlock() - npData := pt.getEntryByPodIPPort(podIP, podPort) + npData := pt.getEntryByPodIPPortProto(podIP, podPort, protocol) exists := (npData != nil) if !exists { - nodePort, protocols, err := pt.getFreePort(podIP, podPort) + nodePort, protocols, err := pt.getFreePort(podIP, podPort, protocol) if err != nil { return 0, err } @@ -183,27 +138,15 @@ func (pt *PortTable) AddRule(podIP string, podPort int, protocol string) (int, e PodPort: podPort, Protocols: protocols, } - } - protocolSocketData := npData.FindProtocol(protocol) - if protocolSocketData == nil { - return 0, fmt.Errorf("unknown protocol %s", protocol) - } - if protocolSocketData.State == stateInUse { - return 0, fmt.Errorf("rule for %s:%d:%s already exists", podIP, podPort, protocol) - } - if protocolSocketData.State == stateClosed { - return 0, fmt.Errorf("invalid socket state for %s:%d:%s", podIP, podPort, protocol) - } - - nodePort := npData.NodePort - if err := pt.PodPortRules.AddRule(nodePort, podIP, podPort, protocol); err != nil { - return 0, err - } - - protocolSocketData.State = stateInUse - if !exists { - pt.NodePortTable[strconv.Itoa(nodePort)] = npData - pt.PodEndpointTable[podIPPortFormat(podIP, podPort)] = npData + nodePort = npData.NodePort + if err := pt.PodPortRules.AddRule(nodePort, podIP, podPort, protocol); err != nil { + return 0, err + } + pt.NodePortTable[NodePortProtoFormat(nodePort, protocol)] = npData + pt.PodEndpointTable[podIPPortProtoFormat(podIP, podPort, protocol)] = npData + } else { + // Only add rules for if the entry does not exist. + return 0, fmt.Errorf("existed windows nodeport entry for %s:%d:%s", podIP, podPort, protocol) } return npData.NodePort, nil } @@ -211,38 +154,19 @@ func (pt *PortTable) AddRule(podIP string, podPort int, protocol string) (int, e func (pt *PortTable) DeleteRule(podIP string, podPort int, protocol string) error { pt.tableLock.Lock() defer pt.tableLock.Unlock() - data := pt.getEntryByPodIPPort(podIP, podPort) + data := pt.getEntryByPodIPPortProto(podIP, podPort, protocol) if data == nil { // Delete not required when the PortTable entry does not exist return nil } - numProtocolsInUse := 0 - var protocolSocketData *ProtocolSocketData - for idx, pData := range data.Protocols { - if pData.State != stateInUse { - continue - } - numProtocolsInUse++ - if pData.Protocol == protocol { - protocolSocketData = &data.Protocols[idx] - } - } - if protocolSocketData != nil { - if err := pt.PodPortRules.DeleteRule(data.NodePort, podIP, podPort, protocol); err != nil { - return err - } - protocolSocketData.State = stateOpen - numProtocolsInUse-- + if err := pt.PodPortRules.DeleteRule(data.NodePort, podIP, podPort, protocol); err != nil { + return err } - if numProtocolsInUse == 0 { - // Node port is not needed anymore: close all sockets and delete - // table entries. - if err := data.CloseSockets(); err != nil { - return err - } - delete(pt.NodePortTable, strconv.Itoa(data.NodePort)) - delete(pt.PodEndpointTable, podIPPortFormat(podIP, podPort)) + if err := data.CloseSockets(); err != nil { + return err } + delete(pt.NodePortTable, NodePortProtoFormat(data.NodePort, protocol)) + delete(pt.PodEndpointTable, podIPPortProtoFormat(podIP, podPort, protocol)) return nil } @@ -251,7 +175,7 @@ func (pt *PortTable) DeleteRulesForPod(podIP string) error { defer pt.tableLock.Unlock() podEntries := pt.getDataForPodIP(podIP) for _, podEntry := range podEntries { - for len(podEntry.Protocols) > 0 { + if len(podEntry.Protocols) > 0 { protocolSocketData := podEntry.Protocols[0] if err := pt.PodPortRules.DeleteRule(podEntry.NodePort, podIP, podEntry.PodPort, protocolSocketData.Protocol); err != nil { return err @@ -259,10 +183,9 @@ func (pt *PortTable) DeleteRulesForPod(podIP string) error { if err := protocolSocketData.socket.Close(); err != nil { return fmt.Errorf("error when releasing local port %d with protocol %s: %v", podEntry.NodePort, protocolSocketData.Protocol, err) } - podEntry.Protocols = podEntry.Protocols[1:] + delete(pt.NodePortTable, NodePortProtoFormat(podEntry.NodePort, protocolSocketData.Protocol)) + delete(pt.PodEndpointTable, podIPPortProtoFormat(podIP, podEntry.PodPort, protocolSocketData.Protocol)) } - delete(pt.NodePortTable, strconv.Itoa(podEntry.NodePort)) - delete(pt.PodEndpointTable, podIPPortFormat(podIP, podEntry.PodPort)) } return nil } @@ -271,7 +194,7 @@ func (pt *PortTable) DeleteRulesForPod(podIP string) error { func (pt *PortTable) syncRules() error { pt.tableLock.Lock() defer pt.tableLock.Unlock() - nplPorts := make([]rules.PodNodePort, 0, len(pt.NodePortTable)) + nplPorts := make([]rules.PodNodePort, 0, 1) for _, npData := range pt.NodePortTable { protocols := make([]string, 0, len(supportedProtocols)) for _, protocol := range npData.Protocols { @@ -283,6 +206,7 @@ func (pt *PortTable) syncRules() error { NodePort: npData.NodePort, PodPort: npData.PodPort, PodIP: npData.PodIP, + Protocol: protocols[0], Protocols: protocols, }) } @@ -300,13 +224,12 @@ func (pt *PortTable) RestoreRules(allNPLPorts []rules.PodNodePort, synced chan<- pt.tableLock.Lock() defer pt.tableLock.Unlock() for _, nplPort := range allNPLPorts { - protocols, err := openSocketsForPort(pt.LocalPortOpener, nplPort.NodePort) + protocols, err := openSocketsForPort(pt.LocalPortOpener, nplPort.NodePort, nplPort.Protocol) if err != nil { // This will be handled gracefully by the NPL controller: if there is an // annotation using this port, it will be removed and replaced with a new // one with a valid port mapping. klog.ErrorS(err, "Cannot bind to local port, skipping it", "port", nplPort.NodePort) - closeSocketsOrRetry(protocols) continue } @@ -316,15 +239,8 @@ func (pt *PortTable) RestoreRules(allNPLPorts []rules.PodNodePort, synced chan<- PodIP: nplPort.PodIP, Protocols: protocols, } - for _, protocol := range nplPort.Protocols { - protocolSocketData := npData.FindProtocol(protocol) - if protocolSocketData == nil { - return fmt.Errorf("unknown protocol %s", protocol) - } - protocolSocketData.State = stateInUse - } - pt.NodePortTable[strconv.Itoa(nplPort.NodePort)] = npData - pt.PodEndpointTable[podIPPortFormat(nplPort.PodIP, nplPort.PodPort)] = pt.NodePortTable[strconv.Itoa(nplPort.NodePort)] + pt.NodePortTable[podIPPortProtoFormat(nplPort.PodIP, nplPort.PodPort, nplPort.Protocol)] = npData + pt.PodEndpointTable[NodePortProtoFormat(nplPort.NodePort, nplPort.Protocol)] = pt.NodePortTable[strconv.Itoa(nplPort.NodePort)] } // retry mechanism as iptables-restore can fail if other components (in Antrea or other // software) are accessing iptables. diff --git a/pkg/agent/nodeportlocal/portcache/port_table_test.go b/pkg/agent/nodeportlocal/portcache/port_table_test.go index 271ffcc2853..1693f42b1c8 100644 --- a/pkg/agent/nodeportlocal/portcache/port_table_test.go +++ b/pkg/agent/nodeportlocal/portcache/port_table_test.go @@ -59,12 +59,21 @@ func TestRestoreRules(t *testing.T) { NodePort: nodePort1, PodPort: 1001, PodIP: podIP, - Protocols: []string{"tcp", "udp"}, + Protocol: "tcp", + Protocols: []string{"tcp"}, + }, + { + NodePort: nodePort1, + PodPort: 1001, + PodIP: podIP, + Protocol: "udp", + Protocols: []string{"udp"}, }, { NodePort: nodePort2, PodPort: 1002, PodIP: podIP, + Protocol: "udp", Protocols: []string{"udp"}, }, } @@ -73,7 +82,6 @@ func TestRestoreRules(t *testing.T) { gomock.InOrder( mockPortOpener.EXPECT().OpenLocalPort(nodePort1, "tcp"), mockPortOpener.EXPECT().OpenLocalPort(nodePort1, "udp"), - mockPortOpener.EXPECT().OpenLocalPort(nodePort2, "tcp"), mockPortOpener.EXPECT().OpenLocalPort(nodePort2, "udp"), ) diff --git a/pkg/agent/nodeportlocal/portcache/port_table_windows.go b/pkg/agent/nodeportlocal/portcache/port_table_windows.go index ef5d6ee539a..36ad2e6bdbf 100644 --- a/pkg/agent/nodeportlocal/portcache/port_table_windows.go +++ b/pkg/agent/nodeportlocal/portcache/port_table_windows.go @@ -30,15 +30,6 @@ const ( stateInUse protocolSocketState = 1 ) -// podIPPortFormat formats the ip, port to string ip:port. -func podIPPortProtoFormat(ip string, port int, protocol string) string { - return fmt.Sprintf("%s:%d:%s", ip, port, protocol) -} - -func (pt *PortTable) getEntryByPodIPPortProto(ip string, port int, protocol string) *NodePortData { - return pt.PodEndpointTable[podIPPortProtoFormat(ip, port, protocol)] -} - func (pt *PortTable) GetEntry(ip string, port int, protocol string) *NodePortData { pt.tableLock.RLock() defer pt.tableLock.RUnlock() diff --git a/test/e2e/nodeportlocal_test.go b/test/e2e/nodeportlocal_test.go index 725036cd6f8..feeacc46149 100644 --- a/test/e2e/nodeportlocal_test.go +++ b/test/e2e/nodeportlocal_test.go @@ -77,7 +77,6 @@ func configureNPLForAgent(t *testing.T, data *TestData, startPort, endPort int) // NodePortLocal related test cases so they can share setup, teardown. func TestNodePortLocal(t *testing.T) { skipIfNotIPv4Cluster(t) - skipIfHasWindowsNodes(t) data, err := setupTest(t) if err != nil {