Skip to content

Commit

Permalink
Remove parameterized AWS session from token.go
Browse files Browse the repository at this point in the history
This simplifies the API, and removes the unnecessary `GetWithRoleForSession()`,
`GetWithRole()`, and `Get()` methods. This also simplifies migration to
aws-sdk-go-v2 by allowing both Generator and TokenOptions to be not bound to a
specific SDK version.
  • Loading branch information
micahhausler authored and sushanth0910 committed Nov 13, 2024
1 parent dc28c71 commit 31f3b60
Showing 1 changed file with 33 additions and 69 deletions.
102 changes: 33 additions & 69 deletions pkg/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ type GetTokenOptions struct {
AssumeRoleARN string
AssumeRoleExternalID string
SessionName string
Session *session.Session
}

// FormatError is returned when there is a problem with token that is
Expand Down Expand Up @@ -182,12 +181,6 @@ type getCallerIdentityWrapper struct {

// Generator provides new tokens for the AWS IAM Authenticator.
type Generator interface {
// Get a token using credentials in the default credentials chain.
Get(string) (Token, error)
// GetWithRole creates a token by assuming the provided role, using the credentials in the default chain.
GetWithRole(clusterID, roleARN string) (Token, error)
// GetWithRoleForSession creates a token by assuming the provided role, using the provided session.
GetWithRoleForSession(clusterID string, roleARN string, sess *session.Session) (Token, error)
// Get a token using the provided options
GetWithOptions(options *GetTokenOptions) (Token, error)
// GetWithSTS returns a token valid for clusterID using the given STS client.
Expand All @@ -211,31 +204,6 @@ func NewGenerator(forwardSessionName bool, cache bool) (Generator, error) {
}, nil
}

// Get uses the directly available AWS credentials to return a token valid for
// clusterID. It follows the default AWS credential handling behavior.
func (g generator) Get(clusterID string) (Token, error) {
return g.GetWithOptions(&GetTokenOptions{ClusterID: clusterID})
}

// GetWithRole assumes the given AWS IAM role and returns a token valid for
// clusterID. If roleARN is empty, behaves like Get (does not assume a role).
func (g generator) GetWithRole(clusterID string, roleARN string) (Token, error) {
return g.GetWithOptions(&GetTokenOptions{
ClusterID: clusterID,
AssumeRoleARN: roleARN,
})
}

// GetWithRoleForSession assumes the given AWS IAM role for the given session and behaves
// like GetWithRole.
func (g generator) GetWithRoleForSession(clusterID string, roleARN string, sess *session.Session) (Token, error) {
return g.GetWithOptions(&GetTokenOptions{
ClusterID: clusterID,
AssumeRoleARN: roleARN,
Session: sess,
})
}

// StdinStderrTokenProvider gets MFA token from standard input.
func StdinStderrTokenProvider() (string, error) {
var v string
Expand All @@ -252,46 +220,42 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) {
return Token{}, fmt.Errorf("ClusterID is required")
}

if options.Session == nil {
// create a session with the "base" credentials available
// (from environment variable, profile files, EC2 metadata, etc)
sess, err := session.NewSessionWithOptions(session.Options{
AssumeRoleTokenProvider: StdinStderrTokenProvider,
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
return Token{}, fmt.Errorf("could not create session: %v", err)
}
sess.Handlers.Build.PushFrontNamed(request.NamedHandler{
Name: "authenticatorUserAgent",
Fn: request.MakeAddToUserAgentHandler(
"aws-iam-authenticator", pkg.Version),
})
if options.Region != "" {
sess = sess.Copy(aws.NewConfig().WithRegion(options.Region).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint))
}
// create a session with the "base" credentials available
// (from environment variable, profile files, EC2 metadata, etc)
sess, err := session.NewSessionWithOptions(session.Options{
AssumeRoleTokenProvider: StdinStderrTokenProvider,
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
return Token{}, fmt.Errorf("could not create session: %v", err)
}
sess.Handlers.Build.PushFrontNamed(request.NamedHandler{
Name: "authenticatorUserAgent",
Fn: request.MakeAddToUserAgentHandler(
"aws-iam-authenticator", pkg.Version),
})
if options.Region != "" {
sess = sess.Copy(aws.NewConfig().WithRegion(options.Region).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint))
}

if g.cache {
// figure out what profile we're using
var profile string
if v := os.Getenv("AWS_PROFILE"); len(v) > 0 {
profile = v
} else {
profile = session.DefaultSharedConfigProfile
}
// create a cacheing Provider wrapper around the Credentials
if cacheProvider, err := NewFileCacheProvider(options.ClusterID, profile, options.AssumeRoleARN, sess.Config.Credentials); err == nil {
sess.Config.Credentials = credentials.NewCredentials(&cacheProvider)
} else {
_, _ = fmt.Fprintf(os.Stderr, "unable to use cache: %v\n", err)
}
if g.cache {
// figure out what profile we're using
var profile string
if v := os.Getenv("AWS_PROFILE"); len(v) > 0 {
profile = v
} else {
profile = session.DefaultSharedConfigProfile
}
// create a cacheing Provider wrapper around the Credentials
if cacheProvider, err := NewFileCacheProvider(options.ClusterID, profile, options.AssumeRoleARN, sess.Config.Credentials); err == nil {
sess.Config.Credentials = credentials.NewCredentials(&cacheProvider)
} else {
fmt.Fprintf(os.Stderr, "unable to use cache: %v\n", err)
}

options.Session = sess
}

// use an STS client based on the direct credentials
stsAPI := sts.New(options.Session)
stsAPI := sts.New(sess)

// if a roleARN was specified, replace the STS client with one that uses
// temporary credentials from that role.
Expand Down Expand Up @@ -326,10 +290,10 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) {
}

// create STS-based credentials that will assume the given role
creds := stscreds.NewCredentials(options.Session, options.AssumeRoleARN, sessionSetters...)
creds := stscreds.NewCredentials(sess, options.AssumeRoleARN, sessionSetters...)

// create an STS API interface that uses the assumed role's temporary credentials
stsAPI = sts.New(options.Session, &aws.Config{Credentials: creds})
stsAPI = sts.New(sess, &aws.Config{Credentials: creds})
}

return g.GetWithSTS(options.ClusterID, stsAPI)
Expand Down

0 comments on commit 31f3b60

Please sign in to comment.