Skip to content

Commit

Permalink
Fix NACK interceptor sendBuffer overridden
Browse files Browse the repository at this point in the history
Keeps a copy of packet in responder_interceptor, fixes #84
  • Loading branch information
davidzhao committed Jan 6, 2022
1 parent 36d6df8 commit d9afd75
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 25 deletions.
1 change: 1 addition & 0 deletions AUTHORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ aler9 <[email protected]>
Antoine Baché <[email protected]>
Atsushi Watanabe <[email protected]>
boks1971 <[email protected]>
David Zhao <[email protected]>
Jonathan Müller <[email protected]>
Mathis Engelbart <[email protected]>
Sean DuBois <[email protected]>
2 changes: 2 additions & 0 deletions pkg/nack/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ import "errors"

// ErrInvalidSize is returned by newReceiveLog/newSendBuffer, when an incorrect buffer size is supplied.
var ErrInvalidSize = errors.New("invalid buffer size")

var errPacketReleased = errors.New("could not retain packet, already released")
21 changes: 14 additions & 7 deletions pkg/nack/responder_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ type ResponderInterceptorFactory struct {
// NewInterceptor constructs a new ResponderInterceptor
func (r *ResponderInterceptorFactory) NewInterceptor(id string) (interceptor.Interceptor, error) {
i := &ResponderInterceptor{
size: 8192,
log: logging.NewDefaultLoggerFactory().NewLogger("nack_responder"),
streams: map[uint32]*localStream{},
size: 8192,
log: logging.NewDefaultLoggerFactory().NewLogger("nack_responder"),
streams: map[uint32]*localStream{},
packetMan: newPacketManager(),
}

for _, opt := range r.opts {
Expand All @@ -38,8 +39,9 @@ func (r *ResponderInterceptorFactory) NewInterceptor(id string) (interceptor.Int
// ResponderInterceptor responds to nack feedback messages
type ResponderInterceptor struct {
interceptor.NoOp
size uint16
log logging.LeveledLogger
size uint16
log logging.LeveledLogger
packetMan *packetManager

streams map[uint32]*localStream
streamsMu sync.Mutex
Expand Down Expand Up @@ -98,7 +100,11 @@ func (n *ResponderInterceptor) BindLocalStream(info *interceptor.StreamInfo, wri
n.streamsMu.Unlock()

return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
sendBuffer.add(&rtp.Packet{Header: *header, Payload: payload})
pkt, err := n.packetMan.NewPacket(header, payload)
if err != nil {
return 0, err
}
sendBuffer.add(pkt)
return writer.Write(header, payload, attributes)
})
}
Expand All @@ -121,9 +127,10 @@ func (n *ResponderInterceptor) resendPackets(nack *rtcp.TransportLayerNack) {
for i := range nack.Nacks {
nack.Nacks[i].Range(func(seq uint16) bool {
if p := stream.sendBuffer.get(seq); p != nil {
if _, err := stream.rtpWriter.Write(&p.Header, p.Payload, interceptor.Attributes{}); err != nil {
if _, err := stream.rtpWriter.Write(p.Header(), p.Payload(), interceptor.Attributes{}); err != nil {
n.log.Warnf("failed resending nacked packet: %+v", err)
}
p.Release()
}

return true
Expand Down
105 changes: 105 additions & 0 deletions pkg/nack/retainable_packet.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package nack

import (
"io"
"sync"

"github.com/pion/rtp"
)

const maxPayloadLen = 1460

type packetManager struct {
headerPool *sync.Pool
payloadPool *sync.Pool
}

func newPacketManager() *packetManager {
return &packetManager{
headerPool: &sync.Pool{
New: func() interface{} {
return &rtp.Header{}
},
},
payloadPool: &sync.Pool{
New: func() interface{} {
buf := make([]byte, maxPayloadLen)
return &buf
},
},
}
}

func (m *packetManager) NewPacket(header *rtp.Header, payload []byte) (*retainablePacket, error) {
if len(payload) > maxPayloadLen {
return nil, io.ErrShortBuffer
}

p := &retainablePacket{
onRelease: m.releasePacket,
// new packets have retain count of 1
count: 1,
}

p.header = m.headerPool.Get().(*rtp.Header)
*p.header = *header

if payload != nil {
p.buffer = m.payloadPool.Get().(*[]byte)
size := copy(*p.buffer, payload)
p.payload = (*p.buffer)[:size]
}

return p, nil
}

func (m *packetManager) releasePacket(header *rtp.Header, payload *[]byte) {
m.headerPool.Put(header)
if payload != nil {
m.payloadPool.Put(payload)
}
}

type retainablePacket struct {
onRelease func(*rtp.Header, *[]byte)

countMu sync.Mutex
count int

header *rtp.Header
buffer *[]byte
payload []byte
}

func (p *retainablePacket) Header() *rtp.Header {
return p.header
}

func (p *retainablePacket) Payload() []byte {
return p.payload
}

func (p *retainablePacket) Retain() error {
if p.count == 0 {
// already released
return errPacketReleased
}
p.countMu.Lock()
p.count++
p.countMu.Unlock()
return nil
}

func (p *retainablePacket) Release() {
p.countMu.Lock()
defer p.countMu.Unlock()
p.count--

if p.count == 0 {
// release back to pool
p.onRelease(p.header, p.buffer)
p.header = nil
p.buffer = nil
p.payload = nil
}
}
38 changes: 28 additions & 10 deletions pkg/nack/send_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@ package nack
import (
"fmt"
"sync"

"github.com/pion/rtp"
)

const (
uint16SizeHalf = 1 << 15
)

type sendBuffer struct {
packets []*rtp.Packet
packets []*retainablePacket
size uint16
lastAdded uint16
started bool
Expand All @@ -36,16 +34,16 @@ func newSendBuffer(size uint16) (*sendBuffer, error) {
}

return &sendBuffer{
packets: make([]*rtp.Packet, size),
packets: make([]*retainablePacket, size),
size: size,
}, nil
}

func (s *sendBuffer) add(packet *rtp.Packet) {
func (s *sendBuffer) add(packet *retainablePacket) {
s.m.Lock()
defer s.m.Unlock()

seq := packet.SequenceNumber
seq := packet.Header().SequenceNumber
if !s.started {
s.packets[seq%s.size] = packet
s.lastAdded = seq
Expand All @@ -58,15 +56,25 @@ func (s *sendBuffer) add(packet *rtp.Packet) {
return
} else if diff < uint16SizeHalf {
for i := s.lastAdded + 1; i != seq; i++ {
s.packets[i%s.size] = nil
idx := i % s.size
prevPacket := s.packets[idx]
if prevPacket != nil {
prevPacket.Release()
}
s.packets[idx] = nil
}
}

s.packets[seq%s.size] = packet
idx := seq % s.size
prevPacket := s.packets[idx]
if prevPacket != nil {
prevPacket.Release()
}
s.packets[idx] = packet
s.lastAdded = seq
}

func (s *sendBuffer) get(seq uint16) *rtp.Packet {
func (s *sendBuffer) get(seq uint16) *retainablePacket {
s.m.RLock()
defer s.m.RUnlock()

Expand All @@ -79,5 +87,15 @@ func (s *sendBuffer) get(seq uint16) *rtp.Packet {
return nil
}

return s.packets[seq%s.size]
pkt := s.packets[seq%s.size]
if pkt != nil {
if pkt.Header().SequenceNumber != seq {
return nil
}
// already released
if err := pkt.Retain(); err != nil {
return nil
}
}
return pkt
}
52 changes: 44 additions & 8 deletions pkg/nack/send_buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,23 @@ import (
"testing"

"github.com/pion/rtp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestSendBuffer(t *testing.T) {
pm := newPacketManager()
for _, start := range []uint16{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 511, 512, 513, 32767, 32768, 32769, 65527, 65528, 65529, 65530, 65531, 65532, 65533, 65534, 65535} {
start := start

sb, err := newSendBuffer(8)
assert.NoError(t, err)
require.NoError(t, err)

add := func(nums ...uint16) {
for _, n := range nums {
seq := start + n
sb.add(&rtp.Packet{Header: rtp.Header{SequenceNumber: seq}})
pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: seq}, nil)
require.NoError(t, err)
sb.add(pkt)
}
}

Expand All @@ -30,9 +33,10 @@ func TestSendBuffer(t *testing.T) {
t.Errorf("packet not found: %d", seq)
continue
}
if packet.SequenceNumber != seq {
t.Errorf("packet for %d returned with incorrect SequenceNumber: %d", seq, packet.SequenceNumber)
if packet.Header().SequenceNumber != seq {
t.Errorf("packet for %d returned with incorrect SequenceNumber: %d", seq, packet.Header().SequenceNumber)
}
packet.Release()
}
}
assertNOTGet := func(nums ...uint16) {
Expand All @@ -41,7 +45,7 @@ func TestSendBuffer(t *testing.T) {
seq := start + n
packet := sb.get(seq)
if packet != nil {
t.Errorf("packet found for %d: %d", seq, packet.SequenceNumber)
t.Errorf("packet found for %d: %d", seq, packet.Header().SequenceNumber)
}
}
}
Expand All @@ -63,20 +67,52 @@ func TestSendBuffer(t *testing.T) {
}
}

func TestSendBuffer_Overridden(t *testing.T) {
// override original packet content and get
pm := newPacketManager()
sb, err := newSendBuffer(1)
require.NoError(t, err)
require.Equal(t, uint16(1), sb.size)

originalBytes := []byte("originalContent")
pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: 1}, originalBytes)
require.NoError(t, err)
sb.add(pkt)

// change payload
copy(originalBytes, "altered")
retrieved := sb.get(1)
require.NotNil(t, retrieved)
require.Equal(t, "originalContent", string(retrieved.Payload()))
retrieved.Release()
require.Equal(t, 1, retrieved.count)

// ensure original packet is released
pkt, err = pm.NewPacket(&rtp.Header{SequenceNumber: 2}, originalBytes)
require.NoError(t, err)
sb.add(pkt)
require.Equal(t, 0, retrieved.count)

require.Nil(t, sb.get(1))
}

// this test is only useful when being run with the race detector, it won't fail otherwise:
//
// go test -race ./pkg/nack/
func TestSendBuffer_Race(t *testing.T) {
pm := newPacketManager()
for _, start := range []uint16{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 511, 512, 513, 32767, 32768, 32769, 65527, 65528, 65529, 65530, 65531, 65532, 65533, 65534, 65535} {
start := start

sb, err := newSendBuffer(8)
assert.NoError(t, err)
require.NoError(t, err)

add := func(nums ...uint16) {
for _, n := range nums {
seq := start + n
sb.add(&rtp.Packet{Header: rtp.Header{SequenceNumber: seq}})
pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: seq}, nil)
require.NoError(t, err)
sb.add(pkt)
}
}

Expand Down

0 comments on commit d9afd75

Please sign in to comment.