diff --git a/middleware.go b/middleware.go index 8dc2912..ed74f1a 100644 --- a/middleware.go +++ b/middleware.go @@ -139,33 +139,34 @@ func JWTAuthMiddleware(errorDomain string, jwtParser JWTParser, opts ...JWTAuthM } func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { - reqCtx := r.Context() + logger := h.logger(r.Context()) bearerToken := GetBearerTokenFromRequest(r) if bearerToken == "" { apiErr := restapi.NewError(h.errorDomain, ErrCodeBearerTokenMissing, ErrMessageBearerTokenMissing) - restapi.RespondError(rw, http.StatusUnauthorized, apiErr, h.logger(reqCtx)) + restapi.RespondError(rw, http.StatusUnauthorized, apiErr, h.logger(r.Context())) return } + // Add the bearer token to the request context + r = r.WithContext(NewContextWithBearerToken(r.Context(), bearerToken)) var jwtClaims jwt.Claims if h.tokenIntrospector != nil { - if introspectionResult, err := h.tokenIntrospector.IntrospectToken(reqCtx, bearerToken); err != nil { + if introspectionResult, err := h.tokenIntrospector.IntrospectToken(r.Context(), bearerToken); err != nil { switch { case errors.Is(err, idptoken.ErrTokenIntrospectionNotNeeded): // Do nothing. Access Token already contains all necessary information for authN/authZ. - h.logger(reqCtx).AtLevel(log.LevelDebug, func(logFunc log.LogFunc) { + logger.AtLevel(log.LevelDebug, func(logFunc log.LogFunc) { logFunc("token's introspection is not needed") }) h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotNeeded) case errors.Is(err, idptoken.ErrTokenNotIntrospectable): // Token is not introspectable by some reason. // In this case, we will parse it as JWT and use it for authZ. - h.logger(reqCtx).Warn("token is not introspectable, it will be used for authentication and authorization as is", + logger.Warn("token is not introspectable, it will be used for authentication and authorization as is", log.Error(err)) h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotIntrospectable) default: - logger := h.logger(reqCtx) logger.Error("token's introspection failed", log.Error(err)) h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusError) apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed) @@ -174,14 +175,14 @@ func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } } else { if !introspectionResult.IsActive() { - h.logger(reqCtx).Warn("token was successfully introspected, but it is not active") + logger.Warn("token was successfully introspected, but it is not active") h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotActive) apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed) - restapi.RespondError(rw, http.StatusUnauthorized, apiErr, h.logger(reqCtx)) + restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger) return } jwtClaims = introspectionResult.GetClaims() - h.logger(reqCtx).AtLevel(log.LevelDebug, func(logFunc log.LogFunc) { + logger.AtLevel(log.LevelDebug, func(logFunc log.LogFunc) { logFunc("token was successfully introspected") }) h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusActive) @@ -190,26 +191,27 @@ func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { if jwtClaims == nil { var err error - if jwtClaims, err = h.jwtParser.Parse(reqCtx, bearerToken); err != nil { - logger := h.logger(reqCtx) + if jwtClaims, err = h.jwtParser.Parse(r.Context(), bearerToken); err != nil { logger.Error("authentication failed", log.Error(err)) apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed) restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger) return } } + // Add the JWT claims to the request context + r = r.WithContext(NewContextWithJWTClaims(r.Context(), jwtClaims)) if h.verifyAccess != nil { + // By passing a *http.Request to verifyAccess, we allow its implementations + // to inject new key/value pairs into the request context. if !h.verifyAccess(r, jwtClaims) { apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthorizationFailed, ErrMessageAuthorizationFailed) - restapi.RespondError(rw, http.StatusForbidden, apiErr, h.logger(reqCtx)) + restapi.RespondError(rw, http.StatusForbidden, apiErr, logger) return } } - reqCtx = NewContextWithBearerToken(reqCtx, bearerToken) - reqCtx = NewContextWithJWTClaims(reqCtx, jwtClaims) - h.next.ServeHTTP(rw, r.WithContext(reqCtx)) + h.next.ServeHTTP(rw, r) } func (h *jwtAuthHandler) logger(ctx context.Context) log.FieldLogger { diff --git a/middleware_test.go b/middleware_test.go index 66d35f2..f048aea 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -4,7 +4,7 @@ Copyright © 2024 Acronis International GmbH. Released under MIT license. */ -package authkit +package authkit_test import ( "context" @@ -14,6 +14,7 @@ import ( "testing" "github.com/acronis/go-appkit/testutil" + "github.com/acronis/go-authkit" jwtgo "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" @@ -22,14 +23,21 @@ import ( "github.com/acronis/go-authkit/jwt" ) +const ( + errDomain = "TestDomain" + testBearerToken = "a.b.c" +) + type mockJWTAuthMiddlewareNextHandler struct { + ctx context.Context called int jwtClaims jwt.Claims } func (h *mockJWTAuthMiddlewareNextHandler) ServeHTTP(_ http.ResponseWriter, r *http.Request) { + h.ctx = r.Context() h.called++ - h.jwtClaims = GetJWTClaimsFromContext(r.Context()) + h.jwtClaims = authkit.GetJWTClaimsFromContext(r.Context()) } type mockJWTParser struct { @@ -59,21 +67,19 @@ func (i *mockTokenIntrospector) IntrospectToken(_ context.Context, token string) } func TestJWTAuthMiddleware(t *testing.T) { - const errDomain = "TestDomain" - t.Run("bearer token is missing", func(t *testing.T) { for _, headerVal := range []string{"", "foobar", "Bearer", "Bearer "} { parser := &mockJWTParser{} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) if headerVal != "" { - req.Header.Set(HeaderAuthorization, headerVal) + req.Header.Set(authkit.HeaderAuthorization, headerVal) } resp := httptest.NewRecorder() - JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req) + authkit.JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req) - testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeBearerTokenMissing) + testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, authkit.ErrCodeBearerTokenMissing) require.Equal(t, 0, parser.parseCalled) require.Equal(t, 0, next.called) require.Nil(t, next.jwtClaims) @@ -84,12 +90,12 @@ func TestJWTAuthMiddleware(t *testing.T) { parser := &mockJWTParser{errToReturn: errors.New("malformed JWT")} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) - req.Header.Set(HeaderAuthorization, "Bearer foobar") + withBearerToken(req, "foobar") resp := httptest.NewRecorder() - JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req) + authkit.JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req) - testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed) + testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, authkit.ErrCodeAuthenticationFailed) require.Equal(t, 1, parser.parseCalled) require.Equal(t, 0, next.called) require.Nil(t, next.jwtClaims) @@ -100,10 +106,10 @@ func TestJWTAuthMiddleware(t *testing.T) { parser := &mockJWTParser{claimsToReturn: &jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) - req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + withBearerToken(req, testBearerToken) resp := httptest.NewRecorder() - JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req) + authkit.JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req) require.Equal(t, http.StatusOK, resp.Code) require.Equal(t, 1, parser.parseCalled) @@ -120,15 +126,16 @@ func TestJWTAuthMiddleware(t *testing.T) { introspector := &mockTokenIntrospector{errToReturn: errors.New("introspection failed")} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) - req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + withBearerToken(req, testBearerToken) resp := httptest.NewRecorder() testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusError), 0) - JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) + authkit.JWTAuthMiddleware(errDomain, parser, authkit.WithJWTAuthMiddlewareTokenIntrospector(introspector))(next). + ServeHTTP(resp, req) - testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed) + testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, authkit.ErrCodeAuthenticationFailed) require.Equal(t, 1, introspector.introspectCalled) require.Equal(t, 0, parser.parseCalled) require.Equal(t, 0, next.called) @@ -143,13 +150,14 @@ func TestJWTAuthMiddleware(t *testing.T) { introspector := &mockTokenIntrospector{errToReturn: idptoken.ErrTokenIntrospectionNotNeeded} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) - req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + withBearerToken(req, testBearerToken) resp := httptest.NewRecorder() testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotNeeded), 0) - JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) + authkit.JWTAuthMiddleware(errDomain, parser, authkit.WithJWTAuthMiddlewareTokenIntrospector(introspector))(next). + ServeHTTP(resp, req) require.Equal(t, http.StatusOK, resp.Code) require.Equal(t, 1, introspector.introspectCalled) @@ -169,13 +177,14 @@ func TestJWTAuthMiddleware(t *testing.T) { introspector := &mockTokenIntrospector{errToReturn: idptoken.ErrTokenNotIntrospectable} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) - req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + withBearerToken(req, testBearerToken) resp := httptest.NewRecorder() testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotIntrospectable), 0) - JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) + authkit.JWTAuthMiddleware(errDomain, parser, authkit.WithJWTAuthMiddlewareTokenIntrospector(introspector))(next). + ServeHTTP(resp, req) require.Equal(t, http.StatusOK, resp.Code) require.Equal(t, 1, introspector.introspectCalled) @@ -195,15 +204,16 @@ func TestJWTAuthMiddleware(t *testing.T) { introspector := &mockTokenIntrospector{resultToReturn: &idptoken.DefaultIntrospectionResult{Active: false}} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) - req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + withBearerToken(req, testBearerToken) resp := httptest.NewRecorder() testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotActive), 0) - JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) + authkit.JWTAuthMiddleware(errDomain, parser, authkit.WithJWTAuthMiddlewareTokenIntrospector(introspector))(next). + ServeHTTP(resp, req) - testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed) + testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, authkit.ErrCodeAuthenticationFailed) require.Equal(t, 1, introspector.introspectCalled) require.Equal(t, 0, parser.parseCalled) require.Equal(t, 0, next.called) @@ -219,13 +229,14 @@ func TestJWTAuthMiddleware(t *testing.T) { Active: true, DefaultClaims: jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}}} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) - req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + withBearerToken(req, testBearerToken) resp := httptest.NewRecorder() testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusActive), 0) - JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) + authkit.JWTAuthMiddleware(errDomain, parser, authkit.WithJWTAuthMiddlewareTokenIntrospector(introspector))(next). + ServeHTTP(resp, req) require.Equal(t, http.StatusOK, resp.Code) require.Equal(t, 1, introspector.introspectCalled) @@ -239,22 +250,48 @@ func TestJWTAuthMiddleware(t *testing.T) { testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusActive), 1) }) + + t.Run("context keys added by verifyAccess are preserved", func(t *testing.T) { + const issuer = "my-idp.com" + parser := &mockJWTParser{claimsToReturn: &jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + withBearerToken(req, testBearerToken) + resp := httptest.NewRecorder() + + const ( + ctxKey = "verify-access-key" + ctxValue = "verify-access-value" + ) + var verifyAccess = func(r *http.Request, claims jwt.Claims) bool { + *r = *r.WithContext(context.WithValue(r.Context(), ctxKey, ctxValue)) + return true + } + + authkit.JWTAuthMiddleware(errDomain, parser, authkit.WithJWTAuthMiddlewareVerifyAccess(verifyAccess))(next). + ServeHTTP(resp, req) + + require.Equal(t, http.StatusOK, resp.Code) + require.Equal(t, 1, parser.parseCalled) + require.Equal(t, 1, next.called) + require.Equal(t, testBearerToken, authkit.GetBearerTokenFromContext(next.ctx), "context is missing bearer token") + require.Equal(t, ctxValue, next.ctx.Value(ctxKey), "context key added by verifyAccess is not preserved") + }) } func TestJWTAuthMiddlewareWithVerifyAccess(t *testing.T) { - const errDomain = "TestDomain" - t.Run("authorization failed", func(t *testing.T) { parser := &mockJWTParser{claimsToReturn: &jwt.DefaultClaims{}} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) - req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + withBearerToken(req, testBearerToken) resp := httptest.NewRecorder() - verifyAccess := NewVerifyAccessByRolesInJWT(Role{Namespace: "my-service", Name: "admin"}) - JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareVerifyAccess(verifyAccess))(next).ServeHTTP(resp, req) + verifyAccess := authkit.NewVerifyAccessByRolesInJWT(authkit.Role{Namespace: "my-service", Name: "admin"}) + authkit.JWTAuthMiddleware(errDomain, parser, authkit.WithJWTAuthMiddlewareVerifyAccess(verifyAccess))(next). + ServeHTTP(resp, req) - testutil.RequireErrorInRecorder(t, resp, http.StatusForbidden, errDomain, ErrCodeAuthorizationFailed) + testutil.RequireErrorInRecorder(t, resp, http.StatusForbidden, errDomain, authkit.ErrCodeAuthorizationFailed) require.Equal(t, 1, parser.parseCalled) require.Equal(t, 0, next.called) require.Nil(t, next.jwtClaims) @@ -265,11 +302,12 @@ func TestJWTAuthMiddlewareWithVerifyAccess(t *testing.T) { parser := &mockJWTParser{claimsToReturn: &jwt.DefaultClaims{Scope: scope}} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) - req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + withBearerToken(req, testBearerToken) resp := httptest.NewRecorder() - verifyAccess := NewVerifyAccessByRolesInJWT(Role{Namespace: "my-service", Name: "admin"}) - JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareVerifyAccess(verifyAccess))(next).ServeHTTP(resp, req) + verifyAccess := authkit.NewVerifyAccessByRolesInJWT(authkit.Role{Namespace: "my-service", Name: "admin"}) + authkit.JWTAuthMiddleware(errDomain, parser, authkit.WithJWTAuthMiddlewareVerifyAccess(verifyAccess))(next). + ServeHTTP(resp, req) require.Equal(t, http.StatusOK, resp.Code) require.Equal(t, 1, parser.parseCalled) @@ -278,3 +316,7 @@ func TestJWTAuthMiddlewareWithVerifyAccess(t *testing.T) { require.EqualValues(t, scope, next.jwtClaims.GetScope()) }) } + +func withBearerToken(r *http.Request, t string) { + r.Header.Set(authkit.HeaderAuthorization, "Bearer "+t) +}