Skip to content

Commit

Permalink
Fix csrf middleware behavior with header key lookup (#2063)
Browse files Browse the repository at this point in the history
* 🐛 [Bug]: Strange CSRF middleware behavior with header KeyLookup configuration #2045
  • Loading branch information
ReneWerner87 authored Aug 30, 2022
1 parent 6026560 commit ec96d16
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 7 deletions.
6 changes: 4 additions & 2 deletions middleware/csrf/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,17 @@ type Config struct {
Extractor func(c *fiber.Ctx) (string, error)
}

const HeaderName = "X-Csrf-Token"

// ConfigDefault is the default config
var ConfigDefault = Config{
KeyLookup: "header:X-Csrf-Token",
KeyLookup: "header:" + HeaderName,
CookieName: "csrf_",
CookieSameSite: "Lax",
Expiration: 1 * time.Hour,
KeyGenerator: utils.UUID,
ErrorHandler: defaultErrorHandler,
Extractor: CsrfFromHeader("X-Csrf-Token"),
Extractor: CsrfFromHeader(HeaderName),
}

// default ErrorHandler that process return error from fiber.Handler
Expand Down
114 changes: 111 additions & 3 deletions middleware/csrf/csrf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func Test_CSRF(t *testing.T) {
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set("X-CSRF-Token", "johndoe")
ctx.Request.Header.Set(HeaderName, "johndoe")
h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode())

Expand All @@ -55,7 +55,7 @@ func Test_CSRF(t *testing.T) {
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set("X-CSRF-Token", token)
ctx.Request.Header.Set(HeaderName, token)
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
}
Expand Down Expand Up @@ -305,7 +305,7 @@ func Test_CSRF_ErrorHandler_InvalidToken(t *testing.T) {
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set("X-CSRF-Token", "johndoe")
ctx.Request.Header.Set(HeaderName, "johndoe")
h(ctx)
utils.AssertEqual(t, 419, ctx.Response.StatusCode())
utils.AssertEqual(t, "invalid CSRF token", string(ctx.Response.Body()))
Expand Down Expand Up @@ -340,3 +340,111 @@ func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) {
utils.AssertEqual(t, 419, ctx.Response.StatusCode())
utils.AssertEqual(t, "empty CSRF token", string(ctx.Response.Body()))
}

// TODO: use this test case and make the unsafe header value bug from https://github.com/gofiber/fiber/issues/2045 reproducible and permanently fixed/tested by this testcase
//func Test_CSRF_UnsafeHeaderValue(t *testing.T) {
// app := fiber.New()
//
// app.Use(New())
// app.Get("/", func(c *fiber.Ctx) error {
// return c.SendStatus(fiber.StatusOK)
// })
// app.Get("/test", func(c *fiber.Ctx) error {
// return c.SendStatus(fiber.StatusOK)
// })
// app.Post("/", func(c *fiber.Ctx) error {
// return c.SendStatus(fiber.StatusOK)
// })
//
// resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
// utils.AssertEqual(t, nil, err)
// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
//
// var token string
// for _, c := range resp.Cookies() {
// if c.Name != ConfigDefault.CookieName {
// continue
// }
// token = c.Value
// break
// }
//
// fmt.Println("token", token)
//
// getReq := httptest.NewRequest(http.MethodGet, "/", nil)
// getReq.Header.Set(HeaderName, token)
// resp, err = app.Test(getReq)
//
// getReq = httptest.NewRequest(http.MethodGet, "/test", nil)
// getReq.Header.Set("X-Requested-With", "XMLHttpRequest")
// getReq.Header.Set(fiber.HeaderCacheControl, "no")
// getReq.Header.Set(HeaderName, token)
//
// resp, err = app.Test(getReq)
//
// getReq.Header.Set(fiber.HeaderAccept, "*/*")
// getReq.Header.Del(HeaderName)
// resp, err = app.Test(getReq)
//
// postReq := httptest.NewRequest(http.MethodPost, "/", nil)
// postReq.Header.Set("X-Requested-With", "XMLHttpRequest")
// postReq.Header.Set(HeaderName, token)
// resp, err = app.Test(postReq)
//}

// go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_Check -benchmem -count=4
func Benchmark_Middleware_CSRF_Check(b *testing.B) {
app := fiber.New()

app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})

fctx := &fasthttp.RequestCtx{}
h := app.Handler()
ctx := &fasthttp.RequestCtx{}

// Generate CSRF token
ctx.Request.Header.SetMethod("GET")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]

ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set(HeaderName, token)

b.ReportAllocs()
b.ResetTimer()

for n := 0; n < b.N; n++ {
h(fctx)
}

utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
}

// go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_GenerateToken -benchmem -count=4
func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) {
app := fiber.New()

app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})

fctx := &fasthttp.RequestCtx{}
h := app.Handler()
ctx := &fasthttp.RequestCtx{}

// Generate CSRF token
ctx.Request.Header.SetMethod("GET")
b.ReportAllocs()
b.ResetTimer()

for n := 0; n < b.N; n++ {
h(fctx)
}

utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
}
7 changes: 5 additions & 2 deletions middleware/csrf/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/memory"
"github.com/gofiber/fiber/v2/utils"
)

// go:generate msgp
Expand Down Expand Up @@ -88,7 +89,8 @@ func (m *manager) set(key string, it *item, exp time.Duration) {
_ = m.storage.Set(key, raw, exp)
}
} else {
m.memory.Set(key, it, exp)
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
m.memory.Set(utils.CopyString(key), it, exp)
}
}

Expand All @@ -97,7 +99,8 @@ func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
if m.storage != nil {
_ = m.storage.Set(key, raw, exp)
} else {
m.memory.Set(key, raw, exp)
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
m.memory.Set(utils.CopyString(key), raw, exp)
}
}

Expand Down

0 comments on commit ec96d16

Please sign in to comment.