Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a memory session to retain tested indexes across requests #202

Merged
merged 9 commits into from
Dec 21, 2024
6 changes: 3 additions & 3 deletions handlers/stream_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func StreamHandler(w http.ResponseWriter, r *http.Request, cm *store.Concurrency
var selectedIndex int
var selectedUrl string

testedIndexes := []int{}
session := store.GetOrCreateSession(r)
firstWrite := true

var resp *http.Response
Expand All @@ -46,7 +46,7 @@ func StreamHandler(w http.ResponseWriter, r *http.Request, cm *store.Concurrency
}()

for {
resp, selectedUrl, selectedIndex, err = stream.LoadBalancer(ctx, &testedIndexes, r.Method)
resp, selectedUrl, selectedIndex, err = stream.LoadBalancer(ctx, &session, r.Method)
if err != nil {
utils.SafeLogf("Error reloading stream for %s: %v\n", streamUrl, err)
return
Expand Down Expand Up @@ -78,7 +78,6 @@ func StreamHandler(w http.ResponseWriter, r *http.Request, cm *store.Concurrency
defer proxyCtxCancel()

go stream.ProxyStream(proxyCtx, selectedIndex, resp, r, w, exitStatus)
testedIndexes = append(testedIndexes, selectedIndex)

select {
case <-ctx.Done():
Expand All @@ -92,6 +91,7 @@ func StreamHandler(w http.ResponseWriter, r *http.Request, cm *store.Concurrency
return
} else if streamExitCode == 1 || streamExitCode == 2 {
// Retry on server-side connection errors
session.SetTestedIndexes(append(session.TestedIndexes, selectedIndex))
utils.SafeLogf("Retrying other servers...\n")
proxyCtxCancel()
} else if streamExitCode == 4 {
Expand Down
16 changes: 6 additions & 10 deletions proxy/load_balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func NewStreamInstance(streamUrl string, cm *store.ConcurrencyManager) (*StreamI
}, nil
}

func (instance *StreamInstance) LoadBalancer(ctx context.Context, previous *[]int, method string) (*http.Response, string, int, error) {
func (instance *StreamInstance) LoadBalancer(ctx context.Context, session *store.Session, method string) (*http.Response, string, int, error) {
debug := os.Getenv("DEBUG") == "true"

m3uIndexes := utils.GetM3UIndexes()
Expand All @@ -57,14 +57,13 @@ func (instance *StreamInstance) LoadBalancer(ctx context.Context, previous *[]in
if debug {
utils.SafeLogf("[DEBUG] Stream attempt %d out of %d\n", lap+1, maxLaps)
}
allSkipped := true // Assume all URLs might be skipped

select {
case <-ctx.Done():
return nil, "", -1, fmt.Errorf("Cancelling load balancer.")
default:
for _, index := range m3uIndexes {
if slices.Contains(*previous, index) {
if slices.Contains(session.TestedIndexes, index) {
utils.SafeLogf("Skipping M3U_%d: marked as previous stream\n", index+1)
continue
}
Expand All @@ -80,8 +79,6 @@ func (instance *StreamInstance) LoadBalancer(ctx context.Context, previous *[]in
continue
}

allSkipped = false // At least one URL is not skipped

resp, err := utils.CustomHttpRequest(method, url)
if err == nil {
if debug {
Expand All @@ -93,14 +90,13 @@ func (instance *StreamInstance) LoadBalancer(ctx context.Context, previous *[]in
if debug {
utils.SafeLogf("[DEBUG] Error fetching stream from %s: %s\n", url, err.Error())
}
session.SetTestedIndexes(append(session.TestedIndexes, index))
}

if allSkipped {
if debug {
utils.SafeLogf("[DEBUG] All streams skipped in lap %d\n", lap)
}
*previous = []int{}
if debug {
utils.SafeLogf("[DEBUG] All streams skipped in lap %d\n", lap)
}
session.SetTestedIndexes([]int{})

}

Expand Down
73 changes: 73 additions & 0 deletions store/sessions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package store

import (
"m3u-stream-merger/utils"
"net/http"
"os"
"sync"
"time"
)

type Session struct {
ID string
CreatedAt time.Time
TestedIndexes []int
}

var sessionStore = struct {
sync.RWMutex
sessions map[string]Session
}{sessions: make(map[string]Session)}

func GetOrCreateSession(r *http.Request) Session {
debug := os.Getenv("DEBUG") == "true"
fingerprint := utils.GenerateFingerprint(r)

sessionStore.RLock()
session, exists := sessionStore.sessions[fingerprint]
sessionStore.RUnlock()
if exists {
if debug {
utils.SafeLogf("[DEBUG] Existing session found: %s\n", fingerprint)
}
return session
}

session = Session{
ID: fingerprint,
CreatedAt: time.Now(),
TestedIndexes: []int{},
}

sessionStore.Lock()
sessionStore.sessions[session.ID] = session
sessionStore.Unlock()

if debug {
utils.SafeLogf("[DEBUG] Generating new session: %s\n", fingerprint)
}

return session
}

func ClearSessionStore() {
sessionStore.Lock()
for k := range sessionStore.sessions {
delete(sessionStore.sessions, k)
}
sessionStore.Unlock()
}

func (s *Session) SetTestedIndexes(indexes []int) {
debug := os.Getenv("DEBUG") == "true"

s.TestedIndexes = indexes

if debug {
utils.SafeLogf("[DEBUG] Setting tested indexes for session - %s: %v\n", s.ID, s.TestedIndexes)
}

sessionStore.Lock()
sessionStore.sessions[s.ID] = *s
sessionStore.Unlock()
}
2 changes: 2 additions & 0 deletions updater/updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ func (instance *Updater) UpdateSources(ctx context.Context) {

utils.SafeLogf("Background process: M3U fetching complete.\n")

store.ClearSessionStore()

cacheOnSync := os.Getenv("CACHE_ON_SYNC")
if len(strings.TrimSpace(cacheOnSync)) == 0 {
cacheOnSync = "false"
Expand Down
34 changes: 34 additions & 0 deletions utils/fingerprint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package utils

import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"os"
"strings"
)

func GenerateFingerprint(r *http.Request) string {
debug := os.Getenv("DEBUG") == "true"

// Collect relevant attributes
ip := strings.Split(r.RemoteAddr, ":")[0]
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
ip = xff
}
userAgent := r.Header.Get("User-Agent")
accept := r.Header.Get("Accept")
acceptLang := r.Header.Get("Accept-Language")
path := r.URL.Path

// Combine into a single string
data := fmt.Sprintf("%s|%s|%s|%s|%s", ip, userAgent, accept, acceptLang, path)
if debug {
SafeLogf("[DEBUG] Generating fingerprint from: %s\n", data)
}

// Hash the string for a compact, fixed-length identifier
hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:])
}
Loading