Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use RoleArn for RefreshableProvider requests #56

Merged
merged 1 commit into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ func getCacheSlug(role string, assume []string) string {
return strings.Join(elements, "/")
}

func (cc *CredentialCache) Get(role string, assumeChain []string) (*creds.RefreshableProvider, error) {
func (cc *CredentialCache) Get(searchString string, assumeChain []string) (*creds.RefreshableProvider, error) {
log.WithFields(logrus.Fields{
"role": role,
"assumeChain": assumeChain,
"searchString": searchString,
"assumeChain": assumeChain,
}).Info("retrieving credentials")
c, ok := cc.get(getCacheSlug(role, assumeChain))
c, ok := cc.get(getCacheSlug(searchString, assumeChain))
if ok {
log.Debugf("found credentials for %s in cache", role)
log.Debugf("found credentials for %s in cache", searchString)
return c, nil
}
return nil, errors.NoCredentialsFoundInCache
Expand Down
72 changes: 36 additions & 36 deletions cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,44 +44,44 @@ func TestCredentialCache_Get(t *testing.T) {
{
Description: "role in cache",
CacheContents: map[string]*creds.RefreshableProvider{
"a": {Role: "a"},
"a": {RoleName: "a"},
},
Role: "a",
AssumeChain: []string{},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a"},
ExpectedResult: &creds.RefreshableProvider{RoleName: "a"},
},
{
Description: "role in cache with assume",
CacheContents: map[string]*creds.RefreshableProvider{
"a": {Role: "a"},
"a/b/c": {Role: "a/b/c"},
"a": {RoleName: "a"},
"a/b/c": {RoleName: "a/b/c"},
},
Role: "a",
AssumeChain: []string{},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a"},
ExpectedResult: &creds.RefreshableProvider{RoleName: "a"},
},
{
Description: "assume role in cache",
CacheContents: map[string]*creds.RefreshableProvider{
"a/b/c": {Role: "a/b/c"},
"a/b/c": {RoleName: "a/b/c"},
},
Role: "a",
AssumeChain: []string{"b", "c"},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a/b/c"},
ExpectedResult: &creds.RefreshableProvider{RoleName: "a/b/c"},
},
{
Description: "assume role in cache with non-assume",
CacheContents: map[string]*creds.RefreshableProvider{
"a": {Role: "a"},
"a/b/c": {Role: "a/b/c"},
"a": {RoleName: "a"},
"a/b/c": {RoleName: "a/b/c"},
},
Role: "a",
AssumeChain: []string{"b", "c"},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a/b/c"},
ExpectedResult: &creds.RefreshableProvider{RoleName: "a/b/c"},
},
}

Expand All @@ -95,7 +95,7 @@ func TestCredentialCache_Get(t *testing.T) {
t.Errorf("%s failed: expected %v error, got %v", tc.Description, tc.ExpectedError, actualError)
continue
}
if actualResult != nil && actualResult.Role != tc.ExpectedResult.Role {
if actualResult != nil && actualResult.RoleArn != tc.ExpectedResult.RoleArn {
t.Errorf("%s failed: expected %v result, got %v", tc.Description, tc.ExpectedResult, actualResult)
}
}
Expand All @@ -120,16 +120,16 @@ func TestCredentialCache_GetDefault(t *testing.T) {
Description: "default role in cache",
DefaultRole: "a",
CacheContents: map[string]*creds.RefreshableProvider{
"a": {Role: "a"},
"a": {RoleName: "a"},
},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a"},
ExpectedResult: &creds.RefreshableProvider{RoleName: "a"},
},
{
Description: "no default role set",
DefaultRole: "",
CacheContents: map[string]*creds.RefreshableProvider{
"a": {Role: "a"},
"a": {RoleName: "a"},
},
ExpectedError: errors.NoDefaultRoleSet,
ExpectedResult: nil,
Expand All @@ -138,30 +138,30 @@ func TestCredentialCache_GetDefault(t *testing.T) {
Description: "default role in cache with assume",
DefaultRole: "a",
CacheContents: map[string]*creds.RefreshableProvider{
"a": {Role: "a"},
"a/b/c": {Role: "a/b/c"},
"a": {RoleName: "a"},
"a/b/c": {RoleName: "a/b/c"},
},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a"},
ExpectedResult: &creds.RefreshableProvider{RoleName: "a"},
},
{
Description: "default assume role in cache",
DefaultRole: "a/b/c",
CacheContents: map[string]*creds.RefreshableProvider{
"a/b/c": {Role: "a/b/c"},
"a/b/c": {RoleName: "a/b/c"},
},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a/b/c"},
ExpectedResult: &creds.RefreshableProvider{RoleName: "a/b/c"},
},
{
Description: "default assume role in cache with non-assume",
DefaultRole: "a/b/c",
CacheContents: map[string]*creds.RefreshableProvider{
"a": {Role: "a"},
"a/b/c": {Role: "a/b/c"},
"a": {RoleName: "a"},
"a/b/c": {RoleName: "a/b/c"},
},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a/b/c"},
ExpectedResult: &creds.RefreshableProvider{RoleName: "a/b/c"},
},
}

Expand All @@ -176,7 +176,7 @@ func TestCredentialCache_GetDefault(t *testing.T) {
t.Errorf("%s failed: expected %v error, got %v", tc.Description, tc.ExpectedError, actualError)
continue
}
if actualResult != nil && actualResult.Role != tc.ExpectedResult.Role {
if actualResult != nil && actualResult.RoleArn != tc.ExpectedResult.RoleArn {
t.Errorf("%s failed: expected %v result, got %v", tc.Description, tc.ExpectedResult, actualResult)
}
}
Expand Down Expand Up @@ -296,7 +296,7 @@ func TestCredentialCache_GetOrSet(t *testing.T) {
cases := []struct {
CacheContents map[string]*creds.RefreshableProvider
ClientResponse interface{}
Role string
SearchString string
AssumeChain []string
Region string
Description string
Expand All @@ -306,30 +306,30 @@ func TestCredentialCache_GetOrSet(t *testing.T) {
{
Description: "role not in cache",
CacheContents: make(map[string]*creds.RefreshableProvider),
Role: "a",
SearchString: "a",
AssumeChain: []string{},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a"},
ExpectedResult: &creds.RefreshableProvider{RoleArn: "arn:aws:iam::012345678901:role/coolRole1"},
},
{
Description: "role not in cache with assume",
CacheContents: map[string]*creds.RefreshableProvider{
"a/b/c": {Role: "a/b/c"},
"a/b/c": {RoleName: "a/b/c"},
},
Role: "a",
SearchString: "a",
AssumeChain: []string{},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a"},
ExpectedResult: &creds.RefreshableProvider{RoleArn: "arn:aws:iam::012345678901:role/coolRole2"},
},
{
Description: "role already in cache",
CacheContents: map[string]*creds.RefreshableProvider{
"a": {Role: "a"},
"a": {RoleArn: "arn:aws:iam::012345678901:role/coolRole3"},
},
Role: "a",
SearchString: "a",
AssumeChain: []string{},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a"},
ExpectedResult: &creds.RefreshableProvider{RoleArn: "arn:aws:iam::012345678901:role/coolRole3"},
},
}

Expand All @@ -344,14 +344,14 @@ func TestCredentialCache_GetOrSet(t *testing.T) {
SecretAccessKey: "b",
SessionToken: "c",
Expiration: creds.Time(time.Unix(1, 0)),
RoleArn: "e",
RoleArn: tc.ExpectedResult.RoleArn,
},
})
if err != nil {
t.Errorf("test setup failure: %e", err)
continue
}
result, actualError := testCache.GetOrSet(client, tc.Role, tc.Region, tc.AssumeChain)
result, actualError := testCache.GetOrSet(client, tc.SearchString, tc.Region, tc.AssumeChain)
if actualError != tc.ExpectedError {
t.Errorf("%s failed: expected %v error, got %v", tc.Description, tc.ExpectedError, actualError)
continue
Expand All @@ -360,8 +360,8 @@ func TestCredentialCache_GetOrSet(t *testing.T) {
t.Errorf("%s failed: got nil result, expected %v", tc.Description, tc.ExpectedResult)
continue
}
if result != nil && result.Role != tc.ExpectedResult.Role {
t.Errorf("%s failed: expected role %v, got %v", tc.Description, tc.ExpectedResult.Role, result.Role)
if result != nil && result.RoleArn != tc.ExpectedResult.RoleArn {
t.Errorf("%s failed: expected role %v, got %v", tc.Description, tc.ExpectedResult.RoleArn, result.RoleArn)
continue
}
}
Expand Down
11 changes: 6 additions & 5 deletions creds/refreshable.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func NewRefreshableProvider(client HTTPClient, role, region string, assumeChain
splitRole := strings.Split(role, "/")
roleName := splitRole[len(splitRole)-1]
rp := &RefreshableProvider{
Role: roleName,
RoleName: roleName,
RoleArn: role,
Region: region,
NoIpRestrict: noIpRestrict,
Expand Down Expand Up @@ -66,7 +66,7 @@ func (rp *RefreshableProvider) AutoRefresh() {
}

func (rp *RefreshableProvider) checkAndRefresh(threshold int) (bool, error) {
log.Debugf("checking credentials for %s", rp.Role)
log.Debugf("checking credentials for %s", rp.RoleName)
// refresh creds if we're within 10 minutes of them expiring
diff := time.Duration(threshold*-1) * time.Minute
thresh := rp.Expiration.Add(diff)
Expand All @@ -80,14 +80,14 @@ func (rp *RefreshableProvider) checkAndRefresh(threshold int) (bool, error) {
}

func (rp *RefreshableProvider) refresh() error {
log.Debugf("refreshing credentials for %s", rp.Role)
log.Debugf("refreshing credentials for %s", rp.RoleArn)
var err error
var newCreds *AwsCredentials

rp.Lock()
defer rp.Unlock()

newCreds, err = GetCredentialsC(rp.client, rp.Role, rp.NoIpRestrict, rp.AssumeChain)
newCreds, err = GetCredentialsC(rp.client, rp.RoleArn, rp.NoIpRestrict, rp.AssumeChain)
if err != nil {
if err == errors.MutualTLSCertNeedsRefreshError {
log.Error(err)
Expand All @@ -106,11 +106,12 @@ func (rp *RefreshableProvider) refresh() error {
rp.value.SecretAccessKey = newCreds.SecretAccessKey
rp.value.AccessKeyID = newCreds.AccessKeyId
rp.LastRefreshed = Time(time.Now())
// We favor the role ARN from ConsoleMe over the one from the user, which could just be a search string.
rp.RoleArn = newCreds.RoleArn
if rp.value.ProviderName == "" {
rp.value.ProviderName = "WeepRefreshableProvider"
}
log.Debugf("successfully refreshed credentials for %s", rp.Role)
log.Debugf("successfully refreshed credentials for %s", rp.RoleArn)
return nil
}

Expand Down
10 changes: 5 additions & 5 deletions creds/refreshable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func TestNewRefreshableProvider(t *testing.T) {
Expiration: testExpiration,
LastRefreshed: Time{},
Region: testRegion,
Role: testRole,
RoleName: testRole,
RoleArn: testRoleArn,
NoIpRestrict: false,
AssumeChain: make([]string, 0),
Expand All @@ -89,7 +89,7 @@ func TestNewRefreshableProvider(t *testing.T) {
Expiration: testExpiration,
LastRefreshed: Time{},
Region: testRegion,
Role: testRole,
RoleName: testRole,
RoleArn: testRoleArn,
NoIpRestrict: true,
AssumeChain: make([]string, 0),
Expand Down Expand Up @@ -120,8 +120,8 @@ func TestNewRefreshableProvider(t *testing.T) {
t.Errorf("%s failed: got %v region, expected %v", tc.Description, actualResult.Region, tc.ExpectedResult.Region)
continue
}
if actualResult != nil && actualResult.Role != tc.ExpectedResult.Role {
t.Errorf("%s failed: got %v role, expected %v", tc.Description, actualResult.Role, tc.ExpectedResult.Role)
if actualResult != nil && actualResult.RoleName != tc.ExpectedResult.RoleName {
t.Errorf("%s failed: got %v role, expected %v", tc.Description, actualResult.RoleName, tc.ExpectedResult.RoleName)
continue
}
if actualResult != nil && actualResult.RoleArn != tc.ExpectedResult.RoleArn {
Expand Down Expand Up @@ -194,7 +194,7 @@ func TestRefreshableProvider_refresh(t *testing.T) {
retries: tc.Retries,
retryDelay: tc.RetryDelay,
Region: tc.Region,
Role: tc.Role,
RoleName: tc.Role,
RoleArn: tc.RoleArn,
NoIpRestrict: tc.NoIpRestrict,
AssumeChain: tc.AssumeChain,
Expand Down
2 changes: 1 addition & 1 deletion creds/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ type RefreshableProvider struct {
Expiration Time
LastRefreshed Time
Region string
Role string
RoleName string
RoleArn string
NoIpRestrict bool
AssumeChain []string
Expand Down
2 changes: 1 addition & 1 deletion server/credentialsHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func RoleHandler(w http.ResponseWriter, r *http.Request) {
util.WriteError(w, "error", 500)
return
}
if _, err := w.Write([]byte(defaultRole.Role)); err != nil {
if _, err := w.Write([]byte(defaultRole.RoleName)); err != nil {
log.Errorf("failed to write response: %v", err)
}
}
Expand Down