Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow parallel Accepts #43

Merged
merged 1 commit into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 46 additions & 20 deletions listen.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ type connRequest struct {
socketId uint32
timestamp uint32

config Config
handshake *packet.CIFHandshake
crypto crypto.Crypto
passphrase string
Expand Down Expand Up @@ -160,7 +161,9 @@ type listener struct {

stopReader context.CancelFunc

doneChan chan error
doneChan chan struct{}
doneErr error
doneOnce sync.Once
}

// Listen returns a new listener on the SRT protocol on the address with
Expand Down Expand Up @@ -220,7 +223,7 @@ func Listen(network, address string, config Config) (Listener, error) {
}
ln.syncookie = syncookie

ln.doneChan = make(chan error)
ln.doneChan = make(chan struct{})

ln.start = time.Now()

Expand All @@ -233,7 +236,7 @@ func Listen(network, address string, config Config) (Listener, error) {

for {
if ln.isShutdown() {
ln.doneChan <- ErrListenerClosed
ln.markDone(ErrListenerClosed)
return
}

Expand All @@ -245,11 +248,11 @@ func Listen(network, address string, config Config) (Listener, error) {
}

if ln.isShutdown() {
ln.doneChan <- ErrListenerClosed
ln.markDone(ErrListenerClosed)
return
}

ln.doneChan <- err
ln.markDone(err)
return
}

Expand All @@ -276,8 +279,8 @@ func (ln *listener) Accept(acceptFn AcceptFunc) (Conn, ConnType, error) {
}

select {
case err := <-ln.doneChan:
return nil, REJECT, err
case <-ln.doneChan:
return nil, REJECT, ln.error()
case request := <-ln.backlog:
if acceptFn == nil {
ln.reject(request, packet.REJ_PEER)
Expand All @@ -299,8 +302,8 @@ func (ln *listener) Accept(acceptFn AcceptFunc) (Conn, ConnType, error) {
socketId := uint32(time.Since(ln.start).Microseconds())

// Select the largest TSBPD delay advertised by the caller, but at least 120ms
recvTsbpdDelay := uint16(ln.config.ReceiverLatency.Milliseconds())
sendTsbpdDelay := uint16(ln.config.PeerLatency.Milliseconds())
recvTsbpdDelay := uint16(request.config.ReceiverLatency.Milliseconds())
sendTsbpdDelay := uint16(request.config.PeerLatency.Milliseconds())

if request.handshake.Version == 5 {
if request.handshake.SRTHS.SendTSBPDDelay > recvTsbpdDelay {
Expand All @@ -311,17 +314,17 @@ func (ln *listener) Accept(acceptFn AcceptFunc) (Conn, ConnType, error) {
sendTsbpdDelay = request.handshake.SRTHS.RecvTSBPDDelay
}

ln.config.StreamId = request.handshake.StreamId
request.config.StreamId = request.handshake.StreamId
}

ln.config.Passphrase = request.passphrase
request.config.Passphrase = request.passphrase

// Create a new connection
conn := newSRTConn(srtConnConfig{
version: request.handshake.Version,
localAddr: ln.addr,
remoteAddr: request.addr,
config: ln.config,
config: request.config,
start: request.start,
socketId: socketId,
peerSocketId: request.handshake.SRTSocketId,
Expand All @@ -333,7 +336,7 @@ func (ln *listener) Accept(acceptFn AcceptFunc) (Conn, ConnType, error) {
keyBaseEncryption: packet.EvenKeyEncrypted,
onSend: ln.send,
onShutdown: ln.handleShutdown,
logger: ln.config.Logger,
logger: request.config.Logger,
})

ln.log("connection:new", func() string { return fmt.Sprintf("%#08x (%s) %s", conn.SocketId(), conn.StreamId(), mode) })
Expand Down Expand Up @@ -369,6 +372,25 @@ func (ln *listener) Accept(acceptFn AcceptFunc) (Conn, ConnType, error) {
return nil, REJECT, nil
}

// markDone marks the listener as done by closing
// the done channel & sets the error
func (ln *listener) markDone(err error) {
ln.doneOnce.Do(func() {
ln.lock.Lock()
defer ln.lock.Unlock()
ln.doneErr = err
close(ln.doneChan)
})
}

// error returns the error that caused the listener to be done
// if it's nil then the listener is not done
func (ln *listener) error() error {
ln.lock.Lock()
defer ln.lock.Unlock()
return ln.doneErr
}

func (ln *listener) handleShutdown(socketId uint32) {
ln.lock.Lock()
delete(ln.conns, socketId)
Expand Down Expand Up @@ -542,6 +564,9 @@ func (ln *listener) handleHandshake(p packet.Packet) {

cif.PeerIP.FromNetAddr(ln.addr)

// Create a copy of the configuration for the connection
config := ln.config

if cif.HandshakeType == packet.HSTYPE_INDUCTION {
// cif
cif.Version = 5
Expand Down Expand Up @@ -585,13 +610,13 @@ func (ln *listener) handleHandshake(p packet.Packet) {
}

// If the peer has a smaller MTU size, adjust to it
if cif.MaxTransmissionUnitSize < ln.config.MSS {
ln.config.MSS = cif.MaxTransmissionUnitSize
ln.config.PayloadSize = ln.config.MSS - SRT_HEADER_SIZE - UDP_HEADER_SIZE
if cif.MaxTransmissionUnitSize < config.MSS {
config.MSS = cif.MaxTransmissionUnitSize
config.PayloadSize = config.MSS - SRT_HEADER_SIZE - UDP_HEADER_SIZE

if ln.config.PayloadSize < MIN_PAYLOAD_SIZE {
if config.PayloadSize < MIN_PAYLOAD_SIZE {
cif.HandshakeType = packet.REJ_ROGUE
ln.log("handshake:recv:error", func() string { return fmt.Sprintf("payload size is too small (%d bytes)", ln.config.PayloadSize) })
ln.log("handshake:recv:error", func() string { return fmt.Sprintf("payload size is too small (%d bytes)", config.PayloadSize) })
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
Expand All @@ -614,10 +639,10 @@ func (ln *listener) handleHandshake(p packet.Packet) {
}
} else if cif.Version == 5 {
// Check if the peer version is sufficient
if cif.SRTHS.SRTVersion < ln.config.MinVersion {
if cif.SRTHS.SRTVersion < config.MinVersion {
cif.HandshakeType = packet.REJ_VERSION
ln.log("handshake:recv:error", func() string {
return fmt.Sprintf("peer version insufficient (%#06x), expecting at least %#06x", cif.SRTHS.SRTVersion, ln.config.MinVersion)
return fmt.Sprintf("peer version insufficient (%#06x), expecting at least %#06x", cif.SRTHS.SRTVersion, config.MinVersion)
})
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
Expand Down Expand Up @@ -668,6 +693,7 @@ func (ln *listener) handleHandshake(p packet.Packet) {
start: time.Now(),
socketId: cif.SRTSocketId,
timestamp: p.Header().Timestamp,
config: config,

handshake: cif,
}
Expand Down
53 changes: 53 additions & 0 deletions listen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"bytes"
"context"
"net"
"strconv"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -382,3 +383,55 @@

pc.Close()
}

func TestListenAsync(t *testing.T) {
const parallelCount = 2
ln, err := Listen("srt", "127.0.0.1:6003", DefaultConfig())
require.NoError(t, err)
var (
// All streams are pending
pendingWg sync.WaitGroup
pendingSet sync.Map // Set of which streams are pending
// All streams are connected
connectedWg sync.WaitGroup
// All listener goroutines are stopped
listenerWg sync.WaitGroup
)
listenerWg.Add(parallelCount)
pendingWg.Add(parallelCount)
connectedWg.Add(parallelCount)
for i := 0; i < parallelCount; i++ {
go func() {
defer listenerWg.Done()
for {
_, _, err := ln.Accept(func(req ConnRequest) ConnType {
// Only call Done() if we're the first request for this stream
if _, ok := pendingSet.Swap(req.StreamId(), struct{}{}); !ok {

Check failure on line 409 in listen_test.go

View workflow job for this annotation

GitHub Actions / build

pendingSet.Swap undefined (type sync.Map has no field or method Swap)
pendingWg.Done()
}
// Wait for all streams to be pending Before returning
pendingWg.Wait()
return PUBLISH
})
if err == ErrListenerClosed {
return
}
require.NoError(t, err)
}
}()

go func(streamId string) {
config := DefaultConfig()
config.StreamId = streamId
conn, err := Dial("srt", "127.0.0.1:6003", config)
require.NoError(t, err)
connectedWg.Done()
conn.Close()
}(strconv.Itoa(i))
}

// Wait for all streams to be connected
connectedWg.Wait()
ln.Close()
listenerWg.Wait()
}
Loading