Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AWS SDK v2] Upgrade library common credentials #31225

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 71 additions & 56 deletions x-pack/libbeat/common/aws/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@
package aws

import (
"context"
"crypto/tls"
"fmt"
"net/http"
"net/url"
"strings"

awssdk "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/defaults"
"github.com/aws/aws-sdk-go-v2/aws/external"
"github.com/aws/aws-sdk-go-v2/aws/stscreds"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/pkg/errors"
awsConfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"

"github.com/elastic/beats/v7/libbeat/common/transport/httpcommon"
"github.com/elastic/beats/v7/libbeat/common/transport/tlscommon"
Expand Down Expand Up @@ -44,26 +43,26 @@ type ConfigAWS struct {
}

// InitializeAWSConfig function creates the awssdk.Config object from the provided config
func InitializeAWSConfig(config ConfigAWS) (awssdk.Config, error) {
AWSConfig, _ := GetAWSCredentials(config)
func InitializeAWSConfig(beatsConfig ConfigAWS) (awssdk.Config, error) {
AWSConfig, _ := GetAWSCredentials(beatsConfig)
if AWSConfig.Region == "" {
if config.DefaultRegion != "" {
AWSConfig.Region = config.DefaultRegion
if beatsConfig.DefaultRegion != "" {
AWSConfig.Region = beatsConfig.DefaultRegion
} else {
AWSConfig.Region = "us-east-1"
}
}
var proxy func(*http.Request) (*url.URL, error)
if config.ProxyUrl != "" {
proxyUrl, err := httpcommon.NewProxyURIFromString(config.ProxyUrl)
if beatsConfig.ProxyUrl != "" {
proxyUrl, err := httpcommon.NewProxyURIFromString(beatsConfig.ProxyUrl)
if err != nil {
return AWSConfig, err
}
proxy = http.ProxyURL(proxyUrl.URI())
}
var tlsConfig *tls.Config
if config.TLS != nil {
TLSConfig, _ := tlscommon.LoadTLSConfig(config.TLS)
if beatsConfig.TLS != nil {
TLSConfig, _ := tlscommon.LoadTLSConfig(beatsConfig.TLS)
tlsConfig = TLSConfig.ToConfig()
}
AWSConfig.HTTPClient = &http.Client{
Expand All @@ -80,102 +79,118 @@ func InitializeAWSConfig(config ConfigAWS) (awssdk.Config, error) {
// If access keys are not given, then load from AWS config file. If credential_profile_name is not
// given, default profile will be used.
// If role_arn is given, assume the IAM role either with access keys or default profile.
func GetAWSCredentials(config ConfigAWS) (awssdk.Config, error) {
func GetAWSCredentials(beatsConfig ConfigAWS) (awssdk.Config, error) {
// Check if accessKeyID or secretAccessKey or sessionToken is given from configuration
if config.AccessKeyID != "" || config.SecretAccessKey != "" || config.SessionToken != "" {
return getAccessKeys(config), nil
if beatsConfig.AccessKeyID != "" || beatsConfig.SecretAccessKey != "" || beatsConfig.SessionToken != "" {
return getConfigForKeys(beatsConfig), nil
}

return getSharedCredentialProfile(config)
return getConfigSharedCredentialProfile(beatsConfig)
}

func getAccessKeys(config ConfigAWS) awssdk.Config {
logger := logp.NewLogger("getAccessKeys")
awsConfig := defaults.Config()
// getConfigForKeys creates a default AWS config and adds a CredentialsProvider using the provided Beats config.
// Provided config must contain an accessKeyID, secretAccessKey and sessionToken to generate a valid CredentialsProfile
func getConfigForKeys(beatsConfig ConfigAWS) awssdk.Config {
logger := logp.NewLogger("getConfigForKeys")

config := awssdk.NewConfig()
awsCredentials := awssdk.Credentials{
AccessKeyID: config.AccessKeyID,
SecretAccessKey: config.SecretAccessKey,
AccessKeyID: beatsConfig.AccessKeyID,
SecretAccessKey: beatsConfig.SecretAccessKey,
}

if config.SessionToken != "" {
awsCredentials.SessionToken = config.SessionToken
if beatsConfig.SessionToken != "" {
awsCredentials.SessionToken = beatsConfig.SessionToken
}

awsConfig.Credentials = awssdk.StaticCredentialsProvider{
config.Credentials = credentials.StaticCredentialsProvider{
Value: awsCredentials,
}

// Assume IAM role if iam_role config parameter is given
if config.RoleArn != "" {
if beatsConfig.RoleArn != "" {
logger.Debug("Using role arn and access keys for AWS credential")
return getRoleArn(config, awsConfig)
addStaticCredentialsProvider(beatsConfig, config)
return *config
}

return awsConfig
return *config
}

func getSharedCredentialProfile(config ConfigAWS) (awssdk.Config, error) {
// If accessKeyID, secretAccessKey or sessionToken is not given, iam_role is not given, then load from default config
// Please see https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-profiles.html
// with more details.
// If credential_profile_name is empty, then default profile is used.
logger := logp.NewLogger("getSharedCredentialProfile")
var options []external.Config
if config.ProfileName != "" {
options = append(options, external.WithSharedConfigProfile(config.ProfileName))
// getConfigSharedCredentialProfile If accessKeyID, secretAccessKey or sessionToken is not given, iam_role is not given,
// then load from default config // Please see https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-profiles.html
// with more details. If credential_profile_name is empty, then default profile is used.
func getConfigSharedCredentialProfile(beatsConfig ConfigAWS) (awssdk.Config, error) {
logger := logp.NewLogger("WithSharedConfigProfile")

var options []func(*awsConfig.LoadOptions) error
if beatsConfig.ProfileName != "" {
options = append(options, awsConfig.WithSharedConfigProfile(beatsConfig.ProfileName))
}

// If shared_credential_file is empty, then external.LoadDefaultAWSConfig
// function will load AWS config from current user's home directory.
// Linux/OSX: "$HOME/.aws/credentials"
// Windows: "%USERPROFILE%\.aws\credentials"
if config.SharedCredentialFile != "" {
options = append(options, external.WithSharedConfigFiles([]string{config.SharedCredentialFile}))
if beatsConfig.SharedCredentialFile != "" {
options = append(options, awsConfig.WithSharedConfigFiles([]string{beatsConfig.SharedCredentialFile}))
}

awsConfig, err := external.LoadDefaultAWSConfig(options...)
cfg, err := awsConfig.LoadDefaultConfig(context.TODO(), options...)
if err != nil {
return awsConfig, errors.Wrap(err, "external.LoadDefaultAWSConfig failed with shared credential profile given")
return cfg, fmt.Errorf("awsConfig.LoadDefaultConfig failed with shared credential profile given: [%w]", err)
}

// Assume IAM role if iam_role config parameter is given
if config.RoleArn != "" {
if beatsConfig.RoleArn != "" {
logger.Debug("Using role arn and shared credential profile for AWS credential")
return getRoleArn(config, awsConfig), nil
addStaticCredentialsProvider(beatsConfig, &cfg)
return cfg, nil
}

logger.Debug("Using shared credential profile for AWS credential")
return awsConfig, nil
return cfg, nil
}

func getRoleArn(config ConfigAWS, awsConfig awssdk.Config) awssdk.Config {
stsSvc := sts.New(awsConfig)
stsCredProvider := stscreds.NewAssumeRoleProvider(stsSvc, config.RoleArn)
awsConfig.Credentials = stsCredProvider
return awsConfig
// addStaticCredentialsProvider adds a static credentials provider to the current AWS config by using the keys stored in Beats config
func addStaticCredentialsProvider(beatsConfig ConfigAWS, awsConfig *awssdk.Config) {
staticCredentialsProvider := credentials.NewStaticCredentialsProvider(
beatsConfig.AccessKeyID,
beatsConfig.SecretAccessKey,
beatsConfig.SessionToken)

awsConfig.Credentials = staticCredentialsProvider

return
}

// EnrichAWSConfigWithEndpoint function enabled endpoint resolver for AWS
// service clients when endpoint is given in config.
func EnrichAWSConfigWithEndpoint(endpoint string, serviceName string, regionName string, awsConfig awssdk.Config) awssdk.Config {
// EnrichAWSConfigWithEndpoint function enabled endpoint resolver for AWS service clients when endpoint is given in config.
func EnrichAWSConfigWithEndpoint(endpoint string, serviceName string, regionName string, beatsConfig awssdk.Config) (awssdk.Config, error) {
var eurl string
if endpoint != "" {
parsedEndpoint, _ := url.Parse(endpoint)

// Beats uses the provided endpoint if the scheme is present or...
if parsedEndpoint.Scheme != "" {
awsConfig.EndpointResolver = awssdk.ResolveWithEndpointURL(endpoint)
eurl = endpoint
} else {
// ...build one by using the scheme, service and region names.
if regionName == "" {
eurl = "https://" + serviceName + "." + endpoint
} else {
eurl = "https://" + serviceName + "." + regionName + "." + endpoint
}
awsConfig.EndpointResolver = awssdk.ResolveWithEndpointURL(eurl)
}

beatsConfig.EndpointResolverWithOptions = awssdk.EndpointResolverWithOptionsFunc(
func(service, region string, options ...interface{}) (awssdk.Endpoint, error) {
return awssdk.Endpoint{URL: eurl}, nil
})
}
return awsConfig
return beatsConfig, nil
}

//Create AWS service name based on Region and FIPS
// CreateServiceName based on Service name, Region and FIPS. Returns service name if Fips is not enabled.
func CreateServiceName(serviceName string, fipsEnabled bool, region string) string {
if fipsEnabled {
_, found := OptionalGovCloudFIPS[serviceName]
Expand Down
19 changes: 14 additions & 5 deletions x-pack/libbeat/common/aws/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func TestEnrichAWSConfigWithEndpoint(t *testing.T) {
"",
awssdk.Config{},
awssdk.Config{
EndpointResolver: awssdk.ResolveWithEndpointURL("https://ec2.amazonaws.com"),
EndpointResolverWithOptions: getEndpointResolverWithOptionsFunc("https://ec2.amazonaws.com"),
},
},
{
Expand All @@ -79,7 +79,7 @@ func TestEnrichAWSConfigWithEndpoint(t *testing.T) {
"us-west-1",
awssdk.Config{},
awssdk.Config{
EndpointResolver: awssdk.ResolveWithEndpointURL("https://cloudwatch.us-west-1.amazonaws.com"),
EndpointResolverWithOptions: getEndpointResolverWithOptionsFunc("https://cloudwatch.us-west-1.amazonaws.com"),
},
},
{
Expand All @@ -89,7 +89,7 @@ func TestEnrichAWSConfigWithEndpoint(t *testing.T) {
"",
awssdk.Config{},
awssdk.Config{
EndpointResolver: awssdk.ResolveWithEndpointURL("https://s3.test.com:9000"),
EndpointResolverWithOptions: getEndpointResolverWithOptionsFunc("https://s3.test.com:9000"),
},
},
{
Expand All @@ -99,18 +99,27 @@ func TestEnrichAWSConfigWithEndpoint(t *testing.T) {
"",
awssdk.Config{},
awssdk.Config{
EndpointResolver: awssdk.ResolveWithEndpointURL("http://testobjects.com:9000"),
EndpointResolverWithOptions: getEndpointResolverWithOptionsFunc("http://testobjects.com:9000"),
},
},
}
for _, c := range cases {
t.Run(c.title, func(t *testing.T) {
enrichedAWSConfig := EnrichAWSConfigWithEndpoint(c.endpoint, c.serviceName, c.region, c.awsConfig)
enrichedAWSConfig, err := EnrichAWSConfigWithEndpoint(c.endpoint, c.serviceName, c.region, c.awsConfig)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, c.expectedAWSConfig, enrichedAWSConfig)
})
}
}

func getEndpointResolverWithOptionsFunc(e string) awssdk.EndpointResolverWithOptionsFunc {
return func(service, region string, options ...interface{}) (awssdk.Endpoint, error) {
return awssdk.Endpoint{URL: e}, nil
}
}

func TestCreateServiceName(t *testing.T) {
cases := []struct {
title string
Expand Down