Skip to content

Commit

Permalink
Merge pull request #12 from sonroyaalmerol/fix-concurrency-limit
Browse files Browse the repository at this point in the history
Fix concurrency limit not persisting across different requests
  • Loading branch information
sonroyaalmerol authored Mar 4, 2024
2 parents f65d62d + 90e0478 commit 8c6dc91
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 71 deletions.
24 changes: 24 additions & 0 deletions database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,30 @@ func GetStreamByTitle(db *sql.DB, title string) (s StreamInfo, err error) {
return s, nil
}

func GetStreamUrlByUrlAndIndex(db *sql.DB, url string, m3u_index int) (s StreamURL, err error) {
mutex.Lock()
defer mutex.Unlock()

rows, err := db.Query("SELECT id, content, m3u_index, max_concurrency FROM stream_urls WHERE content = ? AND m3u_index = ?", url, m3u_index)
if err != nil {
return s, fmt.Errorf("error querying streams: %v", err)
}
defer rows.Close()

for rows.Next() {
err = rows.Scan(&s.DbId, &s.Content, &s.M3UIndex, &s.MaxConcurrency)
if err != nil {
return s, fmt.Errorf("error scanning stream: %v", err)
}
}

if err := rows.Err(); err != nil {
return s, fmt.Errorf("error iterating over rows: %v", err)
}

return s, nil
}

