diff --git a/pkg/security/tests/connect_test.go b/pkg/security/tests/connect_test.go index 1e64fe57e1989..7f56dcfb8956c 100644 --- a/pkg/security/tests/connect_test.go +++ b/pkg/security/tests/connect_test.go @@ -12,7 +12,6 @@ import ( "context" "fmt" "net" - "sync" "testing" "golang.org/x/net/nettest" @@ -50,22 +49,20 @@ func TestConnectEvent(t *testing.T) { } t.Run("connect-af-inet-any-tcp-success", func(t *testing.T) { - var wg sync.WaitGroup - wg.Add(1) - defer wg.Wait() - - done := make(chan struct{}) - defer close(done) - go func() { - defer wg.Done() - err := bindAndAcceptConnection("tcp", ":4242", done) - if err != nil { - t.Error(err) - } - }() + done := make(chan error) + started := make(chan struct{}) + go bindAndAcceptConnection("tcp", ":4242", done, started) + + // Wait until the server is listening before attempting to connect + <-started test.WaitSignal(t, func() error { - return runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET", "any", "tcp") + if err := runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET", "any", "tcp"); err != nil { + return err + } + err := <-done + close(done) + return err }, func(event *model.Event, _ *rules.Rule) { assert.Equal(t, "connect", event.GetType(), "wrong event type") assert.Equal(t, uint16(unix.AF_INET), event.Connect.AddrFamily, "wrong address family") @@ -78,22 +75,19 @@ func TestConnectEvent(t *testing.T) { }) t.Run("connect-af-inet-any-udp-success", func(t *testing.T) { - var wg sync.WaitGroup - wg.Add(1) - defer wg.Wait() - - done := make(chan struct{}) - defer close(done) - go func() { - defer wg.Done() - err := bindAndAcceptConnection("udp", ":4242", done) - if err != nil { - t.Error(err) - } - }() + done := make(chan error) + started := make(chan struct{}) + go bindAndAcceptConnection("udp", ":4242", done, started) + + <-started test.WaitSignal(t, func() error { - return runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET", "any", "udp") + if err := runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET", "any", "udp"); err != nil { + return err + } + err := <-done + close(done) + return err }, func(event *model.Event, _ *rules.Rule) { assert.Equal(t, "connect", event.GetType(), "wrong event type") assert.Equal(t, uint16(unix.AF_INET), event.Connect.AddrFamily, "wrong address family") @@ -110,22 +104,19 @@ func TestConnectEvent(t *testing.T) { t.Skip("IPv6 is not supported") } - var wg sync.WaitGroup - wg.Add(1) - defer wg.Wait() - - done := make(chan struct{}) - defer close(done) - go func() { - defer wg.Done() - err := bindAndAcceptConnection("tcp", ":4242", done) - if err != nil { - t.Error(err) - } - }() + done := make(chan error) + started := make(chan struct{}) + go bindAndAcceptConnection("tcp", ":4242", done, started) + + <-started test.WaitSignal(t, func() error { - return runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET6", "any", "tcp") + if err := runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET6", "any", "tcp"); err != nil { + return err + } + err := <-done + close(done) + return err }, func(event *model.Event, _ *rules.Rule) { assert.Equal(t, "connect", event.GetType(), "wrong event type") assert.Equal(t, uint16(unix.AF_INET6), event.Connect.AddrFamily, "wrong address family") @@ -142,22 +133,19 @@ func TestConnectEvent(t *testing.T) { t.Skip("IPv6 is not supported") } - var wg sync.WaitGroup - wg.Add(1) - defer wg.Wait() - - done := make(chan struct{}) - defer close(done) - go func() { - defer wg.Done() - err := bindAndAcceptConnection("udp", ":4242", done) - if err != nil { - t.Error(err) - } - }() + done := make(chan error) + started := make(chan struct{}) + go bindAndAcceptConnection("udp", ":4242", done, started) + + <-started test.WaitSignal(t, func() error { - return runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET6", "any", "udp") + if err := runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET6", "any", "udp"); err != nil { + return err + } + err := <-done + close(done) + return err }, func(event *model.Event, _ *rules.Rule) { assert.Equal(t, "connect", event.GetType(), "wrong event type") assert.Equal(t, uint16(unix.AF_INET6), event.Connect.AddrFamily, "wrong address family") @@ -168,43 +156,47 @@ func TestConnectEvent(t *testing.T) { test.validateConnectSchema(t, event) }) }) - } -func bindAndAcceptConnection(proto, address string, done chan struct{}) error { +func bindAndAcceptConnection(proto, address string, done chan error, started chan struct{}) { switch proto { case "tcp": listener, err := net.Listen(proto, address) if err != nil { - return fmt.Errorf("failed to bind to %s:%s: %w", proto, address, err) + done <- fmt.Errorf("failed to bind to %s:%s: %w", proto, address, err) + return } defer listener.Close() - // Start a goroutine to accept connections continuously - go func() { - for { - c, err := listener.Accept() - if err != nil { - fmt.Printf("accept error: %v\n", err) - return - } - fmt.Println("Connection accepted") - defer c.Close() - } - }() + // Signal that the server is ready to accept connections + close(started) + + c, err := listener.Accept() + if err != nil { + fmt.Printf("accept error: %v\n", err) + done <- err + return + } + defer c.Close() + + fmt.Println("Connection accepted") case "udp", "unixgram": conn, err := net.ListenPacket(proto, address) if err != nil { - return fmt.Errorf("failed to bind to %s:%s: %w", proto, address, err) + done <- err + return } defer conn.Close() + // For packet-based connections, we can signal readiness immediately + close(started) + default: - return fmt.Errorf("unsupported protocol: %s", proto) + done <- fmt.Errorf("unsupported protocol: %s", proto) + return } // Wait for the test to complete before returning - <-done - return nil + done <- nil }