Skip to content

Commit

Permalink
Add support for CORS allowCredentials (#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
komuw authored Sep 11, 2023
1 parent cccd396 commit 90ee0c3
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 15 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
Most recent version is listed first.


# v0.0.81
- ong/middleware: Add support for CORS allowCredentials: https://github.com/komuw/ong/pull/385

# v0.0.80
- ong/middleware: Validate secretKeys a bit more: https://github.com/komuw/ong/pull/384

Expand Down
14 changes: 13 additions & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,11 @@ func (o Opts) GoString() string {
// If it is less than 1, [config.DefaultLoadShedMinSampleSize] is used instead.
// loadShedBreachLatency is the p99 latency at which point we start dropping(loadshedding) requests. If it is less than 1nanosecond, [config.DefaultLoadShedBreachLatency] is used instead.
//
// allowedOrigins, allowedMethods, allowedHeaders & corsCacheDuration are used by the CORS middleware.
// allowedOrigins, allowedMethods, allowedHeaders, allowCredentials & corsCacheDuration are used by the CORS middleware.
// If allowedOrigins is nil, all origins are allowed. You can also use []string{"*"} to allow all.
// If allowedMethods is nil, "GET", "POST", "HEAD" are allowed. Use []string{"*"} to allow all.
// If allowedHeaders is nil, "Origin", "Accept", "Content-Type", "X-Requested-With" are allowed. Use []string{"*"} to allow all.
// allowCredentials indicates whether the request can include user credentials like cookies, HTTP authentication or client side SSL certificates.
// corsCacheDuration is the duration that preflight responses will be cached. If it is less than 1second, [config.DefaultCorsCacheDuration] is used instead.
//
// csrfTokenDuration is the duration that csrf cookie will be valid for. If it is less than 1second, [config.DefaultCsrfCookieDuration] is used instead.
Expand Down Expand Up @@ -215,6 +216,7 @@ func New(
allowedOrigins []string,
allowedMethods []string,
allowedHeaders []string,
allowCredentials bool,
corsCacheDuration time.Duration,
csrfTokenDuration time.Duration,
sessionCookieDuration time.Duration,
Expand Down Expand Up @@ -247,6 +249,7 @@ func New(
allowedOrigins,
allowedMethods,
allowedHeaders,
allowCredentials,
corsCacheDuration,
csrfTokenDuration,
sessionCookieDuration,
Expand Down Expand Up @@ -483,6 +486,7 @@ type middlewareOpts struct {
AllowedOrigins []string
AllowedMethods []string
AllowedHeaders []string
AllowCredentials bool
CorsCacheDuration time.Duration

// csrf
Expand All @@ -508,6 +512,7 @@ func (m middlewareOpts) String() string {
AllowedOrigins: %v,
AllowedMethods: %v,
AllowedHeaders: %v,
AllowCredentials: %v,
CorsCacheDuration: %v,
CsrfTokenDuration: %v,
SessionCookieDuration: %v,
Expand All @@ -525,6 +530,7 @@ func (m middlewareOpts) String() string {
m.AllowedOrigins,
m.AllowedMethods,
m.AllowedHeaders,
m.AllowCredentials,
m.CorsCacheDuration,
m.CsrfTokenDuration,
m.SessionCookieDuration,
Expand All @@ -550,6 +556,7 @@ func newMiddlewareOpts(
allowedOrigins []string,
allowedMethods []string,
allowedHeaders []string,
allowCredentials bool,
corsCacheDuration time.Duration,
csrfTokenDuration time.Duration,
sessionCookieDuration time.Duration,
Expand Down Expand Up @@ -590,6 +597,7 @@ func newMiddlewareOpts(
AllowedOrigins: allowedOrigins,
AllowedMethods: allowedMethods,
AllowedHeaders: allowedHeaders,
AllowCredentials: allowCredentials,
CorsCacheDuration: corsCacheDuration,

// csrf
Expand Down Expand Up @@ -625,6 +633,7 @@ func withMiddlewareOpts(
nil,
nil,
nil,
false,
DefaultCorsCacheDuration,
DefaultCsrfCookieDuration,
DefaultSessionCookieDuration,
Expand Down Expand Up @@ -882,6 +891,9 @@ func (o Opts) Equal(other Opts) bool {
if !slices.Equal(o.middlewareOpts.AllowedHeaders, other.middlewareOpts.AllowedHeaders) {
return false
}
if o.middlewareOpts.AllowCredentials != other.middlewareOpts.AllowCredentials {
return false
}
if o.CorsCacheDuration != other.CorsCacheDuration {
return false
}
Expand Down
2 changes: 2 additions & 0 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ func TestNewMiddlewareOpts(t *testing.T) {
nil,
nil,
nil,
false,
DefaultCorsCacheDuration,
DefaultCsrfCookieDuration,
DefaultSessionCookieDuration,
Expand All @@ -89,6 +90,7 @@ func TestNewMiddlewareOpts(t *testing.T) {
nil,
nil,
nil,
false,
DefaultCorsCacheDuration,
DefaultCsrfCookieDuration,
DefaultSessionCookieDuration,
Expand Down
2 changes: 2 additions & 0 deletions config/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ func ExampleNew() {
[]string{"GET", "POST"},
// Allow all http headers for CORs.
[]string{"*"},
// Do not allow requests to include user credentials like cookies, HTTP authentication or client side SSL certificates
false,
// Cache CORs preflight requests for 1day.
24*time.Hour,
// Expire csrf cookie after 3days.
Expand Down
18 changes: 15 additions & 3 deletions middleware/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ const (
// credentials are cookies, authorization headers, or tls client certificates
// The only valid value of this header is `true`(`false` is not valid, omit the header entirely instead.)
acacHeader = "Access-Control-Allow-Credentials"
_ = acacHeader
// header to allow CORS to resources in a private network(eg behind a VPN)
// you can set this header to `true` when you receive a preflight request if you want to allow access.
// Otherwise omit it entirely(as we will in this library)
Expand All @@ -63,6 +62,7 @@ func cors(
allowedOrigins []string,
allowedMethods []string,
allowedHeaders []string,
allowCredentials bool,
corsCacheDuration time.Duration,
) http.HandlerFunc {
allowedOrigins, allowedWildcardOrigins := getOrigins(allowedOrigins)
Expand All @@ -76,14 +76,14 @@ func cors(
return func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodOptions && r.Header.Get(acrmHeader) != "" {
// handle preflight request
handlePreflight(w, r, allowedOrigins, allowedWildcardOrigins, allowedMethods, allowedHeaders, corsCacheDuration)
handlePreflight(w, r, allowedOrigins, allowedWildcardOrigins, allowedMethods, allowedHeaders, allowCredentials, corsCacheDuration)
// Preflight requests are standalone and should stop the chain as some other
// middleware may not handle OPTIONS requests correctly. One typical example
// is authentication middleware ; OPTIONS requests won't carry authentication headers.
w.WriteHeader(http.StatusNoContent)
} else {
// handle actual request
handleActualRequest(w, r, allowedOrigins, allowedWildcardOrigins, allowedMethods)
handleActualRequest(w, r, allowedOrigins, allowedWildcardOrigins, allowedMethods, allowCredentials)
wrappedHandler.ServeHTTP(w, r)
}
}
Expand All @@ -96,6 +96,7 @@ func handlePreflight(
allowedWildcardOrigins []wildcard,
allowedMethods []string,
allowedHeaders []string,
allowCredentials bool,
corsCacheDuration time.Duration,
) {
headers := w.Header()
Expand Down Expand Up @@ -160,6 +161,11 @@ func handlePreflight(

// (d)
headers.Set(acmaHeader, fmt.Sprintf("%d", int(corsCacheDuration.Seconds())))

// (e)
if allowCredentials {
headers.Set(acacHeader, "true")
}
}

func handleActualRequest(
Expand All @@ -168,6 +174,7 @@ func handleActualRequest(
allowedOrigins []string,
allowedWildcardOrigins []wildcard,
allowedMethods []string,
allowCredentials bool,
) {
headers := w.Header()
origin := r.Header.Get(originHeader)
Expand Down Expand Up @@ -196,6 +203,11 @@ func handleActualRequest(
} else {
headers.Set(acaoHeader, origin)
}

// (b)
if allowCredentials {
headers.Set(acacHeader, "true")
}
}

type wildcard struct {
Expand Down
22 changes: 11 additions & 11 deletions middleware/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestCorsPreflight(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := cors(someCorsHandler(msg), nil, nil, nil, config.DefaultCorsCacheDuration)
wrappedHandler := cors(someCorsHandler(msg), nil, nil, nil, false, config.DefaultCorsCacheDuration)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "/someUri", nil)
req.Header.Add(acrmHeader, "is-set") // preflight request header set
Expand All @@ -45,7 +45,7 @@ func TestCorsPreflight(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, []string{"*"}, config.DefaultCorsCacheDuration)
wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, []string{"*"}, false, config.DefaultCorsCacheDuration)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "/someUri", nil)
req.Header.Add(acrmHeader, http.MethodGet) // preflight request header set
Expand Down Expand Up @@ -76,7 +76,7 @@ func TestCorsPreflight(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := cors(someCorsHandler(msg), nil, nil, nil, config.DefaultCorsCacheDuration)
wrappedHandler := cors(someCorsHandler(msg), nil, nil, nil, false, config.DefaultCorsCacheDuration)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "/someUri", nil)
// preflight request header NOT set
Expand Down Expand Up @@ -133,7 +133,7 @@ func TestCorsPreflight(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := cors(someCorsHandler(msg), tt.allowedOrigins, []string{"*"}, []string{"*"}, config.DefaultCorsCacheDuration)
wrappedHandler := cors(someCorsHandler(msg), tt.allowedOrigins, []string{"*"}, []string{"*"}, false, config.DefaultCorsCacheDuration)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "/someUri", nil)
req.Header.Add(acrmHeader, "is-set") // preflight request header set
Expand Down Expand Up @@ -207,7 +207,7 @@ func TestCorsPreflight(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, tt.allowedMethods, []string{"*"}, config.DefaultCorsCacheDuration)
wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, tt.allowedMethods, []string{"*"}, false, config.DefaultCorsCacheDuration)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "/someUri", nil)
req.Header.Add(originHeader, "http://some-origin.com")
Expand Down Expand Up @@ -271,7 +271,7 @@ func TestCorsPreflight(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, tt.allowedHeaders, config.DefaultCorsCacheDuration)
wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, tt.allowedHeaders, false, config.DefaultCorsCacheDuration)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "/someUri", nil)
req.Header.Add(acrmHeader, "is-set") // preflight request header set
Expand Down Expand Up @@ -303,7 +303,7 @@ func TestCorsActualRequest(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := cors(someCorsHandler(msg), nil, nil, nil, config.DefaultCorsCacheDuration)
wrappedHandler := cors(someCorsHandler(msg), nil, nil, nil, false, config.DefaultCorsCacheDuration)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
wrappedHandler.ServeHTTP(rec, req)
Expand All @@ -322,7 +322,7 @@ func TestCorsActualRequest(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, []string{"*"}, config.DefaultCorsCacheDuration)
wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, []string{"*"}, false, config.DefaultCorsCacheDuration)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
req.Header.Add(originHeader, "http://example.com")
Expand Down Expand Up @@ -382,7 +382,7 @@ func TestCorsActualRequest(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := cors(someCorsHandler(msg), tt.allowedOrigins, []string{"*"}, []string{"*"}, config.DefaultCorsCacheDuration)
wrappedHandler := cors(someCorsHandler(msg), tt.allowedOrigins, []string{"*"}, []string{"*"}, false, config.DefaultCorsCacheDuration)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
req.Header.Add(originHeader, tt.origin)
Expand Down Expand Up @@ -446,7 +446,7 @@ func TestCorsActualRequest(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, tt.allowedMethods, []string{"*"}, config.DefaultCorsCacheDuration)
wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, tt.allowedMethods, []string{"*"}, false, config.DefaultCorsCacheDuration)
rec := httptest.NewRecorder()
req := httptest.NewRequest(tt.method, "/someUri", nil)
req.Header.Add(originHeader, "http://some-origin.com")
Expand Down Expand Up @@ -475,7 +475,7 @@ func TestCorsActualRequest(t *testing.T) {
msg := "hello"
// for this concurrency test, we have to re-use the same wrappedHandler
// so that state is shared and thus we can see if there is any state which is not handled correctly.
wrappedHandler := cors(someCorsHandler(msg), nil, nil, nil, config.DefaultCorsCacheDuration)
wrappedHandler := cors(someCorsHandler(msg), nil, nil, nil, false, config.DefaultCorsCacheDuration)

runhandler := func() {
rec := httptest.NewRecorder()
Expand Down
2 changes: 2 additions & 0 deletions middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ func allDefaultMiddlewares(
allowedOrigins := o.AllowedOrigins
allowedMethods := o.AllowedOrigins
allowedHeaders := o.AllowedHeaders
allowCredentials := o.AllowCredentials
corsCacheDuration := o.CorsCacheDuration

// csrf
Expand Down Expand Up @@ -151,6 +152,7 @@ func allDefaultMiddlewares(
allowedOrigins,
allowedMethods,
allowedHeaders,
allowCredentials,
corsCacheDuration,
),
domain,
Expand Down
1 change: 1 addition & 0 deletions middleware/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ func BenchmarkAllMiddlewares(b *testing.B) {
nil,
nil,
nil,
false,
config.DefaultCorsCacheDuration,
config.DefaultCsrfCookieDuration,
config.DefaultSessionCookieDuration,
Expand Down

0 comments on commit 90ee0c3

Please sign in to comment.