diff --git a/aws_config.go b/aws_config.go index 4b80a628..7a1c6502 100644 --- a/aws_config.go +++ b/aws_config.go @@ -2,12 +2,10 @@ package awsbase import ( "context" - "crypto/tls" "errors" "fmt" "log" "net" - "net/http" "os" "strings" "time" @@ -22,7 +20,7 @@ import ( "github.com/aws/smithy-go/middleware" "github.com/hashicorp/aws-sdk-go-base/v2/internal/constants" "github.com/hashicorp/aws-sdk-go-base/v2/internal/endpoints" - "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/aws-sdk-go-base/v2/internal/httpclient" ) func GetAwsConfig(ctx context.Context, c *Config) (aws.Config, error) { @@ -40,8 +38,12 @@ func GetAwsConfig(ctx context.Context, c *Config) (aws.Config, error) { Retryer: retryer, } - loadOptions := append( - commonLoadOptions(c), + loadOptions, err := commonLoadOptions(c) + if err != nil { + return aws.Config{}, err + } + loadOptions = append( + loadOptions, config.WithCredentialsProvider(credentialsProvider), config.WithRetryer(func() aws.Retryer { return retryer @@ -119,13 +121,10 @@ func GetAwsAccountIDAndPartition(ctx context.Context, awsConfig aws.Config, skip return "", endpoints.PartitionForRegion(awsConfig.Region), nil } -func commonLoadOptions(c *Config) []func(*config.LoadOptions) error { - httpClient := cleanhttp.DefaultClient() - if c.Insecure { - transport := httpClient.Transport.(*http.Transport) - transport.TLSClientConfig = &tls.Config{ - InsecureSkipVerify: true, - } +func commonLoadOptions(c *Config) ([]func(*config.LoadOptions) error, error) { + httpClient, err := httpclient.DefaultHttpClient(c) + if err != nil { + return nil, err } apiOptions := make([]func(*middleware.Stack) error, 0) @@ -171,5 +170,5 @@ func commonLoadOptions(c *Config) []func(*config.LoadOptions) error { os.Setenv("AWS_EC2_METADATA_DISABLED", "true") } - return loadOptions + return loadOptions, nil } diff --git a/config.go b/config.go index c54df0b9..bf997e84 100644 --- a/config.go +++ b/config.go @@ -1,44 +1,16 @@ package awsbase -type Config struct { - AccessKey string - APNInfo *APNInfo - AssumeRole *AssumeRole - CallerDocumentationURL string - CallerName string - DebugLogging bool - IamEndpoint string - Insecure bool - MaxRetries int - Profile string - Region string - SecretKey string - SharedCredentialsFiles []string - SharedConfigFiles []string - SkipCredsValidation bool - SkipMetadataApiCheck bool - StsEndpoint string - Token string -} +import ( + "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" +) -type APNInfo struct { - PartnerName string - Products []APNProduct -} +// Config, APNInfo, APNProduct, and AssumeRole are aliased to an internal package to break a dependency cycle +// in internal/httpclient. -type APNProduct struct { - Name string - Version string - Comment string -} +type Config = config.Config -type AssumeRole struct { - RoleARN string - DurationSeconds int - ExternalID string - Policy string - PolicyARNs []string - SessionName string - Tags map[string]string - TransitiveTagKeys []string -} +type APNInfo = config.APNInfo + +type APNProduct = config.APNProduct + +type AssumeRole = config.AssumeRole diff --git a/credentials.go b/credentials.go index 105f3d27..63a47a42 100644 --- a/credentials.go +++ b/credentials.go @@ -15,8 +15,12 @@ import ( ) func getCredentialsProvider(ctx context.Context, c *Config) (aws.CredentialsProvider, error) { - loadOptions := append( - commonLoadOptions(c), + loadOptions, err := commonLoadOptions(c) + if err != nil { + return nil, err + } + loadOptions = append( + loadOptions, config.WithSharedConfigProfile(c.Profile), // Bypass retries when validating authentication config.WithRetryer(func() aws.Retryer { diff --git a/errors.go b/errors.go index 1a1bf324..0c154dad 100644 --- a/errors.go +++ b/errors.go @@ -2,34 +2,12 @@ package awsbase import ( "errors" - "fmt" + + "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" ) // CannotAssumeRoleError occurs when AssumeRole cannot complete. -type CannotAssumeRoleError struct { - Config *Config - Err error -} - -func (e CannotAssumeRoleError) Error() string { - if e.Config == nil || e.Config.AssumeRole == nil { - return fmt.Sprintf("cannot assume role: %s", e.Err) - } - - return fmt.Sprintf(`IAM Role (%s) cannot be assumed. - -There are a number of possible causes of this - the most common are: - * The credentials used in order to assume the role are invalid - * The credentials do not have appropriate permission to assume the role - * The role ARN is not valid - -Error: %s -`, e.Config.AssumeRole.RoleARN, e.Err) -} - -func (e CannotAssumeRoleError) Unwrap() error { - return e.Err -} +type CannotAssumeRoleError = config.CannotAssumeRoleError // IsCannotAssumeRoleError returns true if the error contains the CannotAssumeRoleError type. func IsCannotAssumeRoleError(err error) bool { @@ -37,40 +15,11 @@ func IsCannotAssumeRoleError(err error) bool { return errors.As(err, &e) } -func (c *Config) NewCannotAssumeRoleError(err error) CannotAssumeRoleError { - return CannotAssumeRoleError{Config: c, Err: err} -} - // NoValidCredentialSourcesError occurs when all credential lookup methods have been exhausted without results. -type NoValidCredentialSourcesError struct { - Config *Config - Err error -} - -func (e NoValidCredentialSourcesError) Error() string { - if e.Config == nil { - return fmt.Sprintf("no valid credential sources found: %s", e.Err) - } - - return fmt.Sprintf(`no valid credential sources for %[1]s found. - -Please see %[2]s -for more information about providing credentials. - -Error: %[3]s -`, e.Config.CallerName, e.Config.CallerDocumentationURL, e.Err) -} - -func (e NoValidCredentialSourcesError) Unwrap() error { - return e.Err -} +type NoValidCredentialSourcesError = config.NoValidCredentialSourcesError // IsNoValidCredentialSourcesError returns true if the error contains the NoValidCredentialSourcesError type. func IsNoValidCredentialSourcesError(err error) bool { var e NoValidCredentialSourcesError return errors.As(err, &e) } - -func (c *Config) NewNoValidCredentialSourcesError(err error) NoValidCredentialSourcesError { - return NoValidCredentialSourcesError{Config: c, Err: err} -} diff --git a/internal/config/apn_info.go b/internal/config/apn_info.go new file mode 100644 index 00000000..ae5c5fd6 --- /dev/null +++ b/internal/config/apn_info.go @@ -0,0 +1,29 @@ +package config + +import ( + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +// Builds the user-agent string for APN +func (apn APNInfo) BuildUserAgentString() string { + builder := smithyhttp.NewUserAgentBuilder() + builder.AddKeyValue("APN", "1.0") + builder.AddKeyValue(apn.PartnerName, "1.0") + for _, p := range apn.Products { + p.buildUserAgentPart(builder) + } + return builder.Build() +} + +func (p APNProduct) buildUserAgentPart(b *smithyhttp.UserAgentBuilder) { + if p.Name != "" { + if p.Version != "" { + b.AddKeyValue(p.Name, p.Version) + } else { + b.AddKey(p.Name) + } + } + if p.Comment != "" { + b.AddKey("(" + p.Comment + ")") + } +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 00000000..ab4c398c --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,45 @@ +package config + +type Config struct { + AccessKey string + APNInfo *APNInfo + AssumeRole *AssumeRole + CallerDocumentationURL string + CallerName string + DebugLogging bool + HTTPProxy string + IamEndpoint string + Insecure bool + MaxRetries int + Profile string + Region string + SecretKey string + SharedCredentialsFiles []string + SharedConfigFiles []string + SkipCredsValidation bool + SkipMetadataApiCheck bool + StsEndpoint string + Token string +} + +type APNInfo struct { + PartnerName string + Products []APNProduct +} + +type APNProduct struct { + Name string + Version string + Comment string +} + +type AssumeRole struct { + RoleARN string + DurationSeconds int + ExternalID string + Policy string + PolicyARNs []string + SessionName string + Tags map[string]string + TransitiveTagKeys []string +} diff --git a/internal/config/errors.go b/internal/config/errors.go new file mode 100644 index 00000000..7b0546e2 --- /dev/null +++ b/internal/config/errors.go @@ -0,0 +1,63 @@ +package config + +import ( + "fmt" +) + +// CannotAssumeRoleError occurs when AssumeRole cannot complete. +type CannotAssumeRoleError struct { + Config *Config + Err error +} + +func (e CannotAssumeRoleError) Error() string { + if e.Config == nil || e.Config.AssumeRole == nil { + return fmt.Sprintf("cannot assume role: %s", e.Err) + } + + return fmt.Sprintf(`IAM Role (%s) cannot be assumed. + +There are a number of possible causes of this - the most common are: + * The credentials used in order to assume the role are invalid + * The credentials do not have appropriate permission to assume the role + * The role ARN is not valid + +Error: %s +`, e.Config.AssumeRole.RoleARN, e.Err) +} + +func (e CannotAssumeRoleError) Unwrap() error { + return e.Err +} + +func (c *Config) NewCannotAssumeRoleError(err error) CannotAssumeRoleError { + return CannotAssumeRoleError{Config: c, Err: err} +} + +// NoValidCredentialSourcesError occurs when all credential lookup methods have been exhausted without results. +type NoValidCredentialSourcesError struct { + Config *Config + Err error +} + +func (e NoValidCredentialSourcesError) Error() string { + if e.Config == nil { + return fmt.Sprintf("no valid credential sources found: %s", e.Err) + } + + return fmt.Sprintf(`no valid credential sources for %[1]s found. + +Please see %[2]s +for more information about providing credentials. + +Error: %[3]s +`, e.Config.CallerName, e.Config.CallerDocumentationURL, e.Err) +} + +func (e NoValidCredentialSourcesError) Unwrap() error { + return e.Err +} + +func (c *Config) NewNoValidCredentialSourcesError(err error) NoValidCredentialSourcesError { + return NoValidCredentialSourcesError{Config: c, Err: err} +} diff --git a/internal/httpclient/http_client.go b/internal/httpclient/http_client.go new file mode 100644 index 00000000..4719be45 --- /dev/null +++ b/internal/httpclient/http_client.go @@ -0,0 +1,38 @@ +package httpclient + +import ( + "crypto/tls" + "fmt" + "net/http" + "net/url" + + "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" + "github.com/hashicorp/go-cleanhttp" +) + +func DefaultHttpClient(c *config.Config) (*http.Client, error) { + httpClient := cleanhttp.DefaultClient() + transport := httpClient.Transport.(*http.Transport) + + tlsConfig := transport.TLSClientConfig + if tlsConfig == nil { + tlsConfig = &tls.Config{} + transport.TLSClientConfig = tlsConfig + } + tlsConfig.MinVersion = tls.VersionTLS12 + + if c.Insecure { + tlsConfig.InsecureSkipVerify = true + } + + if c.HTTPProxy != "" { + proxyUrl, err := url.Parse(c.HTTPProxy) + if err != nil { + return nil, fmt.Errorf("error parsing HTTP proxy URL: %w", err) + } + + transport.Proxy = http.ProxyURL(proxyUrl) + } + + return httpClient, nil +} diff --git a/user_agent.go b/user_agent.go index d572bc70..080bd717 100644 --- a/user_agent.go +++ b/user_agent.go @@ -8,30 +8,6 @@ import ( smithyhttp "github.com/aws/smithy-go/transport/http" ) -// Builds the user-agent string for APN -func (apn APNInfo) BuildUserAgentString() string { - builder := smithyhttp.NewUserAgentBuilder() - builder.AddKeyValue("APN", "1.0") - builder.AddKeyValue(apn.PartnerName, "1.0") - for _, p := range apn.Products { - p.buildUserAgentPart(builder) - } - return builder.Build() -} - -func (p APNProduct) buildUserAgentPart(b *smithyhttp.UserAgentBuilder) { - if p.Name != "" { - if p.Version != "" { - b.AddKeyValue(p.Name, p.Version) - } else { - b.AddKey(p.Name) - } - } - if p.Comment != "" { - b.AddKey("(" + p.Comment + ")") - } -} - func apnUserAgentMiddleware(apn APNInfo) middleware.BuildMiddleware { return middleware.BuildMiddlewareFunc("APNUserAgent", func(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) (middleware.BuildOutput, middleware.Metadata, error) { diff --git a/v2/awsv1shim/session.go b/v2/awsv1shim/session.go index 422db4ee..aa43aabd 100644 --- a/v2/awsv1shim/session.go +++ b/v2/awsv1shim/session.go @@ -4,6 +4,7 @@ import ( // nosemgrep: no-sdkv2-imports-in-awsv1shim "context" "fmt" "log" + "net/http" "os" awsv2 "github.com/aws/aws-sdk-go-v2/aws" @@ -14,7 +15,7 @@ import ( // nosemgrep: no-sdkv2-imports-in-awsv1shim awsbase "github.com/hashicorp/aws-sdk-go-base/v2" "github.com/hashicorp/aws-sdk-go-base/v2/awsv1shim/v2/tfawserr" "github.com/hashicorp/aws-sdk-go-base/v2/internal/constants" - "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/aws-sdk-go-base/v2/internal/httpclient" ) // getSessionOptions attempts to return valid AWS Go SDK session authentication @@ -26,6 +27,13 @@ func getSessionOptions(awsC *awsv2.Config, c *awsbase.Config) (*session.Options, return nil, fmt.Errorf("error accessing credentials: %w", err) } + httpClient, ok := awsC.HTTPClient.(*http.Client) + if !ok { // This is unlikely, but technically possible + httpClient, err = httpclient.DefaultHttpClient(c) + if err != nil { + return nil, err + } + } options := &session.Options{ Config: aws.Config{ Credentials: credentials.NewStaticCredentials( @@ -34,7 +42,7 @@ func getSessionOptions(awsC *awsv2.Config, c *awsbase.Config) (*session.Options, creds.SessionToken, ), EndpointResolver: endpointResolver(c), - HTTPClient: cleanhttp.DefaultClient(), + HTTPClient: httpClient, MaxRetries: aws.Int(0), Region: aws.String(awsC.Region), },