Skip to content

Commit

Permalink
add resetFragments() to all fragment-based decoders (#637)
Browse files Browse the repository at this point in the history
  • Loading branch information
aler9 authored Oct 23, 2024
1 parent 2899668 commit f41b196
Show file tree
Hide file tree
Showing 13 changed files with 128 additions and 100 deletions.
23 changes: 11 additions & 12 deletions pkg/format/rtpac3/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,30 +42,32 @@ func (d *Decoder) Init() error {
return nil
}

func (d *Decoder) resetFragments() {
d.fragments = d.fragments[:0]
d.fragmentsSize = 0
}

// Decode decodes frames from a RTP packet.
// It returns the frames and the PTS of the first frame.
func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, error) {
if len(pkt.Payload) < 2 {
d.fragments = d.fragments[:0] // discard pending fragments
d.fragmentsSize = 0
d.resetFragments()
return nil, fmt.Errorf("payload is too short")
}

mbz := pkt.Payload[0] >> 2
ft := pkt.Payload[0] & 0b11

if mbz != 0 {
d.fragments = d.fragments[:0] // discard pending fragments
d.fragmentsSize = 0
d.resetFragments()
return nil, fmt.Errorf("invalid MBZ: %v", mbz)
}

var frames [][]byte

switch ft {
case 0:
d.fragments = d.fragments[:0] // discard pending fragments
d.fragmentsSize = 0
d.resetFragments()
d.firstPacketReceived = true

buf := pkt.Payload[2:]
Expand All @@ -91,8 +93,7 @@ func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, error) {
}

case 1, 2:
d.fragments = d.fragments[:0] // discard pending fragments
d.fragmentsSize = 0
d.resetFragments()

var syncInfo ac3.SyncInfo
err := syncInfo.Unmarshal(pkt.Payload[2:])
Expand Down Expand Up @@ -122,8 +123,7 @@ func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, error) {
d.fragmentsExpected -= le

if d.fragmentsExpected < 0 {
d.fragments = d.fragments[:0] // discard pending fragments
d.fragmentsSize = 0
d.resetFragments()
return nil, fmt.Errorf("fragment is too big")
}

Expand All @@ -134,8 +134,7 @@ func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, error) {
}

frames = [][]byte{joinFragments(d.fragments, d.fragmentsSize)}
d.fragments = d.fragments[:0]
d.fragmentsSize = 0
d.resetFragments()
}

return frames, nil
Expand Down
23 changes: 12 additions & 11 deletions pkg/format/rtpav1/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,21 @@ func (d *Decoder) Init() error {
return nil
}

func (d *Decoder) resetFragments() {
d.fragments = d.fragments[:0]
d.fragmentsSize = 0
}

