From 9b7219067d4aa15d1ff81c91bdca12c32285239c Mon Sep 17 00:00:00 2001 From: Bobby Peck Date: Tue, 19 Dec 2023 11:31:32 -0700 Subject: [PATCH] Allow parallel Accepts --- listen.go | 66 +++++++++++++++++++++++++++++++++++--------------- listen_test.go | 53 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 20 deletions(-) diff --git a/listen.go b/listen.go index fb503e6..b03b774 100644 --- a/listen.go +++ b/listen.go @@ -72,6 +72,7 @@ type connRequest struct { socketId uint32 timestamp uint32 + config Config handshake *packet.CIFHandshake crypto crypto.Crypto passphrase string @@ -160,7 +161,9 @@ type listener struct { stopReader context.CancelFunc - doneChan chan error + doneChan chan struct{} + doneErr error + doneOnce sync.Once } // Listen returns a new listener on the SRT protocol on the address with @@ -220,7 +223,7 @@ func Listen(network, address string, config Config) (Listener, error) { } ln.syncookie = syncookie - ln.doneChan = make(chan error) + ln.doneChan = make(chan struct{}) ln.start = time.Now() @@ -233,7 +236,7 @@ func Listen(network, address string, config Config) (Listener, error) { for { if ln.isShutdown() { - ln.doneChan <- ErrListenerClosed + ln.markDone(ErrListenerClosed) return } @@ -245,11 +248,11 @@ func Listen(network, address string, config Config) (Listener, error) { } if ln.isShutdown() { - ln.doneChan <- ErrListenerClosed + ln.markDone(ErrListenerClosed) return } - ln.doneChan <- err + ln.markDone(err) return } @@ -276,8 +279,8 @@ func (ln *listener) Accept(acceptFn AcceptFunc) (Conn, ConnType, error) { } select { - case err := <-ln.doneChan: - return nil, REJECT, err + case <-ln.doneChan: + return nil, REJECT, ln.error() case request := <-ln.backlog: if acceptFn == nil { ln.reject(request, packet.REJ_PEER) @@ -299,8 +302,8 @@ func (ln *listener) Accept(acceptFn AcceptFunc) (Conn, ConnType, error) { socketId := uint32(time.Since(ln.start).Microseconds()) // Select the largest TSBPD delay advertised by the caller, but at least 120ms - recvTsbpdDelay := uint16(ln.config.ReceiverLatency.Milliseconds()) - sendTsbpdDelay := uint16(ln.config.PeerLatency.Milliseconds()) + recvTsbpdDelay := uint16(request.config.ReceiverLatency.Milliseconds()) + sendTsbpdDelay := uint16(request.config.PeerLatency.Milliseconds()) if request.handshake.Version == 5 { if request.handshake.SRTHS.SendTSBPDDelay > recvTsbpdDelay { @@ -311,17 +314,17 @@ func (ln *listener) Accept(acceptFn AcceptFunc) (Conn, ConnType, error) { sendTsbpdDelay = request.handshake.SRTHS.RecvTSBPDDelay } - ln.config.StreamId = request.handshake.StreamId + request.config.StreamId = request.handshake.StreamId } - ln.config.Passphrase = request.passphrase + request.config.Passphrase = request.passphrase // Create a new connection conn := newSRTConn(srtConnConfig{ version: request.handshake.Version, localAddr: ln.addr, remoteAddr: request.addr, - config: ln.config, + config: request.config, start: request.start, socketId: socketId, peerSocketId: request.handshake.SRTSocketId, @@ -333,7 +336,7 @@ func (ln *listener) Accept(acceptFn AcceptFunc) (Conn, ConnType, error) { keyBaseEncryption: packet.EvenKeyEncrypted, onSend: ln.send, onShutdown: ln.handleShutdown, - logger: ln.config.Logger, + logger: request.config.Logger, }) ln.log("connection:new", func() string { return fmt.Sprintf("%#08x (%s) %s", conn.SocketId(), conn.StreamId(), mode) }) @@ -369,6 +372,25 @@ func (ln *listener) Accept(acceptFn AcceptFunc) (Conn, ConnType, error) { return nil, REJECT, nil } +// markDone marks the listener as done by closing +// the done channel & sets the error +func (ln *listener) markDone(err error) { + ln.doneOnce.Do(func() { + ln.lock.Lock() + defer ln.lock.Unlock() + ln.doneErr = err + close(ln.doneChan) + }) +} + +// error returns the error that caused the listener to be done +// if it's nil then the listener is not done +func (ln *listener) error() error { + ln.lock.Lock() + defer ln.lock.Unlock() + return ln.doneErr +} + func (ln *listener) handleShutdown(socketId uint32) { ln.lock.Lock() delete(ln.conns, socketId) @@ -542,6 +564,9 @@ func (ln *listener) handleHandshake(p packet.Packet) { cif.PeerIP.FromNetAddr(ln.addr) + // Create a copy of the configuration for the connection + config := ln.config + if cif.HandshakeType == packet.HSTYPE_INDUCTION { // cif cif.Version = 5 @@ -585,13 +610,13 @@ func (ln *listener) handleHandshake(p packet.Packet) { } // If the peer has a smaller MTU size, adjust to it - if cif.MaxTransmissionUnitSize < ln.config.MSS { - ln.config.MSS = cif.MaxTransmissionUnitSize - ln.config.PayloadSize = ln.config.MSS - SRT_HEADER_SIZE - UDP_HEADER_SIZE + if cif.MaxTransmissionUnitSize < config.MSS { + config.MSS = cif.MaxTransmissionUnitSize + config.PayloadSize = config.MSS - SRT_HEADER_SIZE - UDP_HEADER_SIZE - if ln.config.PayloadSize < MIN_PAYLOAD_SIZE { + if config.PayloadSize < MIN_PAYLOAD_SIZE { cif.HandshakeType = packet.REJ_ROGUE - ln.log("handshake:recv:error", func() string { return fmt.Sprintf("payload size is too small (%d bytes)", ln.config.PayloadSize) }) + ln.log("handshake:recv:error", func() string { return fmt.Sprintf("payload size is too small (%d bytes)", config.PayloadSize) }) p.MarshalCIF(cif) ln.log("handshake:send:dump", func() string { return p.Dump() }) ln.log("handshake:send:cif", func() string { return cif.String() }) @@ -614,10 +639,10 @@ func (ln *listener) handleHandshake(p packet.Packet) { } } else if cif.Version == 5 { // Check if the peer version is sufficient - if cif.SRTHS.SRTVersion < ln.config.MinVersion { + if cif.SRTHS.SRTVersion < config.MinVersion { cif.HandshakeType = packet.REJ_VERSION ln.log("handshake:recv:error", func() string { - return fmt.Sprintf("peer version insufficient (%#06x), expecting at least %#06x", cif.SRTHS.SRTVersion, ln.config.MinVersion) + return fmt.Sprintf("peer version insufficient (%#06x), expecting at least %#06x", cif.SRTHS.SRTVersion, config.MinVersion) }) p.MarshalCIF(cif) ln.log("handshake:send:dump", func() string { return p.Dump() }) @@ -668,6 +693,7 @@ func (ln *listener) handleHandshake(p packet.Packet) { start: time.Now(), socketId: cif.SRTSocketId, timestamp: p.Header().Timestamp, + config: config, handshake: cif, } diff --git a/listen_test.go b/listen_test.go index 760a47d..1569c1c 100644 --- a/listen_test.go +++ b/listen_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "net" + "strconv" "sync" "testing" "time" @@ -382,3 +383,55 @@ func TestListenHSV5(t *testing.T) { pc.Close() } + +func TestListenAsync(t *testing.T) { + const parallelCount = 2 + ln, err := Listen("srt", "127.0.0.1:6003", DefaultConfig()) + require.NoError(t, err) + var ( + // All streams are pending + pendingWg sync.WaitGroup + pendingSet sync.Map // Set of which streams are pending + // All streams are connected + connectedWg sync.WaitGroup + // All listener goroutines are stopped + listenerWg sync.WaitGroup + ) + listenerWg.Add(parallelCount) + pendingWg.Add(parallelCount) + connectedWg.Add(parallelCount) + for i := 0; i < parallelCount; i++ { + go func() { + defer listenerWg.Done() + for { + _, _, err := ln.Accept(func(req ConnRequest) ConnType { + // Only call Done() if we're the first request for this stream + if _, ok := pendingSet.Swap(req.StreamId(), struct{}{}); !ok { + pendingWg.Done() + } + // Wait for all streams to be pending Before returning + pendingWg.Wait() + return PUBLISH + }) + if err == ErrListenerClosed { + return + } + require.NoError(t, err) + } + }() + + go func(streamId string) { + config := DefaultConfig() + config.StreamId = streamId + conn, err := Dial("srt", "127.0.0.1:6003", config) + require.NoError(t, err) + connectedWg.Done() + conn.Close() + }(strconv.Itoa(i)) + } + + // Wait for all streams to be connected + connectedWg.Wait() + ln.Close() + listenerWg.Wait() +}