Skip to content

Commit

Permalink
Merge pull request #133 from sonroyaalmerol/proxy-content-length
Browse files Browse the repository at this point in the history
Support non-GET request forwarding and include all HTTP headers excluding Content-Length
  • Loading branch information
sonroyaalmerol authored Aug 26, 2024
2 parents a50447d + 84f3ebb commit eee55d7
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 40 deletions.
61 changes: 44 additions & 17 deletions m3u/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
Expand Down Expand Up @@ -112,41 +114,66 @@ func downloadM3UToBuffer(m3uURL string, buffer *bytes.Buffer) (err error) {
utils.SafeLogPrintf(nil, &m3uURL, "[DEBUG] Downloading M3U from: %s\n", m3uURL)
}

var file io.Reader
var tempFilePath string

if strings.HasPrefix(m3uURL, "file://") {
// Handle local file directly
localPath := strings.TrimPrefix(m3uURL, "file://")
utils.SafeLogPrintf(nil, &localPath, "Reading M3U from local file: %s\n", localPath)
utils.SafeLogPrintf(nil, &localPath, "[DEBUG] Reading M3U from local file: %s\n", localPath)
tempFilePath = localPath
} else {
// Create temporary file path
fileName := filepath.Base(m3uURL)
tempFilePath = fmt.Sprintf("/tmp/%s.m3u-incomplete", fileName)

localFile, err := os.Open(localPath)
// Make HTTP request to download the file
resp, err := http.Get(m3uURL)
if err != nil {
return fmt.Errorf("Error opening file: %v", err)
return fmt.Errorf("HTTP GET error: %v", err)
}
defer localFile.Close()
defer resp.Body.Close()

file = localFile
} else {
utils.SafeLogPrintf(nil, &m3uURL, "Downloading M3U from URL: %s\n", m3uURL)
resp, err := utils.CustomHttpRequest("GET", m3uURL)
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("Failed to download M3U: HTTP status %d", resp.StatusCode)
}

// Create temporary file
tempFile, err := os.Create(tempFilePath)
if err != nil {
return fmt.Errorf("HTTP GET error: %v", err)
return fmt.Errorf("Error creating temp file: %v", err)
}
defer tempFile.Close()

defer func() {
_, _ = io.Copy(io.Discard, resp.Body)
resp.Body.Close()
}()
// Download to the temporary file
_, err = io.Copy(tempFile, resp.Body)
if err != nil {
return fmt.Errorf("Error writing to temp file: %v", err)
}

// Rename the file to remove the "-incomplete" suffix
finalFilePath := strings.TrimSuffix(tempFilePath, "-incomplete")
err = os.Rename(tempFilePath, finalFilePath)
if err != nil {
return fmt.Errorf("Error renaming temp file: %v", err)
}

file = resp.Body
tempFilePath = finalFilePath
}

// Read the final file into the buffer
file, err := os.Open(tempFilePath)
if err != nil {
return fmt.Errorf("Error opening file: %v", err)
}
defer file.Close()

_, err = io.Copy(buffer, file)
if err != nil {
return fmt.Errorf("Error reading file: %v", err)
return fmt.Errorf("Error reading file to buffer: %v", err)
}

if debug {
log.Println("[DEBUG] Successfully copied M3U content to buffer")
log.Println("[DEBUG] Successfully read M3U content into buffer")
}

return nil
Expand Down
33 changes: 18 additions & 15 deletions stream_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"time"
)

