Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CWS] Fix test connect #32394

Merged
merged 4 commits into from
Dec 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 68 additions & 76 deletions pkg/security/tests/connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"context"
"fmt"
"net"
"sync"
"testing"

"golang.org/x/net/nettest"
Expand Down Expand Up @@ -50,22 +49,20 @@ func TestConnectEvent(t *testing.T) {
}

t.Run("connect-af-inet-any-tcp-success", func(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
defer wg.Wait()

done := make(chan struct{})
defer close(done)
go func() {
defer wg.Done()
err := bindAndAcceptConnection("tcp", ":4242", done)
if err != nil {
t.Error(err)
}
}()
done := make(chan error)
started := make(chan struct{})
go bindAndAcceptConnection("tcp", ":4242", done, started)

// Wait until the server is listening before attempting to connect
<-started

test.WaitSignal(t, func() error {
return runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET", "any", "tcp")
if err := runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET", "any", "tcp"); err != nil {
return err
}
err := <-done
close(done)
return err
}, func(event *model.Event, _ *rules.Rule) {
assert.Equal(t, "connect", event.GetType(), "wrong event type")
assert.Equal(t, uint16(unix.AF_INET), event.Connect.AddrFamily, "wrong address family")
Expand All @@ -78,22 +75,19 @@ func TestConnectEvent(t *testing.T) {
})

t.Run("connect-af-inet-any-udp-success", func(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
defer wg.Wait()

done := make(chan struct{})
defer close(done)
go func() {
defer wg.Done()
err := bindAndAcceptConnection("udp", ":4242", done)
if err != nil {
t.Error(err)
}
}()
done := make(chan error)
started := make(chan struct{})
go bindAndAcceptConnection("udp", ":4242", done, started)

<-started

test.WaitSignal(t, func() error {
return runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET", "any", "udp")
if err := runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET", "any", "udp"); err != nil {
return err
}
err := <-done
close(done)
return err
}, func(event *model.Event, _ *rules.Rule) {
assert.Equal(t, "connect", event.GetType(), "wrong event type")
assert.Equal(t, uint16(unix.AF_INET), event.Connect.AddrFamily, "wrong address family")
Expand All @@ -110,22 +104,19 @@ func TestConnectEvent(t *testing.T) {
t.Skip("IPv6 is not supported")
}

var wg sync.WaitGroup
wg.Add(1)
defer wg.Wait()

done := make(chan struct{})
defer close(done)
go func() {
defer wg.Done()
err := bindAndAcceptConnection("tcp", ":4242", done)
if err != nil {
t.Error(err)
}
}()
done := make(chan error)
started := make(chan struct{})
go bindAndAcceptConnection("tcp", ":4242", done, started)

<-started

test.WaitSignal(t, func() error {
return runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET6", "any", "tcp")
if err := runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET6", "any", "tcp"); err != nil {
return err
}
err := <-done
close(done)
return err
}, func(event *model.Event, _ *rules.Rule) {
assert.Equal(t, "connect", event.GetType(), "wrong event type")
assert.Equal(t, uint16(unix.AF_INET6), event.Connect.AddrFamily, "wrong address family")
Expand All @@ -142,22 +133,19 @@ func TestConnectEvent(t *testing.T) {
t.Skip("IPv6 is not supported")
}

var wg sync.WaitGroup
wg.Add(1)
defer wg.Wait()

done := make(chan struct{})
defer close(done)
go func() {
defer wg.Done()
err := bindAndAcceptConnection("udp", ":4242", done)
if err != nil {
t.Error(err)
}
}()
done := make(chan error)
started := make(chan struct{})
go bindAndAcceptConnection("udp", ":4242", done, started)

<-started

test.WaitSignal(t, func() error {
return runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET6", "any", "udp")
if err := runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET6", "any", "udp"); err != nil {
return err
}
err := <-done
close(done)
return err
}, func(event *model.Event, _ *rules.Rule) {
assert.Equal(t, "connect", event.GetType(), "wrong event type")
assert.Equal(t, uint16(unix.AF_INET6), event.Connect.AddrFamily, "wrong address family")
Expand All @@ -168,43 +156,47 @@ func TestConnectEvent(t *testing.T) {
test.validateConnectSchema(t, event)
})
})

}

func bindAndAcceptConnection(proto, address string, done chan struct{}) error {
func bindAndAcceptConnection(proto, address string, done chan error, started chan struct{}) {
switch proto {
case "tcp":
listener, err := net.Listen(proto, address)
if err != nil {
return fmt.Errorf("failed to bind to %s:%s: %w", proto, address, err)
done <- fmt.Errorf("failed to bind to %s:%s: %w", proto, address, err)
return
}
defer listener.Close()

// Start a goroutine to accept connections continuously
go func() {
for {
c, err := listener.Accept()
if err != nil {
fmt.Printf("accept error: %v\n", err)
return
}
fmt.Println("Connection accepted")
defer c.Close()
}
}()
// Signal that the server is ready to accept connections
close(started)

c, err := listener.Accept()
if err != nil {
fmt.Printf("accept error: %v\n", err)
done <- err
return
}
defer c.Close()

fmt.Println("Connection accepted")

case "udp", "unixgram":
conn, err := net.ListenPacket(proto, address)
if err != nil {
return fmt.Errorf("failed to bind to %s:%s: %w", proto, address, err)
done <- err
return
}
defer conn.Close()

// For packet-based connections, we can signal readiness immediately
close(started)

default:
return fmt.Errorf("unsupported protocol: %s", proto)
done <- fmt.Errorf("unsupported protocol: %s", proto)
return
}

// Wait for the test to complete before returning
<-done
return nil
done <- nil
}
Loading