diff --git a/middleware/cors.go b/middleware/cors.go index 77a902036..4f90e88cd 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -15,7 +15,8 @@ type ( Skipper Skipper // AllowOrigin defines a list of origins that may access the resource. - // Optional. Default value []string{"*"}. + // Optional. If request header `Origin` is set, value is []string{""} + // else []string{"*"}. AllowOrigins []string `json:"allow_origins"` // AllowMethods defines a list methods allowed when accessing the resource. @@ -51,7 +52,6 @@ var ( // DefaultCORSConfig is the default CORS middleware config. DefaultCORSConfig = CORSConfig{ Skipper: defaultSkipper, - AllowOrigins: []string{"*"}, AllowMethods: []string{echo.GET, echo.HEAD, echo.PUT, echo.PATCH, echo.POST, echo.DELETE}, } ) @@ -69,12 +69,10 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { if config.Skipper == nil { config.Skipper = DefaultCORSConfig.Skipper } - if len(config.AllowOrigins) == 0 { - config.AllowOrigins = DefaultCORSConfig.AllowOrigins - } if len(config.AllowMethods) == 0 { config.AllowMethods = DefaultCORSConfig.AllowMethods } + allowedOrigins := strings.Join(config.AllowOrigins, ",") allowMethods := strings.Join(config.AllowMethods, ",") allowHeaders := strings.Join(config.AllowHeaders, ",") @@ -89,6 +87,17 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { req := c.Request() res := c.Response() + origin := req.Header.Get(echo.HeaderOrigin) + + if allowedOrigins == "" { + if origin != "" { + allowedOrigins = origin + } else { + if !config.AllowCredentials { + allowedOrigins = "*" + } + } + } // Simple request if req.Method != echo.OPTIONS { diff --git a/middleware/cors_test.go b/middleware/cors_test.go index 87d1c7da1..9b4f01037 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -11,21 +11,21 @@ import ( func TestCORS(t *testing.T) { e := echo.New() + + // Origin origin req, _ := http.NewRequest(echo.GET, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - cors := CORSWithConfig(CORSConfig{ - AllowCredentials: true, - }) - h := cors(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) + h := CORS()(echo.NotFoundHandler) + req.Header.Set(echo.HeaderOrigin, "localhost") + h(c) + assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) // Wildcard origin req, _ = http.NewRequest(echo.GET, "/", nil) rec = httptest.NewRecorder() c = e.NewContext(req, rec) - req.Header.Set(echo.HeaderOrigin, "localhost") + h = CORS()(echo.NotFoundHandler) h(c) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) @@ -34,14 +34,7 @@ func TestCORS(t *testing.T) { rec = httptest.NewRecorder() c = e.NewContext(req, rec) req.Header.Set(echo.HeaderOrigin, "localhost") - cors = CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"localhost"}, - AllowCredentials: true, - MaxAge: 3600, - }) - h = cors(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) + h = CORS()(echo.NotFoundHandler) h(c) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) @@ -51,6 +44,12 @@ func TestCORS(t *testing.T) { c = e.NewContext(req, rec) req.Header.Set(echo.HeaderOrigin, "localhost") req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + cors := CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"localhost"}, + AllowCredentials: true, + MaxAge: 3600, + }) + h = cors(echo.NotFoundHandler) h(c) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) diff --git a/website/content/middleware/cors.md b/website/content/middleware/cors.md index d2a1c0fe8..3e7a91af7 100644 --- a/website/content/middleware/cors.md +++ b/website/content/middleware/cors.md @@ -33,39 +33,40 @@ e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ ```go CORSConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // AllowOrigin defines a list of origins that may access the resource. - // Optional. Default value []string{"*"}. - AllowOrigins []string `json:"allow_origins"` - - // AllowMethods defines a list methods allowed when accessing the resource. - // This is used in response to a preflight request. - // Optional. Default value DefaultCORSConfig.AllowMethods. - AllowMethods []string `json:"allow_methods"` - - // AllowHeaders defines a list of request headers that can be used when - // making the actual request. This in response to a preflight request. - // Optional. Default value []string{}. - AllowHeaders []string `json:"allow_headers"` - - // AllowCredentials indicates whether or not the response to the request - // can be exposed when the credentials flag is true. When used as part of - // a response to a preflight request, this indicates whether or not the - // actual request can be made using credentials. - // Optional. Default value false. - AllowCredentials bool `json:"allow_credentials"` - - // ExposeHeaders defines a whitelist headers that clients are allowed to - // access. - // Optional. Default value []string{}. - ExposeHeaders []string `json:"expose_headers"` - - // MaxAge indicates how long (in seconds) the results of a preflight request - // can be cached. - // Optional. Default value 0. - MaxAge int `json:"max_age"` + // Skipper defines a function to skip middleware. + Skipper Skipper + + // AllowOrigin defines a list of origins that may access the resource. + // Optional. If request header `Origin` is set, value is []string{""} + // else []string{"*"}. + AllowOrigins []string `json:"allow_origins"` + + // AllowMethods defines a list methods allowed when accessing the resource. + // This is used in response to a preflight request. + // Optional. Default value DefaultCORSConfig.AllowMethods. + AllowMethods []string `json:"allow_methods"` + + // AllowHeaders defines a list of request headers that can be used when + // making the actual request. This in response to a preflight request. + // Optional. Default value []string{}. + AllowHeaders []string `json:"allow_headers"` + + // AllowCredentials indicates whether or not the response to the request + // can be exposed when the credentials flag is true. When used as part of + // a response to a preflight request, this indicates whether or not the + // actual request can be made using credentials. + // Optional. Default value false. + AllowCredentials bool `json:"allow_credentials"` + + // ExposeHeaders defines a whitelist headers that clients are allowed to + // access. + // Optional. Default value []string{}. + ExposeHeaders []string `json:"expose_headers"` + + // MaxAge indicates how long (in seconds) the results of a preflight request + // can be cached. + // Optional. Default value 0. + MaxAge int `json:"max_age"` } ```