From 186cf2b36295573239a3a83807b87dc06a98f58c Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Tue, 7 Dec 2021 20:20:38 +0200 Subject: [PATCH] casbin: add EnforceHandler to allow custom callback to handle enforcing. --- casbin/casbin.go | 70 ++++++++++++++++++++++++------------------- casbin/casbin_test.go | 26 ++++++++++++++++ 2 files changed, 66 insertions(+), 30 deletions(-) diff --git a/casbin/casbin.go b/casbin/casbin.go index fc27670..5f67b3f 100644 --- a/casbin/casbin.go +++ b/casbin/casbin.go @@ -45,11 +45,11 @@ Advanced example: package casbin import ( - "net/http" - + "errors" "github.com/casbin/casbin/v2" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" + "net/http" ) type ( @@ -59,11 +59,18 @@ type ( Skipper middleware.Skipper // Enforcer CasbinAuth main rule. - // Required. + // One of Enforcer or EnforceHandler fields is required. Enforcer *casbin.Enforcer + // EnforceHandler is custom callback to handle enforcing. + // One of Enforcer or EnforceHandler fields is required. + EnforceHandler func(c echo.Context, user string) (bool, error) + // Method to get the username - defaults to using basic auth UserGetter func(c echo.Context) (string, error) + + // Method to handle errors + ErrorHandler func(c echo.Context, internal error, proposedStatus int) error } ) @@ -75,6 +82,11 @@ var ( username, _, _ := c.Request().BasicAuth() return username, nil }, + ErrorHandler: func(c echo.Context, internal error, proposedStatus int) error { + err := echo.NewHTTPError(proposedStatus, internal.Error()) + err.Internal = internal + return err + }, } ) @@ -91,10 +103,23 @@ func Middleware(ce *casbin.Enforcer) echo.MiddlewareFunc { // MiddlewareWithConfig returns a CasbinAuth middleware with config. // See `Middleware()`. func MiddlewareWithConfig(config Config) echo.MiddlewareFunc { - // Defaults + if config.Enforcer == nil && config.EnforceHandler == nil { + panic("one of casbin middleware Enforcer or EnforceHandler fields must be set") + } if config.Skipper == nil { config.Skipper = DefaultConfig.Skipper } + if config.UserGetter == nil { + config.UserGetter = DefaultConfig.UserGetter + } + if config.ErrorHandler == nil { + config.ErrorHandler = DefaultConfig.ErrorHandler + } + if config.EnforceHandler == nil { + config.EnforceHandler = func(c echo.Context, user string) (bool, error) { + return config.Enforcer.Enforce(user, c.Request().URL.Path, c.Request().Method) + } + } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -102,33 +127,18 @@ func MiddlewareWithConfig(config Config) echo.MiddlewareFunc { return next(c) } - if pass, err := config.CheckPermission(c); err == nil && pass { - return next(c) - } else if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + user, err := config.UserGetter(c) + if err != nil { + return config.ErrorHandler(c, err, http.StatusForbidden) } - - return echo.ErrForbidden + pass, err := config.EnforceHandler(c, user) + if err != nil { + return config.ErrorHandler(c, err, http.StatusInternalServerError) + } + if !pass { + return config.ErrorHandler(c, errors.New("enforce did not pass"), http.StatusForbidden) + } + return next(c) } } } - -// GetUserName gets the user name from the request. -// It calls the UserGetter field of the Config struct that allows the caller to customize user identification. -func (a *Config) GetUserName(c echo.Context) (string, error) { - username, err := a.UserGetter(c) - return username, err -} - -// CheckPermission checks the user/method/path combination from the request. -// Returns true (permission granted) or false (permission forbidden) -func (a *Config) CheckPermission(c echo.Context) (bool, error) { - user, err := a.GetUserName(c) - if err != nil { - // Fail safe and do not propagate - return false, nil - } - method := c.Request().Method - path := c.Request().URL.Path - return a.Enforcer.Enforce(user, path, method) -} diff --git a/casbin/casbin_test.go b/casbin/casbin_test.go index 20f8d72..bcaaded 100644 --- a/casbin/casbin_test.go +++ b/casbin/casbin_test.go @@ -2,8 +2,10 @@ package casbin import ( "errors" + "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" + "strings" "testing" "github.com/casbin/casbin/v2" @@ -131,3 +133,27 @@ func TestUserGetterError(t *testing.T) { }) testRequest(t, h, "cathy", "/dataset1/item", "GET", 403) } + +func TestCustomEnforceHandler(t *testing.T) { + ce, err := casbin.NewEnforcer("auth_model.conf", "auth_policy.csv") + assert.NoError(t, err) + + _, err = ce.AddPolicy("bob", "/user/bob", "PATCH_SELF") + assert.NoError(t, err) + + cnf := Config{ + EnforceHandler: func(c echo.Context, user string) (bool, error) { + method := c.Request().Method + if strings.HasPrefix(c.Request().URL.Path, "/user/bob") { + method += "_SELF" + } + return ce.Enforce(user, c.Request().URL.Path, method) + }, + } + h := MiddlewareWithConfig(cnf)(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + testRequest(t, h, "bob", "/dataset2/resource1", "GET", http.StatusOK) + testRequest(t, h, "bob", "/user/alice", "PATCH", http.StatusForbidden) + testRequest(t, h, "bob", "/user/bob", "PATCH", http.StatusOK) +}