Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔥 Feature: Add support for multiple keys in the KeyAuth middleware #3027

Closed
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api/middleware/keyauth.md
dave-gray101 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ curl --header "Authorization: Bearer my-super-secret-key" http://localhost:3000
| Next | `func(*fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
| SuccessHandler | `fiber.Handler` | SuccessHandler defines a function which is executed for a valid key. | `nil` |
| ErrorHandler | `fiber.ErrorHandler` | ErrorHandler defines a function which is executed for an invalid key. | `401 Invalid or expired key` |
| KeyLookup | `string` | KeyLookup is a string in the form of "`<source>:<name>`" that is used to extract key from the request. | "header:Authorization" |
| KeyLookup | `string` | KeyLookup is a string in the form of "`<source>:<name>`" that is used to extract key from the request. If multiple keys are required, they can be delimited with a pipe (`|`) character, and the first matching key will be selected. | "header:Authorization" |
| AuthScheme | `string` | AuthScheme to be used in the Authorization header. | "Bearer" |
| Validator | `func(*fiber.Ctx, string) (bool, error)` | Validator is a function to validate the key. | A function for key validation |
| ContextKey | `interface{}` | Context key to store the bearer token from the token into context. | "token" |
Expand Down
62 changes: 46 additions & 16 deletions middleware/keyauth/keyauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,37 @@
cookie = "cookie"
)

type extractorFunc func(c *fiber.Ctx) (string, error)

// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Init config
cfg := configDefault(config...)

// Initialize
parts := strings.Split(cfg.KeyLookup, ":")
extractor := keyFromHeader(parts[1], cfg.AuthScheme)
switch parts[0] {
case query:
extractor = keyFromQuery(parts[1])
case form:
extractor = keyFromForm(parts[1])
case param:
extractor = keyFromParam(parts[1])
case cookie:
extractor = keyFromCookie(parts[1])

parts := strings.Split(cfg.KeyLookup, "|")
gaby marked this conversation as resolved.
Show resolved Hide resolved

var extractor extractorFunc
if len(parts) <= 1 {
extractor = parseSingleExtractor(cfg.KeyLookup, cfg.AuthScheme)
} else {
subExtractors := []extractorFunc{}
for _, keyLookup := range parts {
subExtractors = append(subExtractors, parseSingleExtractor(keyLookup, cfg.AuthScheme))
}
extractor = func(c *fiber.Ctx) (string, error) {
for _, subExtractor := range subExtractors {
res, err := subExtractor(c)
if err == nil && res != "" {
return res, nil
}
if !errors.Is(err, ErrMissingOrMalformedAPIKey) {
return "", err
}
}
return "", ErrMissingOrMalformedAPIKey
}
gaby marked this conversation as resolved.
Show resolved Hide resolved
}

// Return middleware handler
Expand All @@ -61,8 +75,24 @@
}
}

func parseSingleExtractor(keyLookup string, authScheme string) extractorFunc {

Check failure on line 78 in middleware/keyauth/keyauth.go

View workflow job for this annotation

GitHub Actions / lint

File is not `gofumpt`-ed with `-extra` (gofumpt)
parts := strings.Split(keyLookup, ":")
extractor := keyFromHeader(parts[1], authScheme)
switch parts[0] {
case query:
extractor = keyFromQuery(parts[1])
case form:
extractor = keyFromForm(parts[1])
case param:
extractor = keyFromParam(parts[1])
case cookie:
extractor = keyFromCookie(parts[1])
}
return extractor
}
gaby marked this conversation as resolved.
Show resolved Hide resolved

// keyFromHeader returns a function that extracts api key from the request header.
func keyFromHeader(header, authScheme string) func(c *fiber.Ctx) (string, error) {
func keyFromHeader(header, authScheme string) extractorFunc {
return func(c *fiber.Ctx) (string, error) {
auth := c.Get(header)
l := len(authScheme)
Expand All @@ -77,7 +107,7 @@
}

// keyFromQuery returns a function that extracts api key from the query string.
func keyFromQuery(param string) func(c *fiber.Ctx) (string, error) {
func keyFromQuery(param string) extractorFunc {
return func(c *fiber.Ctx) (string, error) {
key := c.Query(param)
if key == "" {
Expand All @@ -88,7 +118,7 @@
}

// keyFromForm returns a function that extracts api key from the form.
func keyFromForm(param string) func(c *fiber.Ctx) (string, error) {
func keyFromForm(param string) extractorFunc {
return func(c *fiber.Ctx) (string, error) {
key := c.FormValue(param)
if key == "" {
Expand All @@ -99,7 +129,7 @@
}

// keyFromParam returns a function that extracts api key from the url param string.
func keyFromParam(param string) func(c *fiber.Ctx) (string, error) {
func keyFromParam(param string) extractorFunc {
return func(c *fiber.Ctx) (string, error) {
key, err := url.PathUnescape(c.Params(param))
if err != nil {
Expand All @@ -110,7 +140,7 @@
}

// keyFromCookie returns a function that extracts api key from the named cookie.
func keyFromCookie(name string) func(c *fiber.Ctx) (string, error) {
func keyFromCookie(name string) extractorFunc {
return func(c *fiber.Ctx) (string, error) {
key := c.Cookies(name)
if key == "" {
Expand Down
45 changes: 45 additions & 0 deletions middleware/keyauth/keyauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,51 @@ func TestAuthSources(t *testing.T) {
}
}

func TestMultipleKeyLookup(t *testing.T) {
const (
desc = "auth with correct key"
success = "Success!"
)

// setup the fiber endpoint
app := fiber.New()
authMiddleware := New(Config{
KeyLookup: "header:key|query:key",
Validator: func(c *fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
})
app.Use(authMiddleware)
app.Get("/foo", func(c *fiber.Ctx) error {
return c.SendString(success)
})

// construct the test HTTP request
var req *http.Request
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", nil)
utils.AssertEqual(t, err, nil)
q := req.URL.Query()
q.Add("key", CorrectKey)
req.URL.RawQuery = q.Encode()

res, err := app.Test(req, -1)

utils.AssertEqual(t, nil, err, desc)

// test the body of the request
body, err := io.ReadAll(res.Body)
utils.AssertEqual(t, 200, res.StatusCode, desc)
// body
utils.AssertEqual(t, nil, err, desc)
utils.AssertEqual(t, success, string(body), desc)

err = res.Body.Close()
utils.AssertEqual(t, err, nil)
}

func TestMultipleKeyAuth(t *testing.T) {
// setup the fiber endpoint
app := fiber.New()
Expand Down
Loading