Skip to content

Commit

Permalink
srt: process connection requests in parallel (#3382) (#3534)
Browse files Browse the repository at this point in the history
  • Loading branch information
aler9 authored Jul 5, 2024
1 parent c4987d0 commit 342c257
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 123 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,5 @@ replace code.cloudfoundry.org/bytefmt => github.com/cloudfoundry/bytefmt v0.0.0-
replace github.com/pion/ice/v2 => github.com/aler9/ice/v2 v2.0.0-20240608212222-2eebc68350c9

replace github.com/pion/webrtc/v3 => github.com/aler9/webrtc/v3 v3.0.0-20240610104456-eaec24056d06

replace github.com/datarhei/gosrt => github.com/aler9/gosrt v0.0.0-20240705192040-d4bc5eaa3ee7
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ github.com/alecthomas/kong v0.9.0 h1:G5diXxc85KvoV2f0ZRVuMsi45IrBgx9zDNGNj165aPA
github.com/alecthomas/kong v0.9.0/go.mod h1:Y47y5gKfHp1hDc7CH7OeXgLIpp+Q2m1Ni0L5s3bI8Os=
github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc=
github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/aler9/gosrt v0.0.0-20240705192040-d4bc5eaa3ee7 h1:4WE1Nez3YyD1CgJfWlnyp+uLLPZOKD5ywWPvwbf/Jp4=
github.com/aler9/gosrt v0.0.0-20240705192040-d4bc5eaa3ee7/go.mod h1:fsOWdLSHUHShHjgi/46h6wjtdQrtnSdAQFnlas8ONxs=
github.com/aler9/ice/v2 v2.0.0-20240608212222-2eebc68350c9 h1:Vax9SzYE68ZYLwFaK7lnCV2ZhX9/YqAJX6xxROPRqEM=
github.com/aler9/ice/v2 v2.0.0-20240608212222-2eebc68350c9/go.mod h1:KXJJcZK7E8WzrBEYnV4UtqEZsGeWfHxsNqhVcVvgjxw=
github.com/aler9/webrtc/v3 v3.0.0-20240610104456-eaec24056d06 h1:WtKhXOpd8lgTeXF3RQVOzkNRuy83ygvWEpMYD2aoY3Q=
Expand Down Expand Up @@ -37,8 +39,6 @@ github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJ
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
github.com/datarhei/gosrt v0.6.0 h1:HrrXAw90V78ok4WMIhX6se1aTHPCn82Sg2hj+PhdmGc=
github.com/datarhei/gosrt v0.6.0/go.mod h1:fsOWdLSHUHShHjgi/46h6wjtdQrtnSdAQFnlas8ONxs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down
106 changes: 30 additions & 76 deletions internal/servers/srt/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,13 @@ type conn struct {
pathName string
query string
sconn srt.Conn

chNew chan srtNewConnReq
chSetConn chan srt.Conn
}

func (c *conn) initialize() {
c.ctx, c.ctxCancel = context.WithCancel(c.parentCtx)

c.created = time.Now()
c.uuid = uuid.New()
c.chNew = make(chan srtNewConnReq)
c.chSetConn = make(chan srt.Conn)

c.Log(logger.Info, "opened")

Expand Down Expand Up @@ -130,36 +125,20 @@ func (c *conn) run() { //nolint:dupl
}

func (c *conn) runInner() error {
var req srtNewConnReq
select {
case req = <-c.chNew:
case <-c.ctx.Done():
return errors.New("terminated")
}

answerSent, err := c.runInner2(req)

if !answerSent {
req.res <- nil
}

return err
}

func (c *conn) runInner2(req srtNewConnReq) (bool, error) {
var streamID streamID
err := streamID.unmarshal(req.connReq.StreamId())
err := streamID.unmarshal(c.connReq.StreamId())
if err != nil {
return false, fmt.Errorf("invalid stream ID '%s': %w", req.connReq.StreamId(), err)
c.connReq.Reject(srt.REJ_PEER)
return fmt.Errorf("invalid stream ID '%s': %w", c.connReq.StreamId(), err)
}

if streamID.mode == streamIDModePublish {
return c.runPublish(req, &streamID)
return c.runPublish(&streamID)
}
return c.runRead(req, &streamID)
return c.runRead(&streamID)
}

func (c *conn) runPublish(req srtNewConnReq, streamID *streamID) (bool, error) {
func (c *conn) runPublish(streamID *streamID) error {
path, err := c.pathManager.AddPublisher(defs.PathAddPublisherReq{
Author: c,
AccessRequest: defs.PathAccessRequest{
Expand All @@ -178,21 +157,24 @@ func (c *conn) runPublish(req srtNewConnReq, streamID *streamID) (bool, error) {
if errors.As(err, &terr) {
// wait some seconds to mitigate brute force attacks
<-time.After(auth.PauseAfterError)
return false, terr
c.connReq.Reject(srt.REJ_PEER)
return terr
}
return false, err
c.connReq.Reject(srt.REJ_PEER)
return err
}

defer path.RemovePublisher(defs.PathRemovePublisherReq{Author: c})

err = srtCheckPassphrase(req.connReq, path.SafeConf().SRTPublishPassphrase)
err = srtCheckPassphrase(c.connReq, path.SafeConf().SRTPublishPassphrase)
if err != nil {
return false, err
c.connReq.Reject(srt.REJ_PEER)
return err
}

sconn, err := c.exchangeRequestWithConn(req)
sconn, err := c.connReq.Accept()
if err != nil {
return true, err
return err
}

c.mutex.Lock()
Expand All @@ -210,12 +192,12 @@ func (c *conn) runPublish(req srtNewConnReq, streamID *streamID) (bool, error) {
select {
case err := <-readerErr:
sconn.Close()
return true, err
return err

case <-c.ctx.Done():
sconn.Close()
<-readerErr
return true, errors.New("terminated")
return errors.New("terminated")
}
}

Expand Down Expand Up @@ -256,7 +238,7 @@ func (c *conn) runPublishReader(sconn srt.Conn, path defs.Path) error {
}
}

func (c *conn) runRead(req srtNewConnReq, streamID *streamID) (bool, error) {
func (c *conn) runRead(streamID *streamID) error {
path, stream, err := c.pathManager.AddReader(defs.PathAddReaderReq{
Author: c,
AccessRequest: defs.PathAccessRequest{
Expand All @@ -274,21 +256,24 @@ func (c *conn) runRead(req srtNewConnReq, streamID *streamID) (bool, error) {
if errors.As(err, &terr) {
// wait some seconds to mitigate brute force attacks
<-time.After(auth.PauseAfterError)
return false, err
c.connReq.Reject(srt.REJ_PEER)
return terr
}
return false, err
c.connReq.Reject(srt.REJ_PEER)
return err
}

defer path.RemoveReader(defs.PathRemoveReaderReq{Author: c})

err = srtCheckPassphrase(req.connReq, path.SafeConf().SRTReadPassphrase)
err = srtCheckPassphrase(c.connReq, path.SafeConf().SRTReadPassphrase)
if err != nil {
return false, err
c.connReq.Reject(srt.REJ_PEER)
return err
}

sconn, err := c.exchangeRequestWithConn(req)
sconn, err := c.connReq.Accept()
if err != nil {
return true, err
return err
}
defer sconn.Close()

Expand All @@ -307,7 +292,7 @@ func (c *conn) runRead(req srtNewConnReq, streamID *streamID) (bool, error) {

err = mpegts.FromStream(stream, writer, bw, sconn, time.Duration(c.writeTimeout))
if err != nil {
return true, err
return err
}

c.Log(logger.Info, "is reading from path '%s', %s",
Expand All @@ -331,41 +316,10 @@ func (c *conn) runRead(req srtNewConnReq, streamID *streamID) (bool, error) {

select {
case <-c.ctx.Done():
return true, fmt.Errorf("terminated")
return fmt.Errorf("terminated")

case err := <-writer.Error():
return true, err
}
}

func (c *conn) exchangeRequestWithConn(req srtNewConnReq) (srt.Conn, error) {
req.res <- c

select {
case sconn := <-c.chSetConn:
return sconn, nil

case <-c.ctx.Done():
return nil, errors.New("terminated")
}
}

// new is called by srtListener through srtServer.
func (c *conn) new(req srtNewConnReq) *conn {
select {
case c.chNew <- req:
return <-req.res

case <-c.ctx.Done():
return nil
}
}

// setConn is called by srtListener .
func (c *conn) setConn(sconn srt.Conn) {
select {
case c.chSetConn <- sconn:
case <-c.ctx.Done():
return err
}
}

Expand Down
17 changes: 2 additions & 15 deletions internal/servers/srt/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,11 @@ func (l *listener) run() {

func (l *listener) runInner() error {
for {
var sconn *conn
conn, _, err := l.ln.Accept(func(req srt.ConnRequest) srt.ConnType {
sconn = l.parent.newConnRequest(req)
if sconn == nil {
return srt.REJECT
}

// currently it's the same to return SUBSCRIBE or PUBLISH
return srt.SUBSCRIBE
})
req, err := l.ln.Accept2()
if err != nil {
return err
}

if conn == nil {
continue
}

sconn.setConn(conn)
l.parent.newConnRequest(req)
}
}
27 changes: 6 additions & 21 deletions internal/servers/srt/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@ func srtMaxPayloadSize(u int) int {
return ((u - 16) / 188) * 188 // 16 = SRT header, 188 = MPEG-TS packet
}

type srtNewConnReq struct {
connReq srt.ConnRequest
res chan *conn
}

type serverAPIConnsListRes struct {
data *defs.APISRTConnList
err error
Expand Down Expand Up @@ -90,7 +85,7 @@ type Server struct {
conns map[*conn]struct{}

// in
chNewConnRequest chan srtNewConnReq
chNewConnRequest chan srt.ConnRequest
chAcceptErr chan error
chCloseConn chan *conn
chAPIConnsList chan serverAPIConnsListReq
Expand All @@ -113,7 +108,7 @@ func (s *Server) Initialize() error {
s.ctx, s.ctxCancel = context.WithCancel(context.Background())

s.conns = make(map[*conn]struct{})
s.chNewConnRequest = make(chan srtNewConnReq)
s.chNewConnRequest = make(chan srt.ConnRequest)
s.chAcceptErr = make(chan error)
s.chCloseConn = make(chan *conn)
s.chAPIConnsList = make(chan serverAPIConnsListReq)
Expand Down Expand Up @@ -165,7 +160,7 @@ outer:
writeTimeout: s.WriteTimeout,
writeQueueSize: s.WriteQueueSize,
udpMaxPayloadSize: s.UDPMaxPayloadSize,
connReq: req.connReq,
connReq: req,
runOnConnect: s.RunOnConnect,
runOnConnectRestart: s.RunOnConnectRestart,
runOnDisconnect: s.RunOnDisconnect,
Expand All @@ -176,7 +171,6 @@ outer:
}
c.initialize()
s.conns[c] = struct{}{}
req.res <- c

case c := <-s.chCloseConn:
delete(s.conns, c)
Expand Down Expand Up @@ -236,20 +230,11 @@ func (s *Server) findConnByUUID(uuid uuid.UUID) *conn {
}

// newConnRequest is called by srtListener.
func (s *Server) newConnRequest(connReq srt.ConnRequest) *conn {
req := srtNewConnReq{
connReq: connReq,
res: make(chan *conn),
}

func (s *Server) newConnRequest(connReq srt.ConnRequest) {
select {
case s.chNewConnRequest <- req:
c := <-req.res

return c.new(req)

case s.chNewConnRequest <- connReq:
case <-s.ctx.Done():
return nil
connReq.Reject(srt.REJ_CLOSE)
}
}

Expand Down
17 changes: 8 additions & 9 deletions internal/staticsources/srt/source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,15 @@ func TestSource(t *testing.T) {
defer ln.Close()

go func() {
conn, _, err := ln.Accept(func(req srt.ConnRequest) srt.ConnType {
require.Equal(t, "sidname", req.StreamId())
err := req.SetPassphrase("ttest1234567")
if err != nil {
return srt.REJECT
}
return srt.SUBSCRIBE
})
req, err := ln.Accept2()
require.NoError(t, err)

require.Equal(t, "sidname", req.StreamId())
err = req.SetPassphrase("ttest1234567")
require.NoError(t, err)

conn, err := req.Accept()
require.NoError(t, err)
require.NotNil(t, conn)
defer conn.Close()

track := &mpegts.Track{
Expand Down

0 comments on commit 342c257

Please sign in to comment.