Skip to content

Commit

Permalink
WIF support for AWS secrets engine (#24987)
Browse files Browse the repository at this point in the history
* add new plugin wif fields to AWS Secrets Engine

* add changelog

* go get awsutil v0.3.0

* fix up changelog

* fix test and field parsing helper

* godoc on new test

* require role arn when audience set

* make fmt

---------

Co-authored-by: Austin Gebauer <[email protected]>
Co-authored-by: Austin Gebauer <[email protected]>
  • Loading branch information
3 people authored and Monkeychip committed Jan 30, 2024
1 parent 108c989 commit eba8917
Show file tree
Hide file tree
Showing 10 changed files with 224 additions and 35 deletions.
4 changes: 2 additions & 2 deletions builtin/logical/aws/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func (b *backend) clientIAM(ctx context.Context, s logical.Storage) (iamiface.IA
return b.iamClient, nil
}

iamClient, err := nonCachedClientIAM(ctx, s, b.Logger())
iamClient, err := b.nonCachedClientIAM(ctx, s, b.Logger())
if err != nil {
return nil, err
}
Expand All @@ -168,7 +168,7 @@ func (b *backend) clientSTS(ctx context.Context, s logical.Storage) (stsiface.ST
return b.stsClient, nil
}

stsClient, err := nonCachedClientSTS(ctx, s, b.Logger())
stsClient, err := b.nonCachedClientSTS(ctx, s, b.Logger())
if err != nil {
return nil, err
}
Expand Down
71 changes: 65 additions & 6 deletions builtin/logical/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,25 @@ import (
"context"
"fmt"
"os"
"strconv"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/awsutil"

"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
)

// NOTE: The caller is required to ensure that b.clientMutex is at least read locked
func getRootConfig(ctx context.Context, s logical.Storage, clientType string, logger hclog.Logger) (*aws.Config, error) {
func (b *backend) getRootConfig(ctx context.Context, s logical.Storage, clientType string, logger hclog.Logger) (*aws.Config, error) {
credsConfig := &awsutil.CredentialsConfig{}
var endpoint string
var maxRetries int = aws.UseServiceDefaultRetries
Expand All @@ -44,6 +50,26 @@ func getRootConfig(ctx context.Context, s logical.Storage, clientType string, lo
case clientType == "sts" && config.STSEndpoint != "":
endpoint = *aws.String(config.STSEndpoint)
}

if config.IdentityTokenAudience != "" {
ns, err := namespace.FromContext(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get namespace from context: %w", err)
}

fetcher := &PluginIdentityTokenFetcher{
sys: b.System(),
logger: b.Logger(),
ns: ns,
audience: config.IdentityTokenAudience,
ttl: config.IdentityTokenTTL,
}

sessionSuffix := strconv.FormatInt(time.Now().UnixNano(), 10)
credsConfig.RoleSessionName = fmt.Sprintf("vault-aws-secrets-%s", sessionSuffix)
credsConfig.WebIdentityTokenFetcher = fetcher
credsConfig.RoleARN = config.RoleARN
}
}

if credsConfig.Region == "" {
Expand Down Expand Up @@ -74,8 +100,8 @@ func getRootConfig(ctx context.Context, s logical.Storage, clientType string, lo
}, nil
}

func nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger) (*iam.IAM, error) {
awsConfig, err := getRootConfig(ctx, s, "iam", logger)
func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger) (*iam.IAM, error) {
awsConfig, err := b.getRootConfig(ctx, s, "iam", logger)
if err != nil {
return nil, err
}
Expand All @@ -90,8 +116,8 @@ func nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Log
return client, nil
}

func nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Logger) (*sts.STS, error) {
awsConfig, err := getRootConfig(ctx, s, "sts", logger)
func (b *backend) nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Logger) (*sts.STS, error) {
awsConfig, err := b.getRootConfig(ctx, s, "sts", logger)
if err != nil {
return nil, err
}
Expand All @@ -105,3 +131,36 @@ func nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Log
}
return client, nil
}

// PluginIdentityTokenFetcher fetches plugin identity tokens from Vault. It is provided
// to the AWS SDK client to keep assumed role credentials refreshed through expiration.
// When the client's STS credentials expire, it will use this interface to fetch a new
// plugin identity token and exchange it for new STS credentials.
type PluginIdentityTokenFetcher struct {
sys logical.SystemView
logger hclog.Logger
audience string
ns *namespace.Namespace
ttl time.Duration
}

var _ stscreds.TokenFetcher = (*PluginIdentityTokenFetcher)(nil)

