From cc03232a664f18834a11ee8b6c7ca74dd4a672cb Mon Sep 17 00:00:00 2001 From: Son Roy Almerol Date: Fri, 25 Oct 2024 22:05:14 -0400 Subject: [PATCH] major refactoring client and background stream --- proxy/buffer.go | 142 ----------------------- proxy/buffer_stream.go | 244 ++++++++++++++++++++++++++++++++++++++++ proxy/load_balancer.go | 7 +- proxy/stream_handler.go | 165 ++++++++++----------------- 4 files changed, 306 insertions(+), 252 deletions(-) create mode 100644 proxy/buffer_stream.go diff --git a/proxy/buffer.go b/proxy/buffer.go index 3e135766..d732c9ea 100644 --- a/proxy/buffer.go +++ b/proxy/buffer.go @@ -3,15 +3,11 @@ package proxy import ( "context" "fmt" - "io" "m3u-stream-merger/database" - "m3u-stream-merger/utils" - "net/http" "os" "strconv" "time" - "github.com/bsm/redislock" "github.com/redis/go-redis/v9" ) @@ -99,141 +95,3 @@ func (b *Buffer) Subscribe(ctx context.Context) (*chan []byte, error) { return &ch, nil } - -func BufferStream(instance *StreamInstance, m3uIndex int, resp *http.Response, r *http.Request, w http.ResponseWriter, statusChan chan int) { - debug := os.Getenv("DEBUG") == "true" - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - locker := redislock.New(instance.Database.Redis) - - lock, err := locker.Obtain(ctx, instance.Buffer.streamKey, time.Minute, nil) - if err == redislock.ErrNotObtained { - return - } else if err != nil { - utils.SafeLogf("Obtaining lock error: %v\n", err) - return - } - defer func() { - err := lock.Release(context.Background()) - if err != nil && debug { - utils.SafeLogf("Releasing lock error: %v\n", err) - } - }() - - instance.Database.UpdateConcurrency(m3uIndex, true) - defer instance.Database.UpdateConcurrency(m3uIndex, false) - - timeoutSecond, err := strconv.Atoi(os.Getenv("STREAM_TIMEOUT")) - if err != nil || timeoutSecond < 0 { - timeoutSecond = 3 - } - - timeoutDuration := time.Duration(timeoutSecond) * time.Second - if timeoutSecond == 0 { - timeoutDuration = time.Minute - } - timer := time.NewTimer(timeoutDuration) - defer timer.Stop() - - // Backoff settings - initialBackoff := 200 * time.Millisecond - maxBackoff := time.Duration(timeoutSecond-1) * time.Second - currentBackoff := initialBackoff - - returnStatus := 0 - - sourceChunk := make([]byte, 1024) - - for { - select { - case <-ctx.Done(): // handle context cancellation - utils.SafeLogf("Context canceled for stream: %s\n", r.RemoteAddr) - statusChan <- 0 - return - case <-timer.C: - utils.SafeLogf("Timeout reached while trying to stream: %s\n", r.RemoteAddr) - statusChan <- returnStatus - return - default: - err := lock.Refresh(ctx, time.Minute, nil) - if err != nil { - utils.SafeLogf("Failed to refresh lock: %s\n", err) - } - - clients, err := instance.Database.GetBufferUser(instance.Buffer.streamKey) - if err != nil { - utils.SafeLogf("Failed to get number of clients: %s\n", err) - } - - if clients <= 0 { - cancel() - continue - } - - n, err := resp.Body.Read(sourceChunk) - if err != nil { - if err == io.EOF { - utils.SafeLogf("Stream ended (EOF reached): %s\n", r.RemoteAddr) - if timeoutSecond == 0 { - statusChan <- 2 - return - } - - returnStatus = 2 - utils.SafeLogf("Retrying same stream until timeout (%d seconds) is reached...\n", timeoutSecond) - if debug { - utils.SafeLogf("[DEBUG] Retrying same stream with backoff of %v...\n", currentBackoff) - } - - time.Sleep(currentBackoff) - currentBackoff *= 2 - if currentBackoff > maxBackoff { - currentBackoff = maxBackoff - } - - continue - } - - utils.SafeLogf("Error reading stream: %s\n", err.Error()) - - returnStatus = 1 - - if timeoutSecond == 0 { - statusChan <- 1 - return - } - - if debug { - utils.SafeLogf("[DEBUG] Retrying same stream with backoff of %v...\n", currentBackoff) - } - - time.Sleep(currentBackoff) - currentBackoff *= 2 - if currentBackoff > maxBackoff { - currentBackoff = maxBackoff - } - - continue - } - - err = instance.Buffer.Write(ctx, sourceChunk[:n]) - if err != nil { - utils.SafeLogf("Failed to store buffer: %s\n", err.Error()) - } - - // Reset the timer on each successful write and backoff - if !timer.Stop() { - select { - case <-timer.C: // drain the channel to avoid blocking - default: - } - } - timer.Reset(timeoutDuration) - - // Reset the backoff duration after successful read/write - currentBackoff = initialBackoff - } - } -} diff --git a/proxy/buffer_stream.go b/proxy/buffer_stream.go new file mode 100644 index 00000000..698cb33f --- /dev/null +++ b/proxy/buffer_stream.go @@ -0,0 +1,244 @@ +package proxy + +import ( + "context" + "fmt" + "io" + "m3u-stream-merger/database" + "m3u-stream-merger/utils" + "net/http" + "os" + "strconv" + "time" + + "github.com/bsm/redislock" +) + +var BufferStreams map[string]*BufferStream + +type BufferStream struct { + Id string + TestedIndexes []int + Buffer *Buffer + Info database.StreamInfo + Database *database.Instance + Started bool + Response *http.Response +} + +func InitializeBufferStream(streamUrl string) (*BufferStream, error) { + if BufferStreams == nil { + BufferStreams = make(map[string]*BufferStream) + } + + if BufferStreams[streamUrl] != nil { + return BufferStreams[streamUrl], nil + } + + db, err := database.InitializeDb() + if err != nil { + return nil, fmt.Errorf("InitializeBufferStream error: %v", err) + } + + stream, err := db.GetStreamBySlug(streamUrl) + if err != nil { + return nil, fmt.Errorf("InitializeBufferStream error: %v", err) + } + + buffer, err := NewBuffer(db, streamUrl) + if err != nil { + return nil, fmt.Errorf("InitializeBufferStream error: %v", err) + } + + BufferStreams[streamUrl] = &BufferStream{ + Id: "streambuffer:" + streamUrl, + TestedIndexes: make([]int, 0), + Buffer: buffer, + Database: db, + Info: stream, + Started: false, + } + + return BufferStreams[streamUrl], nil +} + +func (stream *BufferStream) Start(r *http.Request) error { + debug := os.Getenv("DEBUG") == "true" + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if stream.Started { + return nil + } + + stream.Started = true + defer func() { + stream.Started = false + }() + + resp, selectedIndex, err := stream.LoadBalancer(r.Method) + if err != nil { + return err + } + + stream.Response = resp + + firstIteration := true + + go func() { + for { + select { + case <-ctx.Done(): + utils.SafeLogf("Context cancelled: %s\n", stream.Info.Title) + resp.Body.Close() + return + default: + if !firstIteration { + resp, selectedIndex, err = stream.LoadBalancer(r.Method) + if err != nil { + utils.SafeLogf("Error reloading stream for %s: %v\n", stream.Info.Slug, err) + return + } + stream.Response = resp + } else { + firstIteration = false + } + + stream.Buffer.testedIndexes = append(stream.Buffer.testedIndexes, selectedIndex) + locker := redislock.New(stream.Database.Redis) + + lock, err := locker.Obtain(ctx, stream.Id, time.Minute, nil) + if err == redislock.ErrNotObtained { + return + } else if err != nil { + utils.SafeLogf("Obtaining lock error: %v\n", err) + return + } + defer func() { + err := lock.Release(context.Background()) + if err != nil && debug { + utils.SafeLogf("Releasing lock error: %v\n", err) + } + }() + + stream.streamToBuffer(ctx, resp, lock, selectedIndex) + resp.Body.Close() + } + } + }() + + return nil +} + +func (stream *BufferStream) streamToBuffer(ctx context.Context, resp *http.Response, lock *redislock.Lock, selectedIndex int) { + debug := os.Getenv("DEBUG") == "true" + + stream.Database.UpdateConcurrency(selectedIndex, true) + defer stream.Database.UpdateConcurrency(selectedIndex, false) + + timeoutSecond, err := strconv.Atoi(os.Getenv("STREAM_TIMEOUT")) + if err != nil || timeoutSecond < 0 { + timeoutSecond = 3 + } + + lastClientSeen := time.Now() + + timeoutDuration := time.Duration(timeoutSecond) * time.Second + if timeoutSecond == 0 { + timeoutDuration = time.Minute + } + timer := time.NewTimer(timeoutDuration) + defer timer.Stop() + + // Backoff settings + initialBackoff := 200 * time.Millisecond + maxBackoff := time.Duration(timeoutSecond-1) * time.Second + currentBackoff := initialBackoff + + sourceChunk := make([]byte, 1024) + + for { + select { + case <-ctx.Done(): // handle context cancellation + utils.SafeLogf("Context canceled for stream: %s\n", stream.Info.Title) + return + case <-timer.C: + utils.SafeLogf("Timeout reached while trying to stream: %s\n", stream.Info.Title) + return + default: + err := lock.Refresh(ctx, time.Minute, nil) + if err != nil { + utils.SafeLogf("Failed to refresh lock: %s\n", err) + } + + clients, err := stream.Database.GetBufferUser(stream.Id) + if err != nil { + utils.SafeLogf("Failed to get number of clients: %s\n", err) + } + + if clients <= 0 && time.Now().Sub(lastClientSeen) > 3*time.Second { + utils.SafeLogf("No clients left: %s\n", stream.Info.Title) + return + } + lastClientSeen = time.Now() + + n, err := resp.Body.Read(sourceChunk) + if err != nil { + if err == io.EOF { + utils.SafeLogf("Stream ended (EOF reached): %s\n", stream.Info.Title) + if timeoutSecond == 0 { + return + } + + utils.SafeLogf("Retrying same stream until timeout (%d seconds) is reached...\n", timeoutSecond) + if debug { + utils.SafeLogf("[DEBUG] Retrying same stream with backoff of %v...\n", currentBackoff) + } + + time.Sleep(currentBackoff) + currentBackoff *= 2 + if currentBackoff > maxBackoff { + currentBackoff = maxBackoff + } + + continue + } + + utils.SafeLogf("Error reading stream: %s\n", err.Error()) + + if timeoutSecond == 0 { + return + } + + if debug { + utils.SafeLogf("[DEBUG] Retrying same stream with backoff of %v...\n", currentBackoff) + } + + time.Sleep(currentBackoff) + currentBackoff *= 2 + if currentBackoff > maxBackoff { + currentBackoff = maxBackoff + } + + continue + } + + err = stream.Buffer.Write(ctx, sourceChunk[:n]) + if err != nil { + utils.SafeLogf("Failed to store buffer: %s\n", err.Error()) + } + + // Reset the timer on each successful write and backoff + if !timer.Stop() { + select { + case <-timer.C: // drain the channel to avoid blocking + default: + } + } + timer.Reset(timeoutDuration) + + // Reset the backoff duration after successful read/write + currentBackoff = initialBackoff + } + } +} diff --git a/proxy/load_balancer.go b/proxy/load_balancer.go index 526ac876..06896e58 100644 --- a/proxy/load_balancer.go +++ b/proxy/load_balancer.go @@ -11,10 +11,11 @@ import ( "strings" ) -func (instance *StreamInstance) LoadBalancer(previous *[]int, method string) (*http.Response, string, int, error) { +func (instance *BufferStream) LoadBalancer(method string) (*http.Response, int, error) { debug := os.Getenv("DEBUG") == "true" m3uIndexes := utils.GetM3UIndexes() + previous := &instance.Buffer.testedIndexes sort.Slice(m3uIndexes, func(i, j int) bool { return instance.Database.ConcurrencyPriorityValue(i) > instance.Database.ConcurrencyPriorityValue(j) @@ -58,7 +59,7 @@ func (instance *StreamInstance) LoadBalancer(previous *[]int, method string) (*h if debug { utils.SafeLogf("[DEBUG] Successfully fetched stream from %s\n", url) } - return resp, url, index, nil + return resp, index, nil } utils.SafeLogf("Error fetching stream: %s\n", err.Error()) if debug { @@ -76,5 +77,5 @@ func (instance *StreamInstance) LoadBalancer(previous *[]int, method string) (*h lap++ } - return nil, "", -1, fmt.Errorf("Error fetching stream. Exhausted all streams.") + return nil, -1, fmt.Errorf("Error fetching stream. Exhausted all streams.") } diff --git a/proxy/stream_handler.go b/proxy/stream_handler.go index 65ee05c9..0a53d68c 100644 --- a/proxy/stream_handler.go +++ b/proxy/stream_handler.go @@ -3,6 +3,7 @@ package proxy import ( "bufio" "context" + "fmt" "m3u-stream-merger/database" "m3u-stream-merger/utils" "net/http" @@ -12,44 +13,29 @@ import ( "strings" ) -type StreamInstance struct { +var streamStatusChans map[string]*chan int + +type ClientInstance struct { Database *database.Instance - Info database.StreamInfo - Buffer *Buffer } -func InitializeStream(ctx context.Context, streamUrl string) (*StreamInstance, error) { +func InitializeClient(ctx context.Context, streamUrl string) (*ClientInstance, error) { initDb, err := database.InitializeDb() if err != nil { utils.SafeLogf("Error initializing Redis database: %v", err) return nil, err } - buffer, err := NewBuffer(initDb, streamUrl) - if err != nil { - utils.SafeLogf("Error initializing stream buffer: %v", err) - return nil, err - } - - stream, err := initDb.GetStreamBySlug(streamUrl) - if err != nil { - return nil, err - } - - return &StreamInstance{ + return &ClientInstance{ Database: initDb, - Info: stream, - Buffer: buffer, }, nil } -func (instance *StreamInstance) DirectProxy(ctx context.Context, resp *http.Response, w http.ResponseWriter, statusChan chan int) { +func (instance *ClientInstance) DirectProxy(ctx context.Context, resp *http.Response, w http.ResponseWriter) error { scanner := bufio.NewScanner(resp.Body) base, err := url.Parse(resp.Request.URL.String()) if err != nil { - utils.SafeLogf("Invalid base URL for M3U8 stream: %v", err) - statusChan <- 4 - return + return fmt.Errorf("Invalid base URL for M3U8 stream: %v", err) } for scanner.Scan() { @@ -57,9 +43,7 @@ func (instance *StreamInstance) DirectProxy(ctx context.Context, resp *http.Resp if strings.HasPrefix(line, "#") { _, err := w.Write([]byte(line + "\n")) if err != nil { - utils.SafeLogf("Failed to write line to response: %v", err) - statusChan <- 4 - return + return fmt.Errorf("Failed to write line to response: %v", err) } } else if strings.TrimSpace(line) != "" { u, err := url.Parse(line) @@ -67,9 +51,7 @@ func (instance *StreamInstance) DirectProxy(ctx context.Context, resp *http.Resp utils.SafeLogf("Failed to parse M3U8 URL in line: %v", err) _, err := w.Write([]byte(line + "\n")) if err != nil { - utils.SafeLogf("Failed to write line to response: %v", err) - statusChan <- 4 - return + return fmt.Errorf("Failed to write line to response: %v", err) } continue } @@ -80,30 +62,27 @@ func (instance *StreamInstance) DirectProxy(ctx context.Context, resp *http.Resp _, err = w.Write([]byte(u.String() + "\n")) if err != nil { - utils.SafeLogf("Failed to write line to response: %v", err) - statusChan <- 4 - return + return fmt.Errorf("Failed to write line to response: %v", err) } } } - statusChan <- 4 + return nil } -func (instance *StreamInstance) StreamBuffer(ctx context.Context, w http.ResponseWriter) { +func (instance *ClientInstance) StreamBuffer(ctx context.Context, buffer *Buffer, w http.ResponseWriter) error { debug := os.Getenv("DEBUG") == "true" - streamCh, err := instance.Buffer.Subscribe(ctx) + streamCh, err := buffer.Subscribe(ctx) if err != nil { - utils.SafeLogf("Error subscribing client: %v", err) - return + return fmt.Errorf("Error subscribing client: %v", err) } - err = instance.Database.IncrementBufferUser(instance.Buffer.streamKey) + err = instance.Database.IncrementBufferUser(buffer.streamKey) if err != nil && debug { utils.SafeLogf("Error incrementing buffer user: %v\n", err) } defer func() { - err = instance.Database.DecrementBufferUser(instance.Buffer.streamKey) + err = instance.Database.DecrementBufferUser(buffer.streamKey) if err != nil && debug { utils.SafeLogf("Error decrementing buffer user: %v\n", err) } @@ -112,12 +91,11 @@ func (instance *StreamInstance) StreamBuffer(ctx context.Context, w http.Respons for { select { case <-ctx.Done(): // handle context cancellation - return + return nil case chunk := <-*streamCh: _, err := w.Write(chunk) if err != nil { - utils.SafeLogf("Error writing to client: %v", err) - return + return fmt.Errorf("Error writing to client: %v", err) } if flusher, ok := w.(http.Flusher); ok { @@ -133,6 +111,10 @@ func Handler(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithCancel(r.Context()) defer cancel() + if streamStatusChans == nil { + streamStatusChans = make(map[string]*chan int) + } + utils.SafeLogf("Received request from %s for URL: %s\n", r.RemoteAddr, r.URL.Path) streamUrl := strings.Split(path.Base(r.URL.Path), ".")[0] @@ -142,83 +124,52 @@ func Handler(w http.ResponseWriter, r *http.Request) { return } - stream, err := InitializeStream(ctx, strings.TrimPrefix(streamUrl, "/")) + streamUrl = strings.TrimPrefix(streamUrl, "/") + + client, err := InitializeClient(ctx, streamUrl) if err != nil { utils.SafeLogf("Error retrieving stream for slug %s: %v\n", streamUrl, err) http.NotFound(w, r) return } - var selectedIndex int - var selectedUrl string - - firstWrite := true - - var resp *http.Response - - for { - select { - case <-ctx.Done(): - utils.SafeLogf("Client disconnected: %s\n", r.RemoteAddr) - return - default: - resp, selectedUrl, selectedIndex, err = stream.LoadBalancer(&stream.Buffer.testedIndexes, r.Method) - if err != nil { - utils.SafeLogf("Error reloading stream for %s: %v\n", streamUrl, err) - return - } - - // HTTP header initialization - if firstWrite { - for k, v := range resp.Header { - if strings.ToLower(k) == "content-length" { - continue - } - - for _, val := range v { - w.Header().Set(k, val) - } - } - w.WriteHeader(resp.StatusCode) + stream, err := InitializeBufferStream(streamUrl) + if err != nil { + utils.SafeLogf("Error retrieving buffer stream for slug %s: %v\n", streamUrl, err) + http.NotFound(w, r) + return + } - if debug { - utils.SafeLogf("[DEBUG] Headers set for response: %v\n", w.Header()) - } - firstWrite = false - } + err = stream.Start(r) + if err != nil { + utils.SafeLogf("Error starting buffer stream for slug %s: %v\n", streamUrl, err) + http.NotFound(w, r) + return + } - exitStatus := make(chan int) + utils.SafeLogf("Proxying %s to %s\n", stream.Info.Title, r.RemoteAddr) - utils.SafeLogf("Proxying %s to %s\n", r.RemoteAddr, selectedUrl) + for k, v := range stream.Response.Header { + if strings.ToLower(k) == "content-length" { + continue + } - if r.Method != http.MethodGet || utils.EOFIsExpected(resp) { - go stream.DirectProxy(ctx, resp, w, exitStatus) - } else { - go stream.StreamBuffer(ctx, w) - go BufferStream(stream, selectedIndex, resp, r, w, exitStatus) - } + for _, val := range v { + w.Header().Set(k, val) + } + } + w.WriteHeader(stream.Response.StatusCode) - stream.Buffer.testedIndexes = append(stream.Buffer.testedIndexes, selectedIndex) - - streamExitCode := <-exitStatus - utils.SafeLogf("Exit code %d received from %s\n", streamExitCode, selectedUrl) - - if streamExitCode == 2 && utils.EOFIsExpected(resp) { - utils.SafeLogf("Successfully proxied playlist: %s\n", r.RemoteAddr) - cancel() - } else if streamExitCode == 1 || streamExitCode == 2 { - // Retry on server-side connection errors - utils.SafeLogf("Retrying other servers...\n") - } else if streamExitCode == 4 { - utils.SafeLogf("Finished handling %s request: %s\n", r.Method, r.RemoteAddr) - cancel() - } else { - // Consider client-side connection errors as complete closure - utils.SafeLogf("Client has closed the stream: %s\n", r.RemoteAddr) - cancel() - } + if debug { + utils.SafeLogf("[DEBUG] Headers set for response: %v\n", w.Header()) + } - resp.Body.Close() - } + if r.Method != http.MethodGet || utils.EOFIsExpected(stream.Response) { + err = client.DirectProxy(ctx, stream.Response, w) + } else { + err = client.StreamBuffer(ctx, stream.Buffer, w) } + + utils.SafeLogf("Error stream for slug %s: %v\n", streamUrl, err) + http.Error(w, err.Error(), http.StatusInternalServerError) }