Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mftoure committed Dec 18, 2024
1 parent edb40fd commit 7485988
Showing 1 changed file with 39 additions and 34 deletions.
73 changes: 39 additions & 34 deletions pkg/security/tests/connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@ func TestConnectEvent(t *testing.T) {

t.Run("connect-af-inet-any-tcp-success", func(t *testing.T) {

done := make(chan struct{})
done := make(chan error)
defer close(done)
go func() {
err := bindAndAcceptConnection("tcp", ":4242", done)
if err != nil {
t.Error(err)
}
bindAndAcceptConnection("tcp", ":4242", done)
}()

test.WaitSignal(t, func() error {
return runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET", "any", "tcp")
if err := runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET", "any", "tcp"); err != nil {
return err
}
return <-done
}, func(event *model.Event, _ *rules.Rule) {
assert.Equal(t, "connect", event.GetType(), "wrong event type")
assert.Equal(t, uint16(unix.AF_INET), event.Connect.AddrFamily, "wrong address family")
Expand All @@ -74,17 +74,17 @@ func TestConnectEvent(t *testing.T) {

t.Run("connect-af-inet-any-udp-success", func(t *testing.T) {

done := make(chan struct{})
done := make(chan error)
defer close(done)
go func() {
err := bindAndAcceptConnection("udp", ":4242", done)
if err != nil {
t.Error(err)
}
bindAndAcceptConnection("udp", ":4242", done)
}()

test.WaitSignal(t, func() error {
return runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET", "any", "udp")
if err := runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET", "any", "udp"); err != nil {
return err
}
return <-done
}, func(event *model.Event, _ *rules.Rule) {
assert.Equal(t, "connect", event.GetType(), "wrong event type")
assert.Equal(t, uint16(unix.AF_INET), event.Connect.AddrFamily, "wrong address family")
Expand All @@ -101,17 +101,17 @@ func TestConnectEvent(t *testing.T) {
t.Skip("IPv6 is not supported")
}

done := make(chan struct{})
done := make(chan error)
defer close(done)
go func() {
err := bindAndAcceptConnection("tcp", ":4242", done)
if err != nil {
t.Error(err)
}
bindAndAcceptConnection("tcp", ":4242", done)
}()

test.WaitSignal(t, func() error {
return runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET6", "any", "tcp")
if err := runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET6", "any", "tcp"); err != nil {
return err
}
return <-done
}, func(event *model.Event, _ *rules.Rule) {
assert.Equal(t, "connect", event.GetType(), "wrong event type")
assert.Equal(t, uint16(unix.AF_INET6), event.Connect.AddrFamily, "wrong address family")
Expand All @@ -128,17 +128,17 @@ func TestConnectEvent(t *testing.T) {
t.Skip("IPv6 is not supported")
}

done := make(chan struct{})
done := make(chan error)
defer close(done)
go func() {
err := bindAndAcceptConnection("udp", ":4242", done)
if err != nil {
t.Error(err)
}
bindAndAcceptConnection("udp", ":4242", done)
}()

test.WaitSignal(t, func() error {
return runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET6", "any", "udp")
if err := runSyscallTesterFunc(context.Background(), t, syscallTester, "connect", "AF_INET6", "any", "udp"); err != nil {
return err
}
return <-done
}, func(event *model.Event, _ *rules.Rule) {
assert.Equal(t, "connect", event.GetType(), "wrong event type")
assert.Equal(t, uint16(unix.AF_INET6), event.Connect.AddrFamily, "wrong address family")
Expand All @@ -152,36 +152,41 @@ func TestConnectEvent(t *testing.T) {

}

func bindAndAcceptConnection(proto, address string, done chan struct{}) error {
func bindAndAcceptConnection(proto, address string, done chan error) {
switch proto {
case "tcp":
listener, err := net.Listen(proto, address)
if err != nil {
return fmt.Errorf("failed to bind to %s:%s: %w", proto, address, err)
done <- fmt.Errorf("failed to bind to %s:%s: %w", proto, address, err)
return
} else {
defer listener.Close()
}
defer listener.Close()

c, err := listener.Accept()
defer c.Close()

if err != nil {
fmt.Printf("accept error: %v\n", err)
return err
done <- err
return
} else {
defer c.Close()
}

fmt.Println("Connection accepted")

case "udp", "unixgram":
conn, err := net.ListenPacket(proto, address)
if err != nil {
return fmt.Errorf("failed to bind to %s:%s: %w", proto, address, err)
done <- err
return
}
defer conn.Close()

default:
return fmt.Errorf("unsupported protocol: %s", proto)
done <- fmt.Errorf("unsupported protocol: %s", proto)
return
}

// Wait for the test to complete before returning
<-done
return nil
done <- nil
}

0 comments on commit 7485988

Please sign in to comment.