func (f PluginIdentityTokenFetcher) FetchToken(ctx aws.Context) ([]byte, error) {
nsCtx := namespace.ContextWithNamespace(ctx, f.ns)
resp, err := f.sys.GenerateIdentityToken(nsCtx, &pluginutil.IdentityTokenRequest{
Audience: f.audience,
TTL: f.ttl,
})
if err != nil {
return nil, fmt.Errorf("failed to generate plugin identity token: %w", err)
}
f.logger.Info("fetched new plugin identity token")

if resp.TTL < f.ttl {
f.logger.Debug("generated plugin identity token has shorter TTL than requested",
"requested", f.ttl, "actual", resp.TTL)
}

return []byte(resp.Token.Token()), nil
}
36 changes: 33 additions & 3 deletions builtin/logical/aws/path_config_root.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@ import (
"context"

"github.com/aws/aws-sdk-go/aws"

"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/pluginidentityutil"
"github.com/hashicorp/vault/sdk/logical"
)

// A single default template that supports both the different credential types (IAM/STS) that are capped at differing length limits (64 chars/32 chars respectively)
const defaultUserNameTemplate = `{{ if (eq .Type "STS") }}{{ printf "vault-%s-%s" (unix_time) (random 20) | truncate 32 }}{{ else }}{{ printf "vault-%s-%s-%s" (printf "%s-%s" (.DisplayName) (.PolicyName) | truncate 42) (unix_time) (random 20) | truncate 64 }}{{ end }}`

