diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index daea6b2e08..1b6077b7f6 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -14,6 +14,7 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/hashicorp/golang-lru/v2" "golang.org/x/oauth2" "github.com/dexidp/dex/connector" @@ -22,8 +23,9 @@ import ( ) const ( - codeChallengeMethodPlain = "plain" - codeChallengeMethodS256 = "S256" + defaultPkceMaxConcurrentConnections = 256 + codeChallengeMethodPlain = "plain" + codeChallengeMethodS256 = "S256" ) func contains(arr []string, item string) bool { @@ -93,8 +95,13 @@ type Config struct { // PromptType will be used for the prompt parameter (when offline_access, by default prompt=consent) PromptType *string `json:"promptType"` + // PKCEChallenge specifies which PKCE algorithm will be used + // If not setted it will be auto-detected the best-fit for the connector. PKCEChallenge string `json:"pkceChallenge"` + // PKCEMaxConcurrentConnections specifies the maximum number of concurrent connections for the PKCE code verify. + PKCEMaxConcurrentConnections int `json:"pkceMaxConcurrentConnections"` + // OverrideClaimMapping will be used to override the options defined in claimMappings. // i.e. if there are 'email' and `preferred_email` claims available, by default Dex will always use the `email` claim independent of the ClaimMapping.EmailKey. // This setting allows you to override the default behavior of Dex and enforce the mappings defined in `claimMapping`. @@ -304,7 +311,19 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, logger.Warn("provided PKCEChallenge method not supported by the connector") } } - pkceVerifier := "" + + // if PKCE will be used, create a state cache for verifier + var pkceVerifierCache *lru.Cache[string, string] + if c.PKCEChallenge != "" { + pkceCacheSize := c.PKCEMaxConcurrentConnections + if pkceCacheSize == 0 { + pkceCacheSize = defaultPkceMaxConcurrentConnections + } + pkceVerifierCache, err = lru.New[string, string](pkceCacheSize) + if err != nil { + logger.Warn("Unable to create PKCE Verifier cache") + } + } clientID := c.ClientID return &oidcConnector{ @@ -321,25 +340,26 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, ctx, // Pass our ctx with customized http.Client &oidc.Config{ClientID: clientID}, ), - logger: logger.With(slog.Group("connector", "type", "oidc", "id", id)), - cancel: cancel, - httpClient: httpClient, - insecureSkipEmailVerified: c.InsecureSkipEmailVerified, - insecureEnableGroups: c.InsecureEnableGroups, - allowedGroups: c.AllowedGroups, - acrValues: c.AcrValues, - getUserInfo: c.GetUserInfo, - promptType: promptType, - userIDKey: c.UserIDKey, - userNameKey: c.UserNameKey, - overrideClaimMapping: c.OverrideClaimMapping, - preferredUsernameKey: c.ClaimMapping.PreferredUsernameKey, - emailKey: c.ClaimMapping.EmailKey, - groupsKey: c.ClaimMapping.GroupsKey, - newGroupFromClaims: c.ClaimMutations.NewGroupFromClaims, - groupsFilter: groupsFilter, - pkceChallenge: c.PKCEChallenge, - pkceVerifier: pkceVerifier, + logger: logger.With(slog.Group("connector", "type", "oidc", "id", id)), + cancel: cancel, + httpClient: httpClient, + insecureSkipEmailVerified: c.InsecureSkipEmailVerified, + insecureEnableGroups: c.InsecureEnableGroups, + allowedGroups: c.AllowedGroups, + acrValues: c.AcrValues, + getUserInfo: c.GetUserInfo, + promptType: promptType, + userIDKey: c.UserIDKey, + userNameKey: c.UserNameKey, + overrideClaimMapping: c.OverrideClaimMapping, + preferredUsernameKey: c.ClaimMapping.PreferredUsernameKey, + emailKey: c.ClaimMapping.EmailKey, + groupsKey: c.ClaimMapping.GroupsKey, + newGroupFromClaims: c.ClaimMutations.NewGroupFromClaims, + groupsFilter: groupsFilter, + pkceChallenge: c.PKCEChallenge, + pkceMaxConcurrentConnections: c.PKCEMaxConcurrentConnections, + pkceVerifierCache: pkceVerifierCache, }, nil } @@ -349,29 +369,30 @@ var ( ) type oidcConnector struct { - provider *oidc.Provider - redirectURI string - oauth2Config *oauth2.Config - verifier *oidc.IDTokenVerifier - cancel context.CancelFunc - logger *slog.Logger - httpClient *http.Client - insecureSkipEmailVerified bool - insecureEnableGroups bool - allowedGroups []string - acrValues []string - getUserInfo bool - promptType string - userIDKey string - userNameKey string - overrideClaimMapping bool - preferredUsernameKey string - emailKey string - groupsKey string - newGroupFromClaims []NewGroupFromClaims - groupsFilter *regexp.Regexp - pkceChallenge string - pkceVerifier string + provider *oidc.Provider + redirectURI string + oauth2Config *oauth2.Config + verifier *oidc.IDTokenVerifier + cancel context.CancelFunc + logger *slog.Logger + httpClient *http.Client + insecureSkipEmailVerified bool + insecureEnableGroups bool + allowedGroups []string + acrValues []string + getUserInfo bool + promptType string + userIDKey string + userNameKey string + overrideClaimMapping bool + preferredUsernameKey string + emailKey string + groupsKey string + newGroupFromClaims []NewGroupFromClaims + groupsFilter *regexp.Regexp + pkceChallenge string + pkceMaxConcurrentConnections int + pkceVerifierCache *lru.Cache[string, string] } func (c *oidcConnector) Close() error { @@ -398,11 +419,13 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) if c.pkceChallenge != "" { switch c.pkceChallenge { case codeChallengeMethodPlain: - c.pkceVerifier = oauth2.GenerateVerifier() - opts = append(opts, oauth2.VerifierOption(c.pkceVerifier)) + pkceVerifier := oauth2.GenerateVerifier() + c.pkceVerifierCache.Add(state, pkceVerifier) + opts = append(opts, oauth2.VerifierOption(pkceVerifier)) case codeChallengeMethodS256: - c.pkceVerifier = oauth2.GenerateVerifier() - opts = append(opts, oauth2.S256ChallengeOption(c.pkceVerifier)) + pkceVerifier := oauth2.GenerateVerifier() + c.pkceVerifierCache.Add(state, pkceVerifier) + opts = append(opts, oauth2.S256ChallengeOption(pkceVerifier)) default: c.logger.Warn("unknown PKCEChallenge method") } @@ -440,8 +463,21 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide ctx := context.WithValue(r.Context(), oauth2.HTTPClient, c.httpClient) var opts []oauth2.AuthCodeOption - if c.pkceVerifier != "" { - opts = append(opts, oauth2.VerifierOption(c.pkceVerifier)) + if c.pkceChallenge != "" { + state := q.Get("state") + if state == "" { + return identity, fmt.Errorf("oidc: missing state in callback") + } + pkceVerifier, found := c.pkceVerifierCache.Get(state) + if !found { + return identity, fmt.Errorf("oidc: received state not in callback cache") + } + + c.pkceVerifierCache.Remove(state) + if pkceVerifier == "" { + return identity, fmt.Errorf("oidc: invalid state in pkce verifier cache") + } + opts = append(opts, oauth2.VerifierOption(pkceVerifier)) } token, err := c.oauth2Config.Exchange(ctx, q.Get("code"), opts...) diff --git a/go.mod b/go.mod index e2100a396b..96671dcd8b 100644 --- a/go.mod +++ b/go.mod @@ -69,6 +69,7 @@ require ( github.com/google/s2a-go v0.1.8 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.3 // indirect github.com/googleapis/gax-go/v2 v2.13.0 // indirect + github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/hashicorp/hcl/v2 v2.13.0 // indirect github.com/huandu/xstrings v1.5.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index cbebe635ec..5cbdb3cf41 100644 --- a/go.sum +++ b/go.sum @@ -135,6 +135,8 @@ github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgf github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/hcl/v2 v2.13.0 h1:0Apadu1w6M11dyGFxWnmhhcMjkbAiKCv7G1r/2QgCNc= github.com/hashicorp/hcl/v2 v2.13.0/go.mod h1:e4z5nxYlWNPdDSNYX+ph14EvWYMFm3eP0zIUqPc2jr0= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=