diff --git a/README.md b/README.md index 521c013..d73b86b 100644 --- a/README.md +++ b/README.md @@ -22,31 +22,24 @@ import ( "github.com/gofiber/keyauth/v2" ) -const ( - apiKey = "my-super-secret-key" -) - var ( - errMissing = &fiber.Error{Code: 403, Message: "Missing API key"} - errInvalid = &fiber.Error{Code: 403, Message: "Invalid API key"} + APIKey = "correct horse battery staple" ) -func validateApiKey(ctx *fiber.Ctx, s string) (bool, error) { - if s == "" { - return false, errMissing - } - if s == apiKey { - return true, nil - } - return false, errInvalid +func validateAPIKey(c *fiber.Ctx, key string) (bool, error) { + if key == APIKey { + return true, nil + } + return false, keyauth.ErrMissingOrMalformedAPIKey } func main() { app := fiber.New() - + + // note that the keyauth middleware needs to be defined before the routes are defined! app.Use(keyauth.New(keyauth.Config{ - KeyLookup: "cookie:access_token", - Validator: validateApiKey, + KeyLookup: "cookie:access_token", + Validator: validateAPIKey, })) app.Get("/", func(c *fiber.Ctx) error { @@ -64,11 +57,11 @@ func main() { curl http://localhost:3000 #> missing or malformed API Key -curl --cookie "access_token=my-super-secret-key" http://localhost:3000 +curl --cookie "access_token=correct horse battery staple" http://localhost:3000 #> Successfully authenticated! curl --cookie "access_token=Clearly A Wrong Key" http://localhost:3000 -#> Invalid or expired API Key +#> missing or malformed API Key ``` For a more detailed example, see also the [`github.com/gofiber/recipes`](https://github.com/gofiber/recipes) repository and specifically the `fiber-envoy-extauthz` repository and the [`keyauth example`](https://github.com/gofiber/recipes/blob/master/fiber-envoy-extauthz/authz/main.go) code. @@ -85,24 +78,15 @@ import ( "github.com/gofiber/fiber/v2" "github.com/gofiber/keyauth/v2" ) - -const ( - apiKey = "my-super-secret-key" -) - var ( - errMissing = &fiber.Error{Code: 403, Message: "Missing API key"} - errInvalid = &fiber.Error{Code: 403, Message: "Invalid API key"} + APIKey = "correct horse battery staple" ) -func validateApiKey(ctx *fiber.Ctx, s string) (bool, error) { - if s == "" { - return false, errMissing - } - if s == apiKey { - return true, nil - } - return false, errInvalid +func validateAPIKey(c *fiber.Ctx, key string) (bool, error) { + if key == APIKey { + return true, nil + } + return false, keyauth.ErrMissingOrMalformedAPIKey } func authFilter(c *fiber.Ctx) bool { @@ -116,8 +100,8 @@ func main() { app.Use(keyauth.New(keyauth.Config{ Filter: authFilter, - KeyLookup: "cookie:access_token", - Validator: validateApiKey, + KeyLookup: "cookie:access_token", + Validator: validateAPIKey, })) app.Get("/", func(c *fiber.Ctx) error { @@ -142,10 +126,10 @@ curl http://localhost:3000 #> Welcome # /authenticated needs to be authenticated -curl --cookie "access_token=my-super-secret-key" http://localhost:3000/authenticated +curl --cookie "access_token=correct horse battery staple" http://localhost:3000/authenticated #> Successfully authenticated! # /auth2 needs to be authenticated too -curl --cookie "access_token=my-super-secret-key" http://localhost:3000/auth2 +curl --cookie "access_token=correct horse battery staple" http://localhost:3000/auth2 #> Successfully authenticated 2! ``` diff --git a/main.go b/main.go index 6cb3e22..fa1d98c 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ package keyauth import ( "errors" + "net/url" "strings" "github.com/gofiber/fiber/v2" @@ -46,7 +47,6 @@ type Config struct { AuthScheme string // Validator is a function to validate key. - // Optional. Default: nil Validator func(*fiber.Ctx, string) (bool, error) // Context key to store the bearertoken from the token into context. @@ -70,7 +70,7 @@ func New(config ...Config) fiber.Handler { if cfg.ErrorHandler == nil { cfg.ErrorHandler = func(c *fiber.Ctx, err error) error { if err == ErrMissingOrMalformedAPIKey { - return c.Status(fiber.StatusBadRequest).SendString(err.Error()) + return c.Status(fiber.StatusUnauthorized).SendString(err.Error()) } return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired API Key") } @@ -83,9 +83,7 @@ func New(config ...Config) fiber.Handler { } } if cfg.Validator == nil { - cfg.Validator = func(c *fiber.Ctx, t string) (bool, error) { - return true, nil - } + panic("fiber: keyauth middleware requires a validator function") } if cfg.ContextKey == "" { cfg.ContextKey = "token" @@ -168,8 +166,8 @@ func keyFromForm(param string) func(c *fiber.Ctx) (string, error) { // keyFromParam returns a function that extracts api key from the url param string. func keyFromParam(param string) func(c *fiber.Ctx) (string, error) { return func(c *fiber.Ctx) (string, error) { - key := c.Params(param) - if key == "" { + key, err := url.PathUnescape(c.Params(param)) + if err != nil { return "", ErrMissingOrMalformedAPIKey } return key, nil diff --git a/main_test.go b/main_test.go index 47dd09e..3cb3d33 100644 --- a/main_test.go +++ b/main_test.go @@ -3,3 +3,268 @@ // 📝 Github Repository: https://github.com/gofiber/fiber package keyauth + +import ( + "io/ioutil" + "net/http" + "net/url" + "testing" + + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/utils" +) + + +func TestAuthSources(t *testing.T) { + + var CorrectKey = "specials: !$%,.#\"!?~`<>@$^*(){}[]|/\\123" + // define test cases + testSources := []string {"header", "cookie", "query", "param", "form"} + + tests := []struct { + route string + authTokenName string + description string + APIKey string + expectedCode int + expectedBody string + }{ + { + route: "/", + authTokenName: "access_token", + description: "auth with correct key", + APIKey: CorrectKey, + expectedCode: 200, + expectedBody: "Success!", + }, + { + route: "/", + authTokenName: "access_token", + description: "auth with no key", + APIKey: "", + expectedCode: 401, // 404 in case of param authentication + expectedBody: "missing or malformed API Key", + }, + { + route: "/", + authTokenName: "access_token", + description: "auth with wrong key", + APIKey: "WRONGKEY", + expectedCode: 401, + expectedBody: "missing or malformed API Key", + }, + } + + + for _, authSource := range testSources { + t.Run(authSource, func(t *testing.T) { + for _, test := range tests { + // setup the fiber endpoint + // note that if UnescapePath: false (the default) + // escaped characters (such as `\"`) will not be handled correctly in the tests + app := fiber.New(fiber.Config{UnescapePath: true}) + + authMiddleware := New(Config{ + KeyLookup: authSource + ":" + test.authTokenName, + Validator: func(c *fiber.Ctx, key string) (bool, error) { + if key == CorrectKey { + return true, nil + } + return false, ErrMissingOrMalformedAPIKey + }, + }) + + var route string + if authSource == "param" { + route = test.route + ":" + test.authTokenName + app.Use(route, authMiddleware) + } else { + route = test.route + app.Use(authMiddleware) + } + + app.Get(route, func(c *fiber.Ctx) error { + return c.SendString("Success!") + }) + + // construct the test HTTP request + var req *http.Request + req, _ = http.NewRequest("GET", test.route, nil) + + // setup the apikey for the different auth schemes + if authSource == "header" { + + req.Header.Set(test.authTokenName, test.APIKey) + + } else if authSource == "cookie" { + + req.Header.Set("Cookie", test.authTokenName + "=" + test.APIKey) + + } else if authSource == "query" || authSource == "form" { + + q := req.URL.Query() + q.Add(test.authTokenName, test.APIKey) + req.URL.RawQuery = q.Encode() + + } else if authSource == "param" { + + r := req.URL.Path + r = r + url.PathEscape(test.APIKey) + req.URL.Path = r + + } + + res, err := app.Test(req, -1) + + utils.AssertEqual(t, nil, err, test.description) + + // test the body of the request + body, err := ioutil.ReadAll(res.Body) + // for param authentication, the route would be /:access_token + // when the access_token is empty, it leads to a 404 (not found) + // not a 401 (auth error) + if authSource == "param" && test.APIKey == "" { + test.expectedCode = 404 + test.expectedBody = "Cannot GET /" + } + utils.AssertEqual(t, test.expectedCode, res.StatusCode, test.description) + + // body + utils.AssertEqual(t, nil, err, test.description) + utils.AssertEqual(t, test.expectedBody, string(body), test.description) + } + }) + } +} + + +func TestMultipleKeyAuth(t *testing.T) { + + // setup the fiber endpoint + app := fiber.New() + + // setup keyauth for /auth1 + app.Use(New(Config{ + Filter: func(c *fiber.Ctx) bool { + return c.OriginalURL() != "/auth1" + }, + KeyLookup: "header:key", + Validator: func(c *fiber.Ctx, key string) (bool, error) { + if key == "password1" { + return true, nil + } + return false, ErrMissingOrMalformedAPIKey + }, + })) + + // setup keyauth for /auth2 + app.Use(New(Config{ + Filter: func(c *fiber.Ctx) bool { + return c.OriginalURL() != "/auth2" + }, + KeyLookup: "header:key", + Validator: func(c *fiber.Ctx, key string) (bool, error) { + if key == "password2" { + return true, nil + } + return false, ErrMissingOrMalformedAPIKey + }, + })) + + app.Get("/", func(c *fiber.Ctx) error { + return c.SendString("No auth needed!") + }) + + app.Get("/auth1", func(c *fiber.Ctx) error { + return c.SendString("Successfully authenticated for auth1!") + }) + + app.Get("/auth2", func(c *fiber.Ctx) error { + return c.SendString("Successfully authenticated for auth2!") + }) + + // define test cases + tests := []struct { + route string + description string + APIKey string + expectedCode int + expectedBody string + }{ + // No auth needed for / + { + route: "/", + description: "No password needed", + APIKey: "", + expectedCode: 200, + expectedBody: "No auth needed!", + }, + + // auth needed for auth1 + { + route: "/auth1", + description: "Normal Authentication Case", + APIKey: "password1", + expectedCode: 200, + expectedBody: "Successfully authenticated for auth1!", + }, + { + route: "/auth1", + description: "Wrong API Key", + APIKey: "WRONG KEY", + expectedCode: 401, + expectedBody: "missing or malformed API Key", + }, + { + route: "/auth1", + description: "Wrong API Key", + APIKey: "", // NO KEY + expectedCode: 401, + expectedBody: "missing or malformed API Key", + }, + + // Auth 2 has a different password + { + route: "/auth2", + description: "Normal Authentication Case for auth2", + APIKey: "password2", + expectedCode: 200, + expectedBody: "Successfully authenticated for auth2!", + }, + { + route: "/auth2", + description: "Wrong API Key", + APIKey: "WRONG KEY", + expectedCode: 401, + expectedBody: "missing or malformed API Key", + }, + { + route: "/auth2", + description: "Wrong API Key", + APIKey: "", // NO KEY + expectedCode: 401, + expectedBody: "missing or malformed API Key", + }, + } + + // run the tests + for _, test := range tests { + var req *http.Request + req, _ = http.NewRequest("GET", test.route, nil) + if test.APIKey != "" { + req.Header.Set("key", test.APIKey) + } + + res, err := app.Test(req, -1) + + utils.AssertEqual(t, nil, err, test.description) + + // test the body of the request + body, err := ioutil.ReadAll(res.Body) + utils.AssertEqual(t, test.expectedCode, res.StatusCode, test.description) + + // body + utils.AssertEqual(t, nil, err, test.description) + utils.AssertEqual(t, test.expectedBody, string(body), test.description) + } +}