From e290d3080b8a78a6451cb258bf21d5bf5eb16a85 Mon Sep 17 00:00:00 2001 From: Jackson Tian Date: Mon, 29 Jul 2024 14:01:52 +0800 Subject: [PATCH] refine RAMRoleArnCredential --- credentials/sts_role_arn_credential.go | 42 ++++++++++----------- credentials/sts_role_arn_credential_test.go | 33 +++++++++++----- 2 files changed, 44 insertions(+), 31 deletions(-) diff --git a/credentials/sts_role_arn_credential.go b/credentials/sts_role_arn_credential.go index 1b75500..428fa24 100644 --- a/credentials/sts_role_arn_credential.go +++ b/credentials/sts_role_arn_credential.go @@ -100,38 +100,38 @@ func (e *RAMRoleArnCredential) GetCredential() (*CredentialModel, error) { // GetAccessKeyId reutrns RamRoleArnCredential's AccessKeyId // if AccessKeyId is not exist or out of date, the function will update it. -func (r *RAMRoleArnCredential) GetAccessKeyId() (*string, error) { - if r.sessionCredential == nil || r.needUpdateCredential() { - err := r.updateCredential() - if err != nil { - return tea.String(""), err - } +func (r *RAMRoleArnCredential) GetAccessKeyId() (accessKeyId *string, err error) { + c, err := r.GetCredential() + if err != nil { + return } - return tea.String(r.sessionCredential.AccessKeyId), nil + + accessKeyId = c.AccessKeyId + return } // GetAccessSecret reutrns RamRoleArnCredential's AccessKeySecret // if AccessKeySecret is not exist or out of date, the function will update it. -func (r *RAMRoleArnCredential) GetAccessKeySecret() (*string, error) { - if r.sessionCredential == nil || r.needUpdateCredential() { - err := r.updateCredential() - if err != nil { - return tea.String(""), err - } +func (r *RAMRoleArnCredential) GetAccessKeySecret() (accessKeySecret *string, err error) { + c, err := r.GetCredential() + if err != nil { + return } - return tea.String(r.sessionCredential.AccessKeySecret), nil + + accessKeySecret = c.AccessKeySecret + return } // GetSecurityToken reutrns RamRoleArnCredential's SecurityToken // if SecurityToken is not exist or out of date, the function will update it. -func (r *RAMRoleArnCredential) GetSecurityToken() (*string, error) { - if r.sessionCredential == nil || r.needUpdateCredential() { - err := r.updateCredential() - if err != nil { - return tea.String(""), err - } +func (r *RAMRoleArnCredential) GetSecurityToken() (securityToken *string, err error) { + c, err := r.GetCredential() + if err != nil { + return } - return tea.String(r.sessionCredential.SecurityToken), nil + + securityToken = c.SecurityToken + return } // GetBearerToken is useless RamRoleArnCredential diff --git a/credentials/sts_role_arn_credential_test.go b/credentials/sts_role_arn_credential_test.go index 42f1eaa..c9f12df 100644 --- a/credentials/sts_role_arn_credential_test.go +++ b/credentials/sts_role_arn_credential_test.go @@ -38,17 +38,17 @@ func Test_RoleArnCredential(t *testing.T) { accesskeyId, err := auth.GetAccessKeyId() assert.NotNil(t, err) assert.Equal(t, "[InvalidParam]:Assume Role session duration should be in the range of 15min - 1Hr", err.Error()) - assert.Equal(t, "", *accesskeyId) + assert.Nil(t, accesskeyId) accesskeySecret, err := auth.GetAccessKeySecret() assert.NotNil(t, err) assert.Equal(t, "[InvalidParam]:Assume Role session duration should be in the range of 15min - 1Hr", err.Error()) - assert.Equal(t, "", *accesskeySecret) + assert.Nil(t, accesskeySecret) ststoken, err := auth.GetSecurityToken() assert.NotNil(t, err) assert.Equal(t, "[InvalidParam]:Assume Role session duration should be in the range of 15min - 1Hr", err.Error()) - assert.Equal(t, "", *ststoken) + assert.Nil(t, ststoken) assert.Equal(t, "", *auth.GetBearerToken()) assert.Equal(t, "ram_role_arn", *auth.GetType()) @@ -57,13 +57,13 @@ func Test_RoleArnCredential(t *testing.T) { accesskeyId, err = auth.GetAccessKeyId() assert.NotNil(t, err) assert.Equal(t, "refresh RoleArn sts token err: Internal error", err.Error()) - assert.Equal(t, "", *accesskeyId) + assert.Nil(t, accesskeyId) auth.RoleSessionExpiration = 0 accesskeyId, err = auth.GetAccessKeyId() assert.NotNil(t, err) assert.Equal(t, "refresh RoleArn sts token err: Internal error", err.Error()) - assert.Equal(t, "", *accesskeyId) + assert.Nil(t, accesskeyId) hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) { return func(req *http.Request) (*http.Response, error) { @@ -73,7 +73,7 @@ func Test_RoleArnCredential(t *testing.T) { accesskeyId, err = auth.GetAccessKeyId() assert.NotNil(t, err) assert.Equal(t, "refresh RoleArn sts token err: httpStatus: 300, message = ", err.Error()) - assert.Equal(t, "", *accesskeyId) + assert.Nil(t, accesskeyId) hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) { return func(req *http.Request) (*http.Response, error) { @@ -83,7 +83,7 @@ func Test_RoleArnCredential(t *testing.T) { accesskeyId, err = auth.GetAccessKeyId() assert.NotNil(t, err) assert.Equal(t, "refresh RoleArn sts token err: Json.Unmarshal fail: invalid character ':' after top-level value", err.Error()) - assert.Equal(t, "", *accesskeyId) + assert.Nil(t, accesskeyId) hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) { return func(req *http.Request) (*http.Response, error) { @@ -93,7 +93,7 @@ func Test_RoleArnCredential(t *testing.T) { accesskeyId, err = auth.GetAccessKeyId() assert.NotNil(t, err) assert.Equal(t, "refresh RoleArn sts token err: AccessKeyId: , AccessKeySecret: accessKeySecret, SecurityToken: securitytoken, Expiration: expiration", err.Error()) - assert.Equal(t, "", *accesskeyId) + assert.Nil(t, accesskeyId) hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) { return func(req *http.Request) (*http.Response, error) { @@ -103,7 +103,7 @@ func Test_RoleArnCredential(t *testing.T) { accesskeyId, err = auth.GetAccessKeyId() assert.NotNil(t, err) assert.Equal(t, "refresh RoleArn sts token err: Credentials is empty", err.Error()) - assert.Equal(t, "", *accesskeyId) + assert.Nil(t, accesskeyId) hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) { return func(req *http.Request) (*http.Response, error) { @@ -140,7 +140,7 @@ func Test_RoleArnCredential(t *testing.T) { accesskeyId, err = auth.GetAccessKeyId() assert.NotNil(t, err) assert.Equal(t, "refresh RoleArn sts token err: Credentials is empty", err.Error()) - assert.Equal(t, "", *accesskeyId) + assert.Nil(t, accesskeyId) auth = newRAMRoleArnWithExternalIdCredential("accessKeyId", "accessKeySecret", "roleArn", "roleSessionName", "policy", 3600, "externalId", nil) hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) { @@ -160,3 +160,16 @@ func Test_RoleArnCredential(t *testing.T) { assert.Nil(t, err) assert.Equal(t, "securitytoken", *ststoken) } + +func TestStsRoleARNCredentialsProviderWithSecurityToken(t *testing.T) { + auth := newRAMRoleArnl("accessKeyId", "accessKeySecret", "securityToken", "roleArn", "roleSessionName", "policy", 3600, "externalId", nil) + hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) { + return func(req *http.Request) (*http.Response, error) { + assert.Equal(t, "securityToken", req.URL.Query().Get("SecurityToken")) + return mockResponse(200, `{"Credentials":{"AccessKeyId":"accessKeyId","AccessKeySecret":"accessKeySecret","SecurityToken":"securitytoken","Expiration":"2020-01-02T15:04:05Z"}}`, nil) + } + } + + _, err := auth.GetCredential() + assert.Nil(t, err) +}