Skip to content

Commit

Permalink
feat: expose internal providers && support init client with provider
Browse files Browse the repository at this point in the history
  • Loading branch information
yndu13 committed Nov 4, 2024
1 parent 47c2eab commit 25ec51c
Show file tree
Hide file tree
Showing 27 changed files with 154 additions and 40 deletions.
16 changes: 8 additions & 8 deletions credentials/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import (

"github.com/alibabacloud-go/debug/debug"
"github.com/alibabacloud-go/tea/tea"
"github.com/aliyun/credentials-go/credentials/internal/providers"
"github.com/aliyun/credentials-go/credentials/internal/utils"
"github.com/aliyun/credentials-go/credentials/providers"
"github.com/aliyun/credentials-go/credentials/request"
"github.com/aliyun/credentials-go/credentials/response"
)
Expand Down Expand Up @@ -209,7 +209,7 @@ func (s *Config) SetExternalId(v string) *Config {
func NewCredential(config *Config) (credential Credential, err error) {
if config == nil {
provider := providers.NewDefaultCredentialsProvider()
credential = fromCredentialsProvider("default", provider)
credential = FromCredentialsProvider("default", provider)
return
}
switch tea.StringValue(config.Type) {
Expand All @@ -234,7 +234,7 @@ func NewCredential(config *Config) (credential Credential, err error) {
if err != nil {
return nil, err
}
credential = fromCredentialsProvider("oidc_role_arn", provider)
credential = FromCredentialsProvider("oidc_role_arn", provider)
case "access_key":
provider, err := providers.NewStaticAKCredentialsProviderBuilder().
WithAccessKeyId(tea.StringValue(config.AccessKeyId)).
Expand All @@ -244,7 +244,7 @@ func NewCredential(config *Config) (credential Credential, err error) {
return nil, err
}

credential = fromCredentialsProvider("access_key", provider)
credential = FromCredentialsProvider("access_key", provider)
case "sts":
provider, err := providers.NewStaticSTSCredentialsProviderBuilder().
WithAccessKeyId(tea.StringValue(config.AccessKeyId)).
Expand All @@ -255,7 +255,7 @@ func NewCredential(config *Config) (credential Credential, err error) {
return nil, err
}

credential = fromCredentialsProvider("sts", provider)
credential = FromCredentialsProvider("sts", provider)
case "ecs_ram_role":
provider, err := providers.NewECSRAMRoleCredentialsProviderBuilder().
WithRoleName(tea.StringValue(config.RoleName)).
Expand All @@ -266,7 +266,7 @@ func NewCredential(config *Config) (credential Credential, err error) {
return nil, err
}

credential = fromCredentialsProvider("ecs_ram_role", provider)
credential = FromCredentialsProvider("ecs_ram_role", provider)
case "ram_role_arn":
var credentialsProvider providers.CredentialsProvider
if config.SecurityToken != nil && *config.SecurityToken != "" {
Expand Down Expand Up @@ -304,7 +304,7 @@ func NewCredential(config *Config) (credential Credential, err error) {
return nil, err
}

credential = fromCredentialsProvider("ram_role_arn", provider)
credential = FromCredentialsProvider("ram_role_arn", provider)
case "rsa_key_pair":
err = checkRSAKeyPair(config)
if err != nil {
Expand Down Expand Up @@ -479,7 +479,7 @@ func (cp *credentialsProviderWrap) GetType() *string {
return &cp.typeName
}

func fromCredentialsProvider(typeName string, cp providers.CredentialsProvider) Credential {
func FromCredentialsProvider(typeName string, cp providers.CredentialsProvider) Credential {
return &credentialsProviderWrap{
typeName: typeName,
provider: cp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,13 @@ type profile struct {
RoleSessionName string `json:"ram_session_name"`
DurationSeconds int `json:"expired_seconds"`
StsRegion string `json:"sts_region"`
EnableVpc bool `json:"enable_vpc"`
SourceProfile string `json:"source_profile"`
RoleName string `json:"ram_role_name"`
OIDCTokenFile string `json:"oidc_token_file"`
OIDCProviderARN string `json:"oidc_provider_arn"`
Policy string `json:"policy"`
ExternalId string `json:"external_id"`
}

type configuration struct {
Expand Down Expand Up @@ -132,6 +135,9 @@ func (provider *CLIProfileCredentialsProvider) getCredentialsProvider(conf *conf
WithRoleSessionName(p.RoleSessionName).
WithDurationSeconds(p.DurationSeconds).
WithStsRegionId(p.StsRegion).
WithEnableVpc(p.EnableVpc).
WithPolicy(p.Policy).
WithExternalId(p.ExternalId).
Build()
case "EcsRamRole":
credentialsProvider, err = NewECSRAMRoleCredentialsProviderBuilder().WithRoleName(p.RoleName).Build()
Expand All @@ -141,8 +147,10 @@ func (provider *CLIProfileCredentialsProvider) getCredentialsProvider(conf *conf
WithOIDCProviderARN(p.OIDCProviderARN).
WithRoleArn(p.RoleArn).
WithStsRegionId(p.StsRegion).
WithEnableVpc(p.EnableVpc).
WithDurationSeconds(p.DurationSeconds).
WithRoleSessionName(p.RoleSessionName).
WithPolicy(p.Policy).
Build()
case "ChainableRamRoleArn":
previousProvider, err1 := provider.getCredentialsProvider(conf, p.SourceProfile)
Expand All @@ -156,6 +164,9 @@ func (provider *CLIProfileCredentialsProvider) getCredentialsProvider(conf *conf
WithRoleSessionName(p.RoleSessionName).
WithDurationSeconds(p.DurationSeconds).
WithStsRegionId(p.StsRegion).
WithEnableVpc(p.EnableVpc).
WithPolicy(p.Policy).
WithExternalId(p.ExternalId).
Build()
default:
err = fmt.Errorf("unsupported profile mode '%s'", p.Mode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ func TestCLIProfileCredentialsProvider_getCredentialsProvider(t *testing.T) {
AccessKeyID: "akid",
AccessKeySecret: "secret",
RoleArn: "arn",
StsRegion: "cn-hangzhou",
EnableVpc: true,
Policy: "policy",
ExternalId: "externalId",
},
{
Mode: "RamRoleArn",
Expand All @@ -107,6 +111,9 @@ func TestCLIProfileCredentialsProvider_getCredentialsProvider(t *testing.T) {
RoleArn: "role_arn",
OIDCTokenFile: "path/to/oidc/file",
OIDCProviderARN: "provider_arn",
StsRegion: "cn-hangzhou",
EnableVpc: true,
Policy: "policy",
},
{
Mode: "ChainableRamRoleArn",
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (builder *ECSRAMRoleCredentialsProviderBuilder) Build() (provider *ECSRAMRo
}

if !builder.provider.disableIMDSv1 {
builder.provider.disableIMDSv1 = os.Getenv("ALIBABA_CLOUD_IMDSV1_DISABLED") == "true"
builder.provider.disableIMDSv1 = strings.ToLower(os.Getenv("ALIBABA_CLOUD_IMDSV1_DISABLED")) == "true"
}

provider = builder.provider
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
"mode": "RamRoleArn",
"access_key_id": "akid",
"access_key_secret": "secret",
"ram_role_arn": "arn"
"ram_role_arn": "arn",
"sts_region": "cn-hangzhou",
"enable_vpc": true,
"policy": "policy",
"external_id": "id"
},
{
"name": "EcsRamRole",
Expand All @@ -24,12 +28,15 @@
"mode": "OIDC",
"ram_role_arn": "role_arn",
"oidc_token_file": "path/to/oidc/file",
"oidc_provider_arn": "provider_arn"
"oidc_provider_arn": "provider_arn",
"sts_region": "cn-hangzhou",
"enable_vpc": true,
"policy": "policy"
},
{
"name": "ChainableRamRoleArn",
"mode": "ChainableRamRoleArn",
"source_profile": "AK"
"source_profile": "ChainableRamRoleArn"
},
{
"name": "ChainableRamRoleArn2",
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,25 @@ import (
"net/http"
"os"
"strconv"
"strings"
"time"

httputil "github.com/aliyun/credentials-go/credentials/internal/http"
"github.com/aliyun/credentials-go/credentials/internal/utils"
)

type OIDCCredentialsProvider struct {
oidcProviderARN string
oidcTokenFilePath string
roleArn string
roleSessionName string
durationSeconds int
policy string
stsRegionId string
stsEndpoint string
oidcProviderARN string
oidcTokenFilePath string
roleArn string
roleSessionName string
durationSeconds int
policy string
// for sts endpoint
stsRegionId string
enableVpc bool
stsEndpoint string

lastUpdateTimestamp int64
expirationTimestamp int64
sessionCredentials *sessionCredentials
Expand Down Expand Up @@ -70,6 +74,11 @@ func (b *OIDCCredentialsProviderBuilder) WithStsRegionId(regionId string) *OIDCC
return b
}

func (b *OIDCCredentialsProviderBuilder) WithEnableVpc(enableVpc bool) *OIDCCredentialsProviderBuilder {
b.provider.enableVpc = enableVpc
return b
}

func (b *OIDCCredentialsProviderBuilder) WithPolicy(policy string) *OIDCCredentialsProviderBuilder {
b.provider.policy = policy
return b
Expand Down Expand Up @@ -126,10 +135,17 @@ func (b *OIDCCredentialsProviderBuilder) Build() (provider *OIDCCredentialsProvi
}

if b.provider.stsEndpoint == "" {
if !b.provider.enableVpc {
b.provider.enableVpc = strings.ToLower(os.Getenv("ALIBABA_CLOUD_VPC_ENDPOINT_ENABLED")) == "true"
}
prefix := "sts"
if b.provider.enableVpc {
prefix = "sts-vpc"
}
if b.provider.stsRegionId != "" {
b.provider.stsEndpoint = fmt.Sprintf("sts.%s.aliyuncs.com", b.provider.stsRegionId)
b.provider.stsEndpoint = fmt.Sprintf("%s.%s.aliyuncs.com", prefix, b.provider.stsRegionId)
} else if region := os.Getenv("ALIBABA_CLOUD_STS_REGION"); region != "" {
b.provider.stsEndpoint = fmt.Sprintf("sts.%s.aliyuncs.com", region)
b.provider.stsEndpoint = fmt.Sprintf("%s.%s.aliyuncs.com", prefix, region)
} else {
b.provider.stsEndpoint = "sts.aliyuncs.com"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestOIDCCredentialsProviderGetCredentialsWithError(t *testing.T) {
}

func TestNewOIDCCredentialsProvider(t *testing.T) {
rollback := utils.Memory("ALIBABA_CLOUD_OIDC_TOKEN_FILE", "ALIBABA_CLOUD_OIDC_PROVIDER_ARN", "ALIBABA_CLOUD_ROLE_ARN", "ALIBABA_CLOUD_STS_REGION")
rollback := utils.Memory("ALIBABA_CLOUD_OIDC_TOKEN_FILE", "ALIBABA_CLOUD_OIDC_PROVIDER_ARN", "ALIBABA_CLOUD_ROLE_ARN", "ALIBABA_CLOUD_STS_REGION", "ALIBABA_CLOUD_VPC_ENDPOINT_ENABLED")
defer func() {
rollback()
}()
Expand Down Expand Up @@ -92,10 +92,11 @@ func TestNewOIDCCredentialsProvider(t *testing.T) {

// sts endpoint: with sts endpoint env
os.Setenv("ALIBABA_CLOUD_STS_REGION", "cn-hangzhou")
os.Setenv("ALIBABA_CLOUD_VPC_ENDPOINT_ENABLED", "true")
p, err = NewOIDCCredentialsProviderBuilder().
Build()
assert.Nil(t, err)
assert.Equal(t, "sts.cn-hangzhou.aliyuncs.com", p.stsEndpoint)
assert.Equal(t, "sts-vpc.cn-hangzhou.aliyuncs.com", p.stsEndpoint)

// sts endpoint: with sts endpoint
p, err = NewOIDCCredentialsProviderBuilder().
Expand All @@ -107,10 +108,12 @@ func TestNewOIDCCredentialsProvider(t *testing.T) {
// sts endpoint: with sts regionId
p, err = NewOIDCCredentialsProviderBuilder().
WithStsRegionId("cn-beijing").
WithEnableVpc(true).
Build()
assert.Nil(t, err)
assert.Equal(t, "sts.cn-beijing.aliyuncs.com", p.stsEndpoint)
assert.Equal(t, "sts-vpc.cn-beijing.aliyuncs.com", p.stsEndpoint)

os.Setenv("ALIBABA_CLOUD_VPC_ENDPOINT_ENABLED", "false")
p, err = NewOIDCCredentialsProviderBuilder().
WithOIDCTokenFilePath("/path/to/invalid/oidc.token").
WithOIDCProviderARN("provider-arn").
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,17 @@ func (provider *ProfileCredentialsProvider) getCredentialsProvider(ini *ini.File
err = errors.New("get previous credentials provider failed")
return
}
rawPolicy, _ := section.GetKey("policy")
policy := ""
if rawPolicy != nil {
policy = rawPolicy.String()
}

credentialsProvider, err = NewRAMRoleARNCredentialsProviderBuilder().
WithCredentialsProvider(previous).
WithRoleArn(value3.String()).
WithRoleSessionName(value4.String()).
WithPolicy(policy).
WithDurationSeconds(3600).
Build()
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ access_key_id = foo
access_key_secret = bar
role_arn = role_arn
role_session_name = session_name
policy = {"Statement": [{"Action": ["*"],"Effect": "Allow","Resource": ["*"]}],"Version":"1"}
[noram]
type = ram_role_arn
Expand Down
Loading

0 comments on commit 25ec51c

Please sign in to comment.