Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an sts_region parameter to the AWS auth engine's client config #7922

Merged
merged 16 commits into from
Dec 11, 2019
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions builtin/credential/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,19 @@ func (b *backend) getRawClientConfig(ctx context.Context, s logical.Storage, reg
endpoint := aws.String("")
var maxRetries int = aws.UseServiceDefaultRetries
if config != nil {
// Override the default endpoint with the configured endpoint.
// Override the defaults with configured values.
switch {
case clientType == "ec2" && config.Endpoint != "":
endpoint = aws.String(config.Endpoint)
case clientType == "iam" && config.IAMEndpoint != "":
endpoint = aws.String(config.IAMEndpoint)
case clientType == "sts" && config.STSEndpoint != "":
endpoint = aws.String(config.STSEndpoint)
case clientType == "sts":
if config.STSEndpoint != "" {
endpoint = aws.String(config.STSEndpoint)
}
if config.STSRegion != "" {
region = config.STSRegion
}
}

credsConfig.AccessKey = config.AccessKey
Expand Down
20 changes: 19 additions & 1 deletion builtin/credential/aws/path_config_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ func (b *backend) pathConfigClient() *framework.Path {
Description: "URL to override the default generated endpoint for making AWS STS API calls.",
},

"sts_region": {
Type: framework.TypeString,
Default: "",
Description: "The region ID for the sts_endpoint, if set.",
},

"iam_server_id_header_value": {
Type: framework.TypeString,
Default: "",
Expand Down Expand Up @@ -127,6 +133,7 @@ func (b *backend) pathConfigClientRead(ctx context.Context, req *logical.Request
"endpoint": clientConfig.Endpoint,
"iam_endpoint": clientConfig.IAMEndpoint,
"sts_endpoint": clientConfig.STSEndpoint,
"sts_region": clientConfig.STSRegion,
"iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue,
"max_retries": clientConfig.MaxRetries,
},
Expand Down Expand Up @@ -217,7 +224,7 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
stsEndpointStr, ok := data.GetOk("sts_endpoint")
if ok {
if configEntry.STSEndpoint != stsEndpointStr.(string) {
// We don't directly cache STS clients as they are ever directly used.
// We don't directly cache STS clients as they are never directly used.
// However, they are potentially indirectly used as credential providers
// for the EC2 and IAM clients, and thus we would be indirectly caching
// them there. So, if we change the STS endpoint, we should flush those
Expand All @@ -229,6 +236,16 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
configEntry.STSEndpoint = data.Get("sts_endpoint").(string)
}

stsRegionStr, ok := data.GetOk("sts_region")
if ok {
if configEntry.STSRegion != stsRegionStr.(string) {
// Region is used when building STS clients. As such, all the comments
// regarding the sts_endpoint changing apply here as well.
changedCreds = true
configEntry.STSRegion = stsRegionStr.(string)
}
}

headerValStr, ok := data.GetOk("iam_server_id_header_value")
if ok {
if configEntry.IAMServerIdHeaderValue != headerValStr.(string) {
Expand Down Expand Up @@ -281,6 +298,7 @@ type clientConfig struct {
Endpoint string `json:"endpoint"`
IAMEndpoint string `json:"iam_endpoint"`
STSEndpoint string `json:"sts_endpoint"`
STSRegion string `json:"sts_region"`
IAMServerIdHeaderValue string `json:"iam_server_id_header_value"`
MaxRetries int `json:"max_retries"`
}
Expand Down
20 changes: 19 additions & 1 deletion builtin/credential/aws/path_config_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func TestBackend_pathConfigClient(t *testing.T) {

data := map[string]interface{}{
"sts_endpoint": "https://my-custom-sts-endpoint.example.com",
"sts_region": "us-east-2",
"iam_server_id_header_value": "vault_server_identification_314159",
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Expand All @@ -52,7 +53,6 @@ func TestBackend_pathConfigClient(t *testing.T) {
Data: data,
Storage: storage,
})

if err != nil {
t.Fatal(err)
}
Expand All @@ -75,8 +75,18 @@ func TestBackend_pathConfigClient(t *testing.T) {
t.Fatalf("expected iam_server_id_header_value: '%#v'; returned iam_server_id_header_value: '%#v'",
data["iam_server_id_header_value"], resp.Data["iam_server_id_header_value"])
}
if resp.Data["sts_endpoint"] != data["sts_endpoint"] {
t.Fatalf("expected sts_endpoint: '%#v'; returned sts_endpoint: '%#v'",
data["sts_endpoint"], resp.Data["sts_endpoint"])
}
if resp.Data["sts_region"] != data["sts_region"] {
t.Fatalf("expected sts_region: '%#v'; returned sts_region: '%#v'",
data["sts_region"], resp.Data["sts_region"])
}

data = map[string]interface{}{
"sts_endpoint": "https://my-custom-sts-endpoint2.example.com",
"sts_region": "us-west-1",
"iam_server_id_header_value": "vault_server_identification_2718281",
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Expand Down Expand Up @@ -108,4 +118,12 @@ func TestBackend_pathConfigClient(t *testing.T) {
t.Fatalf("expected iam_server_id_header_value: '%#v'; returned iam_server_id_header_value: '%#v'",
data["iam_server_id_header_value"], resp.Data["iam_server_id_header_value"])
}
if resp.Data["sts_endpoint"] != data["sts_endpoint"] {
t.Fatalf("expected sts_endpoint: '%#v'; returned sts_endpoint: '%#v'",
data["sts_endpoint"], resp.Data["sts_endpoint"])
}
if resp.Data["sts_region"] != data["sts_region"] {
t.Fatalf("expected sts_region: '%#v'; returned sts_region: '%#v'",
data["sts_region"], resp.Data["sts_region"])
}
}
87 changes: 87 additions & 0 deletions builtin/credential/aws/path_role_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package awsauth

import (
"context"
"os"
"reflect"
"strings"
"testing"

"github.com/go-test/deep"
"github.com/hashicorp/vault/helper/awsutil"
vlttesting "github.com/hashicorp/vault/helper/testhelpers/logical"
"github.com/hashicorp/vault/sdk/helper/policyutil"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/logical"
Expand Down Expand Up @@ -986,6 +989,90 @@ func TestAwsVersion(t *testing.T) {
}
}

// This test was used to reproduce https://github.com/hashicorp/vault/issues/7418
// and verify its fix.
// Please run it at least 3 times to ensure that passing tests are due to actually
// passing, rather than the region being randomly chosen tying to the one in the
// test through luck.
func TestRoleResolutionWithSTSEndpointConfigured(t *testing.T) {
if enabled := os.Getenv(vlttesting.TestEnvVar); enabled == "" {
t.Skip()
}

/* ARN of an AWS role that Vault can query during testing.
This role should exist in your current AWS account and your credentials
should have iam:GetRole permissions to query it.
*/
assumableRoleArn := os.Getenv("AWS_ASSUMABLE_ROLE_ARN")
if assumableRoleArn == "" {
t.Skip("skipping because AWS_ASSUMABLE_ROLE_ARN is unset")
}

// Ensure aws credentials are available locally for testing.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TestBackendAcc_LoginWithInstanceIdentityDocAndWhitelistIdentity and TestBackendAcc_LoginWithCallerIdentity will actually override the AWS_* environment variables with TEST_AWS_* -- this creates a little bit of inconsistency. Maybe extract that overriding logic into a common function (e.g., setupTestCreds) to also be called here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, that has the potential to expand the scope of this PR to tuning tests. For instance, I'm looking at TestBackendAcc_LoginWithInstanceIdentityDocAndWhitelistIdentity and I'm noticing that here, the code verifies that the env vars aren't empty and subsequently uses them, but here it checks for an impossible condition and fails if it's found. I'm thinking it might make more sense to leave this test as-is because it's consistent with code already in Vault, and to make a separate ticket to circle back and resolve test inconsistencies.

credsConfig := &awsutil.CredentialsConfig{}
credsChain, err := credsConfig.GenerateCredentialChain()
if err != nil {
t.SkipNow()
tyrannosaurus-becks marked this conversation as resolved.
Show resolved Hide resolved
}
_, err = credsChain.Get()
if err != nil {
t.SkipNow()
}

config := logical.TestBackendConfig()
storage := &logical.InmemStorage{}
config.StorageView = storage

b, err := Backend(config)
if err != nil {
t.Fatal(err)
}

err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}

// configure the client with an sts endpoint that should be used in creating the role
data := map[string]interface{}{
"sts_endpoint": "https://sts.eu-west-1.amazonaws.com",
// Note - if you comment this out, you can reproduce the error shown
// in the linked GH issue above. This essentially reproduces the problem
// we had when we didn't have an sts_region field.
"sts_region": "eu-west-1",
tyrannosaurus-becks marked this conversation as resolved.
Show resolved Hide resolved
}
resp, err := b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.CreateOperation,
Path: "config/client",
Data: data,
Storage: storage,
})
if err != nil {
t.Fatal(err)
}
if resp != nil && resp.IsError() {
t.Fatalf("failed to create the role entry; resp: %#v", resp)
}

data = map[string]interface{}{
"auth_type": iamAuthType,
"bound_iam_principal_arn": assumableRoleArn,
"resolve_aws_unique_ids": true,
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.CreateOperation,
Path: "role/MyRoleName",
Data: data,
Storage: storage,
})
if err != nil {
t.Fatal(err)
}
if resp != nil && resp.IsError() {
t.Fatalf("failed to create the role entry; resp: %#v", resp)
}
}

func resolveArnToFakeUniqueId(_ context.Context, _ logical.Storage, _ string) (string, error) {
return "FakeUniqueId1", nil
}