Skip to content

Commit

Permalink
Merge pull request #202 from sonroyaalmerol/sessions
Browse files Browse the repository at this point in the history
Add a memory session to retain tested indexes across requests
  • Loading branch information
sonroyaalmerol authored Dec 21, 2024
2 parents 67b921a + 0189bef commit 8c9f44a
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 13 deletions.
6 changes: 3 additions & 3 deletions handlers/stream_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func StreamHandler(w http.ResponseWriter, r *http.Request, cm *store.Concurrency
var selectedIndex int
var selectedUrl string

testedIndexes := []int{}
session := store.GetOrCreateSession(r)
firstWrite := true

var resp *http.Response
Expand All @@ -46,7 +46,7 @@ func StreamHandler(w http.ResponseWriter, r *http.Request, cm *store.Concurrency
}()

for {
resp, selectedUrl, selectedIndex, err = stream.LoadBalancer(ctx, &testedIndexes, r.Method)
resp, selectedUrl, selectedIndex, err = stream.LoadBalancer(ctx, &session, r.Method)
if err != nil {
utils.SafeLogf("Error reloading stream for %s: %v\n", streamUrl, err)
return
Expand Down Expand Up @@ -78,7 +78,6 @@ func StreamHandler(w http.ResponseWriter, r *http.Request, cm *store.Concurrency
defer proxyCtxCancel()

go stream.ProxyStream(proxyCtx, selectedIndex, resp, r, w, exitStatus)
testedIndexes = append(testedIndexes, selectedIndex)

select {
case <-ctx.Done():
Expand All @@ -92,6 +91,7 @@ func StreamHandler(w http.ResponseWriter, r *http.Request, cm *store.Concurrency
return
} else if streamExitCode == 1 || streamExitCode == 2 {
// Retry on server-side connection errors
session.SetTestedIndexes(append(session.TestedIndexes, selectedIndex))
utils.SafeLogf("Retrying other servers...\n")
proxyCtxCancel()
} else if streamExitCode == 4 {
Expand Down
16 changes: 6 additions & 10 deletions proxy/load_balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func NewStreamInstance(streamUrl string, cm *store.ConcurrencyManager) (*StreamI
}, nil
}

func (instance *StreamInstance) LoadBalancer(ctx context.Context, previous *[]int, method string) (*http.Response, string, int, error) {
func (instance *StreamInstance) LoadBalancer(ctx context.Context, session *store.Session, method string) (*http.Response, string, int, error) {
debug := os.Getenv("DEBUG") == "true"

m3uIndexes := utils.GetM3UIndexes()
Expand All @@ -57,14 +57,13 @@ func (instance *StreamInstance) LoadBalancer(ctx context.Context, previous *[]in
if debug {
utils.SafeLogf("[DEBUG] Stream attempt %d out of %d\n", lap+1, maxLaps)
}
allSkipped := true // Assume all URLs might be skipped

select {
case <-ctx.Done():
return nil, "", -1, fmt.Errorf("Cancelling load balancer.")
default:
for _, index := range m3uIndexes {
if slices.Contains(*previous, index) {
if slices.Contains(session.TestedIndexes, index) {
utils.SafeLogf("Skipping M3U_%d: marked as previous stream\n", index+1)
continue
}
Expand All @@ -80,8 +79,6 @@ func (instance *StreamInstance) LoadBalancer(ctx context.Context, previous *[]in
continue
}

allSkipped = false // At least one URL is not skipped

resp, err := utils.CustomHttpRequest(method, url)
if err == nil {
if debug {
Expand All @@ -93,14 +90,13 @@ func (instance *StreamInstance) LoadBalancer(ctx context.Context, previous *[]in
if debug {
utils.SafeLogf("[DEBUG] Error fetching stream from %s: %s\n", url, err.Error())
}
session.SetTestedIndexes(append(session.TestedIndexes, index))
}

if allSkipped {
if debug {
utils.SafeLogf("[DEBUG] All streams skipped in lap %d\n", lap)
}
*previous = []int{}
if debug {
utils.SafeLogf("[DEBUG] All streams skipped in lap %d\n", lap)
}
session.SetTestedIndexes([]int{})

}

Expand Down
73 changes: 73 additions & 0 deletions store/sessions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package store

import (
"m3u-stream-merger/utils"
"net/http"
"os"
"sync"
"time"
)

type Session struct {
ID string
CreatedAt time.Time
TestedIndexes []int
}

var sessionStore = struct {
sync.RWMutex
sessions map[string]Session
}{sessions: make(map[string]Session)}

func GetOrCreateSession(r *http.Request) Session {
debug := os.Getenv("DEBUG") == "true"
fingerprint := utils.GenerateFingerprint(r)

sessionStore.RLock()
session, exists := sessionStore.sessions[fingerprint]
sessionStore.RUnlock()
if exists {
if debug {
utils.SafeLogf("[DEBUG] Existing session found: %s\n", fingerprint)
}
return session
}

session = Session{
ID: fingerprint,
CreatedAt: time.Now(),
TestedIndexes: []int{},
}

sessionStore.Lock()
sessionStore.sessions[session.ID] = session
sessionStore.Unlock()

if debug {
utils.SafeLogf("[DEBUG] Generating new session: %s\n", fingerprint)
}

return session
}

func ClearSessionStore() {
sessionStore.Lock()
for k := range sessionStore.sessions {
delete(sessionStore.sessions, k)
}
sessionStore.Unlock()
}

func (s *Session) SetTestedIndexes(indexes []int) {
debug := os.Getenv("DEBUG") == "true"

s.TestedIndexes = indexes

if debug {
utils.SafeLogf("[DEBUG] Setting tested indexes for session - %s: %v\n", s.ID, s.TestedIndexes)
}

sessionStore.Lock()
sessionStore.sessions[s.ID] = *s
sessionStore.Unlock()
}
2 changes: 2 additions & 0 deletions updater/updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ func (instance *Updater) UpdateSources(ctx context.Context) {

utils.SafeLogf("Background process: M3U fetching complete.\n")

store.ClearSessionStore()

cacheOnSync := os.Getenv("CACHE_ON_SYNC")
if len(strings.TrimSpace(cacheOnSync)) == 0 {
cacheOnSync = "false"
Expand Down
34 changes: 34 additions & 0 deletions utils/fingerprint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package utils

import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"os"
"strings"
)

func GenerateFingerprint(r *http.Request) string {
debug := os.Getenv("DEBUG") == "true"

// Collect relevant attributes
ip := strings.Split(r.RemoteAddr, ":")[0]
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
ip = xff
}
userAgent := r.Header.Get("User-Agent")
accept := r.Header.Get("Accept")
acceptLang := r.Header.Get("Accept-Language")
path := r.URL.Path

// Combine into a single string
data := fmt.Sprintf("%s|%s|%s|%s|%s", ip, userAgent, accept, acceptLang, path)
if debug {
SafeLogf("[DEBUG] Generating fingerprint from: %s\n", data)
}

// Hash the string for a compact, fixed-length identifier
hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:])
}

0 comments on commit 8c9f44a

Please sign in to comment.