Skip to content

Commit

Permalink
In jwtAuthHandler.ServeHTTP(), no longer saving the request context o…
Browse files Browse the repository at this point in the history
…ff, to allow injecting key/value pairs by user defined functions, such as verifyAccess.
  • Loading branch information
Ruggero (Руджеро) Ferretti authored and Ruggero (Руджеро) Ferretti committed Dec 19, 2024
1 parent 9b0feff commit 2531d0b
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 48 deletions.
32 changes: 17 additions & 15 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,33 +139,34 @@ func JWTAuthMiddleware(errorDomain string, jwtParser JWTParser, opts ...JWTAuthM
}

func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
reqCtx := r.Context()
logger := h.logger(r.Context())

bearerToken := GetBearerTokenFromRequest(r)
if bearerToken == "" {
apiErr := restapi.NewError(h.errorDomain, ErrCodeBearerTokenMissing, ErrMessageBearerTokenMissing)
restapi.RespondError(rw, http.StatusUnauthorized, apiErr, h.logger(reqCtx))
restapi.RespondError(rw, http.StatusUnauthorized, apiErr, h.logger(r.Context()))
return
}
// Add the bearer token to the request context
r = r.WithContext(NewContextWithBearerToken(r.Context(), bearerToken))

var jwtClaims jwt.Claims
if h.tokenIntrospector != nil {
if introspectionResult, err := h.tokenIntrospector.IntrospectToken(reqCtx, bearerToken); err != nil {
if introspectionResult, err := h.tokenIntrospector.IntrospectToken(r.Context(), bearerToken); err != nil {
switch {
case errors.Is(err, idptoken.ErrTokenIntrospectionNotNeeded):
// Do nothing. Access Token already contains all necessary information for authN/authZ.
h.logger(reqCtx).AtLevel(log.LevelDebug, func(logFunc log.LogFunc) {
logger.AtLevel(log.LevelDebug, func(logFunc log.LogFunc) {
logFunc("token's introspection is not needed")
})
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotNeeded)
case errors.Is(err, idptoken.ErrTokenNotIntrospectable):
// Token is not introspectable by some reason.
// In this case, we will parse it as JWT and use it for authZ.
h.logger(reqCtx).Warn("token is not introspectable, it will be used for authentication and authorization as is",
logger.Warn("token is not introspectable, it will be used for authentication and authorization as is",
log.Error(err))
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotIntrospectable)
default:
logger := h.logger(reqCtx)
logger.Error("token's introspection failed", log.Error(err))
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusError)
apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed)
Expand All @@ -174,14 +175,14 @@ func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
}
} else {
if !introspectionResult.IsActive() {
h.logger(reqCtx).Warn("token was successfully introspected, but it is not active")
logger.Warn("token was successfully introspected, but it is not active")
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotActive)
apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed)
restapi.RespondError(rw, http.StatusUnauthorized, apiErr, h.logger(reqCtx))
restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger)
return
}
jwtClaims = introspectionResult.GetClaims()
h.logger(reqCtx).AtLevel(log.LevelDebug, func(logFunc log.LogFunc) {
logger.AtLevel(log.LevelDebug, func(logFunc log.LogFunc) {
logFunc("token was successfully introspected")
})
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusActive)
Expand All @@ -190,26 +191,27 @@ func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {

if jwtClaims == nil {
var err error
if jwtClaims, err = h.jwtParser.Parse(reqCtx, bearerToken); err != nil {
logger := h.logger(reqCtx)
if jwtClaims, err = h.jwtParser.Parse(r.Context(), bearerToken); err != nil {
logger.Error("authentication failed", log.Error(err))
apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed)
restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger)
return
}
}
// Add the JWT claims to the request context
r = r.WithContext(NewContextWithJWTClaims(r.Context(), jwtClaims))

if h.verifyAccess != nil {
// By passing a *http.Request to verifyAccess, we allow its implementations
// to inject new key/value pairs into the request context.
if !h.verifyAccess(r, jwtClaims) {
apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthorizationFailed, ErrMessageAuthorizationFailed)
restapi.RespondError(rw, http.StatusForbidden, apiErr, h.logger(reqCtx))
restapi.RespondError(rw, http.StatusForbidden, apiErr, logger)
return
}
}