func GetStreams(db *sql.DB) ([]StreamInfo, error) {
mutex.Lock()
defer mutex.Unlock()
Expand Down
196 changes: 135 additions & 61 deletions m3u/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ import (
"bufio"
"database/sql"
"fmt"
"io"
"log"
"net/http"
"os"
"regexp"
"strconv"
"strings"
"time"

"m3u-stream-merger/database"
)
Expand All @@ -20,6 +23,16 @@ func ParseM3UFromURL(db *sql.DB, m3uURL string, m3uIndex int, maxConcurrency int
userAgent = "IPTV Smarters/1.0.3 (iPad; iOS 16.6.1; Scale/2.00)"
}

maxRetries := 10
var err error
maxRetriesStr, maxRetriesExists := os.LookupEnv("MAX_RETRIES")
if !maxRetriesExists {
maxRetries, err = strconv.Atoi(maxRetriesStr)
if err != nil {
maxRetries = 10
}
}

// Create a new HTTP client with a custom User-Agent header
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
Expand All @@ -31,81 +44,142 @@ func ParseM3UFromURL(db *sql.DB, m3uURL string, m3uIndex int, maxConcurrency int

log.Printf("Parsing M3U from URL: %s\n", m3uURL)

resp, err := client.Get(m3uURL)
if err != nil {
return fmt.Errorf("HTTP GET error: %v", err)
}
defer resp.Body.Close()

scanner := bufio.NewScanner(resp.Body)
for i := 0; i <= maxRetries; i++ {
resp, err := client.Get(m3uURL)
if err != nil {
return fmt.Errorf("HTTP GET error: %v", err)
}
defer resp.Body.Close()

scanner := bufio.NewScanner(resp.Body)

var currentStream database.StreamInfo

for scanner.Scan() {
line := scanner.Text()
extInfLine := ""

if strings.HasPrefix(line, "#EXTINF:") {
currentStream = database.StreamInfo{}
extInfLine = line

lineWithoutPairs := line

// Define a regular expression to capture key-value pairs
regex := regexp.MustCompile(`([a-zA-Z0-9_-]+)=("[^"]+"|[^",]+)`)

// Find all key-value pairs in the line
matches := regex.FindAllStringSubmatch(line, -1)

for _, match := range matches {
key := strings.TrimSpace(match[1])
value := strings.TrimSpace(match[2])

if strings.HasPrefix(value, `"`) && strings.HasSuffix(value, `"`) {
value = strings.Trim(value, `"`)
}

switch key {
case "tvg-id":
currentStream.TvgID = value
case "tvg-name":
currentStream.Title = value
case "group-title":
currentStream.Group = value
case "tvg-logo":
currentStream.LogoURL = value
default:
if os.Getenv("DEBUG") == "true" {
log.Printf("Uncaught attribute: %s=%s\n", key, value)
}
}

var pair string
if strings.Contains(value, `"`) || strings.Contains(value, ",") {
// If the value contains double quotes or commas, format it as key="value"
pair = fmt.Sprintf(`%s="%s"`, key, value)
} else {
// Otherwise, format it as key=value
pair = fmt.Sprintf(`%s=%s`, key, value)
}
lineWithoutPairs = strings.Replace(lineWithoutPairs, pair, "", 1)
}

var currentStream database.StreamInfo
lineCommaSplit := strings.SplitN(lineWithoutPairs, ",", 2)

for scanner.Scan() {
line := scanner.Text()
if len(lineCommaSplit) > 1 {
currentStream.Title = strings.TrimSpace(lineCommaSplit[1])
}
} else if strings.HasPrefix(line, "#EXTVLCOPT:") {
// Extract logo URL from #EXTVLCOPT line
parts := strings.SplitN(line, "=", 2)
if len(parts) == 2 {
if os.Getenv("DEBUG") == "true" {
log.Printf("Uncaught attribute (#EXTVLCOPT): %s=%s\n", parts[0], parts[1])
}
}
} else if strings.HasPrefix(line, "http") {
if len(strings.TrimSpace(currentStream.Title)) == 0 {
log.Printf("Error capturing title, line will be skipped: %s\n", extInfLine)
continue
}

if strings.HasPrefix(line, "#EXTINF:") {
currentStream = database.StreamInfo{}
currentStream.Title = strings.TrimSpace(strings.SplitN(line, ",", 2)[1])
existingStream, err := database.GetStreamByTitle(db, currentStream.Title)
if err != nil {
return fmt.Errorf("GetStreamByTitle error (title: %s): %v", currentStream.Title, err)
}

// Define a regular expression to capture key-value pairs
regex := regexp.MustCompile(`(\S+?)="([^"]*?)"`)
var dbId int64
if existingStream.Title != currentStream.Title {
if os.Getenv("DEBUG") == "true" {
log.Printf("Creating new database entry: %s\n", currentStream.Title)
}
dbId, err = database.InsertStream(db, currentStream)
if err != nil {
return fmt.Errorf("InsertStream error (title: %s): %v", currentStream.Title, err)
}
} else {
if os.Getenv("DEBUG") == "true" {
log.Printf("Using existing database entry: %s\n", existingStream.Title)
}
dbId = existingStream.DbId
}

// Find all key-value pairs in the line
matches := regex.FindAllStringSubmatch(line, -1)
if os.Getenv("DEBUG") == "true" {
log.Printf("Adding MP4 url entry to %s: %s\n", currentStream.Title, line)
}

for _, match := range matches {
key := strings.TrimSpace(match[1])
value := strings.TrimSpace(match[2])
existingUrl, err := database.GetStreamUrlByUrlAndIndex(db, line, m3uIndex)
if err != nil {
return fmt.Errorf("GetStreamUrlByUrlAndIndex error (url: %s): %v", line, err)
}

switch key {
case "tvg-id":
currentStream.TvgID = value
case "group-title":
currentStream.Group = value
case "tvg-logo":
currentStream.LogoURL = value
if existingUrl.Content != line || existingUrl.M3UIndex != m3uIndex {
_, err = database.InsertStreamUrl(db, dbId, database.StreamURL{
Content: line,
M3UIndex: m3uIndex,
MaxConcurrency: maxConcurrency,
})
}
}
} else if strings.HasPrefix(line, "#EXTVLCOPT:") {
// Extract logo URL from #EXTVLCOPT line
parts := strings.SplitN(line, "=", 2)
if len(parts) == 2 {
currentStream.LogoURL = parts[1]
}
} else if strings.HasPrefix(line, "http") {
existingStream, err := database.GetStreamByTitle(db, currentStream.Title)
if err != nil {
return fmt.Errorf("GetStreamByTitle error (title: %s): %v", currentStream.Title, err)
}

var dbId int64
if existingStream.Title != currentStream.Title {
log.Printf("Creating new database entry: %s", currentStream.Title)
dbId, err = database.InsertStream(db, currentStream)
if err != nil {
return fmt.Errorf("InsertStream error (title: %s): %v", currentStream.Title, err)
return fmt.Errorf("InsertStreamUrl error (title: %s): %v", currentStream.Title, err)
}
} else {
log.Printf("Using existing database entry: %s", existingStream.Title)
dbId = existingStream.DbId
}
}

log.Printf("Adding MP4 url entry to %s: %s", currentStream.Title, line)
_, err = database.InsertStreamUrl(db, dbId, database.StreamURL{
Content: line,
M3UIndex: m3uIndex,
MaxConcurrency: maxConcurrency,
})
if err != nil {
return fmt.Errorf("InsertStreamUrl error (title: %s): %v", currentStream.Title, err)
}
if scanner.Err() == io.EOF {
// Unexpected EOF, retry
log.Printf("Unexpected EOF. Retrying in 5 secs... (url: %s)\n", m3uURL)
time.Sleep(5 * time.Second)
continue
}
}

if err := scanner.Err(); err != nil {
return fmt.Errorf("scanner error: %v", err)
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("scanner error: %v", err)
}

return nil
return nil
}
return fmt.Errorf("Max retries reached without success. Failed to fetch %s\n", m3uURL)
}
36 changes: 26 additions & 10 deletions mp4_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"m3u-stream-merger/database"
"m3u-stream-merger/utils"
"net/http"
"os"
"strconv"
"strings"
"syscall"
Expand All @@ -20,7 +21,7 @@ import (
func loadBalancer(ctx context.Context, stream database.StreamInfo) (resp *http.Response, selectedUrl *database.StreamURL, err error) {
// Concurrency check mode
for _, url := range stream.URLs {
if checkConcurrency(ctx, url.Content, url.MaxConcurrency) {
if checkConcurrency(ctx, url.M3UIndex) {
log.Printf("Concurrency limit reached (%d): %s", url.MaxConcurrency, url.Content)
continue // Skip this stream if concurrency limit reached
}
Expand Down Expand Up @@ -105,7 +106,7 @@ func mp4Handler(w http.ResponseWriter, r *http.Request, db *sql.DB) {
return
}
log.Printf("Proxying %s to %s\n", r.RemoteAddr, selectedUrl.Content)
updateConcurrency(ctx, selectedUrl.Content, true)
updateConcurrency(ctx, selectedUrl.M3UIndex, true)

// Log the successful response
log.Printf("Sent MP4 stream to %s\n", r.RemoteAddr)
Expand All @@ -115,7 +116,7 @@ func mp4Handler(w http.ResponseWriter, r *http.Request, db *sql.DB) {
case <-ctx.Done():
// Connection closed, handle accordingly
log.Println("Client disconnected after fetching MP4 stream")
updateConcurrency(ctx, selectedUrl.Content, false)
updateConcurrency(ctx, selectedUrl.M3UIndex, false)
return
default:
// Connection still open, proceed with writing to the response
Expand All @@ -124,7 +125,7 @@ func mp4Handler(w http.ResponseWriter, r *http.Request, db *sql.DB) {
// Log the error
if errors.Is(err, syscall.EPIPE) {
log.Println("Client disconnected after fetching MP4 stream")
updateConcurrency(ctx, selectedUrl.Content, false)
updateConcurrency(ctx, selectedUrl.M3UIndex, false)
} else {
log.Printf("Error copying MP4 stream to response: %s\n", err.Error())
}
Expand All @@ -133,27 +134,42 @@ func mp4Handler(w http.ResponseWriter, r *http.Request, db *sql.DB) {
}
}

func checkConcurrency(ctx context.Context, url string, maxConcurrency int) bool {
func checkConcurrency(ctx context.Context, m3uIndex int) bool {
maxConcurrency := 1
var err error
rawMaxConcurrency, maxConcurrencyExists := os.LookupEnv(fmt.Sprintf("M3U_MAX_CONCURRENCY_%d", m3uIndex))
if maxConcurrencyExists {
maxConcurrency, err = strconv.Atoi(rawMaxConcurrency)
if err != nil {
maxConcurrency = 1
}
}

redisClient := database.InitializeRedis()
val, err := redisClient.Get(ctx, url).Result()
val, err := redisClient.Get(ctx, fmt.Sprintf("m3u_%d", m3uIndex)).Result()
if err == redis.Nil {
return false // Key does not exist
} else if err != nil {
log.Printf("Error checking concurrency: %s\n", err.Error())
return false // Error occurred, treat as concurrency not reached
}

count, _ := strconv.Atoi(val)
count, err := strconv.Atoi(val)
if err != nil {
count = 0
}

log.Printf("Current concurrent connections for M3U_%d: %d", m3uIndex, count)
return count >= maxConcurrency
}

func updateConcurrency(ctx context.Context, url string, incr bool) {
func updateConcurrency(ctx context.Context, m3uIndex int, incr bool) {
redisClient := database.InitializeRedis()
var err error
if incr {
err = redisClient.Incr(ctx, url).Err()
err = redisClient.Incr(ctx, fmt.Sprintf("m3u_%d", m3uIndex)).Err()
} else {
err = redisClient.Decr(ctx, url).Err()
err = redisClient.Decr(ctx, fmt.Sprintf("m3u_%d", m3uIndex)).Err()
}
if err != nil {
log.Printf("Error updating concurrency: %s\n", err.Error())
Expand Down

0 comments on commit 8c6dc91

Please sign in to comment.