Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔥 Feature: Add support for custom KeyLookup functions in the Keyauth middleware #3028

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
0b6b5e9
port over FallbackKeyLookups from v2 middleware to v3
dave-gray101 Jun 9, 2024
f118663
bot pointed out that I missed the format variable
dave-gray101 Jun 9, 2024
a432b80
fix lint and gofumpt issues
dave-gray101 Jun 10, 2024
4e061aa
major revision: instead of FallbackKeyLookups, expose CustomKeyLookup…
dave-gray101 Jun 10, 2024
397ff49
add more tests to boost coverage
dave-gray101 Jun 10, 2024
2968439
teardown code and cleanup
dave-gray101 Jun 10, 2024
5b24181
Merge branch 'main' into feat-keyauth-fallback-keylookup
gaby Jun 11, 2024
7bfc96d
test fixes
dave-gray101 Jun 11, 2024
8c13f25
Merge branch 'main' into feat-keyauth-fallback-keylookup
dave-gray101 Jun 12, 2024
7191f65
slight boost to test coverage
dave-gray101 Jun 16, 2024
1004aa0
Merge branch 'feat-keyauth-fallback-keylookup' of ghgray101:dave-gray…
dave-gray101 Jun 16, 2024
961e8de
docs: fix md table alignment
sixcolors Jun 16, 2024
293c01b
fix comments - change some names, expose functions, improve docs
dave-gray101 Jun 16, 2024
825e11a
Merge branch 'feat-keyauth-fallback-keylookup' of ghgray101:dave-gray…
dave-gray101 Jun 16, 2024
26bc132
missed one old name
dave-gray101 Jun 16, 2024
9588706
fix some suggestions from the bot - error messages, test coverage, ma…
dave-gray101 Jun 17, 2024
2711dc3
Merge branch 'main' into feat-keyauth-fallback-keylookup
gaby Jun 17, 2024
4da76e4
Merge branch 'main' into feat-keyauth-fallback-keylookup
dave-gray101 Jun 18, 2024
12eeca8
Merge branch 'main' into feat-keyauth-fallback-keylookup
dave-gray101 Jun 20, 2024
d6d5bfe
Merge branch 'main' into feat-keyauth-fallback-keylookup
ReneWerner87 Jun 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions docs/middleware/keyauth.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,15 @@ curl --header "Authorization: Bearer my-super-secret-key" http://localhost:3000

## Config

| Property | Type | Description | Default |
|:---------------|:-----------------------------------------|:-------------------------------------------------------------------------------------------------------|:------------------------------|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
| SuccessHandler | `fiber.Handler` | SuccessHandler defines a function which is executed for a valid key. | `nil` |
| ErrorHandler | `fiber.ErrorHandler` | ErrorHandler defines a function which is executed for an invalid key. | `401 Invalid or expired key` |
| KeyLookup | `string` | KeyLookup is a string in the form of "`<source>:<name>`" that is used to extract key from the request. | "header:Authorization" |
| AuthScheme | `string` | AuthScheme to be used in the Authorization header. | "Bearer" |
| Validator | `func(fiber.Ctx, string) (bool, error)` | Validator is a function to validate the key. | A function for key validation |
| Property | Type | Description | Default |
|:----------------|:-----------------------------------------|:-------------------------------------------------------------------------------------------------------|:------------------------------|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
| SuccessHandler | `fiber.Handler` | SuccessHandler defines a function which is executed for a valid key. | `nil` |
| ErrorHandler | `fiber.ErrorHandler` | ErrorHandler defines a function which is executed for an invalid key. | `401 Invalid or expired key` |
| KeyLookup | `string` | KeyLookup is a string in the form of "`<source>:<name>`" that is used to extract the key from the request. | "header:Authorization" |
| CustomKeyLookup | `KeyLookupFunc` aka `func(c fiber.Ctx) (string, error)` | If more complex logic is required to extract the key from the request, an arbitrary function to extract it can be specified here. Utility helper functions are described below. | `nil` |
| AuthScheme | `string` | AuthScheme to be used in the Authorization header. | "Bearer" |
| Validator | `func(fiber.Ctx, string) (bool, error)` | Validator is a function to validate the key. | A function for key validation |

## Default Config

