From 192a4928ba97539621c6f0cffcc21d6853bb5921 Mon Sep 17 00:00:00 2001 From: Scott Windsor Date: Mon, 6 Jun 2016 23:56:00 -0700 Subject: [PATCH] Adding local file cache for credential helper This provides a JSON file cache that's serialized/deserialized for each access. The local file cache tokens are kept for the half-life of the token lifetime (i.e. a 12 hour token will get expired after 6 hours). Users may also opt-out of local cache with an environment variable set of `AWS_ECR_DISABLE_CACHE`. --- ecr-login/api/client.go | 51 ++++++-- ecr-login/api/client_test.go | 188 ++++++++++++++++++++++++++- ecr-login/api/factory.go | 55 +++++++- ecr-login/cache/credentials.go | 36 +++++ ecr-login/cache/file.go | 153 ++++++++++++++++++++++ ecr-login/cache/file_test.go | 91 +++++++++++++ ecr-login/cache/generate_mocks.go | 16 +++ ecr-login/cache/mocks/cache_mocks.go | 69 ++++++++++ ecr-login/cache/null.go | 30 +++++ ecr-login/cache/null_test.go | 34 +++++ 10 files changed, 706 insertions(+), 17 deletions(-) create mode 100644 ecr-login/cache/credentials.go create mode 100644 ecr-login/cache/file.go create mode 100644 ecr-login/cache/file_test.go create mode 100644 ecr-login/cache/generate_mocks.go create mode 100644 ecr-login/cache/mocks/cache_mocks.go create mode 100644 ecr-login/cache/null.go create mode 100644 ecr-login/cache/null_test.go diff --git a/ecr-login/api/client.go b/ecr-login/api/client.go index 89a7989d..560add7b 100644 --- a/ecr-login/api/client.go +++ b/ecr-login/api/client.go @@ -17,10 +17,12 @@ import ( "encoding/base64" "fmt" "strings" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ecr" "github.com/aws/aws-sdk-go/service/ecr/ecriface" + "github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cache" log "github.com/cihub/seelog" ) @@ -30,34 +32,67 @@ type Client interface { GetCredentials(registry, image string) (string, string, error) } type defaultClient struct { - ecrClient ecriface.ECRAPI + ecrClient ecriface.ECRAPI + credentialCache cache.CredentialsCache } func (self *defaultClient) GetCredentials(registry, image string) (string, string, error) { + log.Debugf("GetCredentials for %s", registry) + + cachedEntry := self.credentialCache.Get(registry) + + if cachedEntry != nil { + if cachedEntry.IsValid(time.Now()) { + log.Debugf("Using cached token for %s", registry) + return extractToken(cachedEntry.AuthorizationToken) + } else { + log.Debugf("Cached token is no longer valid. RequestAt: %s, ExpiresAt: %s", cachedEntry.RequestedAt, cachedEntry.ExpiresAt) + } + } + log.Debugf("Calling ECR.GetAuthorizationToken for %s", registry) + input := &ecr.GetAuthorizationTokenInput{ RegistryIds: []*string{aws.String(registry)}, } output, err := self.ecrClient.GetAuthorizationToken(input) - if err != nil { + + if err != nil || output == nil { + if err == nil { + err = fmt.Errorf("Missing AuthorizationData in ECR response for %s", registry) + } + + // if we have a cached token, fall back to avoid failing the request. This may result an expired token + // being returned, but if there is a 500 or timeout from the service side, we'd like to attempt to re-use an + // old token. We invalidate tokens prior to their expiration date to help mitigate this scenario. + if cachedEntry != nil { + log.Infof("Got error fetching authorization token. Falling back to cached token. Error was: %s", err) + return extractToken(cachedEntry.AuthorizationToken) + } + return "", "", err } - if output == nil { - return "", "", fmt.Errorf("Missing AuthorizationData in ECR response for %s", registry) - } for _, authData := range output.AuthorizationData { if authData.ProxyEndpoint != nil && strings.HasPrefix(proxyEndpointScheme+image, aws.StringValue(authData.ProxyEndpoint)) && authData.AuthorizationToken != nil { - return extractToken(authData) + authEntry := cache.AuthEntry{ + AuthorizationToken: aws.StringValue(authData.AuthorizationToken), + RequestedAt: time.Now(), + ExpiresAt: aws.TimeValue(authData.ExpiresAt), + ProxyEndpoint: aws.StringValue(authData.ProxyEndpoint), + } + + self.credentialCache.Set(registry, &authEntry) + return extractToken(aws.StringValue(authData.AuthorizationToken)) } } return "", "", fmt.Errorf("No AuthorizationToken found for %s", registry) } -func extractToken(authData *ecr.AuthorizationData) (string, string, error) { - decodedToken, err := base64.StdEncoding.DecodeString(aws.StringValue(authData.AuthorizationToken)) +func extractToken(token string) (string, string, error) { + decodedToken, err := base64.StdEncoding.DecodeString(token) if err != nil { return "", "", err } diff --git a/ecr-login/api/client_test.go b/ecr-login/api/client_test.go index 36885ece..f5061fc9 100644 --- a/ecr-login/api/client_test.go +++ b/ecr-login/api/client_test.go @@ -17,10 +17,13 @@ import ( "encoding/base64" "errors" "testing" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ecr" "github.com/awslabs/amazon-ecr-credential-helper/ecr-login/api/mocks" + "github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cache" + "github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cache/mocks" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" ) @@ -36,11 +39,17 @@ func TestGetAuthConfigSuccess(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() ecrClient := mock_ecriface.NewMockECRAPI(ctrl) + credentialCache := mock_cache.NewMockCredentialsCache(ctrl) client := &defaultClient{ - ecrClient: ecrClient, + ecrClient: ecrClient, + credentialCache: credentialCache, } + testProxyEndpoint := proxyEndpointScheme + proxyEndpoint + authorizationToken := base64.StdEncoding.EncodeToString([]byte(expectedUsername + ":" + expectedPassword)) + expiresAt := time.Now().Add(12 * time.Hour) + ecrClient.EXPECT().GetAuthorizationToken(gomock.Any()).Do( func(input *ecr.GetAuthorizationTokenInput) { if input == nil { @@ -52,12 +61,26 @@ func TestGetAuthConfigSuccess(t *testing.T) { }).Return(&ecr.GetAuthorizationTokenOutput{ AuthorizationData: []*ecr.AuthorizationData{ &ecr.AuthorizationData{ - ProxyEndpoint: aws.String(proxyEndpointScheme + proxyEndpoint), - AuthorizationToken: aws.String(base64.StdEncoding.EncodeToString([]byte(expectedUsername + ":" + expectedPassword))), + ProxyEndpoint: aws.String(testProxyEndpoint), + ExpiresAt: aws.Time(expiresAt), + AuthorizationToken: aws.String(authorizationToken), }, }, }, nil) + authEntry := &cache.AuthEntry{ + ProxyEndpoint: testProxyEndpoint, + RequestedAt: time.Now(), + ExpiresAt: expiresAt, + AuthorizationToken: authorizationToken, + } + + credentialCache.EXPECT().Get(registryID).Return(nil) + credentialCache.EXPECT().Set(registryID, gomock.Any()).Do( + func(_ string, actual *cache.AuthEntry) { + compareAuthEntry(t, actual, authEntry) + }) + username, password, err := client.GetCredentials(registryID, proxyEndpoint+"/myimage") assert.Nil(t, err) assert.Equal(t, username, expectedUsername) @@ -68,9 +91,11 @@ func TestGetAuthConfigNoMatchAuthorizationToken(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() ecrClient := mock_ecriface.NewMockECRAPI(ctrl) + credentialCache := mock_cache.NewMockCredentialsCache(ctrl) client := &defaultClient{ - ecrClient: ecrClient, + ecrClient: ecrClient, + credentialCache: credentialCache, } ecrClient.EXPECT().GetAuthorizationToken(gomock.Any()).Do( @@ -90,6 +115,8 @@ func TestGetAuthConfigNoMatchAuthorizationToken(t *testing.T) { }, }, nil) + credentialCache.EXPECT().Get(registryID).Return(nil) + username, password, err := client.GetCredentials(registryID, proxyEndpoint+"/myimage") assert.NotNil(t, err) t.Log(err) @@ -97,13 +124,103 @@ func TestGetAuthConfigNoMatchAuthorizationToken(t *testing.T) { assert.Empty(t, password) } +func TestGetAuthConfigGetCacheSuccess(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ecrClient := mock_ecriface.NewMockECRAPI(ctrl) + credentialCache := mock_cache.NewMockCredentialsCache(ctrl) + + client := &defaultClient{ + ecrClient: ecrClient, + credentialCache: credentialCache, + } + + testProxyEndpoint := proxyEndpointScheme + proxyEndpoint + authorizationToken := base64.StdEncoding.EncodeToString([]byte(expectedUsername + ":" + expectedPassword)) + expiresAt := time.Now().Add(12 * time.Hour) + + authEntry := &cache.AuthEntry{ + ProxyEndpoint: testProxyEndpoint, + ExpiresAt: expiresAt, + AuthorizationToken: authorizationToken, + } + + credentialCache.EXPECT().Get(registryID).Return(authEntry) + + username, password, err := client.GetCredentials(registryID, proxyEndpoint+"/myimage") + assert.Nil(t, err) + assert.Equal(t, username, expectedUsername) + assert.Equal(t, password, expectedPassword) +} + +func TestGetAuthConfigSuccessInvalidCacheHit(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ecrClient := mock_ecriface.NewMockECRAPI(ctrl) + credentialCache := mock_cache.NewMockCredentialsCache(ctrl) + + client := &defaultClient{ + ecrClient: ecrClient, + credentialCache: credentialCache, + } + + testProxyEndpoint := proxyEndpointScheme + proxyEndpoint + authorizationToken := base64.StdEncoding.EncodeToString([]byte(expectedUsername + ":" + expectedPassword)) + expiresAt := time.Now().Add(12 * time.Hour) + + ecrClient.EXPECT().GetAuthorizationToken(gomock.Any()).Do( + func(input *ecr.GetAuthorizationTokenInput) { + if input == nil { + t.Fatal("Called with nil input") + } + if len(input.RegistryIds) != 1 { + t.Fatalf("Unexpected number of RegistryIds, expected 1 but got %d", len(input.RegistryIds)) + } + }).Return(&ecr.GetAuthorizationTokenOutput{ + AuthorizationData: []*ecr.AuthorizationData{ + &ecr.AuthorizationData{ + ProxyEndpoint: aws.String(testProxyEndpoint), + ExpiresAt: aws.Time(expiresAt), + AuthorizationToken: aws.String(authorizationToken), + }, + }, + }, nil) + + expiredAuthEntry := &cache.AuthEntry{ + ProxyEndpoint: testProxyEndpoint, + RequestedAt: time.Now().Add(-12 * time.Hour), + ExpiresAt: time.Now().Add(-6 * time.Hour), + AuthorizationToken: authorizationToken, + } + + authEntry := &cache.AuthEntry{ + ProxyEndpoint: testProxyEndpoint, + RequestedAt: time.Now(), + ExpiresAt: expiresAt, + AuthorizationToken: authorizationToken, + } + + credentialCache.EXPECT().Get(registryID).Return(expiredAuthEntry) + credentialCache.EXPECT().Set(registryID, gomock.Any()).Do( + func(_ string, actual *cache.AuthEntry) { + compareAuthEntry(t, actual, authEntry) + }) + + username, password, err := client.GetCredentials(registryID, proxyEndpoint+"/myimage") + assert.Nil(t, err) + assert.Equal(t, username, expectedUsername) + assert.Equal(t, password, expectedPassword) +} + func TestGetAuthConfigBadBase64(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() ecrClient := mock_ecriface.NewMockECRAPI(ctrl) + credentialCache := mock_cache.NewMockCredentialsCache(ctrl) client := &defaultClient{ - ecrClient: ecrClient, + ecrClient: ecrClient, + credentialCache: credentialCache, } ecrClient.EXPECT().GetAuthorizationToken(gomock.Any()).Do( @@ -123,6 +240,8 @@ func TestGetAuthConfigBadBase64(t *testing.T) { }, }, nil) + credentialCache.EXPECT().Get(registryID).Return(nil) + username, password, err := client.GetCredentials(registryID, proxyEndpoint+"/myimage") assert.NotNil(t, err) t.Log(err) @@ -134,9 +253,11 @@ func TestGetAuthConfigMissingResponse(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() ecrClient := mock_ecriface.NewMockECRAPI(ctrl) + credentialCache := mock_cache.NewMockCredentialsCache(ctrl) client := &defaultClient{ - ecrClient: ecrClient, + ecrClient: ecrClient, + credentialCache: credentialCache, } ecrClient.EXPECT().GetAuthorizationToken(gomock.Any()).Do( @@ -149,6 +270,8 @@ func TestGetAuthConfigMissingResponse(t *testing.T) { } }) + credentialCache.EXPECT().Get(registryID).Return(nil) + username, password, err := client.GetCredentials(registryID, proxyEndpoint+"/myimage") assert.NotNil(t, err) t.Log(err) @@ -160,9 +283,11 @@ func TestGetAuthConfigECRError(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() ecrClient := mock_ecriface.NewMockECRAPI(ctrl) + credentialCache := mock_cache.NewMockCredentialsCache(ctrl) client := &defaultClient{ - ecrClient: ecrClient, + ecrClient: ecrClient, + credentialCache: credentialCache, } ecrClient.EXPECT().GetAuthorizationToken(gomock.Any()).Do( @@ -175,9 +300,58 @@ func TestGetAuthConfigECRError(t *testing.T) { } }).Return(nil, errors.New("test error")) + credentialCache.EXPECT().Get(registryID).Return(nil) + username, password, err := client.GetCredentials(registryID, proxyEndpoint+"/myimage") assert.NotNil(t, err) t.Log(err) assert.Empty(t, username) assert.Empty(t, password) } + +func TestGetAuthConfigSuccessInvalidCacheHitFallback(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ecrClient := mock_ecriface.NewMockECRAPI(ctrl) + credentialCache := mock_cache.NewMockCredentialsCache(ctrl) + + client := &defaultClient{ + ecrClient: ecrClient, + credentialCache: credentialCache, + } + + testProxyEndpoint := proxyEndpointScheme + proxyEndpoint + authorizationToken := base64.StdEncoding.EncodeToString([]byte(expectedUsername + ":" + expectedPassword)) + + ecrClient.EXPECT().GetAuthorizationToken(gomock.Any()).Do( + func(input *ecr.GetAuthorizationTokenInput) { + if input == nil { + t.Fatal("Called with nil input") + } + if len(input.RegistryIds) != 1 { + t.Fatalf("Unexpected number of RegistryIds, expected 1 but got %d", len(input.RegistryIds)) + } + }).Return(nil, errors.New("Service eror")) + + expiredAuthEntry := &cache.AuthEntry{ + ProxyEndpoint: testProxyEndpoint, + RequestedAt: time.Now().Add(-12 * time.Hour), + ExpiresAt: time.Now().Add(-6 * time.Hour), + AuthorizationToken: authorizationToken, + } + + credentialCache.EXPECT().Get(registryID).Return(expiredAuthEntry) + + username, password, err := client.GetCredentials(registryID, proxyEndpoint+"/myimage") + assert.Nil(t, err) + assert.Equal(t, username, expectedUsername) + assert.Equal(t, password, expectedPassword) +} + +func compareAuthEntry(t *testing.T, actual *cache.AuthEntry, expected *cache.AuthEntry) { + assert.NotNil(t, actual) + assert.Equal(t, expected.AuthorizationToken, actual.AuthorizationToken) + assert.Equal(t, expected.ProxyEndpoint, actual.ProxyEndpoint) + assert.Equal(t, expected.ExpiresAt, actual.ExpiresAt) + assert.WithinDuration(t, expected.RequestedAt, actual.RequestedAt, 5*time.Second) +} diff --git a/ecr-login/api/factory.go b/ecr-login/api/factory.go index d4132159..73a62721 100644 --- a/ecr-login/api/factory.go +++ b/ecr-login/api/factory.go @@ -14,9 +14,19 @@ package api import ( + "crypto/md5" + "encoding/base64" + "fmt" + "os" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ecr" + "github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cache" + "github.com/mitchellh/go-homedir" + + log "github.com/cihub/seelog" ) type ClientFactory interface { @@ -24,8 +34,49 @@ type ClientFactory interface { } type DefaultClientFactory struct{} -func (DefaultClientFactory) NewClient(region string) Client { +func (defaultClientFactory DefaultClientFactory) NewClient(region string) Client { + awsSession := session.New() + return &defaultClient{ - ecrClient: ecr.New(session.New(), &aws.Config{Region: aws.String(region)}), + ecrClient: ecr.New(awsSession, &aws.Config{Region: aws.String(region)}), + credentialCache: defaultClientFactory.buildCredentialsCache(awsSession, region), + } +} + +func (defaultClientFactory DefaultClientFactory) buildCredentialsCache(awsSession *session.Session, region string) cache.CredentialsCache { + if os.Getenv("AWS_ECR_DISABLE_CACHE") != "" { + log.Debug("Cache disabled due to AWS_ECR_DISABLE_CACHE") + return cache.NewNullCredentialsCache() + } + + cacheDir, err := homedir.Expand("~/.ecr") + if err != nil { + log.Debugf("Could expand cache path: %s", err) + log.Debug("Disabling cache") + return cache.NewNullCredentialsCache() } + + cacheFilename := "cache.json" + + credentials, err := awsSession.Config.Credentials.Get() + if err != nil { + log.Debugf("Could fetch credentials for cache prefix: %s", err) + log.Debug("Disabling cache") + return cache.NewNullCredentialsCache() + } + + return cache.NewFileCredentialsCache(cacheDir, cacheFilename, defaultClientFactory.credentialsCachePrefix(region, &credentials)) +} + +// Determine a key prefix for a credentials cache. Because auth tokens are scoped to an account and region, rely on provided +// region, as well as hash of the access key. +func (defaultClientFactory DefaultClientFactory) credentialsCachePrefix(region string, credentials *credentials.Value) string { + return fmt.Sprintf("%s-%s-", region, checksum(credentials.AccessKeyID)) +} + +// Base64 encodes an MD5 checksum. Relied on for uniqueness, and not for cryptographic security. +func checksum(text string) string { + hasher := md5.New() + data := hasher.Sum([]byte(text)) + return base64.StdEncoding.EncodeToString(data) } diff --git a/ecr-login/cache/credentials.go b/ecr-login/cache/credentials.go new file mode 100644 index 00000000..a81b9343 --- /dev/null +++ b/ecr-login/cache/credentials.go @@ -0,0 +1,36 @@ +// Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package cache + +import "time" + +type CredentialsCache interface { + Get(registry string) *AuthEntry + Set(registry string, entry *AuthEntry) + Clear() +} + +type AuthEntry struct { + AuthorizationToken string + RequestedAt time.Time + ExpiresAt time.Time + ProxyEndpoint string +} + +// Checks if AuthEntry is still valid at testTime. AuthEntries expire at 1/2 of their original +// requested window. +func (authEntry *AuthEntry) IsValid(testTime time.Time) bool { + window := authEntry.ExpiresAt.Sub(authEntry.RequestedAt) + return authEntry.ExpiresAt.After(testTime.Add(-1 * (window / time.Duration(2)))) +} diff --git a/ecr-login/cache/file.go b/ecr-login/cache/file.go new file mode 100644 index 00000000..955e29d5 --- /dev/null +++ b/ecr-login/cache/file.go @@ -0,0 +1,153 @@ +// Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package cache + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "os" + "path/filepath" + + log "github.com/cihub/seelog" +) + +const registryCacheVersion = "1.0" + +type RegistryCache struct { + Registries map[string]*AuthEntry + Version string +} + +type fileCredentialCache struct { + path string + filename string + cachePrefixKey string +} + +func newRegistryCache() *RegistryCache { + return &RegistryCache{ + Registries: make(map[string]*AuthEntry), + Version: registryCacheVersion, + } +} + +// NewFileCredentialsCache returns a new file credentials cache. +// +// path is used for temporary files during save, and filename should be a relative filename +// in the same directory where the cache is serialized and deserialized. +// +// cachePrefixKey is used for scoping credentials for a given credential cache (i.e. region and +// accessKey). +func NewFileCredentialsCache(path string, filename string, cachePrefixKey string) CredentialsCache { + return &fileCredentialCache{path: path, filename: filename, cachePrefixKey: cachePrefixKey} +} + +func (f *fileCredentialCache) Get(registry string) *AuthEntry { + log.Debugf("Checking file cache for %s", registry) + registryCache, err := f.load() + if err != nil { + log.Infof("Could not load existing cache: %v", err) + f.Clear() + registryCache = newRegistryCache() + } + return registryCache.Registries[f.cachePrefixKey+registry] +} + +func (f *fileCredentialCache) Set(registry string, entry *AuthEntry) { + log.Debugf("Saving credentials to file cache for %s", registry) + registryCache, err := f.load() + if err != nil { + log.Infof("Could not load existing cache: %v", err) + f.Clear() + registryCache = newRegistryCache() + } + + registryCache.Registries[f.cachePrefixKey+registry] = entry + + err = f.save(registryCache) + if err != nil { + log.Infof("Could not save cache: %s", err) + } +} + +func (f *fileCredentialCache) Clear() { + err := os.Remove(f.fullFilePath()) + if err != nil { + log.Infof("Could not clear cache: %s") + } +} + +func (f *fileCredentialCache) fullFilePath() string { + return filepath.Join(f.path, f.filename) +} + +// Saves credential cache to disk. This writes to a temporary file first, then moves the file to the config location. +// This elminates from reading partially written credential files, and reduces (but does not eliminate) concurrent +// file access. There is not guarantee here for handling multiple writes at once since there is no out of process locking. +func (f *fileCredentialCache) save(registryCache *RegistryCache) error { + defer log.Flush() + file, err := ioutil.TempFile(f.path, ".config.json.tmp") + if err != nil { + return err + } + + buff, err := json.MarshalIndent(registryCache, "", " ") + if err != nil { + file.Close() + os.Remove(file.Name()) + return err + } + + _, err = file.Write(buff) + + if err != nil { + file.Close() + os.Remove(file.Name()) + return err + } + + file.Close() + // note this is only atomic when relying on linux syscalls + os.Rename(file.Name(), f.fullFilePath()) + return err +} + +// Loading a cache from disk will return errors for malformed or incompatible cache files. +func (f *fileCredentialCache) load() (*RegistryCache, error) { + registryCache := newRegistryCache() + + file, err := os.Open(f.fullFilePath()) + if os.IsNotExist(err) { + return registryCache, nil + } + + if err != nil { + return nil, err + } + + defer file.Close() + + if err = json.NewDecoder(file).Decode(®istryCache); err != nil { + return nil, err + } + + if registryCache.Version != registryCacheVersion { + return nil, fmt.Errorf("Registry cache version %#v is not compatible with %#v. Ignoring existing cache.", + registryCache.Version, + registryCacheVersion) + } + + return registryCache, nil +} diff --git a/ecr-login/cache/file_test.go b/ecr-login/cache/file_test.go new file mode 100644 index 00000000..5f282c8a --- /dev/null +++ b/ecr-login/cache/file_test.go @@ -0,0 +1,91 @@ +// Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package cache + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var testAuthEntry = AuthEntry{ + AuthorizationToken: "testToken", + RequestedAt: time.Now().Add(-6 * time.Hour), + ExpiresAt: time.Now().Add(6 * time.Hour), + ProxyEndpoint: "testEndpoint", +} + +var testRegistryName = "testRegistry" + +var testCachePrefixKey = "prefix-" +var testPath = os.TempDir() +var testFilename = "test.json" +var testFullFillename = filepath.Join(testPath, testFilename) + +func TestAuthEntryValid(t *testing.T) { + assert.True(t, testAuthEntry.IsValid(time.Now())) +} + +func TestAuthEntryInValid(t *testing.T) { + assert.True(t, testAuthEntry.IsValid(time.Now().Add(time.Second))) +} + +func TestCredentials(t *testing.T) { + credentialCache := NewFileCredentialsCache(testPath, testFilename, testCachePrefixKey) + + credentialCache.Set(testRegistryName, &testAuthEntry) + + entry := credentialCache.Get(testRegistryName) + assert.Equal(t, &testAuthEntry, entry) + + credentialCache.Clear() + + entry = credentialCache.Get(testRegistryName) + assert.Nil(t, entry) +} + +func TestPreviousVersionCache(t *testing.T) { + credentialCache := NewFileCredentialsCache(testPath, testFilename, testCachePrefixKey) + + registryCache := newRegistryCache() + registryCache.Version = "0.1" + registryCache.Registries[testRegistryName] = &testAuthEntry + credentialCache.(*fileCredentialCache).save(registryCache) + + entry := credentialCache.Get(testRegistryName) + assert.Nil(t, entry) + + credentialCache.Clear() +} + +const testBadJson = "{nope not good json at all." + +func TestInvalidCache(t *testing.T) { + credentialCache := NewFileCredentialsCache(testPath, testFilename, testCachePrefixKey) + + file, err := os.Create(testFullFillename) + assert.NoError(t, err) + + file.WriteString(testBadJson) + err = file.Close() + assert.NoError(t, err) + + entry := credentialCache.Get(testRegistryName) + assert.Nil(t, entry) + + credentialCache.Clear() +} diff --git a/ecr-login/cache/generate_mocks.go b/ecr-login/cache/generate_mocks.go new file mode 100644 index 00000000..e974850c --- /dev/null +++ b/ecr-login/cache/generate_mocks.go @@ -0,0 +1,16 @@ +// Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package cache + +//go:generate mockgen.sh github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cache CredentialsCache mocks/cache_mocks.go diff --git a/ecr-login/cache/mocks/cache_mocks.go b/ecr-login/cache/mocks/cache_mocks.go new file mode 100644 index 00000000..4263449c --- /dev/null +++ b/ecr-login/cache/mocks/cache_mocks.go @@ -0,0 +1,69 @@ +// Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +// Automatically generated by MockGen. DO NOT EDIT! +// Source: github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cache (interfaces: CredentialsCache) + +package mock_cache + +import ( + cache "github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cache" + gomock "github.com/golang/mock/gomock" +) + +// Mock of CredentialsCache interface +type MockCredentialsCache struct { + ctrl *gomock.Controller + recorder *_MockCredentialsCacheRecorder +} + +// Recorder for MockCredentialsCache (not exported) +type _MockCredentialsCacheRecorder struct { + mock *MockCredentialsCache +} + +func NewMockCredentialsCache(ctrl *gomock.Controller) *MockCredentialsCache { + mock := &MockCredentialsCache{ctrl: ctrl} + mock.recorder = &_MockCredentialsCacheRecorder{mock} + return mock +} + +func (_m *MockCredentialsCache) EXPECT() *_MockCredentialsCacheRecorder { + return _m.recorder +} + +func (_m *MockCredentialsCache) Clear() { + _m.ctrl.Call(_m, "Clear") +} + +func (_mr *_MockCredentialsCacheRecorder) Clear() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "Clear") +} + +func (_m *MockCredentialsCache) Get(_param0 string) *cache.AuthEntry { + ret := _m.ctrl.Call(_m, "Get", _param0) + ret0, _ := ret[0].(*cache.AuthEntry) + return ret0 +} + +func (_mr *_MockCredentialsCacheRecorder) Get(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "Get", arg0) +} + +func (_m *MockCredentialsCache) Set(_param0 string, _param1 *cache.AuthEntry) { + _m.ctrl.Call(_m, "Set", _param0, _param1) +} + +func (_mr *_MockCredentialsCacheRecorder) Set(arg0, arg1 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "Set", arg0, arg1) +} diff --git a/ecr-login/cache/null.go b/ecr-login/cache/null.go new file mode 100644 index 00000000..1a5de8ee --- /dev/null +++ b/ecr-login/cache/null.go @@ -0,0 +1,30 @@ +// Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package cache + +type nullCredentialsCache struct{} + +func NewNullCredentialsCache() CredentialsCache { + return &nullCredentialsCache{} +} + +func (nullCache *nullCredentialsCache) Get(registry string) *AuthEntry { + return nil +} + +func (nullCache *nullCredentialsCache) Set(registry string, entry *AuthEntry) { +} + +func (nullCache *nullCredentialsCache) Clear() { +} diff --git a/ecr-login/cache/null_test.go b/ecr-login/cache/null_test.go new file mode 100644 index 00000000..3b0410b1 --- /dev/null +++ b/ecr-login/cache/null_test.go @@ -0,0 +1,34 @@ +// Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package cache + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNullCache(t *testing.T) { + credentialCache := NewNullCredentialsCache() + + entry := credentialCache.Get(testRegistryName) + assert.Nil(t, entry) + + credentialCache.Set(testRegistryName, &testAuthEntry) + + entry = credentialCache.Get(testRegistryName) + assert.Nil(t, entry) + + credentialCache.Clear() +}