From b081bfe329ddab74d2241b483f3927d61fa01472 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 13 Nov 2024 15:37:23 -0600 Subject: [PATCH 1/2] style: session duration can default to 1h so not required Signed-off-by: Samantha Coyle --- .build-tools/builtin-authentication-profiles.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.build-tools/builtin-authentication-profiles.yaml b/.build-tools/builtin-authentication-profiles.yaml index 86d412878e..9b37c895cd 100644 --- a/.build-tools/builtin-authentication-profiles.yaml +++ b/.build-tools/builtin-authentication-profiles.yaml @@ -54,7 +54,7 @@ aws: If set to 0m, temporary credentials will automatically rotate. default: '1h' example: '0m' - required: true + required: false azuread: - title: "Azure AD: Managed identity" From e6f7699c72bcd6021709fc8c1a085e642824d92c Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Thu, 14 Nov 2024 09:47:08 -0600 Subject: [PATCH 2/2] fix: address final feedback Signed-off-by: Samantha Coyle --- .../builtin-authentication-profiles.yaml | 8 - common/authentication/aws/aws_test.go | 13 ++ common/authentication/aws/client_test.go | 13 ++ common/authentication/aws/static.go | 68 +++--- common/authentication/aws/static_test.go | 13 +- common/authentication/aws/x509.go | 198 ++++++++---------- common/authentication/aws/x509_test.go | 22 +- state/aws/dynamodb/dynamodb.go | 2 +- state/aws/dynamodb/dynamodb_test.go | 23 +- 9 files changed, 168 insertions(+), 192 deletions(-) diff --git a/.build-tools/builtin-authentication-profiles.yaml b/.build-tools/builtin-authentication-profiles.yaml index 9b37c895cd..373c2d230d 100644 --- a/.build-tools/builtin-authentication-profiles.yaml +++ b/.build-tools/builtin-authentication-profiles.yaml @@ -47,14 +47,6 @@ aws: ARN of the AWS IAM role to assume in the trusting AWS account. example: arn:aws:iam:012345678910:role/exampleIAMRoleName required: true - - name: sessionDuration - type: duration - description: | - Duration of the session using AWS IAM Roles Anywhere. - If set to 0m, temporary credentials will automatically rotate. - default: '1h' - example: '0m' - required: false azuread: - title: "Azure AD: Managed identity" diff --git a/common/authentication/aws/aws_test.go b/common/authentication/aws/aws_test.go index 24b6c5649c..15aac78ad7 100644 --- a/common/authentication/aws/aws_test.go +++ b/common/authentication/aws/aws_test.go @@ -1,3 +1,16 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package aws import ( diff --git a/common/authentication/aws/client_test.go b/common/authentication/aws/client_test.go index 20ed547006..67d2ac88f3 100644 --- a/common/authentication/aws/client_test.go +++ b/common/authentication/aws/client_test.go @@ -1,3 +1,16 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package aws import ( diff --git a/common/authentication/aws/static.go b/common/authentication/aws/static.go index 43081fe57f..6997b62a1e 100644 --- a/common/authentication/aws/static.go +++ b/common/authentication/aws/static.go @@ -31,28 +31,28 @@ import ( type StaticAuth struct { mu sync.RWMutex - Logger logger.Logger + logger logger.Logger - Region string - Endpoint *string - AccessKey *string - SecretKey *string - SessionToken *string + region *string + endpoint *string + accessKey *string + secretKey *string + sessionToken *string - Clients *Clients - Session *session.Session - Cfg *aws.Config + session *session.Session + cfg *aws.Config + Clients *Clients // exported to mock clients in unit tests } func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth, error) { auth := &StaticAuth{ - Logger: opts.Logger, - Region: opts.Region, - Endpoint: &opts.Endpoint, - AccessKey: &opts.AccessKey, - SecretKey: &opts.SecretKey, - SessionToken: &opts.SessionToken, - Cfg: func() *aws.Config { + logger: opts.Logger, + region: &opts.Region, + endpoint: &opts.Endpoint, + accessKey: &opts.AccessKey, + secretKey: &opts.SecretKey, + sessionToken: &opts.SessionToken, + cfg: func() *aws.Config { // if nil is passed or it's just a default cfg, // then we use the options to build the aws cfg. if cfg != nil && cfg != aws.NewConfig() { @@ -68,7 +68,7 @@ func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth return nil, fmt.Errorf("failed to get token client: %v", err) } - auth.Session = initialSession + auth.session = initialSession return auth, nil } @@ -83,7 +83,7 @@ func (a *StaticAuth) S3() *S3Clients { s3Clients := S3Clients{} a.Clients.s3 = &s3Clients - a.Clients.s3.New(a.Session) + a.Clients.s3.New(a.session) return a.Clients.s3 } @@ -97,7 +97,7 @@ func (a *StaticAuth) DynamoDB() *DynamoDBClients { clients := DynamoDBClients{} a.Clients.Dynamo = &clients - a.Clients.Dynamo.New(a.Session) + a.Clients.Dynamo.New(a.session) return a.Clients.Dynamo } @@ -112,7 +112,7 @@ func (a *StaticAuth) Sqs() *SqsClients { clients := SqsClients{} a.Clients.sqs = &clients - a.Clients.sqs.New(a.Session) + a.Clients.sqs.New(a.session) return a.Clients.sqs } @@ -127,7 +127,7 @@ func (a *StaticAuth) Sns() *SnsClients { clients := SnsClients{} a.Clients.sns = &clients - a.Clients.sns.New(a.Session) + a.Clients.sns.New(a.session) return a.Clients.sns } @@ -141,7 +141,7 @@ func (a *StaticAuth) SnsSqs() *SnsSqsClients { clients := SnsSqsClients{} a.Clients.snssqs = &clients - a.Clients.snssqs.New(a.Session) + a.Clients.snssqs.New(a.session) return a.Clients.snssqs } @@ -155,7 +155,7 @@ func (a *StaticAuth) SecretManager() *SecretManagerClients { clients := SecretManagerClients{} a.Clients.Secret = &clients - a.Clients.Secret.New(a.Session) + a.Clients.Secret.New(a.session) return a.Clients.Secret } @@ -169,7 +169,7 @@ func (a *StaticAuth) ParameterStore() *ParameterStoreClients { clients := ParameterStoreClients{} a.Clients.ParameterStore = &clients - a.Clients.ParameterStore.New(a.Session) + a.Clients.ParameterStore.New(a.session) return a.Clients.ParameterStore } @@ -183,7 +183,7 @@ func (a *StaticAuth) Kinesis() *KinesisClients { clients := KinesisClients{} a.Clients.kinesis = &clients - a.Clients.kinesis.New(a.Session) + a.Clients.kinesis.New(a.session) return a.Clients.kinesis } @@ -197,29 +197,29 @@ func (a *StaticAuth) Ses() *SesClients { clients := SesClients{} a.Clients.ses = &clients - a.Clients.ses.New(a.Session) + a.Clients.ses.New(a.session) return a.Clients.ses } func (a *StaticAuth) getTokenClient() (*session.Session, error) { var awsConfig *aws.Config - if a.Cfg == nil { + if a.cfg == nil { awsConfig = aws.NewConfig() } else { - awsConfig = a.Cfg + awsConfig = a.cfg } - if a.Region != "" { - awsConfig = awsConfig.WithRegion(a.Region) + if a.region != nil { + awsConfig = awsConfig.WithRegion(*a.region) } - if a.AccessKey != nil && a.SecretKey != nil { + if a.accessKey != nil && a.secretKey != nil { // session token is an option field - awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(*a.AccessKey, *a.SecretKey, *a.SessionToken)) + awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(*a.accessKey, *a.secretKey, *a.sessionToken)) } - if a.Endpoint != nil { - awsConfig = awsConfig.WithEndpoint(*a.Endpoint) + if a.endpoint != nil { + awsConfig = awsConfig.WithEndpoint(*a.endpoint) } awsSession, err := session.NewSessionWithOptions(session.Options{ diff --git a/common/authentication/aws/static_test.go b/common/authentication/aws/static_test.go index 1f2191cfa4..3a9b3a2d00 100644 --- a/common/authentication/aws/static_test.go +++ b/common/authentication/aws/static_test.go @@ -3,7 +3,6 @@ package aws import ( "testing" - "github.com/aws/aws-sdk-go/aws" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -46,11 +45,11 @@ func TestGetTokenClient(t *testing.T) { { name: "valid token client", awsInstance: &StaticAuth{ - AccessKey: aws.String("testAccessKey"), - SecretKey: aws.String("testSecretKey"), - SessionToken: aws.String("testSessionToken"), - Region: "us-west-2", - Endpoint: aws.String("https://test.endpoint.com"), + accessKey: "testAccessKey", + secretKey: "testSecretKey", + sessionToken: "testSessionToken", + region: "us-west-2", + endpoint: "https://test.endpoint.com", }, }, } @@ -60,7 +59,7 @@ func TestGetTokenClient(t *testing.T) { session, err := tt.awsInstance.getTokenClient() require.NotNil(t, session) require.NoError(t, err) - assert.Equal(t, tt.awsInstance.Region, *session.Config.Region) + assert.Equal(t, tt.awsInstance.region, *session.Config.Region) }) } } diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go index f90d763ed6..cb1bafdeb3 100644 --- a/common/authentication/aws/x509.go +++ b/common/authentication/aws/x509.go @@ -1,5 +1,5 @@ /* -Copyright 2021 The Dapr Authors +Copyright 2024 The Dapr Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -41,39 +41,38 @@ import ( "github.com/dapr/kit/ptr" ) -type x509 struct { - mu sync.RWMutex +type x509Options struct { + TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"` + TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"` + AssumeRoleArn *string `json:"assumeRoleArn" mapstructure:"assumeRoleArn"` +} +type x509 struct { + mu sync.RWMutex wg sync.WaitGroup closeCh chan struct{} logger logger.Logger - Clients *Clients + clients *Clients rolesAnywhereClient rolesanywhereiface.RolesAnywhereAPI // this is so we can mock it in tests session *session.Session - Cfg *aws.Config + cfg *aws.Config chainPEM []byte keyPEM []byte region *string - TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"` - TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"` - AssumeRoleArn *string `json:"assumeRoleArn" mapstructure:"assumeRoleArn"` - SessionDuration *time.Duration `json:"sessionDuration" mapstructure:"sessionDuration"` + trustProfileArn *string + trustAnchorArn *string + assumeRoleArn *string } func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) { - var x509Auth x509 + var x509Auth x509Options if err := kitmd.DecodeMetadata(opts.Properties, &x509Auth); err != nil { return nil, err } - if x509Auth.SessionDuration == nil { - defaultDuration := time.Hour - x509Auth.SessionDuration = &defaultDuration - } - switch { case x509Auth.TrustProfileArn == nil: return nil, errors.New("trustProfileArn is required") @@ -81,18 +80,14 @@ func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) return nil, errors.New("trustAnchorArn is required") case x509Auth.AssumeRoleArn == nil: return nil, errors.New("assumeRoleArn is required") - case *x509Auth.SessionDuration != 0 && (*x509Auth.SessionDuration < time.Minute*15 || *x509Auth.SessionDuration > time.Hour*12): - return nil, errors.New("sessionDuration must be greater than 15 minutes, and less than 12 hours") } auth := &x509{ - wg: sync.WaitGroup{}, logger: opts.Logger, - TrustProfileArn: x509Auth.TrustProfileArn, - TrustAnchorArn: x509Auth.TrustAnchorArn, - AssumeRoleArn: x509Auth.AssumeRoleArn, - SessionDuration: x509Auth.SessionDuration, - Cfg: func() *aws.Config { + trustProfileArn: x509Auth.TrustProfileArn, + trustAnchorArn: x509Auth.TrustAnchorArn, + assumeRoleArn: x509Auth.AssumeRoleArn, + cfg: func() *aws.Config { // if nil is passed or it's just a default cfg, // then we use the options to build the aws cfg. if cfg != nil && cfg != aws.NewConfig() { @@ -100,17 +95,15 @@ func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) } return GetConfig(opts) }(), - Clients: newClients(), + clients: newClients(), } - var err error - err = auth.getCertPEM(ctx) - if err != nil { + if err := auth.getCertPEM(ctx); err != nil { return nil, fmt.Errorf("failed to get x.509 credentials: %v", err) } // Parse trust anchor and profile ARNs - if err = auth.initializeTrustAnchors(); err != nil { + if err := auth.initializeTrustAnchors(); err != nil { return nil, err } @@ -157,128 +150,128 @@ func (a *x509) S3() *S3Clients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.s3 != nil { - return a.Clients.s3 + if a.clients.s3 != nil { + return a.clients.s3 } s3Clients := S3Clients{} - a.Clients.s3 = &s3Clients - a.Clients.s3.New(a.session) - return a.Clients.s3 + a.clients.s3 = &s3Clients + a.clients.s3.New(a.session) + return a.clients.s3 } func (a *x509) DynamoDB() *DynamoDBClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.Dynamo != nil { - return a.Clients.Dynamo + if a.clients.Dynamo != nil { + return a.clients.Dynamo } clients := DynamoDBClients{} - a.Clients.Dynamo = &clients - a.Clients.Dynamo.New(a.session) + a.clients.Dynamo = &clients + a.clients.Dynamo.New(a.session) - return a.Clients.Dynamo + return a.clients.Dynamo } func (a *x509) Sqs() *SqsClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.sqs != nil { - return a.Clients.sqs + if a.clients.sqs != nil { + return a.clients.sqs } clients := SqsClients{} - a.Clients.sqs = &clients - a.Clients.sqs.New(a.session) + a.clients.sqs = &clients + a.clients.sqs.New(a.session) - return a.Clients.sqs + return a.clients.sqs } func (a *x509) Sns() *SnsClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.sns != nil { - return a.Clients.sns + if a.clients.sns != nil { + return a.clients.sns } clients := SnsClients{} - a.Clients.sns = &clients - a.Clients.sns.New(a.session) - return a.Clients.sns + a.clients.sns = &clients + a.clients.sns.New(a.session) + return a.clients.sns } func (a *x509) SnsSqs() *SnsSqsClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.snssqs != nil { - return a.Clients.snssqs + if a.clients.snssqs != nil { + return a.clients.snssqs } clients := SnsSqsClients{} - a.Clients.snssqs = &clients - a.Clients.snssqs.New(a.session) - return a.Clients.snssqs + a.clients.snssqs = &clients + a.clients.snssqs.New(a.session) + return a.clients.snssqs } func (a *x509) SecretManager() *SecretManagerClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.Secret != nil { - return a.Clients.Secret + if a.clients.Secret != nil { + return a.clients.Secret } clients := SecretManagerClients{} - a.Clients.Secret = &clients - a.Clients.Secret.New(a.session) - return a.Clients.Secret + a.clients.Secret = &clients + a.clients.Secret.New(a.session) + return a.clients.Secret } func (a *x509) ParameterStore() *ParameterStoreClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.ParameterStore != nil { - return a.Clients.ParameterStore + if a.clients.ParameterStore != nil { + return a.clients.ParameterStore } clients := ParameterStoreClients{} - a.Clients.ParameterStore = &clients - a.Clients.ParameterStore.New(a.session) - return a.Clients.ParameterStore + a.clients.ParameterStore = &clients + a.clients.ParameterStore.New(a.session) + return a.clients.ParameterStore } func (a *x509) Kinesis() *KinesisClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.kinesis != nil { - return a.Clients.kinesis + if a.clients.kinesis != nil { + return a.clients.kinesis } clients := KinesisClients{} - a.Clients.kinesis = &clients - a.Clients.kinesis.New(a.session) - return a.Clients.kinesis + a.clients.kinesis = &clients + a.clients.kinesis.New(a.session) + return a.clients.kinesis } func (a *x509) Ses() *SesClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.ses != nil { - return a.Clients.ses + if a.clients.ses != nil { + return a.clients.ses } clients := SesClients{} - a.Clients.ses = &clients - a.Clients.ses.New(a.session) - return a.Clients.ses + a.clients.ses = &clients + a.clients.ses.New(a.session) + return a.clients.ses } func (a *x509) initializeTrustAnchors() error { @@ -287,16 +280,16 @@ func (a *x509) initializeTrustAnchors() error { profile arn.ARN err error ) - if a.TrustAnchorArn != nil { - trustAnchor, err = arn.Parse(*a.TrustAnchorArn) + if a.trustAnchorArn != nil { + trustAnchor, err = arn.Parse(*a.trustAnchorArn) if err != nil { return err } a.region = &trustAnchor.Region } - if a.TrustProfileArn != nil { - profile, err = arn.Parse(*a.TrustProfileArn) + if a.trustProfileArn != nil { + profile, err = arn.Parse(*a.trustProfileArn) if err != nil { return err } @@ -347,10 +340,10 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er var mySession *session.Session var awsConfig *aws.Config - if a.Cfg == nil { + if a.cfg == nil { awsConfig = aws.NewConfig().WithHTTPClient(client).WithLogLevel(aws.LogOff) } else { - awsConfig = a.Cfg.WithHTTPClient(client).WithLogLevel(aws.LogOff) + awsConfig = a.cfg.WithHTTPClient(client).WithLogLevel(aws.LogOff) } if a.region != nil { awsConfig.WithRegion(*a.region) @@ -368,35 +361,17 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er rolesClient = rolesAnywhereClient } - var ( - duration int64 - createSessionRequest rolesanywhere.CreateSessionInput - ) - if *a.SessionDuration != 0 { - duration = int64(a.SessionDuration.Seconds()) - - createSessionRequest = rolesanywhere.CreateSessionInput{ - Cert: ptr.Of(string(a.chainPEM)), - ProfileArn: a.TrustProfileArn, - TrustAnchorArn: a.TrustAnchorArn, - RoleArn: a.AssumeRoleArn, - DurationSeconds: aws.Int64(duration), - InstanceProperties: nil, - SessionName: nil, - } - } else { - duration = int64(time.Hour.Seconds()) - - createSessionRequest = rolesanywhere.CreateSessionInput{ - Cert: ptr.Of(string(a.chainPEM)), - ProfileArn: a.TrustProfileArn, - TrustAnchorArn: a.TrustAnchorArn, - RoleArn: a.AssumeRoleArn, - DurationSeconds: aws.Int64(duration), - InstanceProperties: nil, - SessionName: nil, - } + createSessionRequest := rolesanywhere.CreateSessionInput{ + Cert: ptr.Of(string(a.chainPEM)), + ProfileArn: a.trustProfileArn, + TrustAnchorArn: a.trustAnchorArn, + RoleArn: a.assumeRoleArn, + // https://aws.amazon.com/about-aws/whats-new/2024/03/iam-roles-anywhere-credentials-valid-12-hours/#:~:text=The%20duration%20can%20range%20from,and%20applications%2C%20to%20use%20X. + DurationSeconds: aws.Int64(int64(time.Hour.Seconds())), // AWS default is 1hr timeout + InstanceProperties: nil, + SessionName: nil, } + var output *rolesanywhere.CreateSessionOutput if a.rolesAnywhereClient != nil { var err error @@ -432,11 +407,6 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er func (a *x509) startSessionRefresher() { a.logger.Infof("starting session refresher for x509 auth") - // if there is a set session duration, then exit bc we will not auto refresh the session. - if *a.SessionDuration != 0 { - a.logger.Debugf("session duration was set, so there is no authentication refreshing") - return - } a.wg.Add(1) go func() { @@ -465,11 +435,11 @@ func (a *x509) refreshClient() { for { newSession, err := a.createOrRefreshSession(context.Background()) if err == nil { - a.Clients.refresh(newSession) + a.clients.refresh(newSession) a.logger.Debugf("AWS IAM Roles Anywhere session credentials refreshed successfully") return } - a.logger.Errorf("Failed to refresh session: %w", err) + a.logger.Errorf("Failed to refresh session, retrying in 5 seconds: %w", err) select { case <-time.After(time.Second * 5): case <-a.closeCh: diff --git a/common/authentication/aws/x509_test.go b/common/authentication/aws/x509_test.go index 4e5752f33d..3f7d2189c3 100644 --- a/common/authentication/aws/x509_test.go +++ b/common/authentication/aws/x509_test.go @@ -1,3 +1,16 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package aws import ( @@ -19,7 +32,6 @@ import ( spiffecontext "github.com/dapr/kit/crypto/spiffe/context" "github.com/dapr/kit/crypto/test" "github.com/dapr/kit/logger" - "github.com/dapr/kit/ptr" ) type mockRolesAnywhereClient struct { @@ -59,17 +71,15 @@ func TestGetX509Client(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - duration := time.Duration(800) mockSvc := &mockRolesAnywhereClient{ CreateSessionOutput: tt.mockOutput, CreateSessionError: tt.mockError, } mockAWS := x509{ logger: logger.NewLogger("testLogger"), - AssumeRoleArn: ptr.Of("arn:aws:iam:012345678910:role/exampleIAMRoleName"), - TrustAnchorArn: ptr.Of("arn:aws:rolesanywhere:us-west-1:012345678910:trust-anchor/01234568-0123-0123-0123-012345678901"), - TrustProfileArn: ptr.Of("arn:aws:rolesanywhere:us-west-1:012345678910:profile/01234568-0123-0123-0123-012345678901"), - SessionDuration: &duration, + assumeRoleArn: aws.String("arn:aws:iam:012345678910:role/exampleIAMRoleName"), + trustAnchorArn: aws.String("arn:aws:rolesanywhere:us-west-1:012345678910:trust-anchor/01234568-0123-0123-0123-012345678901"), + trustProfileArn: aws.String("arn:aws:rolesanywhere:us-west-1:012345678910:profile/01234568-0123-0123-0123-012345678901"), rolesAnywhereClient: mockSvc, } pki := test.GenPKI(t, test.PKIOptions{ diff --git a/state/aws/dynamodb/dynamodb.go b/state/aws/dynamodb/dynamodb.go index b85cd5e0e5..ae4ba7c5e9 100644 --- a/state/aws/dynamodb/dynamodb.go +++ b/state/aws/dynamodb/dynamodb.go @@ -275,7 +275,7 @@ func (d *StateStore) GetComponentMetadata() (metadataInfo metadata.MetadataMap) } func (d *StateStore) Close() error { - return nil + return d.authProvider.Close() } func (d *StateStore) getDynamoDBMetadata(meta state.Metadata) (*dynamoDBMetadata, error) { diff --git a/state/aws/dynamodb/dynamodb_test.go b/state/aws/dynamodb/dynamodb_test.go index d5dd43aeb8..28a34c8af9 100644 --- a/state/aws/dynamodb/dynamodb_test.go +++ b/state/aws/dynamodb/dynamodb_test.go @@ -22,12 +22,9 @@ import ( "time" awsAuth "github.com/dapr/components-contrib/common/authentication/aws" - "github.com/dapr/kit/logger" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" @@ -161,32 +158,14 @@ func TestInit(t *testing.T) { return nil, errors.New("Requested resource not found") }, } - dynamo := awsAuth.DynamoDBClients{ DynamoDB: mockedDB, } - mockedClients := awsAuth.Clients{ Dynamo: &dynamo, } - mockedSession, err := session.NewSession(&aws.Config{ - Region: aws.String("us-west-1"), - Credentials: credentials.AnonymousCredentials, - }) - require.NoError(t, err) mockAuthProvider := &awsAuth.StaticAuth{ - Logger: logger.NewLogger("test"), Clients: &mockedClients, - // mock client creds so we don't get the error -> access: EmptyStaticCreds: static credentials are empty" - Cfg: &aws.Config{ - Credentials: credentials.AnonymousCredentials, - }, - Region: "us-west-1", - AccessKey: aws.String("mocked"), - SecretKey: aws.String("mocked"), - SessionToken: aws.String("mocked"), - Endpoint: aws.String("mocked"), - Session: mockedSession, } s := StateStore{ @@ -195,7 +174,7 @@ func TestInit(t *testing.T) { table: table, } - err = s.Init(context.Background(), m) + err := s.Init(context.Background(), m) require.Error(t, err) require.EqualError(t, err, "error validating DynamoDB table 'does-not-exist' access: Requested resource not found") })