From b9430ec83ba4f97567d344101e5c8666291cc486 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sun, 17 Mar 2024 12:32:30 -0300 Subject: [PATCH 1/7] fix(middleware/cors): categorise requests correctly --- middleware/cors/cors.go | 18 ++++--- middleware/cors/cors_test.go | 101 +++++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 8 deletions(-) diff --git a/middleware/cors/cors.go b/middleware/cors/cors.go index e27e74cba8..9159a1340c 100644 --- a/middleware/cors/cors.go +++ b/middleware/cors/cors.go @@ -172,8 +172,9 @@ func New(config ...Config) fiber.Handler { // Get originHeader header originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin)) - // If the request does not have an Origin header, the request is outside the scope of CORS - if originHeader == "" { + // If the request does not have Origin and Access-Control-Request-Method + // headers, the request is outside the scope of CORS + if originHeader == "" || c.Get(fiber.HeaderAccessControlRequestMethod) == "" { return c.Next() } @@ -211,8 +212,9 @@ func New(config ...Config) fiber.Handler { } // Simple request + // Ommit allowMethods and allowHeaders, only used for pre-flight requests if c.Method() != fiber.MethodOptions { - setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg) + setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg) return c.Next() } @@ -233,14 +235,14 @@ func setCORSHeaders(c *fiber.Ctx, allowOrigin, allowMethods, allowHeaders, expos if cfg.AllowCredentials { // When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*' - if allowOrigin != "*" && allowOrigin != "" { - c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) - c.Set(fiber.HeaderAccessControlAllowCredentials, "true") - } else if allowOrigin == "*" { + if allowOrigin == "*" { c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.") + } else if allowOrigin != "" { + c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) + c.Set(fiber.HeaderAccessControlAllowCredentials, "true") } - } else if len(allowOrigin) > 0 { + } else if allowOrigin != "" { // For non-credential requests, it's safe to set to '*' or specific origins c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) } diff --git a/middleware/cors/cors_test.go b/middleware/cors/cors_test.go index ff5cdd7c25..c56d3c503b 100644 --- a/middleware/cors/cors_test.go +++ b/middleware/cors/cors_test.go @@ -35,6 +35,7 @@ func Test_CORS_Negative_MaxAge(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") app.Handler()(ctx) @@ -49,6 +50,7 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) { // Test default GET response headers ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") h(ctx) @@ -59,6 +61,7 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) { // Test default OPTIONS (preflight) response headers ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") h(ctx) @@ -87,6 +90,7 @@ func Test_CORS_Wildcard(t *testing.T) { ctx.Request.SetRequestURI("/") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) // Perform request handler(ctx) @@ -101,6 +105,7 @@ func Test_CORS_Wildcard(t *testing.T) { ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) handler(ctx) utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) @@ -128,6 +133,7 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) { ctx.Request.SetRequestURI("/") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) // Perform request handler(ctx) @@ -141,6 +147,7 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) { // Test non OPTIONS (preflight) response headers ctx = &fasthttp.RequestCtx{} ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.SetMethod(fiber.MethodGet) handler(ctx) @@ -226,6 +233,7 @@ func Test_CORS_Subdomain(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com") // Perform request @@ -240,6 +248,7 @@ func Test_CORS_Subdomain(t *testing.T) { // Make request with domain only (disallowed) ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") handler(ctx) @@ -252,6 +261,7 @@ func Test_CORS_Subdomain(t *testing.T) { // Make request with allowed origin ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://test.example.com") handler(ctx) @@ -366,6 +376,7 @@ func Test_CORS_AllowOriginScheme(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, tt.reqOrigin) handler(ctx) @@ -422,6 +433,90 @@ func Test_CORS_Next(t *testing.T) { utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) } +// go test -run Test_CORS_Headers_BasedOnRequestType +func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { + t.Parallel() + app := fiber.New() + app.Use(New(Config{})) + + methods := []string{ + fiber.MethodGet, + fiber.MethodPost, + fiber.MethodPut, + fiber.MethodDelete, + fiber.MethodPatch, + fiber.MethodHead, + } + + // Get handler pointer + handler := app.Handler() + + t.Run("Without origin and Access-Control-Request-Method headers", func(t *testing.T) { + // Make request without origin header, and without Access-Control-Request-Method + for _, method := range methods { + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(method) + ctx.Request.SetRequestURI("https://example.com/") + handler(ctx) + utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set") + } + }) + + t.Run("With origin but without Access-Control-Request-Method header", func(t *testing.T) { + // Make request with origin header, but without Access-Control-Request-Method + for _, method := range methods { + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(method) + ctx.Request.SetRequestURI("https://example.com/") + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") + handler(ctx) + utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set") + } + }) + + t.Run("Without origin but with Access-Control-Request-Method header", func(t *testing.T) { + // Make request without origin header, but with Access-Control-Request-Method + for _, method := range methods { + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(method) + ctx.Request.SetRequestURI("https://example.com/") + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) + handler(ctx) + utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set") + } + }) + + t.Run("Preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) { + // Make preflight request with origin header and with Access-Control-Request-Method + for _, method := range methods { + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.SetRequestURI("https://example.com/") + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method) + handler(ctx) + utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set") + utils.AssertEqual(t, "GET,POST,HEAD,PUT,DELETE,PATCH", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)), "Access-Control-Allow-Methods header should be set (preflight request)") + utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)), "Access-Control-Allow-Headers header should be set (preflight request)") + } + }) + + t.Run("Non-preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) { + // Make non-preflight request with origin header and with Access-Control-Request-Method + for _, method := range methods { + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(method) + ctx.Request.SetRequestURI("https://example.com/api/action") + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method) + handler(ctx) + utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set") + utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)), "Access-Control-Allow-Methods header should not be set (non-preflight request)") + utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)), "Access-Control-Allow-Headers header should not be set (non-preflight request)") + } + }) +} + func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) { t.Parallel() // New fiber instance @@ -440,6 +535,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com") // Perform request @@ -454,6 +550,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) { // Make request with allowed origin ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-1.com") handler(ctx) @@ -466,6 +563,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) { // Make request with allowed origin ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com") handler(ctx) @@ -505,6 +603,7 @@ func Test_CORS_AllowOriginsFunc(t *testing.T) { // Make request with allowed origin ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com") handler(ctx) @@ -652,6 +751,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin) handler(ctx) @@ -742,6 +842,7 @@ func Test_CORS_AllowCredentials(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin) handler(ctx) From de3f4abb7df1c1688f84b47f8fc327c8d1f80563 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sun, 17 Mar 2024 12:53:21 -0300 Subject: [PATCH 2/7] test(middleware/cors): improve test coverage for request types --- middleware/cors/cors_test.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/middleware/cors/cors_test.go b/middleware/cors/cors_test.go index c56d3c503b..6ab0202e71 100644 --- a/middleware/cors/cors_test.go +++ b/middleware/cors/cors_test.go @@ -438,6 +438,9 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{})) + app.Use(func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) methods := []string{ fiber.MethodGet, @@ -458,6 +461,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { ctx.Request.Header.SetMethod(method) ctx.Request.SetRequestURI("https://example.com/") handler(ctx) + utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200") utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set") } }) @@ -470,6 +474,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { ctx.Request.SetRequestURI("https://example.com/") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") handler(ctx) + utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200") utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set") } }) @@ -482,6 +487,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { ctx.Request.SetRequestURI("https://example.com/") ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) handler(ctx) + utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200") utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set") } }) @@ -495,6 +501,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method) handler(ctx) + utils.AssertEqual(t, 204, ctx.Response.StatusCode(), "Status code should be 204") utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set") utils.AssertEqual(t, "GET,POST,HEAD,PUT,DELETE,PATCH", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)), "Access-Control-Allow-Methods header should be set (preflight request)") utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)), "Access-Control-Allow-Headers header should be set (preflight request)") @@ -510,6 +517,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method) handler(ctx) + utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200") utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set") utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)), "Access-Control-Allow-Methods header should not be set (non-preflight request)") utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)), "Access-Control-Allow-Headers header should not be set (non-preflight request)") From b5b6c9f64bf53b6e81e1bd8d186b9cf4430ebfdd Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sun, 17 Mar 2024 17:37:33 -0300 Subject: [PATCH 3/7] test(middleware/cors): Add subdomain matching tests --- middleware/cors/utils_test.go | 52 +++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/middleware/cors/utils_test.go b/middleware/cors/utils_test.go index adc729d05f..02c632825d 100644 --- a/middleware/cors/utils_test.go +++ b/middleware/cors/utils_test.go @@ -2,6 +2,8 @@ package cors import ( "testing" + + "github.com/gofiber/fiber/v2/utils" ) // go test -run -v Test_normalizeOrigin @@ -16,6 +18,9 @@ func Test_normalizeOrigin(t *testing.T) { {"http://example.com:3000", true, "http://example.com:3000"}, // Port should be preserved. {"http://example.com:3000/", true, "http://example.com:3000"}, // Trailing slash should be removed. {"http://", false, ""}, // Invalid origin should not be accepted. + {"file:///etc/passwd", false, ""}, // File scheme should not be accepted. + {"https://*example.com", false, ""}, // Wildcard domain should not be accepted. + {"http://*.example.com", false, ""}, // Wildcard subdomain should not be accepted. {"http://example.com/path", false, ""}, // Path should not be accepted. {"http://example.com?query=123", false, ""}, // Query should not be accepted. {"http://example.com#fragment", false, ""}, // Fragment should not be accepted. @@ -105,3 +110,50 @@ func Test_normalizeDomain(t *testing.T) { } } } + +func TestSubdomainMatch(t *testing.T) { + tests := []struct { + name string + sub subdomain + origin string + expected bool + }{ + { + name: "match with valid subdomain", + sub: subdomain{prefix: "https://api.", suffix: ".example.com"}, + origin: "https://api.service.example.com", + expected: true, + }, + { + name: "no match with invalid prefix", + sub: subdomain{prefix: "https://api.", suffix: ".example.com"}, + origin: "https://service.example.com", + expected: false, + }, + { + name: "no match with invalid suffix", + sub: subdomain{prefix: "https://api.", suffix: ".example.com"}, + origin: "https://api.example.org", + expected: false, + }, + { + name: "no match with empty origin", + sub: subdomain{prefix: "https://api.", suffix: ".example.com"}, + origin: "", + expected: false, + }, + { + name: "partial match not considered a match", + sub: subdomain{prefix: "https://api.", suffix: ".example.com"}, + origin: "https://api.example.com", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.sub.match(tt.origin) + utils.AssertEqual(t, tt.expected, got, "subdomain.match()") + }) + } +} From 2b650d151a47afd726db04a30fd79fc225fb4fcc Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sun, 17 Mar 2024 17:50:15 -0300 Subject: [PATCH 4/7] test(middleware/cors): parallel tests for CORS headers based on request type --- middleware/cors/cors_test.go | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/middleware/cors/cors_test.go b/middleware/cors/cors_test.go index 6ab0202e71..932e56c476 100644 --- a/middleware/cors/cors_test.go +++ b/middleware/cors/cors_test.go @@ -455,6 +455,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { handler := app.Handler() t.Run("Without origin and Access-Control-Request-Method headers", func(t *testing.T) { + t.Parallel() // Make request without origin header, and without Access-Control-Request-Method for _, method := range methods { ctx := &fasthttp.RequestCtx{} @@ -467,6 +468,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { }) t.Run("With origin but without Access-Control-Request-Method header", func(t *testing.T) { + t.Parallel() // Make request with origin header, but without Access-Control-Request-Method for _, method := range methods { ctx := &fasthttp.RequestCtx{} @@ -480,6 +482,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { }) t.Run("Without origin but with Access-Control-Request-Method header", func(t *testing.T) { + t.Parallel() // Make request without origin header, but with Access-Control-Request-Method for _, method := range methods { ctx := &fasthttp.RequestCtx{} @@ -493,6 +496,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { }) t.Run("Preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) { + t.Parallel() // Make preflight request with origin header and with Access-Control-Request-Method for _, method := range methods { ctx := &fasthttp.RequestCtx{} @@ -509,6 +513,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { }) t.Run("Non-preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) { + t.Parallel() // Make non-preflight request with origin header and with Access-Control-Request-Method for _, method := range methods { ctx := &fasthttp.RequestCtx{} @@ -676,7 +681,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) { Name: "AllowOriginsDefined/AllowOriginsFuncReturnsTrue/OriginAllowed", Config: Config{ AllowOrigins: "http://aaa.com", - AllowOriginsFunc: func(origin string) bool { + AllowOriginsFunc: func(_ string) bool { return true }, }, @@ -687,7 +692,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) { Name: "AllowOriginsDefined/AllowOriginsFuncReturnsTrue/OriginNotAllowed", Config: Config{ AllowOrigins: "http://aaa.com", - AllowOriginsFunc: func(origin string) bool { + AllowOriginsFunc: func(_ string) bool { return true }, }, @@ -698,7 +703,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) { Name: "AllowOriginsDefined/AllowOriginsFuncReturnsFalse/OriginAllowed", Config: Config{ AllowOrigins: "http://aaa.com", - AllowOriginsFunc: func(origin string) bool { + AllowOriginsFunc: func(_ string) bool { return false }, }, @@ -709,7 +714,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) { Name: "AllowOriginsDefined/AllowOriginsFuncReturnsFalse/OriginNotAllowed", Config: Config{ AllowOrigins: "http://aaa.com", - AllowOriginsFunc: func(origin string) bool { + AllowOriginsFunc: func(_ string) bool { return false }, }, @@ -729,7 +734,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) { Name: "AllowOriginsEmpty/AllowOriginsFuncReturnsTrue/OriginAllowed", Config: Config{ AllowOrigins: "", - AllowOriginsFunc: func(origin string) bool { + AllowOriginsFunc: func(_ string) bool { return true }, }, @@ -740,7 +745,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) { Name: "AllowOriginsEmpty/AllowOriginsFuncReturnsFalse/OriginNotAllowed", Config: Config{ AllowOrigins: "", - AllowOriginsFunc: func(origin string) bool { + AllowOriginsFunc: func(_ string) bool { return false }, }, @@ -782,7 +787,7 @@ func Test_CORS_AllowCredentials(t *testing.T) { Name: "AllowOriginsFuncDefined", Config: Config{ AllowCredentials: true, - AllowOriginsFunc: func(origin string) bool { + AllowOriginsFunc: func(_ string) bool { return true }, }, @@ -795,7 +800,7 @@ func Test_CORS_AllowCredentials(t *testing.T) { Name: "fiber-ghsa-fmg4-x8pw-hjhg-wildcard-credentials", Config: Config{ AllowCredentials: true, - AllowOriginsFunc: func(origin string) bool { + AllowOriginsFunc: func(_ string) bool { return true }, }, From 08fb3f357707165830e1050433cc1468d49d3fa8 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Mon, 18 Mar 2024 20:58:32 -0300 Subject: [PATCH 5/7] test(middleware/cors): Add benchmark for CORS subdomain matching --- middleware/cors/utils_test.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/middleware/cors/utils_test.go b/middleware/cors/utils_test.go index 02c632825d..ba7c0c9c68 100644 --- a/middleware/cors/utils_test.go +++ b/middleware/cors/utils_test.go @@ -157,3 +157,20 @@ func TestSubdomainMatch(t *testing.T) { }) } } + +// go test -v -run=^$ -bench=Benchmark_CORS_SubdomainMatch -benchmem -count=4 +func Benchmark_CORS_SubdomainMatch(b *testing.B) { + s := subdomain{ + prefix: "www", + suffix: ".example.com", + } + + o := "www.example.com" + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + s.match(o) + } +} From eaa110c1849f69135bd9c63871f096877cdc96a1 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Mon, 18 Mar 2024 21:38:20 -0300 Subject: [PATCH 6/7] test(middleware/cors): cover additiona test cases --- middleware/cors/cors_test.go | 10 ++++++ middleware/cors/utils_test.go | 65 ++++++++++++++++++++++------------- 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/middleware/cors/cors_test.go b/middleware/cors/cors_test.go index 932e56c476..57f5f91205 100644 --- a/middleware/cors/cors_test.go +++ b/middleware/cors/cors_test.go @@ -280,6 +280,11 @@ func Test_CORS_AllowOriginScheme(t *testing.T) { reqOrigin: "http://example.com", shouldAllowOrigin: true, }, + { + pattern: "HTTP://EXAMPLE.COM", + reqOrigin: "http://example.com", + shouldAllowOrigin: true, + }, { pattern: "https://example.com", reqOrigin: "https://example.com", @@ -310,6 +315,11 @@ func Test_CORS_AllowOriginScheme(t *testing.T) { reqOrigin: "http://aaa.example.com:8080", shouldAllowOrigin: true, }, + { + pattern: "http://*.example.com", + reqOrigin: "http://1.2.aaa.example.com", + shouldAllowOrigin: true, + }, { pattern: "http://example.com", reqOrigin: "http://gofiber.com", diff --git a/middleware/cors/utils_test.go b/middleware/cors/utils_test.go index ba7c0c9c68..47dddc2c69 100644 --- a/middleware/cors/utils_test.go +++ b/middleware/cors/utils_test.go @@ -111,40 +111,76 @@ func Test_normalizeDomain(t *testing.T) { } } -func TestSubdomainMatch(t *testing.T) { +// go test -v -run=^$ -bench=Benchmark_CORS_SubdomainMatch -benchmem -count=4 +func Benchmark_CORS_SubdomainMatch(b *testing.B) { + s := subdomain{ + prefix: "www", + suffix: ".example.com", + } + + o := "www.example.com" + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + s.match(o) + } +} + +func Test_CORS_SubdomainMatch(t *testing.T) { tests := []struct { name string sub subdomain origin string expected bool }{ + { + name: "match with different scheme", + sub: subdomain{prefix: "http://api.", suffix: ".example.com"}, + origin: "https://api.service.example.com", + expected: false, + }, + { + name: "match with different scheme", + sub: subdomain{prefix: "https://", suffix: ".example.com"}, + origin: "http://api.service.example.com", + expected: false, + }, { name: "match with valid subdomain", - sub: subdomain{prefix: "https://api.", suffix: ".example.com"}, + sub: subdomain{prefix: "https://", suffix: ".example.com"}, origin: "https://api.service.example.com", expected: true, }, + { + name: "match with valid nested subdomain", + sub: subdomain{prefix: "https://", suffix: ".example.com"}, + origin: "https://1.2.api.service.example.com", + expected: true, + }, + { name: "no match with invalid prefix", - sub: subdomain{prefix: "https://api.", suffix: ".example.com"}, + sub: subdomain{prefix: "https://abc.", suffix: ".example.com"}, origin: "https://service.example.com", expected: false, }, { name: "no match with invalid suffix", - sub: subdomain{prefix: "https://api.", suffix: ".example.com"}, + sub: subdomain{prefix: "https://", suffix: ".example.com"}, origin: "https://api.example.org", expected: false, }, { name: "no match with empty origin", - sub: subdomain{prefix: "https://api.", suffix: ".example.com"}, + sub: subdomain{prefix: "https://", suffix: ".example.com"}, origin: "", expected: false, }, { name: "partial match not considered a match", - sub: subdomain{prefix: "https://api.", suffix: ".example.com"}, + sub: subdomain{prefix: "https://service.", suffix: ".example.com"}, origin: "https://api.example.com", expected: false, }, @@ -157,20 +193,3 @@ func TestSubdomainMatch(t *testing.T) { }) } } - -// go test -v -run=^$ -bench=Benchmark_CORS_SubdomainMatch -benchmem -count=4 -func Benchmark_CORS_SubdomainMatch(b *testing.B) { - s := subdomain{ - prefix: "www", - suffix: ".example.com", - } - - o := "www.example.com" - - b.ResetTimer() - b.ReportAllocs() - - for i := 0; i < b.N; i++ { - s.match(o) - } -} From 38ab39b937990f44196b2e0402ef02eed9baab75 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Mon, 18 Mar 2024 21:53:13 -0300 Subject: [PATCH 7/7] refactor(middleware/cors): origin validation and normalization --- middleware/cors/cors.go | 22 ++++++---------------- middleware/cors/cors_test.go | 2 ++ 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/middleware/cors/cors.go b/middleware/cors/cors.go index 9159a1340c..7debfdfaa0 100644 --- a/middleware/cors/cors.go +++ b/middleware/cors/cors.go @@ -119,33 +119,23 @@ func New(config ...Config) fiber.Handler { allowSOrigins := []subdomain{} allowAllOrigins := false - // processOrigin processes an origin string, normalizes it and checks its validity - // it will panic if the origin is invalid - processOrigin := func(origin string) (string, bool) { - trimmedOrigin := strings.TrimSpace(origin) - isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin) - if !isValid { - log.Warnf("[CORS] Invalid origin format in configuration: %s", trimmedOrigin) - panic("[CORS] Invalid origin provided in configuration") - } - return normalizedOrigin, true - } - // Validate and normalize static AllowOrigins if cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" { origins := strings.Split(cfg.AllowOrigins, ",") for _, origin := range origins { if i := strings.Index(origin, "://*."); i != -1 { - normalizedOrigin, isValid := processOrigin(origin[:i+3] + origin[i+4:]) + trimmedOrigin := strings.TrimSpace(origin[:i+3] + origin[i+4:]) + isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin) if !isValid { - continue + panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin) } sd := subdomain{prefix: normalizedOrigin[:i+3], suffix: normalizedOrigin[i+3:]} allowSOrigins = append(allowSOrigins, sd) } else { - normalizedOrigin, isValid := processOrigin(origin) + trimmedOrigin := strings.TrimSpace(origin) + isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin) if !isValid { - continue + panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin) } allowOrigins = append(allowOrigins, normalizedOrigin) } diff --git a/middleware/cors/cors_test.go b/middleware/cors/cors_test.go index 57f5f91205..2e0b5c2244 100644 --- a/middleware/cors/cors_test.go +++ b/middleware/cors/cors_test.go @@ -190,7 +190,9 @@ func Test_CORS_Invalid_Origins_Panic(t *testing.T) { "http://foo.[a-z]*.example.com", "http://*", "https://*", + "http://*.com*", "invalid url", + "http://origin.com,invalid url", // add more invalid origins as needed }