diff --git a/.build-tools/builtin-authentication-profiles.yaml b/.build-tools/builtin-authentication-profiles.yaml index 9113cf286b..cccb195a44 100644 --- a/.build-tools/builtin-authentication-profiles.yaml +++ b/.build-tools/builtin-authentication-profiles.yaml @@ -29,7 +29,25 @@ aws: type: string - title: "AWS: Credentials from Environment Variables" description: Use AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY from the environment - + - title: "AWS: IAM Roles Anywhere" + description: Use X.509 certificates to establish trust between AWS and your AWS account and the Dapr cluster using AWS IAM Roles Anywhere. + metadata: + - name: trustAnchorArn + description: | + ARN of the AWS Trust Anchor in the AWS account granting trust to the Dapr Certificate Authority. + example: arn:aws:rolesanywhere:us-west-1:012345678910:trust-anchor/01234568-0123-0123-0123-012345678901 + required: true + - name: trustProfileArn + description: | + ARN of the AWS IAM Profile in the trusting AWS account. + example: arn:aws:rolesanywhere:us-west-1:012345678910:profile/01234568-0123-0123-0123-012345678901 + required: true + - name: assumeRoleArn + description: | + ARN of the AWS IAM role to assume in the trusting AWS account. + example: arn:aws:iam:012345678910:role/exampleIAMRoleName + required: true + azuread: - title: "Azure AD: Managed identity" description: Authenticate using Azure AD and a managed identity. diff --git a/bindings/aws/dynamodb/dynamodb.go b/bindings/aws/dynamodb/dynamodb.go index bd882e7b55..755b3158d3 100644 --- a/bindings/aws/dynamodb/dynamodb.go +++ b/bindings/aws/dynamodb/dynamodb.go @@ -31,9 +31,9 @@ import ( // DynamoDB allows performing stateful operations on AWS DynamoDB. type DynamoDB struct { - client *dynamodb.DynamoDB - table string - logger logger.Logger + authProvider awsAuth.Provider + table string + logger logger.Logger } type dynamoDBMetadata struct { @@ -51,18 +51,27 @@ func NewDynamoDB(logger logger.Logger) bindings.OutputBinding { } // Init performs connection parsing for DynamoDB. -func (d *DynamoDB) Init(_ context.Context, metadata bindings.Metadata) error { +func (d *DynamoDB) Init(ctx context.Context, metadata bindings.Metadata) error { meta, err := d.getDynamoDBMetadata(metadata) if err != nil { return err } - client, err := d.getClient(meta) + opts := awsAuth.Options{ + Logger: d.logger, + Properties: metadata.Properties, + Region: meta.Region, + Endpoint: meta.Endpoint, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, + } + + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - - d.client = client + d.authProvider = provider d.table = meta.Table return nil @@ -84,7 +93,7 @@ func (d *DynamoDB) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bi return nil, err } - _, err = d.client.PutItemWithContext(ctx, &dynamodb.PutItemInput{ + _, err = d.authProvider.DynamoDB().DynamoDB.PutItemWithContext(ctx, &dynamodb.PutItemInput{ Item: item, TableName: aws.String(d.table), }) @@ -105,16 +114,6 @@ func (d *DynamoDB) getDynamoDBMetadata(spec bindings.Metadata) (*dynamoDBMetadat return &meta, nil } -func (d *DynamoDB) getClient(metadata *dynamoDBMetadata) (*dynamodb.DynamoDB, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - c := dynamodb.New(sess) - - return c, nil -} - // GetComponentMetadata returns the metadata of the component. func (d *DynamoDB) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { metadataStruct := dynamoDBMetadata{} @@ -123,5 +122,5 @@ func (d *DynamoDB) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { } func (d *DynamoDB) Close() error { - return nil + return d.authProvider.Close() } diff --git a/bindings/aws/kinesis/kinesis.go b/bindings/aws/kinesis/kinesis.go index dbe0ceb918..7ede7ba245 100644 --- a/bindings/aws/kinesis/kinesis.go +++ b/bindings/aws/kinesis/kinesis.go @@ -27,7 +27,6 @@ import ( "github.com/aws/aws-sdk-go/service/kinesis" "github.com/cenkalti/backoff/v4" "github.com/google/uuid" - "github.com/vmware/vmware-go-kcl/clientlibrary/config" "github.com/vmware/vmware-go-kcl/clientlibrary/interfaces" "github.com/vmware/vmware-go-kcl/clientlibrary/worker" @@ -40,15 +39,16 @@ import ( // AWSKinesis allows receiving and sending data to/from AWS Kinesis stream. type AWSKinesis struct { - client *kinesis.Kinesis - metadata *kinesisMetadata + authProvider awsAuth.Provider + metadata *kinesisMetadata - worker *worker.Worker - workerConfig *config.KinesisClientLibConfiguration + worker *worker.Worker - streamARN *string - consumerARN *string - logger logger.Logger + streamName string + consumerName string + consumerARN *string + logger logger.Logger + consumerMode string closed atomic.Bool closeCh chan struct{} @@ -112,30 +112,25 @@ func (a *AWSKinesis) Init(ctx context.Context, metadata bindings.Metadata) error return fmt.Errorf("%s invalid \"mode\" field %s", "aws.kinesis", m.KinesisConsumerMode) } - client, err := a.getClient(m) - if err != nil { - return err - } + a.consumerMode = m.KinesisConsumerMode + a.streamName = m.StreamName + a.consumerName = m.ConsumerName + a.metadata = m - streamName := aws.String(m.StreamName) - stream, err := client.DescribeStreamWithContext(ctx, &kinesis.DescribeStreamInput{ - StreamName: streamName, - }) + opts := awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: "", + } + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - - if m.KinesisConsumerMode == SharedThroughput { - kclConfig := config.NewKinesisClientLibConfigWithCredential(m.ConsumerName, - m.StreamName, m.Region, m.ConsumerName, - client.Config.Credentials) - a.workerConfig = kclConfig - } - - a.streamARN = stream.StreamDescription.StreamARN - a.metadata = m - a.client = client - + a.authProvider = provider return nil } @@ -148,7 +143,7 @@ func (a *AWSKinesis) Invoke(ctx context.Context, req *bindings.InvokeRequest) (* if partitionKey == "" { partitionKey = uuid.New().String() } - _, err := a.client.PutRecordWithContext(ctx, &kinesis.PutRecordInput{ + _, err := a.authProvider.Kinesis().Kinesis.PutRecordWithContext(ctx, &kinesis.PutRecordInput{ StreamName: &a.metadata.StreamName, Data: req.Data, PartitionKey: &partitionKey, @@ -161,16 +156,15 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er if a.closed.Load() { return errors.New("binding is closed") } - if a.metadata.KinesisConsumerMode == SharedThroughput { - a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), a.workerConfig) + a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), a.authProvider.Kinesis().WorkerCfg(ctx, a.streamName, a.consumerName, a.consumerMode)) err = a.worker.Start() if err != nil { return err } } else if a.metadata.KinesisConsumerMode == ExtendedFanout { var stream *kinesis.DescribeStreamOutput - stream, err = a.client.DescribeStream(&kinesis.DescribeStreamInput{StreamName: &a.metadata.StreamName}) + stream, err = a.authProvider.Kinesis().Kinesis.DescribeStream(&kinesis.DescribeStreamInput{StreamName: &a.metadata.StreamName}) if err != nil { return err } @@ -180,6 +174,10 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er } } + stream, err := a.authProvider.Kinesis().Stream(ctx, a.streamName) + if err != nil { + return fmt.Errorf("failed to get kinesis stream arn: %v", err) + } // Wait for context cancelation then stop a.wg.Add(1) go func() { @@ -191,7 +189,7 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er if a.metadata.KinesisConsumerMode == SharedThroughput { a.worker.Shutdown() } else if a.metadata.KinesisConsumerMode == ExtendedFanout { - a.deregisterConsumer(a.streamARN, a.consumerARN) + a.deregisterConsumer(ctx, stream, a.consumerARN) } }() @@ -226,8 +224,7 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes return default: } - - sub, err := a.client.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{ + sub, err := a.authProvider.Kinesis().Kinesis.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{ ConsumerARN: consumerARN, ShardId: s.ShardId, StartingPosition: &kinesis.StartingPosition{Type: aws.String(kinesis.ShardIteratorTypeLatest)}, @@ -269,14 +266,14 @@ func (a *AWSKinesis) Close() error { close(a.closeCh) } a.wg.Wait() - return nil + return a.authProvider.Close() } func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*string, error) { // Only set timeout on consumer call. conCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - consumer, err := a.client.DescribeStreamConsumerWithContext(conCtx, &kinesis.DescribeStreamConsumerInput{ + consumer, err := a.authProvider.Kinesis().Kinesis.DescribeStreamConsumerWithContext(conCtx, &kinesis.DescribeStreamConsumerInput{ ConsumerName: &a.metadata.ConsumerName, StreamARN: streamARN, }) @@ -288,7 +285,7 @@ func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*st } func (a *AWSKinesis) registerConsumer(ctx context.Context, streamARN *string) (*string, error) { - consumer, err := a.client.RegisterStreamConsumerWithContext(ctx, &kinesis.RegisterStreamConsumerInput{ + consumer, err := a.authProvider.Kinesis().Kinesis.RegisterStreamConsumerWithContext(ctx, &kinesis.RegisterStreamConsumerInput{ ConsumerName: &a.metadata.ConsumerName, StreamARN: streamARN, }) @@ -307,11 +304,11 @@ func (a *AWSKinesis) registerConsumer(ctx context.Context, streamARN *string) (* return consumer.Consumer.ConsumerARN, nil } -func (a *AWSKinesis) deregisterConsumer(streamARN *string, consumerARN *string) error { +func (a *AWSKinesis) deregisterConsumer(ctx context.Context, streamARN *string, consumerARN *string) error { if a.consumerARN != nil { // Use a background context because the running context may have been canceled already ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - _, err := a.client.DeregisterStreamConsumerWithContext(ctx, &kinesis.DeregisterStreamConsumerInput{ + _, err := a.authProvider.Kinesis().Kinesis.DeregisterStreamConsumerWithContext(ctx, &kinesis.DeregisterStreamConsumerInput{ ConsumerARN: consumerARN, StreamARN: streamARN, ConsumerName: &a.metadata.ConsumerName, @@ -342,7 +339,7 @@ func (a *AWSKinesis) waitUntilConsumerExists(ctx aws.Context, input *kinesis.Des tmp := *input inCpy = &tmp } - req, _ := a.client.DescribeStreamConsumerRequest(inCpy) + req, _ := a.authProvider.Kinesis().Kinesis.DescribeStreamConsumerRequest(inCpy) req.SetContext(ctx) req.ApplyOptions(opts...) @@ -354,16 +351,6 @@ func (a *AWSKinesis) waitUntilConsumerExists(ctx aws.Context, input *kinesis.Des return w.WaitWithContext(ctx) } -func (a *AWSKinesis) getClient(metadata *kinesisMetadata) (*kinesis.Kinesis, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - k := kinesis.New(sess) - - return k, nil -} - func (a *AWSKinesis) parseMetadata(meta bindings.Metadata) (*kinesisMetadata, error) { var m kinesisMetadata err := kitmd.DecodeMetadata(meta.Properties, &m) diff --git a/bindings/aws/s3/s3.go b/bindings/aws/s3/s3.go index cc67cec94f..13f8730e78 100644 --- a/bindings/aws/s3/s3.go +++ b/bindings/aws/s3/s3.go @@ -29,9 +29,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/google/uuid" @@ -61,11 +59,9 @@ const ( // AWSS3 is a binding for an AWS S3 storage bucket. type AWSS3 struct { - metadata *s3Metadata - s3Client *s3.S3 - uploader *s3manager.Uploader - downloader *s3manager.Downloader - logger logger.Logger + metadata *s3Metadata + authProvider awsAuth.Provider + logger logger.Logger } type s3Metadata struct { @@ -109,23 +105,11 @@ func NewAWSS3(logger logger.Logger) bindings.OutputBinding { return &AWSS3{logger: logger} } -// Init does metadata parsing and connection creation. -func (s *AWSS3) Init(_ context.Context, metadata bindings.Metadata) error { - m, err := s.parseMetadata(metadata) - if err != nil { - return err - } - session, err := s.getSession(m) - if err != nil { - return err - } - - cfg := aws.NewConfig(). - WithS3ForcePathStyle(m.ForcePathStyle). - WithDisableSSL(m.DisableSSL) +func (s *AWSS3) getAWSConfig(opts awsAuth.Options) *aws.Config { + cfg := awsAuth.GetConfig(opts).WithS3ForcePathStyle(s.metadata.ForcePathStyle).WithDisableSSL(s.metadata.DisableSSL) // Use a custom HTTP client to allow self-signed certs - if m.InsecureSSL { + if s.metadata.InsecureSSL { customTransport := http.DefaultTransport.(*http.Transport).Clone() customTransport.TLSClientConfig = &tls.Config{ //nolint:gosec @@ -138,17 +122,38 @@ func (s *AWSS3) Init(_ context.Context, metadata bindings.Metadata) error { s.logger.Infof("aws s3: you are using 'insecureSSL' to skip server config verify which is unsafe!") } + return cfg +} +// Init does metadata parsing and connection creation. +func (s *AWSS3) Init(ctx context.Context, metadata bindings.Metadata) error { + m, err := s.parseMetadata(metadata) + if err != nil { + return err + } s.metadata = m - s.s3Client = s3.New(session, cfg) - s.downloader = s3manager.NewDownloaderWithClient(s.s3Client) - s.uploader = s3manager.NewUploaderWithClient(s.s3Client) + + opts := awsAuth.Options{ + Logger: s.logger, + Properties: metadata.Properties, + Region: m.Region, + Endpoint: m.Endpoint, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, + } + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, s.getAWSConfig(opts)) + if err != nil { + return err + } + s.authProvider = provider return nil } func (s *AWSS3) Close() error { - return nil + return s.authProvider.Close() } func (s *AWSS3) Operations() []bindings.OperationKind { @@ -201,8 +206,7 @@ func (s *AWSS3) create(ctx context.Context, req *bindings.InvokeRequest) (*bindi if metadata.StorageClass != "" { storageClass = aws.String(metadata.StorageClass) } - - resultUpload, err := s.uploader.UploadWithContext(ctx, &s3manager.UploadInput{ + resultUpload, err := s.authProvider.S3().Uploader.UploadWithContext(ctx, &s3manager.UploadInput{ Bucket: ptr.Of(metadata.Bucket), Key: ptr.Of(key), Body: r, @@ -215,7 +219,7 @@ func (s *AWSS3) create(ctx context.Context, req *bindings.InvokeRequest) (*bindi var presignURL string if metadata.PresignTTL != "" { - url, presignErr := s.presignObject(metadata.Bucket, key, metadata.PresignTTL) + url, presignErr := s.presignObject(ctx, metadata.Bucket, key, metadata.PresignTTL) if presignErr != nil { return nil, fmt.Errorf("s3 binding error: %s", presignErr) } @@ -255,7 +259,7 @@ func (s *AWSS3) presign(ctx context.Context, req *bindings.InvokeRequest) (*bind return nil, fmt.Errorf("s3 binding error: required metadata '%s' missing", metadataPresignTTL) } - url, err := s.presignObject(metadata.Bucket, key, metadata.PresignTTL) + url, err := s.presignObject(ctx, metadata.Bucket, key, metadata.PresignTTL) if err != nil { return nil, fmt.Errorf("s3 binding error: %w", err) } @@ -272,13 +276,12 @@ func (s *AWSS3) presign(ctx context.Context, req *bindings.InvokeRequest) (*bind }, nil } -func (s *AWSS3) presignObject(bucket, key, ttl string) (string, error) { +func (s *AWSS3) presignObject(ctx context.Context, bucket, key, ttl string) (string, error) { d, err := time.ParseDuration(ttl) if err != nil { return "", fmt.Errorf("s3 binding error: cannot parse duration %s: %w", ttl, err) } - - objReq, _ := s.s3Client.GetObjectRequest(&s3.GetObjectInput{ + objReq, _ := s.authProvider.S3().S3.GetObjectRequest(&s3.GetObjectInput{ Bucket: ptr.Of(bucket), Key: ptr.Of(key), }) @@ -302,8 +305,7 @@ func (s *AWSS3) get(ctx context.Context, req *bindings.InvokeRequest) (*bindings } buff := &aws.WriteAtBuffer{} - - _, err = s.downloader.DownloadWithContext(ctx, + _, err = s.authProvider.S3().Downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{ Bucket: ptr.Of(s.metadata.Bucket), @@ -337,8 +339,7 @@ func (s *AWSS3) delete(ctx context.Context, req *bindings.InvokeRequest) (*bindi if key == "" { return nil, fmt.Errorf("s3 binding error: required metadata '%s' missing", metadataKey) } - - _, err := s.s3Client.DeleteObjectWithContext( + _, err := s.authProvider.S3().S3.DeleteObjectWithContext( ctx, &s3.DeleteObjectInput{ Bucket: ptr.Of(s.metadata.Bucket), @@ -367,8 +368,7 @@ func (s *AWSS3) list(ctx context.Context, req *bindings.InvokeRequest) (*binding if payload.MaxResults < 1 { payload.MaxResults = defaultMaxResults } - - result, err := s.s3Client.ListObjectsWithContext(ctx, &s3.ListObjectsInput{ + result, err := s.authProvider.S3().S3.ListObjectsWithContext(ctx, &s3.ListObjectsInput{ Bucket: ptr.Of(s.metadata.Bucket), MaxKeys: ptr.Of(int64(payload.MaxResults)), Marker: ptr.Of(payload.Marker), @@ -415,15 +415,6 @@ func (s *AWSS3) parseMetadata(md bindings.Metadata) (*s3Metadata, error) { return &m, nil } -func (s *AWSS3) getSession(metadata *s3Metadata) (*session.Session, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - - return sess, nil -} - // Helper to merge config and request metadata. func (metadata s3Metadata) mergeWithRequestMetadata(req *bindings.InvokeRequest) (s3Metadata, error) { merged := metadata diff --git a/bindings/aws/ses/ses.go b/bindings/aws/ses/ses.go index 483fde8c64..4cd752bac5 100644 --- a/bindings/aws/ses/ses.go +++ b/bindings/aws/ses/ses.go @@ -38,9 +38,9 @@ const ( // AWSSES is an AWS SNS binding. type AWSSES struct { - metadata *sesMetadata - logger logger.Logger - svc *ses.SES + authProvider awsAuth.Provider + metadata *sesMetadata + logger logger.Logger } type sesMetadata struct { @@ -61,19 +61,29 @@ func NewAWSSES(logger logger.Logger) bindings.OutputBinding { } // Init does metadata parsing. -func (a *AWSSES) Init(_ context.Context, metadata bindings.Metadata) error { +func (a *AWSSES) Init(ctx context.Context, metadata bindings.Metadata) error { // Parse input metadata - meta, err := a.parseMetadata(metadata) + m, err := a.parseMetadata(metadata) if err != nil { return err } - svc, err := a.getClient(meta) + a.metadata = m + + opts := awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: "", + } + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - a.metadata = meta - a.svc = svc + a.authProvider = provider return nil } @@ -141,7 +151,7 @@ func (a *AWSSES) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind } // Attempt to send the email. - result, err := a.svc.SendEmail(input) + result, err := a.authProvider.Ses().Ses.SendEmail(input) if err != nil { return nil, fmt.Errorf("SES binding error. Sending email failed: %w", err) } @@ -158,18 +168,6 @@ func (metadata sesMetadata) mergeWithRequestMetadata(req *bindings.InvokeRequest return merged } -func (a *AWSSES) getClient(metadata *sesMetadata) (*ses.SES, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, "") - if err != nil { - return nil, fmt.Errorf("SES binding error: error creating AWS session %w", err) - } - - // Create an SES instance - svc := ses.New(sess) - - return svc, nil -} - // GetComponentMetadata returns the metadata of the component. func (a *AWSSES) GetComponentMetadata() (metadataInfo contribMetadata.MetadataMap) { metadataStruct := sesMetadata{} @@ -178,5 +176,5 @@ func (a *AWSSES) GetComponentMetadata() (metadataInfo contribMetadata.MetadataMa } func (a *AWSSES) Close() error { - return nil + return a.authProvider.Close() } diff --git a/bindings/aws/sns/sns.go b/bindings/aws/sns/sns.go index 43b63cd2b1..55e3ccefa5 100644 --- a/bindings/aws/sns/sns.go +++ b/bindings/aws/sns/sns.go @@ -30,8 +30,8 @@ import ( // AWSSNS is an AWS SNS binding. type AWSSNS struct { - client *sns.SNS - topicARN string + authProvider awsAuth.Provider + topicARN string logger logger.Logger } @@ -58,16 +58,27 @@ func NewAWSSNS(logger logger.Logger) bindings.OutputBinding { } // Init does metadata parsing. -func (a *AWSSNS) Init(_ context.Context, metadata bindings.Metadata) error { +func (a *AWSSNS) Init(ctx context.Context, metadata bindings.Metadata) error { m, err := a.parseMetadata(metadata) if err != nil { return err } - client, err := a.getClient(m) + + opts := awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + Endpoint: m.Endpoint, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, + } + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - a.client = client + a.authProvider = provider a.topicARN = m.TopicArn return nil @@ -83,16 +94,6 @@ func (a *AWSSNS) parseMetadata(meta bindings.Metadata) (*snsMetadata, error) { return &m, nil } -func (a *AWSSNS) getClient(metadata *snsMetadata) (*sns.SNS, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - c := sns.New(sess) - - return c, nil -} - func (a *AWSSNS) Operations() []bindings.OperationKind { return []bindings.OperationKind{bindings.CreateOperation} } @@ -107,7 +108,7 @@ func (a *AWSSNS) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind msg := fmt.Sprintf("%v", payload.Message) subject := fmt.Sprintf("%v", payload.Subject) - _, err = a.client.PublishWithContext(ctx, &sns.PublishInput{ + _, err = a.authProvider.Sns().Sns.PublishWithContext(ctx, &sns.PublishInput{ Message: &msg, Subject: &subject, TopicArn: &a.topicARN, @@ -127,5 +128,5 @@ func (a *AWSSNS) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { } func (a *AWSSNS) Close() error { - return nil + return a.authProvider.Close() } diff --git a/bindings/aws/sqs/sqs.go b/bindings/aws/sqs/sqs.go index 465e061b61..d803bafc5a 100644 --- a/bindings/aws/sqs/sqs.go +++ b/bindings/aws/sqs/sqs.go @@ -33,13 +33,12 @@ import ( // AWSSQS allows receiving and sending data to/from AWS SQS. type AWSSQS struct { - Client *sqs.SQS - QueueURL *string - - logger logger.Logger - wg sync.WaitGroup - closeCh chan struct{} - closed atomic.Bool + authProvider awsAuth.Provider + queueName string + logger logger.Logger + wg sync.WaitGroup + closeCh chan struct{} + closed atomic.Bool } type sqsMetadata struct { @@ -66,21 +65,22 @@ func (a *AWSSQS) Init(ctx context.Context, metadata bindings.Metadata) error { return err } - client, err := a.getClient(m) - if err != nil { - return err + opts := awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + Endpoint: m.Endpoint, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, } - - queueName := m.QueueName - resultURL, err := client.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{ - QueueName: aws.String(queueName), - }) + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - - a.QueueURL = resultURL.QueueUrl - a.Client = client + a.authProvider = provider + a.queueName = m.QueueName return nil } @@ -91,9 +91,14 @@ func (a *AWSSQS) Operations() []bindings.OperationKind { func (a *AWSSQS) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { msgBody := string(req.Data) - _, err := a.Client.SendMessageWithContext(ctx, &sqs.SendMessageInput{ + url, err := a.authProvider.Sqs().QueueURL(ctx, a.queueName) + if err != nil { + a.logger.Errorf("failed to get queue url: %v", err) + } + + _, err = a.authProvider.Sqs().Sqs.SendMessageWithContext(ctx, &sqs.SendMessageInput{ MessageBody: &msgBody, - QueueUrl: a.QueueURL, + QueueUrl: url, }) return nil, err @@ -113,9 +118,13 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error { if ctx.Err() != nil || a.closed.Load() { return } + url, err := a.authProvider.Sqs().QueueURL(ctx, a.queueName) + if err != nil { + a.logger.Errorf("failed to get queue url: %v", err) + } - result, err := a.Client.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{ - QueueUrl: a.QueueURL, + result, err := a.authProvider.Sqs().Sqs.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{ + QueueUrl: url, AttributeNames: aws.StringSlice([]string{ "SentTimestamp", }), @@ -126,7 +135,7 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error { WaitTimeSeconds: aws.Int64(20), }) if err != nil { - a.logger.Errorf("Unable to receive message from queue %q, %v.", *a.QueueURL, err) + a.logger.Errorf("Unable to receive message from queue %q, %v.", url, err) } if len(result.Messages) > 0 { @@ -140,8 +149,8 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error { msgHandle := m.ReceiptHandle // Use a background context here because ctx may be canceled already - a.Client.DeleteMessageWithContext(context.Background(), &sqs.DeleteMessageInput{ - QueueUrl: a.QueueURL, + a.authProvider.Sqs().Sqs.DeleteMessageWithContext(context.Background(), &sqs.DeleteMessageInput{ + QueueUrl: url, ReceiptHandle: msgHandle, }) } @@ -164,7 +173,7 @@ func (a *AWSSQS) Close() error { close(a.closeCh) } a.wg.Wait() - return nil + return a.authProvider.Close() } func (a *AWSSQS) parseSQSMetadata(meta bindings.Metadata) (*sqsMetadata, error) { @@ -177,16 +186,6 @@ func (a *AWSSQS) parseSQSMetadata(meta bindings.Metadata) (*sqsMetadata, error) return &m, nil } -func (a *AWSSQS) getClient(metadata *sqsMetadata) (*sqs.SQS, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - c := sqs.New(sess) - - return c, nil -} - // GetComponentMetadata returns the metadata of the component. func (a *AWSSQS) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { metadataStruct := sqsMetadata{} diff --git a/common/authentication/aws/aws.go b/common/authentication/aws/aws.go index 48c8b209a4..a45eb48277 100644 --- a/common/authentication/aws/aws.go +++ b/common/authentication/aws/aws.go @@ -20,14 +20,10 @@ import ( "strconv" "time" - awsv2 "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" v2creds "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/feature/rds/auth" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" @@ -38,59 +34,78 @@ type EnvironmentSettings struct { Metadata map[string]string } -func GetConfigV2(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (awsv2.Config, error) { - optFns := []func(*config.LoadOptions) error{} - if region != "" { - optFns = append(optFns, config.WithRegion(region)) - } +type AWSIAM struct { + // Ignored by metadata parser because included in built-in authentication profile + // Access key to use for accessing PostgreSQL. + AWSAccessKey string `json:"awsAccessKey" mapstructure:"awsAccessKey"` + // Secret key to use for accessing PostgreSQL. + AWSSecretKey string `json:"awsSecretKey" mapstructure:"awsSecretKey"` + // AWS region in which PostgreSQL is deployed. + AWSRegion string `json:"awsRegion" mapstructure:"awsRegion"` +} - if accessKey != "" && secretKey != "" { - provider := v2creds.NewStaticCredentialsProvider(accessKey, secretKey, sessionToken) - optFns = append(optFns, config.WithCredentialsProvider(provider)) - } +type AWSIAMAuthOptions struct { + PoolConfig *pgxpool.Config `json:"poolConfig" mapstructure:"poolConfig"` + ConnectionString string `json:"connectionString" mapstructure:"connectionString"` + Region string `json:"region" mapstructure:"region"` + AccessKey string `json:"accessKey" mapstructure:"accessKey"` + SecretKey string `json:"secretKey" mapstructure:"secretKey"` +} - awsCfg, err := config.LoadDefaultConfig(context.Background(), optFns...) - if err != nil { - return awsv2.Config{}, err - } +type Options struct { + Logger logger.Logger + Properties map[string]string - if endpoint != "" { - awsCfg.BaseEndpoint = &endpoint - } + PoolConfig *pgxpool.Config `json:"poolConfig" mapstructure:"poolConfig"` + ConnectionString string `json:"connectionString" mapstructure:"connectionString"` - return awsCfg, nil + Region string `json:"region" mapstructure:"region"` + AccessKey string `json:"accessKey" mapstructure:"accessKey"` + SecretKey string `json:"secretKey" mapstructure:"secretKey"` + + Endpoint string + SessionToken string } -func GetClient(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (*session.Session, error) { - awsConfig := aws.NewConfig() +func GetConfig(opts Options) *aws.Config { + cfg := aws.NewConfig() - if region != "" { - awsConfig = awsConfig.WithRegion(region) + switch { + case opts.Region != "": + cfg.WithRegion(opts.Region) + case opts.Endpoint != "": + cfg.WithEndpoint(opts.Endpoint) } - if accessKey != "" && secretKey != "" { - awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(accessKey, secretKey, sessionToken)) - } + return cfg +} - if endpoint != "" { - awsConfig = awsConfig.WithEndpoint(endpoint) - } +type Provider interface { + S3() *S3Clients + DynamoDB() *DynamoDBClients + Sqs() *SqsClients + Sns() *SnsClients + SnsSqs() *SnsSqsClients + SecretManager() *SecretManagerClients + ParameterStore() *ParameterStoreClients + Kinesis() *KinesisClients + Ses() *SesClients + + Close() error +} - awsSession, err := session.NewSessionWithOptions(session.Options{ - Config: *awsConfig, - SharedConfigState: session.SharedConfigEnable, - }) - if err != nil { - return nil, err - } +func isX509Auth(m map[string]string) bool { + tp, _ := m["trustProfileArn"] + ta, _ := m["trustAnchorArn"] + ar, _ := m["assumeRoleArn"] + return tp != "" && ta != "" && ar != "" +} - userAgentHandler := request.NamedHandler{ - Name: "UserAgentHandler", - Fn: request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion), +func NewProvider(ctx context.Context, opts Options, cfg *aws.Config) (Provider, error) { + if isX509Auth(opts.Properties) { + return newX509(ctx, opts, cfg) } - awsSession.Handlers.Build.PushBackNamed(userAgentHandler) - - return awsSession, nil + return newStaticIAM(ctx, opts, cfg) } // NewEnvironmentSettings returns a new EnvironmentSettings configured for a given AWS resource. @@ -102,25 +117,7 @@ func NewEnvironmentSettings(md map[string]string) (EnvironmentSettings, error) { return es, nil } -type AWSIAM struct { - // Ignored by metadata parser because included in built-in authentication profile - // Access key to use for accessing PostgreSQL. - AWSAccessKey string `json:"awsAccessKey" mapstructure:"awsAccessKey"` - // Secret key to use for accessing PostgreSQL. - AWSSecretKey string `json:"awsSecretKey" mapstructure:"awsSecretKey"` - // AWS region in which PostgreSQL is deployed. - AWSRegion string `json:"awsRegion" mapstructure:"awsRegion"` -} - -type AWSIAMAuthOptions struct { - PoolConfig *pgxpool.Config `json:"poolConfig" mapstructure:"poolConfig"` - ConnectionString string `json:"connectionString" mapstructure:"connectionString"` - Region string `json:"region" mapstructure:"region"` - AccessKey string `json:"accessKey" mapstructure:"accessKey"` - SecretKey string `json:"secretKey" mapstructure:"secretKey"` -} - -func (opts *AWSIAMAuthOptions) GetAccessToken(ctx context.Context) (string, error) { +func (opts *Options) GetAccessToken(ctx context.Context) (string, error) { dbEndpoint := opts.PoolConfig.ConnConfig.Host + ":" + strconv.Itoa(int(opts.PoolConfig.ConnConfig.Port)) var authenticationToken string @@ -160,7 +157,7 @@ func (opts *AWSIAMAuthOptions) GetAccessToken(ctx context.Context) (string, erro return authenticationToken, nil } -func (opts *AWSIAMAuthOptions) InitiateAWSIAMAuth() error { +func (opts *Options) InitiateAWSIAMAuth() error { // Set max connection lifetime to 8 minutes in postgres connection pool configuration. // Note: this will refresh connections before the 15 min expiration on the IAM AWS auth token, // while leveraging the BeforeConnect hook to recreate the token in time dynamically. diff --git a/common/authentication/aws/aws_test.go b/common/authentication/aws/aws_test.go new file mode 100644 index 0000000000..15aac78ad7 --- /dev/null +++ b/common/authentication/aws/aws_test.go @@ -0,0 +1,44 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewEnvironmentSettings(t *testing.T) { + tests := []struct { + name string + metadata map[string]string + }{ + { + name: "valid metadata", + metadata: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := NewEnvironmentSettings(tt.metadata) + require.NoError(t, err) + assert.NotNil(t, result) + }) + } +} diff --git a/common/authentication/aws/client.go b/common/authentication/aws/client.go new file mode 100644 index 0000000000..8d0e9de20b --- /dev/null +++ b/common/authentication/aws/client.go @@ -0,0 +1,209 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + "errors" + "sync" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" + "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/aws/aws-sdk-go/service/secretsmanager" + "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" + "github.com/aws/aws-sdk-go/service/ses" + "github.com/aws/aws-sdk-go/service/sns" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" + "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/aws-sdk-go/service/ssm/ssmiface" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/vmware/vmware-go-kcl/clientlibrary/config" +) + +type Clients struct { + mu sync.RWMutex + + s3 *S3Clients + Dynamo *DynamoDBClients + sns *SnsClients + sqs *SqsClients + snssqs *SnsSqsClients + Secret *SecretManagerClients + ParameterStore *ParameterStoreClients + kinesis *KinesisClients + ses *SesClients +} + +func newClients() *Clients { + return new(Clients) +} + +func (c *Clients) refresh(session *session.Session) { + c.mu.Lock() + defer c.mu.Unlock() + switch { + case c.s3 != nil: + c.s3.New(session) + case c.Dynamo != nil: + c.Dynamo.New(session) + case c.sns != nil: + c.sns.New(session) + case c.sqs != nil: + c.sqs.New(session) + case c.snssqs != nil: + c.snssqs.New(session) + case c.Secret != nil: + c.Secret.New(session) + case c.ParameterStore != nil: + c.ParameterStore.New(session) + case c.kinesis != nil: + c.kinesis.New(session) + case c.ses != nil: + c.ses.New(session) + } +} + +type S3Clients struct { + S3 *s3.S3 + Uploader *s3manager.Uploader + Downloader *s3manager.Downloader +} + +type DynamoDBClients struct { + DynamoDB dynamodbiface.DynamoDBAPI +} + +type SnsSqsClients struct { + Sns *sns.SNS + Sqs *sqs.SQS + Sts *sts.STS +} + +type SnsClients struct { + Sns *sns.SNS +} + +type SqsClients struct { + Sqs sqsiface.SQSAPI +} + +type SecretManagerClients struct { + Manager secretsmanageriface.SecretsManagerAPI +} + +type ParameterStoreClients struct { + Store ssmiface.SSMAPI +} + +type KinesisClients struct { + Kinesis kinesisiface.KinesisAPI + Region string + Credentials *credentials.Credentials +} + +type SesClients struct { + Ses *ses.SES +} + +func (c *S3Clients) New(session *session.Session) { + refreshedS3 := s3.New(session, session.Config) + c.S3 = refreshedS3 + c.Uploader = s3manager.NewUploaderWithClient(refreshedS3) + c.Downloader = s3manager.NewDownloaderWithClient(refreshedS3) +} + +func (c *DynamoDBClients) New(session *session.Session) { + c.DynamoDB = dynamodb.New(session, session.Config) +} + +func (c *SnsClients) New(session *session.Session) { + c.Sns = sns.New(session, session.Config) +} + +func (c *SnsSqsClients) New(session *session.Session) { + c.Sns = sns.New(session, session.Config) + c.Sqs = sqs.New(session, session.Config) + c.Sts = sts.New(session, session.Config) +} + +func (c *SqsClients) New(session *session.Session) { + c.Sqs = sqs.New(session, session.Config) +} + +func (c *SqsClients) QueueURL(ctx context.Context, queueName string) (*string, error) { + if c.Sqs != nil { + resultURL, err := c.Sqs.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{ + QueueName: aws.String(queueName), + }) + if resultURL != nil { + return resultURL.QueueUrl, err + } + } + return nil, errors.New("unable to get queue url due to empty client") +} + +func (c *SecretManagerClients) New(session *session.Session) { + c.Manager = secretsmanager.New(session, session.Config) +} + +func (c *ParameterStoreClients) New(session *session.Session) { + c.Store = ssm.New(session, session.Config) +} + +func (c *KinesisClients) New(session *session.Session) { + c.Kinesis = kinesis.New(session, session.Config) + c.Region = *session.Config.Region + c.Credentials = session.Config.Credentials +} + +func (c *KinesisClients) Stream(ctx context.Context, streamName string) (*string, error) { + if c.Kinesis != nil { + stream, err := c.Kinesis.DescribeStreamWithContext(ctx, &kinesis.DescribeStreamInput{ + StreamName: aws.String(streamName), + }) + if stream != nil { + return stream.StreamDescription.StreamARN, err + } + } + + return nil, errors.New("unable to get stream arn due to empty client") +} + +func (c *KinesisClients) WorkerCfg(ctx context.Context, stream, consumer, mode string) *config.KinesisClientLibConfiguration { + const sharedMode = "shared" + if c.Kinesis != nil { + if mode == sharedMode { + if c.Credentials != nil { + kclConfig := config.NewKinesisClientLibConfigWithCredential(consumer, + stream, c.Region, consumer, + c.Credentials) + return kclConfig + } + } + } + + return nil +} + +func (c *SesClients) New(session *session.Session) { + c.Ses = ses.New(session, session.Config) +} diff --git a/common/authentication/aws/client_fake.go b/common/authentication/aws/client_fake.go new file mode 100644 index 0000000000..c9e23641ba --- /dev/null +++ b/common/authentication/aws/client_fake.go @@ -0,0 +1,79 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" + "github.com/aws/aws-sdk-go/service/secretsmanager" + "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" + "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/aws-sdk-go/service/ssm/ssmiface" +) + +type MockParameterStore struct { + GetParameterFn func(context.Context, *ssm.GetParameterInput, ...request.Option) (*ssm.GetParameterOutput, error) + DescribeParametersFn func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) + ssmiface.SSMAPI +} + +func (m *MockParameterStore) GetParameterWithContext(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + return m.GetParameterFn(ctx, input, option...) +} + +func (m *MockParameterStore) DescribeParametersWithContext(ctx context.Context, input *ssm.DescribeParametersInput, option ...request.Option) (*ssm.DescribeParametersOutput, error) { + return m.DescribeParametersFn(ctx, input, option...) +} + +type MockSecretManager struct { + GetSecretValueFn func(context.Context, *secretsmanager.GetSecretValueInput, ...request.Option) (*secretsmanager.GetSecretValueOutput, error) + secretsmanageriface.SecretsManagerAPI +} + +func (m *MockSecretManager) GetSecretValueWithContext(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { + return m.GetSecretValueFn(ctx, input, option...) +} + +type MockDynamoDB struct { + GetItemWithContextFn func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) + PutItemWithContextFn func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) + DeleteItemWithContextFn func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) + BatchWriteItemWithContextFn func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) + TransactWriteItemsWithContextFn func(aws.Context, *dynamodb.TransactWriteItemsInput, ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) + dynamodbiface.DynamoDBAPI +} + +func (m *MockDynamoDB) GetItemWithContext(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { + return m.GetItemWithContextFn(ctx, input, op...) +} + +func (m *MockDynamoDB) PutItemWithContext(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) { + return m.PutItemWithContextFn(ctx, input, op...) +} + +func (m *MockDynamoDB) DeleteItemWithContext(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) { + return m.DeleteItemWithContextFn(ctx, input, op...) +} + +func (m *MockDynamoDB) BatchWriteItemWithContext(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) { + return m.BatchWriteItemWithContextFn(ctx, input, op...) +} + +func (m *MockDynamoDB) TransactWriteItemsWithContext(ctx context.Context, input *dynamodb.TransactWriteItemsInput, op ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) { + return m.TransactWriteItemsWithContextFn(ctx, input, op...) +} diff --git a/common/authentication/aws/client_test.go b/common/authentication/aws/client_test.go new file mode 100644 index 0000000000..67d2ac88f3 --- /dev/null +++ b/common/authentication/aws/client_test.go @@ -0,0 +1,265 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + "errors" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vmware/vmware-go-kcl/clientlibrary/config" +) + +type mockedSQS struct { + sqsiface.SQSAPI + GetQueueURLFn func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) +} + +func (m *mockedSQS) GetQueueUrlWithContext(ctx context.Context, input *sqs.GetQueueUrlInput, opts ...request.Option) (*sqs.GetQueueUrlOutput, error) { //nolint:stylecheck + return m.GetQueueURLFn(ctx, input) +} + +type mockedKinesis struct { + kinesisiface.KinesisAPI + DescribeStreamFn func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) +} + +func (m *mockedKinesis) DescribeStreamWithContext(ctx context.Context, input *kinesis.DescribeStreamInput, opts ...request.Option) (*kinesis.DescribeStreamOutput, error) { + return m.DescribeStreamFn(ctx, input) +} + +func TestS3Clients_New(t *testing.T) { + tests := []struct { + name string + s3Client *S3Clients + session *session.Session + }{ + {"initializes S3 client", &S3Clients{}, session.Must(session.NewSession())}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.s3Client.New(tt.session) + require.NotNil(t, tt.s3Client.S3) + require.NotNil(t, tt.s3Client.Uploader) + require.NotNil(t, tt.s3Client.Downloader) + }) + } +} + +func TestSqsClients_QueueURL(t *testing.T) { + tests := []struct { + name string + mockFn func() *mockedSQS + queueName string + expectedURL *string + expectError bool + }{ + { + name: "returns queue URL successfully", + mockFn: func() *mockedSQS { + return &mockedSQS{ + GetQueueURLFn: func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) { + return &sqs.GetQueueUrlOutput{ + QueueUrl: aws.String("https://sqs.aws.com/123456789012/queue"), + }, nil + }, + } + }, + queueName: "valid-queue", + expectedURL: aws.String("https://sqs.aws.com/123456789012/queue"), + expectError: false, + }, + { + name: "returns error when queue URL not found", + mockFn: func() *mockedSQS { + return &mockedSQS{ + GetQueueURLFn: func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) { + return nil, errors.New("unable to get stream arn due to empty client") + }, + } + }, + queueName: "missing-queue", + expectedURL: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockSQS := tt.mockFn() + + // Initialize SqsClients with the mocked SQS client + client := &SqsClients{ + Sqs: mockSQS, + } + + url, err := client.QueueURL(context.Background(), tt.queueName) + + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedURL, url) + } + }) + } +} + +func TestKinesisClients_Stream(t *testing.T) { + tests := []struct { + name string + kinesisClient *KinesisClients + streamName string + mockStreamARN *string + mockError error + expectedStream *string + expectedErr error + }{ + { + name: "successfully retrieves stream ARN", + kinesisClient: &KinesisClients{ + Kinesis: &mockedKinesis{DescribeStreamFn: func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + return &kinesis.DescribeStreamOutput{ + StreamDescription: &kinesis.StreamDescription{ + StreamARN: aws.String("arn:aws:kinesis:some-region:123456789012:stream/some-stream"), + }, + }, nil + }}, + Region: "us-west-1", + Credentials: credentials.NewStaticCredentials("accessKey", "secretKey", ""), + }, + streamName: "some-stream", + expectedStream: aws.String("arn:aws:kinesis:some-region:123456789012:stream/some-stream"), + expectedErr: nil, + }, + { + name: "returns error when stream not found", + kinesisClient: &KinesisClients{ + Kinesis: &mockedKinesis{DescribeStreamFn: func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + return nil, errors.New("stream not found") + }}, + Region: "us-west-1", + Credentials: credentials.NewStaticCredentials("accessKey", "secretKey", ""), + }, + streamName: "nonexistent-stream", + expectedStream: nil, + expectedErr: errors.New("unable to get stream arn due to empty client"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kinesisClient.Stream(context.Background(), tt.streamName) + if tt.expectedErr != nil { + require.Error(t, err) + assert.Equal(t, tt.expectedErr.Error(), err.Error()) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedStream, got) + } + }) + } +} + +func TestKinesisClients_WorkerCfg(t *testing.T) { + testCreds := credentials.NewStaticCredentials("accessKey", "secretKey", "") + tests := []struct { + name string + kinesisClient *KinesisClients + streamName string + consumer string + mode string + expectedConfig *config.KinesisClientLibConfiguration + }{ + { + name: "successfully creates shared mode worker config", + kinesisClient: &KinesisClients{ + Kinesis: &mockedKinesis{ + DescribeStreamFn: func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + return &kinesis.DescribeStreamOutput{ + StreamDescription: &kinesis.StreamDescription{ + StreamARN: aws.String("arn:aws:kinesis:us-east-1:123456789012:stream/existing-stream"), + }, + }, nil + }, + }, + Region: "us-west-1", + Credentials: testCreds, + }, + streamName: "existing-stream", + consumer: "consumer1", + mode: "shared", + expectedConfig: config.NewKinesisClientLibConfigWithCredential( + "consumer1", "existing-stream", "us-west-1", "consumer1", testCreds, + ), + }, + { + name: "returns nil when mode is not shared", + kinesisClient: &KinesisClients{ + Kinesis: &mockedKinesis{ + DescribeStreamFn: func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + return &kinesis.DescribeStreamOutput{ + StreamDescription: &kinesis.StreamDescription{ + StreamARN: aws.String("arn:aws:kinesis:us-east-1:123456789012:stream/existing-stream"), + }, + }, nil + }, + }, + Region: "us-west-1", + Credentials: testCreds, + }, + streamName: "existing-stream", + consumer: "consumer1", + mode: "exclusive", + expectedConfig: nil, + }, + { + name: "returns nil when client is nil", + kinesisClient: &KinesisClients{ + Kinesis: nil, + Region: "us-west-1", + Credentials: credentials.NewStaticCredentials("accessKey", "secretKey", ""), + }, + streamName: "existing-stream", + consumer: "consumer1", + mode: "shared", + expectedConfig: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := tt.kinesisClient.WorkerCfg(context.Background(), tt.streamName, tt.consumer, tt.mode) + if tt.expectedConfig == nil { + assert.Equal(t, tt.expectedConfig, cfg) + return + } + assert.Equal(t, tt.expectedConfig.StreamName, cfg.StreamName) + assert.Equal(t, tt.expectedConfig.EnhancedFanOutConsumerName, cfg.EnhancedFanOutConsumerName) + assert.Equal(t, tt.expectedConfig.EnableEnhancedFanOutConsumer, cfg.EnableEnhancedFanOutConsumer) + assert.Equal(t, tt.expectedConfig.RegionName, cfg.RegionName) + }) + } +} diff --git a/common/authentication/aws/static.go b/common/authentication/aws/static.go new file mode 100644 index 0000000000..a66ef86e1e --- /dev/null +++ b/common/authentication/aws/static.go @@ -0,0 +1,272 @@ +/* +Copyright 2021 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + "fmt" + "sync" + + awsv2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + v2creds "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + + "github.com/dapr/kit/logger" +) + +type StaticAuth struct { + mu sync.RWMutex + logger logger.Logger + + region *string + endpoint *string + accessKey *string + secretKey *string + sessionToken *string + + session *session.Session + cfg *aws.Config + clients *Clients +} + +func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth, error) { + auth := &StaticAuth{ + logger: opts.Logger, + region: &opts.Region, + endpoint: &opts.Endpoint, + accessKey: &opts.AccessKey, + secretKey: &opts.SecretKey, + sessionToken: &opts.SessionToken, + cfg: func() *aws.Config { + // if nil is passed or it's just a default cfg, + // then we use the options to build the aws cfg. + if cfg != nil && cfg != aws.NewConfig() { + return cfg + } + return GetConfig(opts) + }(), + clients: newClients(), + } + + initialSession, err := auth.getTokenClient() + if err != nil { + return nil, fmt.Errorf("failed to get token client: %v", err) + } + + auth.session = initialSession + + return auth, nil +} + +// This is to be used only for test purposes to inject mocked clients +func (a *StaticAuth) WithMockClients(clients *Clients) { + a.clients = clients +} + +func (a *StaticAuth) S3() *S3Clients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.s3 != nil { + return a.clients.s3 + } + + s3Clients := S3Clients{} + a.clients.s3 = &s3Clients + a.clients.s3.New(a.session) + return a.clients.s3 +} + +func (a *StaticAuth) DynamoDB() *DynamoDBClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.Dynamo != nil { + return a.clients.Dynamo + } + + clients := DynamoDBClients{} + a.clients.Dynamo = &clients + a.clients.Dynamo.New(a.session) + + return a.clients.Dynamo +} + +func (a *StaticAuth) Sqs() *SqsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.sqs != nil { + return a.clients.sqs + } + + clients := SqsClients{} + a.clients.sqs = &clients + a.clients.sqs.New(a.session) + + return a.clients.sqs +} + +func (a *StaticAuth) Sns() *SnsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.sns != nil { + return a.clients.sns + } + + clients := SnsClients{} + a.clients.sns = &clients + a.clients.sns.New(a.session) + return a.clients.sns +} + +func (a *StaticAuth) SnsSqs() *SnsSqsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.snssqs != nil { + return a.clients.snssqs + } + + clients := SnsSqsClients{} + a.clients.snssqs = &clients + a.clients.snssqs.New(a.session) + return a.clients.snssqs +} + +func (a *StaticAuth) SecretManager() *SecretManagerClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.Secret != nil { + return a.clients.Secret + } + + clients := SecretManagerClients{} + a.clients.Secret = &clients + a.clients.Secret.New(a.session) + return a.clients.Secret +} + +func (a *StaticAuth) ParameterStore() *ParameterStoreClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.ParameterStore != nil { + return a.clients.ParameterStore + } + + clients := ParameterStoreClients{} + a.clients.ParameterStore = &clients + a.clients.ParameterStore.New(a.session) + return a.clients.ParameterStore +} + +func (a *StaticAuth) Kinesis() *KinesisClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.kinesis != nil { + return a.clients.kinesis + } + + clients := KinesisClients{} + a.clients.kinesis = &clients + a.clients.kinesis.New(a.session) + return a.clients.kinesis +} + +func (a *StaticAuth) Ses() *SesClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.ses != nil { + return a.clients.ses + } + + clients := SesClients{} + a.clients.ses = &clients + a.clients.ses.New(a.session) + return a.clients.ses +} + +func (a *StaticAuth) getTokenClient() (*session.Session, error) { + var awsConfig *aws.Config + if a.cfg == nil { + awsConfig = aws.NewConfig() + } else { + awsConfig = a.cfg + } + + if a.region != nil { + awsConfig = awsConfig.WithRegion(*a.region) + } + + if a.accessKey != nil && a.secretKey != nil { + // session token is an option field + awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(*a.accessKey, *a.secretKey, *a.sessionToken)) + } + + if a.endpoint != nil { + awsConfig = awsConfig.WithEndpoint(*a.endpoint) + } + + awsSession, err := session.NewSessionWithOptions(session.Options{ + Config: *awsConfig, + SharedConfigState: session.SharedConfigEnable, + }) + if err != nil { + return nil, err + } + + userAgentHandler := request.NamedHandler{ + Name: "UserAgentHandler", + Fn: request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion), + } + awsSession.Handlers.Build.PushBackNamed(userAgentHandler) + + return awsSession, nil +} + +func (a *StaticAuth) Close() error { + return nil +} + +func GetConfigV2(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (awsv2.Config, error) { + optFns := []func(*config.LoadOptions) error{} + if region != "" { + optFns = append(optFns, config.WithRegion(region)) + } + + if accessKey != "" && secretKey != "" { + provider := v2creds.NewStaticCredentialsProvider(accessKey, secretKey, sessionToken) + optFns = append(optFns, config.WithCredentialsProvider(provider)) + } + + awsCfg, err := config.LoadDefaultConfig(context.Background(), optFns...) + if err != nil { + return awsv2.Config{}, err + } + + if endpoint != "" { + awsCfg.BaseEndpoint = &endpoint + } + + return awsCfg, nil +} diff --git a/common/authentication/aws/static_test.go b/common/authentication/aws/static_test.go new file mode 100644 index 0000000000..a1a17a093c --- /dev/null +++ b/common/authentication/aws/static_test.go @@ -0,0 +1,66 @@ +package aws + +import ( + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetConfigV2(t *testing.T) { + tests := []struct { + name string + accessKey string + secretKey string + sessionToken string + region string + endpoint string + }{ + { + name: "valid config", + accessKey: "testAccessKey", + secretKey: "testSecretKey", + sessionToken: "testSessionToken", + region: "us-west-2", + endpoint: "https://test.endpoint.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + awsCfg, err := GetConfigV2(tt.accessKey, tt.secretKey, tt.sessionToken, tt.region, tt.endpoint) + require.NoError(t, err) + assert.NotNil(t, awsCfg) + assert.Equal(t, tt.region, awsCfg.Region) + assert.Equal(t, tt.endpoint, *awsCfg.BaseEndpoint) + }) + } +} + +func TestGetTokenClient(t *testing.T) { + tests := []struct { + name string + awsInstance *StaticAuth + }{ + { + name: "valid token client", + awsInstance: &StaticAuth{ + accessKey: aws.String("testAccessKey"), + secretKey: aws.String("testSecretKey"), + sessionToken: aws.String("testSessionToken"), + region: aws.String("us-west-2"), + endpoint: aws.String("https://test.endpoint.com"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + session, err := tt.awsInstance.getTokenClient() + require.NotNil(t, session) + require.NoError(t, err) + assert.Equal(t, tt.awsInstance.region, session.Config.Region) + }) + } +} diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go new file mode 100644 index 0000000000..cb1bafdeb3 --- /dev/null +++ b/common/authentication/aws/x509.go @@ -0,0 +1,449 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + "crypto/ecdsa" + "crypto/tls" + cryptoX509 "crypto/x509" + "errors" + "fmt" + "net/http" + "runtime" + "sync" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + awssh "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere/rolesanywhereiface" + + cryptopem "github.com/dapr/kit/crypto/pem" + spiffecontext "github.com/dapr/kit/crypto/spiffe/context" + "github.com/dapr/kit/logger" + kitmd "github.com/dapr/kit/metadata" + "github.com/dapr/kit/ptr" +) + +type x509Options struct { + TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"` + TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"` + AssumeRoleArn *string `json:"assumeRoleArn" mapstructure:"assumeRoleArn"` +} + +type x509 struct { + mu sync.RWMutex + wg sync.WaitGroup + closeCh chan struct{} + + logger logger.Logger + clients *Clients + rolesAnywhereClient rolesanywhereiface.RolesAnywhereAPI // this is so we can mock it in tests + session *session.Session + cfg *aws.Config + + chainPEM []byte + keyPEM []byte + + region *string + trustProfileArn *string + trustAnchorArn *string + assumeRoleArn *string +} + +func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) { + var x509Auth x509Options + if err := kitmd.DecodeMetadata(opts.Properties, &x509Auth); err != nil { + return nil, err + } + + switch { + case x509Auth.TrustProfileArn == nil: + return nil, errors.New("trustProfileArn is required") + case x509Auth.TrustAnchorArn == nil: + return nil, errors.New("trustAnchorArn is required") + case x509Auth.AssumeRoleArn == nil: + return nil, errors.New("assumeRoleArn is required") + } + + auth := &x509{ + logger: opts.Logger, + trustProfileArn: x509Auth.TrustProfileArn, + trustAnchorArn: x509Auth.TrustAnchorArn, + assumeRoleArn: x509Auth.AssumeRoleArn, + cfg: func() *aws.Config { + // if nil is passed or it's just a default cfg, + // then we use the options to build the aws cfg. + if cfg != nil && cfg != aws.NewConfig() { + return cfg + } + return GetConfig(opts) + }(), + clients: newClients(), + } + + if err := auth.getCertPEM(ctx); err != nil { + return nil, fmt.Errorf("failed to get x.509 credentials: %v", err) + } + + // Parse trust anchor and profile ARNs + if err := auth.initializeTrustAnchors(); err != nil { + return nil, err + } + + initialSession, err := auth.createOrRefreshSession(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create the initial session: %v", err) + } + auth.session = initialSession + auth.startSessionRefresher() + + return auth, nil +} + +func (a *x509) Close() error { + close(a.closeCh) + a.wg.Wait() + return nil +} + +func (a *x509) getCertPEM(ctx context.Context) error { + // retrieve svid from spiffe context + svid, ok := spiffecontext.From(ctx) + if !ok { + return errors.New("no SVID found in context") + } + // get x.509 svid + svidx, err := svid.GetX509SVID() + if err != nil { + return err + } + + // marshal x.509 svid to pem format + chainPEM, keyPEM, err := svidx.Marshal() + if err != nil { + return fmt.Errorf("failed to marshal SVID: %w", err) + } + + a.chainPEM = chainPEM + a.keyPEM = keyPEM + return nil +} + +func (a *x509) S3() *S3Clients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.s3 != nil { + return a.clients.s3 + } + + s3Clients := S3Clients{} + a.clients.s3 = &s3Clients + a.clients.s3.New(a.session) + return a.clients.s3 +} + +func (a *x509) DynamoDB() *DynamoDBClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.Dynamo != nil { + return a.clients.Dynamo + } + + clients := DynamoDBClients{} + a.clients.Dynamo = &clients + a.clients.Dynamo.New(a.session) + + return a.clients.Dynamo +} + +func (a *x509) Sqs() *SqsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.sqs != nil { + return a.clients.sqs + } + + clients := SqsClients{} + a.clients.sqs = &clients + a.clients.sqs.New(a.session) + + return a.clients.sqs +} + +func (a *x509) Sns() *SnsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.sns != nil { + return a.clients.sns + } + + clients := SnsClients{} + a.clients.sns = &clients + a.clients.sns.New(a.session) + return a.clients.sns +} + +func (a *x509) SnsSqs() *SnsSqsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.snssqs != nil { + return a.clients.snssqs + } + + clients := SnsSqsClients{} + a.clients.snssqs = &clients + a.clients.snssqs.New(a.session) + return a.clients.snssqs +} + +func (a *x509) SecretManager() *SecretManagerClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.Secret != nil { + return a.clients.Secret + } + + clients := SecretManagerClients{} + a.clients.Secret = &clients + a.clients.Secret.New(a.session) + return a.clients.Secret +} + +func (a *x509) ParameterStore() *ParameterStoreClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.ParameterStore != nil { + return a.clients.ParameterStore + } + + clients := ParameterStoreClients{} + a.clients.ParameterStore = &clients + a.clients.ParameterStore.New(a.session) + return a.clients.ParameterStore +} + +func (a *x509) Kinesis() *KinesisClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.kinesis != nil { + return a.clients.kinesis + } + + clients := KinesisClients{} + a.clients.kinesis = &clients + a.clients.kinesis.New(a.session) + return a.clients.kinesis +} + +func (a *x509) Ses() *SesClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.ses != nil { + return a.clients.ses + } + + clients := SesClients{} + a.clients.ses = &clients + a.clients.ses.New(a.session) + return a.clients.ses +} + +func (a *x509) initializeTrustAnchors() error { + var ( + trustAnchor arn.ARN + profile arn.ARN + err error + ) + if a.trustAnchorArn != nil { + trustAnchor, err = arn.Parse(*a.trustAnchorArn) + if err != nil { + return err + } + a.region = &trustAnchor.Region + } + + if a.trustProfileArn != nil { + profile, err = arn.Parse(*a.trustProfileArn) + if err != nil { + return err + } + + if profile.Region != "" && trustAnchor.Region != profile.Region { + return fmt.Errorf("trust anchor and profile must be in the same region: trustAnchor=%s, profile=%s", + trustAnchor.Region, profile.Region) + } + } + return nil +} + +func (a *x509) setSigningFunction(rolesAnywhereClient *rolesanywhere.RolesAnywhere) error { + certs, err := cryptopem.DecodePEMCertificatesChain(a.chainPEM) + if err != nil { + return err + } + + ints := make([]cryptoX509.Certificate, 0, len(certs)-1) + for i := range certs[1:] { + ints = append(ints, *certs[i+1]) + } + + key, err := cryptopem.DecodePEMPrivateKey(a.keyPEM) + if err != nil { + return err + } + + keyECDSA := key.(*ecdsa.PrivateKey) + signFunc := awssh.CreateSignFunction(*keyECDSA, *certs[0], ints) + + agentHandlerFunc := request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH) + rolesAnywhereClient.Handlers.Build.RemoveByName("core.SDKVersionUserAgentHandler") + rolesAnywhereClient.Handlers.Build.PushBackNamed(request.NamedHandler{Name: "v4x509.CredHelperUserAgentHandler", Fn: agentHandlerFunc}) + rolesAnywhereClient.Handlers.Sign.Clear() + rolesAnywhereClient.Handlers.Sign.PushBackNamed(request.NamedHandler{Name: "v4x509.SignRequestHandler", Fn: signFunc}) + + return nil +} + +func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, error) { + a.mu.Lock() + defer a.mu.Unlock() + + client := &http.Client{Transport: &http.Transport{ + TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, + }} + var mySession *session.Session + + var awsConfig *aws.Config + if a.cfg == nil { + awsConfig = aws.NewConfig().WithHTTPClient(client).WithLogLevel(aws.LogOff) + } else { + awsConfig = a.cfg.WithHTTPClient(client).WithLogLevel(aws.LogOff) + } + if a.region != nil { + awsConfig.WithRegion(*a.region) + } + // this is needed for testing purposes to mock the client, + // so code never sets the client, but tests do. + var rolesClient *rolesanywhere.RolesAnywhere + if a.rolesAnywhereClient == nil { + mySession = session.Must(session.NewSession(awsConfig)) + rolesAnywhereClient := rolesanywhere.New(mySession, awsConfig) + // Set up signing function and handlers + if err := a.setSigningFunction(rolesAnywhereClient); err != nil { + return nil, err + } + rolesClient = rolesAnywhereClient + } + + createSessionRequest := rolesanywhere.CreateSessionInput{ + Cert: ptr.Of(string(a.chainPEM)), + ProfileArn: a.trustProfileArn, + TrustAnchorArn: a.trustAnchorArn, + RoleArn: a.assumeRoleArn, + // https://aws.amazon.com/about-aws/whats-new/2024/03/iam-roles-anywhere-credentials-valid-12-hours/#:~:text=The%20duration%20can%20range%20from,and%20applications%2C%20to%20use%20X. + DurationSeconds: aws.Int64(int64(time.Hour.Seconds())), // AWS default is 1hr timeout + InstanceProperties: nil, + SessionName: nil, + } + + var output *rolesanywhere.CreateSessionOutput + if a.rolesAnywhereClient != nil { + var err error + output, err = a.rolesAnywhereClient.CreateSessionWithContext(ctx, &createSessionRequest) + if err != nil { + return nil, fmt.Errorf("failed to create session using dapr app identity: %w", err) + } + } else { + var err error + output, err = rolesClient.CreateSessionWithContext(ctx, &createSessionRequest) + if err != nil { + return nil, fmt.Errorf("failed to create session using dapr app identity: %w", err) + } + } + + if output == nil || len(output.CredentialSet) != 1 { + return nil, fmt.Errorf("expected 1 credential set from X.509 rolesanyway response, got %d", len(output.CredentialSet)) + } + + accessKey := output.CredentialSet[0].Credentials.AccessKeyId + secretKey := output.CredentialSet[0].Credentials.SecretAccessKey + sessionToken := output.CredentialSet[0].Credentials.SessionToken + awsCreds := credentials.NewStaticCredentials(*accessKey, *secretKey, *sessionToken) + sess := session.Must(session.NewSession(&aws.Config{ + Credentials: awsCreds, + }, awsConfig)) + if sess == nil { + return nil, errors.New("session is nil") + } + + return sess, nil +} + +func (a *x509) startSessionRefresher() { + a.logger.Infof("starting session refresher for x509 auth") + + a.wg.Add(1) + go func() { + defer a.wg.Done() + for { + // renew at ~half the lifespan + expiration, err := a.session.Config.Credentials.ExpiresAt() + if err != nil { + a.logger.Errorf("Failed to retrieve session expiration time, using 30 minute interval: %w", err) + expiration = time.Now().Add(time.Hour) + } + timeUntilExpiration := time.Until(expiration) + refreshInterval := timeUntilExpiration / 2 + select { + case <-time.After(refreshInterval): + a.refreshClient() + case <-a.closeCh: + a.logger.Debugf("Session refresher is stopped") + return + } + } + }() +} + +func (a *x509) refreshClient() { + for { + newSession, err := a.createOrRefreshSession(context.Background()) + if err == nil { + a.clients.refresh(newSession) + a.logger.Debugf("AWS IAM Roles Anywhere session credentials refreshed successfully") + return + } + a.logger.Errorf("Failed to refresh session, retrying in 5 seconds: %w", err) + select { + case <-time.After(time.Second * 5): + case <-a.closeCh: + return + } + } +} diff --git a/common/authentication/aws/x509_test.go b/common/authentication/aws/x509_test.go new file mode 100644 index 0000000000..3f7d2189c3 --- /dev/null +++ b/common/authentication/aws/x509_test.go @@ -0,0 +1,125 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + cryptoX509 "crypto/x509" + "sync/atomic" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere/rolesanywhereiface" + "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/dapr/kit/crypto/spiffe" + spiffecontext "github.com/dapr/kit/crypto/spiffe/context" + "github.com/dapr/kit/crypto/test" + "github.com/dapr/kit/logger" +) + +type mockRolesAnywhereClient struct { + rolesanywhereiface.RolesAnywhereAPI + + CreateSessionOutput *rolesanywhere.CreateSessionOutput + CreateSessionError error +} + +func (m *mockRolesAnywhereClient) CreateSessionWithContext(ctx context.Context, input *rolesanywhere.CreateSessionInput, opts ...request.Option) (*rolesanywhere.CreateSessionOutput, error) { + return m.CreateSessionOutput, m.CreateSessionError +} + +func TestGetX509Client(t *testing.T) { + tests := []struct { + name string + mockOutput *rolesanywhere.CreateSessionOutput + mockError error + }{ + { + name: "valid x509 client", + mockOutput: &rolesanywhere.CreateSessionOutput{ + CredentialSet: []*rolesanywhere.CredentialResponse{ + { + Credentials: &rolesanywhere.Credentials{ + AccessKeyId: aws.String("mockAccessKeyId"), + SecretAccessKey: aws.String("mockSecretAccessKey"), + SessionToken: aws.String("mockSessionToken"), + Expiration: aws.String(time.Now().Add(15 * time.Minute).Format(time.RFC3339)), + }, + }, + }, + }, + mockError: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockSvc := &mockRolesAnywhereClient{ + CreateSessionOutput: tt.mockOutput, + CreateSessionError: tt.mockError, + } + mockAWS := x509{ + logger: logger.NewLogger("testLogger"), + assumeRoleArn: aws.String("arn:aws:iam:012345678910:role/exampleIAMRoleName"), + trustAnchorArn: aws.String("arn:aws:rolesanywhere:us-west-1:012345678910:trust-anchor/01234568-0123-0123-0123-012345678901"), + trustProfileArn: aws.String("arn:aws:rolesanywhere:us-west-1:012345678910:profile/01234568-0123-0123-0123-012345678901"), + rolesAnywhereClient: mockSvc, + } + pki := test.GenPKI(t, test.PKIOptions{ + LeafID: spiffeid.RequireFromString("spiffe://example.com/foo/bar"), + }) + + respCert := []*cryptoX509.Certificate{pki.LeafCert} + var respErr error + + var fetches atomic.Int32 + s := spiffe.New(spiffe.Options{ + Log: logger.NewLogger("test"), + RequestSVIDFn: func(context.Context, []byte) ([]*cryptoX509.Certificate, error) { + fetches.Add(1) + return respCert, respErr + }, + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + errCh := make(chan error) + go func() { + errCh <- s.Run(ctx) + }() + + select { + case err := <-errCh: + require.NoError(t, err) + default: + } + + err := s.Ready(ctx) + require.NoError(t, err) + + // inject the SVID source into the context + ctx = spiffecontext.With(ctx, s) + session, err := mockAWS.createOrRefreshSession(ctx) + + require.NoError(t, err) + assert.NotNil(t, session) + }) + } +} diff --git a/common/authentication/postgresql/metadata.go b/common/authentication/postgresql/metadata.go index 7cacecfaa4..4b2135ba6a 100644 --- a/common/authentication/postgresql/metadata.go +++ b/common/authentication/postgresql/metadata.go @@ -162,7 +162,7 @@ func (m *PostgresAuthMetadata) GetPgxPoolConfig() (*pgxpool.Config, error) { return nil, err } - awsOpts := aws.AWSIAMAuthOptions{ + awsOpts := aws.Options{ PoolConfig: config, ConnectionString: m.ConnectionString, Region: awsRegion, diff --git a/go.mod b/go.mod index 13a6af3ab6..a8ece053a6 100644 --- a/go.mod +++ b/go.mod @@ -46,6 +46,7 @@ require ( github.com/aws/aws-sdk-go-v2/credentials v1.17.37 github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.3.10 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.17.3 + github.com/aws/rolesanywhere-credential-helper v1.0.4 github.com/bradfitz/gomemcache v0.0.0-20230905024940-24af94b03874 github.com/camunda/zeebe/clients/go/v8 v8.2.12 github.com/cenkalti/backoff/v4 v4.2.1 @@ -106,6 +107,7 @@ require ( github.com/sendgrid/sendgrid-go v3.13.0+incompatible github.com/sijms/go-ora/v2 v2.7.18 github.com/spf13/cast v1.5.1 + github.com/spiffe/go-spiffe/v2 v2.1.7 github.com/stealthrocket/wasi-go v0.8.1-0.20230912180546-8efbab50fb58 github.com/stretchr/testify v1.9.0 github.com/supplyon/gremcos v0.1.40 @@ -379,6 +381,7 @@ require ( github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 // indirect github.com/yuin/gopher-lua v1.1.0 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect + github.com/zeebo/errs v1.3.0 // indirect go.etcd.io/etcd/api/v3 v3.5.9 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.9 // indirect go.opencensus.io v0.24.0 // indirect @@ -402,6 +405,7 @@ require ( google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240513163218-0867130af1f8 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240509183442-62759503f434 // indirect + google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gopkg.in/couchbase/gocbcore.v7 v7.1.18 // indirect gopkg.in/couchbaselabs/gocbconnstr.v1 v1.0.4 // indirect diff --git a/go.sum b/go.sum index 54c28416d8..9bafb4a502 100644 --- a/go.sum +++ b/go.sum @@ -124,6 +124,8 @@ github.com/HdrHistogram/hdrhistogram-go v1.1.2/go.mod h1:yDgFjdqOqDEKOvasDdhWNXY github.com/IBM/sarama v1.43.3 h1:Yj6L2IaNvb2mRBop39N7mmJAHBVY3dTPncr3qGVkxPA= github.com/IBM/sarama v1.43.3/go.mod h1:FVIRaLrhK3Cla/9FfRF5X9Zua2KpS3SYIXxhac1H+FQ= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= +github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= +github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= github.com/Netflix/go-env v0.0.0-20220526054621-78278af1949d h1:wvStE9wLpws31NiWUx+38wny1msZ/tm+eL5xmm4Y7So= github.com/Netflix/go-env v0.0.0-20220526054621-78278af1949d/go.mod h1:9XMFaCeRyW7fC9XJOWQ+NdAv8VLG7ys7l3x4ozEGLUQ= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= @@ -289,6 +291,8 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3/go.mod h1:FnvDM4sfa+isJ3kDX github.com/aws/aws-sdk-go-v2/service/sts v1.7.2/go.mod h1:8EzeIqfWt2wWT4rJVu3f21TfrhJ8AEMzVybRNSb/b4g= github.com/aws/aws-sdk-go-v2/service/sts v1.31.3 h1:VzudTFrDCIDakXtemR7l6Qzt2+JYsVqo2MxBPt5k8T8= github.com/aws/aws-sdk-go-v2/service/sts v1.31.3/go.mod h1:yMWe0F+XG0DkRZK5ODZhG7BEFYhLXi2dqGsv6tX0cgI= +github.com/aws/rolesanywhere-credential-helper v1.0.4 h1:kHIVVdyQQiFZoKBP+zywBdFilGCS8It+UvW5LolKbW8= +github.com/aws/rolesanywhere-credential-helper v1.0.4/go.mod h1:QVGNxlDlYhjR0/ZUee7uGl0hNChWidNpe2+GD87Buqk= github.com/aws/smithy-go v1.8.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= github.com/aws/smithy-go v1.21.0 h1:H7L8dtDRk0P1Qm6y0ji7MCYMQObJ5R9CRpyPhRUkLYA= github.com/aws/smithy-go v1.21.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= @@ -604,6 +608,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= +github.com/go-jose/go-jose/v3 v3.0.1 h1:pWmKFVtt+Jl0vBZTIpz/eAKwsm6LkIxDVVbFHKkchhA= +github.com/go-jose/go-jose/v3 v3.0.1/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.10.0 h1:dXFJfIHVvUcpSgDOV+Ne6t7jXri8Tfv2uOLHUZ2XNuo= @@ -1517,6 +1523,8 @@ github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5q github.com/spf13/viper v1.7.1/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg= github.com/spf13/viper v1.15.0 h1:js3yy885G8xwJa6iOISGFwd+qlUo5AvyXb7CiihdtiU= github.com/spf13/viper v1.15.0/go.mod h1:fFcTBJxvhhzSJiZy8n+PeW6t8l+KeT/uTARa0jHOQLA= +github.com/spiffe/go-spiffe/v2 v2.1.7 h1:VUkM1yIyg/x8X7u1uXqSRVRCdMdfRIEdFBzpqoeASGk= +github.com/spiffe/go-spiffe/v2 v2.1.7/go.mod h1:QJDGdhXllxjxvd5B+2XnhhXB/+rC8gr+lNrtOryiWeE= github.com/stealthrocket/wasi-go v0.8.1-0.20230912180546-8efbab50fb58 h1:mTC4gyv3lcJ1XpzZMAckqkvWUqeT5Bva4RAT1IoHAAA= github.com/stealthrocket/wasi-go v0.8.1-0.20230912180546-8efbab50fb58/go.mod h1:ZAYCOqLJkc9P6fcq14TV4cf+gJ2fHthp9kCGxBViagE= github.com/stealthrocket/wazergo v0.19.1 h1:BPrITETPgSFwiytwmToO0MbUC/+RGC39JScz1JmmG6c= @@ -1652,6 +1660,8 @@ github.com/yuin/gopher-lua v1.1.0/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7 github.com/yusufpapurcu/wmi v1.2.2/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw= github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= +github.com/zeebo/errs v1.3.0 h1:hmiaKqgYZzcVgRL1Vkc1Mn2914BbzB0IBxs+ebeutGs= +github.com/zeebo/errs v1.3.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= github.com/zouyx/agollo/v3 v3.4.5 h1:7YCxzY9ZYaH9TuVUBvmI6Tk0mwMggikah+cfbYogcHQ= github.com/zouyx/agollo/v3 v3.4.5/go.mod h1:LJr3kDmm23QSW+F1Ol4TMHDa7HvJvscMdVxJ2IpUTVc= go.einride.tech/aip v0.66.0 h1:XfV+NQX6L7EOYK11yoHHFtndeaWh3KbD9/cN/6iWEt8= @@ -2316,6 +2326,8 @@ google.golang.org/grpc v1.47.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACu google.golang.org/grpc v1.48.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk= google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY= google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg= +google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20 h1:MLBCGN1O7GzIx+cBiwfYPwtmZ41U3Mn/cotLJciaArI= +google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20/go.mod h1:Nr5H8+MlGWr5+xX/STzdoEqJrO+YteqFbMyCsrb6mH0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/pubsub/aws/snssqs/metadata.go b/pubsub/aws/snssqs/metadata.go index db45fb8d84..4b469106b7 100644 --- a/pubsub/aws/snssqs/metadata.go +++ b/pubsub/aws/snssqs/metadata.go @@ -67,7 +67,7 @@ func maskLeft(s string) string { return string(rs) } -func (s *snsSqs) getSnsSqsMetatdata(meta pubsub.Metadata) (*snsSqsMetadata, error) { +func (s *snsSqs) getSnsSqsMetadata(meta pubsub.Metadata) (*snsSqsMetadata, error) { md := &snsSqsMetadata{ AssetsManagementTimeoutSeconds: assetsManagementDefaultTimeoutSeconds, MessageVisibilityTimeout: 10, diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index 357cfcabb9..93481fb733 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -49,9 +49,7 @@ type snsSqs struct { queues map[string]*sqsQueueInfo // key is a composite key of queue ARN and topic ARN mapping to subscription ARN. subscriptions map[string]string - snsClient *sns.SNS - sqsClient *sqs.SQS - stsClient *sts.STS + authProvider awsAuth.Provider metadata *snsSqsMetadata logger logger.Logger id string @@ -138,23 +136,33 @@ func nameToAWSSanitizedName(name string, isFifo bool) string { } func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error { - md, err := s.getSnsSqsMetatdata(metadata) + m, err := s.getSnsSqsMetadata(metadata) if err != nil { return err } - s.metadata = md + s.metadata = m - sess, err := awsAuth.GetClient(md.AccessKey, md.SecretKey, md.SessionToken, md.Region, md.Endpoint) - if err != nil { - return fmt.Errorf("error creating an AWS client: %w", err) + if s.authProvider == nil { + opts := awsAuth.Options{ + Logger: s.logger, + Properties: metadata.Properties, + Region: m.Region, + Endpoint: m.Endpoint, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, + } + // extra configs needed per component type + var provider awsAuth.Provider + provider, err = awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) + if err != nil { + return err + } + s.authProvider = provider } - // AWS sns,sqs,sts client. - s.snsClient = sns.New(sess) - s.sqsClient = sqs.New(sess) - s.stsClient = sts.New(sess) - s.opsTimeout = time.Duration(md.AssetsManagementTimeoutSeconds * float64(time.Second)) + s.opsTimeout = time.Duration(m.AssetsManagementTimeoutSeconds * float64(time.Second)) err = s.setAwsAccountIDIfNotProvided(ctx) if err != nil { @@ -181,9 +189,8 @@ func (s *snsSqs) setAwsAccountIDIfNotProvided(parentCtx context.Context) error { if len(s.metadata.AccountID) == awsAccountIDLength { return nil } - ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) - callerIDOutput, err := s.stsClient.GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{}) + callerIDOutput, err := s.authProvider.SnsSqs().Sts.GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{}) cancelFn() if err != nil { return fmt.Errorf("error fetching sts caller ID: %w", err) @@ -208,9 +215,8 @@ func (s *snsSqs) createTopic(parentCtx context.Context, topic string) (string, e attributes := map[string]*string{"FifoTopic": aws.String("true"), "ContentBasedDeduplication": aws.String("true")} snsCreateTopicInput.SetAttributes(attributes) } - ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) - createTopicResponse, err := s.snsClient.CreateTopicWithContext(ctx, snsCreateTopicInput) + createTopicResponse, err := s.authProvider.SnsSqs().Sns.CreateTopicWithContext(ctx, snsCreateTopicInput) cancelFn() if err != nil { return "", fmt.Errorf("error while creating an SNS topic: %w", err) @@ -222,7 +228,7 @@ func (s *snsSqs) createTopic(parentCtx context.Context, topic string) (string, e func (s *snsSqs) getTopicArn(parentCtx context.Context, topic string) (string, error) { ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) arn := s.buildARN("sns", topic) - getTopicOutput, err := s.snsClient.GetTopicAttributesWithContext(ctx, &sns.GetTopicAttributesInput{ + getTopicOutput, err := s.authProvider.SnsSqs().Sns.GetTopicAttributesWithContext(ctx, &sns.GetTopicAttributesInput{ TopicArn: &arn, }) cancelFn() @@ -288,15 +294,16 @@ func (s *snsSqs) createQueue(parentCtx context.Context, queueName string) (*sqsQ attributes := map[string]*string{"FifoQueue": aws.String("true"), "ContentBasedDeduplication": aws.String("true")} sqsCreateQueueInput.SetAttributes(attributes) } + ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - createQueueResponse, err := s.sqsClient.CreateQueueWithContext(ctx, sqsCreateQueueInput) + createQueueResponse, err := s.authProvider.SnsSqs().Sqs.CreateQueueWithContext(ctx, sqsCreateQueueInput) cancel() if err != nil { return nil, fmt.Errorf("error creaing an SQS queue: %w", err) } ctx, cancel = context.WithTimeout(parentCtx, s.opsTimeout) - queueAttributesResponse, err := s.sqsClient.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ + queueAttributesResponse, err := s.authProvider.SnsSqs().Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ AttributeNames: []*string{aws.String("QueueArn")}, QueueUrl: createQueueResponse.QueueUrl, }) @@ -313,7 +320,7 @@ func (s *snsSqs) createQueue(parentCtx context.Context, queueName string) (*sqsQ func (s *snsSqs) getQueueArn(parentCtx context.Context, queueName string) (*sqsQueueInfo, error) { ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - queueURLOutput, err := s.sqsClient.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{QueueName: aws.String(queueName), QueueOwnerAWSAccountId: aws.String(s.metadata.AccountID)}) + queueURLOutput, err := s.authProvider.SnsSqs().Sqs.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{QueueName: aws.String(queueName), QueueOwnerAWSAccountId: aws.String(s.metadata.AccountID)}) cancel() if err != nil { return nil, fmt.Errorf("error: %w while getting url of queue: %s", err, queueName) @@ -321,7 +328,7 @@ func (s *snsSqs) getQueueArn(parentCtx context.Context, queueName string) (*sqsQ url := queueURLOutput.QueueUrl ctx, cancel = context.WithTimeout(parentCtx, s.opsTimeout) - getQueueOutput, err := s.sqsClient.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{QueueUrl: url, AttributeNames: []*string{aws.String("QueueArn")}}) + getQueueOutput, err := s.authProvider.SnsSqs().Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{QueueUrl: url, AttributeNames: []*string{aws.String("QueueArn")}}) cancel() if err != nil { return nil, fmt.Errorf("error: %w while getting information for queue: %s, with url: %s", err, queueName, *url) @@ -382,7 +389,7 @@ func (s *snsSqs) getMessageGroupID(req *pubsub.PublishRequest) *string { func (s *snsSqs) createSnsSqsSubscription(parentCtx context.Context, queueArn, topicArn string) (string, error) { ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - subscribeOutput, err := s.snsClient.SubscribeWithContext(ctx, &sns.SubscribeInput{ + subscribeOutput, err := s.authProvider.SnsSqs().Sns.SubscribeWithContext(ctx, &sns.SubscribeInput{ Attributes: nil, Endpoint: aws.String(queueArn), // create SQS queue per subscription. Protocol: aws.String("sqs"), @@ -402,7 +409,7 @@ func (s *snsSqs) createSnsSqsSubscription(parentCtx context.Context, queueArn, t func (s *snsSqs) getSnsSqsSubscriptionArn(parentCtx context.Context, topicArn string) (string, error) { ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - listSubscriptionsOutput, err := s.snsClient.ListSubscriptionsByTopicWithContext(ctx, &sns.ListSubscriptionsByTopicInput{TopicArn: aws.String(topicArn)}) + listSubscriptionsOutput, err := s.authProvider.SnsSqs().Sns.ListSubscriptionsByTopicWithContext(ctx, &sns.ListSubscriptionsByTopicInput{TopicArn: aws.String(topicArn)}) cancel() if err != nil { return "", fmt.Errorf("error listing subsriptions for topic arn: %v: %w", topicArn, err) @@ -451,7 +458,7 @@ func (s *snsSqs) getOrCreateSnsSqsSubscription(ctx context.Context, queueArn, to func (s *snsSqs) acknowledgeMessage(parentCtx context.Context, queueURL string, receiptHandle *string) error { ctx, cancelFn := context.WithCancel(parentCtx) - _, err := s.sqsClient.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{ + _, err := s.authProvider.SnsSqs().Sqs.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{ QueueUrl: aws.String(queueURL), ReceiptHandle: receiptHandle, }) @@ -466,7 +473,7 @@ func (s *snsSqs) acknowledgeMessage(parentCtx context.Context, queueURL string, func (s *snsSqs) resetMessageVisibilityTimeout(parentCtx context.Context, queueURL string, receiptHandle *string) error { ctx, cancelFn := context.WithCancel(parentCtx) // reset the timeout to its initial value so that the remaining timeout would be overridden by the initial value for other consumer to attempt processing. - _, err := s.sqsClient.ChangeMessageVisibilityWithContext(ctx, &sqs.ChangeMessageVisibilityInput{ + _, err := s.authProvider.SnsSqs().Sqs.ChangeMessageVisibilityWithContext(ctx, &sqs.ChangeMessageVisibilityInput{ QueueUrl: aws.String(queueURL), ReceiptHandle: receiptHandle, VisibilityTimeout: aws.Int64(0), @@ -593,12 +600,11 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters if ctx.Err() != nil { break } - - // Internally, by default, aws go sdk performs 3 retires with exponential backoff to contact + // Internally, by default, aws go sdk performs 3 retries with exponential backoff to contact // sqs and try pull messages. Since we are iteratively short polling (based on the defined // s.metadata.messageWaitTimeSeconds) the sdk backoff is not effective as it gets reset per each polling // iteration. Therefore, a global backoff (to the internal backoff) is used (sqsPullExponentialBackoff). - messageResponse, err := s.sqsClient.ReceiveMessageWithContext(ctx, receiveMessageInput) + messageResponse, err := s.authProvider.SnsSqs().Sqs.ReceiveMessageWithContext(ctx, receiveMessageInput) if err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || ctx.Err() != nil { s.logger.Warn("context canceled; stopping consuming from queue arn: %v", queueInfo.arn) @@ -690,9 +696,8 @@ func (s *snsSqs) setDeadLettersQueueAttributes(parentCtx context.Context, queueI return wrappedErr } - ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) - _, derr = s.sqsClient.SetQueueAttributesWithContext(ctx, sqsSetQueueAttributesInput) + _, derr = s.authProvider.SnsSqs().Sqs.SetQueueAttributesWithContext(ctx, sqsSetQueueAttributesInput) cancelFn() if derr != nil { wrappedErr := fmt.Errorf("error updating queue attributes with dead-letter queue: %w", derr) @@ -712,7 +717,7 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(parentCtx context.Context, ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) // only permit SNS to send messages to SQS using the created subscription. - getQueueAttributesOutput, err := s.sqsClient.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ + getQueueAttributesOutput, err := s.authProvider.SnsSqs().Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ QueueUrl: &sqsQueueInfo.url, AttributeNames: []*string{aws.String(sqs.QueueAttributeNamePolicy)}, }) @@ -739,7 +744,7 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(parentCtx context.Context, } ctx, cancelFn = context.WithTimeout(parentCtx, s.opsTimeout) - _, err = s.sqsClient.SetQueueAttributesWithContext(ctx, &(sqs.SetQueueAttributesInput{ + _, err = s.authProvider.SnsSqs().Sqs.SetQueueAttributesWithContext(ctx, &(sqs.SetQueueAttributesInput{ Attributes: map[string]*string{ "Policy": aws.String(string(b)), }, @@ -852,7 +857,7 @@ func (s *snsSqs) Publish(ctx context.Context, req *pubsub.PublishRequest) error } // sns client has internal exponential backoffs. - _, err = s.snsClient.PublishWithContext(ctx, snsPublishInput) + _, err = s.authProvider.SnsSqs().Sns.PublishWithContext(ctx, snsPublishInput) if err != nil { wrappedErr := fmt.Errorf("error publishing to topic: %s with topic ARN %s: %w", req.Topic, topicArn, err) s.logger.Error(wrappedErr) @@ -870,7 +875,7 @@ func (s *snsSqs) Close() error { s.subscriptionManager.Close() } - return nil + return s.authProvider.Close() } func (s *snsSqs) Features() []pubsub.Feature { diff --git a/pubsub/aws/snssqs/snssqs_test.go b/pubsub/aws/snssqs/snssqs_test.go index 1c789b67be..f396ddef47 100644 --- a/pubsub/aws/snssqs/snssqs_test.go +++ b/pubsub/aws/snssqs/snssqs_test.go @@ -38,7 +38,7 @@ func Test_parseTopicArn(t *testing.T) { } // Verify that all metadata ends up in the correct spot. -func Test_getSnsSqsMetatdata_AllConfiguration(t *testing.T) { +func Test_getSnsSqsMetadata_AllConfiguration(t *testing.T) { t.Parallel() r := require.New(t) l := logger.NewLogger("SnsSqs unit test") @@ -47,7 +47,7 @@ func Test_getSnsSqsMetatdata_AllConfiguration(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "consumer", "Endpoint": "endpoint", "concurrencyMode": string(pubsub.Single), @@ -80,7 +80,7 @@ func Test_getSnsSqsMetatdata_AllConfiguration(t *testing.T) { r.Equal(int64(6), md.MessageReceiveLimit) } -func Test_getSnsSqsMetatdata_defaults(t *testing.T) { +func Test_getSnsSqsMetadata_defaults(t *testing.T) { t.Parallel() r := require.New(t) l := logger.NewLogger("SnsSqs unit test") @@ -89,7 +89,7 @@ func Test_getSnsSqsMetatdata_defaults(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "c", "accessKey": "a", "secretKey": "s", @@ -114,7 +114,7 @@ func Test_getSnsSqsMetatdata_defaults(t *testing.T) { r.False(md.DisableDeleteOnRetryLimit) } -func Test_getSnsSqsMetatdata_legacyaliases(t *testing.T) { +func Test_getSnsSqsMetadata_legacyaliases(t *testing.T) { t.Parallel() r := require.New(t) l := logger.NewLogger("SnsSqs unit test") @@ -123,7 +123,7 @@ func Test_getSnsSqsMetatdata_legacyaliases(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "consumer", "awsAccountID": "acctId", "awsSecret": "secret", @@ -151,13 +151,13 @@ func testMetadataParsingShouldFail(t *testing.T, metadata pubsub.Metadata, l log logger: l, } - md, err := ps.getSnsSqsMetatdata(metadata) + md, err := ps.getSnsSqsMetadata(metadata) r.Error(err) r.Nil(md) } -func Test_getSnsSqsMetatdata_invalidMetadataSetup(t *testing.T) { +func Test_getSnsSqsMetadata_invalidMetadataSetup(t *testing.T) { t.Parallel() fixtures := []testUnitFixture{ @@ -432,7 +432,7 @@ func Test_buildARN_DefaultPartition(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "c", "accessKey": "a", "secretKey": "s", @@ -455,7 +455,7 @@ func Test_buildARN_StandardPartition(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "c", "accessKey": "a", "secretKey": "s", @@ -478,7 +478,7 @@ func Test_buildARN_NonStandardPartition(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "c", "accessKey": "a", "secretKey": "s", diff --git a/secretstores/aws/parameterstore/parameterstore.go b/secretstores/aws/parameterstore/parameterstore.go index 1f82031ba8..abf9c6c4de 100644 --- a/secretstores/aws/parameterstore/parameterstore.go +++ b/secretstores/aws/parameterstore/parameterstore.go @@ -20,7 +20,6 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ssm" - "github.com/aws/aws-sdk-go/service/ssm/ssmiface" awsAuth "github.com/dapr/components-contrib/common/authentication/aws" "github.com/dapr/components-contrib/metadata" @@ -53,23 +52,33 @@ type ParameterStoreMetaData struct { } type ssmSecretStore struct { - client ssmiface.SSMAPI - prefix string - logger logger.Logger + authProvider awsAuth.Provider + prefix string + logger logger.Logger } // Init creates an AWS secret manager client. func (s *ssmSecretStore) Init(ctx context.Context, metadata secretstores.Metadata) error { - meta, err := s.getSecretManagerMetadata(metadata) + m, err := s.getSecretManagerMetadata(metadata) if err != nil { return err } - s.client, err = s.getClient(meta) + opts := awsAuth.Options{ + Logger: s.logger, + Properties: metadata.Properties, + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: "", + } + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - s.prefix = meta.Prefix + s.authProvider = provider + s.prefix = m.Prefix return nil } @@ -84,7 +93,7 @@ func (s *ssmSecretStore) GetSecret(ctx context.Context, req secretstores.GetSecr name = fmt.Sprintf("%s:%s", req.Name, versionID) } - output, err := s.client.GetParameterWithContext(ctx, &ssm.GetParameterInput{ + output, err := s.authProvider.ParameterStore().Store.GetParameterWithContext(ctx, &ssm.GetParameterInput{ Name: ptr.Of(s.prefix + name), WithDecryption: ptr.Of(true), }) @@ -124,7 +133,7 @@ func (s *ssmSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bul } for search { - output, err := s.client.DescribeParametersWithContext(ctx, &ssm.DescribeParametersInput{ + output, err := s.authProvider.ParameterStore().Store.DescribeParametersWithContext(ctx, &ssm.DescribeParametersInput{ MaxResults: nil, NextToken: nextToken, ParameterFilters: filters, @@ -134,7 +143,7 @@ func (s *ssmSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bul } for _, entry := range output.Parameters { - params, err := s.client.GetParameterWithContext(ctx, &ssm.GetParameterInput{ + params, err := s.authProvider.ParameterStore().Store.GetParameterWithContext(ctx, &ssm.GetParameterInput{ Name: entry.Name, WithDecryption: aws.Bool(true), }) @@ -155,15 +164,6 @@ func (s *ssmSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bul return resp, nil } -func (s *ssmSecretStore) getClient(metadata *ParameterStoreMetaData) (*ssm.SSM, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, "") - if err != nil { - return nil, err - } - - return ssm.New(sess), nil -} - func (s *ssmSecretStore) getSecretManagerMetadata(spec secretstores.Metadata) (*ParameterStoreMetaData, error) { meta := ParameterStoreMetaData{} err := kitmd.DecodeMetadata(spec.Properties, &meta) @@ -182,5 +182,5 @@ func (s *ssmSecretStore) GetComponentMetadata() (metadataInfo metadata.MetadataM } func (s *ssmSecretStore) Close() error { - return nil + return s.authProvider.Close() } diff --git a/secretstores/aws/parameterstore/parameterstore_test.go b/secretstores/aws/parameterstore/parameterstore_test.go index 8d9bcf6065..04c7a6995e 100644 --- a/secretstores/aws/parameterstore/parameterstore_test.go +++ b/secretstores/aws/parameterstore/parameterstore_test.go @@ -22,9 +22,11 @@ import ( "testing" "github.com/aws/aws-sdk-go/aws" + + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/ssm" - "github.com/aws/aws-sdk-go/service/ssm/ssmiface" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -34,20 +36,6 @@ import ( const secretValue = "secret" -type mockedSSM struct { - GetParameterFn func(context.Context, *ssm.GetParameterInput, ...request.Option) (*ssm.GetParameterOutput, error) - DescribeParametersFn func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) - ssmiface.SSMAPI -} - -func (m *mockedSSM) GetParameterWithContext(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - return m.GetParameterFn(ctx, input, option...) -} - -func (m *mockedSSM) DescribeParametersWithContext(ctx context.Context, input *ssm.DescribeParametersInput, option ...request.Option) (*ssm.DescribeParametersOutput, error) { - return m.DescribeParametersFn(ctx, input, option...) -} - func TestInit(t *testing.T) { m := secretstores.Metadata{} s := NewParameterStore(logger.NewLogger("test")) @@ -68,21 +56,32 @@ func TestInit(t *testing.T) { func TestGetSecret(t *testing.T) { t.Run("successfully retrieve secret", func(t *testing.T) { t.Run("with valid path", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - secret := secretValue - - return &ssm.GetParameterOutput{ - Parameter: &ssm.Parameter{ - Name: input.Name, - Value: &secret, - }, - }, nil - }, + mockSSM := &awsAuth.MockParameterStore{ + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + secret := secretValue + return &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: input.Name, + Value: &secret, + }, + }, nil }, } + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + + s := ssmSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/dev/secret", Metadata: map[string]string{}, @@ -93,25 +92,36 @@ func TestGetSecret(t *testing.T) { }) t.Run("with version id", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - secret := secretValue - keys := strings.Split(*input.Name, ":") - assert.NotNil(t, keys) - assert.Len(t, keys, 2) - assert.Equalf(t, "1", keys[1], "Version IDs are same") - - return &ssm.GetParameterOutput{ - Parameter: &ssm.Parameter{ - Name: &keys[0], - Value: &secret, - }, - }, nil - }, + mockSSM := &awsAuth.MockParameterStore{ + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + secret := secretValue + keys := strings.Split(*input.Name, ":") + assert.NotNil(t, keys) + assert.Len(t, keys, 2) + assert.Equalf(t, "1", keys[1], "Version IDs are same") + + return &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: &keys[0], + Value: &secret, + }, + }, nil }, } + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := ssmSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/dev/secret", Metadata: map[string]string{ @@ -124,21 +134,33 @@ func TestGetSecret(t *testing.T) { }) t.Run("with prefix", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - assert.Equal(t, "/prefix/aws/dev/secret", *input.Name) - secret := secretValue - - return &ssm.GetParameterOutput{ - Parameter: &ssm.Parameter{ - Name: input.Name, - Value: &secret, - }, - }, nil - }, + mockSSM := &awsAuth.MockParameterStore{ + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + assert.Equal(t, "/prefix/aws/dev/secret", *input.Name) + secret := secretValue + + return &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: input.Name, + Value: &secret, + }, + }, nil }, - prefix: "/prefix", + } + + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + + s := ssmSecretStore{ + authProvider: mockAuthProvider, + prefix: "/prefix", } req := secretstores.GetSecretRequest{ @@ -152,13 +174,27 @@ func TestGetSecret(t *testing.T) { }) t.Run("unsuccessfully retrieve secret", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - return nil, errors.New("failed due to any reason") - }, + mockSSM := &awsAuth.MockParameterStore{ + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + return nil, errors.New("failed due to any reason") }, } + + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + + s := ssmSecretStore{ + authProvider: mockAuthProvider, + prefix: "/prefix", + } + req := secretstores.GetSecretRequest{ Name: "/aws/dev/secret", Metadata: map[string]string{}, @@ -170,31 +206,42 @@ func TestGetSecret(t *testing.T) { func TestGetBulkSecrets(t *testing.T) { t.Run("successfully retrieve bulk secrets", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { - return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ - { - Name: aws.String("/aws/dev/secret1"), - }, - { - Name: aws.String("/aws/dev/secret2"), - }, - }}, nil - }, - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - secret := fmt.Sprintf("%s-%s", *input.Name, secretValue) + mockSSM := &awsAuth.MockParameterStore{ + DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { + return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ + { + Name: aws.String("/aws/dev/secret1"), + }, + { + Name: aws.String("/aws/dev/secret2"), + }, + }}, nil + }, + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + secret := fmt.Sprintf("%s-%s", *input.Name, secretValue) - return &ssm.GetParameterOutput{ - Parameter: &ssm.Parameter{ - Name: input.Name, - Value: &secret, - }, - }, nil - }, + return &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: input.Name, + Value: &secret, + }, + }, nil }, } + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := ssmSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.BulkGetSecretRequest{ Metadata: map[string]string{}, } @@ -205,30 +252,41 @@ func TestGetBulkSecrets(t *testing.T) { }) t.Run("successfully retrieve bulk secrets with prefix", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { - return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ - { - Name: aws.String("/prefix/aws/dev/secret1"), - }, - { - Name: aws.String("/prefix/aws/dev/secret2"), - }, - }}, nil - }, - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - secret := fmt.Sprintf("%s-%s", *input.Name, secretValue) + mockSSM := &awsAuth.MockParameterStore{ + DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { + return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ + { + Name: aws.String("/prefix/aws/dev/secret1"), + }, + { + Name: aws.String("/prefix/aws/dev/secret2"), + }, + }}, nil + }, + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + secret := fmt.Sprintf("%s-%s", *input.Name, secretValue) - return &ssm.GetParameterOutput{ - Parameter: &ssm.Parameter{ - Name: input.Name, - Value: &secret, - }, - }, nil - }, + return &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: input.Name, + Value: &secret, + }, + }, nil }, - prefix: "/prefix", + } + + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := ssmSecretStore{ + authProvider: mockAuthProvider, + prefix: "/prefix", } req := secretstores.BulkGetSecretRequest{ @@ -241,23 +299,35 @@ func TestGetBulkSecrets(t *testing.T) { }) t.Run("unsuccessfully retrieve bulk secrets on get parameter", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { - return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ - { - Name: aws.String("/aws/dev/secret1"), - }, - { - Name: aws.String("/aws/dev/secret2"), - }, - }}, nil - }, - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - return nil, errors.New("failed due to any reason") - }, + mockSSM := &awsAuth.MockParameterStore{ + DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { + return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ + { + Name: aws.String("/aws/dev/secret1"), + }, + { + Name: aws.String("/aws/dev/secret2"), + }, + }}, nil }, + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + return nil, errors.New("failed due to any reason") + }, + } + + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := ssmSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.BulkGetSecretRequest{ Metadata: map[string]string{}, } @@ -266,13 +336,25 @@ func TestGetBulkSecrets(t *testing.T) { }) t.Run("unsuccessfully retrieve bulk secrets on describe parameter", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { - return nil, errors.New("failed due to any reason") - }, + mockSSM := &awsAuth.MockParameterStore{ + DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { + return nil, errors.New("failed due to any reason") }, } + + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := ssmSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.BulkGetSecretRequest{ Metadata: map[string]string{}, } diff --git a/secretstores/aws/secretmanager/secretmanager.go b/secretstores/aws/secretmanager/secretmanager.go index 54ed329d35..6faf1f1eab 100644 --- a/secretstores/aws/secretmanager/secretmanager.go +++ b/secretstores/aws/secretmanager/secretmanager.go @@ -20,7 +20,6 @@ import ( "reflect" "github.com/aws/aws-sdk-go/service/secretsmanager" - "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" awsAuth "github.com/dapr/components-contrib/common/authentication/aws" "github.com/dapr/components-contrib/metadata" @@ -49,8 +48,8 @@ type SecretManagerMetaData struct { } type smSecretStore struct { - client secretsmanageriface.SecretsManagerAPI - logger logger.Logger + authProvider awsAuth.Provider + logger logger.Logger } // Init creates an AWS secret manager client. @@ -60,11 +59,20 @@ func (s *smSecretStore) Init(ctx context.Context, metadata secretstores.Metadata return err } - s.client, err = s.getClient(meta) + opts := awsAuth.Options{ + Logger: s.logger, + Region: meta.Region, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, + Endpoint: meta.Endpoint, + } + + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - + s.authProvider = provider return nil } @@ -78,8 +86,7 @@ func (s *smSecretStore) GetSecret(ctx context.Context, req secretstores.GetSecre if value, ok := req.Metadata[VersionStage]; ok { versionStage = &value } - - output, err := s.client.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ + output, err := s.authProvider.SecretManager().Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ SecretId: &req.Name, VersionId: versionID, VersionStage: versionStage, @@ -108,7 +115,7 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk var nextToken *string = nil for search { - output, err := s.client.ListSecretsWithContext(ctx, &secretsmanager.ListSecretsInput{ + output, err := s.authProvider.SecretManager().Manager.ListSecretsWithContext(ctx, &secretsmanager.ListSecretsInput{ MaxResults: nil, NextToken: nextToken, }) @@ -117,7 +124,7 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk } for _, entry := range output.SecretList { - secrets, err := s.client.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ + secrets, err := s.authProvider.SecretManager().Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ SecretId: entry.Name, }) if err != nil { @@ -136,15 +143,6 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk return resp, nil } -func (s *smSecretStore) getClient(metadata *SecretManagerMetaData) (*secretsmanager.SecretsManager, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - - return secretsmanager.New(sess), nil -} - func (s *smSecretStore) getSecretManagerMetadata(spec secretstores.Metadata) (*SecretManagerMetaData, error) { b, err := json.Marshal(spec.Properties) if err != nil { @@ -172,5 +170,5 @@ func (s *smSecretStore) GetComponentMetadata() (metadataInfo metadata.MetadataMa } func (s *smSecretStore) Close() error { - return nil + return s.authProvider.Close() } diff --git a/secretstores/aws/secretmanager/secretmanager_test.go b/secretstores/aws/secretmanager/secretmanager_test.go index 85918237a3..7fbd8493af 100644 --- a/secretstores/aws/secretmanager/secretmanager_test.go +++ b/secretstores/aws/secretmanager/secretmanager_test.go @@ -21,25 +21,17 @@ import ( "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/secretsmanager" - "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" + "github.com/dapr/components-contrib/secretstores" "github.com/dapr/kit/logger" ) const secretValue = "secret" -type mockedSM struct { - GetSecretValueFn func(context.Context, *secretsmanager.GetSecretValueInput, ...request.Option) (*secretsmanager.GetSecretValueOutput, error) - secretsmanageriface.SecretsManagerAPI -} - -func (m *mockedSM) GetSecretValueWithContext(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { - return m.GetSecretValueFn(ctx, input, option...) -} - func TestInit(t *testing.T) { m := secretstores.Metadata{} s := NewSecretManager(logger.NewLogger("test")) @@ -60,21 +52,32 @@ func TestInit(t *testing.T) { func TestGetSecret(t *testing.T) { t.Run("successfully retrieve secret", func(t *testing.T) { t.Run("without version id and version stage", func(t *testing.T) { - s := smSecretStore{ - client: &mockedSM{ - GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { - assert.Nil(t, input.VersionId) - assert.Nil(t, input.VersionStage) - secret := secretValue - - return &secretsmanager.GetSecretValueOutput{ - Name: input.SecretId, - SecretString: &secret, - }, nil - }, + mockSSM := &awsAuth.MockSecretManager{ + GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { + assert.Nil(t, input.VersionId) + assert.Nil(t, input.VersionStage) + secret := secretValue + + return &secretsmanager.GetSecretValueOutput{ + Name: input.SecretId, + SecretString: &secret, + }, nil }, } + secret := awsAuth.SecretManagerClients{ + Manager: mockSSM, + } + + mockedClients := awsAuth.Clients{ + Secret: &secret, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := smSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/secret/testing", Metadata: map[string]string{}, @@ -85,20 +88,32 @@ func TestGetSecret(t *testing.T) { }) t.Run("with version id", func(t *testing.T) { - s := smSecretStore{ - client: &mockedSM{ - GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { - assert.NotNil(t, input.VersionId) - secret := secretValue - - return &secretsmanager.GetSecretValueOutput{ - Name: input.SecretId, - SecretString: &secret, - }, nil - }, + mockSSM := &awsAuth.MockSecretManager{ + GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { + assert.NotNil(t, input.VersionId) + secret := secretValue + + return &secretsmanager.GetSecretValueOutput{ + Name: input.SecretId, + SecretString: &secret, + }, nil }, } + secret := awsAuth.SecretManagerClients{ + Manager: mockSSM, + } + + mockedClients := awsAuth.Clients{ + Secret: &secret, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := smSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/secret/testing", Metadata: map[string]string{ @@ -111,20 +126,32 @@ func TestGetSecret(t *testing.T) { }) t.Run("with version stage", func(t *testing.T) { - s := smSecretStore{ - client: &mockedSM{ - GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { - assert.NotNil(t, input.VersionStage) - secret := secretValue - - return &secretsmanager.GetSecretValueOutput{ - Name: input.SecretId, - SecretString: &secret, - }, nil - }, + mockSSM := &awsAuth.MockSecretManager{ + GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { + assert.NotNil(t, input.VersionStage) + secret := secretValue + + return &secretsmanager.GetSecretValueOutput{ + Name: input.SecretId, + SecretString: &secret, + }, nil }, } + secret := awsAuth.SecretManagerClients{ + Manager: mockSSM, + } + + mockedClients := awsAuth.Clients{ + Secret: &secret, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := smSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/secret/testing", Metadata: map[string]string{ @@ -138,13 +165,26 @@ func TestGetSecret(t *testing.T) { }) t.Run("unsuccessfully retrieve secret", func(t *testing.T) { - s := smSecretStore{ - client: &mockedSM{ - GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { - return nil, errors.New("failed due to any reason") - }, + mockSSM := &awsAuth.MockSecretManager{ + GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { + return nil, errors.New("failed due to any reason") }, } + + secret := awsAuth.SecretManagerClients{ + Manager: mockSSM, + } + + mockedClients := awsAuth.Clients{ + Secret: &secret, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := smSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/secret/testing", Metadata: map[string]string{}, diff --git a/state/aws/dynamodb/dynamodb.go b/state/aws/dynamodb/dynamodb.go index 503d7082c7..ae4ba7c5e9 100644 --- a/state/aws/dynamodb/dynamodb.go +++ b/state/aws/dynamodb/dynamodb.go @@ -25,7 +25,6 @@ import ( "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" jsoniterator "github.com/json-iterator/go" awsAuth "github.com/dapr/components-contrib/common/authentication/aws" @@ -41,7 +40,8 @@ import ( type StateStore struct { state.BulkStore - client dynamodbiface.DynamoDBAPI + authProvider awsAuth.Provider + logger logger.Logger table string ttlAttributeName string partitionKey string @@ -66,9 +66,10 @@ const ( ) // NewDynamoDBStateStore returns a new dynamoDB state store. -func NewDynamoDBStateStore(_ logger.Logger) state.Store { +func NewDynamoDBStateStore(logger logger.Logger) state.Store { s := &StateStore{ partitionKey: defaultPartitionKeyName, + logger: logger, } s.BulkStore = state.NewDefaultBulkStore(s) return s @@ -80,14 +81,24 @@ func (d *StateStore) Init(ctx context.Context, metadata state.Metadata) error { if err != nil { return err } - - // This check is needed because d.client is set to a mock in tests - if d.client == nil { - d.client, err = d.getClient(meta) + if d.authProvider == nil { + opts := awsAuth.Options{ + Logger: d.logger, + Properties: metadata.Properties, + Region: meta.Region, + Endpoint: meta.Endpoint, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, + } + cfg := awsAuth.GetConfig(opts) + provider, err := awsAuth.NewProvider(ctx, opts, cfg) if err != nil { return err } + d.authProvider = provider } + d.table = meta.Table d.ttlAttributeName = meta.TTLAttributeName d.partitionKey = meta.PartitionKey @@ -111,8 +122,7 @@ func (d *StateStore) validateTableAccess(ctx context.Context) error { }, }, } - - _, err := d.client.GetItemWithContext(ctx, input) + _, err := d.authProvider.DynamoDB().DynamoDB.GetItemWithContext(ctx, input) return err } @@ -144,8 +154,7 @@ func (d *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.Get }, }, } - - result, err := d.client.GetItemWithContext(ctx, input) + result, err := d.authProvider.DynamoDB().DynamoDB.GetItemWithContext(ctx, input) if err != nil { return nil, err } @@ -217,8 +226,7 @@ func (d *StateStore) Set(ctx context.Context, req *state.SetRequest) error { condExpr := "attribute_not_exists(etag)" input.ConditionExpression = &condExpr } - - _, err = d.client.PutItemWithContext(ctx, input) + _, err = d.authProvider.DynamoDB().DynamoDB.PutItemWithContext(ctx, input) if err != nil && req.HasETag() { switch cErr := err.(type) { case *dynamodb.ConditionalCheckFailedException: @@ -249,8 +257,7 @@ func (d *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error } input.ExpressionAttributeValues = exprAttrValues } - - _, err := d.client.DeleteItemWithContext(ctx, input) + _, err := d.authProvider.DynamoDB().DynamoDB.DeleteItemWithContext(ctx, input) if err != nil { switch cErr := err.(type) { case *dynamodb.ConditionalCheckFailedException: @@ -268,7 +275,7 @@ func (d *StateStore) GetComponentMetadata() (metadataInfo metadata.MetadataMap) } func (d *StateStore) Close() error { - return nil + return d.authProvider.Close() } func (d *StateStore) getDynamoDBMetadata(meta state.Metadata) (*dynamoDBMetadata, error) { @@ -281,16 +288,6 @@ func (d *StateStore) getDynamoDBMetadata(meta state.Metadata) (*dynamoDBMetadata return &m, err } -func (d *StateStore) getClient(metadata *dynamoDBMetadata) (*dynamodb.DynamoDB, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - c := dynamodb.New(sess) - - return c, nil -} - // getItemFromReq converts a dapr state.SetRequest into an dynamodb item func (d *StateStore) getItemFromReq(req *state.SetRequest) (map[string]*dynamodb.AttributeValue, error) { value, err := d.marshalToString(req.Value) @@ -431,8 +428,7 @@ func (d *StateStore) Multi(ctx context.Context, request *state.TransactionalStat } twinput.TransactItems = append(twinput.TransactItems, twi) } - - _, err := d.client.TransactWriteItemsWithContext(ctx, twinput) + _, err := d.authProvider.DynamoDB().DynamoDB.TransactWriteItemsWithContext(ctx, twinput) return err } diff --git a/state/aws/dynamodb/dynamodb_test.go b/state/aws/dynamodb/dynamodb_test.go index d1f98b70ba..7b667b6c78 100644 --- a/state/aws/dynamodb/dynamodb_test.go +++ b/state/aws/dynamodb/dynamodb_test.go @@ -21,26 +21,18 @@ import ( "testing" "time" + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/dapr/components-contrib/state" ) -type mockedDynamoDB struct { - GetItemWithContextFn func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) - PutItemWithContextFn func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) - DeleteItemWithContextFn func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) - BatchWriteItemWithContextFn func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) - TransactWriteItemsWithContextFn func(aws.Context, *dynamodb.TransactWriteItemsInput, ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) - dynamodbiface.DynamoDBAPI -} - type DynamoDBItem struct { Key string `json:"key"` Value string `json:"value"` @@ -52,36 +44,28 @@ const ( pkey = "partitionKey" ) -func (m *mockedDynamoDB) GetItemWithContext(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { - return m.GetItemWithContextFn(ctx, input, op...) -} - -func (m *mockedDynamoDB) PutItemWithContext(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) { - return m.PutItemWithContextFn(ctx, input, op...) -} - -func (m *mockedDynamoDB) DeleteItemWithContext(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) { - return m.DeleteItemWithContextFn(ctx, input, op...) -} +func TestInit(t *testing.T) { + m := state.Metadata{} + mockedDB := &awsAuth.MockDynamoDB{ + // We're adding this so we can pass the connection check on Init + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { + return nil, nil + }, + } -func (m *mockedDynamoDB) BatchWriteItemWithContext(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) { - return m.BatchWriteItemWithContextFn(ctx, input, op...) -} + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } -func (m *mockedDynamoDB) TransactWriteItemsWithContext(ctx context.Context, input *dynamodb.TransactWriteItemsInput, op ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) { - return m.TransactWriteItemsWithContextFn(ctx, input, op...) -} + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } -func TestInit(t *testing.T) { - m := state.Metadata{} - s := &StateStore{ + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, partitionKey: defaultPartitionKeyName, - client: &mockedDynamoDB{ - // We're adding this so we can pass the connection check on Init - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { - return nil, nil - }, - }, } t.Run("NewDynamoDBStateStore Default Partition Key", func(t *testing.T) { @@ -132,16 +116,29 @@ func TestInit(t *testing.T) { }) t.Run("Init with bad table name or permissions", func(t *testing.T) { + table := "does-not-exist" m.Properties = map[string]string{ - "Table": "does-not-exist", - "Region": "eu-west-1", + "Table": table, } - s.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { return nil, errors.New("Requested resource not found") }, } + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + table: table, + } err := s.Init(context.Background(), m) require.Error(t, err) @@ -151,10 +148,7 @@ func TestInit(t *testing.T) { func TestGet(t *testing.T) { t.Run("Successfully retrieve item", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return &dynamodb.GetItemOutput{ Item: map[string]*dynamodb.AttributeValue{ @@ -172,6 +166,20 @@ func TestGet(t *testing.T) { }, } + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } req := &state.GetRequest{ Key: "someKey", Metadata: nil, @@ -179,34 +187,46 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.NoError(t, err) assert.Equal(t, []byte("some value"), out.Data) assert.Equal(t, "1bdead4badc0ffee", *out.ETag) assert.NotContains(t, out.Metadata, "ttlExpireTime") }) t.Run("Successfully retrieve item (with unexpired ttl)", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { - return &dynamodb.GetItemOutput{ - Item: map[string]*dynamodb.AttributeValue{ - "key": { - S: aws.String("someKey"), - }, - "value": { - S: aws.String("some value"), - }, - "testAttributeName": { - N: aws.String("4074862051"), - }, - "etag": { - S: aws.String("1bdead4badc0ffee"), - }, + mockedDB := &awsAuth.MockDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + return &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("someKey"), }, - }, nil - }, + "value": { + S: aws.String("some value"), + }, + "testAttributeName": { + N: aws.String("4074862051"), + }, + "etag": { + S: aws.String("1bdead4badc0ffee"), + }, + }, + }, nil }, + } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, ttlAttributeName: "testAttributeName", } req := &state.GetRequest{ @@ -216,7 +236,7 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.NoError(t, err) assert.Equal(t, []byte("some value"), out.Data) assert.Equal(t, "1bdead4badc0ffee", *out.ETag) @@ -226,27 +246,39 @@ func TestGet(t *testing.T) { assert.Equal(t, int64(4074862051), expireTime.Unix()) }) t.Run("Successfully retrieve item (with expired ttl)", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { - return &dynamodb.GetItemOutput{ - Item: map[string]*dynamodb.AttributeValue{ - "key": { - S: aws.String("someKey"), - }, - "value": { - S: aws.String("some value"), - }, - "testAttributeName": { - N: aws.String("35489251"), - }, - "etag": { - S: aws.String("1bdead4badc0ffee"), - }, + mockedDB := &awsAuth.MockDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + return &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("someKey"), + }, + "value": { + S: aws.String("some value"), + }, + "testAttributeName": { + N: aws.String("35489251"), }, - }, nil - }, + "etag": { + S: aws.String("1bdead4badc0ffee"), + }, + }, + }, nil }, + } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, ttlAttributeName: "testAttributeName", } req := &state.GetRequest{ @@ -256,20 +288,33 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.NoError(t, err) assert.Nil(t, out.Data) assert.Nil(t, out.ETag) assert.Nil(t, out.Metadata) }) t.Run("Unsuccessfully get item", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { - return nil, errors.New("failed to retrieve data") - }, + mockedDB := &awsAuth.MockDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + return nil, errors.New("failed to retrieve data") }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + } + req := &state.GetRequest{ Key: "key", Metadata: nil, @@ -277,20 +322,32 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.Error(t, err) assert.Nil(t, out) }) t.Run("Unsuccessfully with empty response", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { - return &dynamodb.GetItemOutput{ - Item: map[string]*dynamodb.AttributeValue{}, - }, nil - }, + mockedDB := &awsAuth.MockDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + return &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{}, + }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + } req := &state.GetRequest{ Key: "key", Metadata: nil, @@ -298,26 +355,38 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.NoError(t, err) assert.Nil(t, out.Data) assert.Nil(t, out.ETag) assert.Nil(t, out.Metadata) }) t.Run("Unsuccessfully with no required key", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { - return &dynamodb.GetItemOutput{ - Item: map[string]*dynamodb.AttributeValue{ - "value2": { - S: aws.String("value"), - }, + mockedDB := &awsAuth.MockDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + return &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{ + "value2": { + S: aws.String("value"), }, - }, nil - }, + }, + }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + } req := &state.GetRequest{ Key: "key", Metadata: nil, @@ -325,7 +394,7 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.NoError(t, err) assert.Empty(t, out.Data) assert.Nil(t, out.ETag) @@ -338,10 +407,7 @@ func TestSet(t *testing.T) { } t.Run("Successfully set item", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -360,21 +426,34 @@ func TestSet(t *testing.T) { }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } + req := &state.SetRequest{ Key: "key", Value: value{ Value: "value", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Successfully set item with matching etag", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -397,6 +476,21 @@ func TestSet(t *testing.T) { }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } etag := "1bdead4badc0ffee" req := &state.SetRequest{ ETag: &etag, @@ -405,15 +499,12 @@ func TestSet(t *testing.T) { Value: "value", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Unsuccessfully set item with mismatched etag", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -431,6 +522,21 @@ func TestSet(t *testing.T) { return nil, &checkErr }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } etag := "bogusetag" req := &state.SetRequest{ ETag: &etag, @@ -440,7 +546,7 @@ func TestSet(t *testing.T) { }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.Error(t, err) switch tagErr := err.(type) { case *state.ETagError: @@ -451,10 +557,7 @@ func TestSet(t *testing.T) { }) t.Run("Successfully set item with first-write-concurrency", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -474,6 +577,21 @@ func TestSet(t *testing.T) { }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } req := &state.SetRequest{ Key: "key", Value: value{ @@ -483,15 +601,12 @@ func TestSet(t *testing.T) { Concurrency: state.FirstWrite, }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Unsuccessfully set item with first-write-concurrency", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -506,6 +621,21 @@ func TestSet(t *testing.T) { return nil, &checkErr }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } req := &state.SetRequest{ Key: "key", Value: value{ @@ -515,7 +645,7 @@ func TestSet(t *testing.T) { Concurrency: state.FirstWrite, }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.Error(t, err) switch err.(type) { case *state.ETagError: @@ -525,10 +655,7 @@ func TestSet(t *testing.T) { }) t.Run("Successfully set item with ttl = -1", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Len(t, input.Item, 4) result := DynamoDBItem{} @@ -547,7 +674,22 @@ func TestSet(t *testing.T) { }, nil }, } - ss.ttlAttributeName = "testAttributeName" + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + ttlAttributeName: "testAttributeName", + partitionKey: defaultPartitionKeyName, + } req := &state.SetRequest{ Key: "someKey", @@ -558,14 +700,11 @@ func TestSet(t *testing.T) { "ttlInSeconds": "-1", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Successfully set item with 'correct' ttl", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Len(t, input.Item, 4) result := DynamoDBItem{} @@ -584,7 +723,22 @@ func TestSet(t *testing.T) { }, nil }, } - ss.ttlAttributeName = "testAttributeName" + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + ttlAttributeName: "testAttributeName", + } req := &state.SetRequest{ Key: "someKey", @@ -595,33 +749,42 @@ func TestSet(t *testing.T) { "ttlInSeconds": "180", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Unsuccessfully set item", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { return nil, errors.New("unable to put item") }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } req := &state.SetRequest{ Key: "key", Value: value{ Value: "value", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.Error(t, err) }) t.Run("Successfully set item with correct ttl but without component metadata", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("someKey"), @@ -640,6 +803,21 @@ func TestSet(t *testing.T) { }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } req := &state.SetRequest{ Key: "someKey", Value: value{ @@ -649,34 +827,46 @@ func TestSet(t *testing.T) { "ttlInSeconds": "180", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Unsuccessfully set item with ttl (invalid value)", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { - assert.Equal(t, map[string]*dynamodb.AttributeValue{ - "key": { - S: aws.String("somekey"), - }, - "value": { - S: aws.String(`{"Value":"somevalue"}`), - }, - "ttlInSeconds": { - N: aws.String("180"), - }, - }, input.Item) + mockedDB := &awsAuth.MockDynamoDB{ + PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { + assert.Equal(t, map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("somekey"), + }, + "value": { + S: aws.String(`{"Value":"somevalue"}`), + }, + "ttlInSeconds": { + N: aws.String("180"), + }, + }, input.Item) - return &dynamodb.PutItemOutput{ - Attributes: map[string]*dynamodb.AttributeValue{ - "key": { - S: aws.String("value"), - }, + return &dynamodb.PutItemOutput{ + Attributes: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("value"), }, - }, nil - }, + }, + }, nil }, + } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, ttlAttributeName: "testAttributeName", } req := &state.SetRequest{ @@ -688,7 +878,7 @@ func TestSet(t *testing.T) { "ttlInSeconds": "invalidvalue", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.Error(t, err) assert.Equal(t, "dynamodb error: failed to parse ttlInSeconds: strconv.ParseInt: parsing \"invalidvalue\": invalid syntax", err.Error()) }) @@ -700,10 +890,7 @@ func TestDelete(t *testing.T) { Key: "key", } - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { @@ -715,7 +902,22 @@ func TestDelete(t *testing.T) { }, } - err := ss.Delete(context.Background(), req) + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } + + err := s.Delete(context.Background(), req) require.NoError(t, err) }) @@ -725,10 +927,8 @@ func TestDelete(t *testing.T) { ETag: &etag, Key: "key", } - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + + mockedDB := &awsAuth.MockDynamoDB{ DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { @@ -744,7 +944,22 @@ func TestDelete(t *testing.T) { }, } - err := ss.Delete(context.Background(), req) + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } + + err := s.Delete(context.Background(), req) require.NoError(t, err) }) @@ -755,10 +970,7 @@ func TestDelete(t *testing.T) { Key: "key", } - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { @@ -775,7 +987,21 @@ func TestDelete(t *testing.T) { }, } - err := ss.Delete(context.Background(), req) + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } + err := s.Delete(context.Background(), req) require.Error(t, err) switch tagErr := err.(type) { case *state.ETagError: @@ -786,26 +1012,36 @@ func TestDelete(t *testing.T) { }) t.Run("Unsuccessfully delete item", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { - return nil, errors.New("unable to delete item") - }, + mockedDB := &awsAuth.MockDynamoDB{ + DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { + return nil, errors.New("unable to delete item") }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + } + req := &state.DeleteRequest{ Key: "key", } - err := ss.Delete(context.Background(), req) + err := s.Delete(context.Background(), req) require.Error(t, err) }) } func TestMultiTx(t *testing.T) { t.Run("Successfully Multiple Transaction Operations", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } firstKey := "key1" secondKey := "key2" secondValue := "value2" @@ -829,7 +1065,7 @@ func TestMultiTx(t *testing.T) { }, } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ TransactWriteItemsWithContextFn: func(ctx context.Context, input *dynamodb.TransactWriteItemsInput, op ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) { // ops - duplicates exOps := len(ops) - 1 @@ -853,13 +1089,28 @@ func TestMultiTx(t *testing.T) { return &dynamodb.TransactWriteItemsOutput{}, nil }, } - ss.table = tableName + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + table: tableName, + partitionKey: defaultPartitionKeyName, + } req := &state.TransactionalStateRequest{ Operations: ops, Metadata: map[string]string{}, } - err := ss.Multi(context.Background(), req) + err := s.Multi(context.Background(), req) require.NoError(t, err) }) } diff --git a/tests/certification/bindings/aws/s3/s3_test.go b/tests/certification/bindings/aws/s3/s3_test.go index 16e41abc49..c815901f8c 100644 --- a/tests/certification/bindings/aws/s3/s3_test.go +++ b/tests/certification/bindings/aws/s3/s3_test.go @@ -279,7 +279,7 @@ func S3SForcePathStyle(t *testing.T) { Step(sidecar.Run(sidecarName, append(componentRuntimeOptions(), embedded.WithoutApp(), - embedded.WithComponentsPath("./components/forcePathStyleTrue"), + embedded.WithResourcesPath("./components/forcePathStyleTrue"), embedded.WithDaprGRPCPort(strconv.Itoa(currentGRPCPort)), embedded.WithDaprHTTPPort(strconv.Itoa(currentHTTPPort)), )..., diff --git a/tests/certification/go.mod b/tests/certification/go.mod index 1dc9c0ad44..3fe60b4224 100644 --- a/tests/certification/go.mod +++ b/tests/certification/go.mod @@ -98,6 +98,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/sso v1.23.3 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.31.3 // indirect + github.com/aws/rolesanywhere-credential-helper v1.0.4 // indirect github.com/aws/smithy-go v1.21.0 // indirect github.com/benbjohnson/clock v1.3.5 // indirect github.com/beorn7/perks v1.0.1 // indirect @@ -290,6 +291,7 @@ require ( github.com/tklauser/numcpus v0.6.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.51.0 // indirect + github.com/vmware/vmware-go-kcl v1.5.1 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.1.2 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect @@ -333,6 +335,7 @@ require ( google.golang.org/grpc v1.64.1 // indirect google.golang.org/protobuf v1.34.2 // indirect gopkg.in/inf.v0 v0.9.1 // indirect + gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect k8s.io/api v0.30.2 // indirect diff --git a/tests/certification/go.sum b/tests/certification/go.sum index 145f05a305..d21aec5c3c 100644 --- a/tests/certification/go.sum +++ b/tests/certification/go.sum @@ -234,6 +234,8 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3/go.mod h1:FnvDM4sfa+isJ3kDX github.com/aws/aws-sdk-go-v2/service/sts v1.7.2/go.mod h1:8EzeIqfWt2wWT4rJVu3f21TfrhJ8AEMzVybRNSb/b4g= github.com/aws/aws-sdk-go-v2/service/sts v1.31.3 h1:VzudTFrDCIDakXtemR7l6Qzt2+JYsVqo2MxBPt5k8T8= github.com/aws/aws-sdk-go-v2/service/sts v1.31.3/go.mod h1:yMWe0F+XG0DkRZK5ODZhG7BEFYhLXi2dqGsv6tX0cgI= +github.com/aws/rolesanywhere-credential-helper v1.0.4 h1:kHIVVdyQQiFZoKBP+zywBdFilGCS8It+UvW5LolKbW8= +github.com/aws/rolesanywhere-credential-helper v1.0.4/go.mod h1:QVGNxlDlYhjR0/ZUee7uGl0hNChWidNpe2+GD87Buqk= github.com/aws/smithy-go v1.8.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= github.com/aws/smithy-go v1.21.0 h1:H7L8dtDRk0P1Qm6y0ji7MCYMQObJ5R9CRpyPhRUkLYA= github.com/aws/smithy-go v1.21.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= @@ -1381,6 +1383,8 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g= +github.com/vmware/vmware-go-kcl v1.5.1 h1:1rJLfAX4sDnCyatNoD/WJzVafkwST6u/cgY/Uf2VgHk= +github.com/vmware/vmware-go-kcl v1.5.1/go.mod h1:kXJmQ6h0dRMRrp1uWU9XbIXvwelDpTxSPquvQUBdpbo= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY=