diff --git a/server/server.go b/server/server.go
index 62c8714b9f..a69d90d5f0 100644
--- a/server/server.go
+++ b/server/server.go
@@ -2,13 +2,9 @@ package server
import (
"context"
- "embed"
"fmt"
- "io/fs"
"net"
"net/http"
- "path"
- "regexp"
"strings"
"time"
@@ -47,9 +43,6 @@ import (
versionutils "github.com/argoproj/argo-rollouts/utils/version"
)
-//go:embed static/*
-var static embed.FS //nolint
-
var backoff = wait.Backoff{
Steps: 5,
Duration: 500 * time.Millisecond,
@@ -81,13 +74,6 @@ func NewServer(o ServerOptions) *ArgoRolloutsServer {
return &ArgoRolloutsServer{Options: o}
}
-var re = regexp.MustCompile(``)
-
-func withRootPath(fileContent []byte, rootpath string) []byte {
- var temp = re.ReplaceAllString(string(fileContent), ``)
- return []byte(temp)
-}
-
func (s *ArgoRolloutsServer) newHTTPServer(ctx context.Context, port int) *http.Server {
mux := http.NewServeMux()
endpoint := fmt.Sprintf("0.0.0.0:%d", port)
@@ -117,109 +103,13 @@ func (s *ArgoRolloutsServer) newHTTPServer(ctx context.Context, port int) *http.
panic(err)
}
- var handler http.Handler = gwmux
-
- mux.Handle("/api/", handler)
-
- mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
- requestedURI := path.Clean(r.RequestURI)
- rootPath := path.Clean("/" + s.Options.RootPath)
-
- if requestedURI == "/" {
- http.Redirect(w, r, rootPath+"/", http.StatusFound)
- return
- }
-
- //If the rootPath is not in the prefix 404
- if !strings.HasPrefix(requestedURI, rootPath) {
- http.NotFound(w, r)
- return
- }
- //If the rootPath is the requestedURI, serve index.html
- if requestedURI == rootPath {
- fileBytes, openErr := s.readIndexHtml()
- if openErr != nil {
- log.Errorf("Error opening file index.html: %v", openErr)
- w.WriteHeader(http.StatusInternalServerError)
- return
- }
- w.Write(fileBytes)
- return
- }
-
- embedPath := path.Join("static", strings.TrimPrefix(requestedURI, rootPath))
- file, openErr := static.Open(embedPath)
- if openErr != nil {
- fErr := openErr.(*fs.PathError)
- //If the file is not found, serve index.html
- if fErr.Err == fs.ErrNotExist {
- fileBytes, openErr := s.readIndexHtml()
- if openErr != nil {
- log.Errorf("Error opening file index.html: %v", openErr)
- w.WriteHeader(http.StatusInternalServerError)
- return
- }
- w.Write(fileBytes)
- return
- } else {
- log.Errorf("Error opening file %s: %v", embedPath, openErr)
- w.WriteHeader(http.StatusInternalServerError)
- return
- }
- }
- defer file.Close()
-
- stat, statErr := file.Stat()
- if statErr != nil {
- log.Errorf("Failed to stat file or dir %s: %v", embedPath, err)
- w.WriteHeader(http.StatusInternalServerError)
- return
- }
-
- fileBytes := make([]byte, stat.Size())
- _, err = file.Read(fileBytes)
- if err != nil {
- log.Errorf("Failed to read file %s: %v", embedPath, err)
- w.WriteHeader(http.StatusInternalServerError)
- return
- }
-
- w.Write(fileBytes)
- })
+ var apiHandler http.Handler = gwmux
+ mux.Handle("/api/", apiHandler)
+ mux.HandleFunc("/", s.staticFileHttpHandler)
return &httpS
}
-func (s *ArgoRolloutsServer) readIndexHtml() ([]byte, error) {
- file, err := static.Open("static/index.html")
- if err != nil {
- log.Errorf("Failed to open file %s: %v", "static/index.html", err)
- return nil, err
- }
- defer func() {
- if file != nil {
- if err := file.Close(); err != nil {
- log.Errorf("Error closing file: %v", err)
- }
- }
- }()
-
- stat, err := file.Stat()
- if err != nil {
- log.Errorf("Failed to stat file or dir %s: %v", "static/index.html", err)
- return nil, err
- }
-
- fileBytes := make([]byte, stat.Size())
- _, err = file.Read(fileBytes)
- if err != nil {
- log.Errorf("Failed to read file %s: %v", "static/index.html", err)
- return nil, err
- }
-
- return withRootPath(fileBytes, s.Options.RootPath), nil
-}
-
func (s *ArgoRolloutsServer) newGRPCServer() *grpc.Server {
grpcS := grpc.NewServer()
var rolloutsServer rollout.RolloutServiceServer = NewServer(s.Options)
diff --git a/server/server_static.go b/server/server_static.go
new file mode 100644
index 0000000000..44f2fb78f6
--- /dev/null
+++ b/server/server_static.go
@@ -0,0 +1,101 @@
+package server
+
+import (
+ "embed"
+ "errors"
+ "io/fs"
+ "mime"
+ "net/http"
+ "path"
+ "regexp"
+ "strconv"
+ "strings"
+
+ log "github.com/sirupsen/logrus"
+)
+
+var (
+ //go:embed static/*
+ static embed.FS //nolint
+ staticBasePath = "static"
+ indexHtmlFile = staticBasePath + "/index.html"
+)
+
+const (
+ ContentType = "Content-Type"
+ ContentLength = "Content-Length"
+)
+
+func (s *ArgoRolloutsServer) staticFileHttpHandler(w http.ResponseWriter, r *http.Request) {
+ requestedURI := path.Clean(r.RequestURI)
+ rootPath := path.Clean("/" + s.Options.RootPath)
+
+ if requestedURI == "/" {
+ http.Redirect(w, r, rootPath+"/", http.StatusFound)
+ return
+ }
+
+ //If the rootPath is not in the prefix 404
+ if !strings.HasPrefix(requestedURI, rootPath) {
+ http.NotFound(w, r)
+ return
+ }
+
+ embedPath := path.Join(staticBasePath, strings.TrimPrefix(requestedURI, rootPath))
+
+ //If the rootPath is the requestedURI, serve index.html
+ if requestedURI == rootPath {
+ embedPath = indexHtmlFile
+ }
+
+ fileBytes, err := static.ReadFile(embedPath)
+ if err != nil {
+ if fileNotExistsOrIsDirectoryError(err) {
+ // send index.html, because UI will use path based router in React
+ fileBytes, _ = static.ReadFile(indexHtmlFile)
+ embedPath = indexHtmlFile
+ } else {
+ log.Errorf("Error reading file %s: %v", embedPath, err)
+ w.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+ }
+
+ if embedPath == indexHtmlFile {
+ fileBytes = withRootPath(fileBytes, s.Options.RootPath)
+ }
+
+ w.Header().Set(ContentType, determineMimeType(embedPath))
+ w.Header().Set(ContentLength, strconv.Itoa(len(fileBytes)))
+ w.WriteHeader(http.StatusOK)
+ _, err = w.Write(fileBytes)
+ if err != nil {
+ log.Errorf("Error writing response: %v", err)
+ }
+}
+
+func fileNotExistsOrIsDirectoryError(err error) bool {
+ if errors.Is(err, fs.ErrNotExist) {
+ return true
+ }
+ pathErr, isPathError := err.(*fs.PathError)
+ return isPathError && strings.Contains(pathErr.Error(), "is a directory")
+}
+
+func determineMimeType(fileName string) string {
+ idx := strings.LastIndex(fileName, ".")
+ if idx >= 0 {
+ mimeType := mime.TypeByExtension(fileName[idx:])
+ if len(mimeType) > 0 {
+ return mimeType
+ }
+ }
+ return "text/plain"
+}
+
+var re = regexp.MustCompile(``)
+
+func withRootPath(fileContent []byte, rootpath string) []byte {
+ var temp = re.ReplaceAllString(string(fileContent), ``)
+ return []byte(temp)
+}
diff --git a/server/server_static_test.go b/server/server_static_test.go
new file mode 100644
index 0000000000..82b5b19ac9
--- /dev/null
+++ b/server/server_static_test.go
@@ -0,0 +1,113 @@
+package server
+
+import (
+ "embed"
+ "io"
+ "mime"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/tj/assert"
+)
+
+const TestRootPath = "/test-root"
+
+var (
+ //go:embed static_test/*
+ staticTestData embed.FS //nolint
+ mockServer ArgoRolloutsServer
+)
+
+func init() {
+ static = staticTestData
+ staticBasePath = "static_test"
+ indexHtmlFile = staticBasePath + "/index.html"
+ mockServer = mockArgoRolloutServer()
+}
+
+func TestIndexHtmlIsServed(t *testing.T) {
+ tests := []struct {
+ requestPath string
+ }{
+ {TestRootPath + "/"},
+ {TestRootPath + "/index.html"},
+ {TestRootPath + "/nonsense/../index.html"},
+ {TestRootPath + "/test-dir/test.css"},
+ }
+ for _, test := range tests {
+ t.Run(test.requestPath, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, test.requestPath, nil)
+ w := httptest.NewRecorder()
+ mockServer.staticFileHttpHandler(w, req)
+ res := w.Result()
+ defer res.Body.Close()
+ data, err := io.ReadAll(res.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, res.StatusCode, http.StatusOK)
+ if strings.HasSuffix(test.requestPath, ".css") {
+ assert.Equal(t, res.Header.Get(ContentType), mime.TypeByExtension(".css"))
+ assert.Contains(t, string(data), "empty by intent")
+ } else {
+ assert.Equal(t, res.Header.Get(ContentType), mime.TypeByExtension(".html"))
+ assert.Contains(t, string(data), "
index-title")
+ }
+ })
+ }
+}
+
+func TestWhenFileNotFoundSendIndexPageForUiReactRouter(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, TestRootPath+"/namespace-default", nil)
+ w := httptest.NewRecorder()
+ mockServer.staticFileHttpHandler(w, req)
+ res := w.Result()
+ defer res.Body.Close()
+ data, err := io.ReadAll(res.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, res.StatusCode, http.StatusOK)
+ assert.Contains(t, string(data), "index-title")
+}
+
+func TestSlashWillBeRedirectedToRootPath(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ w := httptest.NewRecorder()
+ mockServer.staticFileHttpHandler(w, req)
+ res := w.Result()
+ defer res.Body.Close()
+ _, err := io.ReadAll(res.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, res.StatusCode, http.StatusFound)
+ assert.Contains(t, res.Header.Get("Location"), TestRootPath)
+}
+
+func TestInvalidFilesOrHackingAttemptReturn404(t *testing.T) {
+ tests := []struct {
+ requestPath string
+ }{
+ {"/index.html"}, // should fail, because not prefixed with Option.RootPath
+ {"/etc/passwd"},
+ {TestRootPath + "/../etc/passwd"},
+ {TestRootPath + "/../../etc/passwd"},
+ {TestRootPath + "/../../../etc/passwd"},
+ }
+ for _, test := range tests {
+ t.Run(test.requestPath, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, test.requestPath, nil)
+ w := httptest.NewRecorder()
+ mockServer.staticFileHttpHandler(w, req)
+ res := w.Result()
+ defer res.Body.Close()
+ assert.Equal(t, res.StatusCode, http.StatusNotFound)
+ })
+ }
+}
+
+func mockArgoRolloutServer() ArgoRolloutsServer {
+ s := ArgoRolloutsServer{
+ Options: ServerOptions{
+ RootPath: TestRootPath,
+ },
+ }
+ return s
+}
diff --git a/server/static_test/index.html b/server/static_test/index.html
new file mode 100644
index 0000000000..03753f5f4f
--- /dev/null
+++ b/server/static_test/index.html
@@ -0,0 +1,9 @@
+
+
+
+ index-title
+
+
+index-body
+
+
diff --git a/server/static_test/test-dir/test.css b/server/static_test/test-dir/test.css
new file mode 100644
index 0000000000..1c10d48411
--- /dev/null
+++ b/server/static_test/test-dir/test.css
@@ -0,0 +1 @@
+/* empty by intent */