func pathConfigRoot(b *backend) *framework.Path {
return &framework.Path{
p := &framework.Path{
Pattern: "config/root",

DisplayAttrs: &framework.DisplayAttributes{
Expand Down Expand Up @@ -54,6 +56,10 @@ func pathConfigRoot(b *backend) *framework.Path {
Type: framework.TypeString,
Description: "Template to generate custom IAM usernames",
},
"role_arn": {
Type: framework.TypeString,
Description: "Role ARN to assume for plugin identity token federation",
},
},

Operations: map[logical.Operation]framework.OperationHandler{
Expand All @@ -75,6 +81,9 @@ func pathConfigRoot(b *backend) *framework.Path {
HelpSynopsis: pathConfigRootHelpSyn,
HelpDescription: pathConfigRootHelpDesc,
}
pluginidentityutil.AddPluginIdentityTokenFields(p.Fields)

return p
}

func (b *backend) pathConfigRootRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
Expand Down Expand Up @@ -102,7 +111,10 @@ func (b *backend) pathConfigRootRead(ctx context.Context, req *logical.Request,
"sts_endpoint": config.STSEndpoint,
"max_retries": config.MaxRetries,
"username_template": config.UsernameTemplate,
"role_arn": config.RoleARN,
}

config.PopulatePluginIdentityTokenData(configData)
return &logical.Response{
Data: configData,
}, nil
Expand All @@ -113,6 +125,7 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request,
iamendpoint := data.Get("iam_endpoint").(string)
stsendpoint := data.Get("sts_endpoint").(string)
maxretries := data.Get("max_retries").(int)
roleARN := data.Get("role_arn").(string)
usernameTemplate := data.Get("username_template").(string)
if usernameTemplate == "" {
usernameTemplate = defaultUserNameTemplate
Expand All @@ -121,15 +134,29 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request,
b.clientMutex.Lock()
defer b.clientMutex.Unlock()

entry, err := logical.StorageEntryJSON("config/root", rootConfig{
rc := rootConfig{
AccessKey: data.Get("access_key").(string),
SecretKey: data.Get("secret_key").(string),
IAMEndpoint: iamendpoint,
STSEndpoint: stsendpoint,
Region: region,
MaxRetries: maxretries,
UsernameTemplate: usernameTemplate,
})
RoleARN: roleARN,
}
if err := rc.ParsePluginIdentityTokenFields(data); err != nil {
return logical.ErrorResponse(err.Error()), nil
}

if rc.IdentityTokenAudience != "" && rc.AccessKey != "" {
return logical.ErrorResponse("only one of 'access_key' or 'identity_token_audience' can be set"), nil
}

if rc.IdentityTokenAudience != "" && rc.RoleARN == "" {
return logical.ErrorResponse("missing required 'role_arn' when 'identity_token_audience' is set"), nil
}

entry, err := logical.StorageEntryJSON("config/root", rc)
if err != nil {
return nil, err
}
Expand All @@ -147,13 +174,16 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request,
}

type rootConfig struct {
pluginidentityutil.PluginIdentityTokenParams

AccessKey string `json:"access_key"`
SecretKey string `json:"secret_key"`
IAMEndpoint string `json:"iam_endpoint"`
STSEndpoint string `json:"sts_endpoint"`
Region string `json:"region"`
MaxRetries int `json:"max_retries"`
UsernameTemplate string `json:"username_template"`
RoleARN string `json:"role_arn"`
}

const pathConfigRootHelpSyn = `
Expand Down
114 changes: 107 additions & 7 deletions builtin/logical/aws/path_config_root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ package aws
import (
"context"
"reflect"
"strings"
"testing"

"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)

func TestBackend_PathConfigRoot(t *testing.T) {
Expand All @@ -21,13 +23,16 @@ func TestBackend_PathConfigRoot(t *testing.T) {
}

configData := map[string]interface{}{
"access_key": "AKIAEXAMPLE",
"secret_key": "RandomData",
"region": "us-west-2",
"iam_endpoint": "https://iam.amazonaws.com",
"sts_endpoint": "https://sts.us-west-2.amazonaws.com",
"max_retries": 10,
"username_template": defaultUserNameTemplate,
"access_key": "AKIAEXAMPLE",
"secret_key": "RandomData",
"region": "us-west-2",
"iam_endpoint": "https://iam.amazonaws.com",
"sts_endpoint": "https://sts.us-west-2.amazonaws.com",
"max_retries": 10,
"username_template": defaultUserNameTemplate,
"role_arn": "",
"identity_token_audience": "",
"identity_token_ttl": int64(0),
}

configReq := &logical.Request{
Expand All @@ -52,7 +57,102 @@ func TestBackend_PathConfigRoot(t *testing.T) {
}

delete(configData, "secret_key")
require.Equal(t, configData, resp.Data)
if !reflect.DeepEqual(resp.Data, configData) {
t.Errorf("bad: expected to read config root as %#v, got %#v instead", configData, resp.Data)
}
}

// TestBackend_PathConfigRoot_PluginIdentityToken tests parsing and validation of
// configuration used to set the secret engine up for web identity federation using
// plugin identity tokens.
func TestBackend_PathConfigRoot_PluginIdentityToken(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}

b := Backend(config)
if err := b.Setup(context.Background(), config); err != nil {
t.Fatal(err)
}

configData := map[string]interface{}{
"identity_token_ttl": int64(10),
"identity_token_audience": "test-aud",
"role_arn": "test-role-arn",
}

configReq := &logical.Request{
Operation: logical.UpdateOperation,
Storage: config.StorageView,
Path: "config/root",
Data: configData,
}

resp, err := b.HandleRequest(context.Background(), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: config writing failed: resp:%#v\n err: %v", resp, err)
}

resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.ReadOperation,
Storage: config.StorageView,
Path: "config/root",
})
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: config reading failed: resp:%#v\n err: %v", resp, err)
}

// Grab the subset of fields from the response we care to look at for this case
got := map[string]interface{}{
"identity_token_ttl": resp.Data["identity_token_ttl"],
"identity_token_audience": resp.Data["identity_token_audience"],
"role_arn": resp.Data["role_arn"],
}

if !reflect.DeepEqual(got, configData) {
t.Errorf("bad: expected to read config root as %#v, got %#v instead", configData, resp.Data)
}

// mutually exclusive fields must result in an error
configData = map[string]interface{}{
"identity_token_audience": "test-aud",
"access_key": "ASIAIO10230XVB",
}

configReq = &logical.Request{
Operation: logical.UpdateOperation,
Storage: config.StorageView,
Path: "config/root",
Data: configData,
}

resp, err = b.HandleRequest(context.Background(), configReq)
if !resp.IsError() {
t.Fatalf("expected an error but got nil")
}
expectedError := "only one of 'access_key' or 'identity_token_audience' can be set"
if !strings.Contains(resp.Error().Error(), expectedError) {
t.Fatalf("expected err %s, got %s", expectedError, resp.Error())
}

// missing role arn with audience must result in an error
configData = map[string]interface{}{
"identity_token_audience": "test-aud",
}

configReq = &logical.Request{
Operation: logical.UpdateOperation,
Storage: config.StorageView,
Path: "config/root",
Data: configData,
}

resp, err = b.HandleRequest(context.Background(), configReq)
if !resp.IsError() {
t.Fatalf("expected an error but got nil")
}
expectedError = "missing required 'role_arn' when 'identity_token_audience' is set"
if !strings.Contains(resp.Error().Error(), expectedError) {
t.Fatalf("expected err %s, got %s", expectedError, resp.Error())
}
}
3 changes: 3 additions & 0 deletions changelog/24987.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:feature
**Plugin Identity Tokens**: Adds secret-less configuration of AWS secret engine using web identity federation.
```
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ require (
github.com/hashicorp/go-raftchunking v0.6.3-0.20191002164813-7e9e8525653a
github.com/hashicorp/go-retryablehttp v0.7.4
github.com/hashicorp/go-rootcerts v1.0.2
github.com/hashicorp/go-secure-stdlib/awsutil v0.2.3
github.com/hashicorp/go-secure-stdlib/awsutil v0.3.0
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
github.com/hashicorp/go-secure-stdlib/gatedwriter v0.1.1
github.com/hashicorp/go-secure-stdlib/kv-builder v0.1.2
Expand Down
Loading

0 comments on commit eba8917

Please sign in to comment.