From 3041d011424c7e225e8bdae3f482e4812d1ffbfd Mon Sep 17 00:00:00 2001 From: Antoine Niek Date: Fri, 26 Jun 2020 09:51:46 -0400 Subject: [PATCH] Expose *gin.Context to AllowOriginFunc --- config.go | 8 +++--- cors.go | 2 +- cors_test.go | 65 +++++++++++++++++++++++---------------------- examples/example.go | 2 +- 4 files changed, 39 insertions(+), 38 deletions(-) diff --git a/config.go b/config.go index d4fc118..b86e49d 100644 --- a/config.go +++ b/config.go @@ -10,7 +10,7 @@ import ( type cors struct { allowAllOrigins bool allowCredentials bool - allowOriginFunc func(string) bool + allowOriginFunc func(string, *gin.Context) bool allowOrigins []string exposeHeaders []string normalHeaders http.Header @@ -68,7 +68,7 @@ func (cors *cors) applyCors(c *gin.Context) { return } - if !cors.validateOrigin(origin) { + if !cors.validateOrigin(origin, c) { c.AbortWithStatus(http.StatusForbidden) return } @@ -101,7 +101,7 @@ func (cors *cors) validateWildcardOrigin(origin string) bool { return false } -func (cors *cors) validateOrigin(origin string) bool { +func (cors *cors) validateOrigin(origin string, c *gin.Context) bool { if cors.allowAllOrigins { return true } @@ -114,7 +114,7 @@ func (cors *cors) validateOrigin(origin string) bool { return true } if cors.allowOriginFunc != nil { - return cors.allowOriginFunc(origin) + return cors.allowOriginFunc(origin, c) } return false } diff --git a/cors.go b/cors.go index d6d06de..eacc631 100644 --- a/cors.go +++ b/cors.go @@ -20,7 +20,7 @@ type Config struct { // AllowOriginFunc is a custom function to validate the origin. It take the origin // as argument and returns true if allowed or false otherwise. If this option is // set, the content of AllowOrigins is ignored. - AllowOriginFunc func(origin string) bool + AllowOriginFunc func(origin string, c *gin.Context) bool // AllowMethods is a list of methods the client is allowed to use with // cross-domain requests. Default value is simple methods (GET and POST) diff --git a/cors_test.go b/cors_test.go index abce415..4fe3340 100644 --- a/cors_test.go +++ b/cors_test.go @@ -81,7 +81,7 @@ func TestBadConfig(t *testing.T) { assert.Panics(t, func() { New(Config{ AllowAllOrigins: true, - AllowOriginFunc: func(origin string) bool { return false }, + AllowOriginFunc: func(origin string, c *gin.Context) bool { return false }, }) }) assert.Panics(t, func() { @@ -200,66 +200,67 @@ func TestGeneratePreflightHeaders_MaxAge(t *testing.T) { } func TestValidateOrigin(t *testing.T) { + emptyContext := &gin.Context{} cors := newCors(Config{ AllowAllOrigins: true, }) - assert.True(t, cors.validateOrigin("http://google.com")) - assert.True(t, cors.validateOrigin("https://google.com")) - assert.True(t, cors.validateOrigin("example.com")) - assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id")) + assert.True(t, cors.validateOrigin("http://google.com", emptyContext)) + assert.True(t, cors.validateOrigin("https://google.com", emptyContext)) + assert.True(t, cors.validateOrigin("example.com", emptyContext)) + assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id", emptyContext)) cors = newCors(Config{ AllowOrigins: []string{"https://google.com", "https://github.com"}, - AllowOriginFunc: func(origin string) bool { + AllowOriginFunc: func(origin string, c *gin.Context) bool { return (origin == "http://news.ycombinator.com") }, AllowBrowserExtensions: true, }) - assert.False(t, cors.validateOrigin("http://google.com")) - assert.True(t, cors.validateOrigin("https://google.com")) - assert.True(t, cors.validateOrigin("https://github.com")) - assert.True(t, cors.validateOrigin("http://news.ycombinator.com")) - assert.False(t, cors.validateOrigin("http://example.com")) - assert.False(t, cors.validateOrigin("google.com")) - assert.False(t, cors.validateOrigin("chrome-extension://random-extension-id")) + assert.False(t, cors.validateOrigin("http://google.com", emptyContext)) + assert.True(t, cors.validateOrigin("https://google.com", emptyContext)) + assert.True(t, cors.validateOrigin("https://github.com", emptyContext)) + assert.True(t, cors.validateOrigin("http://news.ycombinator.com", emptyContext)) + assert.False(t, cors.validateOrigin("http://example.com", emptyContext)) + assert.False(t, cors.validateOrigin("google.com", emptyContext)) + assert.False(t, cors.validateOrigin("chrome-extension://random-extension-id", emptyContext)) cors = newCors(Config{ AllowOrigins: []string{"https://google.com", "https://github.com"}, }) - assert.False(t, cors.validateOrigin("chrome-extension://random-extension-id")) - assert.False(t, cors.validateOrigin("file://some-dangerous-file.js")) - assert.False(t, cors.validateOrigin("wss://socket-connection")) + assert.False(t, cors.validateOrigin("chrome-extension://random-extension-id", emptyContext)) + assert.False(t, cors.validateOrigin("file://some-dangerous-file.js", emptyContext)) + assert.False(t, cors.validateOrigin("wss://socket-connection", emptyContext)) cors = newCors(Config{ AllowOrigins: []string{"chrome-extension://*", "safari-extension://my-extension-*-app", "*.some-domain.com"}, AllowBrowserExtensions: true, AllowWildcard: true, }) - assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id")) - assert.True(t, cors.validateOrigin("chrome-extension://another-one")) - assert.True(t, cors.validateOrigin("safari-extension://my-extension-one-app")) - assert.True(t, cors.validateOrigin("safari-extension://my-extension-two-app")) - assert.False(t, cors.validateOrigin("moz-extension://ext-id-we-not-allow")) - assert.True(t, cors.validateOrigin("http://api.some-domain.com")) - assert.False(t, cors.validateOrigin("http://api.another-domain.com")) + assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id", emptyContext)) + assert.True(t, cors.validateOrigin("chrome-extension://another-one", emptyContext)) + assert.True(t, cors.validateOrigin("safari-extension://my-extension-one-app", emptyContext)) + assert.True(t, cors.validateOrigin("safari-extension://my-extension-two-app", emptyContext)) + assert.False(t, cors.validateOrigin("moz-extension://ext-id-we-not-allow", emptyContext)) + assert.True(t, cors.validateOrigin("http://api.some-domain.com", emptyContext)) + assert.False(t, cors.validateOrigin("http://api.another-domain.com", emptyContext)) cors = newCors(Config{ AllowOrigins: []string{"file://safe-file.js", "wss://some-session-layer-connection"}, AllowFiles: true, AllowWebSockets: true, }) - assert.True(t, cors.validateOrigin("file://safe-file.js")) - assert.False(t, cors.validateOrigin("file://some-dangerous-file.js")) - assert.True(t, cors.validateOrigin("wss://some-session-layer-connection")) - assert.False(t, cors.validateOrigin("ws://not-what-we-expected")) + assert.True(t, cors.validateOrigin("file://safe-file.js", emptyContext)) + assert.False(t, cors.validateOrigin("file://some-dangerous-file.js", emptyContext)) + assert.True(t, cors.validateOrigin("wss://some-session-layer-connection", emptyContext)) + assert.False(t, cors.validateOrigin("ws://not-what-we-expected", emptyContext)) cors = newCors(Config{ AllowOrigins: []string{"*"}, }) - assert.True(t, cors.validateOrigin("http://google.com")) - assert.True(t, cors.validateOrigin("https://google.com")) - assert.True(t, cors.validateOrigin("example.com")) - assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id")) + assert.True(t, cors.validateOrigin("http://google.com", emptyContext)) + assert.True(t, cors.validateOrigin("https://google.com", emptyContext)) + assert.True(t, cors.validateOrigin("example.com", emptyContext)) + assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id", emptyContext)) } func TestPassesAllowOrigins(t *testing.T) { @@ -270,7 +271,7 @@ func TestPassesAllowOrigins(t *testing.T) { ExposeHeaders: []string{"Data", "x-User"}, AllowCredentials: false, MaxAge: 12 * time.Hour, - AllowOriginFunc: func(origin string) bool { + AllowOriginFunc: func(origin string, c *gin.Context) bool { return origin == "http://github.com" }, }) diff --git a/examples/example.go b/examples/example.go index e57303c..c4e232e 100644 --- a/examples/example.go +++ b/examples/example.go @@ -20,7 +20,7 @@ func main() { AllowHeaders: []string{"Origin"}, ExposeHeaders: []string{"Content-Length"}, AllowCredentials: true, - AllowOriginFunc: func(origin string) bool { + AllowOriginFunc: func(origin string, c *gin.Context) bool { return origin == "https://github.com" }, MaxAge: 12 * time.Hour,