From 5bd981c1cf9ea47b10f356c6ffb7bd6db335a625 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Mon, 4 Nov 2024 13:39:01 -0600 Subject: [PATCH 01/39] feat(iam auth): allow iam roles anywhere auth profile Signed-off-by: Samantha Coyle --- .../builtin-authentication-profiles.yaml | 18 ++ bindings/aws/dynamodb/dynamodb.go | 28 +-- bindings/aws/s3/s3.go | 27 +-- bindings/aws/sns/sns.go | 31 +-- common/authentication/aws/aws.go | 194 ++++++++++++++++-- go.mod | 3 + go.sum | 10 + pubsub/aws/snssqs/snssqs.go | 17 +- tests/certification/go.mod | 1 + tests/certification/go.sum | 2 + 10 files changed, 278 insertions(+), 53 deletions(-) diff --git a/.build-tools/builtin-authentication-profiles.yaml b/.build-tools/builtin-authentication-profiles.yaml index 9113cf286b..439ea82f50 100644 --- a/.build-tools/builtin-authentication-profiles.yaml +++ b/.build-tools/builtin-authentication-profiles.yaml @@ -29,6 +29,24 @@ aws: type: string - title: "AWS: Credentials from Environment Variables" description: Use AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY from the environment + - title: "AWS: IAM Roles Anywhere" + description: Use x.509 certificates to establish trust between AWS and a trusted Certificate Authority using AWS IAM Roles Anywhere. + metadata: + - name: trustAnchorArn + description: | + ARN of the AWS Trust Anchor in the AWS account granting trust to a Certificate Authority. + example: arn:aws:rolesanywhere:us-west-1:012345678910:trust-anchor/01234568-0123-0123-0123-012345678901 + required: true + - name: trustProfileArn + description: | + ARN of the AWS IAM Profile in the trusting AWS account. + example: arn:aws:rolesanywhere:us-west-1:012345678910:profile/01234568-0123-0123-0123-012345678901 + required: true + - name: assumeRoleArn + description: | + ARN of the AWS IAM role to assume in the trusting AWS account. + example: arn:aws:iam:012345678910:role/exampleIAMRoleName + required: true azuread: - title: "Azure AD: Managed identity" diff --git a/bindings/aws/dynamodb/dynamodb.go b/bindings/aws/dynamodb/dynamodb.go index bd882e7b55..2e4c7cfdaa 100644 --- a/bindings/aws/dynamodb/dynamodb.go +++ b/bindings/aws/dynamodb/dynamodb.go @@ -51,18 +51,30 @@ func NewDynamoDB(logger logger.Logger) bindings.OutputBinding { } // Init performs connection parsing for DynamoDB. -func (d *DynamoDB) Init(_ context.Context, metadata bindings.Metadata) error { +func (d *DynamoDB) Init(ctx context.Context, metadata bindings.Metadata) error { meta, err := d.getDynamoDBMetadata(metadata) if err != nil { return err } - client, err := d.getClient(meta) + aws, err := awsAuth.New(awsAuth.Options{ + Logger: d.logger, + Properties: metadata.Properties, + Region: meta.Region, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, + }) + if err != nil { + return err + } + + sess, err := aws.GetClient(ctx) if err != nil { return err } - d.client = client + d.client = dynamodb.New(sess) d.table = meta.Table return nil @@ -105,16 +117,6 @@ func (d *DynamoDB) getDynamoDBMetadata(spec bindings.Metadata) (*dynamoDBMetadat return &meta, nil } -func (d *DynamoDB) getClient(metadata *dynamoDBMetadata) (*dynamodb.DynamoDB, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - c := dynamodb.New(sess) - - return c, nil -} - // GetComponentMetadata returns the metadata of the component. func (d *DynamoDB) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { metadataStruct := dynamoDBMetadata{} diff --git a/bindings/aws/s3/s3.go b/bindings/aws/s3/s3.go index cc67cec94f..66aabde47f 100644 --- a/bindings/aws/s3/s3.go +++ b/bindings/aws/s3/s3.go @@ -29,7 +29,6 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3/s3manager" @@ -110,12 +109,25 @@ func NewAWSS3(logger logger.Logger) bindings.OutputBinding { } // Init does metadata parsing and connection creation. -func (s *AWSS3) Init(_ context.Context, metadata bindings.Metadata) error { +func (s *AWSS3) Init(ctx context.Context, metadata bindings.Metadata) error { m, err := s.parseMetadata(metadata) if err != nil { return err } - session, err := s.getSession(m) + + awsA, err := awsAuth.New(awsAuth.Options{ + Logger: s.logger, + Properties: metadata.Properties, + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, + }) + if err != nil { + return err + } + + session, err := awsA.GetClient(ctx) if err != nil { return err } @@ -415,15 +427,6 @@ func (s *AWSS3) parseMetadata(md bindings.Metadata) (*s3Metadata, error) { return &m, nil } -func (s *AWSS3) getSession(metadata *s3Metadata) (*session.Session, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - - return sess, nil -} - // Helper to merge config and request metadata. func (metadata s3Metadata) mergeWithRequestMetadata(req *bindings.InvokeRequest) (s3Metadata, error) { merged := metadata diff --git a/bindings/aws/sns/sns.go b/bindings/aws/sns/sns.go index 43b63cd2b1..e01ee477ac 100644 --- a/bindings/aws/sns/sns.go +++ b/bindings/aws/sns/sns.go @@ -58,16 +58,31 @@ func NewAWSSNS(logger logger.Logger) bindings.OutputBinding { } // Init does metadata parsing. -func (a *AWSSNS) Init(_ context.Context, metadata bindings.Metadata) error { +func (a *AWSSNS) Init(ctx context.Context, metadata bindings.Metadata) error { m, err := a.parseMetadata(metadata) if err != nil { return err } - client, err := a.getClient(m) + + aws, err := awsAuth.New(awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, + }) if err != nil { return err } - a.client = client + + sess, err := aws.GetClient(ctx) + if err != nil { + return err + } + + a.client = sns.New(sess) + a.topicARN = m.TopicArn return nil @@ -83,16 +98,6 @@ func (a *AWSSNS) parseMetadata(meta bindings.Metadata) (*snsMetadata, error) { return &m, nil } -func (a *AWSSNS) getClient(metadata *snsMetadata) (*sns.SNS, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - c := sns.New(sess) - - return c, nil -} - func (a *AWSSNS) Operations() []bindings.OperationKind { return []bindings.OperationKind{bindings.CreateOperation} } diff --git a/common/authentication/aws/aws.go b/common/authentication/aws/aws.go index 48c8b209a4..dee30ec98e 100644 --- a/common/authentication/aws/aws.go +++ b/common/authentication/aws/aws.go @@ -15,9 +15,15 @@ package aws import ( "context" + "crypto/ecdsa" + "crypto/tls" + "crypto/x509" "errors" "fmt" + "net/http" + "runtime" "strconv" + "sync" "time" awsv2 "github.com/aws/aws-sdk-go-v2/aws" @@ -25,13 +31,19 @@ import ( v2creds "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/feature/rds/auth" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" + awssh "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" + cryptopem "github.com/dapr/kit/crypto/pem" + spiffecontext "github.com/dapr/kit/crypto/spiffe/context" + "github.com/dapr/kit/logger" + kitmd "github.com/dapr/kit/metadata" + "github.com/dapr/kit/ptr" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" - - "github.com/dapr/kit/logger" ) type EnvironmentSettings struct { @@ -61,19 +73,128 @@ func GetConfigV2(accessKey string, secretKey string, sessionToken string, region return awsCfg, nil } -func GetClient(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (*session.Session, error) { - awsConfig := aws.NewConfig() +func (a *AWS) GetClient(ctx context.Context) (*session.Session, error) { + a.lock.Lock() + defer a.lock.Unlock() - if region != "" { - awsConfig = awsConfig.WithRegion(region) + switch { + // IAM Roles Anywhere option + case a.x509Auth.TrustAnchorArn != nil && a.x509Auth.AssumeRoleArn != nil: + a.logger.Debug("using X.509 RolesAnywhere authentication using Dapr SVID") + return a.getX509Client(ctx) + default: + a.logger.Debugf("using AWS session client...") + return a.getSessionClient() } +} - if accessKey != "" && secretKey != "" { - awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(accessKey, secretKey, sessionToken)) +func (a *AWS) getX509Client(ctx context.Context) (*session.Session, error) { + // retrieve svid from spiffe context + svid, ok := spiffecontext.From(ctx) + if !ok { + return nil, fmt.Errorf("no SVID found in context") + } + // get x.509 svid + svidx, err := svid.GetX509SVID() + if err != nil { + return nil, err } - if endpoint != "" { - awsConfig = awsConfig.WithEndpoint(endpoint) + // marshal x.509 svid to pem format + chainPEM, keyPEM, err := svidx.Marshal() + if err != nil { + return nil, fmt.Errorf("failed to marshal SVID: %w", err) + } + + var trustAnchor arn.ARN + if a.x509Auth.TrustAnchorArn != nil { + trustAnchor, err = arn.Parse(*a.x509Auth.TrustAnchorArn) + if err != nil { + return nil, err + } + a.region = trustAnchor.Region + } + + if a.x509Auth.TrustProfileArn != nil { + profile, err := arn.Parse(*a.x509Auth.TrustProfileArn) + if err != nil { + return nil, err + } + if profile.Region != "" && trustAnchor.Region != profile.Region { + return nil, fmt.Errorf("trust anchor and profile must be in the same region: trustAnchor=%s, profile=%s", + trustAnchor.Region, profile.Region) + } + } + + mySession, err := session.NewSession() + if err != nil { + return nil, err + } + client := &http.Client{Transport: &http.Transport{ + TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, + }} + config := aws.NewConfig().WithRegion(trustAnchor.Region).WithHTTPClient(client).WithLogLevel(aws.LogOff) + rolesAnywhereClient := rolesanywhere.New(mySession, config) + certs, err := cryptopem.DecodePEMCertificatesChain(chainPEM) + if err != nil { + return nil, err + } + + var ints []x509.Certificate + for i := range certs[1:] { + ints = append(ints, *certs[i+1]) + } + + key, err := cryptopem.DecodePEMPrivateKey(keyPEM) + if err != nil { + return nil, err + } + + keyECDSA := key.(*ecdsa.PrivateKey) + signFunc := awssh.CreateSignFunction(*keyECDSA, *certs[0], ints) + agentHandlerFunc := request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH) + + rolesAnywhereClient.Handlers.Build.RemoveByName("core.SDKVersionUserAgentHandler") + rolesAnywhereClient.Handlers.Build.PushBackNamed(request.NamedHandler{Name: "v4x509.CredHelperUserAgentHandler", Fn: agentHandlerFunc}) + rolesAnywhereClient.Handlers.Sign.Clear() + rolesAnywhereClient.Handlers.Sign.PushBackNamed(request.NamedHandler{Name: "v4x509.SignRequestHandler", Fn: signFunc}) + + // TODO: make metadata field? + var duration int64 = 10000 + createSessionRequest := rolesanywhere.CreateSessionInput{ + Cert: ptr.Of(string(chainPEM)), + ProfileArn: a.x509Auth.TrustProfileArn, + TrustAnchorArn: a.x509Auth.TrustAnchorArn, + RoleArn: a.x509Auth.AssumeRoleArn, + DurationSeconds: &duration, + InstanceProperties: nil, + SessionName: nil, + } + output, err := rolesAnywhereClient.CreateSessionWithContext(ctx, &createSessionRequest) + if err != nil { + return nil, fmt.Errorf("failed to create session using dapr app dentity: %w", err) + } + + if len(output.CredentialSet) != 1 { + return nil, fmt.Errorf("expected 1 credential set from X.509 rolesanyway response, got %d", len(output.CredentialSet)) + } + + a.accessKey = *output.CredentialSet[0].Credentials.AccessKeyId + a.secretKey = *output.CredentialSet[0].Credentials.SecretAccessKey + a.sessionToken = *output.CredentialSet[0].Credentials.SessionToken + + return a.getSessionClient() +} + +func (a *AWS) getSessionClient() (*session.Session, error) { + awsConfig := aws.NewConfig() + + if a.region != "" { + awsConfig = awsConfig.WithRegion(a.region) + } + + if a.accessKey != "" && a.secretKey != "" { + awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(a.accessKey, a.secretKey, a.sessionToken)) } awsSession, err := session.NewSessionWithOptions(session.Options{ @@ -102,6 +223,19 @@ func NewEnvironmentSettings(md map[string]string) (EnvironmentSettings, error) { return es, nil } +type AWS struct { + lock sync.RWMutex + logger logger.Logger + + x509Auth *x509Auth + + region string + endpoint string + accessKey string + secretKey string + sessionToken string +} + type AWSIAM struct { // Ignored by metadata parser because included in built-in authentication profile // Access key to use for accessing PostgreSQL. @@ -110,17 +244,51 @@ type AWSIAM struct { AWSSecretKey string `json:"awsSecretKey" mapstructure:"awsSecretKey"` // AWS region in which PostgreSQL is deployed. AWSRegion string `json:"awsRegion" mapstructure:"awsRegion"` + + // AWS IAM Roles anywhere related fields + x509Auth *x509Auth } -type AWSIAMAuthOptions struct { +type Options struct { + Logger logger.Logger + Properties map[string]string + PoolConfig *pgxpool.Config `json:"poolConfig" mapstructure:"poolConfig"` ConnectionString string `json:"connectionString" mapstructure:"connectionString"` Region string `json:"region" mapstructure:"region"` AccessKey string `json:"accessKey" mapstructure:"accessKey"` SecretKey string `json:"secretKey" mapstructure:"secretKey"` + SessionToken string `json:"sessionToken" mapstructure:"sessionToken"` +} + +type x509Auth struct { + TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"` + TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"` + AssumeRoleArn *string `json:"assumeRoleArn" mapstructure:"assumeRoleArn"` +} + +func New(opts Options) (*AWS, error) { + var x509Auth x509Auth + if err := kitmd.DecodeMetadata(opts.Properties, &x509Auth); err != nil { + return nil, err + } + if x509Auth.AssumeRoleArn != nil { + opts.Logger.Infof("sam x509 fields %s %s ", *x509Auth.AssumeRoleArn, *x509Auth.TrustAnchorArn) + } else { + opts.Logger.Infof("sam still nil somehow...") + } + + return &AWS{ + x509Auth: &x509Auth, + logger: opts.Logger, + region: opts.Region, + accessKey: opts.AccessKey, + secretKey: opts.SecretKey, + sessionToken: opts.SessionToken, + }, nil } -func (opts *AWSIAMAuthOptions) GetAccessToken(ctx context.Context) (string, error) { +func (opts *Options) GetAccessToken(ctx context.Context) (string, error) { dbEndpoint := opts.PoolConfig.ConnConfig.Host + ":" + strconv.Itoa(int(opts.PoolConfig.ConnConfig.Port)) var authenticationToken string @@ -160,7 +328,7 @@ func (opts *AWSIAMAuthOptions) GetAccessToken(ctx context.Context) (string, erro return authenticationToken, nil } -func (opts *AWSIAMAuthOptions) InitiateAWSIAMAuth() error { +func (opts *Options) InitiateAWSIAMAuth() error { // Set max connection lifetime to 8 minutes in postgres connection pool configuration. // Note: this will refresh connections before the 15 min expiration on the IAM AWS auth token, // while leveraging the BeforeConnect hook to recreate the token in time dynamically. diff --git a/go.mod b/go.mod index 13a6af3ab6..16d28420b0 100644 --- a/go.mod +++ b/go.mod @@ -46,6 +46,7 @@ require ( github.com/aws/aws-sdk-go-v2/credentials v1.17.37 github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.3.10 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.17.3 + github.com/aws/rolesanywhere-credential-helper v1.0.4 github.com/bradfitz/gomemcache v0.0.0-20230905024940-24af94b03874 github.com/camunda/zeebe/clients/go/v8 v8.2.12 github.com/cenkalti/backoff/v4 v4.2.1 @@ -359,6 +360,7 @@ require ( github.com/sourcegraph/conc v0.3.0 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/spiffe/go-spiffe/v2 v2.1.7 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/tchap/go-patricia/v2 v2.3.1 // indirect github.com/tidwall/gjson v1.14.4 // indirect @@ -379,6 +381,7 @@ require ( github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 // indirect github.com/yuin/gopher-lua v1.1.0 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect + github.com/zeebo/errs v1.3.0 // indirect go.etcd.io/etcd/api/v3 v3.5.9 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.9 // indirect go.opencensus.io v0.24.0 // indirect diff --git a/go.sum b/go.sum index 54c28416d8..610b654868 100644 --- a/go.sum +++ b/go.sum @@ -289,6 +289,8 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3/go.mod h1:FnvDM4sfa+isJ3kDX github.com/aws/aws-sdk-go-v2/service/sts v1.7.2/go.mod h1:8EzeIqfWt2wWT4rJVu3f21TfrhJ8AEMzVybRNSb/b4g= github.com/aws/aws-sdk-go-v2/service/sts v1.31.3 h1:VzudTFrDCIDakXtemR7l6Qzt2+JYsVqo2MxBPt5k8T8= github.com/aws/aws-sdk-go-v2/service/sts v1.31.3/go.mod h1:yMWe0F+XG0DkRZK5ODZhG7BEFYhLXi2dqGsv6tX0cgI= +github.com/aws/rolesanywhere-credential-helper v1.0.4 h1:kHIVVdyQQiFZoKBP+zywBdFilGCS8It+UvW5LolKbW8= +github.com/aws/rolesanywhere-credential-helper v1.0.4/go.mod h1:QVGNxlDlYhjR0/ZUee7uGl0hNChWidNpe2+GD87Buqk= github.com/aws/smithy-go v1.8.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= github.com/aws/smithy-go v1.21.0 h1:H7L8dtDRk0P1Qm6y0ji7MCYMQObJ5R9CRpyPhRUkLYA= github.com/aws/smithy-go v1.21.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= @@ -604,6 +606,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= +github.com/go-jose/go-jose/v3 v3.0.1 h1:pWmKFVtt+Jl0vBZTIpz/eAKwsm6LkIxDVVbFHKkchhA= +github.com/go-jose/go-jose/v3 v3.0.1/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.10.0 h1:dXFJfIHVvUcpSgDOV+Ne6t7jXri8Tfv2uOLHUZ2XNuo= @@ -1517,6 +1521,8 @@ github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5q github.com/spf13/viper v1.7.1/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg= github.com/spf13/viper v1.15.0 h1:js3yy885G8xwJa6iOISGFwd+qlUo5AvyXb7CiihdtiU= github.com/spf13/viper v1.15.0/go.mod h1:fFcTBJxvhhzSJiZy8n+PeW6t8l+KeT/uTARa0jHOQLA= +github.com/spiffe/go-spiffe/v2 v2.1.7 h1:VUkM1yIyg/x8X7u1uXqSRVRCdMdfRIEdFBzpqoeASGk= +github.com/spiffe/go-spiffe/v2 v2.1.7/go.mod h1:QJDGdhXllxjxvd5B+2XnhhXB/+rC8gr+lNrtOryiWeE= github.com/stealthrocket/wasi-go v0.8.1-0.20230912180546-8efbab50fb58 h1:mTC4gyv3lcJ1XpzZMAckqkvWUqeT5Bva4RAT1IoHAAA= github.com/stealthrocket/wasi-go v0.8.1-0.20230912180546-8efbab50fb58/go.mod h1:ZAYCOqLJkc9P6fcq14TV4cf+gJ2fHthp9kCGxBViagE= github.com/stealthrocket/wazergo v0.19.1 h1:BPrITETPgSFwiytwmToO0MbUC/+RGC39JScz1JmmG6c= @@ -1652,6 +1658,8 @@ github.com/yuin/gopher-lua v1.1.0/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7 github.com/yusufpapurcu/wmi v1.2.2/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw= github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= +github.com/zeebo/errs v1.3.0 h1:hmiaKqgYZzcVgRL1Vkc1Mn2914BbzB0IBxs+ebeutGs= +github.com/zeebo/errs v1.3.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= github.com/zouyx/agollo/v3 v3.4.5 h1:7YCxzY9ZYaH9TuVUBvmI6Tk0mwMggikah+cfbYogcHQ= github.com/zouyx/agollo/v3 v3.4.5/go.mod h1:LJr3kDmm23QSW+F1Ol4TMHDa7HvJvscMdVxJ2IpUTVc= go.einride.tech/aip v0.66.0 h1:XfV+NQX6L7EOYK11yoHHFtndeaWh3KbD9/cN/6iWEt8= @@ -2316,6 +2324,8 @@ google.golang.org/grpc v1.47.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACu google.golang.org/grpc v1.48.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk= google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY= google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg= +google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20 h1:MLBCGN1O7GzIx+cBiwfYPwtmZ41U3Mn/cotLJciaArI= +google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20/go.mod h1:Nr5H8+MlGWr5+xX/STzdoEqJrO+YteqFbMyCsrb6mH0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index 357cfcabb9..d782320c69 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -145,10 +145,23 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error { s.metadata = md - sess, err := awsAuth.GetClient(md.AccessKey, md.SecretKey, md.SessionToken, md.Region, md.Endpoint) + aws, err := awsAuth.New(awsAuth.Options{ + Logger: s.logger, + Properties: metadata.Properties, + Region: md.Region, + AccessKey: md.AccessKey, + SecretKey: md.SecretKey, + SessionToken: md.SessionToken, + }) if err != nil { - return fmt.Errorf("error creating an AWS client: %w", err) + return err } + + sess, err := aws.GetClient(ctx) + if err != nil { + return err + } + // AWS sns,sqs,sts client. s.snsClient = sns.New(sess) s.sqsClient = sqs.New(sess) diff --git a/tests/certification/go.mod b/tests/certification/go.mod index 1dc9c0ad44..51befe4ed7 100644 --- a/tests/certification/go.mod +++ b/tests/certification/go.mod @@ -98,6 +98,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/sso v1.23.3 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.31.3 // indirect + github.com/aws/rolesanywhere-credential-helper v1.0.4 // indirect github.com/aws/smithy-go v1.21.0 // indirect github.com/benbjohnson/clock v1.3.5 // indirect github.com/beorn7/perks v1.0.1 // indirect diff --git a/tests/certification/go.sum b/tests/certification/go.sum index 145f05a305..04311f0e9f 100644 --- a/tests/certification/go.sum +++ b/tests/certification/go.sum @@ -234,6 +234,8 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3/go.mod h1:FnvDM4sfa+isJ3kDX github.com/aws/aws-sdk-go-v2/service/sts v1.7.2/go.mod h1:8EzeIqfWt2wWT4rJVu3f21TfrhJ8AEMzVybRNSb/b4g= github.com/aws/aws-sdk-go-v2/service/sts v1.31.3 h1:VzudTFrDCIDakXtemR7l6Qzt2+JYsVqo2MxBPt5k8T8= github.com/aws/aws-sdk-go-v2/service/sts v1.31.3/go.mod h1:yMWe0F+XG0DkRZK5ODZhG7BEFYhLXi2dqGsv6tX0cgI= +github.com/aws/rolesanywhere-credential-helper v1.0.4 h1:kHIVVdyQQiFZoKBP+zywBdFilGCS8It+UvW5LolKbW8= +github.com/aws/rolesanywhere-credential-helper v1.0.4/go.mod h1:QVGNxlDlYhjR0/ZUee7uGl0hNChWidNpe2+GD87Buqk= github.com/aws/smithy-go v1.8.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= github.com/aws/smithy-go v1.21.0 h1:H7L8dtDRk0P1Qm6y0ji7MCYMQObJ5R9CRpyPhRUkLYA= github.com/aws/smithy-go v1.21.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= From 45db3a1d07019cdec3305ba05f5560662343cfac Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Mon, 4 Nov 2024 13:54:32 -0600 Subject: [PATCH 02/39] fix(build): update more aws components Signed-off-by: Samantha Coyle --- bindings/aws/kinesis/kinesis.go | 30 +++++++++--------- bindings/aws/ses/ses.go | 31 ++++++++++--------- bindings/aws/sqs/sqs.go | 28 +++++++++-------- common/authentication/postgresql/metadata.go | 2 +- .../aws/parameterstore/parameterstore.go | 24 ++++++++------ .../aws/secretmanager/secretmanager.go | 24 ++++++++------ 6 files changed, 77 insertions(+), 62 deletions(-) diff --git a/bindings/aws/kinesis/kinesis.go b/bindings/aws/kinesis/kinesis.go index dbe0ceb918..18388faf2e 100644 --- a/bindings/aws/kinesis/kinesis.go +++ b/bindings/aws/kinesis/kinesis.go @@ -112,13 +112,26 @@ func (a *AWSKinesis) Init(ctx context.Context, metadata bindings.Metadata) error return fmt.Errorf("%s invalid \"mode\" field %s", "aws.kinesis", m.KinesisConsumerMode) } - client, err := a.getClient(m) + awsA, err := awsAuth.New(awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, + }) + if err != nil { + return err + } + + sess, err := awsA.GetClient(ctx) if err != nil { return err } + a.client = kinesis.New(sess) streamName := aws.String(m.StreamName) - stream, err := client.DescribeStreamWithContext(ctx, &kinesis.DescribeStreamInput{ + stream, err := a.client.DescribeStreamWithContext(ctx, &kinesis.DescribeStreamInput{ StreamName: streamName, }) if err != nil { @@ -128,13 +141,12 @@ func (a *AWSKinesis) Init(ctx context.Context, metadata bindings.Metadata) error if m.KinesisConsumerMode == SharedThroughput { kclConfig := config.NewKinesisClientLibConfigWithCredential(m.ConsumerName, m.StreamName, m.Region, m.ConsumerName, - client.Config.Credentials) + a.client.Config.Credentials) a.workerConfig = kclConfig } a.streamARN = stream.StreamDescription.StreamARN a.metadata = m - a.client = client return nil } @@ -354,16 +366,6 @@ func (a *AWSKinesis) waitUntilConsumerExists(ctx aws.Context, input *kinesis.Des return w.WaitWithContext(ctx) } -func (a *AWSKinesis) getClient(metadata *kinesisMetadata) (*kinesis.Kinesis, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - k := kinesis.New(sess) - - return k, nil -} - func (a *AWSKinesis) parseMetadata(meta bindings.Metadata) (*kinesisMetadata, error) { var m kinesisMetadata err := kitmd.DecodeMetadata(meta.Properties, &m) diff --git a/bindings/aws/ses/ses.go b/bindings/aws/ses/ses.go index 483fde8c64..bca4e097d2 100644 --- a/bindings/aws/ses/ses.go +++ b/bindings/aws/ses/ses.go @@ -61,17 +61,32 @@ func NewAWSSES(logger logger.Logger) bindings.OutputBinding { } // Init does metadata parsing. -func (a *AWSSES) Init(_ context.Context, metadata bindings.Metadata) error { +func (a *AWSSES) Init(ctx context.Context, metadata bindings.Metadata) error { // Parse input metadata meta, err := a.parseMetadata(metadata) if err != nil { return err } - svc, err := a.getClient(meta) + aws, err := awsAuth.New(awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: meta.Region, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, + }) if err != nil { return err } + + sess, err := aws.GetClient(ctx) + if err != nil { + return err + } + + // Create an SES instance + svc := ses.New(sess) a.metadata = meta a.svc = svc @@ -158,18 +173,6 @@ func (metadata sesMetadata) mergeWithRequestMetadata(req *bindings.InvokeRequest return merged } -func (a *AWSSES) getClient(metadata *sesMetadata) (*ses.SES, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, "") - if err != nil { - return nil, fmt.Errorf("SES binding error: error creating AWS session %w", err) - } - - // Create an SES instance - svc := ses.New(sess) - - return svc, nil -} - // GetComponentMetadata returns the metadata of the component. func (a *AWSSES) GetComponentMetadata() (metadataInfo contribMetadata.MetadataMap) { metadataStruct := sesMetadata{} diff --git a/bindings/aws/sqs/sqs.go b/bindings/aws/sqs/sqs.go index 465e061b61..ae480a5637 100644 --- a/bindings/aws/sqs/sqs.go +++ b/bindings/aws/sqs/sqs.go @@ -66,13 +66,26 @@ func (a *AWSSQS) Init(ctx context.Context, metadata bindings.Metadata) error { return err } - client, err := a.getClient(m) + awsA, err := awsAuth.New(awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, + }) + if err != nil { + return err + } + + sess, err := awsA.GetClient(ctx) if err != nil { return err } + a.Client = sqs.New(sess) queueName := m.QueueName - resultURL, err := client.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{ + resultURL, err := a.Client.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{ QueueName: aws.String(queueName), }) if err != nil { @@ -80,7 +93,6 @@ func (a *AWSSQS) Init(ctx context.Context, metadata bindings.Metadata) error { } a.QueueURL = resultURL.QueueUrl - a.Client = client return nil } @@ -177,16 +189,6 @@ func (a *AWSSQS) parseSQSMetadata(meta bindings.Metadata) (*sqsMetadata, error) return &m, nil } -func (a *AWSSQS) getClient(metadata *sqsMetadata) (*sqs.SQS, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - c := sqs.New(sess) - - return c, nil -} - // GetComponentMetadata returns the metadata of the component. func (a *AWSSQS) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { metadataStruct := sqsMetadata{} diff --git a/common/authentication/postgresql/metadata.go b/common/authentication/postgresql/metadata.go index 7cacecfaa4..4b2135ba6a 100644 --- a/common/authentication/postgresql/metadata.go +++ b/common/authentication/postgresql/metadata.go @@ -162,7 +162,7 @@ func (m *PostgresAuthMetadata) GetPgxPoolConfig() (*pgxpool.Config, error) { return nil, err } - awsOpts := aws.AWSIAMAuthOptions{ + awsOpts := aws.Options{ PoolConfig: config, ConnectionString: m.ConnectionString, Region: awsRegion, diff --git a/secretstores/aws/parameterstore/parameterstore.go b/secretstores/aws/parameterstore/parameterstore.go index 1f82031ba8..eecade0695 100644 --- a/secretstores/aws/parameterstore/parameterstore.go +++ b/secretstores/aws/parameterstore/parameterstore.go @@ -65,10 +65,23 @@ func (s *ssmSecretStore) Init(ctx context.Context, metadata secretstores.Metadat return err } - s.client, err = s.getClient(meta) + awsA, err := awsAuth.New(awsAuth.Options{ + Logger: s.logger, + Properties: metadata.Properties, + Region: meta.Region, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, + }) + if err != nil { + return err + } + + session, err := awsA.GetClient(ctx) if err != nil { return err } + s.client = ssm.New(session) s.prefix = meta.Prefix return nil @@ -155,15 +168,6 @@ func (s *ssmSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bul return resp, nil } -func (s *ssmSecretStore) getClient(metadata *ParameterStoreMetaData) (*ssm.SSM, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, "") - if err != nil { - return nil, err - } - - return ssm.New(sess), nil -} - func (s *ssmSecretStore) getSecretManagerMetadata(spec secretstores.Metadata) (*ParameterStoreMetaData, error) { meta := ParameterStoreMetaData{} err := kitmd.DecodeMetadata(spec.Properties, &meta) diff --git a/secretstores/aws/secretmanager/secretmanager.go b/secretstores/aws/secretmanager/secretmanager.go index 33f5a54c9d..31eb7b9261 100644 --- a/secretstores/aws/secretmanager/secretmanager.go +++ b/secretstores/aws/secretmanager/secretmanager.go @@ -59,11 +59,24 @@ func (s *smSecretStore) Init(ctx context.Context, metadata secretstores.Metadata return err } - s.client, err = s.getClient(meta) + awsA, err := awsAuth.New(awsAuth.Options{ + Logger: s.logger, + Properties: metadata.Properties, + Region: meta.Region, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, + }) + if err != nil { + return err + } + + session, err := awsA.GetClient(ctx) if err != nil { return err } + s.client = secretsmanager.New(session) return nil } @@ -135,15 +148,6 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk return resp, nil } -func (s *smSecretStore) getClient(metadata *SecretManagerMetaData) (*secretsmanager.SecretsManager, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, "") - if err != nil { - return nil, err - } - - return secretsmanager.New(sess), nil -} - func (s *smSecretStore) getSecretManagerMetadata(spec secretstores.Metadata) (*SecretManagerMetaData, error) { b, err := json.Marshal(spec.Properties) if err != nil { From 58bbeaa9cd96534ed0fd3a9987996ab36b43c9f7 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Mon, 4 Nov 2024 13:59:02 -0600 Subject: [PATCH 03/39] fix(metadata): add endpoint field to options Signed-off-by: Samantha Coyle --- bindings/aws/dynamodb/dynamodb.go | 1 + bindings/aws/kinesis/kinesis.go | 1 + bindings/aws/s3/s3.go | 1 + bindings/aws/sns/sns.go | 1 + bindings/aws/sqs/sqs.go | 1 + common/authentication/aws/aws.go | 6 ++++++ 6 files changed, 11 insertions(+) diff --git a/bindings/aws/dynamodb/dynamodb.go b/bindings/aws/dynamodb/dynamodb.go index 2e4c7cfdaa..23439dcbcc 100644 --- a/bindings/aws/dynamodb/dynamodb.go +++ b/bindings/aws/dynamodb/dynamodb.go @@ -64,6 +64,7 @@ func (d *DynamoDB) Init(ctx context.Context, metadata bindings.Metadata) error { AccessKey: meta.AccessKey, SecretKey: meta.SecretKey, SessionToken: meta.SessionToken, + Endpoint: meta.Endpoint, }) if err != nil { return err diff --git a/bindings/aws/kinesis/kinesis.go b/bindings/aws/kinesis/kinesis.go index 18388faf2e..917c1cddfb 100644 --- a/bindings/aws/kinesis/kinesis.go +++ b/bindings/aws/kinesis/kinesis.go @@ -119,6 +119,7 @@ func (a *AWSKinesis) Init(ctx context.Context, metadata bindings.Metadata) error AccessKey: m.AccessKey, SecretKey: m.SecretKey, SessionToken: m.SessionToken, + Endpoint: m.Endpoint, }) if err != nil { return err diff --git a/bindings/aws/s3/s3.go b/bindings/aws/s3/s3.go index 66aabde47f..b180a43829 100644 --- a/bindings/aws/s3/s3.go +++ b/bindings/aws/s3/s3.go @@ -122,6 +122,7 @@ func (s *AWSS3) Init(ctx context.Context, metadata bindings.Metadata) error { AccessKey: m.AccessKey, SecretKey: m.SecretKey, SessionToken: m.SessionToken, + Endpoint: m.Endpoint, }) if err != nil { return err diff --git a/bindings/aws/sns/sns.go b/bindings/aws/sns/sns.go index e01ee477ac..275743c477 100644 --- a/bindings/aws/sns/sns.go +++ b/bindings/aws/sns/sns.go @@ -71,6 +71,7 @@ func (a *AWSSNS) Init(ctx context.Context, metadata bindings.Metadata) error { AccessKey: m.AccessKey, SecretKey: m.SecretKey, SessionToken: m.SessionToken, + Endpoint: m.Endpoint, }) if err != nil { return err diff --git a/bindings/aws/sqs/sqs.go b/bindings/aws/sqs/sqs.go index ae480a5637..80db597075 100644 --- a/bindings/aws/sqs/sqs.go +++ b/bindings/aws/sqs/sqs.go @@ -73,6 +73,7 @@ func (a *AWSSQS) Init(ctx context.Context, metadata bindings.Metadata) error { AccessKey: m.AccessKey, SecretKey: m.SecretKey, SessionToken: m.SessionToken, + Endpoint: m.Endpoint, }) if err != nil { return err diff --git a/common/authentication/aws/aws.go b/common/authentication/aws/aws.go index dee30ec98e..7f6c78a1a3 100644 --- a/common/authentication/aws/aws.go +++ b/common/authentication/aws/aws.go @@ -197,6 +197,10 @@ func (a *AWS) getSessionClient() (*session.Session, error) { awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(a.accessKey, a.secretKey, a.sessionToken)) } + if a.endpoint != "" { + awsConfig = awsConfig.WithEndpoint(a.endpoint) + } + awsSession, err := session.NewSessionWithOptions(session.Options{ Config: *awsConfig, SharedConfigState: session.SharedConfigEnable, @@ -259,6 +263,7 @@ type Options struct { AccessKey string `json:"accessKey" mapstructure:"accessKey"` SecretKey string `json:"secretKey" mapstructure:"secretKey"` SessionToken string `json:"sessionToken" mapstructure:"sessionToken"` + Endpoint string `json:"endpoint" mapstructure:"endpoint"` } type x509Auth struct { @@ -285,6 +290,7 @@ func New(opts Options) (*AWS, error) { accessKey: opts.AccessKey, secretKey: opts.SecretKey, sessionToken: opts.SessionToken, + endpoint: opts.Endpoint, }, nil } From e9493ae1e4985e353d14863afff53e2b21582670 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Mon, 4 Nov 2024 14:00:11 -0600 Subject: [PATCH 04/39] style: update descriptions on new fields Signed-off-by: Samantha Coyle --- .build-tools/builtin-authentication-profiles.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.build-tools/builtin-authentication-profiles.yaml b/.build-tools/builtin-authentication-profiles.yaml index 439ea82f50..4970fd67bc 100644 --- a/.build-tools/builtin-authentication-profiles.yaml +++ b/.build-tools/builtin-authentication-profiles.yaml @@ -30,11 +30,11 @@ aws: - title: "AWS: Credentials from Environment Variables" description: Use AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY from the environment - title: "AWS: IAM Roles Anywhere" - description: Use x.509 certificates to establish trust between AWS and a trusted Certificate Authority using AWS IAM Roles Anywhere. + description: Use x.509 certificates to establish trust between AWS and your AWS account and the Dapr cluster using AWS IAM Roles Anywhere. metadata: - name: trustAnchorArn description: | - ARN of the AWS Trust Anchor in the AWS account granting trust to a Certificate Authority. + ARN of the AWS Trust Anchor in the AWS account granting trust to the Dapr Certificate Authority. example: arn:aws:rolesanywhere:us-west-1:012345678910:trust-anchor/01234568-0123-0123-0123-012345678901 required: true - name: trustProfileArn From bff9b16912a9dfdfbc0a6b590a54c618c822bb1d Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Mon, 4 Nov 2024 14:13:17 -0600 Subject: [PATCH 05/39] fix: acct for nil check Signed-off-by: Samantha Coyle --- bindings/aws/dynamodb/dynamodb.go | 39 +++++----- bindings/aws/kinesis/kinesis.go | 34 ++++----- bindings/aws/s3/s3.go | 71 ++++++++++--------- bindings/aws/ses/ses.go | 39 +++++----- bindings/aws/sns/sns.go | 38 +++++----- bindings/aws/sqs/sqs.go | 34 ++++----- .../aws/parameterstore/parameterstore.go | 33 +++++---- .../aws/secretmanager/secretmanager.go | 33 +++++---- state/aws/dynamodb/dynamodb.go | 28 +++++--- 9 files changed, 187 insertions(+), 162 deletions(-) diff --git a/bindings/aws/dynamodb/dynamodb.go b/bindings/aws/dynamodb/dynamodb.go index 23439dcbcc..dcf3155836 100644 --- a/bindings/aws/dynamodb/dynamodb.go +++ b/bindings/aws/dynamodb/dynamodb.go @@ -56,26 +56,27 @@ func (d *DynamoDB) Init(ctx context.Context, metadata bindings.Metadata) error { if err != nil { return err } - - aws, err := awsAuth.New(awsAuth.Options{ - Logger: d.logger, - Properties: metadata.Properties, - Region: meta.Region, - AccessKey: meta.AccessKey, - SecretKey: meta.SecretKey, - SessionToken: meta.SessionToken, - Endpoint: meta.Endpoint, - }) - if err != nil { - return err + if d.client == nil { + aws, err := awsAuth.New(awsAuth.Options{ + Logger: d.logger, + Properties: metadata.Properties, + Region: meta.Region, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, + Endpoint: meta.Endpoint, + }) + if err != nil { + return err + } + + sess, err := aws.GetClient(ctx) + if err != nil { + return err + } + + d.client = dynamodb.New(sess) } - - sess, err := aws.GetClient(ctx) - if err != nil { - return err - } - - d.client = dynamodb.New(sess) d.table = meta.Table return nil diff --git a/bindings/aws/kinesis/kinesis.go b/bindings/aws/kinesis/kinesis.go index 917c1cddfb..50f5742f05 100644 --- a/bindings/aws/kinesis/kinesis.go +++ b/bindings/aws/kinesis/kinesis.go @@ -112,24 +112,26 @@ func (a *AWSKinesis) Init(ctx context.Context, metadata bindings.Metadata) error return fmt.Errorf("%s invalid \"mode\" field %s", "aws.kinesis", m.KinesisConsumerMode) } - awsA, err := awsAuth.New(awsAuth.Options{ - Logger: a.logger, - Properties: metadata.Properties, - Region: m.Region, - AccessKey: m.AccessKey, - SecretKey: m.SecretKey, - SessionToken: m.SessionToken, - Endpoint: m.Endpoint, - }) - if err != nil { - return err - } + if a.client == nil { + awsA, err := awsAuth.New(awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, + Endpoint: m.Endpoint, + }) + if err != nil { + return err + } - sess, err := awsA.GetClient(ctx) - if err != nil { - return err + sess, err := awsA.GetClient(ctx) + if err != nil { + return err + } + a.client = kinesis.New(sess) } - a.client = kinesis.New(sess) streamName := aws.String(m.StreamName) stream, err := a.client.DescribeStreamWithContext(ctx, &kinesis.DescribeStreamInput{ diff --git a/bindings/aws/s3/s3.go b/bindings/aws/s3/s3.go index b180a43829..540b33a5ff 100644 --- a/bindings/aws/s3/s3.go +++ b/bindings/aws/s3/s3.go @@ -115,47 +115,50 @@ func (s *AWSS3) Init(ctx context.Context, metadata bindings.Metadata) error { return err } - awsA, err := awsAuth.New(awsAuth.Options{ - Logger: s.logger, - Properties: metadata.Properties, - Region: m.Region, - AccessKey: m.AccessKey, - SecretKey: m.SecretKey, - SessionToken: m.SessionToken, - Endpoint: m.Endpoint, - }) - if err != nil { - return err - } - - session, err := awsA.GetClient(ctx) - if err != nil { - return err - } - - cfg := aws.NewConfig(). - WithS3ForcePathStyle(m.ForcePathStyle). - WithDisableSSL(m.DisableSSL) + if s.s3Client == nil { + awsA, err := awsAuth.New(awsAuth.Options{ + Logger: s.logger, + Properties: metadata.Properties, + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, + Endpoint: m.Endpoint, + }) + if err != nil { + return err + } - // Use a custom HTTP client to allow self-signed certs - if m.InsecureSSL { - customTransport := http.DefaultTransport.(*http.Transport).Clone() - customTransport.TLSClientConfig = &tls.Config{ - //nolint:gosec - InsecureSkipVerify: true, + session, err := awsA.GetClient(ctx) + if err != nil { + return err } - client := &http.Client{ - Transport: customTransport, + + cfg := aws.NewConfig(). + WithS3ForcePathStyle(m.ForcePathStyle). + WithDisableSSL(m.DisableSSL) + + // Use a custom HTTP client to allow self-signed certs + if m.InsecureSSL { + customTransport := http.DefaultTransport.(*http.Transport).Clone() + customTransport.TLSClientConfig = &tls.Config{ + //nolint:gosec + InsecureSkipVerify: true, + } + client := &http.Client{ + Transport: customTransport, + } + cfg = cfg.WithHTTPClient(client) + + s.logger.Infof("aws s3: you are using 'insecureSSL' to skip server config verify which is unsafe!") } - cfg = cfg.WithHTTPClient(client) - s.logger.Infof("aws s3: you are using 'insecureSSL' to skip server config verify which is unsafe!") + s.s3Client = s3.New(session, cfg) + s.downloader = s3manager.NewDownloaderWithClient(s.s3Client) + s.uploader = s3manager.NewUploaderWithClient(s.s3Client) } s.metadata = m - s.s3Client = s3.New(session, cfg) - s.downloader = s3manager.NewDownloaderWithClient(s.s3Client) - s.uploader = s3manager.NewUploaderWithClient(s.s3Client) return nil } diff --git a/bindings/aws/ses/ses.go b/bindings/aws/ses/ses.go index bca4e097d2..1d28026a9f 100644 --- a/bindings/aws/ses/ses.go +++ b/bindings/aws/ses/ses.go @@ -68,27 +68,30 @@ func (a *AWSSES) Init(ctx context.Context, metadata bindings.Metadata) error { return err } - aws, err := awsAuth.New(awsAuth.Options{ - Logger: a.logger, - Properties: metadata.Properties, - Region: meta.Region, - AccessKey: meta.AccessKey, - SecretKey: meta.SecretKey, - SessionToken: meta.SessionToken, - }) - if err != nil { - return err - } - - sess, err := aws.GetClient(ctx) - if err != nil { - return err + if a.svc.Client == nil { + aws, err := awsAuth.New(awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: meta.Region, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, + }) + if err != nil { + return err + } + + sess, err := aws.GetClient(ctx) + if err != nil { + return err + } + + // Create an SES instance + svc := ses.New(sess) + a.svc = svc } - // Create an SES instance - svc := ses.New(sess) a.metadata = meta - a.svc = svc return nil } diff --git a/bindings/aws/sns/sns.go b/bindings/aws/sns/sns.go index 275743c477..e3617bdaf2 100644 --- a/bindings/aws/sns/sns.go +++ b/bindings/aws/sns/sns.go @@ -64,26 +64,28 @@ func (a *AWSSNS) Init(ctx context.Context, metadata bindings.Metadata) error { return err } - aws, err := awsAuth.New(awsAuth.Options{ - Logger: a.logger, - Properties: metadata.Properties, - Region: m.Region, - AccessKey: m.AccessKey, - SecretKey: m.SecretKey, - SessionToken: m.SessionToken, - Endpoint: m.Endpoint, - }) - if err != nil { - return err - } - - sess, err := aws.GetClient(ctx) - if err != nil { - return err + if a.client == nil { + aws, err := awsAuth.New(awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, + Endpoint: m.Endpoint, + }) + if err != nil { + return err + } + + sess, err := aws.GetClient(ctx) + if err != nil { + return err + } + + a.client = sns.New(sess) } - a.client = sns.New(sess) - a.topicARN = m.TopicArn return nil diff --git a/bindings/aws/sqs/sqs.go b/bindings/aws/sqs/sqs.go index 80db597075..fb51b24898 100644 --- a/bindings/aws/sqs/sqs.go +++ b/bindings/aws/sqs/sqs.go @@ -66,24 +66,26 @@ func (a *AWSSQS) Init(ctx context.Context, metadata bindings.Metadata) error { return err } - awsA, err := awsAuth.New(awsAuth.Options{ - Logger: a.logger, - Properties: metadata.Properties, - Region: m.Region, - AccessKey: m.AccessKey, - SecretKey: m.SecretKey, - SessionToken: m.SessionToken, - Endpoint: m.Endpoint, - }) - if err != nil { - return err - } + if a.Client == nil { + awsA, err := awsAuth.New(awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, + Endpoint: m.Endpoint, + }) + if err != nil { + return err + } - sess, err := awsA.GetClient(ctx) - if err != nil { - return err + sess, err := awsA.GetClient(ctx) + if err != nil { + return err + } + a.Client = sqs.New(sess) } - a.Client = sqs.New(sess) queueName := m.QueueName resultURL, err := a.Client.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{ diff --git a/secretstores/aws/parameterstore/parameterstore.go b/secretstores/aws/parameterstore/parameterstore.go index eecade0695..1735bbd674 100644 --- a/secretstores/aws/parameterstore/parameterstore.go +++ b/secretstores/aws/parameterstore/parameterstore.go @@ -65,23 +65,26 @@ func (s *ssmSecretStore) Init(ctx context.Context, metadata secretstores.Metadat return err } - awsA, err := awsAuth.New(awsAuth.Options{ - Logger: s.logger, - Properties: metadata.Properties, - Region: meta.Region, - AccessKey: meta.AccessKey, - SecretKey: meta.SecretKey, - SessionToken: meta.SessionToken, - }) - if err != nil { - return err - } + if s.client == nil { + awsA, err := awsAuth.New(awsAuth.Options{ + Logger: s.logger, + Properties: metadata.Properties, + Region: meta.Region, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, + }) + if err != nil { + return err + } - session, err := awsA.GetClient(ctx) - if err != nil { - return err + session, err := awsA.GetClient(ctx) + if err != nil { + return err + } + s.client = ssm.New(session) } - s.client = ssm.New(session) + s.prefix = meta.Prefix return nil diff --git a/secretstores/aws/secretmanager/secretmanager.go b/secretstores/aws/secretmanager/secretmanager.go index 31eb7b9261..2010c5ada8 100644 --- a/secretstores/aws/secretmanager/secretmanager.go +++ b/secretstores/aws/secretmanager/secretmanager.go @@ -59,24 +59,27 @@ func (s *smSecretStore) Init(ctx context.Context, metadata secretstores.Metadata return err } - awsA, err := awsAuth.New(awsAuth.Options{ - Logger: s.logger, - Properties: metadata.Properties, - Region: meta.Region, - AccessKey: meta.AccessKey, - SecretKey: meta.SecretKey, - SessionToken: meta.SessionToken, - }) - if err != nil { - return err - } + if s.client == nil { + awsA, err := awsAuth.New(awsAuth.Options{ + Logger: s.logger, + Properties: metadata.Properties, + Region: meta.Region, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, + }) + if err != nil { + return err + } - session, err := awsA.GetClient(ctx) - if err != nil { - return err + session, err := awsA.GetClient(ctx) + if err != nil { + return err + } + + s.client = secretsmanager.New(session) } - s.client = secretsmanager.New(session) return nil } diff --git a/state/aws/dynamodb/dynamodb.go b/state/aws/dynamodb/dynamodb.go index 503d7082c7..77642a1f8c 100644 --- a/state/aws/dynamodb/dynamodb.go +++ b/state/aws/dynamodb/dynamodb.go @@ -41,6 +41,7 @@ import ( type StateStore struct { state.BulkStore + logger logger.Logger client dynamodbiface.DynamoDBAPI table string ttlAttributeName string @@ -83,10 +84,25 @@ func (d *StateStore) Init(ctx context.Context, metadata state.Metadata) error { // This check is needed because d.client is set to a mock in tests if d.client == nil { - d.client, err = d.getClient(meta) + aws, err := awsAuth.New(awsAuth.Options{ + Logger: d.logger, + Properties: metadata.Properties, + Region: meta.Region, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, + Endpoint: meta.Endpoint, + }) if err != nil { return err } + + sess, err := aws.GetClient(ctx) + if err != nil { + return err + } + + d.client = dynamodb.New(sess) } d.table = meta.Table d.ttlAttributeName = meta.TTLAttributeName @@ -281,16 +297,6 @@ func (d *StateStore) getDynamoDBMetadata(meta state.Metadata) (*dynamoDBMetadata return &m, err } -func (d *StateStore) getClient(metadata *dynamoDBMetadata) (*dynamodb.DynamoDB, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - c := dynamodb.New(sess) - - return c, nil -} - // getItemFromReq converts a dapr state.SetRequest into an dynamodb item func (d *StateStore) getItemFromReq(req *state.SetRequest) (map[string]*dynamodb.AttributeValue, error) { value, err := d.marshalToString(req.Value) From fd30089c0c0c18f659684f1558034dc317708728 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Mon, 4 Nov 2024 16:28:24 -0600 Subject: [PATCH 06/39] style: make linter happy Signed-off-by: Samantha Coyle --- bindings/aws/sqs/sqs.go | 3 ++- common/authentication/aws/aws.go | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/bindings/aws/sqs/sqs.go b/bindings/aws/sqs/sqs.go index fb51b24898..4ea4291280 100644 --- a/bindings/aws/sqs/sqs.go +++ b/bindings/aws/sqs/sqs.go @@ -67,7 +67,8 @@ func (a *AWSSQS) Init(ctx context.Context, metadata bindings.Metadata) error { } if a.Client == nil { - awsA, err := awsAuth.New(awsAuth.Options{ + var awsA *awsAuth.AWS + awsA, err = awsAuth.New(awsAuth.Options{ Logger: a.logger, Properties: metadata.Properties, Region: m.Region, diff --git a/common/authentication/aws/aws.go b/common/authentication/aws/aws.go index 7f6c78a1a3..8e6a6a204a 100644 --- a/common/authentication/aws/aws.go +++ b/common/authentication/aws/aws.go @@ -37,13 +37,14 @@ import ( "github.com/aws/aws-sdk-go/aws/session" awssh "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + cryptopem "github.com/dapr/kit/crypto/pem" spiffecontext "github.com/dapr/kit/crypto/spiffe/context" "github.com/dapr/kit/logger" kitmd "github.com/dapr/kit/metadata" "github.com/dapr/kit/ptr" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgxpool" ) type EnvironmentSettings struct { @@ -106,7 +107,11 @@ func (a *AWS) getX509Client(ctx context.Context) (*session.Session, error) { return nil, fmt.Errorf("failed to marshal SVID: %w", err) } - var trustAnchor arn.ARN + var ( + trustAnchor arn.ARN + profile arn.ARN + ) + if a.x509Auth.TrustAnchorArn != nil { trustAnchor, err = arn.Parse(*a.x509Auth.TrustAnchorArn) if err != nil { @@ -116,7 +121,7 @@ func (a *AWS) getX509Client(ctx context.Context) (*session.Session, error) { } if a.x509Auth.TrustProfileArn != nil { - profile, err := arn.Parse(*a.x509Auth.TrustProfileArn) + profile, err = arn.Parse(*a.x509Auth.TrustProfileArn) if err != nil { return nil, err } @@ -140,7 +145,7 @@ func (a *AWS) getX509Client(ctx context.Context) (*session.Session, error) { return nil, err } - var ints []x509.Certificate + var ints = make([]x509.Certificate, len(certs)-1) for i := range certs[1:] { ints = append(ints, *certs[i+1]) } @@ -248,9 +253,6 @@ type AWSIAM struct { AWSSecretKey string `json:"awsSecretKey" mapstructure:"awsSecretKey"` // AWS region in which PostgreSQL is deployed. AWSRegion string `json:"awsRegion" mapstructure:"awsRegion"` - - // AWS IAM Roles anywhere related fields - x509Auth *x509Auth } type Options struct { From f712b694ab36869dcb9d1486cda9c1123d976f1c Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Mon, 4 Nov 2024 16:35:05 -0600 Subject: [PATCH 07/39] style: more linter fixes Signed-off-by: Samantha Coyle --- bindings/aws/kinesis/kinesis.go | 7 +++++-- bindings/aws/sqs/sqs.go | 4 +++- common/authentication/aws/aws.go | 12 ++++++------ 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/bindings/aws/kinesis/kinesis.go b/bindings/aws/kinesis/kinesis.go index 50f5742f05..32d7382f09 100644 --- a/bindings/aws/kinesis/kinesis.go +++ b/bindings/aws/kinesis/kinesis.go @@ -24,6 +24,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/kinesis" "github.com/cenkalti/backoff/v4" "github.com/google/uuid" @@ -113,7 +114,8 @@ func (a *AWSKinesis) Init(ctx context.Context, metadata bindings.Metadata) error } if a.client == nil { - awsA, err := awsAuth.New(awsAuth.Options{ + var awsA *awsAuth.AWS + awsA, err = awsAuth.New(awsAuth.Options{ Logger: a.logger, Properties: metadata.Properties, Region: m.Region, @@ -126,7 +128,8 @@ func (a *AWSKinesis) Init(ctx context.Context, metadata bindings.Metadata) error return err } - sess, err := awsA.GetClient(ctx) + var sess *session.Session + sess, err = awsA.GetClient(ctx) if err != nil { return err } diff --git a/bindings/aws/sqs/sqs.go b/bindings/aws/sqs/sqs.go index 4ea4291280..07b519293d 100644 --- a/bindings/aws/sqs/sqs.go +++ b/bindings/aws/sqs/sqs.go @@ -22,6 +22,7 @@ import ( "time" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sqs" "github.com/dapr/components-contrib/bindings" @@ -81,7 +82,8 @@ func (a *AWSSQS) Init(ctx context.Context, metadata bindings.Metadata) error { return err } - sess, err := awsA.GetClient(ctx) + var sess *session.Session + sess, err = awsA.GetClient(ctx) if err != nil { return err } diff --git a/common/authentication/aws/aws.go b/common/authentication/aws/aws.go index 8e6a6a204a..b5156108c9 100644 --- a/common/authentication/aws/aws.go +++ b/common/authentication/aws/aws.go @@ -145,7 +145,7 @@ func (a *AWS) getX509Client(ctx context.Context) (*session.Session, error) { return nil, err } - var ints = make([]x509.Certificate, len(certs)-1) + ints := make([]x509.Certificate, len(certs)-1) for i := range certs[1:] { ints = append(ints, *certs[i+1]) } @@ -275,18 +275,18 @@ type x509Auth struct { } func New(opts Options) (*AWS, error) { - var x509Auth x509Auth - if err := kitmd.DecodeMetadata(opts.Properties, &x509Auth); err != nil { + var x509AuthConfig x509Auth + if err := kitmd.DecodeMetadata(opts.Properties, &x509AuthConfig); err != nil { return nil, err } - if x509Auth.AssumeRoleArn != nil { - opts.Logger.Infof("sam x509 fields %s %s ", *x509Auth.AssumeRoleArn, *x509Auth.TrustAnchorArn) + if x509AuthConfig.AssumeRoleArn != nil { + opts.Logger.Infof("sam x509 fields %s %s ", *x509AuthConfig.AssumeRoleArn, *x509AuthConfig.TrustAnchorArn) } else { opts.Logger.Infof("sam still nil somehow...") } return &AWS{ - x509Auth: &x509Auth, + x509Auth: &x509AuthConfig, logger: opts.Logger, region: opts.Region, accessKey: opts.AccessKey, From 78cf670341d690c80d219ac855c332eeaf965ec2 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Mon, 4 Nov 2024 16:51:31 -0600 Subject: [PATCH 08/39] style: final linter tweaks Signed-off-by: Samantha Coyle --- common/authentication/aws/aws.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/authentication/aws/aws.go b/common/authentication/aws/aws.go index b5156108c9..f88fd8e912 100644 --- a/common/authentication/aws/aws.go +++ b/common/authentication/aws/aws.go @@ -93,7 +93,7 @@ func (a *AWS) getX509Client(ctx context.Context) (*session.Session, error) { // retrieve svid from spiffe context svid, ok := spiffecontext.From(ctx) if !ok { - return nil, fmt.Errorf("no SVID found in context") + return nil, errors.New("no SVID found in context") } // get x.509 svid svidx, err := svid.GetX509SVID() @@ -147,7 +147,7 @@ func (a *AWS) getX509Client(ctx context.Context) (*session.Session, error) { ints := make([]x509.Certificate, len(certs)-1) for i := range certs[1:] { - ints = append(ints, *certs[i+1]) + ints[i] = *certs[i+1] } key, err := cryptopem.DecodePEMPrivateKey(keyPEM) From 8aec6a5b592d1f20702085060fab6bfc9d4ffd9b Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 6 Nov 2024 12:28:32 -0600 Subject: [PATCH 09/39] fix(session): apply auto refresh to s3 Signed-off-by: Samantha Coyle --- .../builtin-authentication-profiles.yaml | 8 + bindings/aws/s3/s3.go | 67 ++- common/authentication/aws/aws.go | 424 ++++++++++++------ 3 files changed, 345 insertions(+), 154 deletions(-) diff --git a/.build-tools/builtin-authentication-profiles.yaml b/.build-tools/builtin-authentication-profiles.yaml index 4970fd67bc..2abe6ac3df 100644 --- a/.build-tools/builtin-authentication-profiles.yaml +++ b/.build-tools/builtin-authentication-profiles.yaml @@ -47,6 +47,14 @@ aws: ARN of the AWS IAM role to assume in the trusting AWS account. example: arn:aws:iam:012345678910:role/exampleIAMRoleName required: true + - name: sessionDuration + type: duration + description: | + Duration of the session using AWS IAM Roles Anywhere. + If set to 0m, temporary credentials will automatically rotate. + default: '15m' + example: '0m' + required: true azuread: - title: "Azure AD: Managed identity" diff --git a/bindings/aws/s3/s3.go b/bindings/aws/s3/s3.go index 540b33a5ff..6d1d54d483 100644 --- a/bindings/aws/s3/s3.go +++ b/bindings/aws/s3/s3.go @@ -29,6 +29,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3/s3manager" @@ -108,6 +109,26 @@ func NewAWSS3(logger logger.Logger) bindings.OutputBinding { return &AWSS3{logger: logger} } +func (s *AWSS3) getAWSConfig(awsA *awsAuth.AWS) *aws.Config { + cfg := awsA.GetConfig().WithS3ForcePathStyle(s.metadata.ForcePathStyle).WithDisableSSL(s.metadata.DisableSSL) + + // Use a custom HTTP client to allow self-signed certs + if s.metadata.InsecureSSL { + customTransport := http.DefaultTransport.(*http.Transport).Clone() + customTransport.TLSClientConfig = &tls.Config{ + //nolint:gosec + InsecureSkipVerify: true, + } + client := &http.Client{ + Transport: customTransport, + } + cfg = cfg.WithHTTPClient(client) + + s.logger.Infof("aws s3: you are using 'insecureSSL' to skip server config verify which is unsafe!") + } + return cfg +} + // Init does metadata parsing and connection creation. func (s *AWSS3) Init(ctx context.Context, metadata bindings.Metadata) error { m, err := s.parseMetadata(metadata) @@ -116,46 +137,40 @@ func (s *AWSS3) Init(ctx context.Context, metadata bindings.Metadata) error { } if s.s3Client == nil { + awsA, err := awsAuth.New(awsAuth.Options{ Logger: s.logger, Properties: metadata.Properties, Region: m.Region, + Endpoint: m.Endpoint, AccessKey: m.AccessKey, SecretKey: m.SecretKey, SessionToken: m.SessionToken, - Endpoint: m.Endpoint, }) if err != nil { return err } - - session, err := awsA.GetClient(ctx) + // initiate clients, before refreshing if needed + sess, err := awsA.GetClient(ctx) if err != nil { return err } - cfg := aws.NewConfig(). - WithS3ForcePathStyle(m.ForcePathStyle). - WithDisableSSL(m.DisableSSL) - - // Use a custom HTTP client to allow self-signed certs - if m.InsecureSSL { - customTransport := http.DefaultTransport.(*http.Transport).Clone() - customTransport.TLSClientConfig = &tls.Config{ - //nolint:gosec - InsecureSkipVerify: true, - } - client := &http.Client{ - Transport: customTransport, - } - cfg = cfg.WithHTTPClient(client) - - s.logger.Infof("aws s3: you are using 'insecureSSL' to skip server config verify which is unsafe!") - } - - s.s3Client = s3.New(session, cfg) + s.metadata = m + s.s3Client = s3.New(sess, s.getAWSConfig(awsA)) s.downloader = s3manager.NewDownloaderWithClient(s.s3Client) s.uploader = s3manager.NewUploaderWithClient(s.s3Client) + + go func() { + for { + select { + case refreshSession := <-awsA.GetSessionUpdateChannel(): + s.updateAWSClients(refreshSession, s.getAWSConfig(awsA)) + case <-ctx.Done(): + return + } + } + }() } s.metadata = m @@ -163,6 +178,12 @@ func (s *AWSS3) Init(ctx context.Context, metadata bindings.Metadata) error { return nil } +func (s *AWSS3) updateAWSClients(session *session.Session, cfgs *aws.Config) { + s.s3Client = s3.New(session, cfgs) + s.downloader = s3manager.NewDownloaderWithClient(s.s3Client) + s.uploader = s3manager.NewUploaderWithClient(s.s3Client) +} + func (s *AWSS3) Close() error { return nil } diff --git a/common/authentication/aws/aws.go b/common/authentication/aws/aws.go index f88fd8e912..f55eba5930 100644 --- a/common/authentication/aws/aws.go +++ b/common/authentication/aws/aws.go @@ -37,6 +37,7 @@ import ( "github.com/aws/aws-sdk-go/aws/session" awssh "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere/rolesanywhereiface" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" @@ -51,6 +52,90 @@ type EnvironmentSettings struct { Metadata map[string]string } +type AWS struct { + mu sync.RWMutex + logger logger.Logger + + x509Auth *x509Auth + + region string + endpoint string + accessKey string + secretKey string + sessionToken string +} + +type AWSIAM struct { + // Ignored by metadata parser because included in built-in authentication profile + // Access key to use for accessing PostgreSQL. + AWSAccessKey string `json:"awsAccessKey" mapstructure:"awsAccessKey"` + // Secret key to use for accessing PostgreSQL. + AWSSecretKey string `json:"awsSecretKey" mapstructure:"awsSecretKey"` + // AWS region in which PostgreSQL is deployed. + AWSRegion string `json:"awsRegion" mapstructure:"awsRegion"` +} + +type Options struct { + Logger logger.Logger + Properties map[string]string + + PoolConfig *pgxpool.Config `json:"poolConfig" mapstructure:"poolConfig"` + ConnectionString string `json:"connectionString" mapstructure:"connectionString"` + Region string `json:"region" mapstructure:"region"` + AccessKey string `json:"accessKey" mapstructure:"accessKey"` + SecretKey string `json:"secretKey" mapstructure:"secretKey"` + SessionToken string `json:"sessionToken" mapstructure:"sessionToken"` + Endpoint string `json:"endpoint" mapstructure:"endpoint"` +} + +type x509Auth struct { + TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"` + TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"` + AssumeRoleArn *string `json:"assumeRoleArn" mapstructure:"assumeRoleArn"` + SessionDuration *time.Duration `json:"sessionDuration" mapstructure:"sessionDuration"` + sessionExpiration time.Time + + chainPEM []byte + keyPEM []byte + + sessionUpdateChannel chan *session.Session + + rolesAnywhereClient rolesanywhereiface.RolesAnywhereAPI +} + +func New(opts Options) (*AWS, error) { + var x509AuthConfig x509Auth + if err := kitmd.DecodeMetadata(opts.Properties, &x509AuthConfig); err != nil { + return nil, err + } + + return &AWS{ + x509Auth: &x509AuthConfig, + logger: opts.Logger, + region: opts.Region, + accessKey: opts.AccessKey, + secretKey: opts.SecretKey, + sessionToken: opts.SessionToken, + endpoint: opts.Endpoint, + }, nil +} + +func (a *AWS) GetConfig() *aws.Config { + cfg := aws.NewConfig() + + if a.region != "" { + cfg.WithRegion(a.region) + } + + return cfg +} + +func (a *AWS) GetSessionUpdateChannel() chan *session.Session { + a.mu.Lock() + defer a.mu.Unlock() + return a.x509Auth.sessionUpdateChannel +} + func GetConfigV2(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (awsv2.Config, error) { optFns := []func(*config.LoadOptions) error{} if region != "" { @@ -75,47 +160,116 @@ func GetConfigV2(accessKey string, secretKey string, sessionToken string, region } func (a *AWS) GetClient(ctx context.Context) (*session.Session, error) { - a.lock.Lock() - defer a.lock.Unlock() + a.mu.Lock() + defer a.mu.Unlock() switch { // IAM Roles Anywhere option case a.x509Auth.TrustAnchorArn != nil && a.x509Auth.AssumeRoleArn != nil: a.logger.Debug("using X.509 RolesAnywhere authentication using Dapr SVID") - return a.getX509Client(ctx) + session, err := a.getX509Client(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create X.509 RolesAnywhere client") + } + // start a session refresher background goroutine to keep rotating the temporary creds + // use background context to keep alive + go a.startSessionRefresher(context.Background()) + + return session, nil default: a.logger.Debugf("using AWS session client...") - return a.getSessionClient() + return a.getTokenClient() } } -func (a *AWS) getX509Client(ctx context.Context) (*session.Session, error) { +func (a *AWS) getTokenClient() (*session.Session, error) { + awsConfig := aws.NewConfig() + + if a.region != "" { + awsConfig = awsConfig.WithRegion(a.region) + } + + if a.accessKey != "" && a.secretKey != "" { + awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(a.accessKey, a.secretKey, a.sessionToken)) + } + + if a.endpoint != "" { + awsConfig = awsConfig.WithEndpoint(a.endpoint) + } + + awsSession, err := session.NewSessionWithOptions(session.Options{ + Config: *awsConfig, + SharedConfigState: session.SharedConfigEnable, + }) + if err != nil { + return nil, err + } + + userAgentHandler := request.NamedHandler{ + Name: "UserAgentHandler", + Fn: request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion), + } + awsSession.Handlers.Build.PushBackNamed(userAgentHandler) + + return awsSession, nil +} + +func (a *AWS) getCertPEM(ctx context.Context) error { // retrieve svid from spiffe context svid, ok := spiffecontext.From(ctx) if !ok { - return nil, errors.New("no SVID found in context") + return errors.New("no SVID found in context") } // get x.509 svid svidx, err := svid.GetX509SVID() if err != nil { - return nil, err + return err } // marshal x.509 svid to pem format chainPEM, keyPEM, err := svidx.Marshal() if err != nil { - return nil, fmt.Errorf("failed to marshal SVID: %w", err) + return fmt.Errorf("failed to marshal SVID: %w", err) + } + + a.x509Auth.chainPEM = chainPEM + a.x509Auth.keyPEM = keyPEM + return nil +} + +func (a *AWS) getX509Client(ctx context.Context) (*session.Session, error) { + // retrieve svid from spiffe context + err := a.getCertPEM(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get x.509 credentials: %v", err) + } + + if err := a.initializeTrustAnchors(); err != nil { + return nil, err + } + + if err := a.initializeRolesAnywhereClient(); err != nil { + return nil, err + } + + err = a.createOrRefreshSession(ctx) + if err != nil { + return nil, fmt.Errorf("failed to refresh token for new session client") } + return a.getTokenClient() +} + +func (a *AWS) initializeTrustAnchors() error { var ( trustAnchor arn.ARN profile arn.ARN + err error ) - if a.x509Auth.TrustAnchorArn != nil { trustAnchor, err = arn.Parse(*a.x509Auth.TrustAnchorArn) if err != nil { - return nil, err + return err } a.region = trustAnchor.Region } @@ -123,104 +277,176 @@ func (a *AWS) getX509Client(ctx context.Context) (*session.Session, error) { if a.x509Auth.TrustProfileArn != nil { profile, err = arn.Parse(*a.x509Auth.TrustProfileArn) if err != nil { - return nil, err + return err } + if profile.Region != "" && trustAnchor.Region != profile.Region { - return nil, fmt.Errorf("trust anchor and profile must be in the same region: trustAnchor=%s, profile=%s", + return fmt.Errorf("trust anchor and profile must be in the same region: trustAnchor=%s, profile=%s", trustAnchor.Region, profile.Region) } } + return nil +} - mySession, err := session.NewSession() - if err != nil { - return nil, err +func (a *AWS) initializeRolesAnywhereClient() error { + if a.x509Auth.rolesAnywhereClient == nil { + client := &http.Client{Transport: &http.Transport{ + TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, + }} + mySession, err := session.NewSession() + if err != nil { + return err + } + config := aws.NewConfig().WithRegion(a.region).WithHTTPClient(client).WithLogLevel(aws.LogOff) + rolesAnywhereClient := rolesanywhere.New(mySession, config) + + // Set up signing function and handlers + if err := a.setSigningFunction(rolesAnywhereClient); err != nil { + return err + } + a.x509Auth.rolesAnywhereClient = rolesAnywhereClient } - client := &http.Client{Transport: &http.Transport{ - TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, - }} - config := aws.NewConfig().WithRegion(trustAnchor.Region).WithHTTPClient(client).WithLogLevel(aws.LogOff) - rolesAnywhereClient := rolesanywhere.New(mySession, config) - certs, err := cryptopem.DecodePEMCertificatesChain(chainPEM) + return nil + +} + +func (a *AWS) setSigningFunction(rolesAnywhereClient *rolesanywhere.RolesAnywhere) error { + certs, err := cryptopem.DecodePEMCertificatesChain(a.x509Auth.chainPEM) if err != nil { - return nil, err + return err } - ints := make([]x509.Certificate, len(certs)-1) + var ints []x509.Certificate for i := range certs[1:] { - ints[i] = *certs[i+1] + ints = append(ints, *certs[i+1]) } - key, err := cryptopem.DecodePEMPrivateKey(keyPEM) + key, err := cryptopem.DecodePEMPrivateKey(a.x509Auth.keyPEM) if err != nil { - return nil, err + return err } keyECDSA := key.(*ecdsa.PrivateKey) signFunc := awssh.CreateSignFunction(*keyECDSA, *certs[0], ints) - agentHandlerFunc := request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH) + agentHandlerFunc := request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH) rolesAnywhereClient.Handlers.Build.RemoveByName("core.SDKVersionUserAgentHandler") rolesAnywhereClient.Handlers.Build.PushBackNamed(request.NamedHandler{Name: "v4x509.CredHelperUserAgentHandler", Fn: agentHandlerFunc}) rolesAnywhereClient.Handlers.Sign.Clear() rolesAnywhereClient.Handlers.Sign.PushBackNamed(request.NamedHandler{Name: "v4x509.SignRequestHandler", Fn: signFunc}) - // TODO: make metadata field? - var duration int64 = 10000 - createSessionRequest := rolesanywhere.CreateSessionInput{ - Cert: ptr.Of(string(chainPEM)), - ProfileArn: a.x509Auth.TrustProfileArn, - TrustAnchorArn: a.x509Auth.TrustAnchorArn, - RoleArn: a.x509Auth.AssumeRoleArn, - DurationSeconds: &duration, - InstanceProperties: nil, - SessionName: nil, - } - output, err := rolesAnywhereClient.CreateSessionWithContext(ctx, &createSessionRequest) - if err != nil { - return nil, fmt.Errorf("failed to create session using dapr app dentity: %w", err) + return nil +} + +func (a *AWS) createOrRefreshSession(ctx context.Context) error { + var ( + duration int64 + createSessionRequest rolesanywhere.CreateSessionInput + ) + + if *a.x509Auth.SessionDuration != 0 { + duration = int64(a.x509Auth.SessionDuration.Seconds()) + + createSessionRequest = rolesanywhere.CreateSessionInput{ + Cert: ptr.Of(string(a.x509Auth.chainPEM)), + ProfileArn: a.x509Auth.TrustProfileArn, + TrustAnchorArn: a.x509Auth.TrustAnchorArn, + RoleArn: a.x509Auth.AssumeRoleArn, + DurationSeconds: aws.Int64(duration), + InstanceProperties: nil, + SessionName: nil, + } + } else { + duration = 900 // 15 minutes in seconds by default and be autorefreshed + + createSessionRequest = rolesanywhere.CreateSessionInput{ + Cert: ptr.Of(string(a.x509Auth.chainPEM)), + ProfileArn: a.x509Auth.TrustProfileArn, + TrustAnchorArn: a.x509Auth.TrustAnchorArn, + RoleArn: a.x509Auth.AssumeRoleArn, + DurationSeconds: aws.Int64(duration), + InstanceProperties: nil, + SessionName: nil, + } } + output, err := a.x509Auth.rolesAnywhereClient.CreateSessionWithContext(ctx, &createSessionRequest) + if err != nil { + return fmt.Errorf("failed to create session using dapr app identity: %w", err) + } if len(output.CredentialSet) != 1 { - return nil, fmt.Errorf("expected 1 credential set from X.509 rolesanyway response, got %d", len(output.CredentialSet)) + return fmt.Errorf("expected 1 credential set from X.509 rolesanyway response, got %d", len(output.CredentialSet)) } a.accessKey = *output.CredentialSet[0].Credentials.AccessKeyId a.secretKey = *output.CredentialSet[0].Credentials.SecretAccessKey a.sessionToken = *output.CredentialSet[0].Credentials.SessionToken - return a.getSessionClient() -} - -func (a *AWS) getSessionClient() (*session.Session, error) { - awsConfig := aws.NewConfig() - - if a.region != "" { - awsConfig = awsConfig.WithRegion(a.region) + // convert expiration time from *string to time.Time + expirationStr := output.CredentialSet[0].Credentials.Expiration + if expirationStr == nil { + return fmt.Errorf("expiration time is nil") } - if a.accessKey != "" && a.secretKey != "" { - awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(a.accessKey, a.secretKey, a.sessionToken)) + expirationTime, err := time.Parse(time.RFC3339, *expirationStr) + if err != nil { + return fmt.Errorf("failed to parse expiration time: %w", err) } - if a.endpoint != "" { - awsConfig = awsConfig.WithEndpoint(a.endpoint) - } + a.x509Auth.sessionExpiration = expirationTime - awsSession, err := session.NewSessionWithOptions(session.Options{ - Config: *awsConfig, - SharedConfigState: session.SharedConfigEnable, - }) - if err != nil { - return nil, err - } + return nil +} - userAgentHandler := request.NamedHandler{ - Name: "UserAgentHandler", - Fn: request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion), +func (a *AWS) startSessionRefresher(ctx context.Context) error { + a.logger.Debugf("starting session refresher for x509 auth") + // if there is a set session duration, then exit bc we will not auto refresh the session. + if *a.x509Auth.SessionDuration != 0 { + return nil } - awsSession.Handlers.Build.PushBackNamed(userAgentHandler) - return awsSession, nil + errChan := make(chan error, 1) + go func() { + ticker := time.NewTicker(2 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + remaining := time.Until(a.x509Auth.sessionExpiration) + if remaining <= 8*time.Minute { + a.logger.Infof("Refreshing session as expiration is within %v", remaining) + + // Refresh the session + err := a.createOrRefreshSession(ctx) + if err != nil { + errChan <- fmt.Errorf("failed to refresh session: %w", err) + return + } + + a.logger.Debugf("AWS IAM Roles Anywhere session refreshed successfully") + refreshedSession, err := a.getTokenClient() + if err != nil { + errChan <- fmt.Errorf("failed to get token client with refreshed credentials: %v", err) + return + } + a.x509Auth.sessionUpdateChannel <- refreshedSession + + } + case <-ctx.Done(): + a.logger.Infof("Session refresher stopped due to context cancellation") + errChan <- nil + return + } + } + }() + + select { + case err := <-errChan: + return err + case <-ctx.Done(): + return ctx.Err() + } } // NewEnvironmentSettings returns a new EnvironmentSettings configured for a given AWS resource. @@ -232,70 +458,6 @@ func NewEnvironmentSettings(md map[string]string) (EnvironmentSettings, error) { return es, nil } -type AWS struct { - lock sync.RWMutex - logger logger.Logger - - x509Auth *x509Auth - - region string - endpoint string - accessKey string - secretKey string - sessionToken string -} - -type AWSIAM struct { - // Ignored by metadata parser because included in built-in authentication profile - // Access key to use for accessing PostgreSQL. - AWSAccessKey string `json:"awsAccessKey" mapstructure:"awsAccessKey"` - // Secret key to use for accessing PostgreSQL. - AWSSecretKey string `json:"awsSecretKey" mapstructure:"awsSecretKey"` - // AWS region in which PostgreSQL is deployed. - AWSRegion string `json:"awsRegion" mapstructure:"awsRegion"` -} - -type Options struct { - Logger logger.Logger - Properties map[string]string - - PoolConfig *pgxpool.Config `json:"poolConfig" mapstructure:"poolConfig"` - ConnectionString string `json:"connectionString" mapstructure:"connectionString"` - Region string `json:"region" mapstructure:"region"` - AccessKey string `json:"accessKey" mapstructure:"accessKey"` - SecretKey string `json:"secretKey" mapstructure:"secretKey"` - SessionToken string `json:"sessionToken" mapstructure:"sessionToken"` - Endpoint string `json:"endpoint" mapstructure:"endpoint"` -} - -type x509Auth struct { - TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"` - TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"` - AssumeRoleArn *string `json:"assumeRoleArn" mapstructure:"assumeRoleArn"` -} - -func New(opts Options) (*AWS, error) { - var x509AuthConfig x509Auth - if err := kitmd.DecodeMetadata(opts.Properties, &x509AuthConfig); err != nil { - return nil, err - } - if x509AuthConfig.AssumeRoleArn != nil { - opts.Logger.Infof("sam x509 fields %s %s ", *x509AuthConfig.AssumeRoleArn, *x509AuthConfig.TrustAnchorArn) - } else { - opts.Logger.Infof("sam still nil somehow...") - } - - return &AWS{ - x509Auth: &x509AuthConfig, - logger: opts.Logger, - region: opts.Region, - accessKey: opts.AccessKey, - secretKey: opts.SecretKey, - sessionToken: opts.SessionToken, - endpoint: opts.Endpoint, - }, nil -} - func (opts *Options) GetAccessToken(ctx context.Context) (string, error) { dbEndpoint := opts.PoolConfig.ConnConfig.Host + ":" + strconv.Itoa(int(opts.PoolConfig.ConnConfig.Port)) var authenticationToken string From 3a5c8bf45e7052e89c6d3445cdd51d0cc68923c6 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 6 Nov 2024 13:07:05 -0600 Subject: [PATCH 10/39] style: mv x509 auth around Signed-off-by: Samantha Coyle --- common/authentication/aws/aws.go | 268 ----------------------- common/authentication/aws/aws_x509.go | 295 ++++++++++++++++++++++++++ 2 files changed, 295 insertions(+), 268 deletions(-) create mode 100644 common/authentication/aws/aws_x509.go diff --git a/common/authentication/aws/aws.go b/common/authentication/aws/aws.go index f55eba5930..83d8f350da 100644 --- a/common/authentication/aws/aws.go +++ b/common/authentication/aws/aws.go @@ -15,13 +15,8 @@ package aws import ( "context" - "crypto/ecdsa" - "crypto/tls" - "crypto/x509" "errors" "fmt" - "net/http" - "runtime" "strconv" "sync" "time" @@ -31,21 +26,14 @@ import ( v2creds "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/feature/rds/auth" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/arn" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" - awssh "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" - "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" - "github.com/aws/rolesanywhere-credential-helper/rolesanywhere/rolesanywhereiface" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" - cryptopem "github.com/dapr/kit/crypto/pem" - spiffecontext "github.com/dapr/kit/crypto/spiffe/context" "github.com/dapr/kit/logger" kitmd "github.com/dapr/kit/metadata" - "github.com/dapr/kit/ptr" ) type EnvironmentSettings struct { @@ -88,21 +76,6 @@ type Options struct { Endpoint string `json:"endpoint" mapstructure:"endpoint"` } -type x509Auth struct { - TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"` - TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"` - AssumeRoleArn *string `json:"assumeRoleArn" mapstructure:"assumeRoleArn"` - SessionDuration *time.Duration `json:"sessionDuration" mapstructure:"sessionDuration"` - sessionExpiration time.Time - - chainPEM []byte - keyPEM []byte - - sessionUpdateChannel chan *session.Session - - rolesAnywhereClient rolesanywhereiface.RolesAnywhereAPI -} - func New(opts Options) (*AWS, error) { var x509AuthConfig x509Auth if err := kitmd.DecodeMetadata(opts.Properties, &x509AuthConfig); err != nil { @@ -130,12 +103,6 @@ func (a *AWS) GetConfig() *aws.Config { return cfg } -func (a *AWS) GetSessionUpdateChannel() chan *session.Session { - a.mu.Lock() - defer a.mu.Unlock() - return a.x509Auth.sessionUpdateChannel -} - func GetConfigV2(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (awsv2.Config, error) { optFns := []func(*config.LoadOptions) error{} if region != "" { @@ -214,241 +181,6 @@ func (a *AWS) getTokenClient() (*session.Session, error) { return awsSession, nil } -func (a *AWS) getCertPEM(ctx context.Context) error { - // retrieve svid from spiffe context - svid, ok := spiffecontext.From(ctx) - if !ok { - return errors.New("no SVID found in context") - } - // get x.509 svid - svidx, err := svid.GetX509SVID() - if err != nil { - return err - } - - // marshal x.509 svid to pem format - chainPEM, keyPEM, err := svidx.Marshal() - if err != nil { - return fmt.Errorf("failed to marshal SVID: %w", err) - } - - a.x509Auth.chainPEM = chainPEM - a.x509Auth.keyPEM = keyPEM - return nil -} - -func (a *AWS) getX509Client(ctx context.Context) (*session.Session, error) { - // retrieve svid from spiffe context - err := a.getCertPEM(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get x.509 credentials: %v", err) - } - - if err := a.initializeTrustAnchors(); err != nil { - return nil, err - } - - if err := a.initializeRolesAnywhereClient(); err != nil { - return nil, err - } - - err = a.createOrRefreshSession(ctx) - if err != nil { - return nil, fmt.Errorf("failed to refresh token for new session client") - } - - return a.getTokenClient() -} - -func (a *AWS) initializeTrustAnchors() error { - var ( - trustAnchor arn.ARN - profile arn.ARN - err error - ) - if a.x509Auth.TrustAnchorArn != nil { - trustAnchor, err = arn.Parse(*a.x509Auth.TrustAnchorArn) - if err != nil { - return err - } - a.region = trustAnchor.Region - } - - if a.x509Auth.TrustProfileArn != nil { - profile, err = arn.Parse(*a.x509Auth.TrustProfileArn) - if err != nil { - return err - } - - if profile.Region != "" && trustAnchor.Region != profile.Region { - return fmt.Errorf("trust anchor and profile must be in the same region: trustAnchor=%s, profile=%s", - trustAnchor.Region, profile.Region) - } - } - return nil -} - -func (a *AWS) initializeRolesAnywhereClient() error { - if a.x509Auth.rolesAnywhereClient == nil { - client := &http.Client{Transport: &http.Transport{ - TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, - }} - mySession, err := session.NewSession() - if err != nil { - return err - } - config := aws.NewConfig().WithRegion(a.region).WithHTTPClient(client).WithLogLevel(aws.LogOff) - rolesAnywhereClient := rolesanywhere.New(mySession, config) - - // Set up signing function and handlers - if err := a.setSigningFunction(rolesAnywhereClient); err != nil { - return err - } - a.x509Auth.rolesAnywhereClient = rolesAnywhereClient - } - return nil - -} - -func (a *AWS) setSigningFunction(rolesAnywhereClient *rolesanywhere.RolesAnywhere) error { - certs, err := cryptopem.DecodePEMCertificatesChain(a.x509Auth.chainPEM) - if err != nil { - return err - } - - var ints []x509.Certificate - for i := range certs[1:] { - ints = append(ints, *certs[i+1]) - } - - key, err := cryptopem.DecodePEMPrivateKey(a.x509Auth.keyPEM) - if err != nil { - return err - } - - keyECDSA := key.(*ecdsa.PrivateKey) - signFunc := awssh.CreateSignFunction(*keyECDSA, *certs[0], ints) - - agentHandlerFunc := request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH) - rolesAnywhereClient.Handlers.Build.RemoveByName("core.SDKVersionUserAgentHandler") - rolesAnywhereClient.Handlers.Build.PushBackNamed(request.NamedHandler{Name: "v4x509.CredHelperUserAgentHandler", Fn: agentHandlerFunc}) - rolesAnywhereClient.Handlers.Sign.Clear() - rolesAnywhereClient.Handlers.Sign.PushBackNamed(request.NamedHandler{Name: "v4x509.SignRequestHandler", Fn: signFunc}) - - return nil -} - -func (a *AWS) createOrRefreshSession(ctx context.Context) error { - var ( - duration int64 - createSessionRequest rolesanywhere.CreateSessionInput - ) - - if *a.x509Auth.SessionDuration != 0 { - duration = int64(a.x509Auth.SessionDuration.Seconds()) - - createSessionRequest = rolesanywhere.CreateSessionInput{ - Cert: ptr.Of(string(a.x509Auth.chainPEM)), - ProfileArn: a.x509Auth.TrustProfileArn, - TrustAnchorArn: a.x509Auth.TrustAnchorArn, - RoleArn: a.x509Auth.AssumeRoleArn, - DurationSeconds: aws.Int64(duration), - InstanceProperties: nil, - SessionName: nil, - } - } else { - duration = 900 // 15 minutes in seconds by default and be autorefreshed - - createSessionRequest = rolesanywhere.CreateSessionInput{ - Cert: ptr.Of(string(a.x509Auth.chainPEM)), - ProfileArn: a.x509Auth.TrustProfileArn, - TrustAnchorArn: a.x509Auth.TrustAnchorArn, - RoleArn: a.x509Auth.AssumeRoleArn, - DurationSeconds: aws.Int64(duration), - InstanceProperties: nil, - SessionName: nil, - } - } - - output, err := a.x509Auth.rolesAnywhereClient.CreateSessionWithContext(ctx, &createSessionRequest) - if err != nil { - return fmt.Errorf("failed to create session using dapr app identity: %w", err) - } - if len(output.CredentialSet) != 1 { - return fmt.Errorf("expected 1 credential set from X.509 rolesanyway response, got %d", len(output.CredentialSet)) - } - - a.accessKey = *output.CredentialSet[0].Credentials.AccessKeyId - a.secretKey = *output.CredentialSet[0].Credentials.SecretAccessKey - a.sessionToken = *output.CredentialSet[0].Credentials.SessionToken - - // convert expiration time from *string to time.Time - expirationStr := output.CredentialSet[0].Credentials.Expiration - if expirationStr == nil { - return fmt.Errorf("expiration time is nil") - } - - expirationTime, err := time.Parse(time.RFC3339, *expirationStr) - if err != nil { - return fmt.Errorf("failed to parse expiration time: %w", err) - } - - a.x509Auth.sessionExpiration = expirationTime - - return nil -} - -func (a *AWS) startSessionRefresher(ctx context.Context) error { - a.logger.Debugf("starting session refresher for x509 auth") - // if there is a set session duration, then exit bc we will not auto refresh the session. - if *a.x509Auth.SessionDuration != 0 { - return nil - } - - errChan := make(chan error, 1) - go func() { - ticker := time.NewTicker(2 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - remaining := time.Until(a.x509Auth.sessionExpiration) - if remaining <= 8*time.Minute { - a.logger.Infof("Refreshing session as expiration is within %v", remaining) - - // Refresh the session - err := a.createOrRefreshSession(ctx) - if err != nil { - errChan <- fmt.Errorf("failed to refresh session: %w", err) - return - } - - a.logger.Debugf("AWS IAM Roles Anywhere session refreshed successfully") - refreshedSession, err := a.getTokenClient() - if err != nil { - errChan <- fmt.Errorf("failed to get token client with refreshed credentials: %v", err) - return - } - a.x509Auth.sessionUpdateChannel <- refreshedSession - - } - case <-ctx.Done(): - a.logger.Infof("Session refresher stopped due to context cancellation") - errChan <- nil - return - } - } - }() - - select { - case err := <-errChan: - return err - case <-ctx.Done(): - return ctx.Err() - } -} - // NewEnvironmentSettings returns a new EnvironmentSettings configured for a given AWS resource. func NewEnvironmentSettings(md map[string]string) (EnvironmentSettings, error) { es := EnvironmentSettings{ diff --git a/common/authentication/aws/aws_x509.go b/common/authentication/aws/aws_x509.go new file mode 100644 index 0000000000..0b77f9782f --- /dev/null +++ b/common/authentication/aws/aws_x509.go @@ -0,0 +1,295 @@ +/* +Copyright 2021 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + "crypto/ecdsa" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "net/http" + "runtime" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + awssh "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere/rolesanywhereiface" + cryptopem "github.com/dapr/kit/crypto/pem" + spiffecontext "github.com/dapr/kit/crypto/spiffe/context" + "github.com/dapr/kit/logger" + "github.com/dapr/kit/ptr" +) + +type x509Auth struct { + // todo unexport these fields except the channel + TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"` + TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"` + AssumeRoleArn *string `json:"assumeRoleArn" mapstructure:"assumeRoleArn"` + SessionDuration *time.Duration `json:"sessionDuration" mapstructure:"sessionDuration"` + sessionExpiration time.Time + + chainPEM []byte + keyPEM []byte + + sessionUpdateChannel chan *session.Session + + rolesAnywhereClient rolesanywhereiface.RolesAnywhereAPI +} + +func (a *AWS) getCertPEM(ctx context.Context) error { + // retrieve svid from spiffe context + svid, ok := spiffecontext.From(ctx) + if !ok { + return errors.New("no SVID found in context") + } + // get x.509 svid + svidx, err := svid.GetX509SVID() + if err != nil { + return err + } + + // marshal x.509 svid to pem format + chainPEM, keyPEM, err := svidx.Marshal() + if err != nil { + return fmt.Errorf("failed to marshal SVID: %w", err) + } + + a.x509Auth.chainPEM = chainPEM + a.x509Auth.keyPEM = keyPEM + return nil +} + +func (a *AWS) getX509Client(ctx context.Context) (*session.Session, error) { + // retrieve svid from spiffe context + err := a.getCertPEM(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get x.509 credentials: %v", err) + } + + if err := a.initializeTrustAnchors(); err != nil { + return nil, err + } + + if err := a.initializeRolesAnywhereClient(); err != nil { + return nil, err + } + + err = a.createOrRefreshSession(ctx) + if err != nil { + return nil, fmt.Errorf("failed to refresh token for new session client") + } + + return a.getTokenClient() +} + +func (a *AWS) initializeTrustAnchors() error { + var ( + trustAnchor arn.ARN + profile arn.ARN + err error + ) + if a.x509Auth.TrustAnchorArn != nil { + trustAnchor, err = arn.Parse(*a.x509Auth.TrustAnchorArn) + if err != nil { + return err + } + a.region = trustAnchor.Region + } + + if a.x509Auth.TrustProfileArn != nil { + profile, err = arn.Parse(*a.x509Auth.TrustProfileArn) + if err != nil { + return err + } + + if profile.Region != "" && trustAnchor.Region != profile.Region { + return fmt.Errorf("trust anchor and profile must be in the same region: trustAnchor=%s, profile=%s", + trustAnchor.Region, profile.Region) + } + } + return nil +} + +func (a *AWS) initializeRolesAnywhereClient() error { + if a.x509Auth.rolesAnywhereClient == nil { + client := &http.Client{Transport: &http.Transport{ + TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, + }} + mySession, err := session.NewSession() + if err != nil { + return err + } + config := aws.NewConfig().WithRegion(a.region).WithHTTPClient(client).WithLogLevel(aws.LogOff) + rolesAnywhereClient := rolesanywhere.New(mySession, config) + + // Set up signing function and handlers + if err := a.setSigningFunction(rolesAnywhereClient); err != nil { + return err + } + a.x509Auth.rolesAnywhereClient = rolesAnywhereClient + } + return nil + +} + +func (a *AWS) setSigningFunction(rolesAnywhereClient *rolesanywhere.RolesAnywhere) error { + certs, err := cryptopem.DecodePEMCertificatesChain(a.x509Auth.chainPEM) + if err != nil { + return err + } + + var ints []x509.Certificate + for i := range certs[1:] { + ints = append(ints, *certs[i+1]) + } + + key, err := cryptopem.DecodePEMPrivateKey(a.x509Auth.keyPEM) + if err != nil { + return err + } + + keyECDSA := key.(*ecdsa.PrivateKey) + signFunc := awssh.CreateSignFunction(*keyECDSA, *certs[0], ints) + + agentHandlerFunc := request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH) + rolesAnywhereClient.Handlers.Build.RemoveByName("core.SDKVersionUserAgentHandler") + rolesAnywhereClient.Handlers.Build.PushBackNamed(request.NamedHandler{Name: "v4x509.CredHelperUserAgentHandler", Fn: agentHandlerFunc}) + rolesAnywhereClient.Handlers.Sign.Clear() + rolesAnywhereClient.Handlers.Sign.PushBackNamed(request.NamedHandler{Name: "v4x509.SignRequestHandler", Fn: signFunc}) + + return nil +} + +func (a *AWS) createOrRefreshSession(ctx context.Context) error { + var ( + duration int64 + createSessionRequest rolesanywhere.CreateSessionInput + ) + + if *a.x509Auth.SessionDuration != 0 { + duration = int64(a.x509Auth.SessionDuration.Seconds()) + + createSessionRequest = rolesanywhere.CreateSessionInput{ + Cert: ptr.Of(string(a.x509Auth.chainPEM)), + ProfileArn: a.x509Auth.TrustProfileArn, + TrustAnchorArn: a.x509Auth.TrustAnchorArn, + RoleArn: a.x509Auth.AssumeRoleArn, + DurationSeconds: aws.Int64(duration), + InstanceProperties: nil, + SessionName: nil, + } + } else { + duration = 900 // 15 minutes in seconds by default and be autorefreshed + + createSessionRequest = rolesanywhere.CreateSessionInput{ + Cert: ptr.Of(string(a.x509Auth.chainPEM)), + ProfileArn: a.x509Auth.TrustProfileArn, + TrustAnchorArn: a.x509Auth.TrustAnchorArn, + RoleArn: a.x509Auth.AssumeRoleArn, + DurationSeconds: aws.Int64(duration), + InstanceProperties: nil, + SessionName: nil, + } + } + + output, err := a.x509Auth.rolesAnywhereClient.CreateSessionWithContext(ctx, &createSessionRequest) + if err != nil { + return fmt.Errorf("failed to create session using dapr app identity: %w", err) + } + if len(output.CredentialSet) != 1 { + return fmt.Errorf("expected 1 credential set from X.509 rolesanyway response, got %d", len(output.CredentialSet)) + } + + a.accessKey = *output.CredentialSet[0].Credentials.AccessKeyId + a.secretKey = *output.CredentialSet[0].Credentials.SecretAccessKey + a.sessionToken = *output.CredentialSet[0].Credentials.SessionToken + + // convert expiration time from *string to time.Time + expirationStr := output.CredentialSet[0].Credentials.Expiration + if expirationStr == nil { + return fmt.Errorf("expiration time is nil") + } + + expirationTime, err := time.Parse(time.RFC3339, *expirationStr) + if err != nil { + return fmt.Errorf("failed to parse expiration time: %w", err) + } + + a.x509Auth.sessionExpiration = expirationTime + + return nil +} + +func (a *AWS) startSessionRefresher(ctx context.Context) error { + a.logger.Debugf("starting session refresher for x509 auth") + // if there is a set session duration, then exit bc we will not auto refresh the session. + if *a.x509Auth.SessionDuration != 0 { + return nil + } + + errChan := make(chan error, 1) + go func() { + ticker := time.NewTicker(2 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + remaining := time.Until(a.x509Auth.sessionExpiration) + if remaining <= 8*time.Minute { + a.logger.Infof("Refreshing session as expiration is within %v", remaining) + + // Refresh the session + err := a.createOrRefreshSession(ctx) + if err != nil { + errChan <- fmt.Errorf("failed to refresh session: %w", err) + return + } + + a.logger.Debugf("AWS IAM Roles Anywhere session refreshed successfully") + refreshedSession, err := a.getTokenClient() + if err != nil { + errChan <- fmt.Errorf("failed to get token client with refreshed credentials: %v", err) + return + } + a.x509Auth.sessionUpdateChannel <- refreshedSession + + } + case <-ctx.Done(): + a.logger.Infof("Session refresher stopped due to context cancellation") + errChan <- nil + return + } + } + }() + + select { + case err := <-errChan: + return err + case <-ctx.Done(): + return ctx.Err() + } +} + +func (a *AWS) GetSessionUpdateChannel() chan *session.Session { + a.mu.Lock() + defer a.mu.Unlock() + return a.x509Auth.sessionUpdateChannel +} From 43ba1e3b3168aa98c88deafb33a3d0fa99e442e0 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Sun, 10 Nov 2024 20:03:38 -0600 Subject: [PATCH 11/39] refactor: overhaul + interfaces for cleanliness + tests + update comps Signed-off-by: Samantha Coyle --- bindings/aws/dynamodb/dynamodb.go | 22 +- bindings/aws/kinesis/kinesis.go | 82 ++-- bindings/aws/s3/s3.go | 69 +-- bindings/aws/ses/ses.go | 40 +- bindings/aws/sns/sns.go | 25 +- bindings/aws/sqs/sqs.go | 66 ++- common/authentication/aws/aws.go | 161 +++---- common/authentication/aws/aws_client.go | 215 +++++++++ common/authentication/aws/aws_test.go | 30 ++ common/authentication/aws/aws_x509.go | 295 ------------ common/authentication/aws/static_iam.go | 257 +++++++++++ common/authentication/aws/x509_iam.go | 430 ++++++++++++++++++ go.mod | 3 +- go.sum | 2 + pubsub/aws/snssqs/snssqs.go | 78 ++-- .../aws/parameterstore/parameterstore.go | 38 +- .../aws/secretmanager/secretmanager.go | 27 +- state/aws/dynamodb/dynamodb.go | 29 +- tests/certification/go.mod | 2 + tests/certification/go.sum | 2 + 20 files changed, 1189 insertions(+), 684 deletions(-) create mode 100644 common/authentication/aws/aws_client.go create mode 100644 common/authentication/aws/aws_test.go delete mode 100644 common/authentication/aws/aws_x509.go create mode 100644 common/authentication/aws/static_iam.go create mode 100644 common/authentication/aws/x509_iam.go diff --git a/bindings/aws/dynamodb/dynamodb.go b/bindings/aws/dynamodb/dynamodb.go index dcf3155836..031e751dbf 100644 --- a/bindings/aws/dynamodb/dynamodb.go +++ b/bindings/aws/dynamodb/dynamodb.go @@ -31,9 +31,9 @@ import ( // DynamoDB allows performing stateful operations on AWS DynamoDB. type DynamoDB struct { - client *dynamodb.DynamoDB - table string - logger logger.Logger + authProvider awsAuth.Provider + table string + logger logger.Logger } type dynamoDBMetadata struct { @@ -56,26 +56,22 @@ func (d *DynamoDB) Init(ctx context.Context, metadata bindings.Metadata) error { if err != nil { return err } - if d.client == nil { - aws, err := awsAuth.New(awsAuth.Options{ + if d.authProvider == nil { + opts := awsAuth.Options{ Logger: d.logger, Properties: metadata.Properties, Region: meta.Region, + Endpoint: meta.Endpoint, AccessKey: meta.AccessKey, SecretKey: meta.SecretKey, SessionToken: meta.SessionToken, - Endpoint: meta.Endpoint, - }) - if err != nil { - return err } - sess, err := aws.GetClient(ctx) + provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) if err != nil { return err } - - d.client = dynamodb.New(sess) + d.authProvider = provider } d.table = meta.Table @@ -98,7 +94,7 @@ func (d *DynamoDB) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bi return nil, err } - _, err = d.client.PutItemWithContext(ctx, &dynamodb.PutItemInput{ + _, err = d.authProvider.DynamoDB(ctx).DynamoDB.PutItemWithContext(ctx, &dynamodb.PutItemInput{ Item: item, TableName: aws.String(d.table), }) diff --git a/bindings/aws/kinesis/kinesis.go b/bindings/aws/kinesis/kinesis.go index 32d7382f09..8d2418821e 100644 --- a/bindings/aws/kinesis/kinesis.go +++ b/bindings/aws/kinesis/kinesis.go @@ -24,11 +24,9 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/kinesis" "github.com/cenkalti/backoff/v4" "github.com/google/uuid" - "github.com/vmware/vmware-go-kcl/clientlibrary/config" "github.com/vmware/vmware-go-kcl/clientlibrary/interfaces" "github.com/vmware/vmware-go-kcl/clientlibrary/worker" @@ -41,15 +39,16 @@ import ( // AWSKinesis allows receiving and sending data to/from AWS Kinesis stream. type AWSKinesis struct { - client *kinesis.Kinesis - metadata *kinesisMetadata + authProvider awsAuth.Provider + metadata *kinesisMetadata - worker *worker.Worker - workerConfig *config.KinesisClientLibConfiguration + worker *worker.Worker - streamARN *string - consumerARN *string - logger logger.Logger + streamName string + consumerName string + consumerARN *string + logger logger.Logger + consumerMode string closed atomic.Bool closeCh chan struct{} @@ -113,47 +112,28 @@ func (a *AWSKinesis) Init(ctx context.Context, metadata bindings.Metadata) error return fmt.Errorf("%s invalid \"mode\" field %s", "aws.kinesis", m.KinesisConsumerMode) } - if a.client == nil { - var awsA *awsAuth.AWS - awsA, err = awsAuth.New(awsAuth.Options{ + a.consumerMode = m.KinesisConsumerMode + a.streamName = m.StreamName + a.consumerName = m.ConsumerName + a.metadata = m + + if a.authProvider == nil { + opts := awsAuth.Options{ Logger: a.logger, Properties: metadata.Properties, Region: m.Region, AccessKey: m.AccessKey, SecretKey: m.SecretKey, - SessionToken: m.SessionToken, - Endpoint: m.Endpoint, - }) - if err != nil { - return err + SessionToken: "", } - - var sess *session.Session - sess, err = awsA.GetClient(ctx) + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) if err != nil { return err } - a.client = kinesis.New(sess) - } - - streamName := aws.String(m.StreamName) - stream, err := a.client.DescribeStreamWithContext(ctx, &kinesis.DescribeStreamInput{ - StreamName: streamName, - }) - if err != nil { - return err + a.authProvider = provider } - if m.KinesisConsumerMode == SharedThroughput { - kclConfig := config.NewKinesisClientLibConfigWithCredential(m.ConsumerName, - m.StreamName, m.Region, m.ConsumerName, - a.client.Config.Credentials) - a.workerConfig = kclConfig - } - - a.streamARN = stream.StreamDescription.StreamARN - a.metadata = m - return nil } @@ -166,7 +146,7 @@ func (a *AWSKinesis) Invoke(ctx context.Context, req *bindings.InvokeRequest) (* if partitionKey == "" { partitionKey = uuid.New().String() } - _, err := a.client.PutRecordWithContext(ctx, &kinesis.PutRecordInput{ + _, err := a.authProvider.Kinesis(ctx).Kinesis.PutRecordWithContext(ctx, &kinesis.PutRecordInput{ StreamName: &a.metadata.StreamName, Data: req.Data, PartitionKey: &partitionKey, @@ -181,14 +161,14 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er } if a.metadata.KinesisConsumerMode == SharedThroughput { - a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), a.workerConfig) + a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), a.authProvider.Kinesis(ctx).WorkerCfg(ctx, a.streamName, a.consumerName, a.consumerMode)) err = a.worker.Start() if err != nil { return err } } else if a.metadata.KinesisConsumerMode == ExtendedFanout { var stream *kinesis.DescribeStreamOutput - stream, err = a.client.DescribeStream(&kinesis.DescribeStreamInput{StreamName: &a.metadata.StreamName}) + stream, err = a.authProvider.Kinesis(ctx).Kinesis.DescribeStream(&kinesis.DescribeStreamInput{StreamName: &a.metadata.StreamName}) if err != nil { return err } @@ -200,6 +180,10 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er // Wait for context cancelation then stop a.wg.Add(1) + stream, err := a.authProvider.Kinesis(ctx).Stream(ctx, a.streamName) + if err != nil { + return fmt.Errorf("failed to get kinesis stream arn: %v", err) + } go func() { defer a.wg.Done() select { @@ -209,7 +193,7 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er if a.metadata.KinesisConsumerMode == SharedThroughput { a.worker.Shutdown() } else if a.metadata.KinesisConsumerMode == ExtendedFanout { - a.deregisterConsumer(a.streamARN, a.consumerARN) + a.deregisterConsumer(stream, a.consumerARN) } }() @@ -245,7 +229,7 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes default: } - sub, err := a.client.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{ + sub, err := a.authProvider.Kinesis(ctx).Kinesis.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{ ConsumerARN: consumerARN, ShardId: s.ShardId, StartingPosition: &kinesis.StartingPosition{Type: aws.String(kinesis.ShardIteratorTypeLatest)}, @@ -287,14 +271,14 @@ func (a *AWSKinesis) Close() error { close(a.closeCh) } a.wg.Wait() - return nil + return a.authProvider.Close() } func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*string, error) { // Only set timeout on consumer call. conCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - consumer, err := a.client.DescribeStreamConsumerWithContext(conCtx, &kinesis.DescribeStreamConsumerInput{ + consumer, err := a.authProvider.Kinesis(ctx).Kinesis.DescribeStreamConsumerWithContext(conCtx, &kinesis.DescribeStreamConsumerInput{ ConsumerName: &a.metadata.ConsumerName, StreamARN: streamARN, }) @@ -306,7 +290,7 @@ func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*st } func (a *AWSKinesis) registerConsumer(ctx context.Context, streamARN *string) (*string, error) { - consumer, err := a.client.RegisterStreamConsumerWithContext(ctx, &kinesis.RegisterStreamConsumerInput{ + consumer, err := a.authProvider.Kinesis(ctx).Kinesis.RegisterStreamConsumerWithContext(ctx, &kinesis.RegisterStreamConsumerInput{ ConsumerName: &a.metadata.ConsumerName, StreamARN: streamARN, }) @@ -329,7 +313,7 @@ func (a *AWSKinesis) deregisterConsumer(streamARN *string, consumerARN *string) if a.consumerARN != nil { // Use a background context because the running context may have been canceled already ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - _, err := a.client.DeregisterStreamConsumerWithContext(ctx, &kinesis.DeregisterStreamConsumerInput{ + _, err := a.authProvider.Kinesis(ctx).Kinesis.DeregisterStreamConsumerWithContext(ctx, &kinesis.DeregisterStreamConsumerInput{ ConsumerARN: consumerARN, StreamARN: streamARN, ConsumerName: &a.metadata.ConsumerName, @@ -360,7 +344,7 @@ func (a *AWSKinesis) waitUntilConsumerExists(ctx aws.Context, input *kinesis.Des tmp := *input inCpy = &tmp } - req, _ := a.client.DescribeStreamConsumerRequest(inCpy) + req, _ := a.authProvider.Kinesis(ctx).Kinesis.DescribeStreamConsumerRequest(inCpy) req.SetContext(ctx) req.ApplyOptions(opts...) diff --git a/bindings/aws/s3/s3.go b/bindings/aws/s3/s3.go index 6d1d54d483..9ad69da5fa 100644 --- a/bindings/aws/s3/s3.go +++ b/bindings/aws/s3/s3.go @@ -29,7 +29,6 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3/s3manager" @@ -61,11 +60,9 @@ const ( // AWSS3 is a binding for an AWS S3 storage bucket. type AWSS3 struct { - metadata *s3Metadata - s3Client *s3.S3 - uploader *s3manager.Uploader - downloader *s3manager.Downloader - logger logger.Logger + metadata *s3Metadata + authProvider awsAuth.Provider + logger logger.Logger } type s3Metadata struct { @@ -109,8 +106,8 @@ func NewAWSS3(logger logger.Logger) bindings.OutputBinding { return &AWSS3{logger: logger} } -func (s *AWSS3) getAWSConfig(awsA *awsAuth.AWS) *aws.Config { - cfg := awsA.GetConfig().WithS3ForcePathStyle(s.metadata.ForcePathStyle).WithDisableSSL(s.metadata.DisableSSL) +func (s *AWSS3) getAWSConfig(opts awsAuth.Options) *aws.Config { + cfg := awsAuth.GetConfig(opts).WithS3ForcePathStyle(s.metadata.ForcePathStyle).WithDisableSSL(s.metadata.DisableSSL) // Use a custom HTTP client to allow self-signed certs if s.metadata.InsecureSSL { @@ -136,9 +133,10 @@ func (s *AWSS3) Init(ctx context.Context, metadata bindings.Metadata) error { return err } - if s.s3Client == nil { + if s.authProvider == nil { + s.metadata = m - awsA, err := awsAuth.New(awsAuth.Options{ + opts := awsAuth.Options{ Logger: s.logger, Properties: metadata.Properties, Region: m.Region, @@ -146,46 +144,21 @@ func (s *AWSS3) Init(ctx context.Context, metadata bindings.Metadata) error { AccessKey: m.AccessKey, SecretKey: m.SecretKey, SessionToken: m.SessionToken, - }) - if err != nil { - return err } - // initiate clients, before refreshing if needed - sess, err := awsA.GetClient(ctx) + // extra configs needed per component type + cfg := s.getAWSConfig(opts) + provider, err := awsAuth.NewProvider(ctx, opts, cfg) if err != nil { return err } - - s.metadata = m - s.s3Client = s3.New(sess, s.getAWSConfig(awsA)) - s.downloader = s3manager.NewDownloaderWithClient(s.s3Client) - s.uploader = s3manager.NewUploaderWithClient(s.s3Client) - - go func() { - for { - select { - case refreshSession := <-awsA.GetSessionUpdateChannel(): - s.updateAWSClients(refreshSession, s.getAWSConfig(awsA)) - case <-ctx.Done(): - return - } - } - }() + s.authProvider = provider } - s.metadata = m - return nil } -func (s *AWSS3) updateAWSClients(session *session.Session, cfgs *aws.Config) { - s.s3Client = s3.New(session, cfgs) - s.downloader = s3manager.NewDownloaderWithClient(s.s3Client) - s.uploader = s3manager.NewUploaderWithClient(s.s3Client) -} - func (s *AWSS3) Close() error { - return nil + return s.authProvider.Close() } func (s *AWSS3) Operations() []bindings.OperationKind { @@ -239,7 +212,7 @@ func (s *AWSS3) create(ctx context.Context, req *bindings.InvokeRequest) (*bindi storageClass = aws.String(metadata.StorageClass) } - resultUpload, err := s.uploader.UploadWithContext(ctx, &s3manager.UploadInput{ + resultUpload, err := s.authProvider.S3(ctx).Uploader.UploadWithContext(ctx, &s3manager.UploadInput{ Bucket: ptr.Of(metadata.Bucket), Key: ptr.Of(key), Body: r, @@ -252,7 +225,7 @@ func (s *AWSS3) create(ctx context.Context, req *bindings.InvokeRequest) (*bindi var presignURL string if metadata.PresignTTL != "" { - url, presignErr := s.presignObject(metadata.Bucket, key, metadata.PresignTTL) + url, presignErr := s.presignObject(ctx, metadata.Bucket, key, metadata.PresignTTL) if presignErr != nil { return nil, fmt.Errorf("s3 binding error: %s", presignErr) } @@ -292,7 +265,7 @@ func (s *AWSS3) presign(ctx context.Context, req *bindings.InvokeRequest) (*bind return nil, fmt.Errorf("s3 binding error: required metadata '%s' missing", metadataPresignTTL) } - url, err := s.presignObject(metadata.Bucket, key, metadata.PresignTTL) + url, err := s.presignObject(ctx, metadata.Bucket, key, metadata.PresignTTL) if err != nil { return nil, fmt.Errorf("s3 binding error: %w", err) } @@ -309,13 +282,13 @@ func (s *AWSS3) presign(ctx context.Context, req *bindings.InvokeRequest) (*bind }, nil } -func (s *AWSS3) presignObject(bucket, key, ttl string) (string, error) { +func (s *AWSS3) presignObject(ctx context.Context, bucket, key, ttl string) (string, error) { d, err := time.ParseDuration(ttl) if err != nil { return "", fmt.Errorf("s3 binding error: cannot parse duration %s: %w", ttl, err) } - objReq, _ := s.s3Client.GetObjectRequest(&s3.GetObjectInput{ + objReq, _ := s.authProvider.S3(ctx).S3.GetObjectRequest(&s3.GetObjectInput{ Bucket: ptr.Of(bucket), Key: ptr.Of(key), }) @@ -340,7 +313,7 @@ func (s *AWSS3) get(ctx context.Context, req *bindings.InvokeRequest) (*bindings buff := &aws.WriteAtBuffer{} - _, err = s.downloader.DownloadWithContext(ctx, + _, err = s.authProvider.S3(ctx).Downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{ Bucket: ptr.Of(s.metadata.Bucket), @@ -375,7 +348,7 @@ func (s *AWSS3) delete(ctx context.Context, req *bindings.InvokeRequest) (*bindi return nil, fmt.Errorf("s3 binding error: required metadata '%s' missing", metadataKey) } - _, err := s.s3Client.DeleteObjectWithContext( + _, err := s.authProvider.S3(ctx).S3.DeleteObjectWithContext( ctx, &s3.DeleteObjectInput{ Bucket: ptr.Of(s.metadata.Bucket), @@ -405,7 +378,7 @@ func (s *AWSS3) list(ctx context.Context, req *bindings.InvokeRequest) (*binding payload.MaxResults = defaultMaxResults } - result, err := s.s3Client.ListObjectsWithContext(ctx, &s3.ListObjectsInput{ + result, err := s.authProvider.S3(ctx).S3.ListObjectsWithContext(ctx, &s3.ListObjectsInput{ Bucket: ptr.Of(s.metadata.Bucket), MaxKeys: ptr.Of(int64(payload.MaxResults)), Marker: ptr.Of(payload.Marker), diff --git a/bindings/aws/ses/ses.go b/bindings/aws/ses/ses.go index 1d28026a9f..b65f3d0358 100644 --- a/bindings/aws/ses/ses.go +++ b/bindings/aws/ses/ses.go @@ -38,9 +38,9 @@ const ( // AWSSES is an AWS SNS binding. type AWSSES struct { - metadata *sesMetadata - logger logger.Logger - svc *ses.SES + authProvider awsAuth.Provider + metadata *sesMetadata + logger logger.Logger } type sesMetadata struct { @@ -63,36 +63,32 @@ func NewAWSSES(logger logger.Logger) bindings.OutputBinding { // Init does metadata parsing. func (a *AWSSES) Init(ctx context.Context, metadata bindings.Metadata) error { // Parse input metadata - meta, err := a.parseMetadata(metadata) + m, err := a.parseMetadata(metadata) if err != nil { return err } - if a.svc.Client == nil { - aws, err := awsAuth.New(awsAuth.Options{ + a.metadata = m + + if a.authProvider == nil { + a.metadata = m + + opts := awsAuth.Options{ Logger: a.logger, Properties: metadata.Properties, - Region: meta.Region, - AccessKey: meta.AccessKey, - SecretKey: meta.SecretKey, - SessionToken: meta.SessionToken, - }) - if err != nil { - return err + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: "", } - - sess, err := aws.GetClient(ctx) + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) if err != nil { return err } - - // Create an SES instance - svc := ses.New(sess) - a.svc = svc + a.authProvider = provider } - a.metadata = meta - return nil } @@ -159,7 +155,7 @@ func (a *AWSSES) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind } // Attempt to send the email. - result, err := a.svc.SendEmail(input) + result, err := a.authProvider.Ses(ctx).Ses.SendEmail(input) if err != nil { return nil, fmt.Errorf("SES binding error. Sending email failed: %w", err) } diff --git a/bindings/aws/sns/sns.go b/bindings/aws/sns/sns.go index e3617bdaf2..adc82dfca6 100644 --- a/bindings/aws/sns/sns.go +++ b/bindings/aws/sns/sns.go @@ -19,6 +19,7 @@ import ( "fmt" "reflect" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/sns" "github.com/dapr/components-contrib/bindings" @@ -30,8 +31,8 @@ import ( // AWSSNS is an AWS SNS binding. type AWSSNS struct { - client *sns.SNS - topicARN string + authProvider awsAuth.Provider + topicARN string logger logger.Logger } @@ -64,26 +65,22 @@ func (a *AWSSNS) Init(ctx context.Context, metadata bindings.Metadata) error { return err } - if a.client == nil { - aws, err := awsAuth.New(awsAuth.Options{ + if a.authProvider == nil { + opts := awsAuth.Options{ Logger: a.logger, Properties: metadata.Properties, Region: m.Region, + Endpoint: m.Endpoint, AccessKey: m.AccessKey, SecretKey: m.SecretKey, SessionToken: m.SessionToken, - Endpoint: m.Endpoint, - }) - if err != nil { - return err } - - sess, err := aws.GetClient(ctx) + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) if err != nil { return err } - - a.client = sns.New(sess) + a.authProvider = provider } a.topicARN = m.TopicArn @@ -115,7 +112,7 @@ func (a *AWSSNS) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind msg := fmt.Sprintf("%v", payload.Message) subject := fmt.Sprintf("%v", payload.Subject) - _, err = a.client.PublishWithContext(ctx, &sns.PublishInput{ + _, err = a.authProvider.Sns(ctx).Sns.PublishWithContext(ctx, &sns.PublishInput{ Message: &msg, Subject: &subject, TopicArn: &a.topicARN, @@ -135,5 +132,5 @@ func (a *AWSSNS) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { } func (a *AWSSNS) Close() error { - return nil + return a.authProvider.Close() } diff --git a/bindings/aws/sqs/sqs.go b/bindings/aws/sqs/sqs.go index 07b519293d..2075966450 100644 --- a/bindings/aws/sqs/sqs.go +++ b/bindings/aws/sqs/sqs.go @@ -16,13 +16,13 @@ package sqs import ( "context" "errors" + "fmt" "reflect" "sync" "sync/atomic" "time" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sqs" "github.com/dapr/components-contrib/bindings" @@ -34,13 +34,12 @@ import ( // AWSSQS allows receiving and sending data to/from AWS SQS. type AWSSQS struct { - Client *sqs.SQS - QueueURL *string - - logger logger.Logger - wg sync.WaitGroup - closeCh chan struct{} - closed atomic.Bool + authProvider awsAuth.Provider + queueName string + logger logger.Logger + wg sync.WaitGroup + closeCh chan struct{} + closed atomic.Bool } type sqsMetadata struct { @@ -67,38 +66,25 @@ func (a *AWSSQS) Init(ctx context.Context, metadata bindings.Metadata) error { return err } - if a.Client == nil { - var awsA *awsAuth.AWS - awsA, err = awsAuth.New(awsAuth.Options{ + if a.authProvider == nil { + opts := awsAuth.Options{ Logger: a.logger, Properties: metadata.Properties, Region: m.Region, + Endpoint: m.Endpoint, AccessKey: m.AccessKey, SecretKey: m.SecretKey, SessionToken: m.SessionToken, - Endpoint: m.Endpoint, - }) - if err != nil { - return err } - - var sess *session.Session - sess, err = awsA.GetClient(ctx) + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) if err != nil { return err } - a.Client = sqs.New(sess) - } - - queueName := m.QueueName - resultURL, err := a.Client.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{ - QueueName: aws.String(queueName), - }) - if err != nil { - return err + a.authProvider = provider } - a.QueueURL = resultURL.QueueUrl + a.queueName = m.QueueName return nil } @@ -109,9 +95,13 @@ func (a *AWSSQS) Operations() []bindings.OperationKind { func (a *AWSSQS) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { msgBody := string(req.Data) - _, err := a.Client.SendMessageWithContext(ctx, &sqs.SendMessageInput{ + url, err := a.authProvider.Sqs(ctx).QueueURL(ctx, a.queueName) + if err != nil { + return nil, fmt.Errorf("failed to get queue url: %v", err) + } + _, err = a.authProvider.Sqs(ctx).Sqs.SendMessageWithContext(ctx, &sqs.SendMessageInput{ MessageBody: &msgBody, - QueueUrl: a.QueueURL, + QueueUrl: url, }) return nil, err @@ -131,9 +121,13 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error { if ctx.Err() != nil || a.closed.Load() { return } + url, err := a.authProvider.Sqs(ctx).QueueURL(ctx, a.queueName) + if err != nil { + fmt.Errorf("failed to get queue url: %v", err) + } - result, err := a.Client.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{ - QueueUrl: a.QueueURL, + result, err := a.authProvider.Sqs(ctx).Sqs.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{ + QueueUrl: url, AttributeNames: aws.StringSlice([]string{ "SentTimestamp", }), @@ -144,7 +138,7 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error { WaitTimeSeconds: aws.Int64(20), }) if err != nil { - a.logger.Errorf("Unable to receive message from queue %q, %v.", *a.QueueURL, err) + a.logger.Errorf("Unable to receive message from queue %q, %v.", url, err) } if len(result.Messages) > 0 { @@ -158,8 +152,8 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error { msgHandle := m.ReceiptHandle // Use a background context here because ctx may be canceled already - a.Client.DeleteMessageWithContext(context.Background(), &sqs.DeleteMessageInput{ - QueueUrl: a.QueueURL, + a.authProvider.Sqs(ctx).Sqs.DeleteMessageWithContext(context.Background(), &sqs.DeleteMessageInput{ + QueueUrl: url, ReceiptHandle: msgHandle, }) } @@ -182,7 +176,7 @@ func (a *AWSSQS) Close() error { close(a.closeCh) } a.wg.Wait() - return nil + return a.authProvider.Close() } func (a *AWSSQS) parseSQSMetadata(meta bindings.Metadata) (*sqsMetadata, error) { diff --git a/common/authentication/aws/aws.go b/common/authentication/aws/aws.go index 83d8f350da..b841010152 100644 --- a/common/authentication/aws/aws.go +++ b/common/authentication/aws/aws.go @@ -18,41 +18,21 @@ import ( "errors" "fmt" "strconv" - "sync" "time" - awsv2 "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" v2creds "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/feature/rds/auth" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" + "github.com/dapr/kit/logger" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" - - "github.com/dapr/kit/logger" - kitmd "github.com/dapr/kit/metadata" ) type EnvironmentSettings struct { Metadata map[string]string } -type AWS struct { - mu sync.RWMutex - logger logger.Logger - - x509Auth *x509Auth - - region string - endpoint string - accessKey string - secretKey string - sessionToken string -} - type AWSIAM struct { // Ignored by metadata parser because included in built-in authentication profile // Access key to use for accessing PostgreSQL. @@ -63,122 +43,83 @@ type AWSIAM struct { AWSRegion string `json:"awsRegion" mapstructure:"awsRegion"` } -type Options struct { - Logger logger.Logger - Properties map[string]string - +type AWSIAMAuthOptions struct { PoolConfig *pgxpool.Config `json:"poolConfig" mapstructure:"poolConfig"` ConnectionString string `json:"connectionString" mapstructure:"connectionString"` Region string `json:"region" mapstructure:"region"` AccessKey string `json:"accessKey" mapstructure:"accessKey"` SecretKey string `json:"secretKey" mapstructure:"secretKey"` - SessionToken string `json:"sessionToken" mapstructure:"sessionToken"` - Endpoint string `json:"endpoint" mapstructure:"endpoint"` } -func New(opts Options) (*AWS, error) { - var x509AuthConfig x509Auth - if err := kitmd.DecodeMetadata(opts.Properties, &x509AuthConfig); err != nil { - return nil, err - } +type Options struct { + Logger logger.Logger + Properties map[string]string + + PoolConfig *pgxpool.Config `json:"poolConfig" mapstructure:"poolConfig"` + ConnectionString string `json:"connectionString" mapstructure:"connectionString"` + + Region string `json:"region" mapstructure:"region"` + AccessKey string `json:"accessKey" mapstructure:"accessKey"` + SecretKey string `json:"secretKey" mapstructure:"secretKey"` - return &AWS{ - x509Auth: &x509AuthConfig, - logger: opts.Logger, - region: opts.Region, - accessKey: opts.AccessKey, - secretKey: opts.SecretKey, - sessionToken: opts.SessionToken, - endpoint: opts.Endpoint, - }, nil + Endpoint string + SessionToken string } -func (a *AWS) GetConfig() *aws.Config { +func GetConfig(opts Options) *aws.Config { cfg := aws.NewConfig() - if a.region != "" { - cfg.WithRegion(a.region) + switch { + case opts.Region != "": + cfg.WithRegion(opts.Region) + case opts.Endpoint != "": + cfg.WithEndpoint(opts.Endpoint) } return cfg } -func GetConfigV2(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (awsv2.Config, error) { - optFns := []func(*config.LoadOptions) error{} - if region != "" { - optFns = append(optFns, config.WithRegion(region)) - } - - if accessKey != "" && secretKey != "" { - provider := v2creds.NewStaticCredentialsProvider(accessKey, secretKey, sessionToken) - optFns = append(optFns, config.WithCredentialsProvider(provider)) - } - - awsCfg, err := config.LoadDefaultConfig(context.Background(), optFns...) - if err != nil { - return awsv2.Config{}, err - } - - if endpoint != "" { - awsCfg.BaseEndpoint = &endpoint - } - - return awsCfg, nil +type Provider interface { + Initialize(ctx context.Context, opts Options, cfg *aws.Config) error + + S3(ctx context.Context) *S3Clients + DynamoDB(ctx context.Context) *DynamoDBClients + DynamoDBI(ctx context.Context) *DynamoDBClientsI + Sqs(ctx context.Context) *SqsClients + Sns(ctx context.Context) *SnsClients + SnsSqs(ctx context.Context) *SnsSqsClients + SecretManager(ctx context.Context) *SecretManagerClients + ParameterStore(ctx context.Context) *ParameterStoreClients + Kinesis(ctx context.Context) *KinesisClients + Ses(ctx context.Context) *SesClients + + Close() error } -func (a *AWS) GetClient(ctx context.Context) (*session.Session, error) { - a.mu.Lock() - defer a.mu.Unlock() +func isX509Auth(m map[string]string) bool { + tp, _ := m["trustProfileArn"] + ta, _ := m["trustAnchorArn"] + ar, _ := m["assumeRoleArn"] + return tp != "" && ta != "" && ar != "" +} - switch { - // IAM Roles Anywhere option - case a.x509Auth.TrustAnchorArn != nil && a.x509Auth.AssumeRoleArn != nil: - a.logger.Debug("using X.509 RolesAnywhere authentication using Dapr SVID") - session, err := a.getX509Client(ctx) +func NewProvider(ctx context.Context, opts Options, cfg *aws.Config) (Provider, error) { + if isX509Auth(opts.Properties) { + provider := &x509TempAuth{} + err := provider.Initialize(ctx, opts, cfg) if err != nil { - return nil, fmt.Errorf("failed to create X.509 RolesAnywhere client") + return nil, fmt.Errorf("failed to initialize AWS Roles Anywhere authentication: %v", err) } - // start a session refresher background goroutine to keep rotating the temporary creds - // use background context to keep alive - go a.startSessionRefresher(context.Background()) - - return session, nil - default: - a.logger.Debugf("using AWS session client...") - return a.getTokenClient() - } -} - -func (a *AWS) getTokenClient() (*session.Session, error) { - awsConfig := aws.NewConfig() - - if a.region != "" { - awsConfig = awsConfig.WithRegion(a.region) - } - - if a.accessKey != "" && a.secretKey != "" { - awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(a.accessKey, a.secretKey, a.sessionToken)) - } - - if a.endpoint != "" { - awsConfig = awsConfig.WithEndpoint(a.endpoint) + return provider, nil } - awsSession, err := session.NewSessionWithOptions(session.Options{ - Config: *awsConfig, - SharedConfigState: session.SharedConfigEnable, - }) + provider := &StaticAuth{} + err := provider.Initialize(ctx, opts, cfg) if err != nil { - return nil, err - } - - userAgentHandler := request.NamedHandler{ - Name: "UserAgentHandler", - Fn: request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion), + return nil, fmt.Errorf("failed to initialize AWS IAM authentication: %v", err) } - awsSession.Handlers.Build.PushBackNamed(userAgentHandler) + return provider, nil - return awsSession, nil } // NewEnvironmentSettings returns a new EnvironmentSettings configured for a given AWS resource. diff --git a/common/authentication/aws/aws_client.go b/common/authentication/aws/aws_client.go new file mode 100644 index 0000000000..8425bd19df --- /dev/null +++ b/common/authentication/aws/aws_client.go @@ -0,0 +1,215 @@ +/* +Copyright 2021 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + "errors" + "sync" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" + "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/aws/aws-sdk-go/service/secretsmanager" + "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" + "github.com/aws/aws-sdk-go/service/ses" + "github.com/aws/aws-sdk-go/service/sns" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/aws-sdk-go/service/ssm/ssmiface" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/dapr/kit/logger" + "github.com/vmware/vmware-go-kcl/clientlibrary/config" +) + +var log logger.Logger + +type Clients struct { + mu sync.RWMutex + + s3 *S3Clients + dynamo *DynamoDBClients + dynamoI *DynamoDBClientsI + sns *SnsClients + sqs *SqsClients + snssqs *SnsSqsClients + secret *SecretManagerClients + parameterStore *ParameterStoreClients + kinesis *KinesisClients + ses *SesClients +} + +func newClients() *Clients { + clients := &Clients{ + mu: sync.RWMutex{}, + } + return clients +} + +func (c *Clients) refresh(session *session.Session) { + c.mu.Lock() + defer c.mu.Unlock() + switch { + case c.s3 != nil: + c.s3.New(session) + case c.dynamo != nil: + c.dynamo.New(session) + case c.dynamoI != nil: + c.dynamoI.New(session) + case c.sns != nil: + c.sns.New(session) + case c.sqs != nil: + c.sqs.New(session) + case c.snssqs != nil: + c.snssqs.New(session) + case c.secret != nil: + c.secret.New(session) + case c.parameterStore != nil: + c.parameterStore.New(session) + case c.kinesis != nil: + c.kinesis.New(session) + case c.ses != nil: + c.ses.New(session) + } +} + +type S3Clients struct { + S3 *s3.S3 + Uploader *s3manager.Uploader + Downloader *s3manager.Downloader +} + +type DynamoDBClients struct { + DynamoDB *dynamodb.DynamoDB +} + +type DynamoDBClientsI struct { + DynamoDB dynamodbiface.DynamoDBAPI +} + +type SnsSqsClients struct { + Sns *sns.SNS + Sqs *sqs.SQS + Sts *sts.STS +} + +type SnsClients struct { + Sns *sns.SNS +} + +type SqsClients struct { + Sqs *sqs.SQS + queueURL *string +} + +type SecretManagerClients struct { + Manager secretsmanageriface.SecretsManagerAPI +} + +type ParameterStoreClients struct { + Store ssmiface.SSMAPI +} + +type KinesisClients struct { + Kinesis *kinesis.Kinesis +} + +type SesClients struct { + Ses *ses.SES +} + +func (c *S3Clients) New(session *session.Session) { + refreshedS3 := s3.New(session, session.Config) + c.S3 = refreshedS3 + c.Uploader = s3manager.NewUploaderWithClient(refreshedS3) + c.Downloader = s3manager.NewDownloaderWithClient(refreshedS3) +} + +func (c *DynamoDBClients) New(session *session.Session) { + c.DynamoDB = dynamodb.New(session, session.Config) +} + +func (c *DynamoDBClientsI) New(session *session.Session) { + c.DynamoDB = dynamodb.New(session, session.Config) +} + +func (c *SnsClients) New(session *session.Session) { + c.Sns = sns.New(session, session.Config) +} + +func (c *SnsSqsClients) New(session *session.Session) { + c.Sns = sns.New(session, session.Config) + c.Sqs = sqs.New(session, session.Config) +} + +func (c *SqsClients) New(session *session.Session) { + c.Sqs = sqs.New(session, session.Config) +} + +func (c *SqsClients) QueueURL(ctx context.Context, queueName string) (*string, error) { + if c.Sqs != nil { + resultURL, err := c.Sqs.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{ + QueueName: aws.String(queueName), + }) + return resultURL.QueueUrl, err + } + + return nil, errors.New("unable to get queue url due to empty client") +} + +func (c *SecretManagerClients) New(session *session.Session) { + c.Manager = secretsmanager.New(session, session.Config) +} + +func (c *ParameterStoreClients) New(session *session.Session) { + c.Store = ssm.New(session, session.Config) +} + +func (c *KinesisClients) New(session *session.Session) { + c.Kinesis = kinesis.New(session, session.Config) +} + +func (c *KinesisClients) Stream(ctx context.Context, streamName string) (*string, error) { + if c.Kinesis != nil { + stream, err := c.Kinesis.DescribeStreamWithContext(ctx, &kinesis.DescribeStreamInput{ + StreamName: aws.String(streamName), + }) + return stream.StreamDescription.StreamARN, err + } + + return nil, errors.New("unable to get stream arn due to empty client") +} + +func (c *KinesisClients) WorkerCfg(ctx context.Context, stream, consumer, mode string) *config.KinesisClientLibConfiguration { + const sharedMode = "shared" + if c.Kinesis != nil { + if mode == sharedMode { + kclConfig := config.NewKinesisClientLibConfigWithCredential(consumer, + stream, *c.Kinesis.Config.Region, consumer, + c.Kinesis.Config.Credentials) + + return kclConfig + } + } + + return nil +} + +func (c *SesClients) New(session *session.Session) { + c.Ses = ses.New(session, session.Config) +} diff --git a/common/authentication/aws/aws_test.go b/common/authentication/aws/aws_test.go new file mode 100644 index 0000000000..a60fc6570e --- /dev/null +++ b/common/authentication/aws/aws_test.go @@ -0,0 +1,30 @@ +package aws + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewEnvironmentSettings(t *testing.T) { + tests := []struct { + name string + metadata map[string]string + }{ + { + name: "valid metadata", + metadata: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := NewEnvironmentSettings(tt.metadata) + assert.NoError(t, err) + assert.NotNil(t, result) + }) + } +} diff --git a/common/authentication/aws/aws_x509.go b/common/authentication/aws/aws_x509.go deleted file mode 100644 index 0b77f9782f..0000000000 --- a/common/authentication/aws/aws_x509.go +++ /dev/null @@ -1,295 +0,0 @@ -/* -Copyright 2021 The Dapr Authors -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package aws - -import ( - "context" - "crypto/ecdsa" - "crypto/tls" - "crypto/x509" - "errors" - "fmt" - "net/http" - "runtime" - "time" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/arn" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" - awssh "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" - "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" - "github.com/aws/rolesanywhere-credential-helper/rolesanywhere/rolesanywhereiface" - cryptopem "github.com/dapr/kit/crypto/pem" - spiffecontext "github.com/dapr/kit/crypto/spiffe/context" - "github.com/dapr/kit/logger" - "github.com/dapr/kit/ptr" -) - -type x509Auth struct { - // todo unexport these fields except the channel - TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"` - TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"` - AssumeRoleArn *string `json:"assumeRoleArn" mapstructure:"assumeRoleArn"` - SessionDuration *time.Duration `json:"sessionDuration" mapstructure:"sessionDuration"` - sessionExpiration time.Time - - chainPEM []byte - keyPEM []byte - - sessionUpdateChannel chan *session.Session - - rolesAnywhereClient rolesanywhereiface.RolesAnywhereAPI -} - -func (a *AWS) getCertPEM(ctx context.Context) error { - // retrieve svid from spiffe context - svid, ok := spiffecontext.From(ctx) - if !ok { - return errors.New("no SVID found in context") - } - // get x.509 svid - svidx, err := svid.GetX509SVID() - if err != nil { - return err - } - - // marshal x.509 svid to pem format - chainPEM, keyPEM, err := svidx.Marshal() - if err != nil { - return fmt.Errorf("failed to marshal SVID: %w", err) - } - - a.x509Auth.chainPEM = chainPEM - a.x509Auth.keyPEM = keyPEM - return nil -} - -func (a *AWS) getX509Client(ctx context.Context) (*session.Session, error) { - // retrieve svid from spiffe context - err := a.getCertPEM(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get x.509 credentials: %v", err) - } - - if err := a.initializeTrustAnchors(); err != nil { - return nil, err - } - - if err := a.initializeRolesAnywhereClient(); err != nil { - return nil, err - } - - err = a.createOrRefreshSession(ctx) - if err != nil { - return nil, fmt.Errorf("failed to refresh token for new session client") - } - - return a.getTokenClient() -} - -func (a *AWS) initializeTrustAnchors() error { - var ( - trustAnchor arn.ARN - profile arn.ARN - err error - ) - if a.x509Auth.TrustAnchorArn != nil { - trustAnchor, err = arn.Parse(*a.x509Auth.TrustAnchorArn) - if err != nil { - return err - } - a.region = trustAnchor.Region - } - - if a.x509Auth.TrustProfileArn != nil { - profile, err = arn.Parse(*a.x509Auth.TrustProfileArn) - if err != nil { - return err - } - - if profile.Region != "" && trustAnchor.Region != profile.Region { - return fmt.Errorf("trust anchor and profile must be in the same region: trustAnchor=%s, profile=%s", - trustAnchor.Region, profile.Region) - } - } - return nil -} - -func (a *AWS) initializeRolesAnywhereClient() error { - if a.x509Auth.rolesAnywhereClient == nil { - client := &http.Client{Transport: &http.Transport{ - TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, - }} - mySession, err := session.NewSession() - if err != nil { - return err - } - config := aws.NewConfig().WithRegion(a.region).WithHTTPClient(client).WithLogLevel(aws.LogOff) - rolesAnywhereClient := rolesanywhere.New(mySession, config) - - // Set up signing function and handlers - if err := a.setSigningFunction(rolesAnywhereClient); err != nil { - return err - } - a.x509Auth.rolesAnywhereClient = rolesAnywhereClient - } - return nil - -} - -func (a *AWS) setSigningFunction(rolesAnywhereClient *rolesanywhere.RolesAnywhere) error { - certs, err := cryptopem.DecodePEMCertificatesChain(a.x509Auth.chainPEM) - if err != nil { - return err - } - - var ints []x509.Certificate - for i := range certs[1:] { - ints = append(ints, *certs[i+1]) - } - - key, err := cryptopem.DecodePEMPrivateKey(a.x509Auth.keyPEM) - if err != nil { - return err - } - - keyECDSA := key.(*ecdsa.PrivateKey) - signFunc := awssh.CreateSignFunction(*keyECDSA, *certs[0], ints) - - agentHandlerFunc := request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH) - rolesAnywhereClient.Handlers.Build.RemoveByName("core.SDKVersionUserAgentHandler") - rolesAnywhereClient.Handlers.Build.PushBackNamed(request.NamedHandler{Name: "v4x509.CredHelperUserAgentHandler", Fn: agentHandlerFunc}) - rolesAnywhereClient.Handlers.Sign.Clear() - rolesAnywhereClient.Handlers.Sign.PushBackNamed(request.NamedHandler{Name: "v4x509.SignRequestHandler", Fn: signFunc}) - - return nil -} - -func (a *AWS) createOrRefreshSession(ctx context.Context) error { - var ( - duration int64 - createSessionRequest rolesanywhere.CreateSessionInput - ) - - if *a.x509Auth.SessionDuration != 0 { - duration = int64(a.x509Auth.SessionDuration.Seconds()) - - createSessionRequest = rolesanywhere.CreateSessionInput{ - Cert: ptr.Of(string(a.x509Auth.chainPEM)), - ProfileArn: a.x509Auth.TrustProfileArn, - TrustAnchorArn: a.x509Auth.TrustAnchorArn, - RoleArn: a.x509Auth.AssumeRoleArn, - DurationSeconds: aws.Int64(duration), - InstanceProperties: nil, - SessionName: nil, - } - } else { - duration = 900 // 15 minutes in seconds by default and be autorefreshed - - createSessionRequest = rolesanywhere.CreateSessionInput{ - Cert: ptr.Of(string(a.x509Auth.chainPEM)), - ProfileArn: a.x509Auth.TrustProfileArn, - TrustAnchorArn: a.x509Auth.TrustAnchorArn, - RoleArn: a.x509Auth.AssumeRoleArn, - DurationSeconds: aws.Int64(duration), - InstanceProperties: nil, - SessionName: nil, - } - } - - output, err := a.x509Auth.rolesAnywhereClient.CreateSessionWithContext(ctx, &createSessionRequest) - if err != nil { - return fmt.Errorf("failed to create session using dapr app identity: %w", err) - } - if len(output.CredentialSet) != 1 { - return fmt.Errorf("expected 1 credential set from X.509 rolesanyway response, got %d", len(output.CredentialSet)) - } - - a.accessKey = *output.CredentialSet[0].Credentials.AccessKeyId - a.secretKey = *output.CredentialSet[0].Credentials.SecretAccessKey - a.sessionToken = *output.CredentialSet[0].Credentials.SessionToken - - // convert expiration time from *string to time.Time - expirationStr := output.CredentialSet[0].Credentials.Expiration - if expirationStr == nil { - return fmt.Errorf("expiration time is nil") - } - - expirationTime, err := time.Parse(time.RFC3339, *expirationStr) - if err != nil { - return fmt.Errorf("failed to parse expiration time: %w", err) - } - - a.x509Auth.sessionExpiration = expirationTime - - return nil -} - -func (a *AWS) startSessionRefresher(ctx context.Context) error { - a.logger.Debugf("starting session refresher for x509 auth") - // if there is a set session duration, then exit bc we will not auto refresh the session. - if *a.x509Auth.SessionDuration != 0 { - return nil - } - - errChan := make(chan error, 1) - go func() { - ticker := time.NewTicker(2 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - remaining := time.Until(a.x509Auth.sessionExpiration) - if remaining <= 8*time.Minute { - a.logger.Infof("Refreshing session as expiration is within %v", remaining) - - // Refresh the session - err := a.createOrRefreshSession(ctx) - if err != nil { - errChan <- fmt.Errorf("failed to refresh session: %w", err) - return - } - - a.logger.Debugf("AWS IAM Roles Anywhere session refreshed successfully") - refreshedSession, err := a.getTokenClient() - if err != nil { - errChan <- fmt.Errorf("failed to get token client with refreshed credentials: %v", err) - return - } - a.x509Auth.sessionUpdateChannel <- refreshedSession - - } - case <-ctx.Done(): - a.logger.Infof("Session refresher stopped due to context cancellation") - errChan <- nil - return - } - } - }() - - select { - case err := <-errChan: - return err - case <-ctx.Done(): - return ctx.Err() - } -} - -func (a *AWS) GetSessionUpdateChannel() chan *session.Session { - a.mu.Lock() - defer a.mu.Unlock() - return a.x509Auth.sessionUpdateChannel -} diff --git a/common/authentication/aws/static_iam.go b/common/authentication/aws/static_iam.go new file mode 100644 index 0000000000..a1b898164d --- /dev/null +++ b/common/authentication/aws/static_iam.go @@ -0,0 +1,257 @@ +/* +Copyright 2021 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + "fmt" + "sync" + + awsv2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + v2creds "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/dapr/kit/logger" +) + +type StaticAuth struct { + mu sync.RWMutex + logger logger.Logger + + region string + endpoint *string + accessKey *string + secretKey *string + sessionToken *string + + clients *Clients + session *session.Session + cfg *aws.Config +} + +func (a *StaticAuth) Initialize(_ context.Context, opts Options, cfg *aws.Config) error { + a.mu.Lock() + defer a.mu.Unlock() + + a.logger = opts.Logger + a.region = opts.Region + a.endpoint = &opts.Endpoint + a.accessKey = &opts.AccessKey + a.secretKey = &opts.SecretKey + a.sessionToken = &opts.SessionToken + a.cfg = cfg + a.clients = newClients() + + initialSession, err := a.getTokenClient() + if err != nil { + return fmt.Errorf("failed to get token client: %v", err) + } + + a.session = initialSession + + return nil +} + +func (a *StaticAuth) S3(ctx context.Context) *S3Clients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.s3 == nil { + s3Clients := S3Clients{} + a.clients.s3 = &s3Clients + a.clients.s3.New(a.session) + } + + return a.clients.s3 +} + +func (a *StaticAuth) DynamoDB(ctx context.Context) *DynamoDBClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.dynamo == nil { + clients := DynamoDBClients{} + a.clients.dynamo = &clients + a.clients.dynamo.New(a.session) + } + + return a.clients.dynamo +} + +func (a *StaticAuth) DynamoDBI(ctx context.Context) *DynamoDBClientsI { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.dynamoI == nil { + clients := DynamoDBClientsI{} + a.clients.dynamoI = &clients + a.clients.dynamoI.New(a.session) + } + + return a.clients.dynamoI +} + +func (a *StaticAuth) Sqs(ctx context.Context) *SqsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.sqs == nil { + clients := SqsClients{} + a.clients.sqs = &clients + a.clients.sqs.New(a.session) + } + + return a.clients.sqs +} + +func (a *StaticAuth) Sns(ctx context.Context) *SnsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.sns == nil { + clients := SnsClients{} + a.clients.sns = &clients + a.clients.sns.New(a.session) + } + + return a.clients.sns +} + +func (a *StaticAuth) SnsSqs(ctx context.Context) *SnsSqsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.snssqs == nil { + clients := SnsSqsClients{} + a.clients.snssqs = &clients + a.clients.snssqs.New(a.session) + } + + return a.clients.snssqs +} + +func (a *StaticAuth) SecretManager(ctx context.Context) *SecretManagerClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.secret == nil { + clients := SecretManagerClients{} + a.clients.secret = &clients + a.clients.secret.New(a.session) + } + + return a.clients.secret +} + +func (a *StaticAuth) ParameterStore(ctx context.Context) *ParameterStoreClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.parameterStore == nil { + clients := ParameterStoreClients{} + a.clients.parameterStore = &clients + a.clients.parameterStore.New(a.session) + } + + return a.clients.parameterStore +} + +func (a *StaticAuth) Kinesis(ctx context.Context) *KinesisClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.kinesis == nil { + clients := KinesisClients{} + a.clients.kinesis = &clients + a.clients.kinesis.New(a.session) + } + + return a.clients.kinesis +} + +func (a *StaticAuth) Ses(ctx context.Context) *SesClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.ses == nil { + clients := SesClients{} + a.clients.ses = &clients + a.clients.ses.New(a.session) + } + + return a.clients.ses +} + +func (a *StaticAuth) getTokenClient() (*session.Session, error) { + awsConfig := aws.NewConfig() + + if a.region != "" { + awsConfig = awsConfig.WithRegion(a.region) + } + + if a.accessKey != nil && a.secretKey != nil { + // session token is an option field + awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(*a.accessKey, *a.secretKey, *a.sessionToken)) + } + + if a.endpoint != nil { + awsConfig = awsConfig.WithEndpoint(*a.endpoint) + } + + awsSession, err := session.NewSessionWithOptions(session.Options{ + Config: *awsConfig, + SharedConfigState: session.SharedConfigEnable, + }) + if err != nil { + return nil, err + } + + userAgentHandler := request.NamedHandler{ + Name: "UserAgentHandler", + Fn: request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion), + } + awsSession.Handlers.Build.PushBackNamed(userAgentHandler) + + return awsSession, nil +} + +func (a *StaticAuth) Close() error { + return nil +} + +func GetConfigV2(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (awsv2.Config, error) { + optFns := []func(*config.LoadOptions) error{} + if region != "" { + optFns = append(optFns, config.WithRegion(region)) + } + + if accessKey != "" && secretKey != "" { + provider := v2creds.NewStaticCredentialsProvider(accessKey, secretKey, sessionToken) + optFns = append(optFns, config.WithCredentialsProvider(provider)) + } + + awsCfg, err := config.LoadDefaultConfig(context.Background(), optFns...) + if err != nil { + return awsv2.Config{}, err + } + + if endpoint != "" { + awsCfg.BaseEndpoint = &endpoint + } + + return awsCfg, nil +} diff --git a/common/authentication/aws/x509_iam.go b/common/authentication/aws/x509_iam.go new file mode 100644 index 0000000000..266528d554 --- /dev/null +++ b/common/authentication/aws/x509_iam.go @@ -0,0 +1,430 @@ +/* +Copyright 2021 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + "crypto/ecdsa" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "net/http" + "runtime" + "sync" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + awssh "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" + cryptopem "github.com/dapr/kit/crypto/pem" + spiffecontext "github.com/dapr/kit/crypto/spiffe/context" + "github.com/dapr/kit/logger" + kitmd "github.com/dapr/kit/metadata" + "github.com/dapr/kit/ptr" +) + +type x509TempAuth struct { + mu sync.RWMutex + logger logger.Logger + clients *Clients + session *session.Session + cfg *aws.Config + + chainPEM []byte + keyPEM []byte + + region *string + TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"` + TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"` + AssumeRoleArn *string `json:"assumeRoleArn" mapstructure:"assumeRoleArn"` + SessionDuration *time.Duration `json:"sessionDuration" mapstructure:"sessionDuration"` +} + +func (a *x509TempAuth) Initialize(ctx context.Context, opts Options, cfg *aws.Config) error { + var x509Auth x509TempAuth + if err := kitmd.DecodeMetadata(opts.Properties, &x509Auth); err != nil { + return err + } + + switch { + case x509Auth.TrustProfileArn == nil: + return errors.New("trustProfileArn is required") + case x509Auth.TrustAnchorArn == nil: + return errors.New("trustAnchorArn is required") + case x509Auth.AssumeRoleArn == nil: + return errors.New("assumeRoleArn is required") + case x509Auth.SessionDuration == nil: + awsDefaultDuration := time.Duration(900) // default 15m + x509Auth.SessionDuration = &awsDefaultDuration + } + + a.logger = opts.Logger + a.TrustProfileArn = x509Auth.TrustProfileArn + a.TrustAnchorArn = x509Auth.TrustAnchorArn + a.AssumeRoleArn = x509Auth.AssumeRoleArn + a.SessionDuration = x509Auth.SessionDuration + a.cfg = GetConfig(opts) + a.clients = newClients() + + err := a.getCertPEM(ctx) + if err != nil { + return fmt.Errorf("failed to get x.509 credentials: %v", err) + } + + // Parse trust anchor and profile ARNs + if err := a.initializeTrustAnchors(); err != nil { + return err + } + + initialSession, err := a.createOrRefreshSession(ctx) + if err != nil { + return fmt.Errorf("failed to create the initial session: %v", err) + } + a.session = initialSession + go a.startSessionRefresher(context.Background()) + + return nil +} + +func (a *x509TempAuth) Close() error { + return nil +} + +func (a *x509TempAuth) getCertPEM(ctx context.Context) error { + // retrieve svid from spiffe context + svid, ok := spiffecontext.From(ctx) + if !ok { + return fmt.Errorf("no SVID found in context") + } + // get x.509 svid + svidx, err := svid.GetX509SVID() + if err != nil { + return err + } + + // marshal x.509 svid to pem format + chainPEM, keyPEM, err := svidx.Marshal() + if err != nil { + return fmt.Errorf("failed to marshal SVID: %w", err) + } + + a.chainPEM = chainPEM + a.keyPEM = keyPEM + return nil +} + +func (a *x509TempAuth) S3(ctx context.Context) *S3Clients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.s3 == nil { + s3Clients := S3Clients{} + a.clients.s3 = &s3Clients + a.clients.s3.New(a.session) + } + + return a.clients.s3 +} + +func (a *x509TempAuth) DynamoDB(ctx context.Context) *DynamoDBClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.dynamo == nil { + clients := DynamoDBClients{} + a.clients.dynamo = &clients + a.clients.dynamo.New(a.session) + } + + return a.clients.dynamo +} + +func (a *x509TempAuth) DynamoDBI(ctx context.Context) *DynamoDBClientsI { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.dynamoI == nil { + clients := DynamoDBClientsI{} + a.clients.dynamoI = &clients + a.clients.dynamoI.New(a.session) + } + + return a.clients.dynamoI +} + +func (a *x509TempAuth) Sqs(ctx context.Context) *SqsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.sqs == nil { + clients := SqsClients{} + a.clients.sqs = &clients + a.clients.sqs.New(a.session) + } + + return a.clients.sqs +} + +func (a *x509TempAuth) Sns(ctx context.Context) *SnsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.sns == nil { + clients := SnsClients{} + a.clients.sns = &clients + a.clients.sns.New(a.session) + } + + return a.clients.sns +} + +func (a *x509TempAuth) SnsSqs(ctx context.Context) *SnsSqsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.snssqs == nil { + clients := SnsSqsClients{} + a.clients.snssqs = &clients + a.clients.snssqs.New(a.session) + } + + return a.clients.snssqs +} + +func (a *x509TempAuth) SecretManager(ctx context.Context) *SecretManagerClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.secret == nil { + clients := SecretManagerClients{} + a.clients.secret = &clients + a.clients.secret.New(a.session) + } + + return a.clients.secret +} + +func (a *x509TempAuth) ParameterStore(ctx context.Context) *ParameterStoreClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.parameterStore == nil { + clients := ParameterStoreClients{} + a.clients.parameterStore = &clients + a.clients.parameterStore.New(a.session) + } + + return a.clients.parameterStore +} + +func (a *x509TempAuth) Kinesis(ctx context.Context) *KinesisClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.kinesis == nil { + clients := KinesisClients{} + a.clients.kinesis = &clients + a.clients.kinesis.New(a.session) + } + + return a.clients.kinesis +} + +func (a *x509TempAuth) Ses(ctx context.Context) *SesClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.ses == nil { + clients := SesClients{} + a.clients.ses = &clients + a.clients.ses.New(a.session) + } + + return a.clients.ses +} + +func (a *x509TempAuth) initializeTrustAnchors() error { + var ( + trustAnchor arn.ARN + profile arn.ARN + err error + ) + if a.TrustAnchorArn != nil { + trustAnchor, err = arn.Parse(*a.TrustAnchorArn) + if err != nil { + return err + } + a.region = &trustAnchor.Region + } + + if a.TrustProfileArn != nil { + profile, err = arn.Parse(*a.TrustProfileArn) + if err != nil { + return err + } + + if profile.Region != "" && trustAnchor.Region != profile.Region { + return fmt.Errorf("trust anchor and profile must be in the same region: trustAnchor=%s, profile=%s", + trustAnchor.Region, profile.Region) + } + } + return nil +} + +func (a *x509TempAuth) setSigningFunction(rolesAnywhereClient *rolesanywhere.RolesAnywhere) error { + certs, err := cryptopem.DecodePEMCertificatesChain(a.chainPEM) + if err != nil { + return err + } + + var ints []x509.Certificate + for i := range certs[1:] { + ints = append(ints, *certs[i+1]) + } + + key, err := cryptopem.DecodePEMPrivateKey(a.keyPEM) + if err != nil { + return err + } + + keyECDSA := key.(*ecdsa.PrivateKey) + signFunc := awssh.CreateSignFunction(*keyECDSA, *certs[0], ints) + + agentHandlerFunc := request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH) + rolesAnywhereClient.Handlers.Build.RemoveByName("core.SDKVersionUserAgentHandler") + rolesAnywhereClient.Handlers.Build.PushBackNamed(request.NamedHandler{Name: "v4x509.CredHelperUserAgentHandler", Fn: agentHandlerFunc}) + rolesAnywhereClient.Handlers.Sign.Clear() + rolesAnywhereClient.Handlers.Sign.PushBackNamed(request.NamedHandler{Name: "v4x509.SignRequestHandler", Fn: signFunc}) + + return nil +} + +func (a *x509TempAuth) createOrRefreshSession(ctx context.Context) (*session.Session, error) { + a.mu.Lock() + defer a.mu.Unlock() + + client := &http.Client{Transport: &http.Transport{ + TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, + }} + var mySession *session.Session + var err error + + config := a.cfg.WithRegion(*a.region).WithHTTPClient(client).WithLogLevel(aws.LogOff) + mySession = session.Must(session.NewSession(config)) + rolesAnywhereClient := rolesanywhere.New(mySession, config) + + // Set up signing function and handlers + if err := a.setSigningFunction(rolesAnywhereClient); err != nil { + return nil, err + } + + var ( + duration int64 + createSessionRequest rolesanywhere.CreateSessionInput + ) + if *a.SessionDuration != 0 { + duration = int64(a.SessionDuration.Seconds()) + + createSessionRequest = rolesanywhere.CreateSessionInput{ + Cert: ptr.Of(string(a.chainPEM)), + ProfileArn: a.TrustProfileArn, + TrustAnchorArn: a.TrustAnchorArn, + RoleArn: a.AssumeRoleArn, + DurationSeconds: aws.Int64(duration), + InstanceProperties: nil, + SessionName: nil, + } + } else { + duration = 900 // 15 minutes in seconds by default and be autorefreshed + + createSessionRequest = rolesanywhere.CreateSessionInput{ + Cert: ptr.Of(string(a.chainPEM)), + ProfileArn: a.TrustProfileArn, + TrustAnchorArn: a.TrustAnchorArn, + RoleArn: a.AssumeRoleArn, + DurationSeconds: aws.Int64(duration), + InstanceProperties: nil, + SessionName: nil, + } + } + + output, err := rolesAnywhereClient.CreateSessionWithContext(ctx, &createSessionRequest) + if err != nil { + return nil, fmt.Errorf("failed to create session using dapr app identity: %w", err) + } + + if output == nil || len(output.CredentialSet) != 1 { + return nil, fmt.Errorf("expected 1 credential set from X.509 rolesanyway response, got %d", len(output.CredentialSet)) + } + + accessKey := output.CredentialSet[0].Credentials.AccessKeyId + secretKey := output.CredentialSet[0].Credentials.SecretAccessKey + sessionToken := output.CredentialSet[0].Credentials.SessionToken + awsCreds := credentials.NewStaticCredentials(*accessKey, *secretKey, *sessionToken) + sess := session.Must(session.NewSession(&aws.Config{ + Credentials: awsCreds, + }, config)) + if sess == nil { + return nil, fmt.Errorf("sam session is nil somehow %v", sess) + } + + return sess, nil +} + +func (a *x509TempAuth) startSessionRefresher(ctx context.Context) error { + // if there is a set session duration, then exit bc we will not auto refresh the session. + if *a.SessionDuration != 0 { + return nil + } + + a.logger.Debugf("starting session refresher for x509 auth") + errChan := make(chan error, 1) + go func() { + // renew at ~half the lifespan + ticker := time.NewTicker(8 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + a.logger.Infof("Refreshing session as expiration is near") + newSession, err := a.createOrRefreshSession(ctx) + if err != nil { + errChan <- fmt.Errorf("failed to refresh session: %w", err) + return + } + + a.clients.refresh(newSession) + + a.logger.Debugf("AWS IAM Roles Anywhere session credentials refreshed successfully") + case <-ctx.Done(): + a.logger.Debugf("Session refresher stopped due to context cancellation") + errChan <- nil + return + } + } + }() + + select { + case err := <-errChan: + return err + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/go.mod b/go.mod index 16d28420b0..a8ece053a6 100644 --- a/go.mod +++ b/go.mod @@ -107,6 +107,7 @@ require ( github.com/sendgrid/sendgrid-go v3.13.0+incompatible github.com/sijms/go-ora/v2 v2.7.18 github.com/spf13/cast v1.5.1 + github.com/spiffe/go-spiffe/v2 v2.1.7 github.com/stealthrocket/wasi-go v0.8.1-0.20230912180546-8efbab50fb58 github.com/stretchr/testify v1.9.0 github.com/supplyon/gremcos v0.1.40 @@ -360,7 +361,6 @@ require ( github.com/sourcegraph/conc v0.3.0 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/spiffe/go-spiffe/v2 v2.1.7 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/tchap/go-patricia/v2 v2.3.1 // indirect github.com/tidwall/gjson v1.14.4 // indirect @@ -405,6 +405,7 @@ require ( google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240513163218-0867130af1f8 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240509183442-62759503f434 // indirect + google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gopkg.in/couchbase/gocbcore.v7 v7.1.18 // indirect gopkg.in/couchbaselabs/gocbconnstr.v1 v1.0.4 // indirect diff --git a/go.sum b/go.sum index 610b654868..9bafb4a502 100644 --- a/go.sum +++ b/go.sum @@ -124,6 +124,8 @@ github.com/HdrHistogram/hdrhistogram-go v1.1.2/go.mod h1:yDgFjdqOqDEKOvasDdhWNXY github.com/IBM/sarama v1.43.3 h1:Yj6L2IaNvb2mRBop39N7mmJAHBVY3dTPncr3qGVkxPA= github.com/IBM/sarama v1.43.3/go.mod h1:FVIRaLrhK3Cla/9FfRF5X9Zua2KpS3SYIXxhac1H+FQ= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= +github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= +github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= github.com/Netflix/go-env v0.0.0-20220526054621-78278af1949d h1:wvStE9wLpws31NiWUx+38wny1msZ/tm+eL5xmm4Y7So= github.com/Netflix/go-env v0.0.0-20220526054621-78278af1949d/go.mod h1:9XMFaCeRyW7fC9XJOWQ+NdAv8VLG7ys7l3x4ozEGLUQ= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index d782320c69..4eca703f45 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -49,9 +49,7 @@ type snsSqs struct { queues map[string]*sqsQueueInfo // key is a composite key of queue ARN and topic ARN mapping to subscription ARN. subscriptions map[string]string - snsClient *sns.SNS - sqsClient *sqs.SQS - stsClient *sts.STS + authProvider awsAuth.Provider metadata *snsSqsMetadata logger logger.Logger id string @@ -138,36 +136,32 @@ func nameToAWSSanitizedName(name string, isFifo bool) string { } func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error { - md, err := s.getSnsSqsMetatdata(metadata) + m, err := s.getSnsSqsMetatdata(metadata) if err != nil { return err } - s.metadata = md + s.metadata = m - aws, err := awsAuth.New(awsAuth.Options{ - Logger: s.logger, - Properties: metadata.Properties, - Region: md.Region, - AccessKey: md.AccessKey, - SecretKey: md.SecretKey, - SessionToken: md.SessionToken, - }) - if err != nil { - return err - } - - sess, err := aws.GetClient(ctx) - if err != nil { - return err + if s.authProvider == nil { + opts := awsAuth.Options{ + Logger: s.logger, + Properties: metadata.Properties, + Region: m.Region, + Endpoint: m.Endpoint, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, + } + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) + if err != nil { + return err + } + s.authProvider = provider } - // AWS sns,sqs,sts client. - s.snsClient = sns.New(sess) - s.sqsClient = sqs.New(sess) - s.stsClient = sts.New(sess) - - s.opsTimeout = time.Duration(md.AssetsManagementTimeoutSeconds * float64(time.Second)) + s.opsTimeout = time.Duration(m.AssetsManagementTimeoutSeconds * float64(time.Second)) err = s.setAwsAccountIDIfNotProvided(ctx) if err != nil { @@ -196,7 +190,7 @@ func (s *snsSqs) setAwsAccountIDIfNotProvided(parentCtx context.Context) error { } ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) - callerIDOutput, err := s.stsClient.GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{}) + callerIDOutput, err := s.authProvider.SnsSqs(ctx).Sts.GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{}) cancelFn() if err != nil { return fmt.Errorf("error fetching sts caller ID: %w", err) @@ -223,7 +217,7 @@ func (s *snsSqs) createTopic(parentCtx context.Context, topic string) (string, e } ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) - createTopicResponse, err := s.snsClient.CreateTopicWithContext(ctx, snsCreateTopicInput) + createTopicResponse, err := s.authProvider.SnsSqs(ctx).Sns.CreateTopicWithContext(ctx, snsCreateTopicInput) cancelFn() if err != nil { return "", fmt.Errorf("error while creating an SNS topic: %w", err) @@ -235,7 +229,7 @@ func (s *snsSqs) createTopic(parentCtx context.Context, topic string) (string, e func (s *snsSqs) getTopicArn(parentCtx context.Context, topic string) (string, error) { ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) arn := s.buildARN("sns", topic) - getTopicOutput, err := s.snsClient.GetTopicAttributesWithContext(ctx, &sns.GetTopicAttributesInput{ + getTopicOutput, err := s.authProvider.SnsSqs(ctx).Sns.GetTopicAttributesWithContext(ctx, &sns.GetTopicAttributesInput{ TopicArn: &arn, }) cancelFn() @@ -302,14 +296,14 @@ func (s *snsSqs) createQueue(parentCtx context.Context, queueName string) (*sqsQ sqsCreateQueueInput.SetAttributes(attributes) } ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - createQueueResponse, err := s.sqsClient.CreateQueueWithContext(ctx, sqsCreateQueueInput) + createQueueResponse, err := s.authProvider.SnsSqs(ctx).Sqs.CreateQueueWithContext(ctx, sqsCreateQueueInput) cancel() if err != nil { return nil, fmt.Errorf("error creaing an SQS queue: %w", err) } ctx, cancel = context.WithTimeout(parentCtx, s.opsTimeout) - queueAttributesResponse, err := s.sqsClient.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ + queueAttributesResponse, err := s.authProvider.SnsSqs(ctx).Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ AttributeNames: []*string{aws.String("QueueArn")}, QueueUrl: createQueueResponse.QueueUrl, }) @@ -326,7 +320,7 @@ func (s *snsSqs) createQueue(parentCtx context.Context, queueName string) (*sqsQ func (s *snsSqs) getQueueArn(parentCtx context.Context, queueName string) (*sqsQueueInfo, error) { ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - queueURLOutput, err := s.sqsClient.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{QueueName: aws.String(queueName), QueueOwnerAWSAccountId: aws.String(s.metadata.AccountID)}) + queueURLOutput, err := s.authProvider.SnsSqs(ctx).Sqs.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{QueueName: aws.String(queueName), QueueOwnerAWSAccountId: aws.String(s.metadata.AccountID)}) cancel() if err != nil { return nil, fmt.Errorf("error: %w while getting url of queue: %s", err, queueName) @@ -334,7 +328,7 @@ func (s *snsSqs) getQueueArn(parentCtx context.Context, queueName string) (*sqsQ url := queueURLOutput.QueueUrl ctx, cancel = context.WithTimeout(parentCtx, s.opsTimeout) - getQueueOutput, err := s.sqsClient.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{QueueUrl: url, AttributeNames: []*string{aws.String("QueueArn")}}) + getQueueOutput, err := s.authProvider.SnsSqs(ctx).Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{QueueUrl: url, AttributeNames: []*string{aws.String("QueueArn")}}) cancel() if err != nil { return nil, fmt.Errorf("error: %w while getting information for queue: %s, with url: %s", err, queueName, *url) @@ -395,7 +389,7 @@ func (s *snsSqs) getMessageGroupID(req *pubsub.PublishRequest) *string { func (s *snsSqs) createSnsSqsSubscription(parentCtx context.Context, queueArn, topicArn string) (string, error) { ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - subscribeOutput, err := s.snsClient.SubscribeWithContext(ctx, &sns.SubscribeInput{ + subscribeOutput, err := s.authProvider.SnsSqs(ctx).Sns.SubscribeWithContext(ctx, &sns.SubscribeInput{ Attributes: nil, Endpoint: aws.String(queueArn), // create SQS queue per subscription. Protocol: aws.String("sqs"), @@ -415,7 +409,7 @@ func (s *snsSqs) createSnsSqsSubscription(parentCtx context.Context, queueArn, t func (s *snsSqs) getSnsSqsSubscriptionArn(parentCtx context.Context, topicArn string) (string, error) { ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - listSubscriptionsOutput, err := s.snsClient.ListSubscriptionsByTopicWithContext(ctx, &sns.ListSubscriptionsByTopicInput{TopicArn: aws.String(topicArn)}) + listSubscriptionsOutput, err := s.authProvider.SnsSqs(ctx).Sns.ListSubscriptionsByTopicWithContext(ctx, &sns.ListSubscriptionsByTopicInput{TopicArn: aws.String(topicArn)}) cancel() if err != nil { return "", fmt.Errorf("error listing subsriptions for topic arn: %v: %w", topicArn, err) @@ -464,7 +458,7 @@ func (s *snsSqs) getOrCreateSnsSqsSubscription(ctx context.Context, queueArn, to func (s *snsSqs) acknowledgeMessage(parentCtx context.Context, queueURL string, receiptHandle *string) error { ctx, cancelFn := context.WithCancel(parentCtx) - _, err := s.sqsClient.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{ + _, err := s.authProvider.SnsSqs(ctx).Sqs.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{ QueueUrl: aws.String(queueURL), ReceiptHandle: receiptHandle, }) @@ -479,7 +473,7 @@ func (s *snsSqs) acknowledgeMessage(parentCtx context.Context, queueURL string, func (s *snsSqs) resetMessageVisibilityTimeout(parentCtx context.Context, queueURL string, receiptHandle *string) error { ctx, cancelFn := context.WithCancel(parentCtx) // reset the timeout to its initial value so that the remaining timeout would be overridden by the initial value for other consumer to attempt processing. - _, err := s.sqsClient.ChangeMessageVisibilityWithContext(ctx, &sqs.ChangeMessageVisibilityInput{ + _, err := s.authProvider.SnsSqs(ctx).Sqs.ChangeMessageVisibilityWithContext(ctx, &sqs.ChangeMessageVisibilityInput{ QueueUrl: aws.String(queueURL), ReceiptHandle: receiptHandle, VisibilityTimeout: aws.Int64(0), @@ -611,7 +605,7 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters // sqs and try pull messages. Since we are iteratively short polling (based on the defined // s.metadata.messageWaitTimeSeconds) the sdk backoff is not effective as it gets reset per each polling // iteration. Therefore, a global backoff (to the internal backoff) is used (sqsPullExponentialBackoff). - messageResponse, err := s.sqsClient.ReceiveMessageWithContext(ctx, receiveMessageInput) + messageResponse, err := s.authProvider.SnsSqs(ctx).Sqs.ReceiveMessageWithContext(ctx, receiveMessageInput) if err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || ctx.Err() != nil { s.logger.Warn("context canceled; stopping consuming from queue arn: %v", queueInfo.arn) @@ -705,7 +699,7 @@ func (s *snsSqs) setDeadLettersQueueAttributes(parentCtx context.Context, queueI } ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) - _, derr = s.sqsClient.SetQueueAttributesWithContext(ctx, sqsSetQueueAttributesInput) + _, derr = s.authProvider.SnsSqs(ctx).Sqs.SetQueueAttributesWithContext(ctx, sqsSetQueueAttributesInput) cancelFn() if derr != nil { wrappedErr := fmt.Errorf("error updating queue attributes with dead-letter queue: %w", derr) @@ -725,7 +719,7 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(parentCtx context.Context, ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) // only permit SNS to send messages to SQS using the created subscription. - getQueueAttributesOutput, err := s.sqsClient.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ + getQueueAttributesOutput, err := s.authProvider.SnsSqs(ctx).Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ QueueUrl: &sqsQueueInfo.url, AttributeNames: []*string{aws.String(sqs.QueueAttributeNamePolicy)}, }) @@ -752,7 +746,7 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(parentCtx context.Context, } ctx, cancelFn = context.WithTimeout(parentCtx, s.opsTimeout) - _, err = s.sqsClient.SetQueueAttributesWithContext(ctx, &(sqs.SetQueueAttributesInput{ + _, err = s.authProvider.SnsSqs(ctx).Sqs.SetQueueAttributesWithContext(ctx, &(sqs.SetQueueAttributesInput{ Attributes: map[string]*string{ "Policy": aws.String(string(b)), }, @@ -865,7 +859,7 @@ func (s *snsSqs) Publish(ctx context.Context, req *pubsub.PublishRequest) error } // sns client has internal exponential backoffs. - _, err = s.snsClient.PublishWithContext(ctx, snsPublishInput) + _, err = s.authProvider.SnsSqs(ctx).Sns.PublishWithContext(ctx, snsPublishInput) if err != nil { wrappedErr := fmt.Errorf("error publishing to topic: %s with topic ARN %s: %w", req.Topic, topicArn, err) s.logger.Error(wrappedErr) diff --git a/secretstores/aws/parameterstore/parameterstore.go b/secretstores/aws/parameterstore/parameterstore.go index 1735bbd674..84c2dc05a4 100644 --- a/secretstores/aws/parameterstore/parameterstore.go +++ b/secretstores/aws/parameterstore/parameterstore.go @@ -20,7 +20,6 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ssm" - "github.com/aws/aws-sdk-go/service/ssm/ssmiface" awsAuth "github.com/dapr/components-contrib/common/authentication/aws" "github.com/dapr/components-contrib/metadata" @@ -53,39 +52,36 @@ type ParameterStoreMetaData struct { } type ssmSecretStore struct { - client ssmiface.SSMAPI - prefix string - logger logger.Logger + authProvider awsAuth.Provider + prefix string + logger logger.Logger } // Init creates an AWS secret manager client. func (s *ssmSecretStore) Init(ctx context.Context, metadata secretstores.Metadata) error { - meta, err := s.getSecretManagerMetadata(metadata) + m, err := s.getSecretManagerMetadata(metadata) if err != nil { return err } - if s.client == nil { - awsA, err := awsAuth.New(awsAuth.Options{ + if s.authProvider == nil { + opts := awsAuth.Options{ Logger: s.logger, Properties: metadata.Properties, - Region: meta.Region, - AccessKey: meta.AccessKey, - SecretKey: meta.SecretKey, - SessionToken: meta.SessionToken, - }) - if err != nil { - return err + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: "", } - - session, err := awsA.GetClient(ctx) + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) if err != nil { return err } - s.client = ssm.New(session) + s.authProvider = provider } - s.prefix = meta.Prefix + s.prefix = m.Prefix return nil } @@ -100,7 +96,7 @@ func (s *ssmSecretStore) GetSecret(ctx context.Context, req secretstores.GetSecr name = fmt.Sprintf("%s:%s", req.Name, versionID) } - output, err := s.client.GetParameterWithContext(ctx, &ssm.GetParameterInput{ + output, err := s.authProvider.ParameterStore(ctx).Store.GetParameterWithContext(ctx, &ssm.GetParameterInput{ Name: ptr.Of(s.prefix + name), WithDecryption: ptr.Of(true), }) @@ -140,7 +136,7 @@ func (s *ssmSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bul } for search { - output, err := s.client.DescribeParametersWithContext(ctx, &ssm.DescribeParametersInput{ + output, err := s.authProvider.ParameterStore(ctx).Store.DescribeParametersWithContext(ctx, &ssm.DescribeParametersInput{ MaxResults: nil, NextToken: nextToken, ParameterFilters: filters, @@ -150,7 +146,7 @@ func (s *ssmSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bul } for _, entry := range output.Parameters { - params, err := s.client.GetParameterWithContext(ctx, &ssm.GetParameterInput{ + params, err := s.authProvider.ParameterStore(ctx).Store.GetParameterWithContext(ctx, &ssm.GetParameterInput{ Name: entry.Name, WithDecryption: aws.Bool(true), }) diff --git a/secretstores/aws/secretmanager/secretmanager.go b/secretstores/aws/secretmanager/secretmanager.go index 2010c5ada8..e50068b109 100644 --- a/secretstores/aws/secretmanager/secretmanager.go +++ b/secretstores/aws/secretmanager/secretmanager.go @@ -19,8 +19,8 @@ import ( "fmt" "reflect" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/secretsmanager" - "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" awsAuth "github.com/dapr/components-contrib/common/authentication/aws" "github.com/dapr/components-contrib/metadata" @@ -48,8 +48,8 @@ type SecretManagerMetaData struct { } type smSecretStore struct { - client secretsmanageriface.SecretsManagerAPI - logger logger.Logger + authProvider awsAuth.Provider + logger logger.Logger } // Init creates an AWS secret manager client. @@ -59,25 +59,20 @@ func (s *smSecretStore) Init(ctx context.Context, metadata secretstores.Metadata return err } - if s.client == nil { - awsA, err := awsAuth.New(awsAuth.Options{ + if s.authProvider == nil { + opts := awsAuth.Options{ Logger: s.logger, - Properties: metadata.Properties, Region: meta.Region, AccessKey: meta.AccessKey, SecretKey: meta.SecretKey, - SessionToken: meta.SessionToken, - }) - if err != nil { - return err + SessionToken: "", } - session, err := awsA.GetClient(ctx) + provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) if err != nil { return err } - - s.client = secretsmanager.New(session) + s.authProvider = provider } return nil @@ -94,7 +89,7 @@ func (s *smSecretStore) GetSecret(ctx context.Context, req secretstores.GetSecre versionStage = &value } - output, err := s.client.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ + output, err := s.authProvider.SecretManager(ctx).Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ SecretId: &req.Name, VersionId: versionID, VersionStage: versionStage, @@ -123,7 +118,7 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk var nextToken *string = nil for search { - output, err := s.client.ListSecretsWithContext(ctx, &secretsmanager.ListSecretsInput{ + output, err := s.authProvider.SecretManager(ctx).Manager.ListSecretsWithContext(ctx, &secretsmanager.ListSecretsInput{ MaxResults: nil, NextToken: nextToken, }) @@ -132,7 +127,7 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk } for _, entry := range output.SecretList { - secrets, err := s.client.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ + secrets, err := s.authProvider.SecretManager(ctx).Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ SecretId: entry.Name, }) if err != nil { diff --git a/state/aws/dynamodb/dynamodb.go b/state/aws/dynamodb/dynamodb.go index 77642a1f8c..a1fa210ab2 100644 --- a/state/aws/dynamodb/dynamodb.go +++ b/state/aws/dynamodb/dynamodb.go @@ -23,9 +23,9 @@ import ( "strconv" "time" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" jsoniterator "github.com/json-iterator/go" awsAuth "github.com/dapr/components-contrib/common/authentication/aws" @@ -42,7 +42,7 @@ type StateStore struct { state.BulkStore logger logger.Logger - client dynamodbiface.DynamoDBAPI + authProvider awsAuth.Provider table string ttlAttributeName string partitionKey string @@ -83,26 +83,21 @@ func (d *StateStore) Init(ctx context.Context, metadata state.Metadata) error { } // This check is needed because d.client is set to a mock in tests - if d.client == nil { - aws, err := awsAuth.New(awsAuth.Options{ + if d.authProvider == nil { + opts := awsAuth.Options{ Logger: d.logger, Properties: metadata.Properties, Region: meta.Region, + Endpoint: meta.Endpoint, AccessKey: meta.AccessKey, SecretKey: meta.SecretKey, SessionToken: meta.SessionToken, - Endpoint: meta.Endpoint, - }) - if err != nil { - return err } - - sess, err := aws.GetClient(ctx) + provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) if err != nil { return err } - - d.client = dynamodb.New(sess) + d.authProvider = provider } d.table = meta.Table d.ttlAttributeName = meta.TTLAttributeName @@ -128,7 +123,7 @@ func (d *StateStore) validateTableAccess(ctx context.Context) error { }, } - _, err := d.client.GetItemWithContext(ctx, input) + _, err := d.authProvider.DynamoDBI(ctx).DynamoDB.GetItemWithContext(ctx, input) return err } @@ -161,7 +156,7 @@ func (d *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.Get }, } - result, err := d.client.GetItemWithContext(ctx, input) + result, err := d.authProvider.DynamoDBI(ctx).DynamoDB.GetItemWithContext(ctx, input) if err != nil { return nil, err } @@ -234,7 +229,7 @@ func (d *StateStore) Set(ctx context.Context, req *state.SetRequest) error { input.ConditionExpression = &condExpr } - _, err = d.client.PutItemWithContext(ctx, input) + _, err = d.authProvider.DynamoDBI(ctx).DynamoDB.PutItemWithContext(ctx, input) if err != nil && req.HasETag() { switch cErr := err.(type) { case *dynamodb.ConditionalCheckFailedException: @@ -266,7 +261,7 @@ func (d *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error input.ExpressionAttributeValues = exprAttrValues } - _, err := d.client.DeleteItemWithContext(ctx, input) + _, err := d.authProvider.DynamoDBI(ctx).DynamoDB.DeleteItemWithContext(ctx, input) if err != nil { switch cErr := err.(type) { case *dynamodb.ConditionalCheckFailedException: @@ -438,7 +433,7 @@ func (d *StateStore) Multi(ctx context.Context, request *state.TransactionalStat twinput.TransactItems = append(twinput.TransactItems, twi) } - _, err := d.client.TransactWriteItemsWithContext(ctx, twinput) + _, err := d.authProvider.DynamoDBI(ctx).DynamoDB.TransactWriteItemsWithContext(ctx, twinput) return err } diff --git a/tests/certification/go.mod b/tests/certification/go.mod index 51befe4ed7..3fe60b4224 100644 --- a/tests/certification/go.mod +++ b/tests/certification/go.mod @@ -291,6 +291,7 @@ require ( github.com/tklauser/numcpus v0.6.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.51.0 // indirect + github.com/vmware/vmware-go-kcl v1.5.1 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.1.2 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect @@ -334,6 +335,7 @@ require ( google.golang.org/grpc v1.64.1 // indirect google.golang.org/protobuf v1.34.2 // indirect gopkg.in/inf.v0 v0.9.1 // indirect + gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect k8s.io/api v0.30.2 // indirect diff --git a/tests/certification/go.sum b/tests/certification/go.sum index 04311f0e9f..d21aec5c3c 100644 --- a/tests/certification/go.sum +++ b/tests/certification/go.sum @@ -1383,6 +1383,8 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g= +github.com/vmware/vmware-go-kcl v1.5.1 h1:1rJLfAX4sDnCyatNoD/WJzVafkwST6u/cgY/Uf2VgHk= +github.com/vmware/vmware-go-kcl v1.5.1/go.mod h1:kXJmQ6h0dRMRrp1uWU9XbIXvwelDpTxSPquvQUBdpbo= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= From 423e993fa19b4dbf64f6ad24a451c08cb3c55276 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Tue, 12 Nov 2024 12:16:44 -0600 Subject: [PATCH 12/39] fix: address initial feedback and fix tests Signed-off-by: Samantha Coyle --- .../builtin-authentication-profiles.yaml | 2 +- bindings/aws/dynamodb/dynamodb.go | 40 +- bindings/aws/kinesis/kinesis.go | 82 ++- bindings/aws/s3/s3.go | 70 +- bindings/aws/ses/ses.go | 38 +- bindings/aws/sns/sns.go | 37 +- bindings/aws/sqs/sqs.go | 55 +- common/authentication/aws/aws.go | 36 +- .../aws/{aws_client.go => client.go} | 42 +- common/authentication/aws/static.go | 361 ++++++++++ common/authentication/aws/static_iam.go | 257 ------- common/authentication/aws/x509.go | 558 +++++++++++++++ common/authentication/aws/x509_iam.go | 430 ----------- pubsub/aws/snssqs/metadata.go | 2 +- pubsub/aws/snssqs/snssqs.go | 103 ++- pubsub/aws/snssqs/snssqs_test.go | 22 +- .../aws/parameterstore/parameterstore.go | 50 +- .../aws/parameterstore/parameterstore_test.go | 333 ++++++--- .../aws/secretmanager/secretmanager.go | 42 +- state/aws/dynamodb/dynamodb.go | 70 +- state/aws/dynamodb/dynamodb_test.go | 676 +++++++++++++----- 21 files changed, 2041 insertions(+), 1265 deletions(-) rename common/authentication/aws/{aws_client.go => client.go} (80%) create mode 100644 common/authentication/aws/static.go delete mode 100644 common/authentication/aws/static_iam.go create mode 100644 common/authentication/aws/x509.go delete mode 100644 common/authentication/aws/x509_iam.go diff --git a/.build-tools/builtin-authentication-profiles.yaml b/.build-tools/builtin-authentication-profiles.yaml index 2abe6ac3df..4b25dab0ff 100644 --- a/.build-tools/builtin-authentication-profiles.yaml +++ b/.build-tools/builtin-authentication-profiles.yaml @@ -30,7 +30,7 @@ aws: - title: "AWS: Credentials from Environment Variables" description: Use AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY from the environment - title: "AWS: IAM Roles Anywhere" - description: Use x.509 certificates to establish trust between AWS and your AWS account and the Dapr cluster using AWS IAM Roles Anywhere. + description: Use X.509 certificates to establish trust between AWS and your AWS account and the Dapr cluster using AWS IAM Roles Anywhere. metadata: - name: trustAnchorArn description: | diff --git a/bindings/aws/dynamodb/dynamodb.go b/bindings/aws/dynamodb/dynamodb.go index 031e751dbf..823894b1ed 100644 --- a/bindings/aws/dynamodb/dynamodb.go +++ b/bindings/aws/dynamodb/dynamodb.go @@ -16,6 +16,7 @@ package dynamodb import ( "context" "encoding/json" + "fmt" "reflect" "github.com/aws/aws-sdk-go/aws" @@ -56,23 +57,22 @@ func (d *DynamoDB) Init(ctx context.Context, metadata bindings.Metadata) error { if err != nil { return err } - if d.authProvider == nil { - opts := awsAuth.Options{ - Logger: d.logger, - Properties: metadata.Properties, - Region: meta.Region, - Endpoint: meta.Endpoint, - AccessKey: meta.AccessKey, - SecretKey: meta.SecretKey, - SessionToken: meta.SessionToken, - } - - provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) - if err != nil { - return err - } - d.authProvider = provider + + opts := awsAuth.Options{ + Logger: d.logger, + Properties: metadata.Properties, + Region: meta.Region, + Endpoint: meta.Endpoint, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, + } + + provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) + if err != nil { + return err } + d.authProvider = provider d.table = meta.Table return nil @@ -94,7 +94,11 @@ func (d *DynamoDB) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bi return nil, err } - _, err = d.authProvider.DynamoDB(ctx).DynamoDB.PutItemWithContext(ctx, &dynamodb.PutItemInput{ + clients, err := d.authProvider.DynamoDB(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get client: %v", err) + } + _, err = clients.DynamoDB.PutItemWithContext(ctx, &dynamodb.PutItemInput{ Item: item, TableName: aws.String(d.table), }) @@ -123,5 +127,5 @@ func (d *DynamoDB) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { } func (d *DynamoDB) Close() error { - return nil + return d.authProvider.Close() } diff --git a/bindings/aws/kinesis/kinesis.go b/bindings/aws/kinesis/kinesis.go index 8d2418821e..f43cf4330f 100644 --- a/bindings/aws/kinesis/kinesis.go +++ b/bindings/aws/kinesis/kinesis.go @@ -117,23 +117,20 @@ func (a *AWSKinesis) Init(ctx context.Context, metadata bindings.Metadata) error a.consumerName = m.ConsumerName a.metadata = m - if a.authProvider == nil { - opts := awsAuth.Options{ - Logger: a.logger, - Properties: metadata.Properties, - Region: m.Region, - AccessKey: m.AccessKey, - SecretKey: m.SecretKey, - SessionToken: "", - } - // extra configs needed per component type - provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) - if err != nil { - return err - } - a.authProvider = provider + opts := awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: "", } - + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) + if err != nil { + return err + } + a.authProvider = provider return nil } @@ -146,7 +143,11 @@ func (a *AWSKinesis) Invoke(ctx context.Context, req *bindings.InvokeRequest) (* if partitionKey == "" { partitionKey = uuid.New().String() } - _, err := a.authProvider.Kinesis(ctx).Kinesis.PutRecordWithContext(ctx, &kinesis.PutRecordInput{ + clients, err := a.authProvider.Kinesis(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get client: %v", err) + } + _, err = clients.Kinesis.PutRecordWithContext(ctx, &kinesis.PutRecordInput{ StreamName: &a.metadata.StreamName, Data: req.Data, PartitionKey: &partitionKey, @@ -159,16 +160,19 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er if a.closed.Load() { return errors.New("binding is closed") } - + clients, err := a.authProvider.Kinesis(ctx) + if err != nil { + return fmt.Errorf("failed to get client: %v", err) + } if a.metadata.KinesisConsumerMode == SharedThroughput { - a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), a.authProvider.Kinesis(ctx).WorkerCfg(ctx, a.streamName, a.consumerName, a.consumerMode)) + a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), clients.WorkerCfg(ctx, a.streamName, a.consumerName, a.consumerMode)) err = a.worker.Start() if err != nil { return err } } else if a.metadata.KinesisConsumerMode == ExtendedFanout { var stream *kinesis.DescribeStreamOutput - stream, err = a.authProvider.Kinesis(ctx).Kinesis.DescribeStream(&kinesis.DescribeStreamInput{StreamName: &a.metadata.StreamName}) + stream, err = clients.Kinesis.DescribeStream(&kinesis.DescribeStreamInput{StreamName: &a.metadata.StreamName}) if err != nil { return err } @@ -178,12 +182,12 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er } } - // Wait for context cancelation then stop - a.wg.Add(1) - stream, err := a.authProvider.Kinesis(ctx).Stream(ctx, a.streamName) + stream, err := clients.Stream(ctx, a.streamName) if err != nil { return fmt.Errorf("failed to get kinesis stream arn: %v", err) } + // Wait for context cancelation then stop + a.wg.Add(1) go func() { defer a.wg.Done() select { @@ -228,8 +232,12 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes return default: } - - sub, err := a.authProvider.Kinesis(ctx).Kinesis.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{ + clients, err := a.authProvider.Kinesis(ctx) + if err != nil { + a.logger.Errorf("failed to get client: %v", err) + return + } + sub, err := clients.Kinesis.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{ ConsumerARN: consumerARN, ShardId: s.ShardId, StartingPosition: &kinesis.StartingPosition{Type: aws.String(kinesis.ShardIteratorTypeLatest)}, @@ -278,7 +286,11 @@ func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*st // Only set timeout on consumer call. conCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - consumer, err := a.authProvider.Kinesis(ctx).Kinesis.DescribeStreamConsumerWithContext(conCtx, &kinesis.DescribeStreamConsumerInput{ + clients, err := a.authProvider.Kinesis(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get client: %v", err) + } + consumer, err := clients.Kinesis.DescribeStreamConsumerWithContext(conCtx, &kinesis.DescribeStreamConsumerInput{ ConsumerName: &a.metadata.ConsumerName, StreamARN: streamARN, }) @@ -290,7 +302,11 @@ func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*st } func (a *AWSKinesis) registerConsumer(ctx context.Context, streamARN *string) (*string, error) { - consumer, err := a.authProvider.Kinesis(ctx).Kinesis.RegisterStreamConsumerWithContext(ctx, &kinesis.RegisterStreamConsumerInput{ + clients, err := a.authProvider.Kinesis(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get client: %v", err) + } + consumer, err := clients.Kinesis.RegisterStreamConsumerWithContext(ctx, &kinesis.RegisterStreamConsumerInput{ ConsumerName: &a.metadata.ConsumerName, StreamARN: streamARN, }) @@ -313,7 +329,11 @@ func (a *AWSKinesis) deregisterConsumer(streamARN *string, consumerARN *string) if a.consumerARN != nil { // Use a background context because the running context may have been canceled already ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - _, err := a.authProvider.Kinesis(ctx).Kinesis.DeregisterStreamConsumerWithContext(ctx, &kinesis.DeregisterStreamConsumerInput{ + clients, err := a.authProvider.Kinesis(ctx) + if err != nil { + return fmt.Errorf("failed to get client: %v", err) + } + _, err = clients.Kinesis.DeregisterStreamConsumerWithContext(ctx, &kinesis.DeregisterStreamConsumerInput{ ConsumerARN: consumerARN, StreamARN: streamARN, ConsumerName: &a.metadata.ConsumerName, @@ -344,7 +364,11 @@ func (a *AWSKinesis) waitUntilConsumerExists(ctx aws.Context, input *kinesis.Des tmp := *input inCpy = &tmp } - req, _ := a.authProvider.Kinesis(ctx).Kinesis.DescribeStreamConsumerRequest(inCpy) + clients, err := a.authProvider.Kinesis(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get client: %v", err) + } + req, _ := clients.Kinesis.DescribeStreamConsumerRequest(inCpy) req.SetContext(ctx) req.ApplyOptions(opts...) diff --git a/bindings/aws/s3/s3.go b/bindings/aws/s3/s3.go index 9ad69da5fa..e57135133b 100644 --- a/bindings/aws/s3/s3.go +++ b/bindings/aws/s3/s3.go @@ -30,7 +30,6 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/google/uuid" @@ -132,27 +131,24 @@ func (s *AWSS3) Init(ctx context.Context, metadata bindings.Metadata) error { if err != nil { return err } + s.metadata = m - if s.authProvider == nil { - s.metadata = m - - opts := awsAuth.Options{ - Logger: s.logger, - Properties: metadata.Properties, - Region: m.Region, - Endpoint: m.Endpoint, - AccessKey: m.AccessKey, - SecretKey: m.SecretKey, - SessionToken: m.SessionToken, - } - // extra configs needed per component type - cfg := s.getAWSConfig(opts) - provider, err := awsAuth.NewProvider(ctx, opts, cfg) - if err != nil { - return err - } - s.authProvider = provider + opts := awsAuth.Options{ + Logger: s.logger, + Properties: metadata.Properties, + Region: m.Region, + Endpoint: m.Endpoint, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, } + // extra configs needed per component type + cfg := s.getAWSConfig(opts) + provider, err := awsAuth.NewProvider(ctx, opts, cfg) + if err != nil { + return err + } + s.authProvider = provider return nil } @@ -172,6 +168,7 @@ func (s *AWSS3) Operations() []bindings.OperationKind { } func (s *AWSS3) create(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { + metadata, err := s.metadata.mergeWithRequestMetadata(req) if err != nil { return nil, fmt.Errorf("s3 binding error: error merging metadata: %w", err) @@ -212,7 +209,11 @@ func (s *AWSS3) create(ctx context.Context, req *bindings.InvokeRequest) (*bindi storageClass = aws.String(metadata.StorageClass) } - resultUpload, err := s.authProvider.S3(ctx).Uploader.UploadWithContext(ctx, &s3manager.UploadInput{ + clients, err := s.authProvider.S3(ctx) + if err != nil { + return nil, fmt.Errorf("s3 binding error: failed to get client: %v", err) + } + resultUpload, err := clients.Uploader.UploadWithContext(ctx, &s3manager.UploadInput{ Bucket: ptr.Of(metadata.Bucket), Key: ptr.Of(key), Body: r, @@ -287,8 +288,11 @@ func (s *AWSS3) presignObject(ctx context.Context, bucket, key, ttl string) (str if err != nil { return "", fmt.Errorf("s3 binding error: cannot parse duration %s: %w", ttl, err) } - - objReq, _ := s.authProvider.S3(ctx).S3.GetObjectRequest(&s3.GetObjectInput{ + clients, err := s.authProvider.S3(ctx) + if err != nil { + return "", fmt.Errorf("s3 binding error: failed to get client: %v", err) + } + objReq, _ := clients.S3.GetObjectRequest(&s3.GetObjectInput{ Bucket: ptr.Of(bucket), Key: ptr.Of(key), }) @@ -313,7 +317,11 @@ func (s *AWSS3) get(ctx context.Context, req *bindings.InvokeRequest) (*bindings buff := &aws.WriteAtBuffer{} - _, err = s.authProvider.S3(ctx).Downloader.DownloadWithContext(ctx, + clients, err := s.authProvider.S3(ctx) + if err != nil { + return nil, fmt.Errorf("s3 binding error: failed to get client: %v", err) + } + _, err = clients.Downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{ Bucket: ptr.Of(s.metadata.Bucket), @@ -348,7 +356,12 @@ func (s *AWSS3) delete(ctx context.Context, req *bindings.InvokeRequest) (*bindi return nil, fmt.Errorf("s3 binding error: required metadata '%s' missing", metadataKey) } - _, err := s.authProvider.S3(ctx).S3.DeleteObjectWithContext( + clients, err := s.authProvider.S3(ctx) + if err != nil { + return nil, fmt.Errorf("s3 binding error: failed to get client: %v", err) + } + + _, err = clients.S3.DeleteObjectWithContext( ctx, &s3.DeleteObjectInput{ Bucket: ptr.Of(s.metadata.Bucket), @@ -378,7 +391,12 @@ func (s *AWSS3) list(ctx context.Context, req *bindings.InvokeRequest) (*binding payload.MaxResults = defaultMaxResults } - result, err := s.authProvider.S3(ctx).S3.ListObjectsWithContext(ctx, &s3.ListObjectsInput{ + clients, err := s.authProvider.S3(ctx) + if err != nil { + return nil, fmt.Errorf("s3 binding error: failed to get client: %v", err) + } + + result, err := clients.S3.ListObjectsWithContext(ctx, &s3.ListObjectsInput{ Bucket: ptr.Of(s.metadata.Bucket), MaxKeys: ptr.Of(int64(payload.MaxResults)), Marker: ptr.Of(payload.Marker), diff --git a/bindings/aws/ses/ses.go b/bindings/aws/ses/ses.go index b65f3d0358..211d434bad 100644 --- a/bindings/aws/ses/ses.go +++ b/bindings/aws/ses/ses.go @@ -70,24 +70,20 @@ func (a *AWSSES) Init(ctx context.Context, metadata bindings.Metadata) error { a.metadata = m - if a.authProvider == nil { - a.metadata = m - - opts := awsAuth.Options{ - Logger: a.logger, - Properties: metadata.Properties, - Region: m.Region, - AccessKey: m.AccessKey, - SecretKey: m.SecretKey, - SessionToken: "", - } - // extra configs needed per component type - provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) - if err != nil { - return err - } - a.authProvider = provider + opts := awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: "", } + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) + if err != nil { + return err + } + a.authProvider = provider return nil } @@ -155,7 +151,11 @@ func (a *AWSSES) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind } // Attempt to send the email. - result, err := a.authProvider.Ses(ctx).Ses.SendEmail(input) + clients, err := a.authProvider.Ses(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get client: %v", err) + } + result, err := clients.Ses.SendEmail(input) if err != nil { return nil, fmt.Errorf("SES binding error. Sending email failed: %w", err) } @@ -180,5 +180,5 @@ func (a *AWSSES) GetComponentMetadata() (metadataInfo contribMetadata.MetadataMa } func (a *AWSSES) Close() error { - return nil + return a.authProvider.Close() } diff --git a/bindings/aws/sns/sns.go b/bindings/aws/sns/sns.go index adc82dfca6..990deffe92 100644 --- a/bindings/aws/sns/sns.go +++ b/bindings/aws/sns/sns.go @@ -65,24 +65,21 @@ func (a *AWSSNS) Init(ctx context.Context, metadata bindings.Metadata) error { return err } - if a.authProvider == nil { - opts := awsAuth.Options{ - Logger: a.logger, - Properties: metadata.Properties, - Region: m.Region, - Endpoint: m.Endpoint, - AccessKey: m.AccessKey, - SecretKey: m.SecretKey, - SessionToken: m.SessionToken, - } - // extra configs needed per component type - provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) - if err != nil { - return err - } - a.authProvider = provider + opts := awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + Endpoint: m.Endpoint, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, } - + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) + if err != nil { + return err + } + a.authProvider = provider a.topicARN = m.TopicArn return nil @@ -112,7 +109,11 @@ func (a *AWSSNS) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind msg := fmt.Sprintf("%v", payload.Message) subject := fmt.Sprintf("%v", payload.Subject) - _, err = a.authProvider.Sns(ctx).Sns.PublishWithContext(ctx, &sns.PublishInput{ + clients, err := a.authProvider.Sns(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get client: %v", err) + } + _, err = clients.Sns.PublishWithContext(ctx, &sns.PublishInput{ Message: &msg, Subject: &subject, TopicArn: &a.topicARN, diff --git a/bindings/aws/sqs/sqs.go b/bindings/aws/sqs/sqs.go index 2075966450..b12b22ad63 100644 --- a/bindings/aws/sqs/sqs.go +++ b/bindings/aws/sqs/sqs.go @@ -16,7 +16,6 @@ package sqs import ( "context" "errors" - "fmt" "reflect" "sync" "sync/atomic" @@ -66,24 +65,21 @@ func (a *AWSSQS) Init(ctx context.Context, metadata bindings.Metadata) error { return err } - if a.authProvider == nil { - opts := awsAuth.Options{ - Logger: a.logger, - Properties: metadata.Properties, - Region: m.Region, - Endpoint: m.Endpoint, - AccessKey: m.AccessKey, - SecretKey: m.SecretKey, - SessionToken: m.SessionToken, - } - // extra configs needed per component type - provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) - if err != nil { - return err - } - a.authProvider = provider + opts := awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + Endpoint: m.Endpoint, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, } - + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) + if err != nil { + return err + } + a.authProvider = provider a.queueName = m.QueueName return nil @@ -95,11 +91,16 @@ func (a *AWSSQS) Operations() []bindings.OperationKind { func (a *AWSSQS) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { msgBody := string(req.Data) - url, err := a.authProvider.Sqs(ctx).QueueURL(ctx, a.queueName) + clients, err := a.authProvider.Sqs(ctx) if err != nil { - return nil, fmt.Errorf("failed to get queue url: %v", err) + a.logger.Errorf("failed to get client: %v", err) } - _, err = a.authProvider.Sqs(ctx).Sqs.SendMessageWithContext(ctx, &sqs.SendMessageInput{ + url, err := clients.QueueURL(ctx, a.queueName) + if err != nil { + a.logger.Errorf("failed to get queue url: %v", err) + } + + _, err = clients.Sqs.SendMessageWithContext(ctx, &sqs.SendMessageInput{ MessageBody: &msgBody, QueueUrl: url, }) @@ -121,12 +122,16 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error { if ctx.Err() != nil || a.closed.Load() { return } - url, err := a.authProvider.Sqs(ctx).QueueURL(ctx, a.queueName) + clients, err := a.authProvider.Sqs(ctx) + if err != nil { + a.logger.Errorf("failed to get client: %v", err) + } + url, err := clients.QueueURL(ctx, a.queueName) if err != nil { - fmt.Errorf("failed to get queue url: %v", err) + a.logger.Errorf("failed to get queue url: %v", err) } - result, err := a.authProvider.Sqs(ctx).Sqs.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{ + result, err := clients.Sqs.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{ QueueUrl: url, AttributeNames: aws.StringSlice([]string{ "SentTimestamp", @@ -152,7 +157,7 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error { msgHandle := m.ReceiptHandle // Use a background context here because ctx may be canceled already - a.authProvider.Sqs(ctx).Sqs.DeleteMessageWithContext(context.Background(), &sqs.DeleteMessageInput{ + clients.Sqs.DeleteMessageWithContext(context.Background(), &sqs.DeleteMessageInput{ QueueUrl: url, ReceiptHandle: msgHandle, }) diff --git a/common/authentication/aws/aws.go b/common/authentication/aws/aws.go index b841010152..2cedd20b89 100644 --- a/common/authentication/aws/aws.go +++ b/common/authentication/aws/aws.go @@ -80,18 +80,15 @@ func GetConfig(opts Options) *aws.Config { } type Provider interface { - Initialize(ctx context.Context, opts Options, cfg *aws.Config) error - - S3(ctx context.Context) *S3Clients - DynamoDB(ctx context.Context) *DynamoDBClients - DynamoDBI(ctx context.Context) *DynamoDBClientsI - Sqs(ctx context.Context) *SqsClients - Sns(ctx context.Context) *SnsClients - SnsSqs(ctx context.Context) *SnsSqsClients - SecretManager(ctx context.Context) *SecretManagerClients - ParameterStore(ctx context.Context) *ParameterStoreClients - Kinesis(ctx context.Context) *KinesisClients - Ses(ctx context.Context) *SesClients + S3(ctx context.Context) (*S3Clients, error) + DynamoDB(ctx context.Context) (*DynamoDBClients, error) + Sqs(ctx context.Context) (*SqsClients, error) + Sns(ctx context.Context) (*SnsClients, error) + SnsSqs(ctx context.Context) (*SnsSqsClients, error) + SecretManager(ctx context.Context) (*SecretManagerClients, error) + ParameterStore(ctx context.Context) (*ParameterStoreClients, error) + Kinesis(ctx context.Context) (*KinesisClients, error) + Ses(ctx context.Context) (*SesClients, error) Close() error } @@ -105,20 +102,9 @@ func isX509Auth(m map[string]string) bool { func NewProvider(ctx context.Context, opts Options, cfg *aws.Config) (Provider, error) { if isX509Auth(opts.Properties) { - provider := &x509TempAuth{} - err := provider.Initialize(ctx, opts, cfg) - if err != nil { - return nil, fmt.Errorf("failed to initialize AWS Roles Anywhere authentication: %v", err) - } - return provider, nil - } - - provider := &StaticAuth{} - err := provider.Initialize(ctx, opts, cfg) - if err != nil { - return nil, fmt.Errorf("failed to initialize AWS IAM authentication: %v", err) + return newX509(ctx, opts, cfg) } - return provider, nil + return newStaticIAM(ctx, opts, cfg) } diff --git a/common/authentication/aws/aws_client.go b/common/authentication/aws/client.go similarity index 80% rename from common/authentication/aws/aws_client.go rename to common/authentication/aws/client.go index 8425bd19df..14d1c11ee5 100644 --- a/common/authentication/aws/aws_client.go +++ b/common/authentication/aws/client.go @@ -1,16 +1,3 @@ -/* -Copyright 2021 The Dapr Authors -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - package aws import ( @@ -43,22 +30,18 @@ type Clients struct { mu sync.RWMutex s3 *S3Clients - dynamo *DynamoDBClients - dynamoI *DynamoDBClientsI + Dynamo *DynamoDBClients sns *SnsClients sqs *SqsClients snssqs *SnsSqsClients secret *SecretManagerClients - parameterStore *ParameterStoreClients + ParameterStore *ParameterStoreClients kinesis *KinesisClients ses *SesClients } func newClients() *Clients { - clients := &Clients{ - mu: sync.RWMutex{}, - } - return clients + return new(Clients) } func (c *Clients) refresh(session *session.Session) { @@ -67,10 +50,8 @@ func (c *Clients) refresh(session *session.Session) { switch { case c.s3 != nil: c.s3.New(session) - case c.dynamo != nil: - c.dynamo.New(session) - case c.dynamoI != nil: - c.dynamoI.New(session) + case c.Dynamo != nil: + c.Dynamo.New(session) case c.sns != nil: c.sns.New(session) case c.sqs != nil: @@ -79,8 +60,8 @@ func (c *Clients) refresh(session *session.Session) { c.snssqs.New(session) case c.secret != nil: c.secret.New(session) - case c.parameterStore != nil: - c.parameterStore.New(session) + case c.ParameterStore != nil: + c.ParameterStore.New(session) case c.kinesis != nil: c.kinesis.New(session) case c.ses != nil: @@ -95,10 +76,6 @@ type S3Clients struct { } type DynamoDBClients struct { - DynamoDB *dynamodb.DynamoDB -} - -type DynamoDBClientsI struct { DynamoDB dynamodbiface.DynamoDBAPI } @@ -144,10 +121,6 @@ func (c *DynamoDBClients) New(session *session.Session) { c.DynamoDB = dynamodb.New(session, session.Config) } -func (c *DynamoDBClientsI) New(session *session.Session) { - c.DynamoDB = dynamodb.New(session, session.Config) -} - func (c *SnsClients) New(session *session.Session) { c.Sns = sns.New(session, session.Config) } @@ -205,6 +178,7 @@ func (c *KinesisClients) WorkerCfg(ctx context.Context, stream, consumer, mode s return kclConfig } + } return nil diff --git a/common/authentication/aws/static.go b/common/authentication/aws/static.go new file mode 100644 index 0000000000..c5454ae924 --- /dev/null +++ b/common/authentication/aws/static.go @@ -0,0 +1,361 @@ +/* +Copyright 2021 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + "fmt" + "sync" + + awsv2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + v2creds "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/dapr/kit/logger" +) + +type StaticAuth struct { + mu sync.RWMutex + logger logger.Logger + + region string + endpoint *string + accessKey *string + secretKey *string + sessionToken *string + + Clients *Clients + session *session.Session + cfg *aws.Config +} + +func newStaticIAM(ctx context.Context, opts Options, cfg *aws.Config) (*StaticAuth, error) { + auth := &StaticAuth{ + logger: opts.Logger, + region: opts.Region, + endpoint: &opts.Endpoint, + accessKey: &opts.AccessKey, + secretKey: &opts.SecretKey, + sessionToken: &opts.SessionToken, + cfg: cfg, + Clients: newClients(), + } + + initialSession, err := auth.getTokenClient() + if err != nil { + return nil, fmt.Errorf("failed to get token client: %v", err) + } + + auth.session = initialSession + + return auth, nil +} + +func (a *StaticAuth) S3(ctx context.Context) (*S3Clients, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.Clients.s3 != nil { + return a.Clients.s3, nil + } + + // respect context cancellation while initializing client + done := make(chan struct{}) + go func() { + defer close(done) + s3Clients := S3Clients{} + a.Clients.s3 = &s3Clients + a.logger.Debugf("Initializing S3 clients with session %v", a.session) + a.Clients.s3.New(a.session) + }() + + // wait for new client or context to be canceled + select { + case <-done: + return a.Clients.s3, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (a *StaticAuth) DynamoDB(ctx context.Context) (*DynamoDBClients, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.Clients.Dynamo != nil { + return a.Clients.Dynamo, nil + } + + // respect context cancellation while initializing client + done := make(chan struct{}) + go func() { + defer close(done) + clients := DynamoDBClients{} + a.Clients.Dynamo = &clients + a.Clients.Dynamo.New(a.session) + }() + + // wait for new client or context to be canceled + select { + case <-done: + return a.Clients.Dynamo, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (a *StaticAuth) Sqs(ctx context.Context) (*SqsClients, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.Clients.sqs != nil { + return a.Clients.sqs, nil + } + + // respect context cancellation while initializing client + done := make(chan struct{}) + go func() { + defer close(done) + clients := SqsClients{} + a.Clients.sqs = &clients + a.Clients.sqs.New(a.session) + }() + + // wait for new client or context to be canceled + select { + case <-done: + return a.Clients.sqs, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (a *StaticAuth) Sns(ctx context.Context) (*SnsClients, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.Clients.sns != nil { + return a.Clients.sns, nil + } + + // respect context cancellation while initializing client + done := make(chan struct{}) + go func() { + defer close(done) + clients := SnsClients{} + a.Clients.sns = &clients + a.Clients.sns.New(a.session) + }() + + // wait for new client or context to be canceled + select { + case <-done: + return a.Clients.sns, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (a *StaticAuth) SnsSqs(ctx context.Context) (*SnsSqsClients, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.Clients.snssqs != nil { + return a.Clients.snssqs, nil + } + + // respect context cancellation while initializing client + done := make(chan struct{}) + go func() { + defer close(done) + clients := SnsSqsClients{} + a.Clients.snssqs = &clients + a.Clients.snssqs.New(a.session) + }() + + // wait for new client or context to be canceled + select { + case <-done: + return a.Clients.snssqs, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (a *StaticAuth) SecretManager(ctx context.Context) (*SecretManagerClients, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.Clients.secret != nil { + return a.Clients.secret, nil + } + + // respect context cancellation while initializing client + done := make(chan struct{}) + go func() { + defer close(done) + clients := SecretManagerClients{} + a.Clients.secret = &clients + a.Clients.secret.New(a.session) + }() + + // wait for new client or context to be canceled + select { + case <-done: + return a.Clients.secret, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (a *StaticAuth) ParameterStore(ctx context.Context) (*ParameterStoreClients, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.Clients.ParameterStore != nil { + return a.Clients.ParameterStore, nil + } + + // respect context cancellation while initializing client + done := make(chan struct{}) + go func() { + defer close(done) + clients := ParameterStoreClients{} + a.Clients.ParameterStore = &clients + a.Clients.ParameterStore.New(a.session) + }() + + // wait for new client or context to be canceled + select { + case <-done: + return a.Clients.ParameterStore, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (a *StaticAuth) Kinesis(ctx context.Context) (*KinesisClients, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.Clients.kinesis != nil { + return a.Clients.kinesis, nil + } + + // respect context cancellation while initializing client + done := make(chan struct{}) + go func() { + defer close(done) + clients := KinesisClients{} + a.Clients.kinesis = &clients + a.Clients.kinesis.New(a.session) + }() + + // wait for new client or context to be canceled + select { + case <-done: + return a.Clients.kinesis, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (a *StaticAuth) Ses(ctx context.Context) (*SesClients, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.Clients.ses != nil { + return a.Clients.ses, nil + } + + // respect context cancellation while initializing client + done := make(chan struct{}) + go func() { + defer close(done) + clients := SesClients{} + a.Clients.ses = &clients + a.Clients.ses.New(a.session) + }() + + // wait for new client or context to be canceled + select { + case <-done: + return a.Clients.ses, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (a *StaticAuth) getTokenClient() (*session.Session, error) { + awsConfig := aws.NewConfig() + + if a.region != "" { + awsConfig = awsConfig.WithRegion(a.region) + } + + if a.accessKey != nil && a.secretKey != nil { + // session token is an option field + awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(*a.accessKey, *a.secretKey, *a.sessionToken)) + } + + if a.endpoint != nil { + awsConfig = awsConfig.WithEndpoint(*a.endpoint) + } + + awsSession, err := session.NewSessionWithOptions(session.Options{ + Config: *awsConfig, + SharedConfigState: session.SharedConfigEnable, + }) + if err != nil { + return nil, err + } + + userAgentHandler := request.NamedHandler{ + Name: "UserAgentHandler", + Fn: request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion), + } + awsSession.Handlers.Build.PushBackNamed(userAgentHandler) + + return awsSession, nil +} + +func (a *StaticAuth) Close() error { + return nil +} + +func GetConfigV2(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (awsv2.Config, error) { + optFns := []func(*config.LoadOptions) error{} + if region != "" { + optFns = append(optFns, config.WithRegion(region)) + } + + if accessKey != "" && secretKey != "" { + provider := v2creds.NewStaticCredentialsProvider(accessKey, secretKey, sessionToken) + optFns = append(optFns, config.WithCredentialsProvider(provider)) + } + + awsCfg, err := config.LoadDefaultConfig(context.Background(), optFns...) + if err != nil { + return awsv2.Config{}, err + } + + if endpoint != "" { + awsCfg.BaseEndpoint = &endpoint + } + + return awsCfg, nil +} diff --git a/common/authentication/aws/static_iam.go b/common/authentication/aws/static_iam.go deleted file mode 100644 index a1b898164d..0000000000 --- a/common/authentication/aws/static_iam.go +++ /dev/null @@ -1,257 +0,0 @@ -/* -Copyright 2021 The Dapr Authors -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package aws - -import ( - "context" - "fmt" - "sync" - - awsv2 "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/config" - v2creds "github.com/aws/aws-sdk-go-v2/credentials" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/dapr/kit/logger" -) - -type StaticAuth struct { - mu sync.RWMutex - logger logger.Logger - - region string - endpoint *string - accessKey *string - secretKey *string - sessionToken *string - - clients *Clients - session *session.Session - cfg *aws.Config -} - -func (a *StaticAuth) Initialize(_ context.Context, opts Options, cfg *aws.Config) error { - a.mu.Lock() - defer a.mu.Unlock() - - a.logger = opts.Logger - a.region = opts.Region - a.endpoint = &opts.Endpoint - a.accessKey = &opts.AccessKey - a.secretKey = &opts.SecretKey - a.sessionToken = &opts.SessionToken - a.cfg = cfg - a.clients = newClients() - - initialSession, err := a.getTokenClient() - if err != nil { - return fmt.Errorf("failed to get token client: %v", err) - } - - a.session = initialSession - - return nil -} - -func (a *StaticAuth) S3(ctx context.Context) *S3Clients { - a.mu.Lock() - defer a.mu.Unlock() - - if a.clients.s3 == nil { - s3Clients := S3Clients{} - a.clients.s3 = &s3Clients - a.clients.s3.New(a.session) - } - - return a.clients.s3 -} - -func (a *StaticAuth) DynamoDB(ctx context.Context) *DynamoDBClients { - a.mu.Lock() - defer a.mu.Unlock() - - if a.clients.dynamo == nil { - clients := DynamoDBClients{} - a.clients.dynamo = &clients - a.clients.dynamo.New(a.session) - } - - return a.clients.dynamo -} - -func (a *StaticAuth) DynamoDBI(ctx context.Context) *DynamoDBClientsI { - a.mu.Lock() - defer a.mu.Unlock() - - if a.clients.dynamoI == nil { - clients := DynamoDBClientsI{} - a.clients.dynamoI = &clients - a.clients.dynamoI.New(a.session) - } - - return a.clients.dynamoI -} - -func (a *StaticAuth) Sqs(ctx context.Context) *SqsClients { - a.mu.Lock() - defer a.mu.Unlock() - - if a.clients.sqs == nil { - clients := SqsClients{} - a.clients.sqs = &clients - a.clients.sqs.New(a.session) - } - - return a.clients.sqs -} - -func (a *StaticAuth) Sns(ctx context.Context) *SnsClients { - a.mu.Lock() - defer a.mu.Unlock() - - if a.clients.sns == nil { - clients := SnsClients{} - a.clients.sns = &clients - a.clients.sns.New(a.session) - } - - return a.clients.sns -} - -func (a *StaticAuth) SnsSqs(ctx context.Context) *SnsSqsClients { - a.mu.Lock() - defer a.mu.Unlock() - - if a.clients.snssqs == nil { - clients := SnsSqsClients{} - a.clients.snssqs = &clients - a.clients.snssqs.New(a.session) - } - - return a.clients.snssqs -} - -func (a *StaticAuth) SecretManager(ctx context.Context) *SecretManagerClients { - a.mu.Lock() - defer a.mu.Unlock() - - if a.clients.secret == nil { - clients := SecretManagerClients{} - a.clients.secret = &clients - a.clients.secret.New(a.session) - } - - return a.clients.secret -} - -func (a *StaticAuth) ParameterStore(ctx context.Context) *ParameterStoreClients { - a.mu.Lock() - defer a.mu.Unlock() - - if a.clients.parameterStore == nil { - clients := ParameterStoreClients{} - a.clients.parameterStore = &clients - a.clients.parameterStore.New(a.session) - } - - return a.clients.parameterStore -} - -func (a *StaticAuth) Kinesis(ctx context.Context) *KinesisClients { - a.mu.Lock() - defer a.mu.Unlock() - - if a.clients.kinesis == nil { - clients := KinesisClients{} - a.clients.kinesis = &clients - a.clients.kinesis.New(a.session) - } - - return a.clients.kinesis -} - -func (a *StaticAuth) Ses(ctx context.Context) *SesClients { - a.mu.Lock() - defer a.mu.Unlock() - - if a.clients.ses == nil { - clients := SesClients{} - a.clients.ses = &clients - a.clients.ses.New(a.session) - } - - return a.clients.ses -} - -func (a *StaticAuth) getTokenClient() (*session.Session, error) { - awsConfig := aws.NewConfig() - - if a.region != "" { - awsConfig = awsConfig.WithRegion(a.region) - } - - if a.accessKey != nil && a.secretKey != nil { - // session token is an option field - awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(*a.accessKey, *a.secretKey, *a.sessionToken)) - } - - if a.endpoint != nil { - awsConfig = awsConfig.WithEndpoint(*a.endpoint) - } - - awsSession, err := session.NewSessionWithOptions(session.Options{ - Config: *awsConfig, - SharedConfigState: session.SharedConfigEnable, - }) - if err != nil { - return nil, err - } - - userAgentHandler := request.NamedHandler{ - Name: "UserAgentHandler", - Fn: request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion), - } - awsSession.Handlers.Build.PushBackNamed(userAgentHandler) - - return awsSession, nil -} - -func (a *StaticAuth) Close() error { - return nil -} - -func GetConfigV2(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (awsv2.Config, error) { - optFns := []func(*config.LoadOptions) error{} - if region != "" { - optFns = append(optFns, config.WithRegion(region)) - } - - if accessKey != "" && secretKey != "" { - provider := v2creds.NewStaticCredentialsProvider(accessKey, secretKey, sessionToken) - optFns = append(optFns, config.WithCredentialsProvider(provider)) - } - - awsCfg, err := config.LoadDefaultConfig(context.Background(), optFns...) - if err != nil { - return awsv2.Config{}, err - } - - if endpoint != "" { - awsCfg.BaseEndpoint = &endpoint - } - - return awsCfg, nil -} diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go new file mode 100644 index 0000000000..06e11e9bee --- /dev/null +++ b/common/authentication/aws/x509.go @@ -0,0 +1,558 @@ +/* +Copyright 2021 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + "crypto/ecdsa" + "crypto/tls" + cryptoX509 "crypto/x509" + "errors" + "fmt" + "net/http" + "runtime" + "sync" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + awssh "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" + cryptopem "github.com/dapr/kit/crypto/pem" + spiffecontext "github.com/dapr/kit/crypto/spiffe/context" + "github.com/dapr/kit/logger" + kitmd "github.com/dapr/kit/metadata" + "github.com/dapr/kit/ptr" +) + +type x509 struct { + mu sync.RWMutex + + wg sync.WaitGroup + // used for background session refresh logic that cannot use the context passed to the newx509 function + internalContext context.Context + internalContextCancel func() + + logger logger.Logger + Clients *Clients + session *session.Session + cfg *aws.Config + + chainPEM []byte + keyPEM []byte + + region *string + TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"` + TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"` + AssumeRoleArn *string `json:"assumeRoleArn" mapstructure:"assumeRoleArn"` + SessionDuration *time.Duration `json:"sessionDuration" mapstructure:"sessionDuration"` +} + +func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) { + var x509Auth x509 + if err := kitmd.DecodeMetadata(opts.Properties, &x509Auth); err != nil { + return nil, err + } + + switch { + case x509Auth.TrustProfileArn == nil: + return nil, errors.New("trustProfileArn is required") + case x509Auth.TrustAnchorArn == nil: + return nil, errors.New("trustAnchorArn is required") + case x509Auth.AssumeRoleArn == nil: + return nil, errors.New("assumeRoleArn is required") + + // https://aws.amazon.com/about-aws/whats-new/2024/03/iam-roles-anywhere-credentials-valid-12-hours/#:~:text=The%20duration%20can%20range%20from,and%20applications%2C%20to%20use%20X. + case x509Auth.SessionDuration == nil: + awsDefaultDuration := time.Hour // default 1 hour from AWS + x509Auth.SessionDuration = &awsDefaultDuration + case *x509Auth.SessionDuration < time.Minute*15 || *x509Auth.SessionDuration > time.Hour*12: + return nil, errors.New("sessionDuration must be greater than 15 minutes, and less than 12 hours") + } + + auth := &x509{ + wg: sync.WaitGroup{}, + logger: opts.Logger, + TrustProfileArn: x509Auth.TrustProfileArn, + TrustAnchorArn: x509Auth.TrustAnchorArn, + AssumeRoleArn: x509Auth.AssumeRoleArn, + SessionDuration: x509Auth.SessionDuration, + cfg: GetConfig(opts), + Clients: newClients(), + } + + err := auth.getCertPEM(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get x.509 credentials: %v", err) + } + + // Parse trust anchor and profile ARNs + if err := auth.initializeTrustAnchors(); err != nil { + return nil, err + } + + initialSession, err := auth.createOrRefreshSession(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create the initial session: %v", err) + } + auth.session = initialSession + + // This is needed to keep the session refresher on the background context, but still cancellable. + auth.internalContext, auth.internalContextCancel = context.WithCancel(context.Background()) + auth.startSessionRefresher() + + return auth, nil +} + +func (a *x509) Close() error { + if a.internalContextCancel != nil { + a.internalContextCancel() + } + a.wg.Wait() + return nil +} + +func (a *x509) getCertPEM(ctx context.Context) error { + // retrieve svid from spiffe context + svid, ok := spiffecontext.From(ctx) + if !ok { + return fmt.Errorf("no SVID found in context") + } + // get x.509 svid + svidx, err := svid.GetX509SVID() + if err != nil { + return err + } + + // marshal x.509 svid to pem format + chainPEM, keyPEM, err := svidx.Marshal() + if err != nil { + return fmt.Errorf("failed to marshal SVID: %w", err) + } + + a.chainPEM = chainPEM + a.keyPEM = keyPEM + return nil +} + +func (a *x509) S3(ctx context.Context) (*S3Clients, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.Clients.s3 != nil { + return a.Clients.s3, nil + } + + // respect context cancellation while initializing client + done := make(chan struct{}) + go func() { + defer close(done) + s3Clients := S3Clients{} + a.Clients.s3 = &s3Clients + a.logger.Debugf("Initializing S3 clients with session %v", a.session) + a.Clients.s3.New(a.session) + }() + + // wait for new client or context to be canceled + select { + case <-done: + return a.Clients.s3, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (a *x509) DynamoDB(ctx context.Context) (*DynamoDBClients, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.Clients.Dynamo != nil { + return a.Clients.Dynamo, nil + } + + // respect context cancellation while initializing client + done := make(chan struct{}) + go func() { + defer close(done) + clients := DynamoDBClients{} + a.Clients.Dynamo = &clients + a.Clients.Dynamo.New(a.session) + }() + + // wait for new client or context to be canceled + select { + case <-done: + return a.Clients.Dynamo, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (a *x509) Sqs(ctx context.Context) (*SqsClients, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.Clients.sqs != nil { + return a.Clients.sqs, nil + } + + // respect context cancellation while initializing client + done := make(chan struct{}) + go func() { + defer close(done) + clients := SqsClients{} + a.Clients.sqs = &clients + a.Clients.sqs.New(a.session) + }() + + // wait for new client or context to be canceled + select { + case <-done: + return a.Clients.sqs, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (a *x509) Sns(ctx context.Context) (*SnsClients, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.Clients.sns != nil { + return a.Clients.sns, nil + } + + // respect context cancellation while initializing client + done := make(chan struct{}) + go func() { + defer close(done) + clients := SnsClients{} + a.Clients.sns = &clients + a.Clients.sns.New(a.session) + }() + + // wait for new client or context to be canceled + select { + case <-done: + return a.Clients.sns, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (a *x509) SnsSqs(ctx context.Context) (*SnsSqsClients, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.Clients.snssqs != nil { + return a.Clients.snssqs, nil + } + + // respect context cancellation while initializing client + done := make(chan struct{}) + go func() { + defer close(done) + clients := SnsSqsClients{} + a.Clients.snssqs = &clients + a.Clients.snssqs.New(a.session) + }() + + // wait for new client or context to be canceled + select { + case <-done: + return a.Clients.snssqs, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (a *x509) SecretManager(ctx context.Context) (*SecretManagerClients, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.Clients.secret != nil { + return a.Clients.secret, nil + } + + // respect context cancellation while initializing client + done := make(chan struct{}) + go func() { + defer close(done) + clients := SecretManagerClients{} + a.Clients.secret = &clients + a.Clients.secret.New(a.session) + }() + + // wait for new client or context to be canceled + select { + case <-done: + return a.Clients.secret, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (a *x509) ParameterStore(ctx context.Context) (*ParameterStoreClients, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.Clients.ParameterStore != nil { + return a.Clients.ParameterStore, nil + } + + // respect context cancellation while initializing client + done := make(chan struct{}) + go func() { + defer close(done) + clients := ParameterStoreClients{} + a.Clients.ParameterStore = &clients + a.Clients.ParameterStore.New(a.session) + }() + + // wait for new client or context to be canceled + select { + case <-done: + return a.Clients.ParameterStore, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (a *x509) Kinesis(ctx context.Context) (*KinesisClients, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.Clients.kinesis != nil { + return a.Clients.kinesis, nil + } + + // respect context cancellation while initializing client + done := make(chan struct{}) + go func() { + defer close(done) + clients := KinesisClients{} + a.Clients.kinesis = &clients + a.Clients.kinesis.New(a.session) + }() + + // wait for new client or context to be canceled + select { + case <-done: + return a.Clients.kinesis, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (a *x509) Ses(ctx context.Context) (*SesClients, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.Clients.ses != nil { + return a.Clients.ses, nil + } + + // respect context cancellation while initializing client + done := make(chan struct{}) + go func() { + defer close(done) + clients := SesClients{} + a.Clients.ses = &clients + a.Clients.ses.New(a.session) + }() + + // wait for new client or context to be canceled + select { + case <-done: + return a.Clients.ses, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (a *x509) initializeTrustAnchors() error { + var ( + trustAnchor arn.ARN + profile arn.ARN + err error + ) + if a.TrustAnchorArn != nil { + trustAnchor, err = arn.Parse(*a.TrustAnchorArn) + if err != nil { + return err + } + a.region = &trustAnchor.Region + } + + if a.TrustProfileArn != nil { + profile, err = arn.Parse(*a.TrustProfileArn) + if err != nil { + return err + } + + if profile.Region != "" && trustAnchor.Region != profile.Region { + return fmt.Errorf("trust anchor and profile must be in the same region: trustAnchor=%s, profile=%s", + trustAnchor.Region, profile.Region) + } + } + return nil +} + +func (a *x509) setSigningFunction(rolesAnywhereClient *rolesanywhere.RolesAnywhere) error { + certs, err := cryptopem.DecodePEMCertificatesChain(a.chainPEM) + if err != nil { + return err + } + + var ints []cryptoX509.Certificate + for i := range certs[1:] { + ints = append(ints, *certs[i+1]) + } + + key, err := cryptopem.DecodePEMPrivateKey(a.keyPEM) + if err != nil { + return err + } + + keyECDSA := key.(*ecdsa.PrivateKey) + signFunc := awssh.CreateSignFunction(*keyECDSA, *certs[0], ints) + + agentHandlerFunc := request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH) + rolesAnywhereClient.Handlers.Build.RemoveByName("core.SDKVersionUserAgentHandler") + rolesAnywhereClient.Handlers.Build.PushBackNamed(request.NamedHandler{Name: "v4x509.CredHelperUserAgentHandler", Fn: agentHandlerFunc}) + rolesAnywhereClient.Handlers.Sign.Clear() + rolesAnywhereClient.Handlers.Sign.PushBackNamed(request.NamedHandler{Name: "v4x509.SignRequestHandler", Fn: signFunc}) + + return nil +} + +func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, error) { + a.mu.Lock() + defer a.mu.Unlock() + + client := &http.Client{Transport: &http.Transport{ + TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, + }} + var mySession *session.Session + var err error + + config := a.cfg.WithRegion(*a.region).WithHTTPClient(client).WithLogLevel(aws.LogOff) + mySession = session.Must(session.NewSession(config)) + rolesAnywhereClient := rolesanywhere.New(mySession, config) + + // Set up signing function and handlers + if err := a.setSigningFunction(rolesAnywhereClient); err != nil { + return nil, err + } + + var ( + duration int64 + createSessionRequest rolesanywhere.CreateSessionInput + ) + if *a.SessionDuration != 0 { + duration = int64(a.SessionDuration.Seconds()) + + createSessionRequest = rolesanywhere.CreateSessionInput{ + Cert: ptr.Of(string(a.chainPEM)), + ProfileArn: a.TrustProfileArn, + TrustAnchorArn: a.TrustAnchorArn, + RoleArn: a.AssumeRoleArn, + DurationSeconds: aws.Int64(duration), + InstanceProperties: nil, + SessionName: nil, + } + } else { + duration = 900 // 15 minutes in seconds by default and be autorefreshed + + createSessionRequest = rolesanywhere.CreateSessionInput{ + Cert: ptr.Of(string(a.chainPEM)), + ProfileArn: a.TrustProfileArn, + TrustAnchorArn: a.TrustAnchorArn, + RoleArn: a.AssumeRoleArn, + DurationSeconds: aws.Int64(duration), + InstanceProperties: nil, + SessionName: nil, + } + } + + output, err := rolesAnywhereClient.CreateSessionWithContext(ctx, &createSessionRequest) + if err != nil { + return nil, fmt.Errorf("failed to create session using dapr app identity: %w", err) + } + + if output == nil || len(output.CredentialSet) != 1 { + return nil, fmt.Errorf("expected 1 credential set from X.509 rolesanyway response, got %d", len(output.CredentialSet)) + } + + accessKey := output.CredentialSet[0].Credentials.AccessKeyId + secretKey := output.CredentialSet[0].Credentials.SecretAccessKey + sessionToken := output.CredentialSet[0].Credentials.SessionToken + awsCreds := credentials.NewStaticCredentials(*accessKey, *secretKey, *sessionToken) + sess := session.Must(session.NewSession(&aws.Config{ + Credentials: awsCreds, + }, config)) + if sess == nil { + return nil, fmt.Errorf("session is nil: %v", sess) + } + + return sess, nil +} + +func (a *x509) startSessionRefresher() { + a.logger.Infof("starting session refresher for x509 auth") + // if there is a set session duration, then exit bc we will not auto refresh the session. + if *a.SessionDuration != 0 { + a.logger.Debugf("session duration was set, so there is no authentication refreshing") + return + } + + a.wg.Add(1) + go func() { + defer a.wg.Done() + + // renew at ~half the lifespan + expiration, err := a.session.Config.Credentials.ExpiresAt() + if err != nil { + a.logger.Errorf("failed to retrieve session expiration time: %w", err) + return + } + + timeUntilExpiration := time.Until(expiration) + refreshInterval := timeUntilExpiration / 2 + ticker := time.NewTicker(refreshInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + a.logger.Debugf("Refreshing session as expiration is near") + newSession, err := a.createOrRefreshSession(a.internalContext) + if err != nil { + a.logger.Errorf("failed to refresh session: %w", err) + return + } + + a.Clients.refresh(newSession) + + a.logger.Debugf("AWS IAM Roles Anywhere session credentials refreshed successfully") + case <-a.internalContext.Done(): + a.logger.Debugf("Session refresher stopped due to context cancellation") + return + } + } + }() +} diff --git a/common/authentication/aws/x509_iam.go b/common/authentication/aws/x509_iam.go deleted file mode 100644 index 266528d554..0000000000 --- a/common/authentication/aws/x509_iam.go +++ /dev/null @@ -1,430 +0,0 @@ -/* -Copyright 2021 The Dapr Authors -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package aws - -import ( - "context" - "crypto/ecdsa" - "crypto/tls" - "crypto/x509" - "errors" - "fmt" - "net/http" - "runtime" - "sync" - "time" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/arn" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" - awssh "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" - "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" - cryptopem "github.com/dapr/kit/crypto/pem" - spiffecontext "github.com/dapr/kit/crypto/spiffe/context" - "github.com/dapr/kit/logger" - kitmd "github.com/dapr/kit/metadata" - "github.com/dapr/kit/ptr" -) - -type x509TempAuth struct { - mu sync.RWMutex - logger logger.Logger - clients *Clients - session *session.Session - cfg *aws.Config - - chainPEM []byte - keyPEM []byte - - region *string - TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"` - TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"` - AssumeRoleArn *string `json:"assumeRoleArn" mapstructure:"assumeRoleArn"` - SessionDuration *time.Duration `json:"sessionDuration" mapstructure:"sessionDuration"` -} - -func (a *x509TempAuth) Initialize(ctx context.Context, opts Options, cfg *aws.Config) error { - var x509Auth x509TempAuth - if err := kitmd.DecodeMetadata(opts.Properties, &x509Auth); err != nil { - return err - } - - switch { - case x509Auth.TrustProfileArn == nil: - return errors.New("trustProfileArn is required") - case x509Auth.TrustAnchorArn == nil: - return errors.New("trustAnchorArn is required") - case x509Auth.AssumeRoleArn == nil: - return errors.New("assumeRoleArn is required") - case x509Auth.SessionDuration == nil: - awsDefaultDuration := time.Duration(900) // default 15m - x509Auth.SessionDuration = &awsDefaultDuration - } - - a.logger = opts.Logger - a.TrustProfileArn = x509Auth.TrustProfileArn - a.TrustAnchorArn = x509Auth.TrustAnchorArn - a.AssumeRoleArn = x509Auth.AssumeRoleArn - a.SessionDuration = x509Auth.SessionDuration - a.cfg = GetConfig(opts) - a.clients = newClients() - - err := a.getCertPEM(ctx) - if err != nil { - return fmt.Errorf("failed to get x.509 credentials: %v", err) - } - - // Parse trust anchor and profile ARNs - if err := a.initializeTrustAnchors(); err != nil { - return err - } - - initialSession, err := a.createOrRefreshSession(ctx) - if err != nil { - return fmt.Errorf("failed to create the initial session: %v", err) - } - a.session = initialSession - go a.startSessionRefresher(context.Background()) - - return nil -} - -func (a *x509TempAuth) Close() error { - return nil -} - -func (a *x509TempAuth) getCertPEM(ctx context.Context) error { - // retrieve svid from spiffe context - svid, ok := spiffecontext.From(ctx) - if !ok { - return fmt.Errorf("no SVID found in context") - } - // get x.509 svid - svidx, err := svid.GetX509SVID() - if err != nil { - return err - } - - // marshal x.509 svid to pem format - chainPEM, keyPEM, err := svidx.Marshal() - if err != nil { - return fmt.Errorf("failed to marshal SVID: %w", err) - } - - a.chainPEM = chainPEM - a.keyPEM = keyPEM - return nil -} - -func (a *x509TempAuth) S3(ctx context.Context) *S3Clients { - a.mu.Lock() - defer a.mu.Unlock() - - if a.clients.s3 == nil { - s3Clients := S3Clients{} - a.clients.s3 = &s3Clients - a.clients.s3.New(a.session) - } - - return a.clients.s3 -} - -func (a *x509TempAuth) DynamoDB(ctx context.Context) *DynamoDBClients { - a.mu.Lock() - defer a.mu.Unlock() - - if a.clients.dynamo == nil { - clients := DynamoDBClients{} - a.clients.dynamo = &clients - a.clients.dynamo.New(a.session) - } - - return a.clients.dynamo -} - -func (a *x509TempAuth) DynamoDBI(ctx context.Context) *DynamoDBClientsI { - a.mu.Lock() - defer a.mu.Unlock() - - if a.clients.dynamoI == nil { - clients := DynamoDBClientsI{} - a.clients.dynamoI = &clients - a.clients.dynamoI.New(a.session) - } - - return a.clients.dynamoI -} - -func (a *x509TempAuth) Sqs(ctx context.Context) *SqsClients { - a.mu.Lock() - defer a.mu.Unlock() - - if a.clients.sqs == nil { - clients := SqsClients{} - a.clients.sqs = &clients - a.clients.sqs.New(a.session) - } - - return a.clients.sqs -} - -func (a *x509TempAuth) Sns(ctx context.Context) *SnsClients { - a.mu.Lock() - defer a.mu.Unlock() - - if a.clients.sns == nil { - clients := SnsClients{} - a.clients.sns = &clients - a.clients.sns.New(a.session) - } - - return a.clients.sns -} - -func (a *x509TempAuth) SnsSqs(ctx context.Context) *SnsSqsClients { - a.mu.Lock() - defer a.mu.Unlock() - - if a.clients.snssqs == nil { - clients := SnsSqsClients{} - a.clients.snssqs = &clients - a.clients.snssqs.New(a.session) - } - - return a.clients.snssqs -} - -func (a *x509TempAuth) SecretManager(ctx context.Context) *SecretManagerClients { - a.mu.Lock() - defer a.mu.Unlock() - - if a.clients.secret == nil { - clients := SecretManagerClients{} - a.clients.secret = &clients - a.clients.secret.New(a.session) - } - - return a.clients.secret -} - -func (a *x509TempAuth) ParameterStore(ctx context.Context) *ParameterStoreClients { - a.mu.Lock() - defer a.mu.Unlock() - - if a.clients.parameterStore == nil { - clients := ParameterStoreClients{} - a.clients.parameterStore = &clients - a.clients.parameterStore.New(a.session) - } - - return a.clients.parameterStore -} - -func (a *x509TempAuth) Kinesis(ctx context.Context) *KinesisClients { - a.mu.Lock() - defer a.mu.Unlock() - - if a.clients.kinesis == nil { - clients := KinesisClients{} - a.clients.kinesis = &clients - a.clients.kinesis.New(a.session) - } - - return a.clients.kinesis -} - -func (a *x509TempAuth) Ses(ctx context.Context) *SesClients { - a.mu.Lock() - defer a.mu.Unlock() - - if a.clients.ses == nil { - clients := SesClients{} - a.clients.ses = &clients - a.clients.ses.New(a.session) - } - - return a.clients.ses -} - -func (a *x509TempAuth) initializeTrustAnchors() error { - var ( - trustAnchor arn.ARN - profile arn.ARN - err error - ) - if a.TrustAnchorArn != nil { - trustAnchor, err = arn.Parse(*a.TrustAnchorArn) - if err != nil { - return err - } - a.region = &trustAnchor.Region - } - - if a.TrustProfileArn != nil { - profile, err = arn.Parse(*a.TrustProfileArn) - if err != nil { - return err - } - - if profile.Region != "" && trustAnchor.Region != profile.Region { - return fmt.Errorf("trust anchor and profile must be in the same region: trustAnchor=%s, profile=%s", - trustAnchor.Region, profile.Region) - } - } - return nil -} - -func (a *x509TempAuth) setSigningFunction(rolesAnywhereClient *rolesanywhere.RolesAnywhere) error { - certs, err := cryptopem.DecodePEMCertificatesChain(a.chainPEM) - if err != nil { - return err - } - - var ints []x509.Certificate - for i := range certs[1:] { - ints = append(ints, *certs[i+1]) - } - - key, err := cryptopem.DecodePEMPrivateKey(a.keyPEM) - if err != nil { - return err - } - - keyECDSA := key.(*ecdsa.PrivateKey) - signFunc := awssh.CreateSignFunction(*keyECDSA, *certs[0], ints) - - agentHandlerFunc := request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH) - rolesAnywhereClient.Handlers.Build.RemoveByName("core.SDKVersionUserAgentHandler") - rolesAnywhereClient.Handlers.Build.PushBackNamed(request.NamedHandler{Name: "v4x509.CredHelperUserAgentHandler", Fn: agentHandlerFunc}) - rolesAnywhereClient.Handlers.Sign.Clear() - rolesAnywhereClient.Handlers.Sign.PushBackNamed(request.NamedHandler{Name: "v4x509.SignRequestHandler", Fn: signFunc}) - - return nil -} - -func (a *x509TempAuth) createOrRefreshSession(ctx context.Context) (*session.Session, error) { - a.mu.Lock() - defer a.mu.Unlock() - - client := &http.Client{Transport: &http.Transport{ - TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, - }} - var mySession *session.Session - var err error - - config := a.cfg.WithRegion(*a.region).WithHTTPClient(client).WithLogLevel(aws.LogOff) - mySession = session.Must(session.NewSession(config)) - rolesAnywhereClient := rolesanywhere.New(mySession, config) - - // Set up signing function and handlers - if err := a.setSigningFunction(rolesAnywhereClient); err != nil { - return nil, err - } - - var ( - duration int64 - createSessionRequest rolesanywhere.CreateSessionInput - ) - if *a.SessionDuration != 0 { - duration = int64(a.SessionDuration.Seconds()) - - createSessionRequest = rolesanywhere.CreateSessionInput{ - Cert: ptr.Of(string(a.chainPEM)), - ProfileArn: a.TrustProfileArn, - TrustAnchorArn: a.TrustAnchorArn, - RoleArn: a.AssumeRoleArn, - DurationSeconds: aws.Int64(duration), - InstanceProperties: nil, - SessionName: nil, - } - } else { - duration = 900 // 15 minutes in seconds by default and be autorefreshed - - createSessionRequest = rolesanywhere.CreateSessionInput{ - Cert: ptr.Of(string(a.chainPEM)), - ProfileArn: a.TrustProfileArn, - TrustAnchorArn: a.TrustAnchorArn, - RoleArn: a.AssumeRoleArn, - DurationSeconds: aws.Int64(duration), - InstanceProperties: nil, - SessionName: nil, - } - } - - output, err := rolesAnywhereClient.CreateSessionWithContext(ctx, &createSessionRequest) - if err != nil { - return nil, fmt.Errorf("failed to create session using dapr app identity: %w", err) - } - - if output == nil || len(output.CredentialSet) != 1 { - return nil, fmt.Errorf("expected 1 credential set from X.509 rolesanyway response, got %d", len(output.CredentialSet)) - } - - accessKey := output.CredentialSet[0].Credentials.AccessKeyId - secretKey := output.CredentialSet[0].Credentials.SecretAccessKey - sessionToken := output.CredentialSet[0].Credentials.SessionToken - awsCreds := credentials.NewStaticCredentials(*accessKey, *secretKey, *sessionToken) - sess := session.Must(session.NewSession(&aws.Config{ - Credentials: awsCreds, - }, config)) - if sess == nil { - return nil, fmt.Errorf("sam session is nil somehow %v", sess) - } - - return sess, nil -} - -func (a *x509TempAuth) startSessionRefresher(ctx context.Context) error { - // if there is a set session duration, then exit bc we will not auto refresh the session. - if *a.SessionDuration != 0 { - return nil - } - - a.logger.Debugf("starting session refresher for x509 auth") - errChan := make(chan error, 1) - go func() { - // renew at ~half the lifespan - ticker := time.NewTicker(8 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - a.logger.Infof("Refreshing session as expiration is near") - newSession, err := a.createOrRefreshSession(ctx) - if err != nil { - errChan <- fmt.Errorf("failed to refresh session: %w", err) - return - } - - a.clients.refresh(newSession) - - a.logger.Debugf("AWS IAM Roles Anywhere session credentials refreshed successfully") - case <-ctx.Done(): - a.logger.Debugf("Session refresher stopped due to context cancellation") - errChan <- nil - return - } - } - }() - - select { - case err := <-errChan: - return err - case <-ctx.Done(): - return ctx.Err() - } -} diff --git a/pubsub/aws/snssqs/metadata.go b/pubsub/aws/snssqs/metadata.go index db45fb8d84..4b469106b7 100644 --- a/pubsub/aws/snssqs/metadata.go +++ b/pubsub/aws/snssqs/metadata.go @@ -67,7 +67,7 @@ func maskLeft(s string) string { return string(rs) } -func (s *snsSqs) getSnsSqsMetatdata(meta pubsub.Metadata) (*snsSqsMetadata, error) { +func (s *snsSqs) getSnsSqsMetadata(meta pubsub.Metadata) (*snsSqsMetadata, error) { md := &snsSqsMetadata{ AssetsManagementTimeoutSeconds: assetsManagementDefaultTimeoutSeconds, MessageVisibilityTimeout: 10, diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index 4eca703f45..e560e7f0c4 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -136,7 +136,7 @@ func nameToAWSSanitizedName(name string, isFifo bool) string { } func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error { - m, err := s.getSnsSqsMetatdata(metadata) + m, err := s.getSnsSqsMetadata(metadata) if err != nil { return err } @@ -189,8 +189,13 @@ func (s *snsSqs) setAwsAccountIDIfNotProvided(parentCtx context.Context) error { return nil } + clients, err := s.authProvider.SnsSqs(parentCtx) + if err != nil { + return fmt.Errorf("failed to get client: %v", err) + } + ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) - callerIDOutput, err := s.authProvider.SnsSqs(ctx).Sts.GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{}) + callerIDOutput, err := clients.Sts.GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{}) cancelFn() if err != nil { return fmt.Errorf("error fetching sts caller ID: %w", err) @@ -215,9 +220,13 @@ func (s *snsSqs) createTopic(parentCtx context.Context, topic string) (string, e attributes := map[string]*string{"FifoTopic": aws.String("true"), "ContentBasedDeduplication": aws.String("true")} snsCreateTopicInput.SetAttributes(attributes) } + clients, err := s.authProvider.SnsSqs(parentCtx) + if err != nil { + return "", fmt.Errorf("failed to get client: %v", err) + } ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) - createTopicResponse, err := s.authProvider.SnsSqs(ctx).Sns.CreateTopicWithContext(ctx, snsCreateTopicInput) + createTopicResponse, err := clients.Sns.CreateTopicWithContext(ctx, snsCreateTopicInput) cancelFn() if err != nil { return "", fmt.Errorf("error while creating an SNS topic: %w", err) @@ -227,9 +236,13 @@ func (s *snsSqs) createTopic(parentCtx context.Context, topic string) (string, e } func (s *snsSqs) getTopicArn(parentCtx context.Context, topic string) (string, error) { + clients, err := s.authProvider.SnsSqs(parentCtx) + if err != nil { + return "", fmt.Errorf("failed to get client: %v", err) + } ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) arn := s.buildARN("sns", topic) - getTopicOutput, err := s.authProvider.SnsSqs(ctx).Sns.GetTopicAttributesWithContext(ctx, &sns.GetTopicAttributesInput{ + getTopicOutput, err := clients.Sns.GetTopicAttributesWithContext(ctx, &sns.GetTopicAttributesInput{ TopicArn: &arn, }) cancelFn() @@ -295,15 +308,21 @@ func (s *snsSqs) createQueue(parentCtx context.Context, queueName string) (*sqsQ attributes := map[string]*string{"FifoQueue": aws.String("true"), "ContentBasedDeduplication": aws.String("true")} sqsCreateQueueInput.SetAttributes(attributes) } + + clients, err := s.authProvider.SnsSqs(parentCtx) + if err != nil { + return nil, fmt.Errorf("failed to get client: %v", err) + } + ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - createQueueResponse, err := s.authProvider.SnsSqs(ctx).Sqs.CreateQueueWithContext(ctx, sqsCreateQueueInput) + createQueueResponse, err := clients.Sqs.CreateQueueWithContext(ctx, sqsCreateQueueInput) cancel() if err != nil { return nil, fmt.Errorf("error creaing an SQS queue: %w", err) } ctx, cancel = context.WithTimeout(parentCtx, s.opsTimeout) - queueAttributesResponse, err := s.authProvider.SnsSqs(ctx).Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ + queueAttributesResponse, err := clients.Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ AttributeNames: []*string{aws.String("QueueArn")}, QueueUrl: createQueueResponse.QueueUrl, }) @@ -319,8 +338,13 @@ func (s *snsSqs) createQueue(parentCtx context.Context, queueName string) (*sqsQ } func (s *snsSqs) getQueueArn(parentCtx context.Context, queueName string) (*sqsQueueInfo, error) { + clients, err := s.authProvider.SnsSqs(parentCtx) + if err != nil { + return nil, fmt.Errorf("failed to get client: %v", err) + } + ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - queueURLOutput, err := s.authProvider.SnsSqs(ctx).Sqs.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{QueueName: aws.String(queueName), QueueOwnerAWSAccountId: aws.String(s.metadata.AccountID)}) + queueURLOutput, err := clients.Sqs.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{QueueName: aws.String(queueName), QueueOwnerAWSAccountId: aws.String(s.metadata.AccountID)}) cancel() if err != nil { return nil, fmt.Errorf("error: %w while getting url of queue: %s", err, queueName) @@ -328,7 +352,7 @@ func (s *snsSqs) getQueueArn(parentCtx context.Context, queueName string) (*sqsQ url := queueURLOutput.QueueUrl ctx, cancel = context.WithTimeout(parentCtx, s.opsTimeout) - getQueueOutput, err := s.authProvider.SnsSqs(ctx).Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{QueueUrl: url, AttributeNames: []*string{aws.String("QueueArn")}}) + getQueueOutput, err := clients.Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{QueueUrl: url, AttributeNames: []*string{aws.String("QueueArn")}}) cancel() if err != nil { return nil, fmt.Errorf("error: %w while getting information for queue: %s, with url: %s", err, queueName, *url) @@ -388,8 +412,13 @@ func (s *snsSqs) getMessageGroupID(req *pubsub.PublishRequest) *string { } func (s *snsSqs) createSnsSqsSubscription(parentCtx context.Context, queueArn, topicArn string) (string, error) { + clients, err := s.authProvider.SnsSqs(parentCtx) + if err != nil { + return "", fmt.Errorf("failed to get client: %v", err) + } + ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - subscribeOutput, err := s.authProvider.SnsSqs(ctx).Sns.SubscribeWithContext(ctx, &sns.SubscribeInput{ + subscribeOutput, err := clients.Sns.SubscribeWithContext(ctx, &sns.SubscribeInput{ Attributes: nil, Endpoint: aws.String(queueArn), // create SQS queue per subscription. Protocol: aws.String("sqs"), @@ -408,8 +437,12 @@ func (s *snsSqs) createSnsSqsSubscription(parentCtx context.Context, queueArn, t } func (s *snsSqs) getSnsSqsSubscriptionArn(parentCtx context.Context, topicArn string) (string, error) { + clients, err := s.authProvider.SnsSqs(parentCtx) + if err != nil { + return "", fmt.Errorf("failed to get client: %v", err) + } ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - listSubscriptionsOutput, err := s.authProvider.SnsSqs(ctx).Sns.ListSubscriptionsByTopicWithContext(ctx, &sns.ListSubscriptionsByTopicInput{TopicArn: aws.String(topicArn)}) + listSubscriptionsOutput, err := clients.Sns.ListSubscriptionsByTopicWithContext(ctx, &sns.ListSubscriptionsByTopicInput{TopicArn: aws.String(topicArn)}) cancel() if err != nil { return "", fmt.Errorf("error listing subsriptions for topic arn: %v: %w", topicArn, err) @@ -457,8 +490,12 @@ func (s *snsSqs) getOrCreateSnsSqsSubscription(ctx context.Context, queueArn, to } func (s *snsSqs) acknowledgeMessage(parentCtx context.Context, queueURL string, receiptHandle *string) error { + clients, err := s.authProvider.SnsSqs(parentCtx) + if err != nil { + return fmt.Errorf("failed to get client: %v", err) + } ctx, cancelFn := context.WithCancel(parentCtx) - _, err := s.authProvider.SnsSqs(ctx).Sqs.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{ + _, err = clients.Sqs.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{ QueueUrl: aws.String(queueURL), ReceiptHandle: receiptHandle, }) @@ -471,9 +508,14 @@ func (s *snsSqs) acknowledgeMessage(parentCtx context.Context, queueURL string, } func (s *snsSqs) resetMessageVisibilityTimeout(parentCtx context.Context, queueURL string, receiptHandle *string) error { + clients, err := s.authProvider.SnsSqs(parentCtx) + if err != nil { + return fmt.Errorf("failed to get client: %v", err) + } + ctx, cancelFn := context.WithCancel(parentCtx) // reset the timeout to its initial value so that the remaining timeout would be overridden by the initial value for other consumer to attempt processing. - _, err := s.authProvider.SnsSqs(ctx).Sqs.ChangeMessageVisibilityWithContext(ctx, &sqs.ChangeMessageVisibilityInput{ + _, err = clients.Sqs.ChangeMessageVisibilityWithContext(ctx, &sqs.ChangeMessageVisibilityInput{ QueueUrl: aws.String(queueURL), ReceiptHandle: receiptHandle, VisibilityTimeout: aws.Int64(0), @@ -601,11 +643,16 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters break } - // Internally, by default, aws go sdk performs 3 retires with exponential backoff to contact + clients, err := s.authProvider.SnsSqs(ctx) + if err != nil { + s.logger.Errorf("failed to get client: %v", err) + } + + // Internally, by default, aws go sdk performs 3 retries with exponential backoff to contact // sqs and try pull messages. Since we are iteratively short polling (based on the defined // s.metadata.messageWaitTimeSeconds) the sdk backoff is not effective as it gets reset per each polling // iteration. Therefore, a global backoff (to the internal backoff) is used (sqsPullExponentialBackoff). - messageResponse, err := s.authProvider.SnsSqs(ctx).Sqs.ReceiveMessageWithContext(ctx, receiveMessageInput) + messageResponse, err := clients.Sqs.ReceiveMessageWithContext(ctx, receiveMessageInput) if err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || ctx.Err() != nil { s.logger.Warn("context canceled; stopping consuming from queue arn: %v", queueInfo.arn) @@ -698,8 +745,13 @@ func (s *snsSqs) setDeadLettersQueueAttributes(parentCtx context.Context, queueI return wrappedErr } + clients, err := s.authProvider.SnsSqs(parentCtx) + if err != nil { + s.logger.Errorf("failed to get client: %v", err) + } + ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) - _, derr = s.authProvider.SnsSqs(ctx).Sqs.SetQueueAttributesWithContext(ctx, sqsSetQueueAttributesInput) + _, derr = clients.Sqs.SetQueueAttributesWithContext(ctx, sqsSetQueueAttributesInput) cancelFn() if derr != nil { wrappedErr := fmt.Errorf("error updating queue attributes with dead-letter queue: %w", derr) @@ -717,9 +769,14 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(parentCtx context.Context, return nil } + clients, err := s.authProvider.SnsSqs(parentCtx) + if err != nil { + return fmt.Errorf("failed to get client: %v", err) + } + ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) // only permit SNS to send messages to SQS using the created subscription. - getQueueAttributesOutput, err := s.authProvider.SnsSqs(ctx).Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ + getQueueAttributesOutput, err := clients.Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ QueueUrl: &sqsQueueInfo.url, AttributeNames: []*string{aws.String(sqs.QueueAttributeNamePolicy)}, }) @@ -745,8 +802,13 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(parentCtx context.Context, return fmt.Errorf("failed serializing new sqs policy: %w", uerr) } + clients, err = s.authProvider.SnsSqs(parentCtx) + if err != nil { + return fmt.Errorf("failed to get client: %v", err) + } + ctx, cancelFn = context.WithTimeout(parentCtx, s.opsTimeout) - _, err = s.authProvider.SnsSqs(ctx).Sqs.SetQueueAttributesWithContext(ctx, &(sqs.SetQueueAttributesInput{ + _, err = clients.Sqs.SetQueueAttributesWithContext(ctx, &(sqs.SetQueueAttributesInput{ Attributes: map[string]*string{ "Policy": aws.String(string(b)), }, @@ -858,8 +920,13 @@ func (s *snsSqs) Publish(ctx context.Context, req *pubsub.PublishRequest) error snsPublishInput.MessageGroupId = s.getMessageGroupID(req) } + clients, err := s.authProvider.SnsSqs(ctx) + if err != nil { + return fmt.Errorf("failed to get client: %v", err) + } + // sns client has internal exponential backoffs. - _, err = s.authProvider.SnsSqs(ctx).Sns.PublishWithContext(ctx, snsPublishInput) + _, err = clients.Sns.PublishWithContext(ctx, snsPublishInput) if err != nil { wrappedErr := fmt.Errorf("error publishing to topic: %s with topic ARN %s: %w", req.Topic, topicArn, err) s.logger.Error(wrappedErr) diff --git a/pubsub/aws/snssqs/snssqs_test.go b/pubsub/aws/snssqs/snssqs_test.go index 1c789b67be..f396ddef47 100644 --- a/pubsub/aws/snssqs/snssqs_test.go +++ b/pubsub/aws/snssqs/snssqs_test.go @@ -38,7 +38,7 @@ func Test_parseTopicArn(t *testing.T) { } // Verify that all metadata ends up in the correct spot. -func Test_getSnsSqsMetatdata_AllConfiguration(t *testing.T) { +func Test_getSnsSqsMetadata_AllConfiguration(t *testing.T) { t.Parallel() r := require.New(t) l := logger.NewLogger("SnsSqs unit test") @@ -47,7 +47,7 @@ func Test_getSnsSqsMetatdata_AllConfiguration(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "consumer", "Endpoint": "endpoint", "concurrencyMode": string(pubsub.Single), @@ -80,7 +80,7 @@ func Test_getSnsSqsMetatdata_AllConfiguration(t *testing.T) { r.Equal(int64(6), md.MessageReceiveLimit) } -func Test_getSnsSqsMetatdata_defaults(t *testing.T) { +func Test_getSnsSqsMetadata_defaults(t *testing.T) { t.Parallel() r := require.New(t) l := logger.NewLogger("SnsSqs unit test") @@ -89,7 +89,7 @@ func Test_getSnsSqsMetatdata_defaults(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "c", "accessKey": "a", "secretKey": "s", @@ -114,7 +114,7 @@ func Test_getSnsSqsMetatdata_defaults(t *testing.T) { r.False(md.DisableDeleteOnRetryLimit) } -func Test_getSnsSqsMetatdata_legacyaliases(t *testing.T) { +func Test_getSnsSqsMetadata_legacyaliases(t *testing.T) { t.Parallel() r := require.New(t) l := logger.NewLogger("SnsSqs unit test") @@ -123,7 +123,7 @@ func Test_getSnsSqsMetatdata_legacyaliases(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "consumer", "awsAccountID": "acctId", "awsSecret": "secret", @@ -151,13 +151,13 @@ func testMetadataParsingShouldFail(t *testing.T, metadata pubsub.Metadata, l log logger: l, } - md, err := ps.getSnsSqsMetatdata(metadata) + md, err := ps.getSnsSqsMetadata(metadata) r.Error(err) r.Nil(md) } -func Test_getSnsSqsMetatdata_invalidMetadataSetup(t *testing.T) { +func Test_getSnsSqsMetadata_invalidMetadataSetup(t *testing.T) { t.Parallel() fixtures := []testUnitFixture{ @@ -432,7 +432,7 @@ func Test_buildARN_DefaultPartition(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "c", "accessKey": "a", "secretKey": "s", @@ -455,7 +455,7 @@ func Test_buildARN_StandardPartition(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "c", "accessKey": "a", "secretKey": "s", @@ -478,7 +478,7 @@ func Test_buildARN_NonStandardPartition(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "c", "accessKey": "a", "secretKey": "s", diff --git a/secretstores/aws/parameterstore/parameterstore.go b/secretstores/aws/parameterstore/parameterstore.go index 84c2dc05a4..a80507f7d3 100644 --- a/secretstores/aws/parameterstore/parameterstore.go +++ b/secretstores/aws/parameterstore/parameterstore.go @@ -64,23 +64,20 @@ func (s *ssmSecretStore) Init(ctx context.Context, metadata secretstores.Metadat return err } - if s.authProvider == nil { - opts := awsAuth.Options{ - Logger: s.logger, - Properties: metadata.Properties, - Region: m.Region, - AccessKey: m.AccessKey, - SecretKey: m.SecretKey, - SessionToken: "", - } - // extra configs needed per component type - provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) - if err != nil { - return err - } - s.authProvider = provider + opts := awsAuth.Options{ + Logger: s.logger, + Properties: metadata.Properties, + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: "", } - + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) + if err != nil { + return err + } + s.authProvider = provider s.prefix = m.Prefix return nil @@ -95,8 +92,11 @@ func (s *ssmSecretStore) GetSecret(ctx context.Context, req secretstores.GetSecr versionID = value name = fmt.Sprintf("%s:%s", req.Name, versionID) } - - output, err := s.authProvider.ParameterStore(ctx).Store.GetParameterWithContext(ctx, &ssm.GetParameterInput{ + clients, err := s.authProvider.ParameterStore(ctx) + if err != nil { + return secretstores.GetSecretResponse{Data: nil}, fmt.Errorf("failed to get client: %v", err) + } + output, err := clients.Store.GetParameterWithContext(ctx, &ssm.GetParameterInput{ Name: ptr.Of(s.prefix + name), WithDecryption: ptr.Of(true), }) @@ -136,7 +136,11 @@ func (s *ssmSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bul } for search { - output, err := s.authProvider.ParameterStore(ctx).Store.DescribeParametersWithContext(ctx, &ssm.DescribeParametersInput{ + clients, err := s.authProvider.ParameterStore(ctx) + if err != nil { + return secretstores.BulkGetSecretResponse{Data: nil}, fmt.Errorf("failed to get client: %v", err) + } + output, err := clients.Store.DescribeParametersWithContext(ctx, &ssm.DescribeParametersInput{ MaxResults: nil, NextToken: nextToken, ParameterFilters: filters, @@ -146,7 +150,11 @@ func (s *ssmSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bul } for _, entry := range output.Parameters { - params, err := s.authProvider.ParameterStore(ctx).Store.GetParameterWithContext(ctx, &ssm.GetParameterInput{ + clients, err = s.authProvider.ParameterStore(ctx) + if err != nil { + return secretstores.BulkGetSecretResponse{Data: nil}, fmt.Errorf("failed to get client: %v", err) + } + params, err := clients.Store.GetParameterWithContext(ctx, &ssm.GetParameterInput{ Name: entry.Name, WithDecryption: aws.Bool(true), }) @@ -185,5 +193,5 @@ func (s *ssmSecretStore) GetComponentMetadata() (metadataInfo metadata.MetadataM } func (s *ssmSecretStore) Close() error { - return nil + return s.authProvider.Close() } diff --git a/secretstores/aws/parameterstore/parameterstore_test.go b/secretstores/aws/parameterstore/parameterstore_test.go index 8d9bcf6065..b07e1467f3 100644 --- a/secretstores/aws/parameterstore/parameterstore_test.go +++ b/secretstores/aws/parameterstore/parameterstore_test.go @@ -22,6 +22,8 @@ import ( "testing" "github.com/aws/aws-sdk-go/aws" + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/ssm" "github.com/aws/aws-sdk-go/service/ssm/ssmiface" @@ -68,21 +70,34 @@ func TestInit(t *testing.T) { func TestGetSecret(t *testing.T) { t.Run("successfully retrieve secret", func(t *testing.T) { t.Run("with valid path", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - secret := secretValue - - return &ssm.GetParameterOutput{ - Parameter: &ssm.Parameter{ - Name: input.Name, - Value: &secret, - }, - }, nil - }, + mockSSM := &mockedSSM{ + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + secret := secretValue + return &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: input.Name, + Value: &secret, + }, + }, nil }, } + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := ssmSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/dev/secret", Metadata: map[string]string{}, @@ -93,25 +108,39 @@ func TestGetSecret(t *testing.T) { }) t.Run("with version id", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - secret := secretValue - keys := strings.Split(*input.Name, ":") - assert.NotNil(t, keys) - assert.Len(t, keys, 2) - assert.Equalf(t, "1", keys[1], "Version IDs are same") - - return &ssm.GetParameterOutput{ - Parameter: &ssm.Parameter{ - Name: &keys[0], - Value: &secret, - }, - }, nil - }, + mockSSM := &mockedSSM{ + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + secret := secretValue + keys := strings.Split(*input.Name, ":") + assert.NotNil(t, keys) + assert.Len(t, keys, 2) + assert.Equalf(t, "1", keys[1], "Version IDs are same") + + return &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: &keys[0], + Value: &secret, + }, + }, nil }, } + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := ssmSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/dev/secret", Metadata: map[string]string{ @@ -124,21 +153,35 @@ func TestGetSecret(t *testing.T) { }) t.Run("with prefix", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - assert.Equal(t, "/prefix/aws/dev/secret", *input.Name) - secret := secretValue - - return &ssm.GetParameterOutput{ - Parameter: &ssm.Parameter{ - Name: input.Name, - Value: &secret, - }, - }, nil - }, + mockSSM := &mockedSSM{ + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + assert.Equal(t, "/prefix/aws/dev/secret", *input.Name) + secret := secretValue + + return &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: input.Name, + Value: &secret, + }, + }, nil }, - prefix: "/prefix", + } + + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := ssmSecretStore{ + authProvider: mockAuthProvider, + prefix: "/prefix", } req := secretstores.GetSecretRequest{ @@ -152,13 +195,29 @@ func TestGetSecret(t *testing.T) { }) t.Run("unsuccessfully retrieve secret", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - return nil, errors.New("failed due to any reason") - }, + mockSSM := &mockedSSM{ + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + return nil, errors.New("failed due to any reason") }, } + + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := ssmSecretStore{ + authProvider: mockAuthProvider, + prefix: "/prefix", + } + req := secretstores.GetSecretRequest{ Name: "/aws/dev/secret", Metadata: map[string]string{}, @@ -170,31 +229,44 @@ func TestGetSecret(t *testing.T) { func TestGetBulkSecrets(t *testing.T) { t.Run("successfully retrieve bulk secrets", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { - return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ - { - Name: aws.String("/aws/dev/secret1"), - }, - { - Name: aws.String("/aws/dev/secret2"), - }, - }}, nil - }, - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - secret := fmt.Sprintf("%s-%s", *input.Name, secretValue) + mockSSM := &mockedSSM{ + DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { + return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ + { + Name: aws.String("/aws/dev/secret1"), + }, + { + Name: aws.String("/aws/dev/secret2"), + }, + }}, nil + }, + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + secret := fmt.Sprintf("%s-%s", *input.Name, secretValue) - return &ssm.GetParameterOutput{ - Parameter: &ssm.Parameter{ - Name: input.Name, - Value: &secret, - }, - }, nil - }, + return &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: input.Name, + Value: &secret, + }, + }, nil }, } + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + s := ssmSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.BulkGetSecretRequest{ Metadata: map[string]string{}, } @@ -205,30 +277,43 @@ func TestGetBulkSecrets(t *testing.T) { }) t.Run("successfully retrieve bulk secrets with prefix", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { - return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ - { - Name: aws.String("/prefix/aws/dev/secret1"), - }, - { - Name: aws.String("/prefix/aws/dev/secret2"), - }, - }}, nil - }, - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - secret := fmt.Sprintf("%s-%s", *input.Name, secretValue) + mockSSM := &mockedSSM{ + DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { + return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ + { + Name: aws.String("/prefix/aws/dev/secret1"), + }, + { + Name: aws.String("/prefix/aws/dev/secret2"), + }, + }}, nil + }, + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + secret := fmt.Sprintf("%s-%s", *input.Name, secretValue) - return &ssm.GetParameterOutput{ - Parameter: &ssm.Parameter{ - Name: input.Name, - Value: &secret, - }, - }, nil - }, + return &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: input.Name, + Value: &secret, + }, + }, nil }, - prefix: "/prefix", + } + + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + s := ssmSecretStore{ + authProvider: mockAuthProvider, + prefix: "/prefix", } req := secretstores.BulkGetSecretRequest{ @@ -241,23 +326,37 @@ func TestGetBulkSecrets(t *testing.T) { }) t.Run("unsuccessfully retrieve bulk secrets on get parameter", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { - return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ - { - Name: aws.String("/aws/dev/secret1"), - }, - { - Name: aws.String("/aws/dev/secret2"), - }, - }}, nil - }, - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - return nil, errors.New("failed due to any reason") - }, + mockSSM := &mockedSSM{ + DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { + return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ + { + Name: aws.String("/aws/dev/secret1"), + }, + { + Name: aws.String("/aws/dev/secret2"), + }, + }}, nil + }, + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + return nil, errors.New("failed due to any reason") }, } + + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + s := ssmSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.BulkGetSecretRequest{ Metadata: map[string]string{}, } @@ -266,13 +365,27 @@ func TestGetBulkSecrets(t *testing.T) { }) t.Run("unsuccessfully retrieve bulk secrets on describe parameter", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { - return nil, errors.New("failed due to any reason") - }, + mockSSM := &mockedSSM{ + DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { + return nil, errors.New("failed due to any reason") }, } + + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + s := ssmSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.BulkGetSecretRequest{ Metadata: map[string]string{}, } diff --git a/secretstores/aws/secretmanager/secretmanager.go b/secretstores/aws/secretmanager/secretmanager.go index e50068b109..f67e0f459b 100644 --- a/secretstores/aws/secretmanager/secretmanager.go +++ b/secretstores/aws/secretmanager/secretmanager.go @@ -59,22 +59,19 @@ func (s *smSecretStore) Init(ctx context.Context, metadata secretstores.Metadata return err } - if s.authProvider == nil { - opts := awsAuth.Options{ - Logger: s.logger, - Region: meta.Region, - AccessKey: meta.AccessKey, - SecretKey: meta.SecretKey, - SessionToken: "", - } - - provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) - if err != nil { - return err - } - s.authProvider = provider + opts := awsAuth.Options{ + Logger: s.logger, + Region: meta.Region, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: "", } + provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) + if err != nil { + return err + } + s.authProvider = provider return nil } @@ -88,8 +85,11 @@ func (s *smSecretStore) GetSecret(ctx context.Context, req secretstores.GetSecre if value, ok := req.Metadata[VersionStage]; ok { versionStage = &value } - - output, err := s.authProvider.SecretManager(ctx).Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ + clients, err := s.authProvider.SecretManager(ctx) + if err != nil { + return secretstores.GetSecretResponse{Data: nil}, fmt.Errorf("failed to get client: %v", err) + } + output, err := clients.Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ SecretId: &req.Name, VersionId: versionID, VersionStage: versionStage, @@ -118,7 +118,11 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk var nextToken *string = nil for search { - output, err := s.authProvider.SecretManager(ctx).Manager.ListSecretsWithContext(ctx, &secretsmanager.ListSecretsInput{ + clients, err := s.authProvider.SecretManager(ctx) + if err != nil { + return secretstores.BulkGetSecretResponse{Data: nil}, fmt.Errorf("failed to get client: %v", err) + } + output, err := clients.Manager.ListSecretsWithContext(ctx, &secretsmanager.ListSecretsInput{ MaxResults: nil, NextToken: nextToken, }) @@ -127,7 +131,7 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk } for _, entry := range output.SecretList { - secrets, err := s.authProvider.SecretManager(ctx).Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ + secrets, err := clients.Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ SecretId: entry.Name, }) if err != nil { @@ -173,5 +177,5 @@ func (s *smSecretStore) GetComponentMetadata() (metadataInfo metadata.MetadataMa } func (s *smSecretStore) Close() error { - return nil + return s.authProvider.Close() } diff --git a/state/aws/dynamodb/dynamodb.go b/state/aws/dynamodb/dynamodb.go index a1fa210ab2..cd1e5a2df2 100644 --- a/state/aws/dynamodb/dynamodb.go +++ b/state/aws/dynamodb/dynamodb.go @@ -41,8 +41,8 @@ import ( type StateStore struct { state.BulkStore - logger logger.Logger authProvider awsAuth.Provider + logger logger.Logger table string ttlAttributeName string partitionKey string @@ -67,9 +67,10 @@ const ( ) // NewDynamoDBStateStore returns a new dynamoDB state store. -func NewDynamoDBStateStore(_ logger.Logger) state.Store { +func NewDynamoDBStateStore(logger logger.Logger) state.Store { s := &StateStore{ partitionKey: defaultPartitionKeyName, + logger: logger, } s.BulkStore = state.NewDefaultBulkStore(s) return s @@ -81,24 +82,20 @@ func (d *StateStore) Init(ctx context.Context, metadata state.Metadata) error { if err != nil { return err } - - // This check is needed because d.client is set to a mock in tests - if d.authProvider == nil { - opts := awsAuth.Options{ - Logger: d.logger, - Properties: metadata.Properties, - Region: meta.Region, - Endpoint: meta.Endpoint, - AccessKey: meta.AccessKey, - SecretKey: meta.SecretKey, - SessionToken: meta.SessionToken, - } - provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) - if err != nil { - return err - } - d.authProvider = provider + opts := awsAuth.Options{ + Logger: d.logger, + Properties: metadata.Properties, + Region: meta.Region, + Endpoint: meta.Endpoint, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, } + provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) + if err != nil { + return err + } + d.authProvider = provider d.table = meta.Table d.ttlAttributeName = meta.TTLAttributeName d.partitionKey = meta.PartitionKey @@ -122,8 +119,11 @@ func (d *StateStore) validateTableAccess(ctx context.Context) error { }, }, } - - _, err := d.authProvider.DynamoDBI(ctx).DynamoDB.GetItemWithContext(ctx, input) + clients, err := d.authProvider.DynamoDB(ctx) + if err != nil { + return fmt.Errorf("failed to get client: %v", err) + } + _, err = clients.DynamoDB.GetItemWithContext(ctx, input) return err } @@ -155,8 +155,11 @@ func (d *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.Get }, }, } - - result, err := d.authProvider.DynamoDBI(ctx).DynamoDB.GetItemWithContext(ctx, input) + clients, err := d.authProvider.DynamoDB(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get client: %v", err) + } + result, err := clients.DynamoDB.GetItemWithContext(ctx, input) if err != nil { return nil, err } @@ -228,8 +231,11 @@ func (d *StateStore) Set(ctx context.Context, req *state.SetRequest) error { condExpr := "attribute_not_exists(etag)" input.ConditionExpression = &condExpr } - - _, err = d.authProvider.DynamoDBI(ctx).DynamoDB.PutItemWithContext(ctx, input) + clients, err := d.authProvider.DynamoDB(ctx) + if err != nil { + return fmt.Errorf("failed to get client: %v", err) + } + _, err = clients.DynamoDB.PutItemWithContext(ctx, input) if err != nil && req.HasETag() { switch cErr := err.(type) { case *dynamodb.ConditionalCheckFailedException: @@ -260,8 +266,11 @@ func (d *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error } input.ExpressionAttributeValues = exprAttrValues } - - _, err := d.authProvider.DynamoDBI(ctx).DynamoDB.DeleteItemWithContext(ctx, input) + clients, err := d.authProvider.DynamoDB(ctx) + if err != nil { + return fmt.Errorf("failed to get client: %v", err) + } + _, err = clients.DynamoDB.DeleteItemWithContext(ctx, input) if err != nil { switch cErr := err.(type) { case *dynamodb.ConditionalCheckFailedException: @@ -432,8 +441,11 @@ func (d *StateStore) Multi(ctx context.Context, request *state.TransactionalStat } twinput.TransactItems = append(twinput.TransactItems, twi) } - - _, err := d.authProvider.DynamoDBI(ctx).DynamoDB.TransactWriteItemsWithContext(ctx, twinput) + clients, err := d.authProvider.DynamoDB(ctx) + if err != nil { + return fmt.Errorf("failed to get client: %v", err) + } + _, err = clients.DynamoDB.TransactWriteItemsWithContext(ctx, twinput) return err } diff --git a/state/aws/dynamodb/dynamodb_test.go b/state/aws/dynamodb/dynamodb_test.go index d1f98b70ba..ca0a97109a 100644 --- a/state/aws/dynamodb/dynamodb_test.go +++ b/state/aws/dynamodb/dynamodb_test.go @@ -21,6 +21,8 @@ import ( "testing" "time" + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/dynamodb" @@ -74,16 +76,30 @@ func (m *mockedDynamoDB) TransactWriteItemsWithContext(ctx context.Context, inpu func TestInit(t *testing.T) { m := state.Metadata{} - s := &StateStore{ - partitionKey: defaultPartitionKeyName, - client: &mockedDynamoDB{ - // We're adding this so we can pass the connection check on Init - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { - return nil, nil - }, + mockedDb := &mockedDynamoDB{ + // We're adding this so we can pass the connection check on Init + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { + return nil, nil }, } + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } + t.Run("NewDynamoDBStateStore Default Partition Key", func(t *testing.T) { assert.NotNil(t, s) assert.Equal(t, defaultPartitionKeyName, s.partitionKey) @@ -137,12 +153,29 @@ func TestInit(t *testing.T) { "Region": "eu-west-1", } - s.client = &mockedDynamoDB{ + mockedDb := &mockedDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { return nil, errors.New("Requested resource not found") }, } + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } + err := s.Init(context.Background(), m) require.Error(t, err) require.EqualError(t, err, "error validating DynamoDB table 'does-not-exist' access: Requested resource not found") @@ -151,10 +184,7 @@ func TestInit(t *testing.T) { func TestGet(t *testing.T) { t.Run("Successfully retrieve item", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDb := &mockedDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return &dynamodb.GetItemOutput{ Item: map[string]*dynamodb.AttributeValue{ @@ -172,6 +202,22 @@ func TestGet(t *testing.T) { }, } + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } req := &state.GetRequest{ Key: "someKey", Metadata: nil, @@ -179,34 +225,48 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.NoError(t, err) assert.Equal(t, []byte("some value"), out.Data) assert.Equal(t, "1bdead4badc0ffee", *out.ETag) assert.NotContains(t, out.Metadata, "ttlExpireTime") }) t.Run("Successfully retrieve item (with unexpired ttl)", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { - return &dynamodb.GetItemOutput{ - Item: map[string]*dynamodb.AttributeValue{ - "key": { - S: aws.String("someKey"), - }, - "value": { - S: aws.String("some value"), - }, - "testAttributeName": { - N: aws.String("4074862051"), - }, - "etag": { - S: aws.String("1bdead4badc0ffee"), - }, + mockedDb := &mockedDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + return &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("someKey"), + }, + "value": { + S: aws.String("some value"), + }, + "testAttributeName": { + N: aws.String("4074862051"), }, - }, nil - }, + "etag": { + S: aws.String("1bdead4badc0ffee"), + }, + }, + }, nil }, + } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, ttlAttributeName: "testAttributeName", } req := &state.GetRequest{ @@ -216,7 +276,7 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.NoError(t, err) assert.Equal(t, []byte("some value"), out.Data) assert.Equal(t, "1bdead4badc0ffee", *out.ETag) @@ -226,27 +286,41 @@ func TestGet(t *testing.T) { assert.Equal(t, int64(4074862051), expireTime.Unix()) }) t.Run("Successfully retrieve item (with expired ttl)", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { - return &dynamodb.GetItemOutput{ - Item: map[string]*dynamodb.AttributeValue{ - "key": { - S: aws.String("someKey"), - }, - "value": { - S: aws.String("some value"), - }, - "testAttributeName": { - N: aws.String("35489251"), - }, - "etag": { - S: aws.String("1bdead4badc0ffee"), - }, + mockedDb := &mockedDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + return &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("someKey"), + }, + "value": { + S: aws.String("some value"), + }, + "testAttributeName": { + N: aws.String("35489251"), }, - }, nil - }, + "etag": { + S: aws.String("1bdead4badc0ffee"), + }, + }, + }, nil }, + } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, ttlAttributeName: "testAttributeName", } req := &state.GetRequest{ @@ -256,20 +330,35 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.NoError(t, err) assert.Nil(t, out.Data) assert.Nil(t, out.ETag) assert.Nil(t, out.Metadata) }) t.Run("Unsuccessfully get item", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { - return nil, errors.New("failed to retrieve data") - }, + mockedDb := &mockedDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + return nil, errors.New("failed to retrieve data") }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, + } + req := &state.GetRequest{ Key: "key", Metadata: nil, @@ -277,20 +366,34 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.Error(t, err) assert.Nil(t, out) }) t.Run("Unsuccessfully with empty response", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { - return &dynamodb.GetItemOutput{ - Item: map[string]*dynamodb.AttributeValue{}, - }, nil - }, + mockedDb := &mockedDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + return &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{}, + }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, + } req := &state.GetRequest{ Key: "key", Metadata: nil, @@ -298,26 +401,40 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.NoError(t, err) assert.Nil(t, out.Data) assert.Nil(t, out.ETag) assert.Nil(t, out.Metadata) }) t.Run("Unsuccessfully with no required key", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { - return &dynamodb.GetItemOutput{ - Item: map[string]*dynamodb.AttributeValue{ - "value2": { - S: aws.String("value"), - }, + mockedDb := &mockedDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + return &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{ + "value2": { + S: aws.String("value"), }, - }, nil - }, + }, + }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, + } req := &state.GetRequest{ Key: "key", Metadata: nil, @@ -325,7 +442,7 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.NoError(t, err) assert.Empty(t, out.Data) assert.Nil(t, out.ETag) @@ -338,10 +455,7 @@ func TestSet(t *testing.T) { } t.Run("Successfully set item", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDb := &mockedDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -360,21 +474,36 @@ func TestSet(t *testing.T) { }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } + req := &state.SetRequest{ Key: "key", Value: value{ Value: "value", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Successfully set item with matching etag", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDb := &mockedDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -397,6 +526,23 @@ func TestSet(t *testing.T) { }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } etag := "1bdead4badc0ffee" req := &state.SetRequest{ ETag: &etag, @@ -405,15 +551,12 @@ func TestSet(t *testing.T) { Value: "value", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Unsuccessfully set item with mismatched etag", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDb := &mockedDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -431,6 +574,23 @@ func TestSet(t *testing.T) { return nil, &checkErr }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } etag := "bogusetag" req := &state.SetRequest{ ETag: &etag, @@ -440,7 +600,7 @@ func TestSet(t *testing.T) { }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.Error(t, err) switch tagErr := err.(type) { case *state.ETagError: @@ -451,10 +611,7 @@ func TestSet(t *testing.T) { }) t.Run("Successfully set item with first-write-concurrency", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDb := &mockedDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -474,6 +631,23 @@ func TestSet(t *testing.T) { }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } req := &state.SetRequest{ Key: "key", Value: value{ @@ -483,15 +657,12 @@ func TestSet(t *testing.T) { Concurrency: state.FirstWrite, }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Unsuccessfully set item with first-write-concurrency", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDb := &mockedDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -506,6 +677,23 @@ func TestSet(t *testing.T) { return nil, &checkErr }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } req := &state.SetRequest{ Key: "key", Value: value{ @@ -515,7 +703,7 @@ func TestSet(t *testing.T) { Concurrency: state.FirstWrite, }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.Error(t, err) switch err.(type) { case *state.ETagError: @@ -525,10 +713,7 @@ func TestSet(t *testing.T) { }) t.Run("Successfully set item with ttl = -1", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDb := &mockedDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Len(t, input.Item, 4) result := DynamoDBItem{} @@ -547,7 +732,23 @@ func TestSet(t *testing.T) { }, nil }, } - ss.ttlAttributeName = "testAttributeName" + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, + ttlAttributeName: "testAttributeName", + } req := &state.SetRequest{ Key: "someKey", @@ -558,14 +759,11 @@ func TestSet(t *testing.T) { "ttlInSeconds": "-1", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Successfully set item with 'correct' ttl", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDb := &mockedDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Len(t, input.Item, 4) result := DynamoDBItem{} @@ -584,7 +782,24 @@ func TestSet(t *testing.T) { }, nil }, } - ss.ttlAttributeName = "testAttributeName" + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + ttlAttributeName: "testAttributeName", + } req := &state.SetRequest{ Key: "someKey", @@ -595,33 +810,44 @@ func TestSet(t *testing.T) { "ttlInSeconds": "180", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Unsuccessfully set item", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDb := &mockedDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { return nil, errors.New("unable to put item") }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } req := &state.SetRequest{ Key: "key", Value: value{ Value: "value", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.Error(t, err) }) t.Run("Successfully set item with correct ttl but without component metadata", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDb := &mockedDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("someKey"), @@ -640,6 +866,23 @@ func TestSet(t *testing.T) { }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } req := &state.SetRequest{ Key: "someKey", Value: value{ @@ -649,34 +892,48 @@ func TestSet(t *testing.T) { "ttlInSeconds": "180", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Unsuccessfully set item with ttl (invalid value)", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { - assert.Equal(t, map[string]*dynamodb.AttributeValue{ - "key": { - S: aws.String("somekey"), - }, - "value": { - S: aws.String(`{"Value":"somevalue"}`), - }, - "ttlInSeconds": { - N: aws.String("180"), - }, - }, input.Item) + mockedDb := &mockedDynamoDB{ + PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { + assert.Equal(t, map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("somekey"), + }, + "value": { + S: aws.String(`{"Value":"somevalue"}`), + }, + "ttlInSeconds": { + N: aws.String("180"), + }, + }, input.Item) - return &dynamodb.PutItemOutput{ - Attributes: map[string]*dynamodb.AttributeValue{ - "key": { - S: aws.String("value"), - }, + return &dynamodb.PutItemOutput{ + Attributes: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("value"), }, - }, nil - }, + }, + }, nil }, + } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, ttlAttributeName: "testAttributeName", } req := &state.SetRequest{ @@ -688,7 +945,7 @@ func TestSet(t *testing.T) { "ttlInSeconds": "invalidvalue", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.Error(t, err) assert.Equal(t, "dynamodb error: failed to parse ttlInSeconds: strconv.ParseInt: parsing \"invalidvalue\": invalid syntax", err.Error()) }) @@ -700,10 +957,7 @@ func TestDelete(t *testing.T) { Key: "key", } - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDb := &mockedDynamoDB{ DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { @@ -715,7 +969,24 @@ func TestDelete(t *testing.T) { }, } - err := ss.Delete(context.Background(), req) + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } + + err := s.Delete(context.Background(), req) require.NoError(t, err) }) @@ -725,10 +996,8 @@ func TestDelete(t *testing.T) { ETag: &etag, Key: "key", } - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + + mockedDb := &mockedDynamoDB{ DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { @@ -744,7 +1013,24 @@ func TestDelete(t *testing.T) { }, } - err := ss.Delete(context.Background(), req) + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } + + err := s.Delete(context.Background(), req) require.NoError(t, err) }) @@ -755,10 +1041,7 @@ func TestDelete(t *testing.T) { Key: "key", } - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDb := &mockedDynamoDB{ DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { @@ -775,7 +1058,23 @@ func TestDelete(t *testing.T) { }, } - err := ss.Delete(context.Background(), req) + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } + err := s.Delete(context.Background(), req) require.Error(t, err) switch tagErr := err.(type) { case *state.ETagError: @@ -786,26 +1085,38 @@ func TestDelete(t *testing.T) { }) t.Run("Unsuccessfully delete item", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { - return nil, errors.New("unable to delete item") - }, + mockedDb := &mockedDynamoDB{ + DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { + return nil, errors.New("unable to delete item") }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, + } + req := &state.DeleteRequest{ Key: "key", } - err := ss.Delete(context.Background(), req) + err := s.Delete(context.Background(), req) require.Error(t, err) }) } func TestMultiTx(t *testing.T) { t.Run("Successfully Multiple Transaction Operations", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } firstKey := "key1" secondKey := "key2" secondValue := "value2" @@ -829,7 +1140,7 @@ func TestMultiTx(t *testing.T) { }, } - ss.client = &mockedDynamoDB{ + mockedDb := &mockedDynamoDB{ TransactWriteItemsWithContextFn: func(ctx context.Context, input *dynamodb.TransactWriteItemsInput, op ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) { // ops - duplicates exOps := len(ops) - 1 @@ -853,13 +1164,30 @@ func TestMultiTx(t *testing.T) { return &dynamodb.TransactWriteItemsOutput{}, nil }, } - ss.table = tableName + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDb, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := StateStore{ + authProvider: mockAuthProvider, + table: tableName, + partitionKey: defaultPartitionKeyName, + } req := &state.TransactionalStateRequest{ Operations: ops, Metadata: map[string]string{}, } - err := ss.Multi(context.Background(), req) + err := s.Multi(context.Background(), req) require.NoError(t, err) }) } From b228e390d600eec7ff95fe5769614e1247cfe242 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Tue, 12 Nov 2024 12:19:05 -0600 Subject: [PATCH 13/39] test: add tests and make things more testable Signed-off-by: Samantha Coyle --- common/authentication/aws/client.go | 32 ++- common/authentication/aws/client_test.go | 252 +++++++++++++++++++++++ common/authentication/aws/static_test.go | 65 ++++++ common/authentication/aws/x509.go | 47 +++-- common/authentication/aws/x509_test.go | 148 +++++++++++++ 5 files changed, 519 insertions(+), 25 deletions(-) create mode 100644 common/authentication/aws/client_test.go create mode 100644 common/authentication/aws/static_test.go create mode 100644 common/authentication/aws/x509_test.go diff --git a/common/authentication/aws/client.go b/common/authentication/aws/client.go index 14d1c11ee5..925f86542c 100644 --- a/common/authentication/aws/client.go +++ b/common/authentication/aws/client.go @@ -6,10 +6,12 @@ import ( "sync" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/aws/aws-sdk-go/service/secretsmanager" @@ -17,6 +19,7 @@ import ( "github.com/aws/aws-sdk-go/service/ses" "github.com/aws/aws-sdk-go/service/sns" "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" "github.com/aws/aws-sdk-go/service/ssm" "github.com/aws/aws-sdk-go/service/ssm/ssmiface" "github.com/aws/aws-sdk-go/service/sts" @@ -90,7 +93,7 @@ type SnsClients struct { } type SqsClients struct { - Sqs *sqs.SQS + Sqs sqsiface.SQSAPI queueURL *string } @@ -103,7 +106,9 @@ type ParameterStoreClients struct { } type KinesisClients struct { - Kinesis *kinesis.Kinesis + Kinesis kinesisiface.KinesisAPI + Region string + Credentials *credentials.Credentials } type SesClients struct { @@ -139,9 +144,10 @@ func (c *SqsClients) QueueURL(ctx context.Context, queueName string) (*string, e resultURL, err := c.Sqs.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{ QueueName: aws.String(queueName), }) - return resultURL.QueueUrl, err + if resultURL != nil { + return resultURL.QueueUrl, err + } } - return nil, errors.New("unable to get queue url due to empty client") } @@ -155,6 +161,8 @@ func (c *ParameterStoreClients) New(session *session.Session) { func (c *KinesisClients) New(session *session.Session) { c.Kinesis = kinesis.New(session, session.Config) + c.Region = *session.Config.Region + c.Credentials = session.Config.Credentials } func (c *KinesisClients) Stream(ctx context.Context, streamName string) (*string, error) { @@ -162,7 +170,9 @@ func (c *KinesisClients) Stream(ctx context.Context, streamName string) (*string stream, err := c.Kinesis.DescribeStreamWithContext(ctx, &kinesis.DescribeStreamInput{ StreamName: aws.String(streamName), }) - return stream.StreamDescription.StreamARN, err + if stream != nil { + return stream.StreamDescription.StreamARN, err + } } return nil, errors.New("unable to get stream arn due to empty client") @@ -172,16 +182,18 @@ func (c *KinesisClients) WorkerCfg(ctx context.Context, stream, consumer, mode s const sharedMode = "shared" if c.Kinesis != nil { if mode == sharedMode { - kclConfig := config.NewKinesisClientLibConfigWithCredential(consumer, - stream, *c.Kinesis.Config.Region, consumer, - c.Kinesis.Config.Credentials) - - return kclConfig + if c.Credentials != nil { + kclConfig := config.NewKinesisClientLibConfigWithCredential(consumer, + stream, c.Region, consumer, + c.Credentials) + return kclConfig + } } } return nil + } func (c *SesClients) New(session *session.Session) { diff --git a/common/authentication/aws/client_test.go b/common/authentication/aws/client_test.go new file mode 100644 index 0000000000..ac87e74859 --- /dev/null +++ b/common/authentication/aws/client_test.go @@ -0,0 +1,252 @@ +package aws + +import ( + "context" + "errors" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vmware/vmware-go-kcl/clientlibrary/config" +) + +type mockedSQS struct { + sqsiface.SQSAPI + GetQueueUrlFn func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) +} + +func (m *mockedSQS) GetQueueUrlWithContext(ctx context.Context, input *sqs.GetQueueUrlInput, opts ...request.Option) (*sqs.GetQueueUrlOutput, error) { + return m.GetQueueUrlFn(ctx, input) +} + +type mockedKinesis struct { + kinesisiface.KinesisAPI + DescribeStreamFn func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) +} + +func (m *mockedKinesis) DescribeStreamWithContext(ctx context.Context, input *kinesis.DescribeStreamInput, opts ...request.Option) (*kinesis.DescribeStreamOutput, error) { + return m.DescribeStreamFn(ctx, input) +} + +func TestS3Clients_New(t *testing.T) { + tests := []struct { + name string + s3Client *S3Clients + session *session.Session + }{ + {"initializes S3 client", &S3Clients{}, session.Must(session.NewSession())}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.s3Client.New(tt.session) + require.NotNil(t, tt.s3Client.S3) + require.NotNil(t, tt.s3Client.Uploader) + require.NotNil(t, tt.s3Client.Downloader) + }) + } +} + +func TestSqsClients_QueueURL(t *testing.T) { + tests := []struct { + name string + mockFn func() *mockedSQS + queueName string + expectedURL *string + expectError bool + }{ + { + name: "returns queue URL successfully", + mockFn: func() *mockedSQS { + return &mockedSQS{ + GetQueueUrlFn: func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) { + return &sqs.GetQueueUrlOutput{ + QueueUrl: aws.String("https://sqs.aws.com/123456789012/queue"), + }, nil + }, + } + }, + queueName: "valid-queue", + expectedURL: aws.String("https://sqs.aws.com/123456789012/queue"), + expectError: false, + }, + { + name: "returns error when queue URL not found", + mockFn: func() *mockedSQS { + return &mockedSQS{ + GetQueueUrlFn: func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) { + return nil, errors.New("unable to get stream arn due to empty client") + }, + } + }, + queueName: "missing-queue", + expectedURL: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockSQS := tt.mockFn() + + // Initialize SqsClients with the mocked SQS client + client := &SqsClients{ + Sqs: mockSQS, + } + + url, err := client.QueueURL(context.Background(), tt.queueName) + + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedURL, url) + } + }) + } +} + +func TestKinesisClients_Stream(t *testing.T) { + tests := []struct { + name string + kinesisClient *KinesisClients + streamName string + mockStreamARN *string + mockError error + expectedStream *string + expectedErr error + }{ + { + name: "successfully retrieves stream ARN", + kinesisClient: &KinesisClients{ + Kinesis: &mockedKinesis{DescribeStreamFn: func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + return &kinesis.DescribeStreamOutput{ + StreamDescription: &kinesis.StreamDescription{ + StreamARN: aws.String("arn:aws:kinesis:some-region:123456789012:stream/some-stream"), + }, + }, nil + }}, + Region: "us-west-1", + Credentials: credentials.NewStaticCredentials("accessKey", "secretKey", ""), + }, + streamName: "some-stream", + expectedStream: aws.String("arn:aws:kinesis:some-region:123456789012:stream/some-stream"), + expectedErr: nil, + }, + { + name: "returns error when stream not found", + kinesisClient: &KinesisClients{ + Kinesis: &mockedKinesis{DescribeStreamFn: func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + return nil, errors.New("stream not found") + }}, + Region: "us-west-1", + Credentials: credentials.NewStaticCredentials("accessKey", "secretKey", ""), + }, + streamName: "nonexistent-stream", + expectedStream: nil, + expectedErr: errors.New("unable to get stream arn due to empty client"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kinesisClient.Stream(context.Background(), tt.streamName) + if tt.expectedErr != nil { + require.Error(t, err) + assert.Equal(t, tt.expectedErr.Error(), err.Error()) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedStream, got) + } + }) + } +} + +func TestKinesisClients_WorkerCfg(t *testing.T) { + testCreds := credentials.NewStaticCredentials("accessKey", "secretKey", "") + tests := []struct { + name string + kinesisClient *KinesisClients + streamName string + consumer string + mode string + expectedConfig *config.KinesisClientLibConfiguration + }{ + { + name: "successfully creates shared mode worker config", + kinesisClient: &KinesisClients{ + Kinesis: &mockedKinesis{ + DescribeStreamFn: func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + return &kinesis.DescribeStreamOutput{ + StreamDescription: &kinesis.StreamDescription{ + StreamARN: aws.String("arn:aws:kinesis:us-east-1:123456789012:stream/existing-stream"), + }, + }, nil + }, + }, + Region: "us-west-1", + Credentials: testCreds, + }, + streamName: "existing-stream", + consumer: "consumer1", + mode: "shared", + expectedConfig: config.NewKinesisClientLibConfigWithCredential( + "consumer1", "existing-stream", "us-west-1", "consumer1", testCreds, + ), + }, + { + name: "returns nil when mode is not shared", + kinesisClient: &KinesisClients{ + Kinesis: &mockedKinesis{ + DescribeStreamFn: func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + return &kinesis.DescribeStreamOutput{ + StreamDescription: &kinesis.StreamDescription{ + StreamARN: aws.String("arn:aws:kinesis:us-east-1:123456789012:stream/existing-stream"), + }, + }, nil + }, + }, + Region: "us-west-1", + Credentials: testCreds, + }, + streamName: "existing-stream", + consumer: "consumer1", + mode: "exclusive", + expectedConfig: nil, + }, + { + name: "returns nil when client is nil", + kinesisClient: &KinesisClients{ + Kinesis: nil, + Region: "us-west-1", + Credentials: credentials.NewStaticCredentials("accessKey", "secretKey", ""), + }, + streamName: "existing-stream", + consumer: "consumer1", + mode: "shared", + expectedConfig: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := tt.kinesisClient.WorkerCfg(context.Background(), tt.streamName, tt.consumer, tt.mode) + if tt.expectedConfig == nil { + assert.Equal(t, tt.expectedConfig, cfg) + return + } + assert.Equal(t, tt.expectedConfig.StreamName, cfg.StreamName) + assert.Equal(t, tt.expectedConfig.EnhancedFanOutConsumerName, cfg.EnhancedFanOutConsumerName) + assert.Equal(t, tt.expectedConfig.EnableEnhancedFanOutConsumer, cfg.EnableEnhancedFanOutConsumer) + assert.Equal(t, tt.expectedConfig.RegionName, cfg.RegionName) + }) + } +} diff --git a/common/authentication/aws/static_test.go b/common/authentication/aws/static_test.go new file mode 100644 index 0000000000..1c88eb0057 --- /dev/null +++ b/common/authentication/aws/static_test.go @@ -0,0 +1,65 @@ +package aws + +import ( + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/stretchr/testify/assert" +) + +func TestGetConfigV2(t *testing.T) { + tests := []struct { + name string + accessKey string + secretKey string + sessionToken string + region string + endpoint string + }{ + { + name: "valid config", + accessKey: "testAccessKey", + secretKey: "testSecretKey", + sessionToken: "testSessionToken", + region: "us-west-2", + endpoint: "https://test.endpoint.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + awsCfg, err := GetConfigV2(tt.accessKey, tt.secretKey, tt.sessionToken, tt.region, tt.endpoint) + assert.NoError(t, err) + assert.NotNil(t, awsCfg) + assert.Equal(t, tt.region, awsCfg.Region) + assert.Equal(t, tt.endpoint, *awsCfg.BaseEndpoint) + }) + } +} + +func TestGetTokenClient(t *testing.T) { + tests := []struct { + name string + awsInstance *StaticAuth + }{ + { + name: "valid token client", + awsInstance: &StaticAuth{ + accessKey: aws.String("testAccessKey"), + secretKey: aws.String("testSecretKey"), + sessionToken: aws.String("testSessionToken"), + region: "us-west-2", + endpoint: aws.String("https://test.endpoint.com"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + session, err := tt.awsInstance.getTokenClient() + assert.NotNil(t, session) + assert.NoError(t, err) + assert.Equal(t, tt.awsInstance.region, *session.Config.Region) + }) + } +} diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go index 06e11e9bee..07b99f6f8c 100644 --- a/common/authentication/aws/x509.go +++ b/common/authentication/aws/x509.go @@ -32,6 +32,7 @@ import ( "github.com/aws/aws-sdk-go/aws/session" awssh "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere/rolesanywhereiface" cryptopem "github.com/dapr/kit/crypto/pem" spiffecontext "github.com/dapr/kit/crypto/spiffe/context" "github.com/dapr/kit/logger" @@ -47,10 +48,11 @@ type x509 struct { internalContext context.Context internalContextCancel func() - logger logger.Logger - Clients *Clients - session *session.Session - cfg *aws.Config + logger logger.Logger + Clients *Clients + rolesAnywhereClient rolesanywhereiface.RolesAnywhereAPI // this is so we can mock it in tests + session *session.Session + cfg *aws.Config chainPEM []byte keyPEM []byte @@ -80,7 +82,7 @@ func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) case x509Auth.SessionDuration == nil: awsDefaultDuration := time.Hour // default 1 hour from AWS x509Auth.SessionDuration = &awsDefaultDuration - case *x509Auth.SessionDuration < time.Minute*15 || *x509Auth.SessionDuration > time.Hour*12: + case *x509Auth.SessionDuration != 0 && (*x509Auth.SessionDuration < time.Minute*15 || *x509Auth.SessionDuration > time.Hour*12): return nil, errors.New("sessionDuration must be greater than 15 minutes, and less than 12 hours") } @@ -99,11 +101,13 @@ func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) if err != nil { return nil, fmt.Errorf("failed to get x.509 credentials: %v", err) } + auth.logger.Infof("sam here 1") // Parse trust anchor and profile ARNs if err := auth.initializeTrustAnchors(); err != nil { return nil, err } + auth.logger.Infof("sam here 2") initialSession, err := auth.createOrRefreshSession(ctx) if err != nil { @@ -448,15 +452,20 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, }} var mySession *session.Session - var err error - config := a.cfg.WithRegion(*a.region).WithHTTPClient(client).WithLogLevel(aws.LogOff) - mySession = session.Must(session.NewSession(config)) - rolesAnywhereClient := rolesanywhere.New(mySession, config) + var config *aws.Config + if a.cfg != nil { + config = a.cfg.WithRegion(*a.region).WithHTTPClient(client).WithLogLevel(aws.LogOff) + } - // Set up signing function and handlers - if err := a.setSigningFunction(rolesAnywhereClient); err != nil { - return nil, err + if a.rolesAnywhereClient == nil { + mySession = session.Must(session.NewSession(config)) + rolesAnywhereClient := rolesanywhere.New(mySession, config) + // Set up signing function and handlers + if err := a.setSigningFunction(rolesAnywhereClient); err != nil { + return nil, err + } + a.rolesAnywhereClient = rolesAnywhereClient } var ( @@ -476,6 +485,7 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er SessionName: nil, } } else { + a.logger.Infof("sam setting 15min default for session duration") duration = 900 // 15 minutes in seconds by default and be autorefreshed createSessionRequest = rolesanywhere.CreateSessionInput{ @@ -488,8 +498,9 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er SessionName: nil, } } + a.logger.Infof("sam session time %v", *createSessionRequest.DurationSeconds) - output, err := rolesAnywhereClient.CreateSessionWithContext(ctx, &createSessionRequest) + output, err := a.rolesAnywhereClient.CreateSessionWithContext(ctx, &createSessionRequest) if err != nil { return nil, fmt.Errorf("failed to create session using dapr app identity: %w", err) } @@ -497,17 +508,22 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er if output == nil || len(output.CredentialSet) != 1 { return nil, fmt.Errorf("expected 1 credential set from X.509 rolesanyway response, got %d", len(output.CredentialSet)) } + a.logger.Infof("sam successfully created new session with iam roles anywhere client!") accessKey := output.CredentialSet[0].Credentials.AccessKeyId secretKey := output.CredentialSet[0].Credentials.SecretAccessKey sessionToken := output.CredentialSet[0].Credentials.SessionToken + + a.logger.Infof("the ak %v sk %v st %v", accessKey, secretKey, sessionToken) + a.logger.Infof("sam the len of credentials set %v", len(output.CredentialSet)) awsCreds := credentials.NewStaticCredentials(*accessKey, *secretKey, *sessionToken) sess := session.Must(session.NewSession(&aws.Config{ Credentials: awsCreds, }, config)) if sess == nil { - return nil, fmt.Errorf("session is nil: %v", sess) + return nil, fmt.Errorf("sam session is nil somehow %v", sess) } + a.logger.Infof("sam just set session in refreshorcreate func %v", a.session) return sess, nil } @@ -539,12 +555,13 @@ func (a *x509) startSessionRefresher() { for { select { case <-ticker.C: - a.logger.Debugf("Refreshing session as expiration is near") + a.logger.Infof("Refreshing session as expiration is near") newSession, err := a.createOrRefreshSession(a.internalContext) if err != nil { a.logger.Errorf("failed to refresh session: %w", err) return } + a.logger.Infof("sam in ticker after created refreshed session %v", newSession) a.Clients.refresh(newSession) diff --git a/common/authentication/aws/x509_test.go b/common/authentication/aws/x509_test.go new file mode 100644 index 0000000000..26773890d7 --- /dev/null +++ b/common/authentication/aws/x509_test.go @@ -0,0 +1,148 @@ +package aws + +import ( + "context" + cryptoX509 "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "math/big" + "net/url" + "sync/atomic" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere/rolesanywhereiface" + "github.com/dapr/kit/crypto/spiffe" + spiffecontext "github.com/dapr/kit/crypto/spiffe/context" + "github.com/dapr/kit/crypto/test" + "github.com/dapr/kit/logger" + "github.com/dapr/kit/ptr" + "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func mockRequestSVIDFn() ([]*cryptoX509.Certificate, error) { + spiffeID, err := url.Parse("spiffe://example.org/test") // create tester SPIFFE ID + if err != nil { + return nil, err + } + + // encode the URI as a SAN extension + uriSAN, err := asn1.Marshal(spiffeID.String()) + if err != nil { + return nil, err + } + + // create dummy certificate with the required URI SAN + cert := &cryptoX509.Certificate{ + Subject: pkix.Name{CommonName: "test-cert"}, + SerialNumber: big.NewInt(1), + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + ExtraExtensions: []pkix.Extension{ + { + Id: asn1.ObjectIdentifier{2, 5, 29, 17}, // OID for subject alternative name + Critical: false, + Value: uriSAN, + }, + }, + } + + return []*cryptoX509.Certificate{cert}, nil +} + +type mockRolesAnywhereClient struct { + rolesanywhereiface.RolesAnywhereAPI + + CreateSessionOutput *rolesanywhere.CreateSessionOutput + CreateSessionError error +} + +func (m *mockRolesAnywhereClient) CreateSessionWithContext(ctx context.Context, input *rolesanywhere.CreateSessionInput, opts ...request.Option) (*rolesanywhere.CreateSessionOutput, error) { + return m.CreateSessionOutput, m.CreateSessionError +} + +func TestGetX509Client(t *testing.T) { + tests := []struct { + name string + mockOutput *rolesanywhere.CreateSessionOutput + mockError error + }{ + { + name: "valid x509 client", + mockOutput: &rolesanywhere.CreateSessionOutput{ + CredentialSet: []*rolesanywhere.CredentialResponse{ + { + Credentials: &rolesanywhere.Credentials{ + AccessKeyId: aws.String("mockAccessKeyId"), + SecretAccessKey: aws.String("mockSecretAccessKey"), + SessionToken: aws.String("mockSessionToken"), + Expiration: aws.String(time.Now().Add(15 * time.Minute).Format(time.RFC3339)), + }, + }, + }, + }, + mockError: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + duration := time.Duration(800) + mockSvc := &mockRolesAnywhereClient{ + CreateSessionOutput: tt.mockOutput, + CreateSessionError: tt.mockError, + } + mockAWS := &x509{ + logger: logger.NewLogger("testLogger"), + AssumeRoleArn: ptr.Of("arn:aws:iam:012345678910:role/exampleIAMRoleName"), + TrustAnchorArn: ptr.Of("arn:aws:rolesanywhere:us-west-1:012345678910:trust-anchor/01234568-0123-0123-0123-012345678901"), + TrustProfileArn: ptr.Of("arn:aws:rolesanywhere:us-west-1:012345678910:profile/01234568-0123-0123-0123-012345678901"), + SessionDuration: &duration, + rolesAnywhereClient: mockSvc, + } + pki := test.GenPKI(t, test.PKIOptions{ + LeafID: spiffeid.RequireFromString("spiffe://example.com/foo/bar"), + }) + + respCert := []*cryptoX509.Certificate{pki.LeafCert} + var respErr error + + var fetches atomic.Int32 + s := spiffe.New(spiffe.Options{ + Log: logger.NewLogger("test"), + RequestSVIDFn: func(context.Context, []byte) ([]*cryptoX509.Certificate, error) { + fetches.Add(1) + return respCert, respErr + }, + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + errCh := make(chan error) + go func() { + errCh <- s.Run(ctx) + }() + + select { + case err := <-errCh: + require.NoError(t, err) + default: + } + + err := s.Ready(ctx) + assert.NoError(t, err) + + // inject the SVID source into the context + ctx = spiffecontext.With(ctx, s) + session, err := mockAWS.createOrRefreshSession(ctx) + + assert.NoError(t, err) + assert.NotNil(t, session) + }) + } +} From 4b1ad205d2c675a4de88c5477acc772d787178c4 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Tue, 12 Nov 2024 13:39:39 -0600 Subject: [PATCH 14/39] style: make linter happy Signed-off-by: Samantha Coyle --- bindings/aws/kinesis/kinesis.go | 6 +- bindings/aws/s3/s3.go | 1 - common/authentication/aws/aws.go | 4 +- common/authentication/aws/aws_test.go | 3 +- common/authentication/aws/client.go | 14 +- common/authentication/aws/client_test.go | 10 +- common/authentication/aws/static.go | 11 +- common/authentication/aws/static_test.go | 7 +- common/authentication/aws/x509.go | 22 +-- common/authentication/aws/x509_test.go | 45 +----- pubsub/aws/snssqs/snssqs.go | 3 +- .../aws/parameterstore/parameterstore_test.go | 1 + .../aws/secretmanager/secretmanager_test.go | 137 +++++++++++++----- state/aws/dynamodb/dynamodb_test.go | 92 ++++++------ 14 files changed, 190 insertions(+), 166 deletions(-) diff --git a/bindings/aws/kinesis/kinesis.go b/bindings/aws/kinesis/kinesis.go index f43cf4330f..8127fa8019 100644 --- a/bindings/aws/kinesis/kinesis.go +++ b/bindings/aws/kinesis/kinesis.go @@ -197,7 +197,7 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er if a.metadata.KinesisConsumerMode == SharedThroughput { a.worker.Shutdown() } else if a.metadata.KinesisConsumerMode == ExtendedFanout { - a.deregisterConsumer(stream, a.consumerARN) + a.deregisterConsumer(ctx, stream, a.consumerARN) } }() @@ -325,14 +325,14 @@ func (a *AWSKinesis) registerConsumer(ctx context.Context, streamARN *string) (* return consumer.Consumer.ConsumerARN, nil } -func (a *AWSKinesis) deregisterConsumer(streamARN *string, consumerARN *string) error { +func (a *AWSKinesis) deregisterConsumer(ctx context.Context, streamARN *string, consumerARN *string) error { if a.consumerARN != nil { // Use a background context because the running context may have been canceled already - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) clients, err := a.authProvider.Kinesis(ctx) if err != nil { return fmt.Errorf("failed to get client: %v", err) } + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) _, err = clients.Kinesis.DeregisterStreamConsumerWithContext(ctx, &kinesis.DeregisterStreamConsumerInput{ ConsumerARN: consumerARN, StreamARN: streamARN, diff --git a/bindings/aws/s3/s3.go b/bindings/aws/s3/s3.go index e57135133b..d9dac8631b 100644 --- a/bindings/aws/s3/s3.go +++ b/bindings/aws/s3/s3.go @@ -168,7 +168,6 @@ func (s *AWSS3) Operations() []bindings.OperationKind { } func (s *AWSS3) create(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { - metadata, err := s.metadata.mergeWithRequestMetadata(req) if err != nil { return nil, fmt.Errorf("s3 binding error: error merging metadata: %w", err) diff --git a/common/authentication/aws/aws.go b/common/authentication/aws/aws.go index 2cedd20b89..c6298a8ebb 100644 --- a/common/authentication/aws/aws.go +++ b/common/authentication/aws/aws.go @@ -24,9 +24,10 @@ import ( v2creds "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/feature/rds/auth" "github.com/aws/aws-sdk-go/aws" - "github.com/dapr/kit/logger" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" + + "github.com/dapr/kit/logger" ) type EnvironmentSettings struct { @@ -105,7 +106,6 @@ func NewProvider(ctx context.Context, opts Options, cfg *aws.Config) (Provider, return newX509(ctx, opts, cfg) } return newStaticIAM(ctx, opts, cfg) - } // NewEnvironmentSettings returns a new EnvironmentSettings configured for a given AWS resource. diff --git a/common/authentication/aws/aws_test.go b/common/authentication/aws/aws_test.go index a60fc6570e..24b6c5649c 100644 --- a/common/authentication/aws/aws_test.go +++ b/common/authentication/aws/aws_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewEnvironmentSettings(t *testing.T) { @@ -23,7 +24,7 @@ func TestNewEnvironmentSettings(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, err := NewEnvironmentSettings(tt.metadata) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, result) }) } diff --git a/common/authentication/aws/client.go b/common/authentication/aws/client.go index 925f86542c..b14a97ec63 100644 --- a/common/authentication/aws/client.go +++ b/common/authentication/aws/client.go @@ -23,12 +23,9 @@ import ( "github.com/aws/aws-sdk-go/service/ssm" "github.com/aws/aws-sdk-go/service/ssm/ssmiface" "github.com/aws/aws-sdk-go/service/sts" - "github.com/dapr/kit/logger" "github.com/vmware/vmware-go-kcl/clientlibrary/config" ) -var log logger.Logger - type Clients struct { mu sync.RWMutex @@ -37,7 +34,7 @@ type Clients struct { sns *SnsClients sqs *SqsClients snssqs *SnsSqsClients - secret *SecretManagerClients + Secret *SecretManagerClients ParameterStore *ParameterStoreClients kinesis *KinesisClients ses *SesClients @@ -61,8 +58,8 @@ func (c *Clients) refresh(session *session.Session) { c.sqs.New(session) case c.snssqs != nil: c.snssqs.New(session) - case c.secret != nil: - c.secret.New(session) + case c.Secret != nil: + c.Secret.New(session) case c.ParameterStore != nil: c.ParameterStore.New(session) case c.kinesis != nil: @@ -93,8 +90,7 @@ type SnsClients struct { } type SqsClients struct { - Sqs sqsiface.SQSAPI - queueURL *string + Sqs sqsiface.SQSAPI } type SecretManagerClients struct { @@ -189,11 +185,9 @@ func (c *KinesisClients) WorkerCfg(ctx context.Context, stream, consumer, mode s return kclConfig } } - } return nil - } func (c *SesClients) New(session *session.Session) { diff --git a/common/authentication/aws/client_test.go b/common/authentication/aws/client_test.go index ac87e74859..e23d1244ca 100644 --- a/common/authentication/aws/client_test.go +++ b/common/authentication/aws/client_test.go @@ -20,11 +20,11 @@ import ( type mockedSQS struct { sqsiface.SQSAPI - GetQueueUrlFn func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) + GetQueueURLFn func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) } -func (m *mockedSQS) GetQueueUrlWithContext(ctx context.Context, input *sqs.GetQueueUrlInput, opts ...request.Option) (*sqs.GetQueueUrlOutput, error) { - return m.GetQueueUrlFn(ctx, input) +func (m *mockedSQS) GetQueueURLWithContext(ctx context.Context, input *sqs.GetQueueUrlInput, opts ...request.Option) (*sqs.GetQueueUrlOutput, error) { + return m.GetQueueURLFn(ctx, input) } type mockedKinesis struct { @@ -67,7 +67,7 @@ func TestSqsClients_QueueURL(t *testing.T) { name: "returns queue URL successfully", mockFn: func() *mockedSQS { return &mockedSQS{ - GetQueueUrlFn: func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) { + GetQueueURLFn: func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) { return &sqs.GetQueueUrlOutput{ QueueUrl: aws.String("https://sqs.aws.com/123456789012/queue"), }, nil @@ -82,7 +82,7 @@ func TestSqsClients_QueueURL(t *testing.T) { name: "returns error when queue URL not found", mockFn: func() *mockedSQS { return &mockedSQS{ - GetQueueUrlFn: func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) { + GetQueueURLFn: func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) { return nil, errors.New("unable to get stream arn due to empty client") }, } diff --git a/common/authentication/aws/static.go b/common/authentication/aws/static.go index c5454ae924..5030cec4fb 100644 --- a/common/authentication/aws/static.go +++ b/common/authentication/aws/static.go @@ -25,6 +25,7 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" + "github.com/dapr/kit/logger" ) @@ -200,8 +201,8 @@ func (a *StaticAuth) SecretManager(ctx context.Context) (*SecretManagerClients, a.mu.Lock() defer a.mu.Unlock() - if a.Clients.secret != nil { - return a.Clients.secret, nil + if a.Clients.Secret != nil { + return a.Clients.Secret, nil } // respect context cancellation while initializing client @@ -209,14 +210,14 @@ func (a *StaticAuth) SecretManager(ctx context.Context) (*SecretManagerClients, go func() { defer close(done) clients := SecretManagerClients{} - a.Clients.secret = &clients - a.Clients.secret.New(a.session) + a.Clients.Secret = &clients + a.Clients.Secret.New(a.session) }() // wait for new client or context to be canceled select { case <-done: - return a.Clients.secret, nil + return a.Clients.Secret, nil case <-ctx.Done(): return nil, ctx.Err() } diff --git a/common/authentication/aws/static_test.go b/common/authentication/aws/static_test.go index 1c88eb0057..9c61b5a412 100644 --- a/common/authentication/aws/static_test.go +++ b/common/authentication/aws/static_test.go @@ -5,6 +5,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGetConfigV2(t *testing.T) { @@ -29,7 +30,7 @@ func TestGetConfigV2(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { awsCfg, err := GetConfigV2(tt.accessKey, tt.secretKey, tt.sessionToken, tt.region, tt.endpoint) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, awsCfg) assert.Equal(t, tt.region, awsCfg.Region) assert.Equal(t, tt.endpoint, *awsCfg.BaseEndpoint) @@ -57,8 +58,8 @@ func TestGetTokenClient(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { session, err := tt.awsInstance.getTokenClient() - assert.NotNil(t, session) - assert.NoError(t, err) + require.NotNil(t, session) + require.NoError(t, err) assert.Equal(t, tt.awsInstance.region, *session.Config.Region) }) } diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go index 07b99f6f8c..96b2e46de0 100644 --- a/common/authentication/aws/x509.go +++ b/common/authentication/aws/x509.go @@ -33,6 +33,7 @@ import ( awssh "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" "github.com/aws/rolesanywhere-credential-helper/rolesanywhere/rolesanywhereiface" + cryptopem "github.com/dapr/kit/crypto/pem" spiffecontext "github.com/dapr/kit/crypto/spiffe/context" "github.com/dapr/kit/logger" @@ -97,17 +98,16 @@ func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) Clients: newClients(), } - err := auth.getCertPEM(ctx) + var err error + err = auth.getCertPEM(ctx) if err != nil { return nil, fmt.Errorf("failed to get x.509 credentials: %v", err) } - auth.logger.Infof("sam here 1") // Parse trust anchor and profile ARNs - if err := auth.initializeTrustAnchors(); err != nil { + if err = auth.initializeTrustAnchors(); err != nil { return nil, err } - auth.logger.Infof("sam here 2") initialSession, err := auth.createOrRefreshSession(ctx) if err != nil { @@ -134,7 +134,7 @@ func (a *x509) getCertPEM(ctx context.Context) error { // retrieve svid from spiffe context svid, ok := spiffecontext.From(ctx) if !ok { - return fmt.Errorf("no SVID found in context") + return errors.New("no SVID found in context") } // get x.509 svid svidx, err := svid.GetX509SVID() @@ -288,8 +288,8 @@ func (a *x509) SecretManager(ctx context.Context) (*SecretManagerClients, error) a.mu.Lock() defer a.mu.Unlock() - if a.Clients.secret != nil { - return a.Clients.secret, nil + if a.Clients.Secret != nil { + return a.Clients.Secret, nil } // respect context cancellation while initializing client @@ -297,14 +297,14 @@ func (a *x509) SecretManager(ctx context.Context) (*SecretManagerClients, error) go func() { defer close(done) clients := SecretManagerClients{} - a.Clients.secret = &clients - a.Clients.secret.New(a.session) + a.Clients.Secret = &clients + a.Clients.Secret.New(a.session) }() // wait for new client or context to be canceled select { case <-done: - return a.Clients.secret, nil + return a.Clients.Secret, nil case <-ctx.Done(): return nil, ctx.Err() } @@ -422,7 +422,7 @@ func (a *x509) setSigningFunction(rolesAnywhereClient *rolesanywhere.RolesAnywhe return err } - var ints []cryptoX509.Certificate + ints := make([]cryptoX509.Certificate, 0, len(certs)-1) for i := range certs[1:] { ints = append(ints, *certs[i+1]) } diff --git a/common/authentication/aws/x509_test.go b/common/authentication/aws/x509_test.go index 26773890d7..43f57b0a30 100644 --- a/common/authentication/aws/x509_test.go +++ b/common/authentication/aws/x509_test.go @@ -3,10 +3,6 @@ package aws import ( "context" cryptoX509 "crypto/x509" - "crypto/x509/pkix" - "encoding/asn1" - "math/big" - "net/url" "sync/atomic" "testing" "time" @@ -15,46 +11,17 @@ import ( "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" "github.com/aws/rolesanywhere-credential-helper/rolesanywhere/rolesanywhereiface" + "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/dapr/kit/crypto/spiffe" spiffecontext "github.com/dapr/kit/crypto/spiffe/context" "github.com/dapr/kit/crypto/test" "github.com/dapr/kit/logger" "github.com/dapr/kit/ptr" - "github.com/spiffe/go-spiffe/v2/spiffeid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -func mockRequestSVIDFn() ([]*cryptoX509.Certificate, error) { - spiffeID, err := url.Parse("spiffe://example.org/test") // create tester SPIFFE ID - if err != nil { - return nil, err - } - - // encode the URI as a SAN extension - uriSAN, err := asn1.Marshal(spiffeID.String()) - if err != nil { - return nil, err - } - - // create dummy certificate with the required URI SAN - cert := &cryptoX509.Certificate{ - Subject: pkix.Name{CommonName: "test-cert"}, - SerialNumber: big.NewInt(1), - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Hour), - ExtraExtensions: []pkix.Extension{ - { - Id: asn1.ObjectIdentifier{2, 5, 29, 17}, // OID for subject alternative name - Critical: false, - Value: uriSAN, - }, - }, - } - - return []*cryptoX509.Certificate{cert}, nil -} - type mockRolesAnywhereClient struct { rolesanywhereiface.RolesAnywhereAPI @@ -135,13 +102,13 @@ func TestGetX509Client(t *testing.T) { } err := s.Ready(ctx) - assert.NoError(t, err) + require.NoError(t, err) // inject the SVID source into the context ctx = spiffecontext.With(ctx, s) session, err := mockAWS.createOrRefreshSession(ctx) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, session) }) } diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index e560e7f0c4..ca93c27f10 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -154,7 +154,8 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error { SessionToken: m.SessionToken, } // extra configs needed per component type - provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) + var provider awsAuth.Provider + provider, err = awsAuth.NewProvider(ctx, opts, aws.NewConfig()) if err != nil { return err } diff --git a/secretstores/aws/parameterstore/parameterstore_test.go b/secretstores/aws/parameterstore/parameterstore_test.go index b07e1467f3..200fa58030 100644 --- a/secretstores/aws/parameterstore/parameterstore_test.go +++ b/secretstores/aws/parameterstore/parameterstore_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/aws/aws-sdk-go/aws" + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" "github.com/aws/aws-sdk-go/aws/request" diff --git a/secretstores/aws/secretmanager/secretmanager_test.go b/secretstores/aws/secretmanager/secretmanager_test.go index 85918237a3..6c6736f8b9 100644 --- a/secretstores/aws/secretmanager/secretmanager_test.go +++ b/secretstores/aws/secretmanager/secretmanager_test.go @@ -25,6 +25,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" + "github.com/dapr/components-contrib/secretstores" "github.com/dapr/kit/logger" ) @@ -60,21 +62,35 @@ func TestInit(t *testing.T) { func TestGetSecret(t *testing.T) { t.Run("successfully retrieve secret", func(t *testing.T) { t.Run("without version id and version stage", func(t *testing.T) { - s := smSecretStore{ - client: &mockedSM{ - GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { - assert.Nil(t, input.VersionId) - assert.Nil(t, input.VersionStage) - secret := secretValue - - return &secretsmanager.GetSecretValueOutput{ - Name: input.SecretId, - SecretString: &secret, - }, nil - }, + mockSSM := &mockedSM{ + GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { + assert.Nil(t, input.VersionId) + assert.Nil(t, input.VersionStage) + secret := secretValue + + return &secretsmanager.GetSecretValueOutput{ + Name: input.SecretId, + SecretString: &secret, + }, nil }, } + secret := awsAuth.SecretManagerClients{ + Manager: mockSSM, + } + + mockedClients := awsAuth.Clients{ + Secret: &secret, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := smSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/secret/testing", Metadata: map[string]string{}, @@ -85,20 +101,34 @@ func TestGetSecret(t *testing.T) { }) t.Run("with version id", func(t *testing.T) { - s := smSecretStore{ - client: &mockedSM{ - GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { - assert.NotNil(t, input.VersionId) - secret := secretValue - - return &secretsmanager.GetSecretValueOutput{ - Name: input.SecretId, - SecretString: &secret, - }, nil - }, + mockSSM := &mockedSM{ + GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { + assert.NotNil(t, input.VersionId) + secret := secretValue + + return &secretsmanager.GetSecretValueOutput{ + Name: input.SecretId, + SecretString: &secret, + }, nil }, } + secret := awsAuth.SecretManagerClients{ + Manager: mockSSM, + } + + mockedClients := awsAuth.Clients{ + Secret: &secret, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := smSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/secret/testing", Metadata: map[string]string{ @@ -111,20 +141,34 @@ func TestGetSecret(t *testing.T) { }) t.Run("with version stage", func(t *testing.T) { - s := smSecretStore{ - client: &mockedSM{ - GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { - assert.NotNil(t, input.VersionStage) - secret := secretValue - - return &secretsmanager.GetSecretValueOutput{ - Name: input.SecretId, - SecretString: &secret, - }, nil - }, + mockSSM := &mockedSM{ + GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { + assert.NotNil(t, input.VersionStage) + secret := secretValue + + return &secretsmanager.GetSecretValueOutput{ + Name: input.SecretId, + SecretString: &secret, + }, nil }, } + secret := awsAuth.SecretManagerClients{ + Manager: mockSSM, + } + + mockedClients := awsAuth.Clients{ + Secret: &secret, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := smSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/secret/testing", Metadata: map[string]string{ @@ -138,13 +182,28 @@ func TestGetSecret(t *testing.T) { }) t.Run("unsuccessfully retrieve secret", func(t *testing.T) { - s := smSecretStore{ - client: &mockedSM{ - GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { - return nil, errors.New("failed due to any reason") - }, + mockSSM := &mockedSM{ + GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { + return nil, errors.New("failed due to any reason") }, } + + secret := awsAuth.SecretManagerClients{ + Manager: mockSSM, + } + + mockedClients := awsAuth.Clients{ + Secret: &secret, + } + + mockAuthProvider := &awsAuth.StaticAuth{ + Clients: &mockedClients, + } + + s := smSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/secret/testing", Metadata: map[string]string{}, diff --git a/state/aws/dynamodb/dynamodb_test.go b/state/aws/dynamodb/dynamodb_test.go index ca0a97109a..8b20277ae8 100644 --- a/state/aws/dynamodb/dynamodb_test.go +++ b/state/aws/dynamodb/dynamodb_test.go @@ -76,7 +76,7 @@ func (m *mockedDynamoDB) TransactWriteItemsWithContext(ctx context.Context, inpu func TestInit(t *testing.T) { m := state.Metadata{} - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ // We're adding this so we can pass the connection check on Init GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { return nil, nil @@ -84,7 +84,7 @@ func TestInit(t *testing.T) { } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -153,14 +153,14 @@ func TestInit(t *testing.T) { "Region": "eu-west-1", } - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { return nil, errors.New("Requested resource not found") }, } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -184,7 +184,7 @@ func TestInit(t *testing.T) { func TestGet(t *testing.T) { t.Run("Successfully retrieve item", func(t *testing.T) { - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return &dynamodb.GetItemOutput{ Item: map[string]*dynamodb.AttributeValue{ @@ -203,7 +203,7 @@ func TestGet(t *testing.T) { } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -232,7 +232,7 @@ func TestGet(t *testing.T) { assert.NotContains(t, out.Metadata, "ttlExpireTime") }) t.Run("Successfully retrieve item (with unexpired ttl)", func(t *testing.T) { - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return &dynamodb.GetItemOutput{ Item: map[string]*dynamodb.AttributeValue{ @@ -254,7 +254,7 @@ func TestGet(t *testing.T) { } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -286,7 +286,7 @@ func TestGet(t *testing.T) { assert.Equal(t, int64(4074862051), expireTime.Unix()) }) t.Run("Successfully retrieve item (with expired ttl)", func(t *testing.T) { - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return &dynamodb.GetItemOutput{ Item: map[string]*dynamodb.AttributeValue{ @@ -308,7 +308,7 @@ func TestGet(t *testing.T) { } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -337,14 +337,14 @@ func TestGet(t *testing.T) { assert.Nil(t, out.Metadata) }) t.Run("Unsuccessfully get item", func(t *testing.T) { - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return nil, errors.New("failed to retrieve data") }, } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -371,7 +371,7 @@ func TestGet(t *testing.T) { assert.Nil(t, out) }) t.Run("Unsuccessfully with empty response", func(t *testing.T) { - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return &dynamodb.GetItemOutput{ Item: map[string]*dynamodb.AttributeValue{}, @@ -380,7 +380,7 @@ func TestGet(t *testing.T) { } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -408,7 +408,7 @@ func TestGet(t *testing.T) { assert.Nil(t, out.Metadata) }) t.Run("Unsuccessfully with no required key", func(t *testing.T) { - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return &dynamodb.GetItemOutput{ Item: map[string]*dynamodb.AttributeValue{ @@ -421,7 +421,7 @@ func TestGet(t *testing.T) { } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -455,7 +455,7 @@ func TestSet(t *testing.T) { } t.Run("Successfully set item", func(t *testing.T) { - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -476,7 +476,7 @@ func TestSet(t *testing.T) { } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -503,7 +503,7 @@ func TestSet(t *testing.T) { }) t.Run("Successfully set item with matching etag", func(t *testing.T) { - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -528,7 +528,7 @@ func TestSet(t *testing.T) { } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -556,7 +556,7 @@ func TestSet(t *testing.T) { }) t.Run("Unsuccessfully set item with mismatched etag", func(t *testing.T) { - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -576,7 +576,7 @@ func TestSet(t *testing.T) { } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -611,7 +611,7 @@ func TestSet(t *testing.T) { }) t.Run("Successfully set item with first-write-concurrency", func(t *testing.T) { - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -633,7 +633,7 @@ func TestSet(t *testing.T) { } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -662,7 +662,7 @@ func TestSet(t *testing.T) { }) t.Run("Unsuccessfully set item with first-write-concurrency", func(t *testing.T) { - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -679,7 +679,7 @@ func TestSet(t *testing.T) { } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -713,7 +713,7 @@ func TestSet(t *testing.T) { }) t.Run("Successfully set item with ttl = -1", func(t *testing.T) { - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Len(t, input.Item, 4) result := DynamoDBItem{} @@ -734,7 +734,7 @@ func TestSet(t *testing.T) { } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -763,7 +763,7 @@ func TestSet(t *testing.T) { require.NoError(t, err) }) t.Run("Successfully set item with 'correct' ttl", func(t *testing.T) { - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Len(t, input.Item, 4) result := DynamoDBItem{} @@ -784,7 +784,7 @@ func TestSet(t *testing.T) { } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -815,14 +815,14 @@ func TestSet(t *testing.T) { }) t.Run("Unsuccessfully set item", func(t *testing.T) { - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { return nil, errors.New("unable to put item") }, } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -847,7 +847,7 @@ func TestSet(t *testing.T) { require.Error(t, err) }) t.Run("Successfully set item with correct ttl but without component metadata", func(t *testing.T) { - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("someKey"), @@ -868,7 +868,7 @@ func TestSet(t *testing.T) { } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -896,7 +896,7 @@ func TestSet(t *testing.T) { require.NoError(t, err) }) t.Run("Unsuccessfully set item with ttl (invalid value)", func(t *testing.T) { - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { @@ -921,7 +921,7 @@ func TestSet(t *testing.T) { } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -957,7 +957,7 @@ func TestDelete(t *testing.T) { Key: "key", } - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { @@ -970,7 +970,7 @@ func TestDelete(t *testing.T) { } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -997,7 +997,7 @@ func TestDelete(t *testing.T) { Key: "key", } - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { @@ -1014,7 +1014,7 @@ func TestDelete(t *testing.T) { } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -1041,7 +1041,7 @@ func TestDelete(t *testing.T) { Key: "key", } - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { @@ -1059,7 +1059,7 @@ func TestDelete(t *testing.T) { } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -1085,14 +1085,14 @@ func TestDelete(t *testing.T) { }) t.Run("Unsuccessfully delete item", func(t *testing.T) { - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { return nil, errors.New("unable to delete item") }, } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ @@ -1140,7 +1140,7 @@ func TestMultiTx(t *testing.T) { }, } - mockedDb := &mockedDynamoDB{ + mockedDB := &mockedDynamoDB{ TransactWriteItemsWithContextFn: func(ctx context.Context, input *dynamodb.TransactWriteItemsInput, op ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) { // ops - duplicates exOps := len(ops) - 1 @@ -1166,7 +1166,7 @@ func TestMultiTx(t *testing.T) { } dynamo := awsAuth.DynamoDBClients{ - DynamoDB: mockedDb, + DynamoDB: mockedDB, } mockedClients := awsAuth.Clients{ From ce94dfee06f0fc6912724986b33178cad5269269 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Tue, 12 Nov 2024 13:40:52 -0600 Subject: [PATCH 15/39] style: clean up logs Signed-off-by: Samantha Coyle --- common/authentication/aws/x509.go | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go index 96b2e46de0..a9fcc17f05 100644 --- a/common/authentication/aws/x509.go +++ b/common/authentication/aws/x509.go @@ -485,7 +485,6 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er SessionName: nil, } } else { - a.logger.Infof("sam setting 15min default for session duration") duration = 900 // 15 minutes in seconds by default and be autorefreshed createSessionRequest = rolesanywhere.CreateSessionInput{ @@ -498,7 +497,6 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er SessionName: nil, } } - a.logger.Infof("sam session time %v", *createSessionRequest.DurationSeconds) output, err := a.rolesAnywhereClient.CreateSessionWithContext(ctx, &createSessionRequest) if err != nil { @@ -508,28 +506,23 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er if output == nil || len(output.CredentialSet) != 1 { return nil, fmt.Errorf("expected 1 credential set from X.509 rolesanyway response, got %d", len(output.CredentialSet)) } - a.logger.Infof("sam successfully created new session with iam roles anywhere client!") accessKey := output.CredentialSet[0].Credentials.AccessKeyId secretKey := output.CredentialSet[0].Credentials.SecretAccessKey sessionToken := output.CredentialSet[0].Credentials.SessionToken - - a.logger.Infof("the ak %v sk %v st %v", accessKey, secretKey, sessionToken) - a.logger.Infof("sam the len of credentials set %v", len(output.CredentialSet)) awsCreds := credentials.NewStaticCredentials(*accessKey, *secretKey, *sessionToken) sess := session.Must(session.NewSession(&aws.Config{ Credentials: awsCreds, }, config)) if sess == nil { - return nil, fmt.Errorf("sam session is nil somehow %v", sess) + return nil, errors.New("session is nil") } - a.logger.Infof("sam just set session in refreshorcreate func %v", a.session) return sess, nil } func (a *x509) startSessionRefresher() { - a.logger.Infof("starting session refresher for x509 auth") + a.logger.Debugf("starting session refresher for x509 auth") // if there is a set session duration, then exit bc we will not auto refresh the session. if *a.SessionDuration != 0 { a.logger.Debugf("session duration was set, so there is no authentication refreshing") @@ -555,13 +548,12 @@ func (a *x509) startSessionRefresher() { for { select { case <-ticker.C: - a.logger.Infof("Refreshing session as expiration is near") + a.logger.Debugf("Refreshing session as expiration is near") newSession, err := a.createOrRefreshSession(a.internalContext) if err != nil { a.logger.Errorf("failed to refresh session: %w", err) return } - a.logger.Infof("sam in ticker after created refreshed session %v", newSession) a.Clients.refresh(newSession) From 30de3dba4c4c825acc1020894e4a130f9127ba02 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Tue, 12 Nov 2024 14:10:21 -0600 Subject: [PATCH 16/39] style: more linter things and adjust for mocking client Signed-off-by: Samantha Coyle --- common/authentication/aws/client_test.go | 2 +- common/authentication/aws/x509.go | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/common/authentication/aws/client_test.go b/common/authentication/aws/client_test.go index e23d1244ca..20ed547006 100644 --- a/common/authentication/aws/client_test.go +++ b/common/authentication/aws/client_test.go @@ -23,7 +23,7 @@ type mockedSQS struct { GetQueueURLFn func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) } -func (m *mockedSQS) GetQueueURLWithContext(ctx context.Context, input *sqs.GetQueueUrlInput, opts ...request.Option) (*sqs.GetQueueUrlOutput, error) { +func (m *mockedSQS) GetQueueUrlWithContext(ctx context.Context, input *sqs.GetQueueUrlInput, opts ...request.Option) (*sqs.GetQueueUrlOutput, error) { //nolint:stylecheck return m.GetQueueURLFn(ctx, input) } diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go index a9fcc17f05..495edb522d 100644 --- a/common/authentication/aws/x509.go +++ b/common/authentication/aws/x509.go @@ -458,6 +458,9 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er config = a.cfg.WithRegion(*a.region).WithHTTPClient(client).WithLogLevel(aws.LogOff) } + // this is needed for testing purposes to mock the client, + // so code never sets the client, but tests do. + var rolesClient *rolesanywhere.RolesAnywhere if a.rolesAnywhereClient == nil { mySession = session.Must(session.NewSession(config)) rolesAnywhereClient := rolesanywhere.New(mySession, config) @@ -465,7 +468,7 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er if err := a.setSigningFunction(rolesAnywhereClient); err != nil { return nil, err } - a.rolesAnywhereClient = rolesAnywhereClient + rolesClient = rolesAnywhereClient } var ( @@ -498,7 +501,7 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er } } - output, err := a.rolesAnywhereClient.CreateSessionWithContext(ctx, &createSessionRequest) + output, err := rolesClient.CreateSessionWithContext(ctx, &createSessionRequest) if err != nil { return nil, fmt.Errorf("failed to create session using dapr app identity: %w", err) } From f8e3567e2600ca71ca014c03fe96189618a12082 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Tue, 12 Nov 2024 14:14:04 -0600 Subject: [PATCH 17/39] fix: make 1 hr default timeout Signed-off-by: Samantha Coyle --- .build-tools/builtin-authentication-profiles.yaml | 2 +- common/authentication/aws/x509.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.build-tools/builtin-authentication-profiles.yaml b/.build-tools/builtin-authentication-profiles.yaml index 4b25dab0ff..86d412878e 100644 --- a/.build-tools/builtin-authentication-profiles.yaml +++ b/.build-tools/builtin-authentication-profiles.yaml @@ -52,7 +52,7 @@ aws: description: | Duration of the session using AWS IAM Roles Anywhere. If set to 0m, temporary credentials will automatically rotate. - default: '15m' + default: '1h' example: '0m' required: true diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go index 495edb522d..e6b0f91e79 100644 --- a/common/authentication/aws/x509.go +++ b/common/authentication/aws/x509.go @@ -488,7 +488,7 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er SessionName: nil, } } else { - duration = 900 // 15 minutes in seconds by default and be autorefreshed + duration = int64(time.Hour) createSessionRequest = rolesanywhere.CreateSessionInput{ Cert: ptr.Of(string(a.chainPEM)), From 3e6a471d381c6559bfd3de1d2224dad3befcec4c Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Tue, 12 Nov 2024 15:39:02 -0600 Subject: [PATCH 18/39] fix: update more tests Signed-off-by: Samantha Coyle --- common/authentication/aws/static.go | 67 ++++++++++++------------ common/authentication/aws/static_test.go | 12 ++--- common/authentication/aws/x509.go | 32 ++++++----- common/authentication/aws/x509_test.go | 2 +- state/aws/dynamodb/dynamodb.go | 32 ++++++----- state/aws/dynamodb/dynamodb_test.go | 27 ++++++++-- 6 files changed, 100 insertions(+), 72 deletions(-) diff --git a/common/authentication/aws/static.go b/common/authentication/aws/static.go index 5030cec4fb..c65136b2cf 100644 --- a/common/authentication/aws/static.go +++ b/common/authentication/aws/static.go @@ -31,28 +31,28 @@ import ( type StaticAuth struct { mu sync.RWMutex - logger logger.Logger + Logger logger.Logger - region string - endpoint *string - accessKey *string - secretKey *string - sessionToken *string + Region string + Endpoint *string + AccessKey *string + SecretKey *string + SessionToken *string Clients *Clients - session *session.Session - cfg *aws.Config + Session *session.Session + Cfg *aws.Config } func newStaticIAM(ctx context.Context, opts Options, cfg *aws.Config) (*StaticAuth, error) { auth := &StaticAuth{ - logger: opts.Logger, - region: opts.Region, - endpoint: &opts.Endpoint, - accessKey: &opts.AccessKey, - secretKey: &opts.SecretKey, - sessionToken: &opts.SessionToken, - cfg: cfg, + Logger: opts.Logger, + Region: opts.Region, + Endpoint: &opts.Endpoint, + AccessKey: &opts.AccessKey, + SecretKey: &opts.SecretKey, + SessionToken: &opts.SessionToken, + Cfg: cfg, Clients: newClients(), } @@ -61,7 +61,7 @@ func newStaticIAM(ctx context.Context, opts Options, cfg *aws.Config) (*StaticAu return nil, fmt.Errorf("failed to get token client: %v", err) } - auth.session = initialSession + auth.Session = initialSession return auth, nil } @@ -80,8 +80,8 @@ func (a *StaticAuth) S3(ctx context.Context) (*S3Clients, error) { defer close(done) s3Clients := S3Clients{} a.Clients.s3 = &s3Clients - a.logger.Debugf("Initializing S3 clients with session %v", a.session) - a.Clients.s3.New(a.session) + a.Logger.Debugf("Initializing S3 clients with session %v", a.Session) + a.Clients.s3.New(a.Session) }() // wait for new client or context to be canceled @@ -96,8 +96,9 @@ func (a *StaticAuth) S3(ctx context.Context) (*S3Clients, error) { func (a *StaticAuth) DynamoDB(ctx context.Context) (*DynamoDBClients, error) { a.mu.Lock() defer a.mu.Unlock() - + fmt.Printf("ready sam") if a.Clients.Dynamo != nil { + fmt.Printf("sam it is not nil so it's injected fine") return a.Clients.Dynamo, nil } @@ -107,7 +108,7 @@ func (a *StaticAuth) DynamoDB(ctx context.Context) (*DynamoDBClients, error) { defer close(done) clients := DynamoDBClients{} a.Clients.Dynamo = &clients - a.Clients.Dynamo.New(a.session) + a.Clients.Dynamo.New(a.Session) }() // wait for new client or context to be canceled @@ -133,7 +134,7 @@ func (a *StaticAuth) Sqs(ctx context.Context) (*SqsClients, error) { defer close(done) clients := SqsClients{} a.Clients.sqs = &clients - a.Clients.sqs.New(a.session) + a.Clients.sqs.New(a.Session) }() // wait for new client or context to be canceled @@ -159,7 +160,7 @@ func (a *StaticAuth) Sns(ctx context.Context) (*SnsClients, error) { defer close(done) clients := SnsClients{} a.Clients.sns = &clients - a.Clients.sns.New(a.session) + a.Clients.sns.New(a.Session) }() // wait for new client or context to be canceled @@ -185,7 +186,7 @@ func (a *StaticAuth) SnsSqs(ctx context.Context) (*SnsSqsClients, error) { defer close(done) clients := SnsSqsClients{} a.Clients.snssqs = &clients - a.Clients.snssqs.New(a.session) + a.Clients.snssqs.New(a.Session) }() // wait for new client or context to be canceled @@ -211,7 +212,7 @@ func (a *StaticAuth) SecretManager(ctx context.Context) (*SecretManagerClients, defer close(done) clients := SecretManagerClients{} a.Clients.Secret = &clients - a.Clients.Secret.New(a.session) + a.Clients.Secret.New(a.Session) }() // wait for new client or context to be canceled @@ -237,7 +238,7 @@ func (a *StaticAuth) ParameterStore(ctx context.Context) (*ParameterStoreClients defer close(done) clients := ParameterStoreClients{} a.Clients.ParameterStore = &clients - a.Clients.ParameterStore.New(a.session) + a.Clients.ParameterStore.New(a.Session) }() // wait for new client or context to be canceled @@ -263,7 +264,7 @@ func (a *StaticAuth) Kinesis(ctx context.Context) (*KinesisClients, error) { defer close(done) clients := KinesisClients{} a.Clients.kinesis = &clients - a.Clients.kinesis.New(a.session) + a.Clients.kinesis.New(a.Session) }() // wait for new client or context to be canceled @@ -289,7 +290,7 @@ func (a *StaticAuth) Ses(ctx context.Context) (*SesClients, error) { defer close(done) clients := SesClients{} a.Clients.ses = &clients - a.Clients.ses.New(a.session) + a.Clients.ses.New(a.Session) }() // wait for new client or context to be canceled @@ -304,17 +305,17 @@ func (a *StaticAuth) Ses(ctx context.Context) (*SesClients, error) { func (a *StaticAuth) getTokenClient() (*session.Session, error) { awsConfig := aws.NewConfig() - if a.region != "" { - awsConfig = awsConfig.WithRegion(a.region) + if a.Region != "" { + awsConfig = awsConfig.WithRegion(a.Region) } - if a.accessKey != nil && a.secretKey != nil { + if a.AccessKey != nil && a.SecretKey != nil { // session token is an option field - awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(*a.accessKey, *a.secretKey, *a.sessionToken)) + awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(*a.AccessKey, *a.SecretKey, *a.SessionToken)) } - if a.endpoint != nil { - awsConfig = awsConfig.WithEndpoint(*a.endpoint) + if a.Endpoint != nil { + awsConfig = awsConfig.WithEndpoint(*a.Endpoint) } awsSession, err := session.NewSessionWithOptions(session.Options{ diff --git a/common/authentication/aws/static_test.go b/common/authentication/aws/static_test.go index 9c61b5a412..1f2191cfa4 100644 --- a/common/authentication/aws/static_test.go +++ b/common/authentication/aws/static_test.go @@ -46,11 +46,11 @@ func TestGetTokenClient(t *testing.T) { { name: "valid token client", awsInstance: &StaticAuth{ - accessKey: aws.String("testAccessKey"), - secretKey: aws.String("testSecretKey"), - sessionToken: aws.String("testSessionToken"), - region: "us-west-2", - endpoint: aws.String("https://test.endpoint.com"), + AccessKey: aws.String("testAccessKey"), + SecretKey: aws.String("testSecretKey"), + SessionToken: aws.String("testSessionToken"), + Region: "us-west-2", + Endpoint: aws.String("https://test.endpoint.com"), }, }, } @@ -60,7 +60,7 @@ func TestGetTokenClient(t *testing.T) { session, err := tt.awsInstance.getTokenClient() require.NotNil(t, session) require.NoError(t, err) - assert.Equal(t, tt.awsInstance.region, *session.Config.Region) + assert.Equal(t, tt.awsInstance.Region, *session.Config.Region) }) } } diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go index e6b0f91e79..99d8b1e8ef 100644 --- a/common/authentication/aws/x509.go +++ b/common/authentication/aws/x509.go @@ -53,7 +53,7 @@ type x509 struct { Clients *Clients rolesAnywhereClient rolesanywhereiface.RolesAnywhereAPI // this is so we can mock it in tests session *session.Session - cfg *aws.Config + Cfg *aws.Config chainPEM []byte keyPEM []byte @@ -78,11 +78,6 @@ func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) return nil, errors.New("trustAnchorArn is required") case x509Auth.AssumeRoleArn == nil: return nil, errors.New("assumeRoleArn is required") - - // https://aws.amazon.com/about-aws/whats-new/2024/03/iam-roles-anywhere-credentials-valid-12-hours/#:~:text=The%20duration%20can%20range%20from,and%20applications%2C%20to%20use%20X. - case x509Auth.SessionDuration == nil: - awsDefaultDuration := time.Hour // default 1 hour from AWS - x509Auth.SessionDuration = &awsDefaultDuration case *x509Auth.SessionDuration != 0 && (*x509Auth.SessionDuration < time.Minute*15 || *x509Auth.SessionDuration > time.Hour*12): return nil, errors.New("sessionDuration must be greater than 15 minutes, and less than 12 hours") } @@ -94,7 +89,7 @@ func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) TrustAnchorArn: x509Auth.TrustAnchorArn, AssumeRoleArn: x509Auth.AssumeRoleArn, SessionDuration: x509Auth.SessionDuration, - cfg: GetConfig(opts), + Cfg: GetConfig(opts), Clients: newClients(), } @@ -454,8 +449,8 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er var mySession *session.Session var config *aws.Config - if a.cfg != nil { - config = a.cfg.WithRegion(*a.region).WithHTTPClient(client).WithLogLevel(aws.LogOff) + if a.Cfg != nil { + config = a.Cfg.WithRegion(*a.region).WithHTTPClient(client).WithLogLevel(aws.LogOff) } // this is needed for testing purposes to mock the client, @@ -488,7 +483,7 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er SessionName: nil, } } else { - duration = int64(time.Hour) + duration = int64(time.Hour.Seconds()) createSessionRequest = rolesanywhere.CreateSessionInput{ Cert: ptr.Of(string(a.chainPEM)), @@ -500,10 +495,19 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er SessionName: nil, } } - - output, err := rolesClient.CreateSessionWithContext(ctx, &createSessionRequest) - if err != nil { - return nil, fmt.Errorf("failed to create session using dapr app identity: %w", err) + var output *rolesanywhere.CreateSessionOutput + if a.rolesAnywhereClient != nil { + var err error + output, err = a.rolesAnywhereClient.CreateSessionWithContext(ctx, &createSessionRequest) + if err != nil { + return nil, fmt.Errorf("failed to create session using dapr app identity: %w", err) + } + } else { + var err error + output, err = rolesClient.CreateSessionWithContext(ctx, &createSessionRequest) + if err != nil { + return nil, fmt.Errorf("failed to create session using dapr app identity: %w", err) + } } if output == nil || len(output.CredentialSet) != 1 { diff --git a/common/authentication/aws/x509_test.go b/common/authentication/aws/x509_test.go index 43f57b0a30..4e5752f33d 100644 --- a/common/authentication/aws/x509_test.go +++ b/common/authentication/aws/x509_test.go @@ -64,7 +64,7 @@ func TestGetX509Client(t *testing.T) { CreateSessionOutput: tt.mockOutput, CreateSessionError: tt.mockError, } - mockAWS := &x509{ + mockAWS := x509{ logger: logger.NewLogger("testLogger"), AssumeRoleArn: ptr.Of("arn:aws:iam:012345678910:role/exampleIAMRoleName"), TrustAnchorArn: ptr.Of("arn:aws:rolesanywhere:us-west-1:012345678910:trust-anchor/01234568-0123-0123-0123-012345678901"), diff --git a/state/aws/dynamodb/dynamodb.go b/state/aws/dynamodb/dynamodb.go index cd1e5a2df2..42db3e2e99 100644 --- a/state/aws/dynamodb/dynamodb.go +++ b/state/aws/dynamodb/dynamodb.go @@ -23,7 +23,6 @@ import ( "strconv" "time" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" jsoniterator "github.com/json-iterator/go" @@ -82,20 +81,24 @@ func (d *StateStore) Init(ctx context.Context, metadata state.Metadata) error { if err != nil { return err } - opts := awsAuth.Options{ - Logger: d.logger, - Properties: metadata.Properties, - Region: meta.Region, - Endpoint: meta.Endpoint, - AccessKey: meta.AccessKey, - SecretKey: meta.SecretKey, - SessionToken: meta.SessionToken, - } - provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) - if err != nil { - return err + if d.authProvider == nil { + opts := awsAuth.Options{ + Logger: d.logger, + Properties: metadata.Properties, + Region: meta.Region, + Endpoint: meta.Endpoint, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, + } + cfg := awsAuth.GetConfig(opts) + provider, err := awsAuth.NewProvider(ctx, opts, cfg) + if err != nil { + return err + } + d.authProvider = provider } - d.authProvider = provider + d.table = meta.Table d.ttlAttributeName = meta.TTLAttributeName d.partitionKey = meta.PartitionKey @@ -123,6 +126,7 @@ func (d *StateStore) validateTableAccess(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to get client: %v", err) } + _, err = clients.DynamoDB.GetItemWithContext(ctx, input) return err } diff --git a/state/aws/dynamodb/dynamodb_test.go b/state/aws/dynamodb/dynamodb_test.go index 8b20277ae8..3a103d88d2 100644 --- a/state/aws/dynamodb/dynamodb_test.go +++ b/state/aws/dynamodb/dynamodb_test.go @@ -22,9 +22,12 @@ import ( "time" awsAuth "github.com/dapr/components-contrib/common/authentication/aws" + "github.com/dapr/kit/logger" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" @@ -148,9 +151,9 @@ func TestInit(t *testing.T) { }) t.Run("Init with bad table name or permissions", func(t *testing.T) { + table := "does-not-exist" m.Properties = map[string]string{ - "Table": "does-not-exist", - "Region": "eu-west-1", + "Table": table, } mockedDB := &mockedDynamoDB{ @@ -166,17 +169,33 @@ func TestInit(t *testing.T) { mockedClients := awsAuth.Clients{ Dynamo: &dynamo, } - + mockedSession, err := session.NewSession(&aws.Config{ + Region: aws.String("us-west-1"), + Credentials: credentials.AnonymousCredentials, + }) + require.NoError(t, err) mockAuthProvider := &awsAuth.StaticAuth{ + Logger: logger.NewLogger("test"), Clients: &mockedClients, + // mock client creds so we don't get the error -> access: EmptyStaticCreds: static credentials are empty" + Cfg: &aws.Config{ + Credentials: credentials.AnonymousCredentials, + }, + Region: "us-west-1", + AccessKey: aws.String("mocked"), + SecretKey: aws.String("mocked"), + SessionToken: aws.String("mocked"), + Endpoint: aws.String("mocked"), + Session: mockedSession, } s := StateStore{ authProvider: mockAuthProvider, partitionKey: defaultPartitionKeyName, + table: table, } - err := s.Init(context.Background(), m) + err = s.Init(context.Background(), m) require.Error(t, err) require.EqualError(t, err, "error validating DynamoDB table 'does-not-exist' access: Requested resource not found") }) From bb2245099d7f00283fcbceed13fb37f3827a48e1 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Tue, 12 Nov 2024 17:22:15 -0600 Subject: [PATCH 19/39] fix: address final feedback Signed-off-by: Samantha Coyle --- bindings/aws/dynamodb/dynamodb.go | 7 +- bindings/aws/kinesis/kinesis.go | 47 +-- bindings/aws/s3/s3.go | 36 +-- bindings/aws/ses/ses.go | 6 +- bindings/aws/sns/sns.go | 6 +- bindings/aws/sqs/sqs.go | 18 +- common/authentication/aws/aws.go | 18 +- common/authentication/aws/static.go | 224 ++++---------- common/authentication/aws/x509.go | 285 ++++++------------ pubsub/aws/snssqs/snssqs.go | 102 +------ .../aws/parameterstore/parameterstore.go | 19 +- .../aws/secretmanager/secretmanager.go | 14 +- state/aws/dynamodb/dynamodb.go | 31 +- 13 files changed, 204 insertions(+), 609 deletions(-) diff --git a/bindings/aws/dynamodb/dynamodb.go b/bindings/aws/dynamodb/dynamodb.go index 823894b1ed..4ad0620257 100644 --- a/bindings/aws/dynamodb/dynamodb.go +++ b/bindings/aws/dynamodb/dynamodb.go @@ -16,7 +16,6 @@ package dynamodb import ( "context" "encoding/json" - "fmt" "reflect" "github.com/aws/aws-sdk-go/aws" @@ -94,11 +93,7 @@ func (d *DynamoDB) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bi return nil, err } - clients, err := d.authProvider.DynamoDB(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get client: %v", err) - } - _, err = clients.DynamoDB.PutItemWithContext(ctx, &dynamodb.PutItemInput{ + _, err = d.authProvider.DynamoDB().DynamoDB.PutItemWithContext(ctx, &dynamodb.PutItemInput{ Item: item, TableName: aws.String(d.table), }) diff --git a/bindings/aws/kinesis/kinesis.go b/bindings/aws/kinesis/kinesis.go index 8127fa8019..2e2f62897f 100644 --- a/bindings/aws/kinesis/kinesis.go +++ b/bindings/aws/kinesis/kinesis.go @@ -143,11 +143,7 @@ func (a *AWSKinesis) Invoke(ctx context.Context, req *bindings.InvokeRequest) (* if partitionKey == "" { partitionKey = uuid.New().String() } - clients, err := a.authProvider.Kinesis(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get client: %v", err) - } - _, err = clients.Kinesis.PutRecordWithContext(ctx, &kinesis.PutRecordInput{ + _, err := a.authProvider.Kinesis().Kinesis.PutRecordWithContext(ctx, &kinesis.PutRecordInput{ StreamName: &a.metadata.StreamName, Data: req.Data, PartitionKey: &partitionKey, @@ -160,19 +156,15 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er if a.closed.Load() { return errors.New("binding is closed") } - clients, err := a.authProvider.Kinesis(ctx) - if err != nil { - return fmt.Errorf("failed to get client: %v", err) - } if a.metadata.KinesisConsumerMode == SharedThroughput { - a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), clients.WorkerCfg(ctx, a.streamName, a.consumerName, a.consumerMode)) + a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), a.authProvider.Kinesis().WorkerCfg(ctx, a.streamName, a.consumerName, a.consumerMode)) err = a.worker.Start() if err != nil { return err } } else if a.metadata.KinesisConsumerMode == ExtendedFanout { var stream *kinesis.DescribeStreamOutput - stream, err = clients.Kinesis.DescribeStream(&kinesis.DescribeStreamInput{StreamName: &a.metadata.StreamName}) + stream, err = a.authProvider.Kinesis().Kinesis.DescribeStream(&kinesis.DescribeStreamInput{StreamName: &a.metadata.StreamName}) if err != nil { return err } @@ -182,7 +174,7 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er } } - stream, err := clients.Stream(ctx, a.streamName) + stream, err := a.authProvider.Kinesis().Stream(ctx, a.streamName) if err != nil { return fmt.Errorf("failed to get kinesis stream arn: %v", err) } @@ -232,12 +224,7 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes return default: } - clients, err := a.authProvider.Kinesis(ctx) - if err != nil { - a.logger.Errorf("failed to get client: %v", err) - return - } - sub, err := clients.Kinesis.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{ + sub, err := a.authProvider.Kinesis().Kinesis.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{ ConsumerARN: consumerARN, ShardId: s.ShardId, StartingPosition: &kinesis.StartingPosition{Type: aws.String(kinesis.ShardIteratorTypeLatest)}, @@ -286,11 +273,7 @@ func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*st // Only set timeout on consumer call. conCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - clients, err := a.authProvider.Kinesis(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get client: %v", err) - } - consumer, err := clients.Kinesis.DescribeStreamConsumerWithContext(conCtx, &kinesis.DescribeStreamConsumerInput{ + consumer, err := a.authProvider.Kinesis().Kinesis.DescribeStreamConsumerWithContext(conCtx, &kinesis.DescribeStreamConsumerInput{ ConsumerName: &a.metadata.ConsumerName, StreamARN: streamARN, }) @@ -302,11 +285,7 @@ func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*st } func (a *AWSKinesis) registerConsumer(ctx context.Context, streamARN *string) (*string, error) { - clients, err := a.authProvider.Kinesis(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get client: %v", err) - } - consumer, err := clients.Kinesis.RegisterStreamConsumerWithContext(ctx, &kinesis.RegisterStreamConsumerInput{ + consumer, err := a.authProvider.Kinesis().Kinesis.RegisterStreamConsumerWithContext(ctx, &kinesis.RegisterStreamConsumerInput{ ConsumerName: &a.metadata.ConsumerName, StreamARN: streamARN, }) @@ -328,12 +307,8 @@ func (a *AWSKinesis) registerConsumer(ctx context.Context, streamARN *string) (* func (a *AWSKinesis) deregisterConsumer(ctx context.Context, streamARN *string, consumerARN *string) error { if a.consumerARN != nil { // Use a background context because the running context may have been canceled already - clients, err := a.authProvider.Kinesis(ctx) - if err != nil { - return fmt.Errorf("failed to get client: %v", err) - } ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - _, err = clients.Kinesis.DeregisterStreamConsumerWithContext(ctx, &kinesis.DeregisterStreamConsumerInput{ + _, err := a.authProvider.Kinesis().Kinesis.DeregisterStreamConsumerWithContext(ctx, &kinesis.DeregisterStreamConsumerInput{ ConsumerARN: consumerARN, StreamARN: streamARN, ConsumerName: &a.metadata.ConsumerName, @@ -364,11 +339,7 @@ func (a *AWSKinesis) waitUntilConsumerExists(ctx aws.Context, input *kinesis.Des tmp := *input inCpy = &tmp } - clients, err := a.authProvider.Kinesis(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get client: %v", err) - } - req, _ := clients.Kinesis.DescribeStreamConsumerRequest(inCpy) + req, _ := a.authProvider.Kinesis().Kinesis.DescribeStreamConsumerRequest(inCpy) req.SetContext(ctx) req.ApplyOptions(opts...) diff --git a/bindings/aws/s3/s3.go b/bindings/aws/s3/s3.go index d9dac8631b..dd6f8906dc 100644 --- a/bindings/aws/s3/s3.go +++ b/bindings/aws/s3/s3.go @@ -207,12 +207,7 @@ func (s *AWSS3) create(ctx context.Context, req *bindings.InvokeRequest) (*bindi if metadata.StorageClass != "" { storageClass = aws.String(metadata.StorageClass) } - - clients, err := s.authProvider.S3(ctx) - if err != nil { - return nil, fmt.Errorf("s3 binding error: failed to get client: %v", err) - } - resultUpload, err := clients.Uploader.UploadWithContext(ctx, &s3manager.UploadInput{ + resultUpload, err := s.authProvider.S3().Uploader.UploadWithContext(ctx, &s3manager.UploadInput{ Bucket: ptr.Of(metadata.Bucket), Key: ptr.Of(key), Body: r, @@ -287,11 +282,7 @@ func (s *AWSS3) presignObject(ctx context.Context, bucket, key, ttl string) (str if err != nil { return "", fmt.Errorf("s3 binding error: cannot parse duration %s: %w", ttl, err) } - clients, err := s.authProvider.S3(ctx) - if err != nil { - return "", fmt.Errorf("s3 binding error: failed to get client: %v", err) - } - objReq, _ := clients.S3.GetObjectRequest(&s3.GetObjectInput{ + objReq, _ := s.authProvider.S3().S3.GetObjectRequest(&s3.GetObjectInput{ Bucket: ptr.Of(bucket), Key: ptr.Of(key), }) @@ -315,12 +306,7 @@ func (s *AWSS3) get(ctx context.Context, req *bindings.InvokeRequest) (*bindings } buff := &aws.WriteAtBuffer{} - - clients, err := s.authProvider.S3(ctx) - if err != nil { - return nil, fmt.Errorf("s3 binding error: failed to get client: %v", err) - } - _, err = clients.Downloader.DownloadWithContext(ctx, + _, err = s.authProvider.S3().Downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{ Bucket: ptr.Of(s.metadata.Bucket), @@ -354,13 +340,7 @@ func (s *AWSS3) delete(ctx context.Context, req *bindings.InvokeRequest) (*bindi if key == "" { return nil, fmt.Errorf("s3 binding error: required metadata '%s' missing", metadataKey) } - - clients, err := s.authProvider.S3(ctx) - if err != nil { - return nil, fmt.Errorf("s3 binding error: failed to get client: %v", err) - } - - _, err = clients.S3.DeleteObjectWithContext( + _, err := s.authProvider.S3().S3.DeleteObjectWithContext( ctx, &s3.DeleteObjectInput{ Bucket: ptr.Of(s.metadata.Bucket), @@ -389,13 +369,7 @@ func (s *AWSS3) list(ctx context.Context, req *bindings.InvokeRequest) (*binding if payload.MaxResults < 1 { payload.MaxResults = defaultMaxResults } - - clients, err := s.authProvider.S3(ctx) - if err != nil { - return nil, fmt.Errorf("s3 binding error: failed to get client: %v", err) - } - - result, err := clients.S3.ListObjectsWithContext(ctx, &s3.ListObjectsInput{ + result, err := s.authProvider.S3().S3.ListObjectsWithContext(ctx, &s3.ListObjectsInput{ Bucket: ptr.Of(s.metadata.Bucket), MaxKeys: ptr.Of(int64(payload.MaxResults)), Marker: ptr.Of(payload.Marker), diff --git a/bindings/aws/ses/ses.go b/bindings/aws/ses/ses.go index 211d434bad..eb7a1143e7 100644 --- a/bindings/aws/ses/ses.go +++ b/bindings/aws/ses/ses.go @@ -151,11 +151,7 @@ func (a *AWSSES) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind } // Attempt to send the email. - clients, err := a.authProvider.Ses(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get client: %v", err) - } - result, err := clients.Ses.SendEmail(input) + result, err := a.authProvider.Ses().Ses.SendEmail(input) if err != nil { return nil, fmt.Errorf("SES binding error. Sending email failed: %w", err) } diff --git a/bindings/aws/sns/sns.go b/bindings/aws/sns/sns.go index 990deffe92..763d2c29d3 100644 --- a/bindings/aws/sns/sns.go +++ b/bindings/aws/sns/sns.go @@ -109,11 +109,7 @@ func (a *AWSSNS) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind msg := fmt.Sprintf("%v", payload.Message) subject := fmt.Sprintf("%v", payload.Subject) - clients, err := a.authProvider.Sns(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get client: %v", err) - } - _, err = clients.Sns.PublishWithContext(ctx, &sns.PublishInput{ + _, err = a.authProvider.Sns().Sns.PublishWithContext(ctx, &sns.PublishInput{ Message: &msg, Subject: &subject, TopicArn: &a.topicARN, diff --git a/bindings/aws/sqs/sqs.go b/bindings/aws/sqs/sqs.go index b12b22ad63..dde86b4e52 100644 --- a/bindings/aws/sqs/sqs.go +++ b/bindings/aws/sqs/sqs.go @@ -91,16 +91,12 @@ func (a *AWSSQS) Operations() []bindings.OperationKind { func (a *AWSSQS) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { msgBody := string(req.Data) - clients, err := a.authProvider.Sqs(ctx) - if err != nil { - a.logger.Errorf("failed to get client: %v", err) - } - url, err := clients.QueueURL(ctx, a.queueName) + url, err := a.authProvider.Sqs().QueueURL(ctx, a.queueName) if err != nil { a.logger.Errorf("failed to get queue url: %v", err) } - _, err = clients.Sqs.SendMessageWithContext(ctx, &sqs.SendMessageInput{ + _, err = a.authProvider.Sqs().Sqs.SendMessageWithContext(ctx, &sqs.SendMessageInput{ MessageBody: &msgBody, QueueUrl: url, }) @@ -122,16 +118,12 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error { if ctx.Err() != nil || a.closed.Load() { return } - clients, err := a.authProvider.Sqs(ctx) - if err != nil { - a.logger.Errorf("failed to get client: %v", err) - } - url, err := clients.QueueURL(ctx, a.queueName) + url, err := a.authProvider.Sqs().QueueURL(ctx, a.queueName) if err != nil { a.logger.Errorf("failed to get queue url: %v", err) } - result, err := clients.Sqs.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{ + result, err := a.authProvider.Sqs().Sqs.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{ QueueUrl: url, AttributeNames: aws.StringSlice([]string{ "SentTimestamp", @@ -157,7 +149,7 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error { msgHandle := m.ReceiptHandle // Use a background context here because ctx may be canceled already - clients.Sqs.DeleteMessageWithContext(context.Background(), &sqs.DeleteMessageInput{ + a.authProvider.Sqs().Sqs.DeleteMessageWithContext(context.Background(), &sqs.DeleteMessageInput{ QueueUrl: url, ReceiptHandle: msgHandle, }) diff --git a/common/authentication/aws/aws.go b/common/authentication/aws/aws.go index c6298a8ebb..a45eb48277 100644 --- a/common/authentication/aws/aws.go +++ b/common/authentication/aws/aws.go @@ -81,15 +81,15 @@ func GetConfig(opts Options) *aws.Config { } type Provider interface { - S3(ctx context.Context) (*S3Clients, error) - DynamoDB(ctx context.Context) (*DynamoDBClients, error) - Sqs(ctx context.Context) (*SqsClients, error) - Sns(ctx context.Context) (*SnsClients, error) - SnsSqs(ctx context.Context) (*SnsSqsClients, error) - SecretManager(ctx context.Context) (*SecretManagerClients, error) - ParameterStore(ctx context.Context) (*ParameterStoreClients, error) - Kinesis(ctx context.Context) (*KinesisClients, error) - Ses(ctx context.Context) (*SesClients, error) + S3() *S3Clients + DynamoDB() *DynamoDBClients + Sqs() *SqsClients + Sns() *SnsClients + SnsSqs() *SnsSqsClients + SecretManager() *SecretManagerClients + ParameterStore() *ParameterStoreClients + Kinesis() *KinesisClients + Ses() *SesClients Close() error } diff --git a/common/authentication/aws/static.go b/common/authentication/aws/static.go index c65136b2cf..8d5d4ea302 100644 --- a/common/authentication/aws/static.go +++ b/common/authentication/aws/static.go @@ -44,7 +44,7 @@ type StaticAuth struct { Cfg *aws.Config } -func newStaticIAM(ctx context.Context, opts Options, cfg *aws.Config) (*StaticAuth, error) { +func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth, error) { auth := &StaticAuth{ Logger: opts.Logger, Region: opts.Region, @@ -66,240 +66,132 @@ func newStaticIAM(ctx context.Context, opts Options, cfg *aws.Config) (*StaticAu return auth, nil } -func (a *StaticAuth) S3(ctx context.Context) (*S3Clients, error) { +func (a *StaticAuth) S3() *S3Clients { a.mu.Lock() defer a.mu.Unlock() if a.Clients.s3 != nil { - return a.Clients.s3, nil + return a.Clients.s3 } - // respect context cancellation while initializing client - done := make(chan struct{}) - go func() { - defer close(done) - s3Clients := S3Clients{} - a.Clients.s3 = &s3Clients - a.Logger.Debugf("Initializing S3 clients with session %v", a.Session) - a.Clients.s3.New(a.Session) - }() - - // wait for new client or context to be canceled - select { - case <-done: - return a.Clients.s3, nil - case <-ctx.Done(): - return nil, ctx.Err() - } + s3Clients := S3Clients{} + a.Clients.s3 = &s3Clients + a.Clients.s3.New(a.Session) + return a.Clients.s3 } -func (a *StaticAuth) DynamoDB(ctx context.Context) (*DynamoDBClients, error) { +func (a *StaticAuth) DynamoDB() *DynamoDBClients { a.mu.Lock() defer a.mu.Unlock() - fmt.Printf("ready sam") + if a.Clients.Dynamo != nil { - fmt.Printf("sam it is not nil so it's injected fine") - return a.Clients.Dynamo, nil + return a.Clients.Dynamo } - // respect context cancellation while initializing client - done := make(chan struct{}) - go func() { - defer close(done) - clients := DynamoDBClients{} - a.Clients.Dynamo = &clients - a.Clients.Dynamo.New(a.Session) - }() - - // wait for new client or context to be canceled - select { - case <-done: - return a.Clients.Dynamo, nil - case <-ctx.Done(): - return nil, ctx.Err() - } + clients := DynamoDBClients{} + a.Clients.Dynamo = &clients + a.Clients.Dynamo.New(a.Session) + + return a.Clients.Dynamo } -func (a *StaticAuth) Sqs(ctx context.Context) (*SqsClients, error) { +func (a *StaticAuth) Sqs() *SqsClients { a.mu.Lock() defer a.mu.Unlock() if a.Clients.sqs != nil { - return a.Clients.sqs, nil + return a.Clients.sqs } - // respect context cancellation while initializing client - done := make(chan struct{}) - go func() { - defer close(done) - clients := SqsClients{} - a.Clients.sqs = &clients - a.Clients.sqs.New(a.Session) - }() - - // wait for new client or context to be canceled - select { - case <-done: - return a.Clients.sqs, nil - case <-ctx.Done(): - return nil, ctx.Err() - } + clients := SqsClients{} + a.Clients.sqs = &clients + a.Clients.sqs.New(a.Session) + + return a.Clients.sqs } -func (a *StaticAuth) Sns(ctx context.Context) (*SnsClients, error) { +func (a *StaticAuth) Sns() *SnsClients { a.mu.Lock() defer a.mu.Unlock() if a.Clients.sns != nil { - return a.Clients.sns, nil + return a.Clients.sns } - // respect context cancellation while initializing client - done := make(chan struct{}) - go func() { - defer close(done) - clients := SnsClients{} - a.Clients.sns = &clients - a.Clients.sns.New(a.Session) - }() - - // wait for new client or context to be canceled - select { - case <-done: - return a.Clients.sns, nil - case <-ctx.Done(): - return nil, ctx.Err() - } + clients := SnsClients{} + a.Clients.sns = &clients + a.Clients.sns.New(a.Session) + return a.Clients.sns } -func (a *StaticAuth) SnsSqs(ctx context.Context) (*SnsSqsClients, error) { +func (a *StaticAuth) SnsSqs() *SnsSqsClients { a.mu.Lock() defer a.mu.Unlock() if a.Clients.snssqs != nil { - return a.Clients.snssqs, nil + return a.Clients.snssqs } - // respect context cancellation while initializing client - done := make(chan struct{}) - go func() { - defer close(done) - clients := SnsSqsClients{} - a.Clients.snssqs = &clients - a.Clients.snssqs.New(a.Session) - }() - - // wait for new client or context to be canceled - select { - case <-done: - return a.Clients.snssqs, nil - case <-ctx.Done(): - return nil, ctx.Err() - } + clients := SnsSqsClients{} + a.Clients.snssqs = &clients + a.Clients.snssqs.New(a.Session) + return a.Clients.snssqs } -func (a *StaticAuth) SecretManager(ctx context.Context) (*SecretManagerClients, error) { +func (a *StaticAuth) SecretManager() *SecretManagerClients { a.mu.Lock() defer a.mu.Unlock() if a.Clients.Secret != nil { - return a.Clients.Secret, nil + return a.Clients.Secret } - // respect context cancellation while initializing client - done := make(chan struct{}) - go func() { - defer close(done) - clients := SecretManagerClients{} - a.Clients.Secret = &clients - a.Clients.Secret.New(a.Session) - }() - - // wait for new client or context to be canceled - select { - case <-done: - return a.Clients.Secret, nil - case <-ctx.Done(): - return nil, ctx.Err() - } + clients := SecretManagerClients{} + a.Clients.Secret = &clients + a.Clients.Secret.New(a.Session) + return a.Clients.Secret } -func (a *StaticAuth) ParameterStore(ctx context.Context) (*ParameterStoreClients, error) { +func (a *StaticAuth) ParameterStore() *ParameterStoreClients { a.mu.Lock() defer a.mu.Unlock() if a.Clients.ParameterStore != nil { - return a.Clients.ParameterStore, nil + return a.Clients.ParameterStore } - // respect context cancellation while initializing client - done := make(chan struct{}) - go func() { - defer close(done) - clients := ParameterStoreClients{} - a.Clients.ParameterStore = &clients - a.Clients.ParameterStore.New(a.Session) - }() - - // wait for new client or context to be canceled - select { - case <-done: - return a.Clients.ParameterStore, nil - case <-ctx.Done(): - return nil, ctx.Err() - } + clients := ParameterStoreClients{} + a.Clients.ParameterStore = &clients + a.Clients.ParameterStore.New(a.Session) + return a.Clients.ParameterStore } -func (a *StaticAuth) Kinesis(ctx context.Context) (*KinesisClients, error) { +func (a *StaticAuth) Kinesis() *KinesisClients { a.mu.Lock() defer a.mu.Unlock() if a.Clients.kinesis != nil { - return a.Clients.kinesis, nil + return a.Clients.kinesis } - // respect context cancellation while initializing client - done := make(chan struct{}) - go func() { - defer close(done) - clients := KinesisClients{} - a.Clients.kinesis = &clients - a.Clients.kinesis.New(a.Session) - }() - - // wait for new client or context to be canceled - select { - case <-done: - return a.Clients.kinesis, nil - case <-ctx.Done(): - return nil, ctx.Err() - } + clients := KinesisClients{} + a.Clients.kinesis = &clients + a.Clients.kinesis.New(a.Session) + return a.Clients.kinesis } -func (a *StaticAuth) Ses(ctx context.Context) (*SesClients, error) { +func (a *StaticAuth) Ses() *SesClients { a.mu.Lock() defer a.mu.Unlock() if a.Clients.ses != nil { - return a.Clients.ses, nil + return a.Clients.ses } - // respect context cancellation while initializing client - done := make(chan struct{}) - go func() { - defer close(done) - clients := SesClients{} - a.Clients.ses = &clients - a.Clients.ses.New(a.Session) - }() - - // wait for new client or context to be canceled - select { - case <-done: - return a.Clients.ses, nil - case <-ctx.Done(): - return nil, ctx.Err() - } + clients := SesClients{} + a.Clients.ses = &clients + a.Clients.ses.New(a.Session) + return a.Clients.ses } func (a *StaticAuth) getTokenClient() (*session.Session, error) { diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go index 99d8b1e8ef..bf395d1dcc 100644 --- a/common/authentication/aws/x509.go +++ b/common/authentication/aws/x509.go @@ -44,10 +44,8 @@ import ( type x509 struct { mu sync.RWMutex - wg sync.WaitGroup - // used for background session refresh logic that cannot use the context passed to the newx509 function - internalContext context.Context - internalContextCancel func() + wg sync.WaitGroup + closeCh chan struct{} logger logger.Logger Clients *Clients @@ -109,18 +107,13 @@ func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) return nil, fmt.Errorf("failed to create the initial session: %v", err) } auth.session = initialSession - - // This is needed to keep the session refresher on the background context, but still cancellable. - auth.internalContext, auth.internalContextCancel = context.WithCancel(context.Background()) auth.startSessionRefresher() return auth, nil } func (a *x509) Close() error { - if a.internalContextCancel != nil { - a.internalContextCancel() - } + close(a.closeCh) a.wg.Wait() return nil } @@ -148,239 +141,132 @@ func (a *x509) getCertPEM(ctx context.Context) error { return nil } -func (a *x509) S3(ctx context.Context) (*S3Clients, error) { +func (a *x509) S3() *S3Clients { a.mu.Lock() defer a.mu.Unlock() if a.Clients.s3 != nil { - return a.Clients.s3, nil + return a.Clients.s3 } - // respect context cancellation while initializing client - done := make(chan struct{}) - go func() { - defer close(done) - s3Clients := S3Clients{} - a.Clients.s3 = &s3Clients - a.logger.Debugf("Initializing S3 clients with session %v", a.session) - a.Clients.s3.New(a.session) - }() - - // wait for new client or context to be canceled - select { - case <-done: - return a.Clients.s3, nil - case <-ctx.Done(): - return nil, ctx.Err() - } + s3Clients := S3Clients{} + a.Clients.s3 = &s3Clients + a.Clients.s3.New(a.session) + return a.Clients.s3 } -func (a *x509) DynamoDB(ctx context.Context) (*DynamoDBClients, error) { +func (a *x509) DynamoDB() *DynamoDBClients { a.mu.Lock() defer a.mu.Unlock() if a.Clients.Dynamo != nil { - return a.Clients.Dynamo, nil + return a.Clients.Dynamo } - // respect context cancellation while initializing client - done := make(chan struct{}) - go func() { - defer close(done) - clients := DynamoDBClients{} - a.Clients.Dynamo = &clients - a.Clients.Dynamo.New(a.session) - }() + clients := DynamoDBClients{} + a.Clients.Dynamo = &clients + a.Clients.Dynamo.New(a.session) - // wait for new client or context to be canceled - select { - case <-done: - return a.Clients.Dynamo, nil - case <-ctx.Done(): - return nil, ctx.Err() - } + return a.Clients.Dynamo } -func (a *x509) Sqs(ctx context.Context) (*SqsClients, error) { +func (a *x509) Sqs() *SqsClients { a.mu.Lock() defer a.mu.Unlock() if a.Clients.sqs != nil { - return a.Clients.sqs, nil + return a.Clients.sqs } - // respect context cancellation while initializing client - done := make(chan struct{}) - go func() { - defer close(done) - clients := SqsClients{} - a.Clients.sqs = &clients - a.Clients.sqs.New(a.session) - }() + clients := SqsClients{} + a.Clients.sqs = &clients + a.Clients.sqs.New(a.session) - // wait for new client or context to be canceled - select { - case <-done: - return a.Clients.sqs, nil - case <-ctx.Done(): - return nil, ctx.Err() - } + return a.Clients.sqs } -func (a *x509) Sns(ctx context.Context) (*SnsClients, error) { +func (a *x509) Sns() *SnsClients { a.mu.Lock() defer a.mu.Unlock() if a.Clients.sns != nil { - return a.Clients.sns, nil + return a.Clients.sns } - // respect context cancellation while initializing client - done := make(chan struct{}) - go func() { - defer close(done) - clients := SnsClients{} - a.Clients.sns = &clients - a.Clients.sns.New(a.session) - }() - - // wait for new client or context to be canceled - select { - case <-done: - return a.Clients.sns, nil - case <-ctx.Done(): - return nil, ctx.Err() - } + clients := SnsClients{} + a.Clients.sns = &clients + a.Clients.sns.New(a.session) + return a.Clients.sns } -func (a *x509) SnsSqs(ctx context.Context) (*SnsSqsClients, error) { +func (a *x509) SnsSqs() *SnsSqsClients { a.mu.Lock() defer a.mu.Unlock() if a.Clients.snssqs != nil { - return a.Clients.snssqs, nil + return a.Clients.snssqs } - // respect context cancellation while initializing client - done := make(chan struct{}) - go func() { - defer close(done) - clients := SnsSqsClients{} - a.Clients.snssqs = &clients - a.Clients.snssqs.New(a.session) - }() - - // wait for new client or context to be canceled - select { - case <-done: - return a.Clients.snssqs, nil - case <-ctx.Done(): - return nil, ctx.Err() - } + clients := SnsSqsClients{} + a.Clients.snssqs = &clients + a.Clients.snssqs.New(a.session) + return a.Clients.snssqs } -func (a *x509) SecretManager(ctx context.Context) (*SecretManagerClients, error) { +func (a *x509) SecretManager() *SecretManagerClients { a.mu.Lock() defer a.mu.Unlock() if a.Clients.Secret != nil { - return a.Clients.Secret, nil + return a.Clients.Secret } - // respect context cancellation while initializing client - done := make(chan struct{}) - go func() { - defer close(done) - clients := SecretManagerClients{} - a.Clients.Secret = &clients - a.Clients.Secret.New(a.session) - }() - - // wait for new client or context to be canceled - select { - case <-done: - return a.Clients.Secret, nil - case <-ctx.Done(): - return nil, ctx.Err() - } + clients := SecretManagerClients{} + a.Clients.Secret = &clients + a.Clients.Secret.New(a.session) + return a.Clients.Secret } -func (a *x509) ParameterStore(ctx context.Context) (*ParameterStoreClients, error) { +func (a *x509) ParameterStore() *ParameterStoreClients { a.mu.Lock() defer a.mu.Unlock() if a.Clients.ParameterStore != nil { - return a.Clients.ParameterStore, nil + return a.Clients.ParameterStore } - // respect context cancellation while initializing client - done := make(chan struct{}) - go func() { - defer close(done) - clients := ParameterStoreClients{} - a.Clients.ParameterStore = &clients - a.Clients.ParameterStore.New(a.session) - }() - - // wait for new client or context to be canceled - select { - case <-done: - return a.Clients.ParameterStore, nil - case <-ctx.Done(): - return nil, ctx.Err() - } + clients := ParameterStoreClients{} + a.Clients.ParameterStore = &clients + a.Clients.ParameterStore.New(a.session) + return a.Clients.ParameterStore } -func (a *x509) Kinesis(ctx context.Context) (*KinesisClients, error) { +func (a *x509) Kinesis() *KinesisClients { a.mu.Lock() defer a.mu.Unlock() if a.Clients.kinesis != nil { - return a.Clients.kinesis, nil + return a.Clients.kinesis } - // respect context cancellation while initializing client - done := make(chan struct{}) - go func() { - defer close(done) - clients := KinesisClients{} - a.Clients.kinesis = &clients - a.Clients.kinesis.New(a.session) - }() - - // wait for new client or context to be canceled - select { - case <-done: - return a.Clients.kinesis, nil - case <-ctx.Done(): - return nil, ctx.Err() - } + clients := KinesisClients{} + a.Clients.kinesis = &clients + a.Clients.kinesis.New(a.session) + return a.Clients.kinesis } -func (a *x509) Ses(ctx context.Context) (*SesClients, error) { +func (a *x509) Ses() *SesClients { a.mu.Lock() defer a.mu.Unlock() if a.Clients.ses != nil { - return a.Clients.ses, nil + return a.Clients.ses } - // respect context cancellation while initializing client - done := make(chan struct{}) - go func() { - defer close(done) - clients := SesClients{} - a.Clients.ses = &clients - a.Clients.ses.New(a.session) - }() - - // wait for new client or context to be canceled - select { - case <-done: - return a.Clients.ses, nil - case <-ctx.Done(): - return nil, ctx.Err() - } + clients := SesClients{} + a.Clients.ses = &clients + a.Clients.ses.New(a.session) + return a.Clients.ses } func (a *x509) initializeTrustAnchors() error { @@ -529,7 +415,7 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er } func (a *x509) startSessionRefresher() { - a.logger.Debugf("starting session refresher for x509 auth") + a.logger.Infof("starting session refresher for x509 auth") // if there is a set session duration, then exit bc we will not auto refresh the session. if *a.SessionDuration != 0 { a.logger.Debugf("session duration was set, so there is no authentication refreshing") @@ -539,36 +425,39 @@ func (a *x509) startSessionRefresher() { a.wg.Add(1) go func() { defer a.wg.Done() - - // renew at ~half the lifespan - expiration, err := a.session.Config.Credentials.ExpiresAt() - if err != nil { - a.logger.Errorf("failed to retrieve session expiration time: %w", err) - return - } - - timeUntilExpiration := time.Until(expiration) - refreshInterval := timeUntilExpiration / 2 - ticker := time.NewTicker(refreshInterval) - defer ticker.Stop() - for { + // renew at ~half the lifespan + expiration, err := a.session.Config.Credentials.ExpiresAt() + if err != nil { + a.logger.Errorf("Failed to retrieve session expiration time, using 30 minute interval: %w", err) + expiration = time.Now().Add(time.Hour) + } + timeUntilExpiration := time.Until(expiration) + refreshInterval := timeUntilExpiration / 2 select { - case <-ticker.C: - a.logger.Debugf("Refreshing session as expiration is near") - newSession, err := a.createOrRefreshSession(a.internalContext) - if err != nil { - a.logger.Errorf("failed to refresh session: %w", err) - return - } - - a.Clients.refresh(newSession) - - a.logger.Debugf("AWS IAM Roles Anywhere session credentials refreshed successfully") - case <-a.internalContext.Done(): - a.logger.Debugf("Session refresher stopped due to context cancellation") + case <-time.After(refreshInterval): + a.refreshClient() + case <-a.closeCh: + a.logger.Debugf("Session refresher is stopped") return } } }() } + +func (a *x509) refreshClient() { + for { + newSession, err := a.createOrRefreshSession(context.Background()) + if err == nil { + a.Clients.refresh(newSession) + a.logger.Debugf("AWS IAM Roles Anywhere session credentials refreshed successfully") + return + } + a.logger.Errorf("Failed to refresh session: %w", err) + select { + case <-time.After(time.Second * 5): + case <-a.closeCh: + return + } + } +} diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index ca93c27f10..71f45b4557 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -189,14 +189,8 @@ func (s *snsSqs) setAwsAccountIDIfNotProvided(parentCtx context.Context) error { if len(s.metadata.AccountID) == awsAccountIDLength { return nil } - - clients, err := s.authProvider.SnsSqs(parentCtx) - if err != nil { - return fmt.Errorf("failed to get client: %v", err) - } - ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) - callerIDOutput, err := clients.Sts.GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{}) + callerIDOutput, err := s.authProvider.SnsSqs().Sts.GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{}) cancelFn() if err != nil { return fmt.Errorf("error fetching sts caller ID: %w", err) @@ -221,13 +215,8 @@ func (s *snsSqs) createTopic(parentCtx context.Context, topic string) (string, e attributes := map[string]*string{"FifoTopic": aws.String("true"), "ContentBasedDeduplication": aws.String("true")} snsCreateTopicInput.SetAttributes(attributes) } - clients, err := s.authProvider.SnsSqs(parentCtx) - if err != nil { - return "", fmt.Errorf("failed to get client: %v", err) - } - ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) - createTopicResponse, err := clients.Sns.CreateTopicWithContext(ctx, snsCreateTopicInput) + createTopicResponse, err := s.authProvider.SnsSqs().Sns.CreateTopicWithContext(ctx, snsCreateTopicInput) cancelFn() if err != nil { return "", fmt.Errorf("error while creating an SNS topic: %w", err) @@ -237,13 +226,9 @@ func (s *snsSqs) createTopic(parentCtx context.Context, topic string) (string, e } func (s *snsSqs) getTopicArn(parentCtx context.Context, topic string) (string, error) { - clients, err := s.authProvider.SnsSqs(parentCtx) - if err != nil { - return "", fmt.Errorf("failed to get client: %v", err) - } ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) arn := s.buildARN("sns", topic) - getTopicOutput, err := clients.Sns.GetTopicAttributesWithContext(ctx, &sns.GetTopicAttributesInput{ + getTopicOutput, err := s.authProvider.SnsSqs().Sns.GetTopicAttributesWithContext(ctx, &sns.GetTopicAttributesInput{ TopicArn: &arn, }) cancelFn() @@ -310,20 +295,15 @@ func (s *snsSqs) createQueue(parentCtx context.Context, queueName string) (*sqsQ sqsCreateQueueInput.SetAttributes(attributes) } - clients, err := s.authProvider.SnsSqs(parentCtx) - if err != nil { - return nil, fmt.Errorf("failed to get client: %v", err) - } - ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - createQueueResponse, err := clients.Sqs.CreateQueueWithContext(ctx, sqsCreateQueueInput) + createQueueResponse, err := s.authProvider.SnsSqs().Sqs.CreateQueueWithContext(ctx, sqsCreateQueueInput) cancel() if err != nil { return nil, fmt.Errorf("error creaing an SQS queue: %w", err) } ctx, cancel = context.WithTimeout(parentCtx, s.opsTimeout) - queueAttributesResponse, err := clients.Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ + queueAttributesResponse, err := s.authProvider.SnsSqs().Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ AttributeNames: []*string{aws.String("QueueArn")}, QueueUrl: createQueueResponse.QueueUrl, }) @@ -339,13 +319,8 @@ func (s *snsSqs) createQueue(parentCtx context.Context, queueName string) (*sqsQ } func (s *snsSqs) getQueueArn(parentCtx context.Context, queueName string) (*sqsQueueInfo, error) { - clients, err := s.authProvider.SnsSqs(parentCtx) - if err != nil { - return nil, fmt.Errorf("failed to get client: %v", err) - } - ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - queueURLOutput, err := clients.Sqs.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{QueueName: aws.String(queueName), QueueOwnerAWSAccountId: aws.String(s.metadata.AccountID)}) + queueURLOutput, err := s.authProvider.SnsSqs().Sqs.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{QueueName: aws.String(queueName), QueueOwnerAWSAccountId: aws.String(s.metadata.AccountID)}) cancel() if err != nil { return nil, fmt.Errorf("error: %w while getting url of queue: %s", err, queueName) @@ -353,7 +328,7 @@ func (s *snsSqs) getQueueArn(parentCtx context.Context, queueName string) (*sqsQ url := queueURLOutput.QueueUrl ctx, cancel = context.WithTimeout(parentCtx, s.opsTimeout) - getQueueOutput, err := clients.Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{QueueUrl: url, AttributeNames: []*string{aws.String("QueueArn")}}) + getQueueOutput, err := s.authProvider.SnsSqs().Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{QueueUrl: url, AttributeNames: []*string{aws.String("QueueArn")}}) cancel() if err != nil { return nil, fmt.Errorf("error: %w while getting information for queue: %s, with url: %s", err, queueName, *url) @@ -413,13 +388,8 @@ func (s *snsSqs) getMessageGroupID(req *pubsub.PublishRequest) *string { } func (s *snsSqs) createSnsSqsSubscription(parentCtx context.Context, queueArn, topicArn string) (string, error) { - clients, err := s.authProvider.SnsSqs(parentCtx) - if err != nil { - return "", fmt.Errorf("failed to get client: %v", err) - } - ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - subscribeOutput, err := clients.Sns.SubscribeWithContext(ctx, &sns.SubscribeInput{ + subscribeOutput, err := s.authProvider.SnsSqs().Sns.SubscribeWithContext(ctx, &sns.SubscribeInput{ Attributes: nil, Endpoint: aws.String(queueArn), // create SQS queue per subscription. Protocol: aws.String("sqs"), @@ -438,12 +408,8 @@ func (s *snsSqs) createSnsSqsSubscription(parentCtx context.Context, queueArn, t } func (s *snsSqs) getSnsSqsSubscriptionArn(parentCtx context.Context, topicArn string) (string, error) { - clients, err := s.authProvider.SnsSqs(parentCtx) - if err != nil { - return "", fmt.Errorf("failed to get client: %v", err) - } ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - listSubscriptionsOutput, err := clients.Sns.ListSubscriptionsByTopicWithContext(ctx, &sns.ListSubscriptionsByTopicInput{TopicArn: aws.String(topicArn)}) + listSubscriptionsOutput, err := s.authProvider.SnsSqs().Sns.ListSubscriptionsByTopicWithContext(ctx, &sns.ListSubscriptionsByTopicInput{TopicArn: aws.String(topicArn)}) cancel() if err != nil { return "", fmt.Errorf("error listing subsriptions for topic arn: %v: %w", topicArn, err) @@ -491,12 +457,8 @@ func (s *snsSqs) getOrCreateSnsSqsSubscription(ctx context.Context, queueArn, to } func (s *snsSqs) acknowledgeMessage(parentCtx context.Context, queueURL string, receiptHandle *string) error { - clients, err := s.authProvider.SnsSqs(parentCtx) - if err != nil { - return fmt.Errorf("failed to get client: %v", err) - } ctx, cancelFn := context.WithCancel(parentCtx) - _, err = clients.Sqs.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{ + _, err := s.authProvider.SnsSqs().Sqs.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{ QueueUrl: aws.String(queueURL), ReceiptHandle: receiptHandle, }) @@ -509,14 +471,9 @@ func (s *snsSqs) acknowledgeMessage(parentCtx context.Context, queueURL string, } func (s *snsSqs) resetMessageVisibilityTimeout(parentCtx context.Context, queueURL string, receiptHandle *string) error { - clients, err := s.authProvider.SnsSqs(parentCtx) - if err != nil { - return fmt.Errorf("failed to get client: %v", err) - } - ctx, cancelFn := context.WithCancel(parentCtx) // reset the timeout to its initial value so that the remaining timeout would be overridden by the initial value for other consumer to attempt processing. - _, err = clients.Sqs.ChangeMessageVisibilityWithContext(ctx, &sqs.ChangeMessageVisibilityInput{ + _, err := s.authProvider.SnsSqs().Sqs.ChangeMessageVisibilityWithContext(ctx, &sqs.ChangeMessageVisibilityInput{ QueueUrl: aws.String(queueURL), ReceiptHandle: receiptHandle, VisibilityTimeout: aws.Int64(0), @@ -643,17 +600,11 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters if ctx.Err() != nil { break } - - clients, err := s.authProvider.SnsSqs(ctx) - if err != nil { - s.logger.Errorf("failed to get client: %v", err) - } - // Internally, by default, aws go sdk performs 3 retries with exponential backoff to contact // sqs and try pull messages. Since we are iteratively short polling (based on the defined // s.metadata.messageWaitTimeSeconds) the sdk backoff is not effective as it gets reset per each polling // iteration. Therefore, a global backoff (to the internal backoff) is used (sqsPullExponentialBackoff). - messageResponse, err := clients.Sqs.ReceiveMessageWithContext(ctx, receiveMessageInput) + messageResponse, err := s.authProvider.SnsSqs().Sqs.ReceiveMessageWithContext(ctx, receiveMessageInput) if err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || ctx.Err() != nil { s.logger.Warn("context canceled; stopping consuming from queue arn: %v", queueInfo.arn) @@ -745,14 +696,8 @@ func (s *snsSqs) setDeadLettersQueueAttributes(parentCtx context.Context, queueI return wrappedErr } - - clients, err := s.authProvider.SnsSqs(parentCtx) - if err != nil { - s.logger.Errorf("failed to get client: %v", err) - } - ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) - _, derr = clients.Sqs.SetQueueAttributesWithContext(ctx, sqsSetQueueAttributesInput) + _, derr = s.authProvider.SnsSqs().Sqs.SetQueueAttributesWithContext(ctx, sqsSetQueueAttributesInput) cancelFn() if derr != nil { wrappedErr := fmt.Errorf("error updating queue attributes with dead-letter queue: %w", derr) @@ -770,14 +715,9 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(parentCtx context.Context, return nil } - clients, err := s.authProvider.SnsSqs(parentCtx) - if err != nil { - return fmt.Errorf("failed to get client: %v", err) - } - ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) // only permit SNS to send messages to SQS using the created subscription. - getQueueAttributesOutput, err := clients.Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ + getQueueAttributesOutput, err := s.authProvider.SnsSqs().Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ QueueUrl: &sqsQueueInfo.url, AttributeNames: []*string{aws.String(sqs.QueueAttributeNamePolicy)}, }) @@ -803,13 +743,8 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(parentCtx context.Context, return fmt.Errorf("failed serializing new sqs policy: %w", uerr) } - clients, err = s.authProvider.SnsSqs(parentCtx) - if err != nil { - return fmt.Errorf("failed to get client: %v", err) - } - ctx, cancelFn = context.WithTimeout(parentCtx, s.opsTimeout) - _, err = clients.Sqs.SetQueueAttributesWithContext(ctx, &(sqs.SetQueueAttributesInput{ + _, err = s.authProvider.SnsSqs().Sqs.SetQueueAttributesWithContext(ctx, &(sqs.SetQueueAttributesInput{ Attributes: map[string]*string{ "Policy": aws.String(string(b)), }, @@ -921,13 +856,8 @@ func (s *snsSqs) Publish(ctx context.Context, req *pubsub.PublishRequest) error snsPublishInput.MessageGroupId = s.getMessageGroupID(req) } - clients, err := s.authProvider.SnsSqs(ctx) - if err != nil { - return fmt.Errorf("failed to get client: %v", err) - } - // sns client has internal exponential backoffs. - _, err = clients.Sns.PublishWithContext(ctx, snsPublishInput) + _, err = s.authProvider.SnsSqs().Sns.PublishWithContext(ctx, snsPublishInput) if err != nil { wrappedErr := fmt.Errorf("error publishing to topic: %s with topic ARN %s: %w", req.Topic, topicArn, err) s.logger.Error(wrappedErr) diff --git a/secretstores/aws/parameterstore/parameterstore.go b/secretstores/aws/parameterstore/parameterstore.go index a80507f7d3..72c8def093 100644 --- a/secretstores/aws/parameterstore/parameterstore.go +++ b/secretstores/aws/parameterstore/parameterstore.go @@ -92,11 +92,8 @@ func (s *ssmSecretStore) GetSecret(ctx context.Context, req secretstores.GetSecr versionID = value name = fmt.Sprintf("%s:%s", req.Name, versionID) } - clients, err := s.authProvider.ParameterStore(ctx) - if err != nil { - return secretstores.GetSecretResponse{Data: nil}, fmt.Errorf("failed to get client: %v", err) - } - output, err := clients.Store.GetParameterWithContext(ctx, &ssm.GetParameterInput{ + + output, err := s.authProvider.ParameterStore().Store.GetParameterWithContext(ctx, &ssm.GetParameterInput{ Name: ptr.Of(s.prefix + name), WithDecryption: ptr.Of(true), }) @@ -136,11 +133,7 @@ func (s *ssmSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bul } for search { - clients, err := s.authProvider.ParameterStore(ctx) - if err != nil { - return secretstores.BulkGetSecretResponse{Data: nil}, fmt.Errorf("failed to get client: %v", err) - } - output, err := clients.Store.DescribeParametersWithContext(ctx, &ssm.DescribeParametersInput{ + output, err := s.authProvider.ParameterStore().Store.DescribeParametersWithContext(ctx, &ssm.DescribeParametersInput{ MaxResults: nil, NextToken: nextToken, ParameterFilters: filters, @@ -150,11 +143,7 @@ func (s *ssmSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bul } for _, entry := range output.Parameters { - clients, err = s.authProvider.ParameterStore(ctx) - if err != nil { - return secretstores.BulkGetSecretResponse{Data: nil}, fmt.Errorf("failed to get client: %v", err) - } - params, err := clients.Store.GetParameterWithContext(ctx, &ssm.GetParameterInput{ + params, err := s.authProvider.ParameterStore().Store.GetParameterWithContext(ctx, &ssm.GetParameterInput{ Name: entry.Name, WithDecryption: aws.Bool(true), }) diff --git a/secretstores/aws/secretmanager/secretmanager.go b/secretstores/aws/secretmanager/secretmanager.go index 3db576adbf..968483318d 100644 --- a/secretstores/aws/secretmanager/secretmanager.go +++ b/secretstores/aws/secretmanager/secretmanager.go @@ -86,11 +86,7 @@ func (s *smSecretStore) GetSecret(ctx context.Context, req secretstores.GetSecre if value, ok := req.Metadata[VersionStage]; ok { versionStage = &value } - clients, err := s.authProvider.SecretManager(ctx) - if err != nil { - return secretstores.GetSecretResponse{Data: nil}, fmt.Errorf("failed to get client: %v", err) - } - output, err := clients.Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ + output, err := s.authProvider.SecretManager().Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ SecretId: &req.Name, VersionId: versionID, VersionStage: versionStage, @@ -119,11 +115,7 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk var nextToken *string = nil for search { - clients, err := s.authProvider.SecretManager(ctx) - if err != nil { - return secretstores.BulkGetSecretResponse{Data: nil}, fmt.Errorf("failed to get client: %v", err) - } - output, err := clients.Manager.ListSecretsWithContext(ctx, &secretsmanager.ListSecretsInput{ + output, err := s.authProvider.SecretManager().Manager.ListSecretsWithContext(ctx, &secretsmanager.ListSecretsInput{ MaxResults: nil, NextToken: nextToken, }) @@ -132,7 +124,7 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk } for _, entry := range output.SecretList { - secrets, err := clients.Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ + secrets, err := s.authProvider.SecretManager().Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ SecretId: entry.Name, }) if err != nil { diff --git a/state/aws/dynamodb/dynamodb.go b/state/aws/dynamodb/dynamodb.go index 42db3e2e99..b85cd5e0e5 100644 --- a/state/aws/dynamodb/dynamodb.go +++ b/state/aws/dynamodb/dynamodb.go @@ -122,12 +122,7 @@ func (d *StateStore) validateTableAccess(ctx context.Context) error { }, }, } - clients, err := d.authProvider.DynamoDB(ctx) - if err != nil { - return fmt.Errorf("failed to get client: %v", err) - } - - _, err = clients.DynamoDB.GetItemWithContext(ctx, input) + _, err := d.authProvider.DynamoDB().DynamoDB.GetItemWithContext(ctx, input) return err } @@ -159,11 +154,7 @@ func (d *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.Get }, }, } - clients, err := d.authProvider.DynamoDB(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get client: %v", err) - } - result, err := clients.DynamoDB.GetItemWithContext(ctx, input) + result, err := d.authProvider.DynamoDB().DynamoDB.GetItemWithContext(ctx, input) if err != nil { return nil, err } @@ -235,11 +226,7 @@ func (d *StateStore) Set(ctx context.Context, req *state.SetRequest) error { condExpr := "attribute_not_exists(etag)" input.ConditionExpression = &condExpr } - clients, err := d.authProvider.DynamoDB(ctx) - if err != nil { - return fmt.Errorf("failed to get client: %v", err) - } - _, err = clients.DynamoDB.PutItemWithContext(ctx, input) + _, err = d.authProvider.DynamoDB().DynamoDB.PutItemWithContext(ctx, input) if err != nil && req.HasETag() { switch cErr := err.(type) { case *dynamodb.ConditionalCheckFailedException: @@ -270,11 +257,7 @@ func (d *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error } input.ExpressionAttributeValues = exprAttrValues } - clients, err := d.authProvider.DynamoDB(ctx) - if err != nil { - return fmt.Errorf("failed to get client: %v", err) - } - _, err = clients.DynamoDB.DeleteItemWithContext(ctx, input) + _, err := d.authProvider.DynamoDB().DynamoDB.DeleteItemWithContext(ctx, input) if err != nil { switch cErr := err.(type) { case *dynamodb.ConditionalCheckFailedException: @@ -445,11 +428,7 @@ func (d *StateStore) Multi(ctx context.Context, request *state.TransactionalStat } twinput.TransactItems = append(twinput.TransactItems, twi) } - clients, err := d.authProvider.DynamoDB(ctx) - if err != nil { - return fmt.Errorf("failed to get client: %v", err) - } - _, err = clients.DynamoDB.TransactWriteItemsWithContext(ctx, twinput) + _, err := d.authProvider.DynamoDB().DynamoDB.TransactWriteItemsWithContext(ctx, twinput) return err } From 5a17558dd8a83ac8d82fde2193da093bb9349ba2 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Tue, 12 Nov 2024 17:45:52 -0600 Subject: [PATCH 20/39] fix(conformance): try to inject mocked creds for session Signed-off-by: Samantha Coyle --- secretstores/aws/secretmanager/secretmanager.go | 12 ++++++------ secretstores/aws/secretmanager/secretmanager_test.go | 8 ++++---- tests/conformance/secretstores_test.go | 11 ++++++++++- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/secretstores/aws/secretmanager/secretmanager.go b/secretstores/aws/secretmanager/secretmanager.go index 968483318d..2f081af0a3 100644 --- a/secretstores/aws/secretmanager/secretmanager.go +++ b/secretstores/aws/secretmanager/secretmanager.go @@ -49,7 +49,7 @@ type SecretManagerMetaData struct { } type smSecretStore struct { - authProvider awsAuth.Provider + AuthProvider awsAuth.Provider logger logger.Logger } @@ -72,7 +72,7 @@ func (s *smSecretStore) Init(ctx context.Context, metadata secretstores.Metadata if err != nil { return err } - s.authProvider = provider + s.AuthProvider = provider return nil } @@ -86,7 +86,7 @@ func (s *smSecretStore) GetSecret(ctx context.Context, req secretstores.GetSecre if value, ok := req.Metadata[VersionStage]; ok { versionStage = &value } - output, err := s.authProvider.SecretManager().Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ + output, err := s.AuthProvider.SecretManager().Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ SecretId: &req.Name, VersionId: versionID, VersionStage: versionStage, @@ -115,7 +115,7 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk var nextToken *string = nil for search { - output, err := s.authProvider.SecretManager().Manager.ListSecretsWithContext(ctx, &secretsmanager.ListSecretsInput{ + output, err := s.AuthProvider.SecretManager().Manager.ListSecretsWithContext(ctx, &secretsmanager.ListSecretsInput{ MaxResults: nil, NextToken: nextToken, }) @@ -124,7 +124,7 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk } for _, entry := range output.SecretList { - secrets, err := s.authProvider.SecretManager().Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ + secrets, err := s.AuthProvider.SecretManager().Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ SecretId: entry.Name, }) if err != nil { @@ -170,5 +170,5 @@ func (s *smSecretStore) GetComponentMetadata() (metadataInfo metadata.MetadataMa } func (s *smSecretStore) Close() error { - return s.authProvider.Close() + return s.AuthProvider.Close() } diff --git a/secretstores/aws/secretmanager/secretmanager_test.go b/secretstores/aws/secretmanager/secretmanager_test.go index 6c6736f8b9..8c5f13b108 100644 --- a/secretstores/aws/secretmanager/secretmanager_test.go +++ b/secretstores/aws/secretmanager/secretmanager_test.go @@ -88,7 +88,7 @@ func TestGetSecret(t *testing.T) { } s := smSecretStore{ - authProvider: mockAuthProvider, + AuthProvider: mockAuthProvider, } req := secretstores.GetSecretRequest{ @@ -126,7 +126,7 @@ func TestGetSecret(t *testing.T) { } s := smSecretStore{ - authProvider: mockAuthProvider, + AuthProvider: mockAuthProvider, } req := secretstores.GetSecretRequest{ @@ -166,7 +166,7 @@ func TestGetSecret(t *testing.T) { } s := smSecretStore{ - authProvider: mockAuthProvider, + AuthProvider: mockAuthProvider, } req := secretstores.GetSecretRequest{ @@ -201,7 +201,7 @@ func TestGetSecret(t *testing.T) { } s := smSecretStore{ - authProvider: mockAuthProvider, + AuthProvider: mockAuthProvider, } req := secretstores.GetSecretRequest{ diff --git a/tests/conformance/secretstores_test.go b/tests/conformance/secretstores_test.go index 2b9e7c9cd5..8b65927f7f 100644 --- a/tests/conformance/secretstores_test.go +++ b/tests/conformance/secretstores_test.go @@ -20,6 +20,9 @@ import ( "path/filepath" "testing" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" "github.com/stretchr/testify/require" "github.com/dapr/components-contrib/secretstores" @@ -73,7 +76,13 @@ func loadSecretStore(name string) secretstores.SecretStore { case "hashicorp.vault": return ss_hashicorp_vault.NewHashiCorpVaultSecretStore(testLogger) case "aws.secretsmanager.docker": - return ss_aws.NewSecretManager(testLogger) + ss := ss_aws.NewSecretManager(testLogger) + mockedSession, err := session.NewSession(&aws.Config{ + Region: aws.String("us-west-1"), + Credentials: credentials.AnonymousCredentials, + }) + ss.AuthProvider.Session = mockedSession + return ss case "aws.secretsmanager.terraform": return ss_aws.NewSecretManager(testLogger) default: From 72fe803e099f03647fabd55312d85e8c8228a067 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Tue, 12 Nov 2024 18:01:06 -0600 Subject: [PATCH 21/39] style: make linter happy Signed-off-by: Samantha Coyle --- tests/conformance/secretstores_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conformance/secretstores_test.go b/tests/conformance/secretstores_test.go index 8b65927f7f..5df65e005c 100644 --- a/tests/conformance/secretstores_test.go +++ b/tests/conformance/secretstores_test.go @@ -77,10 +77,10 @@ func loadSecretStore(name string) secretstores.SecretStore { return ss_hashicorp_vault.NewHashiCorpVaultSecretStore(testLogger) case "aws.secretsmanager.docker": ss := ss_aws.NewSecretManager(testLogger) - mockedSession, err := session.NewSession(&aws.Config{ + mockedSession := session.Must(session.NewSession(&aws.Config{ Region: aws.String("us-west-1"), Credentials: credentials.AnonymousCredentials, - }) + })) ss.AuthProvider.Session = mockedSession return ss case "aws.secretsmanager.terraform": From d4f23ba5db2689b14917adb240e4a57a9db9e87a Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 13 Nov 2024 08:26:48 -0600 Subject: [PATCH 22/39] fix: go back on conformance test changes Signed-off-by: Samantha Coyle --- secretstores/aws/secretmanager/secretmanager.go | 12 ++++++------ secretstores/aws/secretmanager/secretmanager_test.go | 8 ++++---- tests/conformance/secretstores_test.go | 11 +---------- 3 files changed, 11 insertions(+), 20 deletions(-) diff --git a/secretstores/aws/secretmanager/secretmanager.go b/secretstores/aws/secretmanager/secretmanager.go index 2f081af0a3..968483318d 100644 --- a/secretstores/aws/secretmanager/secretmanager.go +++ b/secretstores/aws/secretmanager/secretmanager.go @@ -49,7 +49,7 @@ type SecretManagerMetaData struct { } type smSecretStore struct { - AuthProvider awsAuth.Provider + authProvider awsAuth.Provider logger logger.Logger } @@ -72,7 +72,7 @@ func (s *smSecretStore) Init(ctx context.Context, metadata secretstores.Metadata if err != nil { return err } - s.AuthProvider = provider + s.authProvider = provider return nil } @@ -86,7 +86,7 @@ func (s *smSecretStore) GetSecret(ctx context.Context, req secretstores.GetSecre if value, ok := req.Metadata[VersionStage]; ok { versionStage = &value } - output, err := s.AuthProvider.SecretManager().Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ + output, err := s.authProvider.SecretManager().Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ SecretId: &req.Name, VersionId: versionID, VersionStage: versionStage, @@ -115,7 +115,7 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk var nextToken *string = nil for search { - output, err := s.AuthProvider.SecretManager().Manager.ListSecretsWithContext(ctx, &secretsmanager.ListSecretsInput{ + output, err := s.authProvider.SecretManager().Manager.ListSecretsWithContext(ctx, &secretsmanager.ListSecretsInput{ MaxResults: nil, NextToken: nextToken, }) @@ -124,7 +124,7 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk } for _, entry := range output.SecretList { - secrets, err := s.AuthProvider.SecretManager().Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ + secrets, err := s.authProvider.SecretManager().Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ SecretId: entry.Name, }) if err != nil { @@ -170,5 +170,5 @@ func (s *smSecretStore) GetComponentMetadata() (metadataInfo metadata.MetadataMa } func (s *smSecretStore) Close() error { - return s.AuthProvider.Close() + return s.authProvider.Close() } diff --git a/secretstores/aws/secretmanager/secretmanager_test.go b/secretstores/aws/secretmanager/secretmanager_test.go index 8c5f13b108..6c6736f8b9 100644 --- a/secretstores/aws/secretmanager/secretmanager_test.go +++ b/secretstores/aws/secretmanager/secretmanager_test.go @@ -88,7 +88,7 @@ func TestGetSecret(t *testing.T) { } s := smSecretStore{ - AuthProvider: mockAuthProvider, + authProvider: mockAuthProvider, } req := secretstores.GetSecretRequest{ @@ -126,7 +126,7 @@ func TestGetSecret(t *testing.T) { } s := smSecretStore{ - AuthProvider: mockAuthProvider, + authProvider: mockAuthProvider, } req := secretstores.GetSecretRequest{ @@ -166,7 +166,7 @@ func TestGetSecret(t *testing.T) { } s := smSecretStore{ - AuthProvider: mockAuthProvider, + authProvider: mockAuthProvider, } req := secretstores.GetSecretRequest{ @@ -201,7 +201,7 @@ func TestGetSecret(t *testing.T) { } s := smSecretStore{ - AuthProvider: mockAuthProvider, + authProvider: mockAuthProvider, } req := secretstores.GetSecretRequest{ diff --git a/tests/conformance/secretstores_test.go b/tests/conformance/secretstores_test.go index 5df65e005c..2b9e7c9cd5 100644 --- a/tests/conformance/secretstores_test.go +++ b/tests/conformance/secretstores_test.go @@ -20,9 +20,6 @@ import ( "path/filepath" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" "github.com/stretchr/testify/require" "github.com/dapr/components-contrib/secretstores" @@ -76,13 +73,7 @@ func loadSecretStore(name string) secretstores.SecretStore { case "hashicorp.vault": return ss_hashicorp_vault.NewHashiCorpVaultSecretStore(testLogger) case "aws.secretsmanager.docker": - ss := ss_aws.NewSecretManager(testLogger) - mockedSession := session.Must(session.NewSession(&aws.Config{ - Region: aws.String("us-west-1"), - Credentials: credentials.AnonymousCredentials, - })) - ss.AuthProvider.Session = mockedSession - return ss + return ss_aws.NewSecretManager(testLogger) case "aws.secretsmanager.terraform": return ss_aws.NewSecretManager(testLogger) default: From 67ed5ecdcae0c0f9ccbecd858fb8e2da5e48d531 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 13 Nov 2024 08:48:45 -0600 Subject: [PATCH 23/39] fix: try this for conformance Signed-off-by: Samantha Coyle --- secretstores/aws/secretmanager/secretmanager.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/secretstores/aws/secretmanager/secretmanager.go b/secretstores/aws/secretmanager/secretmanager.go index 968483318d..6aabcdbf27 100644 --- a/secretstores/aws/secretmanager/secretmanager.go +++ b/secretstores/aws/secretmanager/secretmanager.go @@ -65,7 +65,7 @@ func (s *smSecretStore) Init(ctx context.Context, metadata secretstores.Metadata Region: meta.Region, AccessKey: meta.AccessKey, SecretKey: meta.SecretKey, - SessionToken: "", + SessionToken: meta.SessionToken, } provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) From f80c594874a132e1efc9334670155a517c7622d7 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 13 Nov 2024 08:56:50 -0600 Subject: [PATCH 24/39] fix: try another tweak for secretmanager Signed-off-by: Samantha Coyle --- secretstores/aws/secretmanager/secretmanager.go | 1 + 1 file changed, 1 insertion(+) diff --git a/secretstores/aws/secretmanager/secretmanager.go b/secretstores/aws/secretmanager/secretmanager.go index 6aabcdbf27..a963483f0a 100644 --- a/secretstores/aws/secretmanager/secretmanager.go +++ b/secretstores/aws/secretmanager/secretmanager.go @@ -66,6 +66,7 @@ func (s *smSecretStore) Init(ctx context.Context, metadata secretstores.Metadata AccessKey: meta.AccessKey, SecretKey: meta.SecretKey, SessionToken: meta.SessionToken, + Endpoint: meta.Endpoint, } provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) From 29ae9f061d360de1f5f843c9af5d38e3862261fe Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 13 Nov 2024 09:13:45 -0600 Subject: [PATCH 25/39] fix(test): fix dynamo unit test Signed-off-by: Samantha Coyle --- state/aws/dynamodb/dynamodb_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/state/aws/dynamodb/dynamodb_test.go b/state/aws/dynamodb/dynamodb_test.go index 3a103d88d2..d5dd43aeb8 100644 --- a/state/aws/dynamodb/dynamodb_test.go +++ b/state/aws/dynamodb/dynamodb_test.go @@ -767,6 +767,7 @@ func TestSet(t *testing.T) { s := StateStore{ authProvider: mockAuthProvider, ttlAttributeName: "testAttributeName", + partitionKey: defaultPartitionKeyName, } req := &state.SetRequest{ From a9ce95e716aff7d5a9a4d1f1ed9b8a357b63774f Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 13 Nov 2024 10:34:48 -0600 Subject: [PATCH 26/39] fix(snssqs): see if this fixes snssqs conformance Signed-off-by: Samantha Coyle --- pubsub/aws/snssqs/metadata.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pubsub/aws/snssqs/metadata.yaml b/pubsub/aws/snssqs/metadata.yaml index 641ab4ea07..4b12c525bd 100644 --- a/pubsub/aws/snssqs/metadata.yaml +++ b/pubsub/aws/snssqs/metadata.yaml @@ -128,7 +128,7 @@ metadata: default: '"parallel"' example: '"single", "parallel"' type: string - - name: accountId + - name: accountID required: false description: | The AWS account ID. Resolved automatically if not provided. From 4f8c154345f5ee195e6a5e305eee5b5ac071e13d Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 13 Nov 2024 10:41:40 -0600 Subject: [PATCH 27/39] fix: this is what i need for conformance :) Signed-off-by: Samantha Coyle --- common/authentication/aws/client.go | 1 + pubsub/aws/snssqs/metadata.yaml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/common/authentication/aws/client.go b/common/authentication/aws/client.go index b14a97ec63..85864d23b2 100644 --- a/common/authentication/aws/client.go +++ b/common/authentication/aws/client.go @@ -129,6 +129,7 @@ func (c *SnsClients) New(session *session.Session) { func (c *SnsSqsClients) New(session *session.Session) { c.Sns = sns.New(session, session.Config) c.Sqs = sqs.New(session, session.Config) + c.Sts = sts.New(session, session.Config) } func (c *SqsClients) New(session *session.Session) { diff --git a/pubsub/aws/snssqs/metadata.yaml b/pubsub/aws/snssqs/metadata.yaml index 4b12c525bd..641ab4ea07 100644 --- a/pubsub/aws/snssqs/metadata.yaml +++ b/pubsub/aws/snssqs/metadata.yaml @@ -128,7 +128,7 @@ metadata: default: '"parallel"' example: '"single", "parallel"' type: string - - name: accountID + - name: accountId required: false description: | The AWS account ID. Resolved automatically if not provided. From 77498263f485abca1bf3cd3e7cb91a12a92814fb Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 13 Nov 2024 11:08:08 -0600 Subject: [PATCH 28/39] fix: update cfgs for aws Signed-off-by: Samantha Coyle --- common/authentication/aws/static.go | 5 ++++- common/authentication/aws/x509.go | 14 ++++++++------ tests/certification/bindings/aws/s3/s3_test.go | 2 +- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/common/authentication/aws/static.go b/common/authentication/aws/static.go index 8d5d4ea302..0e05c8c268 100644 --- a/common/authentication/aws/static.go +++ b/common/authentication/aws/static.go @@ -195,7 +195,10 @@ func (a *StaticAuth) Ses() *SesClients { } func (a *StaticAuth) getTokenClient() (*session.Session, error) { - awsConfig := aws.NewConfig() + awsConfig := a.Cfg + if awsConfig != nil { + awsConfig = aws.NewConfig() + } if a.Region != "" { awsConfig = awsConfig.WithRegion(a.Region) diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go index bf395d1dcc..224bb97133 100644 --- a/common/authentication/aws/x509.go +++ b/common/authentication/aws/x509.go @@ -334,17 +334,19 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er }} var mySession *session.Session - var config *aws.Config - if a.Cfg != nil { - config = a.Cfg.WithRegion(*a.region).WithHTTPClient(client).WithLogLevel(aws.LogOff) + awsConfig := a.Cfg + if awsConfig != nil { + awsConfig = a.Cfg.WithRegion(*a.region).WithHTTPClient(client).WithLogLevel(aws.LogOff) + } else { + awsConfig = aws.NewConfig().WithRegion(*a.region).WithHTTPClient(client).WithLogLevel(aws.LogOff) } // this is needed for testing purposes to mock the client, // so code never sets the client, but tests do. var rolesClient *rolesanywhere.RolesAnywhere if a.rolesAnywhereClient == nil { - mySession = session.Must(session.NewSession(config)) - rolesAnywhereClient := rolesanywhere.New(mySession, config) + mySession = session.Must(session.NewSession(awsConfig)) + rolesAnywhereClient := rolesanywhere.New(mySession, awsConfig) // Set up signing function and handlers if err := a.setSigningFunction(rolesAnywhereClient); err != nil { return nil, err @@ -406,7 +408,7 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er awsCreds := credentials.NewStaticCredentials(*accessKey, *secretKey, *sessionToken) sess := session.Must(session.NewSession(&aws.Config{ Credentials: awsCreds, - }, config)) + }, awsConfig)) if sess == nil { return nil, errors.New("session is nil") } diff --git a/tests/certification/bindings/aws/s3/s3_test.go b/tests/certification/bindings/aws/s3/s3_test.go index 16e41abc49..c815901f8c 100644 --- a/tests/certification/bindings/aws/s3/s3_test.go +++ b/tests/certification/bindings/aws/s3/s3_test.go @@ -279,7 +279,7 @@ func S3SForcePathStyle(t *testing.T) { Step(sidecar.Run(sidecarName, append(componentRuntimeOptions(), embedded.WithoutApp(), - embedded.WithComponentsPath("./components/forcePathStyleTrue"), + embedded.WithResourcesPath("./components/forcePathStyleTrue"), embedded.WithDaprGRPCPort(strconv.Itoa(currentGRPCPort)), embedded.WithDaprHTTPPort(strconv.Itoa(currentHTTPPort)), )..., From 690b3eca58b8c17e0d2c05111532c31a812695cf Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 13 Nov 2024 11:34:43 -0600 Subject: [PATCH 29/39] fix(test): update for unit test Signed-off-by: Samantha Coyle --- common/authentication/aws/static.go | 5 ++++- common/authentication/aws/x509.go | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/common/authentication/aws/static.go b/common/authentication/aws/static.go index 0e05c8c268..1dae307d37 100644 --- a/common/authentication/aws/static.go +++ b/common/authentication/aws/static.go @@ -45,6 +45,9 @@ type StaticAuth struct { } func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth, error) { + if cfg == nil { + cfg = aws.NewConfig() + } auth := &StaticAuth{ Logger: opts.Logger, Region: opts.Region, @@ -196,7 +199,7 @@ func (a *StaticAuth) Ses() *SesClients { func (a *StaticAuth) getTokenClient() (*session.Session, error) { awsConfig := a.Cfg - if awsConfig != nil { + if a.Cfg == nil { awsConfig = aws.NewConfig() } diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go index 224bb97133..0c1129f45d 100644 --- a/common/authentication/aws/x509.go +++ b/common/authentication/aws/x509.go @@ -80,6 +80,9 @@ func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) return nil, errors.New("sessionDuration must be greater than 15 minutes, and less than 12 hours") } + if cfg == nil { + cfg = aws.NewConfig() + } auth := &x509{ wg: sync.WaitGroup{}, logger: opts.Logger, From a42e74246e59101fcb640191c969a68ff9c8c76d Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 13 Nov 2024 12:04:53 -0600 Subject: [PATCH 30/39] fix(cfg): leverage opts in cfgs for aws Signed-off-by: Samantha Coyle --- bindings/aws/dynamodb/dynamodb.go | 2 +- bindings/aws/kinesis/kinesis.go | 2 +- bindings/aws/ses/ses.go | 2 +- bindings/aws/sns/sns.go | 3 +-- bindings/aws/sqs/sqs.go | 2 +- common/authentication/aws/static.go | 14 +++++++++----- common/authentication/aws/x509.go | 14 +++++++++----- pubsub/aws/snssqs/snssqs.go | 2 +- secretstores/aws/parameterstore/parameterstore.go | 2 +- secretstores/aws/secretmanager/secretmanager.go | 3 +-- 10 files changed, 26 insertions(+), 20 deletions(-) diff --git a/bindings/aws/dynamodb/dynamodb.go b/bindings/aws/dynamodb/dynamodb.go index 4ad0620257..755b3158d3 100644 --- a/bindings/aws/dynamodb/dynamodb.go +++ b/bindings/aws/dynamodb/dynamodb.go @@ -67,7 +67,7 @@ func (d *DynamoDB) Init(ctx context.Context, metadata bindings.Metadata) error { SessionToken: meta.SessionToken, } - provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } diff --git a/bindings/aws/kinesis/kinesis.go b/bindings/aws/kinesis/kinesis.go index 2e2f62897f..7ede7ba245 100644 --- a/bindings/aws/kinesis/kinesis.go +++ b/bindings/aws/kinesis/kinesis.go @@ -126,7 +126,7 @@ func (a *AWSKinesis) Init(ctx context.Context, metadata bindings.Metadata) error SessionToken: "", } // extra configs needed per component type - provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } diff --git a/bindings/aws/ses/ses.go b/bindings/aws/ses/ses.go index eb7a1143e7..4cd752bac5 100644 --- a/bindings/aws/ses/ses.go +++ b/bindings/aws/ses/ses.go @@ -79,7 +79,7 @@ func (a *AWSSES) Init(ctx context.Context, metadata bindings.Metadata) error { SessionToken: "", } // extra configs needed per component type - provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } diff --git a/bindings/aws/sns/sns.go b/bindings/aws/sns/sns.go index 763d2c29d3..55e3ccefa5 100644 --- a/bindings/aws/sns/sns.go +++ b/bindings/aws/sns/sns.go @@ -19,7 +19,6 @@ import ( "fmt" "reflect" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/sns" "github.com/dapr/components-contrib/bindings" @@ -75,7 +74,7 @@ func (a *AWSSNS) Init(ctx context.Context, metadata bindings.Metadata) error { SessionToken: m.SessionToken, } // extra configs needed per component type - provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } diff --git a/bindings/aws/sqs/sqs.go b/bindings/aws/sqs/sqs.go index dde86b4e52..d803bafc5a 100644 --- a/bindings/aws/sqs/sqs.go +++ b/bindings/aws/sqs/sqs.go @@ -75,7 +75,7 @@ func (a *AWSSQS) Init(ctx context.Context, metadata bindings.Metadata) error { SessionToken: m.SessionToken, } // extra configs needed per component type - provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } diff --git a/common/authentication/aws/static.go b/common/authentication/aws/static.go index 1dae307d37..8ad052f30f 100644 --- a/common/authentication/aws/static.go +++ b/common/authentication/aws/static.go @@ -45,9 +45,6 @@ type StaticAuth struct { } func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth, error) { - if cfg == nil { - cfg = aws.NewConfig() - } auth := &StaticAuth{ Logger: opts.Logger, Region: opts.Region, @@ -55,8 +52,15 @@ func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth AccessKey: &opts.AccessKey, SecretKey: &opts.SecretKey, SessionToken: &opts.SessionToken, - Cfg: cfg, - Clients: newClients(), + Cfg: func() *aws.Config { + // if nil is passed or it's just a default cfg, + // then we use the options to build the aws cfg. + if cfg != nil && cfg != aws.NewConfig() { + return cfg + } + return GetConfig(opts) + }(), + Clients: newClients(), } initialSession, err := auth.getTokenClient() diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go index 0c1129f45d..ed8b2154cc 100644 --- a/common/authentication/aws/x509.go +++ b/common/authentication/aws/x509.go @@ -80,9 +80,6 @@ func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) return nil, errors.New("sessionDuration must be greater than 15 minutes, and less than 12 hours") } - if cfg == nil { - cfg = aws.NewConfig() - } auth := &x509{ wg: sync.WaitGroup{}, logger: opts.Logger, @@ -90,8 +87,15 @@ func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) TrustAnchorArn: x509Auth.TrustAnchorArn, AssumeRoleArn: x509Auth.AssumeRoleArn, SessionDuration: x509Auth.SessionDuration, - Cfg: GetConfig(opts), - Clients: newClients(), + Cfg: func() *aws.Config { + // if nil is passed or it's just a default cfg, + // then we use the options to build the aws cfg. + if cfg != nil && cfg != aws.NewConfig() { + return cfg + } + return GetConfig(opts) + }(), + Clients: newClients(), } var err error diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index 71f45b4557..89e8f5ef7a 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -155,7 +155,7 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error { } // extra configs needed per component type var provider awsAuth.Provider - provider, err = awsAuth.NewProvider(ctx, opts, aws.NewConfig()) + provider, err = awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } diff --git a/secretstores/aws/parameterstore/parameterstore.go b/secretstores/aws/parameterstore/parameterstore.go index 72c8def093..abf9c6c4de 100644 --- a/secretstores/aws/parameterstore/parameterstore.go +++ b/secretstores/aws/parameterstore/parameterstore.go @@ -73,7 +73,7 @@ func (s *ssmSecretStore) Init(ctx context.Context, metadata secretstores.Metadat SessionToken: "", } // extra configs needed per component type - provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } diff --git a/secretstores/aws/secretmanager/secretmanager.go b/secretstores/aws/secretmanager/secretmanager.go index a963483f0a..6faf1f1eab 100644 --- a/secretstores/aws/secretmanager/secretmanager.go +++ b/secretstores/aws/secretmanager/secretmanager.go @@ -19,7 +19,6 @@ import ( "fmt" "reflect" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/secretsmanager" awsAuth "github.com/dapr/components-contrib/common/authentication/aws" @@ -69,7 +68,7 @@ func (s *smSecretStore) Init(ctx context.Context, metadata secretstores.Metadata Endpoint: meta.Endpoint, } - provider, err := awsAuth.NewProvider(ctx, opts, aws.NewConfig()) + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } From b5e1d9778c9cdb8f13838ac91c24ee8dc7bb2820 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 13 Nov 2024 14:46:27 -0600 Subject: [PATCH 31/39] fix: minor tweaks Signed-off-by: Samantha Coyle --- bindings/aws/s3/s3.go | 3 +-- common/authentication/aws/x509.go | 4 ++++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/bindings/aws/s3/s3.go b/bindings/aws/s3/s3.go index dd6f8906dc..13f8730e78 100644 --- a/bindings/aws/s3/s3.go +++ b/bindings/aws/s3/s3.go @@ -143,8 +143,7 @@ func (s *AWSS3) Init(ctx context.Context, metadata bindings.Metadata) error { SessionToken: m.SessionToken, } // extra configs needed per component type - cfg := s.getAWSConfig(opts) - provider, err := awsAuth.NewProvider(ctx, opts, cfg) + provider, err := awsAuth.NewProvider(ctx, opts, s.getAWSConfig(opts)) if err != nil { return err } diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go index ed8b2154cc..01238b37c1 100644 --- a/common/authentication/aws/x509.go +++ b/common/authentication/aws/x509.go @@ -69,6 +69,10 @@ func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) return nil, err } + if x509Auth.SessionDuration == nil { + x509Auth.SessionDuration = new(time.Duration) + } + switch { case x509Auth.TrustProfileArn == nil: return nil, errors.New("trustProfileArn is required") From 923ee9488a98ec4a1bb45c3792e6aa8d45675762 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 13 Nov 2024 15:03:28 -0600 Subject: [PATCH 32/39] style: final tweaks Signed-off-by: Samantha Coyle --- common/authentication/aws/static.go | 4 +++- common/authentication/aws/x509.go | 12 +++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/common/authentication/aws/static.go b/common/authentication/aws/static.go index 8ad052f30f..43081fe57f 100644 --- a/common/authentication/aws/static.go +++ b/common/authentication/aws/static.go @@ -202,9 +202,11 @@ func (a *StaticAuth) Ses() *SesClients { } func (a *StaticAuth) getTokenClient() (*session.Session, error) { - awsConfig := a.Cfg + var awsConfig *aws.Config if a.Cfg == nil { awsConfig = aws.NewConfig() + } else { + awsConfig = a.Cfg } if a.Region != "" { diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go index 01238b37c1..fbc4dd67e8 100644 --- a/common/authentication/aws/x509.go +++ b/common/authentication/aws/x509.go @@ -345,13 +345,15 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er }} var mySession *session.Session - awsConfig := a.Cfg - if awsConfig != nil { - awsConfig = a.Cfg.WithRegion(*a.region).WithHTTPClient(client).WithLogLevel(aws.LogOff) + var awsConfig *aws.Config + if a.Cfg == nil { + awsConfig = aws.NewConfig().WithHTTPClient(client).WithLogLevel(aws.LogOff) } else { - awsConfig = aws.NewConfig().WithRegion(*a.region).WithHTTPClient(client).WithLogLevel(aws.LogOff) + awsConfig = a.Cfg.WithHTTPClient(client).WithLogLevel(aws.LogOff) + } + if a.region != nil { + awsConfig.WithRegion(*a.region) } - // this is needed for testing purposes to mock the client, // so code never sets the client, but tests do. var rolesClient *rolesanywhere.RolesAnywhere From 37021c5c8cf557a2ac9a8f5f7d4ccec86bdca1a0 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 13 Nov 2024 15:31:46 -0600 Subject: [PATCH 33/39] fix: update default to be one hour with timeout Signed-off-by: Samantha Coyle --- common/authentication/aws/x509.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go index fbc4dd67e8..f90d763ed6 100644 --- a/common/authentication/aws/x509.go +++ b/common/authentication/aws/x509.go @@ -70,7 +70,8 @@ func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) } if x509Auth.SessionDuration == nil { - x509Auth.SessionDuration = new(time.Duration) + defaultDuration := time.Hour + x509Auth.SessionDuration = &defaultDuration } switch { From b081bfe329ddab74d2241b483f3927d61fa01472 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 13 Nov 2024 15:37:23 -0600 Subject: [PATCH 34/39] style: session duration can default to 1h so not required Signed-off-by: Samantha Coyle --- .build-tools/builtin-authentication-profiles.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.build-tools/builtin-authentication-profiles.yaml b/.build-tools/builtin-authentication-profiles.yaml index 86d412878e..9b37c895cd 100644 --- a/.build-tools/builtin-authentication-profiles.yaml +++ b/.build-tools/builtin-authentication-profiles.yaml @@ -54,7 +54,7 @@ aws: If set to 0m, temporary credentials will automatically rotate. default: '1h' example: '0m' - required: true + required: false azuread: - title: "Azure AD: Managed identity" From 6263e199a88e06f40886d2e90a735684c335f937 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 13 Nov 2024 15:38:46 -0600 Subject: [PATCH 35/39] Update builtin-authentication-profiles.yaml Signed-off-by: Sam --- .build-tools/builtin-authentication-profiles.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.build-tools/builtin-authentication-profiles.yaml b/.build-tools/builtin-authentication-profiles.yaml index 86d412878e..ea991e1f44 100644 --- a/.build-tools/builtin-authentication-profiles.yaml +++ b/.build-tools/builtin-authentication-profiles.yaml @@ -51,10 +51,10 @@ aws: type: duration description: | Duration of the session using AWS IAM Roles Anywhere. - If set to 0m, temporary credentials will automatically rotate. + If set to 0m, temporary credentials will be created and automatically rotated. default: '1h' example: '0m' - required: true + required: false azuread: - title: "Azure AD: Managed identity" From e6f7699c72bcd6021709fc8c1a085e642824d92c Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Thu, 14 Nov 2024 09:47:08 -0600 Subject: [PATCH 36/39] fix: address final feedback Signed-off-by: Samantha Coyle --- .../builtin-authentication-profiles.yaml | 8 - common/authentication/aws/aws_test.go | 13 ++ common/authentication/aws/client_test.go | 13 ++ common/authentication/aws/static.go | 68 +++--- common/authentication/aws/static_test.go | 13 +- common/authentication/aws/x509.go | 198 ++++++++---------- common/authentication/aws/x509_test.go | 22 +- state/aws/dynamodb/dynamodb.go | 2 +- state/aws/dynamodb/dynamodb_test.go | 23 +- 9 files changed, 168 insertions(+), 192 deletions(-) diff --git a/.build-tools/builtin-authentication-profiles.yaml b/.build-tools/builtin-authentication-profiles.yaml index 9b37c895cd..373c2d230d 100644 --- a/.build-tools/builtin-authentication-profiles.yaml +++ b/.build-tools/builtin-authentication-profiles.yaml @@ -47,14 +47,6 @@ aws: ARN of the AWS IAM role to assume in the trusting AWS account. example: arn:aws:iam:012345678910:role/exampleIAMRoleName required: true - - name: sessionDuration - type: duration - description: | - Duration of the session using AWS IAM Roles Anywhere. - If set to 0m, temporary credentials will automatically rotate. - default: '1h' - example: '0m' - required: false azuread: - title: "Azure AD: Managed identity" diff --git a/common/authentication/aws/aws_test.go b/common/authentication/aws/aws_test.go index 24b6c5649c..15aac78ad7 100644 --- a/common/authentication/aws/aws_test.go +++ b/common/authentication/aws/aws_test.go @@ -1,3 +1,16 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package aws import ( diff --git a/common/authentication/aws/client_test.go b/common/authentication/aws/client_test.go index 20ed547006..67d2ac88f3 100644 --- a/common/authentication/aws/client_test.go +++ b/common/authentication/aws/client_test.go @@ -1,3 +1,16 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package aws import ( diff --git a/common/authentication/aws/static.go b/common/authentication/aws/static.go index 43081fe57f..6997b62a1e 100644 --- a/common/authentication/aws/static.go +++ b/common/authentication/aws/static.go @@ -31,28 +31,28 @@ import ( type StaticAuth struct { mu sync.RWMutex - Logger logger.Logger + logger logger.Logger - Region string - Endpoint *string - AccessKey *string - SecretKey *string - SessionToken *string + region *string + endpoint *string + accessKey *string + secretKey *string + sessionToken *string - Clients *Clients - Session *session.Session - Cfg *aws.Config + session *session.Session + cfg *aws.Config + Clients *Clients // exported to mock clients in unit tests } func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth, error) { auth := &StaticAuth{ - Logger: opts.Logger, - Region: opts.Region, - Endpoint: &opts.Endpoint, - AccessKey: &opts.AccessKey, - SecretKey: &opts.SecretKey, - SessionToken: &opts.SessionToken, - Cfg: func() *aws.Config { + logger: opts.Logger, + region: &opts.Region, + endpoint: &opts.Endpoint, + accessKey: &opts.AccessKey, + secretKey: &opts.SecretKey, + sessionToken: &opts.SessionToken, + cfg: func() *aws.Config { // if nil is passed or it's just a default cfg, // then we use the options to build the aws cfg. if cfg != nil && cfg != aws.NewConfig() { @@ -68,7 +68,7 @@ func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth return nil, fmt.Errorf("failed to get token client: %v", err) } - auth.Session = initialSession + auth.session = initialSession return auth, nil } @@ -83,7 +83,7 @@ func (a *StaticAuth) S3() *S3Clients { s3Clients := S3Clients{} a.Clients.s3 = &s3Clients - a.Clients.s3.New(a.Session) + a.Clients.s3.New(a.session) return a.Clients.s3 } @@ -97,7 +97,7 @@ func (a *StaticAuth) DynamoDB() *DynamoDBClients { clients := DynamoDBClients{} a.Clients.Dynamo = &clients - a.Clients.Dynamo.New(a.Session) + a.Clients.Dynamo.New(a.session) return a.Clients.Dynamo } @@ -112,7 +112,7 @@ func (a *StaticAuth) Sqs() *SqsClients { clients := SqsClients{} a.Clients.sqs = &clients - a.Clients.sqs.New(a.Session) + a.Clients.sqs.New(a.session) return a.Clients.sqs } @@ -127,7 +127,7 @@ func (a *StaticAuth) Sns() *SnsClients { clients := SnsClients{} a.Clients.sns = &clients - a.Clients.sns.New(a.Session) + a.Clients.sns.New(a.session) return a.Clients.sns } @@ -141,7 +141,7 @@ func (a *StaticAuth) SnsSqs() *SnsSqsClients { clients := SnsSqsClients{} a.Clients.snssqs = &clients - a.Clients.snssqs.New(a.Session) + a.Clients.snssqs.New(a.session) return a.Clients.snssqs } @@ -155,7 +155,7 @@ func (a *StaticAuth) SecretManager() *SecretManagerClients { clients := SecretManagerClients{} a.Clients.Secret = &clients - a.Clients.Secret.New(a.Session) + a.Clients.Secret.New(a.session) return a.Clients.Secret } @@ -169,7 +169,7 @@ func (a *StaticAuth) ParameterStore() *ParameterStoreClients { clients := ParameterStoreClients{} a.Clients.ParameterStore = &clients - a.Clients.ParameterStore.New(a.Session) + a.Clients.ParameterStore.New(a.session) return a.Clients.ParameterStore } @@ -183,7 +183,7 @@ func (a *StaticAuth) Kinesis() *KinesisClients { clients := KinesisClients{} a.Clients.kinesis = &clients - a.Clients.kinesis.New(a.Session) + a.Clients.kinesis.New(a.session) return a.Clients.kinesis } @@ -197,29 +197,29 @@ func (a *StaticAuth) Ses() *SesClients { clients := SesClients{} a.Clients.ses = &clients - a.Clients.ses.New(a.Session) + a.Clients.ses.New(a.session) return a.Clients.ses } func (a *StaticAuth) getTokenClient() (*session.Session, error) { var awsConfig *aws.Config - if a.Cfg == nil { + if a.cfg == nil { awsConfig = aws.NewConfig() } else { - awsConfig = a.Cfg + awsConfig = a.cfg } - if a.Region != "" { - awsConfig = awsConfig.WithRegion(a.Region) + if a.region != nil { + awsConfig = awsConfig.WithRegion(*a.region) } - if a.AccessKey != nil && a.SecretKey != nil { + if a.accessKey != nil && a.secretKey != nil { // session token is an option field - awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(*a.AccessKey, *a.SecretKey, *a.SessionToken)) + awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(*a.accessKey, *a.secretKey, *a.sessionToken)) } - if a.Endpoint != nil { - awsConfig = awsConfig.WithEndpoint(*a.Endpoint) + if a.endpoint != nil { + awsConfig = awsConfig.WithEndpoint(*a.endpoint) } awsSession, err := session.NewSessionWithOptions(session.Options{ diff --git a/common/authentication/aws/static_test.go b/common/authentication/aws/static_test.go index 1f2191cfa4..3a9b3a2d00 100644 --- a/common/authentication/aws/static_test.go +++ b/common/authentication/aws/static_test.go @@ -3,7 +3,6 @@ package aws import ( "testing" - "github.com/aws/aws-sdk-go/aws" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -46,11 +45,11 @@ func TestGetTokenClient(t *testing.T) { { name: "valid token client", awsInstance: &StaticAuth{ - AccessKey: aws.String("testAccessKey"), - SecretKey: aws.String("testSecretKey"), - SessionToken: aws.String("testSessionToken"), - Region: "us-west-2", - Endpoint: aws.String("https://test.endpoint.com"), + accessKey: "testAccessKey", + secretKey: "testSecretKey", + sessionToken: "testSessionToken", + region: "us-west-2", + endpoint: "https://test.endpoint.com", }, }, } @@ -60,7 +59,7 @@ func TestGetTokenClient(t *testing.T) { session, err := tt.awsInstance.getTokenClient() require.NotNil(t, session) require.NoError(t, err) - assert.Equal(t, tt.awsInstance.Region, *session.Config.Region) + assert.Equal(t, tt.awsInstance.region, *session.Config.Region) }) } } diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go index f90d763ed6..cb1bafdeb3 100644 --- a/common/authentication/aws/x509.go +++ b/common/authentication/aws/x509.go @@ -1,5 +1,5 @@ /* -Copyright 2021 The Dapr Authors +Copyright 2024 The Dapr Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -41,39 +41,38 @@ import ( "github.com/dapr/kit/ptr" ) -type x509 struct { - mu sync.RWMutex +type x509Options struct { + TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"` + TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"` + AssumeRoleArn *string `json:"assumeRoleArn" mapstructure:"assumeRoleArn"` +} +type x509 struct { + mu sync.RWMutex wg sync.WaitGroup closeCh chan struct{} logger logger.Logger - Clients *Clients + clients *Clients rolesAnywhereClient rolesanywhereiface.RolesAnywhereAPI // this is so we can mock it in tests session *session.Session - Cfg *aws.Config + cfg *aws.Config chainPEM []byte keyPEM []byte region *string - TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"` - TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"` - AssumeRoleArn *string `json:"assumeRoleArn" mapstructure:"assumeRoleArn"` - SessionDuration *time.Duration `json:"sessionDuration" mapstructure:"sessionDuration"` + trustProfileArn *string + trustAnchorArn *string + assumeRoleArn *string } func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) { - var x509Auth x509 + var x509Auth x509Options if err := kitmd.DecodeMetadata(opts.Properties, &x509Auth); err != nil { return nil, err } - if x509Auth.SessionDuration == nil { - defaultDuration := time.Hour - x509Auth.SessionDuration = &defaultDuration - } - switch { case x509Auth.TrustProfileArn == nil: return nil, errors.New("trustProfileArn is required") @@ -81,18 +80,14 @@ func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) return nil, errors.New("trustAnchorArn is required") case x509Auth.AssumeRoleArn == nil: return nil, errors.New("assumeRoleArn is required") - case *x509Auth.SessionDuration != 0 && (*x509Auth.SessionDuration < time.Minute*15 || *x509Auth.SessionDuration > time.Hour*12): - return nil, errors.New("sessionDuration must be greater than 15 minutes, and less than 12 hours") } auth := &x509{ - wg: sync.WaitGroup{}, logger: opts.Logger, - TrustProfileArn: x509Auth.TrustProfileArn, - TrustAnchorArn: x509Auth.TrustAnchorArn, - AssumeRoleArn: x509Auth.AssumeRoleArn, - SessionDuration: x509Auth.SessionDuration, - Cfg: func() *aws.Config { + trustProfileArn: x509Auth.TrustProfileArn, + trustAnchorArn: x509Auth.TrustAnchorArn, + assumeRoleArn: x509Auth.AssumeRoleArn, + cfg: func() *aws.Config { // if nil is passed or it's just a default cfg, // then we use the options to build the aws cfg. if cfg != nil && cfg != aws.NewConfig() { @@ -100,17 +95,15 @@ func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) } return GetConfig(opts) }(), - Clients: newClients(), + clients: newClients(), } - var err error - err = auth.getCertPEM(ctx) - if err != nil { + if err := auth.getCertPEM(ctx); err != nil { return nil, fmt.Errorf("failed to get x.509 credentials: %v", err) } // Parse trust anchor and profile ARNs - if err = auth.initializeTrustAnchors(); err != nil { + if err := auth.initializeTrustAnchors(); err != nil { return nil, err } @@ -157,128 +150,128 @@ func (a *x509) S3() *S3Clients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.s3 != nil { - return a.Clients.s3 + if a.clients.s3 != nil { + return a.clients.s3 } s3Clients := S3Clients{} - a.Clients.s3 = &s3Clients - a.Clients.s3.New(a.session) - return a.Clients.s3 + a.clients.s3 = &s3Clients + a.clients.s3.New(a.session) + return a.clients.s3 } func (a *x509) DynamoDB() *DynamoDBClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.Dynamo != nil { - return a.Clients.Dynamo + if a.clients.Dynamo != nil { + return a.clients.Dynamo } clients := DynamoDBClients{} - a.Clients.Dynamo = &clients - a.Clients.Dynamo.New(a.session) + a.clients.Dynamo = &clients + a.clients.Dynamo.New(a.session) - return a.Clients.Dynamo + return a.clients.Dynamo } func (a *x509) Sqs() *SqsClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.sqs != nil { - return a.Clients.sqs + if a.clients.sqs != nil { + return a.clients.sqs } clients := SqsClients{} - a.Clients.sqs = &clients - a.Clients.sqs.New(a.session) + a.clients.sqs = &clients + a.clients.sqs.New(a.session) - return a.Clients.sqs + return a.clients.sqs } func (a *x509) Sns() *SnsClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.sns != nil { - return a.Clients.sns + if a.clients.sns != nil { + return a.clients.sns } clients := SnsClients{} - a.Clients.sns = &clients - a.Clients.sns.New(a.session) - return a.Clients.sns + a.clients.sns = &clients + a.clients.sns.New(a.session) + return a.clients.sns } func (a *x509) SnsSqs() *SnsSqsClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.snssqs != nil { - return a.Clients.snssqs + if a.clients.snssqs != nil { + return a.clients.snssqs } clients := SnsSqsClients{} - a.Clients.snssqs = &clients - a.Clients.snssqs.New(a.session) - return a.Clients.snssqs + a.clients.snssqs = &clients + a.clients.snssqs.New(a.session) + return a.clients.snssqs } func (a *x509) SecretManager() *SecretManagerClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.Secret != nil { - return a.Clients.Secret + if a.clients.Secret != nil { + return a.clients.Secret } clients := SecretManagerClients{} - a.Clients.Secret = &clients - a.Clients.Secret.New(a.session) - return a.Clients.Secret + a.clients.Secret = &clients + a.clients.Secret.New(a.session) + return a.clients.Secret } func (a *x509) ParameterStore() *ParameterStoreClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.ParameterStore != nil { - return a.Clients.ParameterStore + if a.clients.ParameterStore != nil { + return a.clients.ParameterStore } clients := ParameterStoreClients{} - a.Clients.ParameterStore = &clients - a.Clients.ParameterStore.New(a.session) - return a.Clients.ParameterStore + a.clients.ParameterStore = &clients + a.clients.ParameterStore.New(a.session) + return a.clients.ParameterStore } func (a *x509) Kinesis() *KinesisClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.kinesis != nil { - return a.Clients.kinesis + if a.clients.kinesis != nil { + return a.clients.kinesis } clients := KinesisClients{} - a.Clients.kinesis = &clients - a.Clients.kinesis.New(a.session) - return a.Clients.kinesis + a.clients.kinesis = &clients + a.clients.kinesis.New(a.session) + return a.clients.kinesis } func (a *x509) Ses() *SesClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.ses != nil { - return a.Clients.ses + if a.clients.ses != nil { + return a.clients.ses } clients := SesClients{} - a.Clients.ses = &clients - a.Clients.ses.New(a.session) - return a.Clients.ses + a.clients.ses = &clients + a.clients.ses.New(a.session) + return a.clients.ses } func (a *x509) initializeTrustAnchors() error { @@ -287,16 +280,16 @@ func (a *x509) initializeTrustAnchors() error { profile arn.ARN err error ) - if a.TrustAnchorArn != nil { - trustAnchor, err = arn.Parse(*a.TrustAnchorArn) + if a.trustAnchorArn != nil { + trustAnchor, err = arn.Parse(*a.trustAnchorArn) if err != nil { return err } a.region = &trustAnchor.Region } - if a.TrustProfileArn != nil { - profile, err = arn.Parse(*a.TrustProfileArn) + if a.trustProfileArn != nil { + profile, err = arn.Parse(*a.trustProfileArn) if err != nil { return err } @@ -347,10 +340,10 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er var mySession *session.Session var awsConfig *aws.Config - if a.Cfg == nil { + if a.cfg == nil { awsConfig = aws.NewConfig().WithHTTPClient(client).WithLogLevel(aws.LogOff) } else { - awsConfig = a.Cfg.WithHTTPClient(client).WithLogLevel(aws.LogOff) + awsConfig = a.cfg.WithHTTPClient(client).WithLogLevel(aws.LogOff) } if a.region != nil { awsConfig.WithRegion(*a.region) @@ -368,35 +361,17 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er rolesClient = rolesAnywhereClient } - var ( - duration int64 - createSessionRequest rolesanywhere.CreateSessionInput - ) - if *a.SessionDuration != 0 { - duration = int64(a.SessionDuration.Seconds()) - - createSessionRequest = rolesanywhere.CreateSessionInput{ - Cert: ptr.Of(string(a.chainPEM)), - ProfileArn: a.TrustProfileArn, - TrustAnchorArn: a.TrustAnchorArn, - RoleArn: a.AssumeRoleArn, - DurationSeconds: aws.Int64(duration), - InstanceProperties: nil, - SessionName: nil, - } - } else { - duration = int64(time.Hour.Seconds()) - - createSessionRequest = rolesanywhere.CreateSessionInput{ - Cert: ptr.Of(string(a.chainPEM)), - ProfileArn: a.TrustProfileArn, - TrustAnchorArn: a.TrustAnchorArn, - RoleArn: a.AssumeRoleArn, - DurationSeconds: aws.Int64(duration), - InstanceProperties: nil, - SessionName: nil, - } + createSessionRequest := rolesanywhere.CreateSessionInput{ + Cert: ptr.Of(string(a.chainPEM)), + ProfileArn: a.trustProfileArn, + TrustAnchorArn: a.trustAnchorArn, + RoleArn: a.assumeRoleArn, + // https://aws.amazon.com/about-aws/whats-new/2024/03/iam-roles-anywhere-credentials-valid-12-hours/#:~:text=The%20duration%20can%20range%20from,and%20applications%2C%20to%20use%20X. + DurationSeconds: aws.Int64(int64(time.Hour.Seconds())), // AWS default is 1hr timeout + InstanceProperties: nil, + SessionName: nil, } + var output *rolesanywhere.CreateSessionOutput if a.rolesAnywhereClient != nil { var err error @@ -432,11 +407,6 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er func (a *x509) startSessionRefresher() { a.logger.Infof("starting session refresher for x509 auth") - // if there is a set session duration, then exit bc we will not auto refresh the session. - if *a.SessionDuration != 0 { - a.logger.Debugf("session duration was set, so there is no authentication refreshing") - return - } a.wg.Add(1) go func() { @@ -465,11 +435,11 @@ func (a *x509) refreshClient() { for { newSession, err := a.createOrRefreshSession(context.Background()) if err == nil { - a.Clients.refresh(newSession) + a.clients.refresh(newSession) a.logger.Debugf("AWS IAM Roles Anywhere session credentials refreshed successfully") return } - a.logger.Errorf("Failed to refresh session: %w", err) + a.logger.Errorf("Failed to refresh session, retrying in 5 seconds: %w", err) select { case <-time.After(time.Second * 5): case <-a.closeCh: diff --git a/common/authentication/aws/x509_test.go b/common/authentication/aws/x509_test.go index 4e5752f33d..3f7d2189c3 100644 --- a/common/authentication/aws/x509_test.go +++ b/common/authentication/aws/x509_test.go @@ -1,3 +1,16 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package aws import ( @@ -19,7 +32,6 @@ import ( spiffecontext "github.com/dapr/kit/crypto/spiffe/context" "github.com/dapr/kit/crypto/test" "github.com/dapr/kit/logger" - "github.com/dapr/kit/ptr" ) type mockRolesAnywhereClient struct { @@ -59,17 +71,15 @@ func TestGetX509Client(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - duration := time.Duration(800) mockSvc := &mockRolesAnywhereClient{ CreateSessionOutput: tt.mockOutput, CreateSessionError: tt.mockError, } mockAWS := x509{ logger: logger.NewLogger("testLogger"), - AssumeRoleArn: ptr.Of("arn:aws:iam:012345678910:role/exampleIAMRoleName"), - TrustAnchorArn: ptr.Of("arn:aws:rolesanywhere:us-west-1:012345678910:trust-anchor/01234568-0123-0123-0123-012345678901"), - TrustProfileArn: ptr.Of("arn:aws:rolesanywhere:us-west-1:012345678910:profile/01234568-0123-0123-0123-012345678901"), - SessionDuration: &duration, + assumeRoleArn: aws.String("arn:aws:iam:012345678910:role/exampleIAMRoleName"), + trustAnchorArn: aws.String("arn:aws:rolesanywhere:us-west-1:012345678910:trust-anchor/01234568-0123-0123-0123-012345678901"), + trustProfileArn: aws.String("arn:aws:rolesanywhere:us-west-1:012345678910:profile/01234568-0123-0123-0123-012345678901"), rolesAnywhereClient: mockSvc, } pki := test.GenPKI(t, test.PKIOptions{ diff --git a/state/aws/dynamodb/dynamodb.go b/state/aws/dynamodb/dynamodb.go index b85cd5e0e5..ae4ba7c5e9 100644 --- a/state/aws/dynamodb/dynamodb.go +++ b/state/aws/dynamodb/dynamodb.go @@ -275,7 +275,7 @@ func (d *StateStore) GetComponentMetadata() (metadataInfo metadata.MetadataMap) } func (d *StateStore) Close() error { - return nil + return d.authProvider.Close() } func (d *StateStore) getDynamoDBMetadata(meta state.Metadata) (*dynamoDBMetadata, error) { diff --git a/state/aws/dynamodb/dynamodb_test.go b/state/aws/dynamodb/dynamodb_test.go index d5dd43aeb8..28a34c8af9 100644 --- a/state/aws/dynamodb/dynamodb_test.go +++ b/state/aws/dynamodb/dynamodb_test.go @@ -22,12 +22,9 @@ import ( "time" awsAuth "github.com/dapr/components-contrib/common/authentication/aws" - "github.com/dapr/kit/logger" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" @@ -161,32 +158,14 @@ func TestInit(t *testing.T) { return nil, errors.New("Requested resource not found") }, } - dynamo := awsAuth.DynamoDBClients{ DynamoDB: mockedDB, } - mockedClients := awsAuth.Clients{ Dynamo: &dynamo, } - mockedSession, err := session.NewSession(&aws.Config{ - Region: aws.String("us-west-1"), - Credentials: credentials.AnonymousCredentials, - }) - require.NoError(t, err) mockAuthProvider := &awsAuth.StaticAuth{ - Logger: logger.NewLogger("test"), Clients: &mockedClients, - // mock client creds so we don't get the error -> access: EmptyStaticCreds: static credentials are empty" - Cfg: &aws.Config{ - Credentials: credentials.AnonymousCredentials, - }, - Region: "us-west-1", - AccessKey: aws.String("mocked"), - SecretKey: aws.String("mocked"), - SessionToken: aws.String("mocked"), - Endpoint: aws.String("mocked"), - Session: mockedSession, } s := StateStore{ @@ -195,7 +174,7 @@ func TestInit(t *testing.T) { table: table, } - err = s.Init(context.Background(), m) + err := s.Init(context.Background(), m) require.Error(t, err) require.EqualError(t, err, "error validating DynamoDB table 'does-not-exist' access: Requested resource not found") }) From a300a8cc7fe7aaf8eed960f5c4ded8fbcdbdc902 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Thu, 14 Nov 2024 10:19:55 -0600 Subject: [PATCH 37/39] style: make linter happy Signed-off-by: Samantha Coyle --- common/authentication/aws/static_test.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/common/authentication/aws/static_test.go b/common/authentication/aws/static_test.go index 3a9b3a2d00..a1a17a093c 100644 --- a/common/authentication/aws/static_test.go +++ b/common/authentication/aws/static_test.go @@ -3,6 +3,7 @@ package aws import ( "testing" + "github.com/aws/aws-sdk-go/aws" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -45,11 +46,11 @@ func TestGetTokenClient(t *testing.T) { { name: "valid token client", awsInstance: &StaticAuth{ - accessKey: "testAccessKey", - secretKey: "testSecretKey", - sessionToken: "testSessionToken", - region: "us-west-2", - endpoint: "https://test.endpoint.com", + accessKey: aws.String("testAccessKey"), + secretKey: aws.String("testSecretKey"), + sessionToken: aws.String("testSessionToken"), + region: aws.String("us-west-2"), + endpoint: aws.String("https://test.endpoint.com"), }, }, } @@ -59,7 +60,7 @@ func TestGetTokenClient(t *testing.T) { session, err := tt.awsInstance.getTokenClient() require.NotNil(t, session) require.NoError(t, err) - assert.Equal(t, tt.awsInstance.region, *session.Config.Region) + assert.Equal(t, tt.awsInstance.region, session.Config.Region) }) } } From 0b80a398a745a2a79828cd9d328a5b5c2e0df0de Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Thu, 14 Nov 2024 11:35:22 -0600 Subject: [PATCH 38/39] fix: allow for mocked clients without exported field Signed-off-by: Samantha Coyle --- common/authentication/aws/client.go | 13 ++ common/authentication/aws/client_fake.go | 79 +++++++ common/authentication/aws/static.go | 99 ++++---- .../aws/parameterstore/parameterstore_test.go | 80 ++----- .../aws/secretmanager/secretmanager_test.go | 43 +--- state/aws/dynamodb/dynamodb_test.go | 214 ++++++------------ 6 files changed, 249 insertions(+), 279 deletions(-) create mode 100644 common/authentication/aws/client_fake.go diff --git a/common/authentication/aws/client.go b/common/authentication/aws/client.go index 85864d23b2..8d0e9de20b 100644 --- a/common/authentication/aws/client.go +++ b/common/authentication/aws/client.go @@ -1,3 +1,16 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package aws import ( diff --git a/common/authentication/aws/client_fake.go b/common/authentication/aws/client_fake.go new file mode 100644 index 0000000000..c9e23641ba --- /dev/null +++ b/common/authentication/aws/client_fake.go @@ -0,0 +1,79 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" + "github.com/aws/aws-sdk-go/service/secretsmanager" + "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" + "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/aws-sdk-go/service/ssm/ssmiface" +) + +type MockParameterStore struct { + GetParameterFn func(context.Context, *ssm.GetParameterInput, ...request.Option) (*ssm.GetParameterOutput, error) + DescribeParametersFn func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) + ssmiface.SSMAPI +} + +func (m *MockParameterStore) GetParameterWithContext(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + return m.GetParameterFn(ctx, input, option...) +} + +func (m *MockParameterStore) DescribeParametersWithContext(ctx context.Context, input *ssm.DescribeParametersInput, option ...request.Option) (*ssm.DescribeParametersOutput, error) { + return m.DescribeParametersFn(ctx, input, option...) +} + +type MockSecretManager struct { + GetSecretValueFn func(context.Context, *secretsmanager.GetSecretValueInput, ...request.Option) (*secretsmanager.GetSecretValueOutput, error) + secretsmanageriface.SecretsManagerAPI +} + +func (m *MockSecretManager) GetSecretValueWithContext(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { + return m.GetSecretValueFn(ctx, input, option...) +} + +type MockDynamoDB struct { + GetItemWithContextFn func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) + PutItemWithContextFn func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) + DeleteItemWithContextFn func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) + BatchWriteItemWithContextFn func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) + TransactWriteItemsWithContextFn func(aws.Context, *dynamodb.TransactWriteItemsInput, ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) + dynamodbiface.DynamoDBAPI +} + +func (m *MockDynamoDB) GetItemWithContext(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { + return m.GetItemWithContextFn(ctx, input, op...) +} + +func (m *MockDynamoDB) PutItemWithContext(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) { + return m.PutItemWithContextFn(ctx, input, op...) +} + +func (m *MockDynamoDB) DeleteItemWithContext(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) { + return m.DeleteItemWithContextFn(ctx, input, op...) +} + +func (m *MockDynamoDB) BatchWriteItemWithContext(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) { + return m.BatchWriteItemWithContextFn(ctx, input, op...) +} + +func (m *MockDynamoDB) TransactWriteItemsWithContext(ctx context.Context, input *dynamodb.TransactWriteItemsInput, op ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) { + return m.TransactWriteItemsWithContextFn(ctx, input, op...) +} diff --git a/common/authentication/aws/static.go b/common/authentication/aws/static.go index 6997b62a1e..a66ef86e1e 100644 --- a/common/authentication/aws/static.go +++ b/common/authentication/aws/static.go @@ -41,7 +41,7 @@ type StaticAuth struct { session *session.Session cfg *aws.Config - Clients *Clients // exported to mock clients in unit tests + clients *Clients } func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth, error) { @@ -60,7 +60,7 @@ func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth } return GetConfig(opts) }(), - Clients: newClients(), + clients: newClients(), } initialSession, err := auth.getTokenClient() @@ -73,132 +73,137 @@ func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth return auth, nil } +// This is to be used only for test purposes to inject mocked clients +func (a *StaticAuth) WithMockClients(clients *Clients) { + a.clients = clients +} + func (a *StaticAuth) S3() *S3Clients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.s3 != nil { - return a.Clients.s3 + if a.clients.s3 != nil { + return a.clients.s3 } s3Clients := S3Clients{} - a.Clients.s3 = &s3Clients - a.Clients.s3.New(a.session) - return a.Clients.s3 + a.clients.s3 = &s3Clients + a.clients.s3.New(a.session) + return a.clients.s3 } func (a *StaticAuth) DynamoDB() *DynamoDBClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.Dynamo != nil { - return a.Clients.Dynamo + if a.clients.Dynamo != nil { + return a.clients.Dynamo } clients := DynamoDBClients{} - a.Clients.Dynamo = &clients - a.Clients.Dynamo.New(a.session) + a.clients.Dynamo = &clients + a.clients.Dynamo.New(a.session) - return a.Clients.Dynamo + return a.clients.Dynamo } func (a *StaticAuth) Sqs() *SqsClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.sqs != nil { - return a.Clients.sqs + if a.clients.sqs != nil { + return a.clients.sqs } clients := SqsClients{} - a.Clients.sqs = &clients - a.Clients.sqs.New(a.session) + a.clients.sqs = &clients + a.clients.sqs.New(a.session) - return a.Clients.sqs + return a.clients.sqs } func (a *StaticAuth) Sns() *SnsClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.sns != nil { - return a.Clients.sns + if a.clients.sns != nil { + return a.clients.sns } clients := SnsClients{} - a.Clients.sns = &clients - a.Clients.sns.New(a.session) - return a.Clients.sns + a.clients.sns = &clients + a.clients.sns.New(a.session) + return a.clients.sns } func (a *StaticAuth) SnsSqs() *SnsSqsClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.snssqs != nil { - return a.Clients.snssqs + if a.clients.snssqs != nil { + return a.clients.snssqs } clients := SnsSqsClients{} - a.Clients.snssqs = &clients - a.Clients.snssqs.New(a.session) - return a.Clients.snssqs + a.clients.snssqs = &clients + a.clients.snssqs.New(a.session) + return a.clients.snssqs } func (a *StaticAuth) SecretManager() *SecretManagerClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.Secret != nil { - return a.Clients.Secret + if a.clients.Secret != nil { + return a.clients.Secret } clients := SecretManagerClients{} - a.Clients.Secret = &clients - a.Clients.Secret.New(a.session) - return a.Clients.Secret + a.clients.Secret = &clients + a.clients.Secret.New(a.session) + return a.clients.Secret } func (a *StaticAuth) ParameterStore() *ParameterStoreClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.ParameterStore != nil { - return a.Clients.ParameterStore + if a.clients.ParameterStore != nil { + return a.clients.ParameterStore } clients := ParameterStoreClients{} - a.Clients.ParameterStore = &clients - a.Clients.ParameterStore.New(a.session) - return a.Clients.ParameterStore + a.clients.ParameterStore = &clients + a.clients.ParameterStore.New(a.session) + return a.clients.ParameterStore } func (a *StaticAuth) Kinesis() *KinesisClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.kinesis != nil { - return a.Clients.kinesis + if a.clients.kinesis != nil { + return a.clients.kinesis } clients := KinesisClients{} - a.Clients.kinesis = &clients - a.Clients.kinesis.New(a.session) - return a.Clients.kinesis + a.clients.kinesis = &clients + a.clients.kinesis.New(a.session) + return a.clients.kinesis } func (a *StaticAuth) Ses() *SesClients { a.mu.Lock() defer a.mu.Unlock() - if a.Clients.ses != nil { - return a.Clients.ses + if a.clients.ses != nil { + return a.clients.ses } clients := SesClients{} - a.Clients.ses = &clients - a.Clients.ses.New(a.session) - return a.Clients.ses + a.clients.ses = &clients + a.clients.ses.New(a.session) + return a.clients.ses } func (a *StaticAuth) getTokenClient() (*session.Session, error) { diff --git a/secretstores/aws/parameterstore/parameterstore_test.go b/secretstores/aws/parameterstore/parameterstore_test.go index 200fa58030..04c7a6995e 100644 --- a/secretstores/aws/parameterstore/parameterstore_test.go +++ b/secretstores/aws/parameterstore/parameterstore_test.go @@ -27,7 +27,6 @@ import ( "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/ssm" - "github.com/aws/aws-sdk-go/service/ssm/ssmiface" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -37,20 +36,6 @@ import ( const secretValue = "secret" -type mockedSSM struct { - GetParameterFn func(context.Context, *ssm.GetParameterInput, ...request.Option) (*ssm.GetParameterOutput, error) - DescribeParametersFn func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) - ssmiface.SSMAPI -} - -func (m *mockedSSM) GetParameterWithContext(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - return m.GetParameterFn(ctx, input, option...) -} - -func (m *mockedSSM) DescribeParametersWithContext(ctx context.Context, input *ssm.DescribeParametersInput, option ...request.Option) (*ssm.DescribeParametersOutput, error) { - return m.DescribeParametersFn(ctx, input, option...) -} - func TestInit(t *testing.T) { m := secretstores.Metadata{} s := NewParameterStore(logger.NewLogger("test")) @@ -71,7 +56,7 @@ func TestInit(t *testing.T) { func TestGetSecret(t *testing.T) { t.Run("successfully retrieve secret", func(t *testing.T) { t.Run("with valid path", func(t *testing.T) { - mockSSM := &mockedSSM{ + mockSSM := &awsAuth.MockParameterStore{ GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { secret := secretValue return &ssm.GetParameterOutput{ @@ -90,10 +75,8 @@ func TestGetSecret(t *testing.T) { mockedClients := awsAuth.Clients{ ParameterStore: ¶mStore, } - - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := ssmSecretStore{ authProvider: mockAuthProvider, @@ -109,7 +92,7 @@ func TestGetSecret(t *testing.T) { }) t.Run("with version id", func(t *testing.T) { - mockSSM := &mockedSSM{ + mockSSM := &awsAuth.MockParameterStore{ GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { secret := secretValue keys := strings.Split(*input.Name, ":") @@ -133,11 +116,8 @@ func TestGetSecret(t *testing.T) { mockedClients := awsAuth.Clients{ ParameterStore: ¶mStore, } - - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := ssmSecretStore{ authProvider: mockAuthProvider, } @@ -154,7 +134,7 @@ func TestGetSecret(t *testing.T) { }) t.Run("with prefix", func(t *testing.T) { - mockSSM := &mockedSSM{ + mockSSM := &awsAuth.MockParameterStore{ GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { assert.Equal(t, "/prefix/aws/dev/secret", *input.Name) secret := secretValue @@ -175,10 +155,8 @@ func TestGetSecret(t *testing.T) { mockedClients := awsAuth.Clients{ ParameterStore: ¶mStore, } - - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := ssmSecretStore{ authProvider: mockAuthProvider, @@ -196,7 +174,7 @@ func TestGetSecret(t *testing.T) { }) t.Run("unsuccessfully retrieve secret", func(t *testing.T) { - mockSSM := &mockedSSM{ + mockSSM := &awsAuth.MockParameterStore{ GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { return nil, errors.New("failed due to any reason") }, @@ -209,10 +187,8 @@ func TestGetSecret(t *testing.T) { mockedClients := awsAuth.Clients{ ParameterStore: ¶mStore, } - - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := ssmSecretStore{ authProvider: mockAuthProvider, @@ -230,7 +206,7 @@ func TestGetSecret(t *testing.T) { func TestGetBulkSecrets(t *testing.T) { t.Run("successfully retrieve bulk secrets", func(t *testing.T) { - mockSSM := &mockedSSM{ + mockSSM := &awsAuth.MockParameterStore{ DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ { @@ -260,10 +236,8 @@ func TestGetBulkSecrets(t *testing.T) { mockedClients := awsAuth.Clients{ ParameterStore: ¶mStore, } - - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := ssmSecretStore{ authProvider: mockAuthProvider, } @@ -278,7 +252,7 @@ func TestGetBulkSecrets(t *testing.T) { }) t.Run("successfully retrieve bulk secrets with prefix", func(t *testing.T) { - mockSSM := &mockedSSM{ + mockSSM := &awsAuth.MockParameterStore{ DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ { @@ -308,10 +282,8 @@ func TestGetBulkSecrets(t *testing.T) { mockedClients := awsAuth.Clients{ ParameterStore: ¶mStore, } - - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := ssmSecretStore{ authProvider: mockAuthProvider, prefix: "/prefix", @@ -327,7 +299,7 @@ func TestGetBulkSecrets(t *testing.T) { }) t.Run("unsuccessfully retrieve bulk secrets on get parameter", func(t *testing.T) { - mockSSM := &mockedSSM{ + mockSSM := &awsAuth.MockParameterStore{ DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ { @@ -350,10 +322,8 @@ func TestGetBulkSecrets(t *testing.T) { mockedClients := awsAuth.Clients{ ParameterStore: ¶mStore, } - - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := ssmSecretStore{ authProvider: mockAuthProvider, } @@ -366,7 +336,7 @@ func TestGetBulkSecrets(t *testing.T) { }) t.Run("unsuccessfully retrieve bulk secrets on describe parameter", func(t *testing.T) { - mockSSM := &mockedSSM{ + mockSSM := &awsAuth.MockParameterStore{ DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { return nil, errors.New("failed due to any reason") }, @@ -379,10 +349,8 @@ func TestGetBulkSecrets(t *testing.T) { mockedClients := awsAuth.Clients{ ParameterStore: ¶mStore, } - - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := ssmSecretStore{ authProvider: mockAuthProvider, } diff --git a/secretstores/aws/secretmanager/secretmanager_test.go b/secretstores/aws/secretmanager/secretmanager_test.go index 6c6736f8b9..7fbd8493af 100644 --- a/secretstores/aws/secretmanager/secretmanager_test.go +++ b/secretstores/aws/secretmanager/secretmanager_test.go @@ -21,7 +21,6 @@ import ( "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/secretsmanager" - "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -33,15 +32,6 @@ import ( const secretValue = "secret" -type mockedSM struct { - GetSecretValueFn func(context.Context, *secretsmanager.GetSecretValueInput, ...request.Option) (*secretsmanager.GetSecretValueOutput, error) - secretsmanageriface.SecretsManagerAPI -} - -func (m *mockedSM) GetSecretValueWithContext(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { - return m.GetSecretValueFn(ctx, input, option...) -} - func TestInit(t *testing.T) { m := secretstores.Metadata{} s := NewSecretManager(logger.NewLogger("test")) @@ -62,7 +52,7 @@ func TestInit(t *testing.T) { func TestGetSecret(t *testing.T) { t.Run("successfully retrieve secret", func(t *testing.T) { t.Run("without version id and version stage", func(t *testing.T) { - mockSSM := &mockedSM{ + mockSSM := &awsAuth.MockSecretManager{ GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { assert.Nil(t, input.VersionId) assert.Nil(t, input.VersionStage) @@ -82,11 +72,8 @@ func TestGetSecret(t *testing.T) { mockedClients := awsAuth.Clients{ Secret: &secret, } - - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := smSecretStore{ authProvider: mockAuthProvider, } @@ -101,7 +88,7 @@ func TestGetSecret(t *testing.T) { }) t.Run("with version id", func(t *testing.T) { - mockSSM := &mockedSM{ + mockSSM := &awsAuth.MockSecretManager{ GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { assert.NotNil(t, input.VersionId) secret := secretValue @@ -121,10 +108,8 @@ func TestGetSecret(t *testing.T) { Secret: &secret, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := smSecretStore{ authProvider: mockAuthProvider, } @@ -141,7 +126,7 @@ func TestGetSecret(t *testing.T) { }) t.Run("with version stage", func(t *testing.T) { - mockSSM := &mockedSM{ + mockSSM := &awsAuth.MockSecretManager{ GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { assert.NotNil(t, input.VersionStage) secret := secretValue @@ -161,10 +146,8 @@ func TestGetSecret(t *testing.T) { Secret: &secret, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := smSecretStore{ authProvider: mockAuthProvider, } @@ -182,7 +165,7 @@ func TestGetSecret(t *testing.T) { }) t.Run("unsuccessfully retrieve secret", func(t *testing.T) { - mockSSM := &mockedSM{ + mockSSM := &awsAuth.MockSecretManager{ GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { return nil, errors.New("failed due to any reason") }, @@ -196,10 +179,8 @@ func TestGetSecret(t *testing.T) { Secret: &secret, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := smSecretStore{ authProvider: mockAuthProvider, } diff --git a/state/aws/dynamodb/dynamodb_test.go b/state/aws/dynamodb/dynamodb_test.go index 28a34c8af9..7b667b6c78 100644 --- a/state/aws/dynamodb/dynamodb_test.go +++ b/state/aws/dynamodb/dynamodb_test.go @@ -27,22 +27,12 @@ import ( "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/dapr/components-contrib/state" ) -type mockedDynamoDB struct { - GetItemWithContextFn func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) - PutItemWithContextFn func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) - DeleteItemWithContextFn func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) - BatchWriteItemWithContextFn func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) - TransactWriteItemsWithContextFn func(aws.Context, *dynamodb.TransactWriteItemsInput, ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) - dynamodbiface.DynamoDBAPI -} - type DynamoDBItem struct { Key string `json:"key"` Value string `json:"value"` @@ -54,29 +44,9 @@ const ( pkey = "partitionKey" ) -func (m *mockedDynamoDB) GetItemWithContext(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { - return m.GetItemWithContextFn(ctx, input, op...) -} - -func (m *mockedDynamoDB) PutItemWithContext(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) { - return m.PutItemWithContextFn(ctx, input, op...) -} - -func (m *mockedDynamoDB) DeleteItemWithContext(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) { - return m.DeleteItemWithContextFn(ctx, input, op...) -} - -func (m *mockedDynamoDB) BatchWriteItemWithContext(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) { - return m.BatchWriteItemWithContextFn(ctx, input, op...) -} - -func (m *mockedDynamoDB) TransactWriteItemsWithContext(ctx context.Context, input *dynamodb.TransactWriteItemsInput, op ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) { - return m.TransactWriteItemsWithContextFn(ctx, input, op...) -} - func TestInit(t *testing.T) { m := state.Metadata{} - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ // We're adding this so we can pass the connection check on Init GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { return nil, nil @@ -91,10 +61,8 @@ func TestInit(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, partitionKey: defaultPartitionKeyName, @@ -153,7 +121,7 @@ func TestInit(t *testing.T) { "Table": table, } - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { return nil, errors.New("Requested resource not found") }, @@ -164,10 +132,8 @@ func TestInit(t *testing.T) { mockedClients := awsAuth.Clients{ Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, partitionKey: defaultPartitionKeyName, @@ -182,7 +148,7 @@ func TestInit(t *testing.T) { func TestGet(t *testing.T) { t.Run("Successfully retrieve item", func(t *testing.T) { - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return &dynamodb.GetItemOutput{ Item: map[string]*dynamodb.AttributeValue{ @@ -208,10 +174,8 @@ func TestGet(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, partitionKey: defaultPartitionKeyName, @@ -230,7 +194,7 @@ func TestGet(t *testing.T) { assert.NotContains(t, out.Metadata, "ttlExpireTime") }) t.Run("Successfully retrieve item (with unexpired ttl)", func(t *testing.T) { - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return &dynamodb.GetItemOutput{ Item: map[string]*dynamodb.AttributeValue{ @@ -259,10 +223,8 @@ func TestGet(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, ttlAttributeName: "testAttributeName", @@ -284,7 +246,7 @@ func TestGet(t *testing.T) { assert.Equal(t, int64(4074862051), expireTime.Unix()) }) t.Run("Successfully retrieve item (with expired ttl)", func(t *testing.T) { - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return &dynamodb.GetItemOutput{ Item: map[string]*dynamodb.AttributeValue{ @@ -313,10 +275,8 @@ func TestGet(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, ttlAttributeName: "testAttributeName", @@ -335,7 +295,7 @@ func TestGet(t *testing.T) { assert.Nil(t, out.Metadata) }) t.Run("Unsuccessfully get item", func(t *testing.T) { - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return nil, errors.New("failed to retrieve data") }, @@ -349,10 +309,8 @@ func TestGet(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, } @@ -369,7 +327,7 @@ func TestGet(t *testing.T) { assert.Nil(t, out) }) t.Run("Unsuccessfully with empty response", func(t *testing.T) { - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return &dynamodb.GetItemOutput{ Item: map[string]*dynamodb.AttributeValue{}, @@ -385,10 +343,8 @@ func TestGet(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, } @@ -406,7 +362,7 @@ func TestGet(t *testing.T) { assert.Nil(t, out.Metadata) }) t.Run("Unsuccessfully with no required key", func(t *testing.T) { - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return &dynamodb.GetItemOutput{ Item: map[string]*dynamodb.AttributeValue{ @@ -426,10 +382,8 @@ func TestGet(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, } @@ -453,7 +407,7 @@ func TestSet(t *testing.T) { } t.Run("Successfully set item", func(t *testing.T) { - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -481,10 +435,8 @@ func TestSet(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, partitionKey: defaultPartitionKeyName, @@ -501,7 +453,7 @@ func TestSet(t *testing.T) { }) t.Run("Successfully set item with matching etag", func(t *testing.T) { - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -533,10 +485,8 @@ func TestSet(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, partitionKey: defaultPartitionKeyName, @@ -554,7 +504,7 @@ func TestSet(t *testing.T) { }) t.Run("Unsuccessfully set item with mismatched etag", func(t *testing.T) { - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -581,10 +531,8 @@ func TestSet(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, partitionKey: defaultPartitionKeyName, @@ -609,7 +557,7 @@ func TestSet(t *testing.T) { }) t.Run("Successfully set item with first-write-concurrency", func(t *testing.T) { - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -638,10 +586,8 @@ func TestSet(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, partitionKey: defaultPartitionKeyName, @@ -660,7 +606,7 @@ func TestSet(t *testing.T) { }) t.Run("Unsuccessfully set item with first-write-concurrency", func(t *testing.T) { - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -684,10 +630,8 @@ func TestSet(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, partitionKey: defaultPartitionKeyName, @@ -711,7 +655,7 @@ func TestSet(t *testing.T) { }) t.Run("Successfully set item with ttl = -1", func(t *testing.T) { - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Len(t, input.Item, 4) result := DynamoDBItem{} @@ -739,10 +683,8 @@ func TestSet(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, ttlAttributeName: "testAttributeName", @@ -762,7 +704,7 @@ func TestSet(t *testing.T) { require.NoError(t, err) }) t.Run("Successfully set item with 'correct' ttl", func(t *testing.T) { - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Len(t, input.Item, 4) result := DynamoDBItem{} @@ -790,10 +732,8 @@ func TestSet(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, partitionKey: defaultPartitionKeyName, @@ -814,7 +754,7 @@ func TestSet(t *testing.T) { }) t.Run("Unsuccessfully set item", func(t *testing.T) { - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { return nil, errors.New("unable to put item") }, @@ -828,10 +768,8 @@ func TestSet(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, partitionKey: defaultPartitionKeyName, @@ -846,7 +784,7 @@ func TestSet(t *testing.T) { require.Error(t, err) }) t.Run("Successfully set item with correct ttl but without component metadata", func(t *testing.T) { - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("someKey"), @@ -874,10 +812,8 @@ func TestSet(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, partitionKey: defaultPartitionKeyName, @@ -895,7 +831,7 @@ func TestSet(t *testing.T) { require.NoError(t, err) }) t.Run("Unsuccessfully set item with ttl (invalid value)", func(t *testing.T) { - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { @@ -927,10 +863,8 @@ func TestSet(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, ttlAttributeName: "testAttributeName", @@ -956,7 +890,7 @@ func TestDelete(t *testing.T) { Key: "key", } - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { @@ -976,10 +910,8 @@ func TestDelete(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, partitionKey: defaultPartitionKeyName, @@ -996,7 +928,7 @@ func TestDelete(t *testing.T) { Key: "key", } - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { @@ -1020,10 +952,8 @@ func TestDelete(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, partitionKey: defaultPartitionKeyName, @@ -1040,7 +970,7 @@ func TestDelete(t *testing.T) { Key: "key", } - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { @@ -1065,10 +995,8 @@ func TestDelete(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, partitionKey: defaultPartitionKeyName, @@ -1084,7 +1012,7 @@ func TestDelete(t *testing.T) { }) t.Run("Unsuccessfully delete item", func(t *testing.T) { - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { return nil, errors.New("unable to delete item") }, @@ -1098,10 +1026,8 @@ func TestDelete(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, } @@ -1139,7 +1065,7 @@ func TestMultiTx(t *testing.T) { }, } - mockedDB := &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ TransactWriteItemsWithContextFn: func(ctx context.Context, input *dynamodb.TransactWriteItemsInput, op ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) { // ops - duplicates exOps := len(ops) - 1 @@ -1172,10 +1098,8 @@ func TestMultiTx(t *testing.T) { Dynamo: &dynamo, } - mockAuthProvider := &awsAuth.StaticAuth{ - Clients: &mockedClients, - } - + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) s := StateStore{ authProvider: mockAuthProvider, table: tableName, From cde5a1011800e6e297284d3b5e3a8aa911bc6496 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Thu, 14 Nov 2024 11:40:55 -0600 Subject: [PATCH 39/39] fix: add one last closer Signed-off-by: Samantha Coyle --- pubsub/aws/snssqs/snssqs.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index 89e8f5ef7a..93481fb733 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -875,7 +875,7 @@ func (s *snsSqs) Close() error { s.subscriptionManager.Close() } - return nil + return s.authProvider.Close() } func (s *snsSqs) Features() []pubsub.Feature {