diff --git a/pkg/common/common.go b/pkg/common/common.go index 972f24593..c6725662f 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -12,10 +12,12 @@ import ( "os" "path" "path/filepath" + "strings" "syscall" "time" "unicode/utf8" + "github.com/gorilla/mux" "github.com/opencontainers/go-digest" ispec "github.com/opencontainers/image-spec/specs-go/v1" @@ -31,8 +33,8 @@ const ( caCertFilename = "ca.crt" ) -func AllowedMethods(method string) []string { - return []string{http.MethodOptions, method} +func AllowedMethods(methods ...string) []string { + return append(methods, http.MethodOptions) } func Contains(slice []string, item string) bool { @@ -283,3 +285,30 @@ func GetManifestArtifactType(manifestContent ispec.Manifest) string { return manifestContent.Config.MediaType } + +func AddExtensionSecurityHeaders() mux.MiddlewareFunc { //nolint:varnamelen + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + resp.Header().Set("X-Content-Type-Options", "nosniff") + + next.ServeHTTP(resp, req) + }) + } +} + +func ACHeadersHandler(allowedMethods ...string) mux.MiddlewareFunc { + headerValue := strings.Join(allowedMethods, ",") + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + resp.Header().Set("Access-Control-Allow-Methods", headerValue) + resp.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type") + + if req.Method == http.MethodOptions { + return + } + + next.ServeHTTP(resp, req) + }) + } +} diff --git a/pkg/extensions/extension_mgmt.go b/pkg/extensions/extension_mgmt.go index 8599a990a..35c3a7627 100644 --- a/pkg/extensions/extension_mgmt.go +++ b/pkg/extensions/extension_mgmt.go @@ -11,7 +11,7 @@ import ( "zotregistry.io/zot/pkg/api/config" "zotregistry.io/zot/pkg/api/constants" - "zotregistry.io/zot/pkg/common" + zcommon "zotregistry.io/zot/pkg/common" "zotregistry.io/zot/pkg/log" ) @@ -73,7 +73,7 @@ type mgmt struct { func (mgmt *mgmt) handler() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sanitizedConfig := mgmt.config.Sanitize() - buf, err := common.MarshalThroughStruct(sanitizedConfig, &StrippedConfig{}) + buf, err := zcommon.MarshalThroughStruct(sanitizedConfig, &StrippedConfig{}) if err != nil { mgmt.log.Error().Err(err).Msg("mgmt: couldn't marshal config response") w.WriteHeader(http.StatusInternalServerError) @@ -82,20 +82,17 @@ func (mgmt *mgmt) handler() http.Handler { }) } -func addMgmtSecurityHeaders(h http.Handler) http.HandlerFunc { //nolint:varnamelen - return func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Content-Type-Options", "nosniff") - - h.ServeHTTP(w, r) - } -} - func SetupMgmtRoutes(config *config.Config, router *mux.Router, log log.Logger) { if config.Extensions.Mgmt != nil && *config.Extensions.Mgmt.Enable { log.Info().Msg("setting up mgmt routes") mgmt := mgmt{config: config, log: log} - router.PathPrefix(constants.ExtMgmt).Methods("GET").Handler(addMgmtSecurityHeaders(mgmt.handler())) + allowedMethods := zcommon.AllowedMethods(http.MethodGet) + + mgmtRouter := router.PathPrefix(constants.ExtMgmt).Subrouter() + mgmtRouter.Use(zcommon.ACHeadersHandler(allowedMethods...)) + mgmtRouter.Use(zcommon.AddExtensionSecurityHeaders()) + mgmtRouter.Methods(allowedMethods...).Handler(mgmt.handler()) } } diff --git a/pkg/extensions/extension_search.go b/pkg/extensions/extension_search.go index affca5984..1a1054975 100644 --- a/pkg/extensions/extension_search.go +++ b/pkg/extensions/extension_search.go @@ -13,6 +13,7 @@ import ( "zotregistry.io/zot/pkg/api/config" "zotregistry.io/zot/pkg/api/constants" + zcommon "zotregistry.io/zot/pkg/common" "zotregistry.io/zot/pkg/extensions/search" cveinfo "zotregistry.io/zot/pkg/extensions/search/cve" "zotregistry.io/zot/pkg/extensions/search/gql_generated" @@ -165,14 +166,6 @@ func (trivyT *trivyTask) DoWork() error { return nil } -func addSearchSecurityHeaders(h http.Handler) http.HandlerFunc { //nolint:varnamelen - return func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Content-Type-Options", "nosniff") - - h.ServeHTTP(w, r) - } -} - func SetupSearchRoutes(config *config.Config, router *mux.Router, storeController storage.StoreController, repoDB repodb.RepoDB, cveInfo CveInfo, log log.Logger, ) { @@ -181,24 +174,12 @@ func SetupSearchRoutes(config *config.Config, router *mux.Router, storeControlle if config.Extensions.Search != nil && *config.Extensions.Search.Enable { resConfig := search.GetResolverConfig(log, storeController, repoDB, cveInfo) - extRouter := router.PathPrefix(constants.ExtSearch).Subrouter() - extRouter.Use(SearchACHeadersHandler()) - extRouter.Methods("GET", "POST", "OPTIONS"). - Handler(addSearchSecurityHeaders(gqlHandler.NewDefaultServer(gql_generated.NewExecutableSchema(resConfig)))) - } -} - -func SearchACHeadersHandler() mux.MiddlewareFunc { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - resp.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS") - resp.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type") + allowedMethods := zcommon.AllowedMethods(http.MethodGet, http.MethodPost) - if req.Method == http.MethodOptions { - return - } - - next.ServeHTTP(resp, req) - }) + extRouter := router.PathPrefix(constants.ExtSearch).Subrouter() + extRouter.Use(zcommon.ACHeadersHandler(allowedMethods...)) + extRouter.Use(zcommon.AddExtensionSecurityHeaders()) + extRouter.Methods(allowedMethods...). + Handler(gqlHandler.NewDefaultServer(gql_generated.NewExecutableSchema(resConfig))) } } diff --git a/pkg/extensions/extension_userprefs.go b/pkg/extensions/extension_userprefs.go index 01696440a..3531e285f 100644 --- a/pkg/extensions/extension_userprefs.go +++ b/pkg/extensions/extension_userprefs.go @@ -34,25 +34,12 @@ func SetupUserPreferencesRoutes(config *config.Config, router *mux.Router, store if config.Extensions.Search != nil && *config.Extensions.Search.Enable { log.Info().Msg("setting up user preferences routes") - userprefsRouter := router.PathPrefix(constants.ExtUserPreferences).Subrouter() - userprefsRouter.Use(UserPrefsACHeadersHandler()) - - userprefsRouter.HandleFunc("", HandleUserPrefs(repoDB, log)).Methods(zcommon.AllowedMethods(http.MethodPut)...) - } -} + allowedMethods := zcommon.AllowedMethods(http.MethodPut) -func UserPrefsACHeadersHandler() mux.MiddlewareFunc { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - resp.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,PUT,OPTIONS") - resp.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type") - - if req.Method == http.MethodOptions { - return - } - - next.ServeHTTP(resp, req) - }) + userprefsRouter := router.PathPrefix(constants.ExtUserPreferences).Subrouter() + userprefsRouter.Use(zcommon.ACHeadersHandler(allowedMethods...)) + userprefsRouter.Use(zcommon.AddExtensionSecurityHeaders()) + userprefsRouter.HandleFunc("", HandleUserPrefs(repoDB, log)).Methods(allowedMethods...) } } diff --git a/pkg/extensions/extension_userprefs_test.go b/pkg/extensions/extension_userprefs_test.go index e3ec50189..6e2c69118 100644 --- a/pkg/extensions/extension_userprefs_test.go +++ b/pkg/extensions/extension_userprefs_test.go @@ -29,7 +29,7 @@ import ( var ErrTestError = errors.New("TestError") -func TestAllowedMethodsHeader(t *testing.T) { +func TestAllowedMethodsHeaderUserPrefs(t *testing.T) { defaultVal := true Convey("Test http options response", t, func() { @@ -53,7 +53,7 @@ func TestAllowedMethodsHeader(t *testing.T) { resp, _ := resty.R().Options(baseURL + constants.FullUserPreferencesPrefix) So(resp, ShouldNotBeNil) - So(resp.Header().Get("Access-Control-Allow-Methods"), ShouldResemble, "HEAD,GET,POST,PUT,OPTIONS") + So(resp.Header().Get("Access-Control-Allow-Methods"), ShouldResemble, "PUT,OPTIONS") So(resp.StatusCode(), ShouldEqual, http.StatusNoContent) }) } diff --git a/pkg/extensions/extensions_test.go b/pkg/extensions/extensions_test.go index 28d909f5b..65d573d09 100644 --- a/pkg/extensions/extensions_test.go +++ b/pkg/extensions/extensions_test.go @@ -669,3 +669,32 @@ func TestMgmtWithBearer(t *testing.T) { So(mgmtResp.HTTP.Auth.LDAP, ShouldBeNil) }) } + +func TestAllowedMethodsHeaderMgmt(t *testing.T) { + defaultVal := true + + Convey("Test http options response", t, func() { + conf := config.New() + port := test.GetFreePort() + conf.HTTP.Port = port + conf.Extensions = &extconf.ExtensionConfig{ + Mgmt: &extconf.MgmtConfig{ + BaseConfig: extconf.BaseConfig{Enable: &defaultVal}, + }, + } + baseURL := test.GetBaseURL(port) + + ctlr := api.NewController(conf) + ctlr.Config.Storage.RootDirectory = t.TempDir() + + ctrlManager := test.NewControllerManager(ctlr) + + ctrlManager.StartAndWait(port) + defer ctrlManager.StopServer() + + resp, _ := resty.R().Options(baseURL + constants.FullMgmtPrefix) + So(resp, ShouldNotBeNil) + So(resp.Header().Get("Access-Control-Allow-Methods"), ShouldResemble, "GET,OPTIONS") + So(resp.StatusCode(), ShouldEqual, http.StatusNoContent) + }) +}