Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify http cred provider resolving logic #2259

Merged
merged 20 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .changelog/0593bfc1f00841febcb73c7957839f01.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"id": "0593bfc1-f008-41fe-bcb7-3c7957839f01",
"type": "feature",
"collapse": true,
"description": "Modify http cred provider logic to support ecs/eks host and auth token file",
lucix-aws marked this conversation as resolved.
Show resolved Hide resolved
"modules": [
"config",
"credentials"
]
}
81 changes: 74 additions & 7 deletions config/resolve_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package config
import (
"context"
"fmt"
"io/ioutil"
"net"
"net/url"
"os"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
Expand All @@ -21,11 +24,33 @@ import (

const (
// valid credential source values
credSourceEc2Metadata = "Ec2InstanceMetadata"
credSourceEnvironment = "Environment"
credSourceECSContainer = "EcsContainer"
credSourceEc2Metadata = "Ec2InstanceMetadata"
credSourceEnvironment = "Environment"
credSourceECSContainer = "EcsContainer"
httpProviderAuthFileEnvVar = "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE"
)

// direct representation of the IPv4 address for the ECS container
// "169.254.170.2"
var ecsContainerIPv4 net.IP = []byte{
169, 254, 170, 2,
}

// direct representation of the IPv4 address for the EKS container
// "169.254.170.23"
var eksContainerIPv4 net.IP = []byte{
169, 254, 170, 23,
}

// direct representation of the IPv6 address for the EKS container
// "fd00:ec2::23"
var eksContainerIPv6 net.IP = []byte{
0xFD, 0, 0xE, 0xC2,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0x23,
}

var (
ecsContainerEndpoint = "http://169.254.170.2" // not constant to allow for swapping during unit-testing
)
Expand Down Expand Up @@ -222,6 +247,36 @@ func processCredentials(ctx context.Context, cfg *aws.Config, sharedConfig *Shar
return nil
}

// isAllowedHost allows host to be loopback or known ECS/EKS container IPs
//
// host can either be an IP address OR an unresolved hostname - resolution will
// be automatically performed in the latter case
func isAllowedHost(host string) (bool, error) {
if ip := net.ParseIP(host); ip != nil {
return isIPAllowed(ip), nil
}

addrs, err := lookupHostFn(host)
if err != nil {
return false, err
}

for _, addr := range addrs {
if ip := net.ParseIP(addr); ip == nil || !isIPAllowed(ip) {
return false, nil
}
}

return true, nil
}

func isIPAllowed(ip net.IP) bool {
return ip.IsLoopback() ||
ip.Equal(ecsContainerIPv4) ||
ip.Equal(eksContainerIPv4) ||
ip.Equal(eksContainerIPv6)
}

func resolveLocalHTTPCredProvider(ctx context.Context, cfg *aws.Config, endpointURL, authToken string, configs configs) error {
var resolveErr error

Expand All @@ -232,10 +287,12 @@ func resolveLocalHTTPCredProvider(ctx context.Context, cfg *aws.Config, endpoint
host := parsed.Hostname()
if len(host) == 0 {
resolveErr = fmt.Errorf("unable to parse host from local HTTP cred provider URL")
} else if isLoopback, loopbackErr := isLoopbackHost(host); loopbackErr != nil {
resolveErr = fmt.Errorf("failed to resolve host %q, %v", host, loopbackErr)
} else if !isLoopback {
resolveErr = fmt.Errorf("invalid endpoint host, %q, only loopback hosts are allowed", host)
} else if parsed.Scheme == "http" {
if isAllowedHost, allowHostErr := isAllowedHost(host); allowHostErr != nil {
resolveErr = fmt.Errorf("failed to resolve host %q, %v", host, allowHostErr)
} else if !isAllowedHost {
resolveErr = fmt.Errorf("invalid endpoint host, %q, only loopback/ecs/eks hosts are allowed", host)
}
}
}

Expand All @@ -252,6 +309,16 @@ func resolveHTTPCredProvider(ctx context.Context, cfg *aws.Config, url, authToke
if len(authToken) != 0 {
options.AuthorizationToken = authToken
}
if authFilePath := os.Getenv(httpProviderAuthFileEnvVar); authFilePath != "" {
options.AuthorizationTokenProvider = endpointcreds.TokenProviderFunc(func() (string, error) {
var contents []byte
var err error
if contents, err = ioutil.ReadFile(authFilePath); err != nil {
return "", fmt.Errorf("failed to read authorization token from %v: %v", authFilePath, err)
}
return string(contents), nil
})
}
options.APIOptions = cfg.APIOptions
if cfg.Retryer != nil {
options.Retryer = cfg.Retryer()
Expand Down
58 changes: 57 additions & 1 deletion credentials/endpointcreds/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"context"
"fmt"
"net/http"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client"
Expand Down Expand Up @@ -81,7 +82,37 @@ type Options struct {

// Optional authorization token value if set will be used as the value of
// the Authorization header of the endpoint credential request.
//
// When constructed from environment, the provider will use the value of
// AWS_CONTAINER_AUTHORIZATION_TOKEN environment variable as the token
//
// Will be overridden if AuthorizationTokenProvider is configured
AuthorizationToken string

// Optional auth provider func to dynamically load the auth token from a file
// everytime a credential is retrieved
//
// When constructed from environment, the provider will read and use the content
// of the file pointed to by AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE environment variable
// as the auth token everytime credentials are retrieved
//
// Will override AuthorizationToken if configured
AuthorizationTokenProvider AuthTokenProvider
}

// AuthTokenProvider defines an interface to dynamically load a value to be passed
// for the Authorization header of a credentials request.
type AuthTokenProvider interface {
GetToken() (string, error)
}

// TokenProviderFunc is a func type implementing AuthTokenProvider interface
// and enables customizing token provider behavior
type TokenProviderFunc func() (string, error)

// GetToken func retrieves auth token according to TokenProviderFunc implementation
func (p TokenProviderFunc) GetToken() (string, error) {
return p()
}

// New returns a credentials Provider for retrieving AWS credentials
Expand Down Expand Up @@ -132,5 +163,30 @@ func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
}

func (p *Provider) getCredentials(ctx context.Context) (*client.GetCredentialsOutput, error) {
return p.client.GetCredentials(ctx, &client.GetCredentialsInput{AuthorizationToken: p.options.AuthorizationToken})
authToken, err := p.resolveAuthToken()
if err != nil {
return nil, fmt.Errorf("resolve auth token: %v", err)
}

return p.client.GetCredentials(ctx, &client.GetCredentialsInput{
AuthorizationToken: authToken,
})
}

func (p *Provider) resolveAuthToken() (string, error) {
authToken := p.options.AuthorizationToken
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend you compartmentalize this-- e.g.

authToken, err := p.resolveAuthToken()
if err != nil {
    return nil, fmt.Errorf("resolve auth token: %v", err)
}


var err error
if p.options.AuthorizationTokenProvider != nil {
authToken, err = p.options.AuthorizationTokenProvider.GetToken()
if err != nil {
return "", err
}
}

if strings.ContainsAny(authToken, "\r\n") {
return "", fmt.Errorf("authorization token contains invalid newline sequence")
}

return authToken, nil
}
84 changes: 84 additions & 0 deletions credentials/endpointcreds/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,90 @@ func TestRetrieveStaticCredentials(t *testing.T) {
}
}

func TestAuthTokenProvider(t *testing.T) {
cases := map[string]struct {
AuthToken string
AuthTokenProvider endpointcreds.AuthTokenProvider
ExpectAuthToken string
ExpectError bool
}{
"AuthToken": {
AuthToken: "Basic abc123",
ExpectAuthToken: "Basic abc123",
},
"AuthFileToken": {
AuthToken: "Basic abc123",
AuthTokenProvider: endpointcreds.TokenProviderFunc(func() (string, error) {
return "Hello %20world", nil
}),
ExpectAuthToken: "Hello %20world",
},
"RetrieveFileTokenError": {
AuthToken: "Basic abc123",
AuthTokenProvider: endpointcreds.TokenProviderFunc(func() (string, error) {
return "", fmt.Errorf("test error")
}),
ExpectAuthToken: "Hello %20world",
ExpectError: true,
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
orig := sdk.NowTime
defer func() { sdk.NowTime = orig }()

var actualToken string
p := endpointcreds.New("http://127.0.0.1", func(o *endpointcreds.Options) {
o.HTTPClient = mockClient(func(r *http.Request) (*http.Response, error) {
actualToken = r.Header["Authorization"][0]
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader([]byte(`{
"AccessKeyID": "AKID",
"SecretAccessKey": "SECRET"
}`))),
}, nil
})
o.AuthorizationToken = c.AuthToken
o.AuthorizationTokenProvider = c.AuthTokenProvider
})
creds, err := p.Retrieve(context.Background())

if err != nil && !c.ExpectError {
t.Errorf("expect no error, got %v", err)
} else if err == nil && c.ExpectError {
t.Errorf("expect error, got nil")
}

if c.ExpectError {
return
}

if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
if e, a := c.ExpectAuthToken, actualToken; e != a {
t.Errorf("Expect %v, got %v", e, a)
}

sdk.NowTime = func() time.Time {
return time.Date(3000, 12, 16, 1, 30, 37, 0, time.UTC)
}

if creds.Expired() {
t.Errorf("expect not to be expired")
}
})
}
}

func TestFailedRetrieveCredentials(t *testing.T) {
p := endpointcreds.New("http://127.0.0.1", func(o *endpointcreds.Options) {
o.HTTPClient = mockClient(func(r *http.Request) (*http.Response, error) {
Expand Down