Skip to content

Commit

Permalink
[CWS] Fix test connect (#32394)
Browse files Browse the repository at this point in the history
  • Loading branch information
mftoure authored Dec 19, 2024
1 parent c92b3b0 commit 2ca6541
Showing 1 changed file with 68 additions and 76 deletions.
144 changes: 68 additions & 76 deletions pkg/security/tests/connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"context"
"fmt"
"net"
"sync"
"testing"

"golang.org/x/net/nettest"
Expand Down Expand Up @@ -50,22 +49,20 @@ func TestConnectEvent(t *testing.T) {
}

t.Run("connect-af-inet-any-tcp-success", func(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
defer wg.Wait()

done := make(chan struct{})
defer close(done)
go func() {
defer wg.Done()
err := bindAndAcceptConnection("tcp", ":4242", done)
if err != nil {
t.Error(err)
}
}()
done := make(chan error)
started := make(chan struct{})
go bindAndAcceptConnection("tcp", ":4242", done, started)

// Wait until the server is listening before attempting to connect
<-started

test.WaitSignal(t, func() error {
return runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET", "any", "tcp")
if err := runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET", "any", "tcp"); err != nil {
return err
}
err := <-done
close(done)
return err
}, func(event *model.Event, _ *rules.Rule) {
assert.Equal(t, "connect", event.GetType(), "wrong event type")
assert.Equal(t, uint16(unix.AF_INET), event.Connect.AddrFamily, "wrong address family")
Expand All @@ -78,22 +75,19 @@ func TestConnectEvent(t *testing.T) {
})

t.Run("connect-af-inet-any-udp-success", func(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
defer wg.Wait()

done := make(chan struct{})
defer close(done)
go func() {
defer wg.Done()
err := bindAndAcceptConnection("udp", ":4242", done)
if err != nil {
t.Error(err)
}
}()
done := make(chan error)
started := make(chan struct{})
go bindAndAcceptConnection("udp", ":4242", done, started)

<-started

test.WaitSignal(t, func() error {
return runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET", "any", "udp")
if err := runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET", "any", "udp"); err != nil {
return err
}
err := <-done
close(done)
return err
}, func(event *model.Event, _ *rules.Rule) {
assert.Equal(t, "connect", event.GetType(), "wrong event type")
assert.Equal(t, uint16(unix.AF_INET), event.Connect.AddrFamily, "wrong address family")
Expand All @@ -110,22 +104,19 @@ func TestConnectEvent(t *testing.T) {
t.Skip("IPv6 is not supported")
}

var wg sync.WaitGroup
wg.Add(1)
defer wg.Wait()

done := make(chan struct{})
defer close(done)
go func() {
defer wg.Done()
err := bindAndAcceptConnection("tcp", ":4242", done)
if err != nil {
t.Error(err)
}
}()
done := make(chan error)
started := make(chan struct{})
go bindAndAcceptConnection("tcp", ":4242", done, started)

<-started

test.WaitSignal(t, func() error {
return runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET6", "any", "tcp")
if err := runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET6", "any", "tcp"); err != nil {
return err
}
err := <-done
close(done)
return err
}, func(event *model.Event, _ *rules.Rule) {
assert.Equal(t, "connect", event.GetType(), "wrong event type")
assert.Equal(t, uint16(unix.AF_INET6), event.Connect.AddrFamily, "wrong address family")
Expand All @@ -142,22 +133,19 @@ func TestConnectEvent(t *testing.T) {
t.Skip("IPv6 is not supported")
}

var wg sync.WaitGroup
wg.Add(1)
defer wg.Wait()

done := make(chan struct{})
defer close(done)
go func() {
defer wg.Done()
err := bindAndAcceptConnection("udp", ":4242", done)
if err != nil {
t.Error(err)
}
}()
done := make(chan error)
started := make(chan struct{})
go bindAndAcceptConnection("udp", ":4242", done, started)

<-started

test.WaitSignal(t, func() error {
return runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET6", "any", "udp")
if err := runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET6", "any", "udp"); err != nil {
return err
}
err := <-done
close(done)
return err
}, func(event *model.Event, _ *rules.Rule) {
assert.Equal(t, "connect", event.GetType(), "wrong event type")
assert.Equal(t, uint16(unix.AF_INET6), event.Connect.AddrFamily, "wrong address family")
Expand All @@ -168,43 +156,47 @@ func TestConnectEvent(t *testing.T) {
test.validateConnectSchema(t, event)
})
})

}

func bindAndAcceptConnection(proto, address string, done chan struct{}) error {
func bindAndAcceptConnection(proto, address string, done chan error, started chan struct{}) {
switch proto {
case "tcp":
listener, err := net.Listen(proto, address)
if err != nil {
return fmt.Errorf("failed to bind to %s:%s: %w", proto, address, err)
done <- fmt.Errorf("failed to bind to %s:%s: %w", proto, address, err)
return
}
defer listener.Close()

// Start a goroutine to accept connections continuously
go func() {
for {
c, err := listener.Accept()
if err != nil {
fmt.Printf("accept error: %v\n", err)
return
}
fmt.Println("Connection accepted")
defer c.Close()
}
}()
// Signal that the server is ready to accept connections
close(started)

c, err := listener.Accept()
if err != nil {
fmt.Printf("accept error: %v\n", err)
done <- err
return
}
defer c.Close()

fmt.Println("Connection accepted")

case "udp", "unixgram":
conn, err := net.ListenPacket(proto, address)
if err != nil {
return fmt.Errorf("failed to bind to %s:%s: %w", proto, address, err)
done <- err
return
}
defer conn.Close()

// For packet-based connections, we can signal readiness immediately
close(started)

default:
return fmt.Errorf("unsupported protocol: %s", proto)
done <- fmt.Errorf("unsupported protocol: %s", proto)
return
}

// Wait for the test to complete before returning
<-done
return nil
done <- nil
}

0 comments on commit 2ca6541

Please sign in to comment.