Skip to content

Commit

Permalink
Adding local file cache for credential helper
Browse files Browse the repository at this point in the history
This provides a JSON file cache that's serialized/deserialized
for each access. The local file cache tokens are kept for the half-life
of the token lifetime (i.e. a 12 hour token will get expired after 6
hours). Users may also opt-out of local cache with an environment
variable set of `AWS_ECR_DISABLE_CACHE`.
  • Loading branch information
sentientmonkey committed Jul 1, 2016
1 parent 10b5dec commit 192a492
Show file tree
Hide file tree
Showing 10 changed files with 706 additions and 17 deletions.
51 changes: 43 additions & 8 deletions ecr-login/api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ import (
"encoding/base64"
"fmt"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ecr"
"github.com/aws/aws-sdk-go/service/ecr/ecriface"
"github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cache"
log "github.com/cihub/seelog"
)

Expand All @@ -30,34 +32,67 @@ type Client interface {
GetCredentials(registry, image string) (string, string, error)
}
type defaultClient struct {
ecrClient ecriface.ECRAPI
ecrClient ecriface.ECRAPI
credentialCache cache.CredentialsCache
}

func (self *defaultClient) GetCredentials(registry, image string) (string, string, error) {
log.Debugf("GetCredentials for %s", registry)

cachedEntry := self.credentialCache.Get(registry)

if cachedEntry != nil {
if cachedEntry.IsValid(time.Now()) {
log.Debugf("Using cached token for %s", registry)
return extractToken(cachedEntry.AuthorizationToken)
} else {
log.Debugf("Cached token is no longer valid. RequestAt: %s, ExpiresAt: %s", cachedEntry.RequestedAt, cachedEntry.ExpiresAt)
}
}

log.Debugf("Calling ECR.GetAuthorizationToken for %s", registry)

input := &ecr.GetAuthorizationTokenInput{
RegistryIds: []*string{aws.String(registry)},
}

output, err := self.ecrClient.GetAuthorizationToken(input)
if err != nil {

if err != nil || output == nil {
if err == nil {
err = fmt.Errorf("Missing AuthorizationData in ECR response for %s", registry)
}

// if we have a cached token, fall back to avoid failing the request. This may result an expired token
// being returned, but if there is a 500 or timeout from the service side, we'd like to attempt to re-use an
// old token. We invalidate tokens prior to their expiration date to help mitigate this scenario.
if cachedEntry != nil {
log.Infof("Got error fetching authorization token. Falling back to cached token. Error was: %s", err)
return extractToken(cachedEntry.AuthorizationToken)
}

return "", "", err
}
if output == nil {
return "", "", fmt.Errorf("Missing AuthorizationData in ECR response for %s", registry)
}
for _, authData := range output.AuthorizationData {
if authData.ProxyEndpoint != nil &&
strings.HasPrefix(proxyEndpointScheme+image, aws.StringValue(authData.ProxyEndpoint)) &&
authData.AuthorizationToken != nil {
return extractToken(authData)
authEntry := cache.AuthEntry{
AuthorizationToken: aws.StringValue(authData.AuthorizationToken),
RequestedAt: time.Now(),
ExpiresAt: aws.TimeValue(authData.ExpiresAt),
ProxyEndpoint: aws.StringValue(authData.ProxyEndpoint),
}

self.credentialCache.Set(registry, &authEntry)
return extractToken(aws.StringValue(authData.AuthorizationToken))
}
}
return "", "", fmt.Errorf("No AuthorizationToken found for %s", registry)
}

func extractToken(authData *ecr.AuthorizationData) (string, string, error) {
decodedToken, err := base64.StdEncoding.DecodeString(aws.StringValue(authData.AuthorizationToken))
func extractToken(token string) (string, string, error) {
decodedToken, err := base64.StdEncoding.DecodeString(token)
if err != nil {
return "", "", err
}
Expand Down
188 changes: 181 additions & 7 deletions ecr-login/api/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ import (
"encoding/base64"
"errors"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ecr"
"github.com/awslabs/amazon-ecr-credential-helper/ecr-login/api/mocks"
"github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cache"
"github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cache/mocks"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
Expand All @@ -36,11 +39,17 @@ func TestGetAuthConfigSuccess(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ecrClient := mock_ecriface.NewMockECRAPI(ctrl)
credentialCache := mock_cache.NewMockCredentialsCache(ctrl)

client := &defaultClient{
ecrClient: ecrClient,
ecrClient: ecrClient,
credentialCache: credentialCache,
}

testProxyEndpoint := proxyEndpointScheme + proxyEndpoint
authorizationToken := base64.StdEncoding.EncodeToString([]byte(expectedUsername + ":" + expectedPassword))
expiresAt := time.Now().Add(12 * time.Hour)

ecrClient.EXPECT().GetAuthorizationToken(gomock.Any()).Do(
func(input *ecr.GetAuthorizationTokenInput) {
if input == nil {
Expand All @@ -52,12 +61,26 @@ func TestGetAuthConfigSuccess(t *testing.T) {
}).Return(&ecr.GetAuthorizationTokenOutput{
AuthorizationData: []*ecr.AuthorizationData{
&ecr.AuthorizationData{
ProxyEndpoint: aws.String(proxyEndpointScheme + proxyEndpoint),
AuthorizationToken: aws.String(base64.StdEncoding.EncodeToString([]byte(expectedUsername + ":" + expectedPassword))),
ProxyEndpoint: aws.String(testProxyEndpoint),
ExpiresAt: aws.Time(expiresAt),
AuthorizationToken: aws.String(authorizationToken),
},
},
}, nil)

authEntry := &cache.AuthEntry{
ProxyEndpoint: testProxyEndpoint,
RequestedAt: time.Now(),
ExpiresAt: expiresAt,
AuthorizationToken: authorizationToken,
}

credentialCache.EXPECT().Get(registryID).Return(nil)
credentialCache.EXPECT().Set(registryID, gomock.Any()).Do(
func(_ string, actual *cache.AuthEntry) {
compareAuthEntry(t, actual, authEntry)
})

username, password, err := client.GetCredentials(registryID, proxyEndpoint+"/myimage")
assert.Nil(t, err)
assert.Equal(t, username, expectedUsername)
Expand All @@ -68,9 +91,11 @@ func TestGetAuthConfigNoMatchAuthorizationToken(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ecrClient := mock_ecriface.NewMockECRAPI(ctrl)
credentialCache := mock_cache.NewMockCredentialsCache(ctrl)

client := &defaultClient{
ecrClient: ecrClient,
ecrClient: ecrClient,
credentialCache: credentialCache,
}

ecrClient.EXPECT().GetAuthorizationToken(gomock.Any()).Do(
Expand All @@ -90,20 +115,112 @@ func TestGetAuthConfigNoMatchAuthorizationToken(t *testing.T) {
},
}, nil)

credentialCache.EXPECT().Get(registryID).Return(nil)

username, password, err := client.GetCredentials(registryID, proxyEndpoint+"/myimage")
assert.NotNil(t, err)
t.Log(err)
assert.Empty(t, username)
assert.Empty(t, password)
}

func TestGetAuthConfigGetCacheSuccess(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ecrClient := mock_ecriface.NewMockECRAPI(ctrl)
credentialCache := mock_cache.NewMockCredentialsCache(ctrl)

client := &defaultClient{
ecrClient: ecrClient,
credentialCache: credentialCache,
}

testProxyEndpoint := proxyEndpointScheme + proxyEndpoint
authorizationToken := base64.StdEncoding.EncodeToString([]byte(expectedUsername + ":" + expectedPassword))
expiresAt := time.Now().Add(12 * time.Hour)

authEntry := &cache.AuthEntry{
ProxyEndpoint: testProxyEndpoint,
ExpiresAt: expiresAt,
AuthorizationToken: authorizationToken,
}

credentialCache.EXPECT().Get(registryID).Return(authEntry)

username, password, err := client.GetCredentials(registryID, proxyEndpoint+"/myimage")
assert.Nil(t, err)
assert.Equal(t, username, expectedUsername)
assert.Equal(t, password, expectedPassword)
}

func TestGetAuthConfigSuccessInvalidCacheHit(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ecrClient := mock_ecriface.NewMockECRAPI(ctrl)
credentialCache := mock_cache.NewMockCredentialsCache(ctrl)

client := &defaultClient{
ecrClient: ecrClient,
credentialCache: credentialCache,
}

testProxyEndpoint := proxyEndpointScheme + proxyEndpoint
authorizationToken := base64.StdEncoding.EncodeToString([]byte(expectedUsername + ":" + expectedPassword))
expiresAt := time.Now().Add(12 * time.Hour)

ecrClient.EXPECT().GetAuthorizationToken(gomock.Any()).Do(
func(input *ecr.GetAuthorizationTokenInput) {
if input == nil {
t.Fatal("Called with nil input")
}
if len(input.RegistryIds) != 1 {
t.Fatalf("Unexpected number of RegistryIds, expected 1 but got %d", len(input.RegistryIds))
}
}).Return(&ecr.GetAuthorizationTokenOutput{
AuthorizationData: []*ecr.AuthorizationData{
&ecr.AuthorizationData{
ProxyEndpoint: aws.String(testProxyEndpoint),
ExpiresAt: aws.Time(expiresAt),
AuthorizationToken: aws.String(authorizationToken),
},
},
}, nil)

expiredAuthEntry := &cache.AuthEntry{
ProxyEndpoint: testProxyEndpoint,
RequestedAt: time.Now().Add(-12 * time.Hour),
ExpiresAt: time.Now().Add(-6 * time.Hour),
AuthorizationToken: authorizationToken,
}

authEntry := &cache.AuthEntry{
ProxyEndpoint: testProxyEndpoint,
RequestedAt: time.Now(),
ExpiresAt: expiresAt,
AuthorizationToken: authorizationToken,
}

credentialCache.EXPECT().Get(registryID).Return(expiredAuthEntry)
credentialCache.EXPECT().Set(registryID, gomock.Any()).Do(
func(_ string, actual *cache.AuthEntry) {
compareAuthEntry(t, actual, authEntry)
})

username, password, err := client.GetCredentials(registryID, proxyEndpoint+"/myimage")
assert.Nil(t, err)
assert.Equal(t, username, expectedUsername)
assert.Equal(t, password, expectedPassword)
}

func TestGetAuthConfigBadBase64(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ecrClient := mock_ecriface.NewMockECRAPI(ctrl)
credentialCache := mock_cache.NewMockCredentialsCache(ctrl)

client := &defaultClient{
ecrClient: ecrClient,
ecrClient: ecrClient,
credentialCache: credentialCache,
}

ecrClient.EXPECT().GetAuthorizationToken(gomock.Any()).Do(
Expand All @@ -123,6 +240,8 @@ func TestGetAuthConfigBadBase64(t *testing.T) {
},
}, nil)

credentialCache.EXPECT().Get(registryID).Return(nil)

username, password, err := client.GetCredentials(registryID, proxyEndpoint+"/myimage")
assert.NotNil(t, err)
t.Log(err)
Expand All @@ -134,9 +253,11 @@ func TestGetAuthConfigMissingResponse(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ecrClient := mock_ecriface.NewMockECRAPI(ctrl)
credentialCache := mock_cache.NewMockCredentialsCache(ctrl)

client := &defaultClient{
ecrClient: ecrClient,
ecrClient: ecrClient,
credentialCache: credentialCache,
}

ecrClient.EXPECT().GetAuthorizationToken(gomock.Any()).Do(
Expand All @@ -149,6 +270,8 @@ func TestGetAuthConfigMissingResponse(t *testing.T) {
}
})

credentialCache.EXPECT().Get(registryID).Return(nil)

username, password, err := client.GetCredentials(registryID, proxyEndpoint+"/myimage")
assert.NotNil(t, err)
t.Log(err)
Expand All @@ -160,9 +283,11 @@ func TestGetAuthConfigECRError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ecrClient := mock_ecriface.NewMockECRAPI(ctrl)
credentialCache := mock_cache.NewMockCredentialsCache(ctrl)

client := &defaultClient{
ecrClient: ecrClient,
ecrClient: ecrClient,
credentialCache: credentialCache,
}

ecrClient.EXPECT().GetAuthorizationToken(gomock.Any()).Do(
Expand All @@ -175,9 +300,58 @@ func TestGetAuthConfigECRError(t *testing.T) {
}
}).Return(nil, errors.New("test error"))

credentialCache.EXPECT().Get(registryID).Return(nil)

username, password, err := client.GetCredentials(registryID, proxyEndpoint+"/myimage")
assert.NotNil(t, err)
t.Log(err)
assert.Empty(t, username)
assert.Empty(t, password)
}

func TestGetAuthConfigSuccessInvalidCacheHitFallback(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ecrClient := mock_ecriface.NewMockECRAPI(ctrl)
credentialCache := mock_cache.NewMockCredentialsCache(ctrl)

client := &defaultClient{
ecrClient: ecrClient,
credentialCache: credentialCache,
}

testProxyEndpoint := proxyEndpointScheme + proxyEndpoint
authorizationToken := base64.StdEncoding.EncodeToString([]byte(expectedUsername + ":" + expectedPassword))

ecrClient.EXPECT().GetAuthorizationToken(gomock.Any()).Do(
func(input *ecr.GetAuthorizationTokenInput) {
if input == nil {
t.Fatal("Called with nil input")
}
if len(input.RegistryIds) != 1 {
t.Fatalf("Unexpected number of RegistryIds, expected 1 but got %d", len(input.RegistryIds))
}
}).Return(nil, errors.New("Service eror"))

expiredAuthEntry := &cache.AuthEntry{
ProxyEndpoint: testProxyEndpoint,
RequestedAt: time.Now().Add(-12 * time.Hour),
ExpiresAt: time.Now().Add(-6 * time.Hour),
AuthorizationToken: authorizationToken,
}

credentialCache.EXPECT().Get(registryID).Return(expiredAuthEntry)

username, password, err := client.GetCredentials(registryID, proxyEndpoint+"/myimage")
assert.Nil(t, err)
assert.Equal(t, username, expectedUsername)
assert.Equal(t, password, expectedPassword)
}

func compareAuthEntry(t *testing.T, actual *cache.AuthEntry, expected *cache.AuthEntry) {
assert.NotNil(t, actual)
assert.Equal(t, expected.AuthorizationToken, actual.AuthorizationToken)
assert.Equal(t, expected.ProxyEndpoint, actual.ProxyEndpoint)
assert.Equal(t, expected.ExpiresAt, actual.ExpiresAt)
assert.WithinDuration(t, expected.RequestedAt, actual.RequestedAt, 5*time.Second)
}
Loading

0 comments on commit 192a492

Please sign in to comment.