Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an sts_region parameter to the AWS auth engine's client config #7922

Merged
merged 16 commits into from
Dec 11, 2019
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions builtin/credential/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,19 @@ func (b *backend) getRawClientConfig(ctx context.Context, s logical.Storage, reg
endpoint := aws.String("")
var maxRetries int = aws.UseServiceDefaultRetries
if config != nil {
// Override the default endpoint with the configured endpoint.
// Override the defaults with configured values.
switch {
case clientType == "ec2" && config.Endpoint != "":
endpoint = aws.String(config.Endpoint)
case clientType == "iam" && config.IAMEndpoint != "":
endpoint = aws.String(config.IAMEndpoint)
case clientType == "sts" && config.STSEndpoint != "":
endpoint = aws.String(config.STSEndpoint)
case clientType == "sts":
if config.STSEndpoint != "" {
endpoint = aws.String(config.STSEndpoint)
}
if config.STSRegion != "" {
region = config.STSRegion
}
}

credsConfig.AccessKey = config.AccessKey
Expand Down
22 changes: 21 additions & 1 deletion builtin/credential/aws/path_config_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ func (b *backend) pathConfigClient() *framework.Path {
Description: "URL to override the default generated endpoint for making AWS STS API calls.",
},

"sts_region": {
Type: framework.TypeString,
Default: "",
Description: "The region ID for the sts_endpoint, if set.",
},

"iam_server_id_header_value": {
Type: framework.TypeString,
Default: "",
Expand Down Expand Up @@ -127,6 +133,7 @@ func (b *backend) pathConfigClientRead(ctx context.Context, req *logical.Request
"endpoint": clientConfig.Endpoint,
"iam_endpoint": clientConfig.IAMEndpoint,
"sts_endpoint": clientConfig.STSEndpoint,
"sts_region": clientConfig.STSRegion,
"iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue,
"max_retries": clientConfig.MaxRetries,
},
Expand Down Expand Up @@ -217,7 +224,7 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
stsEndpointStr, ok := data.GetOk("sts_endpoint")
if ok {
if configEntry.STSEndpoint != stsEndpointStr.(string) {
// We don't directly cache STS clients as they are ever directly used.
// We don't directly cache STS clients as they are never directly used.
// However, they are potentially indirectly used as credential providers
// for the EC2 and IAM clients, and thus we would be indirectly caching
// them there. So, if we change the STS endpoint, we should flush those
Expand All @@ -229,6 +236,18 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
configEntry.STSEndpoint = data.Get("sts_endpoint").(string)
}

stsRegionStr, ok := data.GetOk("sts_region")
if ok {
if configEntry.STSRegion != stsRegionStr.(string) {
// Region is used when building STS clients. As such, all the comments
// regarding the sts_endpoint changing apply here as well.
changedCreds = true
configEntry.STSRegion = stsRegionStr.(string)
}
} else if req.Operation == logical.CreateOperation {
tyrannosaurus-becks marked this conversation as resolved.
Show resolved Hide resolved
configEntry.STSRegion = data.Get("sts_region").(string)
}

headerValStr, ok := data.GetOk("iam_server_id_header_value")
if ok {
if configEntry.IAMServerIdHeaderValue != headerValStr.(string) {
Expand Down Expand Up @@ -281,6 +300,7 @@ type clientConfig struct {
Endpoint string `json:"endpoint"`
IAMEndpoint string `json:"iam_endpoint"`
STSEndpoint string `json:"sts_endpoint"`
STSRegion string `json:"sts_region"`
IAMServerIdHeaderValue string `json:"iam_server_id_header_value"`
MaxRetries int `json:"max_retries"`
}
Expand Down
20 changes: 19 additions & 1 deletion builtin/credential/aws/path_config_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func TestBackend_pathConfigClient(t *testing.T) {

data := map[string]interface{}{
"sts_endpoint": "https://my-custom-sts-endpoint.example.com",
"sts_region": "us-east-2",
"iam_server_id_header_value": "vault_server_identification_314159",
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Expand All @@ -52,7 +53,6 @@ func TestBackend_pathConfigClient(t *testing.T) {
Data: data,
Storage: storage,
})

if err != nil {
t.Fatal(err)
}
Expand All @@ -75,8 +75,18 @@ func TestBackend_pathConfigClient(t *testing.T) {
t.Fatalf("expected iam_server_id_header_value: '%#v'; returned iam_server_id_header_value: '%#v'",
data["iam_server_id_header_value"], resp.Data["iam_server_id_header_value"])
}
if resp.Data["sts_endpoint"] != data["sts_endpoint"] {
t.Fatalf("expected sts_endpoint: '%#v'; returned sts_endpoint: '%#v'",
data["sts_endpoint"], resp.Data["sts_endpoint"])
}
if resp.Data["sts_region"] != data["sts_region"] {
t.Fatalf("expected sts_region: '%#v'; returned sts_region: '%#v'",
data["sts_region"], resp.Data["sts_region"])
}

data = map[string]interface{}{
"sts_endpoint": "https://my-custom-sts-endpoint2.example.com",
"sts_region": "us-west-1",
"iam_server_id_header_value": "vault_server_identification_2718281",
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Expand Down Expand Up @@ -108,4 +118,12 @@ func TestBackend_pathConfigClient(t *testing.T) {
t.Fatalf("expected iam_server_id_header_value: '%#v'; returned iam_server_id_header_value: '%#v'",
data["iam_server_id_header_value"], resp.Data["iam_server_id_header_value"])
}
if resp.Data["sts_endpoint"] != data["sts_endpoint"] {
t.Fatalf("expected sts_endpoint: '%#v'; returned sts_endpoint: '%#v'",
data["sts_endpoint"], resp.Data["sts_endpoint"])
}
if resp.Data["sts_region"] != data["sts_region"] {
t.Fatalf("expected sts_region: '%#v'; returned sts_region: '%#v'",
data["sts_region"], resp.Data["sts_region"])
}
}
59 changes: 59 additions & 0 deletions builtin/credential/aws/path_role_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,65 @@ func TestAwsVersion(t *testing.T) {
}
}

// This test was used to reproduce https://github.com/hashicorp/vault/issues/7418
// and verify its fix.
func TestRoleResolutionWithSTSEndpointConfigured(t *testing.T) {
t.Skip("skipping test because it hits real endpoints")
tyrannosaurus-becks marked this conversation as resolved.
Show resolved Hide resolved

config := logical.TestBackendConfig()
storage := &logical.InmemStorage{}
config.StorageView = storage

b, err := Backend(config)
if err != nil {
t.Fatal(err)
}

err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}

// configure the client with an sts endpoint that should be used in creating the role
data := map[string]interface{}{
"sts_endpoint": "https://sts.eu-west-1.amazonaws.com",
// Note - if you comment this out, you can reproduce the error shown
// in the linked GH issue above. This essentially reproduces the problem
// we had when we didn't have an sts_region field.
"sts_region": "eu-west-1",
tyrannosaurus-becks marked this conversation as resolved.
Show resolved Hide resolved
}
resp, err := b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.CreateOperation,
Path: "config/client",
Data: data,
Storage: storage,
})
if err != nil {
t.Fatal(err)
}
if resp != nil && resp.IsError() {
t.Fatalf("failed to create the role entry; resp: %#v", resp)
}

data = map[string]interface{}{
"auth_type": iamAuthType,
"bound_iam_principal_arn": "arn:aws:iam::123456789012:assumed-role/MyRole/foo",
tyrannosaurus-becks marked this conversation as resolved.
Show resolved Hide resolved
"resolve_aws_unique_ids": true,
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.CreateOperation,
Path: "role/MyRoleName",
Data: data,
Storage: storage,
})
if err != nil {
t.Fatal(err)
}
if resp != nil && resp.IsError() {
t.Fatalf("failed to create the role entry; resp: %#v", resp)
}
}

func resolveArnToFakeUniqueId(_ context.Context, _ logical.Storage, _ string) (string, error) {
return "FakeUniqueId1", nil
}