Skip to content

Commit

Permalink
Merge pull request #4853 from aws/feature-token-provider
Browse files Browse the repository at this point in the history
Add sso token provider
  • Loading branch information
wty-Bryant authored May 31, 2023
2 parents 0131d37 + 89c821a commit 1077388
Show file tree
Hide file tree
Showing 15 changed files with 891 additions and 20 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG_PENDING.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
### SDK Enhancements

### SDK Bugs
* `aws/credentials/ssocreds`: Implement SSO token provider to support for `sso-session` in AWS shared config.
* Fixes [4649](https://github.com/aws/aws-sdk-go/issues/4649)
49 changes: 49 additions & 0 deletions aws/auth/bearer/token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package bearer

import (
"github.com/aws/aws-sdk-go/aws"
"time"
)

type Token struct {
Value string

CanExpire bool
Expires time.Time
}

// Expired returns if the token's Expires time is before or equal to the time
// provided. If CanExpires is false, Expired will always return false.
func (t Token) Expired(now time.Time) bool {
if !t.CanExpire {
return false
}
now = now.Round(0)
return now.Equal(t.Expires) || now.After(t.Expires)
}

// TokenProvider provides interface for retrieving bearer tokens.
type TokenProvider interface {
RetrieveBearerToken(aws.Context) (Token, error)
}

// TokenProviderFunc provides a helper utility to wrap a function as a type
// that implements the TokenProvider interface.
type TokenProviderFunc func(aws.Context) (Token, error)

// RetrieveBearerToken calls the wrapped function, returning the Token or
// error.
func (fn TokenProviderFunc) RetrieveBearerToken(ctx aws.Context) (Token, error) {
return fn(ctx)
}

// StaticTokenProvider provides a utility for wrapping a static bearer token
// value within an implementation of a token provider.
type StaticTokenProvider struct {
Token Token
}

// RetrieveBearerToken returns the static token specified.
func (s StaticTokenProvider) RetrieveBearerToken(aws.Context) (Token, error) {
return s.Token, nil
}
20 changes: 0 additions & 20 deletions aws/credentials/ssocreds/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"crypto/sha1"
"encoding/hex"
"encoding/json"
"fmt"
"io/ioutil"
"path/filepath"
"strings"
Expand Down Expand Up @@ -123,25 +122,6 @@ func getCacheFileName(url string) (string, error) {
return strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json", nil
}

type rfc3339 time.Time

func (r *rfc3339) UnmarshalJSON(bytes []byte) error {
var value string

if err := json.Unmarshal(bytes, &value); err != nil {
return err
}

parse, err := time.Parse(time.RFC3339, value)
if err != nil {
return fmt.Errorf("expected RFC3339 timestamp: %v", err)
}

*r = rfc3339(parse)

return nil
}

type token struct {
AccessToken string `json:"accessToken"`
ExpiresAt rfc3339 `json:"expiresAt"`
Expand Down
229 changes: 229 additions & 0 deletions aws/credentials/ssocreds/sso_cached_token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
package ssocreds

import (
"crypto/sha1"
"encoding/hex"
"encoding/json"
"fmt"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
"io/ioutil"
"os"
"path/filepath"
"strconv"
"strings"
"time"
)

var resolvedOsUserHomeDir = shareddefaults.UserHomeDir

// StandardCachedTokenFilepath returns the filepath for the cached SSO token file, or
// error if unable get derive the path. Key that will be used to compute a SHA1
// value that is hex encoded.
//
// Derives the filepath using the Key as:
//
// ~/.aws/sso/cache/<sha1-hex-encoded-key>.json
func StandardCachedTokenFilepath(key string) (string, error) {
homeDir := resolvedOsUserHomeDir()
if len(homeDir) == 0 {
return "", fmt.Errorf("unable to get USER's home directory for cached token")
}
hash := sha1.New()
if _, err := hash.Write([]byte(key)); err != nil {
return "", fmt.Errorf("unable to compute cached token filepath key SHA1 hash, %v", err)
}

cacheFilename := strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json"

return filepath.Join(homeDir, ".aws", "sso", "cache", cacheFilename), nil
}

type tokenKnownFields struct {
AccessToken string `json:"accessToken,omitempty"`
ExpiresAt *rfc3339 `json:"expiresAt,omitempty"`

RefreshToken string `json:"refreshToken,omitempty"`
ClientID string `json:"clientId,omitempty"`
ClientSecret string `json:"clientSecret,omitempty"`
}

type cachedToken struct {
tokenKnownFields
UnknownFields map[string]interface{} `json:"-"`
}

func (t cachedToken) MarshalJSON() ([]byte, error) {
fields := map[string]interface{}{}

setTokenFieldString(fields, "accessToken", t.AccessToken)
setTokenFieldRFC3339(fields, "expiresAt", t.ExpiresAt)

setTokenFieldString(fields, "refreshToken", t.RefreshToken)
setTokenFieldString(fields, "clientId", t.ClientID)
setTokenFieldString(fields, "clientSecret", t.ClientSecret)

for k, v := range t.UnknownFields {
if _, ok := fields[k]; ok {
return nil, fmt.Errorf("unknown token field %v, duplicates known field", k)
}
fields[k] = v
}

return json.Marshal(fields)
}

func setTokenFieldString(fields map[string]interface{}, key, value string) {
if value == "" {
return
}
fields[key] = value
}
func setTokenFieldRFC3339(fields map[string]interface{}, key string, value *rfc3339) {
if value == nil {
return
}
fields[key] = value
}

func (t *cachedToken) UnmarshalJSON(b []byte) error {
var fields map[string]interface{}
if err := json.Unmarshal(b, &fields); err != nil {
return nil
}

t.UnknownFields = map[string]interface{}{}

for k, v := range fields {
var err error
switch k {
case "accessToken":
err = getTokenFieldString(v, &t.AccessToken)
case "expiresAt":
err = getTokenFieldRFC3339(v, &t.ExpiresAt)
case "refreshToken":
err = getTokenFieldString(v, &t.RefreshToken)
case "clientId":
err = getTokenFieldString(v, &t.ClientID)
case "clientSecret":
err = getTokenFieldString(v, &t.ClientSecret)
default:
t.UnknownFields[k] = v
}

if err != nil {
return fmt.Errorf("field %q, %v", k, err)
}
}

return nil
}

func getTokenFieldString(v interface{}, value *string) error {
var ok bool
*value, ok = v.(string)
if !ok {
return fmt.Errorf("expect value to be string, got %T", v)
}
return nil
}

func getTokenFieldRFC3339(v interface{}, value **rfc3339) error {
var stringValue string
if err := getTokenFieldString(v, &stringValue); err != nil {
return err
}

timeValue, err := parseRFC3339(stringValue)
if err != nil {
return err
}

*value = &timeValue
return nil
}

func loadCachedToken(filename string) (cachedToken, error) {
fileBytes, err := ioutil.ReadFile(filename)
if err != nil {
return cachedToken{}, fmt.Errorf("failed to read cached SSO token file, %v", err)
}

var t cachedToken
if err := json.Unmarshal(fileBytes, &t); err != nil {
return cachedToken{}, fmt.Errorf("failed to parse cached SSO token file, %v", err)
}

if len(t.AccessToken) == 0 || t.ExpiresAt == nil || time.Time(*t.ExpiresAt).IsZero() {
return cachedToken{}, fmt.Errorf(
"cached SSO token must contain accessToken and expiresAt fields")
}

return t, nil
}

func storeCachedToken(filename string, t cachedToken, fileMode os.FileMode) (err error) {
tmpFilename := filename + ".tmp-" + strconv.FormatInt(nowTime().UnixNano(), 10)
if err := writeCacheFile(tmpFilename, fileMode, t); err != nil {
return err
}

if err := os.Rename(tmpFilename, filename); err != nil {
return fmt.Errorf("failed to replace old cached SSO token file, %v", err)
}

return nil
}

func writeCacheFile(filename string, fileMode os.FileMode, t cachedToken) (err error) {
var f *os.File
f, err = os.OpenFile(filename, os.O_CREATE|os.O_TRUNC|os.O_RDWR, fileMode)
if err != nil {
return fmt.Errorf("failed to create cached SSO token file %v", err)
}

defer func() {
closeErr := f.Close()
if err == nil && closeErr != nil {
err = fmt.Errorf("failed to close cached SSO token file, %v", closeErr)
}
}()

encoder := json.NewEncoder(f)

if err = encoder.Encode(t); err != nil {
return fmt.Errorf("failed to serialize cached SSO token, %v", err)
}

return nil
}

type rfc3339 time.Time

func (r *rfc3339) UnmarshalJSON(bytes []byte) error {
var value string
var err error

if err = json.Unmarshal(bytes, &value); err != nil {
return err
}

*r, err = parseRFC3339(value)
return err
}

func parseRFC3339(v string) (rfc3339, error) {
parsed, err := time.Parse(time.RFC3339, v)
if err != nil {
return rfc3339{}, fmt.Errorf("expected RFC3339 timestamp: %v", err)
}

return rfc3339(parsed), nil
}

func (r *rfc3339) MarshalJSON() ([]byte, error) {
value := time.Time(*r).Format(time.RFC3339)

// Use JSON unmarshal to unescape the quoted value making use of JSON's
// quoting rules.
return json.Marshal(value)
}
Loading

0 comments on commit 1077388

Please sign in to comment.