From eba8917d095b053f4e315cf37e3608903a4e029e Mon Sep 17 00:00:00 2001 From: vinay-gopalan <86625824+vinay-gopalan@users.noreply.github.com> Date: Mon, 29 Jan 2024 11:34:57 -0800 Subject: [PATCH] WIF support for AWS secrets engine (#24987) * add new plugin wif fields to AWS Secrets Engine * add changelog * go get awsutil v0.3.0 * fix up changelog * fix test and field parsing helper * godoc on new test * require role arn when audience set * make fmt --------- Co-authored-by: Austin Gebauer Co-authored-by: Austin Gebauer <34121980+austingebauer@users.noreply.github.com> --- builtin/logical/aws/backend.go | 4 +- builtin/logical/aws/client.go | 71 ++++++++++- builtin/logical/aws/path_config_root.go | 36 +++++- builtin/logical/aws/path_config_root_test.go | 114 ++++++++++++++++-- changelog/24987.txt | 3 + go.mod | 2 +- go.sum | 4 +- sdk/helper/pluginidentityutil/fields.go | 7 -- sdk/helper/pluginidentityutil/fields_test.go | 17 ++- .../raft/raft_binary/raft_test.go | 1 - 10 files changed, 224 insertions(+), 35 deletions(-) create mode 100644 changelog/24987.txt diff --git a/builtin/logical/aws/backend.go b/builtin/logical/aws/backend.go index ed8ac00c9dff..b33fb1b4d693 100644 --- a/builtin/logical/aws/backend.go +++ b/builtin/logical/aws/backend.go @@ -141,7 +141,7 @@ func (b *backend) clientIAM(ctx context.Context, s logical.Storage) (iamiface.IA return b.iamClient, nil } - iamClient, err := nonCachedClientIAM(ctx, s, b.Logger()) + iamClient, err := b.nonCachedClientIAM(ctx, s, b.Logger()) if err != nil { return nil, err } @@ -168,7 +168,7 @@ func (b *backend) clientSTS(ctx context.Context, s logical.Storage) (stsiface.ST return b.stsClient, nil } - stsClient, err := nonCachedClientSTS(ctx, s, b.Logger()) + stsClient, err := b.nonCachedClientSTS(ctx, s, b.Logger()) if err != nil { return nil, err } diff --git a/builtin/logical/aws/client.go b/builtin/logical/aws/client.go index 33dc86c51781..c65b2469eaf2 100644 --- a/builtin/logical/aws/client.go +++ b/builtin/logical/aws/client.go @@ -7,19 +7,25 @@ import ( "context" "fmt" "os" + "strconv" + "time" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/sts" - cleanhttp "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-secure-stdlib/awsutil" + + "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/logical" ) // NOTE: The caller is required to ensure that b.clientMutex is at least read locked -func getRootConfig(ctx context.Context, s logical.Storage, clientType string, logger hclog.Logger) (*aws.Config, error) { +func (b *backend) getRootConfig(ctx context.Context, s logical.Storage, clientType string, logger hclog.Logger) (*aws.Config, error) { credsConfig := &awsutil.CredentialsConfig{} var endpoint string var maxRetries int = aws.UseServiceDefaultRetries @@ -44,6 +50,26 @@ func getRootConfig(ctx context.Context, s logical.Storage, clientType string, lo case clientType == "sts" && config.STSEndpoint != "": endpoint = *aws.String(config.STSEndpoint) } + + if config.IdentityTokenAudience != "" { + ns, err := namespace.FromContext(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get namespace from context: %w", err) + } + + fetcher := &PluginIdentityTokenFetcher{ + sys: b.System(), + logger: b.Logger(), + ns: ns, + audience: config.IdentityTokenAudience, + ttl: config.IdentityTokenTTL, + } + + sessionSuffix := strconv.FormatInt(time.Now().UnixNano(), 10) + credsConfig.RoleSessionName = fmt.Sprintf("vault-aws-secrets-%s", sessionSuffix) + credsConfig.WebIdentityTokenFetcher = fetcher + credsConfig.RoleARN = config.RoleARN + } } if credsConfig.Region == "" { @@ -74,8 +100,8 @@ func getRootConfig(ctx context.Context, s logical.Storage, clientType string, lo }, nil } -func nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger) (*iam.IAM, error) { - awsConfig, err := getRootConfig(ctx, s, "iam", logger) +func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger) (*iam.IAM, error) { + awsConfig, err := b.getRootConfig(ctx, s, "iam", logger) if err != nil { return nil, err } @@ -90,8 +116,8 @@ func nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Log return client, nil } -func nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Logger) (*sts.STS, error) { - awsConfig, err := getRootConfig(ctx, s, "sts", logger) +func (b *backend) nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Logger) (*sts.STS, error) { + awsConfig, err := b.getRootConfig(ctx, s, "sts", logger) if err != nil { return nil, err } @@ -105,3 +131,36 @@ func nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Log } return client, nil } + +// PluginIdentityTokenFetcher fetches plugin identity tokens from Vault. It is provided +// to the AWS SDK client to keep assumed role credentials refreshed through expiration. +// When the client's STS credentials expire, it will use this interface to fetch a new +// plugin identity token and exchange it for new STS credentials. +type PluginIdentityTokenFetcher struct { + sys logical.SystemView + logger hclog.Logger + audience string + ns *namespace.Namespace + ttl time.Duration +} + +var _ stscreds.TokenFetcher = (*PluginIdentityTokenFetcher)(nil) + +func (f PluginIdentityTokenFetcher) FetchToken(ctx aws.Context) ([]byte, error) { + nsCtx := namespace.ContextWithNamespace(ctx, f.ns) + resp, err := f.sys.GenerateIdentityToken(nsCtx, &pluginutil.IdentityTokenRequest{ + Audience: f.audience, + TTL: f.ttl, + }) + if err != nil { + return nil, fmt.Errorf("failed to generate plugin identity token: %w", err) + } + f.logger.Info("fetched new plugin identity token") + + if resp.TTL < f.ttl { + f.logger.Debug("generated plugin identity token has shorter TTL than requested", + "requested", f.ttl, "actual", resp.TTL) + } + + return []byte(resp.Token.Token()), nil +} diff --git a/builtin/logical/aws/path_config_root.go b/builtin/logical/aws/path_config_root.go index 5b5e3f1ce6fa..d2c64bbaa51e 100644 --- a/builtin/logical/aws/path_config_root.go +++ b/builtin/logical/aws/path_config_root.go @@ -7,7 +7,9 @@ import ( "context" "github.com/aws/aws-sdk-go/aws" + "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/helper/pluginidentityutil" "github.com/hashicorp/vault/sdk/logical" ) @@ -15,7 +17,7 @@ import ( const defaultUserNameTemplate = `{{ if (eq .Type "STS") }}{{ printf "vault-%s-%s" (unix_time) (random 20) | truncate 32 }}{{ else }}{{ printf "vault-%s-%s-%s" (printf "%s-%s" (.DisplayName) (.PolicyName) | truncate 42) (unix_time) (random 20) | truncate 64 }}{{ end }}` func pathConfigRoot(b *backend) *framework.Path { - return &framework.Path{ + p := &framework.Path{ Pattern: "config/root", DisplayAttrs: &framework.DisplayAttributes{ @@ -54,6 +56,10 @@ func pathConfigRoot(b *backend) *framework.Path { Type: framework.TypeString, Description: "Template to generate custom IAM usernames", }, + "role_arn": { + Type: framework.TypeString, + Description: "Role ARN to assume for plugin identity token federation", + }, }, Operations: map[logical.Operation]framework.OperationHandler{ @@ -75,6 +81,9 @@ func pathConfigRoot(b *backend) *framework.Path { HelpSynopsis: pathConfigRootHelpSyn, HelpDescription: pathConfigRootHelpDesc, } + pluginidentityutil.AddPluginIdentityTokenFields(p.Fields) + + return p } func (b *backend) pathConfigRootRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { @@ -102,7 +111,10 @@ func (b *backend) pathConfigRootRead(ctx context.Context, req *logical.Request, "sts_endpoint": config.STSEndpoint, "max_retries": config.MaxRetries, "username_template": config.UsernameTemplate, + "role_arn": config.RoleARN, } + + config.PopulatePluginIdentityTokenData(configData) return &logical.Response{ Data: configData, }, nil @@ -113,6 +125,7 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request, iamendpoint := data.Get("iam_endpoint").(string) stsendpoint := data.Get("sts_endpoint").(string) maxretries := data.Get("max_retries").(int) + roleARN := data.Get("role_arn").(string) usernameTemplate := data.Get("username_template").(string) if usernameTemplate == "" { usernameTemplate = defaultUserNameTemplate @@ -121,7 +134,7 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request, b.clientMutex.Lock() defer b.clientMutex.Unlock() - entry, err := logical.StorageEntryJSON("config/root", rootConfig{ + rc := rootConfig{ AccessKey: data.Get("access_key").(string), SecretKey: data.Get("secret_key").(string), IAMEndpoint: iamendpoint, @@ -129,7 +142,21 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request, Region: region, MaxRetries: maxretries, UsernameTemplate: usernameTemplate, - }) + RoleARN: roleARN, + } + if err := rc.ParsePluginIdentityTokenFields(data); err != nil { + return logical.ErrorResponse(err.Error()), nil + } + + if rc.IdentityTokenAudience != "" && rc.AccessKey != "" { + return logical.ErrorResponse("only one of 'access_key' or 'identity_token_audience' can be set"), nil + } + + if rc.IdentityTokenAudience != "" && rc.RoleARN == "" { + return logical.ErrorResponse("missing required 'role_arn' when 'identity_token_audience' is set"), nil + } + + entry, err := logical.StorageEntryJSON("config/root", rc) if err != nil { return nil, err } @@ -147,6 +174,8 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request, } type rootConfig struct { + pluginidentityutil.PluginIdentityTokenParams + AccessKey string `json:"access_key"` SecretKey string `json:"secret_key"` IAMEndpoint string `json:"iam_endpoint"` @@ -154,6 +183,7 @@ type rootConfig struct { Region string `json:"region"` MaxRetries int `json:"max_retries"` UsernameTemplate string `json:"username_template"` + RoleARN string `json:"role_arn"` } const pathConfigRootHelpSyn = ` diff --git a/builtin/logical/aws/path_config_root_test.go b/builtin/logical/aws/path_config_root_test.go index 3de47fcca3ab..0e1c0186025e 100644 --- a/builtin/logical/aws/path_config_root_test.go +++ b/builtin/logical/aws/path_config_root_test.go @@ -6,9 +6,11 @@ package aws import ( "context" "reflect" + "strings" "testing" "github.com/hashicorp/vault/sdk/logical" + "github.com/stretchr/testify/require" ) func TestBackend_PathConfigRoot(t *testing.T) { @@ -21,13 +23,16 @@ func TestBackend_PathConfigRoot(t *testing.T) { } configData := map[string]interface{}{ - "access_key": "AKIAEXAMPLE", - "secret_key": "RandomData", - "region": "us-west-2", - "iam_endpoint": "https://iam.amazonaws.com", - "sts_endpoint": "https://sts.us-west-2.amazonaws.com", - "max_retries": 10, - "username_template": defaultUserNameTemplate, + "access_key": "AKIAEXAMPLE", + "secret_key": "RandomData", + "region": "us-west-2", + "iam_endpoint": "https://iam.amazonaws.com", + "sts_endpoint": "https://sts.us-west-2.amazonaws.com", + "max_retries": 10, + "username_template": defaultUserNameTemplate, + "role_arn": "", + "identity_token_audience": "", + "identity_token_ttl": int64(0), } configReq := &logical.Request{ @@ -52,7 +57,102 @@ func TestBackend_PathConfigRoot(t *testing.T) { } delete(configData, "secret_key") + require.Equal(t, configData, resp.Data) if !reflect.DeepEqual(resp.Data, configData) { t.Errorf("bad: expected to read config root as %#v, got %#v instead", configData, resp.Data) } } + +// TestBackend_PathConfigRoot_PluginIdentityToken tests parsing and validation of +// configuration used to set the secret engine up for web identity federation using +// plugin identity tokens. +func TestBackend_PathConfigRoot_PluginIdentityToken(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + + b := Backend(config) + if err := b.Setup(context.Background(), config); err != nil { + t.Fatal(err) + } + + configData := map[string]interface{}{ + "identity_token_ttl": int64(10), + "identity_token_audience": "test-aud", + "role_arn": "test-role-arn", + } + + configReq := &logical.Request{ + Operation: logical.UpdateOperation, + Storage: config.StorageView, + Path: "config/root", + Data: configData, + } + + resp, err := b.HandleRequest(context.Background(), configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: config writing failed: resp:%#v\n err: %v", resp, err) + } + + resp, err = b.HandleRequest(context.Background(), &logical.Request{ + Operation: logical.ReadOperation, + Storage: config.StorageView, + Path: "config/root", + }) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: config reading failed: resp:%#v\n err: %v", resp, err) + } + + // Grab the subset of fields from the response we care to look at for this case + got := map[string]interface{}{ + "identity_token_ttl": resp.Data["identity_token_ttl"], + "identity_token_audience": resp.Data["identity_token_audience"], + "role_arn": resp.Data["role_arn"], + } + + if !reflect.DeepEqual(got, configData) { + t.Errorf("bad: expected to read config root as %#v, got %#v instead", configData, resp.Data) + } + + // mutually exclusive fields must result in an error + configData = map[string]interface{}{ + "identity_token_audience": "test-aud", + "access_key": "ASIAIO10230XVB", + } + + configReq = &logical.Request{ + Operation: logical.UpdateOperation, + Storage: config.StorageView, + Path: "config/root", + Data: configData, + } + + resp, err = b.HandleRequest(context.Background(), configReq) + if !resp.IsError() { + t.Fatalf("expected an error but got nil") + } + expectedError := "only one of 'access_key' or 'identity_token_audience' can be set" + if !strings.Contains(resp.Error().Error(), expectedError) { + t.Fatalf("expected err %s, got %s", expectedError, resp.Error()) + } + + // missing role arn with audience must result in an error + configData = map[string]interface{}{ + "identity_token_audience": "test-aud", + } + + configReq = &logical.Request{ + Operation: logical.UpdateOperation, + Storage: config.StorageView, + Path: "config/root", + Data: configData, + } + + resp, err = b.HandleRequest(context.Background(), configReq) + if !resp.IsError() { + t.Fatalf("expected an error but got nil") + } + expectedError = "missing required 'role_arn' when 'identity_token_audience' is set" + if !strings.Contains(resp.Error().Error(), expectedError) { + t.Fatalf("expected err %s, got %s", expectedError, resp.Error()) + } +} diff --git a/changelog/24987.txt b/changelog/24987.txt new file mode 100644 index 000000000000..2eecf033f4d8 --- /dev/null +++ b/changelog/24987.txt @@ -0,0 +1,3 @@ +```release-note:feature +**Plugin Identity Tokens**: Adds secret-less configuration of AWS secret engine using web identity federation. +``` diff --git a/go.mod b/go.mod index 96e8ab54dcf3..7ff7e7df3ddb 100644 --- a/go.mod +++ b/go.mod @@ -101,7 +101,7 @@ require ( github.com/hashicorp/go-raftchunking v0.6.3-0.20191002164813-7e9e8525653a github.com/hashicorp/go-retryablehttp v0.7.4 github.com/hashicorp/go-rootcerts v1.0.2 - github.com/hashicorp/go-secure-stdlib/awsutil v0.2.3 + github.com/hashicorp/go-secure-stdlib/awsutil v0.3.0 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-secure-stdlib/gatedwriter v0.1.1 github.com/hashicorp/go-secure-stdlib/kv-builder v0.1.2 diff --git a/go.sum b/go.sum index 2185a69e005d..0f37fcc0452c 100644 --- a/go.sum +++ b/go.sum @@ -2164,8 +2164,8 @@ github.com/hashicorp/go-retryablehttp v0.7.4/go.mod h1:Jy/gPYAdjqffZ/yFGCFV2doI5 github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU= github.com/hashicorp/go-rootcerts v1.0.2 h1:jzhAVGtqPKbwpyCPELlgNWhE1znq+qwJtW5Oi2viEzc= github.com/hashicorp/go-rootcerts v1.0.2/go.mod h1:pqUvnprVnM5bf7AOirdbb01K4ccR319Vf4pU3K5EGc8= -github.com/hashicorp/go-secure-stdlib/awsutil v0.2.3 h1:AAQ6Vmo/ncfrZYtbpjhO+g0Qt+iNpYtl3UWT1NLmbYY= -github.com/hashicorp/go-secure-stdlib/awsutil v0.2.3/go.mod h1:oKHSQs4ivIfZ3fbXGQOop1XuDfdSb8RIsWTGaAanSfg= +github.com/hashicorp/go-secure-stdlib/awsutil v0.3.0 h1:I8bynUKMh9I7JdwtW9voJ0xmHvBpxQtLjrMFDYmhOxY= +github.com/hashicorp/go-secure-stdlib/awsutil v0.3.0/go.mod h1:oKHSQs4ivIfZ3fbXGQOop1XuDfdSb8RIsWTGaAanSfg= github.com/hashicorp/go-secure-stdlib/base62 v0.1.1/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw= github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 h1:ET4pqyjiGmY09R5y+rSd70J2w45CtbWDNvGqWp/R3Ng= github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw= diff --git a/sdk/helper/pluginidentityutil/fields.go b/sdk/helper/pluginidentityutil/fields.go index 27a692b10421..3d97537ecc94 100644 --- a/sdk/helper/pluginidentityutil/fields.go +++ b/sdk/helper/pluginidentityutil/fields.go @@ -4,7 +4,6 @@ package pluginidentityutil import ( - "errors" "fmt" "time" @@ -25,16 +24,10 @@ func (p *PluginIdentityTokenParams) ParsePluginIdentityTokenFields(d *framework. if tokenTTLRaw, ok := d.GetOk("identity_token_ttl"); ok { p.IdentityTokenTTL = time.Duration(tokenTTLRaw.(int)) * time.Second } - if p.IdentityTokenTTL == 0 { - p.IdentityTokenTTL = time.Hour - } if tokenAudienceRaw, ok := d.GetOk("identity_token_audience"); ok { p.IdentityTokenAudience = tokenAudienceRaw.(string) } - if p.IdentityTokenAudience == "" { - return errors.New("missing required identity_token_audience") - } return nil } diff --git a/sdk/helper/pluginidentityutil/fields_test.go b/sdk/helper/pluginidentityutil/fields_test.go index f66196b0d31f..96c844971d25 100644 --- a/sdk/helper/pluginidentityutil/fields_test.go +++ b/sdk/helper/pluginidentityutil/fields_test.go @@ -39,7 +39,7 @@ func TestParsePluginIdentityTokenFields(t *testing.T) { want map[string]interface{} }{ { - name: "basic", + name: "all input", d: identityTokenFieldData(map[string]interface{}{ fieldIDTokenTTL: 10, fieldIDTokenAudience: "test-aud", @@ -50,19 +50,24 @@ func TestParsePluginIdentityTokenFields(t *testing.T) { }, }, { - name: "empty-ttl", + name: "empty ttl", d: identityTokenFieldData(map[string]interface{}{ fieldIDTokenAudience: "test-aud", }), want: map[string]interface{}{ - fieldIDTokenTTL: time.Hour, + fieldIDTokenTTL: time.Duration(0), fieldIDTokenAudience: "test-aud", }, }, { - name: "empty-audience", - d: identityTokenFieldData(map[string]interface{}{}), - wantErr: true, + name: "empty audience", + d: identityTokenFieldData(map[string]interface{}{ + fieldIDTokenTTL: 10, + }), + want: map[string]interface{}{ + fieldIDTokenTTL: time.Duration(10) * time.Second, + fieldIDTokenAudience: "", + }, }, } diff --git a/vault/external_tests/raft/raft_binary/raft_test.go b/vault/external_tests/raft/raft_binary/raft_test.go index cced2f37963a..135ff4fb28b6 100644 --- a/vault/external_tests/raft/raft_binary/raft_test.go +++ b/vault/external_tests/raft/raft_binary/raft_test.go @@ -82,7 +82,6 @@ func stabilizeAndPromote(t *testing.T, client *api.Client, nodeID string) { var err error for time.Now().Before(deadline) { state, err = client.Sys().RaftAutopilotState() - // If the state endpoint gets called during a leader election, we'll get an error about // there not being an active cluster node. Rather than erroring out of this loop, just // ignore the error and keep trying. It should resolve in a few seconds. There's a