diff --git a/Dockerfile b/Dockerfile index e1f1da4e..01fe443c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -30,7 +30,9 @@ FROM alpine:latest # add bash and timezone data # hadolint ignore=DL3018 -RUN apk --no-cache add tzdata +RUN apk --no-cache add tzdata \ + ca-certificates \ + && update-ca-certificates # set the current workdir WORKDIR /app diff --git a/database/database_test.go b/database/database_test.go index 40d5fca4..04739fa1 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -1,16 +1,15 @@ package database import ( + "os" "testing" ) func TestSaveAndLoadFromDb(t *testing.T) { // Test InitializeDb and check if the database file exists - REDIS_ADDR := "127.0.0.1:6379" - REDIS_PASS := "" - REDIS_DB := 1 + os.Setenv("REDIS_DB", "1") - db, err := InitializeDb(REDIS_ADDR, REDIS_PASS, REDIS_DB) + db, err := InitializeDb() if err != nil { t.Errorf("InitializeDb returned error: %v", err) } @@ -21,7 +20,7 @@ func TestSaveAndLoadFromDb(t *testing.T) { } // Test LoadFromDb with existing data in the database - expected := []StreamInfo{{ + expected := []*StreamInfo{{ Slug: "stream1", Title: "stream1", TvgID: "test1", @@ -54,7 +53,7 @@ func TestSaveAndLoadFromDb(t *testing.T) { } for i, expectedStream := range expected { - if !streamInfoEqual(result[i], expectedStream) { + if !streamInfoEqual(result[i], *expectedStream) { t.Errorf("GetStreams returned %+v, expected %+v", result[i], expectedStream) } } @@ -64,11 +63,6 @@ func TestSaveAndLoadFromDb(t *testing.T) { t.Errorf("DeleteStreamBySlug returned error: %v", err) } - err = db.DeleteStreamURL(expected[0], 0) - if err != nil { - t.Errorf("DeleteStreamURL returned error: %v", err) - } - streamChan = db.GetStreams() result = []StreamInfo{} @@ -77,14 +71,13 @@ func TestSaveAndLoadFromDb(t *testing.T) { } expected = expected[:1] - expected[0].URLs = map[int]string{} if len(result) != len(expected) { t.Errorf("GetStreams returned %+v, expected %+v", result, expected) } for i, expectedStream := range expected { - if !streamInfoEqual(result[i], expectedStream) { + if !streamInfoEqual(result[i], *expectedStream) { t.Errorf("GetStreams returned %+v, expected %+v", result[i], expectedStream) } } diff --git a/database/db.go b/database/db.go index c57b4384..74ba14ca 100644 --- a/database/db.go +++ b/database/db.go @@ -2,6 +2,7 @@ package database import ( "context" + "encoding/json" "fmt" "log" "m3u-stream-merger/utils" @@ -19,7 +20,14 @@ type Instance struct { Ctx context.Context } -func InitializeDb(addr string, password string, db int) (*Instance, error) { +func InitializeDb() (*Instance, error) { + addr := os.Getenv("REDIS_ADDR") + password := os.Getenv("REDIS_PASS") + db := 0 + if i, err := strconv.Atoi(os.Getenv("REDIS_DB")); err == nil { + db = i + } + var redisOptions *redis.Options if password == "" { @@ -58,39 +66,27 @@ func (db *Instance) ClearDb() error { return nil } -func (db *Instance) SaveToDb(streams []StreamInfo) error { +func (db *Instance) SaveToDb(streams []*StreamInfo) error { var debug = os.Getenv("DEBUG") == "true" pipeline := db.Redis.Pipeline() for _, s := range streams { streamKey := fmt.Sprintf("stream:%s", s.Slug) - streamData := map[string]interface{}{ - "title": s.Title, - "tvg_id": s.TvgID, - "tvg_chno": s.TvgChNo, - "logo_url": s.LogoURL, - "group_name": s.Group, + + streamDataJson, err := json.Marshal(s) + if err != nil { + return fmt.Errorf("SaveToDb error: %v", err) } if debug { - utils.SafeLogPrintf(nil, nil, "[DEBUG] Preparing to set data for stream key %s: %v\n", streamKey, streamData) + utils.SafeLogPrintf(nil, nil, "[DEBUG] Preparing to set data for stream key %s: %v\n", streamKey, s) } - pipeline.HSet(db.Ctx, streamKey, streamData) - - for index, u := range s.URLs { - streamURLKey := fmt.Sprintf("stream:%s:url:%d", s.Slug, index) - - if debug { - utils.SafeLogPrintf(nil, nil, "[DEBUG] Preparing to set URL for key %s: %s\n", streamURLKey, u) - } - - pipeline.Set(db.Ctx, streamURLKey, u, 0) - } + pipeline.Set(db.Ctx, streamKey, string(streamDataJson), 0) // Add to the sorted set - sortScore := calculateSortScore(s) + sortScore := calculateSortScore(*s) if debug { utils.SafeLogPrintf(nil, nil, "[DEBUG] Adding to sorted set with score %f and member %s\n", sortScore, streamKey) @@ -122,26 +118,6 @@ func (db *Instance) SaveToDb(streams []StreamInfo) error { func (db *Instance) DeleteStreamBySlug(slug string) error { streamKey := fmt.Sprintf("stream:%s", slug) - // Delete associated URLs - cursor := uint64(0) - for { - keys, newCursor, err := db.Redis.Scan(db.Ctx, cursor, fmt.Sprintf("%s:url:*", streamKey), 10).Result() - if err != nil { - return fmt.Errorf("error scanning associated URLs: %v", err) - } - - for _, key := range keys { - if err := db.Redis.Del(db.Ctx, key).Err(); err != nil { - return fmt.Errorf("error deleting stream URL from Redis: %v", err) - } - } - - cursor = newCursor - if cursor == 0 { - break - } - } - // Delete from the sorted set if err := db.Redis.ZRem(db.Ctx, "streams_sorted", streamKey).Err(); err != nil { return fmt.Errorf("error removing stream from sorted set: %v", err) @@ -155,71 +131,25 @@ func (db *Instance) DeleteStreamBySlug(slug string) error { return nil } -func (db *Instance) DeleteStreamURL(s StreamInfo, m3uIndex int) error { - if err := db.Redis.Del(db.Ctx, fmt.Sprintf("stream:%s:url:%d", s.Slug, m3uIndex)).Err(); err != nil { - return fmt.Errorf("error deleting stream URL from Redis: %v", err) - } - - return nil -} - func (db *Instance) GetStreamBySlug(slug string) (StreamInfo, error) { streamKey := fmt.Sprintf("stream:%s", slug) - streamData, err := db.Redis.HGetAll(db.Ctx, streamKey).Result() + streamDataJson, err := db.Redis.Get(db.Ctx, streamKey).Result() if err != nil { return StreamInfo{}, fmt.Errorf("error getting stream from Redis: %v", err) } - if len(streamData) == 0 { - return StreamInfo{}, fmt.Errorf("stream not found: %s", slug) - } + stream := StreamInfo{} - s := StreamInfo{ - Slug: slug, - Title: streamData["title"], - TvgID: streamData["tvg_id"], - TvgChNo: streamData["tvg_chno"], - LogoURL: streamData["logo_url"], - Group: streamData["group_name"], - URLs: map[int]string{}, + err = json.Unmarshal([]byte(streamDataJson), &stream) + if err != nil { + return StreamInfo{}, fmt.Errorf("error getting stream: %v", err) } - cursor := uint64(0) - for { - keys, newCursor, err := db.Redis.Scan(db.Ctx, cursor, fmt.Sprintf("%s:url:*", streamKey), 10).Result() - if err != nil { - return s, fmt.Errorf("error finding URLs for stream: %v", err) - } - - if len(keys) > 0 { - results, err := db.Redis.Pipelined(db.Ctx, func(pipe redis.Pipeliner) error { - for _, key := range keys { - pipe.Get(db.Ctx, key) - } - return nil - }) - if err != nil { - return s, fmt.Errorf("error getting URL data from Redis: %v", err) - } - - for i, result := range results { - urlData := result.(*redis.StringCmd).Val() - - m3uIndex, err := strconv.Atoi(extractM3UIndex(keys[i])) - if err != nil { - return s, fmt.Errorf("m3u index is not an integer: %v", err) - } - s.URLs[m3uIndex] = urlData - } - } - - cursor = newCursor - if cursor == 0 { - break - } + if strings.TrimSpace(stream.Title) == "" { + return StreamInfo{}, fmt.Errorf("stream not found: %s", slug) } - return s, nil + return stream, nil } func (db *Instance) GetStreams() <-chan StreamInfo { @@ -239,52 +169,26 @@ func (db *Instance) GetStreams() <-chan StreamInfo { // Filter out URL keys for _, key := range keys { - if !strings.Contains(key, ":url:") { - streamData, err := db.Redis.HGetAll(db.Ctx, key).Result() - if err != nil { - utils.SafeLogPrintf(nil, nil, "error retrieving stream data: %v", err) - return - } - - slug := extractSlug(key) - stream := StreamInfo{ - Slug: slug, - Title: streamData["title"], - TvgID: streamData["tvg_id"], - TvgChNo: streamData["tvg_chno"], - LogoURL: streamData["logo_url"], - Group: streamData["group_name"], - URLs: map[int]string{}, - } - - if debug { - utils.SafeLogPrintf(nil, nil, "[DEBUG] Processing stream: %v\n", stream) - } - - urlKeys, err := db.Redis.Keys(db.Ctx, fmt.Sprintf("%s:url:*", key)).Result() - if err != nil { - utils.SafeLogPrintf(nil, nil, "error finding URLs for stream: %v", err) - return - } - - for _, urlKey := range urlKeys { - urlData, err := db.Redis.Get(db.Ctx, urlKey).Result() - if err != nil { - utils.SafeLogPrintf(nil, nil, "error getting URL data from Redis: %v", err) - return - } - - m3uIndex, err := strconv.Atoi(extractM3UIndex(urlKey)) - if err != nil { - utils.SafeLogPrintf(nil, nil, "m3u index is not an integer: %v", err) - return - } - stream.URLs[m3uIndex] = urlData - } - - // Send the stream to the channel - streamChan <- stream + streamDataJson, err := db.Redis.Get(db.Ctx, key).Result() + if err != nil { + utils.SafeLogPrintf(nil, nil, "error retrieving stream data: %v", err) + return + } + + stream := StreamInfo{} + + err = json.Unmarshal([]byte(streamDataJson), &stream) + if err != nil { + utils.SafeLogPrintf(nil, nil, "error retrieving stream data: %v", err) + return + } + + if debug { + utils.SafeLogPrintf(nil, nil, "[DEBUG] Processing stream: %v\n", stream) } + + // Send the stream to the channel + streamChan <- stream } if debug { @@ -353,22 +257,6 @@ func (db *Instance) ClearConcurrencies() error { return nil } -func extractSlug(key string) string { - parts := strings.Split(key, ":") - if len(parts) > 1 { - return parts[1] - } - return "" -} - -func extractM3UIndex(key string) string { - parts := strings.Split(key, ":") - if len(parts) > 1 { - return parts[3] - } - return "" -} - func getSortingValue(s StreamInfo) string { key := os.Getenv("SORTING_KEY") diff --git a/database/types.go b/database/types.go index 47455046..7723bac7 100644 --- a/database/types.go +++ b/database/types.go @@ -1,11 +1,11 @@ package database type StreamInfo struct { - Slug string - Title string - TvgID string - TvgChNo string - LogoURL string - Group string - URLs map[int]string + Slug string `json:"slug"` + Title string `json:"title"` + TvgID string `json:"tvg_id"` + TvgChNo string `json:"tvg_chno"` + LogoURL string `json:"logo_url"` + Group string `json:"group_name"` + URLs map[int]string `json:"urls"` } diff --git a/go.mod b/go.mod index a940a31f..6adbd815 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module m3u-stream-merger -go 1.22.1 +go 1.23.0 require ( github.com/gosimple/slug v1.14.0 diff --git a/m3u/generate.go b/m3u/generate.go index 87c07f5d..6df94edd 100644 --- a/m3u/generate.go +++ b/m3u/generate.go @@ -72,7 +72,7 @@ func GenerateAndCacheM3UContent(db *database.Instance, r *http.Request) string { log.Println("[DEBUG] Regenerating M3U cache in the background") } - baseUrl := determineBaseURL(r) + baseUrl := utils.DetermineBaseURL(r) if debug { utils.SafeLogPrintf(r, nil, "[DEBUG] Base URL set to %s\n", baseUrl) @@ -111,23 +111,11 @@ func GenerateAndCacheM3UContent(db *database.Instance, r *http.Request) string { return content.String() } -func determineBaseURL(r *http.Request) string { - if r != nil { - if r.TLS == nil { - return fmt.Sprintf("http://%s/stream", r.Host) - } else { - return fmt.Sprintf("https://%s/stream", r.Host) - } - } - - if customBase, ok := os.LookupEnv("BASE_URL"); ok { - return fmt.Sprintf("%s/stream", strings.TrimSuffix(customBase, "/")) +func Handler(w http.ResponseWriter, r *http.Request) { + db, err := database.InitializeDb() + if err != nil { + log.Fatalf("Error initializing Redis database: %v", err) } - - return "" -} - -func Handler(w http.ResponseWriter, r *http.Request, db *database.Instance) { debug := isDebugMode() if debug { diff --git a/m3u/m3u_test.go b/m3u/m3u_test.go index b24e600e..089b1b61 100644 --- a/m3u/m3u_test.go +++ b/m3u/m3u_test.go @@ -2,9 +2,11 @@ package m3u import ( "fmt" + "maps" "net/http" "net/http/httptest" "os" + "slices" "testing" "m3u-stream-merger/database" @@ -21,12 +23,9 @@ func TestGenerateM3UContent(t *testing.T) { URLs: map[int]string{0: "http://example.com/stream"}, } - // Test InitializeSQLite and check if the database file exists - REDIS_ADDR := "127.0.0.1:6379" - REDIS_PASS := "" - REDIS_DB := 2 + os.Setenv("REDIS_DB", "2") - db, err := database.InitializeDb(REDIS_ADDR, REDIS_PASS, REDIS_DB) + db, err := database.InitializeDb() if err != nil { t.Errorf("InitializeDb returned error: %v", err) } @@ -36,7 +35,7 @@ func TestGenerateM3UContent(t *testing.T) { t.Errorf("ClearDb returned error: %v", err) } - err = db.SaveToDb([]database.StreamInfo{stream}) + err = db.SaveToDb([]*database.StreamInfo{&stream}) if err != nil { t.Fatal(err) } @@ -50,7 +49,7 @@ func TestGenerateM3UContent(t *testing.T) { // Create a ResponseRecorder to record the response rr := httptest.NewRecorder() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - Handler(w, r, db) + Handler(w, r) }) // Call the ServeHTTP method of the handler to execute the test @@ -106,12 +105,8 @@ http://example.com/fox t.Errorf("WriteFile returned error: %v", err) } - // Test InitializeSQLite and check if the database file exists - REDIS_ADDR := "127.0.0.1:6379" - REDIS_PASS := "" - REDIS_DB := 3 - - db, err := database.InitializeDb(REDIS_ADDR, REDIS_PASS, REDIS_DB) + os.Setenv("REDIS_DB", "3") + db, err := database.InitializeDb() if err != nil { t.Errorf("InitializeDb returned error: %v", err) } @@ -121,18 +116,25 @@ http://example.com/fox t.Errorf("ClearDb returned error: %v", err) } + tmpStore := map[string]*database.StreamInfo{} + // Test the parseM3UFromURL function with the mock server URL - err = ParseM3UFromURL(db, mockServer.URL, 0) + err = ParseM3UFromURL(tmpStore, mockServer.URL, 0) if err != nil { t.Errorf("Error parsing M3U from URL: %v", err) } // Test the parseM3UFromURL function with the mock server URL - err = ParseM3UFromURL(db, fmt.Sprintf("file://%s", mockFile), 1) + err = ParseM3UFromURL(tmpStore, fmt.Sprintf("file://%s", mockFile), 1) if err != nil { t.Errorf("Error parsing M3U from URL: %v", err) } + err = db.SaveToDb(slices.Collect(maps.Values(tmpStore))) + if err != nil { + t.Errorf("Error store to db: %v", err) + } + // Verify expected values expectedStreams := []database.StreamInfo{ {Slug: "bbc-one", Title: "BBC One", TvgChNo: "0.0", TvgID: "bbc1", Group: "UK", URLs: map[int]string{0: "http://example.com/bbc1", 1: "http://example.com/bbc1"}}, diff --git a/m3u/parser.go b/m3u/parser.go index b8e986a1..385a2e92 100644 --- a/m3u/parser.go +++ b/m3u/parser.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "log" + "maps" "os" "regexp" "strconv" @@ -95,7 +96,7 @@ func checkIncludeGroup(groups []string, line string) bool { } else { for _, group := range groups { toMatch := "group-title=" + "\"" + group + "\"" - if strings.Contains(line, toMatch) { + if strings.Contains(strings.ToLower(line), toMatch) { if debug { utils.SafeLogPrintf(nil, nil, "[DEBUG] Line matches group: %s\n", group) } @@ -152,7 +153,7 @@ func downloadM3UToBuffer(m3uURL string, buffer *bytes.Buffer) (err error) { return nil } -func ParseM3UFromURL(db *database.Instance, m3uURL string, m3uIndex int) error { +func ParseM3UFromURL(streams map[string]*database.StreamInfo, m3uURL string, m3uIndex int) error { debug := os.Getenv("DEBUG") == "true" maxRetries := 10 @@ -196,7 +197,7 @@ func ParseM3UFromURL(db *database.Instance, m3uURL string, m3uIndex int) error { var wg sync.WaitGroup parserWorkers := os.Getenv("PARSER_WORKERS") - if strings.TrimSpace(parserWorkers) != "" { + if strings.TrimSpace(parserWorkers) == "" { parserWorkers = "5" } @@ -211,7 +212,6 @@ func ParseM3UFromURL(db *database.Instance, m3uURL string, m3uIndex int) error { numWorkers = 5 } - var streamInfos []database.StreamInfo var mu sync.Mutex for w := 0; w < numWorkers; w++ { @@ -223,7 +223,12 @@ func ParseM3UFromURL(db *database.Instance, m3uURL string, m3uIndex int) error { utils.SafeLogPrintf(nil, nil, "[DEBUG] Worker processing stream info: %s\n", streamInfo.Slug) } mu.Lock() - streamInfos = append(streamInfos, streamInfo) + _, ok := streams[streamInfo.Title] + if !ok { + streams[streamInfo.Title] = &streamInfo + } else { + maps.Copy(streams[streamInfo.Title].URLs, streamInfo.URLs) + } mu.Unlock() } }() @@ -251,6 +256,7 @@ func ParseM3UFromURL(db *database.Instance, m3uURL string, m3uIndex int) error { } streamInfo := parseLine(line, nextLine, m3uIndex) + log.Println(streamInfo) streamInfoCh <- streamInfo } } @@ -264,15 +270,6 @@ func ParseM3UFromURL(db *database.Instance, m3uURL string, m3uIndex int) error { return fmt.Errorf("scanner error: %v", err) } - if len(streamInfos) > 0 { - if debug { - utils.SafeLogPrintf(nil, nil, "[DEBUG] Saving %d stream infos to database\n", len(streamInfos)) - } - if err := db.SaveToDb(streamInfos); err != nil { - return fmt.Errorf("failed to save data to database: %v", err) - } - } - buffer.Reset() if debug { diff --git a/main.go b/main.go index 7035c6d5..f34e0a13 100644 --- a/main.go +++ b/main.go @@ -2,76 +2,23 @@ package main import ( "context" - "fmt" "log" "m3u-stream-merger/database" "m3u-stream-merger/m3u" + "m3u-stream-merger/proxy" + "m3u-stream-merger/updater" "net/http" "os" - "strconv" - "strings" - "sync" "time" - - "github.com/robfig/cron/v3" ) -var db *database.Instance -var cronMutex sync.Mutex - -func updateSource(nextDb *database.Instance, m3uUrl string, index int) { - log.Printf("Background process: Updating M3U #%d from %s\n", index+1, m3uUrl) - err := m3u.ParseM3UFromURL(nextDb, m3uUrl, index) - if err != nil { - log.Printf("Background process: Error updating M3U: %v\n", err) - } else { - log.Printf("Background process: Updated M3U #%d from %s\n", index+1, m3uUrl) - } - -} - -func updateSources(ctx context.Context, ewg *sync.WaitGroup) { - // Ensure only one job is running at a time - cronMutex.Lock() - defer cronMutex.Unlock() - if ewg != nil { - defer ewg.Done() - } - - select { - case <-ctx.Done(): - return - default: - log.Println("Background process: Checking M3U_URLs...") - var wg sync.WaitGroup - index := 0 - for { - m3uUrl, m3uExists := os.LookupEnv(fmt.Sprintf("M3U_URL_%d", index+1)) - if !m3uExists { - break - } - - log.Printf("Background process: Fetching M3U_URL_%d...\n", index+1) - wg.Add(1) - // Start the goroutine for periodic updates - go func(currDb *database.Instance, m3uUrl string, index int) { - defer wg.Done() - updateSource(currDb, m3uUrl, index) - }(db, m3uUrl, index) - - index++ - } - wg.Wait() - - log.Println("Background process: Updated M3U database.") - } -} - func main() { // Context for graceful shutdown ctx, cancel := context.WithCancel(context.Background()) defer cancel() + updater.Initialize(ctx) + // manually set time zone if tz := os.Getenv("TZ"); tz != "" { var err error @@ -81,15 +28,8 @@ func main() { } } - REDIS_ADDR := os.Getenv("REDIS_ADDR") - REDIS_PASS := os.Getenv("REDIS_PASS") - REDIS_DB := 0 - if i, err := strconv.Atoi(os.Getenv("REDIS_DB")); err == nil { - REDIS_DB = i - } - var err error - db, err = database.InitializeDb(REDIS_ADDR, REDIS_PASS, REDIS_DB) + db, err := database.InitializeDb() if err != nil { log.Fatalf("Error initializing Redis database: %v", err) } @@ -99,77 +39,12 @@ func main() { log.Fatalf("Error clearing concurrency database: %v", err) } - clearOnBoot := os.Getenv("CLEAR_ON_BOOT") - if len(strings.TrimSpace(clearOnBoot)) == 0 { - clearOnBoot = "false" - } - - if clearOnBoot == "true" { - log.Println("CLEAR_ON_BOOT enabled. Clearing current database.") - if err := db.ClearDb(); err != nil { - log.Fatalf("Error clearing database: %v", err) - } - } - - cacheOnSync := os.Getenv("CACHE_ON_SYNC") - if len(strings.TrimSpace(cacheOnSync)) == 0 { - cacheOnSync = "false" - } - - cronSched := os.Getenv("SYNC_CRON") - if len(strings.TrimSpace(cronSched)) == 0 { - log.Println("SYNC_CRON not initialized. Defaulting to 0 0 * * * (12am every day).") - cronSched = "0 0 * * *" - } - - c := cron.New() - _, err = c.AddFunc(cronSched, func() { - var wg sync.WaitGroup - - wg.Add(1) - go updateSources(ctx, &wg) - if cacheOnSync == "true" { - if _, ok := os.LookupEnv("BASE_URL"); !ok { - log.Println("BASE_URL is required for CACHE_ON_SYNC to work.") - } - wg.Wait() - log.Println("CACHE_ON_SYNC enabled. Building cache.") - m3u.InitCache(db) - } - }) - if err != nil { - log.Fatalf("Error initializing background processes: %v", err) - } - c.Start() - - syncOnBoot := os.Getenv("SYNC_ON_BOOT") - if len(strings.TrimSpace(syncOnBoot)) == 0 { - syncOnBoot = "true" - } - - if syncOnBoot == "true" { - log.Println("SYNC_ON_BOOT enabled. Starting initial M3U update.") - - var wg sync.WaitGroup - - wg.Add(1) - go updateSources(ctx, &wg) - if cacheOnSync == "true" { - if _, ok := os.LookupEnv("BASE_URL"); !ok { - log.Println("BASE_URL is required for CACHE_ON_SYNC to work.") - } - wg.Wait() - log.Println("CACHE_ON_SYNC enabled. Building cache.") - m3u.InitCache(db) - } - } - // HTTP handlers http.HandleFunc("/playlist.m3u", func(w http.ResponseWriter, r *http.Request) { - m3u.Handler(w, r, db) + m3u.Handler(w, r) }) http.HandleFunc("/stream/", func(w http.ResponseWriter, r *http.Request) { - streamHandler(w, r, db) + proxy.Handler(w, r) }) // Start the server diff --git a/main_test.go b/main_test.go deleted file mode 100644 index 0f75b821..00000000 --- a/main_test.go +++ /dev/null @@ -1,95 +0,0 @@ -package main - -import ( - "bytes" - "context" - "io" - "log" - "m3u-stream-merger/database" - "m3u-stream-merger/m3u" - "net/http" - "net/http/httptest" - "os" - "strings" - "sync" - "testing" -) - -func TestStreamHandler(t *testing.T) { - REDIS_ADDR := "127.0.0.1:6379" - REDIS_PASS := "" - REDIS_DB := 0 - - db, err := database.InitializeDb(REDIS_ADDR, REDIS_PASS, REDIS_DB) - if err != nil { - t.Errorf("InitializeDb returned error: %v", err) - } - - err = db.ClearDb() - if err != nil { - t.Errorf("ClearDb returned error: %v", err) - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - os.Setenv("M3U_URL_1", "https://gist.githubusercontent.com/sonroyaalmerol/de1c90e8681af040924da5d15c7f530d/raw/06844df09e69ea278060252ca5aa8d767eb4543d/test-m3u.m3u") - os.Setenv("BUFFER_MB", "3") - os.Setenv("INCLUDE_GROUPS_1", "movies") - - updateSources(ctx, nil) - - streamChan := db.GetStreams() - streams := []database.StreamInfo{} - - for stream := range streamChan { - streams = append(streams, stream) - } - - m3uReq := httptest.NewRequest("GET", "/playlist.m3u", nil) - m3uW := httptest.NewRecorder() - - func() { - m3u.Handler(m3uW, m3uReq, db) - }() - - m3uResp := m3uW.Result() - if m3uResp.StatusCode != http.StatusOK { - t.Errorf("Playlist Route - Expected status code %d, got %d", http.StatusOK, m3uResp.StatusCode) - } - - var wg sync.WaitGroup - for _, stream := range streams { - wg.Add(1) - go func(stream database.StreamInfo) { - defer wg.Done() - log.Printf("Stream (%s): %v", stream.Title, stream) - req := httptest.NewRequest("GET", strings.TrimSpace(m3u.GenerateStreamURL("", stream.Slug, stream.URLs[0])), nil) - w := httptest.NewRecorder() - - // Call the handler function - streamHandler(w, req, db) - - // Check the response status code - resp := w.Result() - if resp.StatusCode != http.StatusOK { - t.Errorf("%s - Expected status code %d, got %d", stream.Title, http.StatusOK, resp.StatusCode) - } - - res, err := http.Get(stream.URLs[0]) - if err != nil { - t.Errorf("HttpGet returned error: %v", err) - } - defer res.Body.Close() - - // Example of checking response body content - expected, _ := io.ReadAll(res.Body) - body, _ := io.ReadAll(resp.Body) - if !bytes.Equal(body, expected) { - t.Errorf("Streams did not match for: %s", stream.Title) - } - }(stream) - } - - wg.Wait() -} diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go new file mode 100644 index 00000000..52eb5548 --- /dev/null +++ b/proxy/proxy_test.go @@ -0,0 +1,89 @@ +package proxy + +import ( + "bytes" + "context" + "io" + "log" + "m3u-stream-merger/database" + "m3u-stream-merger/m3u" + "m3u-stream-merger/updater" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" +) + +func TestStreamHandler(t *testing.T) { + os.Setenv("REDIS_ADDR", "127.0.0.1:6379") + os.Setenv("REDIS_PASS", "") + os.Setenv("REDIS_DB", "0") + + db, err := database.InitializeDb() + if err != nil { + t.Errorf("InitializeDb returned error: %v", err) + } + + err = db.ClearDb() + if err != nil { + t.Errorf("ClearDb returned error: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + os.Setenv("M3U_URL_1", "https://gist.githubusercontent.com/sonroyaalmerol/de1c90e8681af040924da5d15c7f530d/raw/06844df09e69ea278060252ca5aa8d767eb4543d/test-m3u.m3u") + os.Setenv("INCLUDE_GROUPS_1", "movies") + + updater.Initialize(ctx) + + streamChan := db.GetStreams() + streams := []database.StreamInfo{} + + for stream := range streamChan { + streams = append(streams, stream) + } + + m3uReq := httptest.NewRequest("GET", "/playlist.m3u", nil) + m3uW := httptest.NewRecorder() + + func() { + m3u.Handler(m3uW, m3uReq) + }() + + m3uResp := m3uW.Result() + if m3uResp.StatusCode != http.StatusOK { + t.Errorf("Playlist Route - Expected status code %d, got %d", http.StatusOK, m3uResp.StatusCode) + } + + for _, stream := range streams { + log.Printf("Stream (%s): %v", stream.Title, stream) + genStreamUrl := strings.TrimSpace(m3u.GenerateStreamURL("", stream.Slug, stream.URLs[0])) + + req := httptest.NewRequest("GET", genStreamUrl, nil) + w := httptest.NewRecorder() + + // Call the handler function + Handler(w, req) + + // Check the response status code + resp := w.Result() + if resp.StatusCode != http.StatusOK { + t.Errorf("%s - Expected status code %d, got %d", stream.Title, http.StatusOK, resp.StatusCode) + } + + res, err := http.Get(stream.URLs[0]) + if err != nil { + t.Errorf("HttpGet returned error: %v", err) + } + defer res.Body.Close() + + // Example of checking response body content + expected, _ := io.ReadAll(res.Body) + body, _ := io.ReadAll(resp.Body) + if !bytes.Equal(body, expected) { + t.Errorf("Streams did not match for: %s", stream.Title) + } + } +} diff --git a/stream_handler.go b/proxy/stream_handler.go similarity index 84% rename from stream_handler.go rename to proxy/stream_handler.go index a05d272c..19d06663 100644 --- a/stream_handler.go +++ b/proxy/stream_handler.go @@ -1,4 +1,4 @@ -package main +package proxy import ( "context" @@ -16,13 +16,35 @@ import ( "time" ) -func loadBalancer(stream database.StreamInfo, previous *[]int, method string) (*http.Response, string, int, error) { +type StreamInstance struct { + Database *database.Instance + Info database.StreamInfo +} + +func InitializeStream(streamUrl string) (*StreamInstance, error) { + initDb, err := database.InitializeDb() + if err != nil { + log.Fatalf("Error initializing Redis database: %v", err) + } + + stream, err := initDb.GetStreamBySlug(streamUrl) + if err != nil { + return nil, err + } + + return &StreamInstance{ + Database: initDb, + Info: stream, + }, nil +} + +func (instance *StreamInstance) LoadBalancer(previous *[]int, method string) (*http.Response, string, int, error) { debug := os.Getenv("DEBUG") == "true" m3uIndexes := utils.GetM3UIndexes() sort.Slice(m3uIndexes, func(i, j int) bool { - return db.ConcurrencyPriorityValue(i) > db.ConcurrencyPriorityValue(j) + return instance.Database.ConcurrencyPriorityValue(i) > instance.Database.ConcurrencyPriorityValue(j) }) maxLapsString := os.Getenv("MAX_RETRIES") @@ -45,13 +67,13 @@ func loadBalancer(stream database.StreamInfo, previous *[]int, method string) (* continue } - url, ok := stream.URLs[index] + url, ok := instance.Info.URLs[index] if !ok { - utils.SafeLogPrintf(nil, nil, "Channel not found from M3U_%d: %s\n", index+1, stream.Title) + utils.SafeLogPrintf(nil, nil, "Channel not found from M3U_%d: %s\n", index+1, instance.Info.Title) continue } - if db.CheckConcurrency(index) { + if instance.Database.CheckConcurrency(index) { utils.SafeLogPrintf(nil, &url, "Concurrency limit reached for M3U_%d: %s\n", index+1, url) continue } @@ -87,7 +109,7 @@ func loadBalancer(stream database.StreamInfo, previous *[]int, method string) (* return nil, "", -1, fmt.Errorf("Error fetching stream. Exhausted all streams.") } -func proxyStream(ctx context.Context, m3uIndex int, resp *http.Response, r *http.Request, w http.ResponseWriter, statusChan chan int) { +func (instance *StreamInstance) ProxyStream(ctx context.Context, m3uIndex int, resp *http.Response, r *http.Request, w http.ResponseWriter, statusChan chan int) { debug := os.Getenv("DEBUG") == "true" if r.Method != http.MethodGet || utils.EOFIsExpected(resp) { @@ -101,8 +123,8 @@ func proxyStream(ctx context.Context, m3uIndex int, resp *http.Response, r *http return } - db.UpdateConcurrency(m3uIndex, true) - defer db.UpdateConcurrency(m3uIndex, false) + instance.Database.UpdateConcurrency(m3uIndex, true) + defer instance.Database.UpdateConcurrency(m3uIndex, false) bufferMbInt, err := strconv.Atoi(os.Getenv("BUFFER_MB")) if err != nil || bufferMbInt < 0 { @@ -221,7 +243,7 @@ func proxyStream(ctx context.Context, m3uIndex int, resp *http.Response, r *http } } -func streamHandler(w http.ResponseWriter, r *http.Request, db *database.Instance) { +func Handler(w http.ResponseWriter, r *http.Request) { debug := os.Getenv("DEBUG") == "true" ctx, cancel := context.WithCancel(r.Context()) @@ -236,7 +258,7 @@ func streamHandler(w http.ResponseWriter, r *http.Request, db *database.Instance return } - stream, err := db.GetStreamBySlug(streamUrl) + stream, err := InitializeStream(strings.TrimPrefix(streamUrl, "/")) if err != nil { utils.SafeLogPrintf(r, nil, "Error retrieving stream for slug %s: %v\n", streamUrl, err) http.NotFound(w, r) @@ -257,7 +279,7 @@ func streamHandler(w http.ResponseWriter, r *http.Request, db *database.Instance utils.SafeLogPrintf(r, nil, "Client disconnected: %s\n", r.RemoteAddr) return default: - resp, selectedUrl, selectedIndex, err = loadBalancer(stream, &testedIndexes, r.Method) + resp, selectedUrl, selectedIndex, err = stream.LoadBalancer(&testedIndexes, r.Method) if err != nil { utils.SafeLogPrintf(r, nil, "Error reloading stream for %s: %v\n", streamUrl, err) return @@ -283,7 +305,7 @@ func streamHandler(w http.ResponseWriter, r *http.Request, db *database.Instance exitStatus := make(chan int) utils.SafeLogPrintf(r, &selectedUrl, "Proxying %s to %s\n", r.RemoteAddr, selectedUrl) - go proxyStream(ctx, selectedIndex, resp, r, w, exitStatus) + go stream.ProxyStream(ctx, selectedIndex, resp, r, w, exitStatus) testedIndexes = append(testedIndexes, selectedIndex) streamExitCode := <-exitStatus diff --git a/updater/updater.go b/updater/updater.go new file mode 100644 index 00000000..b4cfa22f --- /dev/null +++ b/updater/updater.go @@ -0,0 +1,171 @@ +package updater + +import ( + "context" + "fmt" + "log" + "m3u-stream-merger/database" + "m3u-stream-merger/m3u" + "maps" + "os" + "slices" + "strings" + "sync" + + "github.com/robfig/cron/v3" +) + +type Updater struct { + sync.Mutex + ctx context.Context + db *database.Instance + Cron *cron.Cron +} + +func Initialize(ctx context.Context) *Updater { + db, err := database.InitializeDb() + if err != nil { + log.Fatalf("Error initializing Redis database: %v", err) + } + + clearOnBoot := os.Getenv("CLEAR_ON_BOOT") + if len(strings.TrimSpace(clearOnBoot)) == 0 { + clearOnBoot = "false" + } + + if clearOnBoot == "true" { + log.Println("CLEAR_ON_BOOT enabled. Clearing current database.") + if err := db.ClearDb(); err != nil { + log.Fatalf("Error clearing database: %v", err) + } + } + + cacheOnSync := os.Getenv("CACHE_ON_SYNC") + if len(strings.TrimSpace(cacheOnSync)) == 0 { + cacheOnSync = "false" + } + + cronSched := os.Getenv("SYNC_CRON") + if len(strings.TrimSpace(cronSched)) == 0 { + log.Println("SYNC_CRON not initialized. Defaulting to 0 0 * * * (12am every day).") + cronSched = "0 0 * * *" + } + + updateInstance := &Updater{ + ctx: ctx, + db: db, + } + + c := cron.New() + _, err = c.AddFunc(cronSched, func() { + var wg sync.WaitGroup + + wg.Add(1) + go updateInstance.UpdateSources(ctx) + if cacheOnSync == "true" { + if _, ok := os.LookupEnv("BASE_URL"); !ok { + log.Println("BASE_URL is required for CACHE_ON_SYNC to work.") + } + wg.Wait() + log.Println("CACHE_ON_SYNC enabled. Building cache.") + m3u.InitCache(db) + } + }) + if err != nil { + log.Fatalf("Error initializing background processes: %v", err) + } + c.Start() + + syncOnBoot := os.Getenv("SYNC_ON_BOOT") + if len(strings.TrimSpace(syncOnBoot)) == 0 { + syncOnBoot = "true" + } + + if syncOnBoot == "true" { + log.Println("SYNC_ON_BOOT enabled. Starting initial M3U update.") + + var wg sync.WaitGroup + + wg.Add(1) + go updateInstance.UpdateSources(ctx) + if cacheOnSync == "true" { + if _, ok := os.LookupEnv("BASE_URL"); !ok { + log.Println("BASE_URL is required for CACHE_ON_SYNC to work.") + } + wg.Wait() + log.Println("CACHE_ON_SYNC enabled. Building cache.") + m3u.InitCache(db) + } + } + + updateInstance.Cron = c + + return updateInstance +} + +func (instance *Updater) UpdateSource(tmpStore map[string]*database.StreamInfo, m3uUrl string, index int) { + log.Printf("Background process: Updating M3U #%d from %s\n", index+1, m3uUrl) + err := m3u.ParseM3UFromURL(tmpStore, m3uUrl, index) + if err != nil { + log.Printf("Background process: Error updating M3U: %v\n", err) + } else { + log.Printf("Background process: Updated M3U #%d from %s\n", index+1, m3uUrl) + } +} + +func (instance *Updater) UpdateSources(ctx context.Context) { + // Ensure only one job is running at a time + instance.Lock() + defer instance.Unlock() + + db, err := database.InitializeDb() + if err != nil { + log.Println("Background process: Failed to initialize db connection.") + return + } + + tmpStore := map[string]*database.StreamInfo{} + + select { + case <-ctx.Done(): + return + default: + log.Println("Background process: Checking M3U_URLs...") + var wg sync.WaitGroup + index := 0 + + for { + m3uUrl, m3uExists := os.LookupEnv(fmt.Sprintf("M3U_URL_%d", index+1)) + if !m3uExists { + break + } + + log.Printf("Background process: Fetching M3U_URL_%d...\n", index+1) + wg.Add(1) + // Start the goroutine for periodic updates + go func(m3uUrl string, index int) { + log.Println(index) + defer wg.Done() + instance.UpdateSource(tmpStore, m3uUrl, index) + }(m3uUrl, index) + + index++ + } + wg.Wait() + + log.Printf("Background process: M3U fetching complete. Saving to database...\n") + + storeArray := []*database.StreamInfo{} + storeValues := maps.Values(tmpStore) + if storeValues != nil { + storeArray = slices.Collect(storeValues) + } + + err := db.SaveToDb(storeArray) + if err != nil { + log.Printf("Background process: Error updating M3U database: %v\n", err) + } + + log.Println("Background process: Updated M3U database.") + } +} diff --git a/utils/http.go b/utils/http.go index 5b53354a..ab2abd30 100644 --- a/utils/http.go +++ b/utils/http.go @@ -1,6 +1,11 @@ package utils -import "net/http" +import ( + "fmt" + "net/http" + "os" + "strings" +) func CustomHttpRequest(method string, url string) (*http.Response, error) { userAgent := GetEnv("USER_AGENT") @@ -22,3 +27,19 @@ func CustomHttpRequest(method string, url string) (*http.Response, error) { return resp, nil } + +func DetermineBaseURL(r *http.Request) string { + if r != nil { + if r.TLS == nil { + return fmt.Sprintf("http://%s/stream", r.Host) + } else { + return fmt.Sprintf("https://%s/stream", r.Host) + } + } + + if customBase, ok := os.LookupEnv("BASE_URL"); ok { + return fmt.Sprintf("%s/stream", strings.TrimSuffix(customBase, "/")) + } + + return "" +}