diff --git a/database/db.go b/database/db.go index 5be363b3..999c29db 100644 --- a/database/db.go +++ b/database/db.go @@ -345,6 +345,30 @@ func GetStreamByTitle(db *sql.DB, title string) (s StreamInfo, err error) { return s, nil } +func GetStreamUrlByUrlAndIndex(db *sql.DB, url string, m3u_index int) (s StreamURL, err error) { + mutex.Lock() + defer mutex.Unlock() + + rows, err := db.Query("SELECT id, content, m3u_index, max_concurrency FROM stream_urls WHERE content = ? AND m3u_index = ?", url, m3u_index) + if err != nil { + return s, fmt.Errorf("error querying streams: %v", err) + } + defer rows.Close() + + for rows.Next() { + err = rows.Scan(&s.DbId, &s.Content, &s.M3UIndex, &s.MaxConcurrency) + if err != nil { + return s, fmt.Errorf("error scanning stream: %v", err) + } + } + + if err := rows.Err(); err != nil { + return s, fmt.Errorf("error iterating over rows: %v", err) + } + + return s, nil +} + func GetStreams(db *sql.DB) ([]StreamInfo, error) { mutex.Lock() defer mutex.Unlock() diff --git a/m3u/parser.go b/m3u/parser.go index e0a4bc63..ffe7cbf2 100644 --- a/m3u/parser.go +++ b/m3u/parser.go @@ -4,11 +4,14 @@ import ( "bufio" "database/sql" "fmt" + "io" "log" "net/http" "os" "regexp" + "strconv" "strings" + "time" "m3u-stream-merger/database" ) @@ -20,6 +23,16 @@ func ParseM3UFromURL(db *sql.DB, m3uURL string, m3uIndex int, maxConcurrency int userAgent = "IPTV Smarters/1.0.3 (iPad; iOS 16.6.1; Scale/2.00)" } + maxRetries := 10 + var err error + maxRetriesStr, maxRetriesExists := os.LookupEnv("MAX_RETRIES") + if !maxRetriesExists { + maxRetries, err = strconv.Atoi(maxRetriesStr) + if err != nil { + maxRetries = 10 + } + } + // Create a new HTTP client with a custom User-Agent header client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { @@ -31,81 +44,142 @@ func ParseM3UFromURL(db *sql.DB, m3uURL string, m3uIndex int, maxConcurrency int log.Printf("Parsing M3U from URL: %s\n", m3uURL) - resp, err := client.Get(m3uURL) - if err != nil { - return fmt.Errorf("HTTP GET error: %v", err) - } - defer resp.Body.Close() - - scanner := bufio.NewScanner(resp.Body) + for i := 0; i <= maxRetries; i++ { + resp, err := client.Get(m3uURL) + if err != nil { + return fmt.Errorf("HTTP GET error: %v", err) + } + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + + var currentStream database.StreamInfo + + for scanner.Scan() { + line := scanner.Text() + extInfLine := "" + + if strings.HasPrefix(line, "#EXTINF:") { + currentStream = database.StreamInfo{} + extInfLine = line + + lineWithoutPairs := line + + // Define a regular expression to capture key-value pairs + regex := regexp.MustCompile(`([a-zA-Z0-9_-]+)=("[^"]+"|[^",]+)`) + + // Find all key-value pairs in the line + matches := regex.FindAllStringSubmatch(line, -1) + + for _, match := range matches { + key := strings.TrimSpace(match[1]) + value := strings.TrimSpace(match[2]) + + if strings.HasPrefix(value, `"`) && strings.HasSuffix(value, `"`) { + value = strings.Trim(value, `"`) + } + + switch key { + case "tvg-id": + currentStream.TvgID = value + case "tvg-name": + currentStream.Title = value + case "group-title": + currentStream.Group = value + case "tvg-logo": + currentStream.LogoURL = value + default: + if os.Getenv("DEBUG") == "true" { + log.Printf("Uncaught attribute: %s=%s\n", key, value) + } + } + + var pair string + if strings.Contains(value, `"`) || strings.Contains(value, ",") { + // If the value contains double quotes or commas, format it as key="value" + pair = fmt.Sprintf(`%s="%s"`, key, value) + } else { + // Otherwise, format it as key=value + pair = fmt.Sprintf(`%s=%s`, key, value) + } + lineWithoutPairs = strings.Replace(lineWithoutPairs, pair, "", 1) + } - var currentStream database.StreamInfo + lineCommaSplit := strings.SplitN(lineWithoutPairs, ",", 2) - for scanner.Scan() { - line := scanner.Text() + if len(lineCommaSplit) > 1 { + currentStream.Title = strings.TrimSpace(lineCommaSplit[1]) + } + } else if strings.HasPrefix(line, "#EXTVLCOPT:") { + // Extract logo URL from #EXTVLCOPT line + parts := strings.SplitN(line, "=", 2) + if len(parts) == 2 { + if os.Getenv("DEBUG") == "true" { + log.Printf("Uncaught attribute (#EXTVLCOPT): %s=%s\n", parts[0], parts[1]) + } + } + } else if strings.HasPrefix(line, "http") { + if len(strings.TrimSpace(currentStream.Title)) == 0 { + log.Printf("Error capturing title, line will be skipped: %s\n", extInfLine) + continue + } - if strings.HasPrefix(line, "#EXTINF:") { - currentStream = database.StreamInfo{} - currentStream.Title = strings.TrimSpace(strings.SplitN(line, ",", 2)[1]) + existingStream, err := database.GetStreamByTitle(db, currentStream.Title) + if err != nil { + return fmt.Errorf("GetStreamByTitle error (title: %s): %v", currentStream.Title, err) + } - // Define a regular expression to capture key-value pairs - regex := regexp.MustCompile(`(\S+?)="([^"]*?)"`) + var dbId int64 + if existingStream.Title != currentStream.Title { + if os.Getenv("DEBUG") == "true" { + log.Printf("Creating new database entry: %s\n", currentStream.Title) + } + dbId, err = database.InsertStream(db, currentStream) + if err != nil { + return fmt.Errorf("InsertStream error (title: %s): %v", currentStream.Title, err) + } + } else { + if os.Getenv("DEBUG") == "true" { + log.Printf("Using existing database entry: %s\n", existingStream.Title) + } + dbId = existingStream.DbId + } - // Find all key-value pairs in the line - matches := regex.FindAllStringSubmatch(line, -1) + if os.Getenv("DEBUG") == "true" { + log.Printf("Adding MP4 url entry to %s: %s\n", currentStream.Title, line) + } - for _, match := range matches { - key := strings.TrimSpace(match[1]) - value := strings.TrimSpace(match[2]) + existingUrl, err := database.GetStreamUrlByUrlAndIndex(db, line, m3uIndex) + if err != nil { + return fmt.Errorf("GetStreamUrlByUrlAndIndex error (url: %s): %v", line, err) + } - switch key { - case "tvg-id": - currentStream.TvgID = value - case "group-title": - currentStream.Group = value - case "tvg-logo": - currentStream.LogoURL = value + if existingUrl.Content != line || existingUrl.M3UIndex != m3uIndex { + _, err = database.InsertStreamUrl(db, dbId, database.StreamURL{ + Content: line, + M3UIndex: m3uIndex, + MaxConcurrency: maxConcurrency, + }) } - } - } else if strings.HasPrefix(line, "#EXTVLCOPT:") { - // Extract logo URL from #EXTVLCOPT line - parts := strings.SplitN(line, "=", 2) - if len(parts) == 2 { - currentStream.LogoURL = parts[1] - } - } else if strings.HasPrefix(line, "http") { - existingStream, err := database.GetStreamByTitle(db, currentStream.Title) - if err != nil { - return fmt.Errorf("GetStreamByTitle error (title: %s): %v", currentStream.Title, err) - } - var dbId int64 - if existingStream.Title != currentStream.Title { - log.Printf("Creating new database entry: %s", currentStream.Title) - dbId, err = database.InsertStream(db, currentStream) if err != nil { - return fmt.Errorf("InsertStream error (title: %s): %v", currentStream.Title, err) + return fmt.Errorf("InsertStreamUrl error (title: %s): %v", currentStream.Title, err) } - } else { - log.Printf("Using existing database entry: %s", existingStream.Title) - dbId = existingStream.DbId } + } - log.Printf("Adding MP4 url entry to %s: %s", currentStream.Title, line) - _, err = database.InsertStreamUrl(db, dbId, database.StreamURL{ - Content: line, - M3UIndex: m3uIndex, - MaxConcurrency: maxConcurrency, - }) - if err != nil { - return fmt.Errorf("InsertStreamUrl error (title: %s): %v", currentStream.Title, err) - } + if scanner.Err() == io.EOF { + // Unexpected EOF, retry + log.Printf("Unexpected EOF. Retrying in 5 secs... (url: %s)\n", m3uURL) + time.Sleep(5 * time.Second) + continue } - } - if err := scanner.Err(); err != nil { - return fmt.Errorf("scanner error: %v", err) - } + if err := scanner.Err(); err != nil { + return fmt.Errorf("scanner error: %v", err) + } - return nil + return nil + } + return fmt.Errorf("Max retries reached without success. Failed to fetch %s\n", m3uURL) } diff --git a/mp4_handler.go b/mp4_handler.go index 7b7b23f5..5323c9ce 100644 --- a/mp4_handler.go +++ b/mp4_handler.go @@ -10,6 +10,7 @@ import ( "m3u-stream-merger/database" "m3u-stream-merger/utils" "net/http" + "os" "strconv" "strings" "syscall" @@ -20,7 +21,7 @@ import ( func loadBalancer(ctx context.Context, stream database.StreamInfo) (resp *http.Response, selectedUrl *database.StreamURL, err error) { // Concurrency check mode for _, url := range stream.URLs { - if checkConcurrency(ctx, url.Content, url.MaxConcurrency) { + if checkConcurrency(ctx, url.M3UIndex) { log.Printf("Concurrency limit reached (%d): %s", url.MaxConcurrency, url.Content) continue // Skip this stream if concurrency limit reached } @@ -105,7 +106,7 @@ func mp4Handler(w http.ResponseWriter, r *http.Request, db *sql.DB) { return } log.Printf("Proxying %s to %s\n", r.RemoteAddr, selectedUrl.Content) - updateConcurrency(ctx, selectedUrl.Content, true) + updateConcurrency(ctx, selectedUrl.M3UIndex, true) // Log the successful response log.Printf("Sent MP4 stream to %s\n", r.RemoteAddr) @@ -115,7 +116,7 @@ func mp4Handler(w http.ResponseWriter, r *http.Request, db *sql.DB) { case <-ctx.Done(): // Connection closed, handle accordingly log.Println("Client disconnected after fetching MP4 stream") - updateConcurrency(ctx, selectedUrl.Content, false) + updateConcurrency(ctx, selectedUrl.M3UIndex, false) return default: // Connection still open, proceed with writing to the response @@ -124,7 +125,7 @@ func mp4Handler(w http.ResponseWriter, r *http.Request, db *sql.DB) { // Log the error if errors.Is(err, syscall.EPIPE) { log.Println("Client disconnected after fetching MP4 stream") - updateConcurrency(ctx, selectedUrl.Content, false) + updateConcurrency(ctx, selectedUrl.M3UIndex, false) } else { log.Printf("Error copying MP4 stream to response: %s\n", err.Error()) } @@ -133,9 +134,19 @@ func mp4Handler(w http.ResponseWriter, r *http.Request, db *sql.DB) { } } -func checkConcurrency(ctx context.Context, url string, maxConcurrency int) bool { +func checkConcurrency(ctx context.Context, m3uIndex int) bool { + maxConcurrency := 1 + var err error + rawMaxConcurrency, maxConcurrencyExists := os.LookupEnv(fmt.Sprintf("M3U_MAX_CONCURRENCY_%d", m3uIndex)) + if maxConcurrencyExists { + maxConcurrency, err = strconv.Atoi(rawMaxConcurrency) + if err != nil { + maxConcurrency = 1 + } + } + redisClient := database.InitializeRedis() - val, err := redisClient.Get(ctx, url).Result() + val, err := redisClient.Get(ctx, fmt.Sprintf("m3u_%d", m3uIndex)).Result() if err == redis.Nil { return false // Key does not exist } else if err != nil { @@ -143,17 +154,22 @@ func checkConcurrency(ctx context.Context, url string, maxConcurrency int) bool return false // Error occurred, treat as concurrency not reached } - count, _ := strconv.Atoi(val) + count, err := strconv.Atoi(val) + if err != nil { + count = 0 + } + + log.Printf("Current concurrent connections for M3U_%d: %d", m3uIndex, count) return count >= maxConcurrency } -func updateConcurrency(ctx context.Context, url string, incr bool) { +func updateConcurrency(ctx context.Context, m3uIndex int, incr bool) { redisClient := database.InitializeRedis() var err error if incr { - err = redisClient.Incr(ctx, url).Err() + err = redisClient.Incr(ctx, fmt.Sprintf("m3u_%d", m3uIndex)).Err() } else { - err = redisClient.Decr(ctx, url).Err() + err = redisClient.Decr(ctx, fmt.Sprintf("m3u_%d", m3uIndex)).Err() } if err != nil { log.Printf("Error updating concurrency: %s\n", err.Error())