Skip to content

Commit

Permalink
fix overriding of previously-received RTP packets that leaded to crashes
Browse files Browse the repository at this point in the history
RTP packets were previously take from a buffer pool. This was messing
up the Client, since that buffer pool was used by multiple routines at
once, and was probably messing up the Server too, since packets can be
pushed to different queues and there's no guarantee that these queues
have an overall size less than ReadBufferCount.

This buffer pool is removed; this decreases performance but avoids bugs.
  • Loading branch information
aler9 committed Dec 19, 2022
1 parent b3de3cf commit ffe8c87
Show file tree
Hide file tree
Showing 9 changed files with 290 additions and 270 deletions.
2 changes: 0 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ type Client struct {
medias map[*media.Media]*clientMedia
tcpMediasByChannel map[int]*clientMedia
lastRange *headers.Range
rtpPacketBuffer *rtpPacketMultiBuffer // play
checkStreamTimer *time.Timer
checkStreamInitial bool
tcpLastFrameTime *int64
Expand Down Expand Up @@ -630,7 +629,6 @@ func (c *Client) playRecordStart() {

if c.state == clientStatePlay {
c.keepaliveTimer = time.NewTimer(c.keepalivePeriod)
c.rtpPacketBuffer = newRTPPacketMultiBuffer(uint64(c.ReadBufferCount))

switch *c.effectiveTransport {
case TransportUDP:
Expand Down
254 changes: 140 additions & 114 deletions client_play_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"fmt"
"net"
"strconv"
"strings"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -274,10 +275,16 @@ func TestClientPlay(t *testing.T) {
err = forma.Init()
require.NoError(t, err)

medias := media.Medias{&media.Media{
Type: "application",
Formats: []format.Format{forma},
}}
medias := media.Medias{
&media.Media{
Type: "application",
Formats: []format.Format{forma},
},
&media.Media{
Type: "application",
Formats: []format.Format{forma},
},
}
medias.SetControls()

err = conn.WriteResponse(&base.Response{
Expand All @@ -290,87 +297,92 @@ func TestClientPlay(t *testing.T) {
})
require.NoError(t, err)

req, err = conn.ReadRequest()
require.NoError(t, err)
require.Equal(t, base.Setup, req.Method)
require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value/mediaID=0"), req.URL)
var l1s [2]net.PacketConn
var l2s [2]net.PacketConn
var clientPorts [2]*[2]int

var inTH headers.Transport
err = inTH.Unmarshal(req.Header["Transport"])
require.NoError(t, err)
for i := 0; i < 2; i++ {
req, err = conn.ReadRequest()
require.NoError(t, err)
require.Equal(t, base.Setup, req.Method)
require.Equal(t, mustParseURL(
scheme+"://"+listenIP+":8554/test/stream?param=value/mediaID="+strconv.FormatInt(int64(i), 10)), req.URL)

th := headers.Transport{}
var inTH headers.Transport
err = inTH.Unmarshal(req.Header["Transport"])
require.NoError(t, err)

var l1 net.PacketConn
var l2 net.PacketConn
var th headers.Transport

switch transport {
case "udp":
v := headers.TransportDeliveryUnicast
th.Delivery = &v
th.Protocol = headers.TransportProtocolUDP
th.ClientPorts = inTH.ClientPorts
th.ServerPorts = &[2]int{34556, 34557}
switch transport {
case "udp":
v := headers.TransportDeliveryUnicast
th.Delivery = &v
th.Protocol = headers.TransportProtocolUDP
th.ClientPorts = inTH.ClientPorts
clientPorts[i] = inTH.ClientPorts
th.ServerPorts = &[2]int{34556 + i*2, 34557 + i*2}

l1, err = net.ListenPacket("udp", listenIP+":34556")
require.NoError(t, err)
defer l1.Close()
l1s[i], err = net.ListenPacket("udp", listenIP+":"+strconv.FormatInt(int64(th.ServerPorts[0]), 10))
require.NoError(t, err)
defer l1s[i].Close()

l2, err = net.ListenPacket("udp", listenIP+":34557")
require.NoError(t, err)
defer l2.Close()
l2s[i], err = net.ListenPacket("udp", listenIP+":"+strconv.FormatInt(int64(th.ServerPorts[1]), 10))
require.NoError(t, err)
defer l2s[i].Close()

case "multicast":
v := headers.TransportDeliveryMulticast
th.Delivery = &v
th.Protocol = headers.TransportProtocolUDP
v2 := net.ParseIP("224.1.0.1")
th.Destination = &v2
th.Ports = &[2]int{25000, 25001}
case "multicast":
v := headers.TransportDeliveryMulticast
th.Delivery = &v
th.Protocol = headers.TransportProtocolUDP
v2 := net.ParseIP("224.1.0.1")
th.Destination = &v2
th.Ports = &[2]int{25000 + i*2, 25001 + i*2}

l1s[i], err = net.ListenPacket("udp", "224.0.0.0:"+strconv.FormatInt(int64(th.Ports[0]), 10))
require.NoError(t, err)
defer l1s[i].Close()

l1, err = net.ListenPacket("udp", "224.0.0.0:25000")
require.NoError(t, err)
defer l1.Close()
p := ipv4.NewPacketConn(l1s[i])

p := ipv4.NewPacketConn(l1)
intfs, err := net.Interfaces()
require.NoError(t, err)

intfs, err := net.Interfaces()
require.NoError(t, err)
for _, intf := range intfs {
err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP("224.1.0.1")})
require.NoError(t, err)
}

for _, intf := range intfs {
err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP("224.1.0.1")})
l2s[i], err = net.ListenPacket("udp", "224.0.0.0:25001")
require.NoError(t, err)
}
defer l2s[i].Close()

l2, err = net.ListenPacket("udp", "224.0.0.0:25001")
require.NoError(t, err)
defer l2.Close()
p = ipv4.NewPacketConn(l2s[i])

p = ipv4.NewPacketConn(l2)
intfs, err = net.Interfaces()
require.NoError(t, err)

intfs, err = net.Interfaces()
require.NoError(t, err)
for _, intf := range intfs {
err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP("224.1.0.1")})
require.NoError(t, err)
}

for _, intf := range intfs {
err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP("224.1.0.1")})
require.NoError(t, err)
case "tcp", "tls":
v := headers.TransportDeliveryUnicast
th.Delivery = &v
th.Protocol = headers.TransportProtocolTCP
th.InterleavedIDs = &[2]int{0 + i*2, 1 + i*2}
}

case "tcp", "tls":
v := headers.TransportDeliveryUnicast
th.Delivery = &v
th.Protocol = headers.TransportProtocolTCP
th.InterleavedIDs = &[2]int{0, 1}
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Marshal(),
},
})
require.NoError(t, err)
}