reqCtx = NewContextWithBearerToken(reqCtx, bearerToken)
reqCtx = NewContextWithJWTClaims(reqCtx, jwtClaims)
h.next.ServeHTTP(rw, r.WithContext(reqCtx))
h.next.ServeHTTP(rw, r)
}

func (h *jwtAuthHandler) logger(ctx context.Context) log.FieldLogger {
Expand Down
108 changes: 75 additions & 33 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Copyright © 2024 Acronis International GmbH.
Released under MIT license.
*/

package authkit
package authkit_test

import (
"context"
Expand All @@ -14,6 +14,7 @@ import (
"testing"

"github.com/acronis/go-appkit/testutil"
"github.com/acronis/go-authkit"
jwtgo "github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"

Expand All @@ -22,14 +23,21 @@ import (
"github.com/acronis/go-authkit/jwt"
)

const (
errDomain = "TestDomain"
testBearerToken = "a.b.c"
)

type mockJWTAuthMiddlewareNextHandler struct {
ctx context.Context
called int
jwtClaims jwt.Claims
}

func (h *mockJWTAuthMiddlewareNextHandler) ServeHTTP(_ http.ResponseWriter, r *http.Request) {
h.ctx = r.Context()
h.called++
h.jwtClaims = GetJWTClaimsFromContext(r.Context())
h.jwtClaims = authkit.GetJWTClaimsFromContext(r.Context())
}

type mockJWTParser struct {
Expand Down Expand Up @@ -59,21 +67,19 @@ func (i *mockTokenIntrospector) IntrospectToken(_ context.Context, token string)
}

func TestJWTAuthMiddleware(t *testing.T) {
const errDomain = "TestDomain"

t.Run("bearer token is missing", func(t *testing.T) {
for _, headerVal := range []string{"", "foobar", "Bearer", "Bearer "} {
parser := &mockJWTParser{}
next := &mockJWTAuthMiddlewareNextHandler{}
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
if headerVal != "" {
req.Header.Set(HeaderAuthorization, headerVal)
req.Header.Set(authkit.HeaderAuthorization, headerVal)
}
resp := httptest.NewRecorder()

JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req)
authkit.JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req)

testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeBearerTokenMissing)
testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, authkit.ErrCodeBearerTokenMissing)
require.Equal(t, 0, parser.parseCalled)
require.Equal(t, 0, next.called)
require.Nil(t, next.jwtClaims)
Expand All @@ -84,12 +90,12 @@ func TestJWTAuthMiddleware(t *testing.T) {
parser := &mockJWTParser{errToReturn: errors.New("malformed JWT")}
next := &mockJWTAuthMiddlewareNextHandler{}
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req.Header.Set(HeaderAuthorization, "Bearer foobar")
withBearerToken(req, "foobar")
resp := httptest.NewRecorder()

JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req)
authkit.JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req)

testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed)
testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, authkit.ErrCodeAuthenticationFailed)
require.Equal(t, 1, parser.parseCalled)
require.Equal(t, 0, next.called)
require.Nil(t, next.jwtClaims)
Expand All @@ -100,10 +106,10 @@ func TestJWTAuthMiddleware(t *testing.T) {
parser := &mockJWTParser{claimsToReturn: &jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}}
next := &mockJWTAuthMiddlewareNextHandler{}
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req.Header.Set(HeaderAuthorization, "Bearer a.b.c")
withBearerToken(req, testBearerToken)
resp := httptest.NewRecorder()

JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req)
authkit.JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req)

