diff --git a/backend/docker-compose.test.yaml b/backend/docker-compose.test.yaml
index 05fca6149..64737ac6e 100644
--- a/backend/docker-compose.test.yaml
+++ b/backend/docker-compose.test.yaml
@@ -28,4 +28,4 @@ services:
condition: service_healthy
environment:
- NEBRASKA_DB_URL=postgres://postgres:nebraska@postgres:5432/nebraska_tests?sslmode=disable&connect_timeout=10
- command: sh -c "/nebraska/nebraska --auth-mode=noop --http-static-dir=/nebraska/static"
+ command: sh -c "/nebraska/nebraska --auth-mode=noop --http-static-dir=/nebraska/static --api-endpoint-suffix=/"
diff --git a/backend/pkg/server/server.go b/backend/pkg/server/server.go
index bb7decd52..df6f69825 100644
--- a/backend/pkg/server/server.go
+++ b/backend/pkg/server/server.go
@@ -82,10 +82,17 @@ func New(conf *config.Config, db *db.API) (*echo.Echo, error) {
}
// setup middlewares
- if conf.APIEndpointSuffix != "" {
- e.Pre(custommiddleware.OmahaSecret(conf.APIEndpointSuffix))
- }
e.Pre(middleware.RemoveTrailingSlash())
+
+ // remove trailing slash from the endpoint secret
+ endpointSuffix := strings.TrimSuffix(conf.APIEndpointSuffix, "/")
+ if endpointSuffix != "" {
+ // if endpoint secret doesn't start with slash prepend it
+ if !strings.HasPrefix(endpointSuffix, "/") {
+ endpointSuffix = fmt.Sprintf("/%s", endpointSuffix)
+ }
+ e.Pre(custommiddleware.OmahaSecret(endpointSuffix))
+ }
e.Use(middleware.Recover())
e.Use(middleware.RequestID())
e.Use(middleware.CORS())
diff --git a/backend/test/api/api_secret_test.go b/backend/test/api/api_secret_test.go
new file mode 100644
index 000000000..a0f6be5a5
--- /dev/null
+++ b/backend/test/api/api_secret_test.go
@@ -0,0 +1,159 @@
+package api_test
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "strings"
+ "testing"
+
+ "github.com/google/uuid"
+ "github.com/jinzhu/copier"
+ "github.com/kinvolk/go-omaha/omaha"
+ "github.com/kinvolk/nebraska/backend/pkg/config"
+ "github.com/kinvolk/nebraska/backend/pkg/server"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+const (
+ testServerURL = "http://localhost:6000"
+ serverPort = uint(6000)
+)
+
+var serverPortStr = fmt.Sprintf(":%d", serverPort)
+
+var conf = &config.Config{
+ EnableSyncer: true,
+ NebraskaURL: testServerURL,
+ HTTPLog: true,
+ AuthMode: "noop",
+ Debug: true,
+ ServerPort: serverPort,
+}
+
+func TestAPIEndpointSecret(t *testing.T) {
+
+ // establish db connection
+ db := newDBForTest(t)
+ defer db.Close()
+
+ app := getAppWithInstance(t, db)
+
+ // increase max update for the group
+ group := app.Groups[0]
+ group.PolicyMaxUpdatesPerPeriod = 1000
+ err := db.UpdateGroup(group)
+ require.NoError(t, err)
+
+ tt := []struct {
+ name string
+ secret string
+ url string
+ expectedStatusCode int
+ }{
+ {
+ "success_with_slash_as_secret",
+ "/",
+ fmt.Sprintf("%s/v1/update", testServerURL),
+ http.StatusOK,
+ },
+ {
+ "success_with_slash_as_secret_and_path",
+ "/",
+ fmt.Sprintf("%s/v1/update/", testServerURL),
+ http.StatusOK,
+ },
+ {
+ "success_secret_with_no_pre_slash",
+ "test/this",
+ fmt.Sprintf("%s/v1/update/test/this", testServerURL),
+ http.StatusOK,
+ },
+ {
+ "success_secret_with_two_pre_slash",
+ "//test/this",
+ fmt.Sprintf("%s/v1/update//test/this", testServerURL),
+ http.StatusOK,
+ },
+ {
+ "success_secret_with_two_pre_slash_and_path_with_trailing_slash",
+ "//test/this",
+ fmt.Sprintf("%s/v1/update//test/this/", testServerURL),
+ http.StatusOK,
+ },
+ {
+ "success_with_two_trailing_slash",
+ "/test//",
+ fmt.Sprintf("%s/v1/update/test//", testServerURL),
+ http.StatusOK,
+ },
+ {
+ "success_with_secret",
+ "/test",
+ fmt.Sprintf("%s/v1/update/test", testServerURL),
+ http.StatusOK,
+ },
+ {
+ "failure_with_secret",
+ "/test",
+ fmt.Sprintf("%s/v1/update/failure", testServerURL),
+ http.StatusNotImplemented,
+ },
+ {
+ "success_secret_and_path_with_trailing_slash",
+ "/test/",
+ fmt.Sprintf("%s/v1/update/test/", testServerURL),
+ http.StatusOK,
+ },
+ {
+ "success_secret_with_trailing_slash",
+ "/test/",
+ fmt.Sprintf("%s/v1/update/test", testServerURL),
+ http.StatusOK,
+ },
+ }
+
+ for _, tc := range tt {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ track := group.Track
+
+ var testConfig config.Config
+ err := copier.Copy(&testConfig, conf)
+ require.NoError(t, err)
+
+ testConfig.APIEndpointSuffix = tc.secret
+ server, err := server.New(&testConfig, db)
+ assert.NoError(t, err)
+
+ //nolint:errcheck
+ go server.Start(serverPortStr)
+
+ //nolint:errcheck
+ defer server.Shutdown(context.Background())
+ _, err = waitServerReady(testConfig.NebraskaURL)
+ require.NoError(t, err)
+
+ method := "POST"
+
+ instanceID := uuid.New().String()
+ payload := strings.NewReader(fmt.Sprintf(`
+
+
+
+
+
+ `, app.ID, track, instanceID))
+
+ // response
+ if tc.expectedStatusCode == http.StatusOK {
+ var omahaResp omaha.Response
+ httpDo(t, tc.url, method, payload, tc.expectedStatusCode, "xml", &omahaResp)
+ assert.Equal(t, "ok", omahaResp.Apps[0].Ping.Status)
+ } else {
+ httpDo(t, tc.url, method, payload, tc.expectedStatusCode, "", nil)
+ }
+ })
+ }
+}
diff --git a/backend/test/api/omaha_test.go b/backend/test/api/omaha_test.go
index 0b3180876..1a65b4c15 100644
--- a/backend/test/api/omaha_test.go
+++ b/backend/test/api/omaha_test.go
@@ -53,6 +53,37 @@ func TestOmaha(t *testing.T) {
assert.NotNil(t, instance)
})
+ t.Run("success_with_trailing_slash", func(t *testing.T) {
+ track := app.Groups[0].Track
+
+ url := fmt.Sprintf("%s/v1/update/", os.Getenv("NEBRASKA_TEST_SERVER_URL"))
+
+ method := "POST"
+
+ instanceID := uuid.New().String()
+ payload := strings.NewReader(fmt.Sprintf(`
+
+
+
+
+
+
+
+ `, app.ID, track, instanceID))
+
+ // response
+ var omahaResp omaha.Response
+
+ httpDo(t, url, method, payload, 200, "xml", &omahaResp)
+
+ assert.Equal(t, "ok", omahaResp.Apps[0].Ping.Status)
+
+ // check if instance exists in the DB
+ instance, err := db.GetInstance(instanceID, app.ID)
+ assert.NoError(t, err)
+ assert.NotNil(t, instance)
+ })
+
t.Run("large_request_body", func(t *testing.T) {
url := fmt.Sprintf("%s/v1/update", os.Getenv("NEBRASKA_TEST_SERVER_URL"))