diff --git a/api.go b/api.go index c6010eb526f..f440f0bb358 100644 --- a/api.go +++ b/api.go @@ -4,6 +4,7 @@ package webrtc import ( "github.com/pion/logging" + "github.com/pion/webrtc/v3/pkg/interceptor" ) // API bundles the global functions of the WebRTC and ORTC API. @@ -13,6 +14,7 @@ import ( type API struct { settingEngine *SettingEngine mediaEngine *MediaEngine + interceptor interceptor.Interceptor } // NewAPI Creates a new API object for keeping semi-global settings to WebRTC objects @@ -35,6 +37,10 @@ func NewAPI(options ...func(*API)) *API { a.mediaEngine = &MediaEngine{} } + if a.interceptor == nil { + a.interceptor = &interceptor.NoOp{} + } + return a } @@ -57,3 +63,11 @@ func WithSettingEngine(s SettingEngine) func(a *API) { a.settingEngine = &s } } + +// WithInterceptorRegistry allows providing Interceptors to the API. +// Settings should not be changed after passing the registry to an API. +func WithInterceptorRegistry(interceptorRegistry *InterceptorRegistry) func(a *API) { + return func(a *API) { + a.interceptor = interceptorRegistry.build() + } +} diff --git a/examples/save-to-disk/main.go b/examples/save-to-disk/main.go index bf05bf24634..39f20fdb90d 100644 --- a/examples/save-to-disk/main.go +++ b/examples/save-to-disk/main.go @@ -54,8 +54,15 @@ func main() { panic(err) } + s := webrtc.SettingEngine{} + s.SetSRTPReplayProtectionWindow(8192) // this is needed for nack for now + ir := &webrtc.InterceptorRegistry{} + if err := webrtc.RegisterDefaultInterceptors(&s, &m, ir); err != nil { + panic(err) + } + // Create the API object with the MediaEngine - api := webrtc.NewAPI(webrtc.WithMediaEngine(&m)) + api := webrtc.NewAPI(webrtc.WithSettingEngine(s), webrtc.WithMediaEngine(&m), webrtc.WithInterceptorRegistry(ir)) // Prepare the configuration config := webrtc.Configuration{ diff --git a/go.mod b/go.mod index 1a2d4828878..e7917655721 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/pion/webrtc/v3 -go 1.12 +go 1.13 require ( github.com/pion/datachannel v1.4.21 diff --git a/interceptor_registry.go b/interceptor_registry.go new file mode 100644 index 00000000000..15c5dd54f66 --- /dev/null +++ b/interceptor_registry.go @@ -0,0 +1,59 @@ +// +build !js + +package webrtc + +import ( + "github.com/pion/logging" + "github.com/pion/webrtc/v3/pkg/interceptor" +) + +// InterceptorRegistry is a collector for interceptors. +type InterceptorRegistry struct { + interceptors []interceptor.Interceptor +} + +// Add adds a new Interceptor to the registry. +func (i *InterceptorRegistry) Add(icpr interceptor.Interceptor) { + i.interceptors = append(i.interceptors, icpr) +} + +func (i *InterceptorRegistry) build() interceptor.Interceptor { + if len(i.interceptors) == 0 { + return &interceptor.NoOp{} + } + + return interceptor.NewChain(i.interceptors) +} + +// RegisterDefaultInterceptors will register some useful interceptors. If you want to customize which interceptors are loaded, +// you should copy the code from this method and remove unwanted interceptors. +func RegisterDefaultInterceptors(settingEngine *SettingEngine, mediaEngine *MediaEngine, interceptorRegistry *InterceptorRegistry) error { + loggerFactory := settingEngine.LoggerFactory + if loggerFactory == nil { + loggerFactory = logging.NewDefaultLoggerFactory() + } + + err := ConfigureNack(loggerFactory, mediaEngine, interceptorRegistry) + if err != nil { + return err + } + + return nil +} + +// ConfigureNack will setup everything necessary for handling generating/responding to nack messages. +func ConfigureNack(loggerFactory logging.LoggerFactory, mediaEngine *MediaEngine, interceptorRegistry *InterceptorRegistry) error { + mediaEngine.RegisterFeedback(RTCPFeedback{Type: "nack"}, RTPCodecTypeVideo) + receiverNack, err := interceptor.NewReceiverNack(8192, loggerFactory.NewLogger("receiver_nack")) + if err != nil { + return err + } + interceptorRegistry.Add(receiverNack) + senderNack, err := interceptor.NewSenderNack(8192, loggerFactory.NewLogger("sender_nack")) + if err != nil { + return err + } + interceptorRegistry.Add(senderNack) + + return nil +} diff --git a/interceptor_test.go b/interceptor_test.go new file mode 100644 index 00000000000..04097b90073 --- /dev/null +++ b/interceptor_test.go @@ -0,0 +1,188 @@ +// +build !js + +package webrtc + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/pion/transport/test" + "github.com/pion/webrtc/v3/pkg/interceptor" + "github.com/pion/webrtc/v3/pkg/media" + "github.com/stretchr/testify/assert" +) + +type testInterceptor struct { + t *testing.T + extensionID uint8 + rtcpWriter atomic.Value + lastRTCP atomic.Value + interceptor.NoOp +} + +func (t *testInterceptor) BindLocalStream(_ *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { + return interceptor.RTPWriterFunc(func(p *rtp.Packet, attributes interceptor.Attributes) (int, error) { + // set extension on outgoing packet + p.Header.Extension = true + p.Header.ExtensionProfile = 0xBEDE + assert.NoError(t.t, p.Header.SetExtension(t.extensionID, []byte("write"))) + + return writer.Write(p, attributes) + }) +} + +func (t *testInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { + return interceptor.RTPReaderFunc(func() (*rtp.Packet, interceptor.Attributes, error) { + p, attributes, err := reader.Read() + if err != nil { + return nil, nil, err + } + // set extension on incoming packet + p.Header.Extension = true + p.Header.ExtensionProfile = 0xBEDE + assert.NoError(t.t, p.Header.SetExtension(t.extensionID, []byte("read"))) + + // write back a pli + rtcpWriter := t.rtcpWriter.Load().(interceptor.RTCPWriter) + pli := &rtcp.PictureLossIndication{SenderSSRC: info.SSRC, MediaSSRC: info.SSRC} + _, err = rtcpWriter.Write([]rtcp.Packet{pli}, make(interceptor.Attributes)) + assert.NoError(t.t, err) + + return p, attributes, nil + }) +} + +func (t *testInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader { + return interceptor.RTCPReaderFunc(func() ([]rtcp.Packet, interceptor.Attributes, error) { + pkts, attributes, err := reader.Read() + if err != nil { + return nil, nil, err + } + + t.lastRTCP.Store(pkts[0]) + + return pkts, attributes, nil + }) +} + +func (t *testInterceptor) lastReadRTCP() rtcp.Packet { + p, _ := t.lastRTCP.Load().(rtcp.Packet) + return p +} + +func (t *testInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor.RTCPWriter { + t.rtcpWriter.Store(writer) + return writer +} + +func TestPeerConnection_Interceptor(t *testing.T) { + to := test.TimeOut(time.Second * 20) + defer to.Stop() + + report := test.CheckRoutines(t) + defer report() + + createPC := func(interceptor interceptor.Interceptor) *PeerConnection { + m := &MediaEngine{} + err := m.RegisterDefaultCodecs() + if err != nil { + t.Fatal(err) + } + ir := &InterceptorRegistry{} + ir.Add(interceptor) + pc, err := NewAPI(WithMediaEngine(m), WithInterceptorRegistry(ir)).NewPeerConnection(Configuration{}) + if err != nil { + t.Fatal(err) + } + + return pc + } + + sendInterceptor := &testInterceptor{t: t, extensionID: 1} + senderPC := createPC(sendInterceptor) + receiverPC := createPC(&testInterceptor{t: t, extensionID: 2}) + + track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: "video/vp8"}, "video", "pion") + if err != nil { + t.Fatal(err) + } + + sender, err := senderPC.AddTrack(track) + if err != nil { + t.Fatal(err) + } + + pending := new(int32) + wg := &sync.WaitGroup{} + + wg.Add(1) + *pending++ + receiverPC.OnTrack(func(track *TrackRemote, receiver *RTPReceiver) { + p, readErr := track.ReadRTP() + if readErr != nil { + t.Fatal(readErr) + } + assert.Equal(t, p.Extension, true) + assert.Equal(t, "write", string(p.GetExtension(1))) + assert.Equal(t, "read", string(p.GetExtension(2))) + atomic.AddInt32(pending, -1) + wg.Done() + + for { + _, readErr = track.ReadRTP() + if readErr != nil { + return + } + } + }) + + wg.Add(1) + *pending++ + go func() { + _, readErr := sender.ReadRTCP() + assert.NoError(t, readErr) + atomic.AddInt32(pending, -1) + wg.Done() + + for { + _, readErr = sender.ReadRTCP() + if readErr != nil { + return + } + } + }() + + err = signalPair(senderPC, receiverPC) + if err != nil { + t.Fatal(err) + } + + wg.Add(1) + go func() { + defer wg.Done() + for { + time.Sleep(time.Millisecond * 100) + if routineErr := track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}); routineErr != nil { + t.Error(routineErr) + return + } + + if atomic.LoadInt32(pending) == 0 { + return + } + } + }() + + wg.Wait() + assert.NoError(t, senderPC.Close()) + assert.NoError(t, receiverPC.Close()) + + pli, _ := sendInterceptor.lastReadRTCP().(*rtcp.PictureLossIndication) + if pli == nil || pli.SenderSSRC == 0 { + t.Errorf("pli not found by send interceptor") + } +} diff --git a/interceptor_track_local.go b/interceptor_track_local.go new file mode 100644 index 00000000000..cbe0c78f8ec --- /dev/null +++ b/interceptor_track_local.go @@ -0,0 +1,29 @@ +// +build !js + +package webrtc + +import ( + "sync/atomic" + + "github.com/pion/rtp" + "github.com/pion/webrtc/v3/pkg/interceptor" +) + +type interceptorTrackLocalWriter struct { + TrackLocalWriter + rtpWriter atomic.Value +} + +func (i *interceptorTrackLocalWriter) setRTPWriter(writer interceptor.RTPWriter) { + i.rtpWriter.Store(writer) +} + +func (i *interceptorTrackLocalWriter) WriteRTP(header *rtp.Header, payload []byte) (int, error) { + writer := i.rtpWriter.Load().(interceptor.RTPWriter) + + if writer == nil { + return 0, nil + } + + return writer.Write(&rtp.Packet{Header: *header, Payload: payload}, make(interceptor.Attributes)) +} diff --git a/mediaengine.go b/mediaengine.go index 45d678493f7..b8287ef44a9 100644 --- a/mediaengine.go +++ b/mediaengine.go @@ -80,7 +80,7 @@ func (m *MediaEngine) RegisterDefaultCodecs() error { } } - videoRTCPFeedback := []RTCPFeedback{{"goog-remb", ""}, {"ccm", "fir"}, {"nack", ""}, {"nack", "pli"}} + videoRTCPFeedback := []RTCPFeedback{{"goog-remb", ""}, {"ccm", "fir"}, {"nack", "pli"}} for _, codec := range []RTPCodecParameters{ { RTPCodecCapability: RTPCodecCapability{mimeTypeVP8, 90000, 0, "", videoRTCPFeedback}, @@ -233,6 +233,22 @@ func (m *MediaEngine) RegisterHeaderExtension(extension RTPHeaderExtensionCapabi return nil } +// RegisterFeedback adds feedback mechanism to already registered codecs. +func (m *MediaEngine) RegisterFeedback(feedback RTCPFeedback, typ RTPCodecType) { + switch typ { + case RTPCodecTypeVideo: + for i, v := range m.videoCodecs { + v.RTCPFeedback = append(v.RTCPFeedback, feedback) + m.videoCodecs[i] = v + } + case RTPCodecTypeAudio: + for i, v := range m.audioCodecs { + v.RTCPFeedback = append(v.RTCPFeedback, feedback) + m.audioCodecs[i] = v + } + } +} + // GetHeaderExtensionID returns the negotiated ID for a header extension. // If the Header Extension isn't enabled ok will be false func (m *MediaEngine) GetHeaderExtensionID(extension RTPHeaderExtensionCapability) (val int, audioNegotiated, videoNegotiated bool) { @@ -249,19 +265,19 @@ func (m *MediaEngine) GetHeaderExtensionID(extension RTPHeaderExtensionCapabilit return } -func (m *MediaEngine) getCodecByPayload(payloadType PayloadType) (RTPCodecParameters, error) { +func (m *MediaEngine) getCodecByPayload(payloadType PayloadType) (RTPCodecParameters, RTPCodecType, error) { for _, codec := range m.negotiatedVideoCodecs { if codec.PayloadType == payloadType { - return codec, nil + return codec, RTPCodecTypeVideo, nil } } for _, codec := range m.negotiatedAudioCodecs { if codec.PayloadType == payloadType { - return codec, nil + return codec, RTPCodecTypeAudio, nil } } - return RTPCodecParameters{}, ErrCodecNotFound + return RTPCodecParameters{}, 0, ErrCodecNotFound } func (m *MediaEngine) collectStats(collector *statsReportCollector) { @@ -309,7 +325,7 @@ func (m *MediaEngine) updateCodecParameters(remoteCodec RTPCodecParameters, typ return err } - if _, err = m.getCodecByPayload(PayloadType(payloadType)); err != nil { + if _, _, err = m.getCodecByPayload(PayloadType(payloadType)); err != nil { return nil // not an error, we just ignore this codec we don't support } } @@ -378,8 +394,8 @@ func (m *MediaEngine) updateFromRemoteDescription(desc sdp.SessionDescription) e return err } - for id, extension := range extensions { - if err = m.updateHeaderExtension(extension, id, typ); err != nil { + for extension, id := range extensions { + if err = m.updateHeaderExtension(id, extension, typ); err != nil { return err } } @@ -405,6 +421,39 @@ func (m *MediaEngine) getCodecsByKind(typ RTPCodecType) []RTPCodecParameters { return nil } +func (m *MediaEngine) getRTPParametersByKind(typ RTPCodecType) RTPParameters { + headerExtensions := make([]RTPHeaderExtensionParameter, 0) + for id, e := range m.negotiatedHeaderExtensions { + if e.isAudio && typ == RTPCodecTypeAudio || e.isVideo && typ == RTPCodecTypeVideo { + headerExtensions = append(headerExtensions, RTPHeaderExtensionParameter{ID: id, URI: e.uri}) + } + } + + return RTPParameters{ + HeaderExtensions: headerExtensions, + Codecs: m.getCodecsByKind(typ), + } +} + +func (m *MediaEngine) getRTPParametersByPayloadType(payloadType PayloadType) (RTPParameters, error) { + codec, typ, err := m.getCodecByPayload(payloadType) + if err != nil { + return RTPParameters{}, err + } + + headerExtensions := make([]RTPHeaderExtensionParameter, 0) + for id, e := range m.negotiatedHeaderExtensions { + if e.isAudio && typ == RTPCodecTypeAudio || e.isVideo && typ == RTPCodecTypeVideo { + headerExtensions = append(headerExtensions, RTPHeaderExtensionParameter{ID: id, URI: e.uri}) + } + } + + return RTPParameters{ + HeaderExtensions: headerExtensions, + Codecs: []RTPCodecParameters{codec}, + }, nil +} + func (m *MediaEngine) negotiatedHeaderExtensionsForType(typ RTPCodecType) map[int]mediaEngineHeaderExtension { headerExtensions := map[int]mediaEngineHeaderExtension{} for id, e := range m.negotiatedHeaderExtensions { diff --git a/mediaengine_test.go b/mediaengine_test.go index f3312e876be..b857b2e7549 100644 --- a/mediaengine_test.go +++ b/mediaengine_test.go @@ -63,7 +63,7 @@ a=fmtp:111 minptime=10; useinbandfec=1 assert.False(t, m.negotiatedVideo) assert.True(t, m.negotiatedAudio) - opusCodec, err := m.getCodecByPayload(111) + opusCodec, _, err := m.getCodecByPayload(111) assert.NoError(t, err) assert.Equal(t, opusCodec.MimeType, mimeTypeOpus) }) @@ -85,10 +85,10 @@ a=fmtp:112 minptime=10; useinbandfec=1 assert.False(t, m.negotiatedVideo) assert.True(t, m.negotiatedAudio) - _, err := m.getCodecByPayload(111) + _, _, err := m.getCodecByPayload(111) assert.Error(t, err) - opusCodec, err := m.getCodecByPayload(112) + opusCodec, _, err := m.getCodecByPayload(112) assert.NoError(t, err) assert.Equal(t, opusCodec.MimeType, mimeTypeOpus) }) @@ -110,7 +110,7 @@ a=fmtp:111 minptime=10; useinbandfec=1 assert.False(t, m.negotiatedVideo) assert.True(t, m.negotiatedAudio) - opusCodec, err := m.getCodecByPayload(111) + opusCodec, _, err := m.getCodecByPayload(111) assert.NoError(t, err) assert.Equal(t, opusCodec.MimeType, "audio/OPUS") }) @@ -131,7 +131,7 @@ a=rtpmap:111 opus/48000/2 assert.False(t, m.negotiatedVideo) assert.True(t, m.negotiatedAudio) - opusCodec, err := m.getCodecByPayload(111) + opusCodec, _, err := m.getCodecByPayload(111) assert.NoError(t, err) assert.Equal(t, opusCodec.MimeType, mimeTypeOpus) }) diff --git a/peerconnection.go b/peerconnection.go index aee1360a4b7..49c86ee9e67 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -19,6 +19,7 @@ import ( "github.com/pion/rtcp" "github.com/pion/sdp/v3" "github.com/pion/webrtc/v3/internal/util" + "github.com/pion/webrtc/v3/pkg/interceptor" "github.com/pion/webrtc/v3/pkg/rtcerr" ) @@ -76,6 +77,8 @@ type PeerConnection struct { // A reference to the associated API state used by this connection api *API log logging.LeveledLogger + + interceptorRTCPWriter interceptor.RTCPWriter } // NewPeerConnection creates a peerconnection with the default @@ -119,6 +122,8 @@ func (api *API) NewPeerConnection(configuration Configuration) (*PeerConnection, log: api.settingEngine.LoggerFactory.NewLogger("pc"), } + pc.interceptorRTCPWriter = api.interceptor.BindRTCPWriter(interceptor.RTCPWriterFunc(pc.writeRTCP)) + var err error if err = pc.initConfiguration(configuration); err != nil { return nil, err @@ -1125,7 +1130,7 @@ func (pc *PeerConnection) startReceiver(incoming trackDetails, receiver *RTPRece return } - codec, err := pc.api.mediaEngine.getCodecByPayload(receiver.Track().PayloadType()) + params, err := pc.api.mediaEngine.getRTPParametersByPayloadType(receiver.Track().PayloadType()) if err != nil { pc.log.Warnf("no codec could be found for payloadType %d", receiver.Track().PayloadType()) return @@ -1133,7 +1138,9 @@ func (pc *PeerConnection) startReceiver(incoming trackDetails, receiver *RTPRece receiver.Track().mu.Lock() receiver.Track().kind = receiver.kind - receiver.Track().codec = codec + receiver.Track().codec = params.Codecs[0] + receiver.Track().params = params + receiver.Track().bindInterceptor() receiver.Track().mu.Unlock() pc.onTrack(receiver.Track(), receiver) @@ -1335,7 +1342,7 @@ func (pc *PeerConnection) handleUndeclaredSSRC(rtpStream io.Reader, ssrc SSRC) e continue } - codec, err := pc.api.mediaEngine.getCodecByPayload(payloadType) + params, err := pc.api.mediaEngine.getRTPParametersByPayloadType(payloadType) if err != nil { return err } @@ -1345,7 +1352,7 @@ func (pc *PeerConnection) handleUndeclaredSSRC(rtpStream io.Reader, ssrc SSRC) e continue } - track, err := t.Receiver().receiveForRid(rid, codec, ssrc) + track, err := t.Receiver().receiveForRid(rid, params, ssrc) if err != nil { return err } @@ -1730,28 +1737,33 @@ func (pc *PeerConnection) SetIdentityProvider(provider string) error { return errPeerConnSetIdentityProviderNotImplemented } -// WriteRTCP sends a user provided RTCP packet to the connected peer -// If no peer is connected the packet is discarded +// WriteRTCP sends a user provided RTCP packet to the connected peer. If no peer is connected the +// packet is discarded. It also runs any configured interceptors. func (pc *PeerConnection) WriteRTCP(pkts []rtcp.Packet) error { + _, err := pc.interceptorRTCPWriter.Write(pkts, make(interceptor.Attributes)) + return err +} + +func (pc *PeerConnection) writeRTCP(pkts []rtcp.Packet, _ interceptor.Attributes) (int, error) { raw, err := rtcp.Marshal(pkts) if err != nil { - return err + return 0, err } srtcpSession, err := pc.dtlsTransport.getSRTCPSession() if err != nil { - return nil + return 0, nil } writeStream, err := srtcpSession.OpenWriteStream() if err != nil { - return fmt.Errorf("%w: %v", errPeerConnWriteRTCPOpenWriteStream, err) + return 0, fmt.Errorf("%w: %v", errPeerConnWriteRTCPOpenWriteStream, err) } - if _, err := writeStream.Write(raw); err != nil { - return err + if n, err := writeStream.Write(raw); err != nil { + return n, err } - return nil + return 0, nil } // Close ends the PeerConnection @@ -1775,6 +1787,8 @@ func (pc *PeerConnection) Close() error { // continue the chain the Mux has to be closed. closeErrs := make([]error, 4) + closeErrs = append(closeErrs, pc.api.interceptor.Close()) + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #4) for _, t := range pc.GetTransceivers() { if !t.stopped { diff --git a/pkg/interceptor/chain.go b/pkg/interceptor/chain.go new file mode 100644 index 00000000000..2a3766d328d --- /dev/null +++ b/pkg/interceptor/chain.go @@ -0,0 +1,83 @@ +// +build !js + +package interceptor + +import ( + "github.com/pion/webrtc/v3/internal/util" +) + +// Chain is an interceptor that runs all child interceptors in order. +type Chain struct { + interceptors []Interceptor +} + +// NewChain returns a new Chain interceptor. +func NewChain(interceptors []Interceptor) *Chain { + return &Chain{interceptors: interceptors} +} + +// BindRTCPReader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might +// change in the future. The returned method will be called once per packet batch. +func (i *Chain) BindRTCPReader(reader RTCPReader) RTCPReader { + for _, interceptor := range i.interceptors { + reader = interceptor.BindRTCPReader(reader) + } + + return reader +} + +// BindRTCPWriter lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method +// will be called once per packet batch. +func (i *Chain) BindRTCPWriter(writer RTCPWriter) RTCPWriter { + for _, interceptor := range i.interceptors { + writer = interceptor.BindRTCPWriter(writer) + } + + return writer +} + +// BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method +// will be called once per rtp packet. +func (i *Chain) BindLocalStream(ctx *StreamInfo, writer RTPWriter) RTPWriter { + for _, interceptor := range i.interceptors { + writer = interceptor.BindLocalStream(ctx, writer) + } + + return writer +} + +// UnbindLocalStream is called when the Stream is removed. It can be used to clean up any data related to that track. +func (i *Chain) UnbindLocalStream(ctx *StreamInfo) { + for _, interceptor := range i.interceptors { + interceptor.UnbindLocalStream(ctx) + } +} + +// BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method +// will be called once per rtp packet. +func (i *Chain) BindRemoteStream(ctx *StreamInfo, reader RTPReader) RTPReader { + for _, interceptor := range i.interceptors { + reader = interceptor.BindRemoteStream(ctx, reader) + } + + return reader +} + +// UnbindRemoteStream is called when the Stream is removed. It can be used to clean up any data related to that track. +func (i *Chain) UnbindRemoteStream(ctx *StreamInfo) { + for _, interceptor := range i.interceptors { + interceptor.UnbindRemoteStream(ctx) + } +} + +// Close closes the Interceptor, cleaning up any data if necessary. +func (i *Chain) Close() error { + var errs []error + for _, interceptor := range i.interceptors { + if err := interceptor.Close(); err != nil { + errs = append(errs, err) + } + } + + return util.FlattenErrs(errs) +} diff --git a/pkg/interceptor/interceptor.go b/pkg/interceptor/interceptor.go new file mode 100644 index 00000000000..d168a836559 --- /dev/null +++ b/pkg/interceptor/interceptor.go @@ -0,0 +1,110 @@ +// +build !js + +// Package interceptor contains the Interceptor interface, with some useful interceptors that should be safe to use +// in most cases. +package interceptor + +import ( + "io" + + "github.com/pion/rtcp" + "github.com/pion/rtp" +) + +// Interceptor can be used to add functionality to you PeerConnections by modifying any incoming/outgoing rtp/rtcp +// packets, or sending your own packets as needed. +type Interceptor interface { + + // BindRTCPReader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might + // change in the future. The returned method will be called once per packet batch. + BindRTCPReader(reader RTCPReader) RTCPReader + + // BindRTCPWriter lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method + // will be called once per packet batch. + BindRTCPWriter(writer RTCPWriter) RTCPWriter + + // BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method + // will be called once per rtp packet. + BindLocalStream(info *StreamInfo, writer RTPWriter) RTPWriter + + // UnbindLocalStream is called when the Stream is removed. It can be used to clean up any data related to that track. + UnbindLocalStream(info *StreamInfo) + + // BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method + // will be called once per rtp packet. + BindRemoteStream(info *StreamInfo, reader RTPReader) RTPReader + + // UnbindRemoteStream is called when the Stream is removed. It can be used to clean up any data related to that track. + UnbindRemoteStream(info *StreamInfo) + + io.Closer +} + +// RTPWriter is used by Interceptor.BindLocalStream. +type RTPWriter interface { + // Write a rtp packet + Write(p *rtp.Packet, attributes Attributes) (int, error) +} + +// RTPReader is used by Interceptor.BindRemoteStream. +type RTPReader interface { + // Read a rtp packet + Read() (*rtp.Packet, Attributes, error) +} + +// RTCPWriter is used by Interceptor.BindRTCPWriter. +type RTCPWriter interface { + // Write a batch of rtcp packets + Write(pkts []rtcp.Packet, attributes Attributes) (int, error) +} + +// RTCPReader is used by Interceptor.BindRTCPReader. +type RTCPReader interface { + // Read a batch of rtcp packets + Read() ([]rtcp.Packet, Attributes, error) +} + +// Attributes are a generic key/value store used by interceptors +type Attributes map[interface{}]interface{} + +// RTPWriterFunc is an adapter for RTPWrite interface +type RTPWriterFunc func(p *rtp.Packet, attributes Attributes) (int, error) + +// RTPReaderFunc is an adapter for RTPReader interface +type RTPReaderFunc func() (*rtp.Packet, Attributes, error) + +// RTCPWriterFunc is an adapter for RTCPWriter interface +type RTCPWriterFunc func(pkts []rtcp.Packet, attributes Attributes) (int, error) + +// RTCPReaderFunc is an adapter for RTCPReader interface +type RTCPReaderFunc func() ([]rtcp.Packet, Attributes, error) + +// Write a rtp packet +func (f RTPWriterFunc) Write(p *rtp.Packet, attributes Attributes) (int, error) { + return f(p, attributes) +} + +// Read a rtp packet +func (f RTPReaderFunc) Read() (*rtp.Packet, Attributes, error) { + return f() +} + +// Write a batch of rtcp packets +func (f RTCPWriterFunc) Write(pkts []rtcp.Packet, attributes Attributes) (int, error) { + return f(pkts, attributes) +} + +// Read a batch of rtcp packets +func (f RTCPReaderFunc) Read() ([]rtcp.Packet, Attributes, error) { + return f() +} + +// Get returns the attribute associated with key. +func (a Attributes) Get(key interface{}) interface{} { + return a[key] +} + +// Set sets the attribute associated with key to the given value. +func (a Attributes) Set(key interface{}, val interface{}) { + a[key] = val +} diff --git a/pkg/interceptor/noop.go b/pkg/interceptor/noop.go new file mode 100644 index 00000000000..c8c23c26762 --- /dev/null +++ b/pkg/interceptor/noop.go @@ -0,0 +1,42 @@ +// +build !js + +package interceptor + +// NoOp is an Interceptor that does not modify any packets. It can embedded in other interceptors, so it's +// possible to implement only a subset of the methods. +type NoOp struct{} + +// BindRTCPReader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might +// change in the future. The returned method will be called once per packet batch. +func (i *NoOp) BindRTCPReader(reader RTCPReader) RTCPReader { + return reader +} + +// BindRTCPWriter lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method +// will be called once per packet batch. +func (i *NoOp) BindRTCPWriter(writer RTCPWriter) RTCPWriter { + return writer +} + +// BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method +// will be called once per rtp packet. +func (i *NoOp) BindLocalStream(_ *StreamInfo, writer RTPWriter) RTPWriter { + return writer +} + +// UnbindLocalStream is called when the Stream is removed. It can be used to clean up any data related to that track. +func (i *NoOp) UnbindLocalStream(_ *StreamInfo) {} + +// BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method +// will be called once per rtp packet. +func (i *NoOp) BindRemoteStream(_ *StreamInfo, reader RTPReader) RTPReader { + return reader +} + +// UnbindRemoteStream is called when the Stream is removed. It can be used to clean up any data related to that track. +func (i *NoOp) UnbindRemoteStream(_ *StreamInfo) {} + +// Close closes the Interceptor, cleaning up any data if necessary. +func (i *NoOp) Close() error { + return nil +} diff --git a/pkg/interceptor/receive_log.go b/pkg/interceptor/receive_log.go new file mode 100644 index 00000000000..af04b6065db --- /dev/null +++ b/pkg/interceptor/receive_log.go @@ -0,0 +1,130 @@ +package interceptor + +import ( + "errors" + "strconv" +) + +var ( + allowedReceiveLogSizes map[uint16]bool + invalidReceiveLogSizeError string +) + +func init() { + allowedReceiveLogSizes = make(map[uint16]bool, 15) + invalidReceiveLogSizeError = "invalid ReceiveLog size, must be one of: " + for i := 6; i < 16; i++ { + allowedReceiveLogSizes[1< end (with counting for rollovers) + for i := s.end + 1; i != seq; i++ { + // clear packets between end and seq (these may contain packets from a "size" ago) + s.del(i) + } + s.end = seq + + if s.lastConsecutive+1 == seq { + s.lastConsecutive = seq + } else if seq-s.lastConsecutive > s.size { + s.lastConsecutive = seq - s.size + s.fixLastConsecutive() // there might be valid packets at the beginning of the buffer now + } + } else { + // negative diff, seq < end (with counting for rollovers) + if s.lastConsecutive+1 == seq { + s.lastConsecutive = seq + s.fixLastConsecutive() // there might be other valid packets after seq + } + } + + s.set(seq) +} + +func (s *ReceiveLog) Get(seq uint16) bool { + diff := s.end - seq + if diff >= uint16SizeHalf { + return false + } + + if diff >= s.size { + return false + } + + return s.get(seq) +} + +func (s *ReceiveLog) MissingSeqNumbers(skipLastN uint16) []uint16 { + until := s.end - skipLastN + if until-s.lastConsecutive >= uint16SizeHalf { + // until < s.lastConsecutive (counting for rollover) + return nil + } + + missingPacketSeqNums := make([]uint16, 0) + for i := s.lastConsecutive + 1; i != until+1; i++ { + if !s.get(i) { + missingPacketSeqNums = append(missingPacketSeqNums, i) + } + } + + return missingPacketSeqNums +} + +func (s *ReceiveLog) set(seq uint16) { + pos := seq % s.size + s.packets[pos/64] |= 1 << (pos % 64) +} + +func (s *ReceiveLog) del(seq uint16) { + pos := seq % s.size + s.packets[pos/64] &^= 1 << (pos % 64) +} + +func (s *ReceiveLog) get(seq uint16) bool { + pos := seq % s.size + return (s.packets[pos/64] & (1 << (pos % 64))) != 0 +} + +func (s *ReceiveLog) fixLastConsecutive() { + i := s.lastConsecutive + 1 + for ; i != s.end+1 && s.get(i); i++ { + // find all consecutive packets + } + s.lastConsecutive = i - 1 +} diff --git a/pkg/interceptor/receive_log_test.go b/pkg/interceptor/receive_log_test.go new file mode 100644 index 00000000000..014682039d2 --- /dev/null +++ b/pkg/interceptor/receive_log_test.go @@ -0,0 +1,134 @@ +package interceptor + +import ( + "reflect" + "testing" +) + +func TestReceivedBuffer(t *testing.T) { + for _, start := range []uint16{0, 1, 127, 128, 129, 511, 512, 513, 32767, 32768, 32769, 65407, 65408, 65409, 65534, 65535} { + start := start + + rl, err := NewReceiveLog(128) + if err != nil { + t.Fatalf("%+v", err) + } + + all := func(min uint16, max uint16) []uint16 { + result := make([]uint16, 0) + for i := min; i != max+1; i++ { + result = append(result, i) + } + return result + } + join := func(parts ...[]uint16) []uint16 { + result := make([]uint16, 0) + for _, p := range parts { + result = append(result, p...) + } + return result + } + + add := func(nums ...uint16) { + for _, n := range nums { + seq := start + n + rl.Add(seq) + } + } + + assertGet := func(nums ...uint16) { + t.Helper() + for _, n := range nums { + seq := start + n + if !rl.Get(seq) { + t.Errorf("not found: %d", seq) + } + } + } + assertNOTGet := func(nums ...uint16) { + t.Helper() + for _, n := range nums { + seq := start + n + if rl.Get(seq) { + t.Errorf("packet found for %d", seq) + } + } + } + assertMissing := func(skipLastN uint16, nums []uint16) { + t.Helper() + missing := rl.MissingSeqNumbers(skipLastN) + if missing == nil { + missing = []uint16{} + } + want := make([]uint16, 0, len(nums)) + for _, n := range nums { + want = append(want, start+n) + } + if !reflect.DeepEqual(want, missing) { + t.Errorf("missing want/got %v / %v", want, missing) + } + } + assertLastConsecutive := func(lastConsecutive uint16) { + want := lastConsecutive + start + if rl.lastConsecutive != want { + t.Errorf("invalid lastConsecutive want %d got %d", want, rl.lastConsecutive) + } + } + + add(0) + assertGet(0) + assertMissing(0, []uint16{}) + assertLastConsecutive(0) // first element added + + add(all(1, 127)...) + assertGet(all(1, 127)...) + assertMissing(0, []uint16{}) + assertLastConsecutive(127) + + add(128) + assertGet(128) + assertNOTGet(0) + assertMissing(0, []uint16{}) + assertLastConsecutive(128) + + add(130) + assertGet(130) + assertNOTGet(1, 2, 129) + assertMissing(0, []uint16{129}) + assertLastConsecutive(128) + + add(333) + assertGet(333) + assertNOTGet(all(0, 332)...) + assertMissing(0, all(206, 332)) // all 127 elements missing before 333 + assertMissing(10, all(206, 323)) // skip last 10 packets (324-333) from check + assertLastConsecutive(205) // lastConsecutive is still out of the buffer + + add(329) + assertGet(329) + assertMissing(0, join(all(206, 328), all(330, 332))) + assertMissing(5, join(all(206, 328))) // skip last 5 packets (329-333) from check + assertLastConsecutive(205) + + add(all(207, 320)...) + assertGet(all(207, 320)...) + assertMissing(0, join([]uint16{206}, all(321, 328), all(330, 332))) + assertLastConsecutive(205) + + add(334) + assertGet(334) + assertNOTGet(206) + assertMissing(0, join(all(321, 328), all(330, 332))) + assertLastConsecutive(320) // head of buffer is full of consecutive packages + + add(all(322, 328)...) + assertGet(all(322, 328)...) + assertMissing(0, join([]uint16{321}, all(330, 332))) + assertLastConsecutive(320) + + add(321) + assertGet(321) + assertMissing(0, all(330, 332)) + assertLastConsecutive(329) // after adding a single missing packet, lastConsecutive should jump forward + } +} diff --git a/pkg/interceptor/receiver_nack.go b/pkg/interceptor/receiver_nack.go new file mode 100644 index 00000000000..a3fa1c49c91 --- /dev/null +++ b/pkg/interceptor/receiver_nack.go @@ -0,0 +1,159 @@ +// +build !js + +package interceptor + +import ( + "math/rand" + "sync" + "time" + + "github.com/pion/logging" + "github.com/pion/rtcp" + "github.com/pion/rtp" +) + +// ReceiverNACK interceptor generates nack messages. +type ReceiverNACK struct { + NoOp + size uint16 + receiveLogs *sync.Map + m sync.Mutex + wg sync.WaitGroup + close chan struct{} + log logging.LeveledLogger +} + +// NewReceiverNack returns a new ReceiverNACK interceptor +func NewReceiverNack(size uint16, log logging.LeveledLogger) (*ReceiverNACK, error) { + _, err := NewReceiveLog(size) + if err != nil { + return nil, err + } + + return &ReceiverNACK{ + NoOp: NoOp{}, + size: size, + receiveLogs: &sync.Map{}, + close: make(chan struct{}), + log: log, + }, nil +} + +// BindRTCPWriter lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method +// will be called once per packet batch. +func (n *ReceiverNACK) BindRTCPWriter(writer RTCPWriter) RTCPWriter { + go n.loop(writer) + + return writer +} + +// BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method +// will be called once per rtp packet. +func (n *ReceiverNACK) BindRemoteStream(info *StreamInfo, reader RTPReader) RTPReader { + hasNack := false + for _, fb := range info.RTCPFeedback { + if fb.Type == "nack" && fb.Parameter == "" { + hasNack = true + } + } + + if !hasNack { + return reader + } + + // error is already checked in NewReceiverNack + receiveLog, _ := NewReceiveLog(n.size) + n.receiveLogs.Store(info.SSRC, receiveLog) + + return RTPReaderFunc(func() (*rtp.Packet, Attributes, error) { + p, attr, err := reader.Read() + if err != nil { + return nil, nil, err + } + + receiveLog.Add(p.SequenceNumber) + + return p, attr, nil + }) +} + +// UnbindLocalStream is called when the Stream is removed. It can be used to clean up any data related to that track. +func (n *ReceiverNACK) UnbindLocalStream(info *StreamInfo) { + n.receiveLogs.Delete(info.SSRC) +} + +func (n *ReceiverNACK) Close() error { + defer n.wg.Wait() + n.m.Lock() + defer n.m.Unlock() + + select { + case <-n.close: + // already closed + return nil + default: + } + + close(n.close) + + return nil +} + +func (n *ReceiverNACK) loop(rtcpWriter RTCPWriter) { + defer n.wg.Done() + + senderSSRC := rand.Uint32() + + ticker := time.NewTicker(time.Millisecond * 100) + for { + select { + case <-ticker.C: + n.receiveLogs.Range(func(key, value interface{}) bool { + ssrc := key.(uint32) + receiveLog := value.(*ReceiveLog) + + missing := receiveLog.MissingSeqNumbers(10) + if len(missing) == 0 { + return true + } + + nack := &rtcp.TransportLayerNack{ + SenderSSRC: senderSSRC, + MediaSSRC: ssrc, + Nacks: nackPairs(missing), + } + + _, err := rtcpWriter.Write([]rtcp.Packet{nack}, Attributes{}) + if err != nil { + n.log.Warnf("failed sending nack: %+v", err) + } + + return true + }) + + case <-n.close: + return + } + } +} + +func nackPairs(seqNums []uint16) []rtcp.NackPair { + pairs := make([]rtcp.NackPair, 0) + startSeq := seqNums[0] + nackPair := &rtcp.NackPair{PacketID: startSeq} + for i := 1; i < len(seqNums); i++ { + m := seqNums[i] + + if m-nackPair.PacketID > 16 { + pairs = append(pairs, *nackPair) + nackPair = &rtcp.NackPair{PacketID: m} + continue + } + + nackPair.LostPackets |= 1 << (m - nackPair.PacketID - 1) + } + + pairs = append(pairs, *nackPair) + + return pairs +} diff --git a/pkg/interceptor/send_buffer.go b/pkg/interceptor/send_buffer.go new file mode 100644 index 00000000000..bb14bbf01e3 --- /dev/null +++ b/pkg/interceptor/send_buffer.go @@ -0,0 +1,80 @@ +package interceptor + +import ( + "errors" + "strconv" + + "github.com/pion/rtp" +) + +const ( + uint16SizeHalf = 1 << 15 +) + +var ( + allowedSendBufferSizes map[uint16]bool + invalidSendBufferSizeError string +) + +func init() { + allowedSendBufferSizes = make(map[uint16]bool, 15) + invalidSendBufferSizeError = "invalid sendBuffer size, must be one of: " + for i := 0; i < 16; i++ { + allowedSendBufferSizes[1<= uint16SizeHalf { + return nil + } + + if diff >= s.size { + return nil + } + + return s.packets[seq%s.size] +} diff --git a/pkg/interceptor/send_buffer_test.go b/pkg/interceptor/send_buffer_test.go new file mode 100644 index 00000000000..9e149cb763a --- /dev/null +++ b/pkg/interceptor/send_buffer_test.go @@ -0,0 +1,65 @@ +package interceptor + +import ( + "testing" + + "github.com/pion/rtp" +) + +func TestSendBuffer(t *testing.T) { + for _, start := range []uint16{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 511, 512, 513, 32767, 32768, 32769, 65527, 65528, 65529, 65530, 65531, 65532, 65533, 65534, 65535} { + start := start + + sb, err := NewSendBuffer(8) + if err != nil { + t.Fatalf("%+v", err) + } + + add := func(nums ...uint16) { + for _, n := range nums { + seq := start + n + sb.Add(&rtp.Packet{Header: rtp.Header{SequenceNumber: seq}}) + } + } + + assertGet := func(nums ...uint16) { + t.Helper() + for _, n := range nums { + seq := start + n + packet := sb.Get(seq) + if packet == nil { + t.Errorf("packet not found: %d", seq) + continue + } + if packet.SequenceNumber != seq { + t.Errorf("packet for %d returned with incorrect SequenceNumber: %d", seq, packet.SequenceNumber) + } + } + } + assertNOTGet := func(nums ...uint16) { + t.Helper() + for _, n := range nums { + seq := start + n + packet := sb.Get(seq) + if packet != nil { + t.Errorf("packet found for %d: %d", seq, packet.SequenceNumber) + } + } + } + + add(0, 1, 2, 3, 4, 5, 6, 7) + assertGet(0, 1, 2, 3, 4, 5, 6, 7) + + add(8) + assertGet(8) + assertNOTGet(0) + + add(10) + assertGet(10) + assertNOTGet(1, 2, 9) + + add(22) + assertGet(22) + assertNOTGet(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21) + } +} diff --git a/pkg/interceptor/sender_nack.go b/pkg/interceptor/sender_nack.go new file mode 100644 index 00000000000..f217406b95a --- /dev/null +++ b/pkg/interceptor/sender_nack.go @@ -0,0 +1,125 @@ +package interceptor + +import ( + "sync" + + "github.com/pion/logging" + "github.com/pion/rtcp" + "github.com/pion/rtp" +) + +type SenderNack struct { + NoOp + size uint16 + streams *sync.Map + log logging.LeveledLogger +} + +type senderNackStream struct { + sendBuffer *SendBuffer + rtpWriter RTPWriter +} + +// NewSenderNack returns a new ReceiverNACK interceptor +func NewSenderNack(size uint16, log logging.LeveledLogger) (*SenderNack, error) { + _, err := NewSendBuffer(size) + if err != nil { + return nil, err + } + + return &SenderNack{ + NoOp: NoOp{}, + size: size, + streams: &sync.Map{}, + log: log, + }, nil +} + +// BindRTCPReader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might +// change in the future. The returned method will be called once per packet batch. +func (n *SenderNack) BindRTCPReader(reader RTCPReader) RTCPReader { + return RTCPReaderFunc(func() ([]rtcp.Packet, Attributes, error) { + pkts, attr, err := reader.Read() + if err != nil { + return nil, nil, err + } + + for _, rtcpPacket := range pkts { + nack, ok := rtcpPacket.(*rtcp.TransportLayerNack) + if !ok { + continue + } + + go n.resendPackets(nack) + } + + return pkts, attr, err + }) +} + +// BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method +// will be called once per rtp packet. +func (n *SenderNack) BindLocalStream(info *StreamInfo, writer RTPWriter) RTPWriter { + hasNack := false + for _, fb := range info.RTCPFeedback { + if fb.Type == "nack" && fb.Parameter == "" { + hasNack = true + } + } + + if !hasNack { + return writer + } + + // error is already checked in NewReceiverNack + sendBuffer, _ := NewSendBuffer(n.size) + n.streams.Store(info.SSRC, &senderNackStream{sendBuffer: sendBuffer, rtpWriter: writer}) + + return RTPWriterFunc(func(p *rtp.Packet, attributes Attributes) (int, error) { + sendBuffer.Add(p) + + return writer.Write(p, attributes) + }) +} + +// UnbindLocalStream is called when the Stream is removed. It can be used to clean up any data related to that track. +func (n *SenderNack) UnbindLocalStream(info *StreamInfo) { + n.streams.Delete(info.SSRC) +} + +func (n *SenderNack) resendPackets(nack *rtcp.TransportLayerNack) { + v, ok := n.streams.Load(nack.MediaSSRC) + if !ok { + return + } + + stream := v.(*senderNackStream) + seqNums := nackParsToSequenceNumbers(nack.Nacks) + + for _, seq := range seqNums { + p := stream.sendBuffer.Get(seq) + if p == nil { + continue + } + + _, err := stream.rtpWriter.Write(p, Attributes{}) + if err != nil { + n.log.Warnf("failed resending nacked packet: %+v", err) + } + } +} + +func nackParsToSequenceNumbers(pairs []rtcp.NackPair) []uint16 { + seqs := make([]uint16, 0) + for _, pair := range pairs { + startSeq := pair.PacketID + seqs = append(seqs, startSeq) + for i := 0; i < 16; i++ { + if (pair.LostPackets & (1 << i)) != 0 { + seqs = append(seqs, startSeq+uint16(i)+1) + } + } + } + + return seqs +} diff --git a/pkg/interceptor/streaminfo.go b/pkg/interceptor/streaminfo.go new file mode 100644 index 00000000000..ea90009a49a --- /dev/null +++ b/pkg/interceptor/streaminfo.go @@ -0,0 +1,36 @@ +// +build !js + +package interceptor + +// RTPHeaderExtension represents a negotiated RFC5285 RTP header extension. +type RTPHeaderExtension struct { + URI string + ID int +} + +// StreamInfo is the Context passed when a StreamLocal or StreamRemote has been Binded or Unbinded +type StreamInfo struct { + ID string + Attributes Attributes + SSRC uint32 + PayloadType uint8 + RTPHeaderExtensions []RTPHeaderExtension + MimeType string + ClockRate uint32 + Channels uint16 + SDPFmtpLine string + RTCPFeedback []RTCPFeedback +} + +// RTCPFeedback signals the connection to use additional RTCP packet types. +// https://draft.ortc.org/#dom-rtcrtcpfeedback +type RTCPFeedback struct { + // Type is the type of feedback. + // see: https://draft.ortc.org/#dom-rtcrtcpfeedback + // valid: ack, ccm, nack, goog-remb, transport-cc + Type string + + // The parameter value depends on the type. + // For example, type="nack" parameter="pli" will send Picture Loss Indicator packets. + Parameter string +} diff --git a/rtpcodec.go b/rtpcodec.go index c03abd13d86..2d9b6d7600a 100644 --- a/rtpcodec.go +++ b/rtpcodec.go @@ -57,6 +57,14 @@ type RTPHeaderExtensionCapability struct { URI string } +// RTPHeaderExtensionParameter represents a negotiated RFC5285 RTP header extension. +// +// https://w3c.github.io/webrtc-pc/#dictionary-rtcrtpheaderextensionparameters-members +type RTPHeaderExtensionParameter struct { + URI string + ID int +} + // RTPCodecParameters is a sequence containing the media codecs that an RtpSender // will choose from, as well as entries for RTX, RED and FEC mechanisms. This also // includes the PayloadType that has been negotiated @@ -77,6 +85,14 @@ type RTCRtpCapabilities struct { Codecs []RTPCodecCapability } +// RTPParameters is a list of negotiated codecs and header extensions +// +// https://w3c.github.io/webrtc-pc/#dictionary-rtcrtpparameters-members +type RTPParameters struct { + HeaderExtensions []RTPHeaderExtensionParameter + Codecs []RTPCodecParameters +} + // Do a fuzzy find for a codec in the list of codecs // Used for lookup up a codec in an existing list to find a match func codecParametersFuzzySearch(needle RTPCodecParameters, haystack []RTPCodecParameters) (RTPCodecParameters, error) { diff --git a/rtpreceiver.go b/rtpreceiver.go index 1b37f624aa8..ae0769968e3 100644 --- a/rtpreceiver.go +++ b/rtpreceiver.go @@ -9,6 +9,7 @@ import ( "github.com/pion/rtcp" "github.com/pion/srtp" + "github.com/pion/webrtc/v3/pkg/interceptor" ) // trackStreams maintains a mapping of RTP/RTCP streams to a specific track @@ -31,6 +32,8 @@ type RTPReceiver struct { // A reference to the associated api object api *API + + interceptorRTCPReader interceptor.RTCPReader } // NewRTPReceiver constructs a new RTPReceiver @@ -39,14 +42,17 @@ func (api *API) NewRTPReceiver(kind RTPCodecType, transport *DTLSTransport) (*RT return nil, errRTPReceiverDTLSTransportNil } - return &RTPReceiver{ + r := &RTPReceiver{ kind: kind, transport: transport, api: api, closed: make(chan interface{}), received: make(chan interface{}), tracks: []trackStreams{}, - }, nil + } + r.interceptorRTCPReader = api.interceptor.BindRTCPReader(interceptor.RTCPReaderFunc(r.readRTCP)) + + return r, nil } // Transport returns the currently-configured *DTLSTransport or nil @@ -94,11 +100,12 @@ func (r *RTPReceiver) Receive(parameters RTPReceiveParameters) error { if len(parameters.Encodings) == 1 && parameters.Encodings[0].SSRC != 0 { t := trackStreams{ - track: &TrackRemote{ - kind: r.kind, - ssrc: parameters.Encodings[0].SSRC, - receiver: r, - }, + track: NewTrackRemote( + r.kind, + parameters.Encodings[0].SSRC, + "", + r, + ), } var err error @@ -111,11 +118,12 @@ func (r *RTPReceiver) Receive(parameters RTPReceiveParameters) error { } else { for _, encoding := range parameters.Encodings { r.tracks = append(r.tracks, trackStreams{ - track: &TrackRemote{ - kind: r.kind, - rid: encoding.RID, - receiver: r, - }, + track: NewTrackRemote( + r.kind, + 0, + encoding.RID, + r, + ), }) } } @@ -148,15 +156,27 @@ func (r *RTPReceiver) ReadSimulcast(b []byte, rid string) (n int, err error) { } } -// ReadRTCP is a convenience method that wraps Read and unmarshal for you +// ReadRTCP is a convenience method that wraps Read and unmarshal for you. +// It also runs any configured interceptors. func (r *RTPReceiver) ReadRTCP() ([]rtcp.Packet, error) { + pkts, _, err := r.interceptorRTCPReader.Read() + return pkts, err +} + +// ReadRTCP is a convenience method that wraps Read and unmarshal for you +func (r *RTPReceiver) readRTCP() ([]rtcp.Packet, interceptor.Attributes, error) { b := make([]byte, receiveMTU) i, err := r.Read(b) if err != nil { - return nil, err + return nil, nil, err } - return rtcp.Unmarshal(b[:i]) + pkts, err := rtcp.Unmarshal(b[:i]) + if err != nil { + return nil, nil, err + } + + return pkts, make(interceptor.Attributes), nil } // ReadSimulcastRTCP is a convenience method that wraps ReadSimulcast and unmarshal for you @@ -232,7 +252,7 @@ func (r *RTPReceiver) readRTP(b []byte, reader *TrackRemote) (n int, err error) // receiveForRid is the sibling of Receive expect for RIDs instead of SSRCs // It populates all the internal state for the given RID -func (r *RTPReceiver) receiveForRid(rid string, codec RTPCodecParameters, ssrc SSRC) (*TrackRemote, error) { +func (r *RTPReceiver) receiveForRid(rid string, params RTPParameters, ssrc SSRC) (*TrackRemote, error) { r.mu.Lock() defer r.mu.Unlock() @@ -240,8 +260,10 @@ func (r *RTPReceiver) receiveForRid(rid string, codec RTPCodecParameters, ssrc S if r.tracks[i].track.RID() == rid { r.tracks[i].track.mu.Lock() r.tracks[i].track.kind = r.kind - r.tracks[i].track.codec = codec + r.tracks[i].track.codec = params.Codecs[0] + r.tracks[i].track.params = params r.tracks[i].track.ssrc = ssrc + r.tracks[i].track.bindInterceptor() r.tracks[i].track.mu.Unlock() var err error diff --git a/rtpsender.go b/rtpsender.go index 8cfe97d30a1..b54154fdc83 100644 --- a/rtpsender.go +++ b/rtpsender.go @@ -8,7 +8,9 @@ import ( "github.com/pion/randutil" "github.com/pion/rtcp" + "github.com/pion/rtp" "github.com/pion/srtp" + "github.com/pion/webrtc/v3/pkg/interceptor" ) // RTPSender allows an application to control how a given Track is encoded and transmitted to a remote peer @@ -16,13 +18,12 @@ type RTPSender struct { track TrackLocal rtcpReadStream *srtp.ReadStreamSRTCP - rtpWriteStream *srtp.WriteStreamSRTP + context TrackLocalContext transport *DTLSTransport payloadType PayloadType ssrc SSRC - codec RTPCodecParameters // nolint:godox // TODO(sgotti) remove this when in future we'll avoid replacing @@ -36,6 +37,8 @@ type RTPSender struct { mu sync.RWMutex sendCalled, stopCalled chan interface{} + + interceptorRTCPReader interceptor.RTCPReader } // NewRTPSender constructs a new RTPSender @@ -51,7 +54,7 @@ func (api *API) NewRTPSender(track TrackLocal, transport *DTLSTransport) (*RTPSe return nil, err } - return &RTPSender{ + r := &RTPSender{ track: track, transport: transport, api: api, @@ -59,7 +62,10 @@ func (api *API) NewRTPSender(track TrackLocal, transport *DTLSTransport) (*RTPSe stopCalled: make(chan interface{}), ssrc: SSRC(randutil.NewMathRandomGenerator().Uint32()), id: id, - }, nil + } + r.interceptorRTCPReader = api.interceptor.BindRTCPReader(interceptor.RTCPReaderFunc(r.readRTCP)) + + return r, nil } func (r *RTPSender) isNegotiated() bool { @@ -97,11 +103,7 @@ func (r *RTPSender) ReplaceTrack(track TrackLocal) error { defer r.mu.Unlock() if r.hasSent() { - if err := r.track.Unbind(TrackLocalContext{ - id: r.id, - ssrc: r.ssrc, - writeStream: r.rtpWriteStream, - }); err != nil { + if err := r.track.Unbind(r.context); err != nil { return err } } @@ -111,12 +113,7 @@ func (r *RTPSender) ReplaceTrack(track TrackLocal) error { return nil } - if _, err := track.Bind(TrackLocalContext{ - id: r.id, - codecs: []RTPCodecParameters{r.codec}, - ssrc: r.ssrc, - writeStream: r.rtpWriteStream, - }); err != nil { + if _, err := track.Bind(r.context); err != nil { return err } @@ -148,18 +145,53 @@ func (r *RTPSender) Send(parameters RTPSendParameters) error { return err } - if r.rtpWriteStream, err = srtpSession.OpenWriteStream(); err != nil { + rtpWriteStream, err := srtpSession.OpenWriteStream() + if err != nil { return err } - if r.codec, err = r.track.Bind(TrackLocalContext{ + writeStream := &interceptorTrackLocalWriter{TrackLocalWriter: rtpWriteStream} + + r.context = TrackLocalContext{ id: r.id, - codecs: r.api.mediaEngine.getCodecsByKind(r.track.Kind()), + params: r.api.mediaEngine.getRTPParametersByKind(r.track.Kind()), ssrc: parameters.Encodings.SSRC, - writeStream: r.rtpWriteStream, - }); err != nil { + writeStream: writeStream, + } + + codec, err := r.track.Bind(r.context) + if err != nil { return err } + r.context.params.Codecs = []RTPCodecParameters{codec} + + headerExtensions := make([]interceptor.RTPHeaderExtension, 0, len(r.context.params.HeaderExtensions)) + for _, h := range r.context.params.HeaderExtensions { + headerExtensions = append(headerExtensions, interceptor.RTPHeaderExtension{ID: h.ID, URI: h.URI}) + } + feedbacks := make([]interceptor.RTCPFeedback, 0, len(codec.RTCPFeedback)) + for _, f := range codec.RTCPFeedback { + feedbacks = append(feedbacks, interceptor.RTCPFeedback{Type: f.Type, Parameter: f.Parameter}) + } + info := &interceptor.StreamInfo{ + ID: r.context.id, + Attributes: interceptor.Attributes{}, + SSRC: uint32(r.context.ssrc), + PayloadType: uint8(codec.PayloadType), + RTPHeaderExtensions: headerExtensions, + MimeType: codec.MimeType, + ClockRate: codec.ClockRate, + Channels: codec.Channels, + SDPFmtpLine: codec.SDPFmtpLine, + RTCPFeedback: feedbacks, + } + writeStream.setRTPWriter( + r.api.interceptor.BindLocalStream( + info, + interceptor.RTPWriterFunc(func(p *rtp.Packet, attributes interceptor.Attributes) (int, error) { + return rtpWriteStream.WriteRTP(&p.Header, p.Payload) + }), + )) close(r.sendCalled) return nil @@ -194,15 +226,26 @@ func (r *RTPSender) Read(b []byte) (n int, err error) { } } -// ReadRTCP is a convenience method that wraps Read and unmarshals for you +// ReadRTCP is a convenience method that wraps Read and unmarshals for you. +// It also runs any configured interceptors. func (r *RTPSender) ReadRTCP() ([]rtcp.Packet, error) { + pkts, _, err := r.interceptorRTCPReader.Read() + return pkts, err +} + +func (r *RTPSender) readRTCP() ([]rtcp.Packet, interceptor.Attributes, error) { b := make([]byte, receiveMTU) i, err := r.Read(b) if err != nil { - return nil, err + return nil, nil, err + } + + pkts, err := rtcp.Unmarshal(b[:i]) + if err != nil { + return nil, nil, err } - return rtcp.Unmarshal(b[:i]) + return pkts, make(interceptor.Attributes), nil } // hasSent tells if data has been ever sent for this instance diff --git a/track_local.go b/track_local.go index 1b232b9517e..e6e1da1f490 100644 --- a/track_local.go +++ b/track_local.go @@ -11,10 +11,11 @@ type TrackLocalWriter interface { Write(b []byte) (int, error) } -// TrackLocalContext is the Context passed when a TrackLocal has been Binded/Unbinded from a PeerConnection +// TrackLocalContext is the Context passed when a TrackLocal has been Binded/Unbinded from a PeerConnection, and used +// in Interceptors. type TrackLocalContext struct { id string - codecs []RTPCodecParameters + params RTPParameters ssrc SSRC writeStream TrackLocalWriter } @@ -22,7 +23,13 @@ type TrackLocalContext struct { // CodecParameters returns the negotiated RTPCodecParameters. These are the codecs supported by both // PeerConnections and the SSRC/PayloadTypes func (t *TrackLocalContext) CodecParameters() []RTPCodecParameters { - return t.codecs + return t.params.Codecs +} + +// HeaderExtensions returns the negotiated RTPHeaderExtensionParameters. These are the header extensions supported by +// both PeerConnections and the SSRC/PayloadTypes +func (t *TrackLocalContext) HeaderExtensions() []RTPHeaderExtensionParameter { + return t.params.HeaderExtensions } // SSRC requires the negotiated SSRC of this track diff --git a/track_remote.go b/track_remote.go index dd48e3a8451..9f1f6e972b1 100644 --- a/track_remote.go +++ b/track_remote.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/pion/rtp" + "github.com/pion/webrtc/v3/pkg/interceptor" ) // TrackRemote represents a single inbound source of media @@ -19,10 +20,50 @@ type TrackRemote struct { kind RTPCodecType ssrc SSRC codec RTPCodecParameters + params RTPParameters rid string receiver *RTPReceiver peeked []byte + + interceptorRTPReader interceptor.RTPReader +} + +// NewTrackRemote creates a new TrackRemote. +func NewTrackRemote(kind RTPCodecType, ssrc SSRC, rid string, receiver *RTPReceiver) *TrackRemote { + t := &TrackRemote{ + kind: kind, + ssrc: ssrc, + rid: rid, + receiver: receiver, + } + t.interceptorRTPReader = interceptor.RTPReaderFunc(t.readRTP) + + return t +} + +func (t *TrackRemote) bindInterceptor() { + headerExtensions := make([]interceptor.RTPHeaderExtension, 0, len(t.params.HeaderExtensions)) + for _, h := range t.params.HeaderExtensions { + headerExtensions = append(headerExtensions, interceptor.RTPHeaderExtension{ID: h.ID, URI: h.URI}) + } + feedbacks := make([]interceptor.RTCPFeedback, 0, len(t.codec.RTCPFeedback)) + for _, f := range t.codec.RTCPFeedback { + feedbacks = append(feedbacks, interceptor.RTCPFeedback{Type: f.Type, Parameter: f.Parameter}) + } + info := &interceptor.StreamInfo{ + ID: t.id, + Attributes: interceptor.Attributes{}, + SSRC: uint32(t.ssrc), + PayloadType: uint8(t.payloadType), + RTPHeaderExtensions: headerExtensions, + MimeType: t.codec.MimeType, + ClockRate: t.codec.ClockRate, + Channels: t.codec.Channels, + SDPFmtpLine: t.codec.SDPFmtpLine, + RTCPFeedback: feedbacks, + } + t.interceptorRTPReader = t.receiver.api.interceptor.BindRemoteStream(info, interceptor.RTPReaderFunc(t.readRTP)) } // ID is the unique identifier for this Track. This should be unique for the @@ -125,19 +166,25 @@ func (t *TrackRemote) peek(b []byte) (n int, err error) { return } -// ReadRTP is a convenience method that wraps Read and unmarshals for you +// ReadRTP is a convenience method that wraps Read and unmarshals for you. +// It also runs any configured interceptors. func (t *TrackRemote) ReadRTP() (*rtp.Packet, error) { + p, _, err := t.interceptorRTPReader.Read() + return p, err +} + +func (t *TrackRemote) readRTP() (*rtp.Packet, interceptor.Attributes, error) { b := make([]byte, receiveMTU) i, err := t.Read(b) if err != nil { - return nil, err + return nil, nil, err } r := &rtp.Packet{} if err := r.Unmarshal(b[:i]); err != nil { - return nil, err + return nil, nil, err } - return r, nil + return r, make(interceptor.Attributes), nil } // determinePayloadType blocks and reads a single packet to determine the PayloadType for this Track