diff --git a/backend/docker-compose.test.yaml b/backend/docker-compose.test.yaml index 05fca6149..64737ac6e 100644 --- a/backend/docker-compose.test.yaml +++ b/backend/docker-compose.test.yaml @@ -28,4 +28,4 @@ services: condition: service_healthy environment: - NEBRASKA_DB_URL=postgres://postgres:nebraska@postgres:5432/nebraska_tests?sslmode=disable&connect_timeout=10 - command: sh -c "/nebraska/nebraska --auth-mode=noop --http-static-dir=/nebraska/static" + command: sh -c "/nebraska/nebraska --auth-mode=noop --http-static-dir=/nebraska/static --api-endpoint-suffix=/" diff --git a/backend/pkg/server/server.go b/backend/pkg/server/server.go index bb7decd52..df6f69825 100644 --- a/backend/pkg/server/server.go +++ b/backend/pkg/server/server.go @@ -82,10 +82,17 @@ func New(conf *config.Config, db *db.API) (*echo.Echo, error) { } // setup middlewares - if conf.APIEndpointSuffix != "" { - e.Pre(custommiddleware.OmahaSecret(conf.APIEndpointSuffix)) - } e.Pre(middleware.RemoveTrailingSlash()) + + // remove trailing slash from the endpoint secret + endpointSuffix := strings.TrimSuffix(conf.APIEndpointSuffix, "/") + if endpointSuffix != "" { + // if endpoint secret doesn't start with slash prepend it + if !strings.HasPrefix(endpointSuffix, "/") { + endpointSuffix = fmt.Sprintf("/%s", endpointSuffix) + } + e.Pre(custommiddleware.OmahaSecret(endpointSuffix)) + } e.Use(middleware.Recover()) e.Use(middleware.RequestID()) e.Use(middleware.CORS()) diff --git a/backend/test/api/api_secret_test.go b/backend/test/api/api_secret_test.go new file mode 100644 index 000000000..a0f6be5a5 --- /dev/null +++ b/backend/test/api/api_secret_test.go @@ -0,0 +1,159 @@ +package api_test + +import ( + "context" + "fmt" + "net/http" + "strings" + "testing" + + "github.com/google/uuid" + "github.com/jinzhu/copier" + "github.com/kinvolk/go-omaha/omaha" + "github.com/kinvolk/nebraska/backend/pkg/config" + "github.com/kinvolk/nebraska/backend/pkg/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + testServerURL = "http://localhost:6000" + serverPort = uint(6000) +) + +var serverPortStr = fmt.Sprintf(":%d", serverPort) + +var conf = &config.Config{ + EnableSyncer: true, + NebraskaURL: testServerURL, + HTTPLog: true, + AuthMode: "noop", + Debug: true, + ServerPort: serverPort, +} + +func TestAPIEndpointSecret(t *testing.T) { + + // establish db connection + db := newDBForTest(t) + defer db.Close() + + app := getAppWithInstance(t, db) + + // increase max update for the group + group := app.Groups[0] + group.PolicyMaxUpdatesPerPeriod = 1000 + err := db.UpdateGroup(group) + require.NoError(t, err) + + tt := []struct { + name string + secret string + url string + expectedStatusCode int + }{ + { + "success_with_slash_as_secret", + "/", + fmt.Sprintf("%s/v1/update", testServerURL), + http.StatusOK, + }, + { + "success_with_slash_as_secret_and_path", + "/", + fmt.Sprintf("%s/v1/update/", testServerURL), + http.StatusOK, + }, + { + "success_secret_with_no_pre_slash", + "test/this", + fmt.Sprintf("%s/v1/update/test/this", testServerURL), + http.StatusOK, + }, + { + "success_secret_with_two_pre_slash", + "//test/this", + fmt.Sprintf("%s/v1/update//test/this", testServerURL), + http.StatusOK, + }, + { + "success_secret_with_two_pre_slash_and_path_with_trailing_slash", + "//test/this", + fmt.Sprintf("%s/v1/update//test/this/", testServerURL), + http.StatusOK, + }, + { + "success_with_two_trailing_slash", + "/test//", + fmt.Sprintf("%s/v1/update/test//", testServerURL), + http.StatusOK, + }, + { + "success_with_secret", + "/test", + fmt.Sprintf("%s/v1/update/test", testServerURL), + http.StatusOK, + }, + { + "failure_with_secret", + "/test", + fmt.Sprintf("%s/v1/update/failure", testServerURL), + http.StatusNotImplemented, + }, + { + "success_secret_and_path_with_trailing_slash", + "/test/", + fmt.Sprintf("%s/v1/update/test/", testServerURL), + http.StatusOK, + }, + { + "success_secret_with_trailing_slash", + "/test/", + fmt.Sprintf("%s/v1/update/test", testServerURL), + http.StatusOK, + }, + } + + for _, tc := range tt { + tc := tc + t.Run(tc.name, func(t *testing.T) { + track := group.Track + + var testConfig config.Config + err := copier.Copy(&testConfig, conf) + require.NoError(t, err) + + testConfig.APIEndpointSuffix = tc.secret + server, err := server.New(&testConfig, db) + assert.NoError(t, err) + + //nolint:errcheck + go server.Start(serverPortStr) + + //nolint:errcheck + defer server.Shutdown(context.Background()) + _, err = waitServerReady(testConfig.NebraskaURL) + require.NoError(t, err) + + method := "POST" + + instanceID := uuid.New().String() + payload := strings.NewReader(fmt.Sprintf(` + + + + + + `, app.ID, track, instanceID)) + + // response + if tc.expectedStatusCode == http.StatusOK { + var omahaResp omaha.Response + httpDo(t, tc.url, method, payload, tc.expectedStatusCode, "xml", &omahaResp) + assert.Equal(t, "ok", omahaResp.Apps[0].Ping.Status) + } else { + httpDo(t, tc.url, method, payload, tc.expectedStatusCode, "", nil) + } + }) + } +} diff --git a/backend/test/api/omaha_test.go b/backend/test/api/omaha_test.go index 0b3180876..1a65b4c15 100644 --- a/backend/test/api/omaha_test.go +++ b/backend/test/api/omaha_test.go @@ -53,6 +53,37 @@ func TestOmaha(t *testing.T) { assert.NotNil(t, instance) }) + t.Run("success_with_trailing_slash", func(t *testing.T) { + track := app.Groups[0].Track + + url := fmt.Sprintf("%s/v1/update/", os.Getenv("NEBRASKA_TEST_SERVER_URL")) + + method := "POST" + + instanceID := uuid.New().String() + payload := strings.NewReader(fmt.Sprintf(` + + + + + + + + `, app.ID, track, instanceID)) + + // response + var omahaResp omaha.Response + + httpDo(t, url, method, payload, 200, "xml", &omahaResp) + + assert.Equal(t, "ok", omahaResp.Apps[0].Ping.Status) + + // check if instance exists in the DB + instance, err := db.GetInstance(instanceID, app.ID) + assert.NoError(t, err) + assert.NotNil(t, instance) + }) + t.Run("large_request_body", func(t *testing.T) { url := fmt.Sprintf("%s/v1/update", os.Getenv("NEBRASKA_TEST_SERVER_URL"))