Skip to content

Commit

Permalink
identity: adds generation of plugin identity tokens (#25219)
Browse files Browse the repository at this point in the history
* adds generation of plugin identity tokens

* adds constants

* fix namespace path when getting matching identity storage

* adds changelog

* adds godoc on test

* fix data race with default key generation by moving locks up

* Update changelog/25219.txt

Co-authored-by: Tom Proctor <[email protected]>

* use namespace from context instead of mount entry

* translate mount table entry from mounts to secret

* godoc on test

---------

Co-authored-by: Tom Proctor <[email protected]>
  • Loading branch information
austingebauer and tomhjp authored Feb 6, 2024
1 parent 003ee1c commit 98bffbe
Show file tree
Hide file tree
Showing 10 changed files with 526 additions and 68 deletions.
3 changes: 3 additions & 0 deletions changelog/25219.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:feature
**Plugin Workload Identity**: Vault can generate identity tokens for plugins to use in workload identity federation auth flows.
```
17 changes: 13 additions & 4 deletions vault/dynamic_system_view.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,19 @@ func (d dynamicSystemView) ClusterID(ctx context.Context) (string, error) {
return clusterInfo.ID, nil
}

func (d dynamicSystemView) GenerateIdentityToken(_ context.Context, _ *pluginutil.IdentityTokenRequest) (*pluginutil.IdentityTokenResponse, error) {
// TODO: implement plugin identity token generation using identity store
func (d dynamicSystemView) GenerateIdentityToken(ctx context.Context, req *pluginutil.IdentityTokenRequest) (*pluginutil.IdentityTokenResponse, error) {
storage := d.core.router.MatchingStorageByAPIPath(ctx, mountPathIdentity)
if storage == nil {
return nil, fmt.Errorf("failed to find storage entry for identity mount")
}

token, ttl, err := d.core.IdentityStore().generatePluginIdentityToken(ctx, storage, d.mountEntry, req.Audience, req.TTL)
if err != nil {
return nil, fmt.Errorf("failed to generate plugin identity token: %w", err)
}

return &pluginutil.IdentityTokenResponse{
Token: "unimplemented",
TTL: time.Duration(0),
Token: pluginutil.IdentityToken(token),
TTL: ttl,
}, nil
}
1 change: 1 addition & 0 deletions vault/identity_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ func NewIdentityStore(ctx context.Context, core *Core, config *logical.BackendCo
groupUpdater: core,
tokenStorer: core,
entityCreator: core,
mountLister: core,
mfaBackend: core.loginMFABackend,
}

Expand Down
236 changes: 208 additions & 28 deletions vault/identity_store_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/hashicorp/vault/sdk/logical"
"github.com/patrickmn/go-cache"
"golang.org/x/crypto/ed25519"
"golang.org/x/exp/maps"
)

type oidcConfig struct {
Expand Down Expand Up @@ -126,23 +127,12 @@ type oidcCache struct {
c *cache.Cache
}

var errNilNamespace = errors.New("nil namespace in oidc cache request")

const (
issuerPath = "identity/oidc"
oidcTokensPrefix = "oidc_tokens/"
namedKeyCachePrefix = "namedKeys/"
oidcConfigStorageKey = oidcTokensPrefix + "config/"
namedKeyConfigPath = oidcTokensPrefix + "named_keys/"
publicKeysConfigPath = oidcTokensPrefix + "public_keys/"
roleConfigPath = oidcTokensPrefix + "roles/"
var (
errNilNamespace = errors.New("nil namespace in oidc cache request")

// Identity tokens have a base issuer and plugin issuer
baseIdentityTokenIssuer = ""
pluginIdentityTokenIssuer = "plugins"
)
// pseudo-namespace for cache items that don't belong to any real namespace.
noNamespace = &namespace.Namespace{ID: "__NO_NAMESPACE"}

var (
reservedClaims = []string{
"iat", "aud", "exp", "iss",
"sub", "namespace", "nonce",
Expand All @@ -159,8 +149,24 @@ var (
}
)

// pseudo-namespace for cache items that don't belong to any real namespace.
var noNamespace = &namespace.Namespace{ID: "__NO_NAMESPACE"}
const (
issuerPath = "identity/oidc"
oidcTokensPrefix = "oidc_tokens/"
namedKeyCachePrefix = "namedKeys/"
oidcConfigStorageKey = oidcTokensPrefix + "config/"
namedKeyConfigPath = oidcTokensPrefix + "named_keys/"
publicKeysConfigPath = oidcTokensPrefix + "public_keys/"
roleConfigPath = oidcTokensPrefix + "roles/"

// Identity tokens have a base issuer and plugin issuer
baseIdentityTokenIssuer = ""
pluginIdentityTokenIssuer = "plugins"

pluginTokenSubjectPrefix = "plugin-identity"
pluginTokenPrivateClaimKey = "vaultproject.io"
secretTableValue = "secret"
deleteKeyErrorFmt = "unable to delete key %q because it is currently referenced by these %s: %s"
)

// optionalChildIssuerRegex is a regex for optionally accepting a field in an
// API request as a single path segment. Adapted from framework.OptionalParamRegex
Expand Down Expand Up @@ -784,6 +790,56 @@ func (i *IdentityStore) roleNamesReferencingTargetKeyName(ctx context.Context, r
return names, nil
}

// listMounts returns all mount entries in the namespace.
// Returns an error if the namespace is nil.
func (i *IdentityStore) listMounts(ns *namespace.Namespace) ([]*MountEntry, error) {
if ns == nil {
return nil, errors.New("namespace must not be nil")
}

secretMounts, err := i.mountLister.ListMounts()
if err != nil {
return nil, err
}
authMounts, err := i.mountLister.ListAuths()
if err != nil {
return nil, err
}

var allMounts []*MountEntry
for _, mount := range append(authMounts, secretMounts...) {
if mount.NamespaceID == ns.ID {
allMounts = append(allMounts, mount)
}
}

return allMounts, nil
}

// mountsReferencingKey returns a sorted list of all mount entry paths referencing
// the key in the namespace. Returns an error if the namespace is nil.
func (i *IdentityStore) mountsReferencingKey(ns *namespace.Namespace, key string) ([]string, error) {
if ns == nil {
return nil, errors.New("namespace must not be nil")
}

allMounts, err := i.listMounts(ns)
if err != nil {
return nil, err
}

pathsWithKey := make(map[string]struct{})
for _, mount := range allMounts {
if mount.Config.IdentityTokenKey == key {
pathsWithKey[mount.Path] = struct{}{}
}
}

paths := maps.Keys(pathsWithKey)
sort.Strings(paths)
return paths, nil
}

// handleOIDCDeleteKey is used to delete a key
func (i *IdentityStore) pathOIDCDeleteKey(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
ns, err := namespace.FromContext(ctx)
Expand All @@ -807,8 +863,8 @@ func (i *IdentityStore) pathOIDCDeleteKey(ctx context.Context, req *logical.Requ
}

if len(roleNames) > 0 {
errorMessage := fmt.Sprintf("unable to delete key %q because it is currently referenced by these roles: %s",
targetKeyName, strings.Join(roleNames, ", "))
errorMessage := fmt.Sprintf(deleteKeyErrorFmt,
targetKeyName, "roles", strings.Join(roleNames, ", "))
i.oidcLock.Unlock()
return logical.ErrorResponse(errorMessage), logical.ErrInvalidRequest
}
Expand All @@ -820,8 +876,20 @@ func (i *IdentityStore) pathOIDCDeleteKey(ctx context.Context, req *logical.Requ
}

if len(clientNames) > 0 {
errorMessage := fmt.Sprintf("unable to delete key %q because it is currently referenced by these clients: %s",
targetKeyName, strings.Join(clientNames, ", "))
errorMessage := fmt.Sprintf(deleteKeyErrorFmt,
targetKeyName, "clients", strings.Join(clientNames, ", "))
i.oidcLock.Unlock()
return logical.ErrorResponse(errorMessage), logical.ErrInvalidRequest
}

mounts, err := i.mountsReferencingKey(ns, targetKeyName)
if err != nil {
i.oidcLock.Unlock()
return nil, err
}
if len(mounts) > 0 {
errorMessage := fmt.Sprintf(deleteKeyErrorFmt,
targetKeyName, "mounts", strings.Join(mounts, ", "))
i.oidcLock.Unlock()
return logical.ErrorResponse(errorMessage), logical.ErrInvalidRequest
}
Expand Down Expand Up @@ -1028,6 +1096,99 @@ func (i *IdentityStore) pathOIDCGenerateToken(ctx context.Context, req *logical.
return retResp, nil
}

func (i *IdentityStore) generatePluginIdentityToken(ctx context.Context, storage logical.Storage, me *MountEntry, audience string, ttl time.Duration) (string, time.Duration, error) {
ns, err := namespace.FromContext(ctx)
if err != nil {
return "", 0, err
}

if me == nil {
i.Logger().Error("unexpected nil mount entry when generating plugin identity token")
return "", 0, errors.New("mount entry must not be nil")
}

key := defaultKeyName
if me.Config.IdentityTokenKey != "" {
key = me.Config.IdentityTokenKey
}
if ttl == 0 {
ttl = time.Hour
}
namedKey, err := i.getNamedKey(ctx, storage, key)
if err != nil {
return "", 0, err
}
if namedKey == nil {
return "", 0, fmt.Errorf("key %q not found", key)
}

// Validate that the role is allowed to sign with its key (the key could have been updated)
if !strutil.StrListContains(namedKey.AllowedClientIDs, "*") && !strutil.StrListContains(namedKey.AllowedClientIDs, audience) {
return "", 0, fmt.Errorf("the key %q does not list %q as an allowed audience", key, audience)
}

config, err := i.getOIDCConfig(ctx, storage)
if err != nil {
return "", 0, err
}

// Cap the TTL to the key's verification TTL. This is the maximum amount of
// time the key will remain in the JWKS after it's been rotated.
if ttl > namedKey.VerificationTTL {
ttl = namedKey.VerificationTTL
}

// Tokens for plugins have a distinct issuer from Vault's identity token issuer
issuer, err := config.fullIssuer(pluginIdentityTokenIssuer)
if err != nil {
return "", 0, err
}

// The subject uniquely identifies the plugin
subject := fmt.Sprintf("%s:%s:%s:%s", pluginTokenSubjectPrefix, ns.ID,
translateTableClaim(me.Table), me.Accessor)

now := time.Now()
claims := map[string]any{
"iss": issuer,
"sub": subject,
"aud": []string{audience},
"nbf": now.Unix(),
"iat": now.Unix(),
"exp": now.Add(ttl).Unix(),
pluginTokenPrivateClaimKey: map[string]any{
"namespace_id": ns.ID,
"namespace_path": ns.Path,
"class": translateTableClaim(me.Table),
"plugin": me.Type,
"version": me.RunningVersion,
"path": me.Path,
"accessor": me.Accessor,
"local": me.Local,
},
}
payload, err := json.Marshal(claims)
if err != nil {
return "", 0, err
}

signedToken, err := namedKey.signPayload(payload)
if err != nil {
return "", 0, fmt.Errorf("error signing plugin identity token: %w", err)
}

return signedToken, ttl, nil
}

func translateTableClaim(table string) string {
switch table {
case mountTableType:
return secretTableValue
default:
return table
}
}

func (i *IdentityStore) getNamedKey(ctx context.Context, s logical.Storage, name string) (*namedKey, error) {
ns, err := namespace.FromContext(ctx)
if err != nil {
Expand Down Expand Up @@ -1804,14 +1965,16 @@ func (i *IdentityStore) generatePublicJWKS(ctx context.Context, s logical.Storag
return nil, err
}

// only return keys that are associated with a role
// Only return keys that are associated with a role or plugin mount
// by collecting and de-duplicating keys and key IDs for each
keyNames := make(map[string]struct{})
keyIDs := make(map[string]struct{})

// First collect the set of unique key names
roleNames, err := s.List(ctx, roleConfigPath)
if err != nil {
return nil, err
}

// collect and deduplicate the key IDs for all roles
keyIDs := make(map[string]struct{})
for _, roleName := range roleNames {
role, err := i.getOIDCRole(ctx, s, roleName)
if err != nil {
Expand All @@ -1821,13 +1984,30 @@ func (i *IdentityStore) generatePublicJWKS(ctx context.Context, s logical.Storag
continue
}

roleKeyIDs, err := i.keyIDsByName(ctx, s, role.Key)
keyNames[role.Key] = struct{}{}
}
mounts, err := i.listMounts(ns)
if err != nil {
return nil, err
}
for _, me := range mounts {
key := defaultKeyName
if me.Config.IdentityTokenKey != "" {
key = me.Config.IdentityTokenKey
}

keyNames[key] = struct{}{}
}

// Second collect the set of unique key IDs for each key name
for name := range keyNames {
ids, err := i.keyIDsByName(ctx, s, name)
if err != nil {
return nil, err
}

for _, keyID := range roleKeyIDs {
keyIDs[keyID] = struct{}{}
for _, id := range ids {
keyIDs[id] = struct{}{}
}
}

Expand Down
8 changes: 4 additions & 4 deletions vault/identity_store_oidc_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2629,17 +2629,17 @@ func (i *IdentityStore) lazyGenerateDefaultKey(ctx context.Context, storage logi
return err
}

if err := i.oidcCache.Delete(ns, namedKeyCachePrefix+defaultKeyName); err != nil {
return err
}

entry, err := logical.StorageEntryJSON(namedKeyConfigPath+defaultKeyName, defaultKey)
if err != nil {
return err
}
if err := storage.Put(ctx, entry); err != nil {
return err
}

if err := i.oidcCache.Flush(ns); err != nil {
return err
}
}

return nil
Expand Down
7 changes: 4 additions & 3 deletions vault/identity_store_oidc_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1176,7 +1176,8 @@ func setupOIDCCommon(t *testing.T, c *Core, s logical.Storage) (string, string,
ctx := namespace.RootContext(nil)

// Create a key
resp, err := c.identityStore.HandleRequest(ctx, testKeyReq(s, []string{"*"}, "RS256"))
resp, err := c.identityStore.HandleRequest(ctx, testKeyReq(s, "test-key",
[]string{"*"}, "RS256"))
expectSuccess(t, resp, err)

// Create an entity
Expand Down Expand Up @@ -1359,10 +1360,10 @@ func testEntityReq(s logical.Storage) *logical.Request {
}
}

func testKeyReq(s logical.Storage, allowedClientIDs []string, alg string) *logical.Request {
func testKeyReq(s logical.Storage, name string, allowedClientIDs []string, alg string) *logical.Request {
return &logical.Request{
Storage: s,
Path: "oidc/key/test-key",
Path: fmt.Sprintf("oidc/key/%s", name),
Operation: logical.CreateOperation,
Data: map[string]interface{}{
"allowed_client_ids": allowedClientIDs,
Expand Down
Loading

0 comments on commit 98bffbe

Please sign in to comment.