require.Equal(t, http.StatusOK, resp.Code)
require.Equal(t, 1, parser.parseCalled)
Expand All @@ -120,15 +126,16 @@ func TestJWTAuthMiddleware(t *testing.T) {
introspector := &mockTokenIntrospector{errToReturn: errors.New("introspection failed")}
next := &mockJWTAuthMiddlewareNextHandler{}
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req.Header.Set(HeaderAuthorization, "Bearer a.b.c")
withBearerToken(req, testBearerToken)
resp := httptest.NewRecorder()

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusError), 0)

JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req)
authkit.JWTAuthMiddleware(errDomain, parser, authkit.WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).
ServeHTTP(resp, req)

testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed)
testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, authkit.ErrCodeAuthenticationFailed)
require.Equal(t, 1, introspector.introspectCalled)
require.Equal(t, 0, parser.parseCalled)
require.Equal(t, 0, next.called)
Expand All @@ -143,13 +150,14 @@ func TestJWTAuthMiddleware(t *testing.T) {
introspector := &mockTokenIntrospector{errToReturn: idptoken.ErrTokenIntrospectionNotNeeded}
next := &mockJWTAuthMiddlewareNextHandler{}
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req.Header.Set(HeaderAuthorization, "Bearer a.b.c")
withBearerToken(req, testBearerToken)
resp := httptest.NewRecorder()

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotNeeded), 0)

JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req)
authkit.JWTAuthMiddleware(errDomain, parser, authkit.WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).
ServeHTTP(resp, req)

require.Equal(t, http.StatusOK, resp.Code)
require.Equal(t, 1, introspector.introspectCalled)
Expand All @@ -169,13 +177,14 @@ func TestJWTAuthMiddleware(t *testing.T) {
introspector := &mockTokenIntrospector{errToReturn: idptoken.ErrTokenNotIntrospectable}
next := &mockJWTAuthMiddlewareNextHandler{}
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req.Header.Set(HeaderAuthorization, "Bearer a.b.c")
withBearerToken(req, testBearerToken)
resp := httptest.NewRecorder()

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotIntrospectable), 0)

JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req)
authkit.JWTAuthMiddleware(errDomain, parser, authkit.WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).
ServeHTTP(resp, req)

require.Equal(t, http.StatusOK, resp.Code)
require.Equal(t, 1, introspector.introspectCalled)
Expand All @@ -195,15 +204,16 @@ func TestJWTAuthMiddleware(t *testing.T) {
introspector := &mockTokenIntrospector{resultToReturn: &idptoken.DefaultIntrospectionResult{Active: false}}
next := &mockJWTAuthMiddlewareNextHandler{}
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req.Header.Set(HeaderAuthorization, "Bearer a.b.c")
withBearerToken(req, testBearerToken)
resp := httptest.NewRecorder()

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotActive), 0)

JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req)
authkit.JWTAuthMiddleware(errDomain, parser, authkit.WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).
ServeHTTP(resp, req)

testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed)
testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, authkit.ErrCodeAuthenticationFailed)
require.Equal(t, 1, introspector.introspectCalled)
require.Equal(t, 0, parser.parseCalled)
require.Equal(t, 0, next.called)
Expand All @@ -219,13 +229,14 @@ func TestJWTAuthMiddleware(t *testing.T) {
Active: true, DefaultClaims: jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}}}
next := &mockJWTAuthMiddlewareNextHandler{}
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req.Header.Set(HeaderAuthorization, "Bearer a.b.c")
withBearerToken(req, testBearerToken)
resp := httptest.NewRecorder()

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusActive), 0)

JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req)
authkit.JWTAuthMiddleware(errDomain, parser, authkit.WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).
ServeHTTP(resp, req)

require.Equal(t, http.StatusOK, resp.Code)
require.Equal(t, 1, introspector.introspectCalled)
Expand All @@ -239,22 +250,48 @@ func TestJWTAuthMiddleware(t *testing.T) {
testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusActive), 1)
})

