diff --git a/changelog/12780.txt b/changelog/12780.txt new file mode 100644 index 000000000000..61a2c5d4f8cc --- /dev/null +++ b/changelog/12780.txt @@ -0,0 +1,3 @@ +```release-note:improvement +identity/token: Only return keys from the `.well-known/keys` endpoint that are being used by roles to sign/verify tokens. +``` diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index 8dd8be260688..ce4d628b869e 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -613,6 +613,27 @@ func (i *IdentityStore) pathOIDCReadKey(ctx context.Context, req *logical.Reques }, nil } +// keyIDsByName will return a slice of key IDs for the given key name +func (i *IdentityStore) keyIDsByName(ctx context.Context, s logical.Storage, name string) ([]string, error) { + var keyIDs []string + entry, err := s.Get(ctx, namedKeyConfigPath+name) + if err != nil { + return keyIDs, err + } + if entry == nil { + return keyIDs, nil + } + + var key namedKey + if err := entry.DecodeJSON(&key); err != nil { + return keyIDs, err + } + for _, k := range key.KeyRing { + keyIDs = append(keyIDs, k.KeyID) + } + return keyIDs, nil +} + // rolesReferencingTargetKeyName returns a map of role names to roles // referencing targetKeyName. // @@ -1538,21 +1559,37 @@ func (i *IdentityStore) generatePublicJWKS(ctx context.Context, s logical.Storag return nil, err } - keyIDs, err := listOIDCPublicKeys(ctx, s) + jwks := &jose.JSONWebKeySet{ + Keys: make([]jose.JSONWebKey, 0), + } + + // only return keys that are associated with a role + roleNames, err := s.List(ctx, roleConfigPath) if err != nil { return nil, err } - jwks := &jose.JSONWebKeySet{ - Keys: make([]jose.JSONWebKey, 0, len(keyIDs)), - } + for _, roleName := range roleNames { + role, err := i.getOIDCRole(ctx, s, roleName) + if err != nil { + return nil, err + } + if role == nil { + continue + } - for _, keyID := range keyIDs { - key, err := loadOIDCPublicKey(ctx, s, keyID) + keyIDs, err := i.keyIDsByName(ctx, s, role.Key) if err != nil { return nil, err } - jwks.Keys = append(jwks.Keys, *key) + + for _, keyID := range keyIDs { + key, err := loadOIDCPublicKey(ctx, s, keyID) + if err != nil { + return nil, err + } + jwks.Keys = append(jwks.Keys, *key) + } } if err := i.oidcCache.SetDefault(ns, "jwks", jwks); err != nil { diff --git a/vault/identity_store_oidc_test.go b/vault/identity_store_oidc_test.go index 45a5da3ee9ef..ba63a940f896 100644 --- a/vault/identity_store_oidc_test.go +++ b/vault/identity_store_oidc_test.go @@ -1,6 +1,7 @@ package vault import ( + "context" "crypto/rand" "crypto/rsa" "encoding/json" @@ -637,37 +638,73 @@ func TestOIDC_Path_OIDCKey_DeleteWithExistingClient(t *testing.T) { expectError(t, resp, err) } -// TestOIDC_PublicKeys tests that public keys are updated by -// key creation, rotation, and deletion -func TestOIDC_PublicKeys(t *testing.T) { +// TestOIDC_PublicKeys_NoRole tests that public keys are not returned by the +// oidc/.well-known/keys endpoint when they are not associated with a role +func TestOIDC_PublicKeys_NoRole(t *testing.T) { c, _, _ := TestCoreUnsealed(t) ctx := namespace.RootContext(nil) - storage := &logical.InmemStorage{} + s := &logical.InmemStorage{} // Create a test key "test-key" - c.identityStore.HandleRequest(ctx, &logical.Request{ + resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{ Path: "oidc/key/test-key", Operation: logical.CreateOperation, - Storage: storage, + Storage: s, }) + expectSuccess(t, resp, err) - // .well-known/keys should contain 2 public keys + // .well-known/keys should contain 0 public keys + assertPublicKeyCount(t, ctx, s, c, 0) +} + +func assertPublicKeyCount(t *testing.T, ctx context.Context, s logical.Storage, c *Core, keyCount int) { + t.Helper() + + // .well-known/keys should contain keyCount public keys resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{ Path: "oidc/.well-known/keys", Operation: logical.ReadOperation, - Storage: storage, + Storage: s, }) expectSuccess(t, resp, err) // parse response responseJWKS := &jose.JSONWebKeySet{} json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS) - if len(responseJWKS.Keys) != 2 { - t.Fatalf("expected 2 public keys but instead got %d", len(responseJWKS.Keys)) + if len(responseJWKS.Keys) != keyCount { + t.Fatalf("expected %d public keys but instead got %d", keyCount, len(responseJWKS.Keys)) } +} + +// TestOIDC_PublicKeys tests that public keys are updated by +// key creation, rotation, and deletion +func TestOIDC_PublicKeys(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + ctx := namespace.RootContext(nil) + storage := &logical.InmemStorage{} + + // Create a test key "test-key" + c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/key/test-key", + Operation: logical.CreateOperation, + Storage: storage, + }) + + // Create a test role "test-role" + c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/role/test-role", + Operation: logical.CreateOperation, + Data: map[string]interface{}{ + "key": "test-key", + }, + Storage: storage, + }) + + // .well-known/keys should contain 2 public keys + assertPublicKeyCount(t, ctx, storage, c, 2) // rotate test-key a few times, each rotate should increase the length of public keys returned // by the .well-known endpoint - resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{ Path: "oidc/key/test-key/rotate", Operation: logical.UpdateOperation, Storage: storage, @@ -681,45 +718,47 @@ func TestOIDC_PublicKeys(t *testing.T) { expectSuccess(t, resp, err) // .well-known/keys should contain 4 public keys + assertPublicKeyCount(t, ctx, storage, c, 4) + + // create another named key "test-key2" resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ - Path: "oidc/.well-known/keys", - Operation: logical.ReadOperation, + Path: "oidc/key/test-key2", + Operation: logical.CreateOperation, Storage: storage, }) expectSuccess(t, resp, err) - // parse response - json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS) - if len(responseJWKS.Keys) != 4 { - t.Fatalf("expected 4 public keys but instead got %d", len(responseJWKS.Keys)) - } - // create another named key - c.identityStore.HandleRequest(ctx, &logical.Request{ - Path: "oidc/key/test-key2", + // Create a test role "test-role2" + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/role/test-role2", Operation: logical.CreateOperation, - Storage: storage, + Data: map[string]interface{}{ + "key": "test-key2", + }, + Storage: storage, }) + expectSuccess(t, resp, err) + // .well-known/keys should contain 6 public keys + assertPublicKeyCount(t, ctx, storage, c, 6) - // delete test key - c.identityStore.HandleRequest(ctx, &logical.Request{ - Path: "oidc/key/test-key", + // delete test role that references "test-key" + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/role/test-role", Operation: logical.DeleteOperation, Storage: storage, }) - - // .well-known/keys should contain 2 public key, all of the public keys - // from named key "test-key" should have been deleted + expectSuccess(t, resp, err) + // delete test key resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ - Path: "oidc/.well-known/keys", - Operation: logical.ReadOperation, + Path: "oidc/key/test-key", + Operation: logical.DeleteOperation, Storage: storage, }) expectSuccess(t, resp, err) - // parse response - json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS) - if len(responseJWKS.Keys) != 2 { - t.Fatalf("expected 2 public keys but instead got %d", len(responseJWKS.Keys)) - } + + // .well-known/keys should contain 2 public keys, all of the public keys + // from named key "test-key" should have been deleted + assertPublicKeyCount(t, ctx, storage, c, 2) } // TestOIDC_SignIDToken tests acquiring a signed token and verifying the public portion