Skip to content

Commit

Permalink
Merge branch 'bp/parallel-listens' of github.com:muxinc/gosrt into mu…
Browse files Browse the repository at this point in the history
…xinc-bp/parallel-listens
  • Loading branch information
ioppermann committed Jan 15, 2024
2 parents 49a1a62 + 9b72190 commit 151a03c
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 20 deletions.
66 changes: 46 additions & 20 deletions listen.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ type connRequest struct {
socketId uint32
timestamp uint32

config Config
handshake *packet.CIFHandshake
crypto crypto.Crypto
passphrase string
Expand Down Expand Up @@ -215,7 +216,9 @@ type listener struct {

stopReader context.CancelFunc

doneChan chan error
doneChan chan struct{}
doneErr error
doneOnce sync.Once
}

// Listen returns a new listener on the SRT protocol on the address with
Expand Down Expand Up @@ -275,7 +278,7 @@ func Listen(network, address string, config Config) (Listener, error) {
}
ln.syncookie = syncookie

ln.doneChan = make(chan error)
ln.doneChan = make(chan struct{})

ln.start = time.Now()

Expand All @@ -288,7 +291,7 @@ func Listen(network, address string, config Config) (Listener, error) {

for {
if ln.isShutdown() {
ln.doneChan <- ErrListenerClosed
ln.markDone(ErrListenerClosed)
return
}

Expand All @@ -300,11 +303,11 @@ func Listen(network, address string, config Config) (Listener, error) {
}

if ln.isShutdown() {
ln.doneChan <- ErrListenerClosed
ln.markDone(ErrListenerClosed)
return
}

ln.doneChan <- err
ln.markDone(err)
return
}

Expand All @@ -331,8 +334,8 @@ func (ln *listener) Accept(acceptFn AcceptFunc) (Conn, ConnType, error) {
}

select {
case err := <-ln.doneChan:
return nil, REJECT, err
case <-ln.doneChan:
return nil, REJECT, ln.error()
case request := <-ln.backlog:
if acceptFn == nil {
ln.reject(request, REJ_PEER)
Expand All @@ -359,8 +362,8 @@ func (ln *listener) Accept(acceptFn AcceptFunc) (Conn, ConnType, error) {
socketId := uint32(time.Since(ln.start).Microseconds())

// Select the largest TSBPD delay advertised by the caller, but at least 120ms
recvTsbpdDelay := uint16(ln.config.ReceiverLatency.Milliseconds())
sendTsbpdDelay := uint16(ln.config.PeerLatency.Milliseconds())
recvTsbpdDelay := uint16(request.config.ReceiverLatency.Milliseconds())
sendTsbpdDelay := uint16(request.config.PeerLatency.Milliseconds())

if request.handshake.Version == 5 {
if request.handshake.SRTHS.SendTSBPDDelay > recvTsbpdDelay {
Expand All @@ -371,17 +374,17 @@ func (ln *listener) Accept(acceptFn AcceptFunc) (Conn, ConnType, error) {
sendTsbpdDelay = request.handshake.SRTHS.RecvTSBPDDelay
}

ln.config.StreamId = request.handshake.StreamId
request.config.StreamId = request.handshake.StreamId
}

ln.config.Passphrase = request.passphrase
request.config.Passphrase = request.passphrase

// Create a new connection
conn := newSRTConn(srtConnConfig{
version: request.handshake.Version,
localAddr: ln.addr,
remoteAddr: request.addr,
config: ln.config,
config: request.config,
start: request.start,
socketId: socketId,
peerSocketId: request.handshake.SRTSocketId,
Expand All @@ -393,7 +396,7 @@ func (ln *listener) Accept(acceptFn AcceptFunc) (Conn, ConnType, error) {
keyBaseEncryption: packet.EvenKeyEncrypted,
onSend: ln.send,
onShutdown: ln.handleShutdown,
logger: ln.config.Logger,
logger: request.config.Logger,
})

ln.log("connection:new", func() string { return fmt.Sprintf("%#08x (%s) %s", conn.SocketId(), conn.StreamId(), mode) })
Expand Down Expand Up @@ -429,6 +432,25 @@ func (ln *listener) Accept(acceptFn AcceptFunc) (Conn, ConnType, error) {
return nil, REJECT, nil
}

// markDone marks the listener as done by closing
// the done channel & sets the error
func (ln *listener) markDone(err error) {
ln.doneOnce.Do(func() {
ln.lock.Lock()
defer ln.lock.Unlock()
ln.doneErr = err
close(ln.doneChan)
})
}

// error returns the error that caused the listener to be done
// if it's nil then the listener is not done
func (ln *listener) error() error {
ln.lock.Lock()
defer ln.lock.Unlock()
return ln.doneErr
}

func (ln *listener) handleShutdown(socketId uint32) {
ln.lock.Lock()
delete(ln.conns, socketId)
Expand Down Expand Up @@ -602,6 +624,9 @@ func (ln *listener) handleHandshake(p packet.Packet) {

cif.PeerIP.FromNetAddr(ln.addr)

// Create a copy of the configuration for the connection
config := ln.config

if cif.HandshakeType == packet.HSTYPE_INDUCTION {
// cif
cif.Version = 5
Expand Down Expand Up @@ -645,13 +670,13 @@ func (ln *listener) handleHandshake(p packet.Packet) {
}

// If the peer has a smaller MTU size, adjust to it
if cif.MaxTransmissionUnitSize < ln.config.MSS {
ln.config.MSS = cif.MaxTransmissionUnitSize
ln.config.PayloadSize = ln.config.MSS - SRT_HEADER_SIZE - UDP_HEADER_SIZE
if cif.MaxTransmissionUnitSize < config.MSS {
config.MSS = cif.MaxTransmissionUnitSize
config.PayloadSize = config.MSS - SRT_HEADER_SIZE - UDP_HEADER_SIZE

if ln.config.PayloadSize < MIN_PAYLOAD_SIZE {
if config.PayloadSize < MIN_PAYLOAD_SIZE {
cif.HandshakeType = packet.HandshakeType(REJ_ROGUE)
ln.log("handshake:recv:error", func() string { return fmt.Sprintf("payload size is too small (%d bytes)", ln.config.PayloadSize) })
ln.log("handshake:recv:error", func() string { return fmt.Sprintf("payload size is too small (%d bytes)", config.PayloadSize) })
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
Expand All @@ -674,10 +699,10 @@ func (ln *listener) handleHandshake(p packet.Packet) {
}
} else if cif.Version == 5 {
// Check if the peer version is sufficient
if cif.SRTHS.SRTVersion < ln.config.MinVersion {
if cif.SRTHS.SRTVersion < config.MinVersion {
cif.HandshakeType = packet.HandshakeType(REJ_VERSION)
ln.log("handshake:recv:error", func() string {
return fmt.Sprintf("peer version insufficient (%#06x), expecting at least %#06x", cif.SRTHS.SRTVersion, ln.config.MinVersion)
return fmt.Sprintf("peer version insufficient (%#06x), expecting at least %#06x", cif.SRTHS.SRTVersion, config.MinVersion)
})
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
Expand Down Expand Up @@ -728,6 +753,7 @@ func (ln *listener) handleHandshake(p packet.Packet) {
start: time.Now(),
socketId: cif.SRTSocketId,
timestamp: p.Header().Timestamp,
config: config,

handshake: cif,
}
Expand Down
53 changes: 53 additions & 0 deletions listen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"net"
"strconv"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -382,3 +383,55 @@ func TestListenHSV5(t *testing.T) {

pc.Close()
}

func TestListenAsync(t *testing.T) {
const parallelCount = 2
ln, err := Listen("srt", "127.0.0.1:6003", DefaultConfig())
require.NoError(t, err)
var (
// All streams are pending
pendingWg sync.WaitGroup
pendingSet sync.Map // Set of which streams are pending
// All streams are connected
connectedWg sync.WaitGroup
// All listener goroutines are stopped
listenerWg sync.WaitGroup
)
listenerWg.Add(parallelCount)
pendingWg.Add(parallelCount)
connectedWg.Add(parallelCount)
for i := 0; i < parallelCount; i++ {
go func() {
defer listenerWg.Done()
for {
_, _, err := ln.Accept(func(req ConnRequest) ConnType {
// Only call Done() if we're the first request for this stream
if _, ok := pendingSet.Swap(req.StreamId(), struct{}{}); !ok {
pendingWg.Done()
}
// Wait for all streams to be pending Before returning
pendingWg.Wait()
return PUBLISH
})
if err == ErrListenerClosed {
return
}
require.NoError(t, err)
}
}()

go func(streamId string) {
config := DefaultConfig()
config.StreamId = streamId
conn, err := Dial("srt", "127.0.0.1:6003", config)
require.NoError(t, err)
connectedWg.Done()
conn.Close()
}(strconv.Itoa(i))
}

// Wait for all streams to be connected
connectedWg.Wait()
ln.Close()
listenerWg.Wait()
}

0 comments on commit 151a03c

Please sign in to comment.