Expand All @@ -237,6 +238,13 @@ var ConfigDefault = Config{
return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired API Key")
},
KeyLookup: "header:" + fiber.HeaderAuthorization,
CustomKeyLookup: nil,
AuthScheme: "Bearer",
}
```

## CustomKeyLookup

Two public utility functions are provided that may be useful when creating custom extraction:
* `DefaultKeyLookup(keyLookup string, authScheme string)`: This is the function that implements the default `KeyLookup` behavior, exposed to be used as a component of custom parsing logic
* `MultipleKeySourceLookup(keyLookups []string, authScheme string)`: Creates a CustomKeyLookup function that checks each listed source using the above function until a key is found or the options are all exhausted. For example, `MultipleKeySourceLookup([]string{"header:Authorization", "header:x-api-key", "cookie:apikey"}, "Bearer")` would first check the standard Authorization header, checks the `x-api-key` header next, and finally checks for a cookie named `apikey`. If any of these contain a valid API key, the request continues. Otherwise, an error is returned.
9 changes: 7 additions & 2 deletions middleware/keyauth/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"github.com/gofiber/fiber/v3"
)

type KeyLookupFunc func(c fiber.Ctx) (string, error)

// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip middleware.
Expand All @@ -32,6 +34,8 @@ type Config struct {
// - "cookie:<name>"
KeyLookup string

CustomKeyLookup KeyLookupFunc

// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
Expand All @@ -51,8 +55,9 @@ var ConfigDefault = Config{
}
return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired API Key")
},
KeyLookup: "header:" + fiber.HeaderAuthorization,
AuthScheme: "Bearer",
KeyLookup: "header:" + fiber.HeaderAuthorization,
CustomKeyLookup: nil,
AuthScheme: "Bearer",
}

// Helper function to set default values
Expand Down
75 changes: 58 additions & 17 deletions middleware/keyauth/keyauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import (
"errors"
"fmt"
"net/url"
"strings"

Expand Down Expand Up @@ -34,17 +35,12 @@
cfg := configDefault(config...)

// Initialize
parts := strings.Split(cfg.KeyLookup, ":")
extractor := keyFromHeader(parts[1], cfg.AuthScheme)
switch parts[0] {
case query:
extractor = keyFromQuery(parts[1])
case form:
extractor = keyFromForm(parts[1])
case param:
extractor = keyFromParam(parts[1])
case cookie:
extractor = keyFromCookie(parts[1])
if cfg.CustomKeyLookup == nil {
var err error
cfg.CustomKeyLookup, err = DefaultKeyLookup(cfg.KeyLookup, cfg.AuthScheme)
if err != nil {
panic(fmt.Errorf("unable to create lookup function: %w", err))
}
dave-gray101 marked this conversation as resolved.
Show resolved Hide resolved
}

// Return middleware handler
Expand All @@ -55,7 +51,7 @@
}

// Extract and verify key
key, err := extractor(c)
key, err := cfg.CustomKeyLookup(c)
if err != nil {
return cfg.ErrorHandler(c, err)
}
Expand All @@ -80,8 +76,53 @@
return token
}

// MultipleKeySourceLookup creates a CustomKeyLookup function that checks multiple sources until one is found
// Each element should be specified according to the format used in KeyLookup
func MultipleKeySourceLookup(keyLookups []string, authScheme string) (KeyLookupFunc, error) {
subExtractors := map[string]KeyLookupFunc{}
var err error
for _, keyLookup := range keyLookups {
subExtractors[keyLookup], err = DefaultKeyLookup(keyLookup, authScheme)
if err != nil {
return nil, err
}
}
return func(c fiber.Ctx) (string, error) {
for keyLookup, subExtractor := range subExtractors {
res, err := subExtractor(c)
if err == nil && res != "" {
return res, nil
}
if !errors.Is(err, ErrMissingOrMalformedAPIKey) {
// Defensive Code - not currently possible to hit
return "", fmt.Errorf("[%s] %w", keyLookup, err)

Check warning on line 98 in middleware/keyauth/keyauth.go

View check run for this annotation

Codecov / codecov/patch

middleware/keyauth/keyauth.go#L98

Added line #L98 was not covered by tests
}
}
return "", ErrMissingOrMalformedAPIKey
}, nil
}

func DefaultKeyLookup(keyLookup, authScheme string) (KeyLookupFunc, error) {
parts := strings.Split(keyLookup, ":")
if len(parts) <= 1 {
return nil, fmt.Errorf("invalid keyLookup: %q, expected format 'source:name'", keyLookup)
}
extractor := KeyFromHeader(parts[1], authScheme) // in the event of an invalid prefix, it is interpreted as header:
switch parts[0] {
case query:
extractor = KeyFromQuery(parts[1])
case form:
extractor = KeyFromForm(parts[1])
case param:
extractor = KeyFromParam(parts[1])
case cookie:
extractor = KeyFromCookie(parts[1])
}
return extractor, nil
}

// keyFromHeader returns a function that extracts api key from the request header.
func keyFromHeader(header, authScheme string) func(c fiber.Ctx) (string, error) {
func KeyFromHeader(header, authScheme string) KeyLookupFunc {
return func(c fiber.Ctx) (string, error) {
auth := c.Get(header)
l := len(authScheme)
Expand All @@ -96,7 +137,7 @@
}

// keyFromQuery returns a function that extracts api key from the query string.
func keyFromQuery(param string) func(c fiber.Ctx) (string, error) {
func KeyFromQuery(param string) KeyLookupFunc {
return func(c fiber.Ctx) (string, error) {
key := fiber.Query[string](c, param)
if key == "" {
Expand All @@ -107,7 +148,7 @@
}

// keyFromForm returns a function that extracts api key from the form.
func keyFromForm(param string) func(c fiber.Ctx) (string, error) {
func KeyFromForm(param string) KeyLookupFunc {
return func(c fiber.Ctx) (string, error) {
key := c.FormValue(param)
if key == "" {
Expand All @@ -118,7 +159,7 @@
}

// keyFromParam returns a function that extracts api key from the url param string.
func keyFromParam(param string) func(c fiber.Ctx) (string, error) {
func KeyFromParam(param string) KeyLookupFunc {
return func(c fiber.Ctx) (string, error) {
key, err := url.PathUnescape(c.Params(param))
if err != nil {
Expand All @@ -129,7 +170,7 @@
}

// keyFromCookie returns a function that extracts api key from the named cookie.
func keyFromCookie(name string) func(c fiber.Ctx) (string, error) {
func KeyFromCookie(name string) KeyLookupFunc {
return func(c fiber.Ctx) (string, error) {
key := c.Cookies(name)
if key == "" {
Expand Down
152 changes: 152 additions & 0 deletions middleware/keyauth/keyauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,109 @@ func Test_AuthSources(t *testing.T) {
}
}

func TestPanicOnInvalidConfiguration(t *testing.T) {
require.Panics(t, func() {
authMiddleware := New(Config{
KeyLookup: "invalid",
})
// We shouldn't even make it this far, but these next two lines prevent authMiddleware from being an unused variable.
app := fiber.New()
defer func() { // testing panics, defer block to ensure cleanup
err := app.Shutdown()
require.NoError(t, err)
}()
app.Use(authMiddleware)
}, "should panic if Validator is missing")

require.Panics(t, func() {
authMiddleware := New(Config{
KeyLookup: "invalid",
Validator: func(_ fiber.Ctx, _ string) (bool, error) {
return true, nil
},
})
// We shouldn't even make it this far, but these next two lines prevent authMiddleware from being an unused variable.
app := fiber.New()
defer func() { // testing panics, defer block to ensure cleanup
err := app.Shutdown()
require.NoError(t, err)
}()
app.Use(authMiddleware)
}, "should panic if CustomKeyLookup is not set AND KeyLookup has an invalid value")
}

func TestCustomKeyUtilityFunctionErrors(t *testing.T) {
const (
scheme = "Bearer"
)

// Invalid element while parsing
_, err := DefaultKeyLookup("invalid", scheme)
require.Error(t, err, "DefaultKeyLookup should fail for 'invalid' keyLookup")

_, err = MultipleKeySourceLookup([]string{"header:key", "invalid"}, scheme)
require.Error(t, err, "MultipleKeySourceLookup should fail for 'invalid' keyLookup")
}

func TestMultipleKeyLookup(t *testing.T) {
const (
desc = "auth with correct key"
success = "Success!"
scheme = "Bearer"
)

// setup the fiber endpoint
app := fiber.New()

customKeyLookup, err := MultipleKeySourceLookup([]string{"header:key", "cookie:key", "query:key"}, scheme)
require.NoError(t, err)

authMiddleware := New(Config{
CustomKeyLookup: customKeyLookup,
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
})
app.Use(authMiddleware)
app.Get("/foo", func(c fiber.Ctx) error {
return c.SendString(success)
})

// construct the test HTTP request
var req *http.Request
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", nil)
require.NoError(t, err)
q := req.URL.Query()
q.Add("key", CorrectKey)
req.URL.RawQuery = q.Encode()

res, err := app.Test(req, -1)

require.NoError(t, err)

// test the body of the request
body, err := io.ReadAll(res.Body)
require.Equal(t, 200, res.StatusCode, desc)
// body
require.NoError(t, err)
require.Equal(t, success, string(body), desc)

err = res.Body.Close()
require.NoError(t, err)

// construct a second request without proper key
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", nil)
require.NoError(t, err)
res, err = app.Test(req, -1)
require.NoError(t, err)
errBody, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(errBody))
}

func Test_MultipleKeyAuth(t *testing.T) {
// setup the fiber endpoint
app := fiber.New()
Expand Down Expand Up @@ -376,6 +479,55 @@ func Test_CustomNextFunc(t *testing.T) {
require.Equal(t, string(body), ErrMissingOrMalformedAPIKey.Error())
}

func Test_TokenFromContext_None(t *testing.T) {
app := fiber.New()
// Define a test handler that checks TokenFromContext
app.Get("/", func(c fiber.Ctx) error {
return c.SendString(TokenFromContext(c))
})

// Verify a "" is sent back if nothing sets the token on the context.
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
// Send
res, err := app.Test(req)
require.NoError(t, err)

// Read the response body into a string
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Empty(t, body)
}

func Test_TokenFromContext(t *testing.T) {
app := fiber.New()
// Wire up keyauth middleware to set TokenFromContext now
app.Use(New(Config{
KeyLookup: "header:Authorization",
AuthScheme: "Basic",
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// Define a test handler that checks TokenFromContext
app.Get("/", func(c fiber.Ctx) error {
return c.SendString(TokenFromContext(c))
})

req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Add("Authorization", "Basic "+CorrectKey)
// Send
res, err := app.Test(req)
require.NoError(t, err)

// Read the response body into a string
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, CorrectKey, string(body))
}

func Test_AuthSchemeToken(t *testing.T) {
app := fiber.New()

Expand Down
Loading