diff --git a/credentials/credential.go b/credentials/credential.go index 693f3fa..63ee7ce 100644 --- a/credentials/credential.go +++ b/credentials/credential.go @@ -56,6 +56,7 @@ type Config struct { InAdvanceScale *float64 `json:"inAdvanceScale"` Url *string `json:"url"` STSEndpoint *string `json:"sts_endpoint"` + ExternalId *string `json:"external_id"` } func (s Config) String() string { @@ -234,7 +235,15 @@ func NewCredential(config *Config) (credential Credential, err error) { ConnectTimeout: tea.IntValue(config.ConnectTimeout), STSEndpoint: tea.StringValue(config.STSEndpoint), } - credential = newRAMRoleArnCredential(tea.StringValue(config.AccessKeyId), tea.StringValue(config.AccessKeySecret), tea.StringValue(config.RoleArn), tea.StringValue(config.RoleSessionName), tea.StringValue(config.Policy), tea.IntValue(config.RoleSessionExpiration), runtime) + credential = newRAMRoleArnWithExternalIdCredential( + tea.StringValue(config.AccessKeyId), + tea.StringValue(config.AccessKeySecret), + tea.StringValue(config.RoleArn), + tea.StringValue(config.RoleSessionName), + tea.StringValue(config.Policy), + tea.IntValue(config.RoleSessionExpiration), + tea.StringValue(config.ExternalId), + runtime) case "rsa_key_pair": err = checkRSAKeyPair(config) if err != nil { diff --git a/credentials/sts_role_arn_credential.go b/credentials/sts_role_arn_credential.go index f6933bf..77fa3da 100644 --- a/credentials/sts_role_arn_credential.go +++ b/credentials/sts_role_arn_credential.go @@ -23,6 +23,7 @@ type RAMRoleArnCredential struct { RoleSessionName string RoleSessionExpiration int Policy string + ExternalId string sessionCredential *sessionCredential runtime *utils.Runtime } @@ -51,6 +52,20 @@ func newRAMRoleArnCredential(accessKeyId, accessKeySecret, roleArn, roleSessionN } } +func newRAMRoleArnWithExternalIdCredential(accessKeyId, accessKeySecret, roleArn, roleSessionName, policy string, roleSessionExpiration int, externalId string, runtime *utils.Runtime) *RAMRoleArnCredential { + return &RAMRoleArnCredential{ + AccessKeyId: accessKeyId, + AccessKeySecret: accessKeySecret, + RoleArn: roleArn, + RoleSessionName: roleSessionName, + RoleSessionExpiration: roleSessionExpiration, + Policy: policy, + ExternalId: externalId, + credentialUpdater: new(credentialUpdater), + runtime: runtime, + } +} + // GetAccessKeyId reutrns RamRoleArnCredential's AccessKeyId // if AccessKeyId is not exist or out of date, the function will update it. func (r *RAMRoleArnCredential) GetAccessKeyId() (*string, error) { @@ -125,6 +140,9 @@ func (r *RAMRoleArnCredential) updateCredential() (err error) { if r.Policy != "" { request.QueryParams["Policy"] = r.Policy } + if r.ExternalId != "" { + request.QueryParams["ExternalId"] = r.ExternalId + } request.QueryParams["RoleSessionName"] = r.RoleSessionName request.QueryParams["SignatureMethod"] = "HMAC-SHA1" request.QueryParams["SignatureVersion"] = "1.0" diff --git a/credentials/sts_role_arn_credential_test.go b/credentials/sts_role_arn_credential_test.go index 2adc989..e6745db 100644 --- a/credentials/sts_role_arn_credential_test.go +++ b/credentials/sts_role_arn_credential_test.go @@ -133,4 +133,22 @@ func Test_RoleArnCredential(t *testing.T) { assert.NotNil(t, err) assert.Equal(t, "refresh RoleArn sts token err: Credentials is empty", err.Error()) assert.Equal(t, "", *accesskeyId) + + auth = newRAMRoleArnWithExternalIdCredential("accessKeyId", "accessKeySecret", "roleArn", "roleSessionName", "policy", 3600, "externalId", nil) + hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) { + return func(req *http.Request) (*http.Response, error) { + return mockResponse(200, `{"Credentials":{"AccessKeyId":"accessKeyId","AccessKeySecret":"accessKeySecret","SecurityToken":"securitytoken","Expiration":"2020-01-02T15:04:05Z"}}`, nil) + } + } + accesskeyId, err = auth.GetAccessKeyId() + assert.Nil(t, err) + assert.Equal(t, "accessKeyId", *accesskeyId) + + accesskeySecret, err = auth.GetAccessKeySecret() + assert.Nil(t, err) + assert.Equal(t, "accessKeySecret", *accesskeySecret) + + ststoken, err = auth.GetSecurityToken() + assert.Nil(t, err) + assert.Equal(t, "securitytoken", *ststoken) }