func (d *Decoder) decodeOBUs(pkt *rtp.Packet) ([][]byte, error) {
var av1header codecs.AV1Packet
_, err := av1header.Unmarshal(pkt.Payload)
if err != nil {
d.fragments = d.fragments[:0] // discard pending fragments
d.fragmentsSize = 0
d.resetFragments()
return nil, fmt.Errorf("invalid header: %w", err)
}

if av1header.Z {
if len(d.fragments) == 0 {
if d.fragmentsSize == 0 {
if !d.firstPacketReceived {
return nil, ErrNonStartingPacketAndNoPrevious
}
Expand All @@ -66,8 +70,7 @@ func (d *Decoder) decodeOBUs(pkt *rtp.Packet) ([][]byte, error) {

d.fragmentsSize += len(av1header.OBUElements[0])
if d.fragmentsSize > av1.MaxTemporalUnitSize {
d.fragments = d.fragments[:0]
d.fragmentsSize = 0
d.resetFragments()
return nil, fmt.Errorf("OBU size (%d) is too big, maximum is %d", d.fragmentsSize, av1.MaxTemporalUnitSize)
}

Expand All @@ -82,17 +85,16 @@ func (d *Decoder) decodeOBUs(pkt *rtp.Packet) ([][]byte, error) {
if len(av1header.OBUElements) > 0 {
if d.fragmentsSize != 0 {
obus = append(obus, joinFragments(d.fragments, d.fragmentsSize))
d.fragments = d.fragments[:0]
d.fragmentsSize = 0
d.resetFragments()
}

if av1header.Y {
elementCount := len(av1header.OBUElements)

d.fragmentsSize += len(av1header.OBUElements[elementCount-1])

if d.fragmentsSize > av1.MaxTemporalUnitSize {
d.fragments = d.fragments[:0]
d.fragmentsSize = 0
d.resetFragments()
return nil, fmt.Errorf("OBU size (%d) is too big, maximum is %d", d.fragmentsSize, av1.MaxTemporalUnitSize)
}

Expand All @@ -103,8 +105,7 @@ func (d *Decoder) decodeOBUs(pkt *rtp.Packet) ([][]byte, error) {
obus = append(obus, av1header.OBUElements...)
} else if !av1header.Y {
obus = append(obus, joinFragments(d.fragments, d.fragmentsSize))
d.fragments = d.fragments[:0]
d.fragmentsSize = 0
d.resetFragments()
}

if len(obus) == 0 {
Expand Down
21 changes: 13 additions & 8 deletions pkg/format/rtph264/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,14 @@ func (d *Decoder) Init() error {
return nil
}

func (d *Decoder) resetFragments() {
d.fragments = d.fragments[:0]
d.fragmentsSize = 0
}

func (d *Decoder) decodeNALUs(pkt *rtp.Packet) ([][]byte, error) {
if len(pkt.Payload) < 1 {
d.fragments = d.fragments[:0] // discard pending fragments
d.resetFragments()
return nil, fmt.Errorf("payload is too short")
}

Expand All @@ -82,7 +87,7 @@ func (d *Decoder) decodeNALUs(pkt *rtp.Packet) ([][]byte, error) {
end := (pkt.Payload[1] >> 6) & 0x01

if start == 1 {
d.fragments = d.fragments[:0] // discard pending fragments
d.resetFragments()

if end != 0 {
return nil, fmt.Errorf("invalid FU-A packet (can't contain both a start and end bit)")
Expand All @@ -97,7 +102,7 @@ func (d *Decoder) decodeNALUs(pkt *rtp.Packet) ([][]byte, error) {
return nil, ErrMorePacketsNeeded
}

if len(d.fragments) == 0 {
if d.fragmentsSize == 0 {
if !d.firstPacketReceived {
return nil, ErrNonStartingPacketAndNoPrevious
}
Expand All @@ -108,7 +113,7 @@ func (d *Decoder) decodeNALUs(pkt *rtp.Packet) ([][]byte, error) {
d.fragmentsSize += len(pkt.Payload[2:])

if d.fragmentsSize > h264.MaxAccessUnitSize {
d.fragments = d.fragments[:0]
d.resetFragments()
return nil, fmt.Errorf("NALU size (%d) is too big, maximum is %d", d.fragmentsSize, h264.MaxAccessUnitSize)
}

Expand All @@ -119,10 +124,10 @@ func (d *Decoder) decodeNALUs(pkt *rtp.Packet) ([][]byte, error) {
}

nalus = [][]byte{joinFragments(d.fragments, d.fragmentsSize)}
d.fragments = d.fragments[:0]
d.resetFragments()

case h264.NALUTypeSTAPA:
d.fragments = d.fragments[:0] // discard pending fragments
d.resetFragments()

payload := pkt.Payload[1:]

Expand Down Expand Up @@ -159,12 +164,12 @@ func (d *Decoder) decodeNALUs(pkt *rtp.Packet) ([][]byte, error) {

case h264.NALUTypeSTAPB, h264.NALUTypeMTAP16,
h264.NALUTypeMTAP24, h264.NALUTypeFUB:
d.fragments = d.fragments[:0] // discard pending fragments
d.resetFragments()
d.firstPacketReceived = true
return nil, fmt.Errorf("packet type not supported (%v)", typ)

default:
d.fragments = d.fragments[:0] // discard pending fragments
d.resetFragments()
d.firstPacketReceived = true
nalus = [][]byte{pkt.Payload}
}
Expand Down
27 changes: 15 additions & 12 deletions pkg/format/rtph265/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ type Decoder struct {
MaxDONDiff int

firstPacketReceived bool
fragmentsSize int
fragments [][]byte
fragmentsSize int

// for Decode()
frameBuffer [][]byte
Expand All @@ -52,9 +52,14 @@ func (d *Decoder) Init() error {
return nil
}

func (d *Decoder) resetFragments() {
d.fragments = d.fragments[:0]
d.fragmentsSize = 0
}

func (d *Decoder) decodeNALUs(pkt *rtp.Packet) ([][]byte, error) {
if len(pkt.Payload) < 2 {
d.fragments = d.fragments[:0] // discard pending fragments
d.resetFragments()
return nil, fmt.Errorf("payload is too short")
}

Expand All @@ -63,7 +68,7 @@ func (d *Decoder) decodeNALUs(pkt *rtp.Packet) ([][]byte, error) {

switch typ {
case h265.NALUType_AggregationUnit:
d.fragments = d.fragments[:0] // discard pending fragments
d.resetFragments()

payload := pkt.Payload[2:]

Expand Down Expand Up @@ -95,15 +100,15 @@ func (d *Decoder) decodeNALUs(pkt *rtp.Packet) ([][]byte, error) {

case h265.NALUType_FragmentationUnit:
if len(pkt.Payload) < 3 {
d.fragments = d.fragments[:0] // discard pending fragments
d.resetFragments()
return nil, fmt.Errorf("payload is too short")
}

start := pkt.Payload[2] >> 7
end := (pkt.Payload[2] >> 6) & 0x01

if start == 1 {
d.fragments = d.fragments[:0] // discard pending fragments
d.resetFragments()

if end != 0 {
return nil, fmt.Errorf("invalid fragmentation unit (can't contain both a start and end bit)")
Expand All @@ -118,7 +123,7 @@ func (d *Decoder) decodeNALUs(pkt *rtp.Packet) ([][]byte, error) {
return nil, ErrMorePacketsNeeded
}

if len(d.fragments) == 0 {
if d.fragmentsSize == 0 {
if !d.firstPacketReceived {
return nil, ErrNonStartingPacketAndNoPrevious
}
Expand All @@ -128,7 +133,7 @@ func (d *Decoder) decodeNALUs(pkt *rtp.Packet) ([][]byte, error) {

d.fragmentsSize += len(pkt.Payload[3:])
if d.fragmentsSize > h265.MaxAccessUnitSize {
d.fragments = d.fragments[:0]
d.resetFragments()
return nil, fmt.Errorf("NALU size (%d) is too big, maximum is %d", d.fragmentsSize, h265.MaxAccessUnitSize)
}

Expand All @@ -139,16 +144,14 @@ func (d *Decoder) decodeNALUs(pkt *rtp.Packet) ([][]byte, error) {
}

nalus = [][]byte{joinFragments(d.fragments, d.fragmentsSize)}
d.fragments = d.fragments[:0]
d.resetFragments()

case h265.NALUType_PACI:
d.fragments = d.fragments[:0] // discard pending fragments
d.firstPacketReceived = true
d.resetFragments()
return nil, fmt.Errorf("PACI packets are not supported (yet)")

default:
d.fragments = d.fragments[:0] // discard pending fragments
d.firstPacketReceived = true
d.resetFragments()
nalus = [][]byte{pkt.Payload}
}

Expand Down
16 changes: 9 additions & 7 deletions pkg/format/rtpmjpeg/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ func joinFragments(fragments [][]byte, size int) []byte {
// Specification: https://datatracker.ietf.org/doc/html/rfc2435
type Decoder struct {
firstPacketReceived bool
fragmentsSize int
fragments [][]byte
fragmentsSize int
firstJpegHeader *headerJPEG
quantizationTables [][]byte
}
Expand All @@ -174,6 +174,11 @@ func (d *Decoder) Init() error {
return nil
}

func (d *Decoder) resetFragments() {
d.fragments = d.fragments[:0]
d.fragmentsSize = 0
}

// Decode decodes an image from a RTP packet.
func (d *Decoder) Decode(pkt *rtp.Packet) ([]byte, error) {
byts := pkt.Payload
Expand All @@ -194,8 +199,7 @@ func (d *Decoder) Decode(pkt *rtp.Packet) ([]byte, error) {
}

if jh.FragmentOffset == 0 {
d.fragments = d.fragments[:0] // discard pending fragments
d.fragmentsSize = 0
d.resetFragments()
d.firstPacketReceived = true

if jh.Quantization >= 128 {
Expand All @@ -219,8 +223,7 @@ func (d *Decoder) Decode(pkt *rtp.Packet) ([]byte, error) {
return nil, ErrNonStartingPacketAndNoPrevious
}

d.fragments = d.fragments[:0] // discard pending fragments
d.fragmentsSize = 0
d.resetFragments()
return nil, fmt.Errorf("received wrong fragment")
}

Expand All @@ -237,8 +240,7 @@ func (d *Decoder) Decode(pkt *rtp.Packet) ([]byte, error) {
}

data := joinFragments(d.fragments, d.fragmentsSize)
d.fragments = d.fragments[:0]
d.fragmentsSize = 0
d.resetFragments()

var buf []byte

Expand Down
Loading

0 comments on commit f41b196

Please sign in to comment.