Skip to content

Commit

Permalink
Implement NPL agent unification.
Browse files Browse the repository at this point in the history
* Unify agent behavior across Linux and Windows. Linux agent should support
allocating different nodeports for different protocols when the podports are the same.
* Update port allocation related unit tests.
* Enable windows e2e test.
* Delete unused functions.

Signed-off-by: Shuyang Xin <[email protected]>
  • Loading branch information
XinShuYang committed Jun 29, 2022
1 parent b2c245c commit bf0acfb
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 153 deletions.
27 changes: 14 additions & 13 deletions pkg/agent/nodeportlocal/npl_agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,8 @@ func TestMultiplePods(t *testing.T) {
// TestMultipleProtocols creates multiple Pods with multiple protocols and verifies that
// NPL annotations and iptable rules for both Pods and Protocols are updated correctly.
// In particular we make sure that a given NodePort is never used by more than one Pod,
// irrespective of which protocol is in use.
// One Pod could use multiple Nodeports for different protocol with the same podPort
// because of the new NPL unification implementation.
func TestMultipleProtocols(t *testing.T) {
tcpUdpSvcLabel := map[string]string{"tcp": "true", "udp": "true"}
udpSvcLabel := map[string]string{"tcp": "false", "udp": "true"}
Expand Down Expand Up @@ -701,7 +702,7 @@ func TestMultipleProtocols(t *testing.T) {
assert.True(t, testData.portTable.RuleExists(testPod2.Status.PodIP, defaultPort, protocolUDP))

// Update testSvc2 to serve TCP/80 and UDP/81 both, so pod2 is
// exposed on both TCP and UDP, with the same NodePort.
// exposed on both TCP and UDP, with different NodePorts.

testSvc2.Spec.Ports = append(testSvc2.Spec.Ports, corev1.ServicePort{
Port: 80,
Expand All @@ -715,7 +716,11 @@ func TestMultipleProtocols(t *testing.T) {

pod2ValueUpdate, err := testData.pollForPodAnnotation(testPod2.Name, true)
require.NoError(t, err, "Poll for annotation check failed")
expectedAnnotationsPod2.Add(&pod2Value[0].NodePort, defaultPort, protocolTCP)

// The new nodeport should be the next of the previous port because of
// the implementation of the nodeport allocation.
pod2nodeport := pod2Value[0].NodePort + 1
expectedAnnotationsPod2.Add(&pod2nodeport, defaultPort, protocolTCP)
expectedAnnotationsPod2.Check(t, pod2ValueUpdate)
}

Expand Down Expand Up @@ -760,22 +765,18 @@ var (
portTakenError = fmt.Errorf("Port taken")
)

// TestNodePortAlreadyBoundTo validates that when a port is already bound to, a different port will
// be selected for NPL.
// TestNodePortAlreadyBoundTo validates that when a port with TCP protocol is already bound to,
// the same port should be selected for NPL if any other protocol is available.
func TestNodePortAlreadyBoundTo(t *testing.T) {
nodePort1 := defaultStartPort
nodePort2 := nodePort1 + 1
testConfig := newTestConfig().withCustomPortOpenerExpectations(func(mockPortOpener *portcachetesting.MockLocalPortOpener) {
gomock.InOrder(
// Based on the implementation, we know that TCP is checked first...
// 1. port1 is checked for TCP availability -> success
mockPortOpener.EXPECT().OpenLocalPort(nodePort1, protocolTCP).Return(&fakeSocket{}, nil),
// 2. port1 is checked for UDP availability (even if the Service uses TCP only) -> error
mockPortOpener.EXPECT().OpenLocalPort(nodePort1, protocolUDP).Return(nil, portTakenError),
// 3. port2 is checked for TCP availability -> success
// Based on the implementation, we know only the TCP protocol used in rule is checked...
// 1. port1 is checked for TCP availability -> error
mockPortOpener.EXPECT().OpenLocalPort(nodePort1, protocolTCP).Return(nil, portTakenError),
// 2. port2 is checked for TCP availability -> success
mockPortOpener.EXPECT().OpenLocalPort(nodePort2, protocolTCP).Return(&fakeSocket{}, nil),
// 4. port2 is checked for UDP availability -> success
mockPortOpener.EXPECT().OpenLocalPort(nodePort2, protocolUDP).Return(&fakeSocket{}, nil),
)
})
customNodePort := defaultStartPort + 1
Expand Down
11 changes: 10 additions & 1 deletion pkg/agent/nodeportlocal/portcache/port_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,19 @@ func (pt *PortTable) getEntryByPodIPPort(ip string, port int) *NodePortData {
return pt.PodEndpointTable[podIPPortFormat(ip, port)]
}

// podIPPortFormat formats the ip, port to string ip:port.
func podIPPortProtoFormat(ip string, port int, protocol string) string {
return fmt.Sprintf("%s:%d:%s", ip, port, protocol)
}

func (pt *PortTable) getEntryByPodIPPortProto(ip string, port int, protocol string) *NodePortData {
return pt.PodEndpointTable[podIPPortProtoFormat(ip, port, protocol)]
}

func (pt *PortTable) RuleExists(podIP string, podPort int, protocol string) bool {
pt.tableLock.RLock()
defer pt.tableLock.RUnlock()
if data := pt.getEntryByPodIPPort(podIP, podPort); data != nil {
if data := pt.getEntryByPodIPPortProto(podIP, podPort, protocol); data != nil {
return data.ProtocolInUse(protocol)
}
return false
Expand Down
170 changes: 43 additions & 127 deletions pkg/agent/nodeportlocal/portcache/port_table_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,38 +44,34 @@ var (
)

func (pt *PortTable) GetEntry(ip string, port int, protocol string) *NodePortData {
var _ = protocol
pt.tableLock.RLock()
defer pt.tableLock.RUnlock()
// Return pointer to copy of data from the PodEndpointTable.
if data := pt.getEntryByPodIPPort(ip, port); data != nil {
if data := pt.getEntryByPodIPPortProto(ip, port, protocol); data != nil {
dataCopy := *data
return &dataCopy
}
return nil
}

func openSocketsForPort(localPortOpener LocalPortOpener, port int) ([]ProtocolSocketData, error) {
// Port needs to be available for all supported protocols: we want to use the same port
// number for all protocols and we don't know at this point which protocols are needed.
// This is to preserve the legacy behavior of allocating the same nodePort for all protocols.
protocols := make([]ProtocolSocketData, 0, len(supportedProtocols))
for _, protocol := range supportedProtocols {
socket, err := localPortOpener.OpenLocalPort(port, protocol)
if err != nil {
klog.V(4).InfoS("Local port cannot be opened", "port", port, "protocol", protocol)
return protocols, err
}
protocols = append(protocols, ProtocolSocketData{
Protocol: protocol,
State: stateOpen,
socket: socket,
})
func openSocketsForPort(localPortOpener LocalPortOpener, port int, protocol string) ([]ProtocolSocketData, error) {
// Port needs to be only available for the protocol used by NPL rule.
// We don't need to allocate the same nodePort for all protocols anymore.
protocols := make([]ProtocolSocketData, 0, 1)
socket, err := localPortOpener.OpenLocalPort(port, protocol)
if err != nil {
klog.V(4).InfoS("Local port cannot be opened", "port", port, "protocol", protocol)
return protocols, err
}
protocols = append(protocols, ProtocolSocketData{
Protocol: protocol,
State: stateInUse,
socket: socket,
})
return protocols, nil
}

func (pt *PortTable) getFreePort(podIP string, podPort int) (int, []ProtocolSocketData, error) {
func (pt *PortTable) getFreePort(podIP string, podPort int, protocol string) (int, []ProtocolSocketData, error) {
klog.V(2).InfoS("Looking for free Node port", "podIP", podIP, "podPort", podPort)
numPorts := pt.EndPort - pt.StartPort + 1
for i := 0; i < numPorts; i++ {
Expand All @@ -84,15 +80,14 @@ func (pt *PortTable) getFreePort(podIP string, podPort int) (int, []ProtocolSock
// handle wrap around
port = port - numPorts
}
if _, ok := pt.NodePortTable[strconv.Itoa(port)]; ok {
if _, ok := pt.NodePortTable[NodePortProtoFormat(port, protocol)]; ok {
// port is already taken
continue
}

protocols, err := openSocketsForPort(pt.LocalPortOpener, port)
protocols, err := openSocketsForPort(pt.LocalPortOpener, port, protocol)
if err != nil {
klog.V(4).InfoS("Port cannot be reserved, moving on to the next one", "port", port)
closeSocketsOrRetry(protocols)
continue
}

Expand All @@ -105,46 +100,6 @@ func (pt *PortTable) getFreePort(podIP string, podPort int) (int, []ProtocolSock
return 0, nil, fmt.Errorf("no free port found")
}

func closeSockets(protocols []ProtocolSocketData) error {
for idx := range protocols {
protocolSocketData := &protocols[idx]
if protocolSocketData.State != stateOpen {
continue
}
if err := protocolSocketData.socket.Close(); err != nil {
return err
}
protocolSocketData.State = stateClosed

}
return nil
}

// closeSocketsOrRetry closes all provided sockets. In case of an error, it
// creates a goroutine to retry asynchronously.
func closeSocketsOrRetry(protocols []ProtocolSocketData) {
var err error
if err = closeSockets(protocols); err == nil {
return
}
// Unlikely that there could be transient errors when closing a socket,
// but just in case, we create a goroutine to retry. We make a copy of
// the protocols slice, since the calling goroutine may modify the
// original one.
protocolsCopy := make([]ProtocolSocketData, len(protocols))
copy(protocolsCopy, protocols)
go func() {
const delay = 5 * time.Second
for {
klog.ErrorS(err, "Unexpected error when closing socket(s), will retry", "retryDelay", delay)
time.Sleep(delay)
if err = closeSockets(protocolsCopy); err == nil {
return
}
}
}()
}

func (d *NodePortData) CloseSockets() error {
for idx := range d.Protocols {
protocolSocketData := &d.Protocols[idx]
Expand All @@ -170,10 +125,10 @@ func (d *NodePortData) CloseSockets() error {
func (pt *PortTable) AddRule(podIP string, podPort int, protocol string) (int, error) {
pt.tableLock.Lock()
defer pt.tableLock.Unlock()
npData := pt.getEntryByPodIPPort(podIP, podPort)
npData := pt.getEntryByPodIPPortProto(podIP, podPort, protocol)
exists := (npData != nil)
if !exists {
nodePort, protocols, err := pt.getFreePort(podIP, podPort)
nodePort, protocols, err := pt.getFreePort(podIP, podPort, protocol)
if err != nil {
return 0, err
}
Expand All @@ -183,66 +138,35 @@ func (pt *PortTable) AddRule(podIP string, podPort int, protocol string) (int, e
PodPort: podPort,
Protocols: protocols,
}
}
protocolSocketData := npData.FindProtocol(protocol)
if protocolSocketData == nil {
return 0, fmt.Errorf("unknown protocol %s", protocol)
}
if protocolSocketData.State == stateInUse {
return 0, fmt.Errorf("rule for %s:%d:%s already exists", podIP, podPort, protocol)
}
if protocolSocketData.State == stateClosed {
return 0, fmt.Errorf("invalid socket state for %s:%d:%s", podIP, podPort, protocol)
}

nodePort := npData.NodePort
if err := pt.PodPortRules.AddRule(nodePort, podIP, podPort, protocol); err != nil {
return 0, err
}

protocolSocketData.State = stateInUse
if !exists {
pt.NodePortTable[strconv.Itoa(nodePort)] = npData
pt.PodEndpointTable[podIPPortFormat(podIP, podPort)] = npData
nodePort = npData.NodePort
if err := pt.PodPortRules.AddRule(nodePort, podIP, podPort, protocol); err != nil {
return 0, err
}
pt.NodePortTable[NodePortProtoFormat(nodePort, protocol)] = npData
pt.PodEndpointTable[podIPPortProtoFormat(podIP, podPort, protocol)] = npData
} else {
// Only add rules for if the entry does not exist.
return 0, fmt.Errorf("existed windows nodeport entry for %s:%d:%s", podIP, podPort, protocol)
}
return npData.NodePort, nil
}

func (pt *PortTable) DeleteRule(podIP string, podPort int, protocol string) error {
pt.tableLock.Lock()
defer pt.tableLock.Unlock()
data := pt.getEntryByPodIPPort(podIP, podPort)
data := pt.getEntryByPodIPPortProto(podIP, podPort, protocol)
if data == nil {
// Delete not required when the PortTable entry does not exist
return nil
}
numProtocolsInUse := 0
var protocolSocketData *ProtocolSocketData
for idx, pData := range data.Protocols {
if pData.State != stateInUse {
continue
}
numProtocolsInUse++
if pData.Protocol == protocol {
protocolSocketData = &data.Protocols[idx]
}
}
if protocolSocketData != nil {
if err := pt.PodPortRules.DeleteRule(data.NodePort, podIP, podPort, protocol); err != nil {
return err
}
protocolSocketData.State = stateOpen
numProtocolsInUse--
if err := pt.PodPortRules.DeleteRule(data.NodePort, podIP, podPort, protocol); err != nil {
return err
}
if numProtocolsInUse == 0 {
// Node port is not needed anymore: close all sockets and delete
// table entries.
if err := data.CloseSockets(); err != nil {
return err
}
delete(pt.NodePortTable, strconv.Itoa(data.NodePort))
delete(pt.PodEndpointTable, podIPPortFormat(podIP, podPort))
if err := data.CloseSockets(); err != nil {
return err
}
delete(pt.NodePortTable, NodePortProtoFormat(data.NodePort, protocol))
delete(pt.PodEndpointTable, podIPPortProtoFormat(podIP, podPort, protocol))
return nil
}

Expand All @@ -251,18 +175,17 @@ func (pt *PortTable) DeleteRulesForPod(podIP string) error {
defer pt.tableLock.Unlock()
podEntries := pt.getDataForPodIP(podIP)
for _, podEntry := range podEntries {
for len(podEntry.Protocols) > 0 {
if len(podEntry.Protocols) > 0 {
protocolSocketData := podEntry.Protocols[0]
if err := pt.PodPortRules.DeleteRule(podEntry.NodePort, podIP, podEntry.PodPort, protocolSocketData.Protocol); err != nil {
return err
}
if err := protocolSocketData.socket.Close(); err != nil {
return fmt.Errorf("error when releasing local port %d with protocol %s: %v", podEntry.NodePort, protocolSocketData.Protocol, err)
}
podEntry.Protocols = podEntry.Protocols[1:]
delete(pt.NodePortTable, NodePortProtoFormat(podEntry.NodePort, protocolSocketData.Protocol))
delete(pt.PodEndpointTable, podIPPortProtoFormat(podIP, podEntry.PodPort, protocolSocketData.Protocol))
}
delete(pt.NodePortTable, strconv.Itoa(podEntry.NodePort))
delete(pt.PodEndpointTable, podIPPortFormat(podIP, podEntry.PodPort))
}
return nil
}
Expand All @@ -271,7 +194,7 @@ func (pt *PortTable) DeleteRulesForPod(podIP string) error {
func (pt *PortTable) syncRules() error {
pt.tableLock.Lock()
defer pt.tableLock.Unlock()
nplPorts := make([]rules.PodNodePort, 0, len(pt.NodePortTable))
nplPorts := make([]rules.PodNodePort, 0, 1)
for _, npData := range pt.NodePortTable {
protocols := make([]string, 0, len(supportedProtocols))
for _, protocol := range npData.Protocols {
Expand All @@ -283,6 +206,7 @@ func (pt *PortTable) syncRules() error {
NodePort: npData.NodePort,
PodPort: npData.PodPort,
PodIP: npData.PodIP,
Protocol: protocols[0],
Protocols: protocols,
})
}
Expand All @@ -300,13 +224,12 @@ func (pt *PortTable) RestoreRules(allNPLPorts []rules.PodNodePort, synced chan<-
pt.tableLock.Lock()
defer pt.tableLock.Unlock()
for _, nplPort := range allNPLPorts {
protocols, err := openSocketsForPort(pt.LocalPortOpener, nplPort.NodePort)
protocols, err := openSocketsForPort(pt.LocalPortOpener, nplPort.NodePort, nplPort.Protocol)
if err != nil {
// This will be handled gracefully by the NPL controller: if there is an
// annotation using this port, it will be removed and replaced with a new
// one with a valid port mapping.
klog.ErrorS(err, "Cannot bind to local port, skipping it", "port", nplPort.NodePort)
closeSocketsOrRetry(protocols)
continue
}

Expand All @@ -316,15 +239,8 @@ func (pt *PortTable) RestoreRules(allNPLPorts []rules.PodNodePort, synced chan<-
PodIP: nplPort.PodIP,
Protocols: protocols,
}
for _, protocol := range nplPort.Protocols {
protocolSocketData := npData.FindProtocol(protocol)
if protocolSocketData == nil {
return fmt.Errorf("unknown protocol %s", protocol)
}
protocolSocketData.State = stateInUse
}
pt.NodePortTable[strconv.Itoa(nplPort.NodePort)] = npData
pt.PodEndpointTable[podIPPortFormat(nplPort.PodIP, nplPort.PodPort)] = pt.NodePortTable[strconv.Itoa(nplPort.NodePort)]
pt.NodePortTable[podIPPortProtoFormat(nplPort.PodIP, nplPort.PodPort, nplPort.Protocol)] = npData
pt.PodEndpointTable[NodePortProtoFormat(nplPort.NodePort, nplPort.Protocol)] = pt.NodePortTable[strconv.Itoa(nplPort.NodePort)]
}
// retry mechanism as iptables-restore can fail if other components (in Antrea or other
// software) are accessing iptables.
Expand Down
12 changes: 10 additions & 2 deletions pkg/agent/nodeportlocal/portcache/port_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,21 @@ func TestRestoreRules(t *testing.T) {
NodePort: nodePort1,
PodPort: 1001,
PodIP: podIP,
Protocols: []string{"tcp", "udp"},
Protocol: "tcp",
Protocols: []string{"tcp"},
},
{
NodePort: nodePort1,
PodPort: 1001,
PodIP: podIP,
Protocol: "udp",
Protocols: []string{"udp"},
},
{
NodePort: nodePort2,
PodPort: 1002,
PodIP: podIP,
Protocol: "udp",
Protocols: []string{"udp"},
},
}
Expand All @@ -73,7 +82,6 @@ func TestRestoreRules(t *testing.T) {
gomock.InOrder(
mockPortOpener.EXPECT().OpenLocalPort(nodePort1, "tcp"),
mockPortOpener.EXPECT().OpenLocalPort(nodePort1, "udp"),
mockPortOpener.EXPECT().OpenLocalPort(nodePort2, "tcp"),
mockPortOpener.EXPECT().OpenLocalPort(nodePort2, "udp"),
)

Expand Down
Loading

0 comments on commit bf0acfb

Please sign in to comment.