Skip to content

Commit

Permalink
rtmp: simplify API (#2130)
Browse files Browse the repository at this point in the history
  • Loading branch information
aler9 authored Jul 31, 2023
1 parent 959b017 commit d696a78
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 76 deletions.
9 changes: 3 additions & 6 deletions internal/core/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,9 +573,8 @@ func TestAPIProtocolList(t *testing.T) {
}()
require.NoError(t, err)
defer nconn.Close()
conn := rtmp.NewConn(nconn)

err = conn.InitializeClient(u, true)
conn, err := rtmp.NewClientConn(nconn, u, true)
require.NoError(t, err)

_, err = rtmp.NewWriter(conn, testFormatH264, nil)
Expand Down Expand Up @@ -828,9 +827,8 @@ func TestAPIProtocolGet(t *testing.T) {
}()
require.NoError(t, err)
defer nconn.Close()
conn := rtmp.NewConn(nconn)

err = conn.InitializeClient(u, true)
conn, err := rtmp.NewClientConn(nconn, u, true)
require.NoError(t, err)

_, err = rtmp.NewWriter(conn, testFormatH264, nil)
Expand Down Expand Up @@ -1150,9 +1148,8 @@ func TestAPIProtocolKick(t *testing.T) {
nconn, err := net.Dial("tcp", u.Host)
require.NoError(t, err)
defer nconn.Close()
conn := rtmp.NewConn(nconn)

err = conn.InitializeClient(u, true)
conn, err := rtmp.NewClientConn(nconn, u, true)
require.NoError(t, err)

_, err = rtmp.NewWriter(conn, testFormatH264, nil)
Expand Down
3 changes: 1 addition & 2 deletions internal/core/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,8 @@ webrtc_sessions_bytes_sent 0
nconn, err := net.Dial("tcp", u.Host)
require.NoError(t, err)
defer nconn.Close()
conn := rtmp.NewConn(nconn)

err = conn.InitializeClient(u, true)
conn, err := rtmp.NewClientConn(nconn, u, true)
require.NoError(t, err)

videoTrack := &formats.H264{
Expand Down
39 changes: 25 additions & 14 deletions internal/core/rtmp_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ type rtmpConn struct {
runOnConnect string
runOnConnectRestart bool
wg *sync.WaitGroup
conn *rtmp.Conn
nconn net.Conn
externalCmdPool *externalcmd.Pool
pathManager rtmpConnPathManager
Expand All @@ -75,7 +74,8 @@ type rtmpConn struct {
ctxCancel func()
uuid uuid.UUID
created time.Time
mutex sync.Mutex
mutex sync.RWMutex
conn *rtmp.Conn
state rtmpConnState
pathName string
}
Expand Down Expand Up @@ -106,7 +106,6 @@ func newRTMPConn(
runOnConnect: runOnConnect,
runOnConnectRestart: runOnConnectRestart,
wg: wg,
conn: rtmp.NewConn(nconn),
nconn: nconn,
externalCmdPool: externalCmdPool,
pathManager: pathManager,
Expand Down Expand Up @@ -194,18 +193,22 @@ func (c *rtmpConn) run() {
func (c *rtmpConn) runReader() error {
c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout)))
c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
u, publish, err := c.conn.InitializeServer()
conn, u, publish, err := rtmp.NewServerConn(c.nconn)
if err != nil {
return err
}

c.mutex.Lock()
c.conn = conn
c.mutex.Unlock()

if !publish {
return c.runRead(u)
return c.runRead(conn, u)
}
return c.runPublish(u)
return c.runPublish(conn, u)
}

func (c *rtmpConn) runRead(u *url.URL) error {
func (c *rtmpConn) runRead(conn *rtmp.Conn, u *url.URL) error {
pathName, query, rawQuery := pathNameAndQuery(u)

res := c.pathManager.addReader(pathAddReaderReq{
Expand Down Expand Up @@ -298,7 +301,7 @@ func (c *rtmpConn) runRead(u *url.URL) error {
}

var err error
w, err = rtmp.NewWriter(c.conn, videoFormat, audioFormat)
w, err = rtmp.NewWriter(conn, videoFormat, audioFormat)
if err != nil {
return err
}
Expand Down Expand Up @@ -569,7 +572,7 @@ func (c *rtmpConn) setupAudio(
return nil, nil
}

func (c *rtmpConn) runPublish(u *url.URL) error {
func (c *rtmpConn) runPublish(conn *rtmp.Conn, u *url.URL) error {
pathName, query, rawQuery := pathNameAndQuery(u)

res := c.pathManager.addPublisher(pathAddPublisherReq{
Expand Down Expand Up @@ -601,7 +604,7 @@ func (c *rtmpConn) runPublish(u *url.URL) error {
c.pathName = pathName
c.mutex.Unlock()

r, err := rtmp.NewReader(c.conn)
r, err := rtmp.NewReader(conn)
if err != nil {
return err
}
Expand Down Expand Up @@ -731,8 +734,16 @@ func (c *rtmpConn) apiSourceDescribe() pathAPISourceOrReader {
}

func (c *rtmpConn) apiItem() *apiRTMPConn {
c.mutex.Lock()
defer c.mutex.Unlock()
c.mutex.RLock()
defer c.mutex.RUnlock()

bytesReceived := uint64(0)
bytesSent := uint64(0)

if c.conn != nil {
bytesReceived = c.conn.BytesReceived()
bytesSent = c.conn.BytesSent()
}

return &apiRTMPConn{
ID: c.uuid,
Expand All @@ -749,7 +760,7 @@ func (c *rtmpConn) apiItem() *apiRTMPConn {
return "idle"
}(),
Path: c.pathName,
BytesReceived: c.conn.BytesReceived(),
BytesSent: c.conn.BytesSent(),
BytesReceived: bytesReceived,
BytesSent: bytesSent,
}
}
27 changes: 9 additions & 18 deletions internal/core/rtmp_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ func TestRTMPServerRunOnConnect(t *testing.T) {
nconn, err := net.Dial("tcp", u.Host)
require.NoError(t, err)
defer nconn.Close()
conn := rtmp.NewConn(nconn)

err = conn.InitializeClient(u, true)
_, err = rtmp.NewClientConn(nconn, u, true)
require.NoError(t, err)

time.Sleep(500 * time.Millisecond)
Expand Down Expand Up @@ -125,9 +124,8 @@ func TestRTMPServer(t *testing.T) {
}()
require.NoError(t, err)
defer nconn1.Close()
conn1 := rtmp.NewConn(nconn1)

err = conn1.InitializeClient(u1, true)
conn1, err := rtmp.NewClientConn(nconn1, u1, true)
require.NoError(t, err)

videoTrack := &formats.H264{
Expand Down Expand Up @@ -175,9 +173,8 @@ func TestRTMPServer(t *testing.T) {
}()
require.NoError(t, err)
defer nconn2.Close()
conn2 := rtmp.NewConn(nconn2)

err = conn2.InitializeClient(u2, false)
conn2, err := rtmp.NewClientConn(nconn2, u2, false)
require.NoError(t, err)

r, err := rtmp.NewReader(conn2)
Expand Down Expand Up @@ -237,9 +234,8 @@ func TestRTMPServerAuthFail(t *testing.T) {
nconn1, err := net.Dial("tcp", u1.Host)
require.NoError(t, err)
defer nconn1.Close()
conn1 := rtmp.NewConn(nconn1)

err = conn1.InitializeClient(u1, true)
conn1, err := rtmp.NewClientConn(nconn1, u1, true)
require.NoError(t, err)

videoTrack := &formats.H264{
Expand All @@ -266,9 +262,8 @@ func TestRTMPServerAuthFail(t *testing.T) {
nconn2, err := net.Dial("tcp", u2.Host)
require.NoError(t, err)
defer nconn2.Close()
conn2 := rtmp.NewConn(nconn2)

err = conn2.InitializeClient(u2, false)
conn2, err := rtmp.NewClientConn(nconn2, u2, false)
require.NoError(t, err)

_, err = rtmp.NewReader(conn2)
Expand All @@ -291,9 +286,8 @@ func TestRTMPServerAuthFail(t *testing.T) {
nconn1, err := net.Dial("tcp", u1.Host)
require.NoError(t, err)
defer nconn1.Close()
conn1 := rtmp.NewConn(nconn1)

err = conn1.InitializeClient(u1, true)
conn1, err := rtmp.NewClientConn(nconn1, u1, true)
require.NoError(t, err)

videoTrack := &formats.H264{
Expand All @@ -320,9 +314,8 @@ func TestRTMPServerAuthFail(t *testing.T) {
nconn2, err := net.Dial("tcp", u2.Host)
require.NoError(t, err)
defer nconn2.Close()
conn2 := rtmp.NewConn(nconn2)

err = conn2.InitializeClient(u2, false)
conn2, err := rtmp.NewClientConn(nconn2, u2, false)
require.NoError(t, err)

_, err = rtmp.NewReader(conn2)
Expand All @@ -346,9 +339,8 @@ func TestRTMPServerAuthFail(t *testing.T) {
nconn1, err := net.Dial("tcp", u1.Host)
require.NoError(t, err)
defer nconn1.Close()
conn1 := rtmp.NewConn(nconn1)

err = conn1.InitializeClient(u1, true)
conn1, err := rtmp.NewClientConn(nconn1, u1, true)
require.NoError(t, err)

videoTrack := &formats.H264{
Expand All @@ -375,9 +367,8 @@ func TestRTMPServerAuthFail(t *testing.T) {
nconn2, err := net.Dial("tcp", u2.Host)
require.NoError(t, err)
defer nconn2.Close()
conn2 := rtmp.NewConn(nconn2)

err = conn2.InitializeClient(u2, false)
conn2, err := rtmp.NewClientConn(nconn2, u2, false)
require.NoError(t, err)

_, err = rtmp.NewReader(conn2)
Expand Down
4 changes: 1 addition & 3 deletions internal/core/rtmp_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,9 @@ func (s *rtmpSource) run(ctx context.Context, cnf *conf.PathConf, reloadConf cha
}

func (s *rtmpSource) runReader(u *url.URL, nconn net.Conn) error {
conn := rtmp.NewConn(nconn)

nconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout)))
nconn.SetWriteDeadline(time.Now().Add(time.Duration(s.writeTimeout)))
err := conn.InitializeClient(u, false)
conn, err := rtmp.NewClientConn(nconn, u, false)
if err != nil {
return err
}
Expand Down
3 changes: 1 addition & 2 deletions internal/core/rtmp_source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ func TestRTMPSource(t *testing.T) {
nconn, err := ln.Accept()
require.NoError(t, err)
defer nconn.Close()
conn := rtmp.NewConn(nconn)

_, _, err = conn.InitializeServer()
conn, _, _, err := rtmp.NewServerConn(nconn)
require.NoError(t, err)

videoTrack := &formats.H264{
Expand Down
65 changes: 45 additions & 20 deletions internal/rtmp/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,29 +139,21 @@ type Conn struct {
mrw *message.ReadWriter
}

// NewConn initializes a connection.
func NewConn(rw io.ReadWriter) *Conn {
return &Conn{
// NewClientConn initializes a client-side connection.
func NewClientConn(rw io.ReadWriter, u *url.URL, publish bool) (*Conn, error) {
c := &Conn{
bc: bytecounter.NewReadWriter(rw),
}
}

// BytesReceived returns the number of bytes received.
func (c *Conn) BytesReceived() uint64 {
return c.bc.Reader.Count()
}

// BytesSent returns the number of bytes sent.
func (c *Conn) BytesSent() uint64 {
return c.bc.Writer.Count()
}
err := c.initializeClient(u, publish)
if err != nil {
return nil, err
}

func (c *Conn) skipInitialization() {
c.mrw = message.NewReadWriter(c.bc, false)
return c, nil
}

// InitializeClient performs the initialization of a client-side connection.
func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error {
func (c *Conn) initializeClient(u *url.URL, publish bool) error {
connectpath, actionpath := splitPath(u)

err := handshake.DoClient(c.bc, false)
Expand Down Expand Up @@ -219,7 +211,7 @@ func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error {
return err
}

if !isPublishing {
if !publish {
err = c.mrw.Write(&message.CommandAMF0{
ChunkStreamID: 3,
Name: "createStream",
Expand Down Expand Up @@ -322,8 +314,21 @@ func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error {
return readCommandResult(c.mrw, 5, "onStatus", resultIsOK1)
}

// InitializeServer performs the initialization of a server-side connection.
func (c *Conn) InitializeServer() (*url.URL, bool, error) {
// NewServerConn initializes a server-side connection.
func NewServerConn(rw io.ReadWriter) (*Conn, *url.URL, bool, error) {
c := &Conn{
bc: bytecounter.NewReadWriter(rw),
}

u, publish, err := c.initializeServer()
if err != nil {
return nil, nil, false, err
}

return c, u, publish, nil
}

func (c *Conn) initializeServer() (*url.URL, bool, error) {
err := handshake.DoServer(c.bc, false)
if err != nil {
return nil, false, err
Expand Down Expand Up @@ -571,6 +576,26 @@ func (c *Conn) InitializeServer() (*url.URL, bool, error) {
}
}

func newNoHandshakeConn(rw io.ReadWriter) *Conn {
c := &Conn{
bc: bytecounter.NewReadWriter(rw),
}

c.mrw = message.NewReadWriter(c.bc, false)

return c
}

// BytesReceived returns the number of bytes received.
func (c *Conn) BytesReceived() uint64 {
return c.bc.Reader.Count()
}

// BytesSent returns the number of bytes sent.
func (c *Conn) BytesSent() uint64 {
return c.bc.Writer.Count()
}

// Read reads a message.
func (c *Conn) Read() (message.Message, error) {
return c.mrw.Read()
Expand Down
Loading

0 comments on commit d696a78

Please sign in to comment.