diff --git a/pkg/udplistener/udplistener.go b/pkg/udplistener/udplistener.go index e748faa01..016f9908a 100644 --- a/pkg/udplistener/udplistener.go +++ b/pkg/udplistener/udplistener.go @@ -115,9 +115,9 @@ func (c *conn) Read(byt []byte) (int, error) { return 0, errTerminated } - copy(byt, buf) - c.listener.readDone <- struct{}{} - return len(buf), nil + n := copy(byt, buf) + c.listener.readDone <- n + return n, nil } // Write implements the net.Conn interface. @@ -162,7 +162,7 @@ type Listener struct { closed bool accept chan net.Conn - readDone chan struct{} + readDone chan int } // New allocates a Listener. @@ -176,10 +176,10 @@ func New(network, address string) (net.Listener, error) { pc: pc, conns: make(map[connIndex]*conn), accept: make(chan net.Conn), - readDone: make(chan struct{}), + readDone: make(chan int), } - go l.reader() + go l.runReader() return l, nil } @@ -211,7 +211,7 @@ func (l *Listener) Addr() net.Addr { return l.pc.LocalAddr() } -func (l *Listener) reader() { +func (l *Listener) runReader() { buf := make([]byte, bufferSize) for { @@ -235,7 +235,7 @@ func (l *Listener) reader() { conn, preExisting := l.conns[connIndex] if !preExisting && l.closed { - // listener is closed, ignore new connection + // listener is closed, ignore new connections } else { if !preExisting { conn = newConn(l, connIndex, uaddr) @@ -243,11 +243,16 @@ func (l *Listener) reader() { l.accept <- conn } - // route buffer to connection - conn.read <- buf[:n] + start := 0 + for n > 0 { + // route buffer to connection + conn.read <- buf[start : start+n] - // wait copy since buffer is shared - <-l.readDone + // wait copy since buffer is shared + read := <-l.readDone + n -= read + start += read + } } }() } diff --git a/pkg/udplistener/udplistener_test.go b/pkg/udplistener/udplistener_test.go index b82f77fb3..3b3567493 100644 --- a/pkg/udplistener/udplistener_test.go +++ b/pkg/udplistener/udplistener_test.go @@ -1,6 +1,7 @@ package udplistener import ( + "bytes" "net" "sync" "testing" @@ -9,12 +10,13 @@ import ( "github.com/stretchr/testify/require" ) -func TestUdpListener(t *testing.T) { +func TestMain(t *testing.T) { testBuf1 := []byte("testing testing 1 2 3") testBuf2 := []byte("second part") l, err := New("udp4", "127.0.0.1:18456") require.NoError(t, err) + defer l.Close() var wg sync.WaitGroup wg.Add(5) @@ -64,12 +66,46 @@ func TestUdpListener(t *testing.T) { } wg.Wait() - l.Close() } -func TestUdpListenerDeadline(t *testing.T) { +func TestSamePacketMultipleReads(t *testing.T) { + l, err := New("udp4", "127.0.0.1:18456") + require.NoError(t, err) + defer l.Close() + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + + conn, err := l.Accept() + require.NoError(t, err) + defer conn.Close() + + buf := make([]byte, 256) + + for i := 0; i < 4; i++ { + n, err := conn.Read(buf) + require.NoError(t, err) + require.Equal(t, 256, n) + } + }() + + conn, err := net.Dial("udp4", "127.0.0.1:18456") + require.NoError(t, err) + defer conn.Close() + + _, err = conn.Write(bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1024/4)) + require.NoError(t, err) + + wg.Wait() +} + +func TestDeadline(t *testing.T) { l, err := New("udp4", "127.0.0.1:18456") require.NoError(t, err) + defer l.Close() var wg sync.WaitGroup wg.Add(2) @@ -78,7 +114,6 @@ func TestUdpListenerDeadline(t *testing.T) { go func() { defer wg.Done() - defer l.Close() conn, err := l.Accept() require.NoError(t, err) @@ -122,7 +157,7 @@ func TestUdpListenerDeadline(t *testing.T) { require.NoError(t, err2) } -func TestUdpListenerDoubleClose(t *testing.T) { +func TestDoubleClose(t *testing.T) { l, err := New("udp4", "127.0.0.1:18456") require.NoError(t, err) l.Close()