diff --git a/changelog/13917.txt b/changelog/13917.txt new file mode 100644 index 000000000000..aa166c330792 --- /dev/null +++ b/changelog/13917.txt @@ -0,0 +1,3 @@ +```release-note:improvement +identity/oidc: Adds proof key for code exchange (PKCE) support to OIDC providers. +``` diff --git a/vault/external_tests/identity/oidc_provider_test.go b/vault/external_tests/identity/oidc_provider_test.go index aeb5c0d17ff5..30350bfc814c 100644 --- a/vault/external_tests/identity/oidc_provider_test.go +++ b/vault/external_tests/identity/oidc_provider_test.go @@ -39,11 +39,12 @@ const ( ` ) -// TestOIDC_Auth_Code_Flow_CAP_Client tests the authorization code flow -// using a Vault OIDC provider. The test uses the CAP OIDC client to verify -// that the Vault OIDC provider's responses pass the various client-side -// validation requirements of the OIDC spec. -func TestOIDC_Auth_Code_Flow_CAP_Client(t *testing.T) { +// TestOIDC_Auth_Code_Flow_Confidential_CAP_Client tests the authorization code +// flow using a Vault OIDC provider. The test uses the CAP OIDC client to verify +// that the Vault OIDC provider's responses pass the various client-side validation +// requirements of the OIDC spec. This test uses a confidential client which has +// a client secret and authenticates to the token endpoint. +func TestOIDC_Auth_Code_Flow_Confidential_CAP_Client(t *testing.T) { cluster := setupOIDCTestCluster(t, 2) defer cluster.Cleanup() active := cluster.Cores[0].Client @@ -131,8 +132,8 @@ func TestOIDC_Auth_Code_Flow_CAP_Client(t *testing.T) { }) require.NoError(t, err) - // Create a client - _, err = active.Logical().Write("identity/oidc/client/test-client", map[string]interface{}{ + // Create a confidential client + _, err = active.Logical().Write("identity/oidc/client/confidential", map[string]interface{}{ "key": "test-key", "redirect_uris": []string{testRedirectURI}, "assignments": []string{"test-assignment"}, @@ -142,7 +143,7 @@ func TestOIDC_Auth_Code_Flow_CAP_Client(t *testing.T) { require.NoError(t, err) // Read the client ID and secret in order to configure the OIDC client - resp, err = active.Logical().Read("identity/oidc/client/test-client") + resp, err = active.Logical().Read("identity/oidc/client/confidential") require.NoError(t, err) clientID := resp.Data["client_id"].(string) clientSecret := resp.Data["client_secret"].(string) @@ -191,6 +192,10 @@ func TestOIDC_Auth_Code_Flow_CAP_Client(t *testing.T) { require.NoError(t, err) defer p.Done() + // Create the client-side PKCE code verifier + v, err := oidc.NewCodeVerifier() + require.NoError(t, err) + type args struct { useStandby bool options []oidc.Option @@ -255,6 +260,21 @@ func TestOIDC_Auth_Code_Flow_CAP_Client(t *testing.T) { "auth_time": %d }`, discovery.Issuer, clientID, entityID, expectedAuthTime), }, + { + name: "active: authorization code flow with Proof Key for Code Exchange (PKCE)", + args: args{ + options: []oidc.Option{ + oidc.WithScopes("openid"), + oidc.WithPKCE(v), + }, + }, + expected: fmt.Sprintf(`{ + "iss": "%s", + "aud": "%s", + "sub": "%s", + "namespace": "root" + }`, discovery.Issuer, clientID, entityID), + }, { name: "standby: authorization code flow with additional scopes", args: args{ @@ -369,6 +389,342 @@ func TestOIDC_Auth_Code_Flow_CAP_Client(t *testing.T) { } } +// TestOIDC_Auth_Code_Flow_Public_CAP_Client tests the authorization code flow using +// a Vault OIDC provider. The test uses the CAP OIDC client to verify that the Vault +// OIDC provider's responses pass the various client-side validation requirements of +// the OIDC spec. This test uses a public client which does not have a client secret +// and always uses proof key for code exchange (PKCE). +func TestOIDC_Auth_Code_Flow_Public_CAP_Client(t *testing.T) { + cluster := setupOIDCTestCluster(t, 2) + defer cluster.Cleanup() + active := cluster.Cores[0].Client + standby := cluster.Cores[1].Client + + // Create an entity with some metadata + resp, err := active.Logical().Write("identity/entity", map[string]interface{}{ + "name": "test-entity", + "metadata": map[string]string{ + "email": "test@hashicorp.com", + "phone_number": "123-456-7890", + }, + }) + require.NoError(t, err) + entityID := resp.Data["id"].(string) + + // Create a group + resp, err = active.Logical().Write("identity/group", map[string]interface{}{ + "name": "engineering", + "member_entity_ids": []string{entityID}, + }) + require.NoError(t, err) + groupID := resp.Data["id"].(string) + + // Create a policy that allows updating the provider + err = active.Sys().PutPolicy("test-policy", ` + path "identity/oidc/provider/test-provider" { + capabilities = ["update"] + } + `) + require.NoError(t, err) + + // Enable userpass auth and create a user + err = active.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{ + Type: "userpass", + }) + require.NoError(t, err) + _, err = active.Logical().Write("auth/userpass/users/end-user", map[string]interface{}{ + "password": testPassword, + "token_policies": "test-policy", + }) + require.NoError(t, err) + + // Get the userpass mount accessor + mounts, err := active.Sys().ListAuth() + require.NoError(t, err) + var mountAccessor string + for k, v := range mounts { + if k == "userpass/" { + mountAccessor = v.Accessor + break + } + } + require.NotEmpty(t, mountAccessor) + + // Create an entity alias + _, err = active.Logical().Write("identity/entity-alias", map[string]interface{}{ + "name": "end-user", + "canonical_id": entityID, + "mount_accessor": mountAccessor, + }) + require.NoError(t, err) + + // Create some custom scopes + _, err = active.Logical().Write("identity/oidc/scope/groups", map[string]interface{}{ + "template": testGroupScopeTemplate, + }) + require.NoError(t, err) + _, err = active.Logical().Write("identity/oidc/scope/user", map[string]interface{}{ + "template": fmt.Sprintf(testUserScopeTemplate, mountAccessor), + }) + require.NoError(t, err) + + // Create a key + _, err = active.Logical().Write("identity/oidc/key/test-key", map[string]interface{}{ + "allowed_client_ids": []string{"*"}, + "algorithm": "RS256", + }) + require.NoError(t, err) + + // Create an assignment + _, err = active.Logical().Write("identity/oidc/assignment/test-assignment", map[string]interface{}{ + "entity_ids": []string{entityID}, + "group_ids": []string{groupID}, + }) + require.NoError(t, err) + + // Create a public client + _, err = active.Logical().Write("identity/oidc/client/public", map[string]interface{}{ + "key": "test-key", + "redirect_uris": []string{testRedirectURI}, + "assignments": []string{"test-assignment"}, + "id_token_ttl": "1h", + "access_token_ttl": "30m", + "client_type": "public", + }) + require.NoError(t, err) + + // Read the client ID in order to configure the OIDC client + resp, err = active.Logical().Read("identity/oidc/client/public") + require.NoError(t, err) + clientID := resp.Data["client_id"].(string) + + // Create the OIDC provider + _, err = active.Logical().Write("identity/oidc/provider/test-provider", map[string]interface{}{ + "allowed_client_ids": []string{clientID}, + "scopes_supported": []string{"user", "groups"}, + }) + require.NoError(t, err) + + // We aren't going to open up a browser to facilitate the login and redirect + // from this test, so we'll log in via userpass and set the client's token as + // the token that results from the authentication. + resp, err = active.Logical().Write("auth/userpass/login/end-user", map[string]interface{}{ + "password": testPassword, + }) + require.NoError(t, err) + clientToken := resp.Auth.ClientToken + + // Look up the token to get its creation time. This will be used for test + // cases that make assertions on the max_age parameter and auth_time claim. + resp, err = active.Logical().Write("auth/token/lookup", map[string]interface{}{ + "token": clientToken, + }) + require.NoError(t, err) + expectedAuthTime, err := strconv.Atoi(string(resp.Data["creation_time"].(json.Number))) + require.NoError(t, err) + + // Read the issuer from the OIDC provider's discovery document + var discovery struct { + Issuer string `json:"issuer"` + } + decodeRawRequest(t, active, http.MethodGet, + "/v1/identity/oidc/provider/test-provider/.well-known/openid-configuration", + nil, &discovery) + + // Create the client-side OIDC provider config with client secret intentionally empty + clientSecret := oidc.ClientSecret("") + pc, err := oidc.NewConfig(discovery.Issuer, clientID, clientSecret, []oidc.Alg{oidc.RS256}, + []string{testRedirectURI}, oidc.WithProviderCA(string(cluster.CACertPEM))) + require.NoError(t, err) + + // Create the client-side OIDC provider + p, err := oidc.NewProvider(pc) + require.NoError(t, err) + defer p.Done() + + type args struct { + useStandby bool + options []oidc.Option + } + tests := []struct { + name string + args args + expected string + }{ + { + name: "active: authorization code flow", + args: args{ + options: []oidc.Option{ + oidc.WithScopes("openid user"), + }, + }, + expected: fmt.Sprintf(`{ + "iss": "%s", + "aud": "%s", + "sub": "%s", + "namespace": "root", + "username": "end-user", + "contact": { + "email": "test@hashicorp.com", + "phone_number": "123-456-7890" + } + }`, discovery.Issuer, clientID, entityID), + }, + { + name: "active: authorization code flow with additional scopes", + args: args{ + options: []oidc.Option{ + oidc.WithScopes("openid user groups"), + }, + }, + expected: fmt.Sprintf(`{ + "iss": "%s", + "aud": "%s", + "sub": "%s", + "namespace": "root", + "username": "end-user", + "contact": { + "email": "test@hashicorp.com", + "phone_number": "123-456-7890" + }, + "groups": ["engineering"] + }`, discovery.Issuer, clientID, entityID), + }, + { + name: "active: authorization code flow with max_age parameter", + args: args{ + options: []oidc.Option{ + oidc.WithScopes("openid"), + oidc.WithMaxAge(60), + }, + }, + expected: fmt.Sprintf(`{ + "iss": "%s", + "aud": "%s", + "sub": "%s", + "namespace": "root", + "auth_time": %d + }`, discovery.Issuer, clientID, entityID, expectedAuthTime), + }, + { + name: "standby: authorization code flow with additional scopes", + args: args{ + useStandby: true, + options: []oidc.Option{ + oidc.WithScopes("openid user groups"), + }, + }, + expected: fmt.Sprintf(`{ + "iss": "%s", + "aud": "%s", + "sub": "%s", + "namespace": "root", + "username": "end-user", + "contact": { + "email": "test@hashicorp.com", + "phone_number": "123-456-7890" + }, + "groups": ["engineering"] + }`, discovery.Issuer, clientID, entityID), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := active + if tt.args.useStandby { + client = standby + } + client.SetToken(clientToken) + + // Update allowed client IDs before the authentication flow + _, err = client.Logical().Write("identity/oidc/provider/test-provider", map[string]interface{}{ + "allowed_client_ids": []string{clientID}, + }) + require.NoError(t, err) + + // Create the required client-side PKCE code verifier. + v, err := oidc.NewCodeVerifier() + require.NoError(t, err) + options := append([]oidc.Option{oidc.WithPKCE(v)}, tt.args.options...) + + // Create the client-side OIDC request state + oidcRequest, err := oidc.NewRequest(10*time.Minute, testRedirectURI, options...) + require.NoError(t, err) + + // Get the URL for the authorization endpoint from the OIDC client + authURL, err := p.AuthURL(context.Background(), oidcRequest) + require.NoError(t, err) + parsedAuthURL, err := url.Parse(authURL) + require.NoError(t, err) + + // This replace only occurs because we're not using the browser in this test + authURLPath := strings.Replace(parsedAuthURL.Path, "/ui/vault/", "/v1/", 1) + + // Kick off the authorization code flow + var authResp struct { + Code string `json:"code"` + State string `json:"state"` + } + decodeRawRequest(t, client, http.MethodGet, authURLPath, parsedAuthURL.Query(), &authResp) + + // The returned state must match the OIDC client state + require.Equal(t, oidcRequest.State(), authResp.State) + + // Exchange the authorization code for an ID token and access token. + // The ID token signature is verified using the provider's public keys after + // the exchange takes place. The ID token is also validated according to the + // client-side requirements of the OIDC spec. See the validation code at: + // - https://github.com/hashicorp/cap/blob/main/oidc/provider.go#L240 + // - https://github.com/hashicorp/cap/blob/main/oidc/provider.go#L441 + token, err := p.Exchange(context.Background(), oidcRequest, authResp.State, authResp.Code) + require.NoError(t, err) + require.NotNil(t, token) + idToken := token.IDToken() + accessToken := token.StaticTokenSource() + + // Get the ID token claims + allClaims := make(map[string]interface{}) + require.NoError(t, idToken.Claims(&allClaims)) + + // Get the sub claim for userinfo validation + require.NotEmpty(t, allClaims["sub"]) + subject := allClaims["sub"].(string) + + // Request userinfo using the access token + err = p.UserInfo(context.Background(), accessToken, subject, &allClaims) + require.NoError(t, err) + + // Assert that claims computed during the flow (i.e., not known + // ahead of time in this test) are present as top-level keys + for _, claim := range []string{"iat", "exp", "nonce", "at_hash", "c_hash"} { + _, ok := allClaims[claim] + require.True(t, ok) + } + + // Assert that all other expected claims are populated + expectedClaims := make(map[string]interface{}) + require.NoError(t, json.Unmarshal([]byte(tt.expected), &expectedClaims)) + for k, expectedVal := range expectedClaims { + actualVal, ok := allClaims[k] + require.True(t, ok) + require.EqualValues(t, expectedVal, actualVal) + } + + // Assert that the access token is no longer able to obtain user info + // after removing the client from the provider's allowed client ids + _, err = client.Logical().Write("identity/oidc/provider/test-provider", map[string]interface{}{ + "allowed_client_ids": []string{}, + }) + require.NoError(t, err) + err = p.UserInfo(context.Background(), accessToken, subject, &allClaims) + require.Error(t, err) + require.Equal(t, `Provider.UserInfo: provider UserInfo request failed: 403 Forbidden: {"error":"access_denied","error_description":"client is not authorized to use the provider"}`, + err.Error()) + }) + } +} + func setupOIDCTestCluster(t *testing.T, numCores int) *vault.TestCluster { t.Helper() diff --git a/vault/identity_store_oidc_provider.go b/vault/identity_store_oidc_provider.go index e2e911d93a20..a1f816115666 100644 --- a/vault/identity_store_oidc_provider.go +++ b/vault/identity_store_oidc_provider.go @@ -26,13 +26,15 @@ import ( const ( // OIDC-related constants - openIDScope = "openid" - scopesDelimiter = " " - accessTokenScopesMeta = "scopes" - accessTokenClientIDMeta = "client_id" - clientIDLength = 32 - clientSecretLength = 64 - clientSecretPrefix = "hvo_secret_" + openIDScope = "openid" + scopesDelimiter = " " + accessTokenScopesMeta = "scopes" + accessTokenClientIDMeta = "client_id" + clientIDLength = 32 + clientSecretLength = 64 + clientSecretPrefix = "hvo_secret_" + codeChallengeMethodPlain = "plain" + codeChallengeMethodS256 = "S256" // Storage path constants oidcProviderPrefix = "oidc_provider/" @@ -95,12 +97,31 @@ type client struct { Key string `json:"key"` IDTokenTTL time.Duration `json:"id_token_ttl"` AccessTokenTTL time.Duration `json:"access_token_ttl"` + Type clientType `json:"type"` // Generated values that are used in OIDC endpoints ClientID string `json:"client_id"` ClientSecret string `json:"client_secret"` } +type clientType int + +const ( + confidential clientType = iota + public +) + +func (k clientType) String() string { + switch k { + case confidential: + return "confidential" + case public: + return "public" + default: + return "unknown" + } +} + type provider struct { Issuer string `json:"issuer"` AllowedClientIDs []string `json:"allowed_client_ids"` @@ -127,13 +148,15 @@ type providerDiscovery struct { } type authCodeCacheEntry struct { - provider string - clientID string - entityID string - redirectURI string - nonce string - scopes []string - authTime time.Time + provider string + clientID string + entityID string + redirectURI string + nonce string + scopes []string + authTime time.Time + codeChallenge string + codeChallengeMethod string } func oidcProviderPaths(i *IdentityStore) []*framework.Path { @@ -256,6 +279,11 @@ func oidcProviderPaths(i *IdentityStore) []*framework.Path { Description: "The time-to-live for access tokens obtained by the client.", Default: "24h", }, + "client_type": { + Type: framework.TypeString, + Description: "The client type based on its ability to maintain confidentiality of credentials. The following client types are supported: 'confidential', 'public'. Defaults to 'confidential'.", + Default: "confidential", + }, }, Operations: map[logical.Operation]framework.OperationHandler{ logical.UpdateOperation: &framework.PathOperation{ @@ -405,6 +433,15 @@ func oidcProviderPaths(i *IdentityStore) []*framework.Path { Type: framework.TypeInt, Description: "The allowable elapsed time in seconds since the last time the end-user was actively authenticated.", }, + "code_challenge": { + Type: framework.TypeString, + Description: "The code challenge derived from the code verifier.", + }, + "code_challenge_method": { + Type: framework.TypeString, + Description: "The method that was used to derive the code challenge. The following methods are supported: 'S256', 'plain'. Defaults to 'plain'.", + Default: codeChallengeMethodPlain, + }, }, Operations: map[logical.Operation]framework.OperationHandler{ logical.ReadOperation: &framework.PathOperation{ @@ -443,10 +480,23 @@ func oidcProviderPaths(i *IdentityStore) []*framework.Path { Description: "The callback location where the authentication response was sent.", Required: true, }, - // The client_id and client_secret are provided to the token endpoint via - // the client_secret_basic authentication method, which uses the HTTP Basic - // authentication scheme. See the OIDC spec for details at: + "code_verifier": { + Type: framework.TypeString, + Description: "The code verifier associated with the authorization code.", + }, + // For confidential clients, the client_id and client_secret are provided to + // the token endpoint via the 'client_secret_basic' authentication method, which + // uses the HTTP Basic authentication scheme. See the OIDC spec for details at: // https://openid.net/specs/openid-connect-core-1_0.html#ClientAuthentication + + // For public clients, the client_id is required and a client_secret does + // not exist. This means that public clients use the 'none' authentication + // method. However, public clients are required to use Proof Key for Code + // Exchange (PKCE) when using the authorization code flow. + "client_id": { + Type: framework.TypeString, + Description: "The ID of the requesting client.", + }, }, Operations: map[logical.Operation]framework.OperationHandler{ logical.UpdateOperation: &framework.PathOperation{ @@ -984,6 +1034,22 @@ func (i *IdentityStore) pathOIDCCreateUpdateClient(ctx context.Context, req *log client.AccessTokenTTL = time.Duration(d.Get("access_token_ttl").(int)) * time.Second } + if clientTypeRaw, ok := d.GetOk("client_type"); ok { + clientType := clientTypeRaw.(string) + if req.Operation == logical.UpdateOperation && client.Type.String() != clientType { + return logical.ErrorResponse("client_type modification is not allowed"), nil + } + + switch clientType { + case confidential.String(): + client.Type = confidential + case public.String(): + client.Type = public + default: + return logical.ErrorResponse("invalid client_type %q", clientType), nil + } + } + if client.ClientID == "" { // generate client_id clientID, err := base62.Random(clientIDLength) @@ -993,7 +1059,8 @@ func (i *IdentityStore) pathOIDCCreateUpdateClient(ctx context.Context, req *log client.ClientID = clientID } - if client.ClientSecret == "" { + // client secrets are only generated for confidential clients + if client.Type == confidential && client.ClientSecret == "" { // generate client_secret clientSecret, err := base62.Random(clientSecretLength) if err != nil { @@ -1040,7 +1107,7 @@ func (i *IdentityStore) pathOIDCReadClient(ctx context.Context, req *logical.Req return nil, nil } - return &logical.Response{ + resp := &logical.Response{ Data: map[string]interface{}{ "redirect_uris": client.RedirectURIs, "assignments": client.Assignments, @@ -1048,9 +1115,15 @@ func (i *IdentityStore) pathOIDCReadClient(ctx context.Context, req *logical.Req "id_token_ttl": int64(client.IDTokenTTL.Seconds()), "access_token_ttl": int64(client.AccessTokenTTL.Seconds()), "client_id": client.ClientID, - "client_secret": client.ClientSecret, + "client_type": client.Type.String(), }, - }, nil + } + + if client.Type == confidential { + resp.Data["client_secret"] = client.ClientSecret + } + + return resp, nil } // pathOIDCDeleteClient is used to delete a client @@ -1561,6 +1634,37 @@ func (i *IdentityStore) pathOIDCAuthorize(ctx context.Context, req *logical.Requ scopes: scopes, } + // Validate the Proof Key for Code Exchange (PKCE) code challenge and code challenge + // method. PKCE is required for public clients and optional for confidential clients. + // See details at https://datatracker.ietf.org/doc/html/rfc7636. + codeChallengeRaw, okCodeChallenge := d.GetOk("code_challenge") + if !okCodeChallenge && client.Type == public { + return authResponse("", state, ErrAuthInvalidRequest, "PKCE is required for public clients") + } + if okCodeChallenge { + codeChallenge := codeChallengeRaw.(string) + + // Validate the code challenge method + codeChallengeMethod := d.Get("code_challenge_method").(string) + switch codeChallengeMethod { + case codeChallengeMethodPlain, codeChallengeMethodS256: + case "": + codeChallengeMethod = codeChallengeMethodPlain + default: + return authResponse("", state, ErrAuthInvalidRequest, "invalid code_challenge_method") + } + + // Validate the code challenge + if len(codeChallenge) < 43 || len(codeChallenge) > 128 { + return authResponse("", state, ErrAuthInvalidRequest, "invalid code_challenge") + } + + // Associate the code challenge and method with the authorization code. + // This will be used to verify the code verifier in the token exchange. + authCodeEntry.codeChallenge = codeChallenge + authCodeEntry.codeChallengeMethod = codeChallengeMethod + } + // Validate the optional max_age parameter to check if an active re-authentication // of the user should occur. Re-authentication will be requested if the last time // the token actively authenticated exceeds the given max_age requirement. Returning @@ -1662,13 +1766,13 @@ func (i *IdentityStore) pathOIDCToken(ctx context.Context, req *logical.Request, return tokenResponse(nil, ErrTokenInvalidRequest, "provider not found") } - // Authenticate the client using the client_secret_basic authentication method. - // The authentication method uses the HTTP Basic authentication scheme. Details at - // https://openid.net/specs/openid-connect-core-1_0.html#ClientAuthentication - headerReq := &http.Request{Header: req.Headers} - clientID, clientSecret, ok := headerReq.BasicAuth() - if !ok { - return tokenResponse(nil, ErrTokenInvalidRequest, "client failed to authenticate") + // Get the client ID + clientID, clientSecret, okBasicAuth := basicAuth(req) + if !okBasicAuth { + clientID = d.Get("client_id").(string) + if clientID == "" { + return tokenResponse(nil, ErrTokenInvalidRequest, "client_id parameter is required") + } } client, err := i.clientByID(ctx, req.Storage, clientID) if err != nil { @@ -1678,7 +1782,12 @@ func (i *IdentityStore) pathOIDCToken(ctx context.Context, req *logical.Request, i.Logger().Debug("client failed to authenticate with client not found", "client_id", clientID) return tokenResponse(nil, ErrTokenInvalidClient, "client failed to authenticate") } - if subtle.ConstantTimeCompare([]byte(client.ClientSecret), []byte(clientSecret)) == 0 { + + // Authenticate the client using the client_secret_basic authentication method if it's a + // confidential client. The authentication method uses the HTTP Basic authentication scheme. + // Details at https://openid.net/specs/openid-connect-core-1_0.html#ClientAuthentication + if client.Type == confidential && + subtle.ConstantTimeCompare([]byte(client.ClientSecret), []byte(clientSecret)) == 0 { i.Logger().Debug("client failed to authenticate with invalid client secret", "client_id", clientID) return tokenResponse(nil, ErrTokenInvalidClient, "client failed to authenticate") } @@ -1771,6 +1880,28 @@ func (i *IdentityStore) pathOIDCToken(ctx context.Context, req *logical.Request, return tokenResponse(nil, ErrTokenInvalidRequest, "identity entity not authorized by client assignment") } + // Validate the PKCE code verifier. See details at + // https://datatracker.ietf.org/doc/html/rfc7636#section-4.6. + usedPKCE := authCodeUsedPKCE(authCodeEntry) + codeVerifier := d.Get("code_verifier").(string) + switch { + case !usedPKCE && client.Type == public: + return tokenResponse(nil, ErrTokenInvalidRequest, "PKCE is required for public clients") + case !usedPKCE && codeVerifier != "": + return tokenResponse(nil, ErrTokenInvalidRequest, "unexpected code_verifier for token exchange") + case usedPKCE && codeVerifier == "": + return tokenResponse(nil, ErrTokenInvalidRequest, "expected code_verifier for token exchange") + case usedPKCE: + codeChallenge, err := computeCodeChallenge(codeVerifier, authCodeEntry.codeChallengeMethod) + if err != nil { + return tokenResponse(nil, ErrTokenServerError, err.Error()) + } + + if subtle.ConstantTimeCompare([]byte(codeChallenge), []byte(authCodeEntry.codeChallenge)) == 0 { + return tokenResponse(nil, ErrTokenInvalidGrant, "invalid code_verifier for token exchange") + } + } + // The access token is a Vault batch token with a policy that only // provides access to the issuing provider's userinfo endpoint. accessTokenIssuedAt := time.Now() diff --git a/vault/identity_store_oidc_provider_test.go b/vault/identity_store_oidc_provider_test.go index af18c2792c81..bc367e5c0f57 100644 --- a/vault/identity_store_oidc_provider_test.go +++ b/vault/identity_store_oidc_provider_test.go @@ -270,6 +270,138 @@ func TestOIDC_Path_OIDC_Token(t *testing.T) { }, wantErr: ErrTokenInvalidRequest, }, + { + name: "invalid token request with empty code_verifier", + args: args{ + clientReq: testClientReq(s), + providerReq: testProviderReq(s, clientID), + assignmentReq: testAssignmentReq(s, entityID, groupID), + authorizeReq: func() *logical.Request { + req := testAuthorizeReq(s, clientID) + req.Data["code_challenge_method"] = "plain" + req.Data["code_challenge"] = "43_char_min_abcdefghijklmnopqrstuvwxyzabcde" + return req + }(), + tokenReq: func() *logical.Request { + req := testTokenReq(s, "", clientID, clientSecret) + req.Data["code_verifier"] = "" + return req + }(), + }, + wantErr: ErrTokenInvalidRequest, + }, + { + name: "invalid token request with code_verifier provided for non-PKCE flow", + args: args{ + clientReq: testClientReq(s), + providerReq: testProviderReq(s, clientID), + assignmentReq: testAssignmentReq(s, entityID, groupID), + authorizeReq: testAuthorizeReq(s, clientID), + tokenReq: func() *logical.Request { + req := testTokenReq(s, "", clientID, clientSecret) + req.Data["code_verifier"] = "pkce_not_used_in_authorize_request" + return req + }(), + }, + wantErr: ErrTokenInvalidRequest, + }, + { + name: "invalid token request with incorrect plain code_verifier", + args: args{ + clientReq: testClientReq(s), + providerReq: testProviderReq(s, clientID), + assignmentReq: testAssignmentReq(s, entityID, groupID), + authorizeReq: func() *logical.Request { + req := testAuthorizeReq(s, clientID) + req.Data["code_challenge_method"] = "plain" + req.Data["code_challenge"] = "43_char_min_abcdefghijklmnopqrstuvwxyzabcde" + return req + }(), + tokenReq: func() *logical.Request { + req := testTokenReq(s, "", clientID, clientSecret) + req.Data["code_verifier"] = "wont_match_challenge" + return req + }(), + }, + wantErr: ErrTokenInvalidGrant, + }, + { + name: "invalid token request with incorrect S256 code_verifier", + args: args{ + clientReq: testClientReq(s), + providerReq: testProviderReq(s, clientID), + assignmentReq: testAssignmentReq(s, entityID, groupID), + authorizeReq: func() *logical.Request { + req := testAuthorizeReq(s, clientID) + req.Data["code_challenge_method"] = "S256" + req.Data["code_challenge"] = "43_char_min_abcdefghijklmnopqrstuvwxyzabcde" + return req + }(), + tokenReq: func() *logical.Request { + req := testTokenReq(s, "", clientID, clientSecret) + req.Data["code_verifier"] = "wont_hash_to_challenge" + return req + }(), + }, + wantErr: ErrTokenInvalidGrant, + }, + { + name: "valid token request with plain code_challenge_method", + args: args{ + clientReq: testClientReq(s), + providerReq: testProviderReq(s, clientID), + assignmentReq: testAssignmentReq(s, entityID, groupID), + authorizeReq: func() *logical.Request { + req := testAuthorizeReq(s, clientID) + req.Data["code_challenge_method"] = "plain" + req.Data["code_challenge"] = "43_char_min_abcdefghijklmnopqrstuvwxyzabcde" + return req + }(), + tokenReq: func() *logical.Request { + req := testTokenReq(s, "", clientID, clientSecret) + req.Data["code_verifier"] = "43_char_min_abcdefghijklmnopqrstuvwxyzabcde" + return req + }(), + }, + }, + { + name: "valid token request with default plain code_challenge_method", + args: args{ + clientReq: testClientReq(s), + providerReq: testProviderReq(s, clientID), + assignmentReq: testAssignmentReq(s, entityID, groupID), + authorizeReq: func() *logical.Request { + // code_challenge_method intentionally not provided + req := testAuthorizeReq(s, clientID) + req.Data["code_challenge"] = "43_char_min_abcdefghijklmnopqrstuvwxyzabcde" + return req + }(), + tokenReq: func() *logical.Request { + req := testTokenReq(s, "", clientID, clientSecret) + req.Data["code_verifier"] = "43_char_min_abcdefghijklmnopqrstuvwxyzabcde" + return req + }(), + }, + }, + { + name: "valid token request with S256 code_challenge_method", + args: args{ + clientReq: testClientReq(s), + providerReq: testProviderReq(s, clientID), + assignmentReq: testAssignmentReq(s, entityID, groupID), + authorizeReq: func() *logical.Request { + req := testAuthorizeReq(s, clientID) + req.Data["code_challenge_method"] = "S256" + req.Data["code_challenge"] = "hMn-5TBH-t3uN00FEaGsQtYPhyC4Otbx-9vDcPTYHmc" + return req + }(), + tokenReq: func() *logical.Request { + req := testTokenReq(s, "", clientID, clientSecret) + req.Data["code_verifier"] = "43_char_min_abcdefghijklmnopqrstuvwxyzabcde" + return req + }(), + }, + }, { name: "valid token request with max_age and auth_time claim", args: args{ @@ -712,6 +844,58 @@ func TestOIDC_Path_OIDC_Authorize(t *testing.T) { }, wantErr: ErrAuthInvalidRequest, }, + { + name: "invalid authorize request with invalid code_challenge_method", + args: args{ + entityID: entityID, + clientReq: testClientReq(s), + providerReq: testProviderReq(s, clientID), + assignmentReq: testAssignmentReq(s, entityID, groupID), + authorizeReq: func() *logical.Request { + req := testAuthorizeReq(s, clientID) + req.Data["code_challenge_method"] = "S512" + req.Data["code_challenge"] = "43_char_min_abcdefghijklmnopqrstuvwxyzabcde" + return req + }(), + }, + wantErr: ErrAuthInvalidRequest, + }, + { + name: "invalid authorize request with code_challenge length < 43 characters", + args: args{ + entityID: entityID, + clientReq: testClientReq(s), + providerReq: testProviderReq(s, clientID), + assignmentReq: testAssignmentReq(s, entityID, groupID), + authorizeReq: func() *logical.Request { + req := testAuthorizeReq(s, clientID) + req.Data["code_challenge_method"] = "S256" + req.Data["code_challenge"] = "" + return req + }(), + }, + wantErr: ErrAuthInvalidRequest, + }, + { + name: "invalid authorize request with code_challenge length > 128 characters", + args: args{ + entityID: entityID, + clientReq: testClientReq(s), + providerReq: testProviderReq(s, clientID), + assignmentReq: testAssignmentReq(s, entityID, groupID), + authorizeReq: func() *logical.Request { + req := testAuthorizeReq(s, clientID) + req.Data["code_challenge_method"] = "S256" + req.Data["code_challenge"] = ` + 129_char_abcdefghijklmnopqrstuvwxyzabcd + 129_char_abcdefghijklmnopqrstuvwxyzabcd + 129_char_abcdefghijklmnopqrstuvwxyzabcd + ` + return req + }(), + }, + wantErr: ErrAuthInvalidRequest, + }, { name: "valid authorize request with empty nonce", args: args{ @@ -1323,6 +1507,127 @@ func TestOIDC_Path_OIDC_ProviderReadPublicKey(t *testing.T) { } } +func TestOIDC_Path_OIDC_Client_Type(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + ctx := namespace.RootContext(nil) + storage := &logical.InmemStorage{} + + resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/key/test-key", + Operation: logical.CreateOperation, + Storage: storage, + }) + expectSuccess(t, resp, err) + + tests := []struct { + name string + createClientType clientType + updateClientType clientType + wantCreateErr bool + wantUpdateErr bool + }{ + { + name: "create confidential client and update to public client", + createClientType: confidential, + updateClientType: public, + wantUpdateErr: true, + }, + { + name: "create confidential client and update to confidential client", + createClientType: confidential, + updateClientType: confidential, + }, + { + name: "create public client and update to confidential client", + createClientType: public, + updateClientType: confidential, + wantUpdateErr: true, + }, + { + name: "create public client and update to public client", + createClientType: public, + updateClientType: public, + }, + { + name: "create an invalid client type", + createClientType: clientType(300), + wantCreateErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a client with the given client type + resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/client/test-client", + Operation: logical.CreateOperation, + Storage: storage, + Data: map[string]interface{}{ + "key": "test-key", + "client_type": tt.createClientType.String(), + }, + }) + if tt.wantCreateErr { + expectError(t, resp, err) + return + } + expectSuccess(t, resp, err) + + // Read the client + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/client/test-client", + Operation: logical.ReadOperation, + Storage: storage, + }) + expectSuccess(t, resp, err) + + // Assert that the client type is properly set + clientType := resp.Data["client_type"].(string) + require.Equal(t, tt.createClientType.String(), clientType) + + // Assert that all client types have a client ID + clientID := resp.Data["client_id"].(string) + require.Len(t, clientID, clientIDLength) + + // Assert that confidential clients have a client secret + if tt.createClientType == confidential { + clientSecret := resp.Data["client_secret"].(string) + require.Contains(t, clientSecret, clientSecretPrefix) + } + + // Assert that public clients do not have a client secret + if tt.createClientType == public { + _, ok := resp.Data["client_secret"] + require.False(t, ok) + } + + // Update the client and expect error if the type is different + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/client/test-client", + Operation: logical.UpdateOperation, + Storage: storage, + Data: map[string]interface{}{ + "key": "test-key", + "client_type": tt.updateClientType.String(), + }, + }) + if tt.wantUpdateErr { + expectError(t, resp, err) + } else { + expectSuccess(t, resp, err) + } + + // Delete the client + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/client/test-client", + Operation: logical.DeleteOperation, + Storage: storage, + }) + expectSuccess(t, resp, err) + }) + } +} + // TestOIDC_Path_OIDC_ProviderClient_NoKeyParameter tests that a client cannot // be created without a key parameter func TestOIDC_Path_OIDC_ProviderClient_NoKeyParameter(t *testing.T) { @@ -1554,6 +1859,7 @@ func TestOIDC_Path_OIDC_ProviderClient(t *testing.T) { "access_token_ttl": int64(86400), "client_id": resp.Data["client_id"], "client_secret": resp.Data["client_secret"], + "client_type": confidential.String(), } if diff := deep.Equal(expected, resp.Data); diff != nil { t.Fatal(diff) @@ -1608,6 +1914,7 @@ func TestOIDC_Path_OIDC_ProviderClient(t *testing.T) { "access_token_ttl": int64(60), "client_id": resp.Data["client_id"], "client_secret": resp.Data["client_secret"], + "client_type": confidential.String(), } if diff := deep.Equal(expected, resp.Data); diff != nil { t.Fatal(diff) @@ -1667,6 +1974,7 @@ func TestOIDC_Path_OIDC_ProviderClient_Deduplication(t *testing.T) { "id_token_ttl": "1m", "assignments": []string{"test-assignment1", "test-assignment1"}, "redirect_uris": []string{"http://example.com", "http://notduplicate.com", "http://example.com"}, + "client_type": public.String(), }, }) expectSuccess(t, resp, err) @@ -1685,7 +1993,7 @@ func TestOIDC_Path_OIDC_ProviderClient_Deduplication(t *testing.T) { "id_token_ttl": int64(60), "access_token_ttl": int64(86400), "client_id": resp.Data["client_id"], - "client_secret": resp.Data["client_secret"], + "client_type": public.String(), } if diff := deep.Equal(expected, resp.Data); diff != nil { t.Fatal(diff) @@ -1747,6 +2055,7 @@ func TestOIDC_Path_OIDC_ProviderClient_Update(t *testing.T) { "access_token_ttl": int64(3600), "client_id": resp.Data["client_id"], "client_secret": resp.Data["client_secret"], + "client_type": confidential.String(), } if diff := deep.Equal(expected, resp.Data); diff != nil { t.Fatal(diff) @@ -1780,6 +2089,7 @@ func TestOIDC_Path_OIDC_ProviderClient_Update(t *testing.T) { "access_token_ttl": int64(60), "client_id": resp.Data["client_id"], "client_secret": resp.Data["client_secret"], + "client_type": confidential.String(), } if diff := deep.Equal(expected, resp.Data); diff != nil { t.Fatal(diff) diff --git a/vault/identity_store_oidc_provider_util.go b/vault/identity_store_oidc_provider_util.go index 76819d813b07..8f5f99b16569 100644 --- a/vault/identity_store_oidc_provider_util.go +++ b/vault/identity_store_oidc_provider_util.go @@ -6,9 +6,11 @@ import ( "encoding/base64" "fmt" "hash" + "net/http" "net/url" "github.com/hashicorp/go-secure-stdlib/strutil" + "github.com/hashicorp/vault/sdk/logical" "gopkg.in/square/go-jose.v2" ) @@ -75,3 +77,31 @@ func computeHashClaim(alg string, input string) (string, error) { sum := h.Sum(nil) return base64.RawURLEncoding.EncodeToString(sum[:len(sum)/2]), nil } + +// computeCodeChallenge computes a Proof Key for Code Exchange (PKCE) +// code challenge given a code verifier and code challenge method. +func computeCodeChallenge(verifier string, method string) (string, error) { + switch method { + case codeChallengeMethodPlain: + return verifier, nil + case codeChallengeMethodS256: + hf := sha256.New() + hf.Write([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(hf.Sum(nil)), nil + default: + return "", fmt.Errorf("invalid code challenge method %q", method) + } +} + +// authCodeUsedPKCE returns true if the given entry was granted using PKCE. +func authCodeUsedPKCE(entry *authCodeCacheEntry) bool { + return entry.codeChallenge != "" && entry.codeChallengeMethod != "" +} + +// basicAuth returns the username/password provided in the logical.Request's +// authorization header and a bool indicating if the request used basic +// authentication. +func basicAuth(req *logical.Request) (string, string, bool) { + headerReq := &http.Request{Header: req.Headers} + return headerReq.BasicAuth() +}