t.Run("context keys added by verifyAccess are preserved", func(t *testing.T) {
const issuer = "my-idp.com"
parser := &mockJWTParser{claimsToReturn: &jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}}
next := &mockJWTAuthMiddlewareNextHandler{}
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
withBearerToken(req, testBearerToken)
resp := httptest.NewRecorder()

const (
ctxKey = "verify-access-key"
ctxValue = "verify-access-value"
)
var verifyAccess = func(r *http.Request, claims jwt.Claims) bool {
*r = *r.WithContext(context.WithValue(r.Context(), ctxKey, ctxValue))
return true
}

authkit.JWTAuthMiddleware(errDomain, parser, authkit.WithJWTAuthMiddlewareVerifyAccess(verifyAccess))(next).
ServeHTTP(resp, req)

require.Equal(t, http.StatusOK, resp.Code)
require.Equal(t, 1, parser.parseCalled)
require.Equal(t, 1, next.called)
require.Equal(t, testBearerToken, authkit.GetBearerTokenFromContext(next.ctx), "context is missing bearer token")
require.Equal(t, ctxValue, next.ctx.Value(ctxKey), "context key added by verifyAccess is not preserved")
})
}

func TestJWTAuthMiddlewareWithVerifyAccess(t *testing.T) {
const errDomain = "TestDomain"

t.Run("authorization failed", func(t *testing.T) {
parser := &mockJWTParser{claimsToReturn: &jwt.DefaultClaims{}}
next := &mockJWTAuthMiddlewareNextHandler{}
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req.Header.Set(HeaderAuthorization, "Bearer a.b.c")
withBearerToken(req, testBearerToken)
resp := httptest.NewRecorder()

verifyAccess := NewVerifyAccessByRolesInJWT(Role{Namespace: "my-service", Name: "admin"})
JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareVerifyAccess(verifyAccess))(next).ServeHTTP(resp, req)
verifyAccess := authkit.NewVerifyAccessByRolesInJWT(authkit.Role{Namespace: "my-service", Name: "admin"})
authkit.JWTAuthMiddleware(errDomain, parser, authkit.WithJWTAuthMiddlewareVerifyAccess(verifyAccess))(next).
ServeHTTP(resp, req)

testutil.RequireErrorInRecorder(t, resp, http.StatusForbidden, errDomain, ErrCodeAuthorizationFailed)
testutil.RequireErrorInRecorder(t, resp, http.StatusForbidden, errDomain, authkit.ErrCodeAuthorizationFailed)
require.Equal(t, 1, parser.parseCalled)
require.Equal(t, 0, next.called)
require.Nil(t, next.jwtClaims)
Expand All @@ -265,11 +302,12 @@ func TestJWTAuthMiddlewareWithVerifyAccess(t *testing.T) {
parser := &mockJWTParser{claimsToReturn: &jwt.DefaultClaims{Scope: scope}}
next := &mockJWTAuthMiddlewareNextHandler{}
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req.Header.Set(HeaderAuthorization, "Bearer a.b.c")
withBearerToken(req, testBearerToken)
resp := httptest.NewRecorder()

verifyAccess := NewVerifyAccessByRolesInJWT(Role{Namespace: "my-service", Name: "admin"})
JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareVerifyAccess(verifyAccess))(next).ServeHTTP(resp, req)
verifyAccess := authkit.NewVerifyAccessByRolesInJWT(authkit.Role{Namespace: "my-service", Name: "admin"})
authkit.JWTAuthMiddleware(errDomain, parser, authkit.WithJWTAuthMiddlewareVerifyAccess(verifyAccess))(next).
ServeHTTP(resp, req)

require.Equal(t, http.StatusOK, resp.Code)
require.Equal(t, 1, parser.parseCalled)
Expand All @@ -278,3 +316,7 @@ func TestJWTAuthMiddlewareWithVerifyAccess(t *testing.T) {
require.EqualValues(t, scope, next.jwtClaims.GetScope())
})
}

func withBearerToken(r *http.Request, t string) {
r.Header.Set(authkit.HeaderAuthorization, "Bearer "+t)
}

0 comments on commit 2531d0b

Please sign in to comment.