Skip to content

Commit

Permalink
add a memory session to retain tested indexes across requests
Browse files Browse the repository at this point in the history
  • Loading branch information
sonroyaalmerol committed Dec 20, 2024
1 parent 67b921a commit 5a50686
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 6 deletions.
6 changes: 3 additions & 3 deletions handlers/stream_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func StreamHandler(w http.ResponseWriter, r *http.Request, cm *store.Concurrency
var selectedIndex int
var selectedUrl string

testedIndexes := []int{}
session := store.GetOrCreateSession(r)
firstWrite := true

var resp *http.Response
Expand All @@ -46,7 +46,7 @@ func StreamHandler(w http.ResponseWriter, r *http.Request, cm *store.Concurrency
}()

for {
resp, selectedUrl, selectedIndex, err = stream.LoadBalancer(ctx, &testedIndexes, r.Method)
resp, selectedUrl, selectedIndex, err = stream.LoadBalancer(ctx, &session, r.Method)
if err != nil {
utils.SafeLogf("Error reloading stream for %s: %v\n", streamUrl, err)
return
Expand Down Expand Up @@ -78,7 +78,7 @@ func StreamHandler(w http.ResponseWriter, r *http.Request, cm *store.Concurrency
defer proxyCtxCancel()

go stream.ProxyStream(proxyCtx, selectedIndex, resp, r, w, exitStatus)
testedIndexes = append(testedIndexes, selectedIndex)
session.SetTestedIndexes(append(session.TestedIndexes, selectedIndex))

select {
case <-ctx.Done():
Expand Down
6 changes: 3 additions & 3 deletions proxy/load_balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func NewStreamInstance(streamUrl string, cm *store.ConcurrencyManager) (*StreamI
}, nil
}

func (instance *StreamInstance) LoadBalancer(ctx context.Context, previous *[]int, method string) (*http.Response, string, int, error) {
func (instance *StreamInstance) LoadBalancer(ctx context.Context, session *store.Session, method string) (*http.Response, string, int, error) {
debug := os.Getenv("DEBUG") == "true"

m3uIndexes := utils.GetM3UIndexes()
Expand Down Expand Up @@ -64,7 +64,7 @@ func (instance *StreamInstance) LoadBalancer(ctx context.Context, previous *[]in
return nil, "", -1, fmt.Errorf("Cancelling load balancer.")
default:
for _, index := range m3uIndexes {
if slices.Contains(*previous, index) {
if slices.Contains(session.TestedIndexes, index) {
utils.SafeLogf("Skipping M3U_%d: marked as previous stream\n", index+1)
continue
}
Expand Down Expand Up @@ -99,7 +99,7 @@ func (instance *StreamInstance) LoadBalancer(ctx context.Context, previous *[]in
if debug {
utils.SafeLogf("[DEBUG] All streams skipped in lap %d\n", lap)
}
*previous = []int{}
session.SetTestedIndexes([]int{})
}

}
Expand Down
50 changes: 50 additions & 0 deletions store/sessions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package store

import (
"m3u-stream-merger/utils"
"net/http"
"sync"
"time"
)

type Session struct {
ID string
CreatedAt time.Time
TestedIndexes []int
}

var sessionStore = struct {
sync.RWMutex
sessions map[string]Session
}{sessions: make(map[string]Session)}

func GetOrCreateSession(r *http.Request) Session {
fingerprint := utils.GenerateFingerprint(r)

sessionStore.RLock()
session, exists := sessionStore.sessions[fingerprint]
sessionStore.RUnlock()
if exists {
return session
}

session = Session{
ID: fingerprint,
CreatedAt: time.Now(),
TestedIndexes: []int{},
}

sessionStore.Lock()
sessionStore.sessions[session.ID] = session
sessionStore.Unlock()

return session
}

func (s *Session) SetTestedIndexes(indexes []int) {
s.TestedIndexes = indexes

sessionStore.Lock()
sessionStore.sessions[s.ID] = *s
sessionStore.Unlock()
}
26 changes: 26 additions & 0 deletions utils/fingerprint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package utils

import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
)

func GenerateFingerprint(r *http.Request) string {
// Collect relevant attributes
ip := r.RemoteAddr
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
ip = xff
}
userAgent := r.Header.Get("User-Agent")
accept := r.Header.Get("Accept")
acceptLang := r.Header.Get("Accept-Language")

// Combine into a single string
data := fmt.Sprintf("%s|%s|%s|%s", ip, userAgent, accept, acceptLang)

// Hash the string for a compact, fixed-length identifier
hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:])
}

0 comments on commit 5a50686

Please sign in to comment.