err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Marshal(),
},
})
require.NoError(t, err)

req, err = conn.ReadRequest()
require.NoError(t, err)
require.Equal(t, base.Play, req.Method)
Expand All @@ -382,56 +394,58 @@ func TestClientPlay(t *testing.T) {
})
require.NoError(t, err)

// server -> client (RTP)
switch transport {
case "udp":
l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[0],
})

case "multicast":
l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{
IP: net.ParseIP("224.1.0.1"),
Port: 25000,
})
for i := 0; i < 2; i++ {
// server -> client (RTP)
switch transport {
case "udp":
l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: clientPorts[i][0],
})

case "tcp", "tls":
err := conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 0,
Payload: testRTPPacketMarshaled,
}, make([]byte, 1024))
require.NoError(t, err)
}
case "multicast":
l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{
IP: net.ParseIP("224.1.0.1"),
Port: 25000,
})

// client -> server (RTCP)
switch transport {
case "udp", "multicast":
// skip firewall opening
if transport == "udp" {
buf := make([]byte, 2048)
_, _, err := l2.ReadFrom(buf)
case "tcp", "tls":
err := conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 0 + i*2,
Payload: testRTPPacketMarshaled,
}, make([]byte, 1024))
require.NoError(t, err)
}

buf := make([]byte, 2048)
n, _, err := l2.ReadFrom(buf)
require.NoError(t, err)
packets, err := rtcp.Unmarshal(buf[:n])
require.NoError(t, err)
require.Equal(t, &testRTCPPacket, packets[0])
close(packetRecv)
// client -> server (RTCP)
switch transport {
case "udp", "multicast":
// skip firewall opening
if transport == "udp" {
buf := make([]byte, 2048)
_, _, err := l2s[i].ReadFrom(buf)
require.NoError(t, err)
}

case "tcp", "tls":
f, err := conn.ReadInterleavedFrame()
require.NoError(t, err)
require.Equal(t, 1, f.Channel)
packets, err := rtcp.Unmarshal(f.Payload)
require.NoError(t, err)
require.Equal(t, &testRTCPPacket, packets[0])
close(packetRecv)
buf := make([]byte, 2048)
n, _, err := l2s[i].ReadFrom(buf)
require.NoError(t, err)
packets, err := rtcp.Unmarshal(buf[:n])
require.NoError(t, err)
require.Equal(t, &testRTCPPacket, packets[0])

case "tcp", "tls":
f, err := conn.ReadInterleavedFrame()
require.NoError(t, err)
require.Equal(t, 1+i*2, f.Channel)
packets, err := rtcp.Unmarshal(f.Payload)
require.NoError(t, err)
require.Equal(t, &testRTCPPacket, packets[0])
}
}

close(packetRecv)

req, err = conn.ReadRequest()
require.NoError(t, err)
require.Equal(t, base.Teardown, req.Method)
Expand Down Expand Up @@ -464,16 +478,28 @@ func TestClientPlay(t *testing.T) {
}(),
}

err = readAll(&c,
scheme+"://"+listenIP+":8554/test/stream?param=value",
func(medi *media.Media, forma format.Format, pkt *rtp.Packet) {
require.Equal(t, &testRTPPacket, pkt)
err := c.WritePacketRTCP(medi, &testRTCPPacket)
require.NoError(t, err)
})
u, err := url.Parse(scheme + "://" + listenIP + ":8554/test/stream?param=value")
require.NoError(t, err)

err = c.Start(u.Scheme, u.Host)
require.NoError(t, err)
defer c.Close()

medias, baseURL, _, err := c.Describe(u)
require.NoError(t, err)

err = c.SetupAll(medias, baseURL)
require.NoError(t, err)

c.OnPacketRTPAny(func(medi *media.Media, forma format.Format, pkt *rtp.Packet) {
require.Equal(t, &testRTPPacket, pkt)
err := c.WritePacketRTCP(medi, &testRTCPPacket)
require.NoError(t, err)
})

_, err = c.Play(nil)
require.NoError(t, err)

<-packetRecv
})
}
Expand Down
5 changes: 3 additions & 2 deletions clientmedia.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/pion/rtcp"
"github.com/pion/rtp"

"github.com/aler9/gortsplib/v2/pkg/base"
"github.com/aler9/gortsplib/v2/pkg/media"
Expand Down Expand Up @@ -187,7 +188,7 @@ func (cm *clientMedia) readRTPTCPPlay(payload []byte) error {
now := time.Now()
atomic.StoreInt64(cm.c.tcpLastFrameTime, now.Unix())

pkt := cm.c.rtpPacketBuffer.next()
pkt := &rtp.Packet{}
err := pkt.Unmarshal(payload)
if err != nil {
return err
Expand Down Expand Up @@ -259,7 +260,7 @@ func (cm *clientMedia) readRTPUDPPlay(payload []byte) error {
return nil
}

pkt := cm.c.rtpPacketBuffer.next()
pkt := &rtp.Packet{}
err := pkt.Unmarshal(payload)
if err != nil {
cm.c.OnDecodeError(err)
Expand Down
24 changes: 0 additions & 24 deletions rtppacketmultibuffer.go

This file was deleted.

Loading

0 comments on commit ffe8c87

Please sign in to comment.