Skip to content

Commit

Permalink
feat: Add credential refresh to file mode (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
patricksanders authored Jan 7, 2021
1 parent 3d4e684 commit eafe143
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 11 deletions.
90 changes: 79 additions & 11 deletions cmd/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package cmd
import (
"fmt"
"path"
"time"

"gopkg.in/ini.v1"

Expand All @@ -34,6 +35,7 @@ func init() {
fileCmd.PersistentFlags().StringVarP(&destination, "output", "o", getDefaultCredentialsFile(), "output file for credentials")
fileCmd.PersistentFlags().StringVarP(&profileName, "profile", "p", "default", "profile name")
fileCmd.PersistentFlags().BoolVarP(&force, "force", "f", false, "overwrite existing profile without prompting")
fileCmd.PersistentFlags().BoolVarP(&autoRefresh, "refresh", "R", false, "automatically refresh credentials in file")
rootCmd.AddCommand(fileCmd)
}

Expand All @@ -46,17 +48,54 @@ var fileCmd = &cobra.Command{

func runFile(cmd *cobra.Command, args []string) error {
role = args[0]
err := updateCredentialsFile(role, profileName, destination, noIpRestrict, assumeRole)
if err != nil {
return err
}
if autoRefresh {
log.Infof("starting automatic file refresh for %s", role)
go fileRefresher(role, profileName, destination, noIpRestrict, assumeRole)
<-shutdown
}
return nil
}

func updateCredentialsFile(role, profile, filename string, noIpRestrict bool, assumeRole []string) error {
credentials, err := creds.GetCredentials(role, noIpRestrict, assumeRole)
if err != nil {
return err
}
err = writeCredentialsFile(credentials)
err = writeCredentialsFile(credentials, profile, filename)
if err != nil {
return err
}
return nil
}

func fileRefresher(role, profile, filename string, noIpRestrict bool, assumeRole []string) {
ticker := time.NewTicker(time.Minute)

for {
select {
case _ = <-ticker.C:
log.Debug("checking credentials")
expiring, err := isExpiring(filename, profile, 10)
if err != nil {
log.Errorf("error checking credential expiration: %v", err)
}
if expiring {
log.Info("credentials are expiring soon, refreshing...")
err = updateCredentialsFile(role, profile, filename, noIpRestrict, assumeRole)
if err != nil {
log.Errorf("error updating credentials: %v", err)
} else {
log.Info("credentials refreshed!")
}
}
}
}
}

func getDefaultCredentialsFile() string {
home, err := homedir.Dir()
if err != nil {
Expand All @@ -74,7 +113,7 @@ func getDefaultAwsConfigFile() string {
}

func shouldOverwriteCredentials() bool {
if force {
if force || autoRefresh {
return true
}
userForce, err := util.PromptBool(fmt.Sprintf("Overwrite %s profile?", profileName))
Expand All @@ -84,35 +123,64 @@ func shouldOverwriteCredentials() bool {
return userForce
}

func writeCredentialsFile(credentials *creds.AwsCredentials) error {
func isExpiring(filename, profile string, thresholdMinutes int) (bool, error) {
fileContents, err := ini.Load(filename)
if err != nil {
return false, err
}
section, err := fileContents.GetSection(profile)
if err != nil {
return true, err
}
expiration, err := section.GetKey("expiration")
if err != nil {
return true, err
}
expirationTime, err := expiration.Time()
if err != nil {
return true, err
}
diff := time.Duration(thresholdMinutes) * time.Minute
timeUntilExpiration := expirationTime.Sub(time.Now()).Round(0)
log.Debugf("%s until expiration, refresh threshold is %s", timeUntilExpiration, diff)
if timeUntilExpiration < diff {
log.Debug("will refresh")
return true, nil
}
log.Debug("will not refresh")
return false, nil
}

func writeCredentialsFile(credentials *creds.AwsCredentials, profile, filename string) error {
var credentialsINI *ini.File
var err error

// Disable pretty format, but still put spaces around `=`
ini.PrettyFormat = false
ini.PrettyEqual = true

if util.FileExists(destination) {
credentialsINI, err = ini.Load(destination)
if util.FileExists(filename) {
credentialsINI, err = ini.Load(filename)
if err != nil {
return err
}
} else {
credentialsINI = ini.Empty()
}

if _, err := credentialsINI.GetSection(profileName); err == nil {
if _, err := credentialsINI.GetSection(profile); err == nil {
// section already exists, should we overwrite?
if !shouldOverwriteCredentials() {
// user says no, so we'll just bail out
return fmt.Errorf("not overwriting %s profile", profileName)
return fmt.Errorf("not overwriting %s profile", profile)
}
}

credentialsINI.Section(profileName).Key("aws_access_key_id").SetValue(credentials.AccessKeyId)
credentialsINI.Section(profileName).Key("aws_secret_access_key").SetValue(credentials.SecretAccessKey)
credentialsINI.Section(profileName).Key("aws_session_token").SetValue(credentials.SessionToken)
err = credentialsINI.SaveTo(destination)
credentialsINI.Section(profile).Key("aws_access_key_id").SetValue(credentials.AccessKeyId)
credentialsINI.Section(profile).Key("aws_secret_access_key").SetValue(credentials.SecretAccessKey)
credentialsINI.Section(profile).Key("aws_session_token").SetValue(credentials.SessionToken)
credentialsINI.Section(profile).Key("expiration").SetValue(credentials.Expiration.Format("2006-01-02T15:04:05Z07:00"))
err = credentialsINI.SaveTo(filename)
if err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions cmd/vars.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ var (
destination string
destinationConfig string
force bool
autoRefresh bool
noIpRestrict bool
metadataRegion string
metadataListenAddr string
Expand Down

0 comments on commit eafe143

Please sign in to comment.