Skip to content

Commit

Permalink
Implement NPL agent unification.
Browse files Browse the repository at this point in the history
* Unify agent behavior across Linux and Windows. Linux agent should support
allocating different nodeports for different protocols when the podports are the same.
* Replace map with cache.indexer for cachetable to reduce repeated insertion.
* Update port allocation related unit tests.
* Enable windows e2e test.
* Delete unused functions.

Signed-off-by: Shuyang Xin <[email protected]>
  • Loading branch information
XinShuYang committed Jul 22, 2022
1 parent 4b788e7 commit 551e928
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 232 deletions.
9 changes: 4 additions & 5 deletions pkg/agent/nodeportlocal/k8s/npl_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,10 @@ func (c *NPLController) handleAddUpdatePod(key string, obj interface{}) error {
entries := c.portTable.GetDataForPodIP(podIP)
if nplExists {
for _, data := range entries {
for _, proto := range data.Protocols {
if _, exists := podPorts[util.BuildPortProto(fmt.Sprint(data.PodPort), proto.Protocol)]; !exists {
if err := c.portTable.DeleteRule(podIP, int(data.PodPort), proto.Protocol); err != nil {
return fmt.Errorf("failed to delete rule for Pod IP %s, Pod Port %d, Protocol %s: %v", podIP, data.PodPort, proto.Protocol, err)
}
proto := data.Protocol
if _, exists := podPorts[util.BuildPortProto(fmt.Sprint(data.PodPort), proto.Protocol)]; !exists {
if err := c.portTable.DeleteRule(podIP, int(data.PodPort), proto.Protocol); err != nil {
return fmt.Errorf("failed to delete rule for Pod IP %s, Pod Port %d, Protocol %s: %v", podIP, data.PodPort, proto.Protocol, err)
}
}
}
Expand Down
40 changes: 24 additions & 16 deletions pkg/agent/nodeportlocal/npl_agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ const (

func newPortTable(mockIPTables rules.PodPortRules, mockPortOpener portcache.LocalPortOpener) *portcache.PortTable {
return &portcache.PortTable{
NodePortTable: make(map[string]*portcache.NodePortData),
PodEndpointTable: make(map[string]*portcache.NodePortData),
PortTableCache: cache.NewIndexer(portcache.GetPortTableKey, cache.Indexers{
nodePortIndex: portcache.NodePortIndexFunc,
podEndpointIndex: portcache.PodEndpointIndexFunc,
podIPIndex: portcache.PodIPIndexFunc,
}),
StartPort: defaultStartPort,
EndPort: defaultEndPort,
PortSearchStart: defaultStartPort,
Expand Down Expand Up @@ -658,7 +661,8 @@ func TestMultiplePods(t *testing.T) {
// TestMultipleProtocols creates multiple Pods with multiple protocols and verifies that
// NPL annotations and iptable rules for both Pods and Protocols are updated correctly.
// In particular we make sure that a given NodePort is never used by more than one Pod,
// irrespective of which protocol is in use.
// One Pod could use multiple Nodeports for different protocol with the same podPort
// because of the new NPL unification implementation.
func TestMultipleProtocols(t *testing.T) {
tcpUdpSvcLabel := map[string]string{"tcp": "true", "udp": "true"}
udpSvcLabel := map[string]string{"tcp": "false", "udp": "true"}
Expand Down Expand Up @@ -702,8 +706,7 @@ func TestMultipleProtocols(t *testing.T) {
assert.True(t, testData.portTable.RuleExists(testPod2.Status.PodIP, defaultPort, protocolUDP))

// Update testSvc2 to serve TCP/80 and UDP/81 both, so pod2 is
// exposed on both TCP and UDP, with the same NodePort.

// exposed on both TCP and UDP, with different NodePorts.
testSvc2.Spec.Ports = append(testSvc2.Spec.Ports, corev1.ServicePort{
Port: 80,
Protocol: corev1.ProtocolTCP,
Expand All @@ -716,7 +719,16 @@ func TestMultipleProtocols(t *testing.T) {

pod2ValueUpdate, err := testData.pollForPodAnnotation(testPod2.Name, true)
require.NoError(t, err, "Poll for annotation check failed")
expectedAnnotationsPod2.Add(&pod2Value[0].NodePort, defaultPort, protocolTCP)

// The new nodeport should be the next of the last used port because of
// the implementation of the nodeport allocation.
var pod2nodeport int
if pod1Value[0].NodePort > pod2Value[0].NodePort {
pod2nodeport = pod1Value[0].NodePort + 1
} else {
pod2nodeport = pod2Value[0].NodePort + 1
}
expectedAnnotationsPod2.Add(&pod2nodeport, defaultPort, protocolTCP)
expectedAnnotationsPod2.Check(t, pod2ValueUpdate)
}

Expand Down Expand Up @@ -761,22 +773,18 @@ var (
portTakenError = fmt.Errorf("Port taken")
)

// TestNodePortAlreadyBoundTo validates that when a port is already bound to, a different port will
// be selected for NPL.
// TestNodePortAlreadyBoundTo validates that when a port with TCP protocol is already bound to,
// the same port should be selected for NPL if any other protocol is available.
func TestNodePortAlreadyBoundTo(t *testing.T) {
nodePort1 := defaultStartPort
nodePort2 := nodePort1 + 1
testConfig := newTestConfig().withCustomPortOpenerExpectations(func(mockPortOpener *portcachetesting.MockLocalPortOpener) {
gomock.InOrder(
// Based on the implementation, we know that TCP is checked first...
// 1. port1 is checked for TCP availability -> success
mockPortOpener.EXPECT().OpenLocalPort(nodePort1, protocolTCP).Return(&fakeSocket{}, nil),
// 2. port1 is checked for UDP availability (even if the Service uses TCP only) -> error
mockPortOpener.EXPECT().OpenLocalPort(nodePort1, protocolUDP).Return(nil, portTakenError),
// 3. port2 is checked for TCP availability -> success
// Based on the implementation, we know only the TCP protocol used in rule is checked...
// 1. port1 is checked for TCP availability -> error
mockPortOpener.EXPECT().OpenLocalPort(nodePort1, protocolTCP).Return(nil, portTakenError),
// 2. port2 is checked for TCP availability -> success
mockPortOpener.EXPECT().OpenLocalPort(nodePort2, protocolTCP).Return(&fakeSocket{}, nil),
// 4. port2 is checked for UDP availability -> success
mockPortOpener.EXPECT().OpenLocalPort(nodePort2, protocolUDP).Return(&fakeSocket{}, nil),
)
})
customNodePort := defaultStartPort + 1
Expand Down
162 changes: 137 additions & 25 deletions pkg/agent/nodeportlocal/portcache/port_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,17 @@ import (

"k8s.io/klog/v2"

"k8s.io/client-go/tools/cache"
"antrea.io/antrea/pkg/agent/nodeportlocal/rules"
)

const (
NPLPortTableIndex = "NPLPortTableIndex"
nodePortIndex = "nodePortIndex"
podEndpointIndex = "podEndpointIndex"
podIPIndex = "podIPIndex"
)

// protocolSocketState represents the state of the socket corresponding to a
// given (Node port, protocol) tuple.
type protocolSocketState int
Expand All @@ -39,23 +47,30 @@ type NodePortData struct {
NodePort int
PodPort int
PodIP string
Protocols []ProtocolSocketData
Protocol ProtocolSocketData
}

func (d *NodePortData) FindProtocol(protocol string) *ProtocolSocketData {
for idx, protocolSocketData := range d.Protocols {
if protocolSocketData.Protocol == protocol {
return &d.Protocols[idx]
}
type CacheNpData struct {
Type string
Data *NodePortData
}

func NewCacheNpData(cacheType string, data *NodePortData) (*CacheNpData, error) {
if cacheType == "" {
err := fmt.Errorf("initialization of NPL CacheNpData failed")
return nil, err
}
return nil
npCache := CacheNpData{
Type: cacheType,
Data: data,
}
return &npCache, nil
}

func (d *NodePortData) ProtocolInUse(protocol string) bool {
for _, protocolSocketData := range d.Protocols {
if protocolSocketData.Protocol == protocol {
return protocolSocketData.State == stateInUse
}
protocolSocketData := d.Protocol
if protocolSocketData.Protocol == protocol {
return protocolSocketData.State == stateInUse
}
return false
}
Expand All @@ -67,8 +82,7 @@ type LocalPortOpener interface {
type localPortOpener struct{}

type PortTable struct {
NodePortTable map[string]*NodePortData
PodEndpointTable map[string]*NodePortData
PortTableCache cache.Indexer
StartPort int
EndPort int
PortSearchStart int
Expand All @@ -77,10 +91,102 @@ type PortTable struct {
tableLock sync.RWMutex
}

func GetPortTableKey(obj interface{}) (string, error) {
npData := obj.(*NodePortData)
key := fmt.Sprintf("%d:%s:%d:%s", npData.NodePort, npData.PodIP, npData.PodPort, npData.Protocol.Protocol)
return key, nil
}

func (pt *PortTable) addPortTableCache(npData *NodePortData) error {
if err := pt.PortTableCache.Add(npData); err != nil {
return err
}
return nil
}

func (pt *PortTable) delPortTableCache(npData *NodePortData) error {
if err := pt.PortTableCache.Delete(npData); err != nil {
return err
}
return nil
}

func (pt *PortTable) getPortTableCacheFromNodePortIndex(index string) (*NodePortData, bool) {
objs, _ := pt.PortTableCache.ByIndex(nodePortIndex, index)
if len(objs) == 0 {
return nil, false
}
return objs[0].(*NodePortData), true
}

func (pt *PortTable) getPortTableCacheFromPodEndpointIndex(index string) (*NodePortData, bool) {
objs, _ := pt.PortTableCache.ByIndex(podEndpointIndex, index)
if len(objs) == 0 {
return nil, false
}
return objs[0].(*NodePortData), true
}

func (pt *PortTable) getPortTableCacheFromPodIPIndex(index string) ([]NodePortData, bool) {
var npData []NodePortData
objs, _ := pt.PortTableCache.ByIndex(podIPIndex, index)
if len(objs) == 0 {
return nil, false
}
for _, obj := range objs {
npData = append(npData, *(obj.(*NodePortData)))
}
return npData, true
}

func (pt *PortTable) delPortTableCacheFromNodePortIndex(index string) error {
data, exists := pt.getPortTableCacheFromNodePortIndex(index)
if exists == false {
return nil
}
if err := pt.delPortTableCache(data); err != nil {
return err
}
return nil
}

func (pt *PortTable) releaseDataFromPortTableCache() error {
for _, obj := range pt.PortTableCache.List() {
data := obj.(*NodePortData)
if err := pt.delPortTableCache(data); err != nil {
return err
}
}
return nil
}

func NodePortIndexFunc(obj interface{}) ([]string, error) {
npData := obj.(*NodePortData)
nodePortTuple := NodePortProtoFormat(npData.NodePort, npData.Protocol.Protocol)
//nodePortTuple := []string{strconv.Itoa(npData.NodePort), npData.Protocol.Protocol}
return []string{nodePortTuple}, nil
}

func PodEndpointIndexFunc(obj interface{}) ([]string, error) {
npData := obj.(*NodePortData)
podEndpointTuple := podIPPortProtoFormat(npData.PodIP, npData.PodPort, npData.Protocol.Protocol)
//podEndpointTuple := []string{npData.PodIP, strconv.Itoa(npData.PodPort), npData.Protocol.Protocol}
return []string{podEndpointTuple}, nil
}

func PodIPIndexFunc(obj interface{}) ([]string, error) {
npData := obj.(*NodePortData)
return []string{npData.PodIP}, nil
}

func NewPortTable(start, end int) (*PortTable, error) {
ptable := PortTable{
NodePortTable: make(map[string]*NodePortData),
PodEndpointTable: make(map[string]*NodePortData),
PortTableCache: cache.NewIndexer(GetPortTableKey, cache.Indexers{
nodePortIndex: NodePortIndexFunc,
podEndpointIndex: PodEndpointIndexFunc,
podIPIndex: PodIPIndexFunc,
}),
//PortTableCache: cache.NewStore(GetPortTableKey),
StartPort: start,
EndPort: end,
PortSearchStart: start,
Expand All @@ -96,8 +202,7 @@ func NewPortTable(start, end int) (*PortTable, error) {
func (pt *PortTable) CleanupAllEntries() {
pt.tableLock.Lock()
defer pt.tableLock.Unlock()
pt.NodePortTable = make(map[string]*NodePortData)
pt.PodEndpointTable = make(map[string]*NodePortData)
pt.releaseDataFromPortTableCache()
}

func (pt *PortTable) GetDataForPodIP(ip string) []NodePortData {
Expand All @@ -107,23 +212,30 @@ func (pt *PortTable) GetDataForPodIP(ip string) []NodePortData {
}

func (pt *PortTable) getDataForPodIP(ip string) []NodePortData {
var allData []NodePortData
for i := range pt.NodePortTable {
if pt.NodePortTable[i].PodIP == ip {
allData = append(allData, *pt.NodePortTable[i])
}
allData, exist := pt.getPortTableCacheFromPodIPIndex(ip)
if exist == false {
return nil
}
return allData
}

func (pt *PortTable) getEntryByPodIPPort(ip string, port int) *NodePortData {
return pt.PodEndpointTable[podIPPortFormat(ip, port)]
// podIPPortFormat formats the ip, port to string ip:port.
func podIPPortProtoFormat(ip string, port int, protocol string) string {
return fmt.Sprintf("%s:%d:%s", ip, port, protocol)
}

func (pt *PortTable) getEntryByPodIPPortProto(ip string, port int, protocol string) *NodePortData {
data, exists := pt.getPortTableCacheFromPodEndpointIndex(podIPPortProtoFormat(ip, port, protocol))
if exists == false {
return nil
}
return data
}

func (pt *PortTable) RuleExists(podIP string, podPort int, protocol string) bool {
pt.tableLock.RLock()
defer pt.tableLock.RUnlock()
if data := pt.getEntryByPodIPPort(podIP, podPort); data != nil {
if data := pt.getEntryByPodIPPortProto(podIP, podPort, protocol); data != nil {
return data.ProtocolInUse(protocol)
}
return false
Expand Down
Loading

0 comments on commit 551e928

Please sign in to comment.