diff --git a/database/database_test.go b/database/database_test.go index de3f785b..1189745e 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -1,6 +1,8 @@ package database import ( + "os" + "path/filepath" "testing" ) @@ -36,12 +38,12 @@ func TestSaveAndLoadFromSQLite(t *testing.T) { }}, }} - err = SaveToSQLite(db, expected) // Insert test data into the database + err = db.SaveToSQLite(expected) // Insert test data into the database if err != nil { t.Errorf("SaveToSQLite returned error: %v", err) } - result, err := GetStreams(db) + result, err := db.GetStreams() if err != nil { t.Errorf("GetStreams returned error: %v", err) } @@ -56,17 +58,17 @@ func TestSaveAndLoadFromSQLite(t *testing.T) { } } - err = DeleteStreamByTitle(db, expected[1].Title) + err = db.DeleteStreamByTitle(expected[1].Title) if err != nil { t.Errorf("DeleteStreamByTitle returned error: %v", err) } - err = DeleteStreamURL(db, expected[0].URLs[0].DbId) + err = db.DeleteStreamURL(expected[0].URLs[0].DbId) if err != nil { t.Errorf("DeleteStreamURL returned error: %v", err) } - result, err = GetStreams(db) + result, err = db.GetStreams() if err != nil { t.Errorf("GetStreams returned error: %v", err) } @@ -84,10 +86,16 @@ func TestSaveAndLoadFromSQLite(t *testing.T) { } } - err = DeleteSQLite("test") + err = db.DeleteSQLite() if err != nil { t.Errorf("DeleteSQLite returned error: %v", err) } + + foldername := filepath.Join(".", "data") + err = os.RemoveAll(foldername) + if err != nil { + t.Errorf("Error deleting data folder: %v\n", err) + } } // streamInfoEqual checks if two StreamInfo objects are equal. diff --git a/database/db.go b/database/db.go index 746bb206..cd7ba66f 100644 --- a/database/db.go +++ b/database/db.go @@ -3,7 +3,6 @@ package database import ( "database/sql" "fmt" - "log" "os" "path/filepath" "sync" @@ -11,40 +10,39 @@ import ( _ "modernc.org/sqlite" ) -var mutex sync.Mutex +type Instance struct { + Sql *sql.DB + FileName string + Lock sync.Mutex +} -func InitializeSQLite(name string) (db *sql.DB, err error) { - mutex.Lock() - defer mutex.Unlock() +func InitializeSQLite(name string) (db *Instance, err error) { + db = new(Instance) - if db != nil { - err := db.Close() - if err == nil { - log.Printf("Database session has already been closed: %v\n", err) - } - } + db.Lock.Lock() + defer db.Lock.Unlock() foldername := filepath.Join(".", "data") - filename := filepath.Join(foldername, fmt.Sprintf("%s.db", name)) + db.FileName = filepath.Join(foldername, fmt.Sprintf("%s.db", name)) err = os.MkdirAll(foldername, 0755) if err != nil { return nil, fmt.Errorf("error creating data folder: %v\n", err) } - file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0666) + file, err := os.OpenFile(db.FileName, os.O_RDWR|os.O_CREATE, 0666) if err != nil { return nil, fmt.Errorf("error creating database file: %v\n", err) } file.Close() - db, err = sql.Open("sqlite", filename) + db.Sql, err = sql.Open("sqlite", db.FileName) if err != nil { return nil, fmt.Errorf("error opening SQLite database: %v\n", err) } // Create table if not exists - _, err = db.Exec(` + _, err = db.Sql.Exec(` CREATE TABLE IF NOT EXISTS streams ( id INTEGER PRIMARY KEY AUTOINCREMENT, title TEXT UNIQUE, @@ -57,7 +55,7 @@ func InitializeSQLite(name string) (db *sql.DB, err error) { return nil, fmt.Errorf("error creating table: %v\n", err) } - _, err = db.Exec(` + _, err = db.Sql.Exec(` CREATE TABLE IF NOT EXISTS stream_urls ( id INTEGER PRIMARY KEY AUTOINCREMENT, stream_id INTEGER, @@ -74,39 +72,50 @@ func InitializeSQLite(name string) (db *sql.DB, err error) { } // DeleteSQLite deletes the SQLite database file. -func DeleteSQLite(name string) error { - mutex.Lock() - defer mutex.Unlock() +func (db *Instance) DeleteSQLite() error { + db.Lock.Lock() + defer db.Lock.Unlock() - foldername := filepath.Join(".", "data") - filename := filepath.Join(foldername, fmt.Sprintf("%s.db", name)) + _ = db.Sql.Close() - err := os.Remove(filename) + err := os.Remove(db.FileName) if err != nil { return fmt.Errorf("error deleting database file: %v\n", err) } + db.Sql = nil + return nil } -func RenameSQLite(prevName string, nextName string) error { - mutex.Lock() - defer mutex.Unlock() +func (db *Instance) RenameSQLite(newName string) error { + db.Lock.Lock() + defer db.Lock.Unlock() + + _ = db.Sql.Close() foldername := filepath.Join(".", "data") - prevFileName := filepath.Join(foldername, fmt.Sprintf("%s.db", prevName)) - nextFileName := filepath.Join(foldername, fmt.Sprintf("%s.db", nextName)) + nextFileName := filepath.Join(foldername, fmt.Sprintf("%s.db", newName)) - err := os.Rename(prevFileName, nextFileName) + err := os.Rename(db.FileName, nextFileName) + if err != nil { + return fmt.Errorf("error renaming database file: %v\n", err) + } + + db.FileName = nextFileName + db.Sql, err = sql.Open("sqlite", db.FileName) + if err != nil { + return fmt.Errorf("error opening SQLite database: %v\n", err) + } return err } -func SaveToSQLite(db *sql.DB, streams []StreamInfo) (err error) { - mutex.Lock() - defer mutex.Unlock() +func (db *Instance) SaveToSQLite(streams []StreamInfo) (err error) { + db.Lock.Lock() + defer db.Lock.Unlock() - tx, err := db.Begin() + tx, err := db.Sql.Begin() if err != nil { return fmt.Errorf("error beginning transaction: %v", err) } @@ -155,11 +164,11 @@ func SaveToSQLite(db *sql.DB, streams []StreamInfo) (err error) { return } -func InsertStream(db *sql.DB, s StreamInfo) (i int64, err error) { - mutex.Lock() - defer mutex.Unlock() +func (db *Instance) InsertStream(s StreamInfo) (i int64, err error) { + db.Lock.Lock() + defer db.Lock.Unlock() - tx, err := db.Begin() + tx, err := db.Sql.Begin() if err != nil { return -1, fmt.Errorf("error beginning transaction: %v", err) } @@ -192,11 +201,11 @@ func InsertStream(db *sql.DB, s StreamInfo) (i int64, err error) { return streamID, err } -func InsertStreamUrl(db *sql.DB, id int64, url StreamURL) (i int64, err error) { - mutex.Lock() - defer mutex.Unlock() +func (db *Instance) InsertStreamUrl(id int64, url StreamURL) (i int64, err error) { + db.Lock.Lock() + defer db.Lock.Unlock() - tx, err := db.Begin() + tx, err := db.Sql.Begin() if err != nil { return -1, fmt.Errorf("error beginning transaction: %v", err) } @@ -230,11 +239,11 @@ func InsertStreamUrl(db *sql.DB, id int64, url StreamURL) (i int64, err error) { return insertedId, err } -func DeleteStreamByTitle(db *sql.DB, title string) error { - mutex.Lock() - defer mutex.Unlock() +func (db *Instance) DeleteStreamByTitle(title string) error { + db.Lock.Lock() + defer db.Lock.Unlock() - tx, err := db.Begin() + tx, err := db.Sql.Begin() if err != nil { return fmt.Errorf("error beginning transaction: %v", err) } @@ -263,11 +272,11 @@ func DeleteStreamByTitle(db *sql.DB, title string) error { return nil } -func DeleteStreamURL(db *sql.DB, streamURLID int64) error { - mutex.Lock() - defer mutex.Unlock() +func (db *Instance) DeleteStreamURL(streamURLID int64) error { + db.Lock.Lock() + defer db.Lock.Unlock() - tx, err := db.Begin() + tx, err := db.Sql.Begin() if err != nil { return fmt.Errorf("error beginning transaction: %v", err) } @@ -296,11 +305,11 @@ func DeleteStreamURL(db *sql.DB, streamURLID int64) error { return nil } -func GetStreamByTitle(db *sql.DB, title string) (s StreamInfo, err error) { - mutex.Lock() - defer mutex.Unlock() +func (db *Instance) GetStreamByTitle(title string) (s StreamInfo, err error) { + db.Lock.Lock() + defer db.Lock.Unlock() - rows, err := db.Query("SELECT id, title, tvg_id, logo_url, group_name FROM streams WHERE title = ?", title) + rows, err := db.Sql.Query("SELECT id, title, tvg_id, logo_url, group_name FROM streams WHERE title = ?", title) if err != nil { return s, fmt.Errorf("error querying streams: %v", err) } @@ -312,7 +321,7 @@ func GetStreamByTitle(db *sql.DB, title string) (s StreamInfo, err error) { return s, fmt.Errorf("error scanning stream: %v", err) } - urlRows, err := db.Query("SELECT id, content, m3u_index FROM stream_urls WHERE stream_id = ? ORDER BY m3u_index ASC", s.DbId) + urlRows, err := db.Sql.Query("SELECT id, content, m3u_index FROM stream_urls WHERE stream_id = ? ORDER BY m3u_index ASC", s.DbId) if err != nil { return s, fmt.Errorf("error querying stream URLs: %v", err) } @@ -344,11 +353,11 @@ func GetStreamByTitle(db *sql.DB, title string) (s StreamInfo, err error) { return s, nil } -func GetStreamUrlByUrlAndIndex(db *sql.DB, url string, m3u_index int) (s StreamURL, err error) { - mutex.Lock() - defer mutex.Unlock() +func (db *Instance) GetStreamUrlByUrlAndIndex(url string, m3u_index int) (s StreamURL, err error) { + db.Lock.Lock() + defer db.Lock.Unlock() - rows, err := db.Query("SELECT id, content, m3u_index FROM stream_urls WHERE content = ? AND m3u_index = ? ORDER BY m3u_index ASC", url, m3u_index) + rows, err := db.Sql.Query("SELECT id, content, m3u_index FROM stream_urls WHERE content = ? AND m3u_index = ? ORDER BY m3u_index ASC", url, m3u_index) if err != nil { return s, fmt.Errorf("error querying streams: %v", err) } @@ -368,11 +377,11 @@ func GetStreamUrlByUrlAndIndex(db *sql.DB, url string, m3u_index int) (s StreamU return s, nil } -func GetStreams(db *sql.DB) ([]StreamInfo, error) { - mutex.Lock() - defer mutex.Unlock() +func (db *Instance) GetStreams() ([]StreamInfo, error) { + db.Lock.Lock() + defer db.Lock.Unlock() - rows, err := db.Query("SELECT id, title, tvg_id, logo_url, group_name FROM streams") + rows, err := db.Sql.Query("SELECT id, title, tvg_id, logo_url, group_name FROM streams") if err != nil { return nil, fmt.Errorf("error querying streams: %v", err) } @@ -386,7 +395,7 @@ func GetStreams(db *sql.DB) ([]StreamInfo, error) { return nil, fmt.Errorf("error scanning stream: %v", err) } - urlRows, err := db.Query("SELECT id, content, m3u_index FROM stream_urls WHERE stream_id = ? ORDER BY m3u_index ASC", s.DbId) + urlRows, err := db.Sql.Query("SELECT id, content, m3u_index FROM stream_urls WHERE stream_id = ? ORDER BY m3u_index ASC", s.DbId) if err != nil { return nil, fmt.Errorf("error querying stream URLs: %v", err) } diff --git a/m3u/generate.go b/m3u/generate.go index 9ea47e26..86141f95 100644 --- a/m3u/generate.go +++ b/m3u/generate.go @@ -1,7 +1,6 @@ package m3u import ( - "database/sql" "fmt" "log" "m3u-stream-merger/database" @@ -13,8 +12,8 @@ func generateStreamURL(baseUrl string, title string) string { return fmt.Sprintf("%s/%s.mp4\n", baseUrl, utils.GetStreamUID(title)) } -func GenerateM3UContent(w http.ResponseWriter, r *http.Request, db *sql.DB) { - streams, err := database.GetStreams(db) +func GenerateM3UContent(w http.ResponseWriter, r *http.Request, db *database.Instance) { + streams, err := db.GetStreams() if err != nil { log.Println(fmt.Errorf("GetStreams error: %v", err)) } diff --git a/m3u/m3u_test.go b/m3u/m3u_test.go index 3d3da0f2..af38534a 100644 --- a/m3u/m3u_test.go +++ b/m3u/m3u_test.go @@ -4,6 +4,8 @@ import ( "fmt" "net/http" "net/http/httptest" + "os" + "path/filepath" "testing" "m3u-stream-merger/database" @@ -25,7 +27,7 @@ func TestGenerateM3UContent(t *testing.T) { t.Errorf("InitializeSQLite returned error: %v", err) } - _, err = database.InsertStream(db, stream) + _, err = db.InsertStream(stream) if err != nil { t.Fatal(err) } @@ -67,10 +69,16 @@ func TestGenerateM3UContent(t *testing.T) { rr.Body.String(), expectedContent) } - err = database.DeleteSQLite("test") + err = db.DeleteSQLite() if err != nil { t.Errorf("DeleteSQLite returned error: %v", err) } + + foldername := filepath.Join(".", "data") + err = os.RemoveAll(foldername) + if err != nil { + t.Errorf("Error deleting data folder: %v\n", err) + } } func TestParseM3UFromURL(t *testing.T) { @@ -129,7 +137,7 @@ http://example.com/fox }}, } - storedStreams, err := database.GetStreams(db) + storedStreams, err := db.GetStreams() if err != nil { t.Fatalf("Error retrieving streams from database: %v", err) } @@ -157,10 +165,16 @@ http://example.com/fox } } - err = database.DeleteSQLite("test") + err = db.DeleteSQLite() if err != nil { t.Errorf("DeleteSQLite returned error: %v", err) } + + foldername := filepath.Join(".", "data") + err = os.RemoveAll(foldername) + if err != nil { + t.Errorf("Error deleting data folder: %v\n", err) + } } // streamInfoEqual checks if two StreamInfo objects are equal. diff --git a/m3u/parser.go b/m3u/parser.go index 1c58849e..9e6933a1 100644 --- a/m3u/parser.go +++ b/m3u/parser.go @@ -3,7 +3,6 @@ package m3u import ( "bufio" "bytes" - "database/sql" "fmt" "io" "log" @@ -76,8 +75,8 @@ func parseLine(line string, nextLine string, m3uIndex int) database.StreamInfo { return currentStream } -func insertStreamToDb(db *sql.DB, currentStream database.StreamInfo) error { - existingStream, err := database.GetStreamByTitle(db, currentStream.Title) +func insertStreamToDb(db *database.Instance, currentStream database.StreamInfo) error { + existingStream, err := db.GetStreamByTitle(currentStream.Title) if err != nil { return fmt.Errorf("GetStreamByTitle error (title: %s): %v", currentStream.Title, err) } @@ -87,7 +86,7 @@ func insertStreamToDb(db *sql.DB, currentStream database.StreamInfo) error { if os.Getenv("DEBUG") == "true" { log.Printf("Creating new database entry: %s\n", currentStream.Title) } - dbId, err = database.InsertStream(db, currentStream) + dbId, err = db.InsertStream(currentStream) if err != nil { return fmt.Errorf("InsertStream error (title: %s): %v", currentStream.Title, err) } @@ -103,13 +102,13 @@ func insertStreamToDb(db *sql.DB, currentStream database.StreamInfo) error { } for _, currentStreamUrl := range currentStream.URLs { - existingUrl, err := database.GetStreamUrlByUrlAndIndex(db, currentStreamUrl.Content, currentStreamUrl.M3UIndex) + existingUrl, err := db.GetStreamUrlByUrlAndIndex(currentStreamUrl.Content, currentStreamUrl.M3UIndex) if err != nil { return fmt.Errorf("GetStreamUrlByUrlAndIndex error (url: %s): %v", currentStreamUrl.Content, err) } if existingUrl.Content != currentStreamUrl.Content || existingUrl.M3UIndex != currentStreamUrl.M3UIndex { - _, err = database.InsertStreamUrl(db, dbId, currentStreamUrl) + _, err = db.InsertStreamUrl(dbId, currentStreamUrl) if err != nil { return fmt.Errorf("InsertStreamUrl error (title: %s): %v", currentStream.Title, err) } @@ -136,7 +135,7 @@ func downloadM3UToBuffer(m3uURL string, buffer *bytes.Buffer) (err error) { return nil } -func ParseM3UFromURL(db *sql.DB, m3uURL string, m3uIndex int) error { +func ParseM3UFromURL(db *database.Instance, m3uURL string, m3uIndex int) error { maxRetries := 10 var err error maxRetriesStr, maxRetriesExists := os.LookupEnv("MAX_RETRIES") diff --git a/main.go b/main.go index b6a0a114..4fee86d6 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,6 @@ package main import ( "context" - "database/sql" "fmt" "log" "m3u-stream-merger/database" @@ -16,56 +15,54 @@ import ( "github.com/robfig/cron/v3" ) -var db *sql.DB +var db *database.Instance var cronMutex sync.Mutex +var swappingLock sync.Mutex + +func swapDb(newInstance *database.Instance) error { + swappingLock.Lock() + defer swappingLock.Unlock() + + if db == nil { + err := newInstance.RenameSQLite("current_streams") + if err != nil { + return fmt.Errorf("Error renaming next_streams to current_streams: %v\n", err) + } + + db = newInstance + newInstance = nil + + return nil + } -func swapDb() error { - // Generate a unique temporary name tempName := fmt.Sprintf("temp_%d", time.Now().UnixNano()) - // Rename the current database to a temporary name - err := database.RenameSQLite("current_streams", tempName) + err := db.RenameSQLite(tempName) if err != nil { return fmt.Errorf("Error renaming current_streams to temp: %v\n", err) } - // Rename the next database to current - err = database.RenameSQLite("next_streams", "current_streams") + err = newInstance.RenameSQLite("current_streams") if err != nil { - // If renaming fails, revert the previous renaming to maintain consistency - revertErr := database.RenameSQLite(tempName, "current_streams") + revertErr := db.RenameSQLite("current_streams") if revertErr != nil { return fmt.Errorf("Error renaming back to current_streams: %v\n", revertErr) } return fmt.Errorf("Error renaming next_streams to current_streams: %v\n", err) } - // Initialize the new current database - db, err = database.InitializeSQLite("current_streams") - if err != nil { - // If initialization fails, revert both renamings - revertErr := database.RenameSQLite(tempName, "current_streams") - if revertErr != nil { - return fmt.Errorf("Error renaming back to current_streams: %v\n", revertErr) - } - revertErr = database.RenameSQLite("current_streams", "next_streams") - if revertErr != nil { - return fmt.Errorf("Error renaming back to next_streams: %v\n", revertErr) - } - return fmt.Errorf("Error initializing current_streams: %v\n", err) - } - - // Delete the temporary database - err = database.DeleteSQLite(tempName) + err = db.DeleteSQLite() if err != nil { - // Log the error but do not return as this is not a critical error fmt.Printf("Error deleting temp database: %v\n", err) } + db = newInstance + newInstance = nil + return nil } -func updateSource(nextDb *sql.DB, m3uUrl string, index int) { +func updateSource(nextDb *database.Instance, m3uUrl string, index int) { log.Printf("Background process: Updating M3U #%d from %s\n", index, m3uUrl) err := m3u.ParseM3UFromURL(nextDb, m3uUrl, index) if err != nil { @@ -102,7 +99,7 @@ func updateSources(ctx context.Context) { log.Printf("Background process: Fetching M3U_URL_%d...\n", index) wg.Add(1) // Start the goroutine for periodic updates - go func(nextDb *sql.DB, m3uUrl string, index int) { + go func(nextDb *database.Instance, m3uUrl string, index int) { defer wg.Done() updateSource(nextDb, m3uUrl, index) }(nextDb, m3uUrl, index) @@ -111,7 +108,7 @@ func updateSources(ctx context.Context) { } wg.Wait() - err = swapDb() + err = swapDb(nextDb) if err != nil { log.Fatalf("swapDb: %v", err) } @@ -171,6 +168,9 @@ func main() { // HTTP handlers http.HandleFunc("/playlist.m3u", func(w http.ResponseWriter, r *http.Request) { + swappingLock.Lock() + defer swappingLock.Unlock() + m3u.GenerateM3UContent(w, r, db) }) http.HandleFunc("/stream/", func(w http.ResponseWriter, r *http.Request) { diff --git a/main_test.go b/main_test.go index 1ecbaeee..61c7a41d 100644 --- a/main_test.go +++ b/main_test.go @@ -6,10 +6,12 @@ import ( "fmt" "io" "m3u-stream-merger/database" + "m3u-stream-merger/m3u" "m3u-stream-merger/utils" "net/http" "net/http/httptest" "os" + "path/filepath" "sync" "testing" ) @@ -32,11 +34,26 @@ func TestMP4Handler(t *testing.T) { updateSources(ctx) - streams, err := database.GetStreams(db) + streams, err := db.GetStreams() if err != nil { t.Errorf("GetStreams returned error: %v", err) } + m3uReq := httptest.NewRequest("GET", "/playlist.m3u", nil) + m3uW := httptest.NewRecorder() + + func() { + swappingLock.Lock() + defer swappingLock.Unlock() + + m3u.GenerateM3UContent(m3uW, m3uReq, db) + }() + + m3uResp := m3uW.Result() + if m3uResp.StatusCode != http.StatusOK { + t.Errorf("Playlist Route - Expected status code %d, got %d", http.StatusOK, m3uResp.StatusCode) + } + var wg sync.WaitGroup for _, stream := range streams { wg.Add(1) @@ -72,8 +89,14 @@ func TestMP4Handler(t *testing.T) { wg.Wait() - err = database.DeleteSQLite("current_streams") + err = db.DeleteSQLite() if err != nil { t.Errorf("DeleteSQLite returned error: %v", err) } + + foldername := filepath.Join(".", "data") + err = os.RemoveAll(foldername) + if err != nil { + t.Errorf("Error deleting data folder: %v\n", err) + } } diff --git a/mp4_handler.go b/mp4_handler.go index 34157f66..81f5f096 100644 --- a/mp4_handler.go +++ b/mp4_handler.go @@ -1,7 +1,6 @@ package main import ( - "database/sql" "errors" "fmt" "io" @@ -100,7 +99,7 @@ func loadBalancer(stream database.StreamInfo) (resp *http.Response, selectedUrl return resp, selectedUrl, nil } -func mp4Handler(w http.ResponseWriter, r *http.Request, db *sql.DB) { +func mp4Handler(w http.ResponseWriter, r *http.Request, db *database.Instance) { ctx := r.Context() // Log the incoming request @@ -119,7 +118,7 @@ func mp4Handler(w http.ResponseWriter, r *http.Request, db *sql.DB) { return } - stream, err := database.GetStreamByTitle(db, streamName) + stream, err := db.GetStreamByTitle(streamName) if err != nil { http.NotFound(w, r) return