Skip to content

Commit

Permalink
feat: principal is optionally propagated in AuthOptOut handlers (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielpeach authored Aug 10, 2023
1 parent 5235811 commit e0e166a
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 18 deletions.
42 changes: 26 additions & 16 deletions server/authn_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,34 @@ import (
"net/http"
)

func ginAuthMiddleware(as AuthService, log *zap.SugaredLogger) gin.HandlerFunc {
// ginEnforceAuthMiddleware extracts an iam.ArmoryCloudPrincipal from the incoming HTTP request.
// If a principal cannot be extracted from the request, the middleware aborts the middleware chain
// and returns a 401.
func ginEnforceAuthMiddleware(as AuthService, log *zap.SugaredLogger) gin.HandlerFunc {
return func(c *gin.Context) {
// extract access token from request
auth, err := iam.ExtractBearerToken(c.Request)
if err != nil {
apiErr := serr.NewSimpleErrorWithStatusCode(
"Failed to extract access token from request", http.StatusUnauthorized, err)
writeAndLogApiErrorThenAbort(c, apiErr, log)
if err := extractPrincipalFromHTTPRequestAndSetContext(c, as); err != nil {
writeAndLogApiErrorThenAbort(c, err, log)
c.Abort()
return
}
// verify principal from access token
if err := as.VerifyPrincipalAndSetContext(auth, c); err != nil {
apiErr := serr.NewSimpleErrorWithStatusCode(
"Failed to verify principal from access token", http.StatusUnauthorized, err)
writeAndLogApiErrorThenAbort(c, apiErr, log)
c.Abort()
return
}
}
}

// ginAttemptAuthMiddleware attempts to extract an iam.ArmoryCloudPrincipal from the incoming HTTP request,
// but does not abort the middleware chain if it cannot do so.
func ginAttemptAuthMiddleware(as AuthService) gin.HandlerFunc {
return func(c *gin.Context) {
_ = extractPrincipalFromHTTPRequestAndSetContext(c, as)
}
}

func extractPrincipalFromHTTPRequestAndSetContext(c *gin.Context, as AuthService) serr.Error {
auth, err := iam.ExtractBearerToken(c.Request)
if err != nil {
return serr.NewSimpleErrorWithStatusCode("Failed to extract access token from request", http.StatusUnauthorized, err)
}

if err := as.VerifyPrincipalAndSetContext(auth, c); err != nil {
return serr.NewSimpleErrorWithStatusCode("Failed to verify principal from access token", http.StatusUnauthorized, err)
}
return nil
}
160 changes: 160 additions & 0 deletions server/authn_middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package server

import (
"errors"
"github.com/armory-io/go-commons/iam"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"net/http"
"net/http/httptest"
"testing"
)

func TestGinEnforceAuthMiddleware(t *testing.T) {
cases := []struct {
name string
headers map[string][]string
principal *iam.ArmoryCloudPrincipal
verifyPrincipalError error
assertion func(t *testing.T, ctx *gin.Context, response *http.Response)
}{
{
name: "no bearer token: returns 401",
assertion: func(t *testing.T, ctx *gin.Context, response *http.Response) {
assert.Equal(t, http.StatusUnauthorized, response.StatusCode)
assert.True(t, ctx.IsAborted())
},
},
{
name: "no principal: returns 401",
headers: map[string][]string{
"Authorization": {"Bearer <token>"},
},
verifyPrincipalError: errors.New("invalid principal"),
assertion: func(t *testing.T, ctx *gin.Context, response *http.Response) {
assert.Equal(t, http.StatusUnauthorized, response.StatusCode)
assert.True(t, ctx.IsAborted())
},
},
{
name: "principal is propagated into Gin context: returns 200",
headers: map[string][]string{
"Authorization": {"Bearer <token>"},
},
principal: &iam.ArmoryCloudPrincipal{
Name: "America's #1 Principal",
OrgId: "org-id",
EnvId: "env-id",
},
assertion: func(t *testing.T, ctx *gin.Context, response *http.Response) {
assert.Equal(t, http.StatusOK, response.StatusCode)

principal, err := iam.ExtractPrincipalFromContext(ctx.Request.Context())
assert.NoError(t, err)
assert.Equal(t, "America's #1 Principal", principal.Name)
assert.False(t, ctx.IsAborted())
},
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
logger := zap.S()

recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = &http.Request{
Header: c.headers,
}

authService := mockAuthService{
principal: c.principal,
error: c.verifyPrincipalError,
}
ginEnforceAuthMiddleware(authService, logger)(ctx)
c.assertion(t, ctx, recorder.Result())
})
}
}

func TestGinAttemptAuthMiddleware(t *testing.T) {
cases := []struct {
name string
headers map[string][]string
principal *iam.ArmoryCloudPrincipal
verifyPrincipalError error
assertion func(t *testing.T, ctx *gin.Context, response *http.Response)
}{
{
name: "no bearer token: returns 200",
assertion: func(t *testing.T, ctx *gin.Context, response *http.Response) {
assert.Equal(t, http.StatusOK, response.StatusCode)
assert.False(t, ctx.IsAborted())
},
},
{
name: "no principal: returns 200",
headers: map[string][]string{
"Authorization": {"Bearer <token>"},
},
verifyPrincipalError: errors.New("invalid principal"),
assertion: func(t *testing.T, ctx *gin.Context, response *http.Response) {
assert.Equal(t, http.StatusOK, response.StatusCode)
assert.False(t, ctx.IsAborted())
},
},
{
name: "principal is propagated into Gin context: returns 200",
headers: map[string][]string{
"Authorization": {"Bearer <token>"},
},
principal: &iam.ArmoryCloudPrincipal{
Name: "America's #1 Principal",
OrgId: "org-id",
EnvId: "env-id",
},
assertion: func(t *testing.T, ctx *gin.Context, response *http.Response) {
assert.Equal(t, http.StatusOK, response.StatusCode)

principal, err := iam.ExtractPrincipalFromContext(ctx.Request.Context())
assert.NoError(t, err)
assert.Equal(t, "America's #1 Principal", principal.Name)
assert.False(t, ctx.IsAborted())
},
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = &http.Request{
Header: c.headers,
}

authService := mockAuthService{
principal: c.principal,
error: c.verifyPrincipalError,
}
ginAttemptAuthMiddleware(authService)(ctx)
c.assertion(t, ctx, recorder.Result())
})
}
}

type mockAuthService struct {
principal *iam.ArmoryCloudPrincipal
error error
}

func (m mockAuthService) VerifyPrincipalAndSetContext(tokenOrRawHeader string, c *gin.Context) error {
if m.error != nil {
return m.error
}

if m.principal != nil {
c.Request = c.Request.WithContext(iam.DangerouslyWriteUnverifiedPrincipalToContext(c.Request.Context(), m.principal))
}
return nil
}
2 changes: 1 addition & 1 deletion server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ type (
Default bool
// StatusCode The default status code to return when the request is successful, can be overridden by the handler by setting Response.StatusCode in the handler
StatusCode int
// AuthOptOut Set this to true if the handler should skip AuthZ and AuthN, this will cause the principal to be nil in the request context
// AuthOptOut Set this to true if the handler should skip AuthZ and AuthN.
AuthOptOut bool
// AuthZValidator see AuthZValidatorFn
AuthZValidator AuthZValidatorFn
Expand Down
3 changes: 2 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,14 +381,15 @@ func configureServer(
}

authNotEnforcedGroup := g.Group(httpConfig.Prefix)
authNotEnforcedGroup.Use(ginAttemptAuthMiddleware(as))

// Allow a web-app to serve a single page application (SPA), such as react, vue, angular, etc.
if spaConfig.Enabled {
g.Use(spaMiddleware(spaConfig))
}

authRequiredGroup := g.Group(httpConfig.Prefix)
authRequiredGroup.Use(ginAuthMiddleware(as, logger))
authRequiredGroup.Use(ginEnforceAuthMiddleware(as, logger))

handlerRegistry, err := newHandlerRegistry(name, logger, requestValidator, controllers)
if err != nil {
Expand Down

0 comments on commit e0e166a

Please sign in to comment.