From 61cc8ca3602344041020d03a0cdb975ef28cdaff Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Mon, 1 Jan 2024 20:17:06 +0100 Subject: [PATCH] client: return ErrClientEOS when a playlist ends (#59) --- client.go | 4 ++++ client_primary_downloader.go | 22 +++++++++++++++++++++- client_stream_downloader.go | 6 +++++- client_stream_processor_fmp4.go | 7 +++++++ client_stream_processor_mpegts.go | 7 +++++++ client_test.go | 4 +++- 6 files changed, 47 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index a317171..2c64980 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,7 @@ package gohlslib import ( "context" + "errors" "fmt" "log" "net/http" @@ -22,6 +23,9 @@ const ( clientMaxDTSRTCDiff = 10 * time.Second ) +// ErrClientEOS is returned by Wait() when the stream has ended. +var ErrClientEOS = errors.New("end of stream") + // ClientOnDownloadPrimaryPlaylistFunc is the prototype of Client.OnDownloadPrimaryPlaylist. type ClientOnDownloadPrimaryPlaylistFunc func(url string) diff --git a/client_primary_downloader.go b/client_primary_downloader.go index 08467f5..6f86408 100644 --- a/client_primary_downloader.go +++ b/client_primary_downloader.go @@ -113,6 +113,7 @@ type clientPrimaryDownloader struct { // in chStreamTracks chan clientStreamProcessor + chStreamEnded chan struct{} // out startStreaming chan struct{} @@ -122,6 +123,7 @@ type clientPrimaryDownloader struct { func (d *clientPrimaryDownloader) initialize() { d.streamProcByTrack = make(map[*Track]clientStreamProcessor) d.chStreamTracks = make(chan clientStreamProcessor) + d.chStreamEnded = make(chan struct{}) d.startStreaming = make(chan struct{}) d.leadingTimeSyncReady = make(chan struct{}) } @@ -148,6 +150,7 @@ func (d *clientPrimaryDownloader) run(ctx context.Context) error { initialPlaylist: plt, rp: d.rp, onStreamTracks: d.onStreamTracks, + onStreamEnded: d.onStreamEnded, onSetLeadingTimeSync: d.onSetLeadingTimeSync, onGetLeadingTimeSync: d.onGetLeadingTimeSync, onData: d.onData, @@ -176,6 +179,7 @@ func (d *clientPrimaryDownloader) run(ctx context.Context) error { initialPlaylist: nil, rp: d.rp, onStreamTracks: d.onStreamTracks, + onStreamEnded: d.onStreamEnded, onSetLeadingTimeSync: d.onSetLeadingTimeSync, onGetLeadingTimeSync: d.onGetLeadingTimeSync, onData: d.onData, @@ -208,6 +212,7 @@ func (d *clientPrimaryDownloader) run(ctx context.Context) error { onSetLeadingTimeSync: d.onSetLeadingTimeSync, onGetLeadingTimeSync: d.onGetLeadingTimeSync, onData: d.onData, + onStreamEnded: d.onStreamEnded, } d.rp.add(ds) streamCount++ @@ -251,7 +256,15 @@ func (d *clientPrimaryDownloader) run(ctx context.Context) error { close(d.startStreaming) - return nil + for i := 0; i < streamCount; i++ { + select { + case <-d.chStreamEnded: + case <-ctx.Done(): + return fmt.Errorf("terminated") + } + } + + return ErrClientEOS } func (d *clientPrimaryDownloader) onStreamTracks(ctx context.Context, streamProc clientStreamProcessor) bool { @@ -270,6 +283,13 @@ func (d *clientPrimaryDownloader) onStreamTracks(ctx context.Context, streamProc return true } +func (d *clientPrimaryDownloader) onStreamEnded(ctx context.Context) { + select { + case d.chStreamEnded <- struct{}{}: + case <-ctx.Done(): + } +} + func (d *clientPrimaryDownloader) onSetLeadingTimeSync(ts clientTimeSync) { d.leadingTimeSync = ts close(d.leadingTimeSyncReady) diff --git a/client_stream_downloader.go b/client_stream_downloader.go index e26604f..ed1b255 100644 --- a/client_stream_downloader.go +++ b/client_stream_downloader.go @@ -39,6 +39,7 @@ type clientStreamDownloader struct { initialPlaylist *playlist.Media rp *clientRoutinePool onStreamTracks clientOnStreamTracksFunc + onStreamEnded func(context.Context) onSetLeadingTimeSync func(clientTimeSync) onGetLeadingTimeSync func(context.Context) (clientTimeSync, bool) onData map[*Track]interface{} @@ -77,6 +78,7 @@ func (d *clientStreamDownloader) run(ctx context.Context) error { segmentQueue: segmentQueue, rp: d.rp, onStreamTracks: d.onStreamTracks, + onStreamEnded: d.onStreamEnded, onSetLeadingTimeSync: d.onSetLeadingTimeSync, onGetLeadingTimeSync: d.onGetLeadingTimeSync, onData: d.onData, @@ -94,6 +96,7 @@ func (d *clientStreamDownloader) run(ctx context.Context) error { segmentQueue: segmentQueue, rp: d.rp, onStreamTracks: d.onStreamTracks, + onStreamEnded: d.onStreamEnded, onSetLeadingTimeSync: d.onSetLeadingTimeSync, onGetLeadingTimeSync: d.onGetLeadingTimeSync, onData: d.onData, @@ -231,8 +234,9 @@ func (d *clientStreamDownloader) fillSegmentQueue( }) if pl.Endlist && pl.Segments[len(pl.Segments)-1] == seg { + segmentQueue.push(nil) <-ctx.Done() - return fmt.Errorf("stream has ended") + return fmt.Errorf("terminated") } return nil diff --git a/client_stream_processor_fmp4.go b/client_stream_processor_fmp4.go index a93c19f..622ea6f 100644 --- a/client_stream_processor_fmp4.go +++ b/client_stream_processor_fmp4.go @@ -50,6 +50,7 @@ type clientStreamProcessorFMP4 struct { segmentQueue *clientSegmentQueue rp *clientRoutinePool onStreamTracks clientOnStreamTracksFunc + onStreamEnded func(context.Context) onSetLeadingTimeSync func(clientTimeSync) onGetLeadingTimeSync func(context.Context) (clientTimeSync, bool) onData map[*Track]interface{} @@ -119,6 +120,12 @@ func (p *clientStreamProcessorFMP4) run(ctx context.Context) error { } func (p *clientStreamProcessorFMP4) processSegment(ctx context.Context, seg *segmentData) error { + if seg == nil { + p.onStreamEnded(ctx) + <-ctx.Done() + return fmt.Errorf("terminated") + } + var parts fmp4.Parts err := parts.Unmarshal(seg.payload) if err != nil { diff --git a/client_stream_processor_mpegts.go b/client_stream_processor_mpegts.go index 890d685..88d509f 100644 --- a/client_stream_processor_mpegts.go +++ b/client_stream_processor_mpegts.go @@ -40,6 +40,7 @@ type clientStreamProcessorMPEGTS struct { segmentQueue *clientSegmentQueue rp *clientRoutinePool onStreamTracks clientOnStreamTracksFunc + onStreamEnded func(context.Context) onSetLeadingTimeSync func(clientTimeSync) onGetLeadingTimeSync func(context.Context) (clientTimeSync, bool) onData map[*Track]interface{} @@ -85,6 +86,12 @@ func (p *clientStreamProcessorMPEGTS) run(ctx context.Context) error { } func (p *clientStreamProcessorMPEGTS) processSegment(ctx context.Context, seg *segmentData) error { + if seg == nil { + p.onStreamEnded(ctx) + <-ctx.Done() + return fmt.Errorf("terminated") + } + if p.switchableReader == nil { err := p.initializeReader(ctx, seg.payload) if err != nil { diff --git a/client_test.go b/client_test.go index a009aba..e35ae84 100644 --- a/client_test.go +++ b/client_test.go @@ -403,8 +403,10 @@ func TestClient(t *testing.T) { <-audioRecv <-audioRecv + err = <-c.Wait() + require.Equal(t, ErrClientEOS, err) + c.Close() - <-c.Wait() }) } }