-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sso token provider #4853
Add sso token provider #4853
Changes from 10 commits
8855191
3bf12a6
9e666ea
e3644bf
22968a0
6354a34
3e9df76
df54980
14a703c
6f16394
d0894cf
d5e1072
45be7cf
7b47ebb
e35aaee
3154ff0
3c021ea
359ee1e
19a523b
da6f599
8cf9168
8affb62
2ffc7e2
6e4eb05
b2ca51c
9b2a2a1
7564841
1558f1a
89c821a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
package ssocreds | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fix: This should either be renamed to remove the notion of bearer OR move it to a new package. I think I'd go with the latter (moving it) should we decide we want/need to implement bearer token auth support. e.g.
|
||
|
||
import ( | ||
"context" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fix: Use |
||
"time" | ||
) | ||
|
||
type Token struct { | ||
Value string | ||
|
||
CanExpire bool | ||
Expires time.Time | ||
} | ||
|
||
// Expired returns if the token's Expires time is before or equal to the time | ||
// provided. If CanExpires is false, Expired will always return false. | ||
func (t Token) Expired(now time.Time) bool { | ||
if !t.CanExpire { | ||
return false | ||
} | ||
now = now.Round(0) | ||
return now.Equal(t.Expires) || now.After(t.Expires) | ||
} | ||
|
||
// TokenProvider provides interface for retrieving bearer tokens. | ||
type TokenProvider interface { | ||
RetrieveBearerToken(context.Context) (Token, error) | ||
} | ||
|
||
// TokenProviderFunc provides a helper utility to wrap a function as a type | ||
// that implements the TokenProvider interface. | ||
type TokenProviderFunc func(context.Context) (Token, error) | ||
|
||
// RetrieveBearerToken calls the wrapped function, returning the Token or | ||
// error. | ||
func (fn TokenProviderFunc) RetrieveBearerToken(ctx context.Context) (Token, error) { | ||
return fn(ctx) | ||
} | ||
|
||
// StaticTokenProvider provides a utility for wrapping a static bearer token | ||
// value within an implementation of a token provider. | ||
type StaticTokenProvider struct { | ||
Token Token | ||
} | ||
|
||
// RetrieveBearerToken returns the static token specified. | ||
func (s StaticTokenProvider) RetrieveBearerToken(context.Context) (Token, error) { | ||
return s.Token, nil | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
package ssocreds | ||
|
||
import ( | ||
"crypto/sha1" | ||
"encoding/hex" | ||
"encoding/json" | ||
"fmt" | ||
"github.com/aws/aws-sdk-go/internal/shareddefaults" | ||
"io/ioutil" | ||
"os" | ||
"path/filepath" | ||
"strconv" | ||
"strings" | ||
"time" | ||
) | ||
|
||
var osUserHomeDur = shareddefaults.UserHomeDir | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i get what youre doing here: you seem to be saving it out to a variable so you can inject over it in the test case. but id bias towards descriptive difference in naming rather than intentional misspelling. something like |
||
|
||
// StandardCachedTokenFilepath returns the filepath for the cached SSO token file, or | ||
// error if unable get derive the path. Key that will be used to compute a SHA1 | ||
// value that is hex encoded. | ||
// | ||
// Derives the filepath using the Key as: | ||
// | ||
// ~/.aws/sso/cache/<sha1-hex-encoded-key>.json | ||
func StandardCachedTokenFilepath(key string) (string, error) { | ||
homeDir := osUserHomeDur() | ||
if len(homeDir) == 0 { | ||
return "", fmt.Errorf("unable to get USER's home directory for cached token") | ||
} | ||
hash := sha1.New() | ||
if _, err := hash.Write([]byte(key)); err != nil { | ||
return "", fmt.Errorf("unable to compute cached token filepath key SHA1 hash, %v", err) | ||
} | ||
|
||
cacheFilename := strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json" | ||
|
||
return filepath.Join(homeDir, ".aws", "sso", "cache", cacheFilename), nil | ||
} | ||
|
||
type tokenKnownFields struct { | ||
AccessToken string `json:"accessToken,omitempty"` | ||
ExpiresAt *rfc3339 `json:"expiresAt,omitempty"` | ||
|
||
RefreshToken string `json:"refreshToken,omitempty"` | ||
ClientID string `json:"clientId,omitempty"` | ||
ClientSecret string `json:"clientSecret,omitempty"` | ||
} | ||
|
||
type cachedToken struct { | ||
tokenKnownFields | ||
UnknownFields map[string]interface{} `json:"-"` | ||
} | ||
|
||
func (t cachedToken) MarshalJSON() ([]byte, error) { | ||
fields := map[string]interface{}{} | ||
|
||
setTokenFieldString(fields, "accessToken", t.AccessToken) | ||
setTokenFieldRFC3339(fields, "expiresAt", t.ExpiresAt) | ||
|
||
setTokenFieldString(fields, "refreshToken", t.RefreshToken) | ||
setTokenFieldString(fields, "clientId", t.ClientID) | ||
setTokenFieldString(fields, "clientSecret", t.ClientSecret) | ||
|
||
for k, v := range t.UnknownFields { | ||
if _, ok := fields[k]; ok { | ||
return nil, fmt.Errorf("unknown token field %v, duplicates known field", k) | ||
} | ||
fields[k] = v | ||
} | ||
|
||
return json.Marshal(fields) | ||
} | ||
|
||
func setTokenFieldString(fields map[string]interface{}, key, value string) { | ||
if value == "" { | ||
return | ||
} | ||
fields[key] = value | ||
} | ||
func setTokenFieldRFC3339(fields map[string]interface{}, key string, value *rfc3339) { | ||
if value == nil { | ||
return | ||
} | ||
fields[key] = value | ||
} | ||
|
||
func (t *cachedToken) UnmarshalJSON(b []byte) error { | ||
var fields map[string]interface{} | ||
if err := json.Unmarshal(b, &fields); err != nil { | ||
return nil | ||
} | ||
|
||
t.UnknownFields = map[string]interface{}{} | ||
|
||
for k, v := range fields { | ||
var err error | ||
switch k { | ||
case "accessToken": | ||
err = getTokenFieldString(v, &t.AccessToken) | ||
case "expiresAt": | ||
err = getTokenFieldRFC3339(v, &t.ExpiresAt) | ||
case "refreshToken": | ||
err = getTokenFieldString(v, &t.RefreshToken) | ||
case "clientId": | ||
err = getTokenFieldString(v, &t.ClientID) | ||
case "clientSecret": | ||
err = getTokenFieldString(v, &t.ClientSecret) | ||
default: | ||
t.UnknownFields[k] = v | ||
} | ||
|
||
if err != nil { | ||
return fmt.Errorf("field %q, %v", k, err) | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func getTokenFieldString(v interface{}, value *string) error { | ||
var ok bool | ||
*value, ok = v.(string) | ||
if !ok { | ||
return fmt.Errorf("expect value to be string, got %T", v) | ||
} | ||
return nil | ||
} | ||
|
||
func getTokenFieldRFC3339(v interface{}, value **rfc3339) error { | ||
var stringValue string | ||
if err := getTokenFieldString(v, &stringValue); err != nil { | ||
return err | ||
} | ||
|
||
timeValue, err := parseRFC3339(stringValue) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
*value = &timeValue | ||
return nil | ||
} | ||
|
||
func loadCachedAccessToken(filename string) (cachedToken, error) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question Given this comes from gov2 is there any reason we changed the name from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will change it back. Previously I thought it is mainly used to retrieve access token from cached file. |
||
fileBytes, err := ioutil.ReadFile(filename) | ||
if err != nil { | ||
return cachedToken{}, fmt.Errorf("failed to read cached SSO token file, %v", err) | ||
} | ||
|
||
var t cachedToken | ||
if err := json.Unmarshal(fileBytes, &t); err != nil { | ||
return cachedToken{}, fmt.Errorf("failed to parse cached SSO token file, %v", err) | ||
} | ||
|
||
if len(t.AccessToken) == 0 || t.ExpiresAt == nil || time.Time(*t.ExpiresAt).IsZero() { | ||
return cachedToken{}, fmt.Errorf( | ||
"cached SSO token must contain accessToken and expiresAt fields") | ||
} | ||
|
||
return t, nil | ||
} | ||
|
||
func storeCachedToken(filename string, t cachedToken, fileMode os.FileMode) (err error) { | ||
tmpFilename := filename + ".tmp-" + strconv.FormatInt(nowTime().UnixNano(), 10) | ||
if err := writeCacheFile(tmpFilename, fileMode, t); err != nil { | ||
return err | ||
} | ||
|
||
if err := os.Rename(tmpFilename, filename); err != nil { | ||
return fmt.Errorf("failed to replace old cached SSO token file, %v", err) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func writeCacheFile(filename string, fileMode os.FileMode, t cachedToken) (err error) { | ||
var f *os.File | ||
f, err = os.OpenFile(filename, os.O_CREATE|os.O_TRUNC|os.O_RDWR, fileMode) | ||
if err != nil { | ||
return fmt.Errorf("failed to create cached SSO token file %v", err) | ||
} | ||
|
||
defer func() { | ||
closeErr := f.Close() | ||
if err == nil && closeErr != nil { | ||
err = fmt.Errorf("failed to close cached SSO token file, %v", closeErr) | ||
} | ||
}() | ||
|
||
encoder := json.NewEncoder(f) | ||
|
||
if err = encoder.Encode(t); err != nil { | ||
return fmt.Errorf("failed to serialize cached SSO token, %v", err) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
type rfc3339 time.Time | ||
|
||
func (r *rfc3339) UnmarshalJSON(bytes []byte) error { | ||
var value string | ||
|
||
if err := json.Unmarshal(bytes, &value); err != nil { | ||
return err | ||
} | ||
|
||
parse, err := time.Parse(time.RFC3339, value) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fix: This duplicates the logic below, just call e.g. *r, err = parseRFC3339(value)
return err |
||
if err != nil { | ||
return fmt.Errorf("expected RFC3339 timestamp: %v", err) | ||
} | ||
|
||
*r = rfc3339(parse) | ||
|
||
return nil | ||
} | ||
|
||
func parseRFC3339(v string) (rfc3339, error) { | ||
parsed, err := time.Parse(time.RFC3339, v) | ||
if err != nil { | ||
return rfc3339{}, fmt.Errorf("expected RFC3339 timestamp: %v", err) | ||
} | ||
|
||
return rfc3339(parsed), nil | ||
} | ||
|
||
func (r *rfc3339) MarshalJSON() ([]byte, error) { | ||
value := time.Time(*r).Format(time.RFC3339) | ||
|
||
// Use JSON unmarshal to unescape the quoted value making use of JSON's | ||
// quoting rules. | ||
return json.Marshal(value) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There will only be one changelog entry for this entire body of work so this could be better.
e.g.