diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index 8a1927a39ca..77e83777781 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -3,3 +3,5 @@ ### SDK Enhancements ### SDK Bugs +* `aws/credentials/ssocreds`: Implement SSO token provider to support for `sso-session` in AWS shared config. + * Fixes [4649](https://github.com/aws/aws-sdk-go/issues/4649) \ No newline at end of file diff --git a/aws/auth/bearer/token.go b/aws/auth/bearer/token.go new file mode 100644 index 00000000000..78b8a8d95fe --- /dev/null +++ b/aws/auth/bearer/token.go @@ -0,0 +1,49 @@ +package bearer + +import ( + "github.com/aws/aws-sdk-go/aws" + "time" +) + +type Token struct { + Value string + + CanExpire bool + Expires time.Time +} + +// Expired returns if the token's Expires time is before or equal to the time +// provided. If CanExpires is false, Expired will always return false. +func (t Token) Expired(now time.Time) bool { + if !t.CanExpire { + return false + } + now = now.Round(0) + return now.Equal(t.Expires) || now.After(t.Expires) +} + +// TokenProvider provides interface for retrieving bearer tokens. +type TokenProvider interface { + RetrieveBearerToken(aws.Context) (Token, error) +} + +// TokenProviderFunc provides a helper utility to wrap a function as a type +// that implements the TokenProvider interface. +type TokenProviderFunc func(aws.Context) (Token, error) + +// RetrieveBearerToken calls the wrapped function, returning the Token or +// error. +func (fn TokenProviderFunc) RetrieveBearerToken(ctx aws.Context) (Token, error) { + return fn(ctx) +} + +// StaticTokenProvider provides a utility for wrapping a static bearer token +// value within an implementation of a token provider. +type StaticTokenProvider struct { + Token Token +} + +// RetrieveBearerToken returns the static token specified. +func (s StaticTokenProvider) RetrieveBearerToken(aws.Context) (Token, error) { + return s.Token, nil +} diff --git a/aws/credentials/ssocreds/provider.go b/aws/credentials/ssocreds/provider.go index 6eda2a5557f..95de5520f09 100644 --- a/aws/credentials/ssocreds/provider.go +++ b/aws/credentials/ssocreds/provider.go @@ -4,7 +4,6 @@ import ( "crypto/sha1" "encoding/hex" "encoding/json" - "fmt" "io/ioutil" "path/filepath" "strings" @@ -123,25 +122,6 @@ func getCacheFileName(url string) (string, error) { return strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json", nil } -type rfc3339 time.Time - -func (r *rfc3339) UnmarshalJSON(bytes []byte) error { - var value string - - if err := json.Unmarshal(bytes, &value); err != nil { - return err - } - - parse, err := time.Parse(time.RFC3339, value) - if err != nil { - return fmt.Errorf("expected RFC3339 timestamp: %v", err) - } - - *r = rfc3339(parse) - - return nil -} - type token struct { AccessToken string `json:"accessToken"` ExpiresAt rfc3339 `json:"expiresAt"` diff --git a/aws/credentials/ssocreds/sso_cached_token.go b/aws/credentials/ssocreds/sso_cached_token.go new file mode 100644 index 00000000000..a7e76282b98 --- /dev/null +++ b/aws/credentials/ssocreds/sso_cached_token.go @@ -0,0 +1,229 @@ +package ssocreds + +import ( + "crypto/sha1" + "encoding/hex" + "encoding/json" + "fmt" + "github.com/aws/aws-sdk-go/internal/shareddefaults" + "io/ioutil" + "os" + "path/filepath" + "strconv" + "strings" + "time" +) + +var resolvedOsUserHomeDir = shareddefaults.UserHomeDir + +// StandardCachedTokenFilepath returns the filepath for the cached SSO token file, or +// error if unable get derive the path. Key that will be used to compute a SHA1 +// value that is hex encoded. +// +// Derives the filepath using the Key as: +// +// ~/.aws/sso/cache/.json +func StandardCachedTokenFilepath(key string) (string, error) { + homeDir := resolvedOsUserHomeDir() + if len(homeDir) == 0 { + return "", fmt.Errorf("unable to get USER's home directory for cached token") + } + hash := sha1.New() + if _, err := hash.Write([]byte(key)); err != nil { + return "", fmt.Errorf("unable to compute cached token filepath key SHA1 hash, %v", err) + } + + cacheFilename := strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json" + + return filepath.Join(homeDir, ".aws", "sso", "cache", cacheFilename), nil +} + +type tokenKnownFields struct { + AccessToken string `json:"accessToken,omitempty"` + ExpiresAt *rfc3339 `json:"expiresAt,omitempty"` + + RefreshToken string `json:"refreshToken,omitempty"` + ClientID string `json:"clientId,omitempty"` + ClientSecret string `json:"clientSecret,omitempty"` +} + +type cachedToken struct { + tokenKnownFields + UnknownFields map[string]interface{} `json:"-"` +} + +func (t cachedToken) MarshalJSON() ([]byte, error) { + fields := map[string]interface{}{} + + setTokenFieldString(fields, "accessToken", t.AccessToken) + setTokenFieldRFC3339(fields, "expiresAt", t.ExpiresAt) + + setTokenFieldString(fields, "refreshToken", t.RefreshToken) + setTokenFieldString(fields, "clientId", t.ClientID) + setTokenFieldString(fields, "clientSecret", t.ClientSecret) + + for k, v := range t.UnknownFields { + if _, ok := fields[k]; ok { + return nil, fmt.Errorf("unknown token field %v, duplicates known field", k) + } + fields[k] = v + } + + return json.Marshal(fields) +} + +func setTokenFieldString(fields map[string]interface{}, key, value string) { + if value == "" { + return + } + fields[key] = value +} +func setTokenFieldRFC3339(fields map[string]interface{}, key string, value *rfc3339) { + if value == nil { + return + } + fields[key] = value +} + +func (t *cachedToken) UnmarshalJSON(b []byte) error { + var fields map[string]interface{} + if err := json.Unmarshal(b, &fields); err != nil { + return nil + } + + t.UnknownFields = map[string]interface{}{} + + for k, v := range fields { + var err error + switch k { + case "accessToken": + err = getTokenFieldString(v, &t.AccessToken) + case "expiresAt": + err = getTokenFieldRFC3339(v, &t.ExpiresAt) + case "refreshToken": + err = getTokenFieldString(v, &t.RefreshToken) + case "clientId": + err = getTokenFieldString(v, &t.ClientID) + case "clientSecret": + err = getTokenFieldString(v, &t.ClientSecret) + default: + t.UnknownFields[k] = v + } + + if err != nil { + return fmt.Errorf("field %q, %v", k, err) + } + } + + return nil +} + +func getTokenFieldString(v interface{}, value *string) error { + var ok bool + *value, ok = v.(string) + if !ok { + return fmt.Errorf("expect value to be string, got %T", v) + } + return nil +} + +func getTokenFieldRFC3339(v interface{}, value **rfc3339) error { + var stringValue string + if err := getTokenFieldString(v, &stringValue); err != nil { + return err + } + + timeValue, err := parseRFC3339(stringValue) + if err != nil { + return err + } + + *value = &timeValue + return nil +} + +func loadCachedToken(filename string) (cachedToken, error) { + fileBytes, err := ioutil.ReadFile(filename) + if err != nil { + return cachedToken{}, fmt.Errorf("failed to read cached SSO token file, %v", err) + } + + var t cachedToken + if err := json.Unmarshal(fileBytes, &t); err != nil { + return cachedToken{}, fmt.Errorf("failed to parse cached SSO token file, %v", err) + } + + if len(t.AccessToken) == 0 || t.ExpiresAt == nil || time.Time(*t.ExpiresAt).IsZero() { + return cachedToken{}, fmt.Errorf( + "cached SSO token must contain accessToken and expiresAt fields") + } + + return t, nil +} + +func storeCachedToken(filename string, t cachedToken, fileMode os.FileMode) (err error) { + tmpFilename := filename + ".tmp-" + strconv.FormatInt(nowTime().UnixNano(), 10) + if err := writeCacheFile(tmpFilename, fileMode, t); err != nil { + return err + } + + if err := os.Rename(tmpFilename, filename); err != nil { + return fmt.Errorf("failed to replace old cached SSO token file, %v", err) + } + + return nil +} + +func writeCacheFile(filename string, fileMode os.FileMode, t cachedToken) (err error) { + var f *os.File + f, err = os.OpenFile(filename, os.O_CREATE|os.O_TRUNC|os.O_RDWR, fileMode) + if err != nil { + return fmt.Errorf("failed to create cached SSO token file %v", err) + } + + defer func() { + closeErr := f.Close() + if err == nil && closeErr != nil { + err = fmt.Errorf("failed to close cached SSO token file, %v", closeErr) + } + }() + + encoder := json.NewEncoder(f) + + if err = encoder.Encode(t); err != nil { + return fmt.Errorf("failed to serialize cached SSO token, %v", err) + } + + return nil +} + +type rfc3339 time.Time + +func (r *rfc3339) UnmarshalJSON(bytes []byte) error { + var value string + var err error + + if err = json.Unmarshal(bytes, &value); err != nil { + return err + } + + *r, err = parseRFC3339(value) + return err +} + +func parseRFC3339(v string) (rfc3339, error) { + parsed, err := time.Parse(time.RFC3339, v) + if err != nil { + return rfc3339{}, fmt.Errorf("expected RFC3339 timestamp: %v", err) + } + + return rfc3339(parsed), nil +} + +func (r *rfc3339) MarshalJSON() ([]byte, error) { + value := time.Time(*r).Format(time.RFC3339) + + // Use JSON unmarshal to unescape the quoted value making use of JSON's + // quoting rules. + return json.Marshal(value) +} diff --git a/aws/credentials/ssocreds/sso_cached_token_test.go b/aws/credentials/ssocreds/sso_cached_token_test.go new file mode 100644 index 00000000000..64197c9ac69 --- /dev/null +++ b/aws/credentials/ssocreds/sso_cached_token_test.go @@ -0,0 +1,191 @@ +//go:build go1.9 +// +build go1.9 + +package ssocreds + +import ( + "io/ioutil" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + "time" +) + +func TestStandardSSOCacheTokenFilepath(t *testing.T) { + origHomeDur := resolvedOsUserHomeDir + defer func() { + resolvedOsUserHomeDir = origHomeDur + }() + + cases := map[string]struct { + key string + osUserHomeDir func() string + expectFilename string + expectErr string + }{ + "success": { + key: "https://example.awsapps.com/start", + osUserHomeDir: func() string { + return os.TempDir() + }, + expectFilename: filepath.Join(os.TempDir(), ".aws", "sso", "cache", + "e8be5486177c5b5392bd9aa76563515b29358e6e.json"), + }, + "failure": { + key: "https://example.awsapps.com/start", + osUserHomeDir: func() string { + return "" + }, + expectErr: "some error", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + resolvedOsUserHomeDir = c.osUserHomeDir + + actual, err := StandardCachedTokenFilepath(c.key) + if c.expectErr != "" { + if err == nil { + t.Fatalf("expect error, got none") + } + return + } + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if e, a := c.expectFilename, actual; e != a { + t.Errorf("expect %v filename, got %v", e, a) + } + }) + } +} + +func TestLoadCachedToken(t *testing.T) { + cases := map[string]struct { + filename string + expectToken cachedToken + expectErr string + }{ + "file not found": { + filename: filepath.Join("testdata", "does_not_exist.json"), + expectErr: "failed to read cached SSO token file", + }, + "invalid json": { + filename: filepath.Join("testdata", "invalid_json.json"), + expectErr: "failed to parse cached SSO token file", + }, + "missing accessToken": { + filename: filepath.Join("testdata", "missing_accessToken.json"), + expectErr: "must contain accessToken and expiresAt fields", + }, + "missing expiresAt": { + filename: filepath.Join("testdata", "missing_expiresAt.json"), + expectErr: "must contain accessToken and expiresAt fields", + }, + "standard token": { + filename: filepath.Join("testdata", "valid_token.json"), + expectToken: cachedToken{ + tokenKnownFields: tokenKnownFields{ + AccessToken: "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl", + ExpiresAt: (*rfc3339)(Time(time.Date(2044, 4, 4, 7, 0, 1, 0, time.UTC))), + ClientID: "client id", + ClientSecret: "client secret", + RefreshToken: "refresh token", + }, + UnknownFields: map[string]interface{}{ + "unknownField": "some value", + "registrationExpiresAt": "2044-04-04T07:00:01Z", + "region": "region", + "startURL": "start URL", + }, + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + actualToken, err := loadCachedToken(c.filename) + if c.expectErr != "" { + if err == nil { + t.Fatalf("expect %v error, got none", c.expectErr) + } + if e, a := c.expectErr, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect %v error, got %v", e, a) + } + return + } + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if !reflect.DeepEqual(c.expectToken, actualToken) { + t.Errorf("expect token file %v but got actual %v", c.expectToken, actualToken) + } + }) + } +} + +func TestStoreCachedToken(t *testing.T) { + tempDir, err := ioutil.TempDir(os.TempDir(), "aws-sdk-go-"+t.Name()) + if err != nil { + t.Fatalf("failed to create temporary test directory, %v", err) + } + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Errorf("failed to cleanup temporary test directory, %v", err) + } + }() + + cases := map[string]struct { + token cachedToken + filename string + fileMode os.FileMode + }{ + "standard token": { + filename: filepath.Join(tempDir, "token_file.json"), + fileMode: 0600, + token: cachedToken{ + tokenKnownFields: tokenKnownFields{ + AccessToken: "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl", + ExpiresAt: (*rfc3339)(Time(time.Date(2044, 4, 4, 7, 0, 1, 0, time.UTC))), + ClientID: "client id", + ClientSecret: "client secret", + RefreshToken: "refresh token", + }, + UnknownFields: map[string]interface{}{ + "unknownField": "some value", + "registrationExpiresAt": "2044-04-04T07:00:01Z", + "region": "region", + "startURL": "start URL", + }, + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + err := storeCachedToken(c.filename, c.token, c.fileMode) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + actual, err := loadCachedToken(c.filename) + if err != nil { + t.Fatalf("failed to load stored token, %v", err) + } + + if !reflect.DeepEqual(c.token, actual) { + t.Errorf("expect token file %v but got actual %v", c.token, actual) + } + }) + } +} + +// Time returns a pointer value for the time.Time value passed in. +func Time(v time.Time) *time.Time { + return &v +} diff --git a/aws/credentials/ssocreds/testdata/expired_token.json b/aws/credentials/ssocreds/testdata/expired_token.json new file mode 100644 index 00000000000..7e648605571 --- /dev/null +++ b/aws/credentials/ssocreds/testdata/expired_token.json @@ -0,0 +1,8 @@ +{ + "accessToken": "expired access token", + "expiresAt": "2021-12-21T12:21:00Z", + "clientId": "client id", + "clientSecret": "client secret", + "refreshToken": "refresh token", + "unknownField": "some value" +} diff --git a/aws/credentials/ssocreds/testdata/invalid_json.json b/aws/credentials/ssocreds/testdata/invalid_json.json new file mode 100644 index 00000000000..98232c64fce --- /dev/null +++ b/aws/credentials/ssocreds/testdata/invalid_json.json @@ -0,0 +1 @@ +{ diff --git a/aws/credentials/ssocreds/testdata/missing_accessToken.json b/aws/credentials/ssocreds/testdata/missing_accessToken.json new file mode 100644 index 00000000000..dba6cace2ad --- /dev/null +++ b/aws/credentials/ssocreds/testdata/missing_accessToken.json @@ -0,0 +1,7 @@ +{ + "clientId": "client id", + "clientSecret": "client secret", + "refreshToken": "refresh token", + "missing_accessToken": "access token", + "expiresAt": "2044-04-04T07:00:01Z" +} diff --git a/aws/credentials/ssocreds/testdata/missing_clientId.json b/aws/credentials/ssocreds/testdata/missing_clientId.json new file mode 100644 index 00000000000..76dadfcfe42 --- /dev/null +++ b/aws/credentials/ssocreds/testdata/missing_clientId.json @@ -0,0 +1,7 @@ +{ + "missing_clientId": "client id", + "clientSecret": "client secret", + "refreshToken": "refresh token", + "accessToken": "access token", + "expiresAt": "2021-12-21T12:21:00Z" +} diff --git a/aws/credentials/ssocreds/testdata/missing_clientSecret.json b/aws/credentials/ssocreds/testdata/missing_clientSecret.json new file mode 100644 index 00000000000..aa28fc9f046 --- /dev/null +++ b/aws/credentials/ssocreds/testdata/missing_clientSecret.json @@ -0,0 +1,7 @@ +{ + "clientId": "client id", + "missing_clientSecret": "client secret", + "refreshToken": "refresh token", + "accessToken": "access token", + "expiresAt": "2021-12-21T12:21:00Z" +} diff --git a/aws/credentials/ssocreds/testdata/missing_expiresAt.json b/aws/credentials/ssocreds/testdata/missing_expiresAt.json new file mode 100644 index 00000000000..cd578891273 --- /dev/null +++ b/aws/credentials/ssocreds/testdata/missing_expiresAt.json @@ -0,0 +1,7 @@ +{ + "clientId": "client id", + "clientSecret": "client secret", + "refreshToken": "refresh token", + "accessToken": "access token", + "missing_expiresAt": "2044-04-04T07:00:01Z" +} diff --git a/aws/credentials/ssocreds/testdata/missing_refreshToken.json b/aws/credentials/ssocreds/testdata/missing_refreshToken.json new file mode 100644 index 00000000000..9afcff7465d --- /dev/null +++ b/aws/credentials/ssocreds/testdata/missing_refreshToken.json @@ -0,0 +1,7 @@ +{ + "clientId": "client id", + "clientSecret": "client secret", + "missing_refreshToken": "refresh token", + "accessToken": "access token", + "expiresAt": "2021-12-21T12:21:00Z" +} diff --git a/aws/credentials/ssocreds/testdata/valid_token.json b/aws/credentials/ssocreds/testdata/valid_token.json new file mode 100644 index 00000000000..528d11c4f10 --- /dev/null +++ b/aws/credentials/ssocreds/testdata/valid_token.json @@ -0,0 +1,13 @@ +{ + "accessToken": "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl", + "expiresAt": "2044-04-04T07:00:01Z", + + "refreshToken": "refresh token", + "clientId": "client id", + "clientSecret": "client secret", + + "unknownField": "some value", + "region": "region", + "registrationExpiresAt": "2044-04-04T07:00:01Z", + "startURL": "start URL" +} diff --git a/aws/credentials/ssocreds/token_provider.go b/aws/credentials/ssocreds/token_provider.go new file mode 100644 index 00000000000..2ca4babc936 --- /dev/null +++ b/aws/credentials/ssocreds/token_provider.go @@ -0,0 +1,139 @@ +package ssocreds + +import ( + "fmt" + "github.com/aws/aws-sdk-go/aws/auth/bearer" + "os" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ssooidc" +) + +// CreateTokenAPIClient provides the interface for the SSOTokenProvider's API +// client for calling CreateToken operation to refresh the SSO token. +type CreateTokenAPIClient interface { + CreateToken(input *ssooidc.CreateTokenInput) (*ssooidc.CreateTokenOutput, error) +} + +// SSOTokenProviderOptions provides the options for configuring the +// SSOTokenProvider. +type SSOTokenProviderOptions struct { + // Client that can be overridden + Client CreateTokenAPIClient + + // The path the file containing the cached SSO token will be read from. + // Initialized the NewSSOTokenProvider's cachedTokenFilepath parameter. + CachedTokenFilepath string +} + +// SSOTokenProvider provides a utility for refreshing SSO AccessTokens for +// Bearer Authentication. The SSOTokenProvider can only be used to refresh +// already cached SSO Tokens. This utility cannot perform the initial SSO +// create token. +// +// The initial SSO create token should be preformed with the AWS CLI before the +// Go application using the SSOTokenProvider will need to retrieve the SSO +// token. If the AWS CLI has not created the token cache file, this provider +// will return an error when attempting to retrieve the cached token. +// +// This provider will attempt to refresh the cached SSO token periodically if +// needed when RetrieveBearerToken is called. +// +// A utility such as the AWS CLI must be used to initially create the SSO +// session and cached token file. +// https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-sso.html +type SSOTokenProvider struct { + options SSOTokenProviderOptions +} + +// NewSSOTokenProvider returns an initialized SSOTokenProvider that will +// periodically refresh the SSO token cached stored in the cachedTokenFilepath. +// The cachedTokenFilepath file's content will be rewritten by the token +// provider when the token is refreshed. +// +// The client must be configured for the AWS region the SSO token was created for. +func NewSSOTokenProvider(client CreateTokenAPIClient, cachedTokenFilepath string, optFns ...func(o *SSOTokenProviderOptions)) *SSOTokenProvider { + options := SSOTokenProviderOptions{ + Client: client, + CachedTokenFilepath: cachedTokenFilepath, + } + for _, fn := range optFns { + fn(&options) + } + + provider := &SSOTokenProvider{ + options: options, + } + + return provider +} + +// RetrieveBearerToken returns the SSO token stored in the cachedTokenFilepath +// the SSOTokenProvider was created with. If the token has expired +// RetrieveBearerToken will attempt to refresh it. If the token cannot be +// refreshed or is not present an error will be returned. +// +// A utility such as the AWS CLI must be used to initially create the SSO +// session and cached token file. https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-sso.html +func (p SSOTokenProvider) RetrieveBearerToken(ctx aws.Context) (bearer.Token, error) { + cachedToken, err := loadCachedToken(p.options.CachedTokenFilepath) + if err != nil { + return bearer.Token{}, err + } + + if cachedToken.ExpiresAt != nil && nowTime().After(time.Time(*cachedToken.ExpiresAt)) { + cachedToken, err = p.refreshToken(cachedToken) + if err != nil { + return bearer.Token{}, fmt.Errorf("refresh cached SSO token failed, %v", err) + } + } + + expiresAt := toTime((*time.Time)(cachedToken.ExpiresAt)) + return bearer.Token{ + Value: cachedToken.AccessToken, + CanExpire: !expiresAt.IsZero(), + Expires: expiresAt, + }, nil +} + +func (p SSOTokenProvider) refreshToken(token cachedToken) (cachedToken, error) { + if token.ClientSecret == "" || token.ClientID == "" || token.RefreshToken == "" { + return cachedToken{}, fmt.Errorf("cached SSO token is expired, or not present, and cannot be refreshed") + } + + createResult, err := p.options.Client.CreateToken(&ssooidc.CreateTokenInput{ + ClientId: &token.ClientID, + ClientSecret: &token.ClientSecret, + RefreshToken: &token.RefreshToken, + GrantType: aws.String("refresh_token"), + }) + if err != nil { + return cachedToken{}, fmt.Errorf("unable to refresh SSO token, %v", err) + } + + expiresAt := nowTime().Add(time.Duration(*createResult.ExpiresIn) * time.Second) + + token.AccessToken = *createResult.AccessToken + token.ExpiresAt = (*rfc3339)(&expiresAt) + token.RefreshToken = *createResult.RefreshToken + + fileInfo, err := os.Stat(p.options.CachedTokenFilepath) + if err != nil { + return cachedToken{}, fmt.Errorf("failed to stat cached SSO token file %v", err) + } + + if err = storeCachedToken(p.options.CachedTokenFilepath, token, fileInfo.Mode()); err != nil { + return cachedToken{}, fmt.Errorf("unable to cache refreshed SSO token, %v", err) + } + + return token, nil +} + +func toTime(p *time.Time) (v time.Time) { + if p == nil { + return v + } + + return *p +} diff --git a/aws/credentials/ssocreds/token_provider_test.go b/aws/credentials/ssocreds/token_provider_test.go new file mode 100644 index 00000000000..53cb265a7ba --- /dev/null +++ b/aws/credentials/ssocreds/token_provider_test.go @@ -0,0 +1,224 @@ +//go:build go1.16 +// +build go1.16 + +package ssocreds + +import ( + "fmt" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/auth/bearer" + "github.com/aws/aws-sdk-go/service/ssooidc" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + "time" +) + +func TestSSOTokenProvider(t *testing.T) { + restoreTime := swapNowTime(time.Date(2021, 12, 21, 12, 21, 1, 0, time.UTC)) + defer restoreTime() + + tempDir, err := ioutil.TempDir(os.TempDir(), "aws-sdk-go-"+t.Name()) + if err != nil { + t.Fatalf("failed to create temporary test directory, %v", err) + } + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Errorf("failed to cleanup temporary test directory, %v", err) + } + }() + + cases := map[string]struct { + setup func() error + postRetrieve func() error + client CreateTokenAPIClient + cacheFilePath string + optFns []func(*SSOTokenProviderOptions) + + expectToken bearer.Token + expectErr string + }{ + "no cache file": { + cacheFilePath: filepath.Join("testdata", "file_not_exists"), + expectErr: "failed to read cached SSO token file", + }, + "invalid json cache file": { + cacheFilePath: filepath.Join("testdata", "invalid_json.json"), + expectErr: "failed to parse cached SSO token file", + }, + "missing accessToken": { + cacheFilePath: filepath.Join("testdata", "missing_accessToken.json"), + expectErr: "must contain accessToken and expiresAt fields", + }, + "missing expiresAt": { + cacheFilePath: filepath.Join("testdata", "missing_expiresAt.json"), + expectErr: "must contain accessToken and expiresAt fields", + }, + "expired no clientSecret": { + cacheFilePath: filepath.Join("testdata", "missing_clientSecret.json"), + expectErr: "cached SSO token is expired, or not present", + }, + "expired no clientId": { + cacheFilePath: filepath.Join("testdata", "missing_clientId.json"), + expectErr: "cached SSO token is expired, or not present", + }, + "expired no refreshToken": { + cacheFilePath: filepath.Join("testdata", "missing_refreshToken.json"), + expectErr: "cached SSO token is expired, or not present", + }, + "valid sso token": { + cacheFilePath: filepath.Join("testdata", "valid_token.json"), + expectToken: bearer.Token{ + Value: "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl", + CanExpire: true, + Expires: time.Date(2044, 4, 4, 7, 0, 1, 0, time.UTC), + }, + }, + "refresh expired token": { + setup: func() error { + testFile, err := os.ReadFile(filepath.Join("testdata", "expired_token.json")) + if err != nil { + return err + } + + return os.WriteFile(filepath.Join(tempDir, "expired_token.json"), testFile, 0600) + }, + postRetrieve: func() error { + actual, err := loadCachedToken(filepath.Join(tempDir, "expired_token.json")) + if err != nil { + return err + + } + expect := cachedToken{ + tokenKnownFields: tokenKnownFields{ + AccessToken: "updated access token", + ExpiresAt: (*rfc3339)(aws.Time(time.Date(2021, 12, 21, 12, 31, 1, 0, time.UTC))), + + RefreshToken: "updated refresh token", + ClientID: "client id", + ClientSecret: "client secret", + }, + UnknownFields: map[string]interface{}{ + "unknownField": "some value", + }, + } + + if !reflect.DeepEqual(expect, actual) { + return fmt.Errorf("expect token file %v but got actual %v", expect, actual) + } + return nil + }, + cacheFilePath: filepath.Join(tempDir, "expired_token.json"), + client: &mockCreateTokenAPIClient{ + expectInput: &ssooidc.CreateTokenInput{ + ClientId: aws.String("client id"), + ClientSecret: aws.String("client secret"), + RefreshToken: aws.String("refresh token"), + GrantType: aws.String("refresh_token"), + }, + output: &ssooidc.CreateTokenOutput{ + AccessToken: aws.String("updated access token"), + ExpiresIn: aws.Int64(600), + RefreshToken: aws.String("updated refresh token"), + }, + }, + expectToken: bearer.Token{ + Value: "updated access token", + CanExpire: true, + Expires: time.Date(2021, 12, 21, 12, 31, 1, 0, time.UTC), + }, + }, + "fail refresh expired token": { + setup: func() error { + testFile, err := os.ReadFile(filepath.Join("testdata", "expired_token.json")) + if err != nil { + return err + } + return os.WriteFile(filepath.Join(tempDir, "expired_token.json"), testFile, 0600) + }, + postRetrieve: func() error { + actual, err := loadCachedToken(filepath.Join(tempDir, "expired_token.json")) + if err != nil { + return err + + } + expect := cachedToken{ + tokenKnownFields: tokenKnownFields{ + AccessToken: "access token", + ExpiresAt: (*rfc3339)(aws.Time(time.Date(2021, 12, 21, 12, 21, 1, 0, time.UTC))), + + RefreshToken: "refresh token", + ClientID: "client id", + ClientSecret: "client secret", + }, + } + + if !reflect.DeepEqual(expect, actual) { + return fmt.Errorf("expect token file %v but got actual %v", expect, actual) + } + return nil + }, + cacheFilePath: filepath.Join(tempDir, "expired_token.json"), + client: &mockCreateTokenAPIClient{ + err: fmt.Errorf("sky is falling"), + }, + expectErr: "unable to refresh SSO token, sky is falling", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + if c.setup != nil { + if err := c.setup(); err != nil { + t.Fatalf("failed to setup test, %v", err) + } + } + provider := NewSSOTokenProvider(c.client, c.cacheFilePath, c.optFns...) + + token, err := provider.RetrieveBearerToken(aws.BackgroundContext()) + if c.expectErr != "" { + if err == nil { + t.Fatalf("expect %v error, got none", c.expectErr) + } + if e, a := c.expectErr, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect %v error, got %v", e, a) + } + return + } + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if !reflect.DeepEqual(c.expectToken, token) { + t.Errorf("expect %v, got %v", c.expectToken, token) + } + + if c.postRetrieve != nil { + if err := c.postRetrieve(); err != nil { + t.Fatalf("post retrieve failed, %v", err) + } + } + }) + } +} + +type mockCreateTokenAPIClient struct { + expectInput *ssooidc.CreateTokenInput + output *ssooidc.CreateTokenOutput + err error +} + +func (c *mockCreateTokenAPIClient) CreateToken(input *ssooidc.CreateTokenInput) ( + *ssooidc.CreateTokenOutput, error, +) { + if c.expectInput != nil { + if !reflect.DeepEqual(c.expectInput, input) { + return nil, fmt.Errorf("expect token file %v but got actual %v", c.expectInput, input) + } + } + + return c.output, c.err +}