From 6b2d5ac3dcc748c1494654cf232af78a46c01b99 Mon Sep 17 00:00:00 2001 From: Becca Petrin Date: Tue, 10 Dec 2019 16:02:04 -0800 Subject: [PATCH] Add an sts_region parameter to the AWS auth engine's client config (#7922) --- builtin/credential/aws/client.go | 11 ++- builtin/credential/aws/path_config_client.go | 20 ++++- .../credential/aws/path_config_client_test.go | 20 ++++- builtin/credential/aws/path_role_test.go | 87 +++++++++++++++++++ 4 files changed, 133 insertions(+), 5 deletions(-) diff --git a/builtin/credential/aws/client.go b/builtin/credential/aws/client.go index 90524c5c2a5b..e495dd257967 100644 --- a/builtin/credential/aws/client.go +++ b/builtin/credential/aws/client.go @@ -37,14 +37,19 @@ func (b *backend) getRawClientConfig(ctx context.Context, s logical.Storage, reg endpoint := aws.String("") var maxRetries int = aws.UseServiceDefaultRetries if config != nil { - // Override the default endpoint with the configured endpoint. + // Override the defaults with configured values. switch { case clientType == "ec2" && config.Endpoint != "": endpoint = aws.String(config.Endpoint) case clientType == "iam" && config.IAMEndpoint != "": endpoint = aws.String(config.IAMEndpoint) - case clientType == "sts" && config.STSEndpoint != "": - endpoint = aws.String(config.STSEndpoint) + case clientType == "sts": + if config.STSEndpoint != "" { + endpoint = aws.String(config.STSEndpoint) + } + if config.STSRegion != "" { + region = config.STSRegion + } } credsConfig.AccessKey = config.AccessKey diff --git a/builtin/credential/aws/path_config_client.go b/builtin/credential/aws/path_config_client.go index 18cd8749e01c..228488e30220 100644 --- a/builtin/credential/aws/path_config_client.go +++ b/builtin/credential/aws/path_config_client.go @@ -42,6 +42,12 @@ func (b *backend) pathConfigClient() *framework.Path { Description: "URL to override the default generated endpoint for making AWS STS API calls.", }, + "sts_region": { + Type: framework.TypeString, + Default: "", + Description: "The region ID for the sts_endpoint, if set.", + }, + "iam_server_id_header_value": { Type: framework.TypeString, Default: "", @@ -127,6 +133,7 @@ func (b *backend) pathConfigClientRead(ctx context.Context, req *logical.Request "endpoint": clientConfig.Endpoint, "iam_endpoint": clientConfig.IAMEndpoint, "sts_endpoint": clientConfig.STSEndpoint, + "sts_region": clientConfig.STSRegion, "iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue, "max_retries": clientConfig.MaxRetries, }, @@ -217,7 +224,7 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical stsEndpointStr, ok := data.GetOk("sts_endpoint") if ok { if configEntry.STSEndpoint != stsEndpointStr.(string) { - // We don't directly cache STS clients as they are ever directly used. + // We don't directly cache STS clients as they are never directly used. // However, they are potentially indirectly used as credential providers // for the EC2 and IAM clients, and thus we would be indirectly caching // them there. So, if we change the STS endpoint, we should flush those @@ -229,6 +236,16 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical configEntry.STSEndpoint = data.Get("sts_endpoint").(string) } + stsRegionStr, ok := data.GetOk("sts_region") + if ok { + if configEntry.STSRegion != stsRegionStr.(string) { + // Region is used when building STS clients. As such, all the comments + // regarding the sts_endpoint changing apply here as well. + changedCreds = true + configEntry.STSRegion = stsRegionStr.(string) + } + } + headerValStr, ok := data.GetOk("iam_server_id_header_value") if ok { if configEntry.IAMServerIdHeaderValue != headerValStr.(string) { @@ -281,6 +298,7 @@ type clientConfig struct { Endpoint string `json:"endpoint"` IAMEndpoint string `json:"iam_endpoint"` STSEndpoint string `json:"sts_endpoint"` + STSRegion string `json:"sts_region"` IAMServerIdHeaderValue string `json:"iam_server_id_header_value"` MaxRetries int `json:"max_retries"` } diff --git a/builtin/credential/aws/path_config_client_test.go b/builtin/credential/aws/path_config_client_test.go index ab6fa9b45a70..493d20d9df00 100644 --- a/builtin/credential/aws/path_config_client_test.go +++ b/builtin/credential/aws/path_config_client_test.go @@ -44,6 +44,7 @@ func TestBackend_pathConfigClient(t *testing.T) { data := map[string]interface{}{ "sts_endpoint": "https://my-custom-sts-endpoint.example.com", + "sts_region": "us-east-2", "iam_server_id_header_value": "vault_server_identification_314159", } resp, err = b.HandleRequest(context.Background(), &logical.Request{ @@ -52,7 +53,6 @@ func TestBackend_pathConfigClient(t *testing.T) { Data: data, Storage: storage, }) - if err != nil { t.Fatal(err) } @@ -75,8 +75,18 @@ func TestBackend_pathConfigClient(t *testing.T) { t.Fatalf("expected iam_server_id_header_value: '%#v'; returned iam_server_id_header_value: '%#v'", data["iam_server_id_header_value"], resp.Data["iam_server_id_header_value"]) } + if resp.Data["sts_endpoint"] != data["sts_endpoint"] { + t.Fatalf("expected sts_endpoint: '%#v'; returned sts_endpoint: '%#v'", + data["sts_endpoint"], resp.Data["sts_endpoint"]) + } + if resp.Data["sts_region"] != data["sts_region"] { + t.Fatalf("expected sts_region: '%#v'; returned sts_region: '%#v'", + data["sts_region"], resp.Data["sts_region"]) + } data = map[string]interface{}{ + "sts_endpoint": "https://my-custom-sts-endpoint2.example.com", + "sts_region": "us-west-1", "iam_server_id_header_value": "vault_server_identification_2718281", } resp, err = b.HandleRequest(context.Background(), &logical.Request{ @@ -108,4 +118,12 @@ func TestBackend_pathConfigClient(t *testing.T) { t.Fatalf("expected iam_server_id_header_value: '%#v'; returned iam_server_id_header_value: '%#v'", data["iam_server_id_header_value"], resp.Data["iam_server_id_header_value"]) } + if resp.Data["sts_endpoint"] != data["sts_endpoint"] { + t.Fatalf("expected sts_endpoint: '%#v'; returned sts_endpoint: '%#v'", + data["sts_endpoint"], resp.Data["sts_endpoint"]) + } + if resp.Data["sts_region"] != data["sts_region"] { + t.Fatalf("expected sts_region: '%#v'; returned sts_region: '%#v'", + data["sts_region"], resp.Data["sts_region"]) + } } diff --git a/builtin/credential/aws/path_role_test.go b/builtin/credential/aws/path_role_test.go index 13e58c173843..f3daec40b5dc 100644 --- a/builtin/credential/aws/path_role_test.go +++ b/builtin/credential/aws/path_role_test.go @@ -2,11 +2,14 @@ package awsauth import ( "context" + "os" "reflect" "strings" "testing" "github.com/go-test/deep" + "github.com/hashicorp/vault/helper/awsutil" + vlttesting "github.com/hashicorp/vault/helper/testhelpers/logical" "github.com/hashicorp/vault/sdk/helper/policyutil" "github.com/hashicorp/vault/sdk/helper/strutil" "github.com/hashicorp/vault/sdk/logical" @@ -986,6 +989,90 @@ func TestAwsVersion(t *testing.T) { } } +// This test was used to reproduce https://github.com/hashicorp/vault/issues/7418 +// and verify its fix. +// Please run it at least 3 times to ensure that passing tests are due to actually +// passing, rather than the region being randomly chosen tying to the one in the +// test through luck. +func TestRoleResolutionWithSTSEndpointConfigured(t *testing.T) { + if enabled := os.Getenv(vlttesting.TestEnvVar); enabled == "" { + t.Skip() + } + + /* ARN of an AWS role that Vault can query during testing. + This role should exist in your current AWS account and your credentials + should have iam:GetRole permissions to query it. + */ + assumableRoleArn := os.Getenv("AWS_ASSUMABLE_ROLE_ARN") + if assumableRoleArn == "" { + t.Skip("skipping because AWS_ASSUMABLE_ROLE_ARN is unset") + } + + // Ensure aws credentials are available locally for testing. + credsConfig := &awsutil.CredentialsConfig{} + credsChain, err := credsConfig.GenerateCredentialChain() + if err != nil { + t.Fatal(err) + } + _, err = credsChain.Get() + if err != nil { + t.SkipNow() + } + + config := logical.TestBackendConfig() + storage := &logical.InmemStorage{} + config.StorageView = storage + + b, err := Backend(config) + if err != nil { + t.Fatal(err) + } + + err = b.Setup(context.Background(), config) + if err != nil { + t.Fatal(err) + } + + // configure the client with an sts endpoint that should be used in creating the role + data := map[string]interface{}{ + "sts_endpoint": "https://sts.eu-west-1.amazonaws.com", + // Note - if you comment this out, you can reproduce the error shown + // in the linked GH issue above. This essentially reproduces the problem + // we had when we didn't have an sts_region field. + "sts_region": "eu-west-1", + } + resp, err := b.HandleRequest(context.Background(), &logical.Request{ + Operation: logical.CreateOperation, + Path: "config/client", + Data: data, + Storage: storage, + }) + if err != nil { + t.Fatal(err) + } + if resp != nil && resp.IsError() { + t.Fatalf("failed to create the role entry; resp: %#v", resp) + } + + data = map[string]interface{}{ + "auth_type": iamAuthType, + "bound_iam_principal_arn": assumableRoleArn, + "resolve_aws_unique_ids": true, + } + resp, err = b.HandleRequest(context.Background(), &logical.Request{ + Operation: logical.CreateOperation, + Path: "role/MyRoleName", + Data: data, + Storage: storage, + }) + if err != nil { + t.Fatal(err) + } + if resp != nil && resp.IsError() { + t.Fatalf("failed to create the role entry; resp: %#v", resp) + } +} + func resolveArnToFakeUniqueId(_ context.Context, _ logical.Storage, _ string) (string, error) { return "FakeUniqueId1", nil }