From d696a782f7d0d44e82afe0a6bc1fc25084c60397 Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Mon, 31 Jul 2023 19:41:59 +0200 Subject: [PATCH] rtmp: simplify API (#2130) --- internal/core/api_test.go | 9 ++--- internal/core/metrics_test.go | 3 +- internal/core/rtmp_conn.go | 39 ++++++++++++------- internal/core/rtmp_server_test.go | 27 +++++-------- internal/core/rtmp_source.go | 4 +- internal/core/rtmp_source_test.go | 3 +- internal/rtmp/conn.go | 65 +++++++++++++++++++++---------- internal/rtmp/conn_test.go | 13 +++---- internal/rtmp/reader_test.go | 3 +- internal/rtmp/writer_test.go | 3 +- 10 files changed, 93 insertions(+), 76 deletions(-) diff --git a/internal/core/api_test.go b/internal/core/api_test.go index d6c5762e2ec..5d5c3b1e36d 100644 --- a/internal/core/api_test.go +++ b/internal/core/api_test.go @@ -573,9 +573,8 @@ func TestAPIProtocolList(t *testing.T) { }() require.NoError(t, err) defer nconn.Close() - conn := rtmp.NewConn(nconn) - err = conn.InitializeClient(u, true) + conn, err := rtmp.NewClientConn(nconn, u, true) require.NoError(t, err) _, err = rtmp.NewWriter(conn, testFormatH264, nil) @@ -828,9 +827,8 @@ func TestAPIProtocolGet(t *testing.T) { }() require.NoError(t, err) defer nconn.Close() - conn := rtmp.NewConn(nconn) - err = conn.InitializeClient(u, true) + conn, err := rtmp.NewClientConn(nconn, u, true) require.NoError(t, err) _, err = rtmp.NewWriter(conn, testFormatH264, nil) @@ -1150,9 +1148,8 @@ func TestAPIProtocolKick(t *testing.T) { nconn, err := net.Dial("tcp", u.Host) require.NoError(t, err) defer nconn.Close() - conn := rtmp.NewConn(nconn) - err = conn.InitializeClient(u, true) + conn, err := rtmp.NewClientConn(nconn, u, true) require.NoError(t, err) _, err = rtmp.NewWriter(conn, testFormatH264, nil) diff --git a/internal/core/metrics_test.go b/internal/core/metrics_test.go index 3a980f6a098..709f96d37a3 100644 --- a/internal/core/metrics_test.go +++ b/internal/core/metrics_test.go @@ -85,9 +85,8 @@ webrtc_sessions_bytes_sent 0 nconn, err := net.Dial("tcp", u.Host) require.NoError(t, err) defer nconn.Close() - conn := rtmp.NewConn(nconn) - err = conn.InitializeClient(u, true) + conn, err := rtmp.NewClientConn(nconn, u, true) require.NoError(t, err) videoTrack := &formats.H264{ diff --git a/internal/core/rtmp_conn.go b/internal/core/rtmp_conn.go index f3739553b2e..486a5dea281 100644 --- a/internal/core/rtmp_conn.go +++ b/internal/core/rtmp_conn.go @@ -65,7 +65,6 @@ type rtmpConn struct { runOnConnect string runOnConnectRestart bool wg *sync.WaitGroup - conn *rtmp.Conn nconn net.Conn externalCmdPool *externalcmd.Pool pathManager rtmpConnPathManager @@ -75,7 +74,8 @@ type rtmpConn struct { ctxCancel func() uuid uuid.UUID created time.Time - mutex sync.Mutex + mutex sync.RWMutex + conn *rtmp.Conn state rtmpConnState pathName string } @@ -106,7 +106,6 @@ func newRTMPConn( runOnConnect: runOnConnect, runOnConnectRestart: runOnConnectRestart, wg: wg, - conn: rtmp.NewConn(nconn), nconn: nconn, externalCmdPool: externalCmdPool, pathManager: pathManager, @@ -194,18 +193,22 @@ func (c *rtmpConn) run() { func (c *rtmpConn) runReader() error { c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout))) c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout))) - u, publish, err := c.conn.InitializeServer() + conn, u, publish, err := rtmp.NewServerConn(c.nconn) if err != nil { return err } + c.mutex.Lock() + c.conn = conn + c.mutex.Unlock() + if !publish { - return c.runRead(u) + return c.runRead(conn, u) } - return c.runPublish(u) + return c.runPublish(conn, u) } -func (c *rtmpConn) runRead(u *url.URL) error { +func (c *rtmpConn) runRead(conn *rtmp.Conn, u *url.URL) error { pathName, query, rawQuery := pathNameAndQuery(u) res := c.pathManager.addReader(pathAddReaderReq{ @@ -298,7 +301,7 @@ func (c *rtmpConn) runRead(u *url.URL) error { } var err error - w, err = rtmp.NewWriter(c.conn, videoFormat, audioFormat) + w, err = rtmp.NewWriter(conn, videoFormat, audioFormat) if err != nil { return err } @@ -569,7 +572,7 @@ func (c *rtmpConn) setupAudio( return nil, nil } -func (c *rtmpConn) runPublish(u *url.URL) error { +func (c *rtmpConn) runPublish(conn *rtmp.Conn, u *url.URL) error { pathName, query, rawQuery := pathNameAndQuery(u) res := c.pathManager.addPublisher(pathAddPublisherReq{ @@ -601,7 +604,7 @@ func (c *rtmpConn) runPublish(u *url.URL) error { c.pathName = pathName c.mutex.Unlock() - r, err := rtmp.NewReader(c.conn) + r, err := rtmp.NewReader(conn) if err != nil { return err } @@ -731,8 +734,16 @@ func (c *rtmpConn) apiSourceDescribe() pathAPISourceOrReader { } func (c *rtmpConn) apiItem() *apiRTMPConn { - c.mutex.Lock() - defer c.mutex.Unlock() + c.mutex.RLock() + defer c.mutex.RUnlock() + + bytesReceived := uint64(0) + bytesSent := uint64(0) + + if c.conn != nil { + bytesReceived = c.conn.BytesReceived() + bytesSent = c.conn.BytesSent() + } return &apiRTMPConn{ ID: c.uuid, @@ -749,7 +760,7 @@ func (c *rtmpConn) apiItem() *apiRTMPConn { return "idle" }(), Path: c.pathName, - BytesReceived: c.conn.BytesReceived(), - BytesSent: c.conn.BytesSent(), + BytesReceived: bytesReceived, + BytesSent: bytesSent, } } diff --git a/internal/core/rtmp_server_test.go b/internal/core/rtmp_server_test.go index d9f846f5ad2..c782b4dffdc 100644 --- a/internal/core/rtmp_server_test.go +++ b/internal/core/rtmp_server_test.go @@ -34,9 +34,8 @@ func TestRTMPServerRunOnConnect(t *testing.T) { nconn, err := net.Dial("tcp", u.Host) require.NoError(t, err) defer nconn.Close() - conn := rtmp.NewConn(nconn) - err = conn.InitializeClient(u, true) + _, err = rtmp.NewClientConn(nconn, u, true) require.NoError(t, err) time.Sleep(500 * time.Millisecond) @@ -125,9 +124,8 @@ func TestRTMPServer(t *testing.T) { }() require.NoError(t, err) defer nconn1.Close() - conn1 := rtmp.NewConn(nconn1) - err = conn1.InitializeClient(u1, true) + conn1, err := rtmp.NewClientConn(nconn1, u1, true) require.NoError(t, err) videoTrack := &formats.H264{ @@ -175,9 +173,8 @@ func TestRTMPServer(t *testing.T) { }() require.NoError(t, err) defer nconn2.Close() - conn2 := rtmp.NewConn(nconn2) - err = conn2.InitializeClient(u2, false) + conn2, err := rtmp.NewClientConn(nconn2, u2, false) require.NoError(t, err) r, err := rtmp.NewReader(conn2) @@ -237,9 +234,8 @@ func TestRTMPServerAuthFail(t *testing.T) { nconn1, err := net.Dial("tcp", u1.Host) require.NoError(t, err) defer nconn1.Close() - conn1 := rtmp.NewConn(nconn1) - err = conn1.InitializeClient(u1, true) + conn1, err := rtmp.NewClientConn(nconn1, u1, true) require.NoError(t, err) videoTrack := &formats.H264{ @@ -266,9 +262,8 @@ func TestRTMPServerAuthFail(t *testing.T) { nconn2, err := net.Dial("tcp", u2.Host) require.NoError(t, err) defer nconn2.Close() - conn2 := rtmp.NewConn(nconn2) - err = conn2.InitializeClient(u2, false) + conn2, err := rtmp.NewClientConn(nconn2, u2, false) require.NoError(t, err) _, err = rtmp.NewReader(conn2) @@ -291,9 +286,8 @@ func TestRTMPServerAuthFail(t *testing.T) { nconn1, err := net.Dial("tcp", u1.Host) require.NoError(t, err) defer nconn1.Close() - conn1 := rtmp.NewConn(nconn1) - err = conn1.InitializeClient(u1, true) + conn1, err := rtmp.NewClientConn(nconn1, u1, true) require.NoError(t, err) videoTrack := &formats.H264{ @@ -320,9 +314,8 @@ func TestRTMPServerAuthFail(t *testing.T) { nconn2, err := net.Dial("tcp", u2.Host) require.NoError(t, err) defer nconn2.Close() - conn2 := rtmp.NewConn(nconn2) - err = conn2.InitializeClient(u2, false) + conn2, err := rtmp.NewClientConn(nconn2, u2, false) require.NoError(t, err) _, err = rtmp.NewReader(conn2) @@ -346,9 +339,8 @@ func TestRTMPServerAuthFail(t *testing.T) { nconn1, err := net.Dial("tcp", u1.Host) require.NoError(t, err) defer nconn1.Close() - conn1 := rtmp.NewConn(nconn1) - err = conn1.InitializeClient(u1, true) + conn1, err := rtmp.NewClientConn(nconn1, u1, true) require.NoError(t, err) videoTrack := &formats.H264{ @@ -375,9 +367,8 @@ func TestRTMPServerAuthFail(t *testing.T) { nconn2, err := net.Dial("tcp", u2.Host) require.NoError(t, err) defer nconn2.Close() - conn2 := rtmp.NewConn(nconn2) - err = conn2.InitializeClient(u2, false) + conn2, err := rtmp.NewClientConn(nconn2, u2, false) require.NoError(t, err) _, err = rtmp.NewReader(conn2) diff --git a/internal/core/rtmp_source.go b/internal/core/rtmp_source.go index b7a63da1377..93aead5f9a3 100644 --- a/internal/core/rtmp_source.go +++ b/internal/core/rtmp_source.go @@ -99,11 +99,9 @@ func (s *rtmpSource) run(ctx context.Context, cnf *conf.PathConf, reloadConf cha } func (s *rtmpSource) runReader(u *url.URL, nconn net.Conn) error { - conn := rtmp.NewConn(nconn) - nconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout))) nconn.SetWriteDeadline(time.Now().Add(time.Duration(s.writeTimeout))) - err := conn.InitializeClient(u, false) + conn, err := rtmp.NewClientConn(nconn, u, false) if err != nil { return err } diff --git a/internal/core/rtmp_source_test.go b/internal/core/rtmp_source_test.go index a308c03db36..49d198720b4 100644 --- a/internal/core/rtmp_source_test.go +++ b/internal/core/rtmp_source_test.go @@ -52,9 +52,8 @@ func TestRTMPSource(t *testing.T) { nconn, err := ln.Accept() require.NoError(t, err) defer nconn.Close() - conn := rtmp.NewConn(nconn) - _, _, err = conn.InitializeServer() + conn, _, _, err := rtmp.NewServerConn(nconn) require.NoError(t, err) videoTrack := &formats.H264{ diff --git a/internal/rtmp/conn.go b/internal/rtmp/conn.go index 2a27a4df45f..3bcc736a478 100644 --- a/internal/rtmp/conn.go +++ b/internal/rtmp/conn.go @@ -139,29 +139,21 @@ type Conn struct { mrw *message.ReadWriter } -// NewConn initializes a connection. -func NewConn(rw io.ReadWriter) *Conn { - return &Conn{ +// NewClientConn initializes a client-side connection. +func NewClientConn(rw io.ReadWriter, u *url.URL, publish bool) (*Conn, error) { + c := &Conn{ bc: bytecounter.NewReadWriter(rw), } -} - -// BytesReceived returns the number of bytes received. -func (c *Conn) BytesReceived() uint64 { - return c.bc.Reader.Count() -} -// BytesSent returns the number of bytes sent. -func (c *Conn) BytesSent() uint64 { - return c.bc.Writer.Count() -} + err := c.initializeClient(u, publish) + if err != nil { + return nil, err + } -func (c *Conn) skipInitialization() { - c.mrw = message.NewReadWriter(c.bc, false) + return c, nil } -// InitializeClient performs the initialization of a client-side connection. -func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error { +func (c *Conn) initializeClient(u *url.URL, publish bool) error { connectpath, actionpath := splitPath(u) err := handshake.DoClient(c.bc, false) @@ -219,7 +211,7 @@ func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error { return err } - if !isPublishing { + if !publish { err = c.mrw.Write(&message.CommandAMF0{ ChunkStreamID: 3, Name: "createStream", @@ -322,8 +314,21 @@ func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error { return readCommandResult(c.mrw, 5, "onStatus", resultIsOK1) } -// InitializeServer performs the initialization of a server-side connection. -func (c *Conn) InitializeServer() (*url.URL, bool, error) { +// NewServerConn initializes a server-side connection. +func NewServerConn(rw io.ReadWriter) (*Conn, *url.URL, bool, error) { + c := &Conn{ + bc: bytecounter.NewReadWriter(rw), + } + + u, publish, err := c.initializeServer() + if err != nil { + return nil, nil, false, err + } + + return c, u, publish, nil +} + +func (c *Conn) initializeServer() (*url.URL, bool, error) { err := handshake.DoServer(c.bc, false) if err != nil { return nil, false, err @@ -571,6 +576,26 @@ func (c *Conn) InitializeServer() (*url.URL, bool, error) { } } +func newNoHandshakeConn(rw io.ReadWriter) *Conn { + c := &Conn{ + bc: bytecounter.NewReadWriter(rw), + } + + c.mrw = message.NewReadWriter(c.bc, false) + + return c +} + +// BytesReceived returns the number of bytes received. +func (c *Conn) BytesReceived() uint64 { + return c.bc.Reader.Count() +} + +// BytesSent returns the number of bytes sent. +func (c *Conn) BytesSent() uint64 { + return c.bc.Writer.Count() +} + // Read reads a message. func (c *Conn) Read() (message.Message, error) { return c.mrw.Read() diff --git a/internal/rtmp/conn_test.go b/internal/rtmp/conn_test.go index db583c8f661..acd63d6aaef 100644 --- a/internal/rtmp/conn_test.go +++ b/internal/rtmp/conn_test.go @@ -14,7 +14,7 @@ import ( "github.com/bluenviron/mediamtx/internal/rtmp/message" ) -func TestInitializeClient(t *testing.T) { +func TestNewClientConn(t *testing.T) { for _, ca := range []string{"read", "publish"} { t.Run(ca, func(t *testing.T) { ln, err := net.Listen("tcp", "127.0.0.1:9121") @@ -236,9 +236,8 @@ func TestInitializeClient(t *testing.T) { nconn, err := net.Dial("tcp", u.Host) require.NoError(t, err) defer nconn.Close() - conn := NewConn(nconn) - err = conn.InitializeClient(u, ca == "publish") + conn, err := NewClientConn(nconn, u, ca == "publish") require.NoError(t, err) if ca == "read" { @@ -254,7 +253,7 @@ func TestInitializeClient(t *testing.T) { } } -func TestInitializeServer(t *testing.T) { +func TestNewServerConn(t *testing.T) { for _, ca := range []string{ "read", "publish", @@ -272,9 +271,9 @@ func TestInitializeServer(t *testing.T) { require.NoError(t, err) defer nconn.Close() - conn := NewConn(nconn) - u, isPublishing, err := conn.InitializeServer() + _, u, isPublishing, err := NewServerConn(nconn) require.NoError(t, err) + require.Equal(t, &url.URL{ Scheme: "rtmp", Host: "127.0.0.1:9121", @@ -488,7 +487,7 @@ func BenchmarkRead(b *testing.B) { }) } - conn := NewConn(&buf) + conn := newNoHandshakeConn(&buf) for n := 0; n < b.N; n++ { conn.Read() diff --git a/internal/rtmp/reader_test.go b/internal/rtmp/reader_test.go index 0cc4a5de1bb..01e7c9a7c61 100644 --- a/internal/rtmp/reader_test.go +++ b/internal/rtmp/reader_test.go @@ -536,8 +536,7 @@ func TestReadTracks(t *testing.T) { require.NoError(t, err) } - c := NewConn(&buf) - c.skipInitialization() + c := newNoHandshakeConn(&buf) r, err := NewReader(c) require.NoError(t, err) diff --git a/internal/rtmp/writer_test.go b/internal/rtmp/writer_test.go index ed8bb945aa1..c04778c94ca 100644 --- a/internal/rtmp/writer_test.go +++ b/internal/rtmp/writer_test.go @@ -40,8 +40,7 @@ func TestWriteTracks(t *testing.T) { } var buf bytes.Buffer - c := NewConn(&buf) - c.skipInitialization() + c := newNoHandshakeConn(&buf) _, err := NewWriter(c, videoTrack, audioTrack) require.NoError(t, err)