Skip to content

Commit

Permalink
casbin: add EnforceHandler to allow custom callback to handle enforcing.
Browse files Browse the repository at this point in the history
  • Loading branch information
aldas committed Dec 16, 2021
1 parent 4d116ee commit 186cf2b
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 30 deletions.
70 changes: 40 additions & 30 deletions casbin/casbin.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ Advanced example:
package casbin

import (
"net/http"

"errors"
"github.com/casbin/casbin/v2"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"net/http"
)

type (
Expand All @@ -59,11 +59,18 @@ type (
Skipper middleware.Skipper

// Enforcer CasbinAuth main rule.
// Required.
// One of Enforcer or EnforceHandler fields is required.
Enforcer *casbin.Enforcer

// EnforceHandler is custom callback to handle enforcing.
// One of Enforcer or EnforceHandler fields is required.
EnforceHandler func(c echo.Context, user string) (bool, error)

// Method to get the username - defaults to using basic auth
UserGetter func(c echo.Context) (string, error)

// Method to handle errors
ErrorHandler func(c echo.Context, internal error, proposedStatus int) error
}
)

Expand All @@ -75,6 +82,11 @@ var (
username, _, _ := c.Request().BasicAuth()
return username, nil
},
ErrorHandler: func(c echo.Context, internal error, proposedStatus int) error {
err := echo.NewHTTPError(proposedStatus, internal.Error())
err.Internal = internal
return err
},
}
)

Expand All @@ -91,44 +103,42 @@ func Middleware(ce *casbin.Enforcer) echo.MiddlewareFunc {
// MiddlewareWithConfig returns a CasbinAuth middleware with config.
// See `Middleware()`.
func MiddlewareWithConfig(config Config) echo.MiddlewareFunc {
// Defaults
if config.Enforcer == nil && config.EnforceHandler == nil {
panic("one of casbin middleware Enforcer or EnforceHandler fields must be set")
}
if config.Skipper == nil {
config.Skipper = DefaultConfig.Skipper
}
if config.UserGetter == nil {
config.UserGetter = DefaultConfig.UserGetter
}
if config.ErrorHandler == nil {
config.ErrorHandler = DefaultConfig.ErrorHandler
}
if config.EnforceHandler == nil {
config.EnforceHandler = func(c echo.Context, user string) (bool, error) {
return config.Enforcer.Enforce(user, c.Request().URL.Path, c.Request().Method)
}
}

return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}

if pass, err := config.CheckPermission(c); err == nil && pass {
return next(c)
} else if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
user, err := config.UserGetter(c)
if err != nil {
return config.ErrorHandler(c, err, http.StatusForbidden)
}

return echo.ErrForbidden
pass, err := config.EnforceHandler(c, user)
if err != nil {
return config.ErrorHandler(c, err, http.StatusInternalServerError)
}
if !pass {
return config.ErrorHandler(c, errors.New("enforce did not pass"), http.StatusForbidden)
}
return next(c)
}
}
}

// GetUserName gets the user name from the request.
// It calls the UserGetter field of the Config struct that allows the caller to customize user identification.
func (a *Config) GetUserName(c echo.Context) (string, error) {
username, err := a.UserGetter(c)
return username, err
}

// CheckPermission checks the user/method/path combination from the request.
// Returns true (permission granted) or false (permission forbidden)
func (a *Config) CheckPermission(c echo.Context) (bool, error) {
user, err := a.GetUserName(c)
if err != nil {
// Fail safe and do not propagate
return false, nil
}
method := c.Request().Method
path := c.Request().URL.Path
return a.Enforcer.Enforce(user, path, method)
}
26 changes: 26 additions & 0 deletions casbin/casbin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package casbin

import (
"errors"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/casbin/casbin/v2"
Expand Down Expand Up @@ -131,3 +133,27 @@ func TestUserGetterError(t *testing.T) {
})
testRequest(t, h, "cathy", "/dataset1/item", "GET", 403)
}

func TestCustomEnforceHandler(t *testing.T) {
ce, err := casbin.NewEnforcer("auth_model.conf", "auth_policy.csv")
assert.NoError(t, err)

_, err = ce.AddPolicy("bob", "/user/bob", "PATCH_SELF")
assert.NoError(t, err)

cnf := Config{
EnforceHandler: func(c echo.Context, user string) (bool, error) {
method := c.Request().Method
if strings.HasPrefix(c.Request().URL.Path, "/user/bob") {
method += "_SELF"
}
return ce.Enforce(user, c.Request().URL.Path, method)
},
}
h := MiddlewareWithConfig(cnf)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
testRequest(t, h, "bob", "/dataset2/resource1", "GET", http.StatusOK)
testRequest(t, h, "bob", "/user/alice", "PATCH", http.StatusForbidden)
testRequest(t, h, "bob", "/user/bob", "PATCH", http.StatusOK)
}

0 comments on commit 186cf2b

Please sign in to comment.