func loadBalancer(stream database.StreamInfo, previous *[]int) (*http.Response, string, int, error) {
func loadBalancer(stream database.StreamInfo, previous *[]int, method string) (*http.Response, string, int, error) {
debug := os.Getenv("DEBUG") == "true"

m3uIndexes := utils.GetM3UIndexes()
Expand Down Expand Up @@ -57,7 +57,7 @@ func loadBalancer(stream database.StreamInfo, previous *[]int) (*http.Response,

allSkipped = false // At least one URL is not skipped

resp, err := utils.CustomHttpRequest("GET", url)
resp, err := utils.CustomHttpRequest(method, url, "")
if err == nil {
if debug {
utils.SafeLogPrintf(nil, &url, "[DEBUG] Successfully fetched stream from %s\n", url)
Expand Down Expand Up @@ -89,6 +89,12 @@ func loadBalancer(stream database.StreamInfo, previous *[]int) (*http.Response,
func proxyStream(ctx context.Context, m3uIndex int, resp *http.Response, r *http.Request, w http.ResponseWriter, statusChan chan int) {
debug := os.Getenv("DEBUG") == "true"

if r.Method == http.MethodHead {
statusChan <- 4
resp.Body.Close()
return
}

db.UpdateConcurrency(m3uIndex, true)
defer db.UpdateConcurrency(m3uIndex, false)

Expand Down Expand Up @@ -204,12 +210,6 @@ func proxyStream(ctx context.Context, m3uIndex int, resp *http.Response, r *http
func streamHandler(w http.ResponseWriter, r *http.Request, db *database.Instance) {
debug := os.Getenv("DEBUG") == "true"

if r.Method != http.MethodGet {
w.WriteHeader(http.StatusMethodNotAllowed)
_, _ = w.Write([]byte(fmt.Sprintf("HTTP method %q not allowed", r.Method)))
return
}

ctx, cancel := context.WithCancel(r.Context())
defer cancel()

Expand Down Expand Up @@ -250,21 +250,21 @@ func streamHandler(w http.ResponseWriter, r *http.Request, db *database.Instance
utils.SafeLogPrintf(r, nil, "Client disconnected: %s\n", r.RemoteAddr)
return
default:
resp, selectedUrl, selectedIndex, err = loadBalancer(stream, &testedIndexes)
resp, selectedUrl, selectedIndex, err = loadBalancer(stream, &testedIndexes, r.Method)
if err != nil {
utils.SafeLogPrintf(r, nil, "Error reloading stream for %s: %v\n", streamSlug, err)
return
}

// HTTP header initialization
if firstWrite {
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Access-Control-Allow-Origin", "*")
for k, v := range resp.Header {
if strings.ToLower(k) != "content-length" {
for _, val := range v {
w.Header().Set(k, val)
}
if strings.ToLower(k) == "content-length" && !utils.EOFIsExpected(resp) {
continue
}

for _, val := range v {
w.Header().Set(k, val)
}
}
if debug {
Expand All @@ -288,6 +288,9 @@ func streamHandler(w http.ResponseWriter, r *http.Request, db *database.Instance
} else if streamExitCode == 1 || streamExitCode == 2 {
// Retry on server-side connection errors
utils.SafeLogPrintf(r, nil, "Retrying other servers...\n")
} else if streamExitCode == 4 {
utils.SafeLogPrintf(r, nil, "Successfully proxied HEAD request: %s\n", r.RemoteAddr)
cancel()
} else {
// Consider client-side connection errors as complete closure
utils.SafeLogPrintf(r, nil, "Client has closed the stream: %s\n", r.RemoteAddr)
Expand Down
14 changes: 6 additions & 8 deletions utils/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,11 @@ package utils

import "net/http"

func CustomHttpRequest(method string, url string) (*http.Response, error) {
func CustomHttpRequest(method string, url string, rangeHeader string) (*http.Response, error) {
userAgent := GetEnv("USER_AGENT")

// Create a new HTTP client with a custom User-Agent header
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Follow redirects while preserving the custom User-Agent header
req.Header.Set("User-Agent", userAgent)
return nil
},
}
client := &http.Client{}

req, err := http.NewRequest(method, url, nil)
if err != nil {
Expand All @@ -21,6 +15,10 @@ func CustomHttpRequest(method string, url string) (*http.Response, error) {

req.Header.Set("User-Agent", userAgent)

if rangeHeader != "" {
req.Header.Set("Range", rangeHeader)
}

resp, err := client.Do(req)
if err != nil {
return nil, err
Expand Down

0 comments on commit eee55d7

Please sign in to comment.