diff --git a/.changelog/0593bfc1f00841febcb73c7957839f01.json b/.changelog/0593bfc1f00841febcb73c7957839f01.json new file mode 100644 index 00000000000..1cc8b95077b --- /dev/null +++ b/.changelog/0593bfc1f00841febcb73c7957839f01.json @@ -0,0 +1,10 @@ +{ + "id": "0593bfc1-f008-41fe-bcb7-3c7957839f01", + "type": "feature", + "collapse": true, + "description": "Add support for dynamic auth token from file and EKS container host in absolute/relative URIs in the HTTP credential provider.", + "modules": [ + "config", + "credentials" + ] +} \ No newline at end of file diff --git a/config/resolve_credentials.go b/config/resolve_credentials.go index b21cd30804d..89368520f3f 100644 --- a/config/resolve_credentials.go +++ b/config/resolve_credentials.go @@ -3,7 +3,10 @@ package config import ( "context" "fmt" + "io/ioutil" + "net" "net/url" + "os" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -21,11 +24,33 @@ import ( const ( // valid credential source values - credSourceEc2Metadata = "Ec2InstanceMetadata" - credSourceEnvironment = "Environment" - credSourceECSContainer = "EcsContainer" + credSourceEc2Metadata = "Ec2InstanceMetadata" + credSourceEnvironment = "Environment" + credSourceECSContainer = "EcsContainer" + httpProviderAuthFileEnvVar = "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE" ) +// direct representation of the IPv4 address for the ECS container +// "169.254.170.2" +var ecsContainerIPv4 net.IP = []byte{ + 169, 254, 170, 2, +} + +// direct representation of the IPv4 address for the EKS container +// "169.254.170.23" +var eksContainerIPv4 net.IP = []byte{ + 169, 254, 170, 23, +} + +// direct representation of the IPv6 address for the EKS container +// "fd00:ec2::23" +var eksContainerIPv6 net.IP = []byte{ + 0xFD, 0, 0xE, 0xC2, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0x23, +} + var ( ecsContainerEndpoint = "http://169.254.170.2" // not constant to allow for swapping during unit-testing ) @@ -222,6 +247,36 @@ func processCredentials(ctx context.Context, cfg *aws.Config, sharedConfig *Shar return nil } +// isAllowedHost allows host to be loopback or known ECS/EKS container IPs +// +// host can either be an IP address OR an unresolved hostname - resolution will +// be automatically performed in the latter case +func isAllowedHost(host string) (bool, error) { + if ip := net.ParseIP(host); ip != nil { + return isIPAllowed(ip), nil + } + + addrs, err := lookupHostFn(host) + if err != nil { + return false, err + } + + for _, addr := range addrs { + if ip := net.ParseIP(addr); ip == nil || !isIPAllowed(ip) { + return false, nil + } + } + + return true, nil +} + +func isIPAllowed(ip net.IP) bool { + return ip.IsLoopback() || + ip.Equal(ecsContainerIPv4) || + ip.Equal(eksContainerIPv4) || + ip.Equal(eksContainerIPv6) +} + func resolveLocalHTTPCredProvider(ctx context.Context, cfg *aws.Config, endpointURL, authToken string, configs configs) error { var resolveErr error @@ -232,10 +287,12 @@ func resolveLocalHTTPCredProvider(ctx context.Context, cfg *aws.Config, endpoint host := parsed.Hostname() if len(host) == 0 { resolveErr = fmt.Errorf("unable to parse host from local HTTP cred provider URL") - } else if isLoopback, loopbackErr := isLoopbackHost(host); loopbackErr != nil { - resolveErr = fmt.Errorf("failed to resolve host %q, %v", host, loopbackErr) - } else if !isLoopback { - resolveErr = fmt.Errorf("invalid endpoint host, %q, only loopback hosts are allowed", host) + } else if parsed.Scheme == "http" { + if isAllowedHost, allowHostErr := isAllowedHost(host); allowHostErr != nil { + resolveErr = fmt.Errorf("failed to resolve host %q, %v", host, allowHostErr) + } else if !isAllowedHost { + resolveErr = fmt.Errorf("invalid endpoint host, %q, only loopback/ecs/eks hosts are allowed", host) + } } } @@ -252,6 +309,16 @@ func resolveHTTPCredProvider(ctx context.Context, cfg *aws.Config, url, authToke if len(authToken) != 0 { options.AuthorizationToken = authToken } + if authFilePath := os.Getenv(httpProviderAuthFileEnvVar); authFilePath != "" { + options.AuthorizationTokenProvider = endpointcreds.TokenProviderFunc(func() (string, error) { + var contents []byte + var err error + if contents, err = ioutil.ReadFile(authFilePath); err != nil { + return "", fmt.Errorf("failed to read authorization token from %v: %v", authFilePath, err) + } + return string(contents), nil + }) + } options.APIOptions = cfg.APIOptions if cfg.Retryer != nil { options.Retryer = cfg.Retryer() diff --git a/credentials/endpointcreds/provider.go b/credentials/endpointcreds/provider.go index adc7fc6b000..0c3c4d68266 100644 --- a/credentials/endpointcreds/provider.go +++ b/credentials/endpointcreds/provider.go @@ -36,6 +36,7 @@ import ( "context" "fmt" "net/http" + "strings" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client" @@ -81,7 +82,37 @@ type Options struct { // Optional authorization token value if set will be used as the value of // the Authorization header of the endpoint credential request. + // + // When constructed from environment, the provider will use the value of + // AWS_CONTAINER_AUTHORIZATION_TOKEN environment variable as the token + // + // Will be overridden if AuthorizationTokenProvider is configured AuthorizationToken string + + // Optional auth provider func to dynamically load the auth token from a file + // everytime a credential is retrieved + // + // When constructed from environment, the provider will read and use the content + // of the file pointed to by AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE environment variable + // as the auth token everytime credentials are retrieved + // + // Will override AuthorizationToken if configured + AuthorizationTokenProvider AuthTokenProvider +} + +// AuthTokenProvider defines an interface to dynamically load a value to be passed +// for the Authorization header of a credentials request. +type AuthTokenProvider interface { + GetToken() (string, error) +} + +// TokenProviderFunc is a func type implementing AuthTokenProvider interface +// and enables customizing token provider behavior +type TokenProviderFunc func() (string, error) + +// GetToken func retrieves auth token according to TokenProviderFunc implementation +func (p TokenProviderFunc) GetToken() (string, error) { + return p() } // New returns a credentials Provider for retrieving AWS credentials @@ -132,5 +163,30 @@ func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) { } func (p *Provider) getCredentials(ctx context.Context) (*client.GetCredentialsOutput, error) { - return p.client.GetCredentials(ctx, &client.GetCredentialsInput{AuthorizationToken: p.options.AuthorizationToken}) + authToken, err := p.resolveAuthToken() + if err != nil { + return nil, fmt.Errorf("resolve auth token: %v", err) + } + + return p.client.GetCredentials(ctx, &client.GetCredentialsInput{ + AuthorizationToken: authToken, + }) +} + +func (p *Provider) resolveAuthToken() (string, error) { + authToken := p.options.AuthorizationToken + + var err error + if p.options.AuthorizationTokenProvider != nil { + authToken, err = p.options.AuthorizationTokenProvider.GetToken() + if err != nil { + return "", err + } + } + + if strings.ContainsAny(authToken, "\r\n") { + return "", fmt.Errorf("authorization token contains invalid newline sequence") + } + + return authToken, nil } diff --git a/credentials/endpointcreds/provider_test.go b/credentials/endpointcreds/provider_test.go index 15a539518af..6f39ea4899c 100644 --- a/credentials/endpointcreds/provider_test.go +++ b/credentials/endpointcreds/provider_test.go @@ -108,6 +108,90 @@ func TestRetrieveStaticCredentials(t *testing.T) { } } +func TestAuthTokenProvider(t *testing.T) { + cases := map[string]struct { + AuthToken string + AuthTokenProvider endpointcreds.AuthTokenProvider + ExpectAuthToken string + ExpectError bool + }{ + "AuthToken": { + AuthToken: "Basic abc123", + ExpectAuthToken: "Basic abc123", + }, + "AuthFileToken": { + AuthToken: "Basic abc123", + AuthTokenProvider: endpointcreds.TokenProviderFunc(func() (string, error) { + return "Hello %20world", nil + }), + ExpectAuthToken: "Hello %20world", + }, + "RetrieveFileTokenError": { + AuthToken: "Basic abc123", + AuthTokenProvider: endpointcreds.TokenProviderFunc(func() (string, error) { + return "", fmt.Errorf("test error") + }), + ExpectAuthToken: "Hello %20world", + ExpectError: true, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + orig := sdk.NowTime + defer func() { sdk.NowTime = orig }() + + var actualToken string + p := endpointcreds.New("http://127.0.0.1", func(o *endpointcreds.Options) { + o.HTTPClient = mockClient(func(r *http.Request) (*http.Response, error) { + actualToken = r.Header["Authorization"][0] + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewReader([]byte(`{ + "AccessKeyID": "AKID", + "SecretAccessKey": "SECRET" +}`))), + }, nil + }) + o.AuthorizationToken = c.AuthToken + o.AuthorizationTokenProvider = c.AuthTokenProvider + }) + creds, err := p.Retrieve(context.Background()) + + if err != nil && !c.ExpectError { + t.Errorf("expect no error, got %v", err) + } else if err == nil && c.ExpectError { + t.Errorf("expect error, got nil") + } + + if c.ExpectError { + return + } + + if e, a := "AKID", creds.AccessKeyID; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := "SECRET", creds.SecretAccessKey; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if v := creds.SessionToken; len(v) != 0 { + t.Errorf("expect empty, got %v", v) + } + if e, a := c.ExpectAuthToken, actualToken; e != a { + t.Errorf("Expect %v, got %v", e, a) + } + + sdk.NowTime = func() time.Time { + return time.Date(3000, 12, 16, 1, 30, 37, 0, time.UTC) + } + + if creds.Expired() { + t.Errorf("expect not to be expired") + } + }) + } +} + func TestFailedRetrieveCredentials(t *testing.T) { p := endpointcreds.New("http://127.0.0.1", func(o *endpointcreds.Options) { o.HTTPClient = mockClient(func(r *http.Request) (*http.Response, error) {