diff --git a/internal/concierge/apiserver/apiserver.go b/internal/concierge/apiserver/apiserver.go index 1c7bf3f049..184b5e00b1 100644 --- a/internal/concierge/apiserver/apiserver.go +++ b/internal/concierge/apiserver/apiserver.go @@ -82,7 +82,7 @@ func (c completedConfig) New() (*PinnipedServer, error) { for _, f := range []func() (schema.GroupVersionResource, rest.Storage){ func() (schema.GroupVersionResource, rest.Storage) { tokenCredReqGVR := c.ExtraConfig.LoginConciergeGroupVersion.WithResource("tokencredentialrequests") - tokenCredStorage := credentialrequest.NewREST(c.ExtraConfig.Authenticator, c.ExtraConfig.Issuer, tokenCredReqGVR.GroupResource()) + tokenCredStorage := credentialrequest.NewREST(c.ExtraConfig.Authenticator, c.ExtraConfig.Issuer, tokenCredReqGVR.GroupResource(), plog.New()) return tokenCredReqGVR, tokenCredStorage }, func() (schema.GroupVersionResource, rest.Storage) { diff --git a/internal/controller/supervisorstorage/garbage_collector.go b/internal/controller/supervisorstorage/garbage_collector.go index 0ae8ac5ca3..0ed7231978 100644 --- a/internal/controller/supervisorstorage/garbage_collector.go +++ b/internal/controller/supervisorstorage/garbage_collector.go @@ -9,6 +9,7 @@ import ( "fmt" "time" + "github.com/ory/fosite" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" @@ -35,10 +36,12 @@ import ( const minimumRepeatInterval = 30 * time.Second type garbageCollectorController struct { - idpCache UpstreamOIDCIdentityProviderICache - secretInformer corev1informers.SecretInformer - kubeClient kubernetes.Interface - clock clock.Clock + idpCache UpstreamOIDCIdentityProviderICache + secretInformer corev1informers.SecretInformer + kubeClient kubernetes.Interface + clock clock.Clock + auditLogger plog.AuditLogger + timeOfMostRecentSweep time.Time } @@ -53,6 +56,7 @@ func GarbageCollectorController( kubeClient kubernetes.Interface, secretInformer corev1informers.SecretInformer, withInformer pinnipedcontroller.WithInformerOptionFunc, + auditLogger plog.AuditLogger, ) controllerlib.Controller { isSecretWithGCAnnotation := func(obj metav1.Object) bool { secret, ok := obj.(*corev1.Secret) @@ -70,6 +74,7 @@ func GarbageCollectorController( secretInformer: secretInformer, kubeClient: kubeClient, clock: clock, + auditLogger: auditLogger, }, }, withInformer( @@ -163,6 +168,7 @@ func (c *garbageCollectorController) Sync(ctx controllerlib.Context) error { plog.WarningErr("failed to garbage collect resource", err, logKV(secret)...) continue } + c.maybeAuditLogGC(storageType, secret) plog.Info("storage garbage collector deleted resource", logKV(secret)...) } @@ -192,7 +198,10 @@ func (c *garbageCollectorController) maybeRevokeUpstreamOIDCToken(ctx context.Co return nil } // When the downstream authcode was never used, then its storage must contain the latest upstream token. - return c.tryRevokeUpstreamOIDCToken(ctx, authorizeCodeSession.Request.Session.(*psession.PinnipedSession).Custom, secret) + return c.tryRevokeUpstreamOIDCToken(ctx, + authorizeCodeSession.Request.Session.(*psession.PinnipedSession).Custom, + authorizeCodeSession.Request, + secret) case accesstoken.TypeLabelValue: // For access token storage, check if the "offline_access" scope was granted on the downstream session. @@ -203,11 +212,13 @@ func (c *garbageCollectorController) maybeRevokeUpstreamOIDCToken(ctx context.Co if err != nil { return err } - pinnipedSession := accessTokenSession.Request.Session.(*psession.PinnipedSession) if accessTokenSession.Request.GetGrantedScopes().Has(oidcapi.ScopeOfflineAccess) { return nil } - return c.tryRevokeUpstreamOIDCToken(ctx, pinnipedSession.Custom, secret) + return c.tryRevokeUpstreamOIDCToken(ctx, + accessTokenSession.Request.Session.(*psession.PinnipedSession).Custom, + accessTokenSession.Request, + secret) case refreshtoken.TypeLabelValue: // For refresh token storage, always revoke its upstream token. This refresh token storage could be @@ -217,7 +228,10 @@ func (c *garbageCollectorController) maybeRevokeUpstreamOIDCToken(ctx context.Co if err != nil { return err } - return c.tryRevokeUpstreamOIDCToken(ctx, refreshTokenSession.Request.Session.(*psession.PinnipedSession).Custom, secret) + return c.tryRevokeUpstreamOIDCToken(ctx, + refreshTokenSession.Request.Session.(*psession.PinnipedSession).Custom, + refreshTokenSession.Request, + secret) case pkce.TypeLabelValue: // For PKCE storage, its very existence means that the downstream authcode was never exchanged, because @@ -237,7 +251,12 @@ func (c *garbageCollectorController) maybeRevokeUpstreamOIDCToken(ctx context.Co } } -func (c *garbageCollectorController) tryRevokeUpstreamOIDCToken(ctx context.Context, customSessionData *psession.CustomSessionData, secret *corev1.Secret) error { +func (c *garbageCollectorController) tryRevokeUpstreamOIDCToken( + ctx context.Context, + customSessionData *psession.CustomSessionData, + request *fosite.Request, + secret *corev1.Secret, +) error { // When session was for another upstream IDP type, e.g. LDAP, there is no upstream OIDC token involved. if customSessionData.ProviderType != psession.ProviderTypeOIDC { return nil @@ -264,6 +283,8 @@ func (c *garbageCollectorController) tryRevokeUpstreamOIDCToken(ctx context.Cont if err != nil { return err } + c.auditLogger.Audit(plog.AuditEventUpstreamOIDCTokenRevoked, nil, request, + "type", upstreamprovider.RefreshTokenType) plog.Trace("garbage collector successfully revoked upstream OIDC refresh token (or provider has no revocation endpoint)", logKV(secret)...) } @@ -272,12 +293,56 @@ func (c *garbageCollectorController) tryRevokeUpstreamOIDCToken(ctx context.Cont if err != nil { return err } + c.auditLogger.Audit(plog.AuditEventUpstreamOIDCTokenRevoked, nil, request, + "type", upstreamprovider.AccessTokenType) plog.Trace("garbage collector successfully revoked upstream OIDC access token (or provider has no revocation endpoint)", logKV(secret)...) } return nil } +func (c *garbageCollectorController) maybeAuditLogGC(storageType string, secret *corev1.Secret) { + r, err := c.requestFromSecret(storageType, secret) + if err == nil && r != nil { + c.auditLogger.Audit(plog.AuditEventSessionGarbageCollected, nil, r, "storageType", storageType) + } +} + +func (c *garbageCollectorController) requestFromSecret(storageType string, secret *corev1.Secret) (*fosite.Request, error) { + switch storageType { + case authorizationcode.TypeLabelValue: + authorizeCodeSession, err := authorizationcode.ReadFromSecret(secret) + if err != nil { + return nil, err + } + return authorizeCodeSession.Request, nil + + case accesstoken.TypeLabelValue: + accessTokenSession, err := accesstoken.ReadFromSecret(secret) + if err != nil { + return nil, err + } + return accessTokenSession.Request, nil + + case refreshtoken.TypeLabelValue: + refreshTokenSession, err := refreshtoken.ReadFromSecret(secret) + if err != nil { + return nil, err + } + return refreshTokenSession.Request, nil + + case pkce.TypeLabelValue: + return nil, nil // if this still exists, then it means that the user never exchanged their authcode + + case openidconnect.TypeLabelValue: + return nil, nil // if this still exists, then it means that the user never exchanged their authcode + + default: + // There are no other storage types, so this should never happen in practice. + return nil, errors.New("garbage collector saw invalid label on Secret when trying to determine session ID") + } +} + func logKV(secret *corev1.Secret) []any { return []any{ "secretName", secret.Name, diff --git a/internal/controller/supervisorstorage/garbage_collector_test.go b/internal/controller/supervisorstorage/garbage_collector_test.go index 903ee7753f..436819e15f 100644 --- a/internal/controller/supervisorstorage/garbage_collector_test.go +++ b/internal/controller/supervisorstorage/garbage_collector_test.go @@ -31,6 +31,7 @@ import ( "go.pinniped.dev/internal/fositestorage/accesstoken" "go.pinniped.dev/internal/fositestorage/authorizationcode" "go.pinniped.dev/internal/fositestorage/refreshtoken" + "go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/psession" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil/oidctestutil" @@ -55,6 +56,7 @@ func TestGarbageCollectorControllerInformerFilters(t *testing.T) { nil, secretsInformer, observableWithInformerOption.WithInformer, // make it possible to observe the behavior of the Filters + plog.New(), ) secretsInformerFilter = observableWithInformerOption.GetFilterForInformer(secretsInformer) }) @@ -148,6 +150,7 @@ func TestGarbageCollectorControllerSync(t *testing.T) { kubeClient, kubeInformers.Core().V1().Secrets(), controllerlib.WithInformer, + plog.New(), ) // Set this at the last second to support calling subject.Name(). diff --git a/internal/federationdomain/downstreamsession/downstream_session.go b/internal/federationdomain/downstreamsession/downstream_session.go index 512557afd5..5097a6adf3 100644 --- a/internal/federationdomain/downstreamsession/downstream_session.go +++ b/internal/federationdomain/downstreamsession/downstream_session.go @@ -34,20 +34,30 @@ type SessionConfig struct { ClientID string // The scopes that were granted for the new downstream session. GrantedScopes []string + // The identity provider used to authenticate the user. + IdentityProvider resolvedprovider.FederationDomainResolvedIdentityProvider + // The fosite Requester that is starting this session. + SessionIDGetter plog.SessionIDGetter } // NewPinnipedSession applies the configured FederationDomain identity transformations // and creates a downstream Pinniped session. func NewPinnipedSession( ctx context.Context, - idp resolvedprovider.FederationDomainResolvedIdentityProvider, + auditLogger plog.AuditLogger, c *SessionConfig, ) (*psession.PinnipedSession, error) { now := time.Now().UTC() + auditLogger.Audit(plog.AuditEventIdentityFromUpstreamIDP, ctx, c.SessionIDGetter, + "upstreamUsername", c.UpstreamIdentity.UpstreamUsername, + "upstreamGroups", c.UpstreamIdentity.UpstreamGroups) + downstreamUsername, downstreamGroups, err := applyIdentityTransformations(ctx, - idp.GetTransforms(), c.UpstreamIdentity.UpstreamUsername, c.UpstreamIdentity.UpstreamGroups) + c.IdentityProvider.GetTransforms(), c.UpstreamIdentity.UpstreamUsername, c.UpstreamIdentity.UpstreamGroups) if err != nil { + auditLogger.Audit(plog.AuditEventAuthenticationRejectedByTransforms, ctx, c.SessionIDGetter, + "err", err) return nil, err } @@ -55,12 +65,12 @@ func NewPinnipedSession( Username: downstreamUsername, UpstreamUsername: c.UpstreamIdentity.UpstreamUsername, UpstreamGroups: c.UpstreamIdentity.UpstreamGroups, - ProviderUID: idp.GetProvider().GetResourceUID(), - ProviderName: idp.GetProvider().GetResourceName(), - ProviderType: idp.GetSessionProviderType(), + ProviderUID: c.IdentityProvider.GetProvider().GetResourceUID(), + ProviderName: c.IdentityProvider.GetProvider().GetResourceName(), + ProviderType: c.IdentityProvider.GetSessionProviderType(), Warnings: c.UpstreamLoginExtras.Warnings, } - idp.ApplyIDPSpecificSessionDataToSession(customSessionData, c.UpstreamIdentity.IDPSpecificSessionData) + c.IdentityProvider.ApplyIDPSpecificSessionDataToSession(customSessionData, c.UpstreamIdentity.IDPSpecificSessionData) pinnipedSession := &psession.PinnipedSession{ Fosite: &openid.DefaultSession{ @@ -94,6 +104,13 @@ func NewPinnipedSession( pinnipedSession.IDTokenClaims().Extra = extras + auditLogger.Audit(plog.AuditEventSessionStarted, ctx, c.SessionIDGetter, + "username", downstreamUsername, + "groups", downstreamGroups, + "subject", c.UpstreamIdentity.DownstreamSubject, + "additionalClaims", c.UpstreamLoginExtras.DownstreamAdditionalClaims, + "warnings", c.UpstreamLoginExtras.Warnings) + return pinnipedSession, nil } diff --git a/internal/federationdomain/endpoints/auth/auth_handler.go b/internal/federationdomain/endpoints/auth/auth_handler.go index aa36835c93..a0e75bb464 100644 --- a/internal/federationdomain/endpoints/auth/auth_handler.go +++ b/internal/federationdomain/endpoints/auth/auth_handler.go @@ -13,6 +13,7 @@ import ( "github.com/ory/fosite" "github.com/ory/fosite/handler/openid" fositejwt "github.com/ory/fosite/token/jwt" + "k8s.io/apimachinery/pkg/util/sets" oidcapi "go.pinniped.dev/generated/latest/apis/supervisor/oidc" "go.pinniped.dev/internal/federationdomain/csrftoken" @@ -34,6 +35,21 @@ const ( promptParamNone = "none" ) +//nolint:gochecknoglobals // please treat this as a readonly const, do not mutate +var paramsSafeToLog = sets.New[string]( + // Standard params from https://openid.net/specs/openid-connect-core-1_0.html, some of which are ignored. + // Redacting state and nonce params, in case they contain any info that the client considers sensitive. + "scope", "response_type", "client_id", "redirect_uri", "response_mode", "display", "prompt", + "max_age", "ui_locales", "id_token_hint", "login_hint", "acr_values", "claims_locales", "claims", + "request", "request_uri", "registration", + // PKCE params from https://datatracker.ietf.org/doc/html/rfc7636. Let code_challenge be redacted. + "code_challenge_method", + // Custom Pinniped authorization params. + oidcapi.AuthorizeUpstreamIDPNameParamName, oidcapi.AuthorizeUpstreamIDPTypeParamName, + // Google-specific param that some client libraries will send anyway. Ignored by Pinniped but safe to log. + "access_type", +) + type authorizeHandler struct { downstreamIssuerURL string idpFinder federationdomainproviders.FederationDomainIdentityProvidersFinderI @@ -44,6 +60,7 @@ type authorizeHandler struct { generateNonce func() (nonce.Nonce, error) upstreamStateEncoder oidc.Encoder cookieCodec oidc.Codec + auditLogger plog.AuditLogger } func NewHandler( @@ -56,6 +73,7 @@ func NewHandler( generateNonce func() (nonce.Nonce, error), upstreamStateEncoder oidc.Encoder, cookieCodec oidc.Codec, + auditLogger plog.AuditLogger, ) http.Handler { h := &authorizeHandler{ downstreamIssuerURL: downstreamIssuerURL, @@ -67,6 +85,7 @@ func NewHandler( generateNonce: generateNonce, upstreamStateEncoder: upstreamStateEncoder, cookieCodec: cookieCodec, + auditLogger: auditLogger, } // During a response_mode=form_post auth request using the browser flow, the custom form_post html page may // be used to post certain errors back to the CLI from this handler's response, so allow the form_post @@ -83,9 +102,10 @@ func (h *authorizeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // The client set a username or password header, so they are trying to log in without using a browser. - requestedBrowserlessFlow := len(r.Header.Values(oidcapi.AuthorizeUsernameHeaderName)) > 0 || - len(r.Header.Values(oidcapi.AuthorizePasswordHeaderName)) > 0 + // If the client set a username or password header, they are trying to log in without using a browser. + hadUsernameHeader := len(r.Header.Values(oidcapi.AuthorizeUsernameHeaderName)) > 0 + hadPasswordHeader := len(r.Header.Values(oidcapi.AuthorizePasswordHeaderName)) > 0 + requestedBrowserlessFlow := hadUsernameHeader || hadPasswordHeader // Need to parse the request params, so we can get the IDP name. The style and text of the error is inspired by // fosite's implementation of NewAuthorizeRequest(). Fosite only calls ParseMultipartForm() there. However, @@ -112,6 +132,15 @@ func (h *authorizeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + // Log if these headers were present, but don't log the actual values. The password is obviously sensitive, + // and sometimes users use their password as their username by mistake. + h.auditLogger.Audit(plog.AuditEventHTTPRequestCustomHeadersUsed, r.Context(), nil, + oidcapi.AuthorizeUsernameHeaderName, hadUsernameHeader, + oidcapi.AuthorizePasswordHeaderName, hadPasswordHeader) + + h.auditLogger.Audit(plog.AuditEventHTTPRequestParameters, r.Context(), nil, + "params", plog.SanitizeParams(r.Form, paramsSafeToLog)) + // Note that the client might have used oidcapi.AuthorizeUpstreamIDPNameParamName and // oidcapi.AuthorizeUpstreamIDPTypeParamName query (or form) params to request a certain upstream IDP. // The Pinniped CLI has been sending these params since v0.9.0. @@ -141,6 +170,12 @@ func (h *authorizeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + h.auditLogger.Audit(plog.AuditEventUsingUpstreamIDP, r.Context(), nil, + "displayName", idp.GetDisplayName(), + "resourceName", idp.GetProvider().GetResourceName(), + "resourceUID", idp.GetProvider().GetResourceUID(), + "type", idp.GetSessionProviderType()) + h.authorize(w, r, requestedBrowserlessFlow, idp) } @@ -203,11 +238,13 @@ func (h *authorizeHandler) authorizeWithoutBrowser( return err } - session, err := downstreamsession.NewPinnipedSession(r.Context(), idp, &downstreamsession.SessionConfig{ + session, err := downstreamsession.NewPinnipedSession(r.Context(), h.auditLogger, &downstreamsession.SessionConfig{ UpstreamIdentity: identity, UpstreamLoginExtras: loginExtras, ClientID: authorizeRequester.GetClient().GetID(), GrantedScopes: authorizeRequester.GetGrantedScopes(), + IdentityProvider: idp, + SessionIDGetter: authorizeRequester, }) if err != nil { return fosite.ErrAccessDenied.WithHintf("Reason: %s.", err.Error()) diff --git a/internal/federationdomain/endpoints/auth/auth_handler_test.go b/internal/federationdomain/endpoints/auth/auth_handler_test.go index 8b3aed1106..5bffb75dd5 100644 --- a/internal/federationdomain/endpoints/auth/auth_handler_test.go +++ b/internal/federationdomain/endpoints/auth/auth_handler_test.go @@ -36,6 +36,7 @@ import ( "go.pinniped.dev/internal/federationdomain/oidcclientvalidator" "go.pinniped.dev/internal/federationdomain/storage" "go.pinniped.dev/internal/here" + "go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/psession" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil/oidctestutil" @@ -3624,6 +3625,7 @@ func TestAuthorizationEndpoint(t *testing.T) { //nolint:gocyclo oauthHelperWithNullStorage, oauthHelperWithRealStorage, test.generateCSRF, test.generatePKCE, test.generateNonce, test.stateEncoder, test.cookieEncoder, + plog.New(), ) runOneTestCase(t, test, subject, kubeOauthStore, supervisorClient, kubeClient, secretsClient) }) @@ -3647,6 +3649,7 @@ func TestAuthorizationEndpoint(t *testing.T) { //nolint:gocyclo oauthHelperWithNullStorage, oauthHelperWithRealStorage, test.generateCSRF, test.generatePKCE, test.generateNonce, test.stateEncoder, test.cookieEncoder, + plog.New(), ) runOneTestCase(t, test, subject, kubeOauthStore, supervisorClient, kubeClient, secretsClient) diff --git a/internal/federationdomain/endpoints/callback/callback_handler.go b/internal/federationdomain/endpoints/callback/callback_handler.go index 9295b90efa..5c8d32165b 100644 --- a/internal/federationdomain/endpoints/callback/callback_handler.go +++ b/internal/federationdomain/endpoints/callback/callback_handler.go @@ -24,6 +24,7 @@ func NewHandler( oauthHelper fosite.OAuth2Provider, stateDecoder, cookieDecoder oidc.Decoder, redirectURI string, + auditLogger plog.AuditLogger, ) http.Handler { handler := httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { state, err := validateRequest(r, stateDecoder, cookieDecoder) @@ -69,11 +70,13 @@ func NewHandler( return err } - session, err := downstreamsession.NewPinnipedSession(r.Context(), idp, &downstreamsession.SessionConfig{ + session, err := downstreamsession.NewPinnipedSession(r.Context(), auditLogger, &downstreamsession.SessionConfig{ UpstreamIdentity: identity, UpstreamLoginExtras: loginExtras, ClientID: authorizeRequester.GetClient().GetID(), GrantedScopes: authorizeRequester.GetGrantedScopes(), + IdentityProvider: idp, + SessionIDGetter: authorizeRequester, }) if err != nil { plog.WarningErr("unable to create a Pinniped session", err, diff --git a/internal/federationdomain/endpoints/callback/callback_handler_test.go b/internal/federationdomain/endpoints/callback/callback_handler_test.go index 2dd34d078c..12e6d65e98 100644 --- a/internal/federationdomain/endpoints/callback/callback_handler_test.go +++ b/internal/federationdomain/endpoints/callback/callback_handler_test.go @@ -27,6 +27,7 @@ import ( "go.pinniped.dev/internal/federationdomain/oidcclientvalidator" "go.pinniped.dev/internal/federationdomain/storage" "go.pinniped.dev/internal/federationdomain/upstreamprovider" + "go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/psession" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil/oidctestutil" @@ -1757,7 +1758,15 @@ func TestCallbackEndpoint(t *testing.T) { jwksProviderIsUnused := jwks.NewDynamicJWKSProvider() oauthHelper := oidc.FositeOauth2Helper(oauthStore, downstreamIssuer, hmacSecretFunc, jwksProviderIsUnused, timeoutsConfiguration) - subject := NewHandler(test.idps.BuildFederationDomainIdentityProvidersListerFinder(), oauthHelper, happyStateCodec, happyCookieCodec, happyUpstreamRedirectURI) + subject := NewHandler( + test.idps.BuildFederationDomainIdentityProvidersListerFinder(), + oauthHelper, + happyStateCodec, + happyCookieCodec, + happyUpstreamRedirectURI, + plog.New(), + ) + reqContext := context.WithValue(context.Background(), struct{ name string }{name: "test"}, "request-context") req := httptest.NewRequest(test.method, test.path, nil).WithContext(reqContext) if test.csrfCookie != "" { diff --git a/internal/federationdomain/endpoints/login/post_login_handler.go b/internal/federationdomain/endpoints/login/post_login_handler.go index 07feb9f5ea..084a530aee 100644 --- a/internal/federationdomain/endpoints/login/post_login_handler.go +++ b/internal/federationdomain/endpoints/login/post_login_handler.go @@ -19,7 +19,12 @@ import ( "go.pinniped.dev/internal/plog" ) -func NewPostHandler(issuerURL string, upstreamIDPs federationdomainproviders.FederationDomainIdentityProvidersFinderI, oauthHelper fosite.OAuth2Provider) HandlerFunc { +func NewPostHandler( + issuerURL string, + upstreamIDPs federationdomainproviders.FederationDomainIdentityProvidersFinderI, + oauthHelper fosite.OAuth2Provider, + auditLogger plog.AuditLogger, +) HandlerFunc { return func(w http.ResponseWriter, r *http.Request, encodedState string, decodedState *oidc.UpstreamStateParamData) error { // Note that the login handler prevents this handler from being called with OIDC upstreams. idp, err := upstreamIDPs.FindUpstreamIDPByDisplayName(decodedState.UpstreamName) @@ -84,11 +89,13 @@ func NewPostHandler(issuerURL string, upstreamIDPs federationdomainproviders.Fed } } - session, err := downstreamsession.NewPinnipedSession(r.Context(), idp, &downstreamsession.SessionConfig{ + session, err := downstreamsession.NewPinnipedSession(r.Context(), auditLogger, &downstreamsession.SessionConfig{ UpstreamIdentity: identity, UpstreamLoginExtras: loginExtras, ClientID: authorizeRequester.GetClient().GetID(), GrantedScopes: authorizeRequester.GetGrantedScopes(), + IdentityProvider: idp, + SessionIDGetter: authorizeRequester, }) if err != nil { err = fosite.ErrAccessDenied.WithHintf("Reason: %s.", err.Error()) diff --git a/internal/federationdomain/endpoints/login/post_login_handler_test.go b/internal/federationdomain/endpoints/login/post_login_handler_test.go index cf97c7aa1d..003d3f02ef 100644 --- a/internal/federationdomain/endpoints/login/post_login_handler_test.go +++ b/internal/federationdomain/endpoints/login/post_login_handler_test.go @@ -25,6 +25,7 @@ import ( "go.pinniped.dev/internal/federationdomain/oidc" "go.pinniped.dev/internal/federationdomain/oidcclientvalidator" "go.pinniped.dev/internal/federationdomain/storage" + "go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/psession" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil/oidctestutil" @@ -1146,7 +1147,7 @@ func TestPostLoginEndpoint(t *testing.T) { rsp := httptest.NewRecorder() - subject := NewPostHandler(downstreamIssuer, tt.idps.BuildFederationDomainIdentityProvidersListerFinder(), oauthHelper) + subject := NewPostHandler(downstreamIssuer, tt.idps.BuildFederationDomainIdentityProvidersListerFinder(), oauthHelper, plog.New()) err := subject(rsp, req, happyEncodedUpstreamState, tt.decodedState) if tt.wantErr != "" { diff --git a/internal/federationdomain/endpoints/token/token_handler.go b/internal/federationdomain/endpoints/token/token_handler.go index 038ee01fac..83bdf5f0d1 100644 --- a/internal/federationdomain/endpoints/token/token_handler.go +++ b/internal/federationdomain/endpoints/token/token_handler.go @@ -30,11 +30,23 @@ import ( "go.pinniped.dev/internal/psession" ) +//nolint:gochecknoglobals // please treat this as a readonly const, do not mutate +var paramsSafeToLog = sets.New[string]( + // Standard params from https://openid.net/specs/openid-connect-core-1_0.html for authcde and refresh grants. + // Redacting code, client_secret, refresh_token, and PKCE code_verifier params. + "grant_type", "client_id", "redirect_uri", "scope", + // Token exchange params from https://datatracker.ietf.org/doc/html/rfc8693. + // Redact subject_token and actor_token. + // We don't allow all of these, but they should be safe to log. + "audience", "resource", "scope", "requested_token_type", "actor_token_type", "subject_token_type", +) + func NewHandler( idpLister federationdomainproviders.FederationDomainIdentityProvidersListerI, oauthHelper fosite.OAuth2Provider, overrideAccessTokenLifespan timeouts.OverrideLifespan, overrideIDTokenLifespan timeouts.OverrideLifespan, + auditLogger plog.AuditLogger, ) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { session := psession.NewPinnipedSession() @@ -45,13 +57,17 @@ func NewHandler( return nil } + // Note that r.PostForm and accessRequest were populated by NewAccessRequest(). + auditLogger.Audit(plog.AuditEventHTTPRequestParameters, r.Context(), accessRequest, + "params", plog.SanitizeParams(r.PostForm, paramsSafeToLog)) + // Check if we are performing a refresh grant. if accessRequest.GetGrantTypes().ExactOne(oidcapi.GrantTypeRefreshToken) { // The above call to NewAccessRequest has loaded the session from storage into the accessRequest variable. // The session, requested scopes, and requested audience from the original authorize request was retrieved // from the Kube storage layer and added to the accessRequest. Additionally, the audience and scopes may // have already been granted on the accessRequest. - err = upstreamRefresh(r.Context(), accessRequest, idpLister) + err = upstreamRefresh(r.Context(), accessRequest, idpLister, auditLogger) if err != nil { plog.Info("upstream refresh error", oidc.FositeErrorForLog(err)...) oauthHelper.WriteAccessError(r.Context(), w, accessRequest, err) @@ -128,6 +144,7 @@ func upstreamRefresh( ctx context.Context, accessRequest fosite.AccessRequester, idpLister federationdomainproviders.FederationDomainIdentityProvidersListerI, + auditLogger plog.AuditLogger, ) error { session := accessRequest.GetSession().(*psession.PinnipedSession) @@ -136,6 +153,7 @@ func upstreamRefresh( return errorsx.WithStack(errMissingUpstreamSessionInternalError()) } providerName := customSessionData.ProviderName + providerType := customSessionData.ProviderType providerUID := customSessionData.ProviderUID if providerUID == "" || providerName == "" { return errorsx.WithStack(errMissingUpstreamSessionInternalError()) @@ -188,6 +206,10 @@ func upstreamRefresh( return err } + auditLogger.Audit(plog.AuditEventIdentityRefreshedFromUpstreamIDP, ctx, accessRequest, + "upstreamUsername", refreshedIdentity.UpstreamUsername, + "upstreamGroups", refreshedIdentity.UpstreamGroups) + // If the idp wants to update the session with new information from the refresh, then update it. if refreshedIdentity.IDPSpecificSessionData != nil { idp.ApplyIDPSpecificSessionDataToSession(session.Custom, refreshedIdentity.IDPSpecificSessionData) @@ -203,24 +225,37 @@ func upstreamRefresh( refreshedIdentity.UpstreamGroups = oldUntransformedGroups } - refreshedTransformedGroups, err := applyIdentityTransformationsDuringRefresh(ctx, + refreshedTransformedUsername, refreshedTransformedGroups, err := applyIdentityTransformationsDuringRefresh(ctx, idp.GetTransforms(), - oldTransformedUsername, // this function validates that the old and new transformed usernames match refreshedIdentity.UpstreamUsername, refreshedIdentity.UpstreamGroups, - session.Custom.ProviderName, - session.Custom.ProviderType, + providerName, + providerType, ) if err != nil { + auditLogger.Audit(plog.AuditEventAuthenticationRejectedByTransforms, ctx, accessRequest, + "err", err) return err } + if oldTransformedUsername != refreshedTransformedUsername { + return errUpstreamRefreshError().WithHintf( + "Upstream refresh failed."). + WithTrace(errors.New("username in upstream refresh does not match previous value")). + WithDebugf("provider name: %q, provider type: %q", providerName, providerType) + } + if !skipGroups { warnIfGroupsChanged(ctx, oldTransformedGroups, refreshedTransformedGroups, oldTransformedUsername, accessRequest.GetClient().GetID()) // Replace the old value for the downstream groups in the user's session with the new value. session.Fosite.Claims.Extra[oidcapi.IDTokenClaimGroups] = refreshedTransformedGroups } + auditLogger.Audit(plog.AuditEventSessionRefreshed, ctx, accessRequest, + "username", oldTransformedUsername, // not allowed to change above so must be the same as old + "groups", refreshedTransformedGroups, + "subject", previousIdentity.DownstreamSubject) + return nil } @@ -255,38 +290,30 @@ func validateSessionHasUsername(session *psession.PinnipedSession) error { } // applyIdentityTransformationsDuringRefresh is similar to downstreamsession.applyIdentityTransformations -// but with validation that the username has not changed, and with slightly different error messaging. +// but with slightly different error messaging. func applyIdentityTransformationsDuringRefresh( ctx context.Context, transforms *idtransform.TransformationPipeline, - oldTransformedUsername string, upstreamUsername string, upstreamGroups []string, providerName string, providerType psession.ProviderType, -) ([]string, error) { +) (string, []string, error) { transformationResult, err := transforms.Evaluate(ctx, upstreamUsername, upstreamGroups) if err != nil { - return nil, errUpstreamRefreshError().WithHintf( + return "", nil, errUpstreamRefreshError().WithHintf( "Upstream refresh error while applying configured identity transformations."). WithTrace(err). WithDebugf("provider name: %q, provider type: %q", providerName, providerType) } if !transformationResult.AuthenticationAllowed { - return nil, errUpstreamRefreshError().WithHintf( + return "", nil, errUpstreamRefreshError().WithHintf( "Upstream refresh rejected by configured identity policy: %s.", transformationResult.RejectedAuthenticationMessage). WithDebugf("provider name: %q, provider type: %q", providerName, providerType) } - if oldTransformedUsername != transformationResult.Username { - return nil, errUpstreamRefreshError().WithHintf( - "Upstream refresh failed."). - WithTrace(errors.New("username in upstream refresh does not match previous value")). - WithDebugf("provider name: %q, provider type: %q", providerName, providerType) - } - - return transformationResult.Groups, nil + return transformationResult.Username, transformationResult.Groups, nil } func validateAndGetDownstreamGroupsFromSession(session *psession.PinnipedSession) ([]string, error) { diff --git a/internal/federationdomain/endpoints/token/token_handler_test.go b/internal/federationdomain/endpoints/token/token_handler_test.go index 76cfe26836..3317343e73 100644 --- a/internal/federationdomain/endpoints/token/token_handler_test.go +++ b/internal/federationdomain/endpoints/token/token_handler_test.go @@ -61,6 +61,7 @@ import ( "go.pinniped.dev/internal/here" "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/oidcclientsecretstorage" + "go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/psession" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil/oidctestutil" @@ -4916,6 +4917,7 @@ func exchangeAuthcodeForTokens( oauthHelper, timeoutsConfiguration.OverrideDefaultAccessTokenLifespan, timeoutsConfiguration.OverrideDefaultIDTokenLifespan, + plog.New(), ) authorizeEndpointGrantedOpenIDScope := strings.Contains(authRequest.Form.Get("scope"), "openid") diff --git a/internal/federationdomain/endpoints/tokenexchange/token_exchange.go b/internal/federationdomain/endpoints/tokenexchange/token_exchange.go index cd48157b62..d68e73defe 100644 --- a/internal/federationdomain/endpoints/tokenexchange/token_exchange.go +++ b/internal/federationdomain/endpoints/tokenexchange/token_exchange.go @@ -46,17 +46,10 @@ type tokenExchangeHandler struct { var _ fosite.TokenEndpointHandler = (*tokenExchangeHandler)(nil) func (t *tokenExchangeHandler) HandleTokenEndpointRequest(ctx context.Context, requester fosite.AccessRequester) error { + // Skip this request if it's for a different grant type. if !t.CanHandleTokenEndpointRequest(ctx, requester) { return errors.WithStack(fosite.ErrUnknownRequest) } - return nil -} - -func (t *tokenExchangeHandler) PopulateTokenEndpointResponse(ctx context.Context, requester fosite.AccessRequester, responder fosite.AccessResponder) error { - // Skip this request if it's for a different grant type. - if err := t.HandleTokenEndpointRequest(ctx, requester); err != nil { - return errors.WithStack(err) - } // Validate the basic RFC8693 parameters we support. params, err := t.validateParams(requester.GetRequestForm()) @@ -64,7 +57,7 @@ func (t *tokenExchangeHandler) PopulateTokenEndpointResponse(ctx context.Context return errors.WithStack(err) } - // Validate the incoming access token and lookup the information about the original authorize request. + // Validate the incoming access token and lookup the information about the original authorize request from storage. originalRequester, err := t.validateAccessToken(ctx, requester, params.subjectAccessToken) if err != nil { return errors.WithStack(err) @@ -95,8 +88,28 @@ func (t *tokenExchangeHandler) PopulateTokenEndpointResponse(ctx context.Context return errors.WithStack(err) } + // Copy the original session ID from storage. + requester.SetID(originalRequester.GetID()) + // Copy the original session details from storage, which will be used by PopulateTokenEndpointResponse() to mint a token. + requester.SetSession(originalRequester.GetSession().Clone()) + // Maybe not needed, but just to be safe, copy these too, similar to how flow_refresh.go copies them. + requester.SetRequestedScopes(originalRequester.GetRequestedScopes()) + requester.SetRequestedAudience(originalRequester.GetRequestedAudience()) + + return nil +} + +func (t *tokenExchangeHandler) PopulateTokenEndpointResponse(ctx context.Context, requester fosite.AccessRequester, responder fosite.AccessResponder) error { + // Skip this request if it's for a different grant type. + if !t.CanHandleTokenEndpointRequest(ctx, requester) { + return errors.WithStack(fosite.ErrUnknownRequest) + } + + // Get the requested audience parameter again, which was already validated by HandleTokenEndpointRequest() above. + requestedNewAudience := requester.GetRequestForm().Get("audience") + // Use the original authorize request information, along with the requested audience, to mint a new JWT. - responseToken, err := t.mintJWT(ctx, originalRequester, params.requestedAudience) + responseToken, err := t.mintJWT(ctx, requester, requestedNewAudience) if err != nil { return errors.WithStack(err) } @@ -108,15 +121,15 @@ func (t *tokenExchangeHandler) PopulateTokenEndpointResponse(ctx context.Context return nil } -func (t *tokenExchangeHandler) mintJWT(ctx context.Context, requester fosite.Requester, audience string) (string, error) { - downscoped := fosite.NewAccessRequest(requester.GetSession()) - downscoped.Client.(*fosite.DefaultClient).ID = audience +func (t *tokenExchangeHandler) mintJWT(ctx context.Context, requester fosite.Requester, newAudience string) (string, error) { + requestWithNewAudience := fosite.NewAccessRequest(requester.GetSession()) + requestWithNewAudience.Client.(*fosite.DefaultClient).ID = newAudience // Note: if we wanted to support clients with custom token lifespans, then we would need to call // fosite.GetEffectiveLifespan() to determine the lifespan here. idTokenLifespan := t.fositeConfig.GetIDTokenLifespan(ctx) - return t.idTokenStrategy.GenerateIDToken(ctx, idTokenLifespan, downscoped) + return t.idTokenStrategy.GenerateIDToken(ctx, idTokenLifespan, requestWithNewAudience) } func (t *tokenExchangeHandler) validateSession(requester fosite.Requester) error { diff --git a/internal/federationdomain/endpointsmanager/manager.go b/internal/federationdomain/endpointsmanager/manager.go index ea5f6e016b..f3104a494c 100644 --- a/internal/federationdomain/endpointsmanager/manager.go +++ b/internal/federationdomain/endpointsmanager/manager.go @@ -25,6 +25,7 @@ import ( "go.pinniped.dev/internal/federationdomain/idplister" "go.pinniped.dev/internal/federationdomain/oidc" "go.pinniped.dev/internal/federationdomain/oidcclientvalidator" + "go.pinniped.dev/internal/federationdomain/requestlogger" "go.pinniped.dev/internal/federationdomain/storage" "go.pinniped.dev/internal/httputil/requestutil" "go.pinniped.dev/internal/plog" @@ -40,12 +41,13 @@ type Manager struct { mu sync.RWMutex providers []*federationdomainproviders.FederationDomainIssuer providerHandlers map[string]http.Handler // map of all routes for all providers - nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request + handlerChain http.Handler // http handlers dynamicJWKSProvider jwks.DynamicJWKSProvider // in-memory cache of per-issuer JWKS data upstreamIDPs idplister.UpstreamIdentityProvidersLister // in-memory cache of upstream IDPs secretCache *secret.Cache // in-memory cache of cryptographic material secretsClient corev1client.SecretInterface oidcClientsClient v1alpha1.OIDCClientInterface + auditLogger plog.AuditLogger } // NewManager returns an empty Manager. @@ -59,16 +61,24 @@ func NewManager( secretCache *secret.Cache, secretsClient corev1client.SecretInterface, oidcClientsClient v1alpha1.OIDCClientInterface, + auditLogger plog.AuditLogger, ) *Manager { - return &Manager{ + m := &Manager{ providerHandlers: make(map[string]http.Handler), - nextHandler: nextHandler, dynamicJWKSProvider: dynamicJWKSProvider, upstreamIDPs: upstreamIDPs, secretCache: secretCache, secretsClient: secretsClient, oidcClientsClient: oidcClientsClient, + auditLogger: auditLogger, } + // nextHandler is the next handler in the chain, called when this manager didn't know how to handle a request + m.buildHandlerChain(nextHandler) + return m +} + +func (m *Manager) HandlerChain() http.Handler { + return m.handlerChain } // SetFederationDomains adds or updates all the given providerHandlers using each provider's issuer string @@ -77,7 +87,7 @@ func NewManager( // It also removes any providerHandlers that were previously added but were not passed in to // the current invocation. // -// This method assumes that all of the FederationDomainIssuer arguments have already been validated +// This method assumes that all the FederationDomainIssuer arguments have already been validated // by someone else before they are passed to this method. func (m *Manager) SetFederationDomains(federationDomains ...*federationdomainproviders.FederationDomainIssuer) { m.mu.Lock() @@ -143,6 +153,7 @@ func (m *Manager) SetFederationDomains(federationDomains ...*federationdomainpro nonce.Generate, upstreamStateEncoder, csrfCookieEncoder, + m.auditLogger, ) m.providerHandlers[(issuerHostWithPath + oidc.CallbackEndpointPath)] = callback.NewHandler( @@ -151,6 +162,7 @@ func (m *Manager) SetFederationDomains(federationDomains ...*federationdomainpro upstreamStateEncoder, csrfCookieEncoder, issuerURL+oidc.CallbackEndpointPath, + m.auditLogger, ) m.providerHandlers[(issuerHostWithPath + oidc.ChooseIDPEndpointPath)] = chooseidp.NewHandler( @@ -163,38 +175,49 @@ func (m *Manager) SetFederationDomains(federationDomains ...*federationdomainpro oauthHelperWithKubeStorage, timeoutsConfiguration.OverrideDefaultAccessTokenLifespan, timeoutsConfiguration.OverrideDefaultIDTokenLifespan, + m.auditLogger, ) m.providerHandlers[(issuerHostWithPath + oidc.PinnipedLoginPath)] = login.NewHandler( upstreamStateEncoder, csrfCookieEncoder, login.NewGetHandler(incomingFederationDomain.IssuerPath()+oidc.PinnipedLoginPath), - login.NewPostHandler(issuerURL, idpLister, oauthHelperWithKubeStorage), + login.NewPostHandler(issuerURL, idpLister, oauthHelperWithKubeStorage, m.auditLogger), ) plog.Debug("oidc provider manager added or updated issuer", "issuer", issuerURL) } } -// ServeHTTP implements the http.Handler interface. -func (m *Manager) ServeHTTP(resp http.ResponseWriter, req *http.Request) { - requestHandler := m.findHandler(req) - - // Using Info level so the user can safely configure a production Supervisor to show this message if they choose. - plog.Info("received incoming request", - "proto", req.Proto, - "method", req.Method, - "host", req.Host, - "requestSNIServerName", requestutil.SNIServerName(req), - "path", req.URL.Path, - "remoteAddr", req.RemoteAddr, - "foundFederationDomainRequestHandler", requestHandler != nil, - ) +func (m *Manager) buildHandlerChain(nextHandler http.Handler) { + handler := m.buildManagerHandler(nextHandler) // build the basic handler for FederationDomain endpoints + handler = requestlogger.WithHTTPRequestAuditLogging(handler, m.auditLogger) // log all requests, including audit ID + handler = requestlogger.WithAuditID(handler) // add random audit ID to request context and response headers + m.handlerChain = handler +} - if requestHandler == nil { - requestHandler = m.nextHandler // couldn't find an issuer to handle the request - } - requestHandler.ServeHTTP(resp, req) +func (m *Manager) buildManagerHandler(nextHandler http.Handler) http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + requestHandler := m.findHandler(req) + + // TODO: Should this old log message change in light of the new audit logs? Or do we not want to force people to enable audit logs to debug this SNI stuff? + // Using Info level so the user can safely configure a production Supervisor to show this message if they choose. + plog.Info("received incoming request", + "proto", req.Proto, + "method", req.Method, + "host", req.Host, + "requestSNIServerName", requestutil.SNIServerName(req), + "path", req.URL.Path, + "remoteAddr", req.RemoteAddr, + "userAgent", req.UserAgent(), + "foundFederationDomainRequestHandler", requestHandler != nil, + ) + + if requestHandler == nil { + requestHandler = nextHandler // couldn't find an issuer to handle the request + } + requestHandler.ServeHTTP(resp, req) + }) } func (m *Manager) findHandler(req *http.Request) http.Handler { diff --git a/internal/federationdomain/endpointsmanager/manager_test.go b/internal/federationdomain/endpointsmanager/manager_test.go index dfde19130b..d951b9767f 100644 --- a/internal/federationdomain/endpointsmanager/manager_test.go +++ b/internal/federationdomain/endpointsmanager/manager_test.go @@ -26,6 +26,7 @@ import ( "go.pinniped.dev/internal/federationdomain/oidc" "go.pinniped.dev/internal/here" "go.pinniped.dev/internal/idtransform" + "go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/secret" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil/oidctestutil" @@ -83,7 +84,7 @@ func TestManager(t *testing.T) { requireDiscoveryRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedIssuer string) { recorder := httptest.NewRecorder() - subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.WellKnownEndpointPath+requestURLSuffix)) + subject.HandlerChain().ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.WellKnownEndpointPath+requestURLSuffix)) r.False(fallbackHandlerWasCalled) @@ -101,7 +102,7 @@ func TestManager(t *testing.T) { requirePinnipedIDPsDiscoveryRequestToBeHandled := func(requestIssuer, requestURLSuffix string, expectedIDPNames []string, expectedIDPTypes string, expectedFlows []string) { recorder := httptest.NewRecorder() - subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.PinnipedIDPsPathV1Alpha1+requestURLSuffix)) + subject.HandlerChain().ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.PinnipedIDPsPathV1Alpha1+requestURLSuffix)) r.False(fallbackHandlerWasCalled) @@ -145,7 +146,7 @@ func TestManager(t *testing.T) { "response_type": []string{"bat"}, } - subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.ChooseIDPEndpointPath+"?"+requiredParams.Encode())) + subject.HandlerChain().ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.ChooseIDPEndpointPath+"?"+requiredParams.Encode())) r.False(fallbackHandlerWasCalled) @@ -164,7 +165,7 @@ func TestManager(t *testing.T) { requireAuthorizationRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedRedirectLocationPrefix string) (string, string) { recorder := httptest.NewRecorder() - subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.AuthorizationEndpointPath+requestURLSuffix)) + subject.HandlerChain().ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.AuthorizationEndpointPath+requestURLSuffix)) r.False(fallbackHandlerWasCalled) @@ -202,7 +203,7 @@ func TestManager(t *testing.T) { Name: "__Host-pinniped-csrf", Value: csrfCookieValue, }) - subject.ServeHTTP(recorder, getRequest) + subject.HandlerChain().ServeHTTP(recorder, getRequest) r.False(fallbackHandlerWasCalled) @@ -242,7 +243,7 @@ func TestManager(t *testing.T) { "code_verifier": []string{downstreamPKCECodeVerifier}, "grant_type": []string{"authorization_code"}, }.Encode() - subject.ServeHTTP(recorder, newPostRequest(requestIssuer+oidc.TokenEndpointPath, tokenRequestBody)) + subject.HandlerChain().ServeHTTP(recorder, newPostRequest(requestIssuer+oidc.TokenEndpointPath, tokenRequestBody)) r.False(fallbackHandlerWasCalled) @@ -272,7 +273,7 @@ func TestManager(t *testing.T) { requireJWKSRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedJWKKeyID string) *jose.JSONWebKeySet { recorder := httptest.NewRecorder() - subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.JWKSEndpointPath+requestURLSuffix)) + subject.HandlerChain().ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.JWKSEndpointPath+requestURLSuffix)) r.False(fallbackHandlerWasCalled) @@ -358,13 +359,13 @@ func TestManager(t *testing.T) { cache.SetStateEncoderHashKey(issuer2, []byte("some-state-encoder-hash-key-2")) cache.SetStateEncoderBlockKey(issuer2, []byte("16-bytes-STATE02")) - subject = NewManager(nextHandler, dynamicJWKSProvider, idpLister, &cache, secretsClient, oidcClientsClient) + subject = NewManager(nextHandler, dynamicJWKSProvider, idpLister, &cache, secretsClient, oidcClientsClient, plog.New()) }) when("given no providers via SetFederationDomains()", func() { it("sends all requests to the nextHandler", func() { r.False(fallbackHandlerWasCalled) - subject.ServeHTTP(httptest.NewRecorder(), newGetRequest("/anything")) + subject.HandlerChain().ServeHTTP(httptest.NewRecorder(), newGetRequest("/anything")) r.True(fallbackHandlerWasCalled) }) }) @@ -507,19 +508,19 @@ func TestManager(t *testing.T) { it("sends all non-matching host requests to the nextHandler", func() { r.False(fallbackHandlerWasCalled) wrongHostURL := strings.ReplaceAll(issuer1+oidc.WellKnownEndpointPath, "example.com", "wrong-host.com") - subject.ServeHTTP(httptest.NewRecorder(), newGetRequest(wrongHostURL)) + subject.HandlerChain().ServeHTTP(httptest.NewRecorder(), newGetRequest(wrongHostURL)) r.True(fallbackHandlerWasCalled) }) it("sends all non-matching path requests to the nextHandler", func() { r.False(fallbackHandlerWasCalled) - subject.ServeHTTP(httptest.NewRecorder(), newGetRequest("https://example.com/path-does-not-match-any-provider")) + subject.HandlerChain().ServeHTTP(httptest.NewRecorder(), newGetRequest("https://example.com/path-does-not-match-any-provider")) r.True(fallbackHandlerWasCalled) }) it("sends requests which match the issuer prefix but do not match any of that provider's known paths to the nextHandler", func() { r.False(fallbackHandlerWasCalled) - subject.ServeHTTP(httptest.NewRecorder(), newGetRequest(issuer1+"/unhandled-sub-path")) + subject.HandlerChain().ServeHTTP(httptest.NewRecorder(), newGetRequest(issuer1+"/unhandled-sub-path")) r.True(fallbackHandlerWasCalled) }) diff --git a/internal/federationdomain/requestlogger/request_logger.go b/internal/federationdomain/requestlogger/request_logger.go new file mode 100644 index 0000000000..6f9dc03ed2 --- /dev/null +++ b/internal/federationdomain/requestlogger/request_logger.go @@ -0,0 +1,141 @@ +// Copyright 2024 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package requestlogger + +import ( + "bufio" + "net" + "net/http" + "time" + + "github.com/google/uuid" + "k8s.io/apimachinery/pkg/types" + apisaudit "k8s.io/apiserver/pkg/apis/audit" + "k8s.io/apiserver/pkg/audit" + "k8s.io/apiserver/pkg/endpoints/responsewriter" + + "go.pinniped.dev/internal/httputil/requestutil" + "go.pinniped.dev/internal/plog" +) + +func WithAuditID(handler http.Handler) http.Handler { + return withAuditID(handler, func() string { + return uuid.New().String() + }) +} + +func withAuditID(handler http.Handler, newAuditIDFunc func() string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := audit.WithAuditContext(r.Context()) + r = r.WithContext(ctx) + + auditID := newAuditIDFunc() + audit.WithAuditID(ctx, types.UID(auditID)) + + // Send the Audit-ID response header. + w.Header().Set(apisaudit.HeaderAuditID, auditID) + + handler.ServeHTTP(w, r) + }) +} + +func WithHTTPRequestAuditLogging(handler http.Handler, auditLogger plog.AuditLogger) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + rl := newRequestLogger(req, w, auditLogger, time.Now()) + + rl.LogRequestReceived() + defer rl.LogRequestComplete() + + statusCodeCapturingResponseWriter := responsewriter.WrapForHTTP1Or2(rl) + handler.ServeHTTP(statusCodeCapturingResponseWriter, req) + }) +} + +type requestLogger struct { + startTime time.Time + + hijacked bool + statusRecorded bool + status int + + req *http.Request + userAgent string + w http.ResponseWriter + + auditLogger plog.AuditLogger +} + +func newRequestLogger(req *http.Request, w http.ResponseWriter, auditLogger plog.AuditLogger, startTime time.Time) *requestLogger { + return &requestLogger{ + req: req, + w: w, + startTime: startTime, + userAgent: req.UserAgent(), // cache this from the req to avoid any possibility of concurrent read/write problems with headers map + auditLogger: auditLogger, + } +} + +func (rl *requestLogger) LogRequestReceived() { + r := rl.req + rl.auditLogger.Audit(plog.AuditEventHTTPRequestReceived, + r.Context(), + nil, // no session available yet in this context + "proto", r.Proto, + "method", r.Method, + "host", r.Host, + "serverName", requestutil.SNIServerName(r), + "path", r.URL.Path, + "userAgent", rl.userAgent, + "remoteAddr", r.RemoteAddr, + ) +} + +func (rl *requestLogger) LogRequestComplete() { + r := rl.req + rl.auditLogger.Audit(plog.AuditEventHTTPRequestCompleted, + r.Context(), + nil, // no session available yet in this context + "path", r.URL.Path, // include the path again to make it easy to "grep -v healthz" to watch all other audit events + "latency", time.Since(rl.startTime), + "responseStatus", rl.status, + ) +} + +// Unwrap implements responsewriter.UserProvidedDecorator. +func (rl *requestLogger) Unwrap() http.ResponseWriter { + return rl.w +} + +// Header implements http.ResponseWriter. +func (rl *requestLogger) Header() http.Header { + return rl.w.Header() +} + +// Write implements http.ResponseWriter. +func (rl *requestLogger) Write(b []byte) (int, error) { + if !rl.statusRecorded { + rl.recordStatus(http.StatusOK) // Default if WriteHeader hasn't been called + } + return rl.w.Write(b) +} + +// WriteHeader implements http.ResponseWriter. +func (rl *requestLogger) WriteHeader(status int) { + rl.recordStatus(status) + rl.w.WriteHeader(status) +} + +// Hijack implements http.Hijacker. +func (rl *requestLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) { + rl.hijacked = true + + // the outer ResponseWriter object returned by WrapForHTTP1Or2 implements + // http.Hijacker if the inner object (rl.w) implements http.Hijacker. + return rl.w.(http.Hijacker).Hijack() +} + +func (rl *requestLogger) recordStatus(status int) { + rl.status = status + rl.statusRecorded = true +} diff --git a/internal/plog/audit_event.go b/internal/plog/audit_event.go new file mode 100644 index 0000000000..eb4d823afd --- /dev/null +++ b/internal/plog/audit_event.go @@ -0,0 +1,47 @@ +// Copyright 2024 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package plog + +import ( + "net/url" + + "k8s.io/apimachinery/pkg/util/sets" +) + +type AuditEventMessage string + +const ( + AuditEventHTTPRequestReceived AuditEventMessage = "HTTP Request Received" + AuditEventHTTPRequestCompleted AuditEventMessage = "HTTP Request Completed" + AuditEventHTTPRequestParameters AuditEventMessage = "HTTP Request Parameters" + AuditEventHTTPRequestCustomHeadersUsed AuditEventMessage = "HTTP Request Custom Headers Used" + AuditEventUsingUpstreamIDP AuditEventMessage = "Using Upstream IDP" + AuditEventIdentityFromUpstreamIDP AuditEventMessage = "Identity From Upstream IDP" + AuditEventIdentityRefreshedFromUpstreamIDP AuditEventMessage = "Identity Refreshed From Upstream IDP" + AuditEventSessionStarted AuditEventMessage = "Session Started" + AuditEventSessionRefreshed AuditEventMessage = "Session Refreshed" + AuditEventAuthenticationRejectedByTransforms AuditEventMessage = "Authentication RejectedBy Transforms" + AuditEventUpstreamOIDCTokenRevoked AuditEventMessage = "Upstream OIDC Token Revoked" //nolint:gosec // this is not a credential + AuditEventSessionGarbageCollected AuditEventMessage = "Session Garbage Collected" + AuditEventTokenCredentialRequest AuditEventMessage = "TokenCredentialRequest" //nolint:gosec // this is not a credential +) + +// SanitizeParams can be used to redact all params not included in the allowedKeys set. +// Useful when audit logging AuditEventHTTPRequestParameters events. +func SanitizeParams(params url.Values, allowedKeys sets.Set[string]) string { + if len(params) == 0 { + return "" + } + sanitized := url.Values{} + for key := range params { + if allowedKeys.Has(key) { + sanitized[key] = params[key] + } else { + for range params[key] { + sanitized.Add(key, "redacted") + } + } + } + return sanitized.Encode() +} diff --git a/internal/plog/audit_event_test.go b/internal/plog/audit_event_test.go new file mode 100644 index 0000000000..9cb031b6a5 --- /dev/null +++ b/internal/plog/audit_event_test.go @@ -0,0 +1,75 @@ +// Copyright 2024 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package plog + +import ( + "net/url" + "testing" + + "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/util/sets" +) + +func TestSanitizeParams(t *testing.T) { + tests := []struct { + name string + params url.Values + allowedKeys sets.Set[string] + want string + }{ + { + name: "nil values", + params: nil, + allowedKeys: nil, + want: "", + }, + { + name: "empty values", + params: url.Values{}, + allowedKeys: nil, + want: "", + }, + { + name: "all allowed values", + params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}}, + allowedKeys: sets.New("foo", "bar"), + want: "bar=d&bar=e&bar=f&foo=a&foo=b&foo=c", + }, + { + name: "all allowed values with single values", + params: url.Values{"foo": []string{"a"}, "bar": []string{"d"}}, + allowedKeys: sets.New("foo", "bar"), + want: "bar=d&foo=a", + }, + { + name: "some allowed values", + params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}}, + allowedKeys: sets.New("foo"), + want: "bar=redacted&bar=redacted&bar=redacted&foo=a&foo=b&foo=c", + }, + { + name: "some allowed values with single values", + params: url.Values{"foo": []string{"a"}, "bar": []string{"d"}}, + allowedKeys: sets.New("foo"), + want: "bar=redacted&foo=a", + }, + { + name: "no allowed values", + params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}}, + allowedKeys: sets.New[string](), + want: "bar=redacted&bar=redacted&bar=redacted&foo=redacted&foo=redacted&foo=redacted", + }, + { + name: "nil allowed values", + params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}}, + allowedKeys: nil, + want: "bar=redacted&bar=redacted&bar=redacted&foo=redacted&foo=redacted&foo=redacted", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, SanitizeParams(tt.params, tt.allowedKeys)) + }) + } +} diff --git a/internal/plog/plog.go b/internal/plog/plog.go index 1f989c25c7..894b0acc7d 100644 --- a/internal/plog/plog.go +++ b/internal/plog/plog.go @@ -28,19 +28,37 @@ package plog import ( + "context" "os" "slices" "github.com/go-logr/logr" + "k8s.io/apiserver/pkg/audit" ) const errorKey = "error" // this matches zapr's default for .Error calls (which is asserted via tests) +type SessionIDGetter interface { + GetID() string +} + +// AuditLogger is only the audit logging part of Logger. There is no global function for Audit because +// that would make unit testing of audit logs harder. +type AuditLogger interface { + // Audit writes an audit event to the log. + // reqCtx and session may be null. + // When possible, pass the http request's context as reqCtx, so we may read the audit ID from the context. + // When possible, pass the fosite.Requester or fosite.Request as the session, so we can log the session ID. + Audit(msg AuditEventMessage, reqCtx context.Context, session SessionIDGetter, keysAndValues ...any) +} + // Logger implements the plog logging convention described above. The global functions in this package // such as Info should be used when one does not intend to write tests assertions for specific log messages. // If test assertions are desired, Logger should be passed in as an input. New should be used as the // production implementation and TestLogger should be used to write test assertions. type Logger interface { + AuditLogger + Error(msg string, err error, keysAndValues ...any) Warning(msg string, keysAndValues ...any) WarningErr(msg string, err error, keysAndValues ...any) @@ -79,10 +97,47 @@ func New() Logger { return pLogger{} } +// Error logs show in the pod log output as `"level":"error","message":"some error msg"` +// where the message text comes from the err parameter. +// They also contain the standard `timestamp` and `caller` keys, along with any other keysAndValues. +// Only when the global log level is configured to "trace" or "all", then they will also include a `stacktrace` key. +// Error logs cannot be suppressed by the global log level configuration. func (p pLogger) Error(msg string, err error, keysAndValues ...any) { p.logr().WithCallDepth(p.depth+1).Error(err, msg, keysAndValues...) } +// Audit logs show in the pod log output as `"level":"info","message":"some msg","auditEvent":true` +// where the message text comes from the msg parameter. +// They also contain the standard `timestamp` and `caller` keys, along with any other keysAndValues. +// Only when the global log level is configured to "trace" or "all", then they will also include a `stacktrace` key. +// Audit logs cannot be suppressed by the global log level configuration, but rather can be disabled +// by their own separate configuration. This is because Audit logs should always be printed when they are desired +// by the admin, regardless of global log level, yet the admin should also have a way to entirely disable them +// when they want to avoid potential PII (e.g. usernames) in their pod logs. +// TODO: Add a way to disable output of audit logs, separate from the log level config. +func (p pLogger) Audit(msg AuditEventMessage, reqCtx context.Context, session SessionIDGetter, keysAndValues ...any) { + // Always add a key/value auditEvent=true. + keysAndValues = slices.Concat([]any{"auditEvent", true}, keysAndValues) + + var auditID string + if reqCtx != nil { + auditID = audit.GetAuditIDTruncated(reqCtx) + } + if len(auditID) > 0 { + keysAndValues = slices.Concat([]any{"auditID", auditID}, keysAndValues) + } + + var sessionID string + if session != nil { + sessionID = session.GetID() + } + if len(sessionID) > 0 { + keysAndValues = slices.Concat([]any{"sessionID", sessionID}, keysAndValues) + } + + p.logr().V(klogLevelWarning).WithCallDepth(p.depth+1).Info(string(msg), keysAndValues...) +} + func (p pLogger) warningDepth(msg string, depth int, keysAndValues ...any) { if p.logr().V(klogLevelWarning).Enabled() { // klog's structured logging has no concept of a warning (i.e. no WarningS function) @@ -94,10 +149,20 @@ func (p pLogger) warningDepth(msg string, depth int, keysAndValues ...any) { } } +// Warning logs show in the pod log output as `"level":"info","message":"some msg","warning":true` +// where the message text comes from the msg parameter. +// They also contain the standard `timestamp` and `caller` keys, along with any other keysAndValues. +// Only when the global log level is configured to "trace" or "all", then they will also include a `stacktrace` key. +// Warning logs cannot be suppressed by the global log level configuration. func (p pLogger) Warning(msg string, keysAndValues ...any) { p.warningDepth(msg, p.depth+1, keysAndValues...) } +// WarningErr logs show in the pod log output as `"level":"info","message":"some msg","warning":true,"error":"some error msg"` +// where the message text comes from the msg parameter and the error text comes from the err parameter. +// They also contain the standard `timestamp` and `caller` keys, along with any other keysAndValues. +// Only when the global log level is configured to "trace" or "all", then they will also include a `stacktrace` key. +// WarningErr logs cannot be suppressed by the global log level configuration. func (p pLogger) WarningErr(msg string, err error, keysAndValues ...any) { p.warningDepth(msg, p.depth+1, slices.Concat([]any{errorKey, err}, keysAndValues)...) } @@ -108,10 +173,20 @@ func (p pLogger) infoDepth(msg string, depth int, keysAndValues ...any) { } } +// Info logs show in the pod log output as `"level":"info","message":"some msg"` +// where the message text comes from the msg parameter. +// They also contain the standard `timestamp` and `caller` keys, along with any other keysAndValues. +// Only when the global log level is configured to "trace" or "all", then they will also include a `stacktrace` key. +// Info logs are suppressed by the global log level configuration, unless it is set to "info" or above. func (p pLogger) Info(msg string, keysAndValues ...any) { p.infoDepth(msg, p.depth+1, keysAndValues...) } +// InfoErr logs show in the pod log output as `"level":"info","message":"some msg","error":"some error msg"` +// where the message text comes from the msg parameter and the error text comes from the err parameter. +// They also contain the standard `timestamp` and `caller` keys, along with any other keysAndValues. +// Only when the global log level is configured to "trace" or "all", then they will also include a `stacktrace` key. +// InfoErr logs are suppressed by the global log level configuration, unless it is set to "info" or above. func (p pLogger) InfoErr(msg string, err error, keysAndValues ...any) { p.infoDepth(msg, p.depth+1, slices.Concat([]any{errorKey, err}, keysAndValues)...) } @@ -122,10 +197,20 @@ func (p pLogger) debugDepth(msg string, depth int, keysAndValues ...any) { } } +// Debug logs show in the pod log output as `"level":"debug","message":"some msg"` +// where the message text comes from the msg parameter. +// They also contain the standard `timestamp` and `caller` keys, along with any other keysAndValues. +// Only when the global log level is configured to "trace" or "all", then they will also include a `stacktrace` key. +// Debug logs are suppressed by the global log level configuration, unless it is set to "debug" or above. func (p pLogger) Debug(msg string, keysAndValues ...any) { p.debugDepth(msg, p.depth+1, keysAndValues...) } +// DebugErr logs show in the pod log output as `"level":"debug","message":"some msg","error":"some error msg"` +// where the message text comes from the msg parameter and the error text comes from the err parameter. +// They also contain the standard `timestamp` and `caller` keys, along with any other keysAndValues. +// Only when the global log level is configured to "trace" or "all", then they will also include a `stacktrace` key. +// DebugErr logs are suppressed by the global log level configuration, unless it is set to "debug" or above. func (p pLogger) DebugErr(msg string, err error, keysAndValues ...any) { p.debugDepth(msg, p.depth+1, slices.Concat([]any{errorKey, err}, keysAndValues)...) } @@ -136,20 +221,39 @@ func (p pLogger) traceDepth(msg string, depth int, keysAndValues ...any) { } } +// Trace logs show in the pod log output as `"level":"trace","message":"some msg"` +// where the message text comes from the msg parameter. +// They also contain the standard `timestamp` and `caller` keys, along with any other keysAndValues. +// Only when the global log level is configured to "trace" or "all", then they will also include a `stacktrace` key. +// Trace logs are suppressed by the global log level configuration, unless it is set to "trace" or above. func (p pLogger) Trace(msg string, keysAndValues ...any) { p.traceDepth(msg, p.depth+1, keysAndValues...) } +// TraceErr logs show in the pod log output as `"level":"trace","message":"some msg","error":"some error msg"` +// where the message text comes from the msg parameter and the error text comes from the err parameter. +// They also contain the standard `timestamp` and `caller` keys, along with any other keysAndValues. +// Only when the global log level is configured to "trace" or "all", then they will also include a `stacktrace` key. +// TraceErr logs are suppressed by the global log level configuration, unless it is set to "trace" or above. func (p pLogger) TraceErr(msg string, err error, keysAndValues ...any) { p.traceDepth(msg, p.depth+1, slices.Concat([]any{errorKey, err}, keysAndValues)...) } +// All logs show in the pod log output as `"level":"all","message":"some msg"` +// where the message text comes from the msg parameter. +// They also contain the standard `timestamp` and `caller` keys, along with any other keysAndValues. +// Only when the global log level is configured to "trace" or "all", then they will also include a `stacktrace` key. +// All logs are suppressed by the global log level configuration, unless it is set to "all" or above. func (p pLogger) All(msg string, keysAndValues ...any) { if p.logr().V(klogLevelAll).Enabled() { p.logr().V(klogLevelAll).WithCallDepth(p.depth+1).Info(msg, keysAndValues...) } } +// Always logs show in the pod log output exactly the same as an Info() message, +// except Always logs are always logged regardless of log level configuration. +// Only when the global log level is configured to "trace" or "all", then they will also include a `stacktrace` key. +// Always logs cannot be suppressed by the global log level configuration. func (p pLogger) Always(msg string, keysAndValues ...any) { p.logr().WithCallDepth(p.depth+1).Info(msg, keysAndValues...) } diff --git a/internal/registry/credentialrequest/rest.go b/internal/registry/credentialrequest/rest.go index 477cdde680..447a5500ce 100644 --- a/internal/registry/credentialrequest/rest.go +++ b/internal/registry/credentialrequest/rest.go @@ -22,6 +22,7 @@ import ( loginapi "go.pinniped.dev/generated/latest/apis/concierge/login" "go.pinniped.dev/internal/clientcertissuer" + "go.pinniped.dev/internal/plog" ) // clientCertificateTTL is the TTL for short-lived client certificates returned by this API. @@ -31,11 +32,17 @@ type TokenCredentialRequestAuthenticator interface { AuthenticateTokenCredentialRequest(ctx context.Context, req *loginapi.TokenCredentialRequest) (user.Info, error) } -func NewREST(authenticator TokenCredentialRequestAuthenticator, issuer clientcertissuer.ClientCertIssuer, resource schema.GroupResource) *REST { +func NewREST( + authenticator TokenCredentialRequestAuthenticator, + issuer clientcertissuer.ClientCertIssuer, + resource schema.GroupResource, + auditLogger plog.AuditLogger, +) *REST { return &REST{ authenticator: authenticator, issuer: issuer, tableConvertor: rest.NewDefaultTableConvertor(resource), + auditLogger: auditLogger, } } @@ -43,6 +50,7 @@ type REST struct { authenticator TokenCredentialRequestAuthenticator issuer clientcertissuer.ClientCertIssuer tableConvertor rest.TableConvertor + auditLogger plog.AuditLogger } // Assert that our *REST implements all the optional interfaces that we expect it to implement. @@ -123,6 +131,12 @@ func (r *REST) Create(ctx context.Context, obj runtime.Object, createValidation traceSuccess(t, userInfo, true) + r.auditLogger.Audit(plog.AuditEventTokenCredentialRequest, ctx, nil, + "username", userInfo.GetName(), + "groups", userInfo.GetGroups(), + "authenticated", true, + "expires", expires.Format(time.RFC3339)) + return &loginapi.TokenCredentialRequest{ Status: loginapi.TokenCredentialRequestStatus{ Credential: &loginapi.ClusterCredential{ diff --git a/internal/registry/credentialrequest/rest_test.go b/internal/registry/credentialrequest/rest_test.go index 5054c5cc92..20f8b6c414 100644 --- a/internal/registry/credentialrequest/rest_test.go +++ b/internal/registry/credentialrequest/rest_test.go @@ -28,11 +28,12 @@ import ( "go.pinniped.dev/internal/clientcertissuer" "go.pinniped.dev/internal/mocks/mockcredentialrequest" "go.pinniped.dev/internal/mocks/mockissuer" + "go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/testutil" ) func TestNew(t *testing.T) { - r := NewREST(nil, nil, schema.GroupResource{Group: "bears", Resource: "panda"}) + r := NewREST(nil, nil, schema.GroupResource{Group: "bears", Resource: "panda"}, plog.New()) require.NotNil(t, r) require.False(t, r.NamespaceScoped()) require.Equal(t, []string{"pinniped"}, r.Categories()) @@ -103,7 +104,7 @@ func TestCreate(t *testing.T) { 5*time.Minute, ).Return([]byte("test-cert"), []byte("test-key"), nil) - storage := NewREST(requestAuthenticator, clientCertIssuer, schema.GroupResource{}) + storage := NewREST(requestAuthenticator, clientCertIssuer, schema.GroupResource{}, plog.New()) response, err := callCreate(context.Background(), storage, req) @@ -142,7 +143,7 @@ func TestCreate(t *testing.T) { IssueClientCertPEM(gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, nil, fmt.Errorf("some certificate authority error")) - storage := NewREST(requestAuthenticator, clientCertIssuer, schema.GroupResource{}) + storage := NewREST(requestAuthenticator, clientCertIssuer, schema.GroupResource{}, plog.New()) response, err := callCreate(context.Background(), storage, req) requireSuccessfulResponseWithAuthenticationFailureMessage(t, err, response) @@ -155,7 +156,7 @@ func TestCreate(t *testing.T) { requestAuthenticator := mockcredentialrequest.NewMockTokenCredentialRequestAuthenticator(ctrl) requestAuthenticator.EXPECT().AuthenticateTokenCredentialRequest(gomock.Any(), req).Return(nil, nil) - storage := NewREST(requestAuthenticator, nil, schema.GroupResource{}) + storage := NewREST(requestAuthenticator, nil, schema.GroupResource{}, plog.New()) response, err := callCreate(context.Background(), storage, req) @@ -170,7 +171,7 @@ func TestCreate(t *testing.T) { requestAuthenticator.EXPECT().AuthenticateTokenCredentialRequest(gomock.Any(), req). Return(nil, errors.New("some webhook error")) - storage := NewREST(requestAuthenticator, nil, schema.GroupResource{}) + storage := NewREST(requestAuthenticator, nil, schema.GroupResource{}, plog.New()) response, err := callCreate(context.Background(), storage, req) @@ -185,7 +186,7 @@ func TestCreate(t *testing.T) { requestAuthenticator.EXPECT().AuthenticateTokenCredentialRequest(gomock.Any(), req). Return(&user.DefaultInfo{Name: ""}, nil) - storage := NewREST(requestAuthenticator, nil, schema.GroupResource{}) + storage := NewREST(requestAuthenticator, nil, schema.GroupResource{}, plog.New()) response, err := callCreate(context.Background(), storage, req) @@ -204,7 +205,7 @@ func TestCreate(t *testing.T) { Groups: []string{"test-group-1", "test-group-2"}, }, nil) - storage := NewREST(requestAuthenticator, nil, schema.GroupResource{}) + storage := NewREST(requestAuthenticator, nil, schema.GroupResource{}, plog.New()) response, err := callCreate(context.Background(), storage, req) @@ -223,7 +224,7 @@ func TestCreate(t *testing.T) { Extra: map[string][]string{"test-key": {"test-val-1", "test-val-2"}}, }, nil) - storage := NewREST(requestAuthenticator, nil, schema.GroupResource{}) + storage := NewREST(requestAuthenticator, nil, schema.GroupResource{}, plog.New()) response, err := callCreate(context.Background(), storage, req) @@ -233,7 +234,7 @@ func TestCreate(t *testing.T) { it("CreateFailsWhenGivenTheWrongInputType", func() { notACredentialRequest := runtime.Unknown{} - response, err := NewREST(nil, nil, schema.GroupResource{}).Create( + response, err := NewREST(nil, nil, schema.GroupResource{}, plog.New()).Create( genericapirequest.NewContext(), ¬ACredentialRequest, rest.ValidateAllObjectFunc, @@ -244,7 +245,7 @@ func TestCreate(t *testing.T) { }) it("CreateFailsWhenTokenValueIsEmptyInRequest", func() { - storage := NewREST(nil, nil, schema.GroupResource{}) + storage := NewREST(nil, nil, schema.GroupResource{}, plog.New()) response, err := callCreate(context.Background(), storage, credentialRequest(loginapi.TokenCredentialRequestSpec{ Token: "", })) @@ -255,7 +256,7 @@ func TestCreate(t *testing.T) { }) it("CreateFailsWhenValidationFails", func() { - storage := NewREST(nil, nil, schema.GroupResource{}) + storage := NewREST(nil, nil, schema.GroupResource{}, plog.New()) response, err := storage.Create( context.Background(), validCredentialRequest(), @@ -275,7 +276,7 @@ func TestCreate(t *testing.T) { requestAuthenticator.EXPECT().AuthenticateTokenCredentialRequest(gomock.Any(), req.DeepCopy()). Return(&user.DefaultInfo{Name: "test-user"}, nil) - storage := NewREST(requestAuthenticator, successfulIssuer(ctrl), schema.GroupResource{}) + storage := NewREST(requestAuthenticator, successfulIssuer(ctrl), schema.GroupResource{}, plog.New()) response, err := storage.Create( context.Background(), req, @@ -296,7 +297,7 @@ func TestCreate(t *testing.T) { requestAuthenticator.EXPECT().AuthenticateTokenCredentialRequest(gomock.Any(), req.DeepCopy()). Return(&user.DefaultInfo{Name: "test-user"}, nil) - storage := NewREST(requestAuthenticator, successfulIssuer(ctrl), schema.GroupResource{}) + storage := NewREST(requestAuthenticator, successfulIssuer(ctrl), schema.GroupResource{}, plog.New()) validationFunctionWasCalled := false var validationFunctionSawTokenValue string response, err := storage.Create( @@ -316,7 +317,7 @@ func TestCreate(t *testing.T) { }) it("CreateFailsWhenRequestOptionsDryRunIsNotEmpty", func() { - response, err := NewREST(nil, nil, schema.GroupResource{}).Create( + response, err := NewREST(nil, nil, schema.GroupResource{}, plog.New()).Create( genericapirequest.NewContext(), validCredentialRequest(), rest.ValidateAllObjectFunc, @@ -330,7 +331,7 @@ func TestCreate(t *testing.T) { }) it("CreateFailsWhenNamespaceIsNotEmpty", func() { - response, err := NewREST(nil, nil, schema.GroupResource{}).Create( + response, err := NewREST(nil, nil, schema.GroupResource{}, plog.New()).Create( genericapirequest.WithNamespace(genericapirequest.NewContext(), "some-ns"), validCredentialRequest(), rest.ValidateAllObjectFunc, diff --git a/internal/supervisor/server/server.go b/internal/supervisor/server/server.go index 83c28664a2..8fc9b9d58b 100644 --- a/internal/supervisor/server/server.go +++ b/internal/supervisor/server/server.go @@ -167,6 +167,7 @@ func prepareControllers( kubeClient, secretInformer, controllerlib.WithInformer, + plog.New(), ), singletonWorker, ). @@ -483,6 +484,7 @@ func runSupervisor(ctx context.Context, podInfo *downward.PodInfo, cfg *supervis &secretCache, clientWithoutLeaderElection.Kubernetes.CoreV1().Secrets(serverInstallationNamespace), // writes to kube storage are allowed for non-leaders client.PinnipedSupervisor.ConfigV1alpha1().OIDCClients(serverInstallationNamespace), + plog.New(), ) // Get the "real" name of the client secret supervisor API group (i.e., the API group name with the @@ -544,7 +546,7 @@ func runSupervisor(ctx context.Context, podInfo *downward.PodInfo, cfg *supervis } defer func() { _ = httpListener.Close() }() - startServer(ctx, shutdown, httpListener, oidProvidersManager) + startServer(ctx, shutdown, httpListener, oidProvidersManager.HandlerChain()) plog.Debug("supervisor http listener started", "address", httpListener.Addr().String()) } @@ -601,7 +603,7 @@ func runSupervisor(ctx context.Context, podInfo *downward.PodInfo, cfg *supervis } defer func() { _ = httpsListener.Close() }() - startServer(ctx, shutdown, httpsListener, oidProvidersManager) + startServer(ctx, shutdown, httpsListener, oidProvidersManager.HandlerChain()) plog.Debug("supervisor https listener started", "address", httpsListener.